Skip to content

Commit

Permalink
mnist : minor fixes and adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Apr 13, 2023
1 parent 7399915 commit c0ca5a3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 19 deletions.
42 changes: 41 additions & 1 deletion examples/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use this to save a pytorch model to be converted to ggml format.

## GGML Format Conversion

GGML "format" is whatever you choose for efficient loading. In our case, we just save the hyperparameters used
GGML "format" is whatever you choose for efficient loading. In our case, we just save the hyperparameters used
plus the model weights and biases. Run convert-h5-to-ggml.py to convert your pytorch model. The output format is:

- magic constant (int32)
Expand Down Expand Up @@ -45,3 +45,43 @@ make -j4 mnist
./bin/mnist ../examples/mnist/models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte

For more information, checkout the corresponding programs in the [examples](examples) folder.

# Sample output


```
$ ./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
mnist_model_load: loading model from './models/mnist/ggml-model-f32.bin'
mnist_model_load: ggml ctx size = 1.52 MB
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ * * * * _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ * * * _ _ _ * _ * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ * * _ _ _ _ _ * _ * _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ * * _ _ _ _ _ _ * * * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ * * _ _ _ _ _ _ _ * * * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ * * _ _ _ _ _ * * _ _ * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ * * _ _ _ _ * * _ _ _ _ * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ * * * * * * _ _ _ _ _ _ * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ * _ _ _ _ _ _ _ _ _ * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ * _ _ * * * _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ * * * * * _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

ggml_graph_dump_dot: dot -Tpng mnist.dot -o mnist.dot.png && open mnist.dot.png
Predicted digit is 9
```
31 changes: 16 additions & 15 deletions examples/mnist/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <unistd.h>
#include <time.h>

// default hparams
// default hparams
struct mnist_hparams {
int32_t n_input = 784;
int32_t n_hidden = 500;
Expand Down Expand Up @@ -84,14 +84,15 @@ bool mnist_model_load(const std::string & fname, mnist_model & model) {
struct ggml_init_params params = {
.mem_size = ctx_size + 1024*1024,
.mem_buffer = NULL,
.no_alloc = false,
};

model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
}

// Read FC1 layer 1
{
Expand Down Expand Up @@ -141,8 +142,8 @@ bool mnist_model_load(const std::string & fname, mnist_model & model) {
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne_bias[i]), sizeof(ne_bias[i]));
}
model.fc2_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_classes);
fin.read(reinterpret_cast<char *>(model.fc2_bias->data), ggml_nbytes(model.fc2_bias));
model.fc2_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_classes);
fin.read(reinterpret_cast<char *>(model.fc2_bias->data), ggml_nbytes(model.fc2_bias));
}
fin.close();

Expand All @@ -162,9 +163,9 @@ int mnist_eval(
) {

const auto & hparams = model.hparams;

static size_t buf_size = hparams.n_input * sizeof(float) * 4;
static void * buf = malloc(buf_size);
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
.mem_size = buf_size,
Expand All @@ -180,18 +181,18 @@ int mnist_eval(
// fc1 MLP = Ax + b
ggml_tensor * fc1 = ggml_add(ctx0, ggml_mul_mat(ctx0, model.fc1_weight, input), model.fc1_bias);
ggml_tensor * fc2 = ggml_add(ctx0, ggml_mul_mat(ctx0, model.fc2_weight, ggml_relu(ctx0, fc1)), model.fc2_bias);

// soft max
ggml_tensor * final = ggml_soft_max(ctx0, fc2);

// run the computation
ggml_build_forward_expand(&gf, final);
ggml_graph_compute (ctx0, &gf);

ggml_graph_print (&gf);
//ggml_graph_print (&gf);
ggml_graph_dump_dot(&gf, NULL, "mnist.dot");
float* finalData = ggml_get_data_f32(final);

int prediction = std::max_element(finalData, finalData + 10) - finalData;
ggml_free(ctx0);
return prediction;
Expand Down Expand Up @@ -223,19 +224,19 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
return 1;
}

unsigned char buf[784];
srand(time(NULL));
// Seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
fin.seekg(16 + 784 * (rand() % 10000));
fin.read((char *) &buf, sizeof(buf));
digit.resize(sizeof(buf));

// render the digit in ASCII
for(int row = 0; row < 28; row++) {
for (int col = 0; col < 28; col++) {
fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
digit[row*28+col]=((float)buf[row*28+col]);
fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
digit[row*28+col]=((float)buf[row*28+col]);

}
fprintf(stderr, "\n");
Expand All @@ -245,7 +246,7 @@ int main(int argc, char ** argv) {
t_load_us = ggml_time_us() - t_start_us;
}


fprintf(stdout, "Predicted digit is %d\n", mnist_eval(model, 1, digit));
ggml_free(model.ctx);

Expand Down
8 changes: 5 additions & 3 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3054,9 +3054,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
return NULL;
}

const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);

*ctx = (struct ggml_context) {
/*.mem_size =*/ params.mem_size,
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(params.mem_size),
/*.mem_size =*/ mem_size,
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
/*.no_alloc =*/ params.no_alloc,
/*.n_objects =*/ 0,
Expand All @@ -3066,7 +3068,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
/*.scratch_save =*/ { 0, 0, NULL, },
};

GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
GGML_ASSERT(ctx->mem_buffer != NULL);

ggml_assert_aligned(ctx->mem_buffer);

Expand Down

0 comments on commit c0ca5a3

Please sign in to comment.