Coverage for src/polars_eval_metrics/metric_evaluator.py: 89%

420 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 15:04 +0000

1from __future__ import annotations 

2 

3""" 

4Unified Metric Evaluation Pipeline 

5 

6This module implements a simplified, unified evaluation pipeline for computing metrics 

7using Polars LazyFrames with comprehensive support for scopes, groups, and subgroups. 

8""" 

9 

10from collections.abc import Iterable, Mapping 

11from typing import Any, Sequence 

12 

13# pyre-strict 

14 

15import polars as pl 

16 

17from .ard import ARD 

18from .evaluation_context import EstimateCatalog, FormatterContext, MetricCatalog 

19from .metric_define import MetricDefine, MetricScope, MetricType 

20from .metric_registry import MetricRegistry, MetricInfo 

21from .result_formatter import ( 

22 build_group_pivot, 

23 build_model_pivot, 

24 format_compact_frame, 

25 format_verbose_frame, 

26 convert_to_ard as format_to_ard, 

27) 

28 

29 

30class MetricEvaluator: 

31 """Unified metric evaluation pipeline""" 

32 

33 # Instance attributes with type annotations 

34 df_raw: pl.LazyFrame 

35 ground_truth: str 

36 group_by: dict[str, str] # Maps group column names to display labels 

37 subgroup_by: dict[str, str] # Maps subgroup column names to display labels 

38 filter_expr: pl.Expr | None 

39 error_params: dict[str, dict[str, Any]] 

40 df: pl.LazyFrame 

41 _evaluation_cache: dict[tuple[tuple[str, ...], tuple[str, ...]], ARD] 

42 _metric_catalog: MetricCatalog 

43 _estimate_catalog: EstimateCatalog 

44 _formatter_context: FormatterContext 

45 _subgroup_categories: list[str] 

46 

47 def __init__( 

48 self, 

49 df: pl.DataFrame | pl.LazyFrame, 

50 metrics: MetricDefine | list[MetricDefine], 

51 ground_truth: str = "actual", 

52 estimates: str | list[str] | dict[str, str] | None = None, 

53 group_by: list[str] | dict[str, str] | None = None, 

54 subgroup_by: list[str] | dict[str, str] | None = None, 

55 filter_expr: pl.Expr | None = None, 

56 error_params: dict[str, dict[str, Any]] | None = None, 

57 ) -> None: 

58 """Initialize evaluator with complete evaluation context 

59 

60 Args: 

61 df: Input data as DataFrame or LazyFrame 

62 metrics: Metric definitions to evaluate 

63 ground_truth: Column name containing ground truth values 

64 estimates: Estimate column names. Can be: 

65 - str: Single column name 

66 - list[str]: List of column names 

67 - dict[str, str]: Mapping from column names to display labels 

68 group_by: Columns to group by for analysis. Can be: 

69 - list[str]: List of column names 

70 - dict[str, str]: Mapping from column names to display labels 

71 subgroup_by: Columns for subgroup analysis. Can be: 

72 - list[str]: List of column names 

73 - dict[str, str]: Mapping from column names to display labels 

74 filter_expr: Optional filter expression 

75 error_params: Parameters for error calculations 

76 """ 

77 # Store data as LazyFrame 

78 self.df_raw = df.lazy() if isinstance(df, pl.DataFrame) else df 

79 

80 metric_list = [metrics] if isinstance(metrics, MetricDefine) else list(metrics) 

81 self._metric_catalog = MetricCatalog(tuple(metric_list)) 

82 self.ground_truth = ground_truth 

83 

84 # Process inputs using dedicated methods 

85 self._estimate_catalog = EstimateCatalog.build( 

86 self._process_estimates(estimates) 

87 ) 

88 self.group_by = self._process_grouping(group_by) 

89 self.subgroup_by = self._process_grouping(subgroup_by) 

90 self._subgroup_categories = self._compute_subgroup_categories() 

91 self._formatter_context = FormatterContext( 

92 group_by=self.group_by, 

93 subgroup_by=self.subgroup_by, 

94 estimate_catalog=self._estimate_catalog, 

95 metric_catalog=self._metric_catalog, 

96 subgroup_categories=tuple(self._subgroup_categories), 

97 ) 

98 self.filter_expr = filter_expr 

99 self.error_params = error_params or {} 

100 

101 # Apply base filter once 

102 self.df = self._apply_base_filter() 

103 

104 # Initialize evaluation cache 

105 self._evaluation_cache = {} 

106 

107 # Validate configuration eagerly so errors surface early 

108 self._validate_inputs() 

109 

110 @property 

111 def metrics(self) -> tuple[MetricDefine, ...]: 

112 return self._metric_catalog.entries 

113 

114 @property 

115 def estimates(self) -> Mapping[str, str]: 

116 return self._estimate_catalog.key_to_label 

117 

118 def _apply_base_filter(self) -> pl.LazyFrame: 

119 """Apply initial filter if provided""" 

120 if self.filter_expr is not None: 

121 return self.df_raw.filter(self.filter_expr) 

122 return self.df_raw 

123 

124 def _get_cache_key( 

125 self, 

126 metrics: MetricDefine | list[MetricDefine] | None, 

127 estimates: str | list[str] | None, 

128 ) -> tuple[tuple[str, ...], tuple[str, ...]]: 

