Skip to content

Commit

Permalink
Release 4.0.2 (spotify#159)
Browse files Browse the repository at this point in the history
* Fixed categorical_order_by used with array_like (spotify#157)

* Fix categorical_order_by check for scatter plot

* Fix categorical_order_by check for _construct_source

* Refactor category sorting in _construct_source

* Add tests for categorical_order_by

* Correct scatter test (spotify#158)

* Update version in init

* Update HISTORY.rst

---------

Co-authored-by: Quoc Duong Bui <[email protected]>
  • Loading branch information
iampelle and vanHekthor authored Mar 30, 2023
1 parent ca828ab commit 364361f
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 36 deletions.
9 changes: 9 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
History
=======

4.0.2 (2023-03-30)
------------------

* Fix categorical_order_by check for scatter plot
* Fix categorical_order_by check for _construct_source
* Refactor category sorting in _construct_source
* Add tests for categorical_order_by
* Fix scatter plot tests that used line plots

4.0.1 (2023-03-24)
------------------

Expand Down
2 changes: 1 addition & 1 deletion chartify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

__author__ = """Chris Halpert"""
__email__ = "[email protected]"
__version__ = "4.0.1"
__version__ = "4.0.2"

_IPYTHON_INSTANCE = False

Expand Down
79 changes: 48 additions & 31 deletions chartify/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,49 @@ def _get_bar_width(factors):
else:
return 0.9

@staticmethod
def _sort_categories_by_value(source, categorical_columns, categorical_order_ascending):
# Recursively sort values within each level of the index.
row_totals = source.sum(axis=1, numeric_only=True)
row_totals.name = "sum"
old_index = row_totals.index
row_totals = row_totals.reset_index()
row_totals.columns = ["_%s" % col for col in row_totals.columns]
row_totals.index = old_index

hierarchical_sort_cols = categorical_columns[:]
for i, _ in enumerate(hierarchical_sort_cols):
row_totals["level_%s" % i] = row_totals.groupby(hierarchical_sort_cols[: i + 1])["_sum"].transform(
"sum"
)
row_totals = row_totals.sort_values(
by=["level_%s" % i for i, _ in enumerate(hierarchical_sort_cols)],
ascending=categorical_order_ascending,
)
return source.reindex(row_totals.index)

@staticmethod
def _sort_categories(
source,
categorical_columns,
categorical_order_by,
categorical_order_ascending
):

is_string = isinstance(categorical_order_by, str)
order_length = getattr(categorical_order_by, "__len__", None)
# Sort the categories
if is_string and categorical_order_by == "values":
return PlotMixedTypeXY._sort_categories_by_value(
source, categorical_columns, categorical_order_ascending)
elif is_string and categorical_order_by == "labels":
return source.sort_index(axis=0, ascending=categorical_order_ascending)
# Manual sort
elif not is_string and order_length is not None:
return source.reindex(categorical_order_by, axis="index")

raise ValueError("""Must be 'values', 'labels', or a list of values.""")

def _construct_source(
self,
data_frame,
Expand Down Expand Up @@ -1014,34 +1057,7 @@ def _construct_source(
if normalize:
source = source.div(source.sum(axis=1), axis=0)

order_length = getattr(categorical_order_by, "__len__", None)
# Sort the categories
if categorical_order_by == "values":
# Recursively sort values within each level of the index.
row_totals = source.sum(axis=1, numeric_only=True)
row_totals.name = "sum"
old_index = row_totals.index
row_totals = row_totals.reset_index()
row_totals.columns = ["_%s" % col for col in row_totals.columns]
row_totals.index = old_index

heirarchical_sort_cols = categorical_columns[:]
for i, _ in enumerate(heirarchical_sort_cols):
row_totals["level_%s" % i] = row_totals.groupby(heirarchical_sort_cols[: i + 1])["_sum"].transform(
"sum"
)
row_totals = row_totals.sort_values(
by=["level_%s" % i for i, _ in enumerate(heirarchical_sort_cols)],
ascending=categorical_order_ascending,
)
source = source.reindex(row_totals.index)
elif categorical_order_by == "labels":
source = source.sort_index(axis=0, ascending=categorical_order_ascending)
# Manual sort
elif order_length is not None:
source = source.reindex(categorical_order_by, axis="index")
else:
raise ValueError("""Must be 'values', 'labels', or a list of values.""")
source = self._sort_categories(source, categorical_columns, categorical_order_by, categorical_order_ascending)

# Cast all categorical columns to strings
# Plotting functions will break with non-str types.
Expand Down Expand Up @@ -2003,13 +2019,14 @@ def scatter(

axis_factors = data_frame.groupby(categorical_columns).size()

is_string = isinstance(categorical_order_by, str)
order_length = getattr(categorical_order_by, "__len__", None)
if categorical_order_by == "labels":
if is_string and categorical_order_by == "labels":
axis_factors = axis_factors.sort_index(ascending=categorical_order_ascending).index
elif categorical_order_by == "count":
elif is_string and categorical_order_by == "count":
axis_factors = axis_factors.sort_values(ascending=categorical_order_ascending).index
# User-specified order.
elif order_length is not None:
elif not is_string and order_length is not None:
axis_factors = categorical_order_by
else:
raise ValueError("""Must be 'count', 'labels', or a list of values.""")
Expand Down
73 changes: 69 additions & 4 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def setup_method(self):
})

def test_single_numeric_scatter(self):
"""Single line test"""
"""Single scatter test"""
single_scatter = self.data[self.data['category1'] == 'a']
ch = chartify.Chart()
ch.plot.line(single_scatter, x_column='number1', y_column='number2')
ch.plot.scatter(single_scatter, x_column='number1', y_column='number2')
assert (np.array_equal(chart_data(ch, '')['number1'], [1., 2., 3.]))
assert (np.array_equal(chart_data(ch, '')['number2'], [5, 10, 0]))

def test_multi_numeric_scatter(self):
"""Single line test"""
"""Multi scatter test"""
ch = chartify.Chart()
ch.plot.scatter(
self.data,
Expand All @@ -151,7 +151,7 @@ def test_multi_numeric_scatter(self):
assert (np.array_equal(chart_data(ch, 'b')['number2'], [4, -3, -10]))

def test_single_datetime_scatter(self):
"""Single line test"""
"""Single datetime scatter test"""
data = pd.DataFrame({
'number': [1, 10, -10, 0],
'datetimes':
Expand Down Expand Up @@ -794,6 +794,71 @@ def test_grouped_histogram(self):
assert (np.array_equal(chart_data(ch, 'b')['min_edge'], [2., 6.]))


class TestCategoricalOrderBy:
def _assert_order_by_array_like(self, chart):
assert (np.array_equal(chart.figure.x_range.factors, ['b', 'd', 'a', 'c']))
# check bar data
assert (np.array_equal(chart_data(chart, '')['factors'], ['b', 'd', 'a', 'c']))
assert (np.array_equal(chart_data(chart, '')['number1'], [3, 1, 4, 2]))
# check scatter data
assert (np.array_equal(chart_data(chart, 'number1')['factors'], ['a', 'b', 'c', 'd']))
assert (np.array_equal(chart_data(chart, 'number1')['number1'], [4, 3, 2, 1]))

def setup_method(self):
self.data1 = pd.DataFrame({
'category1': ['a', 'b', 'c', 'd'],
'number1': [4, 3, 2, 1],
})

self.data2 = pd.DataFrame({
'category2': ['b', 'a', 'b', 'b', 'a', 'c'],
'number2': [1, 2, 3, 4, 5, 6]
})

def test_order_by_labels(self):
ch = chartify.Chart(x_axis_type='categorical')

ch.plot.bar(self.data1, ['category1'], 'number1', categorical_order_by='labels')
assert (np.array_equal(ch.figure.x_range.factors, ['d', 'c', 'b', 'a']))
assert (np.array_equal(chart_data(ch, '')['factors'], ['d', 'c', 'b', 'a']))
assert (np.array_equal(chart_data(ch, '')['number1'], [1, 2, 3, 4]))

ch.plot.scatter(self.data1, ['category1'], 'number1', categorical_order_by='labels')
assert (np.array_equal(ch.figure.x_range.factors, ['d', 'c', 'b', 'a']))
assert (np.array_equal(chart_data(ch, 'number1')['factors'], ['a', 'b', 'c', 'd']))
assert (np.array_equal(chart_data(ch, 'number1')['number1'], [4, 3, 2, 1]))

def test_order_by_values(self):
ch = chartify.Chart(x_axis_type='categorical')
ch.plot.bar(self.data1, ['category1'], 'number1', categorical_order_by='values')
assert (np.array_equal(chart_data(ch, '')['factors'], ['a', 'b', 'c', 'd']))
assert (np.array_equal(chart_data(ch, '')['number1'], [4, 3, 2, 1]))

def test_order_by_count(self):
ch = chartify.Chart(x_axis_type='categorical')
ch.plot.scatter(self.data2, ['category2'], 'number2', categorical_order_by='count')

assert (np.array_equal(ch.figure.x_range.factors, ['b', 'a', 'c']))
assert (np.array_equal(chart_data(ch, 'number2')['factors'], ['b', 'a', 'b', 'b', 'a', 'c']))
assert (np.array_equal(chart_data(ch, 'number2')['number2'], [1, 2, 3, 4, 5, 6]))

@pytest.mark.parametrize(
'array_like', [['b', 'd', 'a', 'c'], np.array(['b', 'd', 'a', 'c']), pd.Series(['b', 'd', 'a', 'c'])])
def test_order_by_array_like(self, array_like):
ch = chartify.Chart(x_axis_type='categorical')
ch.plot.scatter(self.data1, ['category1'], 'number1', categorical_order_by=array_like)
ch.plot.bar(self.data1, ['category1'], 'number1', categorical_order_by=array_like)

self._assert_order_by_array_like(ch)

@pytest.mark.parametrize('plot_method,categorical_order_by', [('bar', 'count'), ('scatter', 'values')])
def test_error(self, plot_method, categorical_order_by):
ch = chartify.Chart(x_axis_type='categorical', y_axis_type='linear')
with pytest.raises(ValueError):
plot_method = getattr(ch.plot, plot_method)
plot_method(self.data1, ['category1'], 'number1', categorical_order_by=categorical_order_by)


def test_categorical_axis_type_casting():
"""Categorical axis plotting breaks for non-str types.
Test that type casting is performed correctly"""
Expand Down

0 comments on commit 364361f

Please sign in to comment.