Add support for multiple CNN outputs

Intra-frame partitioning uses a multi-resolution approach. So the CNN
model needs to output one segmentation map for each bsize from
BLOCK_64X64 to BLOCK_8X8.

Change-Id: I6bdf28ef5613741917ac39a4e71b4d7b93035085
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index caacb14..45cc476 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -46,6 +46,8 @@
 typedef struct CNN_THREAD_DATA CNN_THREAD_DATA;
 struct CNN_BRANCH_CONFIG;
 typedef struct CNN_BRANCH_CONFIG CNN_BRANCH_CONFIG;
+struct CNN_MULTI_OUT;
+typedef struct CNN_MULTI_OUT CNN_MULTI_OUT;
 
 /* Function pointers return by CfL functions */
 typedef void (*cfl_subsample_lbd_fn)(const uint8_t *input, int input_stride,
@@ -329,7 +331,7 @@
 
 add_proto qw/void av1_cnn_activate/, " float **input, int channels, int width, int height, int stride, ACTIVATION layer_activation";
 add_proto qw/void av1_cnn_add/, " float **input, int channels, int width, int height, int stride, const float **add";
-add_proto qw/void av1_cnn_predict/, " const float **input, int in_width, int in_height, int in_stride, const CNN_CONFIG *cnn_config, const CNN_THREAD_DATA *thread_data, float **output, int out_stride";
+add_proto qw/void av1_cnn_predict/, " const float **input, int in_width, int in_height, int in_stride, const CNN_CONFIG *cnn_config, const CNN_THREAD_DATA *thread_data, CNN_MULTI_OUT *output_struct";
 add_proto qw/void av1_cnn_convolve/, " const float **input, int in_width, int in_height, int in_stride, const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride, int start_idx, int step";
 add_proto qw/void av1_cnn_deconvolve/, " const float **input, int in_width, int in_height, int in_stride, const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride";
 add_proto qw/void av1_cnn_batchnorm/, "float **image, int channels, int width, int height, int stride, const float *gamma, const float *beta, const float *mean, const float *std";
diff --git a/av1/encoder/cnn.c b/av1/encoder/cnn.c
index bbd37de..da33837 100644
--- a/av1/encoder/cnn.c
+++ b/av1/encoder/cnn.c
@@ -217,24 +217,72 @@
   }
 }
 
+#if CONFIG_DEBUG
+static INLINE int cnn_has_at_least_one_output(const CNN_CONFIG *cnn_config) {
+  const int num_layers = cnn_config->num_layers;
+  const CNN_LAYER_CONFIG *layer_configs = cnn_config->layer_config;
+
+  for (int idx = 0; idx < num_layers; idx++) {
+    if (layer_configs[idx].output_num != -1) {
+      return 1;
+    }
+  }
+  return 0;
+}
+#endif
+
 void av1_find_cnn_output_size(int in_width, int in_height,
                               const CNN_CONFIG *cnn_config, int *out_width,
                               int *out_height, int *out_channels) {
-  int i_width = in_width + cnn_config->ext_width * 2;
-  int i_height = in_height + cnn_config->ext_height * 2;
   int channels_per_branch[CNN_MAX_BRANCHES] = { 0 };
-  for (int i = 0; i < cnn_config->num_layers; ++i) {
-    int o_width = 0, o_height = 0;
-    find_layer_output_size(i_width, i_height, &cnn_config->layer_config[i],
-                           &o_width, &o_height);
-    i_width = o_width;
-    i_height = o_height;
+  int i_width[CNN_MAX_BRANCHES] = { 0 };
+  int i_height[CNN_MAX_BRANCHES] = { 0 };
+  i_width[0] = in_width + cnn_config->ext_width * 2;
+  i_height[0] = in_height + cnn_config->ext_height * 2;
 
-    find_cnn_out_channels(&cnn_config->layer_config[i], channels_per_branch);
+#if CONFIG_DEBUG
+  assert(cnn_has_at_least_one_output(cnn_config));
+#endif
+
+  for (int i = 0; i < cnn_config->num_layers; ++i) {
+    const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[i];
+    const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
+    const int branch = layer_config->branch;
+    int o_width = 0, o_height = 0;
+
+    if (layer_config->branch_copy_type == BRANCH_INPUT) {
+      for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
+        if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
+          assert(i_width[branch] > 0 && i_height[branch] > 0);
+          i_width[b] = i_width[branch];
+          i_height[b] = i_height[branch];
+        }
+      }
+    }
+
+    find_layer_output_size(i_width[branch], i_height[branch], layer_config,
+                           &o_width, &o_height);
+    i_width[branch] = o_width;
+    i_height[branch] = o_height;
+
+    if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
+      for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
+        if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
+          i_width[b] = o_width;
+          i_height[b] = o_height;
+        }
+      }
+    }
+
+    find_cnn_out_channels(layer_config, channels_per_branch);
+
+    const int output_num = layer_config->output_num;
+    if (output_num != -1) {  // Current layer is an output layer
+      out_width[output_num] = o_width;
+      out_height[output_num] = o_height;
+      out_channels[output_num] = channels_per_branch[layer_config->branch];
+    }
   }
