Skip to content

Commit

Permalink
refactor gguf load
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 11, 2024
1 parent 035f0d7 commit 397f617
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 38 deletions.
91 changes: 53 additions & 38 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,69 @@ mnist_eval_result mnist_graph_eval(const std::string & fname, const float * imag
return result;
}

// Temporary util function for loading data from GGUF to a backend != CPU until GGML itself provides this functionality:
bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf) {
FILE * f = ggml_fopen(fname, "rb");
if (!f) {
return false;
}

const size_t buf_size = 4*1024*1024;
void * buf = malloc(buf_size);

const int n_tensors = gguf_get_n_tensors(ctx_gguf);
for (int i = 0; i < n_tensors; i++) {
const char * name = gguf_get_tensor_name(ctx_gguf, i);

struct ggml_tensor * tensor = ggml_get_tensor(ctx_ggml, name);
if (!tensor) {
continue;
}

const size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i);

if (fseek(f, offs, SEEK_SET) != 0) {
fclose(f);
free(buf);
return false;
}

const size_t nbytes = ggml_nbytes(tensor);
for (size_t pos = 0; pos < nbytes; pos += buf_size) {
const size_t nbytes_cpy = buf_size < nbytes - pos ? buf_size : nbytes - pos;

if (fread(buf, 1, nbytes_cpy, f) != nbytes_cpy) {
fclose(f);
free(buf);
return false;
}

ggml_backend_tensor_set(tensor, buf, pos, nbytes_cpy);
}
}

fclose(f);
free(buf);
return true;
}

mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend) {
mnist_model model(backend);
fprintf(stderr, "%s: loading model weights from '%s'\n", __func__, fname.c_str());

struct gguf_context * ctx_be; // be == backend

struct gguf_context * ctx;
{
struct gguf_init_params params = {
/*.no_alloc =*/ true,
/*.ctx =*/ &model.ctx_weight,
};
ctx_be = gguf_init_from_file(fname.c_str(), params);
if (!ctx_be) {
ctx = gguf_init_from_file(fname.c_str(), params);
if (!ctx) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
exit(1);
}
}
model.arch = gguf_get_val_str(ctx_be, gguf_find_key(ctx_be, "general.architecture"));
model.arch = gguf_get_val_str(ctx, gguf_find_key(ctx, "general.architecture"));
fprintf(stderr, "%s: model arch is %s\n", __func__, model.arch.c_str());

if (model.arch == "mnist-fc") {
Expand Down Expand Up @@ -247,41 +292,11 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str
}
model.buf_weightt = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend);

void * buf_tmp = malloc(model.size_weight);
struct ggml_context * ctx_ggml_tmp;
{
struct ggml_init_params params = {
/*.mem_size =*/ model.size_weight,
/*.mem_buffer =*/ buf_tmp,
/*.no_alloc =*/ false,
};
ctx_ggml_tmp = ggml_init(params);
}
struct gguf_context * ctx_gguf_tmp;
{
struct gguf_init_params params = {
/*.no_alloc =*/ false,
/*.ctx =*/ &ctx_ggml_tmp,
};
ctx_gguf_tmp = gguf_init_from_file(fname.c_str(), params);
if (!ctx_gguf_tmp) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
exit(1);
}
}
for (const std::string & s : {"fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"}) {
const struct ggml_tensor * src = ggml_get_tensor(ctx_ggml_tmp, s.c_str());
struct ggml_tensor * dst = ggml_get_tensor(model.ctx_weight, s.c_str());
GGML_ASSERT(ggml_nbytes(src) == ggml_nbytes(dst));
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(dst));
if(!load_from_gguf(fname.c_str(), model.ctx_weight, ctx)) {
fprintf(stderr, "%s: loading weights from %s failed\n", __func__, fname.c_str());
exit(1);
}

gguf_free(ctx_gguf_tmp);
ggml_free(ctx_ggml_tmp);
free(buf_tmp);

gguf_free(ctx_be);

fprintf(stderr, "%s: successfully loaded weights from %s\n", __func__, fname.c_str());
return model;
}
Expand Down
1 change: 1 addition & 0 deletions src/ggml-cuda/out-prod.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "out-prod.cuh"
#include "opt-step-adam.cuh"
#include "vendors/cuda.h"

Expand Down

0 comments on commit 397f617

Please sign in to comment.