Skip to content

Commit

Permalink
examples: add MNIST training + missing ops
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Aug 22, 2024
1 parent 46e22f5 commit 879dcb8
Show file tree
Hide file tree
Showing 24 changed files with 1,819 additions and 1,680 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ CMakeSettings.json
.clangd

.venv/
ggml_env/
.exrc
.cache
.DS_Store
Expand Down
3 changes: 3 additions & 0 deletions examples/mnist/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data/
*.gguf
*.ggml
42 changes: 11 additions & 31 deletions examples/mnist/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,40 +1,20 @@
#
# mnist
# mnist-common

set(TEST_TARGET mnist)
add_executable(${TEST_TARGET} main.cpp)
set(TEST_TARGET mnist-common)
add_library(${TEST_TARGET} mnist-common.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)

#
# mnist-cnn
# mnist-eval

set(TEST_TARGET mnist-cnn)
add_executable(${TEST_TARGET} main-cnn.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
set(TEST_TARGET mnist-eval)
add_executable(${TEST_TARGET} mnist-eval.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)

#
# mnist-cpu

set(TEST_TARGET mnist-cpu)
add_executable(${TEST_TARGET} main-cpu.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)

if (APPLE)
#
# mnist-mtl

find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
# mnist-train

set(TEST_TARGET mnist-mtl)
add_executable(${TEST_TARGET} main-mtl.cpp main-mtl.h main-mtl.m)
target_link_libraries(${TEST_TARGET} PRIVATE
ggml
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()
set(TEST_TARGET mnist-train)
add_executable(${TEST_TARGET} mnist-train.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)
228 changes: 148 additions & 80 deletions examples/mnist/README.md
Original file line number Diff line number Diff line change
@@ -1,119 +1,187 @@
# MNIST Examples for GGML

These are simple examples of how to use GGML for inferencing.
The first example uses convolutional neural network (CNN), the second one uses fully connected neural network.
This directory contains simple examples of how to use GGML for training and inference using the [MNIST dataset](https://yann.lecun.com/exdb/mnist/).
All commands listed in this README assume the working directory to be `examples/mnist`.
Please note that training in GGML is a work-in-progress and not production ready.

## MNIST with CNN
## Obtaining the data

This implementation achieves ~99% accuracy on the MNIST test set.
The data can either be downloaded [here](https://yann.lecun.com/exdb/mnist/) or it will be downloaded automatically when running `mnist-train-fc.py`.

### Training the model
## Fully connected network

Setup the Python environemt and build the examples according to the main README.
Use the `mnist-cnn.py` script to train the model and convert it to GGUF format:
For our first example we will train a fully connected network.
To train a fully connected model in PyTorch and save it as a GGUF file, run:

```bash
$ python3 ../examples/mnist/mnist-cnn.py train mnist-cnn-model
$ python3 mnist-train-fc.py mnist-fc-f32.gguf

...
Keras model saved to 'mnist-cnn-model'
```

Convert the model to GGUF format:
Test loss: 0.069983+-0.009196, Test accuracy: 97.94+-0.14%

```bash
$ python3 ../examples/mnist/mnist-cnn.py convert mnist-cnn-model
...
Model converted and saved to 'mnist-cnn-model.gguf'
Model tensors saved to mnist-fc-f32.gguf:
fc1.weight (500, 784)
fc1.bias (500,)
fc2.weight (10, 500)
fc2.bias (10,)
```

### Running the example
The training script includes an evaluation of the model on the test set.
To evaluate the model using GGML, run:

```bash
$ ./bin/mnist-cnn mnist-cnn-model.gguf ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
main: loaded model in 5.17 ms
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ * * * * * _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ * * * * * * * * _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ * * * * * _ _ _ * * _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * * _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * * _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ * * * _ _ _ _ * * * * * _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ * * * * * * * * * _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ * * * * * * * * * * _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ * * * * * * _ _ * * * _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ * * * _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ * * _ _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ * * * _ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ * * * _ _ _ _ _ _ * * * _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ * * * * * * * * * * _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ * * * * * * _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

ggml_graph_dump_dot: dot -Tpng mnist-cnn.dot -o mnist-cnn.dot.png && open mnist-cnn.dot.png
main: predicted digit is 8
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte

________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________######__________________
____________________________########____________________
________________________########________________________
____________________########________________##__________
__________________######____________________##__________
________________######______________________####________
______________######________________________####________
____________######__________________________####________
____________####____________________________####________
__________####______________________________####________
__________####______________________________####________
__________##________________________________####________
__________##______________________________####__________
__________##____________________________######__________
__________##__________________________######____________
____________##____________________########______________
____________##########################__________________
______________##################________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
mnist_graph_eval: trying to load a ggml graph from mnist-fc-f32.gguf
ggml_graph_import: invalid magic number, got 46554747
mnist_graph_eval: could not load a ggml graph from mnist-fc-f32.gguf
mnist_model_init_from_file: loading model weights from 'mnist-fc-f32.gguf'
mnist_model_init_from_file: model arch is mnist-fc
mnist_model_init_from_file: successfully loaded weights from mnist-fc-f32.gguf
main: loaded model in 1.52 ms
mnist_model_eval: model evaluation on 10000 images took 26.65 ms, 2.66 us/image
main: predicted digit is 0
main: test_loss=0.069983+-0.009196
main: test_acc=97.94+-0.14%
```

Computation graph:
In addition to the evaluation on the test set the GGML evaluation also prints a random image from the test set as well as the model prediction for said image.
To train a fully connected model using GGML run:

![mnist dot](https://user-images.githubusercontent.com/1991296/263763842-3b679b45-7ca1-4ee9-b19a-82e34396624f.png)

## MNIST with fully connected network

A fully connected layer + relu, followed by a fully connected layer + softmax.

### Training the Model
``` bash
$ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
```

A Google Colab notebook for training a simple two-layer network to recognize digits is located here. You can
use this to save a pytorch model to be converted to ggml format.
It can then be evaluated with the same binary as above.
When training a model with GGML the computation graph for the forward pass is also exported to `mnist-fc-f32.ggml`.
Compared to the GGUF (which only contains the weights) this file also contains the model architecture.
As long as the input and output tensors are well-defined an exported GGML graph is fully agnostic w.r.t. the model architecture.
It can be evaluated using the `mnist-eval` binary by substituting the argument for the GGUF file.

[Colab](https://colab.research.google.com/drive/12n_8VNJnolBnX5dVS0HNWubnOjyEaFSb?usp=sharing)
## Convolutional network

GGML "format" is whatever you choose for efficient loading. In our case, we just save the hyperparameters used
plus the model weights and biases. Run convert-h5-to-ggml.py to convert your pytorch model. The output format is:
To train a convolutional network using TensorFlow run:

- magic constant (int32)
- repeated list of tensors
- number of dimensions of tensor (int32)
- tensor dimension (int32 repeated)
- values of tensor (int32)
```bash
$ python3 mnist-train-cnn.py mnist-cnn-f32.gguf

Run ```convert-h5-to-ggml.py mnist_model.state_dict``` where `mnist_model.state_dict` is the saved pytorch model from the Google Colab. For
quickstart, it is included in the mnist/models directory.
...

```bash
mkdir -p models/mnist
python3 ../examples/mnist/convert-h5-to-ggml.py ../examples/mnist/models/mnist/mnist_model.state_dict
Test loss: 0.046456
Test accuracy: 98.40%
GGUF model saved to 'mnist-cnn-f32.gguf'
```

### Running the example
The saved model can be evaluated using the `mnist-eval` binary:

```bash
./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte

________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________####____________________________
__________________________##____________________________
__________________________##____________________________
__________________________##____________________________
__________________________##____________________________
__________________________##____________________________
____________________________##__________________________
____________________________##__________________________
____________________________##__________________________
______________________________##________________________
______________________________##________________________
______________________________####______________________
________________________________##______________________
________________________________##______________________
________________________________####____________________
__________________________________##____________________
________________________________##______________________
________________________________________________________
________________________________________________________
________________________________________________________
________________________________________________________
mnist_graph_eval: trying to load a ggml graph from mnist-cnn-f32.gguf
ggml_graph_import: invalid magic number, got 46554747
mnist_graph_eval: could not load a ggml graph from mnist-cnn-f32.gguf
mnist_model_init_from_file: loading model weights from 'mnist-cnn-f32.gguf'
mnist_model_init_from_file: model arch is mnist-cnn
mnist_model_init_from_file: successfully loaded weights from mnist-cnn-f32.gguf
main: loaded model in 5.45 ms
mnist_model_eval: model evaluation on 10000 images took 605.60 ms, 60.56 us/image
main: predicted digit is 1
main: test_loss=0.046456+-0.007354
main: test_acc=98.40+-0.13%
```

Computation graph:
Like with the fully connected network the convolutional network can also be trained using GGML:

![mnist dot](https://user-images.githubusercontent.com/1991296/231882071-84e29d53-b226-4d73-bdc2-5bd6dcb7efd1.png)
``` bash
$ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
```

As always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`.

## Web demo

The example can be compiled with Emscripten like this:
The evaluation code can be compiled to WebAssembly using [Emscripten](https://emscripten.org/) (may need to re-login to update `$PATH` after installation).
First, copy the GGUF file of either of the trained models to `examples/mnist` and name it `mnist-f32.gguf`.
Copy the test set to `examples/mnist` and name it `t10k-images-idx3-ubyte`.
Symlinking these files will *not* work!
Compile the code like so:

```bash
cd examples/mnist
emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c main.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file models/mnist
$ emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c ../../src/ggml-aarch64.c mnist-common.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file mnist-f32.gguf --preload-file t10k-images-idx3-ubyte
```

The compilation output is in `examples/mnist/web`.
To run it, you need an HTTP server.
For example:

``` bash
$ cd web
$ python3 -m http.server

Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
```

Online demo: https://mnist.ggerganov.com
The web demo can then be accessed via the link printed on the console.
Simply draw a digit on the canvas and the model will try to predict what it's supposed to be.
Alternatively, click the "Random" button to retrieve a random digit from the test set.
Be aware that like all neural networks the one we trained is susceptible to distributional shift:
if the numbers you draw look different than the ones in the training set
(e.g. because they're not centered) the model will perform comparatively worse.
An online demo can be accessed [here](https://mnist.ggerganov.com).
62 changes: 0 additions & 62 deletions examples/mnist/convert-h5-to-ggml.py

This file was deleted.

Loading

0 comments on commit 879dcb8

Please sign in to comment.