Coverage for src/polars_eval_metrics/result_formatter.py: 87%

350 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 15:04 +0000

1from __future__ import annotations 

2 

3from typing import Any, Mapping, Sequence 

4 

5import polars as pl 

6 

7from .ard import ARD 

8from .evaluation_context import EstimateCatalog, FormatterContext 

9from .utils import parse_json_tokens 

10 

11 

12def convert_to_ard(result_lf: pl.LazyFrame, context: FormatterContext) -> ARD: 

13 """Convert the evaluator output into canonical ARD columns lazily.""" 

14 

15 schema = result_lf.collect_schema() 

16 schema_names = set(schema.names()) 

17 

18 ard_frame = result_lf.with_columns( 

19 [ 

20 _expr_groups(schema, context.group_by), 

21 _expr_subgroups(schema, context.subgroup_by), 

22 _expr_estimate(schema, context.estimate_catalog), 

23 _expr_metric_enum(context.metric_catalog.names), 

24 _expr_label_enum(context.metric_catalog.labels), 

25 _expr_stat_struct(schema), 

26 _expr_context_struct(schema, context), 

27 ] 

28 ) 

29 

30 warning_expr = ( 

31 pl.col("_diagnostic_warning") 

32 if "_diagnostic_warning" in schema_names 

33 else pl.lit([], dtype=pl.List(pl.Utf8)) 

34 ) 

35 error_expr = ( 

36 pl.col("_diagnostic_error") 

37 if "_diagnostic_error" in schema_names 

38 else pl.lit([], dtype=pl.List(pl.Utf8)) 

39 ) 

40 

41 ard_frame = ard_frame.with_columns( 

42 [ 

43 pl.col("stat") 

44 .map_elements(ARD._format_stat, return_dtype=pl.Utf8) 

45 .alias("stat_fmt"), 

46 warning_expr.alias("warning"), 

47 error_expr.alias("error"), 

48 ] 

49 ) 

50 

51 cleanup_cols = [ 

52 "_value_kind", 

53 "_value_format", 

54 "_value_float", 

55 "_value_int", 

56 "_value_bool", 

57 "_value_str", 

58 "_value_struct", 

59 "_diagnostic_warning", 

60 "_diagnostic_error", 

61 ] 

62 drop_cols = [col for col in cleanup_cols if col in schema_names] 

63 if drop_cols: 

64 ard_frame = ard_frame.drop(drop_cols) 

65 

66 return ARD(ard_frame) 

67 

68 

69def build_group_pivot( 

70 long_df: pl.DataFrame, 

71 context: FormatterContext, 

72 *, 

73 column_order_by: str, 

74 row_order_by: str, 

75) -> pl.DataFrame: 

76 """Pivot results with groups as rows and model × metric as columns.""" 

77 

78 group_labels = [ 

79 label for label in context.group_by.values() if label in long_df.columns 

80 ] 

81 subgroup_present = ( 

82 "subgroup_name" in long_df.columns and "subgroup_value" in long_df.columns 

83 ) 

84 

85 if row_order_by == "subgroup" and subgroup_present: 

86 index_cols = ["subgroup_name", "subgroup_value"] + group_labels 

87 else: 

88 index_cols = group_labels + ( 

89 ["subgroup_name", "subgroup_value"] if subgroup_present else [] 

90 ) 

91 

92 display_col = ( 

93 "estimate_label" if "estimate_label" in long_df.columns else "estimate" 

94 ) 

95 

96 result, sections = _build_pivot_table( 

97 long_df, 

98 index_cols=index_cols, 

99 default_on=[display_col, "label"], 

100 scoped_on=[("global", ["label"]), ("group", ["label"])], 

101 ) 

102 

103 section_lookup = {name: cols for name, cols in sections} 

104 

105 if result.is_empty(): 

106 if index_cols: 

107 return pl.DataFrame({col: [] for col in index_cols}) 

108 return pl.DataFrame() 

109 

110 value_cols = [col for col in result.columns if col not in index_cols] 

111 default_cols = section_lookup.get("default", []) 

112 default_cols = [ 

113 col for col in default_cols if col.startswith('{"') and col.endswith('"}') 

114 ] 

