Coverage for src/polars_eval_metrics/result_formatter.py: 87%
350 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-29 15:04 +0000
1from __future__ import annotations
3from typing import Any, Mapping, Sequence
5import polars as pl
7from .ard import ARD
8from .evaluation_context import EstimateCatalog, FormatterContext
9from .utils import parse_json_tokens
12def convert_to_ard(result_lf: pl.LazyFrame, context: FormatterContext) -> ARD:
13 """Convert the evaluator output into canonical ARD columns lazily."""
15 schema = result_lf.collect_schema()
16 schema_names = set(schema.names())
18 ard_frame = result_lf.with_columns(
19 [
20 _expr_groups(schema, context.group_by),
21 _expr_subgroups(schema, context.subgroup_by),
22 _expr_estimate(schema, context.estimate_catalog),
23 _expr_metric_enum(context.metric_catalog.names),
24 _expr_label_enum(context.metric_catalog.labels),
25 _expr_stat_struct(schema),
26 _expr_context_struct(schema, context),
27 ]
28 )
30 warning_expr = (
31 pl.col("_diagnostic_warning")
32 if "_diagnostic_warning" in schema_names
33 else pl.lit([], dtype=pl.List(pl.Utf8))
34 )
35 error_expr = (
36 pl.col("_diagnostic_error")
37 if "_diagnostic_error" in schema_names
38 else pl.lit([], dtype=pl.List(pl.Utf8))
39 )
41 ard_frame = ard_frame.with_columns(
42 [
43 pl.col("stat")
44 .map_elements(ARD._format_stat, return_dtype=pl.Utf8)
45 .alias("stat_fmt"),
46 warning_expr.alias("warning"),
47 error_expr.alias("error"),
48 ]
49 )
51 cleanup_cols = [
52 "_value_kind",
53 "_value_format",
54 "_value_float",
55 "_value_int",
56 "_value_bool",
57 "_value_str",
58 "_value_struct",
59 "_diagnostic_warning",
60 "_diagnostic_error",
61 ]
62 drop_cols = [col for col in cleanup_cols if col in schema_names]
63 if drop_cols:
64 ard_frame = ard_frame.drop(drop_cols)
66 return ARD(ard_frame)
69def build_group_pivot(
70 long_df: pl.DataFrame,
71 context: FormatterContext,
72 *,
73 column_order_by: str,
74 row_order_by: str,
75) -> pl.DataFrame:
76 """Pivot results with groups as rows and model × metric as columns."""
78 group_labels = [
79 label for label in context.group_by.values() if label in long_df.columns
80 ]
81 subgroup_present = (
82 "subgroup_name" in long_df.columns and "subgroup_value" in long_df.columns
83 )
85 if row_order_by == "subgroup" and subgroup_present:
86 index_cols = ["subgroup_name", "subgroup_value"] + group_labels
87 else:
88 index_cols = group_labels + (
89 ["subgroup_name", "subgroup_value"] if subgroup_present else []
90 )
92 display_col = (
93 "estimate_label" if "estimate_label" in long_df.columns else "estimate"
94 )
96 result, sections = _build_pivot_table(
97 long_df,
98 index_cols=index_cols,
99 default_on=[display_col, "label"],
100 scoped_on=[("global", ["label"]), ("group", ["label"])],
101 )
103 section_lookup = {name: cols for name, cols in sections}
105 if result.is_empty():
106 if index_cols:
107 return pl.DataFrame({col: [] for col in index_cols})
108 return pl.DataFrame()
110 value_cols = [col for col in result.columns if col not in index_cols]
111 default_cols = section_lookup.get("default", [])
112 default_cols = [
113 col for col in default_cols if col.startswith('{"') and col.endswith('"}')
114 ]
116 estimate_order_lookup: Mapping[str, int] = context.estimate_catalog.label_order
117 metric_label_order_lookup: Mapping[str, int] = context.metric_catalog.label_order
118 metric_name_order_lookup: Mapping[str, int] = context.metric_catalog.name_order
120 def metric_order(label: str) -> int:
121 if label in metric_label_order_lookup:
122 return metric_label_order_lookup[label]
123 return metric_name_order_lookup.get(label, len(metric_label_order_lookup))
125 def estimate_order(label: str) -> int:
126 return estimate_order_lookup.get(label, len(estimate_order_lookup))
128 def sort_default(columns: list[str]) -> list[str]:
129 def parse(column: str) -> tuple[str, str]:
130 inner = column[2:-2]
131 parts = inner.split('","')
132 return (parts[0], parts[1]) if len(parts) == 2 else (column, "")
134 if column_order_by == "metrics":
135 return sorted(
136 columns,
137 key=lambda col: (
138 metric_order(parse(col)[1]),
139 estimate_order(parse(col)[0]),
140 ),
141 )
142 return sorted(
143 columns,
144 key=lambda col: (
145 estimate_order(parse(col)[0]),
146 metric_order(parse(col)[1]),
147 ),
148 )
150 ordered = (
151 index_cols
152 + section_lookup.get("global", [])
153 + section_lookup.get("group", [])
154 + sort_default(default_cols)
155 )
157 remaining = [col for col in value_cols if col not in ordered]
158 ordered.extend(remaining)
159 ordered = [col for col in ordered if col in result.columns]
161 if "subgroup_value" in result.columns:
162 if context.subgroup_categories and all(
163 isinstance(cat, str) for cat in context.subgroup_categories
164 ):
165 result = result.with_columns(
166 pl.col("subgroup_value").cast(
167 pl.Enum(list(context.subgroup_categories))
168 )
169 )
170 else:
171 result = result.with_columns(pl.col("subgroup_value").cast(pl.Utf8))
173 sort_columns: list[str] = []
174 temp_sort_columns: list[str] = []
175 subgroup_order_map = {
176 label: idx for idx, label in enumerate(context.subgroup_by.values())
177 }
179 if row_order_by == "group":
180 sort_columns.extend([col for col in group_labels if col in result.columns])
181 if "subgroup_name" in result.columns and context.subgroup_by:
182 result = result.with_columns(
183 pl.col("subgroup_name")
184 .replace(subgroup_order_map)
185 .fill_null(len(subgroup_order_map))
186 .cast(pl.Int32)
187 .alias("__subgroup_name_order")
188 )
189 temp_sort_columns.append("__subgroup_name_order")
190 sort_columns.append("__subgroup_name_order")
191 if "subgroup_value" in result.columns:
192 sort_columns.append("subgroup_value")
193 else:
194 if "subgroup_value" in result.columns:
195 sort_columns.append("subgroup_value")
196 if "subgroup_name" in result.columns and context.subgroup_by:
197 result = result.with_columns(
198 pl.col("subgroup_name")
199 .replace(subgroup_order_map)
200 .fill_null(len(subgroup_order_map))
201 .cast(pl.Int32)
202 .alias("__subgroup_name_order")
203 )
204 temp_sort_columns.append("__subgroup_name_order")
205 sort_columns.insert(0, "__subgroup_name_order")
206 sort_columns.extend([col for col in group_labels if col in result.columns])
208 if "estimate" in result.columns:
209 sort_columns.append("estimate")
211 if sort_columns:
212 result = result.sort(sort_columns)
213 if temp_sort_columns:
214 result = result.drop(temp_sort_columns)
216 seen: set[str] = set()
217 deduped: list[str] = []
218 for col in ordered:
219 if col not in seen:
220 deduped.append(col)
221 seen.add(col)
223 return result.select(deduped)
226def build_model_pivot(
227 long_df: pl.DataFrame,
228 context: FormatterContext,
229 *,
230 column_order_by: str,
231 row_order_by: str,
232) -> pl.DataFrame:
233 """Pivot results with models as rows and group × metric as columns."""
235 subgroup_present = (
236 "subgroup_name" in long_df.columns and "subgroup_value" in long_df.columns
237 )
239 if row_order_by == "subgroup" and subgroup_present:
240 index_cols = ["estimate", "subgroup_name", "subgroup_value"]
241 else:
242 index_cols = ["estimate"] + (
243 ["subgroup_name", "subgroup_value"] if subgroup_present else []
244 )
246 if "estimate" in index_cols:
247 estimate_series = (
248 long_df.get_column("estimate") if "estimate" in long_df.columns else None
249 )
250 if estimate_series is None or estimate_series.is_null().all():
251 index_cols = [col for col in index_cols if col != "estimate"]
253 result, sections = _build_pivot_table(
254 long_df,
255 index_cols=index_cols,
256 default_on=[*context.group_by.values(), "label"],
257 scoped_on=[
258 ("global", ["label"]),
259 ("group", [*context.group_by.values(), "label"]),
260 ],
261 )
263 section_lookup = {name: cols for name, cols in sections}
265 group_labels = list(context.group_by.values())
266 group_label_count: int = len(group_labels)
267 group_value_orders: list[dict[Any, int]] = []
269 if group_label_count:
270 for label in group_labels:
271 if label not in long_df.columns:
272 group_value_orders.append({})
273 continue
275 series = long_df.get_column(label)
276 dtype = series.dtype
278 if isinstance(dtype, pl.Enum):
279 categories = dtype.categories.to_list()
280 else:
281 categories = sorted(series.drop_nulls().unique().to_list())
283 group_value_orders.append(
284 {value: idx for idx, value in enumerate(categories)}
285 )
287 metric_label_order_lookup: Mapping[str, int] = context.metric_catalog.label_order
288 metric_name_order_lookup: Mapping[str, int] = context.metric_catalog.name_order
289 estimate_label_map = context.estimate_catalog.key_to_label
291 def metric_order(label: str) -> int:
292 if label in metric_label_order_lookup:
293 return metric_label_order_lookup[label]
294 return metric_name_order_lookup.get(label, len(metric_label_order_lookup))
296 def group_order(tokens: tuple[str, ...]) -> tuple[int, ...]:
297 if not group_label_count:
298 return tuple()
299 values = tokens[:group_label_count]
300 order_positions: list[int] = []
301 for idx, value in enumerate(values):
302 mapping = group_value_orders[idx] if idx < len(group_value_orders) else {}
303 order_positions.append(mapping.get(value, len(mapping)))
304 return tuple(order_positions)
306 def column_sort_key(column: str) -> tuple[Any, ...]:
307 tokens = parse_json_tokens(column)
308 if tokens is None:
309 return (float("inf"), column)
310 metric_label = tokens[-1] if tokens else ""
311 metric_idx = metric_order(metric_label)
312 group_idx = group_order(tokens)
313 if column_order_by == "metrics":
314 return (metric_idx, group_idx, tokens)
315 return (group_idx, metric_idx, tokens)
317 if "group" in section_lookup:
318 section_lookup["group"] = sorted(section_lookup["group"], key=column_sort_key)
320 if "default" in section_lookup:
321 section_lookup["default"] = sorted(
322 section_lookup["default"], key=column_sort_key
323 )
325 if "estimate" in result.columns:
326 result = result.with_columns(
327 pl.col("estimate")
328 .map_elements(
329 lambda val, mapping=estimate_label_map: mapping.get(val, val),
330 return_dtype=pl.Utf8,
331 )
332 .alias("estimate_label")
333 )
335 if "subgroup_value" in result.columns:
336 if context.subgroup_categories and all(
337 isinstance(cat, str) for cat in context.subgroup_categories
338 ):
339 result = result.with_columns(
340 pl.col("subgroup_value").cast(
341 pl.Enum(list(context.subgroup_categories))
342 )
343 )
344 else:
345 result = result.with_columns(pl.col("subgroup_value").cast(pl.Utf8))
347 ordered = (
348 [col for col in index_cols if col in result.columns]
349 + [col for col in section_lookup.get("global", []) if col in result.columns]
350 + [col for col in section_lookup.get("group", []) if col in result.columns]
351 + [col for col in section_lookup.get("default", []) if col in result.columns]
352 )
354 remaining = [col for col in result.columns if col not in ordered]
355 ordered.extend(remaining)
357 result = result.select(ordered)
359 sort_columns: list[str] = []
360 if "subgroup_value" in result.columns:
361 sort_columns.append("subgroup_value")
362 if "estimate" in result.columns:
363 sort_columns.append("estimate")
364 for label in context.group_by.values():
365 if label in result.columns:
366 sort_columns.append(label)
367 if sort_columns:
368 result = result.sort(sort_columns)
370 return result
373def _expr_groups(schema: pl.Schema, group_by: Mapping[str, str]) -> pl.Expr:
374 group_cols = [col for col in group_by.keys() if col in schema.names()]
375 if not group_cols:
376 return pl.lit(None).alias("groups")
378 dtype = pl.Struct([pl.Field(col, schema[col]) for col in group_cols])
379 return (
380 pl.when(pl.all_horizontal([pl.col(col).is_null() for col in group_cols]))
381 .then(pl.lit(None, dtype=dtype))
382 .otherwise(pl.struct([pl.col(col).alias(col) for col in group_cols]))
383 .alias("groups")
384 )
387def _expr_subgroups(schema: pl.Schema, subgroup_by: Mapping[str, str]) -> pl.Expr:
388 if (
389 not subgroup_by
390 or "subgroup_name" not in schema.names()
391 or "subgroup_value" not in schema.names()
392 ):
393 return pl.lit(None).alias("subgroups")
395 labels = list(subgroup_by.values())
396 dtype = pl.Struct([pl.Field(label, pl.Utf8) for label in labels])
397 fields = [
398 pl.when(pl.col("subgroup_name") == pl.lit(label))
399 .then(pl.col("subgroup_value").cast(pl.Utf8))
400 .otherwise(pl.lit(None, dtype=pl.Utf8))
401 .alias(label)
402 for label in labels
403 ]
404 return (
405 pl.when(pl.col("subgroup_name").is_null() | pl.col("subgroup_value").is_null())
406 .then(pl.lit(None, dtype=dtype))
407 .otherwise(pl.struct(fields))
408 .alias("subgroups")
409 )
412def _expr_estimate(schema: pl.Schema, estimate_catalog: EstimateCatalog) -> pl.Expr:
413 null_utf8 = pl.lit(None, dtype=pl.Utf8)
414 if "estimate" not in schema.names():
415 return null_utf8.alias("estimate")
417 estimate_names = list(estimate_catalog.keys)
418 if estimate_names:
419 return (
420 pl.col("estimate")
421 .cast(pl.Utf8)
422 .replace({name: name for name in estimate_names})
423 .cast(pl.Enum(estimate_names))
424 .alias("estimate")
425 )
427 return pl.col("estimate").cast(pl.Utf8).alias("estimate")
430def _expr_metric_enum(metric_names: Sequence[str]) -> pl.Expr:
431 metric_categories = list(dict.fromkeys(metric_names))
432 return (
433 pl.col("metric")
434 .cast(pl.Utf8)
435 .replace({name: name for name in metric_categories})
436 .cast(pl.Enum(metric_categories))
437 .alias("metric")
438 )
441def _expr_label_enum(metric_labels: Sequence[str]) -> pl.Expr:
442 unique_labels = list(dict.fromkeys(metric_labels))
443 return pl.col("label").cast(pl.Enum(unique_labels)).alias("label")
446def _expr_stat_struct(schema: pl.Schema) -> pl.Expr:
447 null_utf8 = pl.lit(None, dtype=pl.Utf8)
448 null_float = pl.lit(None, dtype=pl.Float64)
449 null_int = pl.lit(None, dtype=pl.Int64)
450 null_bool = pl.lit(None, dtype=pl.Boolean)
451 null_struct_expr = pl.lit(None, dtype=pl.Struct([]))
453 kind_expr = pl.col("_value_kind") if "_value_kind" in schema.names() else None
454 format_col = (
455 pl.col("_value_format") if "_value_format" in schema.names() else null_utf8
456 )
458 float_value = (
459 pl.col("_value_float") if "_value_float" in schema.names() else null_float
460 )
461 int_value = pl.col("_value_int") if "_value_int" in schema.names() else null_int
462 bool_value = pl.col("_value_bool") if "_value_bool" in schema.names() else null_bool
463 string_value = pl.col("_value_str") if "_value_str" in schema.names() else null_utf8
464 struct_value = (
465 pl.col("_value_struct")
466 if "_value_struct" in schema.names()
467 else null_struct_expr
468 )
470 inferred_kind = "float"
471 if kind_expr is None:
472 value_dtype = schema.get("value") if "value" in schema.names() else None
473 inferred_kind = _infer_value_kind_from_dtype(value_dtype)
474 kind_expr = pl.lit(inferred_kind, dtype=pl.Utf8)
476 type_label = pl.when(kind_expr.is_null()).then(null_utf8).otherwise(kind_expr)
478 return pl.struct(
479 [
480 type_label.alias("type"),
481 float_value.alias("value_float"),
482 int_value.alias("value_int"),
483 bool_value.alias("value_bool"),
484 string_value.alias("value_str"),
485 struct_value.alias("value_struct"),
486 format_col.alias("format"),
487 ]
488 ).alias("stat")
491def _expr_context_struct(schema: pl.Schema, context: FormatterContext) -> pl.Expr:
492 null_utf8 = pl.lit(None, dtype=pl.Utf8)
493 fields = []
494 for field in ("metric_type", "scope", "label"):
495 if field in schema.names():
496 fields.append(pl.col(field).cast(pl.Utf8).alias(field))
497 else:
498 fields.append(null_utf8.alias(field))
499 if "estimate" in schema.names():
500 label_map = context.estimate_catalog.key_to_label
501 fields.append(
502 pl.col("estimate")
503 .cast(pl.Utf8)
504 .map_elements(
505 lambda val, mapping=label_map: mapping.get(val, val),
506 return_dtype=pl.Utf8,
507 )
508 .alias("estimate_label")
509 )
510 else:
511 fields.append(null_utf8.alias("estimate_label"))
512 return pl.struct(fields).alias("context")
515def _infer_value_kind_from_dtype(dtype: pl.DataType | None) -> str:
516 if dtype is None or dtype == pl.Null:
517 return "float"
518 if dtype == pl.Struct:
519 return "struct"
520 if dtype == pl.Boolean:
521 return "bool"
522 if dtype == pl.Utf8:
523 return "string"
524 if hasattr(dtype, "is_numeric") and dtype.is_numeric():
525 return "int" if dtype.is_integer() else "float"
526 return "string"
529def _pivot_frame(
530 df: pl.DataFrame,
531 *,
532 index_cols: Sequence[str],
533 on_cols: Sequence[str],
534) -> pl.DataFrame:
535 if df.is_empty():
536 if index_cols:
537 return pl.DataFrame({col: [] for col in index_cols})
538 return pl.DataFrame()
540 if index_cols:
541 return df.pivot(
542 index=list(index_cols),
543 on=list(on_cols),
544 values="value",
545 aggregate_function="first",
546 )
548 with_idx = df.with_row_index("_idx")
549 return with_idx.pivot(
550 index=["_idx"],
551 on=list(on_cols),
552 values="value",
553 aggregate_function="first",
554 ).drop("_idx")
557def _merge_pivot_frames(
558 base: pl.DataFrame,
559 candidate: pl.DataFrame,
560 index_cols: Sequence[str],
561) -> pl.DataFrame:
562 if base.is_empty():
563 return candidate
564 if candidate.is_empty():
565 return base
567 if not index_cols:
568 return pl.concat([base, candidate], how="horizontal")
570 if candidate.height == 1:
571 broadcast_cols = [col for col in candidate.columns if col not in index_cols]
572 if not broadcast_cols:
573 return base
574 row_values = candidate.row(0, named=True)
575 return base.with_columns(
576 [pl.lit(row_values[col]).alias(col) for col in broadcast_cols]
577 )
579 join_index_cols = list(index_cols)
580 all_null_cols: list[str] = []
581 for col in index_cols:
582 if col in candidate.columns:
583 column = candidate.get_column(col)
584 if column.null_count() == candidate.height:
585 all_null_cols.append(col)
586 if all_null_cols:
587 join_index_cols = [col for col in join_index_cols if col not in all_null_cols]
588 candidate = candidate.drop(all_null_cols)
590 if not join_index_cols:
591 value_cols = [col for col in candidate.columns if col not in index_cols]
592 if not value_cols:
593 return base
594 candidate_unique = candidate.select(value_cols).unique()
595 if candidate_unique.height == 0:
596 return base
597 if candidate_unique.height > 1:
598 candidate_unique = candidate_unique.head(1)
599 row_values = candidate_unique.row(0, named=True)
600 return base.with_columns(
601 [pl.lit(row_values[col]).alias(col) for col in row_values]
602 )
604 return base.join(candidate, on=join_index_cols, how="left")
607def _build_pivot_table(
608 long_df: pl.DataFrame,
609 *,
610 index_cols: Sequence[str],
611 default_on: Sequence[str],
612 scoped_on: Sequence[tuple[str, Sequence[str]]],
613) -> tuple[pl.DataFrame, list[tuple[str, list[str]]]]:
614 default_df = long_df.filter(pl.col("scope").is_null())
615 pivot = _pivot_frame(default_df, index_cols=index_cols, on_cols=default_on)
616 sections: list[tuple[str, list[str]]] = [
617 ("default", [col for col in pivot.columns if col not in index_cols])
618 ]
620 for scope_name, on_cols in scoped_on:
621 scoped_df = long_df.filter(pl.col("scope") == scope_name)
622 scoped_pivot = _pivot_frame(scoped_df, index_cols=index_cols, on_cols=on_cols)
623 if scoped_pivot.is_empty():
624 continue
625 sections.append(
626 (scope_name, [col for col in scoped_pivot.columns if col not in index_cols])
627 )
628 pivot = _merge_pivot_frames(pivot, scoped_pivot, index_cols)
630 return pivot, sections
633def format_verbose_frame(ard: ARD) -> pl.DataFrame:
634 """Render an ARD as a fully expanded DataFrame suitable for inspection."""
636 long_df = ard.to_long()
637 group_sort_cols = list(ard._group_fields)
638 subgroup_struct_cols = list(ard._subgroup_fields)
639 sort_cols: list[str] = []
641 for col in group_sort_cols:
642 if col in long_df.columns:
643 sort_cols.append(col)
645 for col in ("subgroup_name", "subgroup_value"):
646 if col in long_df.columns:
647 sort_cols.append(col)
649 for col in subgroup_struct_cols:
650 if col in long_df.columns and col not in sort_cols:
651 sort_cols.append(col)
653 for col in ("metric", "estimate"):
654 if col in long_df.columns:
655 sort_cols.append(col)
657 if sort_cols:
658 long_df = long_df.sort(sort_cols)
660 preferred_order = [
661 "id",
662 "groups",
663 "subgroups",
664 "subgroup_name",
665 "subgroup_value",
666 "estimate",
667 "metric",
668 "label",
669 "value",
670 "stat",
671 "stat_fmt",
672 "context",
673 "warning",
674 "error",
675 ]
676 ordered_columns = [col for col in preferred_order if col in long_df.columns]
677 remaining_columns = [col for col in long_df.columns if col not in ordered_columns]
678 return long_df.select(ordered_columns + remaining_columns)
681def format_compact_frame(ard: ARD) -> pl.DataFrame:
682 """Render an ARD as a compact DataFrame with struct columns flattened."""
684 verbose_df = format_verbose_frame(ard)
685 flattened = _flatten_struct_columns(
686 verbose_df,
687 group_fields=ard._group_fields,
688 subgroup_fields=ard._subgroup_fields,
689 )
690 detail_cols = [
691 col
692 for col in ("stat", "stat_fmt", "context", "warning", "error")
693 if col in flattened.columns
694 ]
695 if detail_cols:
696 flattened = flattened.drop(detail_cols)
697 return flattened
700def _flatten_struct_columns(
701 df: pl.DataFrame,
702 *,
703 group_fields: Sequence[str],
704 subgroup_fields: Sequence[str],
705) -> pl.DataFrame:
706 """Flatten struct columns for a compact DataFrame view."""
708 working = df
709 nullable_candidates: list[str] = []
711 if "groups" in working.columns and group_fields:
712 group_exprs = [
713 pl.col("groups").struct.field(field).alias(field) for field in group_fields
714 ]
715 working = working.with_columns(group_exprs)
716 nullable_candidates.extend(group_fields)
718 if "subgroups" in working.columns and subgroup_fields:
719 subgroup_exprs = [
720 pl.col("subgroups").struct.field(field).alias(field)
721 for field in subgroup_fields
722 ]
723 working = working.with_columns(subgroup_exprs)
724 nullable_candidates.extend(subgroup_fields)
726 drop_cols = [col for col in ("groups", "subgroups") if col in working.columns]
727 if drop_cols:
728 working = working.drop(drop_cols)
730 nullable_candidates.append("id")
732 def drop_all_null(df: pl.DataFrame, columns: Sequence[str]) -> pl.DataFrame:
733 existing = [col for col in dict.fromkeys(columns) if col in df.columns]
734 if not existing:
735 return df
736 result = df.select(
737 [pl.col(col).is_not_null().any().alias(col) for col in existing]
738 )
739 has_values = result.row(0, named=True)
740 to_drop = [col for col, flag in has_values.items() if not flag]
741 if to_drop:
742 return df.drop(to_drop)
743 return df
745 working = drop_all_null(working, nullable_candidates)
747 return working