Coverage for src/polars_eval_metrics/metric_registry.py: 89%

126 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 15:04 +0000

1""" 

2Unified Expression Registry System 

3 

4This module provides an extensible registry system for all types of expressions: 

5- Error expressions (for data preparation) 

6- Metric expressions (for aggregation) 

7- Summary expressions (for second-level aggregation) 

8 

9Supports both global (class-level) and local (instance-level) registries. 

10""" 

11 

12from dataclasses import dataclass 

13from typing import Any, Callable, ClassVar, cast 

14 

15# pyre-strict 

16 

17import polars as pl 

18 

19 

20class MetricNotFoundError(ValueError): 

21 """Exception raised when a requested metric/error/summary is not found.""" 

22 

23 def __init__( 

24 self, name: str, available: list[str], expr_type: str = "expression" 

25 ) -> None: 

26 self.name = name 

27 self.available = available 

28 self.expr_type = expr_type 

29 super().__init__( 

30 f"{expr_type.capitalize()} '{name}' not found. " 

31 f"Available {expr_type}s: {', '.join(available) if available else 'none'}" 

32 ) 

33 

34 

35@dataclass(frozen=True) 

36class MetricInfo: 

37 expr: pl.Expr 

38 value_kind: str = "float" 

39 format: str | None = None 

40 

41 

42class MetricRegistry: 

43 """ 

44 Unified registry for all expression types used in metric evaluation. 

45 

46 This is a singleton-style class that provides global registration 

47 and retrieval of error types, metrics, and summary expressions. 

48 """ 

49 

50 # Class-level registries 

51 _errors: dict[str, Callable[..., pl.Expr]] = {} 

52 _metrics: dict[str, MetricInfo | Callable[[], MetricInfo]] = {} 

53 _summaries: dict[str, pl.Expr | Callable[[], pl.Expr]] = {} 

54 _REGISTRY_ATTRS: ClassVar[dict[str, str]] = { 

55 "error": "_errors", 

56 "metric": "_metrics", 

57 "summary": "_summaries", 

58 } 

59 

60 @classmethod 

61 def _registry_store(cls, kind: str) -> dict[str, Any]: 

62 try: 

63 attr_name = cls._REGISTRY_ATTRS[kind] 

64 except KeyError as exc: 

65 raise ValueError(f"Unknown registry kind: {kind}") from exc 

66 return getattr(cls, attr_name) 

67 

68 @classmethod 

69 def _registry_names(cls, kind: str) -> list[str]: 

70 return list(cls._registry_store(kind).keys()) 

71 

72 @classmethod 

73 def _registry_contains(cls, kind: str, name: str) -> bool: 

74 return name in cls._registry_store(kind) 

75 

76 # ============ Error Expression Methods ============ 

77 

78 @classmethod 

79 def register_error(cls, name: str, func: Callable[..., pl.Expr]) -> None: 

80 """ 

81 Register a custom error expression function. 

82 

83 Args: 

84 name: Name of the error type (e.g., 'absolute_error', 'buffer_error') 

85 func: Function that takes (estimate, ground_truth, *args) and returns pl.Expr 

86 

87 Example: 

88 def buffer_error(estimate: str, ground_truth: str, threshold: float = 0.5): 

89 return (pl.col(estimate) - pl.col(ground_truth)).abs() <= threshold 

90 

91 MetricRegistry.register_error('buffer_error', buffer_error) 

92 """ 

93 cls._errors[name] = func 

94 

95 @classmethod 

96 def get_error( 

97 cls, 

98 name: str, 

99 estimate: str, 

100 ground_truth: str, 

101 **params: Any, 

102 ) -> pl.Expr: 

103 """ 

104 Get an error expression by name. 

105 

106 Args: 

107 name: Name of the error type 

108 estimate: Estimate column name 

109 ground_truth: Ground truth column name 

110 **params: Additional parameters for parameterized error functions 

111 

112 Returns: 

113 Polars expression that computes the error 

114 """ 

115 if not cls._registry_contains("error", name): 

116 raise MetricNotFoundError(name, cls.list_errors(), "error") 

117 

118 func = cls._registry_store("error")[name] 

119 return func(estimate, ground_truth, **params) 

120 

121 @classmethod 

122 def generate_error_columns( 

123 cls, 

124 estimate: str, 

125 ground_truth: str, 

126 error_types: list[str] | None = None, 

127 error_params: dict[str, dict[str, Any]] | None = None, 

128 ) -> list[pl.Expr]: 

129 """Generate error column expressions for specified error types.""" 

130 error_types = error_types or cls.list_errors() 

131 error_params = error_params or {} 

132 

133 return [ 

134 cls.get_error( 

135 error_type, estimate, ground_truth, **error_params.get(error_type, {}) 

136 ).alias(error_type) 

137 for error_type in error_types 

138 ] 

139 

140 @classmethod 

141 def list_errors(cls) -> list[str]: 

142 """List all available error types.""" 

143 return cls._registry_names("error") 

144 

145 @classmethod 

146 def list_metrics(cls) -> list[str]: 

147 """List all available metrics.""" 

148 return cls._registry_names("metric") 