-  *out_width = i_width;
-  *out_height = i_height;
-  *out_channels = channels_per_branch[0];
 }
 
 activation_fn get_activation(ACTIVATION layer_activation) {
@@ -780,11 +828,18 @@
 
 void av1_cnn_predict_c(const float **input, int in_width, int in_height,
                        int in_stride, const CNN_CONFIG *cnn_config,
-                       const CNN_THREAD_DATA *thread_data, float **output,
-                       int out_stride) {
+                       const CNN_THREAD_DATA *thread_data,
+                       CNN_MULTI_OUT *output_struct) {
   TENSOR tensor1[CNN_MAX_BRANCHES] = { 0 };
   TENSOR tensor2[CNN_MAX_BRANCHES] = { 0 };
 
+  float **output[CNN_MAX_BRANCHES];
+  const int *out_chs = output_struct->output_channels;
+  output[0] = output_struct->output_buffer;
+  for (int out_idx = 1; out_idx < output_struct->num_outputs; out_idx++) {
+    output[out_idx] = output[out_idx - 1] + out_chs[out_idx - 1];
+  }
+
   int i_width = in_width;
   int i_height = in_height;
   int o_width = 0, o_height = 0;
@@ -793,6 +848,7 @@
     init_tensor(&tensor2[b]);
   }
 
+  const int *out_stride = output_struct->output_strides;
   for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
     const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
     const int branch = layer_config->branch;
@@ -807,21 +863,22 @@
       // Swap tensor1 and tensor2
       swap_tensor(&tensor1[branch], &tensor2[branch]);
 
-      i_width = o_width;
-      i_height = o_height;
+      i_width = tensor1[branch].width;
+      i_height = tensor1[branch].height;
     }
 
     // Allocate output tensor
     find_layer_output_size(i_width, i_height, layer_config, &o_width,
                            &o_height);
-    if (layer < cnn_config->num_layers - 1) {  // Non-last layer
+    const int output_num = layer_config->output_num;
+    if (output_num == -1) {  // Non-output layer
       realloc_tensor(&tensor2[branch], layer_config->out_channels, o_width,
                      o_height);
-    } else {                // Last layer
-      assert(branch == 0);  // Last layer must be primary branch
+    } else {  // Output layer
       free_tensor(&tensor2[branch]);
-      assign_tensor(&tensor2[branch], output, layer_config->out_channels,
-                    o_width, o_height, out_stride);
+      assign_tensor(&tensor2[branch], output[output_num],
+                    layer_config->out_channels, o_width, o_height,
+                    out_stride[output_num]);
     }
 
     // If we are combining branches make sure that the branch to combine
@@ -890,7 +947,7 @@
 
     // Concatenate tensors
     if (layer_config->branch_combine_type == BRANCH_CAT) {
-      if (layer < cnn_config->num_layers - 1) {  // Non-last layer
+      if (output_num == -1) {  // Non-output layer
         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
@@ -898,17 +955,27 @@
             concat_tensor(&tensor2[b], &tensor2[branch]);
           }
         }
-      } else {  // Last layer
+      } else {  // Output layer
+        const int existing_channels = tensor2[branch].channels;
+        int num_chs = existing_channels;
         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
-            const int existing_channels = tensor2[branch].channels;
             // Needed only to assign the new channel buffers
-            assign_tensor(&tensor2[branch], output,
-                          existing_channels + tensor2[b].channels, o_width,
-                          o_height, out_stride);
-            copy_tensor(&tensor2[b], tensor2[b].channels, existing_channels,
+            num_chs += tensor2[b].channels;
+          }
+        }
+        assign_tensor(&tensor2[branch], output[output_num], num_chs, o_width,
+                      o_height, out_stride[output_num]);
+
+        num_chs = existing_channels;
+        for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
+          if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
+            assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
+            // Needed only to assign the new channel buffers
+            copy_tensor(&tensor2[b], tensor2[b].channels, num_chs,
                         &tensor2[branch]);
+            num_chs += tensor2[b].channels;
           }
         }
       }
@@ -928,20 +995,15 @@
 
 // Assume output already has proper allocation
 // Assume input image buffers all have same resolution and strides
