Skip to content

Commit 409f091

Browse files
authored
chore: Increase categorical test coverage (#20514)
1 parent 1517599 commit 409f091

15 files changed

+114
-71
lines changed

py-polars/tests/unit/conftest.py

+22
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
if TYPE_CHECKING:
1818
from collections.abc import Generator
19+
from typing import Any
20+
21+
FixtureRequest = Any
1922

2023
load_profile(
2124
profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type]
@@ -229,3 +232,22 @@ def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
229232
yield MemoryUsage()
230233
finally:
231234
tracemalloc.stop()
235+
236+
237+
@pytest.fixture(params=[True, False])
238+
def test_global_and_local(
239+
request: FixtureRequest,
240+
) -> Generator[Any, Any, Any]:
241+
"""
242+
Setup fixture which runs each test with and without global string cache.
243+
244+
Usage: @pytest.mark.usefixtures("test_global_and_local")
245+
"""
246+
use_global = request.param
247+
if use_global:
248+
with pl.StringCache():
249+
# Pre-fill some global items to ensure physical repr isn't 0..n.
250+
pl.Series(["eapioejf", "2m4lmv", "3v3v9dlf"], dtype=pl.Categorical)
251+
yield
252+
else:
253+
yield

py-polars/tests/unit/constructors/test_any_value_fallbacks.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -398,16 +398,16 @@ def test_fallback_with_dtype_strict_failure_decimal_precision() -> None:
398398
PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True)
399399

400400

401+
@pytest.mark.usefixtures("test_global_and_local")
401402
def test_categorical_lit_18874() -> None:
402-
with pl.StringCache():
403-
assert_frame_equal(
404-
pl.DataFrame(
405-
{"a": [1, 2, 3]},
406-
).with_columns(b=pl.lit("foo").cast(pl.Categorical)),
407-
pl.DataFrame(
408-
[
409-
pl.Series("a", [1, 2, 3]),
410-
pl.Series("b", ["foo"] * 3, pl.Categorical),
411-
]
412-
),
413-
)
403+
assert_frame_equal(
404+
pl.DataFrame(
405+
{"a": [1, 2, 3]},
406+
).with_columns(b=pl.lit("foo").cast(pl.Categorical)),
407+
pl.DataFrame(
408+
[
409+
pl.Series("a", [1, 2, 3]),
410+
pl.Series("b", ["foo"] * 3, pl.Categorical),
411+
]
412+
),
413+
)

py-polars/tests/unit/datatypes/test_categorical.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_categorical_full_outer_join() -> None:
7272
assert df["key_right"].cast(pl.String).to_list() == ["bar", "baz", None]
7373

7474

75+
@pytest.mark.usefixtures("test_global_and_local")
7576
def test_read_csv_categorical() -> None:
7677
f = io.BytesIO()
7778
f.write(b"col1,col2,col3,col4,col5,col6\n'foo',2,3,4,5,6\n'bar',8,9,10,11,12")
@@ -80,6 +81,7 @@ def test_read_csv_categorical() -> None:
8081
assert df["col1"].dtype == pl.Categorical
8182

8283

84+
@pytest.mark.usefixtures("test_global_and_local")
8385
def test_cat_to_dummies() -> None:
8486
df = pl.DataFrame({"foo": [1, 2, 3, 4], "bar": ["a", "b", "a", "c"]})
8587
df = df.with_columns(pl.col("bar").cast(pl.Categorical))
@@ -94,7 +96,7 @@ def test_cat_to_dummies() -> None:
9496
}
9597

9698

97-
@StringCache()
99+
@pytest.mark.usefixtures("test_global_and_local")
98100
def test_categorical_is_in_list() -> None:
99101
# this requires type coercion to cast.
100102
# we should not cast within the function as this would be expensive within a
@@ -110,7 +112,7 @@ def test_categorical_is_in_list() -> None:
110112
}
111113

112114

113-
@StringCache()
115+
@pytest.mark.usefixtures("test_global_and_local")
114116
def test_unset_sorted_on_append() -> None:
115117
df1 = pl.DataFrame(
116118
[
@@ -137,6 +139,7 @@ def test_unset_sorted_on_append() -> None:
137139
(pl.Series.eq_missing, pl.Series([True, True, True, False, False, False])),
138140
],
139141
)
142+
@pytest.mark.usefixtures("test_global_and_local")
140143
def test_categorical_equality(
141144
op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series
142145
) -> None:
@@ -272,6 +275,7 @@ def test_categorical_global_ordering_broadcast_lhs(
272275
(operator.gt, pl.Series([False, False, False, True, False, False])),
273276
],
274277
)
278+
@pytest.mark.usefixtures("test_global_and_local")
275279
def test_categorical_ordering(
276280
op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series
277281
) -> None:
@@ -289,6 +293,7 @@ def test_categorical_ordering(
289293
(operator.gt, pl.Series([None, False, False, False, False, False])),
290294
],
291295
)
296+
@pytest.mark.usefixtures("test_global_and_local")
292297
def test_compare_categorical(
293298
op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series
294299
) -> None:
@@ -311,6 +316,7 @@ def test_compare_categorical(
311316
(pl.Series.ne_missing, pl.Series([True, True, False, True, False, True])),
312317
],
313318
)
319+
@pytest.mark.usefixtures("test_global_and_local")
314320
def test_compare_categorical_single(
315321
op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series
316322
) -> None:
@@ -400,6 +406,7 @@ def test_categorical_error_on_local_cmp() -> None:
400406
df_cat.filter(pl.col("a_cat") == pl.col("b_cat"))
401407

