Skip to content

Commit

Permalink
attempt to fix autotune + inplace update issue
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 10, 2023
1 parent 2671a69 commit 3d1e555
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
12 changes: 10 additions & 2 deletions lion_pytorch/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@
print('triton is not installed, please install by running `pip install triton -U --pre`')
exit()

# clone param and exp_avg before autotuning takes place
# as those are updated in-place

def clone_inplace_updated_params(nargs):
nargs['p_ptr'] = nargs['p_ptr'].clone()
nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone()

# triton cuda kernel

@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])
@triton.jit
def update_fn_kernel(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lion-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.1.2',
license='MIT',
description = 'Lion Optimizer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 3d1e555

Please sign in to comment.