129 """Generate cache key for evaluation parameters""" 

130 target_metrics = self._resolve_metrics(metrics) 

131 target_estimates = self._resolve_estimates(estimates) 

132 

133 # Create hashable key from metric names and estimates 

134 metric_names = tuple(sorted(m.name for m in target_metrics)) 

135 estimate_names = tuple(sorted(target_estimates)) 

136 

137 return (metric_names, estimate_names) 

138 

139 def _get_cached_evaluation( 

140 self, 

141 metrics: MetricDefine | list[MetricDefine] | None = None, 

142 estimates: str | list[str] | None = None, 

143 ) -> ARD: 

144 """Get cached evaluation result or compute and cache if not exists""" 

145 cache_key = self._get_cache_key(metrics, estimates) 

146 

147 if cache_key not in self._evaluation_cache: 

148 if metrics is not None or estimates is not None: 

149 filtered_evaluator = self.filter(metrics=metrics, estimates=estimates) 

150 ard_result = filtered_evaluator._evaluate_ard( 

151 metrics=metrics, estimates=estimates 

152 ) 

153 else: 

154 ard_result = self._evaluate_ard(metrics=metrics, estimates=estimates) 

155 self._evaluation_cache[cache_key] = ard_result 

156 

157 return self._evaluation_cache[cache_key] 

158 

159 def clear_cache(self) -> None: 

160 """Clear the evaluation cache""" 

161 self._evaluation_cache.clear() 

162 

163 def filter( 

164 self, 

165 *, 

166 metrics: MetricDefine | list[MetricDefine] | None = None, 

167 estimates: str | list[str] | None = None, 

168 ) -> "MetricEvaluator": 

169 """Return a new evaluator scoped to the requested metrics or estimates.""" 

170 

171 if metrics is not None: 

172 filtered_metrics = self._resolve_metrics(metrics) 

173 else: 

174 filtered_metrics = list(self.metrics) 

175 filtered_estimate_keys = ( 

176 self._resolve_estimates(estimates) 

177 if estimates is not None 

178 else list(self.estimates.keys()) 

179 ) 

180 filtered_estimates = { 

181 key: self.estimates[key] for key in filtered_estimate_keys 

182 } 

183 

184 return MetricEvaluator( 

185 df=self.df, 

186 metrics=filtered_metrics, 

187 ground_truth=self.ground_truth, 

188 estimates=filtered_estimates, 

189 group_by=self.group_by, 

190 subgroup_by=self.subgroup_by, 

191 filter_expr=None, 

192 error_params=self.error_params, 

193 ) 

194 

195 def _evaluate_ard( 

196 self, 

197 metrics: MetricDefine | list[MetricDefine] | None = None, 

198 estimates: str | list[str] | None = None, 

199 ) -> ARD: 

200 """Internal helper that returns evaluation results as ARD.""" 

201 

202 target_metrics = self._resolve_metrics(metrics) 

203 target_estimates = self._resolve_estimates(estimates) 

204 

205 if not target_metrics or not target_estimates: 

206 raise ValueError("No metrics or estimates to evaluate") 

207 

208 combined = self._vectorized_evaluate(target_metrics, target_estimates) 

209 formatted = self._format_result(combined) 

210 return self._convert_to_ard(formatted) 

211 

212 def evaluate( 

213 self, 

214 metrics: MetricDefine | list[MetricDefine] | None = None, 

215 estimates: str | list[str] | None = None, 

216 *, 

217 collect: bool = True, 

218 verbose: bool = False, 

219 ) -> pl.LazyFrame | pl.DataFrame: 

220 """ 

221 Unified evaluation method returning ARD format. 

222 

223 Args: 

224 metrics: Subset of metrics to evaluate (None = use all configured) 

225 estimates: Subset of estimates to evaluate (None = use all configured) 

226 collect: When False, return a ``LazyFrame`` rather than materialising results. 

227 verbose: When True, include struct columns (``id``, ``groups``, ``subgroups``) 

228 and diagnostic fields (``stat``, ``stat_fmt``, ``context``, ``warning``, 

229 ``error``). When False, struct values are flattened and diagnostics are 

230 dropped for a compact table view. 

231 

232 Returns: 

233 ``polars.DataFrame`` when ``collect`` is True (verbose controls struct 

234 flattening), or a ``LazyFrame`` when ``collect`` is False. 

235 """ 

236 

237 ard = self._evaluate_ard(metrics=metrics, estimates=estimates) 

238 

239 if not collect: 

240 return ard.lazy 

241 

242 if verbose: 

243 return format_verbose_frame(ard) 

244 return format_compact_frame(ard) 

245 

246 def _convert_to_ard(self, result_lf: pl.LazyFrame) -> ARD: 

247 return format_to_ard(result_lf, self._formatter_context) 

248 

249 def pivot_by_group( 

250 self, 

251 metrics: MetricDefine | list[MetricDefine] | None = None, 

252 estimates: str | list[str] | None = None, 

253 column_order_by: str = "metrics", 

254 row_order_by: str = "group", 

255 ) -> pl.DataFrame: 