115 

116 estimate_order_lookup: Mapping[str, int] = context.estimate_catalog.label_order 

117 metric_label_order_lookup: Mapping[str, int] = context.metric_catalog.label_order 

118 metric_name_order_lookup: Mapping[str, int] = context.metric_catalog.name_order 

119 

120 def metric_order(label: str) -> int: 

121 if label in metric_label_order_lookup: 

122 return metric_label_order_lookup[label] 

123 return metric_name_order_lookup.get(label, len(metric_label_order_lookup)) 

124 

125 def estimate_order(label: str) -> int: 

126 return estimate_order_lookup.get(label, len(estimate_order_lookup)) 

127 

128 def sort_default(columns: list[str]) -> list[str]: 

129 def parse(column: str) -> tuple[str, str]: 

130 inner = column[2:-2] 

131 parts = inner.split('","') 

132 return (parts[0], parts[1]) if len(parts) == 2 else (column, "") 

133 

134 if column_order_by == "metrics": 

135 return sorted( 

136 columns, 

137 key=lambda col: ( 

138 metric_order(parse(col)[1]), 

139 estimate_order(parse(col)[0]), 

140 ), 

141 ) 

142 return sorted( 

143 columns, 

144 key=lambda col: ( 

145 estimate_order(parse(col)[0]), 

146 metric_order(parse(col)[1]), 

147 ), 

148 ) 

149 

150 ordered = ( 

151 index_cols 

152 + section_lookup.get("global", []) 

153 + section_lookup.get("group", []) 

154 + sort_default(default_cols) 

155 ) 

156 

157 remaining = [col for col in value_cols if col not in ordered] 

158 ordered.extend(remaining) 

159 ordered = [col for col in ordered if col in result.columns] 

160 

161 if "subgroup_value" in result.columns: 

162 if context.subgroup_categories and all( 

163 isinstance(cat, str) for cat in context.subgroup_categories 

164 ): 

165 result = result.with_columns( 

166 pl.col("subgroup_value").cast( 

167 pl.Enum(list(context.subgroup_categories)) 

168 ) 

169 ) 

170 else: 

171 result = result.with_columns(pl.col("subgroup_value").cast(pl.Utf8)) 

172 

173 sort_columns: list[str] = [] 

174 temp_sort_columns: list[str] = [] 

175 subgroup_order_map = { 

176 label: idx for idx, label in enumerate(context.subgroup_by.values()) 

177 } 

178 

179 if row_order_by == "group": 

180 sort_columns.extend([col for col in group_labels if col in result.columns]) 

181 if "subgroup_name" in result.columns and context.subgroup_by: 

182 result = result.with_columns( 

183 pl.col("subgroup_name") 

184 .replace(subgroup_order_map) 

185 .fill_null(len(subgroup_order_map)) 

186 .cast(pl.Int32) 

187 .alias("__subgroup_name_order") 

188 ) 

189 temp_sort_columns.append("__subgroup_name_order") 

190 sort_columns.append("__subgroup_name_order") 

191 if "subgroup_value" in result.columns: 

192 sort_columns.append("subgroup_value") 

193 else: 

194 if "subgroup_value" in result.columns: 

195 sort_columns.append("subgroup_value") 

196 if "subgroup_name" in result.columns and context.subgroup_by: 

197 result = result.with_columns( 

198 pl.col("subgroup_name") 

199 .replace(subgroup_order_map) 

200 .fill_null(len(subgroup_order_map)) 

201 .cast(pl.Int32) 

202 .alias("__subgroup_name_order") 

203 ) 

204 temp_sort_columns.append("__subgroup_name_order") 

205 sort_columns.insert(0, "__subgroup_name_order") 

206 sort_columns.extend([col for col in group_labels if col in result.columns]) 

207 

208 if "estimate" in result.columns: 

209 sort_columns.append("estimate") 

210 

211 if sort_columns: 

212 result = result.sort(sort_columns) 

213 if temp_sort_columns: 

214 result = result.drop(temp_sort_columns) 

215 

216 seen: set[str] = set() 

