Skip to content

Commit

Permalink
broadcast GGML_OP_ADD in backwards pass
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Aug 20, 2024
1 parent d8f7847 commit bada316
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 48 deletions.
50 changes: 7 additions & 43 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ mnist_model mnist_model_init_random(const std::string & arch) {
init_tensors.push_back(model.fc2_bias);
} else if (model.arch == "mnist-cnn") {
model.conv1_kernel = ggml_new_tensor_4d(model.ctx_weight, GGML_TYPE_F32, 3, 3, 1, MNIST_CNN_NCB);
model.conv1_bias = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, MNIST_HW, MNIST_HW, MNIST_CNN_NCB);
model.conv1_bias = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, 1, 1, MNIST_CNN_NCB);
model.conv2_kernel = ggml_new_tensor_4d(model.ctx_weight, GGML_TYPE_F32, 3, 3, MNIST_CNN_NCB, MNIST_CNN_NCB*2);
model.conv2_bias = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, MNIST_HW/2, MNIST_HW/2, MNIST_CNN_NCB*2);
model.conv2_bias = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, 1, 1, MNIST_CNN_NCB*2);
model.dense_weight = ggml_new_tensor_2d(model.ctx_weight, GGML_TYPE_F32, (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), MNIST_NCLASSES);
model.dense_bias = ggml_new_tensor_1d(model.ctx_weight, GGML_TYPE_F32, MNIST_NCLASSES);

Expand Down Expand Up @@ -320,25 +320,12 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
ggml_set_input(model.images);
ggml_set_name(model.images, "images");

ggml_tensor * fc1_bias = model.fc1_bias;
if (model.nbatch > 1) {
fc1_bias = ggml_repeat(model.ctx_compute,
model.fc1_bias,
ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NHIDDEN, model.nbatch));
}
ggml_tensor * fc2_bias = model.fc2_bias;
if (model.nbatch > 1) {
fc2_bias = ggml_repeat(model.ctx_compute,
model.fc2_bias,
ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch));
}

ggml_tensor * fc1 = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
ggml_mul_mat(model.ctx_compute, model.fc1_weight, model.images),
fc1_bias));
model.fc1_bias));
model.logits = ggml_add(model.ctx_compute,
ggml_mul_mat(model.ctx_compute, model.fc2_weight, fc1),
fc2_bias);
model.fc2_bias);
} else if (model.arch == "mnist-cnn") {
ggml_set_param(model.ctx_compute, model.conv1_kernel);
ggml_set_param(model.ctx_compute, model.conv1_bias);
Expand All @@ -351,17 +338,9 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
ggml_set_input(model.images);
ggml_set_name(model.images, "images");

struct ggml_tensor * conv2d_1_bias = model.conv1_bias;
if (model.nbatch > 1) {
int64_t ne[4];
memcpy(ne, conv2d_1_bias->ne, sizeof(ne));
ne[3] = model.nbatch;
conv2d_1_bias = ggml_repeat(model.ctx_compute, conv2d_1_bias, ggml_new_tensor(model.ctx_compute, GGML_TYPE_F32, 4, ne));
}

struct ggml_tensor * conv1_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
ggml_conv_2d(model.ctx_compute, model.conv1_kernel, model.images, 1, 1, 1, 1, 1, 1),
conv2d_1_bias));
model.conv1_bias));
GGML_ASSERT(conv1_out->ne[0] == MNIST_HW);
GGML_ASSERT(conv1_out->ne[1] == MNIST_HW);
GGML_ASSERT(conv1_out->ne[2] == MNIST_CNN_NCB);
Expand All @@ -373,17 +352,9 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
GGML_ASSERT(conv2_in->ne[2] == MNIST_CNN_NCB);
GGML_ASSERT(conv2_in->ne[3] == model.nbatch);

struct ggml_tensor * conv2d_2_bias = model.conv2_bias;
if (model.nbatch > 1) {
int64_t ne[4];
memcpy(ne, conv2d_2_bias->ne, sizeof(ne));
ne[3] = model.nbatch;
conv2d_2_bias = ggml_repeat(model.ctx_compute, conv2d_2_bias, ggml_new_tensor(model.ctx_compute, GGML_TYPE_F32, 4, ne));
}

struct ggml_tensor * conv2_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
ggml_conv_2d(model.ctx_compute, model.conv2_kernel, conv2_in, 1, 1, 1, 1, 1, 1),
conv2d_2_bias));
model.conv2_bias));
GGML_ASSERT(conv2_out->ne[0] == MNIST_HW/2);
GGML_ASSERT(conv2_out->ne[1] == MNIST_HW/2);
GGML_ASSERT(conv2_out->ne[2] == MNIST_CNN_NCB*2);
Expand All @@ -403,14 +374,7 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
GGML_ASSERT(dense_in->ne[2] == 1);
GGML_ASSERT(dense_in->ne[3] == 1);

struct ggml_tensor * dense_bias = model.dense_bias;
if (model.nbatch > 1) {
int64_t ne[4];
memcpy(ne, dense_bias->ne, sizeof(ne));
ne[1] = model.nbatch;
dense_bias = ggml_repeat(model.ctx_compute, dense_bias, ggml_new_tensor(model.ctx_compute, GGML_TYPE_F32, 4, ne));
}
model.logits = ggml_add(model.ctx_compute, ggml_mul_mat(model.ctx_compute, model.dense_weight, dense_in), dense_bias);
model.logits = ggml_add(model.ctx_compute, ggml_mul_mat(model.ctx_compute, model.dense_weight, dense_in), model.dense_bias);
} else {
GGML_ASSERT(false);
}
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/mnist-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ struct mnist_model {
struct ggml_tensor * conv1_bias = nullptr;
struct ggml_tensor * conv2_kernel = nullptr;
struct ggml_tensor * conv2_bias = nullptr;
struct ggml_tensor * dense_weight = nullptr;
struct ggml_tensor * dense_bias = nullptr;
struct ggml_tensor * dense_weight = nullptr;
struct ggml_tensor * dense_bias = nullptr;

static const size_t size_weight = 100 * 1024*1024;
static const size_t size_compute = 1 * 1024*1024*1024;
Expand Down
8 changes: 5 additions & 3 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4498,8 +4498,6 @@ static struct ggml_tensor * ggml_add_impl(
bool is_node = false;

if (!inplace && (a->grad || b->grad)) {
// TODO: support backward pass for broadcasting
GGML_ASSERT(ggml_are_same_shape(a, b));
is_node = true;
}

Expand Down Expand Up @@ -17771,7 +17769,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
if (ggml_are_same_shape(src0, src1)) {
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
} else {
src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
}
}
} break;
case GGML_OP_ADD1:
Expand Down

0 comments on commit bada316

Please sign in to comment.