Comp_mask_upsampled_pred unit tests for widths>32
Change-Id: Ide765c68208a9195ec8d4cc34b055fcb38d00657
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index ad155b2..449a00f 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -363,21 +363,6 @@
}
}
-static AOM_INLINE void diffwtd_mask(uint8_t *mask, int which_inverse,
- int mask_base, const uint8_t *src0,
- int src0_stride, const uint8_t *src1,
- int src1_stride, int h, int w) {
- int i, j, m, diff;
- for (i = 0; i < h; ++i) {
- for (j = 0; j < w; ++j) {
- diff =
- abs((int)src0[i * src0_stride + j] - (int)src1[i * src1_stride + j]);
- m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
- mask[i * w + j] = which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
- }
- }
-}
-
void av1_build_compound_diffwtd_mask_c(uint8_t *mask,
DIFFWTD_MASK_TYPE mask_type,
const uint8_t *src0, int src0_stride,
@@ -385,90 +370,29 @@
int h, int w) {
switch (mask_type) {
case DIFFWTD_38:
- diffwtd_mask(mask, 0, 38, src0, src0_stride, src1, src1_stride, h, w);
+ av1_diffwtd_mask(mask, 0, 38, src0, src0_stride, src1, src1_stride, h, w);
break;
case DIFFWTD_38_INV:
- diffwtd_mask(mask, 1, 38, src0, src0_stride, src1, src1_stride, h, w);
+ av1_diffwtd_mask(mask, 1, 38, src0, src0_stride, src1, src1_stride, h, w);
break;
default: assert(0);
}
}
-static AOM_FORCE_INLINE void diffwtd_mask_highbd(
- uint8_t *mask, int which_inverse, int mask_base, const uint16_t *src0,
- int src0_stride, const uint16_t *src1, int src1_stride, int h, int w,
- const unsigned int bd) {
- assert(bd >= 8);
- if (bd == 8) {
- if (which_inverse) {
- for (int i = 0; i < h; ++i) {
- for (int j = 0; j < w; ++j) {
- int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
- unsigned int m = negative_to_zero(mask_base + diff);
- m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
- mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
- }
- src0 += src0_stride;
- src1 += src1_stride;
- mask += w;
- }
- } else {
- for (int i = 0; i < h; ++i) {
- for (int j = 0; j < w; ++j) {
- int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
- unsigned int m = negative_to_zero(mask_base + diff);
- m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
- mask[j] = m;
- }
- src0 += src0_stride;
- src1 += src1_stride;
- mask += w;
- }
- }
- } else {
- const unsigned int bd_shift = bd - 8;
- if (which_inverse) {
- for (int i = 0; i < h; ++i) {
- for (int j = 0; j < w; ++j) {
- int diff =
- (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
- unsigned int m = negative_to_zero(mask_base + diff);
- m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
- mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
- }
- src0 += src0_stride;
- src1 += src1_stride;
- mask += w;
- }
- } else {
- for (int i = 0; i < h; ++i) {
- for (int j = 0; j < w; ++j) {
- int diff =
- (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
- unsigned int m = negative_to_zero(mask_base + diff);
- m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
- mask[j] = m;
- }
- src0 += src0_stride;
- src1 += src1_stride;
- mask += w;
- }
- }
- }
-}
-
void av1_build_compound_diffwtd_mask_highbd_c(
uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0,
int src0_stride, const uint8_t *src1, int src1_stride, int h, int w,
int bd) {
switch (mask_type) {
case DIFFWTD_38:
- diffwtd_mask_highbd(mask, 0, 38, CONVERT_TO_SHORTPTR(src0), src0_stride,
- CONVERT_TO_SHORTPTR(src1), src1_stride, h, w, bd);
+ av1_diffwtd_mask_highbd(mask, 0, 38, CONVERT_TO_SHORTPTR(src0),
+ src0_stride, CONVERT_TO_SHORTPTR(src1),
+ src1_stride, h, w, bd);
break;
case DIFFWTD_38_INV:
- diffwtd_mask_highbd(mask, 1, 38, CONVERT_TO_SHORTPTR(src0), src0_stride,
- CONVERT_TO_SHORTPTR(src1), src1_stride, h, w, bd);
+ av1_diffwtd_mask_highbd(mask, 1, 38, CONVERT_TO_SHORTPTR(src0),
+ src0_stride, CONVERT_TO_SHORTPTR(src1),
+ src1_stride, h, w, bd);
break;
default: assert(0);
}
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index c869616..46d7d2f 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -17,6 +17,7 @@
#include "av1/common/filter.h"
#include "av1/common/warped_motion.h"
#include "aom/aom_integer.h"
+#include "aom_dsp/blend.h"
// Work out how many pixels off the edge of a reference frame we're allowed
// to go when forming an inter prediction.
@@ -361,6 +362,86 @@
void av1_init_wedge_masks();
+static INLINE void av1_diffwtd_mask(uint8_t *mask, int which_inverse,
+ int mask_base, const uint8_t *src0,
+ int src0_stride, const uint8_t *src1,
+ int src1_stride, int h, int w) {
+ int i, j, m, diff;
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) {
+ diff =
+ abs((int)src0[i * src0_stride + j] - (int)src1[i * src1_stride + j]);
+ m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
+ mask[i * w + j] = which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
+ }
+ }
+}
+
+static INLINE void av1_diffwtd_mask_highbd(uint8_t *mask, int which_inverse,
+ int mask_base, const uint16_t *src0,
+ int src0_stride,
+ const uint16_t *src1,
+ int src1_stride, int h, int w,
+ const unsigned int bd) {
+ assert(bd >= 8);
+ if (bd == 8) {
+ if (which_inverse) {
+ for (int i = 0; i < h; ++i) {
+ for (int j = 0; j < w; ++j) {
+ int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
+ unsigned int m = negative_to_zero(mask_base + diff);
+ m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+ mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
+ }
+ src0 += src0_stride;
+ src1 += src1_stride;
+ mask += w;
+ }
+ } else {
+ for (int i = 0; i < h; ++i) {
+ for (int j = 0; j < w; ++j) {
+ int diff = abs((int)src0[j] - (int)src1[j]) / DIFF_FACTOR;
+ unsigned int m = negative_to_zero(mask_base + diff);
+ m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+ mask[j] = m;
+ }
+ src0 += src0_stride;
+ src1 += src1_stride;
+ mask += w;
+ }
+ }
+ } else {
+ const unsigned int bd_shift = bd - 8;
+ if (which_inverse) {
+ for (int i = 0; i < h; ++i) {
+ for (int j = 0; j < w; ++j) {
+ int diff =
+ (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
+ unsigned int m = negative_to_zero(mask_base + diff);
+ m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+ mask[j] = AOM_BLEND_A64_MAX_ALPHA - m;
+ }
+ src0 += src0_stride;
+ src1 += src1_stride;
+ mask += w;
+ }
+ } else {
+ for (int i = 0; i < h; ++i) {
+ for (int j = 0; j < w; ++j) {
+ int diff =
+ (abs((int)src0[j] - (int)src1[j]) >> bd_shift) / DIFF_FACTOR;
+ unsigned int m = negative_to_zero(mask_base + diff);
+ m = AOMMIN(m, AOM_BLEND_A64_MAX_ALPHA);
+ mask[j] = m;
+ }
+ src0 += src0_stride;
+ src1 += src1_stride;
+ mask += w;
+ }
+ }
+ }
+}
+
static INLINE const uint8_t *av1_get_contiguous_soft_mask(int8_t wedge_index,
int8_t wedge_sign,
BLOCK_SIZE sb_type) {
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index b8ee74c..1e099b3 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -35,14 +35,12 @@
int ref_stride, const uint8_t *mask,
int mask_stride, int invert_mask);
-#if HAVE_SSSE3 || HAVE_SSE2 || HAVE_AVX2
const BLOCK_SIZE kValidBlockSize[] = {
BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32, BLOCK_32X64,
BLOCK_64X32, BLOCK_64X64, BLOCK_64X128, BLOCK_128X64, BLOCK_128X128,
BLOCK_16X64, BLOCK_64X16
};
-#endif
typedef std::tuple<comp_mask_pred_func, BLOCK_SIZE> CompMaskPredParam;
class AV1CompMaskVarianceTest
@@ -55,6 +53,8 @@
protected:
void RunCheckOutput(comp_mask_pred_func test_impl, BLOCK_SIZE bsize, int inv);
+ void RunCheckDiffMask(comp_mask_pred_func test_impl, BLOCK_SIZE bsize,
+ int inv);
void RunSpeedTest(comp_mask_pred_func test_impl, BLOCK_SIZE bsize);
bool CheckResult(int width, int height) {
for (int y = 0; y < height; ++y) {
@@ -122,6 +122,23 @@
}
}
+void AV1CompMaskVarianceTest::RunCheckDiffMask(comp_mask_pred_func test_impl,
+ BLOCK_SIZE bsize, int inv) {
+ const int w = block_size_wide[bsize];
+ const int h = block_size_high[bsize];
+ static uint8_t *mask;
+ mask = (uint8_t *)malloc(64 * w * h);
+ av1_diffwtd_mask(mask, inv, 38, pred_, w, ref_, MAX_SB_SIZE, h, w);
+
+ aom_comp_mask_pred_c(comp_pred1_, pred_, w, h, ref_, MAX_SB_SIZE, mask, w,
+ inv);
+ test_impl(comp_pred2_, pred_, w, h, ref_, MAX_SB_SIZE, mask, w, inv);
+
+ ASSERT_EQ(CheckResult(w, h), true) << " Diffwtd "
+ << " inv " << inv;
+ free(mask);
+}
+
void AV1CompMaskVarianceTest::RunSpeedTest(comp_mask_pred_func test_impl,
BLOCK_SIZE bsize) {
const int w = block_size_wide[bsize];
@@ -153,6 +170,8 @@
// inv = 0, 1
RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 0);
RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 1);
+ RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 0);
+ RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 1);
}
TEST_P(AV1CompMaskVarianceTest, DISABLED_Speed) {
@@ -299,6 +318,8 @@
protected:
void RunCheckOutput(highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize,
int inv);
+ void RunCheckDiffMask(highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize,
+ int inv);
void RunSpeedTest(highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize);
bool CheckResult(int width, int height) {
for (int y = 0; y < height; ++y) {
@@ -333,9 +354,9 @@
(uint16_t *)aom_memalign(16, MAX_SB_SQUARE * sizeof(*comp_pred1_));
comp_pred2_ =
(uint16_t *)aom_memalign(16, MAX_SB_SQUARE * sizeof(*comp_pred2_));
- pred_ = (uint16_t *)aom_memalign(16, MAX_SB_SQUARE * sizeof(*pred_));
+ pred_ = (uint16_t *)aom_memalign(16, 4 * MAX_SB_SQUARE * sizeof(*pred_));
ref_buffer_ = (uint16_t *)aom_memalign(
- 16, (MAX_SB_SQUARE + (8 * MAX_SB_SIZE)) * sizeof(*ref_buffer_));
+ 16, (4 * MAX_SB_SQUARE + (8 * MAX_SB_SIZE)) * sizeof(*ref_buffer_));
ref_ = ref_buffer_ + (8 * MAX_SB_SIZE);
}
@@ -376,6 +397,35 @@
}
}
+void AV1HighbdCompMaskVarianceTest::RunCheckDiffMask(
+ highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize, int inv) {
+ int bd_ = GET_PARAM(2);
+ const int w = block_size_wide[bsize];
+ const int h = block_size_high[bsize];
+
+ for (int i = 0; i < MAX_SB_SQUARE; ++i) {
+ pred_[i] = rnd_.Rand16() & ((1 << bd_) - 1);
+ }
+ for (int i = 0; i < MAX_SB_SQUARE + (8 * MAX_SB_SIZE); ++i) {
+ ref_buffer_[i] = rnd_.Rand16() & ((1 << bd_) - 1);
+ }
+ static uint8_t *mask;
+ mask = (uint8_t *)malloc(64 * w * h);
+ av1_diffwtd_mask_highbd(mask, inv, 38, pred_, w, ref_, MAX_SB_SIZE, h, w,
+ bd_);
+
+ aom_highbd_comp_mask_pred_c(
+ CONVERT_TO_BYTEPTR(comp_pred1_), CONVERT_TO_BYTEPTR(pred_), w, h,
+ CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv);
+
+ test_impl(CONVERT_TO_BYTEPTR(comp_pred2_), CONVERT_TO_BYTEPTR(pred_), w, h,
+ CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv);
+
+ ASSERT_EQ(CheckResult(w, h), true) << " Diffwtd "
+ << " inv " << inv;
+ free(mask);
+}
+
void AV1HighbdCompMaskVarianceTest::RunSpeedTest(
highbd_comp_mask_pred_func test_impl, BLOCK_SIZE bsize) {
int bd_ = GET_PARAM(2);
@@ -419,6 +469,8 @@
// inv = 0, 1
RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 0);
RunCheckOutput(GET_PARAM(0), GET_PARAM(1), 1);
+ RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 0);
+ RunCheckDiffMask(GET_PARAM(0), GET_PARAM(1), 1);
}
TEST_P(AV1HighbdCompMaskVarianceTest, DISABLED_Speed) {