Add copy input or output tensor of a layer.

- Add an enumeration to specify to use a layer's tensor
  as input to a branch prior to convolution, after
  convolution, or after combination.

Change-Id: I642dd86089484dc805481ecf79182c7a0a339364
diff --git a/av1/common/cnn.c b/av1/common/cnn.c
index 854c292..cb4314c 100644
--- a/av1/common/cnn.c
+++ b/av1/common/cnn.c
@@ -223,6 +223,23 @@
   }
 }
 
+static void copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
+                                           const CNN_LAYER_CONFIG *layer_config,
+                                           int branch, int in_width,
+                                           int in_height,
+                                           TENSOR branch_output[]) {
+  for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
+    if ((layer_config->input_to_branches & (1 << b)) && b != branch) {
+      // Copy layer's active tensor to output tensor of branch b if set in
+      // mask. The output becomes the input of the first layer of the branch
+      // because the layer of the branch is not the first layer.
+      realloc_tensor(&branch_output[b], layer_config->in_channels, in_width,
+                     in_height);
+      copy_tensor(layer_active_tensor, 0, &branch_output[b]);
+    }
+  }
+}
+
 void av1_cnn_convolve_c(const float **input, int in_width, int in_height,
                         int in_stride, const CNN_LAYER_CONFIG *layer_config,
                         float **output, int out_stride) {
@@ -620,34 +637,12 @@
                        cnn_config->layer_config[layer].out_channels, o_width,
                        o_height);
       }
-      for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
-        if ((cnn_config->layer_config[layer].input_to_branches & (1 << b)) &&
-            b != branch) {
-          // Copy layer's input tensor to output tensor of branch b if set in
-          // mask
-          realloc_tensor(&tensor2[b],
-                         cnn_config->layer_config[layer].in_channels, in_width,
-                         in_height);
-          copy_tensor(&tensor1[branch], 0, &tensor2[b]);
-        }
-      }
     } else {  // Non-first layer
       // Swap tensor1 and tensor2
       swap_tensor(&tensor1[branch], &tensor2[branch]);
 
       i_width = o_width;
       i_height = o_height;
-      for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
-        if ((cnn_config->layer_config[layer].input_to_branches & (1 << b)) &&
-            b != branch) {
-          // Copy layer's input tensor to output tensor of branch b if set in
-          // mask
-          realloc_tensor(&tensor2[b],
-                         cnn_config->layer_config[layer].in_channels, in_width,
-                         in_height);
-          copy_tensor(&tensor1[branch], 0, &tensor2[b]);
-        }
-      }
       find_layer_output_size(i_width, i_height,
                              &cnn_config->layer_config[layer], &o_width,
                              &o_height);
@@ -663,6 +658,7 @@
                       o_height, out_stride);
       }
     }
+
     // If we are combining branches make sure that the branch to combine
     // is different from the current branch.
     assert(IMPLIES(
@@ -670,6 +666,11 @@
         !(cnn_config->layer_config[layer].branches_to_combine &
           (1 << branch))));
 
+    if (cnn_config->layer_config[layer].branch_copy_mode == COPY_INPUT) {
+      copy_active_tensor_to_branches(&tensor1[branch],
+                                     &cnn_config->layer_config[layer], branch,
+                                     in_width, in_height, tensor2);
+    }
     // Check consistency of input and output channels
     assert(tensor1[branch].channels ==
            cnn_config->layer_config[layer].in_channels);
@@ -689,6 +690,13 @@
                          &cnn_config->layer_config[layer], tensor2[branch].buf,
                          tensor2[branch].stride);
     }
+
+    if (cnn_config->layer_config[layer].branch_copy_mode == COPY_OUTPUT) {
+      copy_active_tensor_to_branches(&tensor1[branch],
+                                     &cnn_config->layer_config[layer], branch,
+                                     in_width, in_height, tensor2);
+    }
+
     // Add tensors from other branches if needed
     if (cnn_config->layer_config[layer].branch_combine_type == BRANCH_ADD) {
       for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
@@ -719,6 +727,12 @@
         }
       }
     }