256 """ 

257 Pivot results with groups as rows and model x metric as columns. 

258 

259 Args: 

260 metrics: Subset of metrics to evaluate (None = use all configured) 

261 estimates: Subset of estimates to evaluate (None = use all configured) 

262 column_order_by: Column ordering strategy ("metrics" or "estimates") 

263 row_order_by: Row ordering strategy ("group" or "subgroup") 

264 

265 Returns: 

266 DataFrame with group combinations as rows and metric columns 

267 """ 

268 long_df = self._collect_long_dataframe(metrics=metrics, estimates=estimates) 

269 return build_group_pivot( 

270 long_df, 

271 self._formatter_context, 

272 column_order_by=column_order_by, 

273 row_order_by=row_order_by, 

274 ) 

275 

276 def pivot_by_model( 

277 self, 

278 metrics: MetricDefine | list[MetricDefine] | None = None, 

279 estimates: str | list[str] | None = None, 

280 column_order_by: str = "estimates", 

281 row_order_by: str = "group", 

282 ) -> pl.DataFrame: 

283 """ 

284 Pivot results with models as rows and group x metric as columns. 

285 

286 Args: 

287 metrics: Subset of metrics to evaluate (None = use all configured) 

288 estimates: Subset of estimates to evaluate (None = use all configured) 

289 column_order_by: Column ordering strategy ("estimates" or "metrics") 

290 row_order_by: Row ordering strategy ("group" or "subgroup") 

291 

292 Returns: 

293 DataFrame with model combinations as rows and group+metric columns 

294 """ 

295 long_df = self._collect_long_dataframe(metrics=metrics, estimates=estimates) 

296 return build_model_pivot( 

297 long_df, 

298 self._formatter_context, 

299 column_order_by=column_order_by, 

300 row_order_by=row_order_by, 

301 ) 

302 

303 def _resolve_metrics( 

304 self, metrics: MetricDefine | list[MetricDefine] | None 

305 ) -> list[MetricDefine]: 

306 """Resolve which metrics to evaluate""" 

307 if metrics is None: 

308 return list(self.metrics) 

309 

310 metrics_list = [metrics] if isinstance(metrics, MetricDefine) else metrics 

311 configured_names = {m.name for m in self.metrics} 

312 

313 for m in metrics_list: 

314 if m.name not in configured_names: 

315 raise ValueError(f"Metric '{m.name}' not in configured metrics") 

316 

317 return metrics_list 

318 

319 def _resolve_estimates(self, estimates: str | list[str] | None) -> list[str]: 

320 """Resolve which estimates to evaluate""" 

321 if estimates is None: 

322 return list(self.estimates.keys()) 

323 

324 estimates_list = [estimates] if isinstance(estimates, str) else estimates 

325 

326 for e in estimates_list: 

327 if e not in self.estimates: 

328 raise ValueError( 

329 f"Estimate '{e}' not in configured estimates: {list(self.estimates.keys())}" 

330 ) 

331 

332 return estimates_list 

333 

334 def _vectorized_evaluate( 

335 self, metrics: list[MetricDefine], estimates: list[str] 

336 ) -> pl.LazyFrame: 

337 """Vectorized evaluation using single Polars group_by operations""" 

338 

339 # Step 1: Prepare data in long format with all estimates 

340 df_long = self._prepare_long_format_data(estimates) 

341 

342 # Step 2: Generate all error columns for the melted data 

343 df_with_errors = self._add_error_columns_vectorized(df_long) 

344 

345 # Step 3: Handle marginal subgroup analysis if needed 

346 if self.subgroup_by: 

347 return self._evaluate_with_marginal_subgroups( 

348 df_with_errors, metrics, estimates 

349 ) 

350 else: 

351 return self._evaluate_without_subgroups(df_with_errors, metrics, estimates) 

352 

353 def _evaluate_without_subgroups( 

354 self, 

355 df_with_errors: pl.LazyFrame, 

356 metrics: list[MetricDefine], 

357 estimates: list[str], 

358 ) -> pl.LazyFrame: 

359 """Evaluate metrics without subgroup analysis""" 

360 results = [] 

361 for metric in metrics: 

362 metric_result = self._evaluate_metric_vectorized( 

363 df_with_errors, metric, estimates 

364 ) 

365 results.append(metric_result) 

366 

367 # Combine results (no schema harmonization needed with fixed evaluation structure) 

368 if results: 

369 return pl.concat(results, how="diagonal") 

370 else: 

371 return pl.DataFrame().lazy() 

372 

373 def _evaluate_with_marginal_subgroups( 

374 self, 

375 df_with_errors: pl.LazyFrame, 

376 metrics: list[MetricDefine], 

377 estimates: list[str], 

378 ) -> pl.LazyFrame: 

379 """Evaluate metrics with marginal subgroup analysis using vectorized operations""" 

380 # Create all subgroup combinations using vectorized unpivot 

381 subgroup_data = self._prepare_subgroup_data_vectorized( 

382 df_with_errors, self.subgroup_by 

383 ) 

384 

385 # Evaluate all metrics across all subgroups in a vectorized manner 

386 results = [] 

387 for metric in metrics: 

388 metric_result = self._evaluate_metric_vectorized( 

389 subgroup_data, metric, estimates 

390 ) 

391 results.append(metric_result) 

392 

393 # Combine results (no schema harmonization needed with fixed evaluation structure) 

394 return pl.concat(results, how="diagonal") 

395 

