Pass tx_type into get_nz_map_ctx()

Change-Id: I7b0e998182b522274768e4b587080d8e88f3a223
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index c386011..b0d5c8f 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -326,7 +326,8 @@
 static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs,
                                  const int coeff_idx,  // raster order
                                  const int bwl, const int height,
-                                 const int16_t *iscan) {
+                                 const int16_t *iscan, TX_TYPE tx_type) {
+  (void)tx_type;
   const int row = coeff_idx >> bwl;
   const int col = coeff_idx - (row << bwl);
   int count = get_nz_count(tcoeffs, bwl, height, row, col, iscan);
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 2ddb0f2..29c4e81 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -100,7 +100,8 @@
 
   for (c = 0; c < seg_eob; ++c) {
     int is_nz;
-    int coeff_ctx = get_nz_map_ctx(tcoeffs, scan[c], bwl, height, iscan);
+    int coeff_ctx =
+        get_nz_map_ctx(tcoeffs, scan[c], bwl, height, iscan, tx_type);
     int eob_ctx = get_eob_ctx(tcoeffs, scan[c], txs_ctx);
 
     if (c < seg_eob - 1) {
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 864672e..41f57ad 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -114,7 +114,8 @@
 #endif
 
   for (c = 0; c < eob; ++c) {
-    int coeff_ctx = get_nz_map_ctx(tcoeff, scan[c], bwl, height, iscan);
+    int coeff_ctx =
+        get_nz_map_ctx(tcoeff, scan[c], bwl, height, iscan, tx_type);
     int eob_ctx = get_eob_ctx(tcoeff, scan[c], txs_ctx);
 
     tran_low_t v = tcoeff[scan[c]];
@@ -383,7 +384,8 @@
     int level = abs(v);
 
     if (c < seg_eob) {
-      int coeff_ctx = get_nz_map_ctx(qcoeff, scan[c], bwl, height, iscan);
+      int coeff_ctx =
+          get_nz_map_ctx(qcoeff, scan[c], bwl, height, iscan, tx_type);
       cost += coeff_costs->nz_map_cost[coeff_ctx][is_nz];
     }
 
@@ -1024,7 +1026,7 @@
         txb_cache->nz_ctx_arr[nb_coeff_idx] = get_nz_map_ctx_from_count(
             count, txb_info->qcoeff, nb_coeff_idx, txb_info->bwl, iscan);
         // int ref_ctx = get_nz_map_ctx(txb_info->qcoeff, nb_coeff_idx,
-        // txb_info->bwl, iscan);
+        // txb_info->bwl, iscan, tx_type);
         // if (ref_ctx != txb_cache->nz_ctx_arr[nb_coeff_idx])
         //   printf("nz ctx %d ref_ctx %d\n",
         //   txb_cache->nz_ctx_arr[nb_coeff_idx], ref_ctx);
@@ -1115,8 +1117,9 @@
   const int16_t *iscan = txb_info->scan_order->iscan;
 
   if (scan_idx < txb_info->seg_eob) {
-    int coeff_ctx = get_nz_map_ctx(txb_info->qcoeff, scan[scan_idx],
-                                   txb_info->bwl, txb_info->height, iscan);
+    int coeff_ctx =
+        get_nz_map_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl,
+                       txb_info->height, iscan, txb_info->tx_type);
     cost += txb_costs->nz_map_cost[coeff_ctx][is_nz];
   }
 
@@ -1467,10 +1470,23 @@
   const int64_t rdmult =
       (x->rdmult * plane_rd_mult[is_inter][plane_type] + 2) >> 2;
 
-  TxbInfo txb_info = { qcoeff,     dqcoeff, tcoeff,  dequant,
-                       shift,      tx_size, txs_ctx, bwl,
-                       stride,     height,  eob,     seg_eob,
-                       scan_order, txb_ctx, rdmult,  &cm->coeff_ctx_table };
+  TxbInfo txb_info = { qcoeff,
+                       dqcoeff,
+                       tcoeff,
+                       dequant,
+                       shift,
+                       tx_size,
+                       txs_ctx,
+                       tx_type,
+                       bwl,
+                       stride,
+                       height,
+                       eob,
+                       seg_eob,
+                       scan_order,
+                       txb_ctx,
+                       rdmult,
+                       &cm->coeff_ctx_table };
 
   TxbCache txb_cache;
   gen_txb_cache(&txb_cache, &txb_info);
@@ -1579,7 +1595,8 @@
   for (c = 0; c < eob; ++c) {
     tran_low_t v = qcoeff[scan[c]];
     int is_nz = (v != 0);
-    int coeff_ctx = get_nz_map_ctx(tcoeff, scan[c], bwl, height, iscan);
+    int coeff_ctx =
+        get_nz_map_ctx(tcoeff, scan[c], bwl, height, iscan, tx_type);
     int eob_ctx = get_eob_ctx(tcoeff, scan[c], txsize_ctx);
 
     if (c == seg_eob - 1) break;
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index 6e549d7..57c2f1f 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -31,6 +31,7 @@
   int shift;
   TX_SIZE tx_size;
   TX_SIZE txs_ctx;
+  TX_TYPE tx_type;
   int bwl;
   int stride;
   int height;