217 deduped: list[str] = [] 

218 for col in ordered: 

219 if col not in seen: 

220 deduped.append(col) 

221 seen.add(col) 

222 

223 return result.select(deduped) 

224 

225 

226def build_model_pivot( 

227 long_df: pl.DataFrame, 

228 context: FormatterContext, 

229 *, 

230 column_order_by: str, 

231 row_order_by: str, 

232) -> pl.DataFrame: 

233 """Pivot results with models as rows and group × metric as columns.""" 

234 

235 subgroup_present = ( 

236 "subgroup_name" in long_df.columns and "subgroup_value" in long_df.columns 

237 ) 

238 

239 if row_order_by == "subgroup" and subgroup_present: 

240 index_cols = ["estimate", "subgroup_name", "subgroup_value"] 

241 else: 

242 index_cols = ["estimate"] + ( 

243 ["subgroup_name", "subgroup_value"] if subgroup_present else [] 

244 ) 

245 

246 if "estimate" in index_cols: 

247 estimate_series = ( 

248 long_df.get_column("estimate") if "estimate" in long_df.columns else None 

249 ) 

250 if estimate_series is None or estimate_series.is_null().all(): 

251 index_cols = [col for col in index_cols if col != "estimate"] 

252 

253 result, sections = _build_pivot_table( 

254 long_df, 

255 index_cols=index_cols, 

256 default_on=[*context.group_by.values(), "label"], 

257 scoped_on=[ 

258 ("global", ["label"]), 

259 ("group", [*context.group_by.values(), "label"]), 

260 ], 

261 ) 

262 

263 section_lookup = {name: cols for name, cols in sections} 

264 

265 group_labels = list(context.group_by.values()) 

266 group_label_count: int = len(group_labels) 

267 group_value_orders: list[dict[Any, int]] = [] 

268 

269 if group_label_count: 

270 for label in group_labels: 

271 if label not in long_df.columns: 

272 group_value_orders.append({}) 

273 continue 

274 

275 series = long_df.get_column(label) 

276 dtype = series.dtype 

277 

278 if isinstance(dtype, pl.Enum): 

279 categories = dtype.categories.to_list() 

280 else: 

281 categories = sorted(series.drop_nulls().unique().to_list()) 

282 

283 group_value_orders.append( 

284 {value: idx for idx, value in enumerate(categories)} 

285 ) 

286 

287 metric_label_order_lookup: Mapping[str, int] = context.metric_catalog.label_order 

288 metric_name_order_lookup: Mapping[str, int] = context.metric_catalog.name_order 

289 estimate_label_map = context.estimate_catalog.key_to_label 

290 

291 def metric_order(label: str) -> int: 

292 if label in metric_label_order_lookup: 

293 return metric_label_order_lookup[label] 

294 return metric_name_order_lookup.get(label, len(metric_label_order_lookup)) 

295 

296 def group_order(tokens: tuple[str, ...]) -> tuple[int, ...]: 

297 if not group_label_count: 

298 return tuple() 

299 values = tokens[:group_label_count] 

300 order_positions: list[int] = [] 

301 for idx, value in enumerate(values): 

302 mapping = group_value_orders[idx] if idx < len(group_value_orders) else {} 

303 order_positions.append(mapping.get(value, len(mapping))) 

304 return tuple(order_positions) 

305 

306 def column_sort_key(column: str) -> tuple[Any, ...]: 

307 tokens = parse_json_tokens(column) 

308 if tokens is None: 

309 return (float("inf"), column) 

310 metric_label = tokens[-1] if tokens else "" 

311 metric_idx = metric_order(metric_label) 

312 group_idx = group_order(tokens) 

313 if column_order_by == "metrics": 

314 return (metric_idx, group_idx, tokens) 

315 return (group_idx, metric_idx, tokens) 

316 

317 if "group" in section_lookup: 

318 section_lookup["group"] = sorted(section_lookup["group"], key=column_sort_key) 

319 

320 if "default" in section_lookup: 

321 section_lookup["default"] = sorted( 

322 section_lookup["default"], key=column_sort_key 

323 ) 

324 

325 if "estimate" in result.columns: 

