Skip to content

Commit

Permalink
Add recurrent layers (elixir-nx#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored May 1, 2021
1 parent 0e97042 commit 275198d
Show file tree
Hide file tree
Showing 5 changed files with 694 additions and 23 deletions.
196 changes: 175 additions & 21 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ defmodule Axon do
@doc type: :layer
def input(input_shape, opts \\ []) do
output_shape = Axon.Shape.input(input_shape)
Axon.layer(nil, :input, output_shape, [], opts[:name], opts)
layer(nil, :input, output_shape, [], opts[:name], opts)
end

@doc """
Expand Down Expand Up @@ -184,7 +184,7 @@ defmodule Axon do
bias_regularizer = opts[:bias_regularizer]
bias = param("bias", bias_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node = Axon.layer(x, :dense, output_shape, [kernel, bias], opts[:name])
node = layer(x, :dense, output_shape, [kernel, bias], opts[:name])

if activation do
node
Expand Down Expand Up @@ -260,7 +260,7 @@ defmodule Axon do
bias = param("bias", bias_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
Axon.layer(x, :conv, output_shape, [kernel, bias], opts[:name],
layer(x, :conv, output_shape, [kernel, bias], opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -328,7 +328,7 @@ defmodule Axon do
)

node =
Axon.layer(x, :conv_transpose, output_shape, [kernel, bias], opts[:name],
layer(x, :conv_transpose, output_shape, [kernel, bias], opts[:name],
strides: strides,
padding: padding,
kernel_dilation: kernel_dilation
Expand Down Expand Up @@ -413,7 +413,7 @@ defmodule Axon do
bias = param("bias", bias_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
Axon.layer(x, :depthwise_conv, output_shape, [kernel, bias], opts[:name],
layer(x, :depthwise_conv, output_shape, [kernel, bias], opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -502,7 +502,7 @@ defmodule Axon do
b2 = param("bias_2", b2_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
Axon.layer(x, :separable_conv2d, output_shape, [k1, b1, k2, b2], opts[:name],
layer(x, :separable_conv2d, output_shape, [k1, b1, k2, b2], opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -599,7 +599,7 @@ defmodule Axon do
b3 = param("bias_3", b3_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
Axon.layer(x, :separable_conv3d, output_shape, [k1, b1, k2, b2, k3, b3], opts[:name],
layer(x, :separable_conv3d, output_shape, [k1, b1, k2, b2, k3, b3], opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -634,7 +634,7 @@ defmodule Axon do

def activation(%Axon{output_shape: shape} = x, activation, opts)
when is_atom(activation) and activation in @activation_layers do
Axon.layer(x, activation, shape, [], opts[:name], opts)
layer(x, activation, shape, [], opts[:name], opts)
end

def activation(_, activation, _) do
Expand Down Expand Up @@ -686,7 +686,7 @@ defmodule Axon do

defp dropout(%Axon{output_shape: parent_shape} = x, dropout, opts) do
rate = opts[:rate] || 0.5
Axon.layer(x, dropout, parent_shape, [], opts[:name], rate: rate)
layer(x, dropout, parent_shape, [], opts[:name], rate: rate)
end

## Pooling
Expand Down Expand Up @@ -722,7 +722,7 @@ defmodule Axon do
strides = list_or_duplicate(:strides, strides, inner_rank)
output_shape = Axon.Shape.pool(parent_shape, kernel_size, strides, padding)

Axon.layer(x, pool, output_shape, [], opts[:name],
layer(x, pool, output_shape, [], opts[:name],
kernel_size: kernel_size,
strides: strides,
padding: padding
Expand Down Expand Up @@ -757,7 +757,7 @@ defmodule Axon do
output_size = tuple_or_duplicate(:output_size, opts[:output_size], inner_rank)
output_shape = Axon.Shape.adaptive_pool(parent_shape, output_size)

Axon.layer(x, pool, output_shape, [], opts[:name], output_size: output_size)
layer(x, pool, output_shape, [], opts[:name], output_size: output_size)
end

## Normalization
Expand Down Expand Up @@ -803,7 +803,7 @@ defmodule Axon do
beta_regularizer = opts[:beta_regularizer]
beta = param("beta", beta_shape, initializer: beta_initializer, regularizer: beta_regularizer)

Axon.layer(x, norm, shape, [gamma, beta], opts[:name],
layer(x, norm, shape, [gamma, beta], opts[:name],
epsilon: epsilon,
channel_index: channel_index
)
Expand Down Expand Up @@ -843,7 +843,7 @@ defmodule Axon do
beta_regularizer = opts[:beta_regularizer]
beta = param("beta", beta_shape, initializer: beta_initializer, regularizer: beta_regularizer)

Axon.layer(x, :group_norm, shape, [gamma, beta], opts[:name],
layer(x, :group_norm, shape, [gamma, beta], opts[:name],
epsilon: epsilon,
channel_index: channel_index,
group_size: group_size
Expand All @@ -869,7 +869,7 @@ defmodule Axon do
expr = Nx.Defn.jit(fun, [param], compiler: Axon.Defn)
output_shape = Tuple.insert_at(expr.shape, 0, batch_size)

Axon.layer(x, fun, output_shape, [], opts[:name])
layer(x, fun, output_shape, [], opts[:name])
end

@doc """
Expand All @@ -887,7 +887,7 @@ defmodule Axon do
@doc type: :composition
def flatten(%Axon{output_shape: shape} = x, opts \\ []) do
output_shape = Axon.Shape.flatten(shape)
Axon.layer(x, :flatten, output_shape, [], opts[:name])
layer(x, :flatten, output_shape, [], opts[:name])
end

@doc """
Expand All @@ -903,7 +903,7 @@ defmodule Axon do
@doc type: :composition
def reshape(%Axon{output_shape: shape} = x, new_shape, opts \\ []) do
output_shape = Axon.Shape.reshape(shape, new_shape)
Axon.layer(x, :reshape, output_shape, [], opts[:name])
layer(x, :reshape, output_shape, [], opts[:name])
end

@doc """
Expand All @@ -918,7 +918,7 @@ defmodule Axon do
@doc type: :composition
def transpose(%Axon{output_shape: shape} = x, permutation, opts \\ []) do
output_shape = Axon.Shape.transpose(shape, permutation)
Axon.layer(x, :transpose, output_shape, [], opts[:name], permutation: permutation)
layer(x, :transpose, output_shape, [], opts[:name], permutation: permutation)
end

@doc """
Expand All @@ -936,7 +936,7 @@ defmodule Axon do
def pad(%Axon{output_shape: shape} = x, config, value \\ 0.0, opts \\ [])
when is_list(config) and is_number(value) do
output_shape = Axon.Shape.pad(shape, config)
Axon.layer(x, :pad, output_shape, [], opts[:name], padding_config: config, value: value)
layer(x, :pad, output_shape, [], opts[:name], padding_config: config, value: value)
end

@doc """
Expand All @@ -957,7 +957,7 @@ defmodule Axon do
axis = opts[:axis] || Nx.rank(x_shape) - 1
output_shape = Axon.Shape.concatenate([x_shape, y_shape], axis)

Axon.layer([x, y], :concatenate, output_shape, [], opts[:name], axis: axis)
layer([x, y], :concatenate, output_shape, [], opts[:name], axis: axis)
end

@doc type: :composition
Expand All @@ -967,7 +967,7 @@ defmodule Axon do
input_shapes = inputs |> Enum.map(fn %Axon{output_shape: shape} -> shape end)
output_shape = Axon.Shape.concatenate(input_shapes, axis)

Axon.layer(inputs, :concatenate, output_shape, [], opts[:name], axis: axis)
layer(inputs, :concatenate, output_shape, [], opts[:name], axis: axis)
end

@doc false
Expand Down Expand Up @@ -1007,7 +1007,7 @@ defmodule Axon do
shape
end)

