Split get_level_count_mag()

to get_level_count() and get_level_mag() since they actually
calculate in different levels and get_level_mag() is hard to be SIMDed.

Change-Id: Iedb12a1d592cf09425e5a77e6bdc9990c271c872
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 71471da..9f040e9 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -46,6 +46,11 @@
   /* clang-format on*/
 };
 
+#define CONTEXT_MAG_POSITION_NUM 3
+static const int mag_ref_offset[CONTEXT_MAG_POSITION_NUM][2] = {
+  { 0, 1 }, { 1, 0 }, { 1, 1 }
+};
+
 static INLINE void get_base_count_mag(int *mag, int *count,
                                       const tran_low_t *tcoeffs, int bwl,
                                       int height, int row, int col) {
@@ -84,12 +89,9 @@
   return idx + TX_PAD_HOR * (idx >> bwl);
 }
 
-static INLINE int get_level_count_mag(int *const mag,
-                                      const uint8_t *const levels,
-                                      const int stride, const int row,
-                                      const int col, const int level,
-                                      const int (*nb_offset)[2],
-                                      const int nb_num) {
+static INLINE int get_level_count(const uint8_t *const levels, const int stride,
+                                  const int row, const int col, const int level,
+                                  const int (*nb_offset)[2], const int nb_num) {
   int count = 0;
 
   for (int idx = 0; idx < nb_num; ++idx) {
@@ -97,13 +99,20 @@
     const int ref_col = col + nb_offset[idx][1];
     const int pos = ref_row * stride + ref_col;
     count += levels[pos] > level;
-    if (nb_offset[idx][0] == 0 && nb_offset[idx][1] == 1) mag[0] = levels[pos];
-    if (nb_offset[idx][0] == 1 && nb_offset[idx][1] == 0) mag[1] = levels[pos];
-    if (nb_offset[idx][0] == 1 && nb_offset[idx][1] == 1) mag[2] = levels[pos];
   }
   return count;
 }
 
+static INLINE void get_level_mag(const uint8_t *const levels, const int stride,
+                                 const int row, const int col, int *const mag) {
+  for (int idx = 0; idx < CONTEXT_MAG_POSITION_NUM; ++idx) {
+    const int ref_row = row + mag_ref_offset[idx][0];
+    const int ref_col = col + mag_ref_offset[idx][1];
+    const int pos = ref_row * stride + ref_col;
+    mag[idx] = levels[pos];
+  }
+}
+
 static INLINE int get_base_ctx_from_count_mag(int row, int col, int count,
                                               int sig_mag) {
   const int ctx = base_level_count_to_index[count];
@@ -170,18 +179,18 @@
 
 static INLINE int get_base_ctx(const uint8_t *const levels,
                                const int c,  // raster order
-                               const int bwl, const int level) {
+                               const int bwl, const int level_minus_1) {
   const int row = c >> bwl;
   const int col = c - (row << bwl);
   const int stride = (1 << bwl) + TX_PAD_HOR;
-  const int level_minus_1 = level - 1;
   int mag_count = 0;
   int nb_mag[3] = { 0 };
-  const int count =
-      get_level_count_mag(nb_mag, levels, stride, row, col, level_minus_1,
-                          base_ref_offset, BASE_CONTEXT_POSITION_NUM);
+  const int count = get_level_count(levels, stride, row, col, level_minus_1,
+                                    base_ref_offset, BASE_CONTEXT_POSITION_NUM);
+  get_level_mag(levels, stride, row, col, nb_mag);
 
-  for (int idx = 0; idx < 3; ++idx) mag_count += nb_mag[idx] > level;
+  for (int idx = 0; idx < 3; ++idx)
+    mag_count += nb_mag[idx] > (level_minus_1 + 1);
   const int ctx_idx =
       get_base_ctx_from_count_mag(row, col, count, AOMMIN(2, mag_count));
   return ctx_idx;
@@ -279,9 +288,9 @@
   const int level_minus_1 = NUM_BASE_LEVELS;
   int mag = 0;
   int nb_mag[3] = { 0 };
-  const int count =
-      get_level_count_mag(nb_mag, levels, stride, row, col, level_minus_1,
-                          br_ref_offset, BR_CONTEXT_POSITION_NUM);
+  const int count = get_level_count(levels, stride, row, col, level_minus_1,
+                                    br_ref_offset, BR_CONTEXT_POSITION_NUM);
+  get_level_mag(levels, stride, row, col, nb_mag);
   for (int idx = 0; idx < 3; ++idx) mag = AOMMAX(mag, nb_mag[idx]);
   const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
   return ctx;
@@ -322,7 +331,6 @@
                                             : sig_ref_offset_horiz[idx][1]));
     const int ref_row = row + row_offset;
     const int ref_col = col + col_offset;
-    if (ref_col >= (1 << bwl)) continue;
     const int nb_pos = ref_row * stride + ref_col;
     const int level = levels[nb_pos];
     count += (level != 0);
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 70006f7..f9c8710 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -213,7 +213,7 @@
 
       if (*level <= i) continue;
 
-      ctx = get_base_ctx(levels, scan[c], bwl, i + 1);
+      ctx = get_base_ctx(levels, scan[c], bwl, i);
 
       if (av1_read_record_bin(
               counts, r, ec_ctx->coeff_base_cdf[txs_ctx][plane_type][i][ctx], 2,
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 7549795..85d8630 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -389,7 +389,7 @@
 
       if (level <= i) continue;
 
-      ctx = get_base_ctx(levels, scan[c], bwl, i + 1);
+      ctx = get_base_ctx(levels, scan[c], bwl, i);
 
       if (level == i + 1) {
         aom_write_bin(w, 1, ec_ctx->coeff_base_cdf[txs_ctx][plane_type][i][ctx],
@@ -1367,7 +1367,7 @@
         txb_cache->base_ctx_arr[base_idx][nb_coeff_idx] =
             base_ctx_table[nb_row != 0][nb_col != 0][mag > level][count];
         // int ref_ctx = get_base_ctx(txb_info->qcoeff, nb_coeff_idx,
-        // txb_info->bwl, level);
+        // txb_info->bwl, level - 1);
         // if (ref_ctx != txb_cache->base_ctx_arr[base_idx][nb_coeff_idx]) {
         //   printf("base ctx %d ref_ctx %d\n",
         //   txb_cache->base_ctx_arr[base_idx][nb_coeff_idx], ref_ctx);
@@ -2103,7 +2103,7 @@
 
       if (level <= i) continue;
 
-      ctx = get_base_ctx(levels, scan[c], bwl, i + 1);
+      ctx = get_base_ctx(levels, scan[c], bwl, i);
 
       if (level == i + 1) {
         ++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][1];