From 6ff45486f432f91eb86937a0def5eb5f2cf792ae Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 28 Sep 2023 09:23:39 +0800 Subject: [PATCH] Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd (#1715) * input grad checks out * adding clamp gamma * Both old and proposed implementation checks out * 2 tests not yet passed due to numerical issues * mem_eff works * fast-layer-norm done * Moving mem-eff to templates * Relax tolerance for memory efficient backward * Fix backward api of python --- apex/contrib/csrc/layer_norm/ln.h | 33 +- apex/contrib/csrc/layer_norm/ln_api.cpp | 53 +-- .../csrc/layer_norm/ln_bwd_kernels.cuh | 23 +- apex/contrib/layer_norm/layer_norm.py | 32 +- .../test/layer_norm/test_fast_layer_norm.py | 26 +- apex/normalization/fused_layer_norm.py | 142 +++++--- csrc/layer_norm_cuda.cpp | 93 +++-- csrc/layer_norm_cuda_kernel.cu | 342 ++++++++++++------ csrc/static_switch.h | 25 ++ .../test_fused_layer_norm.py | 127 ++++--- 10 files changed, 560 insertions(+), 336 deletions(-) create mode 100644 csrc/static_switch.h diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h index 6ab709b09..cf0355c07 100644 --- a/apex/contrib/csrc/layer_norm/ln.h +++ b/apex/contrib/csrc/layer_norm/ln.h @@ -10,7 +10,7 @@ namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct LaunchParams{ size_t workspace_bytes; @@ -26,17 +26,20 @@ struct LaunchParams{ //////////////////////////////////////////////////////////////////////////////////////////////////// -struct ParamsBase { - ParamsBase() +struct FwdParams{ + FwdParams() : ctas_per_col(0) , rows(0) , cols(0) , x(nullptr) + , z(nullptr) , mu(nullptr) , rs(nullptr) , gamma(nullptr) + , beta(nullptr) , workspace(nullptr) , barrier(nullptr) + , epsilon(0.f) { } @@ -49,9 +52,11 @@ struct ParamsBase { // Common data pointers. void *x; + void *z; void *mu; void *rs; void *gamma; + void *beta; // Multi-CTA workspace in gmem. void *workspace; @@ -59,31 +64,15 @@ struct ParamsBase { // Multi-CTA sync barriers in gmem. int *barrier; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() - : ParamsBase() - , z(nullptr) - , beta(nullptr) - , epsilon(0.f) - { - } - // Output of LN FWD. - void *z; - void *beta; float epsilon; - }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct BwdParams : public ParamsBase { +struct BwdParams : public FwdParams{ BwdParams() - : ParamsBase() + : FwdParams() , dz(nullptr) , dbeta_part(nullptr) , dgamma_part(nullptr) @@ -92,7 +81,6 @@ struct BwdParams : public ParamsBase { , dgamma(nullptr) { } - // Input: gradient wrt. LN FWD output. void *dz; @@ -200,3 +188,4 @@ struct BwdRegistrar{ //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm + diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 30e4a5fec..54f04e201 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -130,12 +130,12 @@ std::vector ln_fwd(const at::Tensor &x, // BxSxhidden_size layer_norm::FwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; - params.x = x.data_ptr(); + params.z = z.data_ptr(); params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.beta = beta.data_ptr(); - params.z = z.data_ptr(); + params.x = x.data_ptr(); params.epsilon = epsilon; if( launch_params.barrier_size > 0 ) { @@ -153,33 +153,39 @@ std::vector ln_fwd(const at::Tensor &x, // BxSxhidden_size } //////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size - const at::Tensor &x, // BxSxhidden_size - const at::Tensor &mu, // BxS, FP32! - const at::Tensor &rsigma, // BxS, FP32! - const at::Tensor &gamma // hidden_size +std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size + const at::Tensor &x_or_z, // BxSxhidden_size + c10::optional &mu_, // BxS, FP32! + const at::Tensor &rsigma, // BxS, FP32! + const at::Tensor &gamma, // hidden_size + c10::optional&beta_, // hidden_size + bool memory_efficient ) { - auto itype = x.scalar_type(); + auto itype = x_or_z.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; auto ctype = torch::kFloat32; TORCH_CHECK(dz.dtype() == otype); - TORCH_CHECK(mu.dtype() == ctype); TORCH_CHECK(rsigma.dtype() == ctype); + if (mu_.has_value()) { + TORCH_CHECK(mu_.value().dtype() == ctype); + } - TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x_or_z.is_cuda()); TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(mu.is_cuda()); TORCH_CHECK(rsigma.is_cuda()); TORCH_CHECK(gamma.is_cuda()); + if (beta_.has_value()) { + TORCH_CHECK(beta_.value().is_cuda()); + TORCH_CHECK(beta_.value().dtype() == wtype); + } - TORCH_CHECK(x.is_contiguous()); + TORCH_CHECK(x_or_z.is_contiguous()); TORCH_CHECK(dz.is_contiguous()); - auto sizes = x.sizes(); + auto sizes = x_or_z.sizes(); TORCH_CHECK(sizes.size() == 2); TORCH_CHECK(dz.sizes() == sizes); auto rows = sizes[0]; @@ -187,14 +193,14 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size auto hidden_size = gamma.numel(); - TORCH_CHECK(mu.numel() == rows); - TORCH_CHECK(mu.sizes() == rsigma.sizes()); - TORCH_CHECK(gamma.numel() == cols); + if (beta_.has_value()) { + TORCH_CHECK(beta_.value().numel() == cols); + } - auto options = x.options(); + auto options = x_or_z.options(); - auto dx = torch::empty_like(x); + auto dx = torch::empty_like(x_or_z); auto dgamma = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma); @@ -213,8 +219,13 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size layer_norm::BwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; - params.x = x.data_ptr(); - params.mu = mu.data_ptr(); + if (memory_efficient) { + params.z = x_or_z.data_ptr(); + params.beta = beta_.value().data_ptr(); + } else { + params.x = x_or_z.data_ptr(); + params.mu = mu_.value().data_ptr(); + } params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.dz = dz.data_ptr(); diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index 8595f5ed4..019764a38 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -57,10 +57,14 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { constexpr float rn = 1.f / float(COLS); Wvec gamma[LDGS]; + Wvec beta[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { gamma[it].load_from(params.gamma, idx); + if (params.z != nullptr) { + beta[it].load_from(params.beta, idx); + } idx += Ktraits::VEC_COLS_PER_LDG; } // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the @@ -68,15 +72,19 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { // grid stride over rows #pragma unroll 1 for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t mu_r = static_cast(params.mu)[row]; + const compute_t mu_r = params.z == nullptr ? static_cast(params.mu)[row] : 0.f; const compute_t rs_r = static_cast(params.rs)[row]; - Ivec x[LDGS]; + Ivec x_or_z[LDGS]; Ovec dz[LDGS]; index_t idx = row * Ktraits::VEC_COLS + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dz[it].load_from(params.dz, idx); - x[it].load_from(params.x, idx); + if (params.z != nullptr) { + x_or_z[it].load_from(params.z, idx); + } else { + x_or_z[it].load_from(params.x, idx); + } idx += Ktraits::VEC_COLS_PER_LDG; } @@ -89,10 +97,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { for( int it = 0; it < LDGS; it++ ) { #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_tmp = x[it].data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp - mu_r); - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); - dy_tmp *= compute_t(dz[it].data.elt[jt]); + compute_t gamma_tmp = compute_t(gamma[it].data.elt[jt]); + compute_t beta_tmp = compute_t(beta[it].data.elt[jt]); + compute_t x_or_z_tmp = compute_t(x_or_z[it].data.elt[jt]); + compute_t y_tmp = params.z != nullptr ? (x_or_z_tmp - beta_tmp) / gamma_tmp : rs_r * (x_or_z_tmp - mu_r); + compute_t dy_tmp = compute_t(dz[it].data.elt[jt]) * gamma_tmp; compute_t dz_tmp = dz[it].data.elt[jt]; mdy_local += dy_tmp; diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py index b084b1ace..1d79c561b 100644 --- a/apex/contrib/layer_norm/layer_norm.py +++ b/apex/contrib/layer_norm/layer_norm.py @@ -7,40 +7,44 @@ class FastLayerNormFN(torch.autograd.Function): @staticmethod - def forward(ctx, x, gamma, beta, epsilon): + def forward(ctx, x, gamma, beta, epsilon, memory_efficient): + ctx.x_shape = x.shape + ctx.memory_efficient = memory_efficient + x = x.contiguous() gamma = gamma.contiguous() beta = beta.contiguous() hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) - ctx.save_for_backward(x, gamma, mu, rsigma) + if ctx.memory_efficient: + ctx.save_for_backward(ymat, gamma, None, rsigma, beta) + else: + ctx.save_for_backward(xmat, gamma, mu, rsigma, None) return ymat.view(x.shape) @staticmethod def backward(ctx, dy): # assert dy.is_contiguous() dy = dy.contiguous() # this happens! - x, gamma, mu, rsigma = ctx.saved_tensors - - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dymat = dy.view(xmat.shape) - dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma) - dx = dxmat.view(x.shape) - return dx, dgamma, dbeta, None + x_or_y_mat, gamma, mu, rsigma, beta = ctx.saved_tensors + dymat = dy.view(x_or_y_mat.shape) + dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, x_or_y_mat, mu, rsigma, gamma, beta, ctx.memory_efficient) + dx = dxmat.view(ctx.x_shape) + return dx, dgamma, dbeta, None, None -def _fast_layer_norm(x, weight, bias, epsilon): - args = _cast_if_autocast_enabled(x, weight, bias, epsilon) +def _fast_layer_norm(x, weight, bias, epsilon, memory_efficient): + args = _cast_if_autocast_enabled(x, weight, bias, epsilon, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FastLayerNormFN.apply(*args) class FastLayerNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5): + def __init__(self, hidden_size, eps=1e-5, memory_efficient=False): super().__init__() self.epsilon = eps + self.memory_efficient = memory_efficient self.weight = torch.nn.Parameter(torch.empty(hidden_size)) self.bias = torch.nn.Parameter(torch.empty(hidden_size)) self.reset_parameters() @@ -50,4 +54,4 @@ def reset_parameters(self): init.zeros_(self.bias) def forward(self, x): - return _fast_layer_norm(x, self.weight, self.bias, self.epsilon) + return _fast_layer_norm(x, self.weight, self.bias, self.epsilon, self.memory_efficient) diff --git a/apex/contrib/test/layer_norm/test_fast_layer_norm.py b/apex/contrib/test/layer_norm/test_fast_layer_norm.py index 9f6ee7980..fede67e90 100644 --- a/apex/contrib/test/layer_norm/test_fast_layer_norm.py +++ b/apex/contrib/test/layer_norm/test_fast_layer_norm.py @@ -1,3 +1,4 @@ +import itertools import unittest import torch @@ -106,7 +107,7 @@ def benchmark_(S, B, hidden_size, itype, wtype, runs=100): timer.start() for r in range(runs): - dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma) + dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, z, mu, rsigma, gamma, beta, True) timer.stop() timer.sync() @@ -126,7 +127,7 @@ def benchmark_(S, B, hidden_size, itype, wtype, runs=100): ) -def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): +def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32, mem_eff=False): seed = 1243 torch.manual_seed(seed) @@ -134,7 +135,7 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): otype = wtype print("========================================================") - print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}") + print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype} Mem_Eff={mem_eff}") print("--------------------------------------------------------") x = torch.randn(S * B, hidden_size, dtype=itype, device=device) @@ -165,7 +166,10 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma) z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon) - dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma) + if mem_eff: + dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, z, mu, rs, gamma, beta, True) + else: + dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma, beta, False) re_z, mse_z = metrics(z_ref, z) re_mu, mse_mu = metrics(mu_ref, mu) @@ -184,7 +188,7 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}") def check_err(x, relerr): - tol = 1e-3 if x.dtype == torch.float16 else 5e-6 + tol = 2e-2 if x.dtype in (torch.float16, torch.bfloat16) else 5e-6 return relerr < tol return [ @@ -233,13 +237,13 @@ def test_all_configs(self): 65536, ] - for h in hidden_sizes: + for (h, mem_eff) in itertools.product(hidden_sizes, (True, False)): with self.subTest(f"hidden_size={h}"): - self.assertAll(_test_impl(256, 2, h, fp32, fp32)) - self.assertAll(_test_impl(256, 2, h, fp16, fp16)) - self.assertAll(_test_impl(256, 2, h, fp32, fp16)) - self.assertAll(_test_impl(256, 2, h, bf16, bf16)) - self.assertAll(_test_impl(256, 2, h, fp32, bf16)) + self.assertAll(_test_impl(256, 2, h, fp32, fp32, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, fp16, fp16, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, fp32, fp16, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, bf16, bf16, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, fp32, bf16, mem_eff=mem_eff)) def test_run_benchmark(self): for (S, B, hidden_size, runs) in ( diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index d99e232ae..571b8b456 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -31,172 +31,198 @@ def manual_rms_norm(input, normalized_shape, weight, eps): class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward_affine( input_, ctx.normalized_shape, weight_, bias_, ctx.eps ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, weight_, bias_, mean, invvar = ctx.saved_tensors + input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + grad_output.contiguous(), mean, invvar, input_or_output, + ctx.normalized_shape, weight_, bias_, ctx.eps, ctx.memory_efficient ) - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None class FusedRMSNormAffineFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward_affine( input_, ctx.normalized_shape, weight_, ctx.eps) - ctx.save_for_backward(input_, weight_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, weight_, invvar = ctx.saved_tensors + input_or_output, weight_, invvar = ctx.saved_tensors grad_input = grad_weight = None grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps + grad_output.contiguous(), invvar, input_or_output, + ctx.normalized_shape, weight_, ctx.eps, ctx.memory_efficient ) - return grad_input, grad_weight, None, None + return grad_input, grad_weight, None, None, None class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( input_, ctx.normalized_shape, weight_, bias_, ctx.eps ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( input_, ctx.normalized_shape, weight_, ctx.eps ) - - ctx.save_for_backward(input_, weight_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) return output class FusedLayerNormFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, normalized_shape, eps): + def forward(ctx, input, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, None, invvar) + else: + ctx.save_for_backward(input_, mean, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, mean, invvar = ctx.saved_tensors - grad_input = None + input_or_output, mean, invvar = ctx.saved_tensors grad_input = fused_layer_norm_cuda.backward( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps + grad_output.contiguous(), mean, invvar, input_or_output, + ctx.normalized_shape, ctx.eps, ctx.memory_efficient ) - return grad_input, None, None + return grad_input, None, None, None class FusedRMSNormFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, normalized_shape, eps): + def forward(ctx, input, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, invvar) + else: + ctx.save_for_backward(input_, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, invvar = ctx.saved_tensors + input_or_output, invvar = ctx.saved_tensors grad_input = None grad_input = fused_layer_norm_cuda.rms_backward( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps + grad_output.contiguous(), invvar, input_or_output, + ctx.normalized_shape, ctx.eps, ctx.memory_efficient ) - return grad_input, None, None + return grad_input, None, None, None -def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) +def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedLayerNormAffineFunction.apply(*args) -def fused_layer_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) +def fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedLayerNormFunction.apply(*args) -def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) +def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedLayerNormAffineMixedDtypesFunction.apply(*args) -def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) +def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedRMSNormAffineFunction.apply(*args) -def fused_rms_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) +def fused_rms_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedRMSNormFunction.apply(*args) -def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) +def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedRMSNormAffineMixedDtypesFunction.apply(*args) @@ -261,7 +287,7 @@ class FusedLayerNorm(torch.nn.Module): .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 """ - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False): super().__init__() global fused_layer_norm_cuda @@ -272,6 +298,7 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine + self.memory_efficient = memory_efficient if self.elementwise_affine: self.weight = Parameter(torch.empty(*normalized_shape)) self.bias = Parameter(torch.empty(*normalized_shape)) @@ -289,9 +316,11 @@ def forward(self, input): if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) if self.elementwise_affine: - return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + return fused_layer_norm_affine( + input, self.weight, self.bias, self.normalized_shape, self.eps, self.memory_efficient + ) else: - return fused_layer_norm(input, self.normalized_shape, self.eps) + return fused_layer_norm(input, self.normalized_shape, self.eps, self.memory_efficient) def extra_repr(self): return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) @@ -357,7 +386,7 @@ class FusedRMSNorm(torch.nn.Module): .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf """ - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False): super().__init__() global fused_layer_norm_cuda @@ -368,6 +397,7 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine + self.memory_efficient = memory_efficient if self.elementwise_affine: self.weight = Parameter(torch.empty(*normalized_shape)) else: @@ -383,9 +413,11 @@ def forward(self, input): return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) if self.elementwise_affine: - return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + return fused_rms_norm_affine( + input, self.weight, self.normalized_shape, self.eps, self.memory_efficient + ) else: - return fused_rms_norm(input, self.normalized_shape, self.eps) + return fused_rms_norm(input, self.normalized_shape, self.eps, self.memory_efficient) def extra_repr(self): return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) @@ -397,7 +429,7 @@ def extra_repr(self): # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class MixedFusedLayerNorm(FusedLayerNorm): - def __init__(self, normalized_shape, eps=1e-5, **kwargs): + def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs): if "elementwise_affine" in kwargs: import warnings warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument") @@ -405,13 +437,16 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs): if not elementwise_affine: raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`") - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - + super().__init__( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=True, memory_efficient=memory_efficient + ) def forward(self, input: torch.Tensor): # NOTE (mkozuki): CPU path is here mainly for unittest sake. if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + return mixed_dtype_fused_layer_norm_affine( + input, self.weight, self.bias, self.normalized_shape, self.eps, self.memory_efficient + ) # MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype @@ -419,7 +454,7 @@ def forward(self, input: torch.Tensor): # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class MixedFusedRMSNorm(FusedRMSNorm): - def __init__(self, normalized_shape, eps=1e-5, **kwargs): + def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs): if "elementwise_affine" in kwargs: import warnings warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument") @@ -427,11 +462,14 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs): if not elementwise_affine: raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`") - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - + super().__init__( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=True, memory_efficient=memory_efficient + ) def forward(self, input: torch.Tensor): # NOTE (mkozuki): CPU path is here mainly for unittest sake. # TODO Manual RMS Norm Implementation Here if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) - return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + return mixed_dtype_fused_rms_norm_affine( + input, self.weight, self.normalized_shape, self.eps, self.memory_efficient + ) diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp index 005906103..588375f6f 100644 --- a/csrc/layer_norm_cuda.cpp +++ b/csrc/layer_norm_cuda.cpp @@ -214,7 +214,7 @@ void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -227,38 +227,45 @@ void cuda_layer_norm_gradient( double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma, - at::Tensor* grad_beta + at::Tensor* grad_beta, + bool memory_efficient ); at::Tensor layer_norm_gradient( at::Tensor dout, - at::Tensor mean, + c10::optional mean_, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); - CHECK_INPUT(mean); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,NULL,NULL,epsilon, - &grad_input,NULL,NULL); + check_args(input_or_output,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); + if (mean_.has_value()) { + cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } return grad_input; } std::vector layer_norm_gradient_affine( at::Tensor dout, - at::Tensor mean, + c10::optional mean_, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else @@ -266,21 +273,28 @@ std::vector layer_norm_gradient_affine( #endif at::Tensor gamma, at::Tensor beta, - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); - CHECK_INPUT(mean); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); CHECK_INPUT(gamma); CHECK_INPUT(beta); int n1,n2; - check_args(input,normalized_shape,gamma,beta,n1,n2); - at::Tensor grad_input = at::empty_like(input); + check_args(input_or_output,normalized_shape,gamma,beta,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_beta = at::empty_like(beta); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,&gamma,&beta,epsilon, - &grad_input,&grad_gamma,&grad_beta); +// at::Tensor *mean = mean_.has_value() ? &mean_.value() : NULL; + if (mean_.has_value()) { + cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } return {grad_input, grad_gamma, grad_beta}; } @@ -364,7 +378,7 @@ std::vector rms_norm_affine_mixed_dtypes( void cuda_rms_norm_gradient( at::Tensor* dout, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -375,52 +389,55 @@ void cuda_rms_norm_gradient( at::Tensor* gamma, double epsilon, at::Tensor* grad_input, - at::Tensor* grad_gamma); + at::Tensor* grad_gamma, + bool memory_efficient); at::Tensor rms_norm_gradient( at::Tensor dout, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + check_args(input_or_output,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); + cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, normalized_shape,NULL,epsilon, - &grad_input,NULL); + &grad_input,NULL,memory_efficient); return grad_input; } std::vector rms_norm_gradient_affine( at::Tensor dout, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif at::Tensor gamma, - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); CHECK_INPUT(gamma); int n1,n2; - check_args(input,normalized_shape,gamma,n1,n2); - at::Tensor grad_input = at::empty_like(input); + check_args(input_or_output,normalized_shape,gamma,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); at::Tensor grad_gamma = at::empty_like(gamma); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, normalized_shape,&gamma,epsilon, - &grad_input,&grad_gamma); + &grad_input,&grad_gamma,memory_efficient); return {grad_input, grad_gamma}; } diff --git a/csrc/layer_norm_cuda_kernel.cu b/csrc/layer_norm_cuda_kernel.cu index 21366772c..4e80e057a 100644 --- a/csrc/layer_norm_cuda_kernel.cu +++ b/csrc/layer_norm_cuda_kernel.cu @@ -7,6 +7,7 @@ #include #include "type_shim.h" +#include "static_switch.h" template __device__ void cuWelfordOnlineSum( @@ -437,7 +438,28 @@ void cuApplyRMSNorm( cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); } -template __device__ + +template __device__ +V clamp_by_magnitude(V curr_gamma, double eps) +{ + const V kMinGamma = V(eps); + if (curr_gamma >= 0) { + if (curr_gamma < kMinGamma) { + return kMinGamma; + } else { + return curr_gamma; + } + } else { + if (curr_gamma > -kMinGamma) { + return -kMinGamma; + } else { + return curr_gamma; + } + } +} + + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -446,34 +468,41 @@ void cuLoadWriteStridedInputs( const int row_stride, U* warp_buf1, U* warp_buf2, - const T* input, + const T* input_or_output, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + const V* __restrict__ gamma, + const V* __restrict__ beta, + const double eps, bool rms_only ) { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1*n2+i2; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); + U c_h = static_cast(input_or_output[load_idx]); U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + if (MemoryEfficient) { + U curr_beta = static_cast(beta[i2]); + warp_buf2[write_idx] = curr_dout * (c_h - curr_beta) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] = curr_dout * (c_h - mean[i1]) * invvar[i1]; + } } else { - warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] = curr_dout * (c_h) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] = curr_dout * (c_h) * invvar[i1]; + } } } else { if (!rms_only) { @@ -493,7 +522,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -502,34 +531,41 @@ void cuLoadAddStridedInputs( const int row_stride, U* warp_buf1, U* warp_buf2, - const T* input, + const T* input_or_output, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + const V* __restrict__ gamma, + const V* __restrict__ beta, + const double eps, bool rms_only ) { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1*n2+i2; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); + U c_h = static_cast(input_or_output[load_idx]); U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { + U curr_beta = static_cast(beta[i2]); warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] += curr_dout * (c_h - curr_beta) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] += curr_dout * (c_h - mean[i1]) * invvar[i1]; + } } else { - warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] += curr_dout * (c_h) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] += curr_dout * (c_h) * invvar[i1]; + } } } } @@ -537,17 +573,20 @@ void cuLoadAddStridedInputs( } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const V* __restrict__ dout, - const T* __restrict__ input, + const T* __restrict__ input_or_output, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, U* part_grad_gamma, U* part_grad_beta, + const double eps, bool rms_only) { const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); @@ -565,9 +604,9 @@ void cuComputePartGradGammaBeta( U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only); } __syncthreads(); // inter-warp reductions @@ -675,78 +714,108 @@ void cuComputeGradGammaBeta( } -template __global__ +template __global__ void cuComputeGradInput( const V* __restrict__ dout, - const T* __restrict__ input, + const T* __restrict__ input_or_output, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, const V* gamma, + const V* beta, T* grad_input, + const double eps, bool rms_only) { for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); - U c_mean; - if (!rms_only) { - c_mean = mean[i1]; - } - const U c_invvar = invvar[i1]; - const T* k_input = input + i1*n2; + const T* k_h = input_or_output + i1*n2; const V* k_dout = dout + i1*n2; + const U c_invvar = invvar[i1]; + const U c_mean = !MemoryEfficient ? mean[i1] : 0.; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; if (gamma != NULL) { int l = 4*thrx; for (; l+3 < n2; l+=4*numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); + const U c_h = static_cast(k_h[l+k]); const U c_loss = static_cast(k_dout[l+k]); if (!rms_only) { sum_loss1 += c_loss * gamma[l+k]; - sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * (c_h - beta[l+k]); + } else { + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; + } } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); if (!rms_only) { sum_loss1 += c_loss * gamma[l]; - sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * (c_h - beta[l]); + } else { + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; + } } - } } else { int l = 4*thrx; for (; l+3 < n2; l+=4*numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); + const U c_h = static_cast(k_h[l+k]); const U c_loss = static_cast(k_dout[l+k]); if (!rms_only) { sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); if (!rms_only) { sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } } } } @@ -801,28 +870,46 @@ void cuComputeGradInput( T* k_grad_input = grad_input + i1*n2; if (gamma != NULL) { for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * gamma[l]; + const U k_gamma = static_cast(clamp_by_magnitude(gamma[l], eps)); + U f_grad_input = fH * c_loss * k_gamma; if (!rms_only) { + const U k_beta = beta[l]; f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= (c_h - k_beta) / k_gamma * sum_loss2; + } else { + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h / k_gamma * sum_loss2; + } else { + f_grad_input -= c_h * c_invvar * sum_loss2; + } } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } } else { for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss; if (!rms_only) { f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h * sum_loss2; + } else { + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h * sum_loss2; + } else { + f_grad_input -= c_h * c_invvar * sum_loss2; + } } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); @@ -947,7 +1034,7 @@ void HostLayerNormGradient( const V* dout, const U* mean, const U* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, const V* gamma, @@ -955,7 +1042,8 @@ void HostLayerNormGradient( double epsilon, T* grad_input, V* grad_gamma, - V* grad_beta + V* grad_beta, + bool memory_efficient ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -971,21 +1059,27 @@ void HostLayerNormGradient( // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that // the `cuda_layer_norm_gradient` doesn't support double. const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); + input_or_output->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_beta.DATA_PTR(), - false); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{ + auto kernel = &cuComputePartGradGammaBeta; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + beta, + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + epsilon, + false); + }); const dim3 threads3(32,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); @@ -1008,29 +1102,35 @@ void HostLayerNormGradient( threads1.y > 1 ? threads1.y*threads1.x*sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - gamma, - grad_input, - false); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] { + auto kernel = cuComputeGradInput; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + beta, + grad_input, + epsilon, + false); + }); } template void HostRMSNormGradient( const V* dout, const U* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, const V* gamma, double epsilon, T* grad_input, - V* grad_gamma) + V* grad_gamma, + bool memory_efficient) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1044,20 +1144,27 @@ void HostRMSNormGradient( // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that // the `cuda_layer_norm_gradient` doesn't support double. const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, // unused - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_gamma.DATA_PTR(), /* unused */ - true); + input_or_output->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype)); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{ + auto kernel = &cuComputePartGradGammaBeta; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + gamma, /* unused */ + part_grad_gamma.DATA_PTR(), + part_grad_gamma.DATA_PTR(), /* unused */ + epsilon, + true); + }); + const dim3 threads3(32,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); @@ -1080,23 +1187,28 @@ void HostRMSNormGradient( threads1.y > 1 ? threads1.y*threads1.x*sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, /* unused */ - invvar, - U(epsilon), - gamma, - grad_input, - true); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] { + auto kernel = cuComputeGradInput; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + gamma, /* unused */ + grad_input, + epsilon, + true); + }); } void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -1109,18 +1221,19 @@ void cuda_layer_norm_gradient( double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma, - at::Tensor* grad_beta) + at::Tensor* grad_beta, + bool memory_efficient) { using namespace at; // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", + input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", using accscalar_t = at::acc_type; HostLayerNormGradient( dout->DATA_PTR(), - mean->DATA_PTR(), + mean != NULL ? mean->DATA_PTR() : NULL, invvar->DATA_PTR(), - input, + input_or_output, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. @@ -1129,14 +1242,15 @@ void cuda_layer_norm_gradient( epsilon, grad_input->DATA_PTR(), gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL); + gamma != NULL ? grad_beta->DATA_PTR() : NULL, + memory_efficient); ) } void cuda_rms_norm_gradient( at::Tensor* dout, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -1147,24 +1261,26 @@ void cuda_rms_norm_gradient( at::Tensor* gamma, double epsilon, at::Tensor* grad_input, - at::Tensor* grad_gamma) + at::Tensor* grad_gamma, + bool memory_efficient) { using namespace at; // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", + input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", using accscalar_t = at::acc_type; HostRMSNormGradient( dout->DATA_PTR(), invvar->DATA_PTR(), - input, + input_or_output, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. gamma != NULL ? gamma->DATA_PTR() : NULL, epsilon, grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL); + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + memory_efficient); ) } diff --git a/csrc/static_switch.h b/csrc/static_switch.h new file mode 100644 index 000000000..1ba09857b --- /dev/null +++ b/csrc/static_switch.h @@ -0,0 +1,25 @@ +// From +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py index 13dee874b..94c30057f 100644 --- a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py +++ b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py @@ -21,7 +21,7 @@ def _prep_inputs(batch_size, normalized_shape, dtype): class TestFusedLayerNorm(common_utils.TestCase): def _test_fused_layer_norm( - self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=None, atol=None), bwd_thresholds=dict(rtol=None, atol=None) ): @@ -29,15 +29,19 @@ def _test_fused_layer_norm( if not mixed_fused: module_cpu_ = FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).cpu() + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).cpu() module_cuda_ = FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).to(device="cuda", dtype=dtype) + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) else: assert elementwise_affine module_cpu_ = MixedFusedLayerNorm( - normalized_shape=normalized_shape).cpu() + normalized_shape=normalized_shape, memory_efficient=memory_efficient + ).cpu() module_cuda_ = MixedFusedLayerNorm( - normalized_shape=normalized_shape).to(device="cuda", dtype=dtype) + normalized_shape=normalized_shape, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) torch.cuda.manual_seed(42) if contiguous: @@ -70,7 +74,7 @@ def _test_fused_layer_norm( input_.grad.to(device="cuda", dtype=dtype), input_cuda_.grad, **bwd_thresholds) def _test_fused_rms_norm( - self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=None, atol=None), bwd_thresholds=dict(rtol=None, atol=None) ): @@ -78,9 +82,11 @@ def _test_fused_rms_norm( if not mixed_fused: module_cpu_ = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).cpu() + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).cpu() module_cuda_ = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).to(device="cuda", dtype=dtype) + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) else: assert elementwise_affine module_cpu_ = MixedFusedRMSNorm( @@ -123,87 +129,87 @@ def _test_fused_rms_norm( # layer norm tests @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (False,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (False,), (False,), (torch.float,), (True, False))) ) - def test_layer_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_layer_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (False,), (torch.float,), (True, False))) ) - def test_layer_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_layer_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (True,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (True,), (torch.float,), (True, False))) ) - def test_layer_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_layer_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.half,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) ) - def test_layer_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_layer_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1e-3, atol=1e-3), bwd_thresholds=dict(rtol=1e-3, atol=1e-3)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) ) - def test_layer_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_layer_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-3)) # rms norm tests @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (False,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (False,), (False,), (torch.float,), (True, False))) ) - def test_rms_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_rms_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (False,), (torch.float,), (True, False))) ) - def test_rms_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds=dict(rtol=2e-3, atol=2e-4)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (True,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (True,), (torch.float,), (True, False))) ) - def test_rms_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds=dict(rtol=2e-3, atol=2e-4)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.half,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) ) - def test_rms_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) ) - def test_rms_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-2)) @common_utils.parametrize( - "dtype, elementwise_affine", - list(product(autocast_dtypes, (True, False))) + "dtype, elementwise_affine, memory_efficient", + list(product(autocast_dtypes, (True, False), (True, False))) ) - def test_autocast_fused_layer_norm(self, dtype, elementwise_affine): + def test_autocast_fused_layer_norm(self, dtype, elementwise_affine, memory_efficient): bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) batch_size = 16 @@ -212,7 +218,7 @@ def test_autocast_fused_layer_norm(self, dtype, elementwise_affine): normalized_shape=normalized_shape, elementwise_affine=elementwise_affine ).to(device="cuda", dtype=dtype) fused = FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient ).cuda() native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype) @@ -230,22 +236,27 @@ def test_autocast_fused_layer_norm(self, dtype, elementwise_affine): expected.backward(g_native) actual.backward(g_fused) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else bf16_bwd_thresholds + if dtype != torch.half: + tols = bf16_bwd_thresholds + elif memory_efficient: + tols = {'rtol': 1e-3, 'atol': 1e-4} + else: + tols = {'rtol': None, 'atol': None} torch.testing.assert_close(native_x.grad, fused_x.grad, **tols, check_dtype=False) @common_utils.parametrize( - "dtype, elementwise_affine", - list(product(autocast_dtypes, (True, False))) + "dtype, elementwise_affine, memory_efficient", + list(product(autocast_dtypes, (True, False), (True, False))) ) - def test_autocast_fused_rms_norm(self, dtype, elementwise_affine): + def test_autocast_fused_rms_norm(self, dtype, elementwise_affine, memory_efficient): bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) batch_size = 16 normalized_shape = [32, 16] native = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, ).to(dtype=dtype) fused = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, ).cuda() native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype)