Skip to content

Commit

Permalink
fatmax and logsumexp for infinities (pytorch#1999)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1999

This commit improves the robustness of `fatmax` and `logsumexp` for inputs with infinities.
-  In constrast to `torch.logsumexp`, `logsumexp` does not give rise to `NaN`s in its backward pass even if infinities are present.
- `fatmax` is updated to exhibit the same behavior in the presence of infinities, and now allows for the specification of an `alpha` parameter, which controls the the asymptotic power decay of the fat-tailed approximation.

In addition, the commit introduces helper functions derivative of `logsumexp` and `fatmax`, e.g. `logplusexp`, `fatminimum`, `fatmaximum`, fixes a similar infinity issue with `logdiffexp`, and improves the associated test suite.

Reviewed By: Balandat

Differential Revision: D48878020

fbshipit-source-id: 46561efb10c921b77c1ed483ab383b30e8ac7e20
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Sep 1, 2023
1 parent 748b46a commit 9649b1c
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 37 deletions.
211 changes: 195 additions & 16 deletions botorch/utils/safe_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math

from typing import Tuple, Union
from typing import Callable, Tuple, Union

import torch
from botorch.exceptions import UnsupportedError
Expand All @@ -28,6 +28,9 @@
_log2 = math.log(2)
_inv_sqrt_3 = math.sqrt(1 / 3)

TAU = 1.0 # default temperature parameter for smooth approximations to non-linearities
ALPHA = 2.0 # default alpha parameter for the asymptotic power decay of _pareto


# Unary ops
def exp(x: Tensor, **kwargs) -> Tensor:
Expand Down Expand Up @@ -96,6 +99,12 @@ def logexpit(X: Tensor) -> Tensor:
return -log1pexp(-X)


def logplusexp(a: Tensor, b: Tensor) -> Tensor:
"""Computes log(exp(a) + exp(b)) similar to logsumexp."""
ab = torch.stack(torch.broadcast_tensors(a, b), dim=-1)
return logsumexp(ab, dim=-1)


def logdiffexp(log_a: Tensor, log_b: Tensor) -> Tensor:
"""Computes log(b - a) accurately given log(a) and log(b).
Assumes, log_b > log_a, i.e. b > a > 0.
Expand All @@ -107,7 +116,90 @@ def logdiffexp(log_a: Tensor, log_b: Tensor) -> Tensor:
Returns:
A Tensor of values corresponding to log(b - a).
"""
return log_b + log1mexp(log_a - log_b)
log_a, log_b = torch.broadcast_tensors(log_a, log_b)
is_inf = log_b == -torch.inf # implies log_a == -torch.inf by assumption
return log_b + log1mexp(log_a - log_b.masked_fill(is_inf, 0.0))


def logsumexp(
x: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False
) -> Tensor:
"""Version of logsumexp that has a well-behaved backward pass when
x contains infinities.
In particular, the gradient of the standard torch version becomes NaN
1) for any element that is positive infinity, and 2) for any slice that
only contains negative infinities.
This version returns a gradient of 1 for any positive infinities in case 1, and
for all elements of the slice in case 2, in agreement with the asymptotic behavior
of the function.
Args:
x: The Tensor to which to apply `logsumexp`.
dim: An integer or a tuple of integers, representing the dimensions to reduce.
keepdim: Whether to keep the reduced dimensions. Defaults to False.
Returns:
A Tensor representing the log of the summed exponentials of `x`.
"""
return _inf_max_helper(torch.logsumexp, x=x, dim=dim, keepdim=keepdim)


def _inf_max_helper(
max_fun: Callable[[Tensor], Tensor],
x: Tensor,
dim: Union[int, Tuple[int, ...]],
keepdim: bool,
) -> Tensor:
"""Helper function that generalizes the treatment of infinities for approximations
to the maximum operator, i.e., `max(X, dim, keepdim)`. At the point of writing of
this function, it is used to define `logsumexp` and `fatmax`.
Args:
max_fun: The function that is used to smoothly penalize the difference of an
element to the true maximum.
x: The Tensor on which to compute the smooth approximation to the maximum.
dim: The dimension(s) to reduce over.
keepdim: Whether to keep the reduced dimension. Defaults to False.
Returns:
The Tensor representing the smooth approximation to the maximum over the
specified dimensions.
"""
M = x.amax(dim=dim, keepdim=True)
is_inf_max = torch.logical_and(*torch.broadcast_tensors(M.isinf(), x == M))
has_inf_max = _any(is_inf_max, dim=dim, keepdim=True)

y_inf = x.masked_fill(~is_inf_max, 0.0)
M_no_inf = M.masked_fill(M.isinf(), 0.0)
y_no_inf = x.masked_fill(has_inf_max, 0.0) - M_no_inf

res = torch.where(
has_inf_max,
y_inf.sum(dim=dim, keepdim=True),
M_no_inf + max_fun(y_no_inf, dim=dim, keepdim=True),
)
return res if keepdim else res.squeeze(dim)


def _any(x: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False) -> Tensor:
"""Extension of torch.any, which supports reducing over tuples of dimensions.
Args:
x: The Tensor to reduce over.
dim: An integer or a tuple of integers, representing the dimensions to reduce.
keepdim: Whether to keep the reduced dimensions. Defaults to False.
Returns:
The Tensor corresponding to `any` over the specified dimensions.
"""
if isinstance(dim, Tuple):
for d in dim:
x = x.any(dim=d, keepdim=True)
else:
x = x.any(dim, keepdim=True)
return x if keepdim else x.squeeze(dim)


def logmeanexp(
Expand All @@ -124,10 +216,10 @@ def logmeanexp(
A Tensor of values corresponding to `log(mean(exp(X), dim=dim))`.
"""
n = X.shape[dim] if isinstance(dim, int) else math.prod(X.shape[i] for i in dim)
return torch.logsumexp(X, dim=dim, keepdim=keepdim) - math.log(n)
return logsumexp(X, dim=dim, keepdim=keepdim) - math.log(n)


def log_softplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
def log_softplus(x: Tensor, tau: Union[float, Tensor] = TAU) -> Tensor:
"""Computes the logarithm of the softplus function with high numerical accuracy.
Args:
Expand All @@ -151,7 +243,12 @@ def log_softplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
)


def smooth_amax(X: Tensor, tau: Union[float, Tensor] = 1e-3, dim: int = -1) -> Tensor:
def smooth_amax(
X: Tensor,
dim: Union[int, Tuple[int, ...]] = -1,
keepdim: bool = False,
tau: Union[float, Tensor] = 1.0,
) -> Tensor:
"""Computes a smooth approximation to `max(X, dim=dim)`, i.e the maximum value of
`X` over dimension `dim`, using the logarithm of the `l_(1/tau)` norm of `exp(X)`.
Note that when `X = log(U)` is the *logarithm* of an acquisition utility `U`,
Expand All @@ -160,14 +257,16 @@ def smooth_amax(X: Tensor, tau: Union[float, Tensor] = 1e-3, dim: int = -1) -> T
Args:
X: A Tensor from which to compute the smoothed amax.
dim: The dimensions to reduce over.
keepdim: If True, keeps the reduced dimensions.
tau: Temperature parameter controlling the smooth approximation
to max operator, becomes tighter as tau goes to 0. Needs to be positive.
Returns:
A Tensor of smooth approximations to `max(X, dim=dim)`.
"""
# consider normalizing by log_n = math.log(X.shape[dim]) to reduce error
return torch.logsumexp(X / tau, dim=dim) * tau # ~ X.amax(dim=dim)
return logsumexp(X / tau, dim=dim, keepdim=keepdim) * tau # ~ X.amax(dim=dim)


def check_dtype_float32_or_float64(X: Tensor) -> None:
Expand All @@ -177,7 +276,7 @@ def check_dtype_float32_or_float64(X: Tensor) -> None:
)


def log_fatplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
def log_fatplus(x: Tensor, tau: Union[float, Tensor] = TAU) -> Tensor:
"""Computes the logarithm of the fat-tailed softplus.
NOTE: Separated out in case the complexity of the `log` implementation increases
Expand All @@ -186,14 +285,15 @@ def log_fatplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
return fatplus(x, tau=tau).log()


def fatplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
def fatplus(x: Tensor, tau: Union[float, Tensor] = TAU) -> Tensor:
"""Computes a fat-tailed approximation to `ReLU(x) = max(x, 0)` by linearly
combining a regular softplus function and the density function of a Cauchy
distribution. The coefficient `alpha` of the Cauchy density is chosen to guarantee
monotonicity and convexity.
Args:
x: A Tensor on whose values to compute the smoothed function.
tau: Temperature parameter controlling the smoothness of the approximation.
Returns:
A Tensor of values of the fat-tailed softplus.
Expand All @@ -206,25 +306,77 @@ def _fatplus(x: Tensor) -> Tensor:
return tau * _fatplus(x / tau)


def fatmax(X: Tensor, dim: int, tau: Union[float, Tensor] = 1.0) -> Tensor:
def fatmax(
x: Tensor,
dim: Union[int, Tuple[int, ...]],
keepdim: bool = False,
tau: Union[float, Tensor] = TAU,
alpha: float = ALPHA,
) -> Tensor:
"""Computes a smooth approximation to amax(X, dim=dim) with a fat tail.
Args:
X: A Tensor from which to compute the smoothed amax.
dim: The dimensions to reduce over.
keepdim: If True, keeps the reduced dimensions.
tau: Temperature parameter controlling the smooth approximation
to max operator, becomes tighter as tau goes to 0. Needs to be positive.
standardize: Toggles the temperature standardization of the smoothed function.
alpha: The exponent of the asymptotic power decay of the approximation. The
default value is 2. Higher alpha parameters make the function behave more
similarly to the standard logsumexp approximation to the max, so it is
recommended to keep this value low or moderate, e.g. < 10.
Returns:
A Tensor of smooth approximations to `max(X, dim=dim)` with a fat tail.
"""
if X.shape[dim] == 1:
return X.squeeze(dim)

M = X.amax(dim=dim, keepdim=True)
Y = (X - M) / tau # NOTE: this would cause NaNs when X has Infs.
M = M.squeeze(dim)
return M + tau * cauchy(Y).sum(dim=dim).log() # could change to mean
def max_fun(
x: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False
) -> Tensor:
return tau * _pareto(-x / tau, alpha=alpha).sum(dim=dim, keepdim=keepdim).log()

return _inf_max_helper(max_fun=max_fun, x=x, dim=dim, keepdim=keepdim)


def fatmaximum(
a: Tensor, b: Tensor, tau: Union[float, Tensor] = TAU, alpha: float = ALPHA
) -> Tensor:
"""Computes a smooth approximation to torch.maximum(a, b) with a fat tail.
Args:
a: The first Tensor from which to compute the smoothed component-wise maximum.
b: The second Tensor from which to compute the smoothed component-wise maximum.
tau: Temperature parameter controlling the smoothness of the approximation. A
smaller tau corresponds to a tighter approximation that leads to a sharper
objective landscape that might be more difficult to optimize.
Returns:
A smooth approximation of torch.maximum(a, b).
"""
return fatmax(
torch.stack(torch.broadcast_tensors(a, b), dim=-1),
dim=-1,
keepdim=False,
tau=tau,
)


def fatminimum(
a: Tensor, b: Tensor, tau: Union[float, Tensor] = TAU, alpha: float = ALPHA
) -> Tensor:
"""Computes a smooth approximation to torch.minimum(a, b) with a fat tail.
Args:
a: The first Tensor from which to compute the smoothed component-wise minimum.
b: The second Tensor from which to compute the smoothed component-wise minimum.
tau: Temperature parameter controlling the smoothness of the approximation. A
smaller tau corresponds to a tighter approximation that leads to a sharper
objective landscape that might be more difficult to optimize.
Returns:
A smooth approximation of torch.minimum(a, b).
"""
return -fatmaximum(-a, -b, tau=tau, alpha=alpha)


def log_fatmoid(X: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
Expand Down Expand Up @@ -259,6 +411,33 @@ def cauchy(x: Tensor) -> Tensor:
return 1 / (1 + x.square())


def _pareto(x: Tensor, alpha: float, check: bool = True) -> Tensor:
"""Computes a rational polynomial that is
1) monotonically decreasing for `x > 0`,
2) is equal to 1 at `x = 0`,
3) has a first and second derivative of 1 at `x = 0`, and
4) has an asymptotic decay of `O(1 / x^alpha)`.
These properties make it possible to use the function to define a smooth and
fat-tailed approximation to the maximum, which enables better gradient propagation,
see `fatmax` for details.
Args:
x: The input tensor.
alpha: The exponent of the asymptotic decay.
check: Whether to check if the input tensor only contains non-negative values.
Returns:
The tensor corresponding to the rational polynomial with the stated properties.
"""
if check and (x < 0).any():
raise ValueError("Argument `x` must be non-negative.")
alpha = alpha / 2 # so that alpha stands for the power decay
# choosing beta_0, beta_1 so that first and second derivatives at x = 0 are 1.
beta_1 = 2 * alpha
beta_0 = alpha * beta_1
return (beta_0 / (beta_0 + beta_1 * x + x.square())).pow(alpha)


def sigmoid(X: Tensor, log: bool = False, fat: bool = False) -> Tensor:
"""A sigmoid function with an optional fat tail and evaluation in log space for
better numerical behavior. Notably, the fat-tailed sigmoid can be used to remedy
Expand Down
Loading

0 comments on commit 9649b1c

Please sign in to comment.