Add copy input or output tensor of a layer.
- Add an enumeration to specify to use a layer's tensor
as input to a branch prior to convolution, after
convolution, or after combination.
Change-Id: I642dd86089484dc805481ecf79182c7a0a339364
diff --git a/av1/common/cnn.c b/av1/common/cnn.c
index 854c292..cb4314c 100644
--- a/av1/common/cnn.c
+++ b/av1/common/cnn.c
@@ -223,6 +223,23 @@
}
}
+static void copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
+ const CNN_LAYER_CONFIG *layer_config,
+ int branch, int in_width,
+ int in_height,
+ TENSOR branch_output[]) {
+ for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
+ if ((layer_config->input_to_branches & (1 << b)) && b != branch) {
+ // Copy layer's active tensor to output tensor of branch b if set in
+ // mask. The output becomes the input of the first layer of the branch
+ // because the layer of the branch is not the first layer.
+ realloc_tensor(&branch_output[b], layer_config->in_channels, in_width,
+ in_height);
+ copy_tensor(layer_active_tensor, 0, &branch_output[b]);
+ }
+ }
+}
+
void av1_cnn_convolve_c(const float **input, int in_width, int in_height,
int in_stride, const CNN_LAYER_CONFIG *layer_config,
float **output, int out_stride) {
@@ -620,34 +637,12 @@
cnn_config->layer_config[layer].out_channels, o_width,
o_height);
}
- for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
- if ((cnn_config->layer_config[layer].input_to_branches & (1 << b)) &&
- b != branch) {
- // Copy layer's input tensor to output tensor of branch b if set in
- // mask
- realloc_tensor(&tensor2[b],
- cnn_config->layer_config[layer].in_channels, in_width,
- in_height);
- copy_tensor(&tensor1[branch], 0, &tensor2[b]);
- }
- }
} else { // Non-first layer
// Swap tensor1 and tensor2
swap_tensor(&tensor1[branch], &tensor2[branch]);
i_width = o_width;
i_height = o_height;
- for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
- if ((cnn_config->layer_config[layer].input_to_branches & (1 << b)) &&
- b != branch) {
- // Copy layer's input tensor to output tensor of branch b if set in
- // mask
- realloc_tensor(&tensor2[b],
- cnn_config->layer_config[layer].in_channels, in_width,
- in_height);
- copy_tensor(&tensor1[branch], 0, &tensor2[b]);
- }
- }
find_layer_output_size(i_width, i_height,
&cnn_config->layer_config[layer], &o_width,
&o_height);
@@ -663,6 +658,7 @@
o_height, out_stride);
}
}
+
// If we are combining branches make sure that the branch to combine
// is different from the current branch.
assert(IMPLIES(
@@ -670,6 +666,11 @@
!(cnn_config->layer_config[layer].branches_to_combine &
(1 << branch))));
+ if (cnn_config->layer_config[layer].branch_copy_mode == COPY_INPUT) {
+ copy_active_tensor_to_branches(&tensor1[branch],
+ &cnn_config->layer_config[layer], branch,
+ in_width, in_height, tensor2);
+ }
// Check consistency of input and output channels
assert(tensor1[branch].channels ==
cnn_config->layer_config[layer].in_channels);
@@ -689,6 +690,13 @@
&cnn_config->layer_config[layer], tensor2[branch].buf,
tensor2[branch].stride);
}
+
+ if (cnn_config->layer_config[layer].branch_copy_mode == COPY_OUTPUT) {
+ copy_active_tensor_to_branches(&tensor1[branch],
+ &cnn_config->layer_config[layer], branch,
+ in_width, in_height, tensor2);
+ }
+
// Add tensors from other branches if needed
if (cnn_config->layer_config[layer].branch_combine_type == BRANCH_ADD) {
for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
@@ -719,6 +727,12 @@
}
}
}
+
+ if (cnn_config->layer_config[layer].branch_copy_mode == COPY_COMBINED) {
+ copy_active_tensor_to_branches(&tensor1[branch],
+ &cnn_config->layer_config[layer], branch,
+ in_width, in_height, tensor2);
+ }
}
for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
diff --git a/av1/common/cnn.h b/av1/common/cnn.h
index 70fe270..6ebd5e4 100644
--- a/av1/common/cnn.h
+++ b/av1/common/cnn.h
@@ -37,6 +37,20 @@
// enum { NONE, RELU, SOFTSIGN } UENUM1BYTE(ACTIVATION);
+// Times when input tensor may be copied to branches given in input_to_branches.
+// COPY_NONE: doesn't copy any tensor.
+// COPY_INPUT: copies the input tensor to branches.
+// COPY_OUTPUT: copies the convolved tensor to branches.
+// COPY_COMBINED: copies the combined (after convolving and branch combining)
+// tensor. If no combinations happen at this layer, then this option
+// has the same effect as COPY_OUTPUT.
+enum {
+ COPY_NONE,
+ COPY_INPUT,
+ COPY_OUTPUT,
+ COPY_COMBINED
+} UENUM1BYTE(COPY_TYPE);
+
// Types of combining branches with output of current layer:
// BRANCH_NOC: no branch combining
// BRANCH_ADD: Add previously stored branch tensor to output of layer
@@ -65,7 +79,10 @@
float *bias; // array of length out_channels
PADDING_TYPE pad; // padding type
ACTIVATION activation; // the activation function to use after convolution
- int input_to_branches; // If nonzero, copy the input tensor to the current
+ COPY_TYPE branch_copy_mode;
+ // Within the layer, input a copy of active tensor to branches given in
+ // input_to_branches.
+ int input_to_branches; // If nonzero, copy the active tensor to the current
// layer and store for future use in branches
// specified in the field as a binary mask. For
// example, if input_to_branch = 0x06, it means the
diff --git a/test/cnn_test.cc b/test/cnn_test.cc
index d7eb116..e5811e1 100644
--- a/test/cnn_test.cc
+++ b/test/cnn_test.cc
@@ -203,6 +203,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -221,6 +222,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -239,6 +241,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -317,6 +320,7 @@
.bias = bias,
.pad = PADDING_SAME_ZERO,
.activation = RELU,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -364,6 +368,7 @@
.bias = bias,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -476,6 +481,7 @@
.bias = bias,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -553,6 +559,7 @@
.bias = bias_1,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -822,6 +829,7 @@
.bias = bias_10x11,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -970,6 +978,7 @@
.bias = bias,
.pad = PADDING_SAME_ZERO,
.activation = SOFTSIGN,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1040,6 +1049,7 @@
.bias = bias,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1058,6 +1068,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_INPUT,
.input_to_branches = 0x02,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1076,6 +1087,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1094,6 +1106,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1112,6 +1125,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_ADD,
.branches_to_combine = 0x02,
@@ -1130,6 +1144,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1194,6 +1209,7 @@
.bias = bias,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1212,6 +1228,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_INPUT,
.input_to_branches = 0x02,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1230,6 +1247,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1248,6 +1266,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1266,6 +1285,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_CAT,
.branches_to_combine = 0x02,
@@ -1284,6 +1304,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1297,6 +1318,8 @@
image_width, MSE_INT_TOL);
}
+// TODO(logangw): Add test to test all combinations of branch_copy_mode.
+
TEST_F(CNNTest, TestBranchCombinations) {
int filter_width = 2;
int filter_height = 3;
@@ -1305,31 +1328,31 @@
int image_height = 4;
float input[] = {
- -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
+ -1, -5, 1, 2, -2, 0, -1, 4, -3, -4, 0, -1, 1, 0, -4, 3,
};
float weights[] = {
- 0, 4, -4, -1, -4, -3, -5, 4, -2, -2, -1, 2, 2, 0, -1, 0, 2, -4,
- 3, 4, 2, 0, -5, -4, -4, 4, -3, -2, -5, 3, -3, 0, 0, 1, 1, -5,
- -2, 4, -5, 3, -2, -2, -4, 4, -2, -4, -3, 1, 3, -4, 2, -1, 0, -1,
- -2, -4, 4, 0, 0, -3, -3, -3, -5, 2, 1, 4, -1, 3, -3, -2, 3, -5,
- 2, -1, -4, -4, 3, 1, -4, -3, -4, 0, -5, -2, 1, 0, -1, -2, -2, 1,
- 4, 2, 2, 1, -1, 4, 2, 0, 4, -3, 1, -1, 0, -1, 1, -5, -4, -5,
- 0, -1, 2, -5, -2, 1, -3, -4, -2, 0, -2, -5, -3, -4, -5, 1, 0, 3,
- -4, 3, -4, -2, -1, -1, 2, 3, -3, -1, 0, 3, 0, 3, 0, -2, 2, 2,
- 1, -4, 1, -4, 0, -2, -1, 3, -2, -1, 4, 4, 1, -2, 4, 3, 1, -1,
- -5, -5, 0, -4, 4, 4, 4, 1, 0, -4, -1, 3, -5, 0, 3, -1, 2, 4,
- -2, 4, 4, 1, 4, 3, -4, 2, 0, 4, 2, 1, -5, 2, 4, -5, -2, 2,
- -4, -1, -3, -3, 2, -2, 1, 0, -5, 3, -3, 1, 3, 4, -5, 2, 0, 4,
+ 2, 4, -2, -3, 2, -1, 1, -2, 4, 3, -3, 4, 2, 4, -5, -1, 0, -2,
+ 2, -2, -2, 0, -2, -1, -5, -3, -5, 3, -1, 0, -2, -5, 3, 2, 2, 4,
+ -2, -2, -4, -4, -3, 1, 2, -1, -1, -2, -3, -3, -5, 1, 3, -4, 4, 4,
+ 0, 1, -4, -5, 4, -2, -1, -5, 3, -1, -1, 2, 3, -1, 0, 2, -2, -1,
+ -3, 4, -5, -3, 1, 0, -3, 4, 4, -1, 3, 1, 2, -3, 0, -4, -5, -4,
+ 3, -5, 4, 0, 4, -2, 1, -3, 4, 3, 3, -1, -3, 2, 4, 3, 2, -5,
+ 1, -2, 1, -3, 1, 0, 1, 2, 0, -2, -2, 1, -1, 3, 1, -2, 2, 2,
+ -5, -3, 0, -3, 4, -1, 1, -3, -2, 3, 4, -4, 1, -3, -2, 3, 1, -2,
+ -3, -1, 2, -5, -2, -2, -4, 1, 2, 1, 2, -2, -1, 0, -5, 0, 0, -2,
+ -5, -2, -5, 4, 3, -1, -1, 0, -5, -2, 1, 2, 4, 1, 0, -3, -1, 2,
+ -4, 4, 4, 4, 3, 0, 3, 2, 2, -4, 2, 0, -1, -1, 4, -1, -2, -5,
+ -3, -2, -2, 1, -4, -3, 0, 4, 4, -5, -2, 0, 1, -4, -2, -1, -3, 0,
};
float bias[] = {
- 4, -4, 0, -1, 1, 1, 3, -4, -5, 3, 2, -3, -5, -3, -1, 3, 4, -5, -3,
+ 0, 2, -3, 1, -3, -1, 3, 4, -2, -3, -2, 0, 4, -5, -3, -5, -4, 4, -5,
};
float expected[] = {
- 316397, 106874, 26726, 1971, 355397, 6848, -20952, 18023,
- -10736, -52466, -22737, -1496, -220644, -55007, 15175, -7343,
+ -186549, 23839, 70503, 11452, -224977, -58609, 26347, 28867,
+ 158422, -192903, 35079, 6881, 189062, -73870, -22263, -8869,
};
int channels = 2;
@@ -1354,6 +1377,7 @@
.bias = bias,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1372,6 +1396,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_INPUT,
.input_to_branches = 0x06,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1390,24 +1415,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
- .input_to_branches = 0,
- .branch_combine_type = BRANCH_NOC,
- .branches_to_combine = 0,
- },
- {
- .branch = 2,
- .deconvolve = 0,
- .in_channels = channels,
- .filter_width = filter_width,
- .filter_height = filter_height,
- .out_channels = channels,
- .skip_width = 1,
- .skip_height = 1,
- .maxpool = 0,
- .weights = nullptr,
- .bias = nullptr,
- .pad = PADDING_SAME_ZERO,
- .activation = NONE,
+ .branch_copy_mode = COPY_OUTPUT,
.input_to_branches = 0x08,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1426,6 +1434,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1444,11 +1453,31 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
- .input_to_branches = 0,
+ .branch_copy_mode = COPY_NONE,
+ .input_to_branches = 0x00,
.branch_combine_type = BRANCH_ADD,
.branches_to_combine = 0x08,
},
{
+ .branch = 2,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .branch_copy_mode = COPY_NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0x00,
+ },
+ {
.branch = 1,
.deconvolve = 0,
.in_channels = channels,
@@ -1462,6 +1491,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,
@@ -1480,6 +1510,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_ADD,
.branches_to_combine = 0x0C,
@@ -1498,6 +1529,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_ADD,
.branches_to_combine = 0x02,
@@ -1516,6 +1548,7 @@
.bias = nullptr,
.pad = PADDING_SAME_ZERO,
.activation = NONE,
+ .branch_copy_mode = COPY_NONE,
.input_to_branches = 0,
.branch_combine_type = BRANCH_NOC,
.branches_to_combine = 0,