Coverage for src/polars_eval_metrics/ard.py: 72%
260 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
1"""Analysis Results Data (ARD) container."""
3from __future__ import annotations
5from dataclasses import dataclass
6import json
7from typing import Any, Iterable, Mapping
9import polars as pl
12@dataclass
13class ARD:
14 """Fixed-schema container for metric evaluation output."""
16 _lf: pl.LazyFrame
17 _group_fields: tuple[str, ...]
18 _subgroup_fields: tuple[str, ...]
19 _context_fields: tuple[str, ...]
20 _id_fields: tuple[str, ...]
22 def __init__(self, data: pl.DataFrame | pl.LazyFrame | None = None) -> None:
23 if data is None:
24 self._lf = self._empty_frame()
25 elif isinstance(data, pl.DataFrame):
26 self._validate_schema(data)
27 self._lf = data.lazy()
28 elif isinstance(data, pl.LazyFrame):
29 self._lf = data
30 else:
31 raise TypeError(f"Unsupported data type: {type(data)}")
33 schema = self._lf.collect_schema()
34 self._id_fields = self._extract_struct_fields(schema, "id")
35 self._group_fields = self._extract_struct_fields(schema, "groups")
36 self._subgroup_fields = self._extract_struct_fields(schema, "subgroups")
37 self._context_fields = self._extract_struct_fields(schema, "context")
39 # ---------------------------------------------------------------------
40 # Construction helpers
41 # ---------------------------------------------------------------------
43 @staticmethod
44 def _empty_frame() -> pl.LazyFrame:
45 """Return an empty ARD frame with the canonical schema."""
46 stat_dtype = pl.Struct(
47 [
48 pl.Field("type", pl.Utf8),
49 pl.Field("value_float", pl.Float64),
50 pl.Field("value_int", pl.Int64),
51 pl.Field("value_bool", pl.Boolean),
52 pl.Field("value_str", pl.Utf8),
53 pl.Field("value_struct", pl.Struct([])),
54 pl.Field("format", pl.Utf8),
55 ]
56 )
57 frame = pl.DataFrame(
58 {
59 "id": pl.Series([], dtype=pl.Null),
60 "groups": pl.Series([], dtype=pl.Struct([])),
61 "subgroups": pl.Series([], dtype=pl.Struct([])),
62 "estimate": pl.Series([], dtype=pl.Utf8),
63 "metric": pl.Series([], dtype=pl.Utf8),
64 "label": pl.Series([], dtype=pl.Utf8),
65 "stat": pl.Series([], dtype=stat_dtype),
66 "stat_fmt": pl.Series([], dtype=pl.Utf8),
67 "warning": pl.Series([], dtype=pl.List(pl.Utf8)),
68 "error": pl.Series([], dtype=pl.List(pl.Utf8)),
69 "context": pl.Series([], dtype=pl.Struct([])),
70 }
71 )
72 return frame.lazy()
74 @staticmethod
75 def _extract_struct_fields(
76 schema: Mapping[str, pl.DataType], column: str
77 ) -> tuple[str, ...]:
78 """Return field names for struct columns, or an empty tuple when not present."""
79 dtype = schema.get(column)
80 if isinstance(dtype, pl.Struct):
81 return tuple(field.name for field in dtype.fields)
82 return tuple()
84 @staticmethod
85 def _validate_schema(df: pl.DataFrame) -> None:
86 """Guard against constructing ARD from frames missing required columns."""
87 required = {"groups", "subgroups", "estimate", "metric", "stat", "context"}
88 missing = required - set(df.columns)
89 if missing:
90 raise ValueError(f"Missing required ARD columns: {missing}")
92 # ------------------------------------------------------------------
93 # Basic API
94 # ------------------------------------------------------------------
96 @property
97 def lazy(self) -> pl.LazyFrame:
98 return self._lf
100 def collect(self) -> pl.DataFrame:
101 """Collect the lazy evaluation while keeping the canonical columns when available."""
102 # Keep core columns for backward compatibility when eagerly collecting
103 available = self._lf.collect_schema().names()
104 desired = [
105 col
106 for col in [
107 "id",
108 "groups",
109 "subgroups",
110 "subgroup_name",
111 "subgroup_value",
112 "estimate",
113 "metric",
114 "label",
115 "stat",
116 "stat_fmt",
117 "warning",
118 "error",
119 "context",
120 ]
121 if col in available
122 ]
123 return self._lf.select(desired).collect()
125 def __len__(self) -> int:
126 return self.collect().height
128 @property
129 def shape(self) -> tuple[int, int]:
130 collected = self.collect()
131 return collected.shape
133 @property
134 def columns(self) -> list[str]:
135 return list(self.schema.keys())
137 @property
138 def schema(self) -> dict[str, pl.DataType]:
139 """Expose the ARD schema for compatibility with tests/utilities."""
140 collected = self._lf.collect_schema()
141 return dict(zip(collected.names(), collected.dtypes()))
143 def __getitem__(self, key: str) -> pl.Series:
144 """Allow DataFrame-like column access for compatibility with tests."""
145 collected = self.collect()
146 if key in collected.columns:
147 return collected[key]
148 schema_names = self._lf.collect_schema().names()
149 if key in schema_names:
150 return self._lf.select(pl.col(key)).collect()[key]
151 raise KeyError(key)
153 def iter_rows(self, *args: Any, **kwargs: Any) -> Iterable[tuple[Any, ...]]:
154 """Iterate over rows of the eagerly collected DataFrame."""
155 return self.collect().iter_rows(*args, **kwargs)
157 def sort(self, *args: Any, **kwargs: Any) -> ARD:
158 """Return a sorted ARD (collecting lazily)."""
159 return ARD(self._lf.sort(*args, **kwargs))
161 # ------------------------------------------------------------------
162 # Formatting utilities
163 # ------------------------------------------------------------------
165 @staticmethod
166 def _stat_value(stat: Mapping[str, Any] | None) -> Any:
167 """Extract the native value stored in a stat struct regardless of channel used."""
168 if stat is None:
169 return None
171 type_label = (stat.get("type") or "").lower()
172 if type_label == "float":
173 return stat.get("value_float")
174 if type_label == "int":
175 return stat.get("value_int")
176 if type_label == "bool":
177 return stat.get("value_bool")
178 if type_label == "string":
179 return stat.get("value_str")
180 if type_label == "struct":
181 return stat.get("value_struct")
183 for field in [
184 "value_float",
185 "value_int",
186 "value_bool",
187 "value_str",
188 "value_struct",
189 ]:
190 candidate = stat.get(field)
191 if candidate is not None:
192 if field == "value_struct":
193 return candidate
194 return candidate
196 return None
198 @staticmethod
199 def _format_stat(stat: Mapping[str, Any] | None) -> str:
200 """Render a stat struct into a string while respecting explicit formatting hints."""
201 if stat is None:
202 return "NULL"
204 value = ARD._stat_value(stat)
205 type_label = (stat.get("type") or "").lower()
206 fmt = stat.get("format")
207 if fmt and value is not None:
208 try:
209 rendered = fmt.format(value)
210 except Exception:
211 rendered = str(value)
212 elif isinstance(value, float):
213 rendered = f"{value:.1f}"
214 elif isinstance(value, int):
215 rendered = f"{value:,}"
216 elif isinstance(value, (dict, list, tuple)):
217 rendered = json.dumps(value)
218 else:
219 rendered = None if value is None else str(value)
221 return rendered
223 def __repr__(self) -> str:
224 summary = self.summary()
225 return f"ARD(summary={summary})"
227 # ------------------------------------------------------------------
228 # Null / empty handling
229 # ------------------------------------------------------------------
231 def with_empty_as_null(self) -> ARD:
232 """Collapse empty structs or blank strings to null for easier downstream filtering."""
234 def _collapse(column: str, fields: tuple[str, ...]) -> pl.Expr:
235 if not fields:
236 return pl.col(column)
237 empty = pl.all_horizontal(
238 [pl.col(column).struct.field(field).is_null() for field in fields]
239 )
240 return (
241 pl.when(pl.col(column).is_null() | empty)
242 .then(None)
243 .otherwise(pl.col(column))
244 .alias(column)
245 )
247 lf = self._lf.with_columns(
248 [
249 _collapse("id", self._id_fields),
250 _collapse("groups", self._group_fields),
251 _collapse("subgroups", self._subgroup_fields),
252 _collapse("context", self._context_fields),
253 pl.when(pl.col("estimate") == "")
254 .then(None)
255 .otherwise(pl.col("estimate"))
256 .alias("estimate"),
257 ]
258 )
259 return ARD(lf)
261 def with_null_as_empty(self) -> ARD:
262 """Fill null structs or estimates with empty shells to simplify presentation."""
264 def _expand(column: str, fields: tuple[str, ...]) -> pl.Expr:
265 if not fields:
266 return pl.col(column)
267 placeholders = [pl.lit(None).alias(name) for name in fields]
268 return (
269 pl.when(pl.col(column).is_null())
270 .then(pl.struct(placeholders))
271 .otherwise(pl.col(column))
272 .alias(column)
273 )
275 lf = self._lf.with_columns(
276 [
277 _expand("id", self._id_fields),
278 _expand("groups", self._group_fields),
279 _expand("subgroups", self._subgroup_fields),
280 _expand("context", self._context_fields),
281 pl.col("estimate").fill_null(""),
282 ]
283 )
284 return ARD(lf)
286 # ------------------------------------------------------------------
287 # Transformations
288 # ------------------------------------------------------------------
290 def unnest(self, columns: list[str] | None = None) -> pl.DataFrame:
291 """Expand selected struct columns into top-level fields for inspection or exports."""
292 columns = columns or ["groups", "subgroups"]
293 lf = self._lf
294 schema = lf.collect_schema()
295 for column in columns:
296 if column not in {"id", "groups", "subgroups", "context", "stat"}:
297 continue
298 if column not in schema.names():
299 continue
300 dtype = schema.get(column)
301 if not isinstance(dtype, pl.Struct):
302 continue
303 struct_fields = {field.name for field in dtype.fields}
304 existing_fields = set(schema.names())
305 if struct_fields & existing_fields:
306 continue
307 has_values = lf.select(pl.col(column).is_not_null().any()).collect().item()
308 if has_values:
309 lf = lf.unnest(column)
310 schema = lf.collect_schema()
311 return lf.collect()
313 def to_wide(
314 self,
315 index: list[str] | None = None,
316 columns: list[str] | None = None,
317 values: str = "stat",
318 aggregate: str = "first",
319 ) -> pl.DataFrame:
320 """Pivot the ARD into a wide grid, formatting stats unless a value column is provided."""
321 df = self.unnest(["groups", "subgroups", "context"])
323 if columns is None:
324 has_estimates = (
325 df.filter(pl.col("estimate").is_not_null())["estimate"].n_unique() > 1
326 )
327 columns = ["estimate", "metric"] if has_estimates else ["metric"]
329 if index is None:
330 index = [col for col in df.columns if col not in columns + [values, "stat"]]
332 if values == "stat":
333 if "stat_fmt" in df.columns:
334 formatted_expr = (
335 pl.when(pl.col("stat_fmt").is_null())
336 .then(
337 pl.col("stat").map_elements(
338 ARD._format_stat, return_dtype=pl.Utf8
339 )
340 )
341 .otherwise(pl.col("stat_fmt"))
342 .alias("_value")
343 )
344 else:
345 formatted_expr = (
346 pl.col("stat")
347 .map_elements(ARD._format_stat, return_dtype=pl.Utf8)
348 .alias("_value")
349 )
351 df = df.with_columns(formatted_expr)
352 values = "_value"
354 if not index or all(df[col].null_count() == len(df) for col in index):
355 df = df.with_row_index("_idx")
356 index = ["_idx"]
358 pivoted = df.pivot(
359 index=index, on=columns, values=values, aggregate_function=aggregate
360 )
362 if "_idx" in pivoted.columns:
363 pivoted = pivoted.drop("_idx")
364 if "_value" in pivoted.columns:
365 pivoted = pivoted.drop("_value")
366 return pivoted
368 def to_long(self) -> pl.DataFrame:
369 """Convert ARD to long format with flattened columns for direct Polars operations."""
370 # Start with a copy of the lazy frame
371 lf = self._lf
372 schema = lf.collect_schema()
374 # Check for potential conflicts with context unnesting
375 context_conflicts = False
376 if "context" in schema.names():
377 context_dtype = schema.get("context")
378 if isinstance(context_dtype, pl.Struct):
379 context_fields = {field.name for field in context_dtype.fields}
380 existing_fields = set(schema.names())
381 context_conflicts = bool(context_fields & existing_fields)
383 # Unnest struct columns, checking for conflicts
384 current_schema = lf.collect_schema()
385 for column in ["groups", "subgroups"]:
386 if column in current_schema.names():
387 has_values = (
388 lf.select(pl.col(column).is_not_null().any()).collect().item()
389 )
390 if has_values:
391 # Check for column conflicts before unnesting
392 struct_dtype = current_schema.get(column)
393 if isinstance(struct_dtype, pl.Struct):
394 struct_fields = {field.name for field in struct_dtype.fields}
395 existing_fields = set(current_schema.names())
396 conflicts = struct_fields & existing_fields
398 if not conflicts:
399 # Safe to unnest
400 lf = lf.unnest(column)
401 # If there are conflicts, skip unnesting (top-level columns already exist)
403 # Only unnest context if no conflicts
404 if "context" in schema.names() and not context_conflicts:
405 has_values = (
406 lf.select(pl.col("context").is_not_null().any()).collect().item()
407 )
408 if has_values:
409 lf = lf.unnest("context")
411 # Handle stat column specially to extract value
412 schema_names = lf.collect_schema().names()
413 if "stat" in schema_names:
414 if "stat_fmt" in schema_names:
415 value_expr = (
416 pl.when(pl.col("stat_fmt").is_null())
417 .then(
418 pl.col("stat").map_elements(
419 ARD._format_stat, return_dtype=pl.Utf8
420 )
421 )
422 .otherwise(pl.col("stat_fmt"))
423 .alias("value")
424 )
425 else:
426 value_expr = (
427 pl.col("stat")
428 .map_elements(ARD._format_stat, return_dtype=pl.Utf8)
429 .alias("value")
430 )
432 lf = lf.with_columns(value_expr)
434 return lf.collect()
436 def pivot(
437 self,
438 on: str | list[str],
439 index: str | list[str] | None = None,
440 values: str = "stat",
441 aggregate_function: str = "first",
442 ) -> pl.DataFrame:
443 """Pivot ARD data using flattened column access."""
444 # First flatten the ARD to get columns directly accessible
445 df = self.to_long()
447 # Add value column if using stat
448 if values == "stat":
449 df = df.with_columns(
450 pl.col("stat")
451 .map_elements(ARD._stat_value, return_dtype=pl.Float64)
452 .alias("value")
453 )
454 values = "value"
456 # Set default index if not provided
457 if index is None:
458 # Use all remaining columns except the pivot columns and values
459 on_list = [on] if isinstance(on, str) else on
460 index = [col for col in df.columns if col not in on_list + [values]]
462 # Ensure index is a list
463 if isinstance(index, str):
464 index = [index]
466 return df.pivot(
467 on=on, index=index, values=values, aggregate_function=aggregate_function
468 )
470 def get_stats(self, include_metadata: bool = False) -> pl.DataFrame:
471 """Return a DataFrame of metric values with optional stat metadata columns."""
472 select_cols = ["metric", "stat"]
473 schema_names = self._lf.collect_schema().names()
474 if "stat_fmt" in schema_names:
475 select_cols.append("stat_fmt")
476 df = self._lf.select(select_cols).collect()
478 values = [ARD._stat_value(stat) for stat in df["stat"]]
480 if include_metadata:
481 types = [stat.get("type") if stat else None for stat in df["stat"]]
482 formats = [stat.get("format") if stat else None for stat in df["stat"]]
483 if "stat_fmt" in df.columns:
484 formatted = df["stat_fmt"].to_list()
485 else:
486 formatted = [None] * len(df)
487 return pl.DataFrame(
488 {
489 "metric": df["metric"],
490 "value": values,
491 "type": types,
492 "format": formats,
493 "formatted": formatted,
494 },
495 strict=False,
496 )
498 return pl.DataFrame({"metric": df["metric"], "value": values}, strict=False)
500 # ------------------------------------------------------------------
501 # Summaries
502 # ------------------------------------------------------------------
504 def summary(self) -> dict[str, Any]:
505 """Summarise key counts and distinct values present in the collected ARD."""
506 df = self.collect()
507 return {
508 "n_rows": len(df),
509 "n_metrics": df["metric"].n_unique(),
510 "n_estimates": df["estimate"].n_unique(),
511 "n_groups": df.filter(pl.col("groups").is_not_null())["groups"].n_unique(),
512 "n_subgroups": df.filter(pl.col("subgroups").is_not_null())[
513 "subgroups"
514 ].n_unique(),
515 "metrics": df["metric"].unique().to_list(),
516 "estimates": df["estimate"].unique().to_list(),
517 }
519 def describe(self) -> None:
520 """Print a simple console summary and preview of the ARD contents."""
521 summary = self.summary()
522 print("=" * 50)
523 print(f"ARD Summary: {summary['n_rows']} results")
524 print("=" * 50)
525 print("\nMetrics:")
526 for metric in summary["metrics"]:
527 print(f" - {metric}")
528 if summary["n_estimates"]:
529 print("\nEstimates:")
530 for estimate in summary["estimates"]:
531 if estimate:
532 print(f" - {estimate}")
533 if summary["n_groups"]:
534 print(f"\nGroup combinations: {summary['n_groups']}")
535 if summary["n_subgroups"]:
536 print(f"Subgroup combinations: {summary['n_subgroups']}")
537 print("\nPreview:")
538 print(self._lf.limit(5).collect())