326 result = result.with_columns( 

327 pl.col("estimate") 

328 .map_elements( 

329 lambda val, mapping=estimate_label_map: mapping.get(val, val), 

330 return_dtype=pl.Utf8, 

331 ) 

332 .alias("estimate_label") 

333 ) 

334 

335 if "subgroup_value" in result.columns: 

336 if context.subgroup_categories and all( 

337 isinstance(cat, str) for cat in context.subgroup_categories 

338 ): 

339 result = result.with_columns( 

340 pl.col("subgroup_value").cast( 

341 pl.Enum(list(context.subgroup_categories)) 

342 ) 

343 ) 

344 else: 

345 result = result.with_columns(pl.col("subgroup_value").cast(pl.Utf8)) 

346 

347 ordered = ( 

348 [col for col in index_cols if col in result.columns] 

349 + [col for col in section_lookup.get("global", []) if col in result.columns] 

350 + [col for col in section_lookup.get("group", []) if col in result.columns] 

351 + [col for col in section_lookup.get("default", []) if col in result.columns] 

352 ) 

353 

354 remaining = [col for col in result.columns if col not in ordered] 

355 ordered.extend(remaining) 

356 

357 result = result.select(ordered) 

358 

359 sort_columns: list[str] = [] 

360 if "subgroup_value" in result.columns: 

361 sort_columns.append("subgroup_value") 

362 if "estimate" in result.columns: 

363 sort_columns.append("estimate") 

364 for label in context.group_by.values(): 

365 if label in result.columns: 

366 sort_columns.append(label) 

367 if sort_columns: 

368 result = result.sort(sort_columns) 

369 

370 return result 

371 

372 

373def _expr_groups(schema: pl.Schema, group_by: Mapping[str, str]) -> pl.Expr: 

374 group_cols = [col for col in group_by.keys() if col in schema.names()] 

375 if not group_cols: 

376 return pl.lit(None).alias("groups") 

377 

378 dtype = pl.Struct([pl.Field(col, schema[col]) for col in group_cols]) 

379 return ( 

380 pl.when(pl.all_horizontal([pl.col(col).is_null() for col in group_cols])) 

381 .then(pl.lit(None, dtype=dtype)) 

382 .otherwise(pl.struct([pl.col(col).alias(col) for col in group_cols])) 

383 .alias("groups") 

384 ) 

385 

386 

387def _expr_subgroups(schema: pl.Schema, subgroup_by: Mapping[str, str]) -> pl.Expr: 

388 if ( 

389 not subgroup_by 

390 or "subgroup_name" not in schema.names() 

391 or "subgroup_value" not in schema.names() 

392 ): 

393 return pl.lit(None).alias("subgroups") 

394 

395 labels = list(subgroup_by.values()) 

396 dtype = pl.Struct([pl.Field(label, pl.Utf8) for label in labels]) 

397 fields = [ 

398 pl.when(pl.col("subgroup_name") == pl.lit(label)) 

399 .then(pl.col("subgroup_value").cast(pl.Utf8)) 

400 .otherwise(pl.lit(None, dtype=pl.Utf8)) 

401 .alias(label) 

402 for label in labels 

403 ] 

404 return ( 

405 pl.when(pl.col("subgroup_name").is_null() | pl.col("subgroup_value").is_null()) 

406 .then(pl.lit(None, dtype=dtype)) 

407 .otherwise(pl.struct(fields)) 

408 .alias("subgroups") 

409 ) 

410 

411 

412def _expr_estimate(schema: pl.Schema, estimate_catalog: EstimateCatalog) -> pl.Expr: 

413 null_utf8 = pl.lit(None, dtype=pl.Utf8) 

414 if "estimate" not in schema.names(): 

415 return null_utf8.alias("estimate") 

416 

417 estimate_names = list(estimate_catalog.keys) 

418 if estimate_names: 

419 return ( 

420 pl.col("estimate") 

421 .cast(pl.Utf8) 

422 .replace({name: name for name in estimate_names}) 

423 .cast(pl.Enum(estimate_names)) 

424 .alias("estimate") 

425 ) 

426 

