Fix bugs in copy_active_tensor_to_branches()
CNN lib bug fixes.
Change-Id: I58562e85fdaa2841b065f525bad0724a695f8656
diff --git a/av1/common/cnn.c b/av1/common/cnn.c
index 95c5917..cb83161 100644
--- a/av1/common/cnn.c
+++ b/av1/common/cnn.c
@@ -60,15 +60,17 @@
tensor->buf[c] = &tensor->buf[0][c * width * height];
}
-static void copy_tensor(const TENSOR *src, int dst_offset, TENSOR *dst) {
+static void copy_tensor(const TENSOR *src, int copy_channels, int dst_offset,
+ TENSOR *dst) {
assert(src->width == dst->width);
assert(src->height == dst->height);
+ assert(copy_channels <= src->channels);
if (src->stride == dst->width && dst->stride == dst->width) {
memcpy(dst->buf[dst_offset], src->buf[0],
sizeof(*dst->buf[dst_offset]) * src->width * src->height *
- src->channels);
+ copy_channels);
} else {
- for (int c = 0; c < src->channels; ++c) {
+ for (int c = 0; c < copy_channels; ++c) {
for (int r = 0; r < dst->height; ++r) {
memcpy(&dst->buf[dst_offset + c][r * dst->stride],
&src->buf[c][r * src->stride],
@@ -112,7 +114,7 @@
init_tensor(&t);
// allocate new buffers and copy first the dst channels
realloc_tensor(&t, channels, dst->width, dst->height);
- copy_tensor(dst, 0, &t);
+ copy_tensor(dst, dst->channels, 0, &t);
// Swap the tensors and free the old buffers
swap_tensor(dst, &t);
free_tensor(&t);
@@ -120,7 +122,7 @@
for (int c = 1; c < channels; ++c)
dst->buf[c] = &dst->buf[0][c * dst->width * dst->height];
// Copy the channels in src after the first dst_channels channels.
- copy_tensor(src, dst_channels, dst);
+ copy_tensor(src, src->channels, dst_channels, dst);
}
int check_tensor_equal_dims(TENSOR *t1, TENSOR *t2) {
@@ -225,17 +227,16 @@
static void copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
const CNN_LAYER_CONFIG *layer_config,
- int branch, int in_width,
- int in_height,
- TENSOR branch_output[]) {
+ int branch, TENSOR branch_output[]) {
for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
if ((layer_config->input_to_branches & (1 << b)) && b != branch) {
// Copy layer's active tensor to output tensor of branch b if set in
// mask. The output becomes the input of the first layer of the branch
// because the layer of the branch is not the first layer.
- realloc_tensor(&branch_output[b], layer_config->in_channels, in_width,
- in_height);
- copy_tensor(layer_active_tensor, 0, &branch_output[b]);
+ int copy_channels = layer_active_tensor->channels;
+ realloc_tensor(&branch_output[b], copy_channels,
+ layer_active_tensor->width, layer_active_tensor->height);
+ copy_tensor(layer_active_tensor, copy_channels, 0, &branch_output[b]);
}
}
}
@@ -667,9 +668,8 @@
(1 << branch))));
if (cnn_config->layer_config[layer].branch_copy_mode == COPY_INPUT) {
- copy_active_tensor_to_branches(&tensor1[branch],
- &cnn_config->layer_config[layer], branch,
- in_width, in_height, tensor2);
+ copy_active_tensor_to_branches(
+ &tensor1[branch], &cnn_config->layer_config[layer], branch, tensor2);
}
// Check consistency of input and output channels
assert(tensor1[branch].channels ==
@@ -692,9 +692,8 @@
}
if (cnn_config->layer_config[layer].branch_copy_mode == COPY_OUTPUT) {
- copy_active_tensor_to_branches(&tensor1[branch],
- &cnn_config->layer_config[layer], branch,
- in_width, in_height, tensor2);
+ copy_active_tensor_to_branches(
+ &tensor2[branch], &cnn_config->layer_config[layer], branch, tensor2);
}
// Add tensors from other branches if needed
@@ -729,9 +728,8 @@
}
if (cnn_config->layer_config[layer].branch_copy_mode == COPY_COMBINED) {
- copy_active_tensor_to_branches(&tensor1[branch],
- &cnn_config->layer_config[layer], branch,
- in_width, in_height, tensor2);
+ copy_active_tensor_to_branches(
+ &tensor2[branch], &cnn_config->layer_config[layer], branch, tensor2);
}
}
diff --git a/test/cnn_test.cc b/test/cnn_test.cc
index e5811e1..85fc4de 100644
--- a/test/cnn_test.cc
+++ b/test/cnn_test.cc
@@ -1328,31 +1328,31 @@
int image_height = 4;
float input[] = {
- -1, -5, 1, 2, -2, 0, -1, 4, -3, -4, 0, -1, 1, 0, -4, 3,
+ 3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
};
float weights[] = {
- 2, 4, -2, -3, 2, -1, 1, -2, 4, 3, -3, 4, 2, 4, -5, -1, 0, -2,
- 2, -2, -2, 0, -2, -1, -5, -3, -5, 3, -1, 0, -2, -5, 3, 2, 2, 4,
- -2, -2, -4, -4, -3, 1, 2, -1, -1, -2, -3, -3, -5, 1, 3, -4, 4, 4,
- 0, 1, -4, -5, 4, -2, -1, -5, 3, -1, -1, 2, 3, -1, 0, 2, -2, -1,
- -3, 4, -5, -3, 1, 0, -3, 4, 4, -1, 3, 1, 2, -3, 0, -4, -5, -4,
- 3, -5, 4, 0, 4, -2, 1, -3, 4, 3, 3, -1, -3, 2, 4, 3, 2, -5,
- 1, -2, 1, -3, 1, 0, 1, 2, 0, -2, -2, 1, -1, 3, 1, -2, 2, 2,
- -5, -3, 0, -3, 4, -1, 1, -3, -2, 3, 4, -4, 1, -3, -2, 3, 1, -2,
- -3, -1, 2, -5, -2, -2, -4, 1, 2, 1, 2, -2, -1, 0, -5, 0, 0, -2,
- -5, -2, -5, 4, 3, -1, -1, 0, -5, -2, 1, 2, 4, 1, 0, -3, -1, 2,
- -4, 4, 4, 4, 3, 0, 3, 2, 2, -4, 2, 0, -1, -1, 4, -1, -2, -5,
- -3, -2, -2, 1, -4, -3, 0, 4, 4, -5, -2, 0, 1, -4, -2, -1, -3, 0,
+ 2, 3, 0, 4, 4, 3, 1, 0, 1, -5, 4, -3, 3, 0, 4, -1, -1, -5,
+ 2, 1, -3, -5, 3, -1, -3, -2, 0, -2, 3, 0, -2, -4, -2, -2, 2, -5,
+ 4, -5, 0, 1, -5, -4, -3, -4, 2, -2, 1, 0, 3, -2, -4, 3, 4, -4,
+ -1, -1, -3, -2, -2, -1, 2, 0, 2, -1, 2, -4, -4, -1, 2, 0, 3, -2,
+ -2, 3, -3, 4, -2, 4, 3, 4, 1, 0, -2, -3, -5, 1, -3, 2, 0, -2,
+ -2, -1, -1, -5, -2, -3, -1, 3, 3, 4, 4, 0, 2, 1, 3, -3, 2, -5,
+ -5, 1, -5, -1, 3, 3, 2, -4, -1, 3, -4, -2, -5, -2, 1, 3, 2, 2,
+ -5, -2, -3, -1, -2, -4, -1, -2, 2, 1, -4, -4, 2, 0, 2, 0, 2, -3,
+ -2, -4, 4, 0, 1, -3, -5, 4, -1, 2, 3, -5, -1, 0, 4, -1, -1, 3,
+ -1, -3, 3, 1, 4, 3, 4, 3, -4, -5, -1, 3, 3, -4, 3, 1, 3, -5,
+ 3, 4, -5, 4, 2, -1, -5, 2, 1, 0, 4, 0, -3, 2, 0, 2, -2, 1,
+ -1, -2, -1, -5, 4, 3, 3, -2, 2, 4, -5, -5, -3, -2, 4, 0, -4, 1,
};
float bias[] = {
- 0, 2, -3, 1, -3, -1, 3, 4, -2, -3, -2, 0, 4, -5, -3, -5, -4, 4, -5,
+ -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
};
float expected[] = {
- -186549, 23839, 70503, 11452, -224977, -58609, 26347, 28867,
- 158422, -192903, 35079, 6881, 189062, -73870, -22263, -8869,
+ 149496, 15553, -24193, -20956, 134094, 86432, -68283, -6366,
+ -53031, 133739, 67407, -13539, -53205, -58635, -20033, 1979,
};
int channels = 2;