-void av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
-                         const CNN_CONFIG *cnn_config,
-                         const CNN_THREAD_DATA *thread_data, float **output,
-                         int out_stride) {
+void av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
+                                   int stride, const CNN_CONFIG *cnn_config,
+                                   const CNN_THREAD_DATA *thread_data,
+                                   CNN_MULTI_OUT *output) {
   const float max_val = 255.0;
-  int out_width = 0;
-  int out_height = 0;
-  int out_channels = 0;
-  av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height,
-                           &out_channels);
 
-  int in_width = width + 2 * cnn_config->ext_width;
-  int in_height = height + 2 * cnn_config->ext_height;
-  int in_channels = cnn_config->layer_config[0].in_channels;
+  const int in_width = width + 2 * cnn_config->ext_width;
+  const int in_height = height + 2 * cnn_config->ext_height;
+  const int in_channels = cnn_config->layer_config[0].in_channels;
   float *inputs[CNN_MAX_CHANNELS];
   float *input_ =
       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
@@ -980,27 +1042,24 @@
     }
   }
   av1_cnn_predict((const float **)inputs, in_width, in_height, in_stride,
-                  cnn_config, thread_data, output, out_stride);
+                  cnn_config, thread_data, output);
 
   aom_free(input_);
 }
 
 // Assume output already has proper allocation
 // Assume input image buffers all have same resolution and strides
-void av1_cnn_predict_img_highbd(uint16_t **dgd, int width, int height,
-                                int stride, const CNN_CONFIG *cnn_config,
-                                const CNN_THREAD_DATA *thread_data,
-                                int bit_depth, float **output, int out_stride) {
+void av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
+                                          int stride,
+                                          const CNN_CONFIG *cnn_config,
+                                          const CNN_THREAD_DATA *thread_data,
+                                          int bit_depth,
+                                          CNN_MULTI_OUT *output) {
   const float max_val = (float)((1 << bit_depth) - 1);
-  int out_width = 0;
-  int out_height = 0;
-  int out_channels = 0;
-  av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height,
-                           &out_channels);
 
-  int in_width = width + 2 * cnn_config->ext_width;
-  int in_height = height + 2 * cnn_config->ext_height;
-  int in_channels = cnn_config->layer_config[0].in_channels;
+  const int in_width = width + 2 * cnn_config->ext_width;
+  const int in_height = height + 2 * cnn_config->ext_height;
+  const int in_channels = cnn_config->layer_config[0].in_channels;
   float *inputs[CNN_MAX_CHANNELS];
   float *input_ =
       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
@@ -1038,7 +1097,45 @@
           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
     }
   }
-  av1_cnn_predict((const float **)inputs, width, height, in_stride, cnn_config,
-                  thread_data, output, out_stride);
+
+  av1_cnn_predict((const float **)inputs, in_width, in_height, in_stride,
+                  cnn_config, thread_data, output);
+
   aom_free(input_);
 }
+
+// Assume output already has proper allocation
+// Assume input image buffers all have same resolution and strides
+void av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
+                         const CNN_CONFIG *cnn_config,
+                         const CNN_THREAD_DATA *thread_data, float **output,
+                         int out_stride) {
+  int out_width = 0, out_height = 0, out_channels = 0;
+  av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height,
+                           &out_channels);
+  const int output_chs[1] = { out_channels };
+  const int output_strides[1] = { out_stride };
+  CNN_MULTI_OUT output_struct = { .output_channels = output_chs,
+                                  .output_strides = output_strides,
+                                  .output_buffer = output };
+  av1_cnn_predict_img_multi_out(dgd, width, height, stride, cnn_config,
+                                thread_data, &output_struct);
+}
+
+// Assume output already has proper allocation
+// Assume input image buffers all have same resolution and strides
+void av1_cnn_predict_img_highbd(uint16_t **dgd, int width, int height,
+                                int stride, const CNN_CONFIG *cnn_config,
+                                const CNN_THREAD_DATA *thread_data,
+                                int bit_depth, float **output, int out_stride) {
+  int out_width = 0, out_height = 0, out_channels = 0;
+  av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height,
+                           &out_channels);
+  const int output_chs[1] = { out_channels };
+  const int output_strides[1] = { out_stride };
+  CNN_MULTI_OUT output_struct = { .output_channels = output_chs,
+                                  .output_strides = output_strides,
+                                  .output_buffer = output };
+  av1_cnn_predict_img_multi_out_highbd(dgd, width, height, stride, cnn_config,
+                                       thread_data, bit_depth, &output_struct);
+}
diff --git a/av1/encoder/cnn.h b/av1/encoder/cnn.h
index dc800a0..351b103 100644
--- a/av1/encoder/cnn.h
+++ b/av1/encoder/cnn.h
@@ -107,7 +107,7 @@
   int maxpool;            // whether to use maxpool or not (only effective when
                           // skip width or skip_height are > 1)
   const float *weights;   // array of length filter_height x filter_width x
