Skip to content

Commit

Permalink
Fix LpPool bug
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Sep 28, 2021
1 parent a8dc91b commit be03048
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
22 changes: 17 additions & 5 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -982,11 +982,23 @@ defmodule Axon do
strides = list_or_duplicate(:strides, strides, inner_rank)
output_shape = Axon.Shape.pool(parent_shape, kernel_size, strides, padding)

layer(x, pool, output_shape, %{}, opts[:name],
kernel_size: kernel_size,
strides: strides,
padding: padding
)
name = opts[:name]

opts =
if pool == :lp_pool do
norm = opts[:norm] || 2

[
kernel_size: kernel_size,
strides: strides,
padding: padding,
norm: norm
]
else
[kernel_size: kernel_size, strides: strides, padding: padding]
end

layer(x, pool, output_shape, %{}, name, opts)
end

## Adaptive Pooling
Expand Down
8 changes: 8 additions & 0 deletions test/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,14 @@ defmodule CompilerTest do
end
end

test "lp_pool computes forward pass with custom norm" do
model = Axon.input({nil, 1, 32}) |> Axon.lp_pool(norm: 3)
input = Nx.random_uniform({1, 1, 32}, type: {:f, 32})

assert {_, predict_fn} = Axon.compile(model)
assert predict_fn.(%{}, input) == Axon.Layers.lp_pool(input, kernel_size: {1}, norm: 3)
end

test "computes forward pass with output policy" do
for pool <- @pooling_layers do
model = apply(Axon, pool, [Axon.input({nil, 1, 32})])
Expand Down

0 comments on commit be03048

Please sign in to comment.