Skip to content

Commit

Permalink
Support summary containing multiple where with the same selector (hol…
Browse files Browse the repository at this point in the history
  • Loading branch information
ianthomas23 authored Aug 16, 2023
1 parent ea163e9 commit 765e82e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
21 changes: 17 additions & 4 deletions datashader/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def make_append(bases, cols, calls, glyph, antialias):
subscript = None
prev_local_cuda_mutex = False
categorical_args = {} # Reuse categorical arguments if used in more than one reduction
where_selectors = {} # Reuse where.selector if used more than once in a summary reduction

def get_cuda_mutex_call(lock: bool) -> str:
func = "cuda_mutex_lock" if lock else "cuda_mutex_unlock"
Expand Down Expand Up @@ -379,17 +380,29 @@ def get_cuda_mutex_call(lock: bool) -> str:
# Avoid unnecessary mutex unlock and lock cycle
body.pop()

where_reduction = len(bases) == 1 and bases[0].is_where()
if where_reduction:
update_index_arg_name = next(names)
is_where = len(bases) == 1 and bases[0].is_where()
if is_where:
where_reduction = bases[0]
if isinstance(where_reduction, by):
where_reduction = where_reduction.reduction

selector_hash = hash(where_reduction.selector)
update_index_arg_name = where_selectors.get(selector_hash, None)
new_selector = update_index_arg_name is None
if new_selector:
update_index_arg_name = next(names)
where_selectors[selector_hash] = update_index_arg_name
args.append(update_index_arg_name)

# where reduction needs access to the return of the contained
# reduction, which is the preceding one here.
prev_body = body.pop()
if local_cuda_mutex and not prev_local_cuda_mutex:
body.append(get_cuda_mutex_call(True))
body.append(f'{update_index_arg_name} = {prev_body}')
if new_selector:
body.append(f'{update_index_arg_name} = {prev_body}')
else:
body.append(prev_body)

# If nan_check_column is defined then need to check if value of
# correct row in that column is NaN and if so do nothing. This
Expand Down
9 changes: 9 additions & 0 deletions datashader/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,15 @@ def test_summary_where_n(df):
assert_eq_ndarray(agg['min_n'].data, sol_min_n_rowindex)
assert_eq_ndarray(agg['max_n'].data, sol_max_n_reverse)

# Issue #1270: Support summary reduction containing multiple where
# reductions that use the same selector.
agg = c.points(df, 'x', 'y', ds.summary(
max1=ds.where(ds.max_n('plusminus', 5)),
max2=ds.where(ds.max_n('plusminus', 5), 'reverse'),
))
assert_eq_ndarray(agg['max1'].data, sol_max_n_rowindex)
assert_eq_ndarray(agg['max2'].data, sol_max_n_reverse)


@pytest.mark.parametrize('df', dfs)
def test_summary_different_n(df):
Expand Down

0 comments on commit 765e82e

Please sign in to comment.