blob: 3012451de8de8bd103f3763892bbc68b1adc1c83 [file] [log] [blame]
/*
* Copyright (c) 2019, Alliance for Open Media. All rights reserved.
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include "gtest/gtest.h"
#include "config/av1_rtcd.h"
#include "aom_ports/aom_timer.h"
#include "av1/encoder/cnn.h"
#include "av1/encoder/partition_cnn_weights.h"
#include "test/acm_random.h"
#include "test/function_equivalence_test.h"
#include "test/util.h"
#define SQR(x) ((x) * (x))
// Best possible pixelwise guaranteed precision given each float has at most
// 3 specified decimals.
#define PIXELWISE_FLOAT_TOL 1E-2
#define MSE_FLOAT_TOL 1E-6
#define MSE_INT_TOL 0
// CNN convolve pixelwise error threshold for functional equivalence.
#define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1E-3f
namespace {
class CNNTest : public ::testing::Test {
protected:
static void RunCNNTest(int image_width, int image_height, const float *input,
const float *expected, const CNN_CONFIG *cnn_config,
int in_stride, CNN_THREAD_DATA *thread_data,
double tolerance) {
int out_width, out_height, out_channels;
av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
&out_height, &out_channels);
const int out_size = out_width * out_height;
const int out_stride = out_width;
float *output_ =
(float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
ASSERT_NE(output_, nullptr);
float *output[CNN_MAX_CHANNELS] = { nullptr };
for (int channel = 0; channel < out_channels; ++channel) {
output[channel] = output_ + (channel * out_size);
}
const int num_outputs = 1;
const int output_chs[1] = { out_channels };
const int output_strides[1] = { out_stride };
CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
output };
RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
thread_data, &output_struct, &expected, tolerance);
aom_free(output_);
}
static void RunMultiOutCNNTest(const float **input, int image_width,
int image_height, int in_stride,
const CNN_CONFIG *cnn_config,
CNN_THREAD_DATA *thread_data,
CNN_MULTI_OUT *output, const float **expected,
double tolerance) {
const int num_outputs = output->num_outputs;
const int *output_chs = output->output_channels;
int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
ASSERT_NE(out_widths, nullptr);
ASSERT_NE(out_heights, nullptr);
ASSERT_NE(not_used, nullptr);
av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
out_heights, not_used);
ASSERT_TRUE(av1_cnn_predict(input, image_width, image_height, in_stride,
cnn_config, thread_data, output));
int channel_offset = 0;
for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
const float *expected_out = expected[output_idx];
const int curr_output_chs = output_chs[output_idx];
const int out_size = out_widths[output_idx] * out_heights[output_idx];
double mse = 0;
int expected_ite = 0;
for (int channel = 0; channel < curr_output_chs; ++channel) {
const float *buf_out = output->output_buffer[channel_offset];
for (int i = 0; i < out_size; ++i) {
EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
PIXELWISE_FLOAT_TOL)
<< " output " << output_idx << " channel " << channel << " pixel "
<< expected_ite % out_size << ": " << expected_out[expected_ite]
<< "/" << buf_out[i] << std::endl;
mse += SQR(expected_out[expected_ite] - buf_out[i]);
expected_ite++;
}
channel_offset++;
}
mse /= (out_size * curr_output_chs);
EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
}
aom_free(out_widths);
aom_free(out_heights);
aom_free(not_used);
}
static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
float *bias) {
size_t weight_offset = 0;
size_t bias_offset = 0;
for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
layer_config->weights = weights + weight_offset;
layer_config->bias = bias + bias_offset;
weight_offset += layer_config->filter_width *
layer_config->filter_height * layer_config->in_channels *
layer_config->out_channels;
bias_offset += layer_config->out_channels;
ASSERT_NE(layer_config->weights, nullptr);
ASSERT_NE(layer_config->bias, nullptr);
}
}
};
} // namespace
TEST_F(CNNTest, TestMultilayerConvolution) {
int image_height = 16;
int image_width = 16;
int filter_height = 5;
int filter_width = 4;
float input[] = {
-3, 1, -3, 2, -2, -2, 2, -2, 1, -2, -3, 1, 2, 2, 2, -2, 0, 1, -1,
-3, -1, -1, 1, 0, -3, 1, 0, -1, 1, 0, 0, -3, -3, -3, 0, 2, 1, -1,
2, 0, 1, -3, -1, 2, 2, 1, -2, 0, -1, 0, -2, -2, -1, 1, 0, 0, 0,
-2, -2, -2, 1, 1, -2, 1, 1, -2, -2, 1, -2, -1, -2, -3, 2, -3, -1, 1,
0, -2, -2, -2, 1, -2, -2, -1, -1, 2, 2, 2, -1, 1, -3, -3, 0, 2, 0,
2, 1, -3, -3, 1, 2, 2, 1, -2, -3, 0, -3, 0, -3, -2, 0, 1, 1, 0,
-3, 2, -1, 2, 1, 0, 1, -2, 1, -1, -1, 2, 0, -2, -3, 1, 1, -2, -1,
-3, -3, -1, 0, -3, -2, 0, 0, 1, 0, -3, -2, -1, 1, 0, 2, 1, 0, -3,
-2, -3, -3, -1, 0, -2, 2, -1, -3, 0, -1, -1, 2, 0, -3, -2, -1, 0, 0,
1, -2, 1, 2, 1, 2, 2, -3, 2, -1, 0, 0, -1, 0, 2, 2, -1, 2, -2,
1, 1, -3, -3, 1, -1, -1, -2, 2, -2, -2, 2, -1, -3, 2, -3, 1, -1, -1,
-3, 1, -1, 1, 0, -3, -3, 1, -3, -3, 0, 2, 2, -2, -1, 2, 0, 2, 1,
-1, -3, 0, 0, -1, -1, 1, 0, 2, 0, -3, 2, 1, 0, 1, -3, 2, -3, -3,
-1, -3, -3, 2, 0, 2, -2, 1, -1,
};
float weights[] = {
-2, 2, -2, 2, -1, -3, 2, 2, 0, 0, -3, -1, -2, -3, 1, -1, 0, 0, 0,
2, -2, 2, -2, -3, 1, 1, 1, -3, -1, 0, 1, 2, -2, 0, -1, -3, -1, -2,
2, -3, -3, 1, -2, -3, 0, 2, 1, -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
-1, -3, -1, -2, -2, -3, 2, 0, -3, 0, -3, -3, 1, -3, -1, 0, -1, 1, 1,
-1, 1, -2, 0, 2, 0, -3, 1, -1, -1, 2, 0, 1, -3, -3, 1, 2, -3, -3,
1, -3, 2, 0, -3, 1, 2, 2, -2, -1, -2, 1, 1, 0, -2, -2, 1, 2, -1,
-3, 1, -2, 2, -3, -2, -3, 2, 1, 0, -2, 0, 1, -3, 2, -2, -2, 0, 2,
-3, 2, 0, 0, 1, -2, 1, 1, -2, -1, -2, 1, -2, 0, -2, -2, 0, -1, -1,
-3, -3, -3, 1, -3, -2, 2, -1, 2, 0, 2, -2, 2, -2, 1, -3, -3, -1, 0,
2, 2, 1, -1, -3, -1, -3, 2, 1, -2, 0, -3, -1, -3, -1, 2, 1, 0, 2,
-1, 1, 0, 1, 2, -1, -2, 2, 1, -3, -1, -3, 0, 1, -2, 0, -2, -3, 0,
-2, 2, 2, 0, 0, 2, -3, 2, -3, -2, 1, 2, -3, -3, -1, -3, 0, -3, -3,
-2, -2, -2, 0, 0, 1, 0, 0, -1, 0, 0, -3, 0, -3, -1, -2, 1, -2, -1,
2, -2, 0, 0, 1, 0, -2, -1, 0, -3, 1, 0, -1, -3, 1, -1, 1, -1, -3,
1, 0, 1, 1, -1, 2, 2, 0, 0, 1, -3, 2, -2, -2, -3, -2, -1, -2, 2,
0, 2, -2, -3, -1, -3, 2, 2, -1, 2, 2, -1, 0, -3, 1,
};
float bias[] = {
1, -1, 0, 1, 1, 1, -2,
};
float expected_same[] = {
-1125, 2926, 6406, 631, -1244, 97, -1454, 2526, 1065, 3292, 3464,
2553, -330, 532, 1038, 1182, -402, 3758, 3392, 9854, 4365, 1408,
4736, 3134, 3838, 2409, 3221, 4350, 6750, 4045, 815, 1188, 2959,
9802, 9590, 4572, 5740, 4253, 1701, 7974, 7012, 6854, 7093, 3907,
4539, 3886, 4267, 3505, 465, 7824, 9219, 10026, 7968, 957, 2295,
5594, 10811, 9641, 5950, 10043, 8783, 3132, 1421, 1110, 4108, 13929,
10660, -84, -61, 3932, -180, 6811, 13393, 15147, 15640, 9337, 6961,
3808, 1604, 1398, 1047, 6739, 10144, 6517, 4698, 2678, 7389, 2595,
5248, 12075, 11272, 13951, 8820, 1090, 2199, 2206, 2788, 12116, 6683,
2612, -291, 3183, 9414, 12316, 14524, 12333, 13208, 7832, 4664, 4657,
3534, 1298, -666, 4250, 7707, 9103, 5760, 688, 9571, 15782, 14203,
14878, 17339, 14684, 8690, 5671, 875, 1429, 1531, 6173, 2984, 5558,
2996, 7928, 6733, 16117, 15262, 12757, 7980, 3923, 4795, 5973, 2051,
455, -1922, 1816, 5906, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
7451, 6666, 74, -1645, -35, -391, 3813, 7324, 892, 1656, 6095,
12193, 14648, 12156, 14663, 10251, 10325, 7821, 3925, 323, 697, 442,
1324, 4669, 7002, 5485, 5171, 5086, 10582, 11053, 9709, 11353, 8543,
5256, 2873, 235, -628, 1496, 1878, -867, 3420, 6865, 5937, 10182,
13277, 10069, 10789, 5998, 624, -2082, 4417, 1258, -1080, -819, -1430,
1033, 5220, 6335, 8471, 8980, 11908, 14430, 12584, 8404, 1576, -803,
985, 1481, 1367, -193, 873, 3684, 2288, 6676, 9477, 11155, 9602,
9707, 10507, 4739, 3174, -575, -178, 3002, 1710, 423, -477, 554,
3088, 2029, 5113, 5000, 3771, 6090, 5365, 1185, 2855, 399, -312,
-1577, 176, 955,
};
float expected_replicate[] = {
13768, 13528, 12999, 6906, 4618, 4043, 2611, 9955, 6685, 4776, 2753,
1036, 3063, 4544, 5183, 7349, 12451, 12501, 9131, 12753, 8908, 4058,
6299, 7542, 7115, 3307, 3360, 3543, 9754, 7808, 5991, 9019, 14320,
14919, 12492, 6871, 7373, 3336, 2085, 10604, 9377, 6882, 5009, 3103,
6220, 6278, 7588, 10196, 11045, 11563, 11842, 11911, 8279, 2030, 1858,
6368, 12123, 9909, 6347, 10345, 9365, 4038, 1673, 3051, 16492, 16649,
12276, 408, -301, 4122, -654, 7864, 14038, 15279, 15315, 9744, 8243,
5298, 746, 380, 9824, 9124, 10895, 6640, 4712, 2669, 6980, 2759,
5385, 12345, 11336, 13129, 8600, 2370, 3682, 5219, 12407, 13123, 6784,
2612, -291, 3183, 9414, 12316, 14524, 12333, 13397, 7543, 3916, 4153,
4477, 4314, 7983, 8418, 9163, 9103, 5760, 688, 9571, 15782, 14203,
14878, 17718, 14570, 7940, 6642, 5094, 7133, 9964, 10219, 3224, 5558,
2996, 7928, 6733, 16117, 15262, 12757, 7958, 4401, 5187, 5476, 5529,
6055, 2206, 3909, 6015, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
6967, 6840, 481, -1600, 274, 1, 10373, 8514, 1123, 2117, 6758,
12736, 16223, 13585, 15988, 11771, 10600, 7918, 4156, 2840, 3111, 3287,
6359, 7652, 8813, 6530, 6967, 7789, 13671, 13990, 13247, 13241, 9836,
5251, 3024, 2313, 1834, 4187, 2637, -1312, 2139, 7378, 7665, 11933,
15591, 15314, 15678, 9531, 2820, -1516, 3400, 1314, 22, 363, -2896,
-898, 5906, 7308, 10650, 12975, 16978, 20370, 18817, 12381, 4118, -861,
-137, 236, 1802, 1632, -350, 2334, 3400, 8680, 14064, 18216, 18675,
21765, 22871, 11491, 4937, -1555, -11, 1669, 2392, 3265, -5254, -217,
5001, 8063, 13444, 18884, 19706, 22794, 21064, 9545, 6689, -7, 289,
-2021, 504, 2347,
};
float expected_valid[] = {
2612, -291, 3183, 9414, 12316, 14524, 12333, 9103, 5760, 688,
9571, 15782, 14203, 14878, 5558, 2996, 7928, 6733, 16117, 15262,
12757, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
};
CNN_CONFIG cnn_config = { 3,
0,
0,
0,
0,
{
{
1,
filter_width,
filter_height,
3,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
3,
filter_width,
filter_height,
3,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
3,
filter_width,
filter_height,
1,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
},
} };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
for (int i = 0; i < cnn_config.num_layers; ++i) {
cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
}
RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
for (int i = 0; i < cnn_config.num_layers; ++i) {
cnn_config.layer_config[i].pad = PADDING_VALID;
}
RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestRELUSingleLayer) {
int image_width = 8;
int image_height = 8;
int filter_height = 5;
int filter_width = 4;
float input[] = {
0, -2, -3, 1, -1, 2, -2, 1, -3, -1, 0, 1, -2, -3, -2, -2,
1, -3, 2, -3, -1, -1, 2, 0, -2, -3, 0, -2, -3, 1, -1, -1,
2, -2, 0, -2, -3, -3, 1, 1, -1, 1, 0, 1, -3, 0, 2, 2,
0, -3, 1, -3, 2, -2, 1, -1, -1, -2, -3, -2, -1, -3, -2, -1,
};
float expected_same[] = {
9, 0, 1, 1, 0, 3, 0, 19, 0, 12, 10, 0, 0, 0, 5, 0,
0, 18, 21, 7, 19, 4, 3, 0, 0, 9, 16, 0, 11, 16, 0, 11,
12, 2, 0, 11, 0, 16, 6, 0, 8, 22, 13, 10, 12, 0, 0, 0,
0, 1, 2, 12, 29, 6, 10, 0, 13, 0, 0, 5, 8, 10, 0, 0,
};
float expected_replicate[] = {
18, 17, 12, 2, 0, 0, 5, 11, 0, 17, 22, 6, 0, 0, 17, 0,
0, 18, 21, 7, 19, 4, 3, 5, 3, 9, 16, 0, 11, 16, 0, 3,
3, 2, 0, 11, 0, 16, 6, 0, 17, 22, 13, 10, 12, 0, 0, 0,
0, 4, 1, 10, 30, 7, 10, 0, 23, 8, 0, 13, 15, 19, 8, 10,
};
float expected_valid[] = {
18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
};
float weights[] = {
-2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
};
float bias[] = { -3 };
CNN_CONFIG cnn_config = { 1,
0,
0,
0,
0,
{ {
1,
filter_width,
filter_height,
1,
1,
1,
0,
weights,
bias,
PADDING_SAME_ZERO,
RELU,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
} } };
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
float weights[] = {
1, -5, -3, -4, -1, 1, 2, -3, 2, 2, -1, 1, -5, 1, 1,
-3, -5, 3, 1, 4, -2, -5, -2, -3, -5, 0, -1, -5, 2, -2,
-2, 1, -2, -4, 1, 3, -2, 2, 0, -3, 2, -3, -2, -3,
};
float bias[] = { 2 };
CNN_CONFIG cnn_config = { 1,
0,
0,
0,
0,
{
{
1,
4,
11,
1,
7,
6,
0,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
},
} };
int image_height = 24;
int image_width = 17;
float input[] = {
-1, -3, 4, 4, -5, 4, 3, -5, -1, -3, 4, -4, 2, -3, 3, -5, 2, -1, -5,
1, -1, 3, 1, -3, -3, 4, 0, 2, -3, -5, -5, -4, 0, -5, -2, -3, -1, -2,
2, -5, 4, 4, 0, -4, -3, 1, -3, -5, -4, -4, 1, -2, -3, 3, -3, -3, -1,
-5, -5, -2, 3, 1, -1, -5, -5, 1, -4, -2, -1, -2, -4, -4, 2, -2, 2, 1,
-2, -4, -1, 1, -2, -5, 3, -2, -1, -1, -5, -3, 1, -2, -2, -3, -1, -2, -4,
-2, 1, -4, -1, 4, 3, -4, 0, 4, 2, 2, 4, -3, -5, 2, 2, 1, -1, -4,
-2, 1, 3, 2, 0, 4, -1, -3, 2, 1, -4, 2, 2, -4, -2, 0, -2, -1, 4,
4, 2, 3, -4, 2, -4, -5, 4, -1, -3, -1, 0, -4, 1, 3, -1, -3, -5, 3,
-2, -4, 1, 2, -2, -3, -3, -5, 1, -3, -1, 0, -1, 3, -4, -1, -5, -5, 1,
0, 0, -2, -2, 2, -2, 0, 0, 2, 0, -3, 0, -1, -4, -4, -1, 3, -4, -4,
-1, 0, -5, -3, -2, 4, -3, -4, -4, 0, -5, 1, -2, -3, -3, -4, 4, 3, 4,
3, 3, -1, 3, 1, -3, -2, 3, 3, 0, 2, -4, -3, 2, 2, 0, -2, 4, -2,
2, -2, -1, -4, -2, 2, -4, 3, -1, 4, 1, 1, 4, -1, -4, -4, 1, 1, -2,
4, -1, 3, 2, -3, 4, 3, 1, 4, 0, -4, 2, 0, 2, 4, -2, -2, 4, 2,
-1, -2, 1, -3, 2, 3, -5, -3, 4, 4, 2, -5, -4, -5, -2, -4, 2, 0, 2,
-5, 4, -4, -2, -5, 2, 1, 0, 4, 1, -2, -3, -4, -3, -4, 3, 3, 2, 0,
-3, 1, -5, 4, 0, 4, -1, 3, -5, -5, -2, -1, -1, 4, 3, 3, 4, 3, -4,
4, -3, -3, -1, -4, -1, -4, -1, -2, 4, -2, -4, 4, 4, -3, -4, -1, 1, 2,
-1, -2, -2, 3, 2, 2, -3, 0, -1, 0, 3, 2, -5, 0, -4, 0, 0, 2, -4,
-1, -1, 0, -2, 0, 1, 0, 0, 4, -5, -1, -5, 2, -1, 0, 2, -1, 1, 3,
-3, -5, -2, -3, 4, -2, -2, -1, -3, -4, -1, -2, -4, 1, 4, -3, -2, -1, 3,
-3, -2, 3, 2, 1, -4, -3, -5, 1,
};
float expected_1[] = {
41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
};
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].skip_width = 6;
cnn_config.layer_config[0].skip_height = 7;
float expected_2[] = {
21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
};
RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].skip_width = 3;
cnn_config.layer_config[0].skip_height = 10;
float expected_3[] = {
-26, -21, -35, 69, 49, 4, -51, -43, -56,
-41, 15, -44, 40, -62, 63, 38, 27, 47,
};
RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].skip_width = 10;
cnn_config.layer_config[0].skip_height = 3;
float expected_4[] = {
21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
};
RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestMaxPool) {
int image_width = 8;
int image_height = 8;
int stride = 3;
float input[] = {
1, -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8, 5, -1, -1, 9,
-3, 0, -2, 0, 6, 3, -4, 8, 7, 8, 7, -1, 4, -1, 0, 2,
-5, -2, 8, 5, 5, 4, 2, 7, 4, 6, 2, 8, 8, -4, -3, -4,
-3, -1, 2, 3, 3, 6, -5, 8, 9, 5, 0, -2, -1, 6, 5, 7,
};
float expected[] = {
49, 58, 70, 68, 68, 70, 48, 57, 88,
};
float weights[] = {
3, 1, 3, 4, -1, 5, -2, 1, -4,
};
float bias[] = {
-3,
};
CNN_CONFIG cnn_config = { 1,
0,
0,
0,
0,
{ {
1,
3,
3,
1,
stride,
stride,
1,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
} } };
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
int image_width = 4;
int image_height = 7;
float input[] = {
9, 6, 181, 9, 218, 30, 80, 108, 68, 216, 70, 128, 179, 228,
33, 212, 34, 14, 48, 27, 230, 23, 202, 113, 80, 56, 122, 112,
};
float expected_1_same[] = {
15, -30, 36, -525, 377, -193, 558, 531, 6, -24, -15, 124,
166, -561, -356, -754, -3, -3, -3, -3, -3, -3, -3, -3,
433, -311, 711, 381, 247, -317, 453, 129, 215, -627, -409, -885,
17, -255, -55, -647, -3, -3, -3, -3, -3, -3, -3, -3,
133, -719, 633, -225, 785, 191, 463, 79, 65, 9, 77, -853,
-365, -949, -15, -667, -3, -3, -3, -3, -3, -3, -3, -3,
355, -866, 990, 207, 747, 12, 520, -116, 176, -312, -133, -1370,
-426, -802, 143, -771, -3, -3, -3, -3, -3, -3, -3, -3,
65, -79, 127, -59, 135, -90, 195, 114, 31, -91, -57, -133,
17, -176, -72, -276, -3, -3, -3, -3, -3, -3, -3, -3,
457, -302, 733, 58, 470, -475, 829, 490, 227, -670, -440, -790,
153, -588, -294, -1150, -3, -3, -3, -3, -3, -3, -3, -3,
157, -251, 349, -185, 409, -293, 587, 251, 77, -187, -107, -369,
7, -481, -135, -827, -3, -3, -3, -3, -3, -3, -3, -3,
};
float expected_1_valid[] = {
-30, 15, -30, 36, -525, 377, -193, 558, 531, 24, 24, 6,
6, -24, -15, 124, 166, -561, -356, -754, -21, -39, -3, -3,
-3, -3, -3, -3, -3, -3, -3, -3, -3, -657, 433, -311,
711, 381, 247, -317, 453, 129, 321, 321, 215, 215, -627, -409,
-885, 17, -255, -55, -647, -219, -435, -3, -3, -3, -3, -3,
-3, -3, -3, -3, -3, -3, -207, 133, -719, 633, -225, 785,
191, 463, 79, 381, 381, 65, 65, 9, 77, -853, -365, -949,
-15, -667, -259, -515, -3, -3, -3, -3, -3, -3, -3, -3,
-3, -3, -3, -540, 355, -866, 990, 207, 747, 12, 520, -116,
633, 633, 176, 176, -312, -133, -1370, -426, -802, 143, -771, -427,
-851, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3,
-105, 65, -79, 127, -59, 135, -90, 195, 114, 78, 78, 31,
31, -91, -57, -133, 17, -176, -72, -276, -57, -111, -3, -3,
-3, -3, -3, -3, -3, -3, -3, -3, -3, -693, 457, -302,
733, 58, 470, -475, 829, 490, 336, 336, 227, 227, -670, -440,
-790, 153, -588, -294, -1150, -229, -455, -3, -3, -3, -3, -3,
-3, -3, -3, -3, -3, -3, -243, 157, -251, 349, -185, 409,
-293, 587, 251, 333, 333, 77, 77, -187, -107, -369, 7, -481,
-135, -827, -227, -451,
};
float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
float bias_1[] = { -3 };
CNN_CONFIG cnn_config = { 1,
0,
0,
0,
0,
{ {
1,
5,
2,
1,
2,
3,
0,
weights_1,
bias_1,
PADDING_SAME_ZERO,
NONE,
1,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
} } };
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
// Change padding to valid
cnn_config.layer_config[0].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
float expected_12_same[] = {
15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12, 24,
6, -30, -15, -33, -21, 166, 154, -546, -356, -718, -30, -21,
433, -221, 561, 711, -33, -153, 247, -83, -87, 453, -111, 321,
215, -657, -409, -845, -93, 17, -43, -243, -55, -215, -327, -219,
133, -71, -447, 633, -219, 435, 785, -73, -177, 463, -131, 381,
65, -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259,
355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215, 633,
176, -540, -133, -491, -687, -426, -882, -102, 143, 77, -639, -427,
65, -37, 57, 127, -17, -105, 135, -51, 60, 195, -30, 78,
31, -105, -57, -125, -45, 17, -11, -147, -72, -168, -84, -57,
457, -233, 618, 733, -26, -540, 470, -205, 264, 829, -116, 336,
227, -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229,
157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115, 333,
77, -243, -107, -267, -171, 7, -105, -369, -135, -379, -339, -227,
};
float expected_12_valid[] = {
-30, 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12,
24, 24, 6, 6, -30, -15, -33, -21, 166, 154, -546, -356,
-718, -30, -21, -39, -657, 433, -221, 561, 711, -33, -153, 247,
-83, -87, 453, -111, 321, 321, 215, 215, -657, -409, -845, -93,
17, -43, -243, -55, -215, -327, -219, -435, -207, 133, -71, -447,
633, -219, 435, 785, -73, -177, 463, -131, 381, 381, 65, 65,
-207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259, -515,
-540, 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215,
633, 633, 176, 176, -540, -133, -491, -687, -426, -882, -102, 143,
77, -639, -427, -851, -105, 65, -37, 57, 127, -17, -105, 135,
-51, 60, 195, -30, 78, 78, 31, 31, -105, -57, -125, -45,
17, -11, -147, -72, -168, -84, -57, -111, -693, 457, -233, 618,
733, -26, -540, 470, -205, 264, 829, -116, 336, 336, 227, 227,
-693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229, -455,
-243, 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115,
333, 333, 77, 77, -243, -107, -267, -171, 7, -105, -369, -135,
-379, -339, -227, -451,
};
// Change skip_width, skip_height to {2, 3}
cnn_config.layer_config[0].skip_width = 3;
cnn_config.layer_config[0].skip_height = 2;
// Set padding to same
cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
// Change padding to valid
cnn_config.layer_config[0].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].filter_width = 4;
cnn_config.layer_config[0].filter_height = 3;
float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
float bias_2[] = { -4 };
cnn_config.layer_config[0].weights = weights_2;
cnn_config.layer_config[0].bias = bias_2;
cnn_config.layer_config[0].skip_width = 5;
cnn_config.layer_config[0].skip_height = 2;
float expected_2_same[] = {
-13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
-185, -547, -4, -13, -31, -13, -31, -4, -4, 14, -22, 32,
-4, -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4,
14, -22, 32, -4, -195, -658, -213, -622, -4, -16, -94, -28,
-70, -4, 459, -244, 97, 480, -4, -85, -328, -103, -292, -4,
-4, 432, -440, 868, -4, -4, 56, -64, 116, -4, -4, 156,
-164, 316, -4, -4, 212, -220, 428, -4, 582, -208, 146, 664,
-4, -130, -652, -190, -532, -4, 166, -214, 6, 106, -4, 192,
-388, -24, 44, -4, -4, 132, -140, 268, -4, -4, 428, -436,
860, -4, -4, 136, -144, 276, -4, -4, 252, -260, 508, -4,
21, -541, -115, -269, -4, 416, -688, -16, 176, -4, 173, -103,
33, 177, -4, 168, -640, -88, -128, -4, -4, 354, -362, 712,
-4, -4, 452, -460, 908, -4, -4, 62, -70, 128, -4, -4,
420, -428, 844, -4, 499, -106, 141, 610, -4, 666, -46, 210,
866, -4, 47, -148, -19, -16, -4, 605, -85, 181, 763, -4,
-4, 64, -72, 132, -4, -4, 24, -32, 52, -4, -4, 92,
-100, 188, -4, -4, 50, -58, 104, -4, -132, -694, -200, -558,
-4, 15, -73, -13, -17, -4, -62, -610, -158, -418, -4, -36,
-343, -90, -235, -4, -4, 456, -464, 916, -4, -4, 42, -50,
88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448, -4,
606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
76, 438, -4, 223, -340, -3, 112, -4, -4, 156, -164, 316,
-4, -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4,
220, -228, 444, -4,
};
float expected_2_valid[] = {
-13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
-185, -547, -4, -13, -31, -13, -31, -4, 14, -22, 32, -4,
-4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4, 14,
-22, 32, -195, -658, -213, -622, -4, -16, -94, -28, -70, -4,
459, -244, 97, 480, -4, -85, -328, -103, -292, -4, 432, -440,
868, -4, -4, 56, -64, 116, -4, -4, 156, -164, 316, -4,
-4, 212, -220, 428, 582, -208, 146, 664, -4, -130, -652, -190,
-532, -4, 166, -214, 6, 106, -4, 192, -388, -24, 44, -4,
132, -140, 268, -4, -4, 428, -436, 860, -4, -4, 136, -144,
276, -4, -4, 252, -260, 508, 21, -541, -115, -269, -4, 416,
-688, -16, 176, -4, 173, -103, 33, 177, -4, 168, -640, -88,
-128, -4, 354, -362, 712, -4, -4, 452, -460, 908, -4, -4,
62, -70, 128, -4, -4, 420, -428, 844, 499, -106, 141, 610,
-4, 666, -46, 210, 866, -4, 47, -148, -19, -16, -4, 605,
-85, 181, 763, -4, 64, -72, 132, -4, -4, 24, -32, 52,
-4, -4, 92, -100, 188, -4, -4, 50, -58, 104, -132, -694,
-200, -558, -4, 15, -73, -13, -17, -4, -62, -610, -158, -418,
-4, -36, -343, -90, -235, -4, 456, -464, 916, -4, -4, 42,
-50, 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448,
606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
76, 438, -4, 223, -340, -3, 112, -4, 156, -164, 316, -4,
-4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4, 220,
-228, 444, 236, -4, 76, 316, -4, 164, -4, 52, 220, -4,
362, -4, 118, 484, -4, 332, -4, 108, 444,
};
// Set padding to same
cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].skip_width = 2;
cnn_config.layer_config[0].skip_height = 5;
float expected_21_same[] = {
-31, -19, -49, -191, -565, -194, -574, -13, 14, -22, 44, -16,
382, -366, 738, -22, -4, 23, 32, 545, 20, 204, 720, 5,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -658, -252, -748, -114, -334, -192, -568, -112,
432, -440, 928, -64, 276, -164, 532, -220, -4, 304, 868, 266,
116, 400, 316, 104, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -208, -288, -856, -290,
-862, -202, -598, -132, 132, -140, 700, -436, 1000, -144, 532, -260,
-4, 712, 268, 422, 860, 450, 276, 124, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-541, -411, -1225, -265, -787, -249, -739, -216, 354, -362, 1168, -460,
974, -70, 552, -428, -4, 859, 712, 323, 908, 665, 128, 208,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -106, -52, -148, -66, -190, -79, -229, -31,
64, -72, 160, -32, 148, -100, 242, -58, -4, 72, 132, 154,
52, 125, 188, 23, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -694, -257, -763, -229,
-679, -319, -949, -117, 456, -464, 962, -50, 492, -408, 1030, -230,
-4, 295, 916, 625, 88, 537, 804, 109, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-244, -140, -412, -182, -538, -238, -706, -116, 156, -164, 428, -116,
464, -248, 708, -228, -4, 244, 316, 418, 220, 454, 484, 108,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4,
};
float expected_21_valid[] = {
-13, -31, -19, -49, -191, -565, -194, -574, -13, -31, -4, 14,
-22, 44, -16, 382, -366, 738, -22, 32, 23, -4, 23, 32,
545, 20, 204, 720, 5, 32, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -222, -658, -252, -748, -114, -334, -192, -568, -112, -328,
-4, 432, -440, 928, -64, 276, -164, 532, -220, 428, 650, -4,
304, 868, 266, 116, 400, 316, 104, 428, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -72, -208, -288, -856, -290, -862, -202, -598,
-132, -388, -4, 132, -140, 700, -436, 1000, -144, 532, -260, 508,
200, -4, 712, 268, 422, 860, 450, 276, 124, 508, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -183, -541, -411, -1225, -265, -787,
-249, -739, -216, -640, -4, 354, -362, 1168, -460, 974, -70, 552,
-428, 844, 533, -4, 859, 712, 323, 908, 665, 128, 208, 844,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -38, -106, -52, -148,
-66, -190, -79, -229, -31, -85, -4, 64, -72, 160, -32, 148,
-100, 242, -58, 104, 98, -4, 72, 132, 154, 52, 125, 188,
23, 104, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -234, -694,
-257, -763, -229, -679, -319, -949, -117, -343, -4, 456, -464, 962,
-50, 492, -408, 1030, -230, 448, 686, -4, 295, 916, 625, 88,
537, 804, 109, 448, -4, -4, -4, -4, -4, -4, -4, -4,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-84, -244, -140, -412, -182, -538, -238, -706, -116, -340, -4, 156,
-164, 428, -116, 464, -248, 708, -228, 444, 236, -4, 244, 316,
418, 220, 454, 484, 108, 444,
};
cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestLargeKernelsAndStrides) {
float input_10x11[] = {
4, 4, 2, 4, 2, -5, -2, 3, -1, 0, 0, 1, 2, 0, -5, -2, -5, 1, -3,
-1, 4, -3, 2, -2, 1, 0, 1, -3, -3, -4, -2, -2, 1, -4, -1, 4, 1, -4,
-4, -4, 3, 2, -5, 3, -5, 1, 2, -4, 1, -1, 3, 4, -2, 3, -3, 3, 0,
2, -4, -5, -5, -2, -1, -2, 1, 1, 1, -2, 4, -5, 4, -1, -1, 2, 3, -4,
2, 2, 3, 0, 0, 1, 0, 3, 2, 3, 1, -2, 3, -4, 3, 2, 4, -2, 0,
4, -4, 1, -3, -3, -3, -5, 1, -3, -5, 0, 4, -1, -3, 2,
};
float weights_10x11[] = {
-3, 4, -4, -3, -5, 1, -2, 3, 1, -4, -4, 0, -1, 0, 3, 1, -3, -2, 0,
-1, 1, 3, -4, -4, -3, -3, -2, 4, 3, -5, 4, 2, -3, 4, -2, -1, 2, -1,
-5, 0, -3, 0, 3, -5, -5, 3, -4, -1, -5, 3, 4, 0, 4, -5, 2, -1, 2,
-1, -1, -1, -5, 0, -4, 3, -1, 1, 1, -1, 3, 2, -5, -4, 0, -4, 4, -5,
-3, 4, -5, 2, -5, -4, -4, -1, 3, 3, 0, 2, -4, 1, -2, 1, 1, 0, 3,
-2, 0, 1, 2, 4, -3, -1, -5, -5, 2, -4, 1, 1, 2, -4, -2, -2, 2, 1,
3, 4, -5, 1, -1, -3, -3, -1, -2, -5, 1, -1, 0, 1, 4, 4, 0, 0, 4,
-3, -1, -5, -3, 0, 1, 1, 1, -5, 3, 4, 3, -5, 3, -2, -2, 0, -4, 0,
0, -2, 1, -4, -1, 0, -5, -2, -2, -5, -3, -3, 1, 1, -3, 2, 4, 2, 4,
-4, -3, 3, 1, 1, 3, -4, 4, -2, -3, -3, -3, -3, -4, -2, 3, -5, 2, 4,
-1, -4, -4, 4, -2, -1, 3, -3, -4, -4, -2, 4, 1, 0, 2, -1, 4, -3, 1,
4, -3, 4, 4, 0, -4, 3, -2, -3, 2, 3, -1, -3, 2, 1, 4, -2, -3, 1,
4, -2, 2, -2, -5, -2, 1, 4, -1, -4, 4, -5, 2, -5, -4, -1, -2, 3, 1,
2, 1, -5, 1, -5, -4, -1, -2, 2, -2, -4, -3, -2, -2, 4, -1, 2, 2, -4,
2, -2, 4, -4, -2, -2, 1, -1, 1, 1, 1, -4, -5, -2, 3, -4, -1, 3, -2,
3, 2, -5, -4, 0, 3, -2, -4, -5, 3, -2, -4, 2, -2, 1, -4, 0, 2, -5,
1, -4, -1, -1, 4, -5, -4, 0, -5, -4, -3, -5, -4, 0, 2, 0, -4, 2, -2,
1, 1, -3, 2, 0, -4, 0, -4, 1, 0, -5, -1, -1, -1, -5, 4, 2, 2, -4,
3, -2, -2, 2, -3, -2, -1, 2, -4, -5, 2, -2, -4, -5, -5, -1, 2, -1, 0,
-5, -2, -2, -5, 0, 1, -1, -5, 0, 3, 2, 3, 0, -3, -2, 0, -5, -1, -2,
2, -4, -1, 2, 2, -5, 2, -4, 0, 3, -3, 1, 0, 0, 1, -5, -3, 1, -1,
0, -4, -3, 2, -4, -4, 4, -1, 0, 1, 2, -4, -5, 4, -2, 1, -4, -4, -3,
-1, -1, 1, -1, -4, -1, -4, -3, 2, -1, -2, -4, 1, 1, 0, -2, 0, -4, 3,
-3, 0, -4, -1, -4, 2, -1, -2, -5, -1, -2, -3, 3, -1, 0, -3, 0, 1, -5,
1, -5, 0, 1,
};
float bias_10x11[] = { 3 };
float expected_10x11[] = {
118,
};
CNN_CONFIG cnn_config = { 1,
0,
0,
0,
0,
{ {
1,
23,
20,
1,
15,
20,
0,
weights_10x11,
bias_10x11,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
} } };
int image_height = 10;
int image_width = 11;
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
&cnn_config, image_width, &thread_data, MSE_INT_TOL);
float input_11x10[] = {
-2, -2, 3, -5, -1, -3, 1, 3, 2, 1, 1, -5, 4, 1, 3, -5, 3, -3, -5,
0, -1, -3, -3, 1, 1, -5, -1, -5, -5, -3, 0, 1, -3, -1, -3, -3, 0, 3,
4, -4, -1, 3, -3, -1, -3, 1, -3, -2, -1, -4, -3, 2, -4, 1, -4, -1, -3,
-5, -1, 2, 3, 0, 2, 2, -5, 4, 1, 2, -1, -4, 4, -4, -4, 0, -1, 1,
-1, 1, -3, -3, -2, 1, 2, 4, 4, 4, -3, -3, 0, 1, 0, 1, 4, 1, 3,
4, -3, -2, -4, 4, 2, 0, 3, 4, -1, 2, -2, 1, -3, -2,
};
float weights_11x10[] = {
4, -1, 1, -1, 2, 4, 3, 3, -4, 3, -5, 1, -1, -1, -2, -2, 0, 2, -3,
-2, 3, -5, -1, 0, -1, -2, -2, -1, 2, 4, 3, 1, 0, 0, -3, 3, -4, -1,
-5, 4, -2, -2, 1, 2, -1, -3, 1, 2, -5, 1, -3, 3, 3, 0, -4, -4, -5,
-3, -4, -4, 4, -2, 4, 4, -2, 2, -5, -1, -2, -5, -1, 4, -3, 3, -2, 0,
-4, -3, 0, -1, -2, 4, 2, 0, -2, -5, -4, 1, 4, -4, -2, 2, -2, 1, 1,
-4, 1, -4, -4, -2, 4, 2, -1, -5, -5, 1, -3, -3, 3, -3, -5, -3, 4, -1,
-1, -3, 0, -4, 3, -1, 0, -2, 0, -5, -2, -5, 2, 0, -5, 2, 3, -2, 2,
4, -1, 1, -3, 2, 3, 2, 0, -5, -4, -5, 2, 1, 1, -1, -2, 3, 4, 2,
-2, 4, -2, 3, 1, -4, -3, -1, 4, 4, -3, -5, -2, 2, 0, 3, -2, 3, -1,
-4, 0, -2, 0, 3, 4, -2, -3, -2, 0, 3, 4, 2, -4, 0, 1, 2, 2, -1,
-1, 4, 1, 4, -2, -1, -1, -5, 1, -3, 3, 3, -1, -4, 3, -5, 0, 0, -1,
-4, -1, -2, 4, -2, 3, 3, -3, 1, -1, 2, -1, 4, 4, -2, -2, 4, -2, 0,
3, -3, -5, -1, -2, 4, -4, 2, -4, 0, -2, 3, -3, 2, 2, -2, -5, -1, 4,
3, -2, -1, 3, 3, -1, 3, 0, -3, 0, 4, 2, 0, -1, 4, 1, 1, 2, 1,
3, 1, 1, 1, -3, -5, -4, 4, -4, 2, 0, 0, -4, 1, 4, -5, 4, 4, 0,
1, 0, -2, -4, -4, -3, 0, 1, -5, 4, 0, -3, -2, -4, 2, 4, 1, -5, 1,
-4, 1, 0, -3, -3, 0, 2, -5, 4, 3, -2, -5, 3, 1, -1, 0, 3, -2, -2,
3, -2, -5, 4, 1, -2, 2, -1, 0, 4, 0, -5, 3, -2, 1, 2, 1, -5, -3,
-2, -5, 4, -4, 0, 3, 2, -1, -4, -1, 2, 1, -2, 3, -1, -4, 2, 0, -3,
1, -1, 2, -5, -4, -1, -5, 1, 4, 3, 4, 2, -3, 1, -5, -1, 3, 0, -1,
-4, 3, 4, -5, 4, 4, -3, 2, -3, -1, -3, -5, -3, 2, -3, -2, 1, 1, 0,
-5, 3, 2, 1, -5, 1, 1, 1, 3, 4, -4, -1, -2, 0, -5, -3, -5, -2, -4,
3, 3, 3, 4, 0, -4, -1, -5, 0, -3, 1, 4, 4, -4, 4, -5, -5, -1, -2,
-5, 3, -4, 4, 3, 0, -3, 2, -2, 0, 0, 4, 4, 0, -2, 1, -1, -3, 2,
-1, 1, -3, -5,
};
float bias_11x10[] = {
-5,
};
float expected_11x10[] = {
36, -84, 95, 45, 18, 46, 77, -54, -99, -149, 66, 49, 161, 11,
39, 61, -66, 61, 4, -3, 34, -44, -23, 31, 64, 29, 47, 72,
-27, -27, 121, -3, 100, 1, 30, -78, -12, -89, -59, 8, -16, 112,
91, -102, -26, -4, 30, 54, 4, -84, -24, -58, 27, -53, -33, 5,
53, -26, 63, 50, -103, -130, -23, 6, -104, -207, 73, 23, 77, 132,
38, 32, -130, -44, -60, 7, 27, 176, 45, -32, -2, 99, -97, 63,
69, 126, 47, 63, 136, -57, 5, 16, -40, -157, 8, 38, -44, -10,
91, 7, 122, 140, 30, -105, 4, -1, 113, 64, 180, 141,
};
cnn_config.layer_config[0].weights = weights_11x10;
cnn_config.layer_config[0].bias = bias_11x10;
cnn_config.layer_config[0].filter_width = 20;
cnn_config.layer_config[0].filter_height = 23;
cnn_config.layer_config[0].skip_width = 1;
cnn_config.layer_config[0].skip_height = 1;
image_height = 11;
image_width = 10;
RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
&cnn_config, image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestSoftsignSingleLayer) {
int image_width = 8;
int image_height = 8;
int filter_height = 5;
int filter_width = 4;
float input[] = {
-0.5220f, 0.8410f, -0.8990f, -0.0090f, 0.6710f, -0.9470f, -0.8240f,
-0.0870f, 0.5380f, 0.4750f, 0.570f, -0.3760f, -0.6960f, -0.5940f,
-0.3830f, 0.080f, -0.0980f, -0.4940f, -0.4030f, 0.9460f, -0.6020f,
0.4220f, 0.6190f, 0.6640f, -0.9210f, -0.1470f, -0.2480f, -0.1120f,
-0.580f, -0.0650f, 0.3330f, 0.9860f, -0.7430f, 0.7610f, 0.4840f,
0.1030f, 0.9570f, 0.6120f, -0.5240f, -0.1220f, -0.5850f, -0.270f,
0.7840f, -0.9790f, 0.7290f, -0.30f, -0.6460f, 0.0780f, 0.4750f,
-0.0510f, 0.4550f, 0.3850f, -0.7230f, 0.4460f, -0.6260f, -0.810f,
0.8720f, -0.2120f, -0.580f, -0.9510f, -0.8430f, -0.1340f, -0.0850f,
0.9190f,
};
float expected_same[] = {
0.430f, 0.660f, 0.5510f, -0.610f, 0.450f, -0.1610f, 0.0520f, 0.3240f,
0.6820f, 0.3820f, 0.6360f, 0.7480f, 0.3080f, 0.090f, 0.3910f, 0.1730f,
0.340f, 0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.4670f, 0.6150f,
-0.3880f, 0.7590f, 0.4190f, 0.7350f, 0.5310f, -0.5160f, -0.1760f, 0.6790f,
-0.6780f, 0.5470f, 0.5750f, -0.6420f, 0.7210f, -0.4620f, 0.5430f, 0.770f,
-0.1990f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.6430f, 0.4510f,
-0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f, 0.5870f, 0.4720f,
0.4040f, 0.3630f, 0.670f, 0.2360f, 0.410f, 0.6980f, -0.5350f, 0.3940f,
};
float expected_replicate[] = {
0.540f, 0.7230f, -0.3530f, -0.2130f, 0.7440f, -0.4470f, -0.6260f,
-0.2050f, 0.7230f, 0.4630f, 0.5920f, 0.7440f, 0.6080f, 0.3130f,
-0.5670f, -0.4720f, 0.5480f, 0.6660f, -0.4990f, 0.4280f, 0.1540f,
0.120f, 0.3390f, 0.6090f, 0.4160f, 0.7590f, 0.4190f, 0.7350f,
0.5310f, -0.5160f, -0.490f, 0.4450f, -0.610f, 0.5470f, 0.5750f,
-0.6420f, 0.7210f, -0.4620f, 0.3150f, 0.7370f, -0.5820f, 0.3950f,
0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.7430f, -0.5340f, -0.6270f,
0.4430f, 0.4730f, 0.4570f, 0.7450f, 0.630f, 0.2620f, 0.3140f,
-0.1840f, 0.1810f, 0.7210f, 0.2760f, 0.6430f, 0.6720f, -0.4390f,
0.2040f,
};
float expected_valid[] = {
0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.7590f, 0.4190f,
0.7350f, 0.5310f, -0.5160f, 0.5470f, 0.5750f, -0.6420f, 0.7210f,
-0.4620f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f,
};
float weights[] = {
0.6210f, 0.3710f, -0.2770f, -0.7230f, -0.2450f, 0.6770f, 0.3080f,
-0.9880f, -0.080f, 0.7190f, -0.6760f, -0.0170f, -0.8970f, 0.8260f,
0.7390f, -0.4550f, -0.4260f, -0.6330f, 0.0880f, -0.9390f,
};
float bias[] = {
0.750f,
};
CNN_CONFIG cnn_config = { 1,
0,
0,
0,
0,
{ {
1,
filter_width,
filter_height,
1,
1,
1,
0,
weights,
bias,
PADDING_SAME_ZERO,
SOFTSIGN,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
} } };
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
cnn_config.layer_config[0].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
}
TEST_F(CNNTest, TestBranchTensorAdd) {
int filter_width = 2;
int filter_height = 3;
int image_width = 4;
int image_height = 4;
float input[] = {
-3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
};
float weights[] = {
-3, -1, 4, -1, -3, 3, 3, 0, 2, 0, 3, 2, 4, 4, 4, -5, 1, -4,
2, -4, 1, -3, 0, 4, -5, 4, 0, -4, -3, -1, 0, 0, -2, 0, 0, 2,
-5, -1, 1, -3, 3, 4, 3, 0, 1, -1, 1, 1, 2, 4, -2, -5, 2, -2,
3, -2, 4, -1, 0, 2, 3, 2, -2, -1, -3, 1, 3, 4, -1, -3, 0, -4,
4, 2, -3, -3, -1, 0, 1, 0, 3, 3, -3, 0, 3, 2, -5, -3, 4, -5,
3, -1, -1, -3, 0, 1, -1, -4, 2, 4, -1, 4, -1, 1, 3, 4, 4, 4,
0, -1, -3, -3, -3, -3, 2, -3, -2, 2, 3, -3,
};
float bias[] = {
3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
};
float expected[] = {
-11502, -4101, -3424, 668, -17950, -5470, -5504, 626,
4835, 446, 1779, -3483, 3679, -4214, 4578, -105,
};
int channels = 2;
CNN_CONFIG cnn_config = { 6,
0,
0,
0,
0,
{ {
1,
filter_width,
filter_height,
channels,
1,
1,
0,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_INPUT,
BRANCH_NOC,
{
0x02,
0,
0x00,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
1,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
1,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0x00,
0,
0x02,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
1,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
} } };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestBranchTensorConcatenation) {
int filter_width = 2;
int filter_height = 3;
int image_width = 4;
int image_height = 4;
float input[] = {
-3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
};
float weights[] = {
3, 0, 2, 0, 2, 3, 1, -3, 1, -5, -3, 0, -4, 4, 0, -5, 0, -5, -1,
-2, -5, 0, -3, 2, -4, 2, 0, 2, -1, 0, -4, 3, 0, 0, -1, -5, 2, -1,
4, -4, -2, -3, -3, 3, 4, -2, -1, -4, -1, 4, 4, -1, 4, 3, -4, 2, -2,
-4, -3, -2, 3, -3, -5, -1, 3, -2, 4, 1, -4, -3, -5, -5, -3, 4, -2, -2,
-1, -5, -5, 0, -1, -2, -3, 3, -4, -5, 2, -3, 1, 0, -5, 2, 2, -2, 0,
2, 2, -2, 4, 2, 2, 0, 1, -5, -3, 0, 2, -2, 1, 2, -5, 2, 3, 3,
-1, 3, 0, -3, 3, -4, -4, 3, 3, -4, -2, 2, -2, 2, -2, -1, 3, 0,
};
float bias[] = {
-3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
};
float expected[] = {
-33533, -32087, -6741, -2124, 39979, 41453, 14034, 689,
-22611, -42203, -14882, -239, 15781, 15963, 9524, 837,
};
int channels = 2;
CNN_CONFIG cnn_config = { 6,
0,
0,
0,
0,
{ {
1,
filter_width,
filter_height,
channels,
1,
1,
0,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_INPUT,
BRANCH_NOC,
{
0x02,
0,
0x00,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
1,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
1,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_CAT,
{
0x00,
0,
0x02,
},
{},
-1,
},
{
channels + channels,
filter_width,
filter_height,
1,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
} } };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
// TODO(logangw): Add test to test all combinations of branch_copy_type.
TEST_F(CNNTest, TestBranchCombinations) {
int filter_width = 2;
int filter_height = 3;
int image_width = 4;
int image_height = 4;
float input[] = {
3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
};
float weights[] = {
2, 3, 0, 4, 4, 3, 1, 0, 1, -5, 4, -3, 3, 0, 4, -1, -1, -5,
2, 1, -3, -5, 3, -1, -3, -2, 0, -2, 3, 0, -2, -4, -2, -2, 2, -5,
4, -5, 0, 1, -5, -4, -3, -4, 2, -2, 1, 0, 3, -2, -4, 3, 4, -4,
-1, -1, -3, -2, -2, -1, 2, 0, 2, -1, 2, -4, -4, -1, 2, 0, 3, -2,
-2, 3, -3, 4, -2, 4, 3, 4, 1, 0, -2, -3, -5, 1, -3, 2, 0, -2,
-2, -1, -1, -5, -2, -3, -1, 3, 3, 4, 4, 0, 2, 1, 3, -3, 2, -5,
-5, 1, -5, -1, 3, 3, 2, -4, -1, 3, -4, -2, -5, -2, 1, 3, 2, 2,
-5, -2, -3, -1, -2, -4, -1, -2, 2, 1, -4, -4, 2, 0, 2, 0, 2, -3,
-2, -4, 4, 0, 1, -3, -5, 4, -1, 2, 3, -5, -1, 0, 4, -1, -1, 3,
-1, -3, 3, 1, 4, 3, 4, 3, -4, -5, -1, 3, 3, -4, 3, 1, 3, -5,
3, 4, -5, 4, 2, -1, -5, 2, 1, 0, 4, 0, -3, 2, 0, 2, -2, 1,
-1, -2, -1, -5, 4, 3, 3, -2, 2, 4, -5, -5, -3, -2, 4, 0, -4, 1,
};
float bias[] = {
-1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
};
float expected[] = {
149496, 15553, -24193, -20956, 134094, 86432, -68283, -6366,
-53031, 133739, 67407, -13539, -53205, -58635, -20033, 1979,
};
int channels = 2;
CNN_CONFIG cnn_config = { 10,
0,
0,
0,
0,
{
{
1,
filter_width,
filter_height,
channels,
1,
1,
0,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_INPUT,
BRANCH_NOC,
{
0x06,
0,
0x00,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
2,
BRANCH_OUTPUT,
BRANCH_NOC,
{
0x08,
0,
0x00,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
3,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
2,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0x00,
0,
0x08,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
2,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
1,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
1,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0x00,
0,
0x0C,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
channels,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0x00,
0,
0x02,
},
{},
-1,
},
{
channels,
filter_width,
filter_height,
1,
1,
1,
0,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0,
0,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0,
},
} };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1, nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}