JNT_COMP: unit tests for aom_jnt_sub_pixel_avg_variance

Both c function and ssse3 have passed unit tests.

Change-Id: I48cff97ebf2735b43256b83f3b41ce7ccdf27393
diff --git a/test/variance_test.cc b/test/variance_test.cc
index 2f5c222..9e1240b 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -41,6 +41,12 @@
 typedef unsigned int (*Get4x4SseFunc)(const uint8_t *a, int a_stride,
                                       const uint8_t *b, int b_stride);
 typedef unsigned int (*SumOfSquaresFunction)(const int16_t *src);
+#if CONFIG_JNT_COMP
+typedef unsigned int (*JntSubpixAvgVarMxNFunc)(
+    const uint8_t *a, int a_stride, int xoffset, int yoffset, const uint8_t *b,
+    int b_stride, uint32_t *sse, const uint8_t *second_pred,
+    const JNT_COMP_PARAMS *jcp_param);
+#endif
 
 using libaom_test::ACMRandom;
 
@@ -212,6 +218,68 @@
   return static_cast<uint32_t>(sse - ((se * se) >> (l2w + l2h)));
 }
 
+#if CONFIG_JNT_COMP
+static uint32_t jnt_subpel_avg_variance_ref(
+    const uint8_t *ref, const uint8_t *src, const uint8_t *second_pred, int l2w,
+    int l2h, int xoff, int yoff, uint32_t *sse_ptr, bool use_high_bit_depth,
+    aom_bit_depth_t bit_depth, JNT_COMP_PARAMS *jcp_param) {
+  int64_t se = 0;
+  uint64_t sse = 0;
+  const int w = 1 << l2w;
+  const int h = 1 << l2h;
+
+  xoff <<= 1;
+  yoff <<= 1;
+
+  for (int y = 0; y < h; y++) {
+    for (int x = 0; x < w; x++) {
+      // bilinear interpolation at a 16th pel step
+      if (!use_high_bit_depth) {
+        const int a1 = ref[(w + 1) * (y + 0) + x + 0];
+        const int a2 = ref[(w + 1) * (y + 0) + x + 1];
+        const int b1 = ref[(w + 1) * (y + 1) + x + 0];
+        const int b2 = ref[(w + 1) * (y + 1) + x + 1];
+        const int a = a1 + (((a2 - a1) * xoff + 8) >> 4);
+        const int b = b1 + (((b2 - b1) * xoff + 8) >> 4);
+        const int r = a + (((b - a) * yoff + 8) >> 4);
+        const int avg = ROUND_POWER_OF_TWO(
+            r * jcp_param->fwd_offset +
+                second_pred[w * y + x] * jcp_param->bck_offset,
+            DIST_PRECISION_BITS);
+        const int diff = avg - src[w * y + x];
+
+        se += diff;
+        sse += diff * diff;
+#if CONFIG_HIGHBITDEPTH
+      } else {
+        const uint16_t *ref16 = CONVERT_TO_SHORTPTR(ref);
+        const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
+        const uint16_t *sec16 = CONVERT_TO_SHORTPTR(second_pred);
+        const int a1 = ref16[(w + 1) * (y + 0) + x + 0];
+        const int a2 = ref16[(w + 1) * (y + 0) + x + 1];
+        const int b1 = ref16[(w + 1) * (y + 1) + x + 0];
+        const int b2 = ref16[(w + 1) * (y + 1) + x + 1];
+        const int a = a1 + (((a2 - a1) * xoff + 8) >> 4);
+        const int b = b1 + (((b2 - b1) * xoff + 8) >> 4);
+        const int r = a + (((b - a) * yoff + 8) >> 4);
+        const int avg =
+            ROUND_POWER_OF_TWO(r * jcp_param->fwd_offset +
+                                   sec16[w * y + x] * jcp_param->bck_offset,
+                               DIST_PRECISION_BITS);
+        const int diff = avg - src16[w * y + x];
+
+        se += diff;
+        sse += diff * diff;
+#endif  // CONFIG_HIGHBITDEPTH
+      }
+    }
+  }
+  RoundHighBitDepth(bit_depth, &se, &sse);
+  *sse_ptr = static_cast<uint32_t>(sse);
+  return static_cast<uint32_t>(sse - ((se * se) >> (l2w + l2h)));
+}
+#endif  // CONFIG_JNT_COMP
+
 ////////////////////////////////////////////////////////////////////////////////
 
 class SumOfSquaresTest : public ::testing::TestWithParam<SumOfSquaresFunction> {
@@ -582,6 +650,9 @@
   uint8_t *ref_;
   uint8_t *sec_;
   TestParams<FunctionType> params_;
+#if CONFIG_JNT_COMP
+  JNT_COMP_PARAMS jcp_param_;
+#endif
 
   // some relay helpers
   bool use_high_bit_depth() const { return params_.use_high_bit_depth; }
@@ -697,11 +768,59 @@
   }
 }
 
