Skip to content

Commit

Permalink
Merge pull request #325 from kozistr/update/codes
Browse files Browse the repository at this point in the history
[Feature] Implement `TAM`, `AdaTAM` optimizers
  • Loading branch information
kozistr authored Jan 19, 2025
2 parents a9fb8a2 + 59e8736 commit 55c3553
Show file tree
Hide file tree
Showing 25 changed files with 309 additions and 39 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **92 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -199,6 +199,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |

## Supported LR Scheduler

Expand Down
2 changes: 2 additions & 0 deletions docs/changelogs/v3.3.4.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
* `Lookahead(AdamW, k=5, alpha=0.5, params=model.parameters())`
* Implement `SPAM` optimizer. (#324)
* [Spike-Aware Adam with Momentum Reset for Stable LLM Training](https://arxiv.org/abs/2501.06842)
* Implement `TAM`, and `AdaTAM` optimizers. (#325)
* [Torque-Aware Momentum](https://arxiv.org/abs/2412.18790)
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **92 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -199,6 +199,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |

## Supported LR Scheduler

Expand Down
8 changes: 8 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,14 @@
:docstring:
:members:

::: pytorch_optimizer.TAM
:docstring:
:members:

::: pytorch_optimizer.AdaTAM
:docstring:
:members:

::: pytorch_optimizer.Tiger
:docstring:
:members:
Expand Down
4 changes: 4 additions & 0 deletions docs/qa.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@
## Q2) Memory leak happens when using SophiaH, AdaHessian optimizers.

`torch.autograd.grad` with complex gradient flows sometimes leads memory leak issues, and you might encounter OOM issue. [related issue](https://github.com/kozistr/pytorch_optimizer/issues/278)

## Q3) How to run visualizations?

Run `python3 -m examples.visualize_optimizers` on the project root.
32 changes: 32 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaSmooth.png)

### AdaTAM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaTAM.png)

### AdEMAMix

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdEMAMix.png)
Expand Down Expand Up @@ -254,6 +258,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Ranger21.png)

### Ranger25

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Ranger25.png)

### ScalableShampoo

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScalableShampoo.png)
Expand Down Expand Up @@ -306,6 +314,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SophiaH.png)

### SPAM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SPAM.png)

### SRMM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SRMM.png)
Expand All @@ -318,6 +330,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SWATS.png)

### TAM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_TAM.png)

### Tiger

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Tiger.png)
Expand Down Expand Up @@ -408,6 +424,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaSmooth.png)

### AdaTAM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaTAM.png)

### AdEMAMix

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdEMAMix.png)
Expand Down Expand Up @@ -580,6 +600,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Ranger21.png)

### Ranger25

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Ranger25.png)

### ScalableShampoo

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScalableShampoo.png)
Expand Down Expand Up @@ -632,6 +656,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SophiaH.png)

### SPAM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SPAM.png)

### SRMM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SRMM.png)
Expand All @@ -644,6 +672,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SWATS.png)

### TAM

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_TAM.png)

### Tiger

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Tiger.png)
Expand Down
Binary file added docs/visualizations/rastrigin_AdaTAM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rastrigin_Ranger25.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rastrigin_SPAM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rastrigin_TAM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_AdaTAM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_Ranger25.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_SPAM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_TAM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "3.3.3"
version = "3.3.4"
description = "optimizer & lr scheduler & objective function collections in PyTorch"
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand All @@ -18,8 +18,8 @@ keywords = [
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "Tiger",
"TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "TAM",
"Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
]
classifiers = [
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
SPAM,
SRMM,
SWATS,
TAM,
TRAC,
WSAM,
A2Grad,
Expand All @@ -88,6 +89,7 @@
AdaPNM,
AdaShift,
AdaSmooth,
AdaTAM,
AdEMAMix,
AggMo,
Aida,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from pytorch_optimizer.optimizer.spam import SPAM
from pytorch_optimizer.optimizer.srmm import SRMM
from pytorch_optimizer.optimizer.swats import SWATS
from pytorch_optimizer.optimizer.tam import TAM, AdaTAM
from pytorch_optimizer.optimizer.tiger import Tiger
from pytorch_optimizer.optimizer.trac import TRAC
from pytorch_optimizer.optimizer.yogi import Yogi
Expand Down Expand Up @@ -252,6 +253,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
SRMM,
AvaGrad,
AdaShift,
TAM,
AdaTAM,
AdaDelta,
Amos,
AdaHessian,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class AdaFactor(BaseOptimizer):
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared
hessian trace. if beta1 is None, first momentum will be skipped.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
if beta1 is None, first momentum will be skipped.
:param decay_rate: float. coefficient used to compute running averages of square gradient.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
Expand Down
7 changes: 4 additions & 3 deletions pytorch_optimizer/optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
inv_de_nom = exp_avg_sq.rsqrt().add_(group['eps']).mul_(bias_correction2_sq)

perturb = exp_avg.clone()

if self.cautious:
self.apply_cautious(perturb, grad)

if group['nesterov']:
perturb.mul_(beta1).addcmul_(grad, inv_de_nom, value=1.0 - beta1)
else:
Expand Down Expand Up @@ -173,9 +177,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
bias_correction1=bias_correction1,
)

