Coverage for src/polars_eval_metrics/metric_define.py: 69%

250 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 15:04 +0000

1""" 

2Metric definition and expression preparation 

3 

4This module combines metric configuration with expression compilation, 

5providing a single class for defining metrics and preparing Polars expressions. 

6""" 

7 

8# pyre-strict 

9from enum import Enum 

10from typing import Self 

11 

12import polars as pl 

13from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator 

14 

15from .metric_registry import MetricRegistry, MetricInfo 

16from .utils import format_polars_expr, format_polars_expr_list, parse_enum_value 

17 

18 

19class MetricType(Enum): 

20 """Metric aggregation types""" 

21 

22 ACROSS_SAMPLE = "across_sample" 

23 ACROSS_SUBJECT = "across_subject" 

24 WITHIN_SUBJECT = "within_subject" 

25 ACROSS_VISIT = "across_visit" 

26 WITHIN_VISIT = "within_visit" 

27 

28 

29class MetricScope(Enum): 

30 """Scope for metric calculation - determines at what level the metric is computed""" 

31 

32 GLOBAL = "global" # Calculate once for entire dataset 

33 MODEL = "model" # Calculate per model only, ignoring groups 

34 GROUP = "group" # Calculate per group only, ignoring models 

35 

36 

37class MetricDefine(BaseModel): 

38 """ 

39 Metric definition with hierarchical expression support. 

40 

41 This class defines metrics with support for two-level aggregation patterns: 

42 - within_expr: Expressions for within-entity aggregation (e.g., within subject/visit) 

43 - across_expr: Expressions for across-entity aggregation or final metric computation 

44 

45 Attributes: 

46 name: Metric identifier 

47 label: Display name for the metric 

48 type: Aggregation type (ACROSS_SAMPLE, WITHIN_SUBJECT, ACROSS_SUBJECT, etc.) 

49 scope: Calculation scope (GLOBAL, MODEL, GROUP) - orthogonal to aggregation type 

50 within_expr: Expression(s) for within-entity aggregation: 

51 - Used in WITHIN_SUBJECT, ACROSS_SUBJECT, WITHIN_VISIT, ACROSS_VISIT 

52 - Not used in ACROSS_SAMPLE (which operates directly on samples) 

53 across_expr: Expression for across-entity aggregation or final computation: 

54 - For ACROSS_SAMPLE: Applied directly to error columns 

55 - For ACROSS_SUBJECT/VISIT: Summarizes within_expr results across entities 

56 - For WITHIN_SUBJECT/VISIT: Not used (within_expr is final) 

57 

58 Note: within_expr and across_expr are distinct from group_by/subgroup_by which control 

59 analysis stratification (e.g., by treatment, age, sex) and apply to ALL metric types. 

60 """ 

61 

62 model_config = ConfigDict(arbitrary_types_allowed=True) 

63 

64 name: str = Field(..., description="Metric identifier") 

65 label: str | None = None 

66 type: MetricType = MetricType.ACROSS_SAMPLE 

67 scope: MetricScope | None = None 

68 within_expr: list[str | pl.Expr | MetricInfo] | None = None 

69 across_expr: str | pl.Expr | MetricInfo | None = None 

70 

71 def __init__(self, **kwargs: object) -> None: 

72 """Initialize with default label if not provided""" 

73 if "label" not in kwargs or kwargs["label"] is None: 

74 kwargs["label"] = kwargs.get("name", "Unknown Metric") 

75 kwargs.pop("registry", None) 

76 super().__init__(**kwargs) 

77 

78 @field_validator("name") # pyre-ignore[56] 

79 @classmethod 

80 def validate_name(cls, v: str) -> str: 

81 """Validate metric name is not empty""" 

82 if not v or not v.strip(): 

83 raise ValueError("Metric name cannot be empty") 

84 return v.strip() 

85 

86 @field_validator("label") # pyre-ignore[56] 

87 @classmethod 

88 def validate_label(cls, v: str | None) -> str | None: 

89 """Validate label is not empty""" 

90 if v is None: 

91 return None 

92 if not v.strip(): 

93 raise ValueError("Metric label cannot be empty") 

94 return v.strip() 

95 

96 @field_validator("type", mode="before") # pyre-ignore[56] 

97 @classmethod 

98 def validate_type(cls, v: object) -> MetricType: 

99 """Convert string to MetricType enum if needed""" 

100 result = parse_enum_value(v, MetricType, field="metric type") 

101 assert isinstance(result, MetricType) 

