Add av1_get_br_level_counts_sse2()
Change-Id: I6ce7aea19e3bdeef24d3fe66ac6eba7b8d585f9a
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 6ea0cfd..804941a 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -254,7 +254,8 @@
set(AOM_AV1_COMMON_INTRIN_SSE2
${AOM_AV1_COMMON_INTRIN_SSE2}
- "${AOM_ROOT}/av1/common/cdef_block_sse2.c")
+ "${AOM_ROOT}/av1/common/cdef_block_sse2.c"
+ "${AOM_ROOT}/av1/common/x86/mem_sse2.h")
set(AOM_AV1_COMMON_INTRIN_SSSE3
${AOM_AV1_COMMON_INTRIN_SSSE3}
@@ -384,6 +385,10 @@
"${AOM_ROOT}/av1/common/txb_common.c"
"${AOM_ROOT}/av1/common/txb_common.h")
+ set(AOM_AV1_COMMON_INTRIN_SSE2
+ ${AOM_AV1_COMMON_INTRIN_SSE2}
+ "${AOM_ROOT}/av1/common/x86/txb_sse2.c")
+
set(AOM_AV1_DECODER_SOURCES
${AOM_AV1_DECODER_SOURCES}
"${AOM_ROOT}/av1/decoder/decodetxb.c"
diff --git a/av1/av1_common.mk b/av1/av1_common.mk
index f0c65c2..ed5a7ec 100644
--- a/av1/av1_common.mk
+++ b/av1/av1_common.mk
@@ -76,6 +76,7 @@
AV1_COMMON_SRCS-yes += common/av1_inv_txfm2d.c
AV1_COMMON_SRCS-yes += common/av1_inv_txfm1d_cfg.h
AV1_COMMON_SRCS-$(HAVE_AVX2) += common/x86/convolve_avx2.c
+AV1_COMMON_SRCS-$(HAVE_SSE2) += common/x86/mem_sse2.h
AV1_COMMON_SRCS-$(HAVE_SSSE3) += common/x86/av1_convolve_ssse3.c
ifeq ($(CONFIG_CONVOLVE_ROUND)x$(CONFIG_COMPOUND_ROUND),yesx)
AV1_COMMON_SRCS-$(HAVE_SSE4_1) += common/x86/av1_convolve_scale_sse4.c
@@ -167,6 +168,10 @@
endif
endif
+ifeq ($(CONFIG_LV_MAP),yes)
+AV1_COMMON_SRCS-$(HAVE_SSE2) += common/x86/txb_sse2.c
+endif
+
ifeq ($(CONFIG_Q_ADAPT_PROBS),yes)
AV1_COMMON_SRCS-yes += common/token_cdfs.h
endif
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index ab0fa48..6b2f1ed 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -67,6 +67,14 @@
}
#
+# txb
+#
+if (aom_config("CONFIG_LV_MAP") eq "yes") {
+ add_proto qw/void av1_get_br_level_counts/, "const uint8_t *const levels, const int width, const int height, uint8_t *const level_counts";
+ specialize qw/av1_get_br_level_counts sse2/;
+}
+
+#
# Inverse dct
#
add_proto qw/void av1_iht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, const struct txfm_param *param";
diff --git a/av1/common/txb_common.c b/av1/common/txb_common.c
index 38ac984..4607428 100644
--- a/av1/common/txb_common.c
+++ b/av1/common/txb_common.c
@@ -333,8 +333,8 @@
}
}
-void av1_get_br_level_counts(const uint8_t *const levels, const int width,
- const int height, uint8_t *const level_counts) {
+void av1_get_br_level_counts_c(const uint8_t *const levels, const int width,
+ const int height, uint8_t *const level_counts) {
const int stride = width + TX_PAD_HOR;
const int level_minus_1 = NUM_BASE_LEVELS;
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index c609f29..edae7f7 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -391,7 +391,7 @@
const int width = 1 << bwl;
int ctx = 0;
- int tx_class = get_tx_class(tx_type);
+ const TX_CLASS tx_class = get_tx_class(tx_type);
int offset;
if (tx_class == TX_CLASS_2D)
offset = 0;
@@ -455,7 +455,7 @@
const int row = coeff_idx >> bwl;
const int col = coeff_idx - (row << bwl);
- int tx_class = get_tx_class(tx_type);
+ const TX_CLASS tx_class = get_tx_class(tx_type);
#if USE_CAUSAL_BASE_CTX
int mag = 0;
int count = get_nz_count_mag(levels, bwl, row, col, tx_class, &mag);
@@ -578,7 +578,4 @@
const int level_minus_1, const int width,
const int height, uint8_t *const level_counts);
-void av1_get_br_level_counts(const uint8_t *const levels, const int width,
- const int height, uint8_t *const level_counts);
-
#endif // AV1_COMMON_TXB_COMMON_H_
diff --git a/av1/common/x86/mem_sse2.h b/av1/common/x86/mem_sse2.h
new file mode 100644
index 0000000..432fd99
--- /dev/null
+++ b/av1/common/x86/mem_sse2.h
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AV1_COMMON_X86_MEM_SSE2_H_
+#define AV1_COMMON_X86_MEM_SSE2_H_
+
+#include <emmintrin.h> // SSE2
+
+#include "./aom_config.h"
+#include "aom/aom_integer.h"
+
+static INLINE __m128i loadh_epi64(const void *const src, const __m128i s) {
+ return _mm_castps_si128(
+ _mm_loadh_pi(_mm_castsi128_ps(s), (const __m64 *)src));
+}
+
+#endif // AV1_COMMON_X86_MEM_SSE2_H_
diff --git a/av1/common/x86/txb_sse2.c b/av1/common/x86/txb_sse2.c
new file mode 100644
index 0000000..ebe4e8c
--- /dev/null
+++ b/av1/common/x86/txb_sse2.c
@@ -0,0 +1,248 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <assert.h>
+#include <emmintrin.h> // SSE2
+
+#include "aom/aom_integer.h"
+#include "av1/common/onyxc_int.h"
+#include "av1/common/txb_common.h"
+#include "av1/common/x86/mem_sse2.h"
+
+static INLINE __m128i load_8bit_4x4_sse2(const void *const s0,
+ const void *const s1,
+ const void *const s2,
+ const void *const s3) {
+ return _mm_setr_epi32(*(const uint32_t *)s0, *(const uint32_t *)s1,
+ *(const uint32_t *)s2, *(const uint32_t *)s3);
+}
+
+static INLINE void load_levels_4x4x3_sse2(const uint8_t *const s0,
+ const uint8_t *const s1,
+ const uint8_t *const s2,
+ const uint8_t *const s3,
+ __m128i *const level) {
+ level[0] = load_8bit_4x4_sse2(s0 - 1, s1 - 1, s2 - 1, s3 - 1);
+ level[1] = load_8bit_4x4_sse2(s0 + 0, s1 + 0, s2 + 0, s3 + 0);
+ level[2] = load_8bit_4x4_sse2(s0 + 1, s1 + 1, s2 + 1, s3 + 1);
+}
+
+static INLINE __m128i get_level_counts_kernel_sse2(__m128i *const level) {
+ const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+ __m128i count;
+
+ level[6] = _mm_cmpgt_epi8(level[6], level_minus_1);
+ level[7] = _mm_cmpgt_epi8(level[7], level_minus_1);
+ level[8] = _mm_cmpgt_epi8(level[8], level_minus_1);
+ count = _mm_setzero_si128();
+ count = _mm_sub_epi8(count, level[0]);
+ count = _mm_sub_epi8(count, level[1]);
+ count = _mm_sub_epi8(count, level[2]);
+ count = _mm_sub_epi8(count, level[3]);
+ count = _mm_sub_epi8(count, level[5]);
+ count = _mm_sub_epi8(count, level[6]);
+ count = _mm_sub_epi8(count, level[7]);
+ count = _mm_sub_epi8(count, level[8]);
+ level[0] = level[3];
+ level[1] = level[4];
+ level[2] = level[5];
+ level[3] = level[6];
+ level[4] = level[7];
+ level[5] = level[8];
+
+ return count;
+}
+
+static INLINE void get_4_level_counts_sse2(const uint8_t *const levels,
+ const int height,
+ uint8_t *const level_counts) {
+ const int stride = 4 + TX_PAD_HOR;
+ const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+ int row = height;
+ __m128i count;
+ __m128i level[9];
+
+ /* level_counts must be 16 byte aligned. */
+ assert(!((intptr_t)level_counts & 0xf));
+ assert(!(height % 4));
+
+ if (height == 4) {
+ load_levels_4x4x3_sse2(levels - 1 * stride, levels + 0 * stride,
+ levels + 1 * stride, levels + 2 * stride, &level[0]);
+ load_levels_4x4x3_sse2(levels + 0 * stride, levels + 1 * stride,
+ levels + 2 * stride, levels + 3 * stride, &level[3]);
+ load_levels_4x4x3_sse2(levels + 1 * stride, levels + 2 * stride,
+ levels + 3 * stride, levels + 4 * stride, &level[6]);
+ level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+ level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+ level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+ level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+ level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+ level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+ count = get_level_counts_kernel_sse2(level);
+ _mm_store_si128((__m128i *)level_counts, count);
+ } else {
+ const uint8_t *ls[4];
+ uint8_t *lcs[4];
+
+ ls[0] = levels + 0 * stride * height / 4;
+ ls[1] = levels + 1 * stride * height / 4;
+ ls[2] = levels + 2 * stride * height / 4;
+ ls[3] = levels + 3 * stride * height / 4;
+ lcs[0] = level_counts + 0 * 4 * height / 4;
+ lcs[1] = level_counts + 1 * 4 * height / 4;
+ lcs[2] = level_counts + 2 * 4 * height / 4;
+ lcs[3] = level_counts + 3 * 4 * height / 4;
+
+ load_levels_4x4x3_sse2(ls[0] - 1 * stride, ls[1] - 1 * stride,
+ ls[2] - 1 * stride, ls[3] - 1 * stride, &level[0]);
+ load_levels_4x4x3_sse2(ls[0] + 0 * stride, ls[1] + 0 * stride,
+ ls[2] + 0 * stride, ls[3] + 0 * stride, &level[3]);
+ level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+ level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+ level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+ level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+ level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+ level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+
+ do {
+ load_levels_4x4x3_sse2(ls[0] + 1 * stride, ls[1] + 1 * stride,
+ ls[2] + 1 * stride, ls[3] + 1 * stride, &level[6]);
+
+ count = get_level_counts_kernel_sse2(level);
+ *(int *)(lcs[0]) = _mm_cvtsi128_si32(count);
+ *(int *)(lcs[1]) = _mm_cvtsi128_si32(_mm_srli_si128(count, 4));
+ *(int *)(lcs[2]) = _mm_cvtsi128_si32(_mm_srli_si128(count, 8));
+ *(int *)(lcs[3]) = _mm_cvtsi128_si32(_mm_srli_si128(count, 12));
+ ls[0] += stride;
+ ls[1] += stride;
+ ls[2] += stride;
+ ls[3] += stride;
+ lcs[0] += 4;
+ lcs[1] += 4;
+ lcs[2] += 4;
+ lcs[3] += 4;
+ row -= 4;
+ } while (row);
+ }
+}
+
+static INLINE void get_8_level_counts_sse2(const uint8_t *const levels,
+ const int height,
+ uint8_t *const level_counts) {
+ const int stride = 8 + TX_PAD_HOR;
+ const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+ int row = height;
+ __m128i count;
+ __m128i level[9];
+ const uint8_t *ls[2];
+ uint8_t *lcs[2];
+
+ assert(!(height % 2));
+
+ ls[0] = levels;
+ ls[1] = levels + stride * height / 2;
+ lcs[0] = level_counts;
+ lcs[1] = level_counts + 8 * height / 2;
+
+ level[0] = _mm_loadl_epi64((__m128i *)(ls[0] - 1 * stride - 1));
+ level[1] = _mm_loadl_epi64((__m128i *)(ls[0] - 1 * stride + 0));
+ level[2] = _mm_loadl_epi64((__m128i *)(ls[0] - 1 * stride + 1));
+ level[3] = _mm_loadl_epi64((__m128i *)(ls[0] + 0 * stride - 1));
+ level[4] = _mm_loadl_epi64((__m128i *)(ls[0] + 0 * stride + 0));
+ level[5] = _mm_loadl_epi64((__m128i *)(ls[0] + 0 * stride + 1));
+ level[0] = loadh_epi64(ls[1] - 1 * stride - 1, level[0]);
+ level[1] = loadh_epi64(ls[1] - 1 * stride + 0, level[1]);
+ level[2] = loadh_epi64(ls[1] - 1 * stride + 1, level[2]);
+ level[3] = loadh_epi64(ls[1] + 0 * stride - 1, level[3]);
+ level[4] = loadh_epi64(ls[1] + 0 * stride + 0, level[4]);
+ level[5] = loadh_epi64(ls[1] + 0 * stride + 1, level[5]);
+ level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+ level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+ level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+ level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+ level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+ level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+
+ do {
+ level[6] = _mm_loadl_epi64((__m128i *)(ls[0] + 1 * stride - 1));
+ level[7] = _mm_loadl_epi64((__m128i *)(ls[0] + 1 * stride + 0));
+ level[8] = _mm_loadl_epi64((__m128i *)(ls[0] + 1 * stride + 1));
+ level[6] = loadh_epi64(ls[1] + 1 * stride - 1, level[6]);
+ level[7] = loadh_epi64(ls[1] + 1 * stride + 0, level[7]);
+ level[8] = loadh_epi64(ls[1] + 1 * stride + 1, level[8]);
+
+ count = get_level_counts_kernel_sse2(level);
+ _mm_storel_epi64((__m128i *)(lcs[0]), count);
+ _mm_storeh_pi((__m64 *)(lcs[1]), _mm_castsi128_ps(count));
+ ls[0] += stride;
+ ls[1] += stride;
+ lcs[0] += 8;
+ lcs[1] += 8;
+ row -= 2;
+ } while (row);
+}
+
+static INLINE void get_16x_level_counts_sse2(const uint8_t *levels,
+ const int width, const int height,
+ uint8_t *level_counts) {
+ const int stride = width + TX_PAD_HOR;
+ const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+ __m128i count;
+ __m128i level[9];
+
+ /* level_counts must be 16 byte aligned. */
+ assert(!((intptr_t)level_counts & 0xf));
+ assert(!(width % 16));
+
+ for (int i = 0; i < width; i += 16) {
+ int row = height;
+
+ level[0] = _mm_loadu_si128((__m128i *)(levels + i - 1 * stride - 1));
+ level[1] = _mm_loadu_si128((__m128i *)(levels + i - 1 * stride + 0));
+ level[2] = _mm_loadu_si128((__m128i *)(levels + i - 1 * stride + 1));
+ level[3] = _mm_loadu_si128((__m128i *)(levels + i + 0 * stride - 1));
+ level[4] = _mm_loadu_si128((__m128i *)(levels + i + 0 * stride + 0));
+ level[5] = _mm_loadu_si128((__m128i *)(levels + i + 0 * stride + 1));
+ level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+ level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+ level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+ level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+ level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+ level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+
+ do {
+ level[6] = _mm_loadu_si128((__m128i *)(levels + i + 1 * stride - 1));
+ level[7] = _mm_loadu_si128((__m128i *)(levels + i + 1 * stride + 0));
+ level[8] = _mm_loadu_si128((__m128i *)(levels + i + 1 * stride + 1));
+
+ count = get_level_counts_kernel_sse2(level);
+ _mm_store_si128((__m128i *)(level_counts + i), count);
+ levels += stride;
+ level_counts += width;
+ } while (--row);
+
+ levels -= stride * height;
+ level_counts -= width * height;
+ }
+}
+
+void av1_get_br_level_counts_sse2(const uint8_t *const levels, const int width,
+ const int height,
+ uint8_t *const level_counts) {
+ if (width == 4) {
+ get_4_level_counts_sse2(levels, height, level_counts);
+ } else if (width == 8) {
+ get_8_level_counts_sse2(levels, height, level_counts);
+ } else {
+ get_16x_level_counts_sse2(levels, width, height, level_counts);
+ }
+}
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index edcb0aa..ff19776 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -9,6 +9,7 @@
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
+#include "aom_ports/mem.h"
#include "av1/common/scan.h"
#include "av1/common/idct.h"
#include "av1/common/txb_common.h"
@@ -74,7 +75,7 @@
int cul_level = 0;
uint8_t levels_buf[TX_PAD_2D];
uint8_t *const levels = set_levels(levels_buf, width);
- uint8_t level_counts[MAX_TX_SQUARE];
+ DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
int8_t signs[MAX_TX_SQUARE];
memset(tcoeffs, 0, sizeof(*tcoeffs) * seg_eob);
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index d459f89..21d4a1d 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -9,6 +9,7 @@
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
+#include "aom_ports/mem.h"
#include "av1/common/scan.h"
#include "av1/common/blockd.h"
#include "av1/common/idct.h"
@@ -321,7 +322,7 @@
FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
uint8_t levels_buf[TX_PAD_2D];
uint8_t *const levels = set_levels(levels_buf, width);
- uint8_t level_counts[MAX_TX_SQUARE];
+ DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
(void)blk_row;
(void)blk_col;
@@ -635,7 +636,7 @@
const int16_t *scan = scan_order->scan;
uint8_t levels_buf[TX_PAD_2D];
uint8_t *const levels = set_levels(levels_buf, width);
- uint8_t level_counts[MAX_TX_SQUARE];
+ DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
LV_MAP_COEFF_COST *coeff_costs = &x->coeff_costs[txs_ctx][plane_type];
@@ -2144,7 +2145,7 @@
const int height = tx_size_high[tx_size];
uint8_t levels_buf[TX_PAD_2D];
uint8_t *const levels = set_levels(levels_buf, width);
- uint8_t level_counts[MAX_TX_SQUARE];
+ DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
const uint8_t allow_update_cdf = args->allow_update_cdf;
TX_SIZE txsize_ctx = get_txsize_context(tx_size);
diff --git a/test/test.cmake b/test/test.cmake
index df6ed85..7a0d09a 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -169,6 +169,12 @@
endif ()
endif ()
+ if (CONFIG_LV_MAP)
+ set(AOM_UNIT_TEST_COMMON_SOURCES
+ ${AOM_UNIT_TEST_COMMON_SOURCES}
+ "${AOM_ROOT}/test/txb_test.cc")
+ endif ()
+
set(AOM_UNIT_TEST_COMMON_INTRIN_NEON
${AOM_UNIT_TEST_COMMON_INTRIN_NEON}
"${AOM_ROOT}/test/simd_cmp_neon.cc")
diff --git a/test/test.mk b/test/test.mk
index 17bdb81..c2191b3 100644
--- a/test/test.mk
+++ b/test/test.mk
@@ -154,6 +154,7 @@
LIBAOM_TEST_SRCS-$(HAVE_AVX2) += simd_avx2_test.cc
LIBAOM_TEST_SRCS-$(HAVE_NEON) += simd_neon_test.cc
LIBAOM_TEST_SRCS-yes += intrapred_test.cc
+LIBAOM_TEST_SRCS-$(CONFIG_LV_MAP) += txb_test.cc
LIBAOM_TEST_SRCS-$(CONFIG_INTRABC) += intrabc_test.cc
#LIBAOM_TEST_SRCS-$(CONFIG_AV1_DECODER) += av1_thread_test.cc
LIBAOM_TEST_SRCS-$(CONFIG_AV1_ENCODER) += dct16x16_test.cc
diff --git a/test/txb_test.cc b/test/txb_test.cc
new file mode 100644
index 0000000..83bba51
--- /dev/null
+++ b/test/txb_test.cc
@@ -0,0 +1,149 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+#include "./aom_config.h"
+#include "./av1_rtcd.h"
+#include "aom_ports/aom_timer.h"
+#include "aom_ports/mem.h"
+#include "av1/common/onyxc_int.h"
+#include "av1/common/txb_common.h"
+#include "test/acm_random.h"
+#include "test/clear_system_state.h"
+#include "test/register_state_check.h"
+#include "test/util.h"
+
+namespace {
+using libaom_test::ACMRandom;
+
+typedef void (*GetLevelCountsFunc)(const uint8_t *const levels, const int width,
+ const int height,
+ uint8_t *const level_counts);
+
+class TxbTest : public ::testing::TestWithParam<GetLevelCountsFunc> {
+ public:
+ TxbTest() : get_br_level_counts_func_(GetParam()) {}
+
+ virtual ~TxbTest() {}
+
+ virtual void SetUp() {
+ level_counts_ref_ = reinterpret_cast<uint8_t *>(
+ aom_memalign(16, sizeof(*level_counts_ref_) * MAX_TX_SQUARE));
+ ASSERT_TRUE(level_counts_ref_ != NULL);
+ level_counts_ = reinterpret_cast<uint8_t *>(
+ aom_memalign(16, sizeof(*level_counts_) * MAX_TX_SQUARE));
+ ASSERT_TRUE(level_counts_ != NULL);
+ }
+
+ virtual void TearDown() {
+ aom_free(level_counts_ref_);
+ aom_free(level_counts_);
+ libaom_test::ClearSystemState();
+ }
+
+ void GetLevelCountsRun() {
+ const int kNumTests = 10000;
+ int result = 0;
+
+ for (int tx_size = TX_4X4; tx_size < TX_SIZES_ALL; ++tx_size) {
+ const int width = tx_size_wide[tx_size];
+ const int height = tx_size_high[tx_size];
+ levels_ = set_levels(levels_buf_, width);
+ memset(levels_buf_, 0, sizeof(*levels_buf_) * MAX_TX_SQUARE);
+
+ for (int i = 0; i < kNumTests && !result; ++i) {
+ InitLevels(width, height);
+
+ av1_get_br_level_counts_c(levels_, width, height, level_counts_ref_);
+ get_br_level_counts_func_(levels_, width, height, level_counts_);
+
+ PrintDiff(width, height);
+
+ result = memcmp(level_counts_, level_counts_ref_,
+ sizeof(*level_counts_ref_) * MAX_TX_SQUARE);
+
+ EXPECT_EQ(result, 0) << " width " << width << " height " << height;
+ }
+ }
+ }
+
+ void SpeedTestGetLevelCountsRun() {
+ const int kNumTests = 10000000;
+ aom_usec_timer timer;
+
+ for (int tx_size = TX_4X4; tx_size < TX_SIZES_ALL; ++tx_size) {
+ const int width = tx_size_wide[tx_size];
+ const int height = tx_size_high[tx_size];
+ levels_ = set_levels(levels_buf_, width);
+ memset(levels_buf_, 0, sizeof(*levels_buf_) * MAX_TX_SQUARE);
+ InitLevels(width, height);
+
+ aom_usec_timer_start(&timer);
+ for (int i = 0; i < kNumTests; ++i) {
+ get_br_level_counts_func_(levels_, width, height, level_counts_);
+ }
+ aom_usec_timer_mark(&timer);
+
+ const int elapsed_time = static_cast<int>(aom_usec_timer_elapsed(&timer));
+ printf("get_br_level_counts_%2dx%2d: %7.1f ms\n", width, height,
+ elapsed_time / 1000.0);
+ }
+ }
+
+ void InitLevels(const int width, const int height) {
+ const int stride = width + TX_PAD_HOR;
+
+ for (int i = 0; i < height; ++i) {
+ for (int j = 0; j < width; ++j) {
+ levels_[i * stride + j] = rnd_.Rand8() % (NUM_BASE_LEVELS + 2);
+ }
+ }
+ for (int i = 0; i < MAX_TX_SQUARE; ++i) {
+ level_counts_ref_[i] = level_counts_[i] = 255;
+ }
+ }
+
+ void PrintDiff(const int width, const int height) const {
+ if (memcmp(level_counts_, level_counts_ref_,
+ sizeof(*level_counts_ref_) * MAX_TX_SQUARE)) {
+ for (int y = 0; y < height; y++) {
+ for (int x = 0; x < width; x++) {
+ if (level_counts_ref_[y * width + x] !=
+ level_counts_[y * width + x]) {
+ printf("count[%d][%d] diff:%6d (ref),%6d (opt)\n", y, x,
+ level_counts_ref_[y * width + x],
+ level_counts_[y * width + x]);
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ GetLevelCountsFunc get_br_level_counts_func_;
+ ACMRandom rnd_;
+ tran_low_t *coeff_;
+ uint8_t levels_buf_[TX_PAD_2D];
+ uint8_t *levels_;
+ uint8_t *level_counts_ref_;
+ uint8_t *level_counts_;
+};
+
+TEST_P(TxbTest, BitExact) { GetLevelCountsRun(); }
+
+TEST_P(TxbTest, DISABLED_Speed) { SpeedTestGetLevelCountsRun(); }
+
+#if HAVE_SSE2
+INSTANTIATE_TEST_CASE_P(SSE2, TxbTest,
+ ::testing::Values(av1_get_br_level_counts_sse2));
+#endif
+} // namespace