+#if CONFIG_JNT_COMP
+template <>
+void SubpelVarianceTest<JntSubpixAvgVarMxNFunc>::RefTest() {
+  for (int x = 0; x < 8; ++x) {
+    for (int y = 0; y < 8; ++y) {
+      if (!use_high_bit_depth()) {
+        for (int j = 0; j < block_size(); j++) {
+          src_[j] = rnd_.Rand8();
+          sec_[j] = rnd_.Rand8();
+        }
+        for (int j = 0; j < block_size() + width() + height() + 1; j++) {
+          ref_[j] = rnd_.Rand8();
+        }
+#if CONFIG_HIGHBITDEPTH
+      } else {
+        for (int j = 0; j < block_size(); j++) {
+          CONVERT_TO_SHORTPTR(src_)[j] = rnd_.Rand16() & mask();
+          CONVERT_TO_SHORTPTR(sec_)[j] = rnd_.Rand16() & mask();
+        }
+        for (int j = 0; j < block_size() + width() + height() + 1; j++) {
+          CONVERT_TO_SHORTPTR(ref_)[j] = rnd_.Rand16() & mask();
+        }
+#endif  // CONFIG_HIGHBITDEPTH
+      }
+      for (int x0 = 0; x0 < 2; ++x0) {
+        for (int y0 = 0; y0 < 4; ++y0) {
+          uint32_t sse1, sse2;
+          uint32_t var1, var2;
+          jcp_param_.fwd_offset = quant_dist_lookup_table[x0][y0][0];
+          jcp_param_.bck_offset = quant_dist_lookup_table[x0][y0][1];
+          ASM_REGISTER_STATE_CHECK(var1 = params_.func(ref_, width() + 1, x, y,
+                                                       src_, width(), &sse1,
+                                                       sec_, &jcp_param_));
+          var2 = jnt_subpel_avg_variance_ref(
+              ref_, src_, sec_, params_.log2width, params_.log2height, x, y,
+              &sse2, use_high_bit_depth(), params_.bit_depth, &jcp_param_);
+          EXPECT_EQ(sse1, sse2) << "at position " << x << ", " << y;
+          EXPECT_EQ(var1, var2) << "at position " << x << ", " << y;
+        }
+      }
+    }
+  }
+}
+#endif  // CONFIF_JNT_COMP
+
 typedef MainTestClass<Get4x4SseFunc> AvxSseTest;
 typedef MainTestClass<VarianceMxNFunc> AvxMseTest;
 typedef MainTestClass<VarianceMxNFunc> AvxVarianceTest;
 typedef SubpelVarianceTest<SubpixVarMxNFunc> AvxSubpelVarianceTest;
 typedef SubpelVarianceTest<SubpixAvgVarMxNFunc> AvxSubpelAvgVarianceTest;
+#if CONFIG_JNT_COMP
+typedef SubpelVarianceTest<JntSubpixAvgVarMxNFunc> AvxJntSubpelAvgVarianceTest;
+#endif
 
 TEST_P(AvxSseTest, RefSse) { RefTestSse(); }
 TEST_P(AvxSseTest, MaxSse) { MaxTestSse(); }
@@ -716,6 +835,9 @@
 TEST_P(AvxSubpelVarianceTest, Ref) { RefTest(); }
 TEST_P(AvxSubpelVarianceTest, ExtremeRef) { ExtremeRefTest(); }
 TEST_P(AvxSubpelAvgVarianceTest, Ref) { RefTest(); }
+#if CONFIG_JNT_COMP
+TEST_P(AvxJntSubpelAvgVarianceTest, Ref) { RefTest(); }
+#endif
 
 INSTANTIATE_TEST_CASE_P(C, SumOfSquaresTest,
                         ::testing::Values(aom_get_mb_ss_c));
@@ -785,6 +907,39 @@
         SubpelAvgVarianceParams(2, 3, &aom_sub_pixel_avg_variance4x8_c, 0),
         SubpelAvgVarianceParams(2, 2, &aom_sub_pixel_avg_variance4x4_c, 0)));
 
