Coverage for src/polars_eval_metrics/utils.py: 58%
66 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
1"""Shared utility helpers for metric formatting and parsing."""
3# pyre-strict
5from enum import Enum
6import re
7import textwrap
8from typing import Any, Sequence, Type
10import polars as pl
13def parse_enum_value(
14 value: object,
15 enum_cls: Type[Enum],
16 *,
17 field: str,
18 allow_none: bool = False,
19) -> Enum | None:
20 """Normalize arbitrary inputs into enum values."""
22 if value is None:
23 if allow_none:
24 return None
25 raise ValueError(f"{field} must not be None")
27 if isinstance(value, enum_cls):
28 return value
30 if isinstance(value, str):
31 try:
32 return enum_cls(value)
33 except ValueError:
34 normalized = value.lower().replace("-", "_")
35 for member in enum_cls:
36 if member.value == normalized or member.name.lower() == normalized:
37 return member
39 valid_options = ", ".join(member.value for member in enum_cls)
40 raise ValueError(
41 f"Invalid {field}: '{value}'. Valid options are: {valid_options}"
42 )
44 raise ValueError(
45 f"{field} must be a {enum_cls.__name__} enum or string, got {type(value)}"
46 )
49def clean_polars_expr_string(expr_str: str) -> str:
50 """Remove Polars internal representation artifacts from expression strings."""
52 cleaned = re.sub(r"dyn \w+:\s*", "", expr_str)
53 if cleaned.startswith("[") and cleaned.endswith("]"):
54 cleaned = cleaned[1:-1]
55 return cleaned.replace("[(", "(").replace(")]", ")")
58def format_polars_expr(expr: object, *, max_width: int = 70) -> str:
59 """Format a single Polars expression for readable display."""
61 expr_str = clean_polars_expr_string(str(expr))
62 if isinstance(expr, pl.Expr) and ".alias(" not in expr_str:
63 expr_str = clean_polars_expr_string(str(expr.alias("value")))
65 if len(expr_str) <= max_width:
66 return expr_str
68 return textwrap.fill(
69 expr_str,
70 width=max_width,
71 subsequent_indent=" ",
72 break_long_words=False,
73 break_on_hyphens=False,
74 )
77def format_polars_expr_list(
78 exprs: Sequence[pl.Expr],
79 *,
80 indent: str = " ",
81 max_width: int = 70,
82) -> str:
83 """Format a sequence of Polars expressions with indentation."""
85 expr_list = list(exprs)
86 if not expr_list:
87 return "[]"
88 if len(expr_list) == 1:
89 return format_polars_expr(expr_list[0], max_width=max_width)
91 lines: list[str] = ["["]
92 for index, expr in enumerate(expr_list):
93 formatted = format_polars_expr(expr, max_width=max_width)
94 comma = "," if index < len(expr_list) - 1 else ""
95 expr_lines = formatted.split("\n")
96 if len(expr_lines) == 1:
97 lines.append(f"{indent} {formatted}{comma}")
98 else:
99 lines.append(f"{indent} {expr_lines[0]}")
100 for line in expr_lines[1:]:
101 lines.append(f"{indent} {line}")
102 lines[-1] += comma
103 lines.append(f"{indent}]")
104 return "\n".join(lines)
107def parse_json_tokens(column: str) -> tuple[str, ...] | None:
108 """Parse a JSON-like column label produced by Struct json serialization."""
110 if column.startswith('{"') and column.endswith('"}') and '","' in column:
111 inner = column[2:-2]
112 return tuple(inner.split('","'))
113 return None
116def parse_json_columns(columns: Sequence[str]) -> dict[str, tuple[str, ...]]:
117 """Return mapping of columns encoded as JSON strings to their token tuples."""
119 parsed: dict[str, tuple[str, ...]] = {}
120 for column in columns:
121 tokens = parse_json_tokens(column)
122 if tokens is not None:
123 parsed[column] = tokens
124 return parsed