Add dual and quad intrinsic support for CDEF MSE computation

This CL attempts to evaluate the aom_mse_16_bit for two 8x8
and four 4x4 blocks at a time. Also, avoiding few of the unpacks
and 32-bit register operations which can dealt within 16-bit
register.
The overall encode time reduction for RT preset is listed below

                Encode_time
cpu   Testset   Reduction(%)
 7      rtc       0.765
 8      rtc       0.803

Change-Id: I385c7ba00764575a620d563f0f6fed330ff5096d
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index aeaf9f1..d7ab4d9 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1319,6 +1319,9 @@
   add_proto qw/uint64_t/, "aom_mse_wxh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
   specialize qw/aom_mse_wxh_16bit  sse2 avx2/;
 
+  add_proto qw/uint64_t/, "aom_mse_16xh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int w, int h";
+  specialize qw/aom_mse_16xh_16bit avx2/;
+
   foreach (@encoder_block_sizes) {
     ($w, $h) = @$_;
     add_proto qw/unsigned int/, "aom_variance${w}x${h}", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
diff --git a/aom_dsp/variance.c b/aom_dsp/variance.c
index d764160..a37f732 100644
--- a/aom_dsp/variance.c
+++ b/aom_dsp/variance.c
@@ -1240,6 +1240,20 @@
   return sum;
 }
 
+uint64_t aom_mse_16xh_16bit_c(uint8_t *dst, int dstride, uint16_t *src, int w,
+                              int h) {
+  uint16_t *src_temp = src;
+  uint8_t *dst_temp = dst;
+  const int num_blks = 16 / w;
+  int64_t sum = 0;
+  for (int i = 0; i < num_blks; i++) {
+    sum += aom_mse_wxh_16bit_c(dst_temp, dstride, src_temp, w, w, h);
+    dst_temp += w;
+    src_temp += (w * h);
+  }
+  return sum;
+}
+
 uint64_t aom_mse_wxh_16bit_highbd_c(uint16_t *dst, int dstride, uint16_t *src,
                                     int sstride, int w, int h) {
   uint64_t sum = 0;
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index a7203ec..d5eb253 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -535,7 +535,7 @@
   __m128i dst0_4x8, dst1_4x8, dst2_4x8, dst3_4x8, dst_16x8;
   __m128i src0_4x16, src1_4x16, src2_4x16, src3_4x16;
   __m256i src0_8x16, src1_8x16, dst_16x16, src_16x16;
-  __m256i res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+  __m256i res0_4x64, res1_4x64;
   __m256i sub_result;
   const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
   __m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
@@ -558,30 +558,121 @@
         _mm256_castsi128_si256(_mm_unpacklo_epi64(src2_4x16, src3_4x16));
     src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
 
+    // r15 r14 r13------------r1 r0  - 16 bit
     sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
 
-    src_16x16 = _mm256_unpacklo_epi16(sub_result, zeros);
-    dst_16x16 = _mm256_unpackhi_epi16(sub_result, zeros);
+    // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
+    src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
 
-    src_16x16 = _mm256_madd_epi16(src_16x16, src_16x16);  // 32bit store
-    dst_16x16 = _mm256_madd_epi16(dst_16x16, dst_16x16);  // 32bit store
-
-    res0_4x64 = _mm256_unpacklo_epi32(src_16x16, zeros);
-    res1_4x64 = _mm256_unpackhi_epi32(src_16x16, zeros);
-    res2_4x64 = _mm256_unpacklo_epi32(dst_16x16, zeros);
-    res3_4x64 = _mm256_unpackhi_epi32(dst_16x16, zeros);
-
-    square_result = _mm256_add_epi64(
-        square_result,
-        _mm256_add_epi64(
-            _mm256_add_epi64(_mm256_add_epi64(res0_4x64, res1_4x64), res2_4x64),
-            res3_4x64));
+    // accumulation of result
+    square_result = _mm256_add_epi32(square_result, src_16x16);
   }