396 def _prepare_subgroup_data_vectorized( 

397 self, df_with_errors: pl.LazyFrame, subgroup_by: dict[str, str] 

398 ) -> pl.LazyFrame: 

399 """Prepare subgroup data using vectorized unpivot operations""" 

400 schema_names = df_with_errors.collect_schema().names() 

401 subgroup_cols = list(subgroup_by.keys()) 

402 id_vars = [col for col in schema_names if col not in subgroup_cols] 

403 

404 # Use unpivot to create marginal subgroup analysis 

405 return df_with_errors.unpivot( 

406 index=id_vars, 

407 on=subgroup_cols, 

408 variable_name="subgroup_name", 

409 value_name="subgroup_value", 

410 ).with_columns( 

411 [ 

412 # Replace subgroup column names with their display labels 

413 pl.col("subgroup_name").replace(subgroup_by) 

414 ] 

415 ) 

416 

417 def _prepare_long_format_data(self, estimates: list[str]) -> pl.LazyFrame: 

418 """Reshape data from wide to long format for vectorized processing""" 

419 

420 # Add a row index to the original data to uniquely identify each sample 

421 # This must be done BEFORE unpivoting to avoid double counting 

422 df_with_index = self.df.with_row_index("sample_index") 

423 

424 # Get all columns except estimates to preserve in melt 

425 schema_names = df_with_index.collect_schema().names() 

426 id_vars = [col for col in schema_names if col not in estimates] 

427 

428 # Unpivot estimates into long format 

429 df_long = df_with_index.unpivot( 

430 index=id_vars, 

431 on=estimates, 

432 variable_name="estimate_name", 

433 value_name="estimate_value", 

434 ) 

435 

436 # Preserve canonical estimate key alongside optional display label 

437 label_mapping = self._estimate_catalog.key_to_label 

438 df_long = ( 

439 df_long.rename({self.ground_truth: "ground_truth"}) 

440 .with_columns( 

441 [ 

442 pl.col("estimate_name").alias("estimate"), 

443 pl.col("estimate_name") 

444 .cast(pl.Utf8) 

445 .map_elements( 

446 lambda val, mapping=label_mapping: mapping.get(val, val), 

447 return_dtype=pl.Utf8, 

448 ) 

449 .alias("estimate_label"), 

450 ] 

451 ) 

452 .drop("estimate_name") 

453 ) 

454 

455 return df_long 

456 

457 def _add_error_columns_vectorized(self, df_long: pl.LazyFrame) -> pl.LazyFrame: 

458 """Add error columns for the long-format data""" 

459 

460 # Generate error expressions for the vectorized format 

461 # Use 'estimate_value' as the estimate column and 'ground_truth' as the renamed ground truth column 

462 error_expressions = MetricRegistry.generate_error_columns( 

463 estimate="estimate_value", 

464 ground_truth="ground_truth", 

465 error_types=None, 

466 error_params=self.error_params, 

467 ) 

468 

469 return df_long.with_columns(error_expressions) 

470 

471 def _evaluate_metric_vectorized( 

472 self, df_with_errors: pl.LazyFrame, metric: MetricDefine, estimates: list[str] 

473 ) -> pl.LazyFrame: 

474 """Evaluate a single metric using vectorized operations""" 

475 

476 group_cols = self._get_vectorized_grouping_columns(metric, df_with_errors) 

477 within_infos, across_info = metric.compile_expressions() 

478 df_filtered = self._apply_metric_scope_filter(df_with_errors, metric, estimates) 

479 

480 handlers = { 

481 MetricType.ACROSS_SAMPLE: self._evaluate_across_sample_metric, 

482 MetricType.WITHIN_SUBJECT: self._evaluate_within_entity_metric, 

483 MetricType.WITHIN_VISIT: self._evaluate_within_entity_metric, 

484 MetricType.ACROSS_SUBJECT: self._evaluate_two_stage_metric, 

485 MetricType.ACROSS_VISIT: self._evaluate_two_stage_metric, 

486 } 

487 

488 handler = handlers.get(metric.type) 

489 if handler is None: 

490 raise ValueError(f"Unknown metric type: {metric.type}") 

491 

492 try: 

493 result, result_info = handler( 

494 df_filtered, 

495 metric, 

496 group_cols, 

497 within_infos, 

498 across_info, 

499 ) 

500 return self._add_metadata_vectorized(result, metric, result_info) 

501 except Exception as exc: 

502 fallback_info = self._fallback_metric_info( 

503 metric, within_infos, across_info 

504 ) 

505 placeholder = self._prepare_error_lazyframe(group_cols) 

506 return self._add_metadata_vectorized( 

507 placeholder, 

508 metric, 

509 fallback_info, 

510 warnings_list=[], 

511 errors_list=[self._format_exception_message(exc, metric.name)], 

512 ) 

513 

514 def _evaluate_across_sample_metric( 

515 self, 

516 df: pl.LazyFrame, 

517 metric: MetricDefine, 

518 group_cols: list[str], 

519 _within_infos: Sequence[MetricInfo] | None, 

520 across_info: MetricInfo | None, 

521 ) -> tuple[pl.LazyFrame, MetricInfo]: 

522 if across_info is None: 

523 raise ValueError(f"ACROSS_SAMPLE metric {metric.name} requires across_expr") 

524 

525 agg_exprs = self._metric_agg_expressions(across_info) 