-                          // in_channels // x out_channels where the inner-most
+                          // in_channels x out_channels where the inner-most
                           // scan is out_channels and the outer most scan is
                           // filter_height.
   const float *bias;      // array of length out_channels
@@ -124,8 +124,14 @@
   BRANCH_COMBINE branch_combine_type;
   struct CNN_BRANCH_CONFIG branch_config;
   struct CNN_BATCHNORM_PARAMS
-      bn_params;  // A struct that contains the parameters
-                  // used for batch normalization.
+      bn_params;   // A struct that contains the parameters
+                   // used for batch normalization.
+  int output_num;  // The output buffer idx to which the layer output is
+                   // written. Set to -1 to disable writing it to the output. In
+                   // the case that branch_combine_type is BRANCH_CAT, all
+                   // concatenated channels will be written to output. In the
+                   // case of BRANCH_ADD, the output will be the result of
+                   // summation.
 };
 
 struct CNN_CONFIG {
@@ -144,12 +150,27 @@
   AVxWorker *workers;
 };
 
+struct CNN_MULTI_OUT {
+  int num_outputs;
+  const int *output_channels;
+  const int *output_strides;
+  float **output_buffer;
+};
+
 // Function to return size of output
 void av1_find_cnn_output_size(int in_width, int in_height,
                               const CNN_CONFIG *cnn_config, int *out_width,
                               int *out_height, int *out_channels);
 
-// Prediction functions from set of input image buffers
+// Prediction functions from set of input image buffers. This function supports
+// CNN with multiple outputs.
+void av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
+                                   int stride, const CNN_CONFIG *cnn_config,
+                                   const CNN_THREAD_DATA *thread_data,
+                                   struct CNN_MULTI_OUT *output);
+
+// Prediction functions from set of input image buffers. This function only
+// supports a single output.
 void av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
                          const CNN_CONFIG *cnn_config,
                          const CNN_THREAD_DATA *thread_data, float **output,
diff --git a/test/cnn_test.cc b/test/cnn_test.cc
index 5a884aa..4410493 100644
--- a/test/cnn_test.cc
+++ b/test/cnn_test.cc
@@ -32,9 +32,10 @@
 
 class CNNTest : public ::testing::Test {
  protected:
-  static void RunCNNTest(int image_width, int image_height, float *input,
-                         float *expected, CNN_CONFIG *cnn_config, int in_stride,
-                         CNN_THREAD_DATA *thread_data, double tolerance) {
+  static void RunCNNTest(int image_width, int image_height, const float *input,
+                         const float *expected, const CNN_CONFIG *cnn_config,
+                         int in_stride, CNN_THREAD_DATA *thread_data,
+                         double tolerance) {
     int out_width, out_height, out_channels;
     av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
                              &out_height, &out_channels);
@@ -48,26 +49,68 @@
     for (int channel = 0; channel < out_channels; ++channel) {
       output[channel] = output_ + (channel * out_size);
     }
+    const int num_outputs = 1;
+    const int output_chs[1] = { out_channels };
+    const int output_strides[1] = { out_stride };
+    CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
+                                    output };
 
-    av1_cnn_predict((const float **)&input, image_width, image_height,
-                    in_stride, cnn_config, thread_data, output, out_stride);
-
-    double mse = 0;
-    for (int channel = 0; channel < out_channels; ++channel) {
-      for (int i = 0; i < out_size; ++i) {
-        int index = channel * out_size + i;
-        EXPECT_NEAR(expected[index], output[channel][i], PIXELWISE_FLOAT_TOL)
-            << index << ": " << expected[index] << "/" << output[channel][i]
-            << std::endl;
-        mse += SQR(expected[index] - output[channel][i]);
-      }
-    }
-    mse /= (out_size * out_channels);
-    EXPECT_LE(mse, tolerance);
+    RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
+                       thread_data, &output_struct, &expected, tolerance);
 
     aom_free(output_);
   }
 
