diff --git a/examples/cifar10.exs b/examples/cifar10.exs index 24123c52..85da235c 100644 --- a/examples/cifar10.exs +++ b/examples/cifar10.exs @@ -45,7 +45,7 @@ defmodule Cifar do end defp log_metrics( - %State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state, + %State{epoch: epoch, iteration: iter, metrics: metrics, step_state: pstate} = state, mode ) do loss = diff --git a/examples/fashionmnist_autoencoder.exs b/examples/fashionmnist_autoencoder.exs index f9967078..d2407b2c 100644 --- a/examples/fashionmnist_autoencoder.exs +++ b/examples/fashionmnist_autoencoder.exs @@ -41,7 +41,7 @@ defmodule Fashionmist do end defp log_metrics( - %State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state, + %State{epoch: epoch, iteration: iter, metrics: metrics, step_state: pstate} = state, mode ) do loss = diff --git a/examples/mnist.exs b/examples/mnist.exs index 4cf9aa39..ed40716a 100644 --- a/examples/mnist.exs +++ b/examples/mnist.exs @@ -41,7 +41,7 @@ defmodule Mnist do end defp log_metrics( - %State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state, + %State{epoch: epoch, iteration: iter, metrics: metrics, step_state: pstate} = state, mode ) do loss = diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 7b07f03b..a5761ddf 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -208,9 +208,6 @@ defmodule Axon.Loop do require Axon.Updates require Logger - # TODO(seanmor5): Remove when running average is gone - import Nx.Defn - alias __MODULE__, as: Loop alias Axon.Loop.State @@ -338,7 +335,11 @@ defmodule Axon.Loop do fn x -> elem(x, 1) end ) - new_loss = running_average(loss, batch_loss, i) + new_loss = + loss + |> Nx.multiply(i) + |> Nx.add(batch_loss) + |> Nx.divide(Nx.add(i, 1)) {updates, new_optimizer_state} = update_optimizer_fn.(gradients, optimizer_state, model_state) @@ -572,11 +573,21 @@ defmodule Axon.Loop do loop |> Axon.Loop.metric(:mean_squared_error, "Error") # Will be overwritten |> Axon.Loop.metric(:mean_absolute_error, "Error") # Will be used + + By default, metrics keep a running average of the metric calculation. You can + override this behavior by changing `accumulate`: + + loop + |> Axon.Loop.metric(:true_negatives, "tn", :running_sum) + + Accumulation function can be one of the accumulation combinators in Axon.Metrics + or an arity-3 function of the form: `accumulate(acc, obs, i) :: new_acc`. """ def metric( %Loop{metrics: metric_fns} = loop, metric, name \\ nil, + accumulate \\ :running_average, transform_or_fields \\ [:y_true, :y_pred] ) do name = @@ -602,7 +613,7 @@ defmodule Axon.Loop do :ok end - metric_fn = build_metric_fn(metric, transform_or_fields) + metric_fn = build_metric_fn(metric, accumulate, transform_or_fields) %Loop{loop | metrics: Map.put(metric_fns, name, metric_fn)} end @@ -963,7 +974,7 @@ defmodule Axon.Loop do new_metrics = metrics |> Enum.zip_with(metric_fns, fn {k, avg}, {k, v} -> - {k, running_average(avg, v.(new_step_state), iter)} + {k, v.(avg, List.wrap(new_step_state), iter)} end) |> Map.new() @@ -1082,7 +1093,7 @@ defmodule Axon.Loop do # to extract from the step state, or a function which transforms the step # state before it is passed to the metric function. # TODO(seanmor5): Reconsider the form of output transform - defp build_metric_fn(metric, transform_or_fields) do + defp build_metric_fn(metric, accumulator, transform_or_fields) do transform_fn = case transform_or_fields do [_ | _] = fields -> @@ -1108,27 +1119,45 @@ defmodule Axon.Loop do " applied to the step state" end - case metric do - metric when is_atom(metric) -> - fn output -> - output - |> transform_fn.() - |> then(&apply(Axon.Metrics, metric, &1)) - end + metric_fn = + case metric do + metric when is_atom(metric) -> + fn output -> + output + |> transform_fn.() + |> then(&apply(Axon.Metrics, metric, &1)) + end - metric_fn when is_function(metric) -> - fn output -> - output - |> transform_fn.() - |> then(&apply(metric_fn, &1)) - end + metric_fn when is_function(metric) -> + fn output -> + output + |> transform_fn.() + |> then(&apply(metric_fn, &1)) + |> List.wrap() + end + + invalid -> + raise ArgumentError, + "Invalid metric #{inspect(invalid)}, a valid metric" <> + " is an atom which matches the name of a function in" <> + " Axon.Metrics or a function which takes a transformed" <> + " step state and returns a value" + end + + case accumulator do + acc_fun when acc_fun in [:running_average, :running_sum] -> + apply(Axon.Metrics, acc_fun, [metric_fn]) + + acc_fun when is_function(acc_fun, 3) -> + &acc_fun.(&1, metric_fn.(&2), &3) invalid -> raise ArgumentError, - "Invalid metric #{inspect(invalid)}, a valid metric" <> - " is an atom which matches the name of a function in" <> - " Axon.Metrics or a function which takes a transformed" <> - " step state and returns a value" + "Invalid accumulation function #{inspect(invalid)}, a valid" <> + " accumulation function is an atom which matches the name" <> + " of an accumulation function in Axon.Metrics, or an arity-3" <> + " function which takes current accumulator, observation, and" <> + " iteration and returns an updated accumulator" end end @@ -1179,12 +1208,4 @@ defmodule Axon.Loop do apply(fun, args) end end - - # TODO(seanmor5): Move to metrics as a combinator - defnp running_average(avg, value, i) do - avg - |> Nx.multiply(i) - |> Nx.add(value) - |> Nx.divide(Nx.add(i, 1)) - end end diff --git a/lib/axon/metrics.ex b/lib/axon/metrics.ex index 390a0cbb..cf4dc81d 100644 --- a/lib/axon/metrics.ex +++ b/lib/axon/metrics.ex @@ -25,6 +25,8 @@ defmodule Axon.Metrics do import Nx.Defn import Axon.Shared + # Standard Metrics + @doc ~S""" Computes the accuracy of the given predictions, assuming both targets and predictions are one-hot encoded. @@ -81,17 +83,11 @@ defmodule Axon.Metrics do defn precision(y_true, y_pred, opts \\ []) do assert_shape!(y_true, y_pred) - opts = keyword!(opts, threshold: 0.5) - - thresholded_preds = - y_pred - |> Nx.greater(opts[:threshold]) + true_positives = true_positives(y_true, y_pred, opts) + false_positives = false_positives(y_true, y_pred, opts) - thresholded_preds - |> Nx.equal(y_true) - |> Nx.logical_and(Nx.equal(thresholded_preds, 1)) - |> Nx.sum() - |> Nx.divide(Nx.sum(thresholded_preds) + 1.0e-16) + true_positives + |> Nx.divide(true_positives + false_positives + 1.0e-16) end @doc ~S""" @@ -120,25 +116,146 @@ defmodule Axon.Metrics do defn recall(y_true, y_pred, opts \\ []) do assert_shape!(y_true, y_pred) + true_positives = true_positives(y_true, y_pred, opts) + false_negatives = false_negatives(y_true, y_pred, opts) + + Nx.divide(true_positives, false_negatives + true_positives + 1.0e-16) + end + + @doc """ + Computes the number of true positive predictions with respect + to given targets. + + ## Options + + * `:threshold` - threshold for truth value of predictions. + Defaults to `0.5`. + + ## Examples + + iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0]) + iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) + iex> Axon.Metrics.true_positives(y_true, y_pred) + #Nx.Tensor< + u64 + 1 + > + """ + defn true_positives(y_true, y_pred, opts \\ []) do + assert_shape!(y_true, y_pred) + opts = keyword!(opts, threshold: 0.5) thresholded_preds = y_pred |> Nx.greater(opts[:threshold]) - true_positives = - thresholded_preds - |> Nx.equal(y_true) - |> Nx.logical_and(Nx.equal(thresholded_preds, 1)) - |> Nx.sum() + thresholded_preds + |> Nx.equal(y_true) + |> Nx.logical_and(Nx.equal(thresholded_preds, 1)) + |> Nx.sum() + end - false_negatives = - thresholded_preds - |> Nx.not_equal(y_true) - |> Nx.logical_and(Nx.equal(thresholded_preds, 0)) - |> Nx.sum() + @doc """ + Computes the number of false negative predictions with respect + to given targets. - Nx.divide(true_positives, false_negatives + true_positives + 1.0e-16) + ## Options + + * `:threshold` - threshold for truth value of predictions. + Defaults to `0.5`. + + ## Examples + + iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0]) + iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) + iex> Axon.Metrics.false_negatives(y_true, y_pred) + #Nx.Tensor< + u64 + 3 + > + """ + defn false_negatives(y_true, y_pred, opts \\ []) do + assert_shape!(y_true, y_pred) + + opts = keyword!(opts, threshold: 0.5) + + thresholded_preds = + y_pred + |> Nx.greater(opts[:threshold]) + + thresholded_preds + |> Nx.not_equal(y_true) + |> Nx.logical_and(Nx.equal(thresholded_preds, 0)) + |> Nx.sum() + end + + @doc """ + Computes the number of true negative predictions with respect + to given targets. + + ## Options + + * `:threshold` - threshold for truth value of predictions. + Defaults to `0.5`. + + ## Examples + + iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0]) + iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) + iex> Axon.Metrics.true_negatives(y_true, y_pred) + #Nx.Tensor< + u64 + 1 + > + """ + defn true_negatives(y_true, y_pred, opts \\ []) do + assert_shape!(y_true, y_pred) + + opts = keyword!(opts, threshold: 0.5) + + thresholded_preds = + y_pred + |> Nx.greater(opts[:threshold]) + + thresholded_preds + |> Nx.equal(y_true) + |> Nx.logical_and(Nx.equal(thresholded_preds, 0)) + |> Nx.sum() + end + + @doc """ + Computes the number of false positive predictions with respect + to given targets. + + ## Options + + * `:threshold` - threshold for truth value of predictions. + Defaults to `0.5`. + + ## Examples + + iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0]) + iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) + iex> Axon.Metrics.false_positives(y_true, y_pred) + #Nx.Tensor< + u64 + 2 + > + """ + defn false_positives(y_true, y_pred, opts \\ []) do + assert_shape!(y_true, y_pred) + + opts = keyword!(opts, threshold: 0.5) + + thresholded_preds = + y_pred + |> Nx.greater(opts[:threshold]) + + thresholded_preds + |> Nx.not_equal(y_true) + |> Nx.logical_and(Nx.equal(thresholded_preds, 1)) + |> Nx.sum() end @doc ~S""" @@ -246,4 +363,59 @@ defmodule Axon.Metrics do |> Nx.abs() |> Nx.mean() end + + # Combinators + + @doc """ + Returns a function which computes a running average given current average, + new observation, and current iteration. + + ## Examples + + iex> cur_avg = 0.5 + iex> iteration = 1 + iex> y_true = Nx.tensor([[0, 1], [1, 0], [1, 0]]) + iex> y_pred = Nx.tensor([[0, 1], [1, 0], [1, 0]]) + iex> avg_acc = Axon.Metrics.running_average(&Axon.Metrics.accuracy/2) + iex> avg_acc.(cur_avg, [y_true, y_pred], iteration) + #Nx.Tensor< + f32 + 0.75 + > + """ + def running_average(metric) do + &running_average_impl(&1, apply(metric, &2), &3) + end + + defnp running_average_impl(avg, obs, i) do + avg + |> Nx.multiply(i) + |> Nx.add(obs) + |> Nx.divide(Nx.add(i, 1)) + end + + @doc """ + Returns a function which computes a running sum given current sum, + new observation, and current iteration. + + ## Examples + + iex> cur_sum = 12 + iex> iteration = 2 + iex> y_true = Nx.tensor([0, 1, 0, 1]) + iex> y_pred = Nx.tensor([1, 1, 0, 1]) + iex> fps = Axon.Metrics.running_sum(&Axon.Metrics.false_positives/2) + iex> fps.(cur_sum, [y_true, y_pred], iteration) + #Nx.Tensor< + s64 + 13 + > + """ + def running_sum(metric) do + &running_sum_impl(&1, apply(metric, &2), &3) + end + + defnp running_sum_impl(sum, obs, _) do + Nx.add(sum, obs) + end end diff --git a/test/loop_test.exs b/test/loop_test.exs index 5198c705..090f4f7c 100644 --- a/test/loop_test.exs +++ b/test/loop_test.exs @@ -210,6 +210,40 @@ defmodule Axon.LoopTest do |> Axon.Loop.metric(:accuracy) end) =~ "Metric accuracy declared twice in loop." end + + test "computes running average by default with supervised output transform" do + step_fn = fn _, _ -> 1 end + + loop = + step_fn + |> Loop.loop() + |> Loop.metric(:accuracy) + + assert %Loop{metrics: %{"accuracy" => avg_acc_fun}} = loop + + output = %{foo: 1, y_true: Nx.tensor([1, 0, 1]), y_pred: Nx.tensor([0.8, 0.2, 0.8])} + cur_avg_acc = 0.5 + i = 1 + + assert avg_acc_fun.(cur_avg_acc, List.wrap(output), i) == Nx.tensor(0.75) + end + + test "computes a running sum with custom output transform" do + step_fn = fn _, _ -> 1 end + + loop = + step_fn + |> Loop.loop() + |> Loop.metric(:true_positives, "tp", :running_sum, &Tuple.to_list/1) + + assert %Loop{metrics: %{"tp" => sum_tp_fun}} = loop + + output = {Nx.tensor([1, 0, 1]), Nx.tensor([0, 1, 1])} + cur_sum = 25 + i = 10 + + assert sum_tp_fun.(cur_sum, List.wrap(output), i) == Nx.tensor(26) + end end describe "looping" do