526 result = self._aggregate_lazyframe(df, group_cols, agg_exprs) 

527 return result, across_info 

528 

529 def _evaluate_within_entity_metric( 

530 self, 

531 df: pl.LazyFrame, 

532 metric: MetricDefine, 

533 group_cols: list[str], 

534 within_infos: Sequence[MetricInfo] | None, 

535 across_info: MetricInfo | None, 

536 ) -> tuple[pl.LazyFrame, MetricInfo]: 

537 entity_groups = self._merge_group_columns( 

538 self._get_entity_grouping_columns(metric.type), group_cols 

539 ) 

540 

541 result_info = self._resolve_metric_info( 

542 metric, 

543 primary=within_infos, 

544 fallback=across_info, 

545 error_message=f"No valid expression for metric {metric.name}", 

546 ) 

547 

548 agg_exprs = self._metric_agg_expressions(result_info) 

549 result = self._aggregate_lazyframe(df, entity_groups, agg_exprs) 

550 return result, result_info 

551 

552 def _evaluate_two_stage_metric( 

553 self, 

554 df: pl.LazyFrame, 

555 metric: MetricDefine, 

556 group_cols: list[str], 

557 within_infos: Sequence[MetricInfo] | None, 

558 across_info: MetricInfo | None, 

559 ) -> tuple[pl.LazyFrame, MetricInfo]: 

560 entity_groups = self._merge_group_columns( 

561 self._get_entity_grouping_columns(metric.type), group_cols 

562 ) 

563 

564 base_info = self._resolve_metric_info( 

565 metric, 

566 primary=within_infos, 

567 fallback=across_info, 

568 error_message=( 

569 f"No valid expression for first level of metric {metric.name}" 

570 ), 

571 ) 

572 

573 intermediate = self._aggregate_lazyframe( 

574 df, 

575 entity_groups, 

576 self._metric_agg_expressions(base_info), 

577 ) 

578 

579 if across_info is not None and within_infos: 

580 result_info = across_info 

581 agg_exprs = self._metric_agg_expressions(result_info) 

582 else: 

583 result_info = base_info 

584 agg_exprs = [pl.col("value").mean().alias("value")] 

585 

586 result = self._aggregate_lazyframe(intermediate, group_cols, agg_exprs) 

587 return result, result_info 

588 

589 @staticmethod 

590 def _merge_group_columns( 

591 *column_groups: Sequence[str], 

592 ) -> list[str]: 

593 seen: set[str] = set() 

594 ordered: list[str] = [] 

595 for columns in column_groups: 

596 for col in columns: 

597 if col not in seen: 

598 seen.add(col) 

599 ordered.append(col) 

600 return ordered 

601 

602 def _resolve_metric_info( 

603 self, 

604 metric: MetricDefine, 

605 *, 

606 primary: Sequence[MetricInfo] | None, 

607 fallback: MetricInfo | None, 

608 error_message: str, 

609 ) -> MetricInfo: 

610 if primary: 

611 return primary[0] 

612 if fallback is not None: 

613 return fallback 

614 raise ValueError(error_message) 

615 

616 @staticmethod 

617 def _aggregate_lazyframe( 

618 df: pl.LazyFrame, group_cols: Sequence[str], agg_exprs: Sequence[pl.Expr] 

619 ) -> pl.LazyFrame: 

620 columns = [col for col in group_cols if col] 

621 if columns: 

622 return df.group_by(columns).agg(agg_exprs) 

623 return df.select(*agg_exprs) 

624 

625 def _get_vectorized_grouping_columns( 

626 self, metric: MetricDefine, df: pl.LazyFrame | None = None 

627 ) -> list[str]: 

628 """Get grouping columns for vectorized evaluation based on metric scope""" 

629 

630 schema_names: set[str] 

631 if df is not None: 

632 schema_names = set(df.collect_schema().names()) 

633 else: 

634 schema_names = set(self.df.collect_schema().names()) 

635 

636 using_vectorized_subgroups = { 

637 "subgroup_name", 

638 "subgroup_value", 

639 }.issubset(schema_names) 

640 

641 def existing(columns: Iterable[str]) -> list[str]: 

642 return [col for col in columns if col in schema_names] 

643 

644 group_cols: list[str] = [] 

645 

646 if metric.scope == MetricScope.GLOBAL: 

647 subgroup_cols = ( 

648 ["subgroup_name", "subgroup_value"] 

649 if using_vectorized_subgroups 

650 else existing(self.subgroup_by.keys()) 

651 ) 

652 group_cols.extend(subgroup_cols) 

653 elif metric.scope == MetricScope.MODEL: 

654 model_cols = existing(["estimate"]) 

655 subgroup_cols = ( 

656 ["subgroup_name", "subgroup_value"] 

657 if using_vectorized_subgroups 

658 else existing(self.subgroup_by.keys()) 

659 ) 

660 group_cols.extend(model_cols + subgroup_cols) 

661 elif metric.scope == MetricScope.GROUP: 

662 group_cols.extend(existing(self.group_by.keys())) 

663 subgroup_cols = ( 

664 ["subgroup_name", "subgroup_value"] 

665 if using_vectorized_subgroups 

666 else existing(self.subgroup_by.keys()) 

667 ) 

668 group_cols.extend(subgroup_cols) 

