Make get_nz_count support rectangular tx_size Change-Id: I44bea2cda7c57d82a79a906f52c18e188f1fedea
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h index cbd0adf..9a51a13 100644 --- a/av1/common/txb_common.h +++ b/av1/common/txb_common.h
@@ -257,14 +257,15 @@ return 14 + ctx; } -static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride, int row, - int col, const int16_t *iscan) { +static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride, + int height, int row, int col, + const int16_t *iscan) { int count = 0; const int pos = row * stride + col; for (int idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) { const int ref_row = row + sig_ref_offset[idx][0]; const int ref_col = col + sig_ref_offset[idx][1]; - if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride) + if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride) continue; const int nb_pos = ref_row * stride + ref_col; if (iscan[nb_pos] < iscan[pos]) count += (tcoeffs[nb_pos] != 0); @@ -328,11 +329,12 @@ // testing static INLINE int get_nz_map_ctx2(const tran_low_t *tcoeffs, const int coeff_idx, // raster order - const int bwl, const int16_t *iscan) { + const int bwl, const int height, + const int16_t *iscan) { int stride = 1 << bwl; const int row = coeff_idx >> bwl; const int col = coeff_idx - (row << bwl); - int count = get_nz_count(tcoeffs, stride, row, col, iscan); + int count = get_nz_count(tcoeffs, stride, height, row, col, iscan); return get_nz_map_ctx_from_count(count, tcoeffs, coeff_idx, bwl, iscan); }
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c index bb334fa..c21b2d0 100644 --- a/av1/encoder/encodetxb.c +++ b/av1/encoder/encodetxb.c
@@ -435,7 +435,7 @@ } static void gen_nz_count_arr(int(*nz_count_arr), const tran_low_t *qcoeff, - int stride, int eob, + int stride, int height, int eob, const SCAN_ORDER *scan_order) { const int16_t *scan = scan_order->scan; const int16_t *iscan = scan_order->iscan; @@ -443,7 +443,8 @@ const int coeff_idx = scan[c]; // raster order const int row = coeff_idx / stride; const int col = coeff_idx % stride; - nz_count_arr[coeff_idx] = get_nz_count(qcoeff, stride, row, col, iscan); + nz_count_arr[coeff_idx] = + get_nz_count(qcoeff, stride, height, row, col, iscan); } } @@ -550,7 +551,7 @@ void gen_txb_cache(TxbCache *txb_cache, TxbInfo *txb_info) { const int16_t *scan = txb_info->scan_order->scan; gen_nz_count_arr(txb_cache->nz_count_arr, txb_info->qcoeff, txb_info->stride, - txb_info->eob, txb_info->scan_order); + txb_info->height, txb_info->eob, txb_info->scan_order); gen_nz_ctx_arr(txb_cache->nz_ctx_arr, txb_cache->nz_count_arr, txb_info->qcoeff, txb_info->bwl, txb_info->eob, txb_info->scan_order); @@ -1120,8 +1121,8 @@ const int16_t *iscan = txb_info->scan_order->iscan; if (scan_idx < txb_info->seg_eob) { - int coeff_ctx = - get_nz_map_ctx2(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, iscan); + int coeff_ctx = get_nz_map_ctx2(txb_info->qcoeff, scan[scan_idx], + txb_info->bwl, txb_info->height, iscan); cost += av1_cost_bit(txb_probs->nz_map[coeff_ctx], is_nz); } @@ -1458,6 +1459,7 @@ const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2; const int stride = 1 << bwl; + const int height = tx_size_high[tx_size]; aom_prob(*coeff_base)[COEFF_BASE_CONTEXTS] = xd->fc->coeff_base[txs_ctx][plane_type]; @@ -1479,9 +1481,9 @@ (x->rdmult * plane_rd_mult[is_inter][plane_type] + 2) >> 2; const int64_t rddiv = x->rddiv; - TxbInfo txb_info = { qcoeff, dqcoeff, tcoeff, dequant, shift, - tx_size, txs_ctx, bwl, stride, eob, - seg_eob, scan_order, txb_ctx, rdmult, rddiv }; + TxbInfo txb_info = { qcoeff, dqcoeff, tcoeff, dequant, shift, tx_size, + txs_ctx, bwl, stride, height, eob, seg_eob, + scan_order, txb_ctx, rdmult, rddiv }; TxbCache txb_cache; gen_txb_cache(&txb_cache, &txb_info);
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h index 66d86e5..57c0a61 100644 --- a/av1/encoder/encodetxb.h +++ b/av1/encoder/encodetxb.h
@@ -33,6 +33,7 @@ TX_SIZE txs_ctx; int bwl; int stride; + int height; int eob; int seg_eob; const SCAN_ORDER *scan_order;