-  const __m128i sum_2x64 =
-      _mm_add_epi64(_mm256_castsi256_si128(square_result),
-                    _mm256_extracti128_si256(square_result, 1));
-  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
-  xx_storel_64(&sum, sum_1x64);
+
+  // s5 s4 s1 s0  - 64bit
+  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+  // s7 s6 s3 s2 - 64bit
+  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+  // r3 r2 r1 r0 - 64bit
+  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+  // r1+r3 r2+r0 - 64bit
+  const __m128i sum_1x64 =
+      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+                    _mm256_extracti128_si256(res0_4x64, 1));
+  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
+  return sum;
+}
+
+// Compute mse of four consecutive 4x4 blocks.
+// In src buffer, each 4x4 block in a 32x32 filter block is stored sequentially.
+// Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
+// buffer, thus dstride is a frame level stride.
+uint64_t aom_mse_4xh_quad_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
+                                     int src_blk_stride, int h) {
+  uint64_t sum = 0;
+  __m128i dst0_16x8, dst1_16x8, dst2_16x8, dst3_16x8;
+  __m256i dst0_16x16, dst1_16x16, dst2_16x16, dst3_16x16;
+  __m256i res0_4x64, res1_4x64;
+  __m256i sub_result_0, sub_result_1, sub_result_2, sub_result_3;
+  const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
+  __m256i square_result = zeros;
+  uint16_t *src_temp = src;
+
+  for (int i = 0; i < h; i += 4) {
+    dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
+    dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
+    dst2_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 2) * dstride]));
+    dst3_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 3) * dstride]));
+
+    // row0 of 1st,2nd, 3rd and 4th 4x4 blocks- d00 d10 d20 d30
+    dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
+    // row1 of 1st,2nd, 3rd and 4th 4x4 blocks - d01 d11 d21 d31
+    dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
+    // row2 of 1st,2nd, 3rd and 4th 4x4 blocks - d02 d12 d22 d32
+    dst2_16x16 = _mm256_cvtepu8_epi16(dst2_16x8);
+    // row3 of 1st,2nd, 3rd and 4th 4x4 blocks - d03 d13 d23 d33
+    dst3_16x16 = _mm256_cvtepu8_epi16(dst3_16x8);
+
+    // All rows of 1st 4x4 block - r00 r01 r02 r03
+    __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
+    // All rows of 2nd 4x4 block - r10 r11 r12 r13
+    __m256i src1_16x16 =
+        _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
+    // All rows of 3rd 4x4 block - r20 r21 r22 r23
+    __m256i src2_16x16 =
+        _mm256_loadu_si256((__m256i const *)(&src_temp[2 * src_blk_stride]));
+    // All rows of 4th 4x4 block - r30 r31 r32 r33
+    __m256i src3_16x16 =
+        _mm256_loadu_si256((__m256i const *)(&src_temp[3 * src_blk_stride]));
+
+    // r00 r10 r02 r12
+    __m256i tmp0_16x16 = _mm256_unpacklo_epi64(src0_16x16, src1_16x16);
+    // r01 r11 r03 r13
+    __m256i tmp1_16x16 = _mm256_unpackhi_epi64(src0_16x16, src1_16x16);
+    // r20 r30 r22 r32
+    __m256i tmp2_16x16 = _mm256_unpacklo_epi64(src2_16x16, src3_16x16);
+    // r21 r31 r23 r33
+    __m256i tmp3_16x16 = _mm256_unpackhi_epi64(src2_16x16, src3_16x16);
+
+    // r00 r10 r20 r30
+    src0_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x20);
+    // r01 r11 r21 r31
+    src1_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x20);
+    // r02 r12 r22 r32
+    src2_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x31);
+    // r03 r13 r23 r33
+    src3_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x31);
+
+    // r15 r14 r13------------r1 r0  - 16 bit
+    sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(src0_16x16, dst0_16x16));
+    sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(src1_16x16, dst1_16x16));
+    sub_result_2 = _mm256_abs_epi16(_mm256_sub_epi16(src2_16x16, dst2_16x16));
+    sub_result_3 = _mm256_abs_epi16(_mm256_sub_epi16(src3_16x16, dst3_16x16));
+
+    // s7 s6 s5 s4 s3 s2 s1 s0    - 32bit
+    src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
+    src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
+    src2_16x16 = _mm256_madd_epi16(sub_result_2, sub_result_2);
+    src3_16x16 = _mm256_madd_epi16(sub_result_3, sub_result_3);
+
+    // accumulation of result
+    src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
+    src2_16x16 = _mm256_add_epi32(src2_16x16, src3_16x16);
+    const __m256i square_result_0 = _mm256_add_epi32(src0_16x16, src2_16x16);
+    square_result = _mm256_add_epi32(square_result, square_result_0);
+    src_temp += 16;
+  }
+
+  // s5 s4 s1 s0  - 64bit
+  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+  // s7  s6  s3  s2 - 64bit
+  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+  // r3 r2 r1 r0 - 64bit
+  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+  // r1+r3 r2+r0 - 64bit
+  const __m128i sum_1x64 =
+      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+                    _mm256_extracti128_si256(res0_4x64, 1));
+  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
   return sum;
 }
 
