Coverage for src/polars_eval_metrics/metric_evaluator.py: 89%
420 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
1from __future__ import annotations
3"""
4Unified Metric Evaluation Pipeline
6This module implements a simplified, unified evaluation pipeline for computing metrics
7using Polars LazyFrames with comprehensive support for scopes, groups, and subgroups.
8"""
10from collections.abc import Iterable, Mapping
11from typing import Any, Sequence
13# pyre-strict
15import polars as pl
17from .ard import ARD
18from .evaluation_context import EstimateCatalog, FormatterContext, MetricCatalog
19from .metric_define import MetricDefine, MetricScope, MetricType
20from .metric_registry import MetricRegistry, MetricInfo
21from .result_formatter import (
22 build_group_pivot,
23 build_model_pivot,
24 format_compact_frame,
25 format_verbose_frame,
26 convert_to_ard as format_to_ard,
27)
30class MetricEvaluator:
31 """Unified metric evaluation pipeline"""
33 # Instance attributes with type annotations
34 df_raw: pl.LazyFrame
35 ground_truth: str
36 group_by: dict[str, str] # Maps group column names to display labels
37 subgroup_by: dict[str, str] # Maps subgroup column names to display labels
38 filter_expr: pl.Expr | None
39 error_params: dict[str, dict[str, Any]]
40 df: pl.LazyFrame
41 _evaluation_cache: dict[tuple[tuple[str, ...], tuple[str, ...]], ARD]
42 _metric_catalog: MetricCatalog
43 _estimate_catalog: EstimateCatalog
44 _formatter_context: FormatterContext
45 _subgroup_categories: list[str]
47 def __init__(
48 self,
49 df: pl.DataFrame | pl.LazyFrame,
50 metrics: MetricDefine | list[MetricDefine],
51 ground_truth: str = "actual",
52 estimates: str | list[str] | dict[str, str] | None = None,
53 group_by: list[str] | dict[str, str] | None = None,
54 subgroup_by: list[str] | dict[str, str] | None = None,
55 filter_expr: pl.Expr | None = None,
56 error_params: dict[str, dict[str, Any]] | None = None,
57 ) -> None:
58 """Initialize evaluator with complete evaluation context
60 Args:
61 df: Input data as DataFrame or LazyFrame
62 metrics: Metric definitions to evaluate
63 ground_truth: Column name containing ground truth values
64 estimates: Estimate column names. Can be:
65 - str: Single column name
66 - list[str]: List of column names
67 - dict[str, str]: Mapping from column names to display labels
68 group_by: Columns to group by for analysis. Can be:
69 - list[str]: List of column names
70 - dict[str, str]: Mapping from column names to display labels
71 subgroup_by: Columns for subgroup analysis. Can be:
72 - list[str]: List of column names
73 - dict[str, str]: Mapping from column names to display labels
74 filter_expr: Optional filter expression
75 error_params: Parameters for error calculations
76 """
77 # Store data as LazyFrame
78 self.df_raw = df.lazy() if isinstance(df, pl.DataFrame) else df
80 metric_list = [metrics] if isinstance(metrics, MetricDefine) else list(metrics)
81 self._metric_catalog = MetricCatalog(tuple(metric_list))
82 self.ground_truth = ground_truth
84 # Process inputs using dedicated methods
85 self._estimate_catalog = EstimateCatalog.build(
86 self._process_estimates(estimates)
87 )
88 self.group_by = self._process_grouping(group_by)
89 self.subgroup_by = self._process_grouping(subgroup_by)
90 self._subgroup_categories = self._compute_subgroup_categories()
91 self._formatter_context = FormatterContext(
92 group_by=self.group_by,
93 subgroup_by=self.subgroup_by,
94 estimate_catalog=self._estimate_catalog,
95 metric_catalog=self._metric_catalog,
96 subgroup_categories=tuple(self._subgroup_categories),
97 )
98 self.filter_expr = filter_expr
99 self.error_params = error_params or {}
101 # Apply base filter once
102 self.df = self._apply_base_filter()
104 # Initialize evaluation cache
105 self._evaluation_cache = {}
107 # Validate configuration eagerly so errors surface early
108 self._validate_inputs()
110 @property
111 def metrics(self) -> tuple[MetricDefine, ...]:
112 return self._metric_catalog.entries
114 @property
115 def estimates(self) -> Mapping[str, str]:
116 return self._estimate_catalog.key_to_label
118 def _apply_base_filter(self) -> pl.LazyFrame:
119 """Apply initial filter if provided"""
120 if self.filter_expr is not None:
121 return self.df_raw.filter(self.filter_expr)
122 return self.df_raw
124 def _get_cache_key(
125 self,
126 metrics: MetricDefine | list[MetricDefine] | None,
127 estimates: str | list[str] | None,
128 ) -> tuple[tuple[str, ...], tuple[str, ...]]:
129 """Generate cache key for evaluation parameters"""
130 target_metrics = self._resolve_metrics(metrics)
131 target_estimates = self._resolve_estimates(estimates)
133 # Create hashable key from metric names and estimates
134 metric_names = tuple(sorted(m.name for m in target_metrics))
135 estimate_names = tuple(sorted(target_estimates))
137 return (metric_names, estimate_names)
139 def _get_cached_evaluation(
140 self,
141 metrics: MetricDefine | list[MetricDefine] | None = None,
142 estimates: str | list[str] | None = None,
143 ) -> ARD:
144 """Get cached evaluation result or compute and cache if not exists"""
145 cache_key = self._get_cache_key(metrics, estimates)
147 if cache_key not in self._evaluation_cache:
148 if metrics is not None or estimates is not None:
149 filtered_evaluator = self.filter(metrics=metrics, estimates=estimates)
150 ard_result = filtered_evaluator._evaluate_ard(
151 metrics=metrics, estimates=estimates
152 )
153 else:
154 ard_result = self._evaluate_ard(metrics=metrics, estimates=estimates)
155 self._evaluation_cache[cache_key] = ard_result
157 return self._evaluation_cache[cache_key]
159 def clear_cache(self) -> None:
160 """Clear the evaluation cache"""
161 self._evaluation_cache.clear()
163 def filter(
164 self,
165 *,
166 metrics: MetricDefine | list[MetricDefine] | None = None,
167 estimates: str | list[str] | None = None,
168 ) -> "MetricEvaluator":
169 """Return a new evaluator scoped to the requested metrics or estimates."""
171 if metrics is not None:
172 filtered_metrics = self._resolve_metrics(metrics)
173 else:
174 filtered_metrics = list(self.metrics)
175 filtered_estimate_keys = (
176 self._resolve_estimates(estimates)
177 if estimates is not None
178 else list(self.estimates.keys())
179 )
180 filtered_estimates = {
181 key: self.estimates[key] for key in filtered_estimate_keys
182 }
184 return MetricEvaluator(
185 df=self.df,
186 metrics=filtered_metrics,
187 ground_truth=self.ground_truth,
188 estimates=filtered_estimates,
189 group_by=self.group_by,
190 subgroup_by=self.subgroup_by,
191 filter_expr=None,
192 error_params=self.error_params,
193 )
195 def _evaluate_ard(
196 self,
197 metrics: MetricDefine | list[MetricDefine] | None = None,
198 estimates: str | list[str] | None = None,
199 ) -> ARD:
200 """Internal helper that returns evaluation results as ARD."""
202 target_metrics = self._resolve_metrics(metrics)
203 target_estimates = self._resolve_estimates(estimates)
205 if not target_metrics or not target_estimates:
206 raise ValueError("No metrics or estimates to evaluate")
208 combined = self._vectorized_evaluate(target_metrics, target_estimates)
209 formatted = self._format_result(combined)
210 return self._convert_to_ard(formatted)
212 def evaluate(
213 self,
214 metrics: MetricDefine | list[MetricDefine] | None = None,
215 estimates: str | list[str] | None = None,
216 *,
217 collect: bool = True,
218 verbose: bool = False,
219 ) -> pl.LazyFrame | pl.DataFrame:
220 """
221 Unified evaluation method returning ARD format.
223 Args:
224 metrics: Subset of metrics to evaluate (None = use all configured)
225 estimates: Subset of estimates to evaluate (None = use all configured)
226 collect: When False, return a ``LazyFrame`` rather than materialising results.
227 verbose: When True, include struct columns (``id``, ``groups``, ``subgroups``)
228 and diagnostic fields (``stat``, ``stat_fmt``, ``context``, ``warning``,
229 ``error``). When False, struct values are flattened and diagnostics are
230 dropped for a compact table view.
232 Returns:
233 ``polars.DataFrame`` when ``collect`` is True (verbose controls struct
234 flattening), or a ``LazyFrame`` when ``collect`` is False.
235 """
237 ard = self._evaluate_ard(metrics=metrics, estimates=estimates)
239 if not collect:
240 return ard.lazy
242 if verbose:
243 return format_verbose_frame(ard)
244 return format_compact_frame(ard)
246 def _convert_to_ard(self, result_lf: pl.LazyFrame) -> ARD:
247 return format_to_ard(result_lf, self._formatter_context)
249 def pivot_by_group(
250 self,
251 metrics: MetricDefine | list[MetricDefine] | None = None,
252 estimates: str | list[str] | None = None,
253 column_order_by: str = "metrics",
254 row_order_by: str = "group",
255 ) -> pl.DataFrame:
256 """
257 Pivot results with groups as rows and model x metric as columns.
259 Args:
260 metrics: Subset of metrics to evaluate (None = use all configured)
261 estimates: Subset of estimates to evaluate (None = use all configured)
262 column_order_by: Column ordering strategy ("metrics" or "estimates")
263 row_order_by: Row ordering strategy ("group" or "subgroup")
265 Returns:
266 DataFrame with group combinations as rows and metric columns
267 """
268 long_df = self._collect_long_dataframe(metrics=metrics, estimates=estimates)
269 return build_group_pivot(
270 long_df,
271 self._formatter_context,
272 column_order_by=column_order_by,
273 row_order_by=row_order_by,
274 )
276 def pivot_by_model(
277 self,
278 metrics: MetricDefine | list[MetricDefine] | None = None,
279 estimates: str | list[str] | None = None,
280 column_order_by: str = "estimates",
281 row_order_by: str = "group",
282 ) -> pl.DataFrame:
283 """
284 Pivot results with models as rows and group x metric as columns.
286 Args:
287 metrics: Subset of metrics to evaluate (None = use all configured)
288 estimates: Subset of estimates to evaluate (None = use all configured)
289 column_order_by: Column ordering strategy ("estimates" or "metrics")
290 row_order_by: Row ordering strategy ("group" or "subgroup")
292 Returns:
293 DataFrame with model combinations as rows and group+metric columns
294 """
295 long_df = self._collect_long_dataframe(metrics=metrics, estimates=estimates)
296 return build_model_pivot(
297 long_df,
298 self._formatter_context,
299 column_order_by=column_order_by,
300 row_order_by=row_order_by,
301 )
303 def _resolve_metrics(
304 self, metrics: MetricDefine | list[MetricDefine] | None
305 ) -> list[MetricDefine]:
306 """Resolve which metrics to evaluate"""
307 if metrics is None:
308 return list(self.metrics)
310 metrics_list = [metrics] if isinstance(metrics, MetricDefine) else metrics
311 configured_names = {m.name for m in self.metrics}
313 for m in metrics_list:
314 if m.name not in configured_names:
315 raise ValueError(f"Metric '{m.name}' not in configured metrics")
317 return metrics_list
319 def _resolve_estimates(self, estimates: str | list[str] | None) -> list[str]:
320 """Resolve which estimates to evaluate"""
321 if estimates is None:
322 return list(self.estimates.keys())
324 estimates_list = [estimates] if isinstance(estimates, str) else estimates
326 for e in estimates_list:
327 if e not in self.estimates:
328 raise ValueError(
329 f"Estimate '{e}' not in configured estimates: {list(self.estimates.keys())}"
330 )
332 return estimates_list
334 def _vectorized_evaluate(
335 self, metrics: list[MetricDefine], estimates: list[str]
336 ) -> pl.LazyFrame:
337 """Vectorized evaluation using single Polars group_by operations"""
339 # Step 1: Prepare data in long format with all estimates
340 df_long = self._prepare_long_format_data(estimates)
342 # Step 2: Generate all error columns for the melted data
343 df_with_errors = self._add_error_columns_vectorized(df_long)
345 # Step 3: Handle marginal subgroup analysis if needed
346 if self.subgroup_by:
347 return self._evaluate_with_marginal_subgroups(
348 df_with_errors, metrics, estimates
349 )
350 else:
351 return self._evaluate_without_subgroups(df_with_errors, metrics, estimates)
353 def _evaluate_without_subgroups(
354 self,
355 df_with_errors: pl.LazyFrame,
356 metrics: list[MetricDefine],
357 estimates: list[str],
358 ) -> pl.LazyFrame:
359 """Evaluate metrics without subgroup analysis"""
360 results = []
361 for metric in metrics:
362 metric_result = self._evaluate_metric_vectorized(
363 df_with_errors, metric, estimates
364 )
365 results.append(metric_result)
367 # Combine results (no schema harmonization needed with fixed evaluation structure)
368 if results:
369 return pl.concat(results, how="diagonal")
370 else:
371 return pl.DataFrame().lazy()
373 def _evaluate_with_marginal_subgroups(
374 self,
375 df_with_errors: pl.LazyFrame,
376 metrics: list[MetricDefine],
377 estimates: list[str],
378 ) -> pl.LazyFrame:
379 """Evaluate metrics with marginal subgroup analysis using vectorized operations"""
380 # Create all subgroup combinations using vectorized unpivot
381 subgroup_data = self._prepare_subgroup_data_vectorized(
382 df_with_errors, self.subgroup_by
383 )
385 # Evaluate all metrics across all subgroups in a vectorized manner
386 results = []
387 for metric in metrics:
388 metric_result = self._evaluate_metric_vectorized(
389 subgroup_data, metric, estimates
390 )
391 results.append(metric_result)
393 # Combine results (no schema harmonization needed with fixed evaluation structure)
394 return pl.concat(results, how="diagonal")
396 def _prepare_subgroup_data_vectorized(
397 self, df_with_errors: pl.LazyFrame, subgroup_by: dict[str, str]
398 ) -> pl.LazyFrame:
399 """Prepare subgroup data using vectorized unpivot operations"""
400 schema_names = df_with_errors.collect_schema().names()
401 subgroup_cols = list(subgroup_by.keys())
402 id_vars = [col for col in schema_names if col not in subgroup_cols]
404 # Use unpivot to create marginal subgroup analysis
405 return df_with_errors.unpivot(
406 index=id_vars,
407 on=subgroup_cols,
408 variable_name="subgroup_name",
409 value_name="subgroup_value",
410 ).with_columns(
411 [
412 # Replace subgroup column names with their display labels
413 pl.col("subgroup_name").replace(subgroup_by)
414 ]
415 )
417 def _prepare_long_format_data(self, estimates: list[str]) -> pl.LazyFrame:
418 """Reshape data from wide to long format for vectorized processing"""
420 # Add a row index to the original data to uniquely identify each sample
421 # This must be done BEFORE unpivoting to avoid double counting
422 df_with_index = self.df.with_row_index("sample_index")
424 # Get all columns except estimates to preserve in melt
425 schema_names = df_with_index.collect_schema().names()
426 id_vars = [col for col in schema_names if col not in estimates]
428 # Unpivot estimates into long format
429 df_long = df_with_index.unpivot(
430 index=id_vars,
431 on=estimates,
432 variable_name="estimate_name",
433 value_name="estimate_value",
434 )
436 # Preserve canonical estimate key alongside optional display label
437 label_mapping = self._estimate_catalog.key_to_label
438 df_long = (
439 df_long.rename({self.ground_truth: "ground_truth"})
440 .with_columns(
441 [
442 pl.col("estimate_name").alias("estimate"),
443 pl.col("estimate_name")
444 .cast(pl.Utf8)
445 .map_elements(
446 lambda val, mapping=label_mapping: mapping.get(val, val),
447 return_dtype=pl.Utf8,
448 )
449 .alias("estimate_label"),
450 ]
451 )
452 .drop("estimate_name")
453 )
455 return df_long
457 def _add_error_columns_vectorized(self, df_long: pl.LazyFrame) -> pl.LazyFrame:
458 """Add error columns for the long-format data"""
460 # Generate error expressions for the vectorized format
461 # Use 'estimate_value' as the estimate column and 'ground_truth' as the renamed ground truth column
462 error_expressions = MetricRegistry.generate_error_columns(
463 estimate="estimate_value",
464 ground_truth="ground_truth",
465 error_types=None,
466 error_params=self.error_params,
467 )
469 return df_long.with_columns(error_expressions)
471 def _evaluate_metric_vectorized(
472 self, df_with_errors: pl.LazyFrame, metric: MetricDefine, estimates: list[str]
473 ) -> pl.LazyFrame:
474 """Evaluate a single metric using vectorized operations"""
476 group_cols = self._get_vectorized_grouping_columns(metric, df_with_errors)
477 within_infos, across_info = metric.compile_expressions()
478 df_filtered = self._apply_metric_scope_filter(df_with_errors, metric, estimates)
480 handlers = {
481 MetricType.ACROSS_SAMPLE: self._evaluate_across_sample_metric,
482 MetricType.WITHIN_SUBJECT: self._evaluate_within_entity_metric,
483 MetricType.WITHIN_VISIT: self._evaluate_within_entity_metric,
484 MetricType.ACROSS_SUBJECT: self._evaluate_two_stage_metric,
485 MetricType.ACROSS_VISIT: self._evaluate_two_stage_metric,
486 }
488 handler = handlers.get(metric.type)
489 if handler is None:
490 raise ValueError(f"Unknown metric type: {metric.type}")
492 try:
493 result, result_info = handler(
494 df_filtered,
495 metric,
496 group_cols,
497 within_infos,
498 across_info,
499 )
500 return self._add_metadata_vectorized(result, metric, result_info)
501 except Exception as exc:
502 fallback_info = self._fallback_metric_info(
503 metric, within_infos, across_info
504 )
505 placeholder = self._prepare_error_lazyframe(group_cols)
506 return self._add_metadata_vectorized(
507 placeholder,
508 metric,
509 fallback_info,
510 warnings_list=[],
511 errors_list=[self._format_exception_message(exc, metric.name)],
512 )
514 def _evaluate_across_sample_metric(
515 self,
516 df: pl.LazyFrame,
517 metric: MetricDefine,
518 group_cols: list[str],
519 _within_infos: Sequence[MetricInfo] | None,
520 across_info: MetricInfo | None,
521 ) -> tuple[pl.LazyFrame, MetricInfo]:
522 if across_info is None:
523 raise ValueError(f"ACROSS_SAMPLE metric {metric.name} requires across_expr")
525 agg_exprs = self._metric_agg_expressions(across_info)
526 result = self._aggregate_lazyframe(df, group_cols, agg_exprs)
527 return result, across_info
529 def _evaluate_within_entity_metric(
530 self,
531 df: pl.LazyFrame,
532 metric: MetricDefine,
533 group_cols: list[str],
534 within_infos: Sequence[MetricInfo] | None,
535 across_info: MetricInfo | None,
536 ) -> tuple[pl.LazyFrame, MetricInfo]:
537 entity_groups = self._merge_group_columns(
538 self._get_entity_grouping_columns(metric.type), group_cols
539 )
541 result_info = self._resolve_metric_info(
542 metric,
543 primary=within_infos,
544 fallback=across_info,
545 error_message=f"No valid expression for metric {metric.name}",
546 )
548 agg_exprs = self._metric_agg_expressions(result_info)
549 result = self._aggregate_lazyframe(df, entity_groups, agg_exprs)
550 return result, result_info
552 def _evaluate_two_stage_metric(
553 self,
554 df: pl.LazyFrame,
555 metric: MetricDefine,
556 group_cols: list[str],
557 within_infos: Sequence[MetricInfo] | None,
558 across_info: MetricInfo | None,
559 ) -> tuple[pl.LazyFrame, MetricInfo]:
560 entity_groups = self._merge_group_columns(
561 self._get_entity_grouping_columns(metric.type), group_cols
562 )
564 base_info = self._resolve_metric_info(
565 metric,
566 primary=within_infos,
567 fallback=across_info,
568 error_message=(
569 f"No valid expression for first level of metric {metric.name}"
570 ),
571 )
573 intermediate = self._aggregate_lazyframe(
574 df,
575 entity_groups,
576 self._metric_agg_expressions(base_info),
577 )
579 if across_info is not None and within_infos:
580 result_info = across_info
581 agg_exprs = self._metric_agg_expressions(result_info)
582 else:
583 result_info = base_info
584 agg_exprs = [pl.col("value").mean().alias("value")]
586 result = self._aggregate_lazyframe(intermediate, group_cols, agg_exprs)
587 return result, result_info
589 @staticmethod
590 def _merge_group_columns(
591 *column_groups: Sequence[str],
592 ) -> list[str]:
593 seen: set[str] = set()
594 ordered: list[str] = []
595 for columns in column_groups:
596 for col in columns:
597 if col not in seen:
598 seen.add(col)
599 ordered.append(col)
600 return ordered
602 def _resolve_metric_info(
603 self,
604 metric: MetricDefine,
605 *,
606 primary: Sequence[MetricInfo] | None,
607 fallback: MetricInfo | None,
608 error_message: str,
609 ) -> MetricInfo:
610 if primary:
611 return primary[0]
612 if fallback is not None:
613 return fallback
614 raise ValueError(error_message)
616 @staticmethod
617 def _aggregate_lazyframe(
618 df: pl.LazyFrame, group_cols: Sequence[str], agg_exprs: Sequence[pl.Expr]
619 ) -> pl.LazyFrame:
620 columns = [col for col in group_cols if col]
621 if columns:
622 return df.group_by(columns).agg(agg_exprs)
623 return df.select(*agg_exprs)
625 def _get_vectorized_grouping_columns(
626 self, metric: MetricDefine, df: pl.LazyFrame | None = None
627 ) -> list[str]:
628 """Get grouping columns for vectorized evaluation based on metric scope"""
630 schema_names: set[str]
631 if df is not None:
632 schema_names = set(df.collect_schema().names())
633 else:
634 schema_names = set(self.df.collect_schema().names())
636 using_vectorized_subgroups = {
637 "subgroup_name",
638 "subgroup_value",
639 }.issubset(schema_names)
641 def existing(columns: Iterable[str]) -> list[str]:
642 return [col for col in columns if col in schema_names]
644 group_cols: list[str] = []
646 if metric.scope == MetricScope.GLOBAL:
647 subgroup_cols = (
648 ["subgroup_name", "subgroup_value"]
649 if using_vectorized_subgroups
650 else existing(self.subgroup_by.keys())
651 )
652 group_cols.extend(subgroup_cols)
653 elif metric.scope == MetricScope.MODEL:
654 model_cols = existing(["estimate"])
655 subgroup_cols = (
656 ["subgroup_name", "subgroup_value"]
657 if using_vectorized_subgroups
658 else existing(self.subgroup_by.keys())
659 )
660 group_cols.extend(model_cols + subgroup_cols)
661 elif metric.scope == MetricScope.GROUP:
662 group_cols.extend(existing(self.group_by.keys()))
663 subgroup_cols = (
664 ["subgroup_name", "subgroup_value"]
665 if using_vectorized_subgroups
666 else existing(self.subgroup_by.keys())
667 )
668 group_cols.extend(subgroup_cols)
669 else:
670 group_cols.extend(existing(["estimate"]))
671 group_cols.extend(existing(self.group_by.keys()))
672 subgroup_cols = (
673 ["subgroup_name", "subgroup_value"]
674 if using_vectorized_subgroups
675 else existing(self.subgroup_by.keys())
676 )
677 group_cols.extend(subgroup_cols)
679 return self._merge_group_columns(group_cols)
681 def _apply_metric_scope_filter(
682 self, df: pl.LazyFrame, metric: MetricDefine, estimates: list[str]
683 ) -> pl.LazyFrame:
684 """Apply any scope-specific filtering"""
685 # For now, no additional filtering needed beyond grouping
686 # Future: could add estimate filtering for specific scopes
687 _ = metric, estimates # Suppress unused parameter warnings
688 return df
690 def _metric_agg_expressions(self, info: MetricInfo) -> list[pl.Expr]:
691 return [info.expr.alias("value")]
693 def _get_entity_grouping_columns(self, metric_type: MetricType) -> list[str]:
694 """Get entity-level grouping columns (subject_id, visit_id)"""
695 if metric_type in [MetricType.WITHIN_SUBJECT, MetricType.ACROSS_SUBJECT]:
696 return ["subject_id"]
697 elif metric_type in [MetricType.WITHIN_VISIT, MetricType.ACROSS_VISIT]:
698 return ["subject_id", "visit_id"]
699 else:
700 return []
702 def _add_metadata_vectorized(
703 self,
704 result: pl.LazyFrame,
705 metric: MetricDefine,
706 info: MetricInfo,
707 *,
708 warnings_list: Sequence[str] | None = None,
709 errors_list: Sequence[str] | None = None,
710 ) -> pl.LazyFrame:
711 """Add metadata columns to vectorized result"""
713 schema = result.collect_schema()
714 value_dtype = schema.get("value") if "value" in schema.names() else None
715 value_kind = (
716 (info.value_kind or "").lower()
717 if info.value_kind
718 else self._infer_value_kind_from_dtype(value_dtype)
719 )
720 if not value_kind:
721 value_kind = "float"
723 metadata_columns = [
724 pl.lit(metric.name).cast(pl.Utf8).alias("metric"),
725 pl.lit(metric.label).cast(pl.Utf8).alias("label"),
726 pl.lit(metric.type.value).cast(pl.Utf8).alias("metric_type"),
727 pl.lit(metric.scope.value if metric.scope else None)
728 .cast(pl.Utf8)
729 .alias("scope"),
730 pl.lit(value_kind).cast(pl.Utf8).alias("_value_kind"),
731 pl.lit(info.format).cast(pl.Utf8).alias("_value_format"),
732 ]
734 result = result.with_columns(metadata_columns)
736 diagnostics_columns = [
737 pl.lit(list(warnings_list or []), dtype=pl.List(pl.Utf8)).alias(
738 "_diagnostic_warning"
739 ),
740 pl.lit(list(errors_list or []), dtype=pl.List(pl.Utf8)).alias(
741 "_diagnostic_error"
742 ),
743 ]
745 result = result.with_columns(diagnostics_columns)
747 # Attach entity identifiers for within-entity metrics
748 result = self._attach_entity_identifier(result, metric)
750 # Helper columns for stat struct construction
751 helper_columns = [
752 pl.lit(None, dtype=pl.Float64).alias("_value_float"),
753 pl.lit(None, dtype=pl.Int64).alias("_value_int"),
754 pl.lit(None, dtype=pl.Boolean).alias("_value_bool"),
755 pl.lit(None, dtype=pl.Utf8).alias("_value_str"),
756 pl.lit(None).alias("_value_struct"),
757 ]
758 result = result.with_columns(helper_columns)
760 if value_kind == "int":
761 result = result.with_columns(
762 pl.col("value").cast(pl.Int64, strict=False).alias("_value_int"),
763 pl.col("value").cast(pl.Float64, strict=False).alias("_value_float"),
764 pl.col("value").cast(pl.Float64, strict=False),
765 )
766 elif value_kind == "float":
767 result = result.with_columns(
768 pl.col("value").cast(pl.Float64, strict=False).alias("_value_float"),
769 pl.col("value").cast(pl.Float64, strict=False),
770 )
771 elif value_kind == "bool":
772 result = result.with_columns(
773 pl.col("value").cast(pl.Boolean, strict=False).alias("_value_bool"),
774 pl.lit(None, dtype=pl.Float64).alias("value"),
775 )
776 elif value_kind == "string":
777 result = result.with_columns(
778 pl.col("value").cast(pl.Utf8, strict=False).alias("_value_str"),
779 pl.lit(None, dtype=pl.Float64).alias("value"),
780 )
781 elif value_kind == "struct":
782 result = result.with_columns(
783 pl.col("value").alias("_value_struct"),
784 pl.lit(None, dtype=pl.Float64).alias("value"),
785 )
786 else:
787 result = result.with_columns(
788 pl.col("value").cast(pl.Utf8, strict=False).alias("_value_str"),
789 pl.lit(None, dtype=pl.Float64).alias("value"),
790 )
792 return result
794 @staticmethod
795 def _format_exception_message(exc: Exception, metric_name: str) -> str:
796 return f"{metric_name}: {type(exc).__name__}: {exc}".strip()
798 @staticmethod
799 def _fallback_metric_info(
800 metric: MetricDefine,
801 within_infos: Sequence[MetricInfo] | None,
802 across_info: MetricInfo | None,
803 ) -> MetricInfo:
804 if within_infos:
805 return within_infos[0]
806 if across_info is not None:
807 return across_info
808 return MetricInfo(expr=pl.lit(None).alias("value"), value_kind="float")
810 @staticmethod
811 def _prepare_error_lazyframe(group_cols: Sequence[str]) -> pl.LazyFrame:
812 data: dict[str, list[Any]] = {}
813 for col in group_cols:
814 data[col] = [None]
815 data["value"] = [None]
816 return pl.DataFrame(data, strict=False).lazy()
818 @staticmethod
819 def _infer_value_kind_from_dtype(dtype: pl.DataType | None) -> str:
820 """Map Polars dtypes to MetricInfo value_kind labels."""
822 if dtype is None or dtype == pl.Null:
823 return "float"
824 if dtype == pl.Struct:
825 return "struct"
826 if dtype == pl.Boolean:
827 return "bool"
828 if dtype == pl.Utf8:
829 return "string"
830 if hasattr(dtype, "is_numeric") and dtype.is_numeric():
831 return "int" if dtype.is_integer() else "float"
832 return "string"
834 def _attach_entity_identifier(
835 self, result: pl.LazyFrame, metric: MetricDefine
836 ) -> pl.LazyFrame:
837 """Attach a canonical id struct for within-entity metrics."""
839 entity_cols = self._get_entity_grouping_columns(metric.type)
840 if not entity_cols:
841 if "id" in result.collect_schema().names():
842 return result
843 return result.with_columns(pl.lit(None).alias("id"))
845 schema = result.collect_schema()
846 available = set(schema.names())
847 present = [col for col in entity_cols if col in available]
849 if not present:
850 if "id" in available:
851 return result
852 return result.with_columns(pl.lit(None).alias("id"))
854 id_struct = pl.struct([pl.col(col).alias(col) for col in present]).alias("id")
855 cleaned = result.with_columns(id_struct)
857 # Drop entity columns to avoid polluting downstream schema
858 return cleaned.drop(present)
860 def _format_result(self, combined: pl.LazyFrame) -> pl.LazyFrame:
861 """Minimal formatting - ARD handles all presentation concerns"""
862 return combined
864 # ------------------------------------------------------------------
865 # Result shaping helpers
866 # ------------------------------------------------------------------
868 def _collect_long_dataframe(
869 self,
870 metrics: MetricDefine | list[MetricDefine] | None = None,
871 estimates: str | list[str] | None = None,
872 ) -> pl.DataFrame:
873 """Collect evaluation results as a flat DataFrame for pivoting."""
875 ard = self._get_cached_evaluation(metrics=metrics, estimates=estimates)
876 lf = ard.lazy
877 schema = lf.collect_schema()
879 exprs: list[pl.Expr] = []
880 estimate_label_map = self._estimate_catalog.key_to_label
882 # Group columns with display labels
883 for col, label in self.group_by.items():
884 if col in schema.names():
885 exprs.append(pl.col(col).alias(label))
887 # Subgroup columns
888 if "subgroup_name" in schema.names():
889 exprs.append(pl.col("subgroup_name"))
890 if "subgroup_value" in schema.names():
891 exprs.append(pl.col("subgroup_value"))
893 # Estimate / metric / label columns
894 if "estimate" in schema.names():
895 exprs.append(pl.col("estimate").cast(pl.Utf8))
896 exprs.append(
897 pl.col("estimate")
898 .cast(pl.Utf8)
899 .map_elements(
900 lambda val, mapping=estimate_label_map: mapping.get(val, val),
901 return_dtype=pl.Utf8,
902 )
903 .alias("estimate_label")
904 )
905 if "metric" in schema.names():
906 exprs.append(pl.col("metric").cast(pl.Utf8))
907 if "label" in schema.names():
908 exprs.append(pl.col("label").cast(pl.Utf8))
909 else:
910 exprs.append(pl.col("metric").cast(pl.Utf8).alias("label"))
912 if "stat" in schema.names():
913 exprs.append(pl.col("stat"))
914 if "stat_fmt" in schema.names():
915 exprs.append(pl.col("stat_fmt"))
916 exprs.append(
917 pl.when(pl.col("stat_fmt").is_null())
918 .then(
919 pl.col("stat").map_elements(
920 ARD._format_stat, return_dtype=pl.Utf8
921 )
922 )
923 .otherwise(pl.col("stat_fmt"))
924 .alias("value")
925 )
926 else:
927 exprs.append(
928 pl.col("stat")
929 .map_elements(ARD._format_stat, return_dtype=pl.Utf8)
930 .alias("value")
931 )
933 if "id" in schema.names():
934 exprs.append(pl.col("id"))
936 # Scope metadata
937 if "metric_type" in schema.names():
938 exprs.append(pl.col("metric_type").cast(pl.Utf8))
939 if "scope" in schema.names():
940 exprs.append(pl.col("scope").cast(pl.Utf8))
942 return lf.select(exprs).collect()
944 # ========================================
945 # INPUT PROCESSING METHODS - Pure Logic
946 # ========================================
948 @staticmethod
949 def _process_estimates(
950 estimates: str | list[str] | dict[str, str] | None,
951 ) -> dict[str, str]:
952 """Pure transformation: normalize estimates to dict format"""
953 if isinstance(estimates, str):
954 return {estimates: estimates}
955 elif isinstance(estimates, dict):
956 return estimates
957 elif isinstance(estimates, list):
958 return {est: est for est in (estimates or [])}
959 else:
960 return {}
962 @staticmethod
963 def _process_grouping(
964 grouping: list[str] | dict[str, str] | None,
965 ) -> dict[str, str]:
966 """Pure transformation: normalize grouping to dict format"""
967 if isinstance(grouping, dict):
968 return grouping
969 elif isinstance(grouping, list):
970 return {col: col for col in (grouping or [])}
971 else:
972 return {}
974 def _compute_subgroup_categories(self) -> list[str]:
975 if not self.subgroup_by:
976 return []
978 categories: list[Any] = []
979 seen: set[Any] = set()
980 schema = self.df_raw.collect_schema()
982 for column in self.subgroup_by.keys():
983 if column not in schema.names():
984 continue
986 dtype = schema[column]
988 if isinstance(dtype, pl.Enum):
989 for value in dtype.categories.to_list():
990 if value not in seen:
991 seen.add(value)
992 categories.append(value)
993 continue
995 ordered_values = self._collect_unique_subgroup_values(column, dtype)
996 for value in ordered_values:
997 if value in seen:
998 continue
999 seen.add(value)
1000 categories.append(value)
1002 return categories
1004 def _collect_unique_subgroup_values(
1005 self, column: str, dtype: pl.DataType
1006 ) -> list[Any]:
1007 """Collect sorted unique subgroup values using lazy execution."""
1009 expr = pl.col(column).drop_nulls()
1011 if not dtype.is_numeric():
1012 expr = expr.cast(pl.Utf8)
1014 lazy_unique = (
1015 self.df_raw.select(expr.alias(column)).unique(subset=[column]).sort(column)
1016 )
1018 df_unique = lazy_unique.collect(engine="streaming")
1019 values = df_unique[column].to_list()
1021 if dtype.is_numeric():
1022 return values
1024 return [str(value) for value in values]
1026 # ========================================
1027 # VALIDATION METHODS - Centralized Logic
1028 # ========================================
1030 def _validate_inputs(self) -> None:
1031 """Validate all inputs after processing"""
1032 if not self.estimates:
1033 raise ValueError("No estimates provided")
1035 if not self.metrics:
1036 raise ValueError("No metrics provided")
1038 # Validate that required columns exist
1039 schema_names = self.df_raw.collect_schema().names()
1041 if self.ground_truth not in schema_names:
1042 raise ValueError(
1043 f"Ground truth column '{self.ground_truth}' not found in data"
1044 )
1046 missing_estimates = [
1047 est for est in self.estimates.keys() if est not in schema_names
1048 ]
1049 if missing_estimates:
1050 raise ValueError(f"Estimate columns not found in data: {missing_estimates}")
1052 missing_groups = [
1053 col for col in self.group_by.keys() if col not in schema_names
1054 ]
1055 if missing_groups:
1056 raise ValueError(f"Group columns not found in data: {missing_groups}")
1058 missing_subgroups = [
1059 col for col in self.subgroup_by.keys() if col not in schema_names
1060 ]
1061 if missing_subgroups:
1062 raise ValueError(f"Subgroup columns not found in data: {missing_subgroups}")
1064 overlap = set(self.group_by.keys()) & set(self.subgroup_by.keys())
1065 if overlap:
1066 raise ValueError(
1067 "Group and subgroup columns must be distinct; found duplicates: "
1068 f"{sorted(overlap)}"
1069 )