From 2a19d22732348d43ade9275ba450b29fe7d89f0f Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 2 Jan 2024 16:06:20 +0400 Subject: [PATCH] feat(python): validate Enum categories (#13356) --- py-polars/polars/datatypes/classes.py | 19 ++++++++++++++-- py-polars/tests/unit/datatypes/test_enum.py | 24 ++++++++++++++++++++- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 39647be83202..4332602facae 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -538,16 +538,31 @@ class Enum(DataType): categories: list[str] - def __init__(self, categories: list[str]): + def __init__(self, categories: Iterable[str]): """ A fixed set categorical encoding of a set of strings. Parameters ---------- categories - Categories in the dataset. + Valid categories in the dataset. """ + if not isinstance(categories, list): + categories = list(categories) + + seen: set[str] = set() + for cat in categories: + if cat in seen: + raise ValueError( + f"Enum categories must be unique; found duplicate {cat!r}" + ) + if not isinstance(cat, str): + raise TypeError( + f"Enum categories must be strings; found {cat!r} ({type(cat).__name__})" + ) + seen.add(cat) + self.categories = categories def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 213e6d5a8271..cf65b323e89d 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -1,6 +1,7 @@ import operator +from datetime import date from textwrap import dedent -from typing import Callable +from typing import Any, Callable import pytest @@ -15,6 +16,13 @@ def test_enum_creation() -> None: assert s.len() == 3 assert s.dtype == pl.Enum(categories=["a", "b"]) + # from iterables + e = pl.Enum(f"x{i}" for i in range(5)) + assert e.categories == ["x0", "x1", "x2", "x3", "x4"] + + e = pl.Enum("abcde") + assert e.categories == ["a", "b", "c", "d", "e"] + def test_enum_non_existent() -> None: with pytest.raises( @@ -303,3 +311,17 @@ def test_different_enum_comparison_order() -> None: match="can only compare categoricals of the same type", ): df_enum.filter(op(pl.col("a_cat"), pl.col("b_cat"))) + + +@pytest.mark.parametrize( + "categories", + [[None], [date.today()], [-10, 10], ["x", "y", None]], +) +def test_valid_enum_category_types(categories: Any) -> None: + with pytest.raises(TypeError, match="Enum categories"): + pl.Enum(categories) + + +def test_enum_categories_unique() -> None: + with pytest.raises(ValueError, match="must be unique; found duplicate 'a'"): + pl.Enum(["a", "a", "b", "b", "b", "c"])