Update and add tests for branching CNNs.

- Update old addition and concatenation test to use
  new framework.
- Add new branching test that merges and splits various branches.

Change-Id: Idbb384f5e063415db90f7bebe85023ef6d209ef0
diff --git a/test/cnn_test.cc b/test/cnn_test.cc
index 52034a6..d7eb116 100644
--- a/test/cnn_test.cc
+++ b/test/cnn_test.cc
@@ -989,41 +989,39 @@
              image_width, MSE_FLOAT_TOL);
 }
 
-TEST_F(CNNTest, TestSkipTensorAdd) {
-  int filter_width = 3;
+TEST_F(CNNTest, TestBranchTensorAdd) {
+  int filter_width = 2;
   int filter_height = 3;
 
-  int image_width = 6;
-  int image_height = 6;
+  int image_width = 4;
+  int image_height = 4;
 
   float input[] = {
-    -2, -2, -3, 2,  0,  -1, -1, 0, 1,  -2, 1, -1, -2, 0, 1, 2,  -2, 2,
-    -3, 2,  -1, -3, -3, 0,  0,  2, -2, 0,  2, -1, -1, 0, 2, -1, -2, -2,
+    -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
   };
 
   float weights[] = {
-    1,  -1, -1, -3, -3, -1, 1,  -3, 1,  -2, 0,  2,  -1, -1, -2, 0,  2,  0,
-    -3, -1, -1, -3, -1, 2,  -2, -1, 1,  0,  -3, 1,  -2, 0,  -2, -2, -3, -2,
-    -2, -1, -1, 2,  -2, 2,  -1, -2, -1, -1, 2,  0,  -2, 0,  -2, -2, 1,  -2,
-    2,  -3, -1, -2, 0,  -2, -1, 1,  2,  0,  0,  -1, 1,  0,  0,  -3, 2,  0,
-    2,  0,  0,  -3, 2,  -2, -1, -3, 0,  2,  2,  0,  1,  -3, -2, -2, 1,  -1,
-    2,  -2, -2, 1,  0,  -1, 0,  0,  1,  0,  0,  -1, 2,  -3, -3, 0,  0,  -2,
+    -3, -1, 4,  -1, -3, 3,  3,  0,  2,  0,  3,  2,  4,  4, 4,  -5, 1, -4,
+    2,  -4, 1,  -3, 0,  4,  -5, 4,  0,  -4, -3, -1, 0,  0, -2, 0,  0, 2,
+    -5, -1, 1,  -3, 3,  4,  3,  0,  1,  -1, 1,  1,  2,  4, -2, -5, 2, -2,
+    3,  -2, 4,  -1, 0,  2,  3,  2,  -2, -1, -3, 1,  3,  4, -1, -3, 0, -4,
+    4,  2,  -3, -3, -1, 0,  1,  0,  3,  3,  -3, 0,  3,  2, -5, -3, 4, -5,
+    3,  -1, -1, -3, 0,  1,  -1, -4, 2,  4,  -1, 4,  -1, 1, 3,  4,  4, 4,
+    0,  -1, -3, -3, -3, -3, 2,  -3, -2, 2,  3,  -3,
   };
 
   float bias[] = {
-    1, -1, 1, -1, -3, -1, 1,
+    3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
   };
 
   float expected[] = {
-    -2971, -3879, -4326, -7119, -835,  -521,  -1401, -4918, -9286,
-    -6980, -1696, -611,  -157,  -5482, -9842, -7264, -1820, -3158,
-    489,   -6367, -7028, -4763, -1629, -1987, 1095,  -3985, -4128,
-    -2136, -2574, -4071, 880,   -2678, -1462, -1052, -377,  -1370,
+    -11502, -4101, -3424, 668,   -17950, -5470, -5504, 626,
+    4835,   446,   1779,  -3483, 3679,   -4214, 4578,  -105,
   };
 
   int channels = 2;
 
-  CNN_CONFIG cnn_config = { .num_layers = 4,
+  CNN_CONFIG cnn_config = { .num_layers = 6,
                             .is_residue = 0,
                             .ext_width = 0,
                             .ext_height = 0,
@@ -1065,6 +1063,42 @@
                                   .branches_to_combine = 0,
                               },
                               {
+                                  .branch = 1,
+                                  .deconvolve = 0,
+                                  .in_channels = channels,
+                                  .filter_width = filter_width,
+                                  .filter_height = filter_height,
+                                  .out_channels = channels,
+                                  .skip_width = 1,
+                                  .skip_height = 1,
+                                  .maxpool = 0,
+                                  .weights = nullptr,
+                                  .bias = nullptr,
+                                  .pad = PADDING_SAME_ZERO,
+                                  .activation = NONE,
+                                  .input_to_branches = 0,
+                                  .branch_combine_type = BRANCH_NOC,
+                                  .branches_to_combine = 0,
+                              },
+                              {
+                                  .branch = 1,
+                                  .deconvolve = 0,
+                                  .in_channels = channels,
+                                  .filter_width = filter_width,
+                                  .filter_height = filter_height,
+                                  .out_channels = channels,
+                                  .skip_width = 1,
+                                  .skip_height = 1,
+                                  .maxpool = 0,
+                                  .weights = nullptr,
+                                  .bias = nullptr,
+                                  .pad = PADDING_SAME_ZERO,
+                                  .activation = NONE,
+                                  .input_to_branches = 0,
+                                  .branch_combine_type = BRANCH_NOC,
+                                  .branches_to_combine = 0,
+                              },
+                              {
                                   .branch = 0,
                                   .deconvolve = 0,
                                   .in_channels = channels,
@@ -1109,42 +1143,39 @@
              image_width, MSE_INT_TOL);
 }
 
-TEST_F(CNNTest, TestSkipTensorConcatenation) {
-  int filter_width = 3;
+TEST_F(CNNTest, TestBranchTensorConcatenation) {
+  int filter_width = 2;
   int filter_height = 3;
 
-  int image_width = 6;
-  int image_height = 6;
+  int image_width = 4;
+  int image_height = 4;
 
   float input[] = {
-    0, 1,  -1, 0,  -2, -3, 0,  2, 1,  1,  -2, -2, 1,  1,  0,  1,  -1, 0,
-    2, -3, 2,  -2, 0,  -2, -3, 2, -3, -1, 1,  2,  -2, -3, -3, -2, 1,  -2,
+    -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
   };
 
   float weights[] = {
-    -2, 0,  2,  -2, 1,  1,  -3, -1, -1, -1, 1,  1,  -1, -3, -1, 0,  0,  0,
-    0,  0,  2,  2,  0,  0,  -1, -2, -1, 1,  1,  1,  2,  1,  -1, 2,  -2, 2,
-    -1, -3, -2, -2, 0,  -2, 1,  1,  0,  0,  0,  -1, 1,  0,  2,  -3, 1,  -1,
-    -1, -1, 1,  -2, -2, -2, -3, 2,  2,  1,  -1, -3, 0,  -2, 1,  -2, -2, -3,
-    2,  1,  -3, -2, -2, 1,  1,  2,  2,  -1, -1, -1, 2,  2,  -2, -2, 2,  1,
-    1,  -2, 2,  2,  2,  -1, 2,  2,  2,  0,  2,  -1, 0,  -2, -3, 2,  -3, 1,
-    1,  -3, 0,  2,  2,  2,  -1, -1, -2, -2, -3, -2, 2,  1,  -3, -3, 1,  2,
+    3,  0,  2,  0,  2,  3,  1,  -3, 1,  -5, -3, 0,  -4, 4,  0,  -5, 0,  -5, -1,
+    -2, -5, 0,  -3, 2,  -4, 2,  0,  2,  -1, 0,  -4, 3,  0,  0,  -1, -5, 2,  -1,
+    4,  -4, -2, -3, -3, 3,  4,  -2, -1, -4, -1, 4,  4,  -1, 4,  3,  -4, 2,  -2,
+    -4, -3, -2, 3,  -3, -5, -1, 3,  -2, 4,  1,  -4, -3, -5, -5, -3, 4,  -2, -2,
+    -1, -5, -5, 0,  -1, -2, -3, 3,  -4, -5, 2,  -3, 1,  0,  -5, 2,  2,  -2, 0,
+    2,  2,  -2, 4,  2,  2,  0,  1,  -5, -3, 0,  2,  -2, 1,  2,  -5, 2,  3,  3,
+    -1, 3,  0,  -3, 3,  -4, -4, 3,  3,  -4, -2, 2,  -2, 2,  -2, -1, 3,  0,
   };
 
   float bias[] = {
-    -2, 0, 1, 1, -1, 0, 2,
+    -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
   };
 
   float expected[] = {
-    4032,  3724,  -21,   -2743, -1811, 684,   5068,  9168,  7067,
-    1217,  -4681, -1320, 2373,  6161,  7928,  6649,  -1183, -1995,
-    -1976, -2717, 2881,  5825,  2324,  -187,  -5585, -4958, -4390,
-    1053,  2734,  987,   881,   -409,  -1173, -1899, 340,   -1015,
+    -33533, -32087, -6741,  -2124, 39979, 41453, 14034, 689,
+    -22611, -42203, -14882, -239,  15781, 15963, 9524,  837,
   };
 
   int channels = 2;
 
-  CNN_CONFIG cnn_config = { .num_layers = 4,
+  CNN_CONFIG cnn_config = { .num_layers = 6,
                             .is_residue = 0,
                             .ext_width = 0,
                             .ext_height = 0,
@@ -1186,6 +1217,42 @@
                                   .branches_to_combine = 0,
                               },
                               {
+                                  .branch = 1,
+                                  .deconvolve = 0,
+                                  .in_channels = channels,
+                                  .filter_width = filter_width,
+                                  .filter_height = filter_height,
+                                  .out_channels = channels,
+                                  .skip_width = 1,
+                                  .skip_height = 1,
+                                  .maxpool = 0,
+                                  .weights = nullptr,
+                                  .bias = nullptr,
+                                  .pad = PADDING_SAME_ZERO,
+                                  .activation = NONE,
+                                  .input_to_branches = 0,
+                                  .branch_combine_type = BRANCH_NOC,
+                                  .branches_to_combine = 0,
+                              },
+                              {
+                                  .branch = 1,
+                                  .deconvolve = 0,
+                                  .in_channels = channels,
+                                  .filter_width = filter_width,
+                                  .filter_height = filter_height,
+                                  .out_channels = channels,
+                                  .skip_width = 1,
+                                  .skip_height = 1,
+                                  .maxpool = 0,
+                                  .weights = nullptr,
+                                  .bias = nullptr,
+                                  .pad = PADDING_SAME_ZERO,
+                                  .activation = NONE,
+                                  .input_to_branches = 0,
+                                  .branch_combine_type = BRANCH_NOC,
+                                  .branches_to_combine = 0,
+                              },
+                              {
                                   .branch = 0,
                                   .deconvolve = 0,
                                   .in_channels = channels,
@@ -1229,3 +1296,236 @@
   RunCNNTest(image_width, image_height, input, expected, cnn_config,
              image_width, MSE_INT_TOL);
 }
+
+TEST_F(CNNTest, TestBranchCombinations) {
+  int filter_width = 2;
+  int filter_height = 3;
+
+  int image_width = 4;
+  int image_height = 4;
+
+  float input[] = {
+    -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
+  };
+
+  float weights[] = {
+    0,  4,  -4, -1, -4, -3, -5, 4,  -2, -2, -1, 2,  2,  0,  -1, 0,  2,  -4,
+    3,  4,  2,  0,  -5, -4, -4, 4,  -3, -2, -5, 3,  -3, 0,  0,  1,  1,  -5,
+    -2, 4,  -5, 3,  -2, -2, -4, 4,  -2, -4, -3, 1,  3,  -4, 2,  -1, 0,  -1,
+    -2, -4, 4,  0,  0,  -3, -3, -3, -5, 2,  1,  4,  -1, 3,  -3, -2, 3,  -5,
+    2,  -1, -4, -4, 3,  1,  -4, -3, -4, 0,  -5, -2, 1,  0,  -1, -2, -2, 1,
+    4,  2,  2,  1,  -1, 4,  2,  0,  4,  -3, 1,  -1, 0,  -1, 1,  -5, -4, -5,
+    0,  -1, 2,  -5, -2, 1,  -3, -4, -2, 0,  -2, -5, -3, -4, -5, 1,  0,  3,
+    -4, 3,  -4, -2, -1, -1, 2,  3,  -3, -1, 0,  3,  0,  3,  0,  -2, 2,  2,
+    1,  -4, 1,  -4, 0,  -2, -1, 3,  -2, -1, 4,  4,  1,  -2, 4,  3,  1,  -1,
+    -5, -5, 0,  -4, 4,  4,  4,  1,  0,  -4, -1, 3,  -5, 0,  3,  -1, 2,  4,
+    -2, 4,  4,  1,  4,  3,  -4, 2,  0,  4,  2,  1,  -5, 2,  4,  -5, -2, 2,
+    -4, -1, -3, -3, 2,  -2, 1,  0,  -5, 3,  -3, 1,  3,  4,  -5, 2,  0,  4,
+  };
+
+  float bias[] = {
+    4, -4, 0, -1, 1, 1, 3, -4, -5, 3, 2, -3, -5, -3, -1, 3, 4, -5, -3,
+  };
+
+  float expected[] = {
+    316397, 106874, 26726,  1971,  355397,  6848,   -20952, 18023,
+    -10736, -52466, -22737, -1496, -220644, -55007, 15175,  -7343,
+  };
+
+  int channels = 2;
+
+  CNN_CONFIG cnn_config = { .num_layers = 10,
+                            .is_residue = 0,
+                            .ext_width = 0,
+                            .ext_height = 0,
+                            .strict_bounds = 0,
+                            {
+                                {
+                                    .branch = 0,
+                                    .deconvolve = 0,
+                                    .in_channels = 1,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = weights,
+                                    .bias = bias,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0,
+                                },
+                                {
+                                    .branch = 0,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0x06,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0,
+                                },
+                                {
+                                    .branch = 2,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0,
+                                },
+                                {
+                                    .branch = 2,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0x08,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0,
+                                },
+                                {
+                                    .branch = 3,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0,
+                                },
+                                {
+                                    .branch = 2,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_ADD,
+                                    .branches_to_combine = 0x08,
+                                },
+                                {
+                                    .branch = 1,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0,
+                                },
+                                {
+                                    .branch = 1,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_ADD,
+                                    .branches_to_combine = 0x0C,
+                                },
+                                {
+                                    .branch = 0,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = channels,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_ADD,
+                                    .branches_to_combine = 0x02,
+                                },
+                                {
+                                    .branch = 0,
+                                    .deconvolve = 0,
+                                    .in_channels = channels,
+                                    .filter_width = filter_width,
+                                    .filter_height = filter_height,
+                                    .out_channels = 1,
+                                    .skip_width = 1,
+                                    .skip_height = 1,
+                                    .maxpool = 0,
+                                    .weights = nullptr,
+                                    .bias = nullptr,
+                                    .pad = PADDING_SAME_ZERO,
+                                    .activation = NONE,
+                                    .input_to_branches = 0,
+                                    .branch_combine_type = BRANCH_NOC,
+                                    .branches_to_combine = 0,
+                                },
+                            } };
+
+  // Weights and biases need to be specified separately because
+  // of the offset.
+  AssignLayerWeightsBiases(&cnn_config, weights, bias);
+
+  RunCNNTest(image_width, image_height, input, expected, cnn_config,
+             image_width, MSE_INT_TOL);
+}