Skip to content

Commit

Permalink
feat(python): validate Enum categories (pola-rs#13356)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Jan 2, 2024
1 parent 4d95c18 commit 2a19d22
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
19 changes: 17 additions & 2 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,16 +538,31 @@ class Enum(DataType):

categories: list[str]

def __init__(self, categories: list[str]):
def __init__(self, categories: Iterable[str]):
"""
A fixed set categorical encoding of a set of strings.
Parameters
----------
categories
Categories in the dataset.
Valid categories in the dataset.
"""
if not isinstance(categories, list):
categories = list(categories)

seen: set[str] = set()
for cat in categories:
if cat in seen:
raise ValueError(
f"Enum categories must be unique; found duplicate {cat!r}"
)
if not isinstance(cat, str):
raise TypeError(
f"Enum categories must be strings; found {cat!r} ({type(cat).__name__})"
)
seen.add(cat)

self.categories = categories

def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
Expand Down
24 changes: 23 additions & 1 deletion py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import operator
from datetime import date
from textwrap import dedent
from typing import Callable
from typing import Any, Callable

import pytest

Expand All @@ -15,6 +16,13 @@ def test_enum_creation() -> None:
assert s.len() == 3
assert s.dtype == pl.Enum(categories=["a", "b"])

# from iterables
e = pl.Enum(f"x{i}" for i in range(5))
assert e.categories == ["x0", "x1", "x2", "x3", "x4"]

e = pl.Enum("abcde")
assert e.categories == ["a", "b", "c", "d", "e"]


def test_enum_non_existent() -> None:
with pytest.raises(
Expand Down Expand Up @@ -303,3 +311,17 @@ def test_different_enum_comparison_order() -> None:
match="can only compare categoricals of the same type",
):
df_enum.filter(op(pl.col("a_cat"), pl.col("b_cat")))


@pytest.mark.parametrize(
"categories",
[[None], [date.today()], [-10, 10], ["x", "y", None]],
)
def test_valid_enum_category_types(categories: Any) -> None:
with pytest.raises(TypeError, match="Enum categories"):
pl.Enum(categories)


def test_enum_categories_unique() -> None:
with pytest.raises(ValueError, match="must be unique; found duplicate 'a'"):
pl.Enum(["a", "a", "b", "b", "b", "c"])

0 comments on commit 2a19d22

Please sign in to comment.