Support rectangular tx_size in level map coding

Map the rectangular transform block size into the bigger square
transform block size as the context for level map probability
model.

Change-Id: I20cf2b16daec16172855a78a201b670ff0547bf5
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index bea162d..9286b88 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -24,6 +24,10 @@
   int dc_sign_ctx;
 } TXB_CTX;
 
+static INLINE TX_SIZE get_txsize_context(TX_SIZE tx_size) {
+  return txsize_sqr_up_map[tx_size];
+}
+
 #define BASE_CONTEXT_POSITION_NUM 12
 static int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
   /* clang-format off*/
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 90685a1..66e986b 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -47,9 +47,10 @@
                             int16_t *max_scan_line, int *eob) {
   FRAME_COUNTS *counts = xd->counts;
   TX_SIZE tx_size = get_tx_size(plane, xd);
+  TX_SIZE txs_ctx = get_txsize_context(tx_size);
   PLANE_TYPE plane_type = get_plane_type(plane);
-  aom_prob *nz_map = cm->fc->nz_map[tx_size][plane_type];
-  aom_prob *eob_flag = cm->fc->eob_flag[tx_size][plane_type];
+  aom_prob *nz_map = cm->fc->nz_map[txs_ctx][plane_type];
+  aom_prob *eob_flag = cm->fc->eob_flag[txs_ctx][plane_type];
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const int seg_eob = tx_size_2d[tx_size];
   int c = 0;
@@ -61,14 +62,14 @@
   unsigned int(*nz_map_count)[SIG_COEF_CONTEXTS][2];
   uint8_t txb_mask[32 * 32] = { 0 };
 
-  nz_map_count = (counts) ? &counts->nz_map[tx_size][plane_type] : NULL;
+  nz_map_count = (counts) ? &counts->nz_map[txs_ctx][plane_type] : NULL;
 
   memset(tcoeffs, 0, sizeof(*tcoeffs) * seg_eob);
 
   int all_zero =
-      aom_read(r, cm->fc->txb_skip[tx_size][txb_ctx->txb_skip_ctx], ACCT_STR);
+      aom_read(r, cm->fc->txb_skip[txs_ctx][txb_ctx->txb_skip_ctx], ACCT_STR);
   if (xd->counts)
-    ++xd->counts->txb_skip[tx_size][txb_ctx->txb_skip_ctx][all_zero];
+    ++xd->counts->txb_skip[txs_ctx][txb_ctx->txb_skip_ctx][all_zero];
 
   *eob = 0;
   if (all_zero) {
@@ -106,7 +107,7 @@
 
     if (is_nz) {
       int is_eob = aom_read(r, eob_flag[eob_ctx], tx_size);
-      if (counts) ++counts->eob_flag[tx_size][plane_type][eob_ctx][is_eob];
+      if (counts) ++counts->eob_flag[txs_ctx][plane_type][eob_ctx][is_eob];
       if (is_eob) break;
     }
     txb_mask[scan[c]] = 1;
@@ -117,7 +118,7 @@
 
   int i;
   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
-    aom_prob *coeff_base = cm->fc->coeff_base[tx_size][plane_type][i];
+    aom_prob *coeff_base = cm->fc->coeff_base[txs_ctx][plane_type][i];
 
     update_eob = 0;
     for (c = *eob - 1; c >= 0; --c) {
@@ -133,7 +134,7 @@
         *v = i + 1;
         cul_level += i + 1;
 
-        if (counts) ++counts->coeff_base[tx_size][plane_type][i][ctx][1];
+        if (counts) ++counts->coeff_base[txs_ctx][plane_type][i][ctx][1];
 
         if (c == 0) {
           int dc_sign_ctx = txb_ctx->dc_sign_ctx;
@@ -146,7 +147,7 @@
         continue;
       }
       *v = i + 2;
-      if (counts) ++counts->coeff_base[tx_size][plane_type][i][ctx][0];
+      if (counts) ++counts->coeff_base[txs_ctx][plane_type][i][ctx][0];
 
       // update the eob flag for coefficients with magnitude above 1.
       update_eob = AOMMAX(update_eob, c);
@@ -171,18 +172,18 @@
 
     ctx = get_br_ctx(tcoeffs, scan[c], bwl);
 
-    if (cm->fc->coeff_lps[tx_size][plane_type][ctx] == 0) exit(0);
+    if (cm->fc->coeff_lps[txs_ctx][plane_type][ctx] == 0) exit(0);
 
     for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
-      if (aom_read(r, cm->fc->coeff_lps[tx_size][plane_type][ctx], tx_size)) {
+      if (aom_read(r, cm->fc->coeff_lps[txs_ctx][plane_type][ctx], tx_size)) {
         *v = (idx + 1 + NUM_BASE_LEVELS);
         if (sign) *v = -(*v);
         cul_level += abs(*v);
 
-        if (counts) ++counts->coeff_lps[tx_size][plane_type][ctx][1];
+        if (counts) ++counts->coeff_lps[txs_ctx][plane_type][ctx][1];
         break;
       }
-      if (counts) ++counts->coeff_lps[tx_size][plane_type][ctx][0];
+      if (counts) ++counts->coeff_lps[txs_ctx][plane_type][ctx][0];
     }
     if (idx < COEFF_BASE_RANGE) continue;
 
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index acd915f..c2a99fb 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -467,6 +467,7 @@
   assert((mb->qindex == 0) ^ (xd->lossless[xd->mi[0]->mbmi.segment_id] == 0));
   if (eob == 0) return eob;
   if (xd->lossless[xd->mi[0]->mbmi.segment_id]) return eob;
+
 #if CONFIG_PVQ
   (void)cm;
   (void)tx_size;
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 9f1e729..a40731e 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -78,6 +78,7 @@
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_SIZE tx_size = get_tx_size(plane, xd);
+  const TX_SIZE txs_ctx = get_txsize_context(tx_size);
   const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, is_inter_block(mbmi));
@@ -89,15 +90,15 @@
   uint8_t txb_mask[32 * 32] = { 0 };
   uint16_t update_eob = 0;
 
-  aom_write(w, eob == 0, cm->fc->txb_skip[tx_size][txb_ctx->txb_skip_ctx]);
+  aom_write(w, eob == 0, cm->fc->txb_skip[txs_ctx][txb_ctx->txb_skip_ctx]);
 
   if (eob == 0) return;
 #if CONFIG_TXK_SEL
   av1_write_tx_type(cm, xd, block, plane, w);
 #endif
 
-  nz_map = cm->fc->nz_map[tx_size][plane_type];
-  eob_flag = cm->fc->eob_flag[tx_size][plane_type];
+  nz_map = cm->fc->nz_map[txs_ctx][plane_type];
+  eob_flag = cm->fc->eob_flag[txs_ctx][plane_type];
 
   for (c = 0; c < eob; ++c) {
     int coeff_ctx = get_nz_map_ctx(tcoeff, txb_mask, scan[c], bwl);
@@ -118,7 +119,7 @@
 
   int i;
   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
-    aom_prob *coeff_base = cm->fc->coeff_base[tx_size][plane_type][i];
+    aom_prob *coeff_base = cm->fc->coeff_base[txs_ctx][plane_type][i];
 
     update_eob = 0;
     for (c = eob - 1; c >= 0; --c) {
@@ -164,10 +165,10 @@
     ctx = get_br_ctx(tcoeff, scan[c], bwl);
     for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
       if (level == (idx + 1 + NUM_BASE_LEVELS)) {
-        aom_write(w, 1, cm->fc->coeff_lps[tx_size][plane_type][ctx]);
+        aom_write(w, 1, cm->fc->coeff_lps[txs_ctx][plane_type][ctx]);
         break;
       }
-      aom_write(w, 0, cm->fc->coeff_lps[tx_size][plane_type][ctx]);
+      aom_write(w, 0, cm->fc->coeff_lps[txs_ctx][plane_type][ctx]);
     }
     if (idx < COEFF_BASE_RANGE) continue;
 
@@ -289,6 +290,7 @@
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   const TX_SIZE tx_size = get_tx_size(plane, xd);
+  TX_SIZE txs_ctx = get_txsize_context(tx_size);
   const PLANE_TYPE plane_type = get_plane_type(plane);
   const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
@@ -298,14 +300,14 @@
   int c, cost;
   const int seg_eob = AOMMIN(eob, tx_size_2d[tx_size] - 1);
   int txb_skip_ctx = txb_ctx->txb_skip_ctx;
-  aom_prob *nz_map = xd->fc->nz_map[tx_size][plane_type];
+  aom_prob *nz_map = xd->fc->nz_map[txs_ctx][plane_type];
 
   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
   // txb_mask is only initialized for once here. After that, it will be set when
   // coding zero map and then reset when coding level 1 info.
   uint8_t txb_mask[32 * 32] = { 0 };
   aom_prob(*coeff_base)[COEFF_BASE_CONTEXTS] =
-      xd->fc->coeff_base[tx_size][plane_type];
+      xd->fc->coeff_base[txs_ctx][plane_type];
 
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, is_inter_block(mbmi));
@@ -314,11 +316,11 @@
   cost = 0;
 
   if (eob == 0) {
-    cost = av1_cost_bit(xd->fc->txb_skip[tx_size][txb_skip_ctx], 1);
+    cost = av1_cost_bit(xd->fc->txb_skip[txs_ctx][txb_skip_ctx], 1);
     return cost;
   }
 
-  cost = av1_cost_bit(xd->fc->txb_skip[tx_size][txb_skip_ctx], 0);
+  cost = av1_cost_bit(xd->fc->txb_skip[txs_ctx][txb_skip_ctx], 0);
 
 #if CONFIG_TXK_SEL
   cost += av1_tx_type_cost(cpi, xd, mbmi->sb_type, plane, tx_size, tx_type);
@@ -369,10 +371,10 @@
         for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
           if (level == (idx + 1 + NUM_BASE_LEVELS)) {
             cost +=
-                av1_cost_bit(xd->fc->coeff_lps[tx_size][plane_type][ctx], 1);
+                av1_cost_bit(xd->fc->coeff_lps[txs_ctx][plane_type][ctx], 1);
             break;
           }
-          cost += av1_cost_bit(xd->fc->coeff_lps[tx_size][plane_type][ctx], 0);
+          cost += av1_cost_bit(xd->fc->coeff_lps[txs_ctx][plane_type][ctx], 0);
         }
 
         if (idx >= COEFF_BASE_RANGE) {
@@ -395,7 +397,7 @@
 
       if (c < seg_eob) {
         int eob_ctx = get_eob_ctx(qcoeff, scan[c], bwl);
-        cost += av1_cost_bit(xd->fc->eob_flag[tx_size][plane_type][eob_ctx],
+        cost += av1_cost_bit(xd->fc->eob_flag[txs_ctx][plane_type][eob_ctx],
                              c == (eob - 1));
       }
     }
@@ -1440,6 +1442,7 @@
                      TX_SIZE tx_size, TXB_CTX *txb_ctx) {
   MACROBLOCKD *const xd = &x->e_mbd;
   const PLANE_TYPE plane_type = get_plane_type(plane);
+  const TX_SIZE txs_ctx = get_txsize_context(tx_size);
   const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
   const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   const struct macroblock_plane *p = &x->plane[plane];
@@ -1450,14 +1453,14 @@
   const tran_low_t *tcoeff = BLOCK_OFFSET(p->coeff, block);
   const int16_t *dequant = pd->dequant;
   const int seg_eob = AOMMIN(eob, tx_size_2d[tx_size] - 1);
-  const aom_prob *nz_map = xd->fc->nz_map[tx_size][plane_type];
+  const aom_prob *nz_map = xd->fc->nz_map[txs_ctx][plane_type];
 
   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
   const int stride = 1 << bwl;
   aom_prob(*coeff_base)[COEFF_BASE_CONTEXTS] =
-      xd->fc->coeff_base[tx_size][plane_type];
+      xd->fc->coeff_base[txs_ctx][plane_type];
 
-  const aom_prob *coeff_lps = xd->fc->coeff_lps[tx_size][plane_type];
+  const aom_prob *coeff_lps = xd->fc->coeff_lps[txs_ctx][plane_type];
 
   const int is_inter = is_inter_block(mbmi);
   const SCAN_ORDER *const scan_order =
@@ -1467,8 +1470,8 @@
                                nz_map,
                                coeff_base,
                                coeff_lps,
-                               xd->fc->eob_flag[tx_size][plane_type],
-                               xd->fc->txb_skip[tx_size] };
+                               xd->fc->eob_flag[txs_ctx][plane_type],
+                               xd->fc->txb_skip[txs_ctx] };
 
   const int shift = av1_get_tx_scale(tx_size);
   const int64_t rdmult =
@@ -1555,11 +1558,13 @@
   unsigned int(*nz_map_count)[SIG_COEF_CONTEXTS][2];
   uint8_t txb_mask[32 * 32] = { 0 };
 
-  nz_map_count = &td->counts->nz_map[tx_size][plane_type];
+  TX_SIZE txsize_ctx = get_txsize_context(tx_size);
+
+  nz_map_count = &td->counts->nz_map[txsize_ctx][plane_type];
 
   memcpy(tcoeff, qcoeff, sizeof(*tcoeff) * seg_eob);
 
-  ++td->counts->txb_skip[tx_size][txb_ctx.txb_skip_ctx][eob == 0];
+  ++td->counts->txb_skip[txsize_ctx][txb_ctx.txb_skip_ctx][eob == 0];
   x->mbmi_ext->txb_skip_ctx[plane][block] = txb_ctx.txb_skip_ctx;
 
   x->mbmi_ext->eobs[plane][block] = eob;
@@ -1585,7 +1590,7 @@
     ++(*nz_map_count)[coeff_ctx][is_nz];
 
     if (is_nz) {
-      ++td->counts->eob_flag[tx_size][plane_type][eob_ctx][c == (eob - 1)];
+      ++td->counts->eob_flag[txsize_ctx][plane_type][eob_ctx][c == (eob - 1)];
     }
     txb_mask[scan[c]] = 1;
   }
@@ -1603,7 +1608,7 @@
       ctx = get_base_ctx(tcoeff, scan[c], bwl, i + 1);
 
       if (level == i + 1) {
-        ++td->counts->coeff_base[tx_size][plane_type][i][ctx][1];
+        ++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][1];
         if (c == 0) {
           int dc_sign_ctx = txb_ctx.dc_sign_ctx;
 
@@ -1613,7 +1618,7 @@
         cul_level += level;
         continue;
       }
-      ++td->counts->coeff_base[tx_size][plane_type][i][ctx][0];
+      ++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][0];
       update_eob = AOMMAX(update_eob, c);
     }
   }
@@ -1638,10 +1643,10 @@
     ctx = get_br_ctx(tcoeff, scan[c], bwl);
     for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
       if (level == (idx + 1 + NUM_BASE_LEVELS)) {
-        ++td->counts->coeff_lps[tx_size][plane_type][ctx][1];
+        ++td->counts->coeff_lps[txsize_ctx][plane_type][ctx][1];
         break;
       }
-      ++td->counts->coeff_lps[tx_size][plane_type][ctx][0];
+      ++td->counts->coeff_lps[txsize_ctx][plane_type][ctx][0];
     }
     if (idx < COEFF_BASE_RANGE) continue;