Add av1_get_br_level_counts_sse2()

Change-Id: I6ce7aea19e3bdeef24d3fe66ac6eba7b8d585f9a
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 6ea0cfd..804941a 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -254,7 +254,8 @@
 
 set(AOM_AV1_COMMON_INTRIN_SSE2
     ${AOM_AV1_COMMON_INTRIN_SSE2}
-    "${AOM_ROOT}/av1/common/cdef_block_sse2.c")
+    "${AOM_ROOT}/av1/common/cdef_block_sse2.c"
+    "${AOM_ROOT}/av1/common/x86/mem_sse2.h")
 
 set(AOM_AV1_COMMON_INTRIN_SSSE3
     ${AOM_AV1_COMMON_INTRIN_SSSE3}
@@ -384,6 +385,10 @@
       "${AOM_ROOT}/av1/common/txb_common.c"
       "${AOM_ROOT}/av1/common/txb_common.h")
 
+  set(AOM_AV1_COMMON_INTRIN_SSE2
+      ${AOM_AV1_COMMON_INTRIN_SSE2}
+      "${AOM_ROOT}/av1/common/x86/txb_sse2.c")
+
   set(AOM_AV1_DECODER_SOURCES
       ${AOM_AV1_DECODER_SOURCES}
       "${AOM_ROOT}/av1/decoder/decodetxb.c"
diff --git a/av1/av1_common.mk b/av1/av1_common.mk
index f0c65c2..ed5a7ec 100644
--- a/av1/av1_common.mk
+++ b/av1/av1_common.mk
@@ -76,6 +76,7 @@
 AV1_COMMON_SRCS-yes += common/av1_inv_txfm2d.c
 AV1_COMMON_SRCS-yes += common/av1_inv_txfm1d_cfg.h
 AV1_COMMON_SRCS-$(HAVE_AVX2) += common/x86/convolve_avx2.c
+AV1_COMMON_SRCS-$(HAVE_SSE2) += common/x86/mem_sse2.h
 AV1_COMMON_SRCS-$(HAVE_SSSE3) += common/x86/av1_convolve_ssse3.c
 ifeq ($(CONFIG_CONVOLVE_ROUND)x$(CONFIG_COMPOUND_ROUND),yesx)
 AV1_COMMON_SRCS-$(HAVE_SSE4_1) += common/x86/av1_convolve_scale_sse4.c
@@ -167,6 +168,10 @@
 endif
 endif
 
+ifeq ($(CONFIG_LV_MAP),yes)
+AV1_COMMON_SRCS-$(HAVE_SSE2) += common/x86/txb_sse2.c
+endif
+
 ifeq ($(CONFIG_Q_ADAPT_PROBS),yes)
 AV1_COMMON_SRCS-yes += common/token_cdfs.h
 endif
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index ab0fa48..6b2f1ed 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -67,6 +67,14 @@
 }
 
 #
+# txb
+#
+if (aom_config("CONFIG_LV_MAP") eq "yes") {
+  add_proto qw/void av1_get_br_level_counts/, "const uint8_t *const levels, const int width, const int height, uint8_t *const level_counts";
+  specialize qw/av1_get_br_level_counts sse2/;
+}
+
+#
 # Inverse dct
 #
 add_proto qw/void av1_iht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, const struct txfm_param *param";
diff --git a/av1/common/txb_common.c b/av1/common/txb_common.c
index 38ac984..4607428 100644
--- a/av1/common/txb_common.c
+++ b/av1/common/txb_common.c
@@ -333,8 +333,8 @@
   }
 }
 
