Comp_mask_upsampled_pred unit tests for widths>32

Change-Id: Ide765c68208a9195ec8d4cc34b055fcb38d00657
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index ad155b2..449a00f 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -363,21 +363,6 @@
   }
 }
 
-static AOM_INLINE void diffwtd_mask(uint8_t *mask, int which_inverse,
-                                    int mask_base, const uint8_t *src0,
-                                    int src0_stride, const uint8_t *src1,
-                                    int src1_stride, int h, int w) {
-  int i, j, m, diff;
-  for (i = 0; i < h; ++i) {
-    for (j = 0; j < w; ++j) {
-      diff =
-          abs((int)src0[i * src0_stride + j] - (int)src1[i * src1_stride + j]);
-      m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
-      mask[i * w + j] = which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
-    }
-  }
-}
-
 void av1_build_compound_diffwtd_mask_c(uint8_t *mask,
                                        DIFFWTD_MASK_TYPE mask_type,
                                        const uint8_t *src0, int src0_stride,
@@ -385,90 +370,29 @@
                                        int h, int w) {
   switch (mask_type) {
     case DIFFWTD_38:
-      diffwtd_mask(mask, 0, 38, src0, src0_stride, src1, src1_stride, h, w);
+      av1_diffwtd_mask(mask, 0, 38, src0, src0_stride, src1, src1_stride, h, w);
       break;
     case DIFFWTD_38_INV:
-      diffwtd_mask(mask, 1, 38, src0, src0_stride, src1, src1_stride, h, w);
+      av1_diffwtd_mask(mask, 1, 38, src0, src0_stride, src1, src1_stride, h, w);
       break;
     default: assert(0);
   }
 }
 
