Jaerin Lee* · Bong Gyun Kang* · Kihoon Kim · Kyoung Mu Lee
*Denotes equal contribution.
tl;dr: We accelerate the grokking phenomenon by amplifying low-frequencies of the parameter gradients with an augmented optimizer.
Abstract:
One puzzling artifact in machine learning dubbed grokking is where delayed generalization is achieved tenfolds of iterations after near perfect overfitting to the training data.
Focusing on the long delay itself on behalf of machine learning practitioners, our goal is to accelerate generalization of a model under grokking phenomenon.
By regarding a series of gradients of a parameter over training iterations as a random signal over time, we can spectrally decompose the parameter trajectories under gradient descent into two components: the fast-varying, overfitting-yielding component and the slow-varying, generalization-inducing component.
This analysis allows us to accelerate the grokking phenomenon more than
Grokfast doesn't require additional packages except for PyTorch. The file requirements.txt
is only for reproduction of the experiments in the article, as described in the Reproduction section below.
Grokfast can be applied by inserting a single line before the optimizer call.
- Download a single file
grokfast.py
from our repository.
wget https://raw.githubusercontent.com/ironjr/grokfast/main/grokfast.py
- Import the helper function.
from grokfast import gradfilter_ma, gradfilter_ema
- Insert the following line before the training loop.
grads = None
- Between
loss.backward()
andoptimizer.step()
, insert one of the following line. Make suremodel
is of typenn.Module
andgrads
are initialized properly before the training loop:
# ... in the optimization loop.
loss.backwards() # Calculate the gradients.
### Option 1: Grokfast (has argument alpha, lamb)
grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb)
### Option 2: Grokfast-MA (has argument window_size, lamb)
grads = gradfilter_ma(model, grads=grads, window_size=window_size, lamb=lamb)
optimizer.step() # Call the optimizer.
# ... logging & other codes.
Done!
(2-1) ...or, copy and paste the method directly into your code!
### Imports
from collections import deque
from typing import Dict, Optional, Literal
import torch
import torch.nn as nn
### Grokfast
def gradfilter_ema(
m: nn.Module,
grads: Optional[Dict[str, torch.Tensor]] = None,
alpha: float = 0.99,
lamb: float = 5.0,
) -> Dict[str, torch.Tensor]:
if grads is None:
grads = {n: p.grad.data.detach() for n, p in m.named_parameters() if p.requires_grad}
for n, p in m.named_parameters():
if p.requires_grad:
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
p.grad.data = p.grad.data + grads[n] * lamb
return grads
### Grokfast-MA
def gradfilter_ma(
m: nn.Module,
grads: Optional[Dict[str, deque]] = None,
window_size: int = 128,
lamb: float = 5.0,
filter_type: Literal['mean', 'sum'] = 'mean',
warmup: bool = True,
trigger: bool = False,
) -> Dict[str, deque]:
if grads is None:
grads = {n: deque(maxlen=window_size) for n, p in m.named_parameters() if p.requires_grad}
for n, p in m.named_parameters():
if p.requires_grad:
grads[n].append(p.grad.data.detach())
if not warmup or len(grads[n]) == window_size and not trigger:
if filter_type == "mean":
avg = sum(grads[n]) / len(grads[n])
elif filter_type == "sum":
avg = sum(grads[n])
else:
raise ValueError(f"Unrecognized filter_type {filter_type}")
p.grad.data = p.grad.data + avg * lamb
return grads
-
Grokfast (
gradfilter_ema
)m: nn.Module
: Model that contains every trainable parameters.grads: Optional[Dict[str, torch.Tensor]] = None
: Running memory (EMA). Initialize by setting it toNone
. Feed the output of the method recursively after on.alpha: float = 0.98
: Momentum hyperparmeter of the EMA.lamb: float = 2.0
: Amplifying factor hyperparameter of the filter.
-
Grokfast-MA (
gradfilter_ma
)m: nn.Module
: Model that contains every trainable parameters.grads: Optional[Dict[str, deque]] = None
: Running memory (Queue for windowed moving average). Initialize by setting it toNone
. Feed the output of the method recursively after on.window_size: int = 100
: The width of the filter window. Additional memory requirements increases linearly with respect to the windows size.lamb: float = 5.0
: Amplifying factor hyperparameter of the filter.filter_type: Literal['mean', 'sum'] = 'mean'
: Aggregation method for the running queue.warmup: bool = True
: If true, filter is not applied until the queue is filled.trigger: bool = False
: For ablation study only. If true, the filter is simply not applied.
We also note the additional computational resources required for each run. Time & memory costs are measured with a single GTX 1080 Ti GPU.
This will install the additional packages to preprocess each data and to summarize the results.
conda create -n grok python=3.10 && conda activate grok
git clone https://github.com/ironjr/grokfast
pip install -r requirements.txt
Run | Iterations to Reach 95% Val. Acc. | Wall Clock Time to Reach 95% Val. Acc. (s) | VRAM Requirements (MB) | Latency Per Iteration (s) |
---|---|---|---|---|
Baseline | ||||
Grokfast-MA |
# python main.py --label test # Baseline.
python main.py --label test --filter ma --window_size 100 --lamb 5.0 --weight_decay 0.01
Run | Iterations to Reach 95% Val. Acc. | Wall Clock Time to Reach 95% Val. Acc. (s) | VRAM Requirements (MB) | Latency Per Iteration (s) |
---|---|---|---|---|
Baseline | ||||
Grokfast |
# python main.py --label test # Baseline.
python main.py --label test --filter ema --alpha 0.98 --lamb 2.0 --weight_decay 0.005
Run | Iterations to Reach 95% Val. Acc. | Wall Clock Time to Reach 95% Val. Acc. (s) | VRAM Requirements (MB) | Latency Per Iteration (ms) |
---|---|---|---|---|
Baseline | ||||
Grokfast |
# python main_mnist.py --label test # Baseline.
python main_mnist.py --label test --alpha 0.8 --lamb 0.1 --weight_decay 2.0
Run | Best Validation Acc. | Minimum Validation Loss | VRAM Requirements (MB) | Latency Per Iteration (ms) |
---|---|---|---|---|
Baseline | ||||
Grokfast |
- Before training, download the IMDb dataset from Google Drive.
# python main_imdb.py --label test # Baseline.
python main_imdb.py --label test --alpha 0.98 --lamb 2.0 --weight_decay 10.0
Run | Minimum Validation Loss | VRAM Requirements (MB) | Latency Per Iteration (ms) |
---|---|---|---|
Baseline | |||
Grokfast |
# python main_qm9.py --label test # Baseline.
python main_qm9.py --label test --alpha 0.9 --lamb 1.0 --weight_decay 0.01
Our code is heavily based on the following projects:
- Ziming Liu et al., "Omnigrok: Grokking Beyond Algorithmic Data," ICLR 2023. [arXiv] [code]
- Alethea Power et al., "Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets," arXiv preprint arXiv:2201.02177. [arXiv] [code]
- @danielmamay's Re-implementation of Grokking. [code]
Thank you all for providing useful references!
Please cite us if you find our project useful!
@article{lee2024grokfast,
title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},
author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},
journal={arXiv preprint arXiv:2405.20233},
year={2024}
}
If you have any questions, please email [email protected]
.