Coverage for src/polars_eval_metrics/table_formatter.py: 21%
77 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
1"""Table formatting utilities for polars-eval-metrics."""
3from typing import Any, Callable
5import polars as pl
6from great_tables import GT, html
7from polars import selectors as cs
9from .ard import ARD
10from .utils import parse_json_columns
12# pyre-strict
14ParsedColumns = dict[str, tuple[str, ...]]
17def pivot_to_gt(df: pl.DataFrame, decimals: int = 1) -> GT:
18 """Format a pivot table from MetricEvaluator using great_tables."""
20 stub_builder = _pivot_stub_builder(df)
21 return _gt_from_wide(
22 df,
23 decimals=decimals,
24 column_label_index=0,
25 spanner_index=1,
26 non_json_formatter=_pivot_non_json_label,
27 stub_builder=stub_builder,
28 )
31def ard_to_wide(
32 ard: ARD,
33 index: list[str] | None = None,
34 columns: list[str] | None = None,
35 values: str = "stat",
36 aggregate_fn: str = "first",
37) -> pl.DataFrame:
38 """Thin wrapper around ``ARD.to_wide`` for backwards compatibility."""
40 return ard.to_wide(
41 index=index,
42 columns=columns,
43 values=values,
44 aggregate=aggregate_fn,
45 )
48def ard_to_gt(ard: ARD, decimals: int = 1) -> GT:
49 """
50 Convert ARD to Great Tables format for HTML display.
52 Args:
53 ard: ARD data structure
54 decimals: Number of decimal places for formatting numbers
56 Returns:
57 GT: Great Tables object for display
58 """
59 # Convert to wide format first
60 df = ard_to_wide(ard)
62 parsed_columns = parse_json_columns(df.columns)
63 stub_builder = _ard_stub_builder(ard, df)
65 extra_labels: dict[str, Any] = {}
66 if not parsed_columns:
67 metrics_in_ard = (
68 ard.lazy.select("metric").unique().collect()["metric"].to_list()
69 )
70 for metric in metrics_in_ard:
71 if metric in df.columns:
72 extra_labels[metric] = html(metric)
74 return _gt_from_wide(
75 df,
76 decimals=decimals,
77 column_label_index=0,
78 spanner_index=1,
79 non_json_formatter=_ard_non_json_label,
80 stub_builder=stub_builder,
81 parsed_columns=parsed_columns,
82 extra_labels=extra_labels if extra_labels else None,
83 )
86def _gt_from_wide(
87 df: pl.DataFrame,
88 *,
89 decimals: int,
90 column_label_index: int,
91 spanner_index: int,
92 non_json_formatter: Callable[[str], Any | None],
93 stub_builder: Callable[[GT], GT] | None = None,
94 parsed_columns: ParsedColumns | None = None,
95 extra_labels: dict[str, Any] | None = None,
96) -> GT:
97 parsed = parsed_columns or parse_json_columns(df.columns)
98 gt_table = GT(df)
99 if stub_builder is not None:
100 gt_table = stub_builder(gt_table)
101 return _apply_gt_formatting(
102 gt_table,
103 df,
104 parsed_columns=parsed,
105 column_label_index=column_label_index,
106 spanner_index=spanner_index,
107 non_json_formatter=non_json_formatter,
108 decimals=decimals,
109 extra_labels=extra_labels,
110 )
113def _apply_gt_formatting(
114 gt_table: GT,
115 df: pl.DataFrame,
116 *,
117 parsed_columns: ParsedColumns,
118 column_label_index: int,
119 spanner_index: int,
120 non_json_formatter: Callable[[str], Any | None],
121 decimals: int,
122 extra_labels: dict[str, Any] | None,
123) -> GT:
124 metrics = sorted(
125 {
126 tokens[spanner_index]
127 for tokens in parsed_columns.values()
128 if len(tokens) > spanner_index
129 }
130 )
131 for metric in metrics:
132 metric_columns = [
133 col
134 for col, tokens in parsed_columns.items()
135 if len(tokens) > spanner_index and tokens[spanner_index] == metric
136 ]
137 if metric_columns:
138 gt_table = gt_table.tab_spanner(label=html(metric), columns=metric_columns)
140 column_renames: dict[str, Any] = {}
141 for col, tokens in parsed_columns.items():
142 if len(tokens) > column_label_index:
143 column_renames[col] = html(tokens[column_label_index])
145 for col in df.columns:
146 if col in parsed_columns:
147 continue
148 label = non_json_formatter(col)
149 if label is not None:
150 column_renames[col] = label
152 if extra_labels:
153 column_renames.update(extra_labels)
155 if column_renames:
156 gt_table = gt_table.cols_label(**column_renames)
158 return gt_table.fmt_number(columns=cs.numeric(), decimals=decimals).cols_align(
159 align="center", columns=cs.numeric()
160 )
163def _pivot_stub_builder(df: pl.DataFrame) -> Callable[[GT], GT] | None:
164 if {"subgroup_name", "subgroup_value"}.issubset(df.columns):
165 return lambda gt: gt.tab_stub(
166 rowname_col="subgroup_value", groupname_col="subgroup_name"
167 )
168 return None
171def _pivot_non_json_label(column: str) -> Any | None:
172 if column in {"subgroup_name", "subgroup_value"}:
173 return None
174 return html(column)
177def _ard_stub_builder(ard: ARD, df: pl.DataFrame) -> Callable[[GT], GT] | None:
178 if "subgroups" not in df.columns or df["subgroups"].is_null().all():
179 return None
181 unnested = ard.unnest(["subgroups"])
182 subgroup_cols = [col for col in unnested.columns if col.startswith("subgroups.")]
184 if len(subgroup_cols) == 1:
185 subgroup_col = subgroup_cols[0].replace("subgroups.", "")
187 if subgroup_col in df.columns:
188 return lambda gt: gt.tab_stub(rowname_col=subgroup_col)
190 elif len(subgroup_cols) > 1:
191 group_col = subgroup_cols[0].replace("subgroups.", "")
192 row_col = subgroup_cols[1].replace("subgroups.", "")
194 if group_col in df.columns and row_col in df.columns:
195 return lambda gt: gt.tab_stub(rowname_col=row_col, groupname_col=group_col)
197 return None
200def _ard_non_json_label(column: str) -> Any | None:
201 if column == "subgroups":
202 return None
203 return html(column.replace(".", " ").title())