+
+    if (cnn_config->layer_config[layer].branch_copy_mode == COPY_COMBINED) {
+      copy_active_tensor_to_branches(&tensor1[branch],
+                                     &cnn_config->layer_config[layer], branch,
+                                     in_width, in_height, tensor2);
+    }
   }
 
   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
diff --git a/av1/common/cnn.h b/av1/common/cnn.h
index 70fe270..6ebd5e4 100644
--- a/av1/common/cnn.h
+++ b/av1/common/cnn.h
@@ -37,6 +37,20 @@
 
 // enum { NONE, RELU, SOFTSIGN } UENUM1BYTE(ACTIVATION);
 
+// Times when input tensor may be copied to branches given in input_to_branches.
+// COPY_NONE: doesn't copy any tensor.
+// COPY_INPUT: copies the input tensor to branches.
+// COPY_OUTPUT: copies the convolved tensor to branches.
+// COPY_COMBINED: copies the combined (after convolving and branch combining)
+//   tensor. If no combinations happen at this layer, then this option
+//   has the same effect as COPY_OUTPUT.
+enum {
+  COPY_NONE,
+  COPY_INPUT,
+  COPY_OUTPUT,
+  COPY_COMBINED
+} UENUM1BYTE(COPY_TYPE);
+
 // Types of combining branches with output of current layer:
 // BRANCH_NOC: no branch combining
 // BRANCH_ADD: Add previously stored branch tensor to output of layer
@@ -65,7 +79,10 @@
   float *bias;     // array of length out_channels
   PADDING_TYPE pad;       // padding type
   ACTIVATION activation;  // the activation function to use after convolution
-  int input_to_branches;  // If nonzero, copy the input tensor to the current
+  COPY_TYPE branch_copy_mode;
+  // Within the layer, input a copy of active tensor to branches given in
+  // input_to_branches.
+  int input_to_branches;  // If nonzero, copy the active tensor to the current
                           // layer and store for future use in branches
                           // specified in the field as a binary mask. For
                           // example, if input_to_branch = 0x06, it means the
diff --git a/test/cnn_test.cc b/test/cnn_test.cc
index d7eb116..e5811e1 100644
--- a/test/cnn_test.cc
+++ b/test/cnn_test.cc
@@ -203,6 +203,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -221,6 +222,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -239,6 +241,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -317,6 +320,7 @@
                                 .bias = bias,
                                 .pad = PADDING_SAME_ZERO,
                                 .activation = RELU,
+                                .branch_copy_mode = COPY_NONE,
                                 .input_to_branches = 0,
                                 .branch_combine_type = BRANCH_NOC,
                                 .branches_to_combine = 0,
@@ -364,6 +368,7 @@
                                     .bias = bias,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -476,6 +481,7 @@
                                     .bias = bias,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -553,6 +559,7 @@
                                 .bias = bias_1,
                                 .pad = PADDING_SAME_ZERO,
                                 .activation = NONE,
+                                .branch_copy_mode = COPY_NONE,
                                 .input_to_branches = 0,
                                 .branch_combine_type = BRANCH_NOC,
                                 .branches_to_combine = 0,
@@ -822,6 +829,7 @@
                                 .bias = bias_10x11,
                                 .pad = PADDING_SAME_ZERO,
                                 .activation = NONE,
+                                .branch_copy_mode = COPY_NONE,
                                 .input_to_branches = 0,
                                 .branch_combine_type = BRANCH_NOC,
                                 .branches_to_combine = 0,
@@ -970,6 +978,7 @@
                                 .bias = bias,
                                 .pad = PADDING_SAME_ZERO,
                                 .activation = SOFTSIGN,
+                                .branch_copy_mode = COPY_NONE,
                                 .input_to_branches = 0,
                                 .branch_combine_type = BRANCH_NOC,
                                 .branches_to_combine = 0,
