Skip to content

Commit

Permalink
Revert "address an issue with triton auto-tuner and in-place calls. m…
Browse files Browse the repository at this point in the history
…ake the assumption that after the first optimizer.step call, things are properly cached"

This reverts commit 6ab873a.
  • Loading branch information
lucidrains committed May 10, 2023
1 parent 0781eb1 commit 2671a69
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 39 deletions.
15 changes: 1 addition & 14 deletions lion_pytorch/lion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def __init__(
super().__init__(params, defaults)

self.update_fn = update_fn
self.use_triton = use_triton
self.took_first_step = False

if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
Expand All @@ -65,13 +63,6 @@ def step(
with torch.enable_grad():
loss = closure()

# address an issue with autotune and in-place updates with triton
# on the first .step call, simply do not update parameters in-place, if using triton

update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict()

# update all parameters

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

Expand All @@ -91,11 +82,7 @@ def step(
lr,
wd,
beta1,
beta2,
**update_kwargs
beta2
)

if not self.took_first_step:
self.took_first_step = True

return loss
28 changes: 4 additions & 24 deletions lion_pytorch/triton.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torch import Tensor

try:
import triton
Expand All @@ -8,7 +7,6 @@
print('triton is not installed, please install by running `pip install triton -U --pre`')
exit()

# triton cuda kernel

@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
Expand Down Expand Up @@ -74,31 +72,19 @@ def update_fn_kernel(
tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)

def update_fn(
p: Tensor,
grad: Tensor,
exp_avg: Tensor,
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
lr: float,
wd: float,
beta1: float,
beta2: float,
inplace: bool = True
beta2: float
):
assert all([t.is_cuda for t in (p, grad, exp_avg)])
n_elements = p.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

# address autotune and in-place update issue

if not inplace:
orig_p = p
orig_exp_avg = exp_avg

p = p.clone()
exp_avg = exp_avg.clone()

# call triton cuda kernel

update_fn_kernel[grid](
p,
grad,
Expand All @@ -109,9 +95,3 @@ def update_fn(
beta2,
n_elements
)

# update if not in-place call

if not inplace:
orig_p.copy_(p)
orig_exp_avg.copy_(exp_avg)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lion-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.7',
license='MIT',
description = 'Lion Optimizer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2671a69

Please sign in to comment.