Skip to content

Commit

Permalink
Add bilinear layer
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Aug 28, 2021
1 parent 92502e4 commit 81eb013
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 2 deletions.
74 changes: 74 additions & 0 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,80 @@ defmodule Axon do
end
end

@doc """
Adds a bilinear layer to the network.
The bilinear layer implements:
output = activation(dot(dot(input1, kernel), input2) + bias)
where `activation` is given by the `:activation` option and both
`kernel` and `bias` are layer parameters. `units` specifies the
number of output units.
All dimensions but the last of `input1` and `input2` must match. The
batch sizes of both inputs must also match or at least one must be `nil`.
Inferred output batch size coerces to the strictest input batch size.
Compiles to `Axon.Layers.bilinear/4`.
## Options
* `name` - Layer name.
* `name` - Layer name.
* `kernel_initializer` - Initializer for `kernel` weights.
* `bias_initializer` - Initializer for `bias` weights.
* `activation` - Element-wise activation function.
"""
@doc type: :linear
def bilinear(
%Axon{output_shape: parent1_shape} = input1,
%Axon{output_shape: parent2_shape} = input2,
units,
opts \\ []
)
when is_integer(units) and units > 0 do
activation = opts[:activation]
use_bias = Keyword.get(opts, :use_bias, true)

kernel_shape = Axon.Shape.bilinear_kernel(parent1_shape, parent2_shape, units)
bias_shape = Axon.Shape.bilinear_bias(parent1_shape, parent2_shape, units)
output_shape = Axon.Shape.bilinear(parent1_shape, parent2_shape, units)

kernel_initializer = opts[:kernel_initializer]
kernel_regularizer = opts[:kernel_regularizer]

kernel =
param("kernel", kernel_shape,
initializer: kernel_initializer,
regularizer: kernel_regularizer
)

params =
if use_bias do
bias_initializer = opts[:bias_initializer] || :zeros
bias_regularizer = opts[:bias_regularizer]

bias =
param("bias", bias_shape, initializer: bias_initializer, regularizer: bias_regularizer)

%{"kernel" => kernel, "bias" => bias}
else
%{"kernel" => kernel}
end

node =
layer([input1, input2], :bilinear, output_shape, params, opts[:name], use_bias: use_bias)

if activation do
node
|> activation(activation)
else
node
end
end