427 return pl.col("estimate").cast(pl.Utf8).alias("estimate") 

428 

429 

430def _expr_metric_enum(metric_names: Sequence[str]) -> pl.Expr: 

431 metric_categories = list(dict.fromkeys(metric_names)) 

432 return ( 

433 pl.col("metric") 

434 .cast(pl.Utf8) 

435 .replace({name: name for name in metric_categories}) 

436 .cast(pl.Enum(metric_categories)) 

437 .alias("metric") 

438 ) 

439 

440 

441def _expr_label_enum(metric_labels: Sequence[str]) -> pl.Expr: 

442 unique_labels = list(dict.fromkeys(metric_labels)) 

443 return pl.col("label").cast(pl.Enum(unique_labels)).alias("label") 

444 

445 

446def _expr_stat_struct(schema: pl.Schema) -> pl.Expr: 

447 null_utf8 = pl.lit(None, dtype=pl.Utf8) 

448 null_float = pl.lit(None, dtype=pl.Float64) 

449 null_int = pl.lit(None, dtype=pl.Int64) 

450 null_bool = pl.lit(None, dtype=pl.Boolean) 

451 null_struct_expr = pl.lit(None, dtype=pl.Struct([])) 

452 

453 kind_expr = pl.col("_value_kind") if "_value_kind" in schema.names() else None 

454 format_col = ( 

455 pl.col("_value_format") if "_value_format" in schema.names() else null_utf8 

456 ) 

457 

458 float_value = ( 

459 pl.col("_value_float") if "_value_float" in schema.names() else null_float 

460 ) 

461 int_value = pl.col("_value_int") if "_value_int" in schema.names() else null_int 

462 bool_value = pl.col("_value_bool") if "_value_bool" in schema.names() else null_bool 

463 string_value = pl.col("_value_str") if "_value_str" in schema.names() else null_utf8 

464 struct_value = ( 

465 pl.col("_value_struct") 

466 if "_value_struct" in schema.names() 

467 else null_struct_expr 

468 ) 

469 

470 inferred_kind = "float" 

471 if kind_expr is None: 

472 value_dtype = schema.get("value") if "value" in schema.names() else None 

473 inferred_kind = _infer_value_kind_from_dtype(value_dtype) 

474 kind_expr = pl.lit(inferred_kind, dtype=pl.Utf8) 

475 

476 type_label = pl.when(kind_expr.is_null()).then(null_utf8).otherwise(kind_expr) 

477 

478 return pl.struct( 

479 [ 

480 type_label.alias("type"), 

481 float_value.alias("value_float"), 

482 int_value.alias("value_int"), 

483 bool_value.alias("value_bool"), 

484 string_value.alias("value_str"), 

485 struct_value.alias("value_struct"), 

486 format_col.alias("format"), 

487 ] 

488 ).alias("stat") 

489 

490 

491def _expr_context_struct(schema: pl.Schema, context: FormatterContext) -> pl.Expr: 

492 null_utf8 = pl.lit(None, dtype=pl.Utf8) 

493 fields = [] 

494 for field in ("metric_type", "scope", "label"): 

495 if field in schema.names(): 

496 fields.append(pl.col(field).cast(pl.Utf8).alias(field)) 

497 else: 

498 fields.append(null_utf8.alias(field)) 

499 if "estimate" in schema.names(): 

500 label_map = context.estimate_catalog.key_to_label 

501 fields.append( 

502 pl.col("estimate") 

503 .cast(pl.Utf8) 

504 .map_elements( 

505 lambda val, mapping=label_map: mapping.get(val, val), 

506 return_dtype=pl.Utf8, 

507 ) 

508 .alias("estimate_label") 

509 ) 

510 else: 

511 fields.append(null_utf8.alias("estimate_label")) 

512 return pl.struct(fields).alias("context") 

513 

514 

515def _infer_value_kind_from_dtype(dtype: pl.DataType | None) -> str: 

516 if dtype is None or dtype == pl.Null: 

517 return "float" 

518 if dtype == pl.Struct: 

519 return "struct" 

520 if dtype == pl.Boolean: 

521 return "bool" 

