Skip to content

Commit

Permalink
Massively reduce LayerNorm/RMSNorm GPU memory usage in modern network…
Browse files Browse the repository at this point in the history
…s by tricking torch autograd (#1715)

* input grad checks out

* adding clamp gamma

* Both old and proposed implementation checks out

* 2 tests not yet passed due to numerical issues

* mem_eff works

* fast-layer-norm done

* Moving mem-eff to templates

* Relax tolerance for memory efficient backward

* Fix backward api of python
  • Loading branch information
RuiWang1998 authored Sep 28, 2023
1 parent 741bdf5 commit 6ff4548
Show file tree
Hide file tree
Showing 10 changed files with 560 additions and 336 deletions.
33 changes: 11 additions & 22 deletions apex/contrib/csrc/layer_norm/ln.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace layer_norm {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Params>
template<typename Params>
struct LaunchParams{

size_t workspace_bytes;
Expand All @@ -26,17 +26,20 @@ struct LaunchParams{

////////////////////////////////////////////////////////////////////////////////////////////////////

struct ParamsBase {
ParamsBase()
struct FwdParams{
FwdParams()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, z(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, beta(nullptr)
, workspace(nullptr)
, barrier(nullptr)
, epsilon(0.f)
{
}

Expand All @@ -49,41 +52,27 @@ struct ParamsBase {

// Common data pointers.
void *x;
void *z;
void *mu;
void *rs;
void *gamma;
void *beta;

// Multi-CTA workspace in gmem.
void *workspace;

// Multi-CTA sync barriers in gmem.
int *barrier;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
{
}

// Output of LN FWD.
void *z;
void *beta;
float epsilon;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct BwdParams : public ParamsBase {
struct BwdParams : public FwdParams{
BwdParams()
: ParamsBase()
: FwdParams()
, dz(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
Expand All @@ -92,7 +81,6 @@ struct BwdParams : public ParamsBase {
, dgamma(nullptr)
{
}

// Input: gradient wrt. LN FWD output.
void *dz;

Expand Down Expand Up @@ -200,3 +188,4 @@ struct BwdRegistrar{
////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace layer_norm

53 changes: 32 additions & 21 deletions apex/contrib/csrc/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.z = z.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr();
params.z = z.data_ptr();
params.x = x.data_ptr();
params.epsilon = epsilon;

if( launch_params.barrier_size > 0 ) {
Expand All @@ -153,48 +153,54 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
}

////////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma // hidden_size
std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x_or_z, // BxSxhidden_size
c10::optional<const at::Tensor> &mu_, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor>&beta_, // hidden_size
bool memory_efficient
) {

auto itype = x.scalar_type();
auto itype = x_or_z.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;

TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);
if (mu_.has_value()) {
TORCH_CHECK(mu_.value().dtype() == ctype);
}

TORCH_CHECK(x.is_cuda());
TORCH_CHECK(x_or_z.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
if (beta_.has_value()) {
TORCH_CHECK(beta_.value().is_cuda());
TORCH_CHECK(beta_.value().dtype() == wtype);
}

TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(x_or_z.is_contiguous());
TORCH_CHECK(dz.is_contiguous());

auto sizes = x.sizes();
auto sizes = x_or_z.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];

auto hidden_size = gamma.numel();

TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());

TORCH_CHECK(gamma.numel() == cols);
if (beta_.has_value()) {
TORCH_CHECK(beta_.value().numel() == cols);
}

auto options = x.options();
auto options = x_or_z.options();

auto dx = torch::empty_like(x);
auto dx = torch::empty_like(x_or_z);
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);

Expand All @@ -213,8 +219,13 @@ std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
if (memory_efficient) {
params.z = x_or_z.data_ptr();
params.beta = beta_.value().data_ptr();
} else {
params.x = x_or_z.data_ptr();
params.mu = mu_.value().data_ptr();
}
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
Expand Down
23 changes: 16 additions & 7 deletions apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,34 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {

constexpr float rn = 1.f / float(COLS);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
if (params.z != nullptr) {
beta[it].load_from(params.beta, idx);
}
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t mu_r = params.z == nullptr ? static_cast<const compute_t *>(params.mu)[row] : 0.f;
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ivec x_or_z[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
if (params.z != nullptr) {
x_or_z[it].load_from(params.z, idx);
} else {
x_or_z[it].load_from(params.x, idx);
}
idx += Ktraits::VEC_COLS_PER_LDG;
}

Expand All @@ -89,10 +97,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t gamma_tmp = compute_t(gamma[it].data.elt[jt]);
compute_t beta_tmp = compute_t(beta[it].data.elt[jt]);
compute_t x_or_z_tmp = compute_t(x_or_z[it].data.elt[jt]);
compute_t y_tmp = params.z != nullptr ? (x_or_z_tmp - beta_tmp) / gamma_tmp : rs_r * (x_or_z_tmp - mu_r);
compute_t dy_tmp = compute_t(dz[it].data.elt[jt]) * gamma_tmp;
compute_t dz_tmp = dz[it].data.elt[jt];

mdy_local += dy_tmp;
Expand Down
32 changes: 18 additions & 14 deletions apex/contrib/layer_norm/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,44 @@

class FastLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, epsilon):
def forward(ctx, x, gamma, beta, epsilon, memory_efficient):
ctx.x_shape = x.shape
ctx.memory_efficient = memory_efficient

x = x.contiguous()
gamma = gamma.contiguous()
beta = beta.contiguous()
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)
ctx.save_for_backward(x, gamma, mu, rsigma)
if ctx.memory_efficient:
ctx.save_for_backward(ymat, gamma, None, rsigma, beta)
else:
ctx.save_for_backward(xmat, gamma, mu, rsigma, None)
return ymat.view(x.shape)

@staticmethod
def backward(ctx, dy):
# assert dy.is_contiguous()
dy = dy.contiguous() # this happens!
x, gamma, mu, rsigma = ctx.saved_tensors

hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dymat = dy.view(xmat.shape)
dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dx = dxmat.view(x.shape)
return dx, dgamma, dbeta, None
x_or_y_mat, gamma, mu, rsigma, beta = ctx.saved_tensors
dymat = dy.view(x_or_y_mat.shape)
dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, x_or_y_mat, mu, rsigma, gamma, beta, ctx.memory_efficient)
dx = dxmat.view(ctx.x_shape)
return dx, dgamma, dbeta, None, None


def _fast_layer_norm(x, weight, bias, epsilon):
args = _cast_if_autocast_enabled(x, weight, bias, epsilon)
def _fast_layer_norm(x, weight, bias, epsilon, memory_efficient):
args = _cast_if_autocast_enabled(x, weight, bias, epsilon, memory_efficient)
with torch.cuda.amp.autocast(enabled=False):
return FastLayerNormFN.apply(*args)


class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
def __init__(self, hidden_size, eps=1e-5, memory_efficient=False):
super().__init__()
self.epsilon = eps
self.memory_efficient = memory_efficient
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.bias = torch.nn.Parameter(torch.empty(hidden_size))
self.reset_parameters()
Expand All @@ -50,4 +54,4 @@ def reset_parameters(self):
init.zeros_(self.bias)

def forward(self, x):
return _fast_layer_norm(x, self.weight, self.bias, self.epsilon)
return _fast_layer_norm(x, self.weight, self.bias, self.epsilon, self.memory_efficient)
26 changes: 15 additions & 11 deletions apex/contrib/test/layer_norm/test_fast_layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import unittest

import torch
Expand Down Expand Up @@ -106,7 +107,7 @@ def benchmark_(S, B, hidden_size, itype, wtype, runs=100):

timer.start()
for r in range(runs):
dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma)
dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, z, mu, rsigma, gamma, beta, True)
timer.stop()
timer.sync()

Expand All @@ -126,15 +127,15 @@ def benchmark_(S, B, hidden_size, itype, wtype, runs=100):
)


def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32):
def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32, mem_eff=False):

seed = 1243
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

otype = wtype
print("========================================================")
print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}")
print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype} Mem_Eff={mem_eff}")
print("--------------------------------------------------------")

x = torch.randn(S * B, hidden_size, dtype=itype, device=device)
Expand Down Expand Up @@ -165,7 +166,10 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32):
dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma)