149 

150 @classmethod 

151 def list_summaries(cls) -> list[str]: 

152 """List all available summaries.""" 

153 return cls._registry_names("summary") 

154 

155 @classmethod 

156 def has_metric(cls, name: str) -> bool: 

157 """Check if a metric exists in the registry.""" 

158 return cls._registry_contains("metric", name) 

159 

160 @classmethod 

161 def has_summary(cls, name: str) -> bool: 

162 """Check if a summary/selector exists in the registry.""" 

163 return cls._registry_contains("summary", name) 

164 

165 @classmethod 

166 def has_error(cls, name: str) -> bool: 

167 """Check if an error type exists in the registry.""" 

168 return cls._registry_contains("error", name) 

169 

170 # ============ Metric Expression Methods ============ 

171 

172 @classmethod 

173 def register_metric( 

174 cls, 

175 name: str, 

176 expr: pl.Expr | Callable[[], pl.Expr] | MetricInfo | Callable[[], MetricInfo], 

177 *, 

178 value_kind: str | None = None, 

179 format: str | None = None, 

180 ) -> None: 

181 """ 

182 Register a custom metric expression. 

183 

184 Args: 

185 name: Name of the metric (e.g., 'mae', 'custom_accuracy') 

186 expr: Polars expression/MetricInfo or callable returning one. 

187 Expressions reference error columns and define the aggregated result. 

188 value_kind: Optional type hint (``float``, ``int``, ``struct`` ...) used to 

189 populate the ``stat`` struct. Defaults to ``float``. 

190 format: Optional string formatter applied when rendering ``value``. 

191 

192 Example: 

193 MetricRegistry.register_metric('mae', pl.col('absolute_error').mean().alias('value')) 

194 """ 

195 if isinstance(expr, MetricInfo): 

196 cls._metrics[name] = expr 

197 return 

198 

199 if callable(expr): 

200 callable_expr: Callable[[], MetricInfo | pl.Expr] = cast( 

201 Callable[[], MetricInfo | pl.Expr], expr 

202 ) 

203 

204 def factory() -> MetricInfo: 

205 result = callable_expr() 

206 if isinstance(result, MetricInfo): 

207 return result 

208 return MetricInfo( 

209 expr=result, 

210 value_kind=value_kind or "float", 

211 format=format, 

212 ) 

213 

214 cls._metrics[name] = factory 

215 return 

216 

217 info = MetricInfo( 

218 expr=expr, 

219 value_kind=value_kind or "float", 

220 format=format, 

221 ) 

222 cls._metrics[name] = info 

223 

224 @classmethod 

225 def register_summary(cls, name: str, expr: pl.Expr | Callable[[], pl.Expr]) -> None: 

226 """ 

227 Register a custom summary expression. 

228 

229 Args: 

230 name: Name of the summary (e.g., 'mean', 'p90') 

231 expr: Polars expression or callable that returns a Polars expression 

232 The expression should typically operate on 'value' column 

233 

234 Example: 

235 MetricRegistry.register_summary('p90', pl.col('value').quantile(0.9, interpolation="linear")) 

236 """ 

237 cls._summaries[name] = expr 

238 

239 @classmethod 

240 def get_metric(cls, name: str) -> MetricInfo: 

241 """ 

242 Get a metric expression by name. 

243 

244 Args: 

245 name: Name of the metric 

246 

247 Returns: 

248 Polars expression for the metric 

249 

250 Raises: 

251 ValueError: If the metric is not registered 

252 """ 

253 if not cls._registry_contains("metric", name): 

254 raise MetricNotFoundError(name, cls.list_metrics(), "metric") 

255 

256 expr = cls._registry_store("metric")[name] 

257 # If it's a callable, call it to get the expression/info 

258 if callable(expr): 

259 expr = expr() 

260 if not isinstance(expr, MetricInfo): 

261 expr = MetricInfo(expr=expr) 

262 return expr 

263 

264 @classmethod 

265 def get_summary(cls, name: str) -> pl.Expr: 

266 """ 

267 Get a summary expression by name. 

268 

269 Args: 

270 name: Name of the summary 

271 

272 Returns: 

273 Polars expression for the summary 

274 

275 Raises: 

276 ValueError: If the summary is not registered 

277 """ 

278 if not cls._registry_contains("summary", name): 

279 raise MetricNotFoundError(name, cls.list_summaries(), "summary") 

280 

281 summary_entry = cls._registry_store("summary")[name] 

282 result = summary_entry() if callable(summary_entry) else summary_entry 

283 if not isinstance(result, pl.Expr): 

284 raise TypeError( 

285 "Summary registry returned a non-expression object; " 

286 "ensure summaries resolve to pl.Expr." 

287 ) 

288 return result 

289 

290 # ============ Built-in Error Expression Functions ============ 

291 

292 

293def _error(estimate: str, ground_truth: str) -> pl.Expr: 

294 """Basic error: estimate - ground_truth""" 

295 return pl.col(estimate) - pl.col(ground_truth) 

296 

297 

298def _absolute_error(estimate: str, ground_truth: str) -> pl.Expr: 

299 """Absolute error: |estimate - ground_truth|""" 

