Skip to content

Commit

Permalink
Make traversal API public
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jun 20, 2021
1 parent 103a1f3 commit d5d1e01
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 73 deletions.
98 changes: 55 additions & 43 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1392,68 +1392,80 @@ defmodule Axon do
"""
def freeze(%Axon{} = model, fun \\ & &1) when is_function(fun, 1) do
parameters =
model
|> get_params([])
|> Enum.uniq()
tree_reduce(model, MapSet.new(), fn %Axon{params: params}, acc ->
Enum.reduce(params, acc, fn {_, param}, acc ->
MapSet.put(acc, param)
end)
end)

parameters_to_freeze = fun.(parameters)
do_freeze(model, parameters_to_freeze)
end
parameters_to_freeze = fun.(Enum.to_list(parameters))

tree_map(model, fn %Axon{params: params} = axon ->
frozen_params =
params
|> Map.new(fn {k, %{name: param_name} = v} ->
if Enum.any?(parameters_to_freeze, fn %{name: name} -> name == param_name end) do
{k, %{v | frozen: true}}
else
{k, v}
end
end)

defp get_params(model, acc) when is_tuple(model) do
model
|> Tuple.to_list()
|> Enum.reduce(acc, &get_params/2)
%{axon | params: frozen_params}
end)
end

defp get_params(%Axon{op: :input}, acc), do: acc
## Traversal

defp get_params(%Axon{parent: x}, acc) when is_list(x) do
Enum.reduce(x, acc, &get_params/2)
@doc """
Traverses a model tree applying `fun` to each layer.
"""
def tree_map(%Axon{op: :input} = axon, fun) when is_function(fun, 1) do
fun.(axon)
end

defp get_params(%Axon{parent: x, params: params, opts: opts}, acc) do
acc =
def tree_map(%Axon{parent: x} = axon, fun) when is_list(x) do
x = Enum.map(x, &tree_map(&1, fun))
%{fun.(axon) | parent: x}
end

def tree_map(%Axon{parent: x, opts: opts} = axon, fun) do
opts =
case opts[:hidden_state] do
state when is_tuple(state) ->
state
|> Tuple.to_list()
|> Enum.reduce(acc, &get_params/2)
%Axon{} = hidden_state ->
hidden_state = tree_map(hidden_state, fun)
Keyword.replace(opts, :hidden_state, hidden_state)

nil ->
acc
opts
end

get_params(x, Enum.reduce(Map.values(params), acc, fn x, ls -> [x | ls] end))
x = tree_map(x, fun)
%{fun.(axon) | parent: x, opts: opts}
end

defp do_freeze(%Axon{op: :input} = x, _), do: x

defp do_freeze(%Axon{parent: parent} = x, parameters_to_freeze) when is_list(parent) do
parent = Enum.map(parent, &do_freeze(&1, parameters_to_freeze))
%{x | parent: parent}
@doc """
Traverses a model applying `fun` with an accumulator.
"""
def tree_reduce(%Axon{op: :input} = axon, acc, fun) when is_function(fun, 2) do
fun.(axon, acc)
end

defp do_freeze(%Axon{parent: parent, params: params} = x, parameters_to_freeze) do
parent = do_freeze(parent, parameters_to_freeze)
def tree_reduce(%Axon{parent: x} = axon, acc, fun) when is_list(x) do
Enum.reduce(x, fun.(axon, acc), &tree_reduce(&1, &2, fun))
end

params =
params
|> Map.new(fn {k, %{name: param_name} = v} ->
if Enum.any?(parameters_to_freeze, fn %{name: name} -> name == param_name end) do
{k, %{v | frozen: true}}
else
{k, v}
end
end)
def tree_reduce(%Axon{parent: x, opts: opts} = axon, acc, fun) do
acc =
case opts[:hidden_state] do
%Axon{} = hidden_state ->
tree_reduce(hidden_state, acc, fun)

%{x | parent: parent, params: params}
end
nil ->
acc
end

defp do_freeze(x, parameters_to_freeze) when is_tuple(x) do
x
|> Tuple.to_list()
|> Enum.map(&do_freeze(&1, parameters_to_freeze))
tree_reduce(x, fun.(axon, acc), fun)
end

@doc """
Expand Down
5 changes: 2 additions & 3 deletions lib/axon/mixed_precision.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ defmodule Axon.MixedPrecision do
pass before casting the output back to `{:f, 32}`.
"""

import Axon.Shared
alias Axon.MixedPrecision.Policy

@doc """
Expand Down Expand Up @@ -87,11 +86,11 @@ defmodule Axon.MixedPrecision do
%Axon{op: :dense} -> true
%Axon{op: :batch_norm} -> false
%Axon{op: :conv} -> false
%Axon{op: _} -> true
%Axon{op: _} -> true
end)
"""
def apply_policy(%Axon{} = axon, %Policy{} = policy, filter) when is_function(filter) do
tree_map(axon, fn layer ->
Axon.tree_map(axon, fn layer ->
if filter.(layer) do
%{layer | policy: policy}
else
Expand Down
27 changes: 0 additions & 27 deletions lib/axon/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -116,33 +116,6 @@ defmodule Axon.Shared do
end
end

@doc """
Traverses a model tree applying `fun` to each layer.
"""
def tree_map(%Axon{op: :input} = axon, fun) when is_function(fun, 1) do
fun.(axon)
end

def tree_map(%Axon{parent: x} = axon, fun) when is_list(x) do
x = Enum.map(x, &tree_map(&1, fun))
%{fun.(axon) | parent: x}
end

def tree_map(%Axon{parent: x, opts: opts} = axon, fun) do
opts =
case opts[:hidden_state] do
%Axon{} = hidden_state ->
hidden_state = tree_map(hidden_state, fun)
Keyword.replace(opts, :hidden_state, hidden_state)

nil ->
opts
end

x = tree_map(x, fun)
%{fun.(axon) | parent: x, opts: opts}
end

## Numerical Helpers

# TODO: These should be contained somewhere else, like another library
Expand Down

0 comments on commit d5d1e01

Please sign in to comment.