Skip to content

Commit

Permalink
mirror nx for cumulative and window functions (elixir-explorer#200)
Browse files Browse the repository at this point in the history
* cum_* -> cumulative_*

* rolling_* -> window_*

* max reverse? an option in cumulative series functions

* update changelog

* don't use question mark for option

Co-authored-by: José Valim <[email protected]>

Co-authored-by: José Valim <[email protected]>
  • Loading branch information
cigrainger and josevalim authored May 2, 2022
1 parent 32ecd2d commit 0f4febc
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 63 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Now uses `polars`'s "performant" feature
- `Explorer.default_backend/0` is now `Explorer.Backend.get/0`
- `Explorer.default_backend/1` is now `Explorer.Backend.put/1`
- `Series.cum_*` functions are now `Series.cumulative_*` to mirror `Nx`
- `Series.rolling_*` functions are now `Series.window_*` to mirror `Nx`
- `reverse?` is now an option instead of an argument in `Series.cumulative_*` functions

## [v0.1.1] - 2022-04-27

Expand Down
16 changes: 8 additions & 8 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ defmodule Explorer.Backend.Series do

# Cumulative

@callback cum_max(s, reverse? :: boolean()) :: s
@callback cum_min(s, reverse? :: boolean()) :: s
@callback cum_sum(s, reverse? :: boolean()) :: s
@callback cumulative_max(s, reverse? :: boolean()) :: s
@callback cumulative_min(s, reverse? :: boolean()) :: s
@callback cumulative_sum(s, reverse? :: boolean()) :: s

# Local minima/maxima

Expand Down Expand Up @@ -93,15 +93,15 @@ defmodule Explorer.Backend.Series do

# Rolling

@type rolling_option ::
@type window_option ::
{:weights, [float()] | nil}
| {:min_periods, integer() | nil}
| {:center, boolean()}

@callback rolling_sum(s, window_size :: integer(), [rolling_option()]) :: s
@callback rolling_min(s, window_size :: integer(), [rolling_option()]) :: s
@callback rolling_max(s, window_size :: integer(), [rolling_option()]) :: s
@callback rolling_mean(s, window_size :: integer(), [rolling_option()]) :: s
@callback window_sum(s, window_size :: integer(), [window_option()]) :: s
@callback window_min(s, window_size :: integer(), [window_option()]) :: s
@callback window_max(s, window_size :: integer(), [window_option()]) :: s
@callback window_mean(s, window_size :: integer(), [window_option()]) :: s

# Nulls

Expand Down
19 changes: 11 additions & 8 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,16 @@ defmodule Explorer.PolarsBackend.Series do
# Cumulative

@impl true
def cum_max(series, reverse?), do: Shared.apply_native(series, :s_cum_max, [reverse?])
def cumulative_max(series, reverse?),
do: Shared.apply_native(series, :s_cum_max, [reverse?])

@impl true
def cum_min(series, reverse?), do: Shared.apply_native(series, :s_cum_min, [reverse?])
def cumulative_min(series, reverse?),
do: Shared.apply_native(series, :s_cum_min, [reverse?])

@impl true
def cum_sum(series, reverse?), do: Shared.apply_native(series, :s_cum_sum, [reverse?])
def cumulative_sum(series, reverse?),
do: Shared.apply_native(series, :s_cum_sum, [reverse?])

# Local minima/maxima

Expand Down Expand Up @@ -265,10 +268,10 @@ defmodule Explorer.PolarsBackend.Series do
|> DataFrame.rename(["values", "counts"])
|> DataFrame.mutate(counts: &Series.cast(&1["counts"], :integer))

# Rolling
# Window

@impl true
def rolling_max(series, window_size, opts) do
def window_max(series, window_size, opts) do
weights = Keyword.fetch!(opts, :weights)
min_periods = Keyword.fetch!(opts, :min_periods)
center = Keyword.fetch!(opts, :center)
Expand All @@ -277,7 +280,7 @@ defmodule Explorer.PolarsBackend.Series do
end

@impl true
def rolling_mean(series, window_size, opts) do
def window_mean(series, window_size, opts) do
weights = Keyword.fetch!(opts, :weights)
min_periods = Keyword.fetch!(opts, :min_periods)
center = Keyword.fetch!(opts, :center)
Expand All @@ -286,7 +289,7 @@ defmodule Explorer.PolarsBackend.Series do
end

@impl true
def rolling_min(series, window_size, opts) do
def window_min(series, window_size, opts) do
weights = Keyword.fetch!(opts, :weights)
min_periods = Keyword.fetch!(opts, :min_periods)
center = Keyword.fetch!(opts, :center)
Expand All @@ -295,7 +298,7 @@ defmodule Explorer.PolarsBackend.Series do
end

@impl true
def rolling_sum(series, window_size, opts) do
def window_sum(series, window_size, opts) do
weights = Keyword.fetch!(opts, :weights)
min_periods = Keyword.fetch!(opts, :min_periods)
center = Keyword.fetch!(opts, :center)
Expand Down
100 changes: 53 additions & 47 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -924,28 +924,30 @@ defmodule Explorer.Series do
## Examples
iex> s = [1, 2, 3, 4] |> Explorer.Series.from_list()
iex> Explorer.Series.cum_max(s)
iex> Explorer.Series.cumulative_max(s)
#Explorer.Series<
integer[4]
[1, 2, 3, 4]
>
iex> s = [1, 2, nil, 4] |> Explorer.Series.from_list()
iex> Explorer.Series.cum_max(s)
iex> Explorer.Series.cumulative_max(s)
#Explorer.Series<
integer[4]
[1, 2, nil, 4]
>
"""
@spec cum_max(series :: Series.t(), reverse? :: boolean()) :: Series.t()
def cum_max(series, reverse? \\ false)
@spec cumulative_max(series :: Series.t(), opts :: Keyword.t()) :: Series.t()
def cumulative_max(series, opts \\ [])

def cum_max(%Series{dtype: dtype} = series, reverse?)
when dtype in [:integer, :float, :date, :datetime],
do: apply_impl(series, :cum_max, [reverse?])
def cumulative_max(%Series{dtype: dtype} = series, opts)
when dtype in [:integer, :float, :date, :datetime] do
opts = Keyword.validate!(opts, reverse: false)
apply_impl(series, :cumulative_max, [opts[:reverse]])
end

def cum_max(%Series{dtype: dtype}, _),
do: dtype_error("cum_max/2", dtype, [:integer, :float, :date, :datetime])
def cumulative_max(%Series{dtype: dtype}, _),
do: dtype_error("cumulative_max/2", dtype, [:integer, :float, :date, :datetime])

@doc """
Calculates the cumulative minimum of the series.
Expand All @@ -964,28 +966,30 @@ defmodule Explorer.Series do
## Examples
iex> s = [1, 2, 3, 4] |> Explorer.Series.from_list()
iex> Explorer.Series.cum_min(s)
iex> Explorer.Series.cumulative_min(s)
#Explorer.Series<
integer[4]
[1, 1, 1, 1]
>
iex> s = [1, 2, nil, 4] |> Explorer.Series.from_list()
iex> Explorer.Series.cum_min(s)
iex> Explorer.Series.cumulative_min(s)
#Explorer.Series<
integer[4]
[1, 1, nil, 1]
>
"""
@spec cum_min(series :: Series.t(), reverse? :: boolean()) :: Series.t()
def cum_min(series, reverse? \\ false)
@spec cumulative_min(series :: Series.t(), opts :: Keyword.t()) :: Series.t()
def cumulative_min(series, opts \\ [])

def cum_min(%Series{dtype: dtype} = series, reverse?)
when dtype in [:integer, :float, :date, :datetime],
do: apply_impl(series, :cum_min, [reverse?])
def cumulative_min(%Series{dtype: dtype} = series, opts)
when dtype in [:integer, :float, :date, :datetime] do
opts = Keyword.validate!(opts, reverse: false)
apply_impl(series, :cumulative_min, [opts[:reverse]])
end

def cum_min(%Series{dtype: dtype}, _),
do: dtype_error("cum_min/2", dtype, [:integer, :float, :date, :datetime])
def cumulative_min(%Series{dtype: dtype}, _),
do: dtype_error("cumulative_min/2", dtype, [:integer, :float, :date, :datetime])

@doc """
Calculates the cumulative sum of the series.
Expand All @@ -1003,28 +1007,30 @@ defmodule Explorer.Series do
## Examples
iex> s = [1, 2, 3, 4] |> Explorer.Series.from_list()
iex> Explorer.Series.cum_sum(s)
iex> Explorer.Series.cumulative_sum(s)
#Explorer.Series<
integer[4]
[1, 3, 6, 10]
>
iex> s = [1, 2, nil, 4] |> Explorer.Series.from_list()
iex> Explorer.Series.cum_sum(s)
iex> Explorer.Series.cumulative_sum(s)
#Explorer.Series<
integer[4]
[1, 3, nil, 7]
>
"""
@spec cum_sum(series :: Series.t(), reverse? :: boolean()) :: Series.t()
def cum_sum(series, reverse? \\ false)
@spec cumulative_sum(series :: Series.t(), opts :: Keyword.t()) :: Series.t()
def cumulative_sum(series, opts \\ [])

def cum_sum(%Series{dtype: dtype} = series, reverse?)
when dtype in [:integer, :float, :boolean],
do: apply_impl(series, :cum_sum, [reverse?])
def cumulative_sum(%Series{dtype: dtype} = series, opts)
when dtype in [:integer, :float] do
opts = Keyword.validate!(opts, reverse: false)
apply_impl(series, :cumulative_sum, [opts[:reverse]])
end

def cum_sum(%Series{dtype: dtype}, _),
do: dtype_error("cum_sum/2", dtype, [:integer, :float])
def cumulative_sum(%Series{dtype: dtype}, _),
do: dtype_error("cumulative_sum/2", dtype, [:integer, :float])

# Local minima/maxima

Expand Down Expand Up @@ -1663,12 +1669,12 @@ defmodule Explorer.Series do
>
"""
def sort(series, reverse? \\ false), do: apply_impl(series, :sort, [reverse?])
def sort(series, reverse \\ false), do: apply_impl(series, :sort, [reverse])

@doc """
Returns the indices that would sort the series.
"""
def argsort(series, reverse? \\ false), do: apply_impl(series, :argsort, [reverse?])
def argsort(series, reverse \\ false), do: apply_impl(series, :argsort, [reverse])

@doc """
Reverses the series order.
Expand Down Expand Up @@ -1738,7 +1744,7 @@ defmodule Explorer.Series do
"""
def count(series), do: apply_impl(series, :count)

# Rolling
# Window

@doc """
Calculate the rolling sum, given a window size and optional list of weights.
Expand All @@ -1756,21 +1762,21 @@ defmodule Explorer.Series do
## Examples
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_sum(s, 4)
iex> Explorer.Series.window_sum(s, 4)
#Explorer.Series<
integer[10]
[1, 3, 6, 10, 14, 18, 22, 26, 30, 34]
>
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_sum(s, 2, weights: [1.0, 2.0])
iex> Explorer.Series.window_sum(s, 2, weights: [1.0, 2.0])
#Explorer.Series<
float[10]
[1.0, 5.0, 8.0, 11.0, 14.0, 17.0, 20.0, 23.0, 26.0, 29.0]
>
"""
def rolling_sum(series, window_size, opts \\ []),
do: apply_impl(series, :rolling_sum, [window_size, rolling_opts_with_defaults(opts)])
def window_sum(series, window_size, opts \\ []),
do: apply_impl(series, :window_sum, [window_size, window_opts_with_defaults(opts)])

@doc """
Calculate the rolling mean, given a window size and optional list of weights.
Expand All @@ -1788,21 +1794,21 @@ defmodule Explorer.Series do
## Examples
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_mean(s, 4)
iex> Explorer.Series.window_mean(s, 4)
#Explorer.Series<
float[10]
[1.0, 1.5, 2.0, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5]
>
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_mean(s, 2, weights: [1.0, 2.0])
iex> Explorer.Series.window_mean(s, 2, weights: [1.0, 2.0])
#Explorer.Series<
float[10]
[1.0, 2.5, 4.0, 5.5, 7.0, 8.5, 10.0, 11.5, 13.0, 14.5]
>
"""
def rolling_mean(series, window_size, opts \\ []),
do: apply_impl(series, :rolling_mean, [window_size, rolling_opts_with_defaults(opts)])
def window_mean(series, window_size, opts \\ []),
do: apply_impl(series, :window_mean, [window_size, window_opts_with_defaults(opts)])

@doc """
Calculate the rolling min, given a window size and optional list of weights.
Expand All @@ -1820,21 +1826,21 @@ defmodule Explorer.Series do
## Examples
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_min(s, 4)
iex> Explorer.Series.window_min(s, 4)
#Explorer.Series<
integer[10]
[1, 1, 1, 1, 2, 3, 4, 5, 6, 7]
>
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_min(s, 2, weights: [1.0, 2.0])
iex> Explorer.Series.window_min(s, 2, weights: [1.0, 2.0])
#Explorer.Series<
float[10]
[1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
>
"""
def rolling_min(series, window_size, opts \\ []),
do: apply_impl(series, :rolling_min, [window_size, rolling_opts_with_defaults(opts)])
def window_min(series, window_size, opts \\ []),
do: apply_impl(series, :window_min, [window_size, window_opts_with_defaults(opts)])

@doc """
Calculate the rolling max, given a window size and optional list of weights.
Expand All @@ -1852,23 +1858,23 @@ defmodule Explorer.Series do
## Examples
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_max(s, 4)
iex> Explorer.Series.window_max(s, 4)
#Explorer.Series<
integer[10]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>
iex> s = 1..10 |> Enum.to_list() |> Explorer.Series.from_list()
iex> Explorer.Series.rolling_max(s, 2, weights: [1.0, 2.0])
iex> Explorer.Series.window_max(s, 2, weights: [1.0, 2.0])
#Explorer.Series<
float[10]
[1.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0]
>
"""
def rolling_max(series, window_size, opts \\ []),
do: apply_impl(series, :rolling_max, [window_size, rolling_opts_with_defaults(opts)])
def window_max(series, window_size, opts \\ []),
do: apply_impl(series, :window_max, [window_size, window_opts_with_defaults(opts)])

defp rolling_opts_with_defaults(opts) do
defp window_opts_with_defaults(opts) do
defaults = [weights: nil, min_periods: 1, center: false]

Keyword.merge(defaults, opts, fn _key, _left, right -> right end)
Expand Down

0 comments on commit 0f4febc

Please sign in to comment.