Coverage for src/polars_eval_metrics/table_formatter.py: 21%

77 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 15:04 +0000

1"""Table formatting utilities for polars-eval-metrics.""" 

2 

3from typing import Any, Callable 

4 

5import polars as pl 

6from great_tables import GT, html 

7from polars import selectors as cs 

8 

9from .ard import ARD 

10from .utils import parse_json_columns 

11 

12# pyre-strict 

13 

14ParsedColumns = dict[str, tuple[str, ...]] 

15 

16 

17def pivot_to_gt(df: pl.DataFrame, decimals: int = 1) -> GT: 

18 """Format a pivot table from MetricEvaluator using great_tables.""" 

19 

20 stub_builder = _pivot_stub_builder(df) 

21 return _gt_from_wide( 

22 df, 

23 decimals=decimals, 

24 column_label_index=0, 

25 spanner_index=1, 

26 non_json_formatter=_pivot_non_json_label, 

27 stub_builder=stub_builder, 

28 ) 

29 

30 

31def ard_to_wide( 

32 ard: ARD, 

33 index: list[str] | None = None, 

34 columns: list[str] | None = None, 

35 values: str = "stat", 

36 aggregate_fn: str = "first", 

37) -> pl.DataFrame: 

38 """Thin wrapper around ``ARD.to_wide`` for backwards compatibility.""" 

39 

40 return ard.to_wide( 

41 index=index, 

42 columns=columns, 

43 values=values, 

44 aggregate=aggregate_fn, 

45 ) 

46 

47 

48def ard_to_gt(ard: ARD, decimals: int = 1) -> GT: 

49 """ 

50 Convert ARD to Great Tables format for HTML display. 

51 

52 Args: 

53 ard: ARD data structure 

54 decimals: Number of decimal places for formatting numbers 

55 

56 Returns: 

57 GT: Great Tables object for display 

58 """ 

59 # Convert to wide format first 

60 df = ard_to_wide(ard) 

61 

62 parsed_columns = parse_json_columns(df.columns) 

63 stub_builder = _ard_stub_builder(ard, df) 

64 

65 extra_labels: dict[str, Any] = {} 

66 if not parsed_columns: 

67 metrics_in_ard = ( 

68 ard.lazy.select("metric").unique().collect()["metric"].to_list() 

69 ) 

70 for metric in metrics_in_ard: 

71 if metric in df.columns: 

72 extra_labels[metric] = html(metric) 

73 

74 return _gt_from_wide( 

75 df, 

76 decimals=decimals, 

77 column_label_index=0, 

78 spanner_index=1, 

79 non_json_formatter=_ard_non_json_label, 

80 stub_builder=stub_builder, 

81 parsed_columns=parsed_columns, 

82 extra_labels=extra_labels if extra_labels else None, 

83 ) 

84 

85 

86def _gt_from_wide( 

87 df: pl.DataFrame, 

88 *, 

89 decimals: int, 

90 column_label_index: int, 

91 spanner_index: int, 

92 non_json_formatter: Callable[[str], Any | None], 

93 stub_builder: Callable[[GT], GT] | None = None, 

94 parsed_columns: ParsedColumns | None = None, 

95 extra_labels: dict[str, Any] | None = None, 

96) -> GT: 

97 parsed = parsed_columns or parse_json_columns(df.columns) 

98 gt_table = GT(df) 

99 if stub_builder is not None: 

100 gt_table = stub_builder(gt_table) 

101 return _apply_gt_formatting( 

102 gt_table, 

103 df, 

104 parsed_columns=parsed, 

105 column_label_index=column_label_index, 

106 spanner_index=spanner_index, 

107 non_json_formatter=non_json_formatter, 

108 decimals=decimals, 

109 extra_labels=extra_labels, 

110 ) 

111 

112 

113def _apply_gt_formatting( 

114 gt_table: GT, 

115 df: pl.DataFrame, 

116 *, 

117 parsed_columns: ParsedColumns, 

118 column_label_index: int, 

119 spanner_index: int, 

120 non_json_formatter: Callable[[str], Any | None], 

121 decimals: int, 

122 extra_labels: dict[str, Any] | None, 

123) -> GT: 

124 metrics = sorted( 

125 { 

126 tokens[spanner_index] 

127 for tokens in parsed_columns.values() 

128 if len(tokens) > spanner_index 

129 } 

130 ) 

131 for metric in metrics: 

132 metric_columns = [ 

133 col 

134 for col, tokens in parsed_columns.items() 

135 if len(tokens) > spanner_index and tokens[spanner_index] == metric 

136 ] 

137 if metric_columns: 

138 gt_table = gt_table.tab_spanner(label=html(metric), columns=metric_columns) 

139 

140 column_renames: dict[str, Any] = {} 

141 for col, tokens in parsed_columns.items(): 

142 if len(tokens) > column_label_index: 

143 column_renames[col] = html(tokens[column_label_index]) 

144 

145 for col in df.columns: 

146 if col in parsed_columns: 

147 continue 

148 label = non_json_formatter(col) 

149 if label is not None: 

150 column_renames[col] = label 

151 

152 if extra_labels: 

153 column_renames.update(extra_labels) 

154 

155 if column_renames: 

156 gt_table = gt_table.cols_label(**column_renames) 

157 

158 return gt_table.fmt_number(columns=cs.numeric(), decimals=decimals).cols_align( 

159 align="center", columns=cs.numeric() 

160 ) 

161 

162 

163def _pivot_stub_builder(df: pl.DataFrame) -> Callable[[GT], GT] | None: 

164 if {"subgroup_name", "subgroup_value"}.issubset(df.columns): 

165 return lambda gt: gt.tab_stub( 

166 rowname_col="subgroup_value", groupname_col="subgroup_name" 

167 ) 

168 return None 

169 

170 

171def _pivot_non_json_label(column: str) -> Any | None: 

172 if column in {"subgroup_name", "subgroup_value"}: 

173 return None 

174 return html(column) 

175 

176 

177def _ard_stub_builder(ard: ARD, df: pl.DataFrame) -> Callable[[GT], GT] | None: 

178 if "subgroups" not in df.columns or df["subgroups"].is_null().all(): 

179 return None 

180 

181 unnested = ard.unnest(["subgroups"]) 

182 subgroup_cols = [col for col in unnested.columns if col.startswith("subgroups.")] 

183 

184 if len(subgroup_cols) == 1: 

185 subgroup_col = subgroup_cols[0].replace("subgroups.", "") 

186 

187 if subgroup_col in df.columns: 

188 return lambda gt: gt.tab_stub(rowname_col=subgroup_col) 

189 

190 elif len(subgroup_cols) > 1: 

191 group_col = subgroup_cols[0].replace("subgroups.", "") 

192 row_col = subgroup_cols[1].replace("subgroups.", "") 

193 

194 if group_col in df.columns and row_col in df.columns: 

195 return lambda gt: gt.tab_stub(rowname_col=row_col, groupname_col=group_col) 

196 

197 return None 

198 

199 

200def _ard_non_json_label(column: str) -> Any | None: 

201 if column == "subgroups": 

202 return None 

203 return html(column.replace(".", " ").title())