Coverage for src/polars_eval_metrics/utils.py: 58%

66 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 15:04 +0000

1"""Shared utility helpers for metric formatting and parsing.""" 

2 

3# pyre-strict 

4 

5from enum import Enum 

6import re 

7import textwrap 

8from typing import Any, Sequence, Type 

9 

10import polars as pl 

11 

12 

13def parse_enum_value( 

14 value: object, 

15 enum_cls: Type[Enum], 

16 *, 

17 field: str, 

18 allow_none: bool = False, 

19) -> Enum | None: 

20 """Normalize arbitrary inputs into enum values.""" 

21 

22 if value is None: 

23 if allow_none: 

24 return None 

25 raise ValueError(f"{field} must not be None") 

26 

27 if isinstance(value, enum_cls): 

28 return value 

29 

30 if isinstance(value, str): 

31 try: 

32 return enum_cls(value) 

33 except ValueError: 

34 normalized = value.lower().replace("-", "_") 

35 for member in enum_cls: 

36 if member.value == normalized or member.name.lower() == normalized: 

37 return member 

38 

39 valid_options = ", ".join(member.value for member in enum_cls) 

40 raise ValueError( 

41 f"Invalid {field}: '{value}'. Valid options are: {valid_options}" 

42 ) 

43 

44 raise ValueError( 

45 f"{field} must be a {enum_cls.__name__} enum or string, got {type(value)}" 

46 ) 

47 

48 

49def clean_polars_expr_string(expr_str: str) -> str: 

50 """Remove Polars internal representation artifacts from expression strings.""" 

51 

52 cleaned = re.sub(r"dyn \w+:\s*", "", expr_str) 

53 if cleaned.startswith("[") and cleaned.endswith("]"): 

54 cleaned = cleaned[1:-1] 

55 return cleaned.replace("[(", "(").replace(")]", ")") 

56 

57 

58def format_polars_expr(expr: object, *, max_width: int = 70) -> str: 

59 """Format a single Polars expression for readable display.""" 

60 

61 expr_str = clean_polars_expr_string(str(expr)) 

62 if isinstance(expr, pl.Expr) and ".alias(" not in expr_str: 

63 expr_str = clean_polars_expr_string(str(expr.alias("value"))) 

64 

65 if len(expr_str) <= max_width: 

66 return expr_str 

67 

68 return textwrap.fill( 

69 expr_str, 

70 width=max_width, 

71 subsequent_indent=" ", 

72 break_long_words=False, 

73 break_on_hyphens=False, 

74 ) 

75 

76 

77def format_polars_expr_list( 

78 exprs: Sequence[pl.Expr], 

79 *, 

80 indent: str = " ", 

81 max_width: int = 70, 

82) -> str: 

83 """Format a sequence of Polars expressions with indentation.""" 

84 

85 expr_list = list(exprs) 

86 if not expr_list: 

87 return "[]" 

88 if len(expr_list) == 1: 

89 return format_polars_expr(expr_list[0], max_width=max_width) 

90 

91 lines: list[str] = ["["] 

92 for index, expr in enumerate(expr_list): 

93 formatted = format_polars_expr(expr, max_width=max_width) 

94 comma = "," if index < len(expr_list) - 1 else "" 

95 expr_lines = formatted.split("\n") 

96 if len(expr_lines) == 1: 

97 lines.append(f"{indent} {formatted}{comma}") 

98 else: 

99 lines.append(f"{indent} {expr_lines[0]}") 

100 for line in expr_lines[1:]: 

101 lines.append(f"{indent} {line}") 

102 lines[-1] += comma 

103 lines.append(f"{indent}]") 

104 return "\n".join(lines) 

105 

106 

107def parse_json_tokens(column: str) -> tuple[str, ...] | None: 

108 """Parse a JSON-like column label produced by Struct json serialization.""" 

109 

110 if column.startswith('{"') and column.endswith('"}') and '","' in column: 

111 inner = column[2:-2] 

112 return tuple(inner.split('","')) 

113 return None 

114 

115 

116def parse_json_columns(columns: Sequence[str]) -> dict[str, tuple[str, ...]]: 

117 """Return mapping of columns encoded as JSON strings to their token tuples.""" 

118 

119 parsed: dict[str, tuple[str, ...]] = {} 

120 for column in columns: 

121 tokens = parse_json_tokens(column) 

122 if tokens is not None: 

123 parsed[column] = tokens 

124 return parsed