+  static void RunMultiOutCNNTest(const float **input, int image_width,
+                                 int image_height, int in_stride,
+                                 const CNN_CONFIG *cnn_config,
+                                 CNN_THREAD_DATA *thread_data,
+                                 CNN_MULTI_OUT *output, const float **expected,
+                                 double tolerance) {
+    const int num_outputs = output->num_outputs;
+    const int *output_chs = output->output_channels;
+
+    int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
+    int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
+    int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
+
+    av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
+                             out_heights, not_used);
+    av1_cnn_predict(input, image_width, image_height, in_stride, cnn_config,
+                    thread_data, output);
+
+    int channel_offset = 0;
+    for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+      const float *expected_out = expected[output_idx];
+      const int curr_output_chs = output_chs[output_idx];
+      const int out_size = out_widths[output_idx] * out_heights[output_idx];
+
+      double mse = 0;
+      int expected_ite = 0;
+      for (int channel = 0; channel < curr_output_chs; ++channel) {
+        const float *buf_out = output->output_buffer[channel_offset];
+
+        for (int i = 0; i < out_size; ++i) {
+          EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
+                      PIXELWISE_FLOAT_TOL)
+              << " output " << output_idx << " channel " << channel << " pixel "
+              << expected_ite % out_size << ": " << expected_out[expected_ite]
+              << "/" << buf_out[i] << std::endl;
+          mse += SQR(expected_out[expected_ite] - buf_out[i]);
+          expected_ite++;
+        }
+
+        channel_offset++;
+      }
+      mse /= (out_size * curr_output_chs);
+      EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
+    }
+
+    aom_free(out_widths);
+    aom_free(out_heights);
+    aom_free(not_used);
+  }
+
   static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
                                        float *bias) {
     size_t weight_offset = 0;
@@ -219,6 +262,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    -1,
                                 },
                                 {
                                     3,
@@ -238,6 +282,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    -1,
                                 },
                                 {
                                     3,
@@ -257,6 +302,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    0,
                                 },
                             } };
 
@@ -338,6 +384,7 @@
                                 BRANCH_NOC,
                                 {},
                                 {},
+                                0,
                             } } };
 
   CNN_THREAD_DATA thread_data = { 1, NULL };
@@ -388,6 +435,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    0,
                                 },
                             } };
 
@@ -502,6 +550,7 @@
                                 BRANCH_NOC,
                                 {},
                                 {},
+                                0,
                             } } };
 
   CNN_THREAD_DATA thread_data = { 1, NULL };
@@ -581,6 +630,7 @@
                                 BRANCH_NOC,
                                 {},
                                 {},
+                                0,
                             } } };
 
   CNN_THREAD_DATA thread_data = { 1, NULL };
@@ -853,6 +903,7 @@
                                 BRANCH_NOC,
                                 {},
                                 {},
+                                0,
                             } } };
 
   int image_height = 10;
@@ -1004,6 +1055,7 @@
                                 BRANCH_NOC,
                                 {},
                                 {},
+                                0,
                             } } };
 
   CNN_THREAD_DATA thread_data = { 1, NULL };
@@ -1077,6 +1129,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1100,6 +1153,7 @@
                                       0x00,
                                   },
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1119,6 +1173,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1138,6 +1193,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1161,6 +1217,7 @@
                                       0x02,
                                   },
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1180,6 +1237,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  0,
                               } } };
 
   // Weights and biases need to be specified separately because
@@ -1247,6 +1305,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1270,6 +1329,7 @@
                                       0x00,
                                   },
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1289,6 +1349,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1308,6 +1369,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  -1,
                               },
                               {
                                   channels,
@@ -1331,6 +1393,7 @@
                                       0x02,
                                   },
                                   {},
+                                  -1,
                               },
                               {
                                   channels + channels,
@@ -1350,6 +1413,7 @@
                                   BRANCH_NOC,
                                   {},
                                   {},
+                                  0,
                               } } };
 
   // Weights and biases need to be specified separately because
@@ -1425,6 +1489,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1448,6 +1513,7 @@
                                         0x00,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1471,6 +1537,7 @@
                                         0x00,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1490,6 +1557,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1513,6 +1581,7 @@
                                         0x08,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1532,6 +1601,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1551,6 +1621,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1574,6 +1645,7 @@
                                         0x0C,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1597,6 +1669,7 @@
                                         0x02,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     channels,
@@ -1616,6 +1689,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    0,
                                 },
                             } };
 
@@ -1686,6 +1760,7 @@
                                         0x00,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     4,
@@ -1709,6 +1784,7 @@
                                         0x02,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     4,
@@ -1728,6 +1804,7 @@
                                     BRANCH_NOC,
                                     {},
                                     {},
+                                    0,
                                 },
                             } };
 
@@ -1787,6 +1864,7 @@
                                         0x00,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     1,
@@ -1810,6 +1888,7 @@
                                         0x03,
                                     },
                                     {},
+                                    -1,
                                 },
                                 {
                                     2,
@@ -1833,6 +1912,7 @@
                                         0x04,
                                     },
                                     {},
