Skip to content

Commit

Permalink
fix recompute bug
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 31, 2024
1 parent 0bc24c3 commit 9c8edaf
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2633,17 +2633,17 @@ void gpt2_backward(GPT2 *model, int* inputs) {
floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;
floatX* dl_fcprojb = grads.fcprojb + l * C;
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute == 0) ? acts.ln1 + l * B * T * C : acts.ln1;
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.ln1;
floatX* l_ln1_mean = acts.ln1_mean + l * B * T;
floatX* l_ln1_rstd = acts.ln1_rstd + l * B * T;
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
floatX* l_atty = acts.atty + l * B * T * C;
floatX* l_residual2 = acts.residual2 + l * B * T * C;
floatX* l_ln2 = (model->recompute == 0) ? acts.ln2 + l * B * T * C : acts.ln2;
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.ln2;
floatX* l_ln2_mean = acts.ln2_mean + l * B * T;
floatX* l_ln2_rstd = acts.ln2_rstd + l * B * T;
floatX* l_fch = acts.fch + l * B * T * 4*C;
floatX* l_fch_gelu = (model->recompute == 0) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;
// get the pointers of the gradients of the activations for this layer
// notice that there is no l *, because we just have a single copy, and keep
// re-using this memory in every Transformer block as we calculate backward pass
Expand Down

0 comments on commit 9c8edaf

Please sign in to comment.