522 if dtype == pl.Utf8: 

523 return "string" 

524 if hasattr(dtype, "is_numeric") and dtype.is_numeric(): 

525 return "int" if dtype.is_integer() else "float" 

526 return "string" 

527 

528 

529def _pivot_frame( 

530 df: pl.DataFrame, 

531 *, 

532 index_cols: Sequence[str], 

533 on_cols: Sequence[str], 

534) -> pl.DataFrame: 

535 if df.is_empty(): 

536 if index_cols: 

537 return pl.DataFrame({col: [] for col in index_cols}) 

538 return pl.DataFrame() 

539 

540 if index_cols: 

541 return df.pivot( 

542 index=list(index_cols), 

543 on=list(on_cols), 

544 values="value", 

545 aggregate_function="first", 

546 ) 

547 

548 with_idx = df.with_row_index("_idx") 

549 return with_idx.pivot( 

550 index=["_idx"], 

551 on=list(on_cols), 

552 values="value", 

553 aggregate_function="first", 

554 ).drop("_idx") 

555 

556 

557def _merge_pivot_frames( 

558 base: pl.DataFrame, 

559 candidate: pl.DataFrame, 

560 index_cols: Sequence[str], 

561) -> pl.DataFrame: 

562 if base.is_empty(): 

563 return candidate 

564 if candidate.is_empty(): 

565 return base 

566 

567 if not index_cols: 

568 return pl.concat([base, candidate], how="horizontal") 

569 

570 if candidate.height == 1: 

571 broadcast_cols = [col for col in candidate.columns if col not in index_cols] 

572 if not broadcast_cols: 

573 return base 

574 row_values = candidate.row(0, named=True) 

575 return base.with_columns( 

576 [pl.lit(row_values[col]).alias(col) for col in broadcast_cols] 

577 ) 

578 

579 join_index_cols = list(index_cols) 

580 all_null_cols: list[str] = [] 

581 for col in index_cols: 

582 if col in candidate.columns: 

583 column = candidate.get_column(col) 

584 if column.null_count() == candidate.height: 

585 all_null_cols.append(col) 

586 if all_null_cols: 

587 join_index_cols = [col for col in join_index_cols if col not in all_null_cols] 

588 candidate = candidate.drop(all_null_cols) 

589 

590 if not join_index_cols: 

591 value_cols = [col for col in candidate.columns if col not in index_cols] 

592 if not value_cols: 

593 return base 

594 candidate_unique = candidate.select(value_cols).unique() 

595 if candidate_unique.height == 0: 

596 return base 

597 if candidate_unique.height > 1: 

598 candidate_unique = candidate_unique.head(1) 

599 row_values = candidate_unique.row(0, named=True) 

600 return base.with_columns( 

601 [pl.lit(row_values[col]).alias(col) for col in row_values] 

602 ) 

603 

604 return base.join(candidate, on=join_index_cols, how="left") 

605 

606 

607def _build_pivot_table( 

608 long_df: pl.DataFrame, 

609 *, 

610 index_cols: Sequence[str], 

611 default_on: Sequence[str], 

612 scoped_on: Sequence[tuple[str, Sequence[str]]], 

613) -> tuple[pl.DataFrame, list[tuple[str, list[str]]]]: 

614 default_df = long_df.filter(pl.col("scope").is_null()) 

615 pivot = _pivot_frame(default_df, index_cols=index_cols, on_cols=default_on) 

616 sections: list[tuple[str, list[str]]] = [ 

617 ("default", [col for col in pivot.columns if col not in index_cols]) 

618 ] 

619 

620 for scope_name, on_cols in scoped_on: 

621 scoped_df = long_df.filter(pl.col("scope") == scope_name) 

622 scoped_pivot = _pivot_frame(scoped_df, index_cols=index_cols, on_cols=on_cols) 

623 if scoped_pivot.is_empty(): 

624 continue 

625 sections.append( 

626 (scope_name, [col for col in scoped_pivot.columns if col not in index_cols]) 

627 ) 

628 pivot = _merge_pivot_frames(pivot, scoped_pivot, index_cols) 

629 

630 return pivot, sections 

