Skip to content

Commit

Permalink
Fix padding issue
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jun 21, 2021
1 parent d5d1e01 commit b15be1f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
33 changes: 30 additions & 3 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -656,12 +656,21 @@ defmodule Axon.Layers do
end
)

padding =
transform({Nx.rank(input), opts[:padding]},
fn
{_, :same} -> :same
{_, :valid} -> :valid
{rank, padding} ->
List.duplicate({0, 0}, rank - 2) ++ padding
end)

opts = transform(opts, &Keyword.delete(&1, :kernel_size))

input
|> Nx.window_max(window_dimensions,
strides: strides,
padding: opts[:padding],
padding: padding,
window_dilations: opts[:window_dilations]
)
end
Expand Down Expand Up @@ -720,12 +729,21 @@ defmodule Axon.Layers do
end
)

padding =
transform({Nx.rank(input), opts[:padding]},
fn
{_, :same} -> :same
{_, :valid} -> :valid
{rank, padding} ->
List.duplicate({0, 0}, rank - 2) ++ padding
end)

opts = transform(opts, &Keyword.delete(&1, :kernel_size))

input
|> Nx.window_mean(window_dimensions,
strides: strides,
padding: opts[:padding],
padding: padding,
window_dilations: opts[:window_dilations]
)
end
Expand Down Expand Up @@ -804,6 +822,15 @@ defmodule Axon.Layers do
end
)

padding =
transform({Nx.rank(input), opts[:padding]},
fn
{_, :same} -> :same
{_, :valid} -> :valid
{rank, padding} ->
List.duplicate({0, 0}, rank - 2) ++ padding
end)

norm = opts[:norm]

opts =
Expand All @@ -815,7 +842,7 @@ defmodule Axon.Layers do
|> Nx.power(norm)
|> Nx.window_sum(window_dimensions,
strides: strides,
padding: opts[:padding],
padding: padding,
window_dilations: opts[:window_dilations]
)
|> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/training.ex
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ defmodule Axon.Training do

step_fn = fn train_state, input, target ->
{{preds, batch_loss}, gradients} =
Nx.Defn.Kernel.value_and_grad(
Nx.Defn.value_and_grad(
train_state[:params],
&objective_fn.(&1, input, target),
fn x -> elem(x, 1) end
Expand Down

0 comments on commit b15be1f

Please sign in to comment.