-void av1_get_br_level_counts(const uint8_t *const levels, const int width,
-                             const int height, uint8_t *const level_counts) {
+void av1_get_br_level_counts_c(const uint8_t *const levels, const int width,
+                               const int height, uint8_t *const level_counts) {
   const int stride = width + TX_PAD_HOR;
   const int level_minus_1 = NUM_BASE_LEVELS;
 
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index c609f29..edae7f7 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -391,7 +391,7 @@
   const int width = 1 << bwl;
 
   int ctx = 0;
-  int tx_class = get_tx_class(tx_type);
+  const TX_CLASS tx_class = get_tx_class(tx_type);
   int offset;
   if (tx_class == TX_CLASS_2D)
     offset = 0;
@@ -455,7 +455,7 @@
   const int row = coeff_idx >> bwl;
   const int col = coeff_idx - (row << bwl);
 
-  int tx_class = get_tx_class(tx_type);
+  const TX_CLASS tx_class = get_tx_class(tx_type);
 #if USE_CAUSAL_BASE_CTX
   int mag = 0;
   int count = get_nz_count_mag(levels, bwl, row, col, tx_class, &mag);
@@ -578,7 +578,4 @@
                                const int level_minus_1, const int width,
                                const int height, uint8_t *const level_counts);
 
-void av1_get_br_level_counts(const uint8_t *const levels, const int width,
-                             const int height, uint8_t *const level_counts);
-
 #endif  // AV1_COMMON_TXB_COMMON_H_
diff --git a/av1/common/x86/mem_sse2.h b/av1/common/x86/mem_sse2.h
new file mode 100644
index 0000000..432fd99
--- /dev/null
+++ b/av1/common/x86/mem_sse2.h
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AV1_COMMON_X86_MEM_SSE2_H_
+#define AV1_COMMON_X86_MEM_SSE2_H_
+
+#include <emmintrin.h>  // SSE2
+
+#include "./aom_config.h"
+#include "aom/aom_integer.h"
+
+static INLINE __m128i loadh_epi64(const void *const src, const __m128i s) {
+  return _mm_castps_si128(
+      _mm_loadh_pi(_mm_castsi128_ps(s), (const __m64 *)src));
+}
+
+#endif  // AV1_COMMON_X86_MEM_SSE2_H_
diff --git a/av1/common/x86/txb_sse2.c b/av1/common/x86/txb_sse2.c
new file mode 100644
index 0000000..ebe4e8c
--- /dev/null
+++ b/av1/common/x86/txb_sse2.c
@@ -0,0 +1,248 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <assert.h>
+#include <emmintrin.h>  // SSE2
+
+#include "aom/aom_integer.h"
+#include "av1/common/onyxc_int.h"
+#include "av1/common/txb_common.h"
+#include "av1/common/x86/mem_sse2.h"
+
+static INLINE __m128i load_8bit_4x4_sse2(const void *const s0,
+                                         const void *const s1,
+                                         const void *const s2,
+                                         const void *const s3) {
+  return _mm_setr_epi32(*(const uint32_t *)s0, *(const uint32_t *)s1,
+                        *(const uint32_t *)s2, *(const uint32_t *)s3);
+}
+
+static INLINE void load_levels_4x4x3_sse2(const uint8_t *const s0,
+                                          const uint8_t *const s1,
+                                          const uint8_t *const s2,
+                                          const uint8_t *const s3,
+                                          __m128i *const level) {
+  level[0] = load_8bit_4x4_sse2(s0 - 1, s1 - 1, s2 - 1, s3 - 1);
+  level[1] = load_8bit_4x4_sse2(s0 + 0, s1 + 0, s2 + 0, s3 + 0);
+  level[2] = load_8bit_4x4_sse2(s0 + 1, s1 + 1, s2 + 1, s3 + 1);
+}
+
+static INLINE __m128i get_level_counts_kernel_sse2(__m128i *const level) {
+  const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+  __m128i count;
+
+  level[6] = _mm_cmpgt_epi8(level[6], level_minus_1);
+  level[7] = _mm_cmpgt_epi8(level[7], level_minus_1);
+  level[8] = _mm_cmpgt_epi8(level[8], level_minus_1);
+  count = _mm_setzero_si128();
+  count = _mm_sub_epi8(count, level[0]);
+  count = _mm_sub_epi8(count, level[1]);
+  count = _mm_sub_epi8(count, level[2]);
+  count = _mm_sub_epi8(count, level[3]);
+  count = _mm_sub_epi8(count, level[5]);
+  count = _mm_sub_epi8(count, level[6]);
+  count = _mm_sub_epi8(count, level[7]);
+  count = _mm_sub_epi8(count, level[8]);
+  level[0] = level[3];
+  level[1] = level[4];
+  level[2] = level[5];
+  level[3] = level[6];
+  level[4] = level[7];
+  level[5] = level[8];
+
+  return count;
+}
+
+static INLINE void get_4_level_counts_sse2(const uint8_t *const levels,
+                                           const int height,
+                                           uint8_t *const level_counts) {
+  const int stride = 4 + TX_PAD_HOR;
+  const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+  int row = height;
+  __m128i count;
+  __m128i level[9];
+
+  /* level_counts must be 16 byte aligned. */
+  assert(!((intptr_t)level_counts & 0xf));
+  assert(!(height % 4));
+
+  if (height == 4) {
+    load_levels_4x4x3_sse2(levels - 1 * stride, levels + 0 * stride,
+                           levels + 1 * stride, levels + 2 * stride, &level[0]);
+    load_levels_4x4x3_sse2(levels + 0 * stride, levels + 1 * stride,
+                           levels + 2 * stride, levels + 3 * stride, &level[3]);
+    load_levels_4x4x3_sse2(levels + 1 * stride, levels + 2 * stride,
+                           levels + 3 * stride, levels + 4 * stride, &level[6]);
+    level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+    level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+    level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+    level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+    level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+    level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+    count = get_level_counts_kernel_sse2(level);
+    _mm_store_si128((__m128i *)level_counts, count);
+  } else {
+    const uint8_t *ls[4];
+    uint8_t *lcs[4];
+
+    ls[0] = levels + 0 * stride * height / 4;
+    ls[1] = levels + 1 * stride * height / 4;
+    ls[2] = levels + 2 * stride * height / 4;
+    ls[3] = levels + 3 * stride * height / 4;
+    lcs[0] = level_counts + 0 * 4 * height / 4;
+    lcs[1] = level_counts + 1 * 4 * height / 4;
+    lcs[2] = level_counts + 2 * 4 * height / 4;
+    lcs[3] = level_counts + 3 * 4 * height / 4;
+
+    load_levels_4x4x3_sse2(ls[0] - 1 * stride, ls[1] - 1 * stride,
+                           ls[2] - 1 * stride, ls[3] - 1 * stride, &level[0]);
+    load_levels_4x4x3_sse2(ls[0] + 0 * stride, ls[1] + 0 * stride,
+                           ls[2] + 0 * stride, ls[3] + 0 * stride, &level[3]);
+    level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+    level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+    level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+    level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+    level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+    level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+
+    do {
+      load_levels_4x4x3_sse2(ls[0] + 1 * stride, ls[1] + 1 * stride,
+                             ls[2] + 1 * stride, ls[3] + 1 * stride, &level[6]);
+
+      count = get_level_counts_kernel_sse2(level);
+      *(int *)(lcs[0]) = _mm_cvtsi128_si32(count);
+      *(int *)(lcs[1]) = _mm_cvtsi128_si32(_mm_srli_si128(count, 4));
+      *(int *)(lcs[2]) = _mm_cvtsi128_si32(_mm_srli_si128(count, 8));
+      *(int *)(lcs[3]) = _mm_cvtsi128_si32(_mm_srli_si128(count, 12));
+      ls[0] += stride;
+      ls[1] += stride;
+      ls[2] += stride;
+      ls[3] += stride;
+      lcs[0] += 4;
+      lcs[1] += 4;
+      lcs[2] += 4;
+      lcs[3] += 4;
+      row -= 4;
+    } while (row);
+  }
+}
+
+static INLINE void get_8_level_counts_sse2(const uint8_t *const levels,
+                                           const int height,
+                                           uint8_t *const level_counts) {
+  const int stride = 8 + TX_PAD_HOR;
+  const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+  int row = height;
+  __m128i count;
+  __m128i level[9];
+  const uint8_t *ls[2];
+  uint8_t *lcs[2];
+
+  assert(!(height % 2));
+
+  ls[0] = levels;
+  ls[1] = levels + stride * height / 2;
+  lcs[0] = level_counts;
+  lcs[1] = level_counts + 8 * height / 2;
+
+  level[0] = _mm_loadl_epi64((__m128i *)(ls[0] - 1 * stride - 1));
+  level[1] = _mm_loadl_epi64((__m128i *)(ls[0] - 1 * stride + 0));
+  level[2] = _mm_loadl_epi64((__m128i *)(ls[0] - 1 * stride + 1));
+  level[3] = _mm_loadl_epi64((__m128i *)(ls[0] + 0 * stride - 1));
+  level[4] = _mm_loadl_epi64((__m128i *)(ls[0] + 0 * stride + 0));
+  level[5] = _mm_loadl_epi64((__m128i *)(ls[0] + 0 * stride + 1));
+  level[0] = loadh_epi64(ls[1] - 1 * stride - 1, level[0]);
+  level[1] = loadh_epi64(ls[1] - 1 * stride + 0, level[1]);
+  level[2] = loadh_epi64(ls[1] - 1 * stride + 1, level[2]);
+  level[3] = loadh_epi64(ls[1] + 0 * stride - 1, level[3]);
+  level[4] = loadh_epi64(ls[1] + 0 * stride + 0, level[4]);
+  level[5] = loadh_epi64(ls[1] + 0 * stride + 1, level[5]);
+  level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+  level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+  level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+  level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+  level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+  level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+
+  do {
+    level[6] = _mm_loadl_epi64((__m128i *)(ls[0] + 1 * stride - 1));
+    level[7] = _mm_loadl_epi64((__m128i *)(ls[0] + 1 * stride + 0));
+    level[8] = _mm_loadl_epi64((__m128i *)(ls[0] + 1 * stride + 1));
+    level[6] = loadh_epi64(ls[1] + 1 * stride - 1, level[6]);
+    level[7] = loadh_epi64(ls[1] + 1 * stride + 0, level[7]);
+    level[8] = loadh_epi64(ls[1] + 1 * stride + 1, level[8]);
+
+    count = get_level_counts_kernel_sse2(level);
+    _mm_storel_epi64((__m128i *)(lcs[0]), count);
+    _mm_storeh_pi((__m64 *)(lcs[1]), _mm_castsi128_ps(count));
+    ls[0] += stride;
+    ls[1] += stride;
+    lcs[0] += 8;
+    lcs[1] += 8;
+    row -= 2;
+  } while (row);
+}
+
+static INLINE void get_16x_level_counts_sse2(const uint8_t *levels,
+                                             const int width, const int height,
+                                             uint8_t *level_counts) {
+  const int stride = width + TX_PAD_HOR;
+  const __m128i level_minus_1 = _mm_set1_epi8(NUM_BASE_LEVELS);
+  __m128i count;
+  __m128i level[9];
+
+  /* level_counts must be 16 byte aligned. */
+  assert(!((intptr_t)level_counts & 0xf));
+  assert(!(width % 16));
+
+  for (int i = 0; i < width; i += 16) {
+    int row = height;
+
+    level[0] = _mm_loadu_si128((__m128i *)(levels + i - 1 * stride - 1));
+    level[1] = _mm_loadu_si128((__m128i *)(levels + i - 1 * stride + 0));
+    level[2] = _mm_loadu_si128((__m128i *)(levels + i - 1 * stride + 1));
+    level[3] = _mm_loadu_si128((__m128i *)(levels + i + 0 * stride - 1));
+    level[4] = _mm_loadu_si128((__m128i *)(levels + i + 0 * stride + 0));
+    level[5] = _mm_loadu_si128((__m128i *)(levels + i + 0 * stride + 1));
+    level[0] = _mm_cmpgt_epi8(level[0], level_minus_1);
+    level[1] = _mm_cmpgt_epi8(level[1], level_minus_1);
+    level[2] = _mm_cmpgt_epi8(level[2], level_minus_1);
+    level[3] = _mm_cmpgt_epi8(level[3], level_minus_1);
+    level[4] = _mm_cmpgt_epi8(level[4], level_minus_1);
+    level[5] = _mm_cmpgt_epi8(level[5], level_minus_1);
+
+    do {
+      level[6] = _mm_loadu_si128((__m128i *)(levels + i + 1 * stride - 1));
+      level[7] = _mm_loadu_si128((__m128i *)(levels + i + 1 * stride + 0));
+      level[8] = _mm_loadu_si128((__m128i *)(levels + i + 1 * stride + 1));
+
+      count = get_level_counts_kernel_sse2(level);
+      _mm_store_si128((__m128i *)(level_counts + i), count);
+      levels += stride;
+      level_counts += width;
+    } while (--row);
+
+    levels -= stride * height;
+    level_counts -= width * height;
+  }
+}
+
+void av1_get_br_level_counts_sse2(const uint8_t *const levels, const int width,
+                                  const int height,
+                                  uint8_t *const level_counts) {
+  if (width == 4) {
+    get_4_level_counts_sse2(levels, height, level_counts);
+  } else if (width == 8) {
+    get_8_level_counts_sse2(levels, height, level_counts);
+  } else {
+    get_16x_level_counts_sse2(levels, width, height, level_counts);
+  }
+}
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index edcb0aa..ff19776 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -9,6 +9,7 @@
  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
  */
 
+#include "aom_ports/mem.h"
 #include "av1/common/scan.h"
 #include "av1/common/idct.h"
 #include "av1/common/txb_common.h"
@@ -74,7 +75,7 @@
   int cul_level = 0;
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
-  uint8_t level_counts[MAX_TX_SQUARE];
+  DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
   int8_t signs[MAX_TX_SQUARE];
 
   memset(tcoeffs, 0, sizeof(*tcoeffs) * seg_eob);
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index d459f89..21d4a1d 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -9,6 +9,7 @@
  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
  */
 
+#include "aom_ports/mem.h"
 #include "av1/common/scan.h"
 #include "av1/common/blockd.h"
 #include "av1/common/idct.h"
@@ -321,7 +322,7 @@
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
-  uint8_t level_counts[MAX_TX_SQUARE];
+  DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
 
   (void)blk_row;
   (void)blk_col;
@@ -635,7 +636,7 @@
   const int16_t *scan = scan_order->scan;
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
-  uint8_t level_counts[MAX_TX_SQUARE];
+  DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
 
   LV_MAP_COEFF_COST *coeff_costs = &x->coeff_costs[txs_ctx][plane_type];
 
@@ -2144,7 +2145,7 @@
   const int height = tx_size_high[tx_size];
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
-  uint8_t level_counts[MAX_TX_SQUARE];
+  DECLARE_ALIGNED(16, uint8_t, level_counts[MAX_TX_SQUARE]);
   const uint8_t allow_update_cdf = args->allow_update_cdf;
 
   TX_SIZE txsize_ctx = get_txsize_context(tx_size);
diff --git a/test/test.cmake b/test/test.cmake
index df6ed85..7a0d09a 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -169,6 +169,12 @@
       endif ()
     endif ()
 
+    if (CONFIG_LV_MAP)
+      set(AOM_UNIT_TEST_COMMON_SOURCES
+          ${AOM_UNIT_TEST_COMMON_SOURCES}
+          "${AOM_ROOT}/test/txb_test.cc")
+    endif ()
+
     set(AOM_UNIT_TEST_COMMON_INTRIN_NEON
         ${AOM_UNIT_TEST_COMMON_INTRIN_NEON}
         "${AOM_ROOT}/test/simd_cmp_neon.cc")
diff --git a/test/test.mk b/test/test.mk
index 17bdb81..c2191b3 100644
--- a/test/test.mk
+++ b/test/test.mk
@@ -154,6 +154,7 @@
 LIBAOM_TEST_SRCS-$(HAVE_AVX2)          += simd_avx2_test.cc
 LIBAOM_TEST_SRCS-$(HAVE_NEON)          += simd_neon_test.cc
 LIBAOM_TEST_SRCS-yes                   += intrapred_test.cc
+LIBAOM_TEST_SRCS-$(CONFIG_LV_MAP)      += txb_test.cc
 LIBAOM_TEST_SRCS-$(CONFIG_INTRABC)     += intrabc_test.cc
 #LIBAOM_TEST_SRCS-$(CONFIG_AV1_DECODER) += av1_thread_test.cc
 LIBAOM_TEST_SRCS-$(CONFIG_AV1_ENCODER) += dct16x16_test.cc
diff --git a/test/txb_test.cc b/test/txb_test.cc
new file mode 100644
index 0000000..83bba51
--- /dev/null
+++ b/test/txb_test.cc
@@ -0,0 +1,149 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+#include "./aom_config.h"
+#include "./av1_rtcd.h"
+#include "aom_ports/aom_timer.h"
+#include "aom_ports/mem.h"
+#include "av1/common/onyxc_int.h"
+#include "av1/common/txb_common.h"
+#include "test/acm_random.h"
+#include "test/clear_system_state.h"
+#include "test/register_state_check.h"
+#include "test/util.h"
+
+namespace {
+using libaom_test::ACMRandom;
+
+typedef void (*GetLevelCountsFunc)(const uint8_t *const levels, const int width,
+                                   const int height,
+                                   uint8_t *const level_counts);
+
+class TxbTest : public ::testing::TestWithParam<GetLevelCountsFunc> {
+ public:
+  TxbTest() : get_br_level_counts_func_(GetParam()) {}
+
+  virtual ~TxbTest() {}
+
+  virtual void SetUp() {
+    level_counts_ref_ = reinterpret_cast<uint8_t *>(
+        aom_memalign(16, sizeof(*level_counts_ref_) * MAX_TX_SQUARE));
+    ASSERT_TRUE(level_counts_ref_ != NULL);
+    level_counts_ = reinterpret_cast<uint8_t *>(
+        aom_memalign(16, sizeof(*level_counts_) * MAX_TX_SQUARE));
+    ASSERT_TRUE(level_counts_ != NULL);
+  }
+
+  virtual void TearDown() {
+    aom_free(level_counts_ref_);
+    aom_free(level_counts_);
+    libaom_test::ClearSystemState();
+  }
+
+  void GetLevelCountsRun() {
+    const int kNumTests = 10000;
+    int result = 0;
+
+    for (int tx_size = TX_4X4; tx_size < TX_SIZES_ALL; ++tx_size) {
+      const int width = tx_size_wide[tx_size];
+      const int height = tx_size_high[tx_size];
+      levels_ = set_levels(levels_buf_, width);
+      memset(levels_buf_, 0, sizeof(*levels_buf_) * MAX_TX_SQUARE);
+
+      for (int i = 0; i < kNumTests && !result; ++i) {
+        InitLevels(width, height);
+
+        av1_get_br_level_counts_c(levels_, width, height, level_counts_ref_);
+        get_br_level_counts_func_(levels_, width, height, level_counts_);
+
+        PrintDiff(width, height);
+
+        result = memcmp(level_counts_, level_counts_ref_,
+                        sizeof(*level_counts_ref_) * MAX_TX_SQUARE);
+
+        EXPECT_EQ(result, 0) << " width " << width << " height " << height;
+      }
+    }
+  }
+
+  void SpeedTestGetLevelCountsRun() {
+    const int kNumTests = 10000000;
+    aom_usec_timer timer;
+
+    for (int tx_size = TX_4X4; tx_size < TX_SIZES_ALL; ++tx_size) {
+      const int width = tx_size_wide[tx_size];
+      const int height = tx_size_high[tx_size];
+      levels_ = set_levels(levels_buf_, width);
+      memset(levels_buf_, 0, sizeof(*levels_buf_) * MAX_TX_SQUARE);
+      InitLevels(width, height);
+
+      aom_usec_timer_start(&timer);
+      for (int i = 0; i < kNumTests; ++i) {
+        get_br_level_counts_func_(levels_, width, height, level_counts_);
+      }
+      aom_usec_timer_mark(&timer);
+
+      const int elapsed_time = static_cast<int>(aom_usec_timer_elapsed(&timer));
+      printf("get_br_level_counts_%2dx%2d: %7.1f ms\n", width, height,
+             elapsed_time / 1000.0);
+    }
+  }
+
+  void InitLevels(const int width, const int height) {
+    const int stride = width + TX_PAD_HOR;
+
+    for (int i = 0; i < height; ++i) {
+      for (int j = 0; j < width; ++j) {
+        levels_[i * stride + j] = rnd_.Rand8() % (NUM_BASE_LEVELS + 2);
+      }
+    }
+    for (int i = 0; i < MAX_TX_SQUARE; ++i) {
+      level_counts_ref_[i] = level_counts_[i] = 255;
+    }
+  }
+
+  void PrintDiff(const int width, const int height) const {
+    if (memcmp(level_counts_, level_counts_ref_,
+               sizeof(*level_counts_ref_) * MAX_TX_SQUARE)) {
+      for (int y = 0; y < height; y++) {
+        for (int x = 0; x < width; x++) {
+          if (level_counts_ref_[y * width + x] !=
+              level_counts_[y * width + x]) {
+            printf("count[%d][%d] diff:%6d (ref),%6d (opt)\n", y, x,
+                   level_counts_ref_[y * width + x],
+                   level_counts_[y * width + x]);
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  GetLevelCountsFunc get_br_level_counts_func_;
+  ACMRandom rnd_;
+  tran_low_t *coeff_;
+  uint8_t levels_buf_[TX_PAD_2D];
+  uint8_t *levels_;
+  uint8_t *level_counts_ref_;
+  uint8_t *level_counts_;
+};
+
+TEST_P(TxbTest, BitExact) { GetLevelCountsRun(); }
+
+TEST_P(TxbTest, DISABLED_Speed) { SpeedTestGetLevelCountsRun(); }
+
+#if HAVE_SSE2
+INSTANTIATE_TEST_CASE_P(SSE2, TxbTest,
+                        ::testing::Values(av1_get_br_level_counts_sse2));
+#endif
+}  // namespace