Skip to content

Commit

Permalink
allocate_state utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Jul 18, 2024
1 parent 5c89416 commit cde79ff
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 47 deletions.
1 change: 1 addition & 0 deletions profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ int main(int argc, char *argv[]) {
model.config.num_layers = 1;
set_zero_configs(&multi_gpu_config, 0, model.num_parameters);

gpt2_allocate_state(&model, B, T);
// do a training step
gpt2_forward(&model, x, B, T);
gpt2_backward_and_reduce(&model, x, y, 1, 0);
Expand Down
3 changes: 3 additions & 0 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ int main(int argc, char *argv[]) {
// overall OK signal for the test
int allok = 1;

gpt2_allocate_state(&model, B, T);

// First, do target-free forward pass to validate logits
gpt2_forward(&model, x, B, T);
// at this point, target should be equal to expected_logits, let's compare
Expand Down Expand Up @@ -344,6 +346,7 @@ int main(int argc, char *argv[]) {
gpt2_free(&model);
gpt2_build_from_checkpoint(&model, "test_gpt2cu_model.ckpt");
int ld_step;
gpt2_allocate_state(&model, B, T);
load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt");
for (int step = 0; step < 10; step++) {
dataloader_next_batch(&loader);
Expand Down
99 changes: 52 additions & 47 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ typedef struct {
float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps
float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
int use_master_weights; // keep master weights copy in float for optim update? 0|1
int use_master_weights; // keep master weights copy in float for optim update? 0|1
bool init_state; // set to true if master weights need to be initialized
int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward)
int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
Expand Down Expand Up @@ -344,6 +345,7 @@ void gpt2_init_common(GPT2 *model, GPT2Config config) {
// other default settings
model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding
model->use_master_weights = 1; // safe default: do keep master weights in fp32
model->init_state = false;
model->recompute = 1; // good default: recompute gelu but not layernorm
model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main())

Expand All @@ -358,7 +360,42 @@ void gpt2_init_common(GPT2 *model, GPT2Config config) {
// create memory for model parameters on the device
assert(model->params_memory == nullptr && "Old model needs to be freed before loading from checkpoint again");
model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof);
}

void gpt2_allocate_state(GPT2 *model, int B, int T) {
printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024)));
model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof);

// record the current B,T as well
model->batch_size = B;
model->seq_len = T;

// allocate the space
fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute);
model->acts_memory = malloc_and_point_activations(model->acts_specs);
// also create memory for caching inputs and targets
cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int)));
cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int)));
cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float)));
cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));

// initialise cpu scratch buffers for encoder backward
size_t num_c_groups = CEIL_DIV(model->config.channels, (WARP_SIZE * x128::size));
assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?)
model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);

size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for
printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20);
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float)));
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));

if (model->use_master_weights == 1) {
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**) &model->master_weights, shard_num_parameters * sizeof(float)));
model->init_state = true;
}
}

void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
Expand Down Expand Up @@ -586,33 +623,11 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
const size_t NH = model->config.num_heads;
const size_t C = model->config.channels;

// allocate space for all the activations if needed (done here, lazily)
if(model->acts_memory == NULL) {
NvtxRange rng("InitActs");
// record the current B,T as well
model->batch_size = B;
model->seq_len = T;
// allocate the space
fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute);
model->acts_memory = malloc_and_point_activations(model->acts_specs);
// also create memory for caching inputs and targets
cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int)));
cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int)));
cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float)));
cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));

// initialise cpu scratch buffers for encoder backward
size_t num_c_groups = CEIL_DIV(model->config.channels, (WARP_SIZE * x128::size));
assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?)
model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);
} else {
// validate B,T are not larger than the values used at initialisation
// (smaller B,T are okay for inference only)
if (B > model->batch_size || T > model->seq_len) {
printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T);
exit(EXIT_FAILURE);
}
// validate B,T are not larger than the values used at initialisation
// (smaller B,T are okay for inference only)
if (B > model->batch_size || T > model->seq_len) {
printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T);
exit(EXIT_FAILURE);
}

// copy inputs/targets to the model
Expand Down Expand Up @@ -983,26 +998,13 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
// selectively weight decay some, but not all tensors :(
// TODO: revisit and probably refactor this entire function
NVTX_RANGE_FN();
size_t shard_num_parameters = multi_gpu_config->shard_num_parameters; // num parameters we are responsible for

// lazily allocate m,v memory and master weights (usually on the first iteration)
if (model->m_memory == NULL) {
bool init_state = model->init_state;
if(init_state) {
model->init_state = false;
NvtxRange rng("InitOpt");
printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20);
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float)));
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));
cudaCheck(cudaMemset(model->m_memory, 0, shard_num_parameters * sizeof(float)));
cudaCheck(cudaMemset(model->v_memory, 0, shard_num_parameters * sizeof(float)));
}

bool init_master_weights = false;
if (model->use_master_weights == 1 && model->master_weights == NULL) {
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float)));
init_master_weights = true;
cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float)));
cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float)));
}

// AdamW update
// handle adamw for all the transformer blocks
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
Expand Down Expand Up @@ -1032,7 +1034,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
float* v_ptr = model->v_memory + opt_state_offset;
float* master_ptr = NULL;
if (model->master_weights != NULL) { master_ptr = model->master_weights + opt_state_offset; }
if(init_master_weights) {
if(init_state) {
size_t grid_size = CEIL_DIV(shard.size, 512);
copy_and_cast_kernel<<<dim3(grid_size, num_layers), 512, 0, main_stream>>>(master_ptr, param_ptr, shard.size,
shard.size, tensor.size);
Expand Down Expand Up @@ -1230,6 +1232,8 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename
file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
}

model->init_state = false; // we just got the state from file, no need to do first-touch init

// revive the DataLoader object and its state
loader->should_shuffle = should_shuffle;
if (should_shuffle == 1) {
Expand Down Expand Up @@ -1613,6 +1617,7 @@ int main(int argc, char *argv[]) {

// if we found a checkpoint to resume from, load the optimization state
int step = 0;
gpt2_allocate_state(&model, B, T);
if (resuming == 1) {
snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank);
load_state(&step, &model, &train_loader, filename_buffer);
Expand Down

0 comments on commit cde79ff

Please sign in to comment.