@@ -1040,6 +1049,7 @@
                                   .bias = bias,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1058,6 +1068,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_INPUT,
                                   .input_to_branches = 0x02,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1076,6 +1087,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1094,6 +1106,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1112,6 +1125,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_ADD,
                                   .branches_to_combine = 0x02,
@@ -1130,6 +1144,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1194,6 +1209,7 @@
                                   .bias = bias,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1212,6 +1228,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_INPUT,
                                   .input_to_branches = 0x02,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1230,6 +1247,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1248,6 +1266,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1266,6 +1285,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_CAT,
                                   .branches_to_combine = 0x02,
@@ -1284,6 +1304,7 @@
                                   .bias = nullptr,
                                   .pad = PADDING_SAME_ZERO,
                                   .activation = NONE,
+                                  .branch_copy_mode = COPY_NONE,
                                   .input_to_branches = 0,
                                   .branch_combine_type = BRANCH_NOC,
                                   .branches_to_combine = 0,
@@ -1297,6 +1318,8 @@
              image_width, MSE_INT_TOL);
 }
 
+// TODO(logangw): Add test to test all combinations of branch_copy_mode.
+
 TEST_F(CNNTest, TestBranchCombinations) {
   int filter_width = 2;
   int filter_height = 3;
@@ -1305,31 +1328,31 @@
   int image_height = 4;
 
   float input[] = {
-    -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
+    -1, -5, 1, 2, -2, 0, -1, 4, -3, -4, 0, -1, 1, 0, -4, 3,
   };
 
   float weights[] = {
-    0,  4,  -4, -1, -4, -3, -5, 4,  -2, -2, -1, 2,  2,  0,  -1, 0,  2,  -4,
-    3,  4,  2,  0,  -5, -4, -4, 4,  -3, -2, -5, 3,  -3, 0,  0,  1,  1,  -5,
-    -2, 4,  -5, 3,  -2, -2, -4, 4,  -2, -4, -3, 1,  3,  -4, 2,  -1, 0,  -1,
-    -2, -4, 4,  0,  0,  -3, -3, -3, -5, 2,  1,  4,  -1, 3,  -3, -2, 3,  -5,
-    2,  -1, -4, -4, 3,  1,  -4, -3, -4, 0,  -5, -2, 1,  0,  -1, -2, -2, 1,
-    4,  2,  2,  1,  -1, 4,  2,  0,  4,  -3, 1,  -1, 0,  -1, 1,  -5, -4, -5,
-    0,  -1, 2,  -5, -2, 1,  -3, -4, -2, 0,  -2, -5, -3, -4, -5, 1,  0,  3,
-    -4, 3,  -4, -2, -1, -1, 2,  3,  -3, -1, 0,  3,  0,  3,  0,  -2, 2,  2,
-    1,  -4, 1,  -4, 0,  -2, -1, 3,  -2, -1, 4,  4,  1,  -2, 4,  3,  1,  -1,
-    -5, -5, 0,  -4, 4,  4,  4,  1,  0,  -4, -1, 3,  -5, 0,  3,  -1, 2,  4,
-    -2, 4,  4,  1,  4,  3,  -4, 2,  0,  4,  2,  1,  -5, 2,  4,  -5, -2, 2,
-    -4, -1, -3, -3, 2,  -2, 1,  0,  -5, 3,  -3, 1,  3,  4,  -5, 2,  0,  4,
+    2,  4,  -2, -3, 2,  -1, 1,  -2, 4,  3,  -3, 4,  2,  4,  -5, -1, 0,  -2,
+    2,  -2, -2, 0,  -2, -1, -5, -3, -5, 3,  -1, 0,  -2, -5, 3,  2,  2,  4,
+    -2, -2, -4, -4, -3, 1,  2,  -1, -1, -2, -3, -3, -5, 1,  3,  -4, 4,  4,
+    0,  1,  -4, -5, 4,  -2, -1, -5, 3,  -1, -1, 2,  3,  -1, 0,  2,  -2, -1,
+    -3, 4,  -5, -3, 1,  0,  -3, 4,  4,  -1, 3,  1,  2,  -3, 0,  -4, -5, -4,
+    3,  -5, 4,  0,  4,  -2, 1,  -3, 4,  3,  3,  -1, -3, 2,  4,  3,  2,  -5,
+    1,  -2, 1,  -3, 1,  0,  1,  2,  0,  -2, -2, 1,  -1, 3,  1,  -2, 2,  2,
+    -5, -3, 0,  -3, 4,  -1, 1,  -3, -2, 3,  4,  -4, 1,  -3, -2, 3,  1,  -2,
+    -3, -1, 2,  -5, -2, -2, -4, 1,  2,  1,  2,  -2, -1, 0,  -5, 0,  0,  -2,
+    -5, -2, -5, 4,  3,  -1, -1, 0,  -5, -2, 1,  2,  4,  1,  0,  -3, -1, 2,
+    -4, 4,  4,  4,  3,  0,  3,  2,  2,  -4, 2,  0,  -1, -1, 4,  -1, -2, -5,
+    -3, -2, -2, 1,  -4, -3, 0,  4,  4,  -5, -2, 0,  1,  -4, -2, -1, -3, 0,
   };
 
   float bias[] = {
-    4, -4, 0, -1, 1, 1, 3, -4, -5, 3, 2, -3, -5, -3, -1, 3, 4, -5, -3,
+    0, 2, -3, 1, -3, -1, 3, 4, -2, -3, -2, 0, 4, -5, -3, -5, -4, 4, -5,
   };
 
   float expected[] = {
-    316397, 106874, 26726,  1971,  355397,  6848,   -20952, 18023,
-    -10736, -52466, -22737, -1496, -220644, -55007, 15175,  -7343,
+    -186549, 23839,   70503, 11452, -224977, -58609, 26347,  28867,
+    158422,  -192903, 35079, 6881,  189062,  -73870, -22263, -8869,
   };
 
   int channels = 2;
@@ -1354,6 +1377,7 @@
                                     .bias = bias,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -1372,6 +1396,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_INPUT,
                                     .input_to_branches = 0x06,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -1390,24 +1415,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
-                                    .input_to_branches = 0,
-                                    .branch_combine_type = BRANCH_NOC,
-                                    .branches_to_combine = 0,
-                                },
-                                {
-                                    .branch = 2,
-                                    .deconvolve = 0,
-                                    .in_channels = channels,
-                                    .filter_width = filter_width,
-                                    .filter_height = filter_height,
-                                    .out_channels = channels,
-                                    .skip_width = 1,
-                                    .skip_height = 1,
-                                    .maxpool = 0,
-                                    .weights = nullptr,
-                                    .bias = nullptr,
-                                    .pad = PADDING_SAME_ZERO,
-                                    .activation = NONE,
+                                    .branch_copy_mode = COPY_OUTPUT,
                                     .input_to_branches = 0x08,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -1426,6 +1434,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -1444,11 +1453,31 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
-                                    .input_to_branches = 0,
+                                    .branch_copy_mode = COPY_NONE,
+                                    .input_to_branches = 0x00,
                                     .branch_combine_type = BRANCH_ADD,
                                     .branches_to_combine = 0x08,
                                 },
                                 {
+                                    .branch = 2,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0x00,
+                                },
+                                {
                                     .branch = 1,
                                     .deconvolve = 0,
                                     .in_channels = channels,
@@ -1462,6 +1491,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,
@@ -1480,6 +1510,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_ADD,
                                     .branches_to_combine = 0x0C,
@@ -1498,6 +1529,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_ADD,
                                     .branches_to_combine = 0x02,
@@ -1516,6 +1548,7 @@
                                     .bias = nullptr,
                                     .pad = PADDING_SAME_ZERO,
                                     .activation = NONE,
+                                    .branch_copy_mode = COPY_NONE,
                                     .input_to_branches = 0,
                                     .branch_combine_type = BRANCH_NOC,
                                     .branches_to_combine = 0,