@doc """
Adds a convolution layer to the network.
Expand Down
48 changes: 47 additions & 1 deletion lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,21 @@ defmodule Axon.Compiler do
|> Enum.reduce(cache, fn x, acc -> to_init_fun(x, acc) end)
end

defp to_init_fun(%Axon{parent: parents}, cache) when is_list(parents) do
defp to_init_fun(%Axon{parent: parents, params: params, policy: %{params: dtype}}, cache)
when is_list(parents) do
cache =
Enum.reduce(params, cache, fn {_, %{name: name} = param}, cache ->
case cache do
%{^name => _} ->
cache

%{} ->
%{name: name, shape: shape, initializer: initializer} = param
fun = fn -> apply(Axon.Initializers, initializer, [[type: dtype, shape: shape]]) end
Map.put(cache, name, fun)
end
end)

Enum.reduce(parents, cache, &to_init_fun/2)
end

Expand Down Expand Up @@ -309,6 +323,38 @@ defmodule Axon.Compiler do
{fun, cache}
end

defp recur_predict_fun(
%Axon{
op: :bilinear,
parent: parents,
params: %{"kernel" => %{name: w, frozen: w_frz}} = layer_params,
policy: %{compute: compute, output: output},
opts: [use_bias: use_bias]
},
cache,
input_map
) do
{[fun1, fun2], cache} = Enum.map_reduce(parents, cache, &recur_predict_fun(&1, &2, input_map))

fun = fn params, inputs ->
input1 = Nx.as_type(fun1.(params, inputs), compute)
input2 = Nx.as_type(fun2.(params, inputs), compute)
w = Nx.as_type(maybe_freeze(params[w], w_frz), compute)

b =
if use_bias do
%{name: b, frozen: b_frz} = layer_params["bias"]
Nx.as_type(maybe_freeze(params[b], b_frz), compute)
else
Nx.tensor(0.0, type: compute)
end

Nx.as_type(apply(Axon.Layers, :bilinear, [input1, input2, w, b]), output)
end

{fun, cache}
end

## Sparse Layers

defp recur_predict_fun(
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ defmodule Axon.Layers do
## Output Shape
`{batch_size, output_features}`
`{batch_size, ..., output_features}`
## Examples
Expand Down
94 changes: 94 additions & 0 deletions lib/axon/shape.ex
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,100 @@ defmodule Axon.Shape do
{elem(input_shape, 0), units}
end

@doc """
Calculates the shape of a bilinear kernel given both input
shapes and output units.
## Examples
iex> Axon.Shape.bilinear_kernel({nil, 32}, {nil, 64}, 128)
{128, 32, 64}
iex> Axon.Shape.bilinear_kernel({nil, 32, 64}, {nil, 16}, 32)
{32, 64, 16}
"""
def bilinear_kernel(parent1, parent2, units) do
parent1_features = elem(parent1, Nx.rank(parent1) - 1)
parent2_features = elem(parent2, Nx.rank(parent2) - 1)
{units, parent1_features, parent2_features}
end

@doc """
Calculates the shape of a bilinear bias given both input
shapes and output units.
## Examples
iex> Axon.Shape.bilinear_bias({nil, 32}, {nil, 64}, 128)
{128}
iex> Axon.Shape.bilinear_bias({nil, 32, 64}, {nil, 32, 16}, 32)
{32}
"""
def bilinear_bias(_parent1, _parent2, units) do
{units}
end

@doc """
Calculates the output shape of a bilinear layer given both input
shapes and output units.
## Examples
iex> Axon.Shape.bilinear({nil, 32}, {nil, 64}, 128)
{nil, 128}
iex> Axon.Shape.bilinear({nil, 32, 64}, {nil, 32, 16}, 32)
{nil, 32, 32}
iex> Axon.Shape.bilinear({nil, 32, 64}, {16, 32, 16}, 32)
{16, 32, 32}
### Errors
iex> Axon.Shape.bilinear({32, 32}, {16, 16}, 32)
** (ArgumentError) all input dimensions but the last must match, got 32 and 16 for shapes {32, 32} and {16, 16}
iex> Axon.Shape.bilinear({nil, 16, 32}, {nil, 16}, 32)
** (ArgumentError) input ranks must match, got 3 and 2
"""
def bilinear(parent1, parent2, units) do
unless Nx.rank(parent1) == Nx.rank(parent2) do
raise ArgumentError,
"input ranks must match, got #{inspect(Nx.rank(parent1))}" <>
" and #{inspect(Nx.rank(parent2))}"
end

parent1_without_features =
parent1
|> Tuple.delete_at(Nx.rank(parent1) - 1)
|> Tuple.to_list()

parent2_without_features =
parent2
|> Tuple.delete_at(Nx.rank(parent2) - 1)
|> Tuple.to_list()

output_shape_no_features =
parent1_without_features
|> Enum.zip_with(parent2_without_features, fn p1, p2 ->
unless is_nil(p1) or is_nil(p2) or p1 == p2 do
raise ArgumentError,
"all input dimensions but the last must match, got #{inspect(p1)}" <>
" and #{inspect(p2)} for shapes #{inspect(parent1)} and #{inspect(parent2)}"
end

if is_nil(p1) do
p2
else
p1
end
end)
|> List.to_tuple()

Tuple.append(output_shape_no_features, units)
end

## Sparse

@doc """
Expand Down
137 changes: 137 additions & 0 deletions test/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,143 @@ defmodule CompilerTest do
end
end

describe "bilinear" do
test "initializes in default case" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model = Axon.bilinear(input1, input2, 1, name: "bilinear")

assert {init_fn, _predict_fn} = Axon.compile(model)
assert %{"bilinear_kernel" => kernel, "bilinear_bias" => bias} = init_fn.()
assert Nx.shape(kernel) == {1, 1, 2}
assert Nx.type(kernel) == {:f, 32}
assert Nx.shape(bias) == {1}
assert Nx.type(bias) == {:f, 32}
end

test "initializes with custom initializers" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model1 = Axon.bilinear(input1, input2, 1, name: "bilinear", kernel_initializer: :zeros)

assert {init_fn, _predict_fn} = Axon.compile(model1)
assert %{"bilinear_kernel" => kernel, "bilinear_bias" => bias} = init_fn.()
assert kernel == Axon.Initializers.zeros(shape: {1, 1, 2})
assert Nx.shape(bias) == {1}
assert Nx.type(bias) == {:f, 32}