Axon.layer(inputs, unquote(op), output_shape, [], opts[:name])
layer(inputs, unquote(op), output_shape, [], [], opts[:name])
end

def unquote(op)(%Axon{output_shape: shape} = x, %Axon{output_shape: shape} = y) do
Expand All @@ -1017,6 +1017,160 @@ defmodule Axon do
def unquote(op)([%Axon{} | _] = inputs), do: unquote(op)(inputs, [])
end

@doc """
LSTM Layer.
"""
def lstm(%Axon{output_shape: shape} = x, units, opts \\ [])
when is_integer(units) and units > 0 do
activation = opts[:activation] || :tanh
gate = opts[:gate] || :sigmoid
hidden_state = opts[:hidden_state]

output_shape = Axon.Shape.rnn(shape, units, "LSTM")
input_kernel_shape = Axon.Shape.rnn_input_kernel(shape, units, "LSTM")
hidden_kernel_shape = Axon.Shape.rnn_hidden_kernel(shape, units, "LSTM")
bias_shape = Axon.Shape.rnn_bias(shape, units, "LSTM")
hidden_state_shape = Axon.Shape.rnn_hidden_state(shape, units, "LSTM")

kernel_initializer = opts[:kernel_initializer] || :glorot_uniform
recurrent_initializer = opts[:recurrent_initializer] || :glorot_uniform
bias_initializer = opts[:bias_initializer] || :zeros

# Parameters
wii = param("wii", input_kernel_shape, initializer: kernel_initializer)
wif = param("wif", input_kernel_shape, initializer: kernel_initializer)
wig = param("wig", input_kernel_shape, initializer: kernel_initializer)
wio = param("wio", input_kernel_shape, initializer: kernel_initializer)

whi = param("whi", hidden_kernel_shape, initializer: kernel_initializer)
whf = param("whf", hidden_kernel_shape, initializer: kernel_initializer)
whg = param("whg", hidden_kernel_shape, initializer: kernel_initializer)
who = param("who", hidden_kernel_shape, initializer: kernel_initializer)

bi = param("bi", bias_shape, initializer: bias_initializer)
bf = param("bf", bias_shape, initializer: bias_initializer)
bg = param("bg", bias_shape, initializer: bias_initializer)
bo = param("bo", bias_shape, initializer: bias_initializer)

