Skip to content

Commit

Permalink
macro-based alternative
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Jun 27, 2024
1 parent d38c614 commit 98875dd
Showing 1 changed file with 55 additions and 87 deletions.
142 changes: 55 additions & 87 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,13 @@ typedef struct {
float* ln1_mean; // (L, B, T)
float* ln1_rstd; // (L, B, T)
floatX* atty; // (L, B, T, C)
floatX* att; // (L, B, NH, T, T) (smaller with cuDNN)
// cuDNN saves only some statistics information
#if ENABLE_CUDNN
float* att; // (L, B, NH, T)
#else
floatX* att; // (L, B, NH, T, T)
#endif

floatX* attproj; // (L, B, T, C)
floatX* residual2; // (L, B, T, C)
floatX* ln2; // (L, B, T, C)
Expand Down Expand Up @@ -252,93 +258,62 @@ size_t sizeof_dtype(DType type) {
}
}

// Translating from types to the corresponding data type enum.
// Technically, we are defining variable templates, i.e., variables
// where there is one instance for each argument type.
// The better way to think of this might be that dtype_of is a function,
// but because you cannot pass types as arguments, your "function call"
// uses angle brackets, i.e., it looks like dtype_of<float> instead of dtype_of(float).

// declare the base template, but don't define it.
// only types that explicitly opt-in below are supported
template<class T>
DType dtype_of;

// explicitly specify the enum value for each data type.
template<>
constexpr const DType dtype_of<float> = DType::FP32;
template<>
constexpr const DType dtype_of<half> = DType::FP16;
template<>
constexpr const DType dtype_of<nv_bfloat16> = DType::BF16;

void fill_in_activation_sizes(size_t* act_sizes, DType* act_dtypes, size_t B, size_t T, GPT2Config config, int recompute) {
DType dtype_of(float* f) { return DType::FP32; }
DType dtype_of(nv_bfloat16 * f) { return DType::BF16; }
DType dtype_of(half * f) { return DType::FP16; }

struct TensorSpec {
void** ptr;
size_t size;
DType type;
};


#define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)};

void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) {
size_t Vp = config.padded_vocab_size;
size_t L = config.num_layers;
size_t NH = config.num_heads;
size_t C = config.channels;
act_sizes[0] = B * T * C; // encoded
tensors[0] = TENSOR_SPEC(data->encoded, B * T * C);
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
act_sizes[1] = (recompute < 2) ? L * B * T * C : 0; // ln1
act_sizes[2] = L * B * T; // ln1_mean
act_sizes[3] = L * B * T; // ln1_rstd
act_sizes[4] = L * B * T * C; // atty
tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0);
tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T);
tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T);
tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C);
#ifdef ENABLE_CUDNN
// FP32 stats tensor for cuDNN to be passed to backward pass
act_sizes[5] = L * B * NH * T * (sizeof(float) / sizeof(floatX));
tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T);
#else
act_sizes[5] = L * B * NH * T * T; // att
tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T * T);
#endif
act_sizes[6] = L * B * T * C; // attproj
act_sizes[7] = L * B * T * C; // residual2
tensors[6] = TENSOR_SPEC(data->attproj, L * B * T * C);
tensors[7] = TENSOR_SPEC(data->residual2, L * B * T * C);
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
act_sizes[8] = (recompute < 2) ? L * B * T * C : 0; // ln2
act_sizes[9] = L * B * T; // ln2_mean
act_sizes[10] = L * B * T; // ln2_rstd
act_sizes[11] = L * B * T * 4*C; // fch
tensors[8] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0);
tensors[9] = TENSOR_SPEC(data->ln2_mean, L * B * T);
tensors[10] = TENSOR_SPEC(data->ln2_rstd, L * B * T);
tensors[11] = TENSOR_SPEC(data->fch, L * B * T * 4*C);
// if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer
act_sizes[12] = (recompute < 1) ? L * B * T * 4*C : B * T * 4*C;
act_sizes[13] = L * B * T * C; // fcproj
act_sizes[14] = L * B * T * C; // residual3
act_sizes[15] = B * T * C; // lnf
act_sizes[16] = B * T; // lnf_mean
act_sizes[17] = B * T; // lnf_rstd
act_sizes[18] = B * T; // losses
act_sizes[19] = L * B * T * 3*C; // qkvr
act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch

act_sizes[21] = B * T * 4 * C; // scratch_bt4c
act_sizes[22] = B * T * C; // scratch_btc
}

// Given a list of pointers, fills in target with corresponding void* pointers, and dtypes with their dtypes.
template<class... T>
void fill_in_types_and_pointers(DType* dtypes, void*** target, T**... pointers) {
constexpr const int n = sizeof...(T);
// we cannot iterate over a parameter pack directly, so we extract the quantities we want (i.e., the data type, and
// the pointer cast to void), and put those into local arrays, which *can* be filled from a paramter pack.
// Good ol' for then just copies the data into the arrays that we actually want.
DType dtype_helper[n] = {dtype_of<T>...};
void** ptr_helper[n] = {(void**)pointers...};
for(int i = 0; i < n; ++i) {
dtypes[i] = dtype_helper[i];
target[i] = ptr_helper[i];
}
tensors[12] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C);
tensors[13] = TENSOR_SPEC(data->fcproj, L * B * T * C);
tensors[14] = TENSOR_SPEC(data->residual3, L * B * T * C);
tensors[15] = TENSOR_SPEC(data->lnf, B * T * C);
tensors[16] = TENSOR_SPEC(data->lnf_mean, B * T);
tensors[17] = TENSOR_SPEC(data->lnf_rstd, B * T);
tensors[18] = TENSOR_SPEC(data->losses, B * T);
tensors[19] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C);
tensors[20] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp)));

tensors[21] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C);
tensors[22] = TENSOR_SPEC(data->scratch_btc, B * T * C);
}

void* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_sizes, DType* act_types) {
void** ptrs[NUM_ACTIVATION_TENSORS];
fill_in_types_and_pointers(act_types, ptrs,
&acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->atty,
&acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean,
&acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf,
&acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output,
&acts->scratch_bt4c, &acts->scratch_btc
);

void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) {
size_t bytes = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
bytes += act_sizes[i] * sizeof_dtype(act_types[i]);
bytes += tensors[i].size * sizeof_dtype(tensors[i].type);
}

printf0("allocating %d MiB for activations\n", (int)round(bytes / (1024 * 1024)));
Expand All @@ -348,11 +323,11 @@ void* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_si
char* acts_memory_iterator = (char*)acts_memory;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
// extra protection so we don't accidentally use an empty buffer
if(act_sizes[i] == 0) {
*(ptrs[i]) = NULL;
if(tensors[i].size == 0) {
*(tensors[i].ptr) = NULL;
}else {
*(ptrs[i]) = acts_memory_iterator;
acts_memory_iterator += act_sizes[i] * sizeof_dtype(act_types[i]);
*(tensors[i].ptr) = acts_memory_iterator;
acts_memory_iterator += tensors[i].size * sizeof_dtype(tensors[i].type);
}
}
return acts_memory;
Expand All @@ -376,10 +351,8 @@ typedef struct {
float* master_weights; // is NULL unless fp32 weights is enabled.
// the activations of the model, and their sizes
ActivationTensors acts;
size_t act_sizes[NUM_ACTIVATION_TENSORS];
DType act_types[NUM_ACTIVATION_TENSORS];
TensorSpec acts_specs[NUM_ACTIVATION_TENSORS];
void* acts_memory;
size_t num_activations;
// other run state configuration
int batch_size; // the batch size (B) of current forward pass
int seq_len; // the sequence length (T) of current forward pass
Expand Down Expand Up @@ -630,13 +603,8 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
model->batch_size = B;
model->seq_len = T;
// allocate the space
fill_in_activation_sizes(model->act_sizes, model->act_types, B, T, model->config, model->recompute);
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += model->act_sizes[i];
}
model->num_activations = num_activations;
model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes, model->act_types);
fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute);
model->acts_memory = malloc_and_point_activations(model->acts_specs);
// also create memory for caching inputs and targets
cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int)));
cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int)));
Expand Down

0 comments on commit 98875dd

Please sign in to comment.