Make get_nz_count support rectangular tx_size

Change-Id: I44bea2cda7c57d82a79a906f52c18e188f1fedea
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index cbd0adf..9a51a13 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -257,14 +257,15 @@
   return 14 + ctx;
 }
 
-static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride, int row,
-                               int col, const int16_t *iscan) {
+static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride,
+                               int height, int row, int col,
+                               const int16_t *iscan) {
   int count = 0;
   const int pos = row * stride + col;
   for (int idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
     const int ref_row = row + sig_ref_offset[idx][0];
     const int ref_col = col + sig_ref_offset[idx][1];
-    if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+    if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
       continue;
     const int nb_pos = ref_row * stride + ref_col;
     if (iscan[nb_pos] < iscan[pos]) count += (tcoeffs[nb_pos] != 0);
@@ -328,11 +329,12 @@
 // testing
 static INLINE int get_nz_map_ctx2(const tran_low_t *tcoeffs,
                                   const int coeff_idx,  // raster order
-                                  const int bwl, const int16_t *iscan) {
+                                  const int bwl, const int height,
+                                  const int16_t *iscan) {
   int stride = 1 << bwl;
   const int row = coeff_idx >> bwl;
   const int col = coeff_idx - (row << bwl);
-  int count = get_nz_count(tcoeffs, stride, row, col, iscan);
+  int count = get_nz_count(tcoeffs, stride, height, row, col, iscan);
   return get_nz_map_ctx_from_count(count, tcoeffs, coeff_idx, bwl, iscan);
 }
 
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index bb334fa..c21b2d0 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -435,7 +435,7 @@
 }
 
 static void gen_nz_count_arr(int(*nz_count_arr), const tran_low_t *qcoeff,
-                             int stride, int eob,
+                             int stride, int height, int eob,
                              const SCAN_ORDER *scan_order) {
   const int16_t *scan = scan_order->scan;
   const int16_t *iscan = scan_order->iscan;
@@ -443,7 +443,8 @@
     const int coeff_idx = scan[c];  // raster order
     const int row = coeff_idx / stride;
     const int col = coeff_idx % stride;
-    nz_count_arr[coeff_idx] = get_nz_count(qcoeff, stride, row, col, iscan);
+    nz_count_arr[coeff_idx] =
+        get_nz_count(qcoeff, stride, height, row, col, iscan);
   }
 }
 
@@ -550,7 +551,7 @@
 void gen_txb_cache(TxbCache *txb_cache, TxbInfo *txb_info) {
   const int16_t *scan = txb_info->scan_order->scan;
   gen_nz_count_arr(txb_cache->nz_count_arr, txb_info->qcoeff, txb_info->stride,
-                   txb_info->eob, txb_info->scan_order);
+                   txb_info->height, txb_info->eob, txb_info->scan_order);
   gen_nz_ctx_arr(txb_cache->nz_ctx_arr, txb_cache->nz_count_arr,
                  txb_info->qcoeff, txb_info->bwl, txb_info->eob,
                  txb_info->scan_order);
@@ -1120,8 +1121,8 @@
   const int16_t *iscan = txb_info->scan_order->iscan;
 
   if (scan_idx < txb_info->seg_eob) {
-    int coeff_ctx =
-        get_nz_map_ctx2(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, iscan);
+    int coeff_ctx = get_nz_map_ctx2(txb_info->qcoeff, scan[scan_idx],
+                                    txb_info->bwl, txb_info->height, iscan);
     cost += av1_cost_bit(txb_probs->nz_map[coeff_ctx], is_nz);
   }
 
@@ -1458,6 +1459,7 @@
 
   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
   const int stride = 1 << bwl;
+  const int height = tx_size_high[tx_size];
   aom_prob(*coeff_base)[COEFF_BASE_CONTEXTS] =
       xd->fc->coeff_base[txs_ctx][plane_type];
 
@@ -1479,9 +1481,9 @@
       (x->rdmult * plane_rd_mult[is_inter][plane_type] + 2) >> 2;
   const int64_t rddiv = x->rddiv;
 
-  TxbInfo txb_info = { qcoeff,  dqcoeff,    tcoeff,  dequant, shift,
-                       tx_size, txs_ctx,    bwl,     stride,  eob,
-                       seg_eob, scan_order, txb_ctx, rdmult,  rddiv };
+  TxbInfo txb_info = { qcoeff,     dqcoeff, tcoeff, dequant, shift, tx_size,
+                       txs_ctx,    bwl,     stride, height,  eob,   seg_eob,
+                       scan_order, txb_ctx, rdmult, rddiv };
   TxbCache txb_cache;
   gen_txb_cache(&txb_cache, &txb_info);
 
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index 66d86e5..57c0a61 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -33,6 +33,7 @@
   TX_SIZE txs_ctx;
   int bwl;
   int stride;
+  int height;
   int eob;
   int seg_eob;
   const SCAN_ORDER *scan_order;