@@ -590,7 +681,7 @@
   uint64_t sum = 0;
   __m128i dst0_8x8, dst1_8x8, dst3_16x8;
   __m256i src0_8x16, src1_8x16, src_16x16, dst_16x16;
-  __m256i res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+  __m256i res0_4x64, res1_4x64;
   __m256i sub_result;
   const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
   __m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
@@ -607,38 +698,98 @@
         _mm_loadu_si128((__m128i *)&src[(i + 1) * sstride]));
     src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
 
+    // r15 r14 r13 - - - r1 r0 - 16 bit
     sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
 
-    src_16x16 = _mm256_unpacklo_epi16(sub_result, zeros);
-    dst_16x16 = _mm256_unpackhi_epi16(sub_result, zeros);
+    // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
+    src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
 
-    src_16x16 = _mm256_madd_epi16(src_16x16, src_16x16);
-    dst_16x16 = _mm256_madd_epi16(dst_16x16, dst_16x16);
-
-    res0_4x64 = _mm256_unpacklo_epi32(src_16x16, zeros);
-    res1_4x64 = _mm256_unpackhi_epi32(src_16x16, zeros);
-    res2_4x64 = _mm256_unpacklo_epi32(dst_16x16, zeros);
-    res3_4x64 = _mm256_unpackhi_epi32(dst_16x16, zeros);
-
-    square_result = _mm256_add_epi64(
-        square_result,
-        _mm256_add_epi64(
-            _mm256_add_epi64(_mm256_add_epi64(res0_4x64, res1_4x64), res2_4x64),
-            res3_4x64));
+    // accumulation of result
+    square_result = _mm256_add_epi32(square_result, src_16x16);
   }
 