669 else: 

670 group_cols.extend(existing(["estimate"])) 

671 group_cols.extend(existing(self.group_by.keys())) 

672 subgroup_cols = ( 

673 ["subgroup_name", "subgroup_value"] 

674 if using_vectorized_subgroups 

675 else existing(self.subgroup_by.keys()) 

676 ) 

677 group_cols.extend(subgroup_cols) 

678 

679 return self._merge_group_columns(group_cols) 

680 

681 def _apply_metric_scope_filter( 

682 self, df: pl.LazyFrame, metric: MetricDefine, estimates: list[str] 

683 ) -> pl.LazyFrame: 

684 """Apply any scope-specific filtering""" 

685 # For now, no additional filtering needed beyond grouping 

686 # Future: could add estimate filtering for specific scopes 

687 _ = metric, estimates # Suppress unused parameter warnings 

688 return df 

689 

690 def _metric_agg_expressions(self, info: MetricInfo) -> list[pl.Expr]: 

691 return [info.expr.alias("value")] 

692 

693 def _get_entity_grouping_columns(self, metric_type: MetricType) -> list[str]: 

694 """Get entity-level grouping columns (subject_id, visit_id)""" 

695 if metric_type in [MetricType.WITHIN_SUBJECT, MetricType.ACROSS_SUBJECT]: 

696 return ["subject_id"] 

697 elif metric_type in [MetricType.WITHIN_VISIT, MetricType.ACROSS_VISIT]: 

698 return ["subject_id", "visit_id"] 

699 else: 

700 return [] 

701 

702 def _add_metadata_vectorized( 

703 self, 

704 result: pl.LazyFrame, 

705 metric: MetricDefine, 

706 info: MetricInfo, 

707 *, 

708 warnings_list: Sequence[str] | None = None, 

709 errors_list: Sequence[str] | None = None, 

710 ) -> pl.LazyFrame: 

711 """Add metadata columns to vectorized result""" 

712 

713 schema = result.collect_schema() 

714 value_dtype = schema.get("value") if "value" in schema.names() else None 

715 value_kind = ( 

716 (info.value_kind or "").lower() 

717 if info.value_kind 

718 else self._infer_value_kind_from_dtype(value_dtype) 

719 ) 

720 if not value_kind: 

721 value_kind = "float" 

722 

723 metadata_columns = [ 

724 pl.lit(metric.name).cast(pl.Utf8).alias("metric"), 

725 pl.lit(metric.label).cast(pl.Utf8).alias("label"), 

726 pl.lit(metric.type.value).cast(pl.Utf8).alias("metric_type"), 

727 pl.lit(metric.scope.value if metric.scope else None) 

728 .cast(pl.Utf8) 

729 .alias("scope"), 

730 pl.lit(value_kind).cast(pl.Utf8).alias("_value_kind"), 

731 pl.lit(info.format).cast(pl.Utf8).alias("_value_format"), 

732 ] 

733 

734 result = result.with_columns(metadata_columns) 

735 

736 diagnostics_columns = [ 

737 pl.lit(list(warnings_list or []), dtype=pl.List(pl.Utf8)).alias( 

738 "_diagnostic_warning" 

739 ), 

740 pl.lit(list(errors_list or []), dtype=pl.List(pl.Utf8)).alias( 

741 "_diagnostic_error" 

742 ), 

743 ] 

744 

745 result = result.with_columns(diagnostics_columns) 

746 

747 # Attach entity identifiers for within-entity metrics 

748 result = self._attach_entity_identifier(result, metric) 

749 

750 # Helper columns for stat struct construction 

751 helper_columns = [ 

752 pl.lit(None, dtype=pl.Float64).alias("_value_float"), 

753 pl.lit(None, dtype=pl.Int64).alias("_value_int"), 

754 pl.lit(None, dtype=pl.Boolean).alias("_value_bool"), 

755 pl.lit(None, dtype=pl.Utf8).alias("_value_str"), 

756 pl.lit(None).alias("_value_struct"), 

757 ] 

758 result = result.with_columns(helper_columns) 

759 

760 if value_kind == "int": 

761 result = result.with_columns( 

762 pl.col("value").cast(pl.Int64, strict=False).alias("_value_int"), 

763 pl.col("value").cast(pl.Float64, strict=False).alias("_value_float"), 

764 pl.col("value").cast(pl.Float64, strict=False), 

765 ) 

766 elif value_kind == "float": 

767 result = result.with_columns( 

768 pl.col("value").cast(pl.Float64, strict=False).alias("_value_float"), 

769 pl.col("value").cast(pl.Float64, strict=False), 

770 ) 

771 elif value_kind == "bool": 

772 result = result.with_columns( 

773 pl.col("value").cast(pl.Boolean, strict=False).alias("_value_bool"), 

774 pl.lit(None, dtype=pl.Float64).alias("value"), 

775 ) 

776 elif value_kind == "string": 

777 result = result.with_columns( 

778 pl.col("value").cast(pl.Utf8, strict=False).alias("_value_str"), 

779 pl.lit(None, dtype=pl.Float64).alias("value"), 

780 ) 

781 elif value_kind == "struct": 

782 result = result.with_columns( 

783 pl.col("value").alias("_value_struct"), 

784 pl.lit(None, dtype=pl.Float64).alias("value"), 

785 ) 

