Update txb context calculation code

Split coefficients into signs (0 or 1) and levels (0 to 255),
so that they both can be fit in 1-byte.

Change-Id: I0f486368b7b819a77aaddda4710e83189e53fc55
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 3bf8f8c..aa06495 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -32,7 +32,7 @@
   return txsize_sqr_up_map[tx_size];
 }
 
-static int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
+static const int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
   /* clang-format off*/
   { -2, 0 }, { -1, -1 }, { -1, 0 }, { -1, 1 }, { 0, -2 }, { 0, -1 }, { 0, 1 },
   { 0, 2 },  { 1, -1 },  { 1, 0 },  { 1, 1 },  { 2, 0 }
@@ -110,10 +110,30 @@
   }
 }
 
-static INLINE int get_level_count_mag(int *mag, const tran_low_t *tcoeffs,
-                                      int bwl, int height, int row, int col,
-                                      int level, int (*nb_offset)[2],
-                                      int nb_num) {
+static INLINE int get_level_count_mag(
+    int *const mag, const uint8_t *const levels, const int bwl,
+    const int height, const int row, const int col, const int level,
+    const int (*nb_offset)[2], const int nb_num) {
+  const int stride = 1 << bwl;
+  int count = 0;
+  *mag = 0;
+  for (int idx = 0; idx < nb_num; ++idx) {
+    const int ref_row = row + nb_offset[idx][0];
+    const int ref_col = col + nb_offset[idx][1];
+    if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
+      continue;
+    const int pos = (ref_row << bwl) + ref_col;
+    count += levels[pos] > level;
+    if (nb_offset[idx][0] >= 0 && nb_offset[idx][1] >= 0)
+      *mag = AOMMAX(*mag, levels[pos]);
+  }
+  return count;
+}
+
+static INLINE int get_level_count_mag_coeff(
+    int *const mag, const tran_low_t *const tcoeffs, const int bwl,
+    const int height, const int row, const int col, const int level,
+    const int (*nb_offset)[2], const int nb_num) {
   const int stride = 1 << bwl;
   int count = 0;
   *mag = 0;
@@ -154,23 +174,23 @@
   return ctx_idx;
 }
 
-static INLINE int get_base_ctx(const tran_low_t *tcoeffs,
-                               int c,  // raster order
+static INLINE int get_base_ctx(const uint8_t *const levels,
+                               const int c,  // raster order
                                const int bwl, const int height,
                                const int level) {
   const int row = c >> bwl;
   const int col = c - (row << bwl);
   const int level_minus_1 = level - 1;
   int mag;
-  int count =
-      get_level_count_mag(&mag, tcoeffs, bwl, height, row, col, level_minus_1,
+  const int count =
+      get_level_count_mag(&mag, levels, bwl, height, row, col, level_minus_1,
                           base_ref_offset, BASE_CONTEXT_POSITION_NUM);
-  int ctx_idx = get_base_ctx_from_count_mag(row, col, count, mag > level);
+  const int ctx_idx = get_base_ctx_from_count_mag(row, col, count, mag > level);
   return ctx_idx;
 }
 
 #define BR_CONTEXT_POSITION_NUM 8  // Base range coefficient context
-static int br_ref_offset[BR_CONTEXT_POSITION_NUM][2] = {
+static const int br_ref_offset[BR_CONTEXT_POSITION_NUM][2] = {
   /* clang-format off*/
   { -1, -1 }, { -1, 0 }, { -1, 1 }, { 0, -1 },
   { 0, 1 },   { 1, -1 }, { 1, 0 },  { 1, 1 },
@@ -251,7 +271,7 @@
   return 8 + ctx;
 }
 
-static INLINE int get_br_ctx(const tran_low_t *tcoeffs,
+static INLINE int get_br_ctx(const uint8_t *const levels,
                              const int c,  // raster order
                              const int bwl, const int height) {
   const int row = c >> bwl;
@@ -259,12 +279,26 @@
   const int level_minus_1 = NUM_BASE_LEVELS;
   int mag;
   const int count =
-      get_level_count_mag(&mag, tcoeffs, bwl, height, row, col, level_minus_1,
+      get_level_count_mag(&mag, levels, bwl, height, row, col, level_minus_1,
                           br_ref_offset, BR_CONTEXT_POSITION_NUM);
   const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
   return ctx;
 }
 
+static INLINE int get_br_ctx_coeff(const tran_low_t *const tcoeffs,
+                                   const int c,  // raster order
+                                   const int bwl, const int height) {
+  const int row = c >> bwl;
+  const int col = c - (row << bwl);
+  const int level_minus_1 = NUM_BASE_LEVELS;
+  int mag;
+  const int count = get_level_count_mag_coeff(&mag, tcoeffs, bwl, height, row,
+                                              col, level_minus_1, br_ref_offset,
+                                              BR_CONTEXT_POSITION_NUM);
+  const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
+  return ctx;
+}
+
 #define SIG_REF_OFFSET_NUM 7
 static int sig_ref_offset[SIG_REF_OFFSET_NUM][2] = {
   { -2, -1 }, { -2, 0 }, { -1, -2 }, { -1, -1 },
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 13f944b..df8bc11 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -289,6 +289,9 @@
   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
   const int height = tx_size_high[tx_size];
   int cul_level = 0;
+  uint8_t levels[64 * 64];
+  int8_t signs[64 * 64];
+
   memset(tcoeffs, 0, sizeof(*tcoeffs) * seg_eob);
 
 #if LV_MAP_PROB
@@ -311,6 +314,8 @@
     return 0;
   }
 
+  memset(signs, 0, sizeof(signs[0]) * seg_eob);
+
   (void)blk_row;
   (void)blk_col;
 #if CONFIG_TXK_SEL
@@ -357,19 +362,23 @@
   *max_scan_line = *eob;
 
   int i;
+  for (i = 0; i < seg_eob; i++) {
+    levels[i] = (uint8_t)tcoeffs[i];
+  }
+
   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
 #if !LV_MAP_PROB
     aom_prob *coeff_base = ec_ctx->coeff_base[txs_ctx][plane_type][i];
 #endif
     update_eob = 0;
     for (c = *eob - 1; c >= 0; --c) {
-      tran_low_t *v = &tcoeffs[scan[c]];
-      int sign;
+      uint8_t *const level = &levels[scan[c]];
+      int8_t *const sign = &signs[scan[c]];
       int ctx;
 
-      if (*v <= i) continue;
+      if (*level <= i) continue;
 
-      ctx = get_base_ctx(tcoeffs, scan[c], bwl, height, i + 1);
+      ctx = get_base_ctx(levels, scan[c], bwl, height, i + 1);
 
 #if LV_MAP_PROB
       if (av1_read_record_bin(
@@ -379,7 +388,7 @@
       if (aom_read(r, coeff_base[ctx], ACCT_STR))
 #endif
       {
-        *v = i + 1;
+        assert(*level == i + 1);
         cul_level += i + 1;
 
         if (counts) ++counts->coeff_base[txs_ctx][plane_type][i][ctx][1];
@@ -387,21 +396,20 @@
         if (c == 0) {
           int dc_sign_ctx = txb_ctx->dc_sign_ctx;
 #if LV_MAP_PROB
-          sign = av1_read_record_bin(
+          *sign = av1_read_record_bin(
               counts, r, ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], 2,
               ACCT_STR);
 #else
-          sign =
+          *sign =
               aom_read(r, ec_ctx->dc_sign[plane_type][dc_sign_ctx], ACCT_STR);
 #endif
-          if (counts) ++counts->dc_sign[plane_type][dc_sign_ctx][sign];
+          if (counts) ++counts->dc_sign[plane_type][dc_sign_ctx][*sign];
         } else {
-          sign = av1_read_record_bit(counts, r, ACCT_STR);
+          *sign = av1_read_record_bit(counts, r, ACCT_STR);
         }
-        if (sign) *v = -(*v);
         continue;
       }
-      *v = i + 2;
+      *level = i + 2;
       if (counts) ++counts->coeff_base[txs_ctx][plane_type][i][ctx][0];
 
       // update the eob flag for coefficients with magnitude above 1.
@@ -410,27 +418,27 @@
   }
 
   for (c = update_eob; c >= 0; --c) {
-    tran_low_t *v = &tcoeffs[scan[c]];
-    int sign;
+    uint8_t *const level = &levels[scan[c]];
+    int8_t *const sign = &signs[scan[c]];
     int idx;
     int ctx;
 
-    if (*v <= NUM_BASE_LEVELS) continue;
+    if (*level <= NUM_BASE_LEVELS) continue;
 
     if (c == 0) {
       int dc_sign_ctx = txb_ctx->dc_sign_ctx;
 #if LV_MAP_PROB
-      sign = av1_read_record_bin(
+      *sign = av1_read_record_bin(
           counts, r, ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], 2, ACCT_STR);
 #else
-      sign = aom_read(r, ec_ctx->dc_sign[plane_type][dc_sign_ctx], ACCT_STR);
+      *sign = aom_read(r, ec_ctx->dc_sign[plane_type][dc_sign_ctx], ACCT_STR);
 #endif
-      if (counts) ++counts->dc_sign[plane_type][dc_sign_ctx][sign];
+      if (counts) ++counts->dc_sign[plane_type][dc_sign_ctx][*sign];
     } else {
-      sign = av1_read_record_bit(counts, r, ACCT_STR);
+      *sign = av1_read_record_bit(counts, r, ACCT_STR);
     }
 
-    ctx = get_br_ctx(tcoeffs, scan[c], bwl, height);
+    ctx = get_br_ctx(levels, scan[c], bwl, height);
 
 #if BR_NODE
     for (idx = 0; idx < BASE_RANGE_SETS; ++idx) {
@@ -468,9 +476,8 @@
 
         int br_base = br_index_to_coeff[idx];
 
-        *v = NUM_BASE_LEVELS + 1 + br_base + br_offset;
-        cul_level += *v;
-        if (sign) *v = -(*v);
+        *level = NUM_BASE_LEVELS + 1 + br_base + br_offset;
+        cul_level += *level;
         break;
       }
       if (counts) ++counts->coeff_br[txs_ctx][plane_type][idx][ctx][0];
@@ -487,9 +494,8 @@
       if (aom_read(r, ec_ctx->coeff_lps[txs_ctx][plane_type][ctx], ACCT_STR))
 #endif
       {
-        *v = (idx + 1 + NUM_BASE_LEVELS);
-        if (sign) *v = -(*v);
-        cul_level += abs(*v);
+        *level = idx + 1 + NUM_BASE_LEVELS;
+        cul_level += *level;
 
         if (counts) ++counts->coeff_lps[txs_ctx][plane_type][ctx][1];
         break;
@@ -500,20 +506,19 @@
 #endif
 
     // decode 0-th order Golomb code
-    *v = read_golomb(xd, r, counts) + COEFF_BASE_RANGE + 1 + NUM_BASE_LEVELS;
-    if (sign) *v = -(*v);
-    cul_level += abs(*v);
+    *level =
+        read_golomb(xd, r, counts) + COEFF_BASE_RANGE + 1 + NUM_BASE_LEVELS;
+    cul_level += *level;
   }
 
   for (c = 0; c < *eob; ++c) {
-    int16_t dqv = (c == 0) ? dequant[0] : dequant[1];
-    tran_low_t *v = &tcoeffs[scan[c]];
+    const int16_t dqv = (c == 0) ? dequant[0] : dequant[1];
+    const int level = levels[scan[c]];
+    const int16_t t = (level * dqv) >> shift;
 #if CONFIG_SYMBOLRATE
-    av1_record_coeff(counts, abs(*v));
+    av1_record_coeff(counts, level);
 #endif
-    int sign = (*v) < 0;
-    *v = (abs(*v) * dqv) >> shift;
-    if (sign) *v = -(*v);
+    tcoeffs[scan[c]] = signs[scan[c]] ? -t : t;
   }
 
   cul_level = AOMMIN(63, cul_level);
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index e7af3bf..87707ad 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -272,11 +272,15 @@
       av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
   const int16_t *scan = scan_order->scan;
+  const int seg_eob = tx_size_2d[tx_size];
   int c;
   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
   const int height = tx_size_high[tx_size];
   uint16_t update_eob = 0;
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
+  uint8_t levels[64 * 64];
+  int8_t signs[64 * 64];
+  int i;
 
   (void)blk_row;
   (void)blk_col;
@@ -289,6 +293,12 @@
 #endif
 
   if (eob == 0) return;
+
+  for (i = 0; i < seg_eob; i++) {
+    levels[i] = (uint8_t)abs(tcoeff[i]);
+    signs[i] = (int8_t)(tcoeff[i] < 0);
+  }
+
 #if CONFIG_TXK_SEL
   av1_write_tx_type(cm, xd, blk_row, blk_col, block, plane,
                     get_min_tx_size(tx_size), w);
@@ -325,21 +335,19 @@
   write_nz_map(w, tcoeff, eob, plane, scan, tx_size, tx_type, ec_ctx);
 #endif  // CONFIG_CTX1D
 
-  int i;
   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
 #if !LV_MAP_PROB
     aom_prob *coeff_base = ec_ctx->coeff_base[txs_ctx][plane_type][i];
 #endif
     update_eob = 0;
     for (c = eob - 1; c >= 0; --c) {
-      tran_low_t v = tcoeff[scan[c]];
-      tran_low_t level = abs(v);
-      int sign = (v < 0) ? 1 : 0;
+      const int level = levels[scan[c]];
+      const int sign = signs[scan[c]];
       int ctx;
 
       if (level <= i) continue;
 
-      ctx = get_base_ctx(tcoeff, scan[c], bwl, height, i + 1);
+      ctx = get_base_ctx(levels, scan[c], bwl, height, i + 1);
 
       if (level == i + 1) {
 #if LV_MAP_PROB
@@ -373,9 +381,8 @@
   }
 
   for (c = update_eob; c >= 0; --c) {
-    tran_low_t v = tcoeff[scan[c]];
-    tran_low_t level = abs(v);
-    int sign = (v < 0) ? 1 : 0;
+    const int level = levels[scan[c]];
+    const int sign = signs[scan[c]];
     int idx;
     int ctx;
 
@@ -393,7 +400,7 @@
     }
 
     // level is above 1.
-    ctx = get_br_ctx(tcoeff, scan[c], bwl, height);
+    ctx = get_br_ctx(levels, scan[c], bwl, height);
 
 #if BR_NODE
     int base_range = level - 1 - NUM_BASE_LEVELS;
@@ -770,7 +777,7 @@
 
       if (level > NUM_BASE_LEVELS) {
         int ctx;
-        ctx = get_br_ctx(qcoeff, scan[c], bwl, height);
+        ctx = get_br_ctx_coeff(qcoeff, scan[c], bwl, height);
 #if BR_NODE
         int base_range = level - 1 - NUM_BASE_LEVELS;
         if (base_range < COEFF_BASE_RANGE) {
@@ -1289,7 +1296,7 @@
       ref_num = BASE_CONTEXT_POSITION_NUM;
     }
 #else
-    int(*ref_offset)[2] = base_ref_offset;
+    const int(*ref_offset)[2] = base_ref_offset;
     int ref_num = BASE_CONTEXT_POSITION_NUM;
 #endif
     for (int i = 0; i < ref_num; ++i) {
@@ -1325,7 +1332,7 @@
       ref_num = BR_CONTEXT_POSITION_NUM;
     }
 #else
-    int(*ref_offset)[2] = br_ref_offset;
+    const int(*ref_offset)[2] = br_ref_offset;
     const int ref_num = BR_CONTEXT_POSITION_NUM;
 #endif
     for (int i = 0; i < ref_num; ++i) {
@@ -1594,8 +1601,8 @@
     }
 
     if (abs_qc > NUM_BASE_LEVELS) {
-      int ctx = get_br_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl,
-                           txb_info->height);
+      int ctx = get_br_ctx_coeff(txb_info->qcoeff, scan[scan_idx],
+                                 txb_info->bwl, txb_info->height);
       cost += get_br_cost(abs_qc, ctx, txb_costs->lps_cost[ctx]);
       cost += get_golomb_cost(abs_qc);
     }
@@ -2176,6 +2183,8 @@
   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
   const int height = tx_size_high[tx_size];
   int cul_level = 0;
+  uint8_t levels[64 * 64];
+  int8_t signs[64 * 64];
 
   TX_SIZE txsize_ctx = get_txsize_context(tx_size);
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
@@ -2196,6 +2205,11 @@
     return;
   }
 
+  for (i = 0; i < seg_eob; i++) {
+    levels[i] = (uint8_t)abs(tcoeff[i]);
+    signs[i] = (int8_t)(tcoeff[i] < 0);
+  }
+
 #if CONFIG_TXK_SEL
   av1_update_tx_type_count(cm, xd, blk_row, blk_col, block, plane,
                            mbmi->sb_type, get_min_tx_size(tx_size), td->counts);
@@ -2239,13 +2253,13 @@
   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
     update_eob = 0;
     for (c = eob - 1; c >= 0; --c) {
-      tran_low_t v = qcoeff[scan[c]];
-      tran_low_t level = abs(v);
+      const int level = levels[scan[c]];
+      const int sign = signs[scan[c]];
       int ctx;
 
       if (level <= i) continue;
 
-      ctx = get_base_ctx(tcoeff, scan[c], bwl, height, i + 1);
+      ctx = get_base_ctx(levels, scan[c], bwl, height, i + 1);
 
       if (level == i + 1) {
         ++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][1];
@@ -2256,9 +2270,9 @@
         if (c == 0) {
           int dc_sign_ctx = txb_ctx.dc_sign_ctx;
 
-          ++td->counts->dc_sign[plane_type][dc_sign_ctx][v < 0];
+          ++td->counts->dc_sign[plane_type][dc_sign_ctx][sign];
 #if LV_MAP_PROB
-          update_bin(ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], v < 0, 2);
+          update_bin(ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], sign, 2);
 #endif
           x->mbmi_ext->dc_sign_ctx[plane][block] = dc_sign_ctx;
         }
@@ -2274,8 +2288,8 @@
   }
 
   for (c = update_eob; c >= 0; --c) {
-    tran_low_t v = qcoeff[scan[c]];
-    tran_low_t level = abs(v);
+    const int level = levels[scan[c]];
+    const int sign = signs[scan[c]];
     int idx;
     int ctx;
 
@@ -2285,15 +2299,15 @@
     if (c == 0) {
       int dc_sign_ctx = txb_ctx.dc_sign_ctx;
 
-      ++td->counts->dc_sign[plane_type][dc_sign_ctx][v < 0];
+      ++td->counts->dc_sign[plane_type][dc_sign_ctx][sign];
 #if LV_MAP_PROB
-      update_bin(ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], v < 0, 2);
+      update_bin(ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], sign, 2);
 #endif
       x->mbmi_ext->dc_sign_ctx[plane][block] = dc_sign_ctx;
     }
 
     // level is above 1.
-    ctx = get_br_ctx(tcoeff, scan[c], bwl, height);
+    ctx = get_br_ctx(levels, scan[c], bwl, height);
 
 #if BR_NODE
     int base_range = level - 1 - NUM_BASE_LEVELS;