z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon)
dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma)
if mem_eff:
dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, z, mu, rs, gamma, beta, True)
else:
dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma, beta, False)

re_z, mse_z = metrics(z_ref, z)
re_mu, mse_mu = metrics(mu_ref, mu)
Expand All @@ -184,7 +188,7 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32):
print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}")

def check_err(x, relerr):
tol = 1e-3 if x.dtype == torch.float16 else 5e-6
tol = 2e-2 if x.dtype in (torch.float16, torch.bfloat16) else 5e-6
return relerr < tol

return [
Expand Down Expand Up @@ -233,13 +237,13 @@ def test_all_configs(self):
65536,
]

for h in hidden_sizes:
for (h, mem_eff) in itertools.product(hidden_sizes, (True, False)):
with self.subTest(f"hidden_size={h}"):
self.assertAll(_test_impl(256, 2, h, fp32, fp32))
self.assertAll(_test_impl(256, 2, h, fp16, fp16))
self.assertAll(_test_impl(256, 2, h, fp32, fp16))
self.assertAll(_test_impl(256, 2, h, bf16, bf16))
self.assertAll(_test_impl(256, 2, h, fp32, bf16))
self.assertAll(_test_impl(256, 2, h, fp32, fp32, mem_eff=mem_eff))
self.assertAll(_test_impl(256, 2, h, fp16, fp16, mem_eff=mem_eff))
self.assertAll(_test_impl(256, 2, h, fp32, fp16, mem_eff=mem_eff))
self.assertAll(_test_impl(256, 2, h, bf16, bf16, mem_eff=mem_eff))
self.assertAll(_test_impl(256, 2, h, fp32, bf16, mem_eff=mem_eff))

def test_run_benchmark(self):
for (S, B, hidden_size, runs) in (
Expand Down
Loading

0 comments on commit 6ff4548

Please sign in to comment.