-  const __m128i sum_2x64 =
-      _mm_add_epi64(_mm256_castsi256_si128(square_result),
-                    _mm256_extracti128_si256(square_result, 1));
-  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
-  xx_storel_64(&sum, sum_1x64);
+  // s5 s4 s1 s0  - 64bit
+  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+  // s7 s6 s3 s2 - 64bit
+  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+  // r3 r2 r1 r0 - 64bit
+  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+  // r1+r3 r2+r0 - 64bit
+  const __m128i sum_1x64 =
+      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+                    _mm256_extracti128_si256(res0_4x64, 1));
+  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
+  return sum;
+}
+
+// Compute mse of two consecutive 8x8 blocks.
+// In src buffer, each 8x8 block in a 64x64 filter block is stored sequentially.
+// Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
+// buffer, thus dstride is a frame level stride.
+uint64_t aom_mse_8xh_dual_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
+                                     int src_blk_stride, int h) {
+  uint64_t sum = 0;
+  __m128i dst0_16x8, dst1_16x8;
+  __m256i dst0_16x16, dst1_16x16;
+  __m256i res0_4x64, res1_4x64;
+  __m256i sub_result_0, sub_result_1;
+  const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
+  __m256i square_result = zeros;
+  uint16_t *src_temp = src;
+
+  for (int i = 0; i < h; i += 2) {
+    dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
+    dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
+
+    // row0 of 1st and 2nd 8x8 block - d00 d10
+    dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
+    // row1 of 1st and 2nd 8x8 block - d01 d11
+    dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
+
+    // 2 rows of 1st 8x8 block - r00 r01
+    __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
+    // 2 rows of 2nd 8x8 block - r10 r11
+    __m256i src1_16x16 =
+        _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
+    // r00 r10 - 128bit
+    __m256i tmp0_16x16 =
+        _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x20);
+    // r01 r11 - 128bit
+    __m256i tmp1_16x16 =
+        _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x31);
+
+    // r15 r14 r13------------r1 r0 - 16 bit
+    sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(tmp0_16x16, dst0_16x16));
+    sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(tmp1_16x16, dst1_16x16));
+
+    // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit each
+    src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
+    src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
+
+    // accumulation of result
+    src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
+    square_result = _mm256_add_epi32(square_result, src0_16x16);
+    src_temp += 16;
+  }
+
+  // s5 s4 s1 s0  - 64bit
+  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
+  // s7 s6 s3 s2 - 64bit
+  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
+  // r3 r2 r1 r0 - 64bit
+  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
+  // r1+r3 r2+r0 - 64bit
+  const __m128i sum_1x64 =
+      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
+                    _mm256_extracti128_si256(res0_4x64, 1));
+  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
   return sum;
 }
 
 uint64_t aom_mse_wxh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
                                 int sstride, int w, int h) {
   assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
-         "w=8/4 and h=8/4 must satisfy");
+         "w=8/4 and h=8/4 must be satisfied");
   switch (w) {
     case 4: return aom_mse_4xh_16bit_avx2(dst, dstride, src, sstride, h);
     case 8: return aom_mse_8xh_16bit_avx2(dst, dstride, src, sstride, h);
@@ -646,6 +797,21 @@
   }
 }
 
+// Computes mse of two 8x8 or four 4x4 consecutive blocks. Luma plane uses 8x8
+// block and Chroma uses 4x4 block. In src buffer, each block in a filter block
+// is stored sequentially. Hence src_blk_stride is same as block width. Whereas
+// dst buffer is a frame buffer, thus dstride is a frame level stride.
+uint64_t aom_mse_16xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
+                                 int w, int h) {
+  assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
+         "w=8/4 and h=8/4 must be satisfied");
+  switch (w) {
+    case 4: return aom_mse_4xh_quad_16bit_avx2(dst, dstride, src, w * h, h);
+    case 8: return aom_mse_8xh_dual_16bit_avx2(dst, dstride, src, w * h, h);
+    default: assert(0 && "unsupported width"); return -1;
+  }
+}
+
 static INLINE void sum_final_256bit_avx2(__m256i sum_8x16[2], int *const sum) {
   const __m256i sum_result_0 = _mm256_hadd_epi16(sum_8x16[0], sum_8x16[1]);
   const __m256i sum_result_1 =
diff --git a/av1/encoder/pickcdef.c b/av1/encoder/pickcdef.c
index 3659650..ece0ba3 100644
--- a/av1/encoder/pickcdef.c
+++ b/av1/encoder/pickcdef.c
@@ -267,6 +267,22 @@
   return sum >> 2 * coeff_shift;
 }
 #endif
+
+// Checks dual and quad block processing is applicable for block widths 8 and 4
+// respectively.
+static INLINE int is_dual_or_quad_applicable(cdef_list *dlist, int width,
+                                             int cdef_count, int bi, int iter) {
+  assert(width == 8 || width == 4);
+  const int blk_offset = (width == 8) ? 1 : 3;
+  if ((iter + blk_offset) >= cdef_count) return 0;
+
+  if (dlist[bi].by == dlist[bi + blk_offset].by &&
+      dlist[bi].bx + blk_offset == dlist[bi + blk_offset].bx)
+    return 1;
+
+  return 0;
+}
+
 static uint64_t compute_cdef_dist(void *dst, int dstride, uint16_t *src,
                                   cdef_list *dlist, int cdef_count,
                                   BLOCK_SIZE bsize, int coeff_shift, int row,
@@ -275,18 +291,34 @@
          bsize == BLOCK_8X8);
   uint64_t sum = 0;
   int bi, bx, by;
