Skip to content

Commit

Permalink
Fix conv transpose (elixir-nx#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Oct 20, 2021
1 parent d98c201 commit bfe3b25
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 4 deletions.
5 changes: 4 additions & 1 deletion lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,12 @@ defmodule Axon.Layers do
end
)

ones = transform(Nx.rank(input), &List.duplicate(1, &1 - 2))

conv(input, weight, bias,
strides: opts[:strides],
strides: ones,
padding: padding,
input_dilation: strides,
kernel_dilation: opts[:kernel_dilation]
)
end
Expand Down
6 changes: 3 additions & 3 deletions lib/axon/shape.ex
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ defmodule Axon.Shape do
def conv_transpose(parent_shape, kernel_shape, strides, padding, kernel_dilation) do
permutation = for i <- 0..(Nx.rank(parent_shape) - 1), do: i
names = List.duplicate(nil, Nx.rank(parent_shape))
input_dilation = List.duplicate(1, Nx.rank(parent_shape) - 2)

input_dilation = strides
one = List.duplicate(1, Nx.rank(parent_shape) - 2)
padding = conv_transpose_padding(kernel_shape, kernel_dilation, strides, padding)

input_shape =
Expand All @@ -440,7 +440,7 @@ defmodule Axon.Shape do
names,
kernel_shape,
names,
strides,
one,
padding,
1,
1,
Expand Down
130 changes: 130 additions & 0 deletions test/layers_test.exs
Original file line number Diff line number Diff line change
@@ -1,4 +1,134 @@
defmodule Axon.LayersTest do
use ExUnit.Case, async: true
doctest Axon.Layers

describe "conv_transpose" do
test "correct valid padding, no strides" do
inp = Nx.iota({1, 1, 4}, type: {:f, 32})
kernel = Nx.iota({3, 1, 2}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias, padding: :valid) ==
Nx.tensor([
[
[0.0, 1.0, 2.0, 3.0, 0.0],
[0.0, 3.0, 8.0, 13.0, 6.0],
[0.0, 5.0, 14.0, 23.0, 12.0]
]
])
end

test "correct with valid padding, strides" do
inp = Nx.iota({1, 1, 2, 2}, type: {:f, 32})
kernel = Nx.iota({3, 1, 2, 2}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias, padding: :valid, strides: [2, 1]) ==
Nx.tensor([
[
[[0.0, 3.0, 2.0], [0.0, 1.0, 0.0], [6.0, 13.0, 6.0], [2.0, 3.0, 0.0]],
[[0.0, 7.0, 6.0], [0.0, 5.0, 4.0], [14.0, 33.0, 18.0], [10.0, 23.0, 12.0]],
[[0.0, 11.0, 10.0], [0.0, 9.0, 8.0], [22.0, 53.0, 30.0], [18.0, 43.0, 24.0]]
]
])
end

test "correct with 3 spatial dimensions" do
inp = Nx.iota({1, 1, 2, 2, 1}, type: {:f, 32})
kernel = Nx.iota({3, 1, 2, 2, 1}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias, padding: :valid, strides: [1, 1, 2]) ==
Nx.tensor([
[
[
[[0.0, 0.0], [3.0, 0.0], [2.0, 0.0]],
[[6.0, 0.0], [14.0, 0.0], [6.0, 0.0]],
[[2.0, 0.0], [3.0, 0.0], [0.0, 0.0]]
],
[
[[0.0, 0.0], [7.0, 0.0], [6.0, 0.0]],
[[14.0, 0.0], [38.0, 0.0], [22.0, 0.0]],
[[10.0, 0.0], [23.0, 0.0], [12.0, 0.0]]
],
[
[[0.0, 0.0], [11.0, 0.0], [10.0, 0.0]],
[[22.0, 0.0], [62.0, 0.0], [38.0, 0.0]],
[[18.0, 0.0], [43.0, 0.0], [24.0, 0.0]]
]
]
])
end

test "correct with same padding, no strides" do
inp = Nx.iota({3, 1, 2, 2}, type: {:f, 32})
kernel = Nx.iota({1, 1, 2, 2}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias, padding: :same) ==
Nx.tensor([
[[[0.0, 3.0], [6.0, 14.0]]],
[[[12.0, 23.0], [22.0, 38.0]]],
[[[24.0, 43.0], [38.0, 62.0]]]
])
end

test "correct with same padding, strides" do
inp = Nx.iota({1, 3, 2, 2}, type: {:f, 32})
kernel = Nx.iota({4, 3, 1, 2}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias, padding: :same, strides: [2, 1]) ==
Nx.tensor([
[
[[52.0, 101.0], [0.0, 0.0], [70.0, 131.0], [0.0, 0.0]],
[[124.0, 263.0], [0.0, 0.0], [178.0, 365.0], [0.0, 0.0]],
[[196.0, 425.0], [0.0, 0.0], [286.0, 599.0], [0.0, 0.0]],
[[268.0, 587.0], [0.0, 0.0], [394.0, 833.0], [0.0, 0.0]]
]
])
end

test "correct with custom padding, no strides" do
inp = Nx.iota({1, 1, 2, 2}, type: {:f, 32})
kernel = Nx.iota({1, 1, 2, 1}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias, padding: [{0, 1}, {1, 2}]) ==
Nx.tensor([[[[0.0, 2.0, 3.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]]])
end

test "correct with custom padding, strides" do
inp = Nx.iota({1, 1, 2, 2}, type: {:f, 32})
kernel = Nx.iota({1, 1, 2, 1}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias,
padding: [{0, 1}, {1, 2}],
strides: [2, 1]
) ==
Nx.tensor([
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 2.0, 3.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
])
end

test "correct with kernel dilation" do
inp = Nx.iota({1, 1, 2, 4}, type: {:f, 32})
kernel = Nx.iota({1, 1, 2, 3}, type: {:f, 32})
bias = 0.0

assert Axon.Layers.conv_transpose(inp, kernel, bias,
kernel_dilation: [2, 1],
padding: [{0, 1}, {1, 2}],
strides: [2, 1]
) ==
Nx.tensor([[[[43.0, 67.0, 82.0, 49.0, 21.0], [0.0, 0.0, 0.0, 0.0, 0.0]]]])
end
end
end

0 comments on commit bfe3b25

Please sign in to comment.