631 

632 

633def format_verbose_frame(ard: ARD) -> pl.DataFrame: 

634 """Render an ARD as a fully expanded DataFrame suitable for inspection.""" 

635 

636 long_df = ard.to_long() 

637 group_sort_cols = list(ard._group_fields) 

638 subgroup_struct_cols = list(ard._subgroup_fields) 

639 sort_cols: list[str] = [] 

640 

641 for col in group_sort_cols: 

642 if col in long_df.columns: 

643 sort_cols.append(col) 

644 

645 for col in ("subgroup_name", "subgroup_value"): 

646 if col in long_df.columns: 

647 sort_cols.append(col) 

648 

649 for col in subgroup_struct_cols: 

650 if col in long_df.columns and col not in sort_cols: 

651 sort_cols.append(col) 

652 

653 for col in ("metric", "estimate"): 

654 if col in long_df.columns: 

655 sort_cols.append(col) 

656 

657 if sort_cols: 

658 long_df = long_df.sort(sort_cols) 

659 

660 preferred_order = [ 

661 "id", 

662 "groups", 

663 "subgroups", 

664 "subgroup_name", 

665 "subgroup_value", 

666 "estimate", 

667 "metric", 

668 "label", 

669 "value", 

670 "stat", 

671 "stat_fmt", 

672 "context", 

673 "warning", 

674 "error", 

675 ] 

676 ordered_columns = [col for col in preferred_order if col in long_df.columns] 

677 remaining_columns = [col for col in long_df.columns if col not in ordered_columns] 

678 return long_df.select(ordered_columns + remaining_columns) 

679 

680 

681def format_compact_frame(ard: ARD) -> pl.DataFrame: 

682 """Render an ARD as a compact DataFrame with struct columns flattened.""" 

683 

684 verbose_df = format_verbose_frame(ard) 

685 flattened = _flatten_struct_columns( 

686 verbose_df, 

687 group_fields=ard._group_fields, 

688 subgroup_fields=ard._subgroup_fields, 

689 ) 

690 detail_cols = [ 

691 col 

692 for col in ("stat", "stat_fmt", "context", "warning", "error") 

693 if col in flattened.columns 

694 ] 

695 if detail_cols: 

696 flattened = flattened.drop(detail_cols) 

697 return flattened 

698 

699 

700def _flatten_struct_columns( 

701 df: pl.DataFrame, 

702 *, 

703 group_fields: Sequence[str], 

704 subgroup_fields: Sequence[str], 

705) -> pl.DataFrame: 

706 """Flatten struct columns for a compact DataFrame view.""" 

707 

708 working = df 

709 nullable_candidates: list[str] = [] 

710 

711 if "groups" in working.columns and group_fields: 

712 group_exprs = [ 

713 pl.col("groups").struct.field(field).alias(field) for field in group_fields 

714 ] 

715 working = working.with_columns(group_exprs) 

716 nullable_candidates.extend(group_fields) 

717 

718 if "subgroups" in working.columns and subgroup_fields: 

719 subgroup_exprs = [ 

720 pl.col("subgroups").struct.field(field).alias(field) 

721 for field in subgroup_fields 

722 ] 

723 working = working.with_columns(subgroup_exprs) 

724 nullable_candidates.extend(subgroup_fields) 

725 

726 drop_cols = [col for col in ("groups", "subgroups") if col in working.columns] 

727 if drop_cols: 

728 working = working.drop(drop_cols) 

729 

730 nullable_candidates.append("id") 

731 

732 def drop_all_null(df: pl.DataFrame, columns: Sequence[str]) -> pl.DataFrame: 

733 existing = [col for col in dict.fromkeys(columns) if col in df.columns] 

734 if not existing: 

735 return df 

736 result = df.select( 

737 [pl.col(col).is_not_null().any().alias(col) for col in existing] 

738 ) 

739 has_values = result.row(0, named=True) 

740 to_drop = [col for col, flag in has_values.items() if not flag] 

741 if to_drop: 

742 return df.drop(to_drop) 

743 return df 

744 

745 working = drop_all_null(working, nullable_candidates) 

746 

747 return working