Update and add tests for branching CNNs.
- Update old addition and concatenation test to use
new framework.
- Add new branching test that merges and splits various branches.
Change-Id: Idbb384f5e063415db90f7bebe85023ef6d209ef0
diff --git a/test/cnn_test.cc b/test/cnn_test.cc
index 52034a6..d7eb116 100644
--- a/test/cnn_test.cc
+++ b/test/cnn_test.cc
@@ -989,41 +989,39 @@
image_width, MSE_FLOAT_TOL);
}
-TEST_F(CNNTest, TestSkipTensorAdd) {
- int filter_width = 3;
+TEST_F(CNNTest, TestBranchTensorAdd) {
+ int filter_width = 2;
int filter_height = 3;
- int image_width = 6;
- int image_height = 6;
+ int image_width = 4;
+ int image_height = 4;
float input[] = {
- -2, -2, -3, 2, 0, -1, -1, 0, 1, -2, 1, -1, -2, 0, 1, 2, -2, 2,
- -3, 2, -1, -3, -3, 0, 0, 2, -2, 0, 2, -1, -1, 0, 2, -1, -2, -2,
+ -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
};
float weights[] = {
- 1, -1, -1, -3, -3, -1, 1, -3, 1, -2, 0, 2, -1, -1, -2, 0, 2, 0,
- -3, -1, -1, -3, -1, 2, -2, -1, 1, 0, -3, 1, -2, 0, -2, -2, -3, -2,
- -2, -1, -1, 2, -2, 2, -1, -2, -1, -1, 2, 0, -2, 0, -2, -2, 1, -2,
- 2, -3, -1, -2, 0, -2, -1, 1, 2, 0, 0, -1, 1, 0, 0, -3, 2, 0,
- 2, 0, 0, -3, 2, -2, -1, -3, 0, 2, 2, 0, 1, -3, -2, -2, 1, -1,
- 2, -2, -2, 1, 0, -1, 0, 0, 1, 0, 0, -1, 2, -3, -3, 0, 0, -2,
+ -3, -1, 4, -1, -3, 3, 3, 0, 2, 0, 3, 2, 4, 4, 4, -5, 1, -4,
+ 2, -4, 1, -3, 0, 4, -5, 4, 0, -4, -3, -1, 0, 0, -2, 0, 0, 2,
+ -5, -1, 1, -3, 3, 4, 3, 0, 1, -1, 1, 1, 2, 4, -2, -5, 2, -2,
+ 3, -2, 4, -1, 0, 2, 3, 2, -2, -1, -3, 1, 3, 4, -1, -3, 0, -4,
+ 4, 2, -3, -3, -1, 0, 1, 0, 3, 3, -3, 0, 3, 2, -5, -3, 4, -5,
+ 3, -1, -1, -3, 0, 1, -1, -4, 2, 4, -1, 4, -1, 1, 3, 4, 4, 4,
+ 0, -1, -3, -3, -3, -3, 2, -3, -2, 2, 3, -3,
};
float bias[] = {
- 1, -1, 1, -1, -3, -1, 1,
+ 3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
};
float expected[] = {
- -2971, -3879, -4326, -7119, -835, -521, -1401, -4918, -9286,
- -6980, -1696, -611, -157, -5482, -9842, -7264, -1820, -3158,
- 489, -6367, -7028, -4763, -1629, -1987, 1095, -3985, -4128,
- -2136, -2574, -4071, 880, -2678, -1462, -1052, -377, -1370,
+ -11502, -4101, -3424, 668, -17950, -5470, -5504, 626,
+ 4835, 446, 1779, -3483, 3679, -4214, 4578, -105,
};
int channels = 2;
- CNN_CONFIG cnn_config = { .num_layers = 4,
+ CNN_CONFIG cnn_config = { .num_layers = 6,
.is_residue = 0,
.ext_width = 0,
.ext_height = 0,
@@ -1065,6 +1063,42 @@
.branches_to_combine = 0,
},
{
+ .branch = 1,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 1,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
.branch = 0,
.deconvolve = 0,
.in_channels = channels,
@@ -1109,42 +1143,39 @@
image_width, MSE_INT_TOL);
}
-TEST_F(CNNTest, TestSkipTensorConcatenation) {
- int filter_width = 3;
+TEST_F(CNNTest, TestBranchTensorConcatenation) {
+ int filter_width = 2;
int filter_height = 3;
- int image_width = 6;
- int image_height = 6;
+ int image_width = 4;
+ int image_height = 4;
float input[] = {
- 0, 1, -1, 0, -2, -3, 0, 2, 1, 1, -2, -2, 1, 1, 0, 1, -1, 0,
- 2, -3, 2, -2, 0, -2, -3, 2, -3, -1, 1, 2, -2, -3, -3, -2, 1, -2,
+ -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
};
float weights[] = {
- -2, 0, 2, -2, 1, 1, -3, -1, -1, -1, 1, 1, -1, -3, -1, 0, 0, 0,
- 0, 0, 2, 2, 0, 0, -1, -2, -1, 1, 1, 1, 2, 1, -1, 2, -2, 2,
- -1, -3, -2, -2, 0, -2, 1, 1, 0, 0, 0, -1, 1, 0, 2, -3, 1, -1,
- -1, -1, 1, -2, -2, -2, -3, 2, 2, 1, -1, -3, 0, -2, 1, -2, -2, -3,
- 2, 1, -3, -2, -2, 1, 1, 2, 2, -1, -1, -1, 2, 2, -2, -2, 2, 1,
- 1, -2, 2, 2, 2, -1, 2, 2, 2, 0, 2, -1, 0, -2, -3, 2, -3, 1,
- 1, -3, 0, 2, 2, 2, -1, -1, -2, -2, -3, -2, 2, 1, -3, -3, 1, 2,
+ 3, 0, 2, 0, 2, 3, 1, -3, 1, -5, -3, 0, -4, 4, 0, -5, 0, -5, -1,
+ -2, -5, 0, -3, 2, -4, 2, 0, 2, -1, 0, -4, 3, 0, 0, -1, -5, 2, -1,
+ 4, -4, -2, -3, -3, 3, 4, -2, -1, -4, -1, 4, 4, -1, 4, 3, -4, 2, -2,
+ -4, -3, -2, 3, -3, -5, -1, 3, -2, 4, 1, -4, -3, -5, -5, -3, 4, -2, -2,
+ -1, -5, -5, 0, -1, -2, -3, 3, -4, -5, 2, -3, 1, 0, -5, 2, 2, -2, 0,
+ 2, 2, -2, 4, 2, 2, 0, 1, -5, -3, 0, 2, -2, 1, 2, -5, 2, 3, 3,
+ -1, 3, 0, -3, 3, -4, -4, 3, 3, -4, -2, 2, -2, 2, -2, -1, 3, 0,
};
float bias[] = {
- -2, 0, 1, 1, -1, 0, 2,
+ -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
};
float expected[] = {
- 4032, 3724, -21, -2743, -1811, 684, 5068, 9168, 7067,
- 1217, -4681, -1320, 2373, 6161, 7928, 6649, -1183, -1995,
- -1976, -2717, 2881, 5825, 2324, -187, -5585, -4958, -4390,
- 1053, 2734, 987, 881, -409, -1173, -1899, 340, -1015,
+ -33533, -32087, -6741, -2124, 39979, 41453, 14034, 689,
+ -22611, -42203, -14882, -239, 15781, 15963, 9524, 837,
};
int channels = 2;
- CNN_CONFIG cnn_config = { .num_layers = 4,
+ CNN_CONFIG cnn_config = { .num_layers = 6,
.is_residue = 0,
.ext_width = 0,
.ext_height = 0,
@@ -1186,6 +1217,42 @@
.branches_to_combine = 0,
},
{
+ .branch = 1,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 1,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
.branch = 0,
.deconvolve = 0,
.in_channels = channels,
@@ -1229,3 +1296,236 @@
RunCNNTest(image_width, image_height, input, expected, cnn_config,
image_width, MSE_INT_TOL);
}
+
+TEST_F(CNNTest, TestBranchCombinations) {
+ int filter_width = 2;
+ int filter_height = 3;
+
+ int image_width = 4;
+ int image_height = 4;
+
+ float input[] = {
+ -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
+ };
+
+ float weights[] = {
+ 0, 4, -4, -1, -4, -3, -5, 4, -2, -2, -1, 2, 2, 0, -1, 0, 2, -4,
+ 3, 4, 2, 0, -5, -4, -4, 4, -3, -2, -5, 3, -3, 0, 0, 1, 1, -5,
+ -2, 4, -5, 3, -2, -2, -4, 4, -2, -4, -3, 1, 3, -4, 2, -1, 0, -1,
+ -2, -4, 4, 0, 0, -3, -3, -3, -5, 2, 1, 4, -1, 3, -3, -2, 3, -5,
+ 2, -1, -4, -4, 3, 1, -4, -3, -4, 0, -5, -2, 1, 0, -1, -2, -2, 1,
+ 4, 2, 2, 1, -1, 4, 2, 0, 4, -3, 1, -1, 0, -1, 1, -5, -4, -5,
+ 0, -1, 2, -5, -2, 1, -3, -4, -2, 0, -2, -5, -3, -4, -5, 1, 0, 3,
+ -4, 3, -4, -2, -1, -1, 2, 3, -3, -1, 0, 3, 0, 3, 0, -2, 2, 2,
+ 1, -4, 1, -4, 0, -2, -1, 3, -2, -1, 4, 4, 1, -2, 4, 3, 1, -1,
+ -5, -5, 0, -4, 4, 4, 4, 1, 0, -4, -1, 3, -5, 0, 3, -1, 2, 4,
+ -2, 4, 4, 1, 4, 3, -4, 2, 0, 4, 2, 1, -5, 2, 4, -5, -2, 2,
+ -4, -1, -3, -3, 2, -2, 1, 0, -5, 3, -3, 1, 3, 4, -5, 2, 0, 4,
+ };
+
+ float bias[] = {
+ 4, -4, 0, -1, 1, 1, 3, -4, -5, 3, 2, -3, -5, -3, -1, 3, 4, -5, -3,
+ };
+
+ float expected[] = {
+ 316397, 106874, 26726, 1971, 355397, 6848, -20952, 18023,
+ -10736, -52466, -22737, -1496, -220644, -55007, 15175, -7343,
+ };
+
+ int channels = 2;
+
+ CNN_CONFIG cnn_config = { .num_layers = 10,
+ .is_residue = 0,
+ .ext_width = 0,
+ .ext_height = 0,
+ .strict_bounds = 0,
+ {
+ {
+ .branch = 0,
+ .deconvolve = 0,
+ .in_channels = 1,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = weights,
+ .bias = bias,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 0,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0x06,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 2,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 2,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0x08,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 3,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 2,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_ADD,
+ .branches_to_combine = 0x08,
+ },
+ {
+ .branch = 1,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ {
+ .branch = 1,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_ADD,
+ .branches_to_combine = 0x0C,
+ },
+ {
+ .branch = 0,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = channels,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_ADD,
+ .branches_to_combine = 0x02,
+ },
+ {
+ .branch = 0,
+ .deconvolve = 0,
+ .in_channels = channels,
+ .filter_width = filter_width,
+ .filter_height = filter_height,
+ .out_channels = 1,
+ .skip_width = 1,
+ .skip_height = 1,
+ .maxpool = 0,
+ .weights = nullptr,
+ .bias = nullptr,
+ .pad = PADDING_SAME_ZERO,
+ .activation = NONE,
+ .input_to_branches = 0,
+ .branch_combine_type = BRANCH_NOC,
+ .branches_to_combine = 0,
+ },
+ } };
+
+ // Weights and biases need to be specified separately because
+ // of the offset.
+ AssignLayerWeightsBiases(&cnn_config, weights, bias);
+
+ RunCNNTest(image_width, image_height, input, expected, cnn_config,
+ image_width, MSE_INT_TOL);
+}