Coverage for src/polars_eval_metrics/metric_define.py: 69%
250 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
1"""
2Metric definition and expression preparation
4This module combines metric configuration with expression compilation,
5providing a single class for defining metrics and preparing Polars expressions.
6"""
8# pyre-strict
9from enum import Enum
10from typing import Self
12import polars as pl
13from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
15from .metric_registry import MetricRegistry, MetricInfo
16from .utils import format_polars_expr, format_polars_expr_list, parse_enum_value
19class MetricType(Enum):
20 """Metric aggregation types"""
22 ACROSS_SAMPLE = "across_sample"
23 ACROSS_SUBJECT = "across_subject"
24 WITHIN_SUBJECT = "within_subject"
25 ACROSS_VISIT = "across_visit"
26 WITHIN_VISIT = "within_visit"
29class MetricScope(Enum):
30 """Scope for metric calculation - determines at what level the metric is computed"""
32 GLOBAL = "global" # Calculate once for entire dataset
33 MODEL = "model" # Calculate per model only, ignoring groups
34 GROUP = "group" # Calculate per group only, ignoring models
37class MetricDefine(BaseModel):
38 """
39 Metric definition with hierarchical expression support.
41 This class defines metrics with support for two-level aggregation patterns:
42 - within_expr: Expressions for within-entity aggregation (e.g., within subject/visit)
43 - across_expr: Expressions for across-entity aggregation or final metric computation
45 Attributes:
46 name: Metric identifier
47 label: Display name for the metric
48 type: Aggregation type (ACROSS_SAMPLE, WITHIN_SUBJECT, ACROSS_SUBJECT, etc.)
49 scope: Calculation scope (GLOBAL, MODEL, GROUP) - orthogonal to aggregation type
50 within_expr: Expression(s) for within-entity aggregation:
51 - Used in WITHIN_SUBJECT, ACROSS_SUBJECT, WITHIN_VISIT, ACROSS_VISIT
52 - Not used in ACROSS_SAMPLE (which operates directly on samples)
53 across_expr: Expression for across-entity aggregation or final computation:
54 - For ACROSS_SAMPLE: Applied directly to error columns
55 - For ACROSS_SUBJECT/VISIT: Summarizes within_expr results across entities
56 - For WITHIN_SUBJECT/VISIT: Not used (within_expr is final)
58 Note: within_expr and across_expr are distinct from group_by/subgroup_by which control
59 analysis stratification (e.g., by treatment, age, sex) and apply to ALL metric types.
60 """
62 model_config = ConfigDict(arbitrary_types_allowed=True)
64 name: str = Field(..., description="Metric identifier")
65 label: str | None = None
66 type: MetricType = MetricType.ACROSS_SAMPLE
67 scope: MetricScope | None = None
68 within_expr: list[str | pl.Expr | MetricInfo] | None = None
69 across_expr: str | pl.Expr | MetricInfo | None = None
71 def __init__(self, **kwargs: object) -> None:
72 """Initialize with default label if not provided"""
73 if "label" not in kwargs or kwargs["label"] is None:
74 kwargs["label"] = kwargs.get("name", "Unknown Metric")
75 kwargs.pop("registry", None)
76 super().__init__(**kwargs)
78 @field_validator("name") # pyre-ignore[56]
79 @classmethod
80 def validate_name(cls, v: str) -> str:
81 """Validate metric name is not empty"""
82 if not v or not v.strip():
83 raise ValueError("Metric name cannot be empty")
84 return v.strip()
86 @field_validator("label") # pyre-ignore[56]
87 @classmethod
88 def validate_label(cls, v: str | None) -> str | None:
89 """Validate label is not empty"""
90 if v is None:
91 return None
92 if not v.strip():
93 raise ValueError("Metric label cannot be empty")
94 return v.strip()
96 @field_validator("type", mode="before") # pyre-ignore[56]
97 @classmethod
98 def validate_type(cls, v: object) -> MetricType:
99 """Convert string to MetricType enum if needed"""
100 result = parse_enum_value(v, MetricType, field="metric type")
101 assert isinstance(result, MetricType)
102 return result
104 @field_validator("scope", mode="before") # pyre-ignore[56]
105 @classmethod
106 def validate_scope(cls, v: object) -> MetricScope | None:
107 """Convert string to MetricScope enum if needed"""
108 result = parse_enum_value(v, MetricScope, field="metric scope", allow_none=True)
109 assert result is None or isinstance(result, MetricScope)
110 return result
112 @field_validator("within_expr", mode="before") # pyre-ignore[56]
113 @classmethod
114 def normalize_within_expr(cls, v: object) -> object:
115 """Convert single string to list before validation"""
116 if v is None:
117 return None
118 if isinstance(v, str):
119 return [v] # Convert single string to list
120 if isinstance(v, pl.Expr):
121 return [v] # Convert single expression to list
122 if isinstance(v, MetricInfo):
123 return [v]
124 return v # Already a list or something else
126 @field_validator("within_expr") # pyre-ignore[56]
127 @classmethod
128 def validate_within_expr(
129 cls, v: list[str | pl.Expr | MetricInfo] | None
130 ) -> list[str | pl.Expr | MetricInfo] | None:
131 """Validate within-entity aggregation expressions list"""
132 if v is None:
133 return None
135 if not isinstance(v, list):
136 raise ValueError(
137 f"within_expr must be a list after normalization, got {type(v)}"
138 )
140 if not v:
141 raise ValueError("within_expr cannot be an empty list")
143 for i, item in enumerate(v):
144 if isinstance(item, str):
145 if not item.strip():
146 raise ValueError(
147 f"within_expr[{i}]: Built-in metric name cannot be empty"
148 )
149 elif not isinstance(item, (pl.Expr, MetricInfo)):
150 raise ValueError(
151 f"within_expr[{i}] must be a string (built-in name), Polars expression, or MetricInfo"
152 )
154 return v
156 @field_validator("across_expr") # pyre-ignore[56]
157 @classmethod
158 def validate_across_expr(
159 cls, v: str | pl.Expr | MetricInfo | None
160 ) -> str | pl.Expr | MetricInfo | None:
161 """Validate across-entity expression - built-in selector, Polars expression, or MetricInfo"""
162 if v is None:
163 return None
165 # If it's a string, validate it's not empty (will check if valid built-in during compile)
166 if isinstance(v, str):
167 if not v.strip():
168 raise ValueError("Built-in selector name cannot be empty")
169 return v.strip()
171 # Otherwise it must be a Polars expression or MetricInfo
172 if not isinstance(v, (pl.Expr, MetricInfo)):
173 raise ValueError(
174 "across_expr must be a string (built-in selector), Polars expression, or MetricInfo"
175 )
176 return v
178 @model_validator(mode="after") # pyre-ignore[56]
179 def validate_expressions(self) -> Self:
180 """Validate expression combinations"""
181 is_custom = self.within_expr is not None or self.across_expr is not None
183 if is_custom:
184 # Custom metrics must have at least one expression
185 if self.within_expr is None and self.across_expr is None:
186 raise ValueError(
187 "Custom metrics must have at least within_expr or across_expr"
188 )
190 # For two-level aggregation with multiple within_expr, across_expr is required
191 if self.type in (MetricType.ACROSS_SUBJECT, MetricType.ACROSS_VISIT):
192 if (
193 self.within_expr
194 and len(self.within_expr) > 1
195 and self.across_expr is None
196 ):
197 raise ValueError(
198 f"across_expr required for multiple within expressions in {self.type.value}"
199 )
200 else:
201 # Built-in metrics should follow naming convention
202 if ":" in self.name:
203 parts = self.name.split(":", 1)
204 if len(parts) != 2 or not parts[0] or not parts[1]:
205 raise ValueError(
206 f"Invalid built-in metric name format: {self.name}"
207 )
209 return self
211 def compile_expressions(self) -> tuple[list[MetricInfo], MetricInfo | None]:
212 """
213 Compile this metric's expressions to Polars expressions.
215 Returns:
216 Tuple of (aggregation_expressions, selection_expression)
217 """
218 # Handle custom expressions
219 if self.within_expr is not None or self.across_expr is not None:
220 result = self._compile_custom_expressions()
221 else:
222 # Handle built-in metrics
223 result = self._compile_builtin_expressions()
225 # For ACROSS_SAMPLE, move single expression to selection
226 if self.type == MetricType.ACROSS_SAMPLE:
227 agg_infos, sel_info = result
228 if agg_infos and sel_info is None:
229 return [], agg_infos[0]
231 return result
233 def _compile_custom_expressions(
234 self,
235 ) -> tuple[list[MetricInfo], MetricInfo | None]:
236 """Compile custom metric expressions - assumes all inputs are validated"""
237 within_exprs = self._resolve_within_expressions()
238 across_expr = self._resolve_across_expression()
240 # If only across_expr provided (no within_expr), use it as single aggregation
241 if len(within_exprs) == 0 and across_expr is not None:
242 return [across_expr], None
244 return within_exprs, across_expr
246 def _resolve_within_expressions(self) -> list[MetricInfo]:
247 """Pure implementation: resolve within expressions without validation"""
248 if self.within_expr is None:
249 return []
251 return [self._ensure_metric_info(item) for item in self.within_expr]
253 def _resolve_across_expression(self) -> MetricInfo | None:
254 """Pure implementation: resolve across expression without validation"""
255 if self.across_expr is None:
256 return None
258 expr = (
259 MetricRegistry.get_summary(self.across_expr)
260 if isinstance(self.across_expr, str)
261 else self.across_expr
262 )
263 return self._ensure_metric_info(expr)
265 def _compile_builtin_expressions(
266 self,
267 ) -> tuple[list[MetricInfo], MetricInfo | None]:
268 """Compile built-in metric expressions"""
269 parts = (self.name + ":").split(":")[:2]
270 agg_name, select_name = parts[0], parts[1] if parts[1] else None
272 # Get built-in aggregation expression (already a Polars expression)
273 try:
274 agg_info = MetricRegistry.get_metric(agg_name)
275 except ValueError:
276 raise ValueError(f"Unknown built-in metric: {agg_name}")
278 # Get selector expression if specified (already a Polars expression)
279 select_expr = None
280 if select_name:
281 try:
282 select_expr = MetricRegistry.get_summary(select_name)
283 except ValueError:
284 raise ValueError(f"Unknown built-in selector: {select_name}")
285 # If there's a selector, return as aggregation + selection
286 return [agg_info], self._ensure_metric_info(select_expr)
288 # No selector: this is likely ACROSS_SAMPLE, return as selection only
289 return [], agg_info
291 def _ensure_metric_info(self, value: object) -> MetricInfo:
292 if isinstance(value, MetricInfo):
293 return value
294 if isinstance(value, str):
295 return MetricRegistry.get_metric(value)
296 if isinstance(value, pl.Expr):
297 return MetricInfo(expr=value)
298 raise TypeError(f"Unsupported metric expression type: {type(value)!r}")
300 def get_pl_chain(self) -> str:
301 """
302 Get a string representation of the Polars LazyFrame chain for this metric.
304 Returns:
305 String showing the LazyFrame operations that would be executed
306 """
307 agg_infos, select_info = self.compile_expressions()
308 agg_exprs = [info.expr for info in agg_infos]
309 select_expr = select_info.expr if select_info is not None else None
311 chain_lines = ["(", " pl.LazyFrame"]
313 # Determine the chain based on metric type
314 if self.type == MetricType.ACROSS_SAMPLE:
315 # Simple aggregation across all samples
316 if select_expr is not None:
317 chain_lines.append(f" .select({format_polars_expr(select_expr)})")
318 elif agg_exprs:
319 if len(agg_exprs) == 1:
320 chain_lines.append(f" .select({format_polars_expr(agg_exprs[0])})")
321 else:
322 chain_lines.append(" .select(")
323 chain_lines.append(format_polars_expr_list(agg_exprs))
324 chain_lines.append(" )")
326 elif self.type == MetricType.WITHIN_SUBJECT:
327 # Group by subject, then aggregate
328 chain_lines.append(" .group_by('subject_id')")
329 if select_expr is not None:
330 chain_lines.append(f" .agg({format_polars_expr(select_expr)})")
331 elif agg_exprs:
332 if len(agg_exprs) == 1:
333 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})")
334 else:
335 chain_lines.append(" .agg(")
336 chain_lines.append(format_polars_expr_list(agg_exprs))
337 chain_lines.append(" )")
339 elif self.type == MetricType.ACROSS_SUBJECT:
340 # Two-level: group by subject, aggregate, then aggregate across
341 if agg_exprs:
342 chain_lines.append(" .group_by('subject_id')")
343 if len(agg_exprs) == 1:
344 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})")
345 else:
346 chain_lines.append(" .agg(")
347 chain_lines.append(format_polars_expr_list(agg_exprs))
348 chain_lines.append(" )")
349 if select_expr is not None:
350 chain_lines.append(f" .select({format_polars_expr(select_expr)})")
352 elif self.type == MetricType.WITHIN_VISIT:
353 # Group by subject and visit
354 chain_lines.append(" .group_by(['subject_id', 'visit_id'])")
355 if select_expr is not None:
356 chain_lines.append(f" .agg({format_polars_expr(select_expr)})")
357 elif agg_exprs:
358 if len(agg_exprs) == 1:
359 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})")
360 else:
361 chain_lines.append(" .agg(")
362 chain_lines.append(format_polars_expr_list(agg_exprs))
363 chain_lines.append(" )")
365 elif self.type == MetricType.ACROSS_VISIT:
366 # Two-level: group by visit, aggregate, then aggregate across
367 if agg_exprs:
368 chain_lines.append(" .group_by(['subject_id', 'visit_id'])")
369 if len(agg_exprs) == 1:
370 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})")
371 else:
372 chain_lines.append(" .agg(")
373 chain_lines.append(format_polars_expr_list(agg_exprs))
374 chain_lines.append(" )")
375 if select_expr is not None:
376 chain_lines.append(f" .select({format_polars_expr(select_expr)})")
378 chain_lines.append(")")
380 return "\n".join(chain_lines)
382 def __str__(self) -> str:
383 """String representation for display"""
384 lines = [f"MetricDefine(name='{self.name}', type={self.type.value})"]
385 lines.append(f" Label: '{self.label}'")
386 if self.scope is not None:
387 lines.append(f" Scope: {self.scope.value}")
389 try:
390 agg_infos, select_info = self.compile_expressions()
391 agg_exprs = [info.expr for info in agg_infos]
392 select_expr = select_info.expr if select_info is not None else None
394 # Determine base metric and selector names
395 if ":" in self.name:
396 base_name, selector_name = self.name.split(":", 1)
397 else:
398 base_name = self.name
399 selector_name = None
401 # Show within-entity expressions (only if they exist)
402 if agg_exprs:
403 lines.append(" Within-entity expressions:")
404 for i, expr in enumerate(agg_exprs):
405 # Determine source for each expression
406 if self.within_expr is not None and i < len(self.within_expr):
407 item = self.within_expr[i]
408 source = item if isinstance(item, str) else "custom"
409 else:
410 source = base_name # From metric name
411 lines.append(f" - [{source}] {expr}")
413 # Show across-entity expression (only if it exists)
414 if select_expr is not None:
415 lines.append(" Across-entity expression:")
416 # Determine source for selection expression
417 if isinstance(self.across_expr, str):
418 source = self.across_expr # Built-in selector name
419 elif self.across_expr is not None:
420 source = "custom" # Custom expression
421 elif selector_name:
422 source = selector_name # From metric name's selector part
423 else:
424 source = base_name # From metric name
425 lines.append(f" - [{source}] {select_expr}")
427 # Add the LazyFrame chain
428 lines.append("")
429 lines.append(self.get_pl_chain())
431 except Exception as e:
432 lines.append(f" Error compiling expressions: {str(e)}")
434 return "\n".join(lines)
436 def __repr__(self) -> str:
437 """Representation for interactive display"""
438 return self.__str__()