Make level map coding system support rectangular tx_size
This commit makes the level map coding system support the transform
coefficients from rectangular transform block sizes.
Change-Id: I5cd6c71d12e41938f942adc98cc1e1f286336f12
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 9a51a13..d3c0664 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -37,14 +37,14 @@
};
static INLINE int get_level_count(const tran_low_t *tcoeffs, int stride,
- int row, int col, int level,
+ int height, int row, int col, int level,
int (*nb_offset)[2], int nb_num) {
int count = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
count += abs_coeff > level;
@@ -53,14 +53,15 @@
}
static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int stride,
- int row, int col, int (*nb_offset)[2], int nb_num) {
+ int height, int row, int col, int (*nb_offset)[2],
+ int nb_num) {
mag[0] = 0;
mag[1] = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
if (nb_offset[idx][0] >= 0 && nb_offset[idx][1] >= 0) {
@@ -74,15 +75,16 @@
}
}
static INLINE int get_level_count_mag(int *mag, const tran_low_t *tcoeffs,
- int stride, int row, int col, int level,
- int (*nb_offset)[2], int nb_num) {
+ int stride, int height, int row, int col,
+ int level, int (*nb_offset)[2],
+ int nb_num) {
int count = 0;
*mag = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
count += abs_coeff > level;
@@ -115,15 +117,16 @@
static INLINE int get_base_ctx(const tran_low_t *tcoeffs,
int c, // raster order
- const int bwl, const int level) {
+ const int bwl, const int height,
+ const int level) {
const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = level - 1;
int mag;
- int count =
- get_level_count_mag(&mag, tcoeffs, stride, row, col, level_minus_1,
- base_ref_offset, BASE_CONTEXT_POSITION_NUM);
+ int count = get_level_count_mag(&mag, tcoeffs, stride, height, row, col,
+ level_minus_1, base_ref_offset,
+ BASE_CONTEXT_POSITION_NUM);
int ctx_idx = get_base_ctx_from_count_mag(row, col, count, mag, level);
return ctx_idx;
}
@@ -173,15 +176,15 @@
static INLINE int get_br_ctx(const tran_low_t *tcoeffs,
const int c, // raster order
- const int bwl) {
+ const int bwl, const int height) {
const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = NUM_BASE_LEVELS;
int mag;
- const int count =
- get_level_count_mag(&mag, tcoeffs, stride, row, col, level_minus_1,
- br_ref_offset, BR_CONTEXT_POSITION_NUM);
+ const int count = get_level_count_mag(&mag, tcoeffs, stride, height, row, col,
+ level_minus_1, br_ref_offset,
+ BR_CONTEXT_POSITION_NUM);
const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
return ctx;
}
@@ -195,7 +198,7 @@
static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs,
const uint8_t *txb_mask,
const int coeff_idx, // raster order
- const int bwl) {
+ const int bwl, const int height) {
const int row = coeff_idx >> bwl;
const int col = coeff_idx - (row << bwl);
int ctx = 0;
@@ -228,7 +231,7 @@
int ref_col = col + sig_ref_offset[idx][1];
int pos;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
pos = (ref_row << bwl) + ref_col;
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 74693b5..98e4db1 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -58,6 +58,7 @@
const int16_t *const dequant = xd->plane[plane].seg_dequant[mbmi->segment_id];
const int shift = av1_get_tx_scale(tx_size);
const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
+ const int height = tx_size_high[tx_size];
int cul_level = 0;
unsigned int(*nz_map_count)[SIG_COEF_CONTEXTS][2];
uint8_t txb_mask[32 * 32] = { 0 };
@@ -87,7 +88,7 @@
for (c = 0; c < seg_eob; ++c) {
int is_nz;
- int coeff_ctx = get_nz_map_ctx(tcoeffs, txb_mask, scan[c], bwl);
+ int coeff_ctx = get_nz_map_ctx(tcoeffs, txb_mask, scan[c], bwl, height);
int eob_ctx = get_eob_ctx(tcoeffs, scan[c], txs_ctx);
if (c < seg_eob - 1)
@@ -128,7 +129,7 @@
if (*v <= i) continue;
- ctx = get_base_ctx(tcoeffs, scan[c], bwl, i + 1);
+ ctx = get_base_ctx(tcoeffs, scan[c], bwl, height, i + 1);
if (aom_read(r, coeff_base[ctx], tx_size)) {
*v = i + 1;
@@ -170,7 +171,7 @@
sign = aom_read_bit(r, ACCT_STR);
}
- ctx = get_br_ctx(tcoeffs, scan[c], bwl);
+ ctx = get_br_ctx(tcoeffs, scan[c], bwl, height);
if (cm->fc->coeff_lps[txs_ctx][plane_type][ctx] == 0) exit(0);
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index c21b2d0..1fd462a 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -86,6 +86,7 @@
int c;
int is_nz;
const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
+ const int height = tx_size_high[tx_size];
const int seg_eob = tx_size_2d[tx_size];
uint8_t txb_mask[32 * 32] = { 0 };
uint16_t update_eob = 0;
@@ -101,7 +102,7 @@
eob_flag = cm->fc->eob_flag[txs_ctx][plane_type];
for (c = 0; c < eob; ++c) {
- int coeff_ctx = get_nz_map_ctx(tcoeff, txb_mask, scan[c], bwl);
+ int coeff_ctx = get_nz_map_ctx(tcoeff, txb_mask, scan[c], bwl, height);
int eob_ctx = get_eob_ctx(tcoeff, scan[c], txs_ctx);
tran_low_t v = tcoeff[scan[c]];
@@ -130,7 +131,7 @@
if (level <= i) continue;
- ctx = get_base_ctx(tcoeff, scan[c], bwl, i + 1);
+ ctx = get_base_ctx(tcoeff, scan[c], bwl, height, i + 1);
if (level == i + 1) {
aom_write(w, 1, coeff_base[ctx]);
@@ -162,7 +163,7 @@
}
// level is above 1.
- ctx = get_br_ctx(tcoeff, scan[c], bwl);
+ ctx = get_br_ctx(tcoeff, scan[c], bwl, height);
for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
if (level == (idx + 1 + NUM_BASE_LEVELS)) {
aom_write(w, 1, cm->fc->coeff_lps[txs_ctx][plane_type][ctx]);
@@ -217,7 +218,7 @@
static INLINE void get_base_ctx_set(const tran_low_t *tcoeffs,
int c, // raster order
- const int bwl,
+ const int bwl, const int height,
int ctx_set[NUM_BASE_LEVELS]) {
const int row = c >> bwl;
const int col = c - (row << bwl);
@@ -232,7 +233,7 @@
int ref_col = col + base_ref_offset[idx][1];
int pos = (ref_row << bwl) + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
abs_coeff = abs(tcoeffs[pos]);
@@ -303,6 +304,8 @@
aom_prob *nz_map = xd->fc->nz_map[txs_ctx][plane_type];
const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
+ const int height = tx_size_high[tx_size];
+
// txb_mask is only initialized for once here. After that, it will be set when
// coding zero map and then reset when coding level 1 info.
uint8_t txb_mask[32 * 32] = { 0 };
@@ -332,7 +335,7 @@
int level = abs(v);
if (c < seg_eob) {
- int coeff_ctx = get_nz_map_ctx(qcoeff, txb_mask, scan[c], bwl);
+ int coeff_ctx = get_nz_map_ctx(qcoeff, txb_mask, scan[c], bwl, height);
cost += av1_cost_bit(nz_map[coeff_ctx], is_nz);
}
@@ -349,7 +352,7 @@
cost += av1_cost_bit(128, sign);
}
- get_base_ctx_set(qcoeff, scan[c], bwl, ctx_ls);
+ get_base_ctx_set(qcoeff, scan[c], bwl, height, ctx_ls);
int i;
for (i = 0; i < NUM_BASE_LEVELS; ++i) {
@@ -366,7 +369,7 @@
int idx;
int ctx;
- ctx = get_br_ctx(qcoeff, scan[c], bwl);
+ ctx = get_br_ctx(qcoeff, scan[c], bwl, height);
for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
if (level == (idx + 1 + NUM_BASE_LEVELS)) {
@@ -416,20 +419,20 @@
static void gen_base_count_mag_arr(int (*base_count_arr)[MAX_TX_SQUARE],
int (*base_mag_arr)[2],
const tran_low_t *qcoeff, int stride,
- int eob, const int16_t *scan) {
+ int height, int eob, const int16_t *scan) {
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
if (!has_base(qcoeff[coeff_idx], 0)) continue;
const int row = coeff_idx / stride;
const int col = coeff_idx % stride;
int *mag = base_mag_arr[coeff_idx];
- get_mag(mag, qcoeff, stride, row, col, base_ref_offset,
+ get_mag(mag, qcoeff, stride, height, row, col, base_ref_offset,
BASE_CONTEXT_POSITION_NUM);
for (int i = 0; i < NUM_BASE_LEVELS; ++i) {
if (!has_base(qcoeff[coeff_idx], i)) continue;
int *count = base_count_arr[i] + coeff_idx;
- *count = get_level_count(qcoeff, stride, row, col, i, base_ref_offset,
- BASE_CONTEXT_POSITION_NUM);
+ *count = get_level_count(qcoeff, stride, height, row, col, i,
+ base_ref_offset, BASE_CONTEXT_POSITION_NUM);
}
}
}
@@ -486,8 +489,8 @@
}
static void gen_br_count_mag_arr(int *br_count_arr, int (*br_mag_arr)[2],
- const tran_low_t *qcoeff, int stride, int eob,
- const int16_t *scan) {
+ const tran_low_t *qcoeff, int stride,
+ int height, int eob, const int16_t *scan) {
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
if (!has_br(qcoeff[coeff_idx])) continue;
@@ -495,9 +498,9 @@
const int col = coeff_idx % stride;
int *count = br_count_arr + coeff_idx;
int *mag = br_mag_arr[coeff_idx];
- *count = get_level_count(qcoeff, stride, row, col, NUM_BASE_LEVELS,
+ *count = get_level_count(qcoeff, stride, height, row, col, NUM_BASE_LEVELS,
br_ref_offset, BR_CONTEXT_POSITION_NUM);
- get_mag(mag, qcoeff, stride, row, col, br_ref_offset,
+ get_mag(mag, qcoeff, stride, height, row, col, br_ref_offset,
BR_CONTEXT_POSITION_NUM);
}
}
@@ -556,13 +559,14 @@
txb_info->qcoeff, txb_info->bwl, txb_info->eob,
txb_info->scan_order);
gen_base_count_mag_arr(txb_cache->base_count_arr, txb_cache->base_mag_arr,
- txb_info->qcoeff, txb_info->stride, txb_info->eob,
- scan);
+ txb_info->qcoeff, txb_info->stride, txb_info->height,
+ txb_info->eob, scan);
gen_base_ctx_arr(txb_cache->base_ctx_arr, txb_cache->base_count_arr,
txb_cache->base_mag_arr, txb_info->qcoeff, txb_info->stride,
txb_info->eob, scan);
gen_br_count_mag_arr(txb_cache->br_count_arr, txb_cache->br_mag_arr,
- txb_info->qcoeff, txb_info->stride, txb_info->eob, scan);
+ txb_info->qcoeff, txb_info->stride, txb_info->height,
+ txb_info->eob, scan);
gen_br_ctx_arr(txb_cache->br_ctx_arr, txb_cache->br_count_arr,
txb_cache->br_mag_arr, txb_info->qcoeff, txb_info->stride,
txb_info->eob, scan);
@@ -1131,7 +1135,8 @@
txb_ctx->dc_sign_ctx);
int ctx_ls[NUM_BASE_LEVELS] = { 0 };
- get_base_ctx_set(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, ctx_ls);
+ get_base_ctx_set(txb_info->qcoeff, scan[scan_idx], txb_info->bwl,
+ txb_info->height, ctx_ls);
int i;
for (i = 0; i < NUM_BASE_LEVELS; ++i) {
@@ -1139,7 +1144,8 @@
}
if (abs_qc > NUM_BASE_LEVELS) {
- int ctx = get_br_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl);
+ int ctx = get_br_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl,
+ txb_info->height);
cost += get_br_cost(abs_qc, ctx, txb_probs->coeff_lps);
cost += get_golomb_cost(abs_qc);
}
@@ -1557,6 +1563,7 @@
get_txb_ctx(plane_bsize, tx_size, plane, pd->above_context + blk_col,
pd->left_context + blk_row, &txb_ctx);
const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
+ const int height = tx_size_high[tx_size];
int cul_level = 0;
unsigned int(*nz_map_count)[SIG_COEF_CONTEXTS][2];
uint8_t txb_mask[32 * 32] = { 0 };
@@ -1585,7 +1592,7 @@
for (c = 0; c < eob; ++c) {
tran_low_t v = qcoeff[scan[c]];
int is_nz = (v != 0);
- int coeff_ctx = get_nz_map_ctx(tcoeff, txb_mask, scan[c], bwl);
+ int coeff_ctx = get_nz_map_ctx(tcoeff, txb_mask, scan[c], bwl, height);
int eob_ctx = get_eob_ctx(tcoeff, scan[c], txsize_ctx);
if (c == seg_eob - 1) break;
@@ -1608,7 +1615,7 @@
if (level <= i) continue;
- ctx = get_base_ctx(tcoeff, scan[c], bwl, i + 1);
+ ctx = get_base_ctx(tcoeff, scan[c], bwl, height, i + 1);
if (level == i + 1) {
++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][1];
@@ -1643,7 +1650,7 @@
}
// level is above 1.
- ctx = get_br_ctx(tcoeff, scan[c], bwl);
+ ctx = get_br_ctx(tcoeff, scan[c], bwl, height);
for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
if (level == (idx + 1 + NUM_BASE_LEVELS)) {
++td->counts->coeff_lps[txsize_ctx][plane_type][ctx][1];