Coverage for src/polars_eval_metrics/ard.py: 72%

260 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 15:04 +0000

1"""Analysis Results Data (ARD) container.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass 

6import json 

7from typing import Any, Iterable, Mapping 

8 

9import polars as pl 

10 

11 

12@dataclass 

13class ARD: 

14 """Fixed-schema container for metric evaluation output.""" 

15 

16 _lf: pl.LazyFrame 

17 _group_fields: tuple[str, ...] 

18 _subgroup_fields: tuple[str, ...] 

19 _context_fields: tuple[str, ...] 

20 _id_fields: tuple[str, ...] 

21 

22 def __init__(self, data: pl.DataFrame | pl.LazyFrame | None = None) -> None: 

23 if data is None: 

24 self._lf = self._empty_frame() 

25 elif isinstance(data, pl.DataFrame): 

26 self._validate_schema(data) 

27 self._lf = data.lazy() 

28 elif isinstance(data, pl.LazyFrame): 

29 self._lf = data 

30 else: 

31 raise TypeError(f"Unsupported data type: {type(data)}") 

32 

33 schema = self._lf.collect_schema() 

34 self._id_fields = self._extract_struct_fields(schema, "id") 

35 self._group_fields = self._extract_struct_fields(schema, "groups") 

36 self._subgroup_fields = self._extract_struct_fields(schema, "subgroups") 

37 self._context_fields = self._extract_struct_fields(schema, "context") 

38 

39 # --------------------------------------------------------------------- 

40 # Construction helpers 

41 # --------------------------------------------------------------------- 

42 

43 @staticmethod 

44 def _empty_frame() -> pl.LazyFrame: 

45 """Return an empty ARD frame with the canonical schema.""" 

46 stat_dtype = pl.Struct( 

47 [ 

48 pl.Field("type", pl.Utf8), 

49 pl.Field("value_float", pl.Float64), 

50 pl.Field("value_int", pl.Int64), 

51 pl.Field("value_bool", pl.Boolean), 

52 pl.Field("value_str", pl.Utf8), 

53 pl.Field("value_struct", pl.Struct([])), 

54 pl.Field("format", pl.Utf8), 

55 ] 

56 ) 

57 frame = pl.DataFrame( 

58 { 

59 "id": pl.Series([], dtype=pl.Null), 

60 "groups": pl.Series([], dtype=pl.Struct([])), 

61 "subgroups": pl.Series([], dtype=pl.Struct([])), 

62 "estimate": pl.Series([], dtype=pl.Utf8), 

63 "metric": pl.Series([], dtype=pl.Utf8), 

64 "label": pl.Series([], dtype=pl.Utf8), 

65 "stat": pl.Series([], dtype=stat_dtype), 

66 "stat_fmt": pl.Series([], dtype=pl.Utf8), 

67 "warning": pl.Series([], dtype=pl.List(pl.Utf8)), 

68 "error": pl.Series([], dtype=pl.List(pl.Utf8)), 

69 "context": pl.Series([], dtype=pl.Struct([])), 

70 } 

71 ) 

72 return frame.lazy() 

73 

74 @staticmethod 

75 def _extract_struct_fields( 

76 schema: Mapping[str, pl.DataType], column: str 

77 ) -> tuple[str, ...]: 

78 """Return field names for struct columns, or an empty tuple when not present.""" 

79 dtype = schema.get(column) 

80 if isinstance(dtype, pl.Struct): 

81 return tuple(field.name for field in dtype.fields) 

82 return tuple() 

83 

84 @staticmethod 

85 def _validate_schema(df: pl.DataFrame) -> None: 

86 """Guard against constructing ARD from frames missing required columns.""" 

87 required = {"groups", "subgroups", "estimate", "metric", "stat", "context"} 

88 missing = required - set(df.columns) 

89 if missing: 

90 raise ValueError(f"Missing required ARD columns: {missing}") 

91 

92 # ------------------------------------------------------------------ 

93 # Basic API 

94 # ------------------------------------------------------------------ 

95 

96 @property 

97 def lazy(self) -> pl.LazyFrame: 

98 return self._lf 

99 

100 def collect(self) -> pl.DataFrame: 

101 """Collect the lazy evaluation while keeping the canonical columns when available.""" 

102 # Keep core columns for backward compatibility when eagerly collecting 

103 available = self._lf.collect_schema().names() 

104 desired = [ 

105 col 

106 for col in [ 

107 "id", 

108 "groups", 

109 "subgroups", 

110 "subgroup_name", 

111 "subgroup_value", 

112 "estimate", 

113 "metric", 

114 "label", 

115 "stat", 

116 "stat_fmt", 

117 "warning", 

118 "error", 

119 "context", 

120 ] 

121 if col in available 

122 ] 

123 return self._lf.select(desired).collect() 

124 

125 def __len__(self) -> int: 

126 return self.collect().height 

127 

128 @property 

129 def shape(self) -> tuple[int, int]: 

130 collected = self.collect() 

131 return collected.shape 

132 

133 @property 

134 def columns(self) -> list[str]: 

135 return list(self.schema.keys()) 

136 

137 @property 

138 def schema(self) -> dict[str, pl.DataType]: 

139 """Expose the ARD schema for compatibility with tests/utilities.""" 

140 collected = self._lf.collect_schema() 

141 return dict(zip(collected.names(), collected.dtypes())) 

142 

143 def __getitem__(self, key: str) -> pl.Series: 

144 """Allow DataFrame-like column access for compatibility with tests.""" 

145 collected = self.collect() 

146 if key in collected.columns: 

147 return collected[key] 

148 schema_names = self._lf.collect_schema().names() 

149 if key in schema_names: 

150 return self._lf.select(pl.col(key)).collect()[key] 

151 raise KeyError(key) 

152 

153 def iter_rows(self, *args: Any, **kwargs: Any) -> Iterable[tuple[Any, ...]]: 

154 """Iterate over rows of the eagerly collected DataFrame.""" 

155 return self.collect().iter_rows(*args, **kwargs) 

156 

157 def sort(self, *args: Any, **kwargs: Any) -> ARD: 

158 """Return a sorted ARD (collecting lazily).""" 

159 return ARD(self._lf.sort(*args, **kwargs)) 

160 

161 # ------------------------------------------------------------------ 

162 # Formatting utilities 

163 # ------------------------------------------------------------------ 

164 

165 @staticmethod 

166 def _stat_value(stat: Mapping[str, Any] | None) -> Any: 

167 """Extract the native value stored in a stat struct regardless of channel used.""" 

168 if stat is None: 

169 return None 

170 

171 type_label = (stat.get("type") or "").lower() 

172 if type_label == "float": 

173 return stat.get("value_float") 

174 if type_label == "int": 

175 return stat.get("value_int") 

176 if type_label == "bool": 

177 return stat.get("value_bool") 

178 if type_label == "string": 

179 return stat.get("value_str") 

180 if type_label == "struct": 

181 return stat.get("value_struct") 

182 

183 for field in [ 

184 "value_float", 

185 "value_int", 

186 "value_bool", 

187 "value_str", 

188 "value_struct", 

189 ]: 

190 candidate = stat.get(field) 

191 if candidate is not None: 

192 if field == "value_struct": 

193 return candidate 

194 return candidate 

195 

196 return None 

197 

198 @staticmethod 

199 def _format_stat(stat: Mapping[str, Any] | None) -> str: 

200 """Render a stat struct into a string while respecting explicit formatting hints.""" 

201 if stat is None: 

202 return "NULL" 

203 

204 value = ARD._stat_value(stat) 

205 type_label = (stat.get("type") or "").lower() 

206 fmt = stat.get("format") 

207 if fmt and value is not None: 

208 try: 

209 rendered = fmt.format(value) 

210 except Exception: 

211 rendered = str(value) 

212 elif isinstance(value, float): 

213 rendered = f"{value:.1f}" 

214 elif isinstance(value, int): 

215 rendered = f"{value:,}" 

216 elif isinstance(value, (dict, list, tuple)): 

217 rendered = json.dumps(value) 

218 else: 

219 rendered = None if value is None else str(value) 

220 

221 return rendered 

222 

223 def __repr__(self) -> str: 

224 summary = self.summary() 

225 return f"ARD(summary={summary})" 

226 

227 # ------------------------------------------------------------------ 

228 # Null / empty handling 

229 # ------------------------------------------------------------------ 

230 

231 def with_empty_as_null(self) -> ARD: 

232 """Collapse empty structs or blank strings to null for easier downstream filtering.""" 

233 

234 def _collapse(column: str, fields: tuple[str, ...]) -> pl.Expr: 

235 if not fields: 

236 return pl.col(column) 

237 empty = pl.all_horizontal( 

238 [pl.col(column).struct.field(field).is_null() for field in fields] 

239 ) 

240 return ( 

241 pl.when(pl.col(column).is_null() | empty) 

242 .then(None) 

243 .otherwise(pl.col(column)) 

244 .alias(column) 

245 ) 

246 

247 lf = self._lf.with_columns( 

248 [ 

249 _collapse("id", self._id_fields), 

250 _collapse("groups", self._group_fields), 

251 _collapse("subgroups", self._subgroup_fields), 

252 _collapse("context", self._context_fields), 

253 pl.when(pl.col("estimate") == "") 

254 .then(None) 

255 .otherwise(pl.col("estimate")) 

256 .alias("estimate"), 

257 ] 

258 ) 

259 return ARD(lf) 

260 

261 def with_null_as_empty(self) -> ARD: 

262 """Fill null structs or estimates with empty shells to simplify presentation.""" 

263 

264 def _expand(column: str, fields: tuple[str, ...]) -> pl.Expr: 

265 if not fields: 

266 return pl.col(column) 

267 placeholders = [pl.lit(None).alias(name) for name in fields] 

268 return ( 

269 pl.when(pl.col(column).is_null()) 

270 .then(pl.struct(placeholders)) 

271 .otherwise(pl.col(column)) 

272 .alias(column) 

273 ) 

274 

275 lf = self._lf.with_columns( 

276 [ 

277 _expand("id", self._id_fields), 

278 _expand("groups", self._group_fields), 

279 _expand("subgroups", self._subgroup_fields), 

280 _expand("context", self._context_fields), 

281 pl.col("estimate").fill_null(""), 

282 ] 

283 ) 

284 return ARD(lf) 

285 

286 # ------------------------------------------------------------------ 

287 # Transformations 

288 # ------------------------------------------------------------------ 

289 

290 def unnest(self, columns: list[str] | None = None) -> pl.DataFrame: 

291 """Expand selected struct columns into top-level fields for inspection or exports.""" 

292 columns = columns or ["groups", "subgroups"] 

293 lf = self._lf 

294 schema = lf.collect_schema() 

295 for column in columns: 

296 if column not in {"id", "groups", "subgroups", "context", "stat"}: 

297 continue 

298 if column not in schema.names(): 

299 continue 

300 dtype = schema.get(column) 

301 if not isinstance(dtype, pl.Struct): 

302 continue 

303 struct_fields = {field.name for field in dtype.fields} 

304 existing_fields = set(schema.names()) 

305 if struct_fields & existing_fields: 

306 continue 

307 has_values = lf.select(pl.col(column).is_not_null().any()).collect().item() 

308 if has_values: 

309 lf = lf.unnest(column) 

310 schema = lf.collect_schema() 

311 return lf.collect() 

312 

313 def to_wide( 

314 self, 

315 index: list[str] | None = None, 

316 columns: list[str] | None = None, 

317 values: str = "stat", 

318 aggregate: str = "first", 

319 ) -> pl.DataFrame: 

320 """Pivot the ARD into a wide grid, formatting stats unless a value column is provided.""" 

321 df = self.unnest(["groups", "subgroups", "context"]) 

322 

323 if columns is None: 

324 has_estimates = ( 

325 df.filter(pl.col("estimate").is_not_null())["estimate"].n_unique() > 1 

326 ) 

327 columns = ["estimate", "metric"] if has_estimates else ["metric"] 

328 

329 if index is None: 

330 index = [col for col in df.columns if col not in columns + [values, "stat"]] 

331 

332 if values == "stat": 

333 if "stat_fmt" in df.columns: 

334 formatted_expr = ( 

335 pl.when(pl.col("stat_fmt").is_null()) 

336 .then( 

337 pl.col("stat").map_elements( 

338 ARD._format_stat, return_dtype=pl.Utf8 

339 ) 

340 ) 

341 .otherwise(pl.col("stat_fmt")) 

342 .alias("_value") 

343 ) 

344 else: 

345 formatted_expr = ( 

346 pl.col("stat") 

347 .map_elements(ARD._format_stat, return_dtype=pl.Utf8) 

348 .alias("_value") 

349 ) 

350 

351 df = df.with_columns(formatted_expr) 

352 values = "_value" 

353 

354 if not index or all(df[col].null_count() == len(df) for col in index): 

355 df = df.with_row_index("_idx") 

356 index = ["_idx"] 

357 

358 pivoted = df.pivot( 

359 index=index, on=columns, values=values, aggregate_function=aggregate 

360 ) 

361 

362 if "_idx" in pivoted.columns: 

363 pivoted = pivoted.drop("_idx") 

364 if "_value" in pivoted.columns: 

365 pivoted = pivoted.drop("_value") 

366 return pivoted 

367 

368 def to_long(self) -> pl.DataFrame: 

369 """Convert ARD to long format with flattened columns for direct Polars operations.""" 

370 # Start with a copy of the lazy frame 

371 lf = self._lf 

372 schema = lf.collect_schema() 

373 

374 # Check for potential conflicts with context unnesting 

375 context_conflicts = False 

376 if "context" in schema.names(): 

377 context_dtype = schema.get("context") 

378 if isinstance(context_dtype, pl.Struct): 

379 context_fields = {field.name for field in context_dtype.fields} 

380 existing_fields = set(schema.names()) 

381 context_conflicts = bool(context_fields & existing_fields) 

382 

383 # Unnest struct columns, checking for conflicts 

384 current_schema = lf.collect_schema() 

385 for column in ["groups", "subgroups"]: 

386 if column in current_schema.names(): 

387 has_values = ( 

388 lf.select(pl.col(column).is_not_null().any()).collect().item() 

389 ) 

390 if has_values: 

391 # Check for column conflicts before unnesting 

392 struct_dtype = current_schema.get(column) 

393 if isinstance(struct_dtype, pl.Struct): 

394 struct_fields = {field.name for field in struct_dtype.fields} 

395 existing_fields = set(current_schema.names()) 

396 conflicts = struct_fields & existing_fields 

397 

398 if not conflicts: 

399 # Safe to unnest 

400 lf = lf.unnest(column) 

401 # If there are conflicts, skip unnesting (top-level columns already exist) 

402 

403 # Only unnest context if no conflicts 

404 if "context" in schema.names() and not context_conflicts: 

405 has_values = ( 

406 lf.select(pl.col("context").is_not_null().any()).collect().item() 

407 ) 

408 if has_values: 

409 lf = lf.unnest("context") 

410 

411 # Handle stat column specially to extract value 

412 schema_names = lf.collect_schema().names() 

413 if "stat" in schema_names: 

414 if "stat_fmt" in schema_names: 

415 value_expr = ( 

416 pl.when(pl.col("stat_fmt").is_null()) 

417 .then( 

418 pl.col("stat").map_elements( 

419 ARD._format_stat, return_dtype=pl.Utf8 

420 ) 

421 ) 

422 .otherwise(pl.col("stat_fmt")) 

423 .alias("value") 

424 ) 

425 else: 

426 value_expr = ( 

427 pl.col("stat") 

428 .map_elements(ARD._format_stat, return_dtype=pl.Utf8) 

429 .alias("value") 

430 ) 

431 

432 lf = lf.with_columns(value_expr) 

433 

434 return lf.collect() 

435 

436 def pivot( 

437 self, 

438 on: str | list[str], 

439 index: str | list[str] | None = None, 

440 values: str = "stat", 

441 aggregate_function: str = "first", 

442 ) -> pl.DataFrame: 

443 """Pivot ARD data using flattened column access.""" 

444 # First flatten the ARD to get columns directly accessible 

445 df = self.to_long() 

446 

447 # Add value column if using stat 

448 if values == "stat": 

449 df = df.with_columns( 

450 pl.col("stat") 

451 .map_elements(ARD._stat_value, return_dtype=pl.Float64) 

452 .alias("value") 

453 ) 

454 values = "value" 

455 

456 # Set default index if not provided 

457 if index is None: 

458 # Use all remaining columns except the pivot columns and values 

459 on_list = [on] if isinstance(on, str) else on 

460 index = [col for col in df.columns if col not in on_list + [values]] 

461 

462 # Ensure index is a list 

463 if isinstance(index, str): 

464 index = [index] 

465 

466 return df.pivot( 

467 on=on, index=index, values=values, aggregate_function=aggregate_function 

468 ) 

469 

470 def get_stats(self, include_metadata: bool = False) -> pl.DataFrame: 

471 """Return a DataFrame of metric values with optional stat metadata columns.""" 

472 select_cols = ["metric", "stat"] 

473 schema_names = self._lf.collect_schema().names() 

474 if "stat_fmt" in schema_names: 

475 select_cols.append("stat_fmt") 

476 df = self._lf.select(select_cols).collect() 

477 

478 values = [ARD._stat_value(stat) for stat in df["stat"]] 

479 

480 if include_metadata: 

481 types = [stat.get("type") if stat else None for stat in df["stat"]] 

482 formats = [stat.get("format") if stat else None for stat in df["stat"]] 

483 if "stat_fmt" in df.columns: 

484 formatted = df["stat_fmt"].to_list() 

485 else: 

486 formatted = [None] * len(df) 

487 return pl.DataFrame( 

488 { 

489 "metric": df["metric"], 

490 "value": values, 

491 "type": types, 

492 "format": formats, 

493 "formatted": formatted, 

494 }, 

495 strict=False, 

496 ) 

497 

498 return pl.DataFrame({"metric": df["metric"], "value": values}, strict=False) 

499 

500 # ------------------------------------------------------------------ 

501 # Summaries 

502 # ------------------------------------------------------------------ 

503 

504 def summary(self) -> dict[str, Any]: 

505 """Summarise key counts and distinct values present in the collected ARD.""" 

506 df = self.collect() 

507 return { 

508 "n_rows": len(df), 

509 "n_metrics": df["metric"].n_unique(), 

510 "n_estimates": df["estimate"].n_unique(), 

511 "n_groups": df.filter(pl.col("groups").is_not_null())["groups"].n_unique(), 

512 "n_subgroups": df.filter(pl.col("subgroups").is_not_null())[ 

513 "subgroups" 

514 ].n_unique(), 

515 "metrics": df["metric"].unique().to_list(), 

516 "estimates": df["estimate"].unique().to_list(), 

517 } 

518 

519 def describe(self) -> None: 

520 """Print a simple console summary and preview of the ARD contents.""" 

521 summary = self.summary() 

522 print("=" * 50) 

523 print(f"ARD Summary: {summary['n_rows']} results") 

524 print("=" * 50) 

525 print("\nMetrics:") 

526 for metric in summary["metrics"]: 

527 print(f" - {metric}") 

528 if summary["n_estimates"]: 

529 print("\nEstimates:") 

530 for estimate in summary["estimates"]: 

531 if estimate: 

532 print(f" - {estimate}") 

533 if summary["n_groups"]: 

534 print(f"\nGroup combinations: {summary['n_groups']}") 

535 if summary["n_subgroups"]: 

536 print(f"Subgroup combinations: {summary['n_subgroups']}") 

537 print("\nPreview:") 

538 print(self._lf.limit(5).collect())