102 return result 

103 

104 @field_validator("scope", mode="before") # pyre-ignore[56] 

105 @classmethod 

106 def validate_scope(cls, v: object) -> MetricScope | None: 

107 """Convert string to MetricScope enum if needed""" 

108 result = parse_enum_value(v, MetricScope, field="metric scope", allow_none=True) 

109 assert result is None or isinstance(result, MetricScope) 

110 return result 

111 

112 @field_validator("within_expr", mode="before") # pyre-ignore[56] 

113 @classmethod 

114 def normalize_within_expr(cls, v: object) -> object: 

115 """Convert single string to list before validation""" 

116 if v is None: 

117 return None 

118 if isinstance(v, str): 

119 return [v] # Convert single string to list 

120 if isinstance(v, pl.Expr): 

121 return [v] # Convert single expression to list 

122 if isinstance(v, MetricInfo): 

123 return [v] 

124 return v # Already a list or something else 

125 

126 @field_validator("within_expr") # pyre-ignore[56] 

127 @classmethod 

128 def validate_within_expr( 

129 cls, v: list[str | pl.Expr | MetricInfo] | None 

130 ) -> list[str | pl.Expr | MetricInfo] | None: 

131 """Validate within-entity aggregation expressions list""" 

132 if v is None: 

133 return None 

134 

135 if not isinstance(v, list): 

136 raise ValueError( 

137 f"within_expr must be a list after normalization, got {type(v)}" 

138 ) 

139 

140 if not v: 

141 raise ValueError("within_expr cannot be an empty list") 

142 

143 for i, item in enumerate(v): 

144 if isinstance(item, str): 

145 if not item.strip(): 

146 raise ValueError( 

147 f"within_expr[{i}]: Built-in metric name cannot be empty" 

148 ) 

149 elif not isinstance(item, (pl.Expr, MetricInfo)): 

150 raise ValueError( 

151 f"within_expr[{i}] must be a string (built-in name), Polars expression, or MetricInfo" 

152 ) 

153 

154 return v 

155 

156 @field_validator("across_expr") # pyre-ignore[56] 

157 @classmethod 

158 def validate_across_expr( 

159 cls, v: str | pl.Expr | MetricInfo | None 

160 ) -> str | pl.Expr | MetricInfo | None: 

161 """Validate across-entity expression - built-in selector, Polars expression, or MetricInfo""" 

162 if v is None: 

163 return None 

164 

165 # If it's a string, validate it's not empty (will check if valid built-in during compile) 

166 if isinstance(v, str): 

167 if not v.strip(): 

168 raise ValueError("Built-in selector name cannot be empty") 

169 return v.strip() 

170 

171 # Otherwise it must be a Polars expression or MetricInfo 

172 if not isinstance(v, (pl.Expr, MetricInfo)): 

173 raise ValueError( 

174 "across_expr must be a string (built-in selector), Polars expression, or MetricInfo" 

175 ) 

176 return v 

177 

178 @model_validator(mode="after") # pyre-ignore[56] 

179 def validate_expressions(self) -> Self: 

180 """Validate expression combinations""" 

181 is_custom = self.within_expr is not None or self.across_expr is not None 

182 

183 if is_custom: 

184 # Custom metrics must have at least one expression 

185 if self.within_expr is None and self.across_expr is None: 

186 raise ValueError( 

187 "Custom metrics must have at least within_expr or across_expr" 

188 ) 

189 

190 # For two-level aggregation with multiple within_expr, across_expr is required 

191 if self.type in (MetricType.ACROSS_SUBJECT, MetricType.ACROSS_VISIT): 

192 if ( 

193 self.within_expr 

194 and len(self.within_expr) > 1 

195 and self.across_expr is None 

196 ): 

197 raise ValueError( 

198 f"across_expr required for multiple within expressions in {self.type.value}" 

199 ) 

200 else: 

201 # Built-in metrics should follow naming convention 

202 if ":" in self.name: 

203 parts = self.name.split(":", 1) 

204 if len(parts) != 2 or not parts[0] or not parts[1]: 

205 raise ValueError( 

206 f"Invalid built-in metric name format: {self.name}" 

207 ) 

208 

209 return self 

210 

211 def compile_expressions(self) -> tuple[list[MetricInfo], MetricInfo | None]: 

212 """ 

213 Compile this metric's expressions to Polars expressions. 

214 

215 Returns: 

216 Tuple of (aggregation_expressions, selection_expression) 

217 """ 

218 # Handle custom expressions 