model2 = Axon.bilinear(input1, input2, 1, name: "bilinear", bias_initializer: :zeros)

assert {init_fn, _predict_fn} = Axon.compile(model2)
assert %{"bilinear_kernel" => kernel, "bilinear_bias" => bias} = init_fn.()
assert Nx.shape(kernel) == {1, 1, 2}
assert Nx.type(kernel) == {:f, 32}
assert bias == Axon.Initializers.zeros(shape: {1})
end

test "computes forward pass" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model = Axon.bilinear(input1, input2, 1, name: "bilinear")

input1 = Nx.iota({1, 1}, type: {:f, 32})
input2 = Nx.iota({1, 2}, type: {:f, 32})

assert {init_fn, predict_fn} = Axon.compile(model)
assert %{"bilinear_kernel" => kernel, "bilinear_bias" => bias} = params = init_fn.()

assert predict_fn.(params, {input1, input2}) ==
Axon.Layers.bilinear(input1, input2, kernel, bias)
end

test "computes forward pass with constant" do
input1 = Axon.input({nil, 1})
input2 = Axon.constant(Nx.iota({2, 1}))
model = Axon.bilinear(input1, input2, 1, name: "bilinear")

input1 = Nx.iota({2, 1}, type: {:f, 32})
input2 = Nx.iota({2, 1}, type: {:f, 32})

assert {init_fn, predict_fn} = Axon.compile(model)
assert %{"bilinear_kernel" => kernel, "bilinear_bias" => bias} = params = init_fn.()

assert predict_fn.(params, input1) == Axon.Layers.bilinear(input1, input2, kernel, bias)
end

test "returns zero gradient for frozen parameters" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model = Axon.bilinear(input1, input2, 1, name: "bilinear") |> Axon.freeze()

assert {init_fn, predict_fn} = Axon.compile(model)

backward = fn params, input ->
Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input)))
end

assert %{"bilinear_kernel" => kernel_grad, "bilinear_bias" => bias_grad} =
Nx.Defn.jit(backward, [
init_fn.(),
{Nx.random_uniform({1, 1}), Nx.random_uniform({1, 2})}
])

assert kernel_grad == Nx.broadcast(0.0, {1, 1, 2})
assert bias_grad == Nx.broadcast(0.0, {1})
end

test "initializes with parameter policy" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model = Axon.bilinear(input1, input2, 1, name: "bilinear")
policy = AMP.create_policy(params: {:bf, 16})
mp_model = AMP.apply_policy(model, policy)

assert {init_fn, _} = Axon.compile(mp_model)
assert %{"bilinear_kernel" => kernel, "bilinear_bias" => bias} = init_fn.()
assert Nx.type(kernel) == {:bf, 16}
assert Nx.type(bias) == {:bf, 16}
end

test "computes forward pass with output policy" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model = Axon.bilinear(input1, input2, 1, name: "bilinear")
policy = AMP.create_policy(output: {:bf, 16})
mp_model = AMP.apply_policy(model, policy)

assert {init_fn, predict_fn} = Axon.compile(mp_model)

assert Nx.type(
predict_fn.(init_fn.(), {Nx.random_uniform({1, 1}), Nx.random_uniform({1, 2})})
) == {:bf, 16}
end

test "initializes with use_bias false" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model = Axon.bilinear(input1, input2, 1, name: "bilinear", use_bias: false)

assert {init_fn, _} = Axon.compile(model)
assert %{"bilinear_kernel" => _} = params = init_fn.()
assert Map.has_key?(params, "bilinear_bias") == false
end

test "computes forward pass with use_bias false" do
input1 = Axon.input({nil, 1})
input2 = Axon.input({nil, 2})
model = Axon.bilinear(input1, input2, 1, name: "bilinear", use_bias: false)

inp1 = Nx.random_uniform({1, 1})
inp2 = Nx.random_uniform({1, 2})

assert {init_fn, predict_fn} = Axon.compile(model)
assert %{"bilinear_kernel" => k} = params = init_fn.()

assert predict_fn.(params, {inp1, inp2}) ==
Axon.Layers.bilinear(inp1, inp2, k, Nx.tensor(0.0))
end
end

describe "embedding" do
test "initializes in default case" do
model = Axon.input({nil, 1}) |> Axon.embedding(1, 1, name: "embedding")
Expand Down

0 comments on commit 81eb013

Please sign in to comment.