+#if CONFIG_JNT_COMP
+typedef TestParams<JntSubpixAvgVarMxNFunc> JntSubpelAvgVarianceParams;
+INSTANTIATE_TEST_CASE_P(
+    C, AvxJntSubpelAvgVarianceTest,
+    ::testing::Values(
+        JntSubpelAvgVarianceParams(6, 6, &aom_jnt_sub_pixel_avg_variance64x64_c,
+                                   0),
+        JntSubpelAvgVarianceParams(6, 5, &aom_jnt_sub_pixel_avg_variance64x32_c,
+                                   0),
+        JntSubpelAvgVarianceParams(5, 6, &aom_jnt_sub_pixel_avg_variance32x64_c,
+                                   0),
+        JntSubpelAvgVarianceParams(5, 5, &aom_jnt_sub_pixel_avg_variance32x32_c,
+                                   0),
+        JntSubpelAvgVarianceParams(5, 4, &aom_jnt_sub_pixel_avg_variance32x16_c,
+                                   0),
+        JntSubpelAvgVarianceParams(4, 5, &aom_jnt_sub_pixel_avg_variance16x32_c,
+                                   0),
+        JntSubpelAvgVarianceParams(4, 4, &aom_jnt_sub_pixel_avg_variance16x16_c,
+                                   0),
+        JntSubpelAvgVarianceParams(4, 3, &aom_jnt_sub_pixel_avg_variance16x8_c,
+                                   0),
+        JntSubpelAvgVarianceParams(3, 4, &aom_jnt_sub_pixel_avg_variance8x16_c,
+                                   0),
+        JntSubpelAvgVarianceParams(3, 3, &aom_jnt_sub_pixel_avg_variance8x8_c,
+                                   0),
+        JntSubpelAvgVarianceParams(3, 2, &aom_jnt_sub_pixel_avg_variance8x4_c,
+                                   0),
+        JntSubpelAvgVarianceParams(2, 3, &aom_jnt_sub_pixel_avg_variance4x8_c,
+                                   0),
+        JntSubpelAvgVarianceParams(2, 2, &aom_jnt_sub_pixel_avg_variance4x4_c,
+                                   0)));
+#endif  // CONFIG_JNT_COMP
+
 #if CONFIG_HIGHBITDEPTH
 typedef MainTestClass<VarianceMxNFunc> AvxHBDMseTest;
 typedef MainTestClass<VarianceMxNFunc> AvxHBDVarianceTest;
@@ -1323,6 +1478,48 @@
         SubpelAvgVarianceParams(2, 3, &aom_sub_pixel_avg_variance4x8_ssse3, 0),
         SubpelAvgVarianceParams(2, 2, &aom_sub_pixel_avg_variance4x4_ssse3,
                                 0)));
+
+#if CONFIG_JNT_COMP
+INSTANTIATE_TEST_CASE_P(
+    SSSE3, AvxJntSubpelAvgVarianceTest,
+    ::testing::Values(
+        JntSubpelAvgVarianceParams(6, 6,
+                                   &aom_jnt_sub_pixel_avg_variance64x64_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(6, 5,
+                                   &aom_jnt_sub_pixel_avg_variance64x32_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(5, 6,
+                                   &aom_jnt_sub_pixel_avg_variance32x64_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(5, 5,
+                                   &aom_jnt_sub_pixel_avg_variance32x32_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(5, 4,
+                                   &aom_jnt_sub_pixel_avg_variance32x16_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(4, 5,
+                                   &aom_jnt_sub_pixel_avg_variance16x32_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(4, 4,
+                                   &aom_jnt_sub_pixel_avg_variance16x16_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(4, 3,
+                                   &aom_jnt_sub_pixel_avg_variance16x8_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(3, 4,
+                                   &aom_jnt_sub_pixel_avg_variance8x16_ssse3,
+                                   0),
+        JntSubpelAvgVarianceParams(3, 3,
+                                   &aom_jnt_sub_pixel_avg_variance8x8_ssse3, 0),
+        JntSubpelAvgVarianceParams(3, 2,
+                                   &aom_jnt_sub_pixel_avg_variance8x4_ssse3, 0),
+        JntSubpelAvgVarianceParams(2, 3,
+                                   &aom_jnt_sub_pixel_avg_variance4x8_ssse3, 0),
+        JntSubpelAvgVarianceParams(2, 2,
+                                   &aom_jnt_sub_pixel_avg_variance4x4_ssse3,
+                                   0)));
+#endif  // CONFIG_JNT_COMP
 #endif  // HAVE_SSSE3
 
 #if HAVE_AVX2