Skip to content

Commit

Permalink
Change pooling defaults (elixir-nx#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Oct 7, 2021
1 parent 3dc79a0 commit d42cc99
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
10 changes: 6 additions & 4 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,10 @@ defmodule Axon do
## Options
* `:name` - Layer name.
* `:kernel_size` - Pooling kernel size.
* `:strides` - Pooling strides.
* `name` - Layer name.
* `kernel_size` - Pooling kernel size. Defaults to `1`.
* `padding` - Padding to apply to input of pooling operation.
* `strides` - Pooling strides. Defaults to size of kernel.
"""
@doc type: :pooling
Expand All @@ -974,11 +975,12 @@ defmodule Axon do

defp pool(%Axon{output_shape: parent_shape} = x, pool, opts) do
kernel_size = opts[:kernel_size] || 1
strides = opts[:strides] || 1
strides = opts[:strides]
padding = opts[:padding] || :valid
inner_rank = Nx.rank(parent_shape) - 2

kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = if strides, do: strides, else: Tuple.to_list(kernel_size)
strides = list_or_duplicate(:strides, strides, inner_rank)
output_shape = Axon.Shape.pool(parent_shape, kernel_size, strides, padding)

Expand Down
47 changes: 25 additions & 22 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ defmodule Axon.Layers do
* `:strides` - kernel strides. Can be a scalar or a list
who's length matches the number of spatial dimensions in
the input tensor. Defaults to 1.
the input tensor. Defaults to size of kernel.
* `:padding` - zero padding on the input. Can be one of
`:valid`, `:same` or a general padding configuration
Expand All @@ -598,12 +598,12 @@ defmodule Axon.Layers do
...> [0.7099999785423279, 0.7282999753952026, -0.18639999628067017]]], type: {:f, 32})
iex> Axon.Layers.max_pool(t, kernel_size: 2)
#Nx.Tensor<
f32[1][3][2]
f32[1][3][1]
[
[
[0.051500000059604645, -0.32899999618530273],
[1.6191999912261963, 1.6191999912261963],
[0.7282999753952026, 0.7282999753952026]
[0.051500000059604645],
[1.6191999912261963],
[0.7282999753952026]
]
]
>
Expand All @@ -613,7 +613,7 @@ defmodule Axon.Layers do
opts =
keyword!(
opts,
[:kernel_size, strides: 1, padding: :valid, window_dilations: 1]
[:kernel_size, strides: nil, padding: :valid, window_dilations: 1]
)

window_dimensions =
Expand All @@ -626,10 +626,11 @@ defmodule Axon.Layers do

strides =
transform(
{Nx.rank(input), opts[:strides]},
{Nx.rank(input), opts[:strides], window_dimensions},
fn
{_, [_ | _] = strides} -> [1, 1 | strides]
{rank, strides} -> [1, 1 | List.duplicate(strides, rank - 2)]
{_, nil, dims} -> Tuple.to_list(dims)
{_, [_ | _] = strides, _} -> [1, 1 | strides]
{rank, strides, _} -> [1, 1 | List.duplicate(strides, rank - 2)]
end
)

Expand Down Expand Up @@ -692,7 +693,7 @@ defmodule Axon.Layers do
opts =
keyword!(
opts,
[:kernel_size, strides: 1, padding: :valid, window_dilations: 1]
[:kernel_size, strides: nil, padding: :valid, window_dilations: 1]
)

window_dimensions =
Expand All @@ -705,10 +706,11 @@ defmodule Axon.Layers do

strides =
transform(
{Nx.rank(input), opts[:strides]},
{Nx.rank(input), opts[:strides], window_dimensions},
fn
{_, [_ | _] = strides} -> [1, 1 | strides]
{rank, strides} -> [1, 1 | List.duplicate(strides, rank - 2)]
{_, nil, dims} -> Tuple.to_list(dims)
{_, [_ | _] = strides, _} -> [1, 1 | strides]
{rank, strides, _} -> [1, 1 | List.duplicate(strides, rank - 2)]
end
)

Expand Down Expand Up @@ -759,7 +761,7 @@ defmodule Axon.Layers do
* `:strides` - kernel strides. Can be a scalar or a list
who's length matches the number of spatial dimensions in
the input tensor. Defaults to 1.
the input tensor. Defaults to size of kernel.
* `:padding` - zero padding on the input. Can be one of
`:valid`, `:same` or a general padding configuration
Expand All @@ -778,12 +780,12 @@ defmodule Axon.Layers do
iex> t = Nx.tensor([[[0.9450, 0.4684, 1.8146], [1.2663, 0.4354, -0.0781], [-0.4759, 0.3251, 0.8742]]], type: {:f, 32})
iex> Axon.Layers.lp_pool(t, kernel_size: 2, norm: 2)
#Nx.Tensor<
f32[1][3][2]
f32[1][3][1]
[
[
[1.0547149181365967, 1.8740788698196411],
[1.3390626907348633, 0.4423491656780243],
[0.5763426423072815, 0.9326926469802856]
[1.0547149181365967],
[1.3390626907348633],
[0.5763426423072815]
]
]
>
Expand All @@ -793,7 +795,7 @@ defmodule Axon.Layers do
opts =
keyword!(
opts,
[:kernel_size, strides: 1, padding: :valid, window_dilations: 1, norm: 2]
[:kernel_size, strides: nil, padding: :valid, window_dilations: 1, norm: 2]
)

window_dimensions =
Expand All @@ -806,10 +808,11 @@ defmodule Axon.Layers do

strides =
transform(
{Nx.rank(input), opts[:strides]},
{Nx.rank(input), opts[:strides], window_dimensions},
fn
{_, [_ | _] = strides} -> [1, 1 | strides]
{rank, strides} -> [1, 1 | List.duplicate(strides, rank - 2)]
{_, nil, dims} -> Tuple.to_list(dims)
{_, [_ | _] = strides, _} -> [1, 1 | strides]
{rank, strides, _} -> [1, 1 | List.duplicate(strides, rank - 2)]
end
)

Expand Down

0 comments on commit d42cc99

Please sign in to comment.