output =
layer(
x,
:lstm,
{{hidden_state_shape, hidden_state_shape}, output_shape},
[wii, wif, wig, wio, whi, whf, whg, who, bi, bf, bg, bo],
opts[:name],
activation: activation,
gate: gate,
hidden_state: hidden_state,
hidden_state_shape: hidden_state_shape,
recurrent_initializer: recurrent_initializer
)

new_c = layer(output, fn x -> elem(elem(x, 0), 0) end, hidden_state_shape, [])
new_h = layer(output, fn x -> elem(elem(x, 0), 1) end, hidden_state_shape, [])
output_sequence = layer(output, fn x -> elem(x, 1) end, output_shape, [])

{{new_c, new_h}, output_sequence}
end

@doc """
GRU Layer.
"""
def gru(%Axon{output_shape: shape} = x, units, opts \\ [])
when is_integer(units) and units > 0 do
activation = opts[:activation] || :tanh
gate = opts[:gate] || :sigmoid
hidden_state = opts[:hidden_state]

output_shape = Axon.Shape.rnn(shape, units, "GRU")
input_kernel_shape = Axon.Shape.rnn_input_kernel(shape, units, "GRU")
hidden_kernel_shape = Axon.Shape.rnn_hidden_kernel(shape, units, "GRU")
bias_shape = Axon.Shape.rnn_bias(shape, units, "GRU")
hidden_state_shape = Axon.Shape.rnn_hidden_state(shape, units, "GRU")

kernel_initializer = opts[:kernel_initializer] || :glorot_uniform
recurrent_initializer = opts[:recurrent_initializer] || :glorot_uniform
bias_initializer = opts[:bias_initializer] || :zeros

wir = param("wir", input_kernel_shape, initializer: kernel_initializer)
wiz = param("wiz", input_kernel_shape, initializer: kernel_initializer)
win = param("win", input_kernel_shape, initializer: kernel_initializer)
whr = param("whr", hidden_kernel_shape, initializer: kernel_initializer)
whz = param("whz", hidden_kernel_shape, initializer: kernel_initializer)
whn = param("whn", hidden_kernel_shape, initializer: kernel_initializer)
br = param("br", bias_shape, initializer: bias_initializer)
bz = param("bz", bias_shape, initializer: bias_initializer)
bin = param("bin", bias_shape, initializer: bias_initializer)
bhn = param("bhn", bias_shape, initializer: bias_initializer)

output =
layer(
x,
:gru,
{{hidden_state_shape}, output_shape},
[wir, wiz, win, whr, whz, whn, br, bz, bin, bhn],
opts[:name],
activation: activation,
gate: gate,
hidden_state: hidden_state,
hidden_state_shape: hidden_state_shape,
recurrent_initializer: recurrent_initializer
)

new_h = layer(output, fn x -> elem(elem(x, 0), 0) end, hidden_state_shape, [])
output_sequence = layer(output, fn x -> elem(x, 1) end, output_shape, [])

{{new_h}, output_sequence}
end

@doc """
ConvLSTM Layer.
"""
def conv_lstm(%Axon{output_shape: shape} = x, units, opts \\ []) do
padding = opts[:padding] || :same
kernel_size = opts[:kernel_size] || 1
strides = opts[:strides] || 1
hidden_state = opts[:hidden_state]

kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, 1)
strides = list_or_duplicate(:strides, strides, 1)

hidden_state_shape = Axon.Shape.rnn_hidden_state(shape, units, "ConvLSTM")
input_kernel_shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size)
hidden_kernel_shape = Axon.Shape.conv_kernel(hidden_state_shape, 4 * units, kernel_size)
bias_shape = Axon.Shape.conv_bias(shape, 4 * units, kernel_size)

output_shape = Axon.Shape.rnn(shape, units, "ConvLSTM")

kernel_initializer = opts[:kernel_initializer] || :glorot_uniform
recurrent_initializer = opts[:recurrent_initializer] || :glorot_uniform
bias_initializer = opts[:bias_initializer] || :zeros

wi = param("wi", input_kernel_shape, initializer: kernel_initializer)
wh = param("wh", hidden_kernel_shape, initializer: kernel_initializer)
b = param("b", bias_shape, initializer: bias_initializer)

output =
layer(
x,
:conv_lstm,
{{hidden_state_shape, hidden_state_shape}, output_shape},
[wi, wh, b],
opts[:name],
hidden_state: hidden_state,
strides: strides,
padding: padding,
hidden_state_shape: hidden_state_shape,
recurrent_initializer: recurrent_initializer
)

new_c = layer(output, fn x -> elem(elem(x, 0), 0) end, hidden_state_shape, [])
new_h = layer(output, fn x -> elem(elem(x, 0), 1) end, hidden_state_shape, [])
output_sequence = layer(output, fn x -> elem(x, 1) end, output_shape, [])

{{new_c, new_h}, output_sequence}
end

@doc """
Compiles the given model to `{init_fn, predict_fn}`.
"""
Expand Down
Loading

0 comments on commit 275198d

Please sign in to comment.