219 if self.within_expr is not None or self.across_expr is not None: 

220 result = self._compile_custom_expressions() 

221 else: 

222 # Handle built-in metrics 

223 result = self._compile_builtin_expressions() 

224 

225 # For ACROSS_SAMPLE, move single expression to selection 

226 if self.type == MetricType.ACROSS_SAMPLE: 

227 agg_infos, sel_info = result 

228 if agg_infos and sel_info is None: 

229 return [], agg_infos[0] 

230 

231 return result 

232 

233 def _compile_custom_expressions( 

234 self, 

235 ) -> tuple[list[MetricInfo], MetricInfo | None]: 

236 """Compile custom metric expressions - assumes all inputs are validated""" 

237 within_exprs = self._resolve_within_expressions() 

238 across_expr = self._resolve_across_expression() 

239 

240 # If only across_expr provided (no within_expr), use it as single aggregation 

241 if len(within_exprs) == 0 and across_expr is not None: 

242 return [across_expr], None 

243 

244 return within_exprs, across_expr 

245 

246 def _resolve_within_expressions(self) -> list[MetricInfo]: 

247 """Pure implementation: resolve within expressions without validation""" 

248 if self.within_expr is None: 

249 return [] 

250 

251 return [self._ensure_metric_info(item) for item in self.within_expr] 

252 

253 def _resolve_across_expression(self) -> MetricInfo | None: 

254 """Pure implementation: resolve across expression without validation""" 

255 if self.across_expr is None: 

256 return None 

257 

258 expr = ( 

259 MetricRegistry.get_summary(self.across_expr) 

260 if isinstance(self.across_expr, str) 

261 else self.across_expr 

262 ) 

263 return self._ensure_metric_info(expr) 

264 

265 def _compile_builtin_expressions( 

266 self, 

267 ) -> tuple[list[MetricInfo], MetricInfo | None]: 

268 """Compile built-in metric expressions""" 

269 parts = (self.name + ":").split(":")[:2] 

270 agg_name, select_name = parts[0], parts[1] if parts[1] else None 

271 

272 # Get built-in aggregation expression (already a Polars expression) 

273 try: 

274 agg_info = MetricRegistry.get_metric(agg_name) 

275 except ValueError: 

276 raise ValueError(f"Unknown built-in metric: {agg_name}") 

277 

278 # Get selector expression if specified (already a Polars expression) 

279 select_expr = None 

280 if select_name: 

281 try: 

282 select_expr = MetricRegistry.get_summary(select_name) 

283 except ValueError: 

284 raise ValueError(f"Unknown built-in selector: {select_name}") 

285 # If there's a selector, return as aggregation + selection 

286 return [agg_info], self._ensure_metric_info(select_expr) 

287 

288 # No selector: this is likely ACROSS_SAMPLE, return as selection only 

289 return [], agg_info 

290 

291 def _ensure_metric_info(self, value: object) -> MetricInfo: 

292 if isinstance(value, MetricInfo): 

293 return value 

294 if isinstance(value, str): 

295 return MetricRegistry.get_metric(value) 

296 if isinstance(value, pl.Expr): 

297 return MetricInfo(expr=value) 

298 raise TypeError(f"Unsupported metric expression type: {type(value)!r}") 

299 

300 def get_pl_chain(self) -> str: 

301 """ 

302 Get a string representation of the Polars LazyFrame chain for this metric. 

303 

304 Returns: 

305 String showing the LazyFrame operations that would be executed 

306 """ 

307 agg_infos, select_info = self.compile_expressions() 

308 agg_exprs = [info.expr for info in agg_infos] 

309 select_expr = select_info.expr if select_info is not None else None 

310 

311 chain_lines = ["(", " pl.LazyFrame"] 

312 

313 # Determine the chain based on metric type 

314 if self.type == MetricType.ACROSS_SAMPLE: 

315 # Simple aggregation across all samples 

316 if select_expr is not None: 

317 chain_lines.append(f" .select({format_polars_expr(select_expr)})") 

318 elif agg_exprs: 

319 if len(agg_exprs) == 1: 

320 chain_lines.append(f" .select({format_polars_expr(agg_exprs[0])})") 

321 else: 

322 chain_lines.append(" .select(") 

323 chain_lines.append(format_polars_expr_list(agg_exprs)) 

324 chain_lines.append(" )") 

325 

326 elif self.type == MetricType.WITHIN_SUBJECT: 

327 # Group by subject, then aggregate 