786 else: 

787 result = result.with_columns( 

788 pl.col("value").cast(pl.Utf8, strict=False).alias("_value_str"), 

789 pl.lit(None, dtype=pl.Float64).alias("value"), 

790 ) 

791 

792 return result 

793 

794 @staticmethod 

795 def _format_exception_message(exc: Exception, metric_name: str) -> str: 

796 return f"{metric_name}: {type(exc).__name__}: {exc}".strip() 

797 

798 @staticmethod 

799 def _fallback_metric_info( 

800 metric: MetricDefine, 

801 within_infos: Sequence[MetricInfo] | None, 

802 across_info: MetricInfo | None, 

803 ) -> MetricInfo: 

804 if within_infos: 

805 return within_infos[0] 

806 if across_info is not None: 

807 return across_info 

808 return MetricInfo(expr=pl.lit(None).alias("value"), value_kind="float") 

809 

810 @staticmethod 

811 def _prepare_error_lazyframe(group_cols: Sequence[str]) -> pl.LazyFrame: 

812 data: dict[str, list[Any]] = {} 

813 for col in group_cols: 

814 data[col] = [None] 

815 data["value"] = [None] 

816 return pl.DataFrame(data, strict=False).lazy() 

817 

818 @staticmethod 

819 def _infer_value_kind_from_dtype(dtype: pl.DataType | None) -> str: 

820 """Map Polars dtypes to MetricInfo value_kind labels.""" 

821 

822 if dtype is None or dtype == pl.Null: 

823 return "float" 

824 if dtype == pl.Struct: 

825 return "struct" 

826 if dtype == pl.Boolean: 

827 return "bool" 

828 if dtype == pl.Utf8: 

829 return "string" 

830 if hasattr(dtype, "is_numeric") and dtype.is_numeric(): 

831 return "int" if dtype.is_integer() else "float" 

832 return "string" 

833 

834 def _attach_entity_identifier( 

835 self, result: pl.LazyFrame, metric: MetricDefine 

836 ) -> pl.LazyFrame: 

837 """Attach a canonical id struct for within-entity metrics.""" 

838 

839 entity_cols = self._get_entity_grouping_columns(metric.type) 

840 if not entity_cols: 

841 if "id" in result.collect_schema().names(): 

842 return result 

843 return result.with_columns(pl.lit(None).alias("id")) 

844 

845 schema = result.collect_schema() 

846 available = set(schema.names()) 

847 present = [col for col in entity_cols if col in available] 

848 

849 if not present: 

850 if "id" in available: 

851 return result 

852 return result.with_columns(pl.lit(None).alias("id")) 

853 

854 id_struct = pl.struct([pl.col(col).alias(col) for col in present]).alias("id") 

855 cleaned = result.with_columns(id_struct) 

856 

857 # Drop entity columns to avoid polluting downstream schema 

858 return cleaned.drop(present) 

859 

860 def _format_result(self, combined: pl.LazyFrame) -> pl.LazyFrame: 

861 """Minimal formatting - ARD handles all presentation concerns""" 

862 return combined 

863 

864 # ------------------------------------------------------------------ 

865 # Result shaping helpers 

866 # ------------------------------------------------------------------ 

867 

868 def _collect_long_dataframe( 

869 self, 

870 metrics: MetricDefine | list[MetricDefine] | None = None, 

871 estimates: str | list[str] | None = None, 

872 ) -> pl.DataFrame: 

873 """Collect evaluation results as a flat DataFrame for pivoting.""" 

874 

875 ard = self._get_cached_evaluation(metrics=metrics, estimates=estimates) 

876 lf = ard.lazy 

877 schema = lf.collect_schema() 

878 

879 exprs: list[pl.Expr] = [] 

880 estimate_label_map = self._estimate_catalog.key_to_label 

881 

882 # Group columns with display labels 

883 for col, label in self.group_by.items(): 

884 if col in schema.names(): 

885 exprs.append(pl.col(col).alias(label)) 

886 

887 # Subgroup columns 

888 if "subgroup_name" in schema.names(): 

889 exprs.append(pl.col("subgroup_name")) 

890 if "subgroup_value" in schema.names(): 

891 exprs.append(pl.col("subgroup_value")) 

892 

893 # Estimate / metric / label columns 

894 if "estimate" in schema.names(): 

895 exprs.append(pl.col("estimate").cast(pl.Utf8)) 

896 exprs.append( 

897 pl.col("estimate") 

898 .cast(pl.Utf8) 

899 .map_elements( 

900 lambda val, mapping=estimate_label_map: mapping.get(val, val), 

901 return_dtype=pl.Utf8, 

902 ) 

903 .alias("estimate_label") 

904 ) 

905 if "metric" in schema.names(): 

906 exprs.append(pl.col("metric").cast(pl.Utf8)) 

907 if "label" in schema.names(): 

908 exprs.append(pl.col("label").cast(pl.Utf8)) 

909 else: 

910 exprs.append(pl.col("metric").cast(pl.Utf8).alias("label")) 

911 

912 if "stat" in schema.names(): 

913 exprs.append(pl.col("stat")) 

914 if "stat_fmt" in schema.names(): 

915 exprs.append(pl.col("stat_fmt")) 