300 error = pl.col(estimate) - pl.col(ground_truth) 

301 return error.abs() 

302 

303 

304def _squared_error(estimate: str, ground_truth: str) -> pl.Expr: 

305 """Squared error: (estimate - ground_truth)^2""" 

306 error = pl.col(estimate) - pl.col(ground_truth) 

307 return error**2 

308 

309 

310def _percent_error(estimate: str, ground_truth: str) -> pl.Expr: 

311 """Percent error: (estimate - ground_truth) / ground_truth * 100""" 

312 error = pl.col(estimate) - pl.col(ground_truth) 

313 return ( 

314 pl.when(pl.col(ground_truth) != 0) 

315 .then(error / pl.col(ground_truth) * 100) 

316 .otherwise(None) 

317 ) 

318 

319 

320def _absolute_percent_error(estimate: str, ground_truth: str) -> pl.Expr: 

321 """Absolute percent error: |(estimate - ground_truth) / ground_truth| * 100""" 

322 error = pl.col(estimate) - pl.col(ground_truth) 

323 return ( 

324 pl.when(pl.col(ground_truth) != 0) 

325 .then((error / pl.col(ground_truth) * 100).abs()) 

326 .otherwise(None) 

327 ) 

328 

329 

330# ============ Register All Built-in Expressions Globally ============ 

331 

332_DEFAULT_ERRORS: dict[str, Callable[..., pl.Expr]] = { 

333 "error": _error, 

334 "absolute_error": _absolute_error, 

335 "squared_error": _squared_error, 

336 "percent_error": _percent_error, 

337 "absolute_percent_error": _absolute_percent_error, 

338} 

339 

340for _name, _func in _DEFAULT_ERRORS.items(): 

341 MetricRegistry.register_error(_name, _func) 

342 

343_DEFAULT_METRICS: tuple[tuple[str, pl.Expr, dict[str, Any]], ...] = ( 

344 ("me", pl.col("error").mean(), {}), 

345 ("mae", pl.col("absolute_error").mean(), {}), 

346 ("mse", pl.col("squared_error").mean(), {}), 

347 ("rmse", pl.col("squared_error").mean().sqrt(), {}), 

348 ("mpe", pl.col("percent_error").mean(), {}), 

349 ("mape", pl.col("absolute_percent_error").mean(), {}), 

350 ("n_subject", pl.col("subject_id").n_unique(), {"value_kind": "int"}), 

351 ( 

352 "n_visit", 

353 pl.struct(["subject_id", "visit_id"]).n_unique(), 

354 {"value_kind": "int"}, 

355 ), 

356 ( 

357 "n_sample", 

358 pl.col("sample_index").n_unique(), 

359 {"value_kind": "int"}, 

360 ), 

361 ( 

362 "n_subject_with_data", 

363 pl.col("subject_id").filter(pl.col("error").is_not_null()).n_unique(), 

364 {"value_kind": "int"}, 

365 ), 

366 ( 

367 "pct_subject_with_data", 

368 ( 

369 pl.col("subject_id").filter(pl.col("error").is_not_null()).n_unique() 

370 / pl.col("subject_id").n_unique() 

371 * 100 

372 ).alias("value"), 

373 {}, 

374 ), 

375 ( 

376 "n_visit_with_data", 

377 pl.struct(["subject_id", "visit_id"]) 

378 .filter(pl.col("error").is_not_null()) 

379 .n_unique(), 

380 {"value_kind": "int"}, 

381 ), 

382 ( 

383 "pct_visit_with_data", 

384 ( 

385 pl.struct(["subject_id", "visit_id"]) 

386 .filter(pl.col("error").is_not_null()) 

387 .n_unique() 

388 / pl.struct(["subject_id", "visit_id"]).n_unique() 

389 * 100 

390 ).alias("value"), 

391 {}, 

392 ), 

393 ( 

394 "n_sample_with_data", 

395 pl.col("error").is_not_null().sum(), 

396 {"value_kind": "int"}, 

397 ), 

398 ( 

399 "pct_sample_with_data", 

400 (pl.col("error").is_not_null().mean() * 100).alias("value"), 

401 {}, 

402 ), 

403) 

404 

405for _name, _expr, _options in _DEFAULT_METRICS: 

406 MetricRegistry.register_metric(_name, _expr, **_options) 

407 

408_DEFAULT_SUMMARIES: dict[str, pl.Expr] = { 

409 "mean": pl.col("value").mean(), 

410 "median": pl.col("value").median(), 

411 "std": pl.col("value").std(), 

412 "min": pl.col("value").min(), 

413 "max": pl.col("value").max(), 

414 "sum": pl.col("value").sum(), 

415 "sqrt": pl.col("value").sqrt(), 

416} 

417_DEFAULT_SUMMARIES.update( 

418 { 

419 f"p{_p}": pl.col("value").quantile(_p / 100, interpolation="linear") 

420 for _p in (1, 5, 25, 75, 90, 95, 99) 

421 } 

422) 

423 

424for _name, _expr in _DEFAULT_SUMMARIES.items(): 

425 MetricRegistry.register_summary(_name, _expr)