Skip to content

Commit

Permalink
Merge pull request karpathy#50 from karpathy/memmap
Browse files Browse the repository at this point in the history
candidate memmap implementation
  • Loading branch information
karpathy authored Jul 25, 2023
2 parents d18e9ef + a1f6b46 commit 133ad3f
Showing 1 changed file with 45 additions and 61 deletions.
106 changes: 45 additions & 61 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ Then run with:
#include <time.h>
#include <math.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>

// ----------------------------------------------------------------------------
// Transformer and RunState structs, and related memory management
Expand Down Expand Up @@ -104,68 +107,39 @@ void free_run_state(RunState* s) {
free(s->value_cache);
}

void malloc_weights(TransformerWeights* w, Config* p) {
// we calloc instead of malloc to keep valgrind happy
w->token_embedding_table = calloc(p->vocab_size * p->dim, sizeof(float));
w->rms_att_weight = calloc(p->n_layers * p->dim, sizeof(float));
w->rms_ffn_weight = calloc(p->n_layers * p->dim, sizeof(float));
w->wq = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->wk = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->wv = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->wo = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->w1 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float));
w->w2 = calloc(p->n_layers * p->dim * p->hidden_dim, sizeof(float));
w->w3 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float));
w->rms_final_weight = calloc(p->dim, sizeof(float));
w->freq_cis_real = calloc(p->seq_len * p->dim / 2, sizeof(float));
w->freq_cis_imag = calloc(p->seq_len * p->dim / 2, sizeof(float));
// ensure all mallocs went fine
if (!w->token_embedding_table || !w->rms_att_weight || !w->rms_ffn_weight
|| !w->wq || !w->wk || !w->wv || !w->wo || !w->w1 || !w->w2 || !w->w3 ||
!w->rms_final_weight || !w->freq_cis_real || !w->freq_cis_imag) {
printf("malloc failed!\n");
exit(1);
}
}

void free_weights(TransformerWeights* w) {
free(w->token_embedding_table);
free(w->rms_att_weight);
free(w->rms_ffn_weight);
free(w->wq);
free(w->wk);
free(w->wv);
free(w->wo);
free(w->w1);
free(w->w2);
free(w->w3);
free(w->rms_final_weight);
free(w->freq_cis_real);
free(w->freq_cis_imag);
}

// ----------------------------------------------------------------------------
// initialization: read from checkpoint

int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != p->vocab_size * p->dim) return 1;
if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != p->n_layers * p->dim) return 1;
if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
if (fread(w->wk, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
if (fread(w->wv, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
if (fread(w->wo, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
if (fread(w->rms_ffn_weight, sizeof(float), p->n_layers * p->dim, f) != p->n_layers * p->dim) return 1;
if (fread(w->w1, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != p->n_layers * p->dim * p->hidden_dim) return 1;
if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != p->n_layers * p->hidden_dim * p->dim) return 1;
if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != p->n_layers * p->dim * p->hidden_dim) return 1;
if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != p->dim) return 1;
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
float* ptr = f;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
ptr += p->n_layers * p->dim;
w->wq = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->wk = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->wv = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->wo = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->rms_ffn_weight = ptr;
ptr += p->n_layers * p->dim;
w->w1 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
w->w2 = ptr;
ptr += p->n_layers * p->hidden_dim * p->dim;
w->w3 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
w->rms_final_weight = ptr;
ptr += p->dim;
w->freq_cis_real = ptr;
int head_size = p->dim / p->n_heads;
if (fread(w->freq_cis_real, sizeof(float), p->seq_len * head_size / 2, f) != p->seq_len * head_size / 2) return 1;
if (fread(w->freq_cis_imag, sizeof(float), p->seq_len * head_size / 2, f) != p->seq_len * head_size / 2) return 1;
return 0;
ptr += p->seq_len * head_size / 2;
w->freq_cis_imag = ptr;
}


// ----------------------------------------------------------------------------
// neural net blocks

Expand Down Expand Up @@ -410,6 +384,9 @@ int main(int argc, char *argv[]) {
// read in the model.bin file
Config config;
TransformerWeights weights;
int fd = 0;
float* data = NULL;
long file_size;
{
FILE *file = fopen(checkpoint, "rb");
if (!file) {
Expand All @@ -418,10 +395,16 @@ int main(int argc, char *argv[]) {
}
// read in the config header
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
// read in the Transformer weights
malloc_weights(&weights, &config);
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
// figure out the file size
fseek(file, 0, SEEK_END); // move file pointer to end of file
file_size = ftell(file); // get the file size, in bytes
fclose(file);
// memory map the Transformer weights into the data pointer
fd = open(checkpoint, O_RDONLY); // open in read only mode
if (fd == -1) { printf("open failed!\n"); return 1; }
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
checkpoint_init_weights(&weights, &config, data + sizeof(Config)/sizeof(float));
}
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
Expand Down Expand Up @@ -484,10 +467,11 @@ int main(int argc, char *argv[]) {
long end = time_in_ms();
printf("\nachieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);

// memory cleanup
// memory and file handles cleanup
free_run_state(&state);
free_weights(&weights);
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
free(vocab);
if (data != MAP_FAILED) munmap(data, file_size);
if (fd != -1) close(fd);
return 0;
}

0 comments on commit 133ad3f

Please sign in to comment.