if self.cautious:
self.apply_cautious(perturb, grad)

p.add_(perturb, alpha=-step_size)

return loss
3 changes: 1 addition & 2 deletions pytorch_optimizer/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg.lerp_(grad, weight=beta1_comp)
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat)

rms = self.get_stable_adamw_rms(grad, exp_avg_sq, eps=eps_p2)
lr = group['lr'] / rms
lr: float = group['lr'] / self.get_stable_adamw_rms(grad, exp_avg_sq, eps=eps_p2)

self.apply_weight_decay(
p,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class SGDSaI(BaseOptimizer):
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param momentum: float. coefficients used for computing running averages of gradient.
:param momentum: float. coefficients used for computing running averages of gradient.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param eps: float. term added to the denominator to improve numerical stability.
Expand All @@ -423,7 +423,7 @@ def __init__(
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_range(momentum, 'beta', 0.0, 1.0)
self.validate_range(momentum, 'momentum', 0.0, 1.0)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

Expand Down
39 changes: 15 additions & 24 deletions pytorch_optimizer/optimizer/spam.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self.validate_non_negative(density, 'density')
self.validate_non_negative(threshold, 'threshold')
self.validate_non_negative(grad_accu_steps, 'grad_accu_steps')
self.validate_non_negative(update_proj_gap, 'update_proj_gap')
self.validate_positive(update_proj_gap, 'update_proj_gap')
self.validate_non_negative(eps, 'eps')

self.density = density
Expand All @@ -91,41 +91,32 @@ def __init__(
self.update_proj_gap = update_proj_gap
self.warmup = CosineDecay(0.99, warmup_epoch)

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'eps': eps,
**kwargs,
}
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}
super().__init__(params, defaults)

self.init_masks()

self.state['total_step'] = 0
self.state['current_step'] = warmup_epoch + 1
self.state['current_step'] = self.warmup_epoch + 1

@staticmethod
def initialize_random_rank_boolean_tensor(m: int, n: int, density: float) -> torch.Tensor:
def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device: torch.device) -> torch.Tensor:
r"""Create an (m x n) boolean tensor with `density` fraction of True entries.
:param m: int. number of rows.
:param n: int. number of columns.
:param density: float. fraction of True entries. 1.0 means all True.
:param device: torch.device. device.
"""
total_elements: int = m * n
non_zero_count: int = int(density * total_elements)

tensor = torch.zeros((m, n), dtype=torch.bool)
tensor = torch.zeros(total_elements, dtype=torch.bool, device=device)

if non_zero_count == 0:
return tensor
if non_zero_count > 0:
tensor[torch.randperm(total_elements, device=device)[:non_zero_count]] = True

indices = torch.randperm(total_elements)[:non_zero_count]
rows, cols = indices // n, indices % n
tensor[rows, cols] = True

return tensor
return tensor.view(m, n)

def update_mask_random(self, density: float, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
r"""Update a random mask.
Expand Down Expand Up @@ -164,9 +155,8 @@ def update_masks(self) -> None:
for p in group['params']:
state = self.state[p]
if 'mask' in state:
new_mask = self.update_mask_random(self.density, p, state['mask'])
state['mask'] = new_mask
p.mask = new_mask
state['mask'] = self.update_mask_random(self.density, p, state['mask'])
p.mask = state['mask']

def init_masks(self) -> None:
r"""Initialize random masks for each parameter group that has 'density'."""
Expand All @@ -175,10 +165,11 @@ def init_masks(self) -> None:
state = self.state[p]
if p.dim() == 2 and 'mask' not in state:
state['mask'] = self.initialize_random_rank_boolean_tensor(
p.shape[0],
p.shape[1],
m=p.shape[0],
n=p.shape[1],
density=self.density,
).to(p.device)
device=p.device,
)

def __str__(self) -> str:
return 'SPAM'
Expand Down
Loading

0 comments on commit 55c3553

Please sign in to comment.