Skip to content

Commit

Permalink
Use parameter map (elixir-nx#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored May 26, 2021
1 parent 29141b3 commit d1bd802
Show file tree
Hide file tree
Showing 12 changed files with 432 additions and 730 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
- name: Check formatting
run: mix format --check-formatted
- name: Run tests
run: mix test
run: MIX_ENV=test mix do compile --warnings-as-errors, test
2 changes: 1 addition & 1 deletion examples/cifar10.exs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ IO.inspect model

final_training_state =
model
|> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adam(0.01), metrics: [:accuracy])
|> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.sgd(0.01), metrics: [:accuracy])
|> Axon.Training.train(train_images, train_labels, epochs: 20, compiler: EXLA)
|> Nx.backend_transfer()
|> IO.inspect()
Expand Down
1 change: 0 additions & 1 deletion examples/mnist.exs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ IO.inspect train_images |> hd() |> Nx.slice_axis(0, 1, 0) |> Nx.reshape({1, 28,
model =
Axon.input({nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.layer_norm()
|> Axon.dropout()
|> Axon.dense(10, activation: :softmax)

Expand Down
4 changes: 2 additions & 2 deletions examples/xor.exs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ targets =
Nx.logical_xor(x1, x2)
end

{params, _} =
final_training_state =
model
|> Axon.Training.step(:binary_cross_entropy, Axon.Optimizers.sgd(0.01))
|> Axon.Training.train(data, targets, epochs: 10)

IO.inspect Axon.predict(model, params, {Nx.tensor([[0]]), Nx.tensor([[1]])})
IO.inspect Axon.predict(model, final_training_state[:params], {Nx.tensor([[0]]), Nx.tensor([[1]])})
124 changes: 81 additions & 43 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,39 @@ defmodule Axon do
Custom Axon layer with given parent.
Applies `op` on `parent` with parameters `parameters`. `parameters`
is a list of trainable `parameters` created using `Axon.param`. Assumes
is a map of trainable `parameters` created using `Axon.param`. Assumes
`op` is a function of the following form:
op = fn input, w1, w2, ... wn -> ... end
op = fn input, params -> ... end
If `opts` is not empty, it is treated as input options to the layer
method:
op = fn input, w1, w2, ... wn, opts -> ... end
op = fn input, params, opts -> ... end
Parameters *must* be declared in the order of their usage:
Parameters are accessed using the same key referenced in the `parameters`
map passed to `Axon.layer`:
w1 = Axon.param("weight", {})
b1 = Axon.param("bias", {})
op = fn input, w1, b1 -> w1 * input + b1 end
op = fn input, params -> params["weight"] * input + params["bias"] end
Axon.layer(parent, op, {}, [w1, b1])
Axon.layer(parent, op, {}, %{"weight" => w1, "bias" => b1})
"""
@doc type: :special
def layer(parent, op, output_shape, parameters, name \\ nil, opts \\ [])
when is_atom(op) or is_function(op) do
when is_atom(op) or (is_function(op) and is_map(parameters)) do
op_name = if is_atom(op), do: op, else: :layer

{id, name} = unique_identifiers(op_name, name)

parameters =
parameters
|> Enum.map(fn %{name: p_name} = param -> %{param | name: name <> "_" <> p_name} end)
|> Enum.map(fn {k, %{name: p_name} = param} ->
{k, %{param | name: name <> "_" <> p_name}}
end)
|> Map.new()

%Axon{
id: id,
Expand Down Expand Up @@ -171,7 +175,7 @@ defmodule Axon do
@doc type: :special
def input(input_shape, opts \\ []) do
output_shape = Axon.Shape.input(input_shape)
layer(nil, :input, output_shape, [], opts[:name], opts)
layer(nil, :input, output_shape, %{}, opts[:name], opts)
end

@doc """
Expand Down Expand Up @@ -217,7 +221,7 @@ defmodule Axon do
bias_regularizer = opts[:bias_regularizer]
bias = param("bias", bias_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node = layer(x, :dense, output_shape, [kernel, bias], opts[:name])
node = layer(x, :dense, output_shape, %{"kernel" => kernel, "bias" => bias}, opts[:name])

if activation do
node
Expand Down Expand Up @@ -293,7 +297,7 @@ defmodule Axon do
bias = param("bias", bias_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
layer(x, :conv, output_shape, [kernel, bias], opts[:name],
layer(x, :conv, output_shape, %{"kernel" => kernel, "bias" => bias}, opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -367,7 +371,7 @@ defmodule Axon do
)

node =
layer(x, :conv_transpose, output_shape, [kernel, bias], opts[:name],
layer(x, :conv_transpose, output_shape, %{"kernel" => kernel, "bias" => bias}, opts[:name],
strides: strides,
padding: padding,
kernel_dilation: kernel_dilation
Expand Down Expand Up @@ -452,7 +456,7 @@ defmodule Axon do
bias = param("bias", bias_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
layer(x, :depthwise_conv, output_shape, [kernel, bias], opts[:name],
layer(x, :depthwise_conv, output_shape, %{"kernel" => kernel, "bias" => bias}, opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -541,7 +545,12 @@ defmodule Axon do
b2 = param("bias_2", b2_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
layer(x, :separable_conv2d, output_shape, [k1, b1, k2, b2], opts[:name],
layer(
x,
:separable_conv2d,
output_shape,
%{"k1" => k1, "b1" => b1, "k2" => k2, "b2" => b2},
opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -638,7 +647,12 @@ defmodule Axon do
b3 = param("bias_3", b3_shape, initializer: bias_initializer, regularizer: bias_regularizer)

node =
layer(x, :separable_conv3d, output_shape, [k1, b1, k2, b2, k3, b3], opts[:name],
layer(
x,
:separable_conv3d,
output_shape,
%{"k1" => k1, "b1" => b1, "k2" => k2, "b2" => b2, "k3" => k3, "b3" => b3},
opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
Expand Down Expand Up @@ -691,12 +705,12 @@ defmodule Axon do
def activation(x, activation, opts \\ [])

def activation(%Axon{output_shape: shape} = x, activation, opts) when is_atom(activation) do
layer(x, activation, shape, [], opts[:name], opts)
layer(x, activation, shape, %{}, opts[:name], opts)
end

def activation(%Axon{output_shape: shape} = x, activation, opts)
when is_function(activation, 1) do
layer(x, activation, shape, [], opts[:name], opts)
layer(x, activation, shape, %{}, opts[:name], opts)
end

## Activation
Expand Down Expand Up @@ -747,7 +761,7 @@ defmodule Axon do

defp dropout(%Axon{output_shape: parent_shape} = x, dropout, opts) do
rate = opts[:rate] || 0.5
layer(x, dropout, parent_shape, [], opts[:name], rate: rate)
layer(x, dropout, parent_shape, %{}, opts[:name], rate: rate)
end

## Pooling
Expand Down Expand Up @@ -787,7 +801,7 @@ defmodule Axon do
strides = list_or_duplicate(:strides, strides, inner_rank)
output_shape = Axon.Shape.pool(parent_shape, kernel_size, strides, padding)

layer(x, pool, output_shape, [], opts[:name],
layer(x, pool, output_shape, %{}, opts[:name],
kernel_size: kernel_size,
strides: strides,
padding: padding
Expand Down Expand Up @@ -825,7 +839,7 @@ defmodule Axon do
output_size = tuple_or_duplicate(:output_size, opts[:output_size], inner_rank)
output_shape = Axon.Shape.adaptive_pool(parent_shape, output_size)

layer(x, pool, output_shape, [], opts[:name], output_size: output_size)
layer(x, pool, output_shape, %{}, opts[:name], output_size: output_size)
end

## Normalization
Expand Down Expand Up @@ -875,7 +889,7 @@ defmodule Axon do
beta_regularizer = opts[:beta_regularizer]
beta = param("beta", beta_shape, initializer: beta_initializer, regularizer: beta_regularizer)

layer(x, norm, shape, [gamma, beta], opts[:name],
layer(x, norm, shape, %{"gamma" => gamma, "beta" => beta}, opts[:name],
epsilon: epsilon,
channel_index: channel_index
)
Expand Down Expand Up @@ -915,7 +929,7 @@ defmodule Axon do
beta_regularizer = opts[:beta_regularizer]
beta = param("beta", beta_shape, initializer: beta_initializer, regularizer: beta_regularizer)

layer(x, :group_norm, shape, [gamma, beta], opts[:name],
layer(x, :group_norm, shape, %{"gamma" => gamma, "beta" => beta}, opts[:name],
epsilon: epsilon,
channel_index: channel_index,
group_size: group_size
Expand All @@ -941,7 +955,7 @@ defmodule Axon do
expr = Nx.Defn.jit(fun, [param], compiler: Axon.Defn)
output_shape = Tuple.insert_at(expr.shape, 0, batch_size)

layer(x, fun, output_shape, [], opts[:name])
layer(x, fun, output_shape, %{}, opts[:name])
end

@doc """
Expand All @@ -959,7 +973,7 @@ defmodule Axon do
@doc type: :shape
def flatten(%Axon{output_shape: shape} = x, opts \\ []) do
output_shape = Axon.Shape.flatten(shape)
layer(x, :flatten, output_shape, [], opts[:name])
layer(x, :flatten, output_shape, %{}, opts[:name])
end

@doc """
Expand All @@ -975,7 +989,7 @@ defmodule Axon do
@doc type: :shape
def reshape(%Axon{output_shape: shape} = x, new_shape, opts \\ []) do
output_shape = Axon.Shape.reshape(shape, new_shape)
layer(x, :reshape, output_shape, [], opts[:name])
layer(x, :reshape, output_shape, %{}, opts[:name])
end

@doc """
Expand All @@ -990,7 +1004,7 @@ defmodule Axon do
@doc type: :shape
def transpose(%Axon{output_shape: shape} = x, permutation, opts \\ []) do
output_shape = Axon.Shape.transpose(shape, permutation)
layer(x, :transpose, output_shape, [], opts[:name], permutation: permutation)
layer(x, :transpose, output_shape, %{}, opts[:name], permutation: permutation)
end

@doc """
Expand All @@ -1008,7 +1022,7 @@ defmodule Axon do
def pad(%Axon{output_shape: shape} = x, config, value \\ 0.0, opts \\ [])
when is_list(config) and is_number(value) do
output_shape = Axon.Shape.pad(shape, config)
layer(x, :pad, output_shape, [], opts[:name], padding_config: config, value: value)
layer(x, :pad, output_shape, %{}, opts[:name], padding_config: config, value: value)
end

@doc """
Expand All @@ -1029,7 +1043,7 @@ defmodule Axon do
axis = opts[:axis] || Nx.rank(x_shape) - 1
output_shape = Axon.Shape.concatenate([x_shape, y_shape], axis)

layer([x, y], :concatenate, output_shape, [], opts[:name], axis: axis)
layer([x, y], :concatenate, output_shape, %{}, opts[:name], axis: axis)
end

@doc type: :composition
Expand All @@ -1039,7 +1053,7 @@ defmodule Axon do
input_shapes = inputs |> Enum.map(fn %Axon{output_shape: shape} -> shape end)
output_shape = Axon.Shape.concatenate(input_shapes, axis)

layer(inputs, :concatenate, output_shape, [], opts[:name], axis: axis)
layer(inputs, :concatenate, output_shape, %{}, opts[:name], axis: axis)
end

@doc false
Expand All @@ -1064,7 +1078,7 @@ defmodule Axon do
"""
@doc type: :composition
def unquote(op)(%Axon{output_shape: shape} = x, %Axon{output_shape: shape} = y, opts) do
Axon.layer([x, y], unquote(op), shape, [], opts[:name])
Axon.layer([x, y], unquote(op), shape, %{}, opts[:name])
end

@doc """
Expand All @@ -1090,7 +1104,7 @@ defmodule Axon do
shape
end)

layer(inputs, unquote(op), output_shape, [], [], opts[:name])
layer(inputs, unquote(op), output_shape, %{}, [], opts[:name])
end

@doc false
Expand Down Expand Up @@ -1161,7 +1175,20 @@ defmodule Axon do
x,
:lstm,
{{hidden_state_shape, hidden_state_shape}, output_shape},
[wii, wif, wig, wio, whi, whf, whg, who, bi, bf, bg, bo],
%{
"wii" => wii,
"wif" => wif,
"wig" => wig,
"wio" => wio,
"whi" => whi,
"whf" => whf,
"whg" => whg,
"who" => who,
"bi" => bi,
"bf" => bf,
"bg" => bg,
"bo" => bo
},
opts[:name],
activation: activation,
gate: gate,
Expand All @@ -1171,9 +1198,9 @@ defmodule Axon do
unroll: unroll
)

new_c = layer(output, fn x -> elem(elem(x, 0), 0) end, hidden_state_shape, [])
new_h = layer(output, fn x -> elem(elem(x, 0), 1) end, hidden_state_shape, [])
output_sequence = layer(output, fn x -> elem(x, 1) end, output_shape, [])
new_c = layer(output, fn x, _ -> elem(elem(x, 0), 0) end, hidden_state_shape, %{})
new_h = layer(output, fn x, _ -> elem(elem(x, 0), 1) end, hidden_state_shape, %{})
output_sequence = layer(output, fn x, _ -> elem(x, 1) end, output_shape, %{})

{{new_c, new_h}, output_sequence}
end
Expand Down Expand Up @@ -1232,7 +1259,18 @@ defmodule Axon do
x,
:gru,
{{hidden_state_shape}, output_shape},
[wir, wiz, win, whr, whz, whn, br, bz, bin, bhn],
%{
"wir" => wir,
"wiz" => wiz,
"win" => win,
"whr" => whr,
"whz" => whz,
"whn" => whn,
"br" => br,
"bz" => bz,
"bin" => bin,
"bhn" => bhn
},
opts[:name],
activation: activation,
gate: gate,
Expand All @@ -1242,8 +1280,8 @@ defmodule Axon do
unroll: unroll
)

new_h = layer(output, fn x -> elem(elem(x, 0), 0) end, hidden_state_shape, [])
output_sequence = layer(output, fn x -> elem(x, 1) end, output_shape, [])
new_h = layer(output, fn x, _ -> elem(elem(x, 0), 0) end, hidden_state_shape, %{})
output_sequence = layer(output, fn x, _ -> elem(x, 1) end, output_shape, %{})

{{new_h}, output_sequence}
end
Expand Down Expand Up @@ -1300,7 +1338,7 @@ defmodule Axon do
x,
:conv_lstm,
{{hidden_state_shape, hidden_state_shape}, output_shape},
[wi, wh, b],
%{"wi" => wi, "wh" => wh, "b" => b},
opts[:name],
strides: strides,
padding: padding,
Expand All @@ -1310,9 +1348,9 @@ defmodule Axon do
unroll: unroll
)

new_c = layer(output, fn x -> elem(elem(x, 0), 0) end, hidden_state_shape, [])
new_h = layer(output, fn x -> elem(elem(x, 0), 1) end, hidden_state_shape, [])
output_sequence = layer(output, fn x -> elem(x, 1) end, output_shape, [])
new_c = layer(output, fn x, _ -> elem(elem(x, 0), 0) end, hidden_state_shape, %{})
new_h = layer(output, fn x, _ -> elem(elem(x, 0), 1) end, hidden_state_shape, %{})
output_sequence = layer(output, fn x, _ -> elem(x, 1) end, output_shape, %{})

{{new_c, new_h}, output_sequence}
end
Expand Down Expand Up @@ -1455,7 +1493,7 @@ defmodule Axon do

num_params =
params
|> Enum.reduce(0, fn %Axon.Parameter{shape: shape}, acc -> acc + Nx.size(shape) end)
|> Enum.reduce(0, fn {_, %Axon.Parameter{shape: shape}}, acc -> acc + Nx.size(shape) end)

row = [name <> " ( #{Atom.to_string(op)} )", "#{inspect(shape)}", "#{num_params}"]
{row, cache}
Expand Down
Loading

0 comments on commit d1bd802

Please sign in to comment.