Skip to content

Commit

Permalink
Add aggregate metrics (elixir-nx#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Oct 20, 2021
1 parent 702b4af commit aa664f4
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 56 deletions.
2 changes: 1 addition & 1 deletion examples/cifar10.exs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ defmodule Cifar do
end

defp log_metrics(
%State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state,
%State{epoch: epoch, iteration: iter, metrics: metrics, step_state: pstate} = state,
mode
) do
loss =
Expand Down
2 changes: 1 addition & 1 deletion examples/fashionmnist_autoencoder.exs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ defmodule Fashionmist do
end

defp log_metrics(
%State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state,
%State{epoch: epoch, iteration: iter, metrics: metrics, step_state: pstate} = state,
mode
) do
loss =
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist.exs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ defmodule Mnist do
end

defp log_metrics(
%State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state,
%State{epoch: epoch, iteration: iter, metrics: metrics, step_state: pstate} = state,
mode
) do
loss =
Expand Down
85 changes: 53 additions & 32 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,6 @@ defmodule Axon.Loop do
require Axon.Updates
require Logger

# TODO(seanmor5): Remove when running average is gone
import Nx.Defn

alias __MODULE__, as: Loop
alias Axon.Loop.State

Expand Down Expand Up @@ -338,7 +335,11 @@ defmodule Axon.Loop do
fn x -> elem(x, 1) end
)

new_loss = running_average(loss, batch_loss, i)
new_loss =
loss
|> Nx.multiply(i)
|> Nx.add(batch_loss)
|> Nx.divide(Nx.add(i, 1))

{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)
Expand Down Expand Up @@ -572,11 +573,21 @@ defmodule Axon.Loop do
loop
|> Axon.Loop.metric(:mean_squared_error, "Error") # Will be overwritten
|> Axon.Loop.metric(:mean_absolute_error, "Error") # Will be used
By default, metrics keep a running average of the metric calculation. You can
override this behavior by changing `accumulate`:
loop
|> Axon.Loop.metric(:true_negatives, "tn", :running_sum)
Accumulation function can be one of the accumulation combinators in Axon.Metrics
or an arity-3 function of the form: `accumulate(acc, obs, i) :: new_acc`.
"""
def metric(
%Loop{metrics: metric_fns} = loop,
metric,
name \\ nil,
accumulate \\ :running_average,
transform_or_fields \\ [:y_true, :y_pred]
) do
name =
Expand All @@ -602,7 +613,7 @@ defmodule Axon.Loop do
:ok
end

metric_fn = build_metric_fn(metric, transform_or_fields)
metric_fn = build_metric_fn(metric, accumulate, transform_or_fields)
%Loop{loop | metrics: Map.put(metric_fns, name, metric_fn)}
end

Expand Down Expand Up @@ -963,7 +974,7 @@ defmodule Axon.Loop do
new_metrics =
metrics
|> Enum.zip_with(metric_fns, fn {k, avg}, {k, v} ->
{k, running_average(avg, v.(new_step_state), iter)}
{k, v.(avg, List.wrap(new_step_state), iter)}
end)
|> Map.new()

Expand Down Expand Up @@ -1082,7 +1093,7 @@ defmodule Axon.Loop do
# to extract from the step state, or a function which transforms the step
# state before it is passed to the metric function.
# TODO(seanmor5): Reconsider the form of output transform
defp build_metric_fn(metric, transform_or_fields) do
defp build_metric_fn(metric, accumulator, transform_or_fields) do
transform_fn =
case transform_or_fields do
[_ | _] = fields ->
Expand All @@ -1108,27 +1119,45 @@ defmodule Axon.Loop do
" applied to the step state"
end

case metric do
metric when is_atom(metric) ->
fn output ->
output
|> transform_fn.()
|> then(&apply(Axon.Metrics, metric, &1))
end
metric_fn =
case metric do
metric when is_atom(metric) ->
fn output ->
output
|> transform_fn.()
|> then(&apply(Axon.Metrics, metric, &1))
end

metric_fn when is_function(metric) ->
fn output ->
output
|> transform_fn.()
|> then(&apply(metric_fn, &1))
end
metric_fn when is_function(metric) ->
fn output ->
output
|> transform_fn.()
|> then(&apply(metric_fn, &1))
|> List.wrap()
end

invalid ->
raise ArgumentError,
"Invalid metric #{inspect(invalid)}, a valid metric" <>
" is an atom which matches the name of a function in" <>
" Axon.Metrics or a function which takes a transformed" <>
" step state and returns a value"
end

case accumulator do
acc_fun when acc_fun in [:running_average, :running_sum] ->
apply(Axon.Metrics, acc_fun, [metric_fn])

acc_fun when is_function(acc_fun, 3) ->
&acc_fun.(&1, metric_fn.(&2), &3)

invalid ->
raise ArgumentError,
"Invalid metric #{inspect(invalid)}, a valid metric" <>
" is an atom which matches the name of a function in" <>
" Axon.Metrics or a function which takes a transformed" <>
" step state and returns a value"
"Invalid accumulation function #{inspect(invalid)}, a valid" <>
" accumulation function is an atom which matches the name" <>
" of an accumulation function in Axon.Metrics, or an arity-3" <>
" function which takes current accumulator, observation, and" <>
" iteration and returns an updated accumulator"
end
end

Expand Down Expand Up @@ -1179,12 +1208,4 @@ defmodule Axon.Loop do
apply(fun, args)
end
end

# TODO(seanmor5): Move to metrics as a combinator
defnp running_average(avg, value, i) do
avg
|> Nx.multiply(i)
|> Nx.add(value)
|> Nx.divide(Nx.add(i, 1))
end
end
Loading

0 comments on commit aa664f4

Please sign in to comment.