Skip to content

Commit

Permalink
yolo : add backend support (#924)
Browse files Browse the repository at this point in the history
* yolo : add backend support

* metal : add sub and sqrt kernels

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
rgerganov and ggerganov committed Aug 19, 2024
1 parent 9ad0906 commit 46e22f5
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 62 deletions.
9 changes: 8 additions & 1 deletion examples/yolo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@ $ ./convert-yolov3-tiny.py yolov3-tiny.weights
yolov3-tiny.weights converted to yolov3-tiny.gguf
```

Alternatively, you can download the converted model from [HuggingFace](https://huggingface.co/rgerganov/yolo-gguf/resolve/main/yolov3-tiny.gguf)

Object detection:

```bash
$ wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg
$ ./yolov3-tiny -m yolov3-tiny.gguf -i dog.jpg
load_model: using CUDA backend
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA T1200 Laptop GPU, compute capability 7.5, VMM: yes
Layer 0 output shape: 416 x 416 x 16 x 1
Layer 1 output shape: 208 x 208 x 16 x 1
Layer 2 output shape: 208 x 208 x 32 x 1
Expand All @@ -48,5 +55,5 @@ car: 52%
truck: 56%
car: 62%
bicycle: 59%
Detected objects saved in 'predictions.jpg' (time: 0.357000 sec.)
Detected objects saved in 'predictions.jpg' (time: 0.057000 sec.)
```
204 changes: 144 additions & 60 deletions examples/yolo/yolov3-tiny.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"

#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif

#include "yolo-image.h"

#include <cmath>
Expand Down Expand Up @@ -29,22 +40,29 @@ struct yolo_model {
int width = 416;
int height = 416;
std::vector<conv2d_layer> conv2d_layers;
ggml_backend_t backend = NULL;
ggml_backend_buffer_t buffer;
struct ggml_context * ctx;
};

struct yolo_layer {
int classes = 80;
std::vector<int> mask;
std::vector<float> anchors;
struct ggml_tensor * predictions;

yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * predictions)
: classes(classes), mask(mask), anchors(anchors), predictions(predictions)
{ }
std::vector<float> predictions;
int w;
int h;

yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * prev_layer)
: classes(classes), mask(mask), anchors(anchors)
{
w = prev_layer->ne[0];
h = prev_layer->ne[1];
predictions.resize(ggml_nbytes(prev_layer)/sizeof(float));
ggml_backend_tensor_get(prev_layer, predictions.data(), 0, ggml_nbytes(prev_layer));
}

int entry_index(int location, int entry) const {
int w = predictions->ne[0];
int h = predictions->ne[1];
int n = location / (w*h);
int loc = location % (w*h);
return n*w*h*(4+classes+1) + entry*w*h + loc;
Expand All @@ -62,15 +80,60 @@ struct detection {
};

static bool load_model(const std::string & fname, yolo_model & model) {
struct gguf_init_params params = {
// initialize the backend
#ifdef GGML_USE_CUDA
fprintf(stderr, "%s: using CUDA backend\n", __func__);
model.backend = ggml_backend_cuda_init(0); // init device 0
if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
#endif

#ifdef GGML_USE_METAL
fprintf(stderr, "%s: using Metal backend\n", __func__);
model.backend = ggml_backend_metal_init();
if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
}
#endif

// if there aren't GPU Backends fallback to CPU backend
if (!model.backend) {
model.backend = ggml_backend_cpu_init();
}
struct ggml_context * tmp_ctx = nullptr;
struct gguf_init_params gguf_params = {
/*.no_alloc =*/ false,
/*.ctx =*/ &model.ctx,
/*.ctx =*/ &tmp_ctx,
};
gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
if (!ctx) {
gguf_context * gguf_ctx = gguf_init_from_file(fname.c_str(), gguf_params);
if (!gguf_ctx) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
return false;
}

int num_tensors = gguf_get_n_tensors(gguf_ctx);
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead() * num_tensors,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
model.ctx = ggml_init(params);
for (int i = 0; i < num_tensors; i++) {
const char * name = gguf_get_tensor_name(gguf_ctx, i);
struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, name);
struct ggml_tensor * dst = ggml_dup_tensor(model.ctx, src);
ggml_set_name(dst, name);
}
model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, model.backend);
// copy tensors from main memory to backend
for (struct ggml_tensor * cur = ggml_get_first_tensor(model.ctx); cur != NULL; cur = ggml_get_next_tensor(model.ctx, cur)) {
struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, ggml_get_name(cur));
size_t n_size = ggml_nbytes(src);
ggml_backend_tensor_set(cur, ggml_get_data(src), 0, n_size);
}
gguf_free(gguf_ctx);

model.width = 416;
model.height = 416;
model.conv2d_layers.resize(13);
Expand Down Expand Up @@ -155,10 +218,10 @@ static void activate_array(float * x, const int n)

static void apply_yolo(yolo_layer & layer)
{
int w = layer.predictions->ne[0];
int h = layer.predictions->ne[1];
int w = layer.w;
int h = layer.h;
int N = layer.mask.size();
float * data = ggml_get_data_f32(layer.predictions);
float * data = layer.predictions.data();
for (int n = 0; n < N; n++) {
int index = layer.entry_index(n*w*h, 0);
activate_array(data + index, 2*w*h);
Expand All @@ -169,7 +232,7 @@ static void apply_yolo(yolo_layer & layer)

static box get_yolo_box(const yolo_layer & layer, int n, int index, int i, int j, int lw, int lh, int w, int h, int stride)
{
float * predictions = ggml_get_data_f32(layer.predictions);
const float * predictions = layer.predictions.data();
box b;
b.x = (i + predictions[index + 0*stride]) / lw;
b.y = (j + predictions[index + 1*stride]) / lh;
Expand Down Expand Up @@ -197,10 +260,10 @@ static void correct_yolo_box(box & b, int im_w, int im_h, int net_w, int net_h)

static void get_yolo_detections(const yolo_layer & layer, std::vector<detection> & detections, int im_w, int im_h, int netw, int neth, float thresh)
{
int w = layer.predictions->ne[0];
int h = layer.predictions->ne[1];
int w = layer.w;
int h = layer.h;
int N = layer.mask.size();
float * predictions = ggml_get_data_f32(layer.predictions);
const float * predictions = layer.predictions.data();
std::vector<detection> result;
for (int i = 0; i < w*h; i++) {
for (int n = 0; n < N; n++) {
Expand Down Expand Up @@ -353,88 +416,92 @@ static void print_shape(int layer, const ggml_tensor * t)
printf("Layer %2d output shape: %3d x %3d x %4d x %3d\n", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]);
}

void detect(yolo_image & img, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
{
static size_t buf_size = 20000000 * sizeof(float) * 4;
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};

struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
std::vector<detection> detections;
static struct ggml_cgraph * build_graph(struct ggml_context * ctx_cgraph, const yolo_model & model) {
struct ggml_cgraph * gf = ggml_new_graph(ctx_cgraph);

yolo_image sized = letterbox_image(img, model.width, model.height);
struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, model.width, model.height, 3, 1);
std::memcpy(input->data, sized.data.data(), ggml_nbytes(input));
struct ggml_tensor * input = ggml_new_tensor_4d(ctx_cgraph, GGML_TYPE_F32, model.width, model.height, 3, 1);
ggml_set_name(input, "input");

struct ggml_tensor * result = apply_conv2d(ctx0, input, model.conv2d_layers[0]);
struct ggml_tensor * result = apply_conv2d(ctx_cgraph, input, model.conv2d_layers[0]);
print_shape(0, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(1, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[1]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[1]);
print_shape(2, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(3, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[2]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[2]);
print_shape(4, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(5, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[3]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[3]);
print_shape(6, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(7, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[4]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[4]);
struct ggml_tensor * layer_8 = result;
print_shape(8, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(9, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[5]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[5]);
print_shape(10, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
print_shape(11, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[6]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[6]);
print_shape(12, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[7]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[7]);
struct ggml_tensor * layer_13 = result;
print_shape(13, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[8]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[8]);
print_shape(14, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[9]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[9]);
struct ggml_tensor * layer_15 = result;
ggml_set_output(layer_15);
ggml_set_name(layer_15, "layer_15");

print_shape(15, result);
result = apply_conv2d(ctx0, layer_13, model.conv2d_layers[10]);
result = apply_conv2d(ctx_cgraph, layer_13, model.conv2d_layers[10]);
print_shape(18, result);
result = ggml_upscale(ctx0, result, 2);
result = ggml_upscale(ctx_cgraph, result, 2);
print_shape(19, result);
result = ggml_concat(ctx0, result, layer_8, 2);
result = ggml_concat(ctx_cgraph, result, layer_8, 2);
print_shape(20, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[11]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[11]);
print_shape(21, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[12]);
result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[12]);
struct ggml_tensor * layer_22 = result;
ggml_set_output(layer_22);
ggml_set_name(layer_22, "layer_22");
print_shape(22, result);

ggml_build_forward_expand(gf, layer_15);
ggml_build_forward_expand(gf, layer_22);
ggml_graph_compute_with_ctx(ctx0, gf, 1);
return gf;
}

void detect(yolo_image & img, struct ggml_cgraph * gf, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
{
std::vector<detection> detections;
yolo_image sized = letterbox_image(img, model.width, model.height);
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "input");
ggml_backend_tensor_set(input, sized.data.data(), 0, ggml_nbytes(input));

if (ggml_backend_graph_compute(model.backend, gf) != GGML_STATUS_SUCCESS) {
fprintf(stderr, "%s: ggml_backend_graph_compute() failed\n", __func__);
return;
}

struct ggml_tensor * layer_15 = ggml_graph_get_tensor(gf, "layer_15");
yolo_layer yolo16{ 80, {3, 4, 5}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_15};
apply_yolo(yolo16);
get_yolo_detections(yolo16, detections, img.w, img.h, model.width, model.height, thresh);

struct ggml_tensor * layer_22 = ggml_graph_get_tensor(gf, "layer_22");
yolo_layer yolo23{ 80, {0, 1, 2}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_22};
apply_yolo(yolo23);
get_yolo_detections(yolo23, detections, img.w, img.h, model.width, model.height, thresh);

do_nms_sort(detections, yolo23.classes, .45);
draw_detections(img, detections, thresh, labels, alphabet);
ggml_free(ctx0);
}

struct yolo_params {
Expand Down Expand Up @@ -512,14 +579,31 @@ int main(int argc, char *argv[])
fprintf(stderr, "%s: failed to load alphabet\n", __func__);
return 1;
}

struct ggml_init_params params0 = {
/*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
struct ggml_context * ctx_cgraph = ggml_init(params0);
struct ggml_cgraph * gf = build_graph(ctx_cgraph, model);

ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
ggml_gallocr_alloc_graph(allocr, gf);

const int64_t t_start_ms = ggml_time_ms();
detect(img, model, params.thresh, labels, alphabet);
detect(img, gf, model, params.thresh, labels, alphabet);
const int64_t t_detect_ms = ggml_time_ms() - t_start_ms;
if (!save_image(img, params.fname_out.c_str(), 80)) {
fprintf(stderr, "%s: failed to save image to '%s'\n", __func__, params.fname_out.c_str());
return 1;
}
printf("Detected objects saved in '%s' (time: %f sec.)\n", params.fname_out.c_str(), t_detect_ms / 1000.0f);

ggml_free(ctx_cgraph);
ggml_gallocr_free(allocr);
ggml_free(model.ctx);
ggml_backend_buffer_free(model.buffer);
ggml_backend_free(model.backend);
return 0;
}
4 changes: 4 additions & 0 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2181,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ADD:
ggml_cuda_op_add(ctx, dst);
break;
case GGML_OP_SUB:
ggml_cuda_op_sub(ctx, dst);
break;
case GGML_OP_ACC:
ggml_cuda_op_acc(ctx, dst);
break;
Expand Down Expand Up @@ -2859,6 +2862,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
Expand Down
8 changes: 8 additions & 0 deletions src/ggml-cuda/binbcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
return a + b;
}

static __device__ __forceinline__ float op_sub(const float a, const float b) {
return a - b;
}

static __device__ __forceinline__ float op_mul(const float a, const float b) {
return a * b;
}
Expand Down Expand Up @@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}

void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}

void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
Expand Down
1 change: 1 addition & 0 deletions src/ggml-cuda/binbcast.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Loading

0 comments on commit 46e22f5

Please sign in to comment.