402408

409+
@pytest.mark.usefixtures("test_global_and_local")
403410
def test_cast_null_to_categorical() -> None:
404411
assert pl.DataFrame().with_columns(
405412
pl.lit(None).cast(pl.Categorical).alias("nullable_enum")
@@ -454,6 +461,7 @@ def create_lazy(data: dict) -> pl.LazyFrame: # type: ignore[type-arg]
454461
assert pl.using_string_cache() is False
455462

456463

464+
@pytest.mark.usefixtures("test_global_and_local")
457465
def test_categorical_in_struct_nulls() -> None:
458466
s = pl.Series(
459467
"job", ["doctor", "waiter", None, None, None, "doctor"], pl.Categorical
@@ -466,6 +474,7 @@ def test_categorical_in_struct_nulls() -> None:
466474
assert s[2] == {"job": "waiter", "count": 1}
467475

468476

477+
@pytest.mark.usefixtures("test_global_and_local")
469478
def test_cast_inner_categorical() -> None:
470479
dtype = pl.List(pl.Categorical)
471480
out = pl.Series("foo", [["a"], ["a", "b"]]).cast(dtype)
@@ -501,6 +510,7 @@ def test_stringcache() -> None:
501510
(pl.Categorical("lexical"), ["bar", "baz", "foo"]),
502511
],
503512
)
513+
@pytest.mark.usefixtures("test_global_and_local")
504514
def test_categorical_sort_order_by_parameter(
505515
dtype: PolarsDataType, outcome: list[str]
506516
) -> None:
@@ -557,12 +567,14 @@ def test_err_on_categorical_asof_join_by_arg() -> None:
557567
df1.join_asof(df2, on=pl.col("time").set_sorted(), by="cat")
558568

559569

570+
@pytest.mark.usefixtures("test_global_and_local")
560571
def test_categorical_list_get_item() -> None:
561572
out = pl.Series([["a"]]).cast(pl.List(pl.Categorical)).item()
562573
assert isinstance(out, pl.Series)
563574
assert out.dtype == pl.Categorical
564575

565576

577+
@pytest.mark.usefixtures("test_global_and_local")
566578
def test_nested_categorical_aggregation_7848() -> None:
567579
# a double categorical aggregation
568580
assert pl.DataFrame(
@@ -580,6 +592,7 @@ def test_nested_categorical_aggregation_7848() -> None:
580592
}
581593

582594

595+
@pytest.mark.usefixtures("test_global_and_local")
583596
def test_nested_categorical_cast() -> None:
584597
values = [["x"], ["y"], ["x"]]
585598
dtype = pl.List(pl.Categorical)
@@ -588,6 +601,7 @@ def test_nested_categorical_cast() -> None:
588601
assert s.to_list() == values
589602

590603

