Skip to content

Commit

Permalink
remove duplication for residuals. requires memset so we remain correct
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Apr 18, 2024
1 parent fab549b commit ecf072f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,7 @@ void gpt2_backward(GPT2 *model) {
bw_act_sizes[18] = 0; // lnf_mean
bw_act_sizes[19] = 0; // lnf_rstd
bw_act_sizes[21] = 0; // probs
// residual3 is tricky. For now, just allocate per layer
bw_act_sizes[16] = model->config.num_layers * model->batch_size * model->seq_len * model->config.channels;

model->grads_acts_memory = malloc_and_point_activations(&model->grads_acts, bw_act_sizes);
model->num_grad_acts = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
Expand Down Expand Up @@ -1528,13 +1527,13 @@ void gpt2_backward(GPT2 *model) {
matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, V);
// backward the final layernorm
float* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // and its gradient
float* dresidual = grads_acts.residual3; // and its gradient
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);

// now backward all the layers
for (int l = L-1; l >= 0; l--) {
residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3 + (l-1) * B * T * C;
dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3;

// get the pointers of the weights for this layer
float* l_ln1w = params.ln1w + l * C;
Expand Down Expand Up @@ -1586,11 +1585,12 @@ void gpt2_backward(GPT2 *model) {
float* dl_fch = grads_acts.fch;
float* dl_fch_gelu = grads_acts.fch_gelu;
float* dl_fcproj = grads_acts.fcproj;
float* dl_residual3 = grads_acts.residual3 + l * B * T * C;
float* dl_residual3 = grads_acts.residual3;

// backprop this layer
cudaCheck(cudaMemset(dl_residual2, 0, B * T * C * sizeof(float)));
residual_backward(dl_residual2, dl_fcproj, dl_residual3, B*T*C);
cudaCheck(cudaMemset(dl_residual3, 0, B * T * C * sizeof(float)));
matmul_backward(dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C);
gelu_backward(dl_fch, l_fch, dl_fch_gelu, B*T*4*C);
matmul_backward(dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C);
Expand Down

0 comments on commit ecf072f

Please sign in to comment.