Skip to content

Commit

Permalink
fallback to memory allocation of m,v,master_weights on host automatic…
Browse files Browse the repository at this point in the history
…ally in case of OOM. will run slower but won't OOM
  • Loading branch information
karpathy committed Aug 16, 2024
1 parent 2882ec6 commit e6856bc
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 17 deletions.
2 changes: 1 addition & 1 deletion llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ constexpr std::bool_constant<true> False;
// ----------------------------------------------------------------------------
// Error checking

// CUDA error checking
// CUDA error checking. Underscore added so this function can be called directly not just via macro
inline void cudaCheck_(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error));
Expand Down
19 changes: 8 additions & 11 deletions llmc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,20 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud
// ----------------------------------------------------------------------------
// memory management

// allocate memory, preferrably on the
void cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {
// try to allocate `bytes` on device
// allocate memory, preferrably on the device
// returns a status code. 0 = OK, 1 = fell back to managed memory
int cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {
// try to allocate
cudaError_t err = cudaMalloc(out, bytes);
if(err == cudaErrorMemoryAllocation) {
// if that fails, fallback to a managed allocation. It will be slower, but at least
// it won't crash.
fprintf(stderr, "[WARN] Not enough space to allocate %zu MiB on device.\n"
" Falling back to managed allocation.\n"
" Speed may be negatively affected.\n",
bytes / 1024 / 1024);
// reset the error before the next API call
cudaGetLastError();
// if we OOM, fallback to a managed allocation. slower but at least won't crash.
cudaGetLastError(); // reset the error before the next API call
cudaCheck_(cudaMallocManaged(out, bytes), file, line);
cudaCheck_(cudaMemAdvise(*out, bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId), file, line);
return 1;
} else {
cudaCheck_(err, file, line);
return 0;
}
}

Expand Down
21 changes: 16 additions & 5 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/cuda_common.h"
// defines:
// Packed128, f128, x128
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel, cudaMallocConditionallyManaged
#include "llmc/cuda_utils.cuh"
// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
// defines: cublas_compute, cublaslt_handle, cublas_handle
Expand Down Expand Up @@ -388,24 +388,35 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) {
model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);

// cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device
// and returns a status code of 1 if it had to fall back, in that case we want to print warning.
int memory_status = 0;

// we will now init the optimizer states and master weights
// this is usually a substantial amount of memory allocation right here.
size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for
printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20);
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
assert(model->m_memory == nullptr);
assert(model->v_memory == nullptr);
cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float));
cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float));
memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float));
memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float));

if (model->use_master_weights == 1) {
assert(model->master_weights == nullptr);
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float));
memory_status |= cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float));
}

// report on mixed memory allocation status
if (memory_status == 1) {
printf0("WARNING: fell back to cudaMallocManaged when initializing m,v,master_weights.\n");
printf0(" Prevents an OOM, but code may run much slower due to device <-> host memory movement\n");
}
// report on device memory usage
size_t free, total;
cudaCheck(cudaMemGetInfo(&free, &total));
printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024);

// give an estimate of the maximum batch size
size_t bytes_per_sequence = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
Expand Down

0 comments on commit e6856bc

Please sign in to comment.