604+
@pytest.mark.usefixtures("test_global_and_local")
591605
def test_struct_categorical_nesting() -> None:
592606
# this triggers a lot of materialization
593607
df = pl.DataFrame(
@@ -610,7 +624,7 @@ def test_categorical_fill_null_existing_category() -> None:
610624
assert result.to_dict(as_series=False) == expected
611625

612626

613-
@StringCache()
627+
@pytest.mark.usefixtures("test_global_and_local")
614628
def test_categorical_fill_null_stringcache() -> None:
615629
df = pl.LazyFrame(
616630
{"index": [1, 2, 3], "cat": ["a", "b", None]},
@@ -622,6 +636,7 @@ def test_categorical_fill_null_stringcache() -> None:
622636
assert a.dtypes == [pl.Categorical]
623637

624638

639+
@pytest.mark.usefixtures("test_global_and_local")
625640
def test_fast_unique_flag_from_arrow() -> None:
626641
df = pl.DataFrame(
627642
{
@@ -633,6 +648,7 @@ def test_fast_unique_flag_from_arrow() -> None:
633648
assert pl.from_arrow(filtered).select(pl.col("colB").n_unique()).item() == 4 # type: ignore[union-attr]
634649

635650

651+
@pytest.mark.usefixtures("test_global_and_local")
636652
def test_construct_with_null() -> None:
637653
# Example from https://github.com/pola-rs/polars/issues/7188
638654
df = pl.from_dicts([{"A": None}, {"A": "foo"}], schema={"A": pl.Categorical})
@@ -663,6 +679,7 @@ def test_list_builder_different_categorical_rev_maps() -> None:
663679
}
664680

665681

682+
@pytest.mark.usefixtures("test_global_and_local")
666683
def test_categorical_collect_11408() -> None:
667684
df = pl.DataFrame(
668685
data={"groups": ["a", "b", "c"], "cats": ["a", "b", "c"], "amount": [1, 2, 3]},
@@ -677,6 +694,7 @@ def test_categorical_collect_11408() -> None:
677694
}
678695

679696

697+
@pytest.mark.usefixtures("test_global_and_local")
680698
def test_categorical_nested_cast_unchecked() -> None:
681699
s = pl.Series("cat", [["cat"]]).cast(pl.List(pl.Categorical))
682700
assert pl.Series([s]).to_list() == [[["cat"]]]
@@ -751,6 +769,7 @@ def test_categorical_vstack_with_local_different_rev_map() -> None:
751769
assert df3.get_column("a").cast(pl.UInt32).to_list() == [0, 1, 2, 3, 4, 5]
752770

753771

772+
@pytest.mark.usefixtures("test_global_and_local")
754773
def test_shift_over_13041() -> None:
755774
df = pl.DataFrame(
756775
{
@@ -768,6 +787,7 @@ def test_shift_over_13041() -> None:
768787

769788
@pytest.mark.parametrize("context", [pl.StringCache(), contextlib.nullcontext()])
770789
@pytest.mark.parametrize("ordering", ["physical", "lexical"])
790+
@pytest.mark.usefixtures("test_global_and_local")
771791
def test_sort_categorical_retain_none(
772792
context: contextlib.AbstractContextManager, # type: ignore[type-arg]
773793
ordering: Literal["physical", "lexical"],
@@ -799,6 +819,7 @@ def test_sort_categorical_retain_none(
799819
]
800820

801821

822+
@pytest.mark.usefixtures("test_global_and_local")
802823
def test_cast_from_cat_to_numeric() -> None:
803824
cat_series = pl.Series(
804825
"cat_series",
@@ -811,12 +832,14 @@ def test_cast_from_cat_to_numeric() -> None:
811832
assert s.cast(pl.UInt8).sum() == 6
812833

813834

835+
@pytest.mark.usefixtures("test_global_and_local")
814836
def test_cat_preserve_lexical_ordering_on_clear() -> None:
815837
s = pl.Series("a", ["a", "b"], dtype=pl.Categorical(ordering="lexical"))
816838
s2 = s.clear()
817839
assert s.dtype == s2.dtype
818840

819841

842+
@pytest.mark.usefixtures("test_global_and_local")
820843
def test_cat_preserve_lexical_ordering_on_concat() -> None:
821844
dtype = pl.Categorical(ordering="lexical")
822845

@@ -827,6 +850,7 @@ def test_cat_preserve_lexical_ordering_on_concat() -> None:
827850

828851
# TODO: Bug see: https://github.com/pola-rs/polars/issues/20440
829852
@pytest.mark.may_fail_auto_streaming
853+
@pytest.mark.usefixtures("test_global_and_local")
830854
def test_cat_append_lexical_sorted_flag() -> None:
831855
df = pl.DataFrame({"x": [0, 1, 1], "y": ["B", "B", "A"]}).with_columns(
832856
pl.col("y").cast(pl.Categorical(ordering="lexical"))
@@ -845,6 +869,7 @@ def test_cat_append_lexical_sorted_flag() -> None:
845869
assert not (s1.is_sorted())
846870

847871

872+
@pytest.mark.usefixtures("test_global_and_local")
848873
def test_get_cat_categories_multiple_chunks() -> None:
849874
df = pl.DataFrame(
850875
[
@@ -877,6 +902,7 @@ def test_nested_categorical_concat(
877902
pl.concat([a, b])
878903

879904

905+
@pytest.mark.usefixtures("test_global_and_local")
880906
def test_perfect_group_by_19452() -> None:
881907
n = 40
882908
df2 = pl.DataFrame(
@@ -889,6 +915,7 @@ def test_perfect_group_by_19452() -> None:
889915
assert df2.with_columns(a=(pl.col("b")).over(pl.col("a")))["a"].is_sorted()
890916

891917

918+
@pytest.mark.usefixtures("test_global_and_local")
892919
def test_perfect_group_by_19950() -> None:
893920
dtype = pl.Enum(categories=["a", "b", "c"])
894921

@@ -900,14 +927,14 @@ def test_perfect_group_by_19950() -> None:
900927
}
901928

902929

903-
@StringCache()
930+
@pytest.mark.usefixtures("test_global_and_local")
904931
def test_categorical_unique() -> None:
905932
s = pl.Series(["a", "b", None], dtype=pl.Categorical)
906933
assert s.n_unique() == 3
907934
assert s.unique().sort().to_list() == [None, "a", "b"]
908935

909936

910-
@StringCache()
937+
@pytest.mark.usefixtures("test_global_and_local")
911938
def test_categorical_unique_20539() -> None:
912939
df = pl.DataFrame({"number": [1, 1, 2, 2, 3], "letter": ["a", "b", "b", "c", "c"]})
913940

@@ -927,13 +954,10 @@ def test_categorical_unique_20539() -> None:
927954
}
928955

929956

930-
@StringCache()
931957
@pytest.mark.may_fail_auto_streaming
958+
@pytest.mark.usefixtures("test_global_and_local")
932959
def test_categorical_prefill() -> None:
933960
# https://github.com/pola-rs/polars/pull/20547#issuecomment-2569473443
934-
# prefill cache
935-
pl.Series(["aaa", "bbb", "ccc"], dtype=pl.Categorical) # pre-fill cache
936-
937961
# test_compare_categorical_single
938962
assert (pl.Series(["a"], dtype=pl.Categorical) < "a").to_list() == [False]
939963

py-polars/tests/unit/datatypes/test_list.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def test_dtype() -> None:
6565
]
6666

6767

68+
@pytest.mark.usefixtures("test_global_and_local")
6869
def test_categorical() -> None:
6970
# https://github.com/pola-rs/polars/issues/2038
7071
df = pl.DataFrame(

py-polars/tests/unit/functions/as_datatype/test_concat_list.py

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def test_list_concat_supertype() -> None:
9191
].to_list() == [[1, 10000], [2, 20000]]
9292

9393

94+
@pytest.mark.usefixtures("test_global_and_local")
9495
def test_categorical_list_concat_4762() -> None:
9596
df = pl.DataFrame({"x": "a"})
9697
expected = {"x": [["a", "a"]]}

py-polars/tests/unit/interchange/test_from_dataframe.py

+1
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def test_string_column_to_series_no_offsets() -> None:
334334
_string_column_to_series(col, allow_copy=True)
335335

336336

337+
@pytest.mark.usefixtures("test_global_and_local")
337338
def test_categorical_column_to_series_non_dictionary() -> None:
338339
s = pl.Series(["a", "b", None, "a"], dtype=pl.Categorical)
339340

py-polars/tests/unit/io/test_delta.py

+1
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def test_unsupported_dtypes(tmp_path: Path) -> None:
474474
reason="upstream bug in delta-rs causing categorical to be written as categorical in parquet"
475475
)
476476
@pytest.mark.write_disk
477+
@pytest.mark.usefixtures("test_global_and_local")
477478
def test_categorical_becomes_string(tmp_path: Path) -> None:
478479
df = pl.DataFrame({"a": ["A", "B", "A"]}, schema={"a": pl.Categorical})
479480
df.write_delta(tmp_path)

py-polars/tests/unit/io/test_lazy_parquet.py

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_row_index_len_16543(foods_parquet_path: Path) -> None:
6666

6767

6868
@pytest.mark.write_disk
69+
@pytest.mark.usefixtures("test_global_and_local")
6970
def test_categorical_parquet_statistics(tmp_path: Path) -> None:
7071
tmp_path.mkdir(exist_ok=True)
7172

@@ -281,6 +282,7 @@ def test_parquet_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> Non
281282

282283

283284
@pytest.mark.write_disk
285+
@pytest.mark.usefixtures("test_global_and_local")
284286
def test_categorical(tmp_path: Path) -> None:
285287
tmp_path.mkdir(exist_ok=True)
286288

py-polars/tests/unit/io/test_other.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def test_copy() -> None:
8484
assert_series_equal(copy.deepcopy(a), a)
8585

8686

87+
@pytest.mark.usefixtures("test_global_and_local")
8788
def test_categorical_round_trip() -> None:
8889
df = pl.DataFrame({"ints": [1, 2, 3], "cat": ["a", "b", "c"]})
8990
df = df.with_columns(pl.col("cat").cast(pl.Categorical))

py-polars/tests/unit/io/test_parquet.py

+1
Original file line numberDiff line numberDiff line change
@@ -2433,6 +2433,7 @@ def test_dict_masked(
24332433
)
24342434

24352435

2436+
@pytest.mark.usefixtures("test_global_and_local")
24362437
def test_categorical_sliced_20017() -> None:
24372438
f = io.BytesIO()
24382439
df = (

0 commit comments

Comments
 (0)