-static AOM_FORCE_INLINE void diffwtd_mask_highbd(
-    uint8_t *mask, int which_inverse, int mask_base, const uint16_t *src0,
-    int src0_stride, const uint16_t *src1, int src1_stride, int h, int w,
-    const unsigned int bd) {
-  assert(bd >= 8);
-  if (bd == 8) {
-    if (which_inverse) {
-      for (int i = 0; i < h; ++i) {
-        for (int j = 0; j < w; ++j) {
-          int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
-          unsigned int m = negative_to_zero(mask_base + diff);
-          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
-          mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
-        }
-        src0 += src0_stride;
-        src1 += src1_stride;
-        mask += w;
-      }
-    } else {
-      for (int i = 0; i < h; ++i) {
-        for (int j = 0; j < w; ++j) {
-          int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
-          unsigned int m = negative_to_zero(mask_base + diff);
-          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
-          mask[j] = m;
-        }
-        src0 += src0_stride;
-        src1 += src1_stride;
-        mask += w;
-      }
-    }
-  } else {
-    const unsigned int bd_shift = bd - 8;
-    if (which_inverse) {
-      for (int i = 0; i < h; ++i) {
-        for (int j = 0; j < w; ++j) {
-          int diff =
-              (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
-          unsigned int m = negative_to_zero(mask_base + diff);
-          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
-          mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
-        }
-        src0 += src0_stride;
-        src1 += src1_stride;
-        mask += w;
-      }
-    } else {
-      for (int i = 0; i < h; ++i) {
-        for (int j = 0; j < w; ++j) {
-          int diff =
-              (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
-          unsigned int m = negative_to_zero(mask_base + diff);
-          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
-          mask[j] = m;
-        }
-        src0 += src0_stride;
-        src1 += src1_stride;
-        mask += w;
-      }
-    }
-  }
-}
-
 void av1_build_compound_diffwtd_mask_highbd_c(
     uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0,
     int src0_stride, const uint8_t *src1, int src1_stride, int h, int w,
     int bd) {
   switch (mask_type) {
     case DIFFWTD_38:
-      diffwtd_mask_highbd(mask, 0, 38, CONVERT_TO_SHORTPTR(src0), src0_stride,
-                          CONVERT_TO_SHORTPTR(src1), src1_stride, h, w, bd);
+      av1_diffwtd_mask_highbd(mask, 0, 38, CONVERT_TO_SHORTPTR(src0),
+                              src0_stride, CONVERT_TO_SHORTPTR(src1),
+                              src1_stride, h, w, bd);
       break;
     case DIFFWTD_38_INV:
-      diffwtd_mask_highbd(mask, 1, 38, CONVERT_TO_SHORTPTR(src0), src0_stride,
-                          CONVERT_TO_SHORTPTR(src1), src1_stride, h, w, bd);
+      av1_diffwtd_mask_highbd(mask, 1, 38, CONVERT_TO_SHORTPTR(src0),
+                              src0_stride, CONVERT_TO_SHORTPTR(src1),
+                              src1_stride, h, w, bd);
       break;
     default: assert(0);
   }
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index c869616..46d7d2f 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -17,6 +17,7 @@
 #include "av1/common/filter.h"
 #include "av1/common/warped_motion.h"
 #include "aom/aom_integer.h"
+#include "aom_dsp/blend.h"
 
 // Work out how many pixels off the edge of a reference frame we're allowed
 // to go when forming an inter prediction.
@@ -361,6 +362,86 @@
 
 void av1_init_wedge_masks();
 
+static INLINE void av1_diffwtd_mask(uint8_t *mask, int which_inverse,
+                                    int mask_base, const uint8_t *src0,
+                                    int src0_stride, const uint8_t *src1,
+                                    int src1_stride, int h, int w) {
+  int i, j, m, diff;
+  for (i = 0; i < h; ++i) {
+    for (j = 0; j < w; ++j) {
+      diff =
+          abs((int)src0[i * src0_stride + j] - (int)src1[i * src1_stride + j]);
+      m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
+      mask[i * w + j] = which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
+    }
+  }
+}
+
+static INLINE void av1_diffwtd_mask_highbd(uint8_t *mask, int which_inverse,
+                                           int mask_base, const uint16_t *src0,
+                                           int src0_stride,
+                                           const uint16_t *src1,
+                                           int src1_stride, int h, int w,
+                                           const unsigned int bd) {
+  assert(bd >= 8);
+  if (bd == 8) {
+    if (which_inverse) {
+      for (int i = 0; i < h; ++i) {
+        for (int j = 0; j < w; ++j) {
+          int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
+          unsigned int m = negative_to_zero(mask_base + diff);
+          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+          mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
+        }
+        src0 += src0_stride;
+        src1 += src1_stride;
+        mask += w;
+      }
+    } else {
+      for (int i = 0; i < h; ++i) {
+        for (int j = 0; j < w; ++j) {
+          int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
+          unsigned int m = negative_to_zero(mask_base + diff);
+          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+          mask[j] = m;
+        }
+        src0 += src0_stride;
+        src1 += src1_stride;
+        mask += w;
+      }
+    }
+  } else {
+    const unsigned int bd_shift = bd - 8;
+    if (which_inverse) {
+      for (int i = 0; i < h; ++i) {
+        for (int j = 0; j < w; ++j) {
+          int diff =
+              (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
+          unsigned int m = negative_to_zero(mask_base + diff);
+          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+          mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
+        }
+        src0 += src0_stride;
+        src1 += src1_stride;
+        mask += w;
+      }
+    } else {
+      for (int i = 0; i < h; ++i) {
+        for (int j = 0; j < w; ++j) {
+          int diff =
+              (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
+          unsigned int m = negative_to_zero(mask_base + diff);
+          m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+          mask[j] = m;
+        }
+        src0 += src0_stride;
+        src1 += src1_stride;
+        mask += w;
+      }
+    }
+  }
+}
+
 static INLINE const uint8_t *av1_get_contiguous_soft_mask(int8_t wedge_index,
                                                           int8_t wedge_sign,
                                                           BLOCK_SIZE sb_type) {
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index b8ee74c..1e099b3 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -35,14 +35,12 @@
                                     int ref_stride, const uint8_t *mask,
                                     int mask_stride, int invert_mask);
 
-#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2
 const BLOCK_SIZE kValidBlockSize[] = {
   BLOCK_8X8,   BLOCK_8X16,  BLOCK_8X32,   BLOCK_16X8,   BLOCK_16X16,
   BLOCK_16X32, BLOCK_32X8,  BLOCK_32X16,  BLOCK_32X32,  BLOCK_32X64,
   BLOCK_64X32, BLOCK_64X64, BLOCK_64X128, BLOCK_128X64, BLOCK_128X128,
   BLOCK_16X64, BLOCK_64X16
 };
-#endif
 typedef std::tuple<comp_mask_pred_func, BLOCK_SIZE> CompMaskPredParam;
 
 class AV1CompMaskVarianceTest
@@ -55,6 +53,8 @@
 
  protected:
   void RunCheckOutput(comp_mask_pred_func test_impl, BLOCK_SIZE bsize, int inv);
+  void RunCheckDiffMask(comp_mask_pred_func test_impl, BLOCK_SIZE bsize,
+                        int inv);
   void RunSpeedTest(comp_mask_pred_func test_impl, BLOCK_SIZE bsize);
   bool CheckResult(int width, int height) {
     for (int y = 0; y < height; ++y) {
@@ -122,6 +122,23 @@
   }
 }
 
+void AV1CompMaskVarianceTest::RunCheckDiffMask(comp_mask_pred_func test_impl,
+                                               BLOCK_SIZE bsize, int inv) {
+  const int w = block_size_wide[bsize];
+  const int h = block_size_high[bsize];
+  static uint8_t *mask;
+  mask = (uint8_t *)malloc(64 * w * h);
+  av1_diffwtd_mask(mask, inv, 38, pred_, w, ref_, MAX_SB_SIZE, h, w);
+
+  aom_comp_mask_pred_c(comp_pred1_, pred_, w, h, ref_, MAX_SB_SIZE, mask, w,
+                       inv);
+  test_impl(comp_pred2_, pred_, w, h, ref_, MAX_SB_SIZE, mask, w, inv);
+
+  ASSERT_EQ(CheckResult(w, h), true) << " Diffwtd "
+                                     << " inv " << inv;
+  free(mask);
+}
+
 void AV1CompMaskVarianceTest::RunSpeedTest(comp_mask_pred_func test_impl,
                                            BLOCK_SIZE bsize) {
   const int w = block_size_wide[bsize];
@@ -153,6 +170,8 @@
   // inv = 0, 1
   RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 0);
   RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 1);
+  RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 0);
+  RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 1);
 }
 
 TEST_P(AV1CompMaskVarianceTest, DISABLED_Speed) {
@@ -299,6 +318,8 @@
  protected:
   void RunCheckOutput(highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize,
                       int inv);
+  void RunCheckDiffMask(highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize,
+                        int inv);
   void RunSpeedTest(highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize);
   bool CheckResult(int width, int height) {
     for (int y = 0; y < height; ++y) {
@@ -333,9 +354,9 @@
       (uint16_t *)aom_memalign(16, MAX_SB_SQUARE * sizeof(*comp_pred1_));
   comp_pred2_ =
       (uint16_t *)aom_memalign(16, MAX_SB_SQUARE * sizeof(*comp_pred2_));
-  pred_ = (uint16_t *)aom_memalign(16, MAX_SB_SQUARE * sizeof(*pred_));
+  pred_ = (uint16_t *)aom_memalign(16, 4 * MAX_SB_SQUARE * sizeof(*pred_));
   ref_buffer_ = (uint16_t *)aom_memalign(
-      16, (MAX_SB_SQUARE + (8 * MAX_SB_SIZE)) * sizeof(*ref_buffer_));
+      16, (4 * MAX_SB_SQUARE + (8 * MAX_SB_SIZE)) * sizeof(*ref_buffer_));
   ref_ = ref_buffer_ + (8 * MAX_SB_SIZE);
 }
 
@@ -376,6 +397,35 @@
   }
 }
 