328 chain_lines.append(" .group_by('subject_id')") 

329 if select_expr is not None: 

330 chain_lines.append(f" .agg({format_polars_expr(select_expr)})") 

331 elif agg_exprs: 

332 if len(agg_exprs) == 1: 

333 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})") 

334 else: 

335 chain_lines.append(" .agg(") 

336 chain_lines.append(format_polars_expr_list(agg_exprs)) 

337 chain_lines.append(" )") 

338 

339 elif self.type == MetricType.ACROSS_SUBJECT: 

340 # Two-level: group by subject, aggregate, then aggregate across 

341 if agg_exprs: 

342 chain_lines.append(" .group_by('subject_id')") 

343 if len(agg_exprs) == 1: 

344 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})") 

345 else: 

346 chain_lines.append(" .agg(") 

347 chain_lines.append(format_polars_expr_list(agg_exprs)) 

348 chain_lines.append(" )") 

349 if select_expr is not None: 

350 chain_lines.append(f" .select({format_polars_expr(select_expr)})") 

351 

352 elif self.type == MetricType.WITHIN_VISIT: 

353 # Group by subject and visit 

354 chain_lines.append(" .group_by(['subject_id', 'visit_id'])") 

355 if select_expr is not None: 

356 chain_lines.append(f" .agg({format_polars_expr(select_expr)})") 

357 elif agg_exprs: 

358 if len(agg_exprs) == 1: 

359 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})") 

360 else: 

361 chain_lines.append(" .agg(") 

362 chain_lines.append(format_polars_expr_list(agg_exprs)) 

363 chain_lines.append(" )") 

364 

365 elif self.type == MetricType.ACROSS_VISIT: 

366 # Two-level: group by visit, aggregate, then aggregate across 

367 if agg_exprs: 

368 chain_lines.append(" .group_by(['subject_id', 'visit_id'])") 

369 if len(agg_exprs) == 1: 

370 chain_lines.append(f" .agg({format_polars_expr(agg_exprs[0])})") 

371 else: 

372 chain_lines.append(" .agg(") 

373 chain_lines.append(format_polars_expr_list(agg_exprs)) 

374 chain_lines.append(" )") 

375 if select_expr is not None: 

376 chain_lines.append(f" .select({format_polars_expr(select_expr)})") 

377 

378 chain_lines.append(")") 

379 

380 return "\n".join(chain_lines) 

381 

382 def __str__(self) -> str: 

383 """String representation for display""" 

384 lines = [f"MetricDefine(name='{self.name}', type={self.type.value})"] 

385 lines.append(f" Label: '{self.label}'") 

386 if self.scope is not None: 

387 lines.append(f" Scope: {self.scope.value}") 

388 

389 try: 

390 agg_infos, select_info = self.compile_expressions() 

391 agg_exprs = [info.expr for info in agg_infos] 

392 select_expr = select_info.expr if select_info is not None else None 

393 

394 # Determine base metric and selector names 

395 if ":" in self.name: 

396 base_name, selector_name = self.name.split(":", 1) 

397 else: 

398 base_name = self.name 

399 selector_name = None 

400 

401 # Show within-entity expressions (only if they exist) 

402 if agg_exprs: 

403 lines.append(" Within-entity expressions:") 

404 for i, expr in enumerate(agg_exprs): 

405 # Determine source for each expression 

406 if self.within_expr is not None and i < len(self.within_expr): 

407 item = self.within_expr[i] 

408 source = item if isinstance(item, str) else "custom" 

409 else: 

410 source = base_name # From metric name 

411 lines.append(f" - [{source}] {expr}") 

412 

413 # Show across-entity expression (only if it exists) 

414 if select_expr is not None: 

415 lines.append(" Across-entity expression:") 

416 # Determine source for selection expression 

417 if isinstance(self.across_expr, str): 

418 source = self.across_expr # Built-in selector name 

419 elif self.across_expr is not None: 

420 source = "custom" # Custom expression 

421 elif selector_name: 

422 source = selector_name # From metric name's selector part 

423 else: 

424 source = base_name # From metric name 

425 lines.append(f" - [{source}] {select_expr}") 

426 

427 # Add the LazyFrame chain 

428 lines.append("") 

429 lines.append(self.get_pl_chain()) 

430 

431 except Exception as e: 

432 lines.append(f" Error compiling expressions: {str(e)}") 

433 

434 return "\n".join(lines) 

435 

436 def __repr__(self) -> str: 

437 """Representation for interactive display""" 

438 return self.__str__()