Use tx_class instead of tx_type

These changes reduce tx_type_to_class array lookup

Change-Id: Icb17eef582df1c9b01b29cabb7a6cafbe5346b54
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index b056f81..cc2e016 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -237,7 +237,7 @@
   # End av1_high encoder functions
 
   # txb
-  add_proto qw/void av1_get_nz_map_contexts/, "const uint8_t *const levels, const int16_t *const scan, const uint16_t eob, const TX_SIZE tx_size, const TX_TYPE tx_type, int8_t *const coeff_contexts";
+  add_proto qw/void av1_get_nz_map_contexts/, "const uint8_t *const levels, const int16_t *const scan, const uint16_t eob, const TX_SIZE tx_size, const TX_CLASS tx_class, int8_t *const coeff_contexts";
   specialize qw/av1_get_nz_map_contexts sse2/;
   add_proto qw/void av1_txb_init_levels/, "const tran_low_t *const coeff, const int width, const int height, uint8_t *const levels";
   specialize qw/av1_txb_init_levels sse4_1/;
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 6d4334f..4d98c9f 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -364,11 +364,10 @@
 
 static INLINE int get_br_ctx(const uint8_t *const levels,
                              const int c,  // raster order
-                             const int bwl, const TX_TYPE tx_type) {
+                             const int bwl, const TX_CLASS tx_class) {
   const int row = c >> bwl;
   const int col = c - (row << bwl);
   const int stride = (1 << bwl) + TX_PAD_HOR;
-  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   const int pos = row * stride + col;
   int mag = levels[pos + 1];
   mag += levels[pos + stride];
@@ -572,8 +571,7 @@
 }
 static INLINE int get_lower_levels_ctx(const uint8_t *levels, int coeff_idx,
                                        int bwl, TX_SIZE tx_size,
-                                       TX_TYPE tx_type) {
-  const TX_CLASS tx_class = tx_type_to_class[tx_type];
+                                       TX_CLASS tx_class) {
   const int stats =
       get_nz_mag(levels + get_padded_idx(coeff_idx, bwl), bwl, tx_class);
   return get_nz_map_ctx_from_stats(stats, coeff_idx, bwl, tx_size, tx_class);
@@ -583,14 +581,14 @@
                                                int bwl, int height,
                                                const uint8_t *levels,
                                                int coeff_idx, TX_SIZE tx_size,
-                                               TX_TYPE tx_type) {
+                                               TX_CLASS tx_class) {
   if (is_last) {
     if (scan_idx == 0) return 0;
     if (scan_idx <= (height << bwl) >> 3) return 1;
     if (scan_idx <= (height << bwl) >> 2) return 2;
     return 3;
   }
-  return get_lower_levels_ctx(levels, coeff_idx, bwl, tx_size, tx_type);
+  return get_lower_levels_ctx(levels, coeff_idx, bwl, tx_size, tx_class);
 }
 
 static INLINE void set_dc_sign(int *cul_level, int dc_val) {
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 6c1962f..27c2728 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -84,18 +84,18 @@
 }
 
 static INLINE void read_coeffs_reverse(aom_reader *r, TX_SIZE tx_size,
-                                       TX_TYPE tx_type, int start_si,
+                                       TX_CLASS tx_class, int start_si,
                                        int end_si, const int16_t *scan, int bwl,
                                        uint8_t *levels, base_cdf_arr base_cdf,
                                        br_cdf_arr br_cdf) {
   for (int c = end_si; c >= start_si; --c) {
     const int pos = scan[c];
     const int coeff_ctx =
-        get_lower_levels_ctx(levels, pos, bwl, tx_size, tx_type);
+        get_lower_levels_ctx(levels, pos, bwl, tx_size, tx_class);
     const int nsymbs = 4;
     int level = aom_read_symbol(r, base_cdf[coeff_ctx], nsymbs, ACCT_STR);
     if (level > NUM_BASE_LEVELS) {
-      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_type);
+      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_class);
       aom_cdf_prob *cdf = br_cdf[br_ctx];
       for (int idx = 0; idx < COEFF_BASE_RANGE; idx += BR_CDF_SIZE - 1) {
         const int k = aom_read_symbol(r, cdf, BR_CDF_SIZE, ACCT_STR);
@@ -222,6 +222,7 @@
   }
   *eob = rec_eob_pos(eob_pt, eob_extra);
 
+  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   {
     // Read the non-zero coefficient with scan index eob-1
     // TODO(angiebird): Put this into a function
@@ -233,7 +234,7 @@
         ec_ctx->coeff_base_eob_cdf[txs_ctx][plane_type][coeff_ctx];
     int level = aom_read_symbol(r, cdf, nsymbs, ACCT_STR) + 1;
     if (level > NUM_BASE_LEVELS) {
-      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_type);
+      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_class);
       for (int idx = 0; idx < COEFF_BASE_RANGE; idx += BR_CDF_SIZE - 1) {
         const int k = aom_read_symbol(
             r,
@@ -249,14 +250,13 @@
     base_cdf_arr base_cdf = ec_ctx->coeff_base_cdf[txs_ctx][plane_type];
     br_cdf_arr br_cdf =
         ec_ctx->coeff_br_cdf[AOMMIN(txs_ctx, TX_32X32)][plane_type];
-    const TX_CLASS tx_class = tx_type_to_class[tx_type];
     if (tx_class == TX_CLASS_2D) {
       read_coeffs_reverse_2d(r, tx_size, 1, *eob - 1 - 1, scan, bwl, levels,
                              base_cdf, br_cdf);
-      read_coeffs_reverse(r, tx_size, tx_type, 0, 0, scan, bwl, levels,
+      read_coeffs_reverse(r, tx_size, tx_class, 0, 0, scan, bwl, levels,
                           base_cdf, br_cdf);
     } else {
-      read_coeffs_reverse(r, tx_size, tx_type, 0, *eob - 1 - 1, scan, bwl,
+      read_coeffs_reverse(r, tx_size, tx_class, 0, *eob - 1 - 1, scan, bwl,
                           levels, base_cdf, br_cdf);
     }
   }
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 3c447d5..4520a85 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -135,11 +135,11 @@
 
 #if CONFIG_ENTROPY_STATS
 void av1_update_eob_context(int cdf_idx, int eob, TX_SIZE tx_size,
-                            TX_TYPE tx_type, PLANE_TYPE plane,
+                            TX_CLASS tx_class, PLANE_TYPE plane,
                             FRAME_CONTEXT *ec_ctx, FRAME_COUNTS *counts,
                             uint8_t allow_update_cdf) {
 #else
-void av1_update_eob_context(int eob, TX_SIZE tx_size, TX_TYPE tx_type,
+void av1_update_eob_context(int eob, TX_SIZE tx_size, TX_CLASS tx_class,
                             PLANE_TYPE plane, FRAME_CONTEXT *ec_ctx,
                             uint8_t allow_update_cdf) {
 #endif
@@ -148,7 +148,7 @@
   TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
 
   const int eob_multi_size = txsize_log2_minus4[tx_size];
-  const int eob_multi_ctx = (tx_type_to_class[tx_type] == TX_CLASS_2D) ? 0 : 1;
+  const int eob_multi_ctx = (tx_class == TX_CLASS_2D) ? 0 : 1;
 
   switch (eob_multi_size) {
     case 0:
@@ -223,11 +223,11 @@
 }
 
 static int get_eob_cost(int eob, const LV_MAP_EOB_COST *txb_eob_costs,
-                        const LV_MAP_COEFF_COST *txb_costs, TX_TYPE tx_type) {
+                        const LV_MAP_COEFF_COST *txb_costs, TX_CLASS tx_class) {
   int eob_extra;
   const int eob_pt = get_eob_pos_token(eob, &eob_extra);
   int eob_cost = 0;
-  const int eob_multi_ctx = (tx_type_to_class[tx_type] == TX_CLASS_2D) ? 0 : 1;
+  const int eob_multi_ctx = (tx_class == TX_CLASS_2D) ? 0 : 1;
   eob_cost = txb_eob_costs->eob_cost[eob_multi_ctx][eob_pt - 1];
 
   if (k_eob_offset_bits[eob_pt] > 0) {
@@ -277,7 +277,7 @@
 static int get_coeff_cost(const tran_low_t qc, const int scan_idx,
                           const int is_eob, const TxbInfo *const txb_info,
                           const LV_MAP_COEFF_COST *const txb_costs,
-                          const int coeff_ctx) {
+                          const int coeff_ctx, const TX_CLASS tx_class) {
   const TXB_CTX *txb_ctx = txb_info->txb_ctx;
   const int is_nz = (qc != 0);
   const tran_low_t abs_qc = abs(qc);
@@ -296,7 +296,7 @@
 
     if (abs_qc > NUM_BASE_LEVELS) {
       const int ctx =
-          get_br_ctx(txb_info->levels, pos, txb_info->bwl, txb_info->tx_type);
+          get_br_ctx(txb_info->levels, pos, txb_info->bwl, tx_class);
       cost += get_br_cost(abs_qc, ctx, txb_costs->lps_cost[ctx]);
       cost += get_golomb_cost(abs_qc);
     }
@@ -308,14 +308,13 @@
                                  const int coeff_idx, const int bwl,
                                  const int height, const int scan_idx,
                                  const int is_eob, const TX_SIZE tx_size,
-                                 const TX_TYPE tx_type) {
+                                 const TX_CLASS tx_class) {
   if (is_eob) {
     if (scan_idx == 0) return 0;
     if (scan_idx <= (height << bwl) / 8) return 1;
     if (scan_idx <= (height << bwl) / 4) return 2;
     return 3;
   }
-  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   const int stats =
       get_nz_mag(levels + get_padded_idx(coeff_idx, bwl), bwl, tx_class);
   return get_nz_map_ctx_from_stats(stats, coeff_idx, bwl, tx_size, tx_class);
@@ -324,7 +323,8 @@
 static void get_dist_cost_stats(LevelDownStats *const stats, const int scan_idx,
                                 const int is_eob,
                                 const LV_MAP_COEFF_COST *const txb_costs,
-                                const TxbInfo *const txb_info) {
+                                const TxbInfo *const txb_info,
+                                const TX_CLASS tx_class) {
   const int16_t *const scan = txb_info->scan_order->scan;
   const int coeff_idx = scan[scan_idx];
   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
@@ -340,12 +340,11 @@
 
   const tran_low_t tqc = txb_info->tcoeff[coeff_idx];
   const int dqv = txb_info->dequant[coeff_idx != 0];
-
   const int coeff_ctx =
       get_nz_map_ctx(levels, coeff_idx, txb_info->bwl, txb_info->height,
-                     scan_idx, is_eob, txb_info->tx_size, txb_info->tx_type);
-  const int qc_cost =
-      get_coeff_cost(qc, scan_idx, is_eob, txb_info, txb_costs, coeff_ctx);
+                     scan_idx, is_eob, txb_info->tx_size, tx_class);
+  const int qc_cost = get_coeff_cost(qc, scan_idx, is_eob, txb_info, txb_costs,
+                                     coeff_ctx, tx_class);
   assert(qc != 0);
   const tran_low_t dqc = qcoeff_to_dqcoeff(qc, coeff_idx, dqv, txb_info->shift,
                                            txb_info->iqmatrix);
@@ -375,8 +374,9 @@
           get_coeff_dist(tqc, stats->low_dqc, txb_info->shift);
       stats->dist_low = low_dqc_dist - stats->dist0;
     }
-    const int low_qc_cost = get_coeff_cost(stats->low_qc, scan_idx, is_eob,
-                                           txb_info, txb_costs, coeff_ctx);
+    const int low_qc_cost =
+        get_coeff_cost(stats->low_qc, scan_idx, is_eob, txb_info, txb_costs,
+                       coeff_ctx, tx_class);
     stats->rate_low = low_qc_cost;
     stats->rd_low = RDCOST(txb_info->rdmult, stats->rate_low, stats->dist_low);
   }
@@ -384,22 +384,24 @@
 
 static void get_dist_cost_stats_with_eob(
     LevelDownStats *const stats, const int scan_idx,
-    const LV_MAP_COEFF_COST *const txb_costs, const TxbInfo *const txb_info) {
+    const LV_MAP_COEFF_COST *const txb_costs, const TxbInfo *const txb_info,
+    const TX_CLASS tx_class) {
   const int is_eob = 0;
-  get_dist_cost_stats(stats, scan_idx, is_eob, txb_costs, txb_info);
+  get_dist_cost_stats(stats, scan_idx, is_eob, txb_costs, txb_info, tx_class);
 
   const int16_t *const scan = txb_info->scan_order->scan;
   const int coeff_idx = scan[scan_idx];
   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
   const int coeff_ctx_temp = get_nz_map_ctx(
       txb_info->levels, coeff_idx, txb_info->bwl, txb_info->height, scan_idx, 1,
-      txb_info->tx_size, txb_info->tx_type);
-  const int qc_eob_cost =
-      get_coeff_cost(qc, scan_idx, 1, txb_info, txb_costs, coeff_ctx_temp);
+      txb_info->tx_size, tx_class);
+  const int qc_eob_cost = get_coeff_cost(qc, scan_idx, 1, txb_info, txb_costs,
+                                         coeff_ctx_temp, tx_class);
   int64_t rd_eob = RDCOST(txb_info->rdmult, qc_eob_cost, stats->dist);
   if (stats->low_qc != 0) {
-    const int low_qc_eob_cost = get_coeff_cost(
-        stats->low_qc, scan_idx, 1, txb_info, txb_costs, coeff_ctx_temp);
+    const int low_qc_eob_cost =
+        get_coeff_cost(stats->low_qc, scan_idx, 1, txb_info, txb_costs,
+                       coeff_ctx_temp, tx_class);
     int64_t rd_eob_low =
         RDCOST(txb_info->rdmult, low_qc_eob_cost, stats->dist_low);
     rd_eob = (rd_eob > rd_eob_low) ? rd_eob_low : rd_eob;
@@ -445,14 +447,14 @@
 
 void av1_get_nz_map_contexts_c(const uint8_t *const levels,
                                const int16_t *const scan, const uint16_t eob,
-                               const TX_SIZE tx_size, const TX_TYPE tx_type,
+                               const TX_SIZE tx_size, const TX_CLASS tx_class,
                                int8_t *const coeff_contexts) {
   const int bwl = get_txb_bwl(tx_size);
   const int height = get_txb_high(tx_size);
   for (int i = 0; i < eob; ++i) {
     const int pos = scan[i];
     coeff_contexts[pos] = get_nz_map_ctx(levels, pos, bwl, height, i,
-                                         i == eob - 1, tx_size, tx_type);
+                                         i == eob - 1, tx_size, tx_class);
   }
 }
 
@@ -464,6 +466,7 @@
   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
   const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col,
                                           tx_size, cm->reduced_tx_set_used);
+  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   const SCAN_ORDER *const scan_order = get_scan(tx_size, tx_type);
   const int16_t *const scan = scan_order->scan;
   int c;
@@ -489,7 +492,7 @@
   int eob_extra;
   const int eob_pt = get_eob_pos_token(eob, &eob_extra);
   const int eob_multi_size = txsize_log2_minus4[tx_size];
-  const int eob_multi_ctx = (tx_type_to_class[tx_type] == TX_CLASS_2D) ? 0 : 1;
+  const int eob_multi_ctx = (tx_class == TX_CLASS_2D) ? 0 : 1;
   switch (eob_multi_size) {
     case 0:
       aom_write_symbol(w, eob_pt - 1,
@@ -533,7 +536,7 @@
     }
   }
 
-  av1_get_nz_map_contexts(levels, scan, eob, tx_size, tx_type, coeff_contexts);
+  av1_get_nz_map_contexts(levels, scan, eob, tx_size, tx_class, coeff_contexts);
 
   for (c = eob - 1; c >= 0; --c) {
     const int pos = scan[c];
@@ -553,7 +556,7 @@
     if (level > NUM_BASE_LEVELS) {
       // level is above 1.
       const int base_range = level - 1 - NUM_BASE_LEVELS;
-      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_type);
+      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_class);
       for (int idx = 0; idx < COEFF_BASE_RANGE; idx += BR_CDF_SIZE - 1) {
         const int k = AOMMIN(base_range - idx, BR_CDF_SIZE - 1);
         aom_write_symbol(
@@ -668,6 +671,7 @@
   const MACROBLOCKD *const xd = &x->e_mbd;
   const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col,
                                           tx_size, cm->reduced_tx_set_used);
+  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   const tran_low_t *const qcoeff = BLOCK_OFFSET(p->qcoeff, block);
   const int txb_skip_ctx = txb_ctx->txb_skip_ctx;
   const int bwl = get_txb_bwl(tx_size);
@@ -687,9 +691,9 @@
 
   cost += av1_tx_type_cost(cm, x, xd, plane, tx_size, tx_type);
 
-  cost += get_eob_cost(eob, eob_costs, coeff_costs, tx_type);
+  cost += get_eob_cost(eob, eob_costs, coeff_costs, tx_class);
 
-  av1_get_nz_map_contexts(levels, scan, eob, tx_size, tx_type, coeff_contexts);
+  av1_get_nz_map_contexts(levels, scan, eob, tx_size, tx_class, coeff_contexts);
 
   for (int c = eob - 1; c >= 0; --c) {
     const int pos = scan[c];
@@ -714,7 +718,7 @@
         cost += av1_cost_literal(1);
       }
       if (level > NUM_BASE_LEVELS) {
-        const int ctx = get_br_ctx(levels, pos, bwl, tx_type);
+        const int ctx = get_br_ctx(levels, pos, bwl, tx_class);
         const int base_range = level - 1 - NUM_BASE_LEVELS;
         if (base_range < COEFF_BASE_RANGE) {
           cost += coeff_costs->lps_cost[ctx][base_range];
@@ -738,8 +742,9 @@
   const int16_t *const scan = txb_info->scan_order->scan;
   // forward optimize the nz_map`
   const int init_eob = txb_info->eob;
+  const TX_CLASS tx_class = tx_type_to_class[txb_info->tx_type];
   const int eob_cost =
-      get_eob_cost(init_eob, txb_eob_costs, txb_costs, txb_info->tx_type);
+      get_eob_cost(init_eob, txb_eob_costs, txb_costs, tx_class);
 
   // backward optimize the level-k map
   int accu_rate = eob_cost;
@@ -751,7 +756,8 @@
     const int si = init_eob - 1;
     const int coeff_idx = scan[si];
     LevelDownStats stats;
-    get_dist_cost_stats(&stats, si, si == init_eob - 1, txb_costs, txb_info);
+    get_dist_cost_stats(&stats, si, si == init_eob - 1, txb_costs, txb_info,
+                        tx_class);
     if ((stats.rd_low < stats.rd) && (stats.low_qc != 0)) {
       update = 1;
       update_coeff(coeff_idx, stats.low_qc, txb_info);
@@ -774,14 +780,14 @@
     if (qc == 0) {
       const int coeff_ctx =
           get_lower_levels_ctx(txb_info->levels, coeff_idx, txb_info->bwl,
-                               txb_info->tx_size, txb_info->tx_type);
+                               txb_info->tx_size, tx_class);
       accu_rate += txb_costs->base_cost[coeff_ctx][0];
     } else {
       LevelDownStats stats;
-      get_dist_cost_stats_with_eob(&stats, si, txb_costs, txb_info);
+      get_dist_cost_stats_with_eob(&stats, si, txb_costs, txb_info, tx_class);
       // check if it is better to make this the last significant coefficient
       int cur_eob_rate =
-          get_eob_cost(si + 1, txb_eob_costs, txb_costs, txb_info->tx_type);
+          get_eob_cost(si + 1, txb_eob_costs, txb_costs, tx_class);
       cur_eob_rd_cost = RDCOST(txb_info->rdmult, cur_eob_rate, 0);
       prev_eob_rd_cost =
           RDCOST(txb_info->rdmult, accu_rate, accu_dist) + stats.nz_rd;
@@ -796,7 +802,7 @@
         // rerun cost calculation due to change of eob
         accu_rate = cur_eob_rate;
         accu_dist = 0;
-        get_dist_cost_stats(&stats, si, 1, txb_costs, txb_info);
+        get_dist_cost_stats(&stats, si, 1, txb_costs, txb_info, tx_class);
         if ((stats.rd_low < stats.rd) && (stats.low_qc != 0)) {
           update = 1;
           update_coeff(coeff_idx, stats.low_qc, txb_info);
@@ -841,11 +847,11 @@
     if (qc == 0) {
       const int coeff_ctx =
           get_lower_levels_ctx(txb_info->levels, coeff_idx, txb_info->bwl,
-                               txb_info->tx_size, txb_info->tx_type);
+                               txb_info->tx_size, tx_class);
       accu_rate += txb_costs->base_cost[coeff_ctx][0];
     } else {
       LevelDownStats stats;
-      get_dist_cost_stats(&stats, si, 0, txb_costs, txb_info);
+      get_dist_cost_stats(&stats, si, 0, txb_costs, txb_info, tx_class);
 
       int bUpdCoeff = 0;
       if (stats.rd_low < stats.rd) {
@@ -1171,7 +1177,7 @@
 static INLINE int get_coeff_cost_simple(int ci, tran_low_t abs_qc,
                                         int coeff_ctx,
                                         const LV_MAP_COEFF_COST *txb_costs,
-                                        int bwl, TX_TYPE tx_type,
+                                        int bwl, TX_CLASS tx_class,
                                         const uint8_t *levels) {
   // this simple version assumes the coeff's scan_idx is not DC (scan_idx != 0)
   // and not the last (scan_idx != eob - 1)
@@ -1180,7 +1186,7 @@
   if (abs_qc) {
     cost += av1_cost_literal(1);
     if (abs_qc > NUM_BASE_LEVELS) {
-      const int br_ctx = get_br_ctx(levels, ci, bwl, tx_type);
+      const int br_ctx = get_br_ctx(levels, ci, bwl, tx_class);
       cost += get_br_cost(abs_qc, br_ctx, txb_costs->lps_cost[br_ctx]);
       cost += get_golomb_cost(abs_qc);
     }
@@ -1192,7 +1198,7 @@
                                          int sign, int coeff_ctx,
                                          int dc_sign_ctx,
                                          const LV_MAP_COEFF_COST *txb_costs,
-                                         int bwl, TX_TYPE tx_type,
+                                         int bwl, TX_CLASS tx_class,
                                          const uint8_t *levels) {
   int cost = 0;
   if (is_last) {
@@ -1207,7 +1213,7 @@
       cost += av1_cost_literal(1);
     }
     if (abs_qc > NUM_BASE_LEVELS) {
-      const int br_ctx = get_br_ctx(levels, ci, bwl, tx_type);
+      const int br_ctx = get_br_ctx(levels, ci, bwl, tx_class);
       cost += get_br_cost(abs_qc, br_ctx, txb_costs->lps_cost[br_ctx]);
       cost += get_golomb_cost(abs_qc);
     }
@@ -1228,7 +1234,7 @@
 
 static INLINE void update_coeff_general(
     int *accu_rate, int64_t *accu_dist, int si, int eob, TX_SIZE tx_size,
-    TX_TYPE tx_type, int bwl, int height, int64_t rdmult, int shift,
+    TX_CLASS tx_class, int bwl, int height, int64_t rdmult, int shift,
     int dc_sign_ctx, const int16_t *dequant, const int16_t *scan,
     const LV_MAP_COEFF_COST *txb_costs, const tran_low_t *tcoeff,
     tran_low_t *qcoeff, tran_low_t *dqcoeff, uint8_t *levels) {
@@ -1237,7 +1243,7 @@
   const tran_low_t qc = qcoeff[ci];
   const int is_last = si == (eob - 1);
   const int coeff_ctx = get_lower_levels_ctx_general(
-      is_last, si, bwl, height, levels, ci, tx_size, tx_type);
+      is_last, si, bwl, height, levels, ci, tx_size, tx_class);
   if (qc == 0) {
     *accu_rate += txb_costs->base_cost[coeff_ctx][0];
   } else {
@@ -1249,7 +1255,7 @@
     const int64_t dist0 = get_coeff_dist(tqc, 0, shift);
     const int rate =
         get_coeff_cost_general(is_last, ci, abs_qc, sign, coeff_ctx,
-                               dc_sign_ctx, txb_costs, bwl, tx_type, levels);
+                               dc_sign_ctx, txb_costs, bwl, tx_class, levels);
     const int64_t rd = RDCOST(rdmult, rate, dist);
 
     tran_low_t qc_low, dqc_low;
@@ -1258,7 +1264,7 @@
     const int64_t dist_low = get_coeff_dist(tqc, dqc_low, shift);
     const int rate_low =
         get_coeff_cost_general(is_last, ci, abs_qc_low, sign, coeff_ctx,
-                               dc_sign_ctx, txb_costs, bwl, tx_type, levels);
+                               dc_sign_ctx, txb_costs, bwl, tx_class, levels);
     const int64_t rd_low = RDCOST(rdmult, rate_low, dist_low);
     if (rd_low < rd) {
       qcoeff[ci] = qc_low;
@@ -1274,10 +1280,11 @@
 }
 
 static INLINE void update_coeff_simple(
-    int *accu_rate, int si, int eob, TX_SIZE tx_size, TX_TYPE tx_type, int bwl,
-    int64_t rdmult, int shift, const int16_t *dequant, const int16_t *scan,
-    const LV_MAP_COEFF_COST *txb_costs, const tran_low_t *tcoeff,
-    tran_low_t *qcoeff, tran_low_t *dqcoeff, uint8_t *levels) {
+    int *accu_rate, int si, int eob, TX_SIZE tx_size, TX_CLASS tx_class,
+    int bwl, int64_t rdmult, int shift, const int16_t *dequant,
+    const int16_t *scan, const LV_MAP_COEFF_COST *txb_costs,
+    const tran_low_t *tcoeff, tran_low_t *qcoeff, tran_low_t *dqcoeff,
+    uint8_t *levels) {
   const int dqv = dequant[1];
   (void)eob;
   // this simple version assumes the coeff's scan_idx is not DC (scan_idx != 0)
@@ -1286,7 +1293,8 @@
   assert(si > 0);
   const int ci = scan[si];
   const tran_low_t qc = qcoeff[ci];
-  const int coeff_ctx = get_lower_levels_ctx(levels, ci, bwl, tx_size, tx_type);
+  const int coeff_ctx =
+      get_lower_levels_ctx(levels, ci, bwl, tx_size, tx_class);
   if (qc == 0) {
     *accu_rate += txb_costs->base_cost[coeff_ctx][0];
   } else {
@@ -1295,7 +1303,7 @@
     const tran_low_t dqc = dqcoeff[ci];
     const int64_t dist = get_coeff_dist(tqc, dqc, shift);
     const int rate = get_coeff_cost_simple(ci, abs_qc, coeff_ctx, txb_costs,
-                                           bwl, tx_type, levels);
+                                           bwl, tx_class, levels);
     const int64_t rd = RDCOST(rdmult, rate, dist);
 
     const int sign = (qc < 0) ? 1 : 0;
@@ -1303,8 +1311,8 @@
     get_qc_dqc_low(abs_qc, sign, dqv, shift, &qc_low, &dqc_low);
     const tran_low_t abs_qc_low = abs_qc - 1;
     const int64_t dist_low = get_coeff_dist(tqc, dqc_low, shift);
-    const int rate_low = get_coeff_cost_simple(ci, abs_qc_low, coeff_ctx,
-                                               txb_costs, bwl, tx_type, levels);
+    const int rate_low = get_coeff_cost_simple(
+        ci, abs_qc_low, coeff_ctx, txb_costs, bwl, tx_class, levels);
     const int64_t rd_low = RDCOST(rdmult, rate_low, dist_low);
     if (rd_low < rd) {
       qcoeff[ci] = qc_low;
@@ -1319,7 +1327,7 @@
 
 static INLINE void update_coeff_eob(
     int *accu_rate, int64_t *accu_dist, int *eob, int *nz_num, int *nz_ci,
-    int si, TX_SIZE tx_size, TX_TYPE tx_type, int bwl, int height,
+    int si, TX_SIZE tx_size, TX_CLASS tx_class, int bwl, int height,
     int dc_sign_ctx, int64_t rdmult, int shift, const int16_t *dequant,
     const int16_t *scan, const LV_MAP_EOB_COST *txb_eob_costs,
     const LV_MAP_COEFF_COST *txb_costs, const tran_low_t *tcoeff,
@@ -1328,7 +1336,8 @@
   assert(si != *eob - 1);
   const int ci = scan[si];
   const tran_low_t qc = qcoeff[ci];
-  const int coeff_ctx = get_lower_levels_ctx(levels, ci, bwl, tx_size, tx_type);
+  const int coeff_ctx =
+      get_lower_levels_ctx(levels, ci, bwl, tx_size, tx_class);
   if (qc == 0) {
     *accu_rate += txb_costs->base_cost[coeff_ctx][0];
   } else {
@@ -1341,7 +1350,7 @@
     int64_t dist = get_coeff_dist(tqc, dqc, shift) - dist0;
     int rate =
         get_coeff_cost_general(0, ci, abs_qc, sign, coeff_ctx, dc_sign_ctx,
-                               txb_costs, bwl, tx_type, levels);
+                               txb_costs, bwl, tx_class, levels);
     int64_t rd = RDCOST(rdmult, *accu_rate + rate, *accu_dist + dist);
 
     tran_low_t qc_low, dqc_low;
@@ -1350,7 +1359,7 @@
     const int64_t dist_low = get_coeff_dist(tqc, dqc_low, shift) - dist0;
     const int rate_low =
         get_coeff_cost_general(0, ci, abs_qc_low, sign, coeff_ctx, dc_sign_ctx,
-                               txb_costs, bwl, tx_type, levels);
+                               txb_costs, bwl, tx_class, levels);
     const int64_t rd_low =
         RDCOST(rdmult, *accu_rate + rate_low, *accu_dist + dist_low);
 
@@ -1364,13 +1373,13 @@
     }
 
     const int coeff_ctx_new_eob = get_lower_levels_ctx_general(
-        1, si, bwl, height, levels, ci, tx_size, tx_type);
+        1, si, bwl, height, levels, ci, tx_size, tx_class);
     const int new_eob_cost =
-        get_eob_cost(new_eob, txb_eob_costs, txb_costs, tx_type);
+        get_eob_cost(new_eob, txb_eob_costs, txb_costs, tx_class);
     int rate_coeff_eob =
         new_eob_cost + get_coeff_cost_general(1, ci, abs_qc, sign,
                                               coeff_ctx_new_eob, dc_sign_ctx,
-                                              txb_costs, bwl, tx_type, levels);
+                                              txb_costs, bwl, tx_class, levels);
     int64_t dist_new_eob = dist;
     int64_t rd_new_eob = RDCOST(rdmult, rate_coeff_eob, dist_new_eob);
 
@@ -1378,7 +1387,7 @@
       const int rate_coeff_eob_low =
           new_eob_cost +
           get_coeff_cost_general(1, ci, abs_qc_low, sign, coeff_ctx_new_eob,
-                                 dc_sign_ctx, txb_costs, bwl, tx_type, levels);
+                                 dc_sign_ctx, txb_costs, bwl, tx_class, levels);
       const int64_t dist_new_eob_low = dist_low;
       const int64_t rd_new_eob_low =
           RDCOST(rdmult, rate_coeff_eob_low, dist_new_eob_low);
@@ -1458,6 +1467,7 @@
   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
   const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col,
                                           tx_size, cm->reduced_tx_set_used);
+  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   const MB_MODE_INFO *mbmi = xd->mi[0];
   const struct macroblock_plane *p = &x->plane[plane];
   struct macroblockd_plane *pd = &xd->plane[plane];
@@ -1492,7 +1502,7 @@
   const int non_skip_cost = txb_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][0];
   const int skip_cost = txb_costs->txb_skip_cost[txb_ctx->txb_skip_ctx][1];
   int eob = p->eobs[block];
-  const int eob_cost = get_eob_cost(eob, txb_eob_costs, txb_costs, tx_type);
+  const int eob_cost = get_eob_cost(eob, txb_eob_costs, txb_costs, tx_class);
   int accu_rate = eob_cost;
   int64_t accu_dist = 0;
   int si = eob - 1;
@@ -1504,17 +1514,18 @@
   int nz_num = 1;
   int nz_ci[3] = { ci, 0, 0 };
   if (abs_qc >= 2) {
-    update_coeff_general(&accu_rate, &accu_dist, si, eob, tx_size, tx_type, bwl,
-                         height, rdmult, shift, txb_ctx->dc_sign_ctx, dequant,
-                         scan, txb_costs, tcoeff, qcoeff, dqcoeff, levels);
+    update_coeff_general(&accu_rate, &accu_dist, si, eob, tx_size, tx_class,
+                         bwl, height, rdmult, shift, txb_ctx->dc_sign_ctx,
+                         dequant, scan, txb_costs, tcoeff, qcoeff, dqcoeff,
+                         levels);
     --si;
   } else {
     assert(abs_qc == 1);
     const int coeff_ctx = get_lower_levels_ctx_general(
-        1, si, bwl, height, levels, ci, tx_size, tx_type);
+        1, si, bwl, height, levels, ci, tx_size, tx_class);
     accu_rate += get_coeff_cost_general(1, ci, abs_qc, sign, coeff_ctx,
                                         txb_ctx->dc_sign_ctx, txb_costs, bwl,
-                                        tx_type, levels);
+                                        tx_class, levels);
     const tran_low_t tqc = tcoeff[ci];
     const tran_low_t dqc = dqcoeff[ci];
     const int64_t dist = get_coeff_dist(tqc, dqc, shift);
@@ -1525,7 +1536,7 @@
 
   for (; si >= 0 && nz_num <= max_nz_num; --si) {
     update_coeff_eob(&accu_rate, &accu_dist, &eob, &nz_num, nz_ci, si, tx_size,
-                     tx_type, bwl, height, txb_ctx->dc_sign_ctx, rdmult, shift,
+                     tx_class, bwl, height, txb_ctx->dc_sign_ctx, rdmult, shift,
                      dequant, scan, txb_eob_costs, txb_costs, tcoeff, qcoeff,
                      dqcoeff, levels);
   }
@@ -1536,7 +1547,7 @@
   }
 
   for (; si >= 1; --si) {
-    update_coeff_simple(&accu_rate, si, eob, tx_size, tx_type, bwl, rdmult,
+    update_coeff_simple(&accu_rate, si, eob, tx_size, tx_class, bwl, rdmult,
                         shift, dequant, scan, txb_costs, tcoeff, qcoeff,
                         dqcoeff, levels);
   }
@@ -1545,7 +1556,7 @@
   if (si == 0) {
     // no need to update accu_dist because it's not used after this point
     int64_t dummy_dist = 0;
-    update_coeff_general(&accu_rate, &dummy_dist, si, eob, tx_size, tx_type,
+    update_coeff_general(&accu_rate, &dummy_dist, si, eob, tx_size, tx_class,
                          bwl, height, rdmult, shift, txb_ctx->dc_sign_ctx,
                          dequant, scan, txb_costs, tcoeff, qcoeff, dqcoeff,
                          levels);
@@ -1734,18 +1745,19 @@
   const PLANE_TYPE plane_type = pd->plane_type;
   const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col,
                                           tx_size, cm->reduced_tx_set_used);
+  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   const SCAN_ORDER *const scan_order = get_scan(tx_size, tx_type);
   const int16_t *const scan = scan_order->scan;
 #if CONFIG_ENTROPY_STATS
-  av1_update_eob_context(cdf_idx, eob, tx_size, tx_type, plane_type, ec_ctx,
+  av1_update_eob_context(cdf_idx, eob, tx_size, tx_class, plane_type, ec_ctx,
                          td->counts, allow_update_cdf);
 #else
-  av1_update_eob_context(eob, tx_size, tx_type, plane_type, ec_ctx,
+  av1_update_eob_context(eob, tx_size, tx_class, plane_type, ec_ctx,
                          allow_update_cdf);
 #endif
 
   DECLARE_ALIGNED(16, int8_t, coeff_contexts[MAX_TX_SQUARE]);
-  av1_get_nz_map_contexts(levels, scan, eob, tx_size, tx_type, coeff_contexts);
+  av1_get_nz_map_contexts(levels, scan, eob, tx_size, tx_class, coeff_contexts);
 
   for (int c = eob - 1; c >= 0; --c) {
     const int pos = scan[c];
@@ -1778,7 +1790,7 @@
     }
     if (level > NUM_BASE_LEVELS) {
       const int base_range = level - 1 - NUM_BASE_LEVELS;
-      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_type);
+      const int br_ctx = get_br_ctx(levels, pos, bwl, tx_class);
       for (int idx = 0; idx < COEFF_BASE_RANGE; idx += BR_CDF_SIZE - 1) {
         const int k = AOMMIN(base_range - idx, BR_CDF_SIZE - 1);
         if (allow_update_cdf) {
diff --git a/av1/encoder/x86/encodetxb_sse2.c b/av1/encoder/x86/encodetxb_sse2.c
index 52f4baa..dedb4d0 100644
--- a/av1/encoder/x86/encodetxb_sse2.c
+++ b/av1/encoder/x86/encodetxb_sse2.c
@@ -433,7 +433,8 @@
 // Note: levels[] must be in the range [0, 127], inclusive.
 void av1_get_nz_map_contexts_sse2(const uint8_t *const levels,
                                   const int16_t *const scan, const uint16_t eob,
-                                  const TX_SIZE tx_size, const TX_TYPE tx_type,
+                                  const TX_SIZE tx_size,
+                                  const TX_CLASS tx_class,
                                   int8_t *const coeff_contexts) {
   const int last_idx = eob - 1;
   if (!last_idx) {
@@ -446,7 +447,6 @@
   const int width = get_txb_wide(tx_size);
   const int height = get_txb_high(tx_size);
   const int stride = width + TX_PAD_HOR;
-  const TX_CLASS tx_class = tx_type_to_class[tx_type];
   ptrdiff_t offsets[3];
 
   /* coeff_contexts must be 16 byte aligned. */
diff --git a/test/encodetxb_test.cc b/test/encodetxb_test.cc
index f1bd52c..6a4f3cf 100644
--- a/test/encodetxb_test.cc
+++ b/test/encodetxb_test.cc
@@ -34,7 +34,7 @@
 typedef void (*GetNzMapContextsFunc)(const uint8_t *const levels,
                                      const int16_t *const scan,
                                      const uint16_t eob, const TX_SIZE tx_size,
-                                     const TX_TYPE tx_type,
+                                     const TX_CLASS tx_class,
                                      int8_t *const coeff_contexts);
 
 class EncodeTxbTest : public ::testing::TestWithParam<GetNzMapContextsFunc> {
@@ -64,6 +64,7 @@
 
     for (int is_inter = 0; is_inter < 2; ++is_inter) {
       for (int tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
+        const TX_CLASS tx_class = tx_type_to_class[tx_type];
         for (int tx_size = TX_4X4; tx_size < TX_SIZES_ALL; ++tx_size) {
           const int bwl = get_txb_bwl((TX_SIZE)tx_size);
           const int width = get_txb_wide((TX_SIZE)tx_size);
@@ -78,15 +79,15 @@
               InitDataWithEob(scan, bwl, eob);
 
               av1_get_nz_map_contexts_c(levels_, scan, eob, (TX_SIZE)tx_size,
-                                        (TX_TYPE)tx_type, coeff_contexts_ref_);
+                                        tx_class, coeff_contexts_ref_);
               get_nz_map_contexts_func_(levels_, scan, eob, (TX_SIZE)tx_size,
-                                        (TX_TYPE)tx_type, coeff_contexts_);
+                                        tx_class, coeff_contexts_);
 
               result = Compare(scan, eob);
 
               EXPECT_EQ(result, 0)
-                  << " tx_class " << tx_type_to_class[tx_type] << " width "
-                  << real_width << " height " << real_height << " eob " << eob;
+                  << " tx_class " << tx_class << " width " << real_width
+                  << " height " << real_height << " eob " << eob;
             }
           }
         }
@@ -106,6 +107,7 @@
       const int real_width = tx_size_wide[tx_size];
       const int real_height = tx_size_high[tx_size];
       const TX_TYPE tx_type = DCT_DCT;
+      const TX_CLASS tx_class = tx_type_to_class[tx_type];
       const int16_t *const scan = av1_scan_orders[tx_size][tx_type].scan;
       const int eob = width * height;
       const int numTests = kNumTests / (width * height);
@@ -115,8 +117,8 @@
 
       aom_usec_timer_start(&timer);
       for (int i = 0; i < numTests; ++i) {
-        get_nz_map_contexts_func_(levels_, scan, eob, (TX_SIZE)tx_size, tx_type,
-                                  coeff_contexts_);
+        get_nz_map_contexts_func_(levels_, scan, eob, (TX_SIZE)tx_size,
+                                  tx_class, coeff_contexts_);
       }
       aom_usec_timer_mark(&timer);