+  int iter = 0;
+  int inc = 1;
   uint8_t *dst8 = (uint8_t *)dst;
   uint8_t *dst_buff = &dst8[row * dstride + col];
   int src_stride, width, height, width_log2, height_log2;
   init_src_params(&src_stride, &width, &height, &width_log2, &height_log2,
                   bsize);
-  for (bi = 0; bi < cdef_count; bi++) {
+
+  const int num_blks = 16 / width;
+  for (bi = 0; bi < cdef_count; bi += inc) {
     by = dlist[bi].by;
     bx = dlist[bi].bx;
-    sum += aom_mse_wxh_16bit(
-        &dst_buff[(by << height_log2) * dstride + (bx << width_log2)], dstride,
-        &src[bi << (height_log2 + width_log2)], src_stride, width, height);
+    uint16_t *src_tmp = &src[bi << (height_log2 + width_log2)];
+    uint8_t *dst_tmp =
+        &dst_buff[(by << height_log2) * dstride + (bx << width_log2)];
+
+    if (is_dual_or_quad_applicable(dlist, width, cdef_count, bi, iter)) {
+      sum += aom_mse_16xh_16bit(dst_tmp, dstride, src_tmp, width, height);
+      iter += num_blks;
+      inc = num_blks;
+    } else {
+      sum += aom_mse_wxh_16bit(dst_tmp, dstride, src_tmp, src_stride, width,
+                               height);
+      iter += 1;
+      inc = 1;
+    }
   }
+
   return sum >> 2 * coeff_shift;
 }
 
diff --git a/test/variance_test.cc b/test/variance_test.cc
index 8e6abf6..6c0180f 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -26,11 +26,14 @@
 #include "aom_mem/aom_mem.h"
 #include "aom_ports/aom_timer.h"
 #include "aom_ports/mem.h"
+#include "av1/common/cdef_block.h"
 
 namespace {
 
 typedef uint64_t (*MseWxH16bitFunc)(uint8_t *dst, int dstride, uint16_t *src,
                                     int sstride, int w, int h);
+typedef uint64_t (*Mse16xH16bitFunc)(uint8_t *dst, int dstride, uint16_t *src,
+                                     int w, int h);
 typedef unsigned int (*VarianceMxNFunc)(const uint8_t *a, int a_stride,
                                         const uint8_t *b, int b_stride,
                                         unsigned int *sse);
@@ -513,6 +516,139 @@
   }
 }
 
+template <typename FunctionType>
+class Mse16xHTestClass
+    : public ::testing::TestWithParam<TestParams<FunctionType> > {
+ public:
+  // Memory required to compute mse of two 8x8 and four 4x4 blocks assigned for
+  // maximum width 16 and maximum height 8.
+  int mem_size = 16 * 8;
+  virtual void SetUp() {
+    params_ = this->GetParam();
+    rnd_.Reset(ACMRandom::DeterministicSeed());
+    src_ = reinterpret_cast<uint16_t *>(
+        aom_memalign(16, mem_size * sizeof(*src_)));
+    dst_ =
+        reinterpret_cast<uint8_t *>(aom_memalign(16, mem_size * sizeof(*dst_)));
+    ASSERT_NE(src_, nullptr);
+    ASSERT_NE(dst_, nullptr);
+  }
+
+  virtual void TearDown() {
+    aom_free(src_);
+    aom_free(dst_);
+    src_ = nullptr;
+    dst_ = nullptr;
+  }
+
+  uint8_t RandBool() {
+    const uint32_t value = rnd_.Rand8();
+    return (value & 0x1);
+  }
+
+ protected:
+  void RefMatchExtremeTestMse();
+  void RefMatchTestMse();
+  void SpeedTest();
+
+ protected:
+  ACMRandom rnd_;
+  uint8_t *dst_;
+  uint16_t *src_;
+  TestParams<FunctionType> params_;
+
+  // some relay helpers
+  int width() const { return params_.width; }
+  int height() const { return params_.height; }
+  int d_stride() const { return params_.width; }
+};
+
+template <typename Mse16xHFunctionType>
+void Mse16xHTestClass<Mse16xHFunctionType>::SpeedTest() {
+  aom_usec_timer ref_timer, test_timer;
+  double elapsed_time_c = 0.0;
+  double elapsed_time_simd = 0.0;
+  const int loop_count = 10000000;
+  const int w = width();
+  const int h = height();
+  const int dstride = d_stride();
+
+  for (int k = 0; k < mem_size; ++k) {
+    dst_[k] = rnd_.Rand8();
+    // Right shift by 6 is done to generate more input in range of [0,255] than
+    // CDEF_VERY_LARGE
+    int rnd_i10 = rnd_.Rand16() >> 6;
+    src_[k] = (rnd_i10 < 256) ? rnd_i10 : CDEF_VERY_LARGE;
+  }
+
+  aom_usec_timer_start(&ref_timer);
+  for (int i = 0; i < loop_count; i++) {
+    aom_mse_16xh_16bit_c(dst_, dstride, src_, w, h);
+  }
+  aom_usec_timer_mark(&ref_timer);
+  elapsed_time_c = static_cast<double>(aom_usec_timer_elapsed(&ref_timer));
+
+  aom_usec_timer_start(&test_timer);
+  for (int i = 0; i < loop_count; i++) {
+    params_.func(dst_, dstride, src_, w, h);
+  }
+  aom_usec_timer_mark(&test_timer);
+  elapsed_time_simd = static_cast<double>(aom_usec_timer_elapsed(&test_timer));
+
+  printf("%dx%d\tc_time=%lf \t simd_time=%lf \t gain=%.31f\n", width(),
+         height(), elapsed_time_c, elapsed_time_simd,
+         (elapsed_time_c / elapsed_time_simd));
+}
+
+template <typename Mse16xHFunctionType>
+void Mse16xHTestClass<Mse16xHFunctionType>::RefMatchTestMse() {
+  uint64_t mse_ref = 0;
+  uint64_t mse_mod = 0;
+  const int w = width();
+  const int h = height();
+  const int dstride = d_stride();
+
+  for (int i = 0; i < 10; i++) {
+    for (int k = 0; k < mem_size; ++k) {
+      dst_[k] = rnd_.Rand8();
+      // Right shift by 6 is done to generate more input in range of [0,255]
+      // than CDEF_VERY_LARGE
+      int rnd_i10 = rnd_.Rand16() >> 6;
+      src_[k] = (rnd_i10 < 256) ? rnd_i10 : CDEF_VERY_LARGE;
+    }
+
+    API_REGISTER_STATE_CHECK(
+        mse_ref = aom_mse_16xh_16bit_c(dst_, dstride, src_, w, h));
+    API_REGISTER_STATE_CHECK(mse_mod = params_.func(dst_, dstride, src_, w, h));
+    EXPECT_EQ(mse_ref, mse_mod)
+        << "ref mse: " << mse_ref << " mod mse: " << mse_mod;
+  }
+}
+
+template <typename Mse16xHFunctionType>
+void Mse16xHTestClass<Mse16xHFunctionType>::RefMatchExtremeTestMse() {
+  uint64_t mse_ref = 0;
+  uint64_t mse_mod = 0;
+  const int w = width();
+  const int h = height();
+  const int dstride = d_stride();
+  const int iter = 10;
+
+  // Fill the buffers with extreme values
+  for (int i = 0; i < iter; i++) {
+    for (int k = 0; k < mem_size; ++k) {
+      dst_[k] = static_cast<uint8_t>(RandBool() ? 0 : 255);
+      src_[k] = static_cast<uint16_t>(RandBool() ? 0 : CDEF_VERY_LARGE);
+    }
+
+    API_REGISTER_STATE_CHECK(
+        mse_ref = aom_mse_16xh_16bit_c(dst_, dstride, src_, w, h));
+    API_REGISTER_STATE_CHECK(mse_mod = params_.func(dst_, dstride, src_, w, h));
+    EXPECT_EQ(mse_ref, mse_mod)
+        << "ref mse: " << mse_ref << " mod mse: " << mse_mod;
+  }
+}
+
 // Main class for testing a function type
 template <typename FunctionType>
 class MainTestClass
@@ -1327,6 +1463,7 @@
 #endif  // !CONFIG_REALTIME_ONLY
 
 typedef MseWxHTestClass<MseWxH16bitFunc> MseWxHTest;
+typedef Mse16xHTestClass<Mse16xH16bitFunc> Mse16xHTest;
 typedef MainTestClass<Get4x4SseFunc> AvxSseTest;
 typedef MainTestClass<VarianceMxNFunc> AvxMseTest;
 typedef MainTestClass<VarianceMxNFunc> AvxVarianceTest;
@@ -1339,11 +1476,15 @@
 typedef ObmcVarianceTest<ObmcSubpelVarFunc> AvxObmcSubpelVarianceTest;
 #endif
 typedef TestParams<MseWxH16bitFunc> MseWxHParams;
+typedef TestParams<Mse16xH16bitFunc> Mse16xHParams;
 
 TEST_P(AvxSseTest, RefSse) { RefTestSse(); }
 TEST_P(AvxSseTest, MaxSse) { MaxTestSse(); }
 TEST_P(MseWxHTest, RefMse) { RefMatchTestMse(); }
 TEST_P(MseWxHTest, DISABLED_SpeedMse) { SpeedTest(); }
+TEST_P(Mse16xHTest, RefMse) { RefMatchTestMse(); }
+TEST_P(Mse16xHTest, RefMseExtreme) { RefMatchExtremeTestMse(); }
+TEST_P(Mse16xHTest, DISABLED_SpeedMse) { SpeedTest(); }
 TEST_P(AvxMseTest, RefMse) { RefTestMse(); }
 TEST_P(AvxMseTest, MaxMse) { MaxTestMse(); }
 TEST_P(AvxVarianceTest, Zero) { ZeroTest(); }
@@ -1375,6 +1516,13 @@
                       MseWxHParams(2, 3, &aom_mse_wxh_16bit_c, 8),
                       MseWxHParams(2, 2, &aom_mse_wxh_16bit_c, 8)));
 
+INSTANTIATE_TEST_SUITE_P(
+    C, Mse16xHTest,
+    ::testing::Values(Mse16xHParams(3, 3, &aom_mse_16xh_16bit_c, 8),
+                      Mse16xHParams(3, 2, &aom_mse_16xh_16bit_c, 8),
+                      Mse16xHParams(2, 3, &aom_mse_16xh_16bit_c, 8),
+                      Mse16xHParams(2, 2, &aom_mse_16xh_16bit_c, 8)));
+
 INSTANTIATE_TEST_SUITE_P(C, SumOfSquaresTest,
                          ::testing::Values(aom_get_mb_ss_c));
 
@@ -2740,6 +2888,13 @@
                       MseWxHParams(2, 3, &aom_mse_wxh_16bit_avx2, 8),
                       MseWxHParams(2, 2, &aom_mse_wxh_16bit_avx2, 8)));
 
+INSTANTIATE_TEST_SUITE_P(
+    AVX2, Mse16xHTest,
+    ::testing::Values(Mse16xHParams(3, 3, &aom_mse_16xh_16bit_avx2, 8),
+                      Mse16xHParams(3, 2, &aom_mse_16xh_16bit_avx2, 8),
+                      Mse16xHParams(2, 3, &aom_mse_16xh_16bit_avx2, 8),
+                      Mse16xHParams(2, 2, &aom_mse_16xh_16bit_avx2, 8)));
+
 INSTANTIATE_TEST_SUITE_P(AVX2, AvxMseTest,
                          ::testing::Values(MseParams(4, 4,
                                                      &aom_mse16x16_avx2)));