Skip to content

Commit

Permalink
Fix column naming for DataFrames with MultiIndex columns (scikit-lear…
Browse files Browse the repository at this point in the history
  • Loading branch information
kristofve authored and dukebody committed Aug 15, 2018
1 parent 757cc33 commit 7fdc39a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
.tox/
build/
dist/
.cache/
.cache/
.idea/
.pytest_cache/
2 changes: 1 addition & 1 deletion sklearn_pandas/dataframe_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def get_names(self, columns, transformer, x, alias=None):
if alias is not None:
name = alias
elif isinstance(columns, list):
name = '_'.join(columns)
name = '_'.join(map(str, columns))
else:
name = columns
num_cols = x.shape[1] if len(x.shape) > 1 else 1
Expand Down
50 changes: 50 additions & 0 deletions tests/test_dataframe_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,29 @@ def complex_dataframe():
'feat2': [1, 2, 3, 2, 3, 4]})


@pytest.fixture
def multiindex_dataframe():
"""Example MultiIndex DataFrame, taken from pandas documentation
"""
iterables = [['bar', 'baz', 'foo', 'qux'], ['one', 'two']]
index = pd.MultiIndex.from_product(iterables, names=['first', 'second'])
df = pd.DataFrame(np.random.randn(10, 8), columns=index)
return df


@pytest.fixture
def multiindex_dataframe_incomplete(multiindex_dataframe):
"""Example MultiIndex DataFrame with missing entries
"""
df = multiindex_dataframe
mask_array = np.zeros(df.size)
mask_array[:20] = 1
np.random.shuffle(mask_array)
mask = mask_array.reshape(df.shape).astype(bool)
df.mask(mask, inplace=True)
return df


def test_transformed_names_simple(simple_dataframe):
"""
Get transformed names of features in `transformed_names` attribute
Expand Down Expand Up @@ -234,6 +257,33 @@ def test_complex_df(complex_dataframe):
assert len(transformed[c]) == len(df[c])


def test_numeric_column_names(complex_dataframe):
"""
Get a dataframe from a complex mapped dataframe with numeric column names
"""
df = complex_dataframe
df.columns = [0, 1, 2]
mapper = DataFrameMapper(
[(0, None), (1, None), (2, None)], df_out=True)
transformed = mapper.fit_transform(df)
assert len(transformed) == len(complex_dataframe)
for c in df.columns:
assert len(transformed[c]) == len(df[c])


def test_multiindex_df(multiindex_dataframe_incomplete):
"""
Get a dataframe from a multiindex dataframe with missing data
"""
df = multiindex_dataframe_incomplete
mapper = DataFrameMapper([([c], Imputer()) for c in df.columns],
df_out=True)
transformed = mapper.fit_transform(df)
assert len(transformed) == len(multiindex_dataframe_incomplete)
for c in df.columns:
assert len(transformed[str(c)]) == len(df[c])


def test_binarizer_df():
"""
Check level names from LabelBinarizer
Expand Down

0 comments on commit 7fdc39a

Please sign in to comment.