+                                    0,
                                 },
                             } };
 
@@ -2065,6 +2145,7 @@
             BRANCH_NOC,
             {},
             bn_params,
+            0,
         },
     },
   };
@@ -2129,6 +2210,7 @@
             BRANCH_NOC,
             {},
             {},
+            0,
         },
     },
   };
@@ -2154,3 +2236,261 @@
     winterface->end(&workers[i]);
   }
 }
+
+TEST_F(CNNTest, TestMultiOutput) {
+  const int image_dim = 8;
+  const int image_ch = 3;
+  const int filter_dim = 2;
+  const int stride = 2;
+  const int num_filters = 2;
+
+  const float input_[] = {
+    1.7537929121f,     0.134331551012f,    0.123580039877f,   0.957731845246f,
+    0.391006834217f,   1.00699352042f,     -0.778177955829f,  -0.814166433059f,
+    -0.656374394915f,  0.321967305228f,    -2.19455719176f,   0.708035038966f,
+    0.409148822266f,   -0.318254408902f,   0.152450211189f,   -0.250210793369f,
+    0.826811563186f,   1.6804156584f,      0.273626975978f,   0.437936241887f,
+    -0.329935520167f,  -0.288761611645f,   0.156937008304f,   0.271054157295f,
+    -0.0224828854332f, 1.70110336895f,     -0.989066699309f,  1.30863131729f,
+    -0.165813705702f,  0.00380178619265f,  -0.0837342367587f, 0.760954783156f,
+    -0.413610373524f,  1.17968204175f,     0.720295719536f,   0.308718974472f,
+    -1.10091337671f,   0.693160033687f,    -0.0202862320697f, 1.0221927503f,
+    -1.24521801881f,   -0.478501952308f,   -1.71648619442f,   -0.182571723636f,
+    0.339292649504f,   2.0806519131f,      0.967974033444f,   0.175248672328f,
+    0.0658124561472f,  0.795504169496f,    0.750592557361f,   -1.46631013249f,
+    -1.79052846838f,   -1.03672179515f,    -0.841985521653f,  1.20995011489f,
+    0.140859718215f,   -0.651552622661f,   0.451065110806f,   1.1189443693f,
+    0.100213260593f,   -0.834076868118f,   -1.28734321611f,   1.22064420095f,
+    -0.364143084361f,  0.750961509335f,    -0.888689074553f,  -0.8253547106f,
+    -1.21800999027f,   -0.966670603566f,   1.37384014741f,    0.47281264834f,
+    -0.420416235531f,  0.520163906493f,    0.501296589423f,   1.53418976951f,
+    0.715234751485f,   0.644551588907f,    0.0763504863375f,  -0.0018541943723f,
+    0.322853189656f,   -0.795099723224f,   -0.125177096675f,  1.4476577471f,
+    -0.585888410088f,  -1.44391754955f,    -0.610543221933f,  -0.221859179799f,
+    0.252060200774f,   -0.86287169623f,    -0.0350246229157f, 1.0932311997f,
+    0.899464648842f,   -0.468806951704f,   -0.300861137168f,  1.15776414206f,
+    1.03268544738f,    -0.171579585622f,   -0.179136557119f,  -0.354091003368f,
+    -0.612298249394f,  -1.20237379258f,    1.54604109659f,    0.130664370287f,
+    0.885225111868f,   1.0362799581f,      0.980561720868f,   -0.619379186999f,
+    -1.33818929924f,   -0.237233737961f,   -1.89335425073f,   0.567821011321f,
+    0.862420368465f,   -1.37380916821f,    0.352190056666f,   0.611261516274f,
+    0.393237747152f,   0.894686247967f,    0.190405182149f,   0.264872662911f,
+    -0.0657009133797f, 0.0580512653493f,   -0.401825294366f,  0.4106081318f,
+    0.49484512188f,    -0.0751103149442f,  -1.43243736382f,   1.79855656009f,
+    -1.1075351975f,    0.000354882733011f, -0.950716438608f,  1.27129831688f,
+    1.00495189838f,    0.110358656713f,    1.08315032822f,    -0.972676676218f,
+    -0.0757668962831f, 1.88932045165f,     -0.0672638136275f, 0.425913010161f,
+    -0.781540372017f,  0.976000248609f,    0.687218504122f,   1.31374513445f,
+    -0.932658930672f,  -1.25339468479f,    0.422071294078f,   -0.24189927912f,
+    0.216906604642f,   -1.88720997548f,    1.99252872889f,    0.353943735777f,
+    0.737434784132f,   -1.17848645017f,    1.70424254896f,    0.775297112968f,
+    -0.516392797501f,  0.398130609129f,    0.737248101457f,   0.166282500886f,
+    1.24699015468f,    0.47116183125f,     1.19091180182f,    -0.372695424578f,
+    0.219773209389f,   -0.829467838962f,   -0.52533122724f,   1.98707754595f,
+    0.553692606972f,   -0.933228902369f,   1.55427751643f,    -1.08813399144f,
+    -0.325686682094f,  0.205091443796f,    -1.70381666435f,   0.466465327942f,
+    1.73126863447f,    -0.939133672634f,   1.48318077459f,    -0.599414038168f,
+    -1.1583078687f,    0.518116190201f,    0.133571482458f,   0.84958342672f,
+    1.02205000597f,    -0.0772082009087f,  -1.69567503859f,   1.4697939436f,
+    1.67813743122f,    -0.627911582938f,   0.131380509137f,   -1.35717850726f,
+  };
+  const float *input[3] = { input_, &input_[image_dim * image_dim],
+                            &input_[2 * image_dim * image_dim] };
+
+  const float bias[] = { 0.0f, 0.0f };
+
+  const float weights_1[] = {
+    -0.489547413618f, 0.141916424749f,  -0.279286485585f,  -0.115322211094f,
+    0.299572786936f,  0.205289980785f,  -0.536254480088f,  -0.253626313744f,
+    -0.422883815849f, -0.169702966298f, -0.540104704793f,  0.495319646763f,
+    0.298799079422f,  -0.10054550901f,  -0.306085047056f,  0.171061886165f,
+    -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
+    -0.157203423678f, -0.362138920529f, -0.216206085209f,  0.147502517971f,
+  };
+
+  const float weights_2[] = {
+    0.207580604357f,  0.480821146263f,  -0.29111909562f,   0.47422567493f,
+    0.206892553253f,  -0.235067084092f, 0.354516800602f,   -0.212399370252f,
+    -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
+    0.567044864811f,  -0.060341127522f, 0.0501464839637f,  -0.437785677916f,
+  };
+
+  const float weights_3[] = {
+    -0.0690452401448f, -0.356657338763f,   -0.219464031809f, 0.551288365843f,
+    0.181372090853f,   -0.00245268542109f, 0.409000696276f,  -0.593209108763f,
+    0.587352566749f,   -0.243720660227f,   0.266232713887f,  -0.00439285245097f,
+    0.252883228305f,   0.152646192631f,    0.0918944932026f, 0.398853715057f,
+  };
+
+  const float weights_4[] = {
+    0.207560791573f,   0.194201350401f,   0.227802322443f,  0.206533663345f,
+    0.0557331066805f,  0.0224159800424f,  -0.143939197467f, -0.27703361602f,
+    0.130643888389f,   -0.269456557461f,  0.186242862864f,  -0.162879944774f,
+    -0.145503996718f,  -0.0768822987581f, -0.203127976359f, -0.238119922873f,
+    -0.258806479994f,  0.0357957680385f,  -0.1027606976f,   -0.287920082345f,
+    0.189047820993f,   0.250711538481f,   -0.272815714175f, -0.0431449742024f,
+    0.207261230996f,   -0.0396472677451f, 0.131236557412f,  0.174291832499f,
+    -0.251515885765f,  -0.107164007499f,  0.185824534748f,  -0.00561585838161f,
+    0.273393799578f,   -0.139563699075f,  -0.263922456031f, -0.118859844081f,
+    0.109230982597f,   -0.170170294794f,  0.0123025648515f, -0.0839368964355f,
+    -0.0774058234297f, 0.255847138286f,   -0.208430879637f, 0.279170114319f,
+    -0.272890330712f,  -0.217725903006f,  -0.295923275459f, -0.17008723953f,
+    -0.284281803405f,  0.281406323629f,   0.266910044663f,  -0.209963914338f,
+    0.271980962964f,   0.142013581699f,   -0.143896509026f, -0.290509242975f,
+    -0.305768180935f,  0.196902832117f,   -0.090424189662f, -0.147460802346f,
+    0.217722016651f,   0.12353848977f,    -0.169177363577f, -0.0454230918512f,
+  };
+
+  const float expected_0[] = {
+    -2.04858441055f,  -2.12883075791f,    -0.045177363807f, 0.763949675768f,
+    -0.544361512821f, -1.58123168032f,    1.89319847039f,   0.16859080901f,
+    -1.16023321135f,  -0.396988107751f,   1.76637090744f,   -1.40434786514f,
+    0.908227575669f,  0.817064817605f,    0.215631134908f,  -0.848605613428f,
+    -0.106756747018f, 0.0193027166685f,   0.801345615113f,  -0.395407237598f,
+    -1.79983795658f,  -1.73054496242f,    0.0584392594454f, -0.388786095569f,
+    -0.237269619354f, 0.000843578271263f, -1.24043512104f,  0.487839445893f,
+    -0.394259726605f, 0.559632843424f,    -0.527224052291f, -1.53792340282f,
+  };
+
+  const float expected_1[] = {
+    0.0f, 0.0f,           0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
+    0.0f, 1.22013465602f,
+  };
+
+  const float expected_2[] = {
+    0.156119444687f,
+    0.517385299817f,
+  };
+
+  const float expected_3[] = {
+    0.224177852984f,
+    0.503384419034f,
+    0.156119444687f,
+    0.517385299817f,
+  };
+
+  const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
+
+  CNN_CONFIG cnn_config = {
+    4,  // num_layers
+    0,  // is_residue
+    0,  // ext_width
+    0,  // ext_height
+    0,  // strict_bounds
+    {
+        // layer_config
+        {
+            image_ch,           // in_channels
+            filter_dim,         // filter_width
+            filter_dim,         // filter_height
+            num_filters,        // out_channels
+            stride,             // skip_width
+            stride,             // skip_height
+            0,                  // max_pool
+            weights_1,          // weights
+            bias,               // bias
+            PADDING_SAME_ZERO,  // pad
+            NONE,               // activation
+            0,                  // deconvolve
+            0,                  // branch
+            BRANCH_OUTPUT,      // branch_copy_type
+            BRANCH_NOC,         // branch_combine_type
+            { 2, 0, 0 },        // branch_config
+            {},                 // bn_params
+            0,                  // output_num
+        },
+        {
+            num_filters,        // in_channels
+            filter_dim,         // filter_width
+            filter_dim,         // filter_height
+            num_filters,        // out_channels
+            stride,             // skip_width
+            stride,             // skip_height
+            0,                  // max_pool
+            weights_2,          // weights
+            bias,               // bias
+            PADDING_SAME_ZERO,  // pad
+            RELU,               // activation
+            0,                  // deconvolve
+            0,                  // branch
+            BRANCH_NO_COPY,     // branch_copy_type
+            BRANCH_NOC,         // branch_combine_type
+            {},                 // branch_config
+            {},                 // bn_params
+            1,                  // output_num
+        },
+        {
+            num_filters,        // in_channels
+            filter_dim,         // filter_width
+            filter_dim,         // filter_height
+            num_filters,        // out_channels
+            stride,             // skip_width
+            stride,             // skip_height
+            0,                  // max_pool
+            weights_3,          // weights
+            bias,               // bias
+            PADDING_SAME_ZERO,  // pad
+            RELU,               // activation
+            0,                  // deconvolve
+            0,                  // branch
+            BRANCH_NO_COPY,     // branch_copy_type
+            BRANCH_NOC,         // branch_combine_type
+            {},                 // branch_config
+            {},                 // bn_params
+            2,                  // output_num
+        },
+        {
+            num_filters,     // in_channels
+            2 * filter_dim,  // filter_width
+            2 * filter_dim,  // filter_height
+            num_filters,     // out_channels
+            2 * stride,      // skip_width
+            2 * stride,      // skip_height
+            0,               // max_pool
+            weights_4,       // weights
+            bias,            // bias
+            PADDING_VALID,   // pad
+            RELU,            // activation
+            0,               // deconvolve
+            1,               // branch
+            BRANCH_NO_COPY,  // branch_copy_type
+            BRANCH_CAT,      // branch_combine_type
+            { 0, 0, 1 },     // branch_config
+            {},              // bn_params
+            3,               // output_num
+        },
+    },
+  };
+
+  CNN_THREAD_DATA thread_data = { 1, NULL };
+
+  const int num_outputs = 4;
+  const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
+                              2 * filter_dim };
+  const int output_dims[4] = { 4, 2, 1, 1 };
+  const int output_sizes[4] = {
+    output_chs[0] * output_dims[0] * output_dims[0],
+    output_chs[1] * output_dims[1] * output_dims[1],
+    output_chs[2] * output_dims[2] * output_dims[2],
+    output_chs[3] * output_dims[3] * output_dims[3],
+  };
+  float *const output_ = (float *)aom_malloc(
+      sizeof(*output_) *
+      (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
+  float *output[CNN_MAX_CHANNELS] = { nullptr };
+  int ch_ite = 0;
+  float *output_ite = output_;
+  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+    for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
+      output[ch_ite++] = output_ite;
+      output_ite += output_dims[output_idx] * output_dims[output_idx];
+    }
+  }
+  CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
+                                  output };
+
+  RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
+                     &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
+
+  aom_free(output_);
+}