Add high-bitdepth version of satd functions for tpl
Change-Id: Iabc7109881bce73bf6bbe0c4e603eb5d9f1c141f
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index b654a9a..06214ab 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -880,8 +880,17 @@
add_proto qw/void aom_hadamard_32x32/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
specialize qw/aom_hadamard_32x32 avx2 sse2/;
+ add_proto qw/void aom_highbd_hadamard_8x8/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
+ specialize qw/aom_highbd_hadamard_8x8 avx2/;
+
+ add_proto qw/void aom_highbd_hadamard_16x16/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
+ specialize qw/aom_highbd_hadamard_16x16 avx2/;
+
+ add_proto qw/void aom_highbd_hadamard_32x32/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
+ specialize qw/aom_highbd_hadamard_32x32 avx2/;
+
add_proto qw/int aom_satd/, "const tran_low_t *coeff, int length";
- specialize qw/aom_satd avx2 sse2/;
+ specialize qw/aom_satd avx2/;
#
# Structured Similarity (SSIM)
diff --git a/aom_dsp/avg.c b/aom_dsp/avg.c
index 43d2760..fcffc73 100644
--- a/aom_dsp/avg.c
+++ b/aom_dsp/avg.c
@@ -170,6 +170,162 @@
}
}
+static void hadamard_highbd_col8_first_pass(const int16_t *src_diff,
+ ptrdiff_t src_stride,
+ int16_t *coeff) {
+ int16_t b0 = src_diff[0 * src_stride] + src_diff[1 * src_stride];
+ int16_t b1 = src_diff[0 * src_stride] - src_diff[1 * src_stride];
+ int16_t b2 = src_diff[2 * src_stride] + src_diff[3 * src_stride];
+ int16_t b3 = src_diff[2 * src_stride] - src_diff[3 * src_stride];
+ int16_t b4 = src_diff[4 * src_stride] + src_diff[5 * src_stride];
+ int16_t b5 = src_diff[4 * src_stride] - src_diff[5 * src_stride];
+ int16_t b6 = src_diff[6 * src_stride] + src_diff[7 * src_stride];
+ int16_t b7 = src_diff[6 * src_stride] - src_diff[7 * src_stride];
+
+ int16_t c0 = b0 + b2;
+ int16_t c1 = b1 + b3;
+ int16_t c2 = b0 - b2;
+ int16_t c3 = b1 - b3;
+ int16_t c4 = b4 + b6;
+ int16_t c5 = b5 + b7;
+ int16_t c6 = b4 - b6;
+ int16_t c7 = b5 - b7;
+
+ coeff[0] = c0 + c4;
+ coeff[7] = c1 + c5;
+ coeff[3] = c2 + c6;
+ coeff[4] = c3 + c7;
+ coeff[2] = c0 - c4;
+ coeff[6] = c1 - c5;
+ coeff[1] = c2 - c6;
+ coeff[5] = c3 - c7;
+}
+
+// src_diff: 16 bit, dynamic range [-32760, 32760]
+// coeff: 19 bit
+static void hadamard_highbd_col8_second_pass(const int16_t *src_diff,
+ ptrdiff_t src_stride,
+ int32_t *coeff) {
+ int32_t b0 = src_diff[0 * src_stride] + src_diff[1 * src_stride];
+ int32_t b1 = src_diff[0 * src_stride] - src_diff[1 * src_stride];
+ int32_t b2 = src_diff[2 * src_stride] + src_diff[3 * src_stride];
+ int32_t b3 = src_diff[2 * src_stride] - src_diff[3 * src_stride];
+ int32_t b4 = src_diff[4 * src_stride] + src_diff[5 * src_stride];
+ int32_t b5 = src_diff[4 * src_stride] - src_diff[5 * src_stride];
+ int32_t b6 = src_diff[6 * src_stride] + src_diff[7 * src_stride];
+ int32_t b7 = src_diff[6 * src_stride] - src_diff[7 * src_stride];
+
+ int32_t c0 = b0 + b2;
+ int32_t c1 = b1 + b3;
+ int32_t c2 = b0 - b2;
+ int32_t c3 = b1 - b3;
+ int32_t c4 = b4 + b6;
+ int32_t c5 = b5 + b7;
+ int32_t c6 = b4 - b6;
+ int32_t c7 = b5 - b7;
+
+ coeff[0] = c0 + c4;
+ coeff[7] = c1 + c5;
+ coeff[3] = c2 + c6;
+ coeff[4] = c3 + c7;
+ coeff[2] = c0 - c4;
+ coeff[6] = c1 - c5;
+ coeff[1] = c2 - c6;
+ coeff[5] = c3 - c7;
+}
+
+// The order of the output coeff of the hadamard is not important. For
+// optimization purposes the final transpose may be skipped.
+void aom_highbd_hadamard_8x8_c(const int16_t *src_diff, ptrdiff_t src_stride,
+ tran_low_t *coeff) {
+ int idx;
+ int16_t buffer[64];
+ int32_t buffer2[64];
+ int16_t *tmp_buf = &buffer[0];
+ for (idx = 0; idx < 8; ++idx) {
+ // src_diff: 13 bit
+ // buffer: 16 bit, dynamic range [-32760, 32760]
+ hadamard_highbd_col8_first_pass(src_diff, src_stride, tmp_buf);
+ tmp_buf += 8;
+ ++src_diff;
+ }
+
+ tmp_buf = &buffer[0];
+ for (idx = 0; idx < 8; ++idx) {
+ // buffer: 16 bit
+ // buffer2: 19 bit, dynamic range [-262080, 262080]
+ hadamard_highbd_col8_second_pass(tmp_buf, 8, buffer2 + 8 * idx);
+ ++tmp_buf;
+ }
+
+ for (idx = 0; idx < 64; ++idx) coeff[idx] = (tran_low_t)buffer2[idx];
+}
+
+// In place 16x16 2D Hadamard transform
+void aom_highbd_hadamard_16x16_c(const int16_t *src_diff, ptrdiff_t src_stride,
+ tran_low_t *coeff) {
+ int idx;
+ for (idx = 0; idx < 4; ++idx) {
+ // src_diff: 13 bit, dynamic range [-4095, 4095]
+ const int16_t *src_ptr =
+ src_diff + (idx >> 1) * 8 * src_stride + (idx & 0x01) * 8;
+ aom_highbd_hadamard_8x8_c(src_ptr, src_stride, coeff + idx * 64);
+ }
+
+ // coeff: 19 bit, dynamic range [-262080, 262080]
+ for (idx = 0; idx < 64; ++idx) {
+ tran_low_t a0 = coeff[0];
+ tran_low_t a1 = coeff[64];
+ tran_low_t a2 = coeff[128];
+ tran_low_t a3 = coeff[192];
+
+ tran_low_t b0 = (a0 + a1) >> 1;
+ tran_low_t b1 = (a0 - a1) >> 1;
+ tran_low_t b2 = (a2 + a3) >> 1;
+ tran_low_t b3 = (a2 - a3) >> 1;
+
+ // new coeff dynamic range: 20 bit
+ coeff[0] = b0 + b2;
+ coeff[64] = b1 + b3;
+ coeff[128] = b0 - b2;
+ coeff[192] = b1 - b3;
+
+ ++coeff;
+ }
+}
+
+void aom_highbd_hadamard_32x32_c(const int16_t *src_diff, ptrdiff_t src_stride,
+ tran_low_t *coeff) {
+ int idx;
+ for (idx = 0; idx < 4; ++idx) {
+ // src_diff: 13 bit, dynamic range [-4095, 4095]
+ const int16_t *src_ptr =
+ src_diff + (idx >> 1) * 16 * src_stride + (idx & 0x01) * 16;
+ aom_highbd_hadamard_16x16_c(src_ptr, src_stride, coeff + idx * 256);
+ }
+
+ // coeff: 20 bit
+ for (idx = 0; idx < 256; ++idx) {
+ tran_low_t a0 = coeff[0];
+ tran_low_t a1 = coeff[256];
+ tran_low_t a2 = coeff[512];
+ tran_low_t a3 = coeff[768];
+
+ tran_low_t b0 = (a0 + a1) >> 2;
+ tran_low_t b1 = (a0 - a1) >> 2;
+ tran_low_t b2 = (a2 + a3) >> 2;
+ tran_low_t b3 = (a2 - a3) >> 2;
+
+ // new coeff dynamic range: 20 bit
+ coeff[0] = b0 + b2;
+ coeff[256] = b1 + b3;
+ coeff[512] = b0 - b2;
+ coeff[768] = b1 - b3;
+
+ ++coeff;
+ }
+}
+
// coeff: 16 bits, dynamic range [-32640, 32640].
// length: value range {16, 64, 256, 1024}.
int aom_satd_c(const tran_low_t *coeff, int length) {
diff --git a/aom_dsp/x86/avg_intrin_avx2.c b/aom_dsp/x86/avg_intrin_avx2.c
index e0ba8d5..f9bcdf2 100644
--- a/aom_dsp/x86/avg_intrin_avx2.c
+++ b/aom_dsp/x86/avg_intrin_avx2.c
@@ -224,17 +224,215 @@
}
}
+static void highbd_hadamard_col8_avx2(__m256i *in, int iter) {
+ __m256i a0 = in[0];
+ __m256i a1 = in[1];
+ __m256i a2 = in[2];
+ __m256i a3 = in[3];
+ __m256i a4 = in[4];
+ __m256i a5 = in[5];
+ __m256i a6 = in[6];
+ __m256i a7 = in[7];
+
+ __m256i b0 = _mm256_add_epi32(a0, a1);
+ __m256i b1 = _mm256_sub_epi32(a0, a1);
+ __m256i b2 = _mm256_add_epi32(a2, a3);
+ __m256i b3 = _mm256_sub_epi32(a2, a3);
+ __m256i b4 = _mm256_add_epi32(a4, a5);
+ __m256i b5 = _mm256_sub_epi32(a4, a5);
+ __m256i b6 = _mm256_add_epi32(a6, a7);
+ __m256i b7 = _mm256_sub_epi32(a6, a7);
+
+ a0 = _mm256_add_epi32(b0, b2);
+ a1 = _mm256_add_epi32(b1, b3);
+ a2 = _mm256_sub_epi32(b0, b2);
+ a3 = _mm256_sub_epi32(b1, b3);
+ a4 = _mm256_add_epi32(b4, b6);
+ a5 = _mm256_add_epi32(b5, b7);
+ a6 = _mm256_sub_epi32(b4, b6);
+ a7 = _mm256_sub_epi32(b5, b7);
+
+ if (iter == 0) {
+ b0 = _mm256_add_epi32(a0, a4);
+ b7 = _mm256_add_epi32(a1, a5);
+ b3 = _mm256_add_epi32(a2, a6);
+ b4 = _mm256_add_epi32(a3, a7);
+ b2 = _mm256_sub_epi32(a0, a4);
+ b6 = _mm256_sub_epi32(a1, a5);
+ b1 = _mm256_sub_epi32(a2, a6);
+ b5 = _mm256_sub_epi32(a3, a7);
+
+ a0 = _mm256_unpacklo_epi32(b0, b1);
+ a1 = _mm256_unpacklo_epi32(b2, b3);
+ a2 = _mm256_unpackhi_epi32(b0, b1);
+ a3 = _mm256_unpackhi_epi32(b2, b3);
+ a4 = _mm256_unpacklo_epi32(b4, b5);
+ a5 = _mm256_unpacklo_epi32(b6, b7);
+ a6 = _mm256_unpackhi_epi32(b4, b5);
+ a7 = _mm256_unpackhi_epi32(b6, b7);
+
+ b0 = _mm256_unpacklo_epi64(a0, a1);
+ b1 = _mm256_unpacklo_epi64(a4, a5);
+ b2 = _mm256_unpackhi_epi64(a0, a1);
+ b3 = _mm256_unpackhi_epi64(a4, a5);
+ b4 = _mm256_unpacklo_epi64(a2, a3);
+ b5 = _mm256_unpacklo_epi64(a6, a7);
+ b6 = _mm256_unpackhi_epi64(a2, a3);
+ b7 = _mm256_unpackhi_epi64(a6, a7);
+
+ in[0] = _mm256_permute2x128_si256(b0, b1, 0x20);
+ in[1] = _mm256_permute2x128_si256(b0, b1, 0x31);
+ in[2] = _mm256_permute2x128_si256(b2, b3, 0x20);
+ in[3] = _mm256_permute2x128_si256(b2, b3, 0x31);
+ in[4] = _mm256_permute2x128_si256(b4, b5, 0x20);
+ in[5] = _mm256_permute2x128_si256(b4, b5, 0x31);
+ in[6] = _mm256_permute2x128_si256(b6, b7, 0x20);
+ in[7] = _mm256_permute2x128_si256(b6, b7, 0x31);
+ } else {
+ in[0] = _mm256_add_epi32(a0, a4);
+ in[7] = _mm256_add_epi32(a1, a5);
+ in[3] = _mm256_add_epi32(a2, a6);
+ in[4] = _mm256_add_epi32(a3, a7);
+ in[2] = _mm256_sub_epi32(a0, a4);
+ in[6] = _mm256_sub_epi32(a1, a5);
+ in[1] = _mm256_sub_epi32(a2, a6);
+ in[5] = _mm256_sub_epi32(a3, a7);
+ }
+}
+
+void aom_highbd_hadamard_8x8_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
+ tran_low_t *coeff) {
+ __m128i src16[8];
+ __m256i src32[8];
+
+ src16[0] = _mm_loadu_si128((const __m128i *)src_diff);
+ src16[1] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
+ src16[2] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
+ src16[3] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
+ src16[4] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
+ src16[5] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
+ src16[6] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
+ src16[7] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
+
+ src32[0] = _mm256_cvtepi16_epi32(src16[0]);
+ src32[1] = _mm256_cvtepi16_epi32(src16[1]);
+ src32[2] = _mm256_cvtepi16_epi32(src16[2]);
+ src32[3] = _mm256_cvtepi16_epi32(src16[3]);
+ src32[4] = _mm256_cvtepi16_epi32(src16[4]);
+ src32[5] = _mm256_cvtepi16_epi32(src16[5]);
+ src32[6] = _mm256_cvtepi16_epi32(src16[6]);
+ src32[7] = _mm256_cvtepi16_epi32(src16[7]);
+
+ highbd_hadamard_col8_avx2(src32, 0);
+ highbd_hadamard_col8_avx2(src32, 1);
+
+ _mm256_storeu_si256((__m256i *)coeff, src32[0]);
+ coeff += 8;
+ _mm256_storeu_si256((__m256i *)coeff, src32[1]);
+ coeff += 8;
+ _mm256_storeu_si256((__m256i *)coeff, src32[2]);
+ coeff += 8;
+ _mm256_storeu_si256((__m256i *)coeff, src32[3]);
+ coeff += 8;
+ _mm256_storeu_si256((__m256i *)coeff, src32[4]);
+ coeff += 8;
+ _mm256_storeu_si256((__m256i *)coeff, src32[5]);
+ coeff += 8;
+ _mm256_storeu_si256((__m256i *)coeff, src32[6]);
+ coeff += 8;
+ _mm256_storeu_si256((__m256i *)coeff, src32[7]);
+}
+
+void aom_highbd_hadamard_16x16_avx2(const int16_t *src_diff,
+ ptrdiff_t src_stride, tran_low_t *coeff) {
+ int idx;
+ tran_low_t *t_coeff = coeff;
+ for (idx = 0; idx < 4; ++idx) {
+ const int16_t *src_ptr =
+ src_diff + (idx >> 1) * 8 * src_stride + (idx & 0x01) * 8;
+ aom_highbd_hadamard_8x8_avx2(src_ptr, src_stride, t_coeff + idx * 64);
+ }
+
+ for (idx = 0; idx < 64; idx += 8) {
+ __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
+ __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 64));
+ __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 128));
+ __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 192));
+
+ __m256i b0 = _mm256_add_epi32(coeff0, coeff1);
+ __m256i b1 = _mm256_sub_epi32(coeff0, coeff1);
+ __m256i b2 = _mm256_add_epi32(coeff2, coeff3);
+ __m256i b3 = _mm256_sub_epi32(coeff2, coeff3);
+
+ b0 = _mm256_srai_epi32(b0, 1);
+ b1 = _mm256_srai_epi32(b1, 1);
+ b2 = _mm256_srai_epi32(b2, 1);
+ b3 = _mm256_srai_epi32(b3, 1);
+
+ coeff0 = _mm256_add_epi32(b0, b2);
+ coeff1 = _mm256_add_epi32(b1, b3);
+ coeff2 = _mm256_sub_epi32(b0, b2);
+ coeff3 = _mm256_sub_epi32(b1, b3);
+
+ _mm256_storeu_si256((__m256i *)coeff, coeff0);
+ _mm256_storeu_si256((__m256i *)(coeff + 64), coeff1);
+ _mm256_storeu_si256((__m256i *)(coeff + 128), coeff2);
+ _mm256_storeu_si256((__m256i *)(coeff + 192), coeff3);
+
+ coeff += 8;
+ t_coeff += 8;
+ }
+}
+
+void aom_highbd_hadamard_32x32_avx2(const int16_t *src_diff,
+ ptrdiff_t src_stride, tran_low_t *coeff) {
+ int idx;
+ tran_low_t *t_coeff = coeff;
+ for (idx = 0; idx < 4; ++idx) {
+ const int16_t *src_ptr =
+ src_diff + (idx >> 1) * 16 * src_stride + (idx & 0x01) * 16;
+ aom_highbd_hadamard_16x16_avx2(src_ptr, src_stride, t_coeff + idx * 256);
+ }
+
+ for (idx = 0; idx < 256; idx += 8) {
+ __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
+ __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 256));
+ __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 512));
+ __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 768));
+
+ __m256i b0 = _mm256_add_epi32(coeff0, coeff1);
+ __m256i b1 = _mm256_sub_epi32(coeff0, coeff1);
+ __m256i b2 = _mm256_add_epi32(coeff2, coeff3);
+ __m256i b3 = _mm256_sub_epi32(coeff2, coeff3);
+
+ b0 = _mm256_srai_epi32(b0, 2);
+ b1 = _mm256_srai_epi32(b1, 2);
+ b2 = _mm256_srai_epi32(b2, 2);
+ b3 = _mm256_srai_epi32(b3, 2);
+
+ coeff0 = _mm256_add_epi32(b0, b2);
+ coeff1 = _mm256_add_epi32(b1, b3);
+ coeff2 = _mm256_sub_epi32(b0, b2);
+ coeff3 = _mm256_sub_epi32(b1, b3);
+
+ _mm256_storeu_si256((__m256i *)coeff, coeff0);
+ _mm256_storeu_si256((__m256i *)(coeff + 256), coeff1);
+ _mm256_storeu_si256((__m256i *)(coeff + 512), coeff2);
+ _mm256_storeu_si256((__m256i *)(coeff + 768), coeff3);
+
+ coeff += 8;
+ t_coeff += 8;
+ }
+}
+
int aom_satd_avx2(const tran_low_t *coeff, int length) {
- const __m256i one = _mm256_set1_epi16(1);
__m256i accum = _mm256_setzero_si256();
int i;
- for (i = 0; i < length; i += 16) {
- const __m256i src_line = load_tran_low(coeff);
- const __m256i abs = _mm256_abs_epi16(src_line);
- const __m256i sum = _mm256_madd_epi16(abs, one);
- accum = _mm256_add_epi32(accum, sum);
- coeff += 16;
+ for (i = 0; i < length; i += 8, coeff += 8) {
+ const __m256i src_line = _mm256_loadu_si256((const __m256i *)coeff);
+ const __m256i abs = _mm256_abs_epi32(src_line);
+ accum = _mm256_add_epi32(accum, abs);
}
{ // 32 bit horizontal add
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index ddd5f09..a482496 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -28,12 +28,21 @@
#define MC_FLOW_NUM_PELS (MC_FLOW_BSIZE * MC_FLOW_BSIZE)
static void wht_fwd_txfm(int16_t *src_diff, int bw, tran_low_t *coeff,
- TX_SIZE tx_size) {
- switch (tx_size) {
- case TX_8X8: aom_hadamard_8x8(src_diff, bw, coeff); break;
- case TX_16X16: aom_hadamard_16x16(src_diff, bw, coeff); break;
- case TX_32X32: aom_hadamard_32x32(src_diff, bw, coeff); break;
- default: assert(0);
+ TX_SIZE tx_size, int is_hbd) {
+ if (is_hbd) {
+ switch (tx_size) {
+ case TX_8X8: aom_highbd_hadamard_8x8(src_diff, bw, coeff); break;
+ case TX_16X16: aom_highbd_hadamard_16x16(src_diff, bw, coeff); break;
+ case TX_32X32: aom_highbd_hadamard_32x32(src_diff, bw, coeff); break;
+ default: assert(0);
+ }
+ } else {
+ switch (tx_size) {
+ case TX_8X8: aom_hadamard_8x8(src_diff, bw, coeff); break;
+ case TX_16X16: aom_hadamard_16x16(src_diff, bw, coeff); break;
+ case TX_32X32: aom_hadamard_32x32(src_diff, bw, coeff); break;
+ default: assert(0);
+ }
}
}
@@ -156,7 +165,7 @@
aom_subtract_block(bh, bw, src_diff, bw, src, src_stride, dst,
dst_stride);
}
- wht_fwd_txfm(src_diff, bw, coeff, tx_size);
+ wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
intra_cost = aom_satd(coeff, pix_num);
} else {
int64_t sse;
@@ -233,7 +242,8 @@
xd->cur_buf->y_buffer + mb_y_offset,
xd->cur_buf->y_stride, &predictor[0], bw);
}
- wht_fwd_txfm(src_diff, bw, coeff, tx_size);
+ wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
+
inter_cost = aom_satd(coeff, pix_num);
} else {
int64_t sse;
@@ -766,8 +776,7 @@
aom_subtract_block(bh, bw, src_diff, bw, src_buf, src_stride,
dst_buf, dst_stride);
}
-
- wht_fwd_txfm(src_diff, bw, coeff, tx_size);
+ wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
intra_cost = aom_satd(coeff, pix_num);
} else {
@@ -809,7 +818,7 @@
aom_subtract_block(bh, bw, src_diff, bw, src->y_buffer + mb_y_offset,
src->y_stride, &predictor[0], bw);
}
- wht_fwd_txfm(src_diff, bw, coeff, tx_size);
+ wht_fwd_txfm(src_diff, bw, coeff, tx_size, is_cur_buf_hbd(xd));
inter_cost = aom_satd(coeff, pix_num);
} else {
int64_t sse;