916 exprs.append( 

917 pl.when(pl.col("stat_fmt").is_null()) 

918 .then( 

919 pl.col("stat").map_elements( 

920 ARD._format_stat, return_dtype=pl.Utf8 

921 ) 

922 ) 

923 .otherwise(pl.col("stat_fmt")) 

924 .alias("value") 

925 ) 

926 else: 

927 exprs.append( 

928 pl.col("stat") 

929 .map_elements(ARD._format_stat, return_dtype=pl.Utf8) 

930 .alias("value") 

931 ) 

932 

933 if "id" in schema.names(): 

934 exprs.append(pl.col("id")) 

935 

936 # Scope metadata 

937 if "metric_type" in schema.names(): 

938 exprs.append(pl.col("metric_type").cast(pl.Utf8)) 

939 if "scope" in schema.names(): 

940 exprs.append(pl.col("scope").cast(pl.Utf8)) 

941 

942 return lf.select(exprs).collect() 

943 

944 # ======================================== 

945 # INPUT PROCESSING METHODS - Pure Logic 

946 # ======================================== 

947 

948 @staticmethod 

949 def _process_estimates( 

950 estimates: str | list[str] | dict[str, str] | None, 

951 ) -> dict[str, str]: 

952 """Pure transformation: normalize estimates to dict format""" 

953 if isinstance(estimates, str): 

954 return {estimates: estimates} 

955 elif isinstance(estimates, dict): 

956 return estimates 

957 elif isinstance(estimates, list): 

958 return {est: est for est in (estimates or [])} 

959 else: 

960 return {} 

961 

962 @staticmethod 

963 def _process_grouping( 

964 grouping: list[str] | dict[str, str] | None, 

965 ) -> dict[str, str]: 

966 """Pure transformation: normalize grouping to dict format""" 

967 if isinstance(grouping, dict): 

968 return grouping 

969 elif isinstance(grouping, list): 

970 return {col: col for col in (grouping or [])} 

971 else: 

972 return {} 

973 

974 def _compute_subgroup_categories(self) -> list[str]: 

975 if not self.subgroup_by: 

976 return [] 

977 

978 categories: list[Any] = [] 

979 seen: set[Any] = set() 

980 schema = self.df_raw.collect_schema() 

981 

982 for column in self.subgroup_by.keys(): 

983 if column not in schema.names(): 

984 continue 

985 

986 dtype = schema[column] 

987 

988 if isinstance(dtype, pl.Enum): 

989 for value in dtype.categories.to_list(): 

990 if value not in seen: 

991 seen.add(value) 

992 categories.append(value) 

993 continue 

994 

995 ordered_values = self._collect_unique_subgroup_values(column, dtype) 

996 for value in ordered_values: 

997 if value in seen: 

998 continue 

999 seen.add(value) 

1000 categories.append(value) 

1001 

1002 return categories 

1003 

1004 def _collect_unique_subgroup_values( 

1005 self, column: str, dtype: pl.DataType 

1006 ) -> list[Any]: 

1007 """Collect sorted unique subgroup values using lazy execution.""" 

1008 

1009 expr = pl.col(column).drop_nulls() 

1010 

1011 if not dtype.is_numeric(): 

1012 expr = expr.cast(pl.Utf8) 

1013 

1014 lazy_unique = ( 

1015 self.df_raw.select(expr.alias(column)).unique(subset=[column]).sort(column) 

1016 ) 

1017 

1018 df_unique = lazy_unique.collect(engine="streaming") 

1019 values = df_unique[column].to_list() 

1020 

1021 if dtype.is_numeric(): 

1022 return values 

1023 

1024 return [str(value) for value in values] 

1025 

1026 # ======================================== 

1027 # VALIDATION METHODS - Centralized Logic 

1028 # ======================================== 

1029 

1030 def _validate_inputs(self) -> None: 

1031 """Validate all inputs after processing""" 

1032 if not self.estimates: 

1033 raise ValueError("No estimates provided") 

1034 

1035 if not self.metrics: 

1036 raise ValueError("No metrics provided") 

1037 

1038 # Validate that required columns exist 

1039 schema_names = self.df_raw.collect_schema().names() 

1040 

1041 if self.ground_truth not in schema_names: 

1042 raise ValueError( 

1043 f"Ground truth column '{self.ground_truth}' not found in data" 

1044 ) 

1045 

1046 missing_estimates = [ 

1047 est for est in self.estimates.keys() if est not in schema_names 

1048 ] 

1049 if missing_estimates: 

1050 raise ValueError(f"Estimate columns not found in data: {missing_estimates}") 

1051 

1052 missing_groups = [ 

1053 col for col in self.group_by.keys() if col not in schema_names 

1054 ] 

1055 if missing_groups: 

1056 raise ValueError(f"Group columns not found in data: {missing_groups}") 

1057 

1058 missing_subgroups = [ 

1059 col for col in self.subgroup_by.keys() if col not in schema_names 

1060 ] 

1061 if missing_subgroups: 

1062 raise ValueError(f"Subgroup columns not found in data: {missing_subgroups}") 

1063 

1064 overlap = set(self.group_by.keys()) & set(self.subgroup_by.keys()) 

1065 if overlap: 

1066 raise ValueError( 

1067 "Group and subgroup columns must be distinct; found duplicates: " 

1068 f"{sorted(overlap)}" 

1069 )