diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py index e025f89..f96fec6 100644 --- a/lion_pytorch/lion_pytorch.py +++ b/lion_pytorch/lion_pytorch.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple, Optional, Callable import torch @@ -33,7 +34,8 @@ def __init__( lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, - use_triton: bool = False + use_triton: bool = False, + triton_block_size: int = 1024 ): assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) @@ -52,7 +54,7 @@ def __init__( if use_triton: from lion_pytorch.triton import update_fn as triton_update_fn - self.update_fn = triton_update_fn + self.update_fn = partial(triton_update_fn, BLOCK_SIZE = triton_block_size) @torch.no_grad() def step( @@ -65,11 +67,6 @@ def step( with torch.enable_grad(): loss = closure() - # address an issue with autotune and in-place updates with triton - # on the first .step call, simply do not update parameters in-place, if using triton - - update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict() - # update all parameters for group in self.param_groups: @@ -91,11 +88,7 @@ def step( lr, wd, beta1, - beta2, - **update_kwargs + beta2 ) - if not self.took_first_step: - self.took_first_step = True - return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index ca59800..0615dbc 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -8,12 +8,18 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() +# helper functions + +def calc_num_warps(block_size): + num_warps = 4 + if block_size >= 2048: + num_warps = 8 + if block_size >= 4096: + num_warps = 16 + return num_warps + # triton cuda kernel -@triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), -], key = ['n_elements']) @triton.jit def update_fn_kernel( p_ptr, @@ -81,25 +87,20 @@ def update_fn( wd: float, beta1: float, beta2: float, - inplace: bool = True + inplace: bool = True, + BLOCK_SIZE: int = 1024 ): assert all([t.is_cuda for t in (p, grad, exp_avg)]) - n_elements = p.numel() - - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - # address autotune and in-place update issue - - if not inplace: - orig_p = p - orig_exp_avg = exp_avg + n_elements = p.numel() - p = p.clone() - exp_avg = exp_avg.clone() + block_size = triton.next_power_of_2(BLOCK_SIZE) + num_warps = calc_num_warps(block_size) + n_rows = triton.cdiv(n_elements, block_size) # call triton cuda kernel - update_fn_kernel[grid]( + update_fn_kernel[(n_rows,)]( p, grad, exp_avg, @@ -107,11 +108,7 @@ def update_fn( wd, beta1, beta2, - n_elements + n_elements, + num_warps = num_warps, + BLOCK_SIZE = BLOCK_SIZE ) - - # update if not in-place call - - if not inplace: - orig_p.copy_(p) - orig_exp_avg.copy_(exp_avg) diff --git a/setup.py b/setup.py index 1eace6d..35eb081 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.0.8', + version = '0.1.0', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang',