Coverage for src/polars_eval_metrics/metric_registry.py: 89%
126 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
1"""
2Unified Expression Registry System
4This module provides an extensible registry system for all types of expressions:
5- Error expressions (for data preparation)
6- Metric expressions (for aggregation)
7- Summary expressions (for second-level aggregation)
9Supports both global (class-level) and local (instance-level) registries.
10"""
12from dataclasses import dataclass
13from typing import Any, Callable, ClassVar, cast
15# pyre-strict
17import polars as pl
20class MetricNotFoundError(ValueError):
21 """Exception raised when a requested metric/error/summary is not found."""
23 def __init__(
24 self, name: str, available: list[str], expr_type: str = "expression"
25 ) -> None:
26 self.name = name
27 self.available = available
28 self.expr_type = expr_type
29 super().__init__(
30 f"{expr_type.capitalize()} '{name}' not found. "
31 f"Available {expr_type}s: {', '.join(available) if available else 'none'}"
32 )
35@dataclass(frozen=True)
36class MetricInfo:
37 expr: pl.Expr
38 value_kind: str = "float"
39 format: str | None = None
42class MetricRegistry:
43 """
44 Unified registry for all expression types used in metric evaluation.
46 This is a singleton-style class that provides global registration
47 and retrieval of error types, metrics, and summary expressions.
48 """
50 # Class-level registries
51 _errors: dict[str, Callable[..., pl.Expr]] = {}
52 _metrics: dict[str, MetricInfo | Callable[[], MetricInfo]] = {}
53 _summaries: dict[str, pl.Expr | Callable[[], pl.Expr]] = {}
54 _REGISTRY_ATTRS: ClassVar[dict[str, str]] = {
55 "error": "_errors",
56 "metric": "_metrics",
57 "summary": "_summaries",
58 }
60 @classmethod
61 def _registry_store(cls, kind: str) -> dict[str, Any]:
62 try:
63 attr_name = cls._REGISTRY_ATTRS[kind]
64 except KeyError as exc:
65 raise ValueError(f"Unknown registry kind: {kind}") from exc
66 return getattr(cls, attr_name)
68 @classmethod
69 def _registry_names(cls, kind: str) -> list[str]:
70 return list(cls._registry_store(kind).keys())
72 @classmethod
73 def _registry_contains(cls, kind: str, name: str) -> bool:
74 return name in cls._registry_store(kind)
76 # ============ Error Expression Methods ============
78 @classmethod
79 def register_error(cls, name: str, func: Callable[..., pl.Expr]) -> None:
80 """
81 Register a custom error expression function.
83 Args:
84 name: Name of the error type (e.g., 'absolute_error', 'buffer_error')
85 func: Function that takes (estimate, ground_truth, *args) and returns pl.Expr
87 Example:
88 def buffer_error(estimate: str, ground_truth: str, threshold: float = 0.5):
89 return (pl.col(estimate) - pl.col(ground_truth)).abs() <= threshold
91 MetricRegistry.register_error('buffer_error', buffer_error)
92 """
93 cls._errors[name] = func
95 @classmethod
96 def get_error(
97 cls,
98 name: str,
99 estimate: str,
100 ground_truth: str,
101 **params: Any,
102 ) -> pl.Expr:
103 """
104 Get an error expression by name.
106 Args:
107 name: Name of the error type
108 estimate: Estimate column name
109 ground_truth: Ground truth column name
110 **params: Additional parameters for parameterized error functions
112 Returns:
113 Polars expression that computes the error
114 """
115 if not cls._registry_contains("error", name):
116 raise MetricNotFoundError(name, cls.list_errors(), "error")
118 func = cls._registry_store("error")[name]
119 return func(estimate, ground_truth, **params)
121 @classmethod
122 def generate_error_columns(
123 cls,
124 estimate: str,
125 ground_truth: str,
126 error_types: list[str] | None = None,
127 error_params: dict[str, dict[str, Any]] | None = None,
128 ) -> list[pl.Expr]:
129 """Generate error column expressions for specified error types."""
130 error_types = error_types or cls.list_errors()
131 error_params = error_params or {}
133 return [
134 cls.get_error(
135 error_type, estimate, ground_truth, **error_params.get(error_type, {})
136 ).alias(error_type)
137 for error_type in error_types
138 ]
140 @classmethod
141 def list_errors(cls) -> list[str]:
142 """List all available error types."""
143 return cls._registry_names("error")
145 @classmethod
146 def list_metrics(cls) -> list[str]:
147 """List all available metrics."""
148 return cls._registry_names("metric")
150 @classmethod
151 def list_summaries(cls) -> list[str]:
152 """List all available summaries."""
153 return cls._registry_names("summary")
155 @classmethod
156 def has_metric(cls, name: str) -> bool:
157 """Check if a metric exists in the registry."""
158 return cls._registry_contains("metric", name)
160 @classmethod
161 def has_summary(cls, name: str) -> bool:
162 """Check if a summary/selector exists in the registry."""
163 return cls._registry_contains("summary", name)
165 @classmethod
166 def has_error(cls, name: str) -> bool:
167 """Check if an error type exists in the registry."""
168 return cls._registry_contains("error", name)
170 # ============ Metric Expression Methods ============
172 @classmethod
173 def register_metric(
174 cls,
175 name: str,
176 expr: pl.Expr | Callable[[], pl.Expr] | MetricInfo | Callable[[], MetricInfo],
177 *,
178 value_kind: str | None = None,
179 format: str | None = None,
180 ) -> None:
181 """
182 Register a custom metric expression.
184 Args:
185 name: Name of the metric (e.g., 'mae', 'custom_accuracy')
186 expr: Polars expression/MetricInfo or callable returning one.
187 Expressions reference error columns and define the aggregated result.
188 value_kind: Optional type hint (``float``, ``int``, ``struct`` ...) used to
189 populate the ``stat`` struct. Defaults to ``float``.
190 format: Optional string formatter applied when rendering ``value``.
192 Example:
193 MetricRegistry.register_metric('mae', pl.col('absolute_error').mean().alias('value'))
194 """
195 if isinstance(expr, MetricInfo):
196 cls._metrics[name] = expr
197 return
199 if callable(expr):
200 callable_expr: Callable[[], MetricInfo | pl.Expr] = cast(
201 Callable[[], MetricInfo | pl.Expr], expr
202 )
204 def factory() -> MetricInfo:
205 result = callable_expr()
206 if isinstance(result, MetricInfo):
207 return result
208 return MetricInfo(
209 expr=result,
210 value_kind=value_kind or "float",
211 format=format,
212 )
214 cls._metrics[name] = factory
215 return
217 info = MetricInfo(
218 expr=expr,
219 value_kind=value_kind or "float",
220 format=format,
221 )
222 cls._metrics[name] = info
224 @classmethod
225 def register_summary(cls, name: str, expr: pl.Expr | Callable[[], pl.Expr]) -> None:
226 """
227 Register a custom summary expression.
229 Args:
230 name: Name of the summary (e.g., 'mean', 'p90')
231 expr: Polars expression or callable that returns a Polars expression
232 The expression should typically operate on 'value' column
234 Example:
235 MetricRegistry.register_summary('p90', pl.col('value').quantile(0.9, interpolation="linear"))
236 """
237 cls._summaries[name] = expr
239 @classmethod
240 def get_metric(cls, name: str) -> MetricInfo:
241 """
242 Get a metric expression by name.
244 Args:
245 name: Name of the metric
247 Returns:
248 Polars expression for the metric
250 Raises:
251 ValueError: If the metric is not registered
252 """
253 if not cls._registry_contains("metric", name):
254 raise MetricNotFoundError(name, cls.list_metrics(), "metric")
256 expr = cls._registry_store("metric")[name]
257 # If it's a callable, call it to get the expression/info
258 if callable(expr):
259 expr = expr()
260 if not isinstance(expr, MetricInfo):
261 expr = MetricInfo(expr=expr)
262 return expr
264 @classmethod
265 def get_summary(cls, name: str) -> pl.Expr:
266 """
267 Get a summary expression by name.
269 Args:
270 name: Name of the summary
272 Returns:
273 Polars expression for the summary
275 Raises:
276 ValueError: If the summary is not registered
277 """
278 if not cls._registry_contains("summary", name):
279 raise MetricNotFoundError(name, cls.list_summaries(), "summary")
281 summary_entry = cls._registry_store("summary")[name]
282 result = summary_entry() if callable(summary_entry) else summary_entry
283 if not isinstance(result, pl.Expr):
284 raise TypeError(
285 "Summary registry returned a non-expression object; "
286 "ensure summaries resolve to pl.Expr."
287 )
288 return result
290 # ============ Built-in Error Expression Functions ============
293def _error(estimate: str, ground_truth: str) -> pl.Expr:
294 """Basic error: estimate - ground_truth"""
295 return pl.col(estimate) - pl.col(ground_truth)
298def _absolute_error(estimate: str, ground_truth: str) -> pl.Expr:
299 """Absolute error: |estimate - ground_truth|"""
300 error = pl.col(estimate) - pl.col(ground_truth)
301 return error.abs()
304def _squared_error(estimate: str, ground_truth: str) -> pl.Expr:
305 """Squared error: (estimate - ground_truth)^2"""
306 error = pl.col(estimate) - pl.col(ground_truth)
307 return error**2
310def _percent_error(estimate: str, ground_truth: str) -> pl.Expr:
311 """Percent error: (estimate - ground_truth) / ground_truth * 100"""
312 error = pl.col(estimate) - pl.col(ground_truth)
313 return (
314 pl.when(pl.col(ground_truth) != 0)
315 .then(error / pl.col(ground_truth) * 100)
316 .otherwise(None)
317 )
320def _absolute_percent_error(estimate: str, ground_truth: str) -> pl.Expr:
321 """Absolute percent error: |(estimate - ground_truth) / ground_truth| * 100"""
322 error = pl.col(estimate) - pl.col(ground_truth)
323 return (
324 pl.when(pl.col(ground_truth) != 0)
325 .then((error / pl.col(ground_truth) * 100).abs())
326 .otherwise(None)
327 )
330# ============ Register All Built-in Expressions Globally ============
332_DEFAULT_ERRORS: dict[str, Callable[..., pl.Expr]] = {
333 "error": _error,
334 "absolute_error": _absolute_error,
335 "squared_error": _squared_error,
336 "percent_error": _percent_error,
337 "absolute_percent_error": _absolute_percent_error,
338}
340for _name, _func in _DEFAULT_ERRORS.items():
341 MetricRegistry.register_error(_name, _func)
343_DEFAULT_METRICS: tuple[tuple[str, pl.Expr, dict[str, Any]], ...] = (
344 ("me", pl.col("error").mean(), {}),
345 ("mae", pl.col("absolute_error").mean(), {}),
346 ("mse", pl.col("squared_error").mean(), {}),
347 ("rmse", pl.col("squared_error").mean().sqrt(), {}),
348 ("mpe", pl.col("percent_error").mean(), {}),
349 ("mape", pl.col("absolute_percent_error").mean(), {}),
350 ("n_subject", pl.col("subject_id").n_unique(), {"value_kind": "int"}),
351 (
352 "n_visit",
353 pl.struct(["subject_id", "visit_id"]).n_unique(),
354 {"value_kind": "int"},
355 ),
356 (
357 "n_sample",
358 pl.col("sample_index").n_unique(),
359 {"value_kind": "int"},
360 ),
361 (
362 "n_subject_with_data",
363 pl.col("subject_id").filter(pl.col("error").is_not_null()).n_unique(),
364 {"value_kind": "int"},
365 ),
366 (
367 "pct_subject_with_data",
368 (
369 pl.col("subject_id").filter(pl.col("error").is_not_null()).n_unique()
370 / pl.col("subject_id").n_unique()
371 * 100
372 ).alias("value"),
373 {},
374 ),
375 (
376 "n_visit_with_data",
377 pl.struct(["subject_id", "visit_id"])
378 .filter(pl.col("error").is_not_null())
379 .n_unique(),
380 {"value_kind": "int"},
381 ),
382 (
383 "pct_visit_with_data",
384 (
385 pl.struct(["subject_id", "visit_id"])
386 .filter(pl.col("error").is_not_null())
387 .n_unique()
388 / pl.struct(["subject_id", "visit_id"]).n_unique()
389 * 100
390 ).alias("value"),
391 {},
392 ),
393 (
394 "n_sample_with_data",
395 pl.col("error").is_not_null().sum(),
396 {"value_kind": "int"},
397 ),
398 (
399 "pct_sample_with_data",
400 (pl.col("error").is_not_null().mean() * 100).alias("value"),
401 {},
402 ),
403)
405for _name, _expr, _options in _DEFAULT_METRICS:
406 MetricRegistry.register_metric(_name, _expr, **_options)
408_DEFAULT_SUMMARIES: dict[str, pl.Expr] = {
409 "mean": pl.col("value").mean(),
410 "median": pl.col("value").median(),
411 "std": pl.col("value").std(),
412 "min": pl.col("value").min(),
413 "max": pl.col("value").max(),
414 "sum": pl.col("value").sum(),
415 "sqrt": pl.col("value").sqrt(),
416}
417_DEFAULT_SUMMARIES.update(
418 {
419 f"p{_p}": pl.col("value").quantile(_p / 100, interpolation="linear")
420 for _p in (1, 5, 25, 75, 90, 95, 99)
421 }
422)
424for _name, _expr in _DEFAULT_SUMMARIES.items():
425 MetricRegistry.register_summary(_name, _expr)