+void AV1HighbdCompMaskVarianceTest::RunCheckDiffMask(
+    highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize, int inv) {
+  int bd_ = GET_PARAM(2);
+  const int w = block_size_wide[bsize];
+  const int h = block_size_high[bsize];
+
+  for (int i = 0; i < MAX_SB_SQUARE; ++i) {
+    pred_[i] = rnd_.Rand16() & ((1 << bd_) - 1);
+  }
+  for (int i = 0; i < MAX_SB_SQUARE + (8 * MAX_SB_SIZE); ++i) {
+    ref_buffer_[i] = rnd_.Rand16() & ((1 << bd_) - 1);
+  }
+  static uint8_t *mask;
+  mask = (uint8_t *)malloc(64 * w * h);
+  av1_diffwtd_mask_highbd(mask, inv, 38, pred_, w, ref_, MAX_SB_SIZE, h, w,
+                          bd_);
+
+  aom_highbd_comp_mask_pred_c(
+      CONVERT_TO_BYTEPTR(comp_pred1_), CONVERT_TO_BYTEPTR(pred_), w, h,
+      CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv);
+
+  test_impl(CONVERT_TO_BYTEPTR(comp_pred2_), CONVERT_TO_BYTEPTR(pred_), w, h,
+            CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv);
+
+  ASSERT_EQ(CheckResult(w, h), true) << " Diffwtd "
+                                     << " inv " << inv;
+  free(mask);
+}
+
 void AV1HighbdCompMaskVarianceTest::RunSpeedTest(
     highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize) {
   int bd_ = GET_PARAM(2);
@@ -419,6 +469,8 @@
   // inv = 0, 1
   RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 0);
   RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 1);
+  RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 0);
+  RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 1);
 }
 
 TEST_P(AV1HighbdCompMaskVarianceTest, DISABLED_Speed) {