Split get_level_count_mag()
to get_level_count() and get_level_mag() since they actually
calculate in different levels and get_level_mag() is hard to be SIMDed.
Change-Id: Iedb12a1d592cf09425e5a77e6bdc9990c271c872
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 71471da..9f040e9 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -46,6 +46,11 @@
/* clang-format on*/
};
+#define CONTEXT_MAG_POSITION_NUM 3
+static const int mag_ref_offset[CONTEXT_MAG_POSITION_NUM][2] = {
+ { 0, 1 }, { 1, 0 }, { 1, 1 }
+};
+
static INLINE void get_base_count_mag(int *mag, int *count,
const tran_low_t *tcoeffs, int bwl,
int height, int row, int col) {
@@ -84,12 +89,9 @@
return idx + TX_PAD_HOR * (idx >> bwl);
}
-static INLINE int get_level_count_mag(int *const mag,
- const uint8_t *const levels,
- const int stride, const int row,
- const int col, const int level,
- const int (*nb_offset)[2],
- const int nb_num) {
+static INLINE int get_level_count(const uint8_t *const levels, const int stride,
+ const int row, const int col, const int level,
+ const int (*nb_offset)[2], const int nb_num) {
int count = 0;
for (int idx = 0; idx < nb_num; ++idx) {
@@ -97,13 +99,20 @@
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
count += levels[pos] > level;
- if (nb_offset[idx][0] == 0 && nb_offset[idx][1] == 1) mag[0] = levels[pos];
- if (nb_offset[idx][0] == 1 && nb_offset[idx][1] == 0) mag[1] = levels[pos];
- if (nb_offset[idx][0] == 1 && nb_offset[idx][1] == 1) mag[2] = levels[pos];
}
return count;
}
+static INLINE void get_level_mag(const uint8_t *const levels, const int stride,
+ const int row, const int col, int *const mag) {
+ for (int idx = 0; idx < CONTEXT_MAG_POSITION_NUM; ++idx) {
+ const int ref_row = row + mag_ref_offset[idx][0];
+ const int ref_col = col + mag_ref_offset[idx][1];
+ const int pos = ref_row * stride + ref_col;
+ mag[idx] = levels[pos];
+ }
+}
+
static INLINE int get_base_ctx_from_count_mag(int row, int col, int count,
int sig_mag) {
const int ctx = base_level_count_to_index[count];
@@ -170,18 +179,18 @@
static INLINE int get_base_ctx(const uint8_t *const levels,
const int c, // raster order
- const int bwl, const int level) {
+ const int bwl, const int level_minus_1) {
const int row = c >> bwl;
const int col = c - (row << bwl);
const int stride = (1 << bwl) + TX_PAD_HOR;
- const int level_minus_1 = level - 1;
int mag_count = 0;
int nb_mag[3] = { 0 };
- const int count =
- get_level_count_mag(nb_mag, levels, stride, row, col, level_minus_1,
- base_ref_offset, BASE_CONTEXT_POSITION_NUM);
+ const int count = get_level_count(levels, stride, row, col, level_minus_1,
+ base_ref_offset, BASE_CONTEXT_POSITION_NUM);
+ get_level_mag(levels, stride, row, col, nb_mag);
- for (int idx = 0; idx < 3; ++idx) mag_count += nb_mag[idx] > level;
+ for (int idx = 0; idx < 3; ++idx)
+ mag_count += nb_mag[idx] > (level_minus_1 + 1);
const int ctx_idx =
get_base_ctx_from_count_mag(row, col, count, AOMMIN(2, mag_count));
return ctx_idx;
@@ -279,9 +288,9 @@
const int level_minus_1 = NUM_BASE_LEVELS;
int mag = 0;
int nb_mag[3] = { 0 };
- const int count =
- get_level_count_mag(nb_mag, levels, stride, row, col, level_minus_1,
- br_ref_offset, BR_CONTEXT_POSITION_NUM);
+ const int count = get_level_count(levels, stride, row, col, level_minus_1,
+ br_ref_offset, BR_CONTEXT_POSITION_NUM);
+ get_level_mag(levels, stride, row, col, nb_mag);
for (int idx = 0; idx < 3; ++idx) mag = AOMMAX(mag, nb_mag[idx]);
const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
return ctx;
@@ -322,7 +331,6 @@
: sig_ref_offset_horiz[idx][1]));
const int ref_row = row + row_offset;
const int ref_col = col + col_offset;
- if (ref_col >= (1 << bwl)) continue;
const int nb_pos = ref_row * stride + ref_col;
const int level = levels[nb_pos];
count += (level != 0);
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 70006f7..f9c8710 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -213,7 +213,7 @@
if (*level <= i) continue;
- ctx = get_base_ctx(levels, scan[c], bwl, i + 1);
+ ctx = get_base_ctx(levels, scan[c], bwl, i);
if (av1_read_record_bin(
counts, r, ec_ctx->coeff_base_cdf[txs_ctx][plane_type][i][ctx], 2,
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 7549795..85d8630 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -389,7 +389,7 @@
if (level <= i) continue;
- ctx = get_base_ctx(levels, scan[c], bwl, i + 1);
+ ctx = get_base_ctx(levels, scan[c], bwl, i);
if (level == i + 1) {
aom_write_bin(w, 1, ec_ctx->coeff_base_cdf[txs_ctx][plane_type][i][ctx],
@@ -1367,7 +1367,7 @@
txb_cache->base_ctx_arr[base_idx][nb_coeff_idx] =
base_ctx_table[nb_row != 0][nb_col != 0][mag > level][count];
// int ref_ctx = get_base_ctx(txb_info->qcoeff, nb_coeff_idx,
- // txb_info->bwl, level);
+ // txb_info->bwl, level - 1);
// if (ref_ctx != txb_cache->base_ctx_arr[base_idx][nb_coeff_idx]) {
// printf("base ctx %d ref_ctx %d\n",
// txb_cache->base_ctx_arr[base_idx][nb_coeff_idx], ref_ctx);
@@ -2103,7 +2103,7 @@
if (level <= i) continue;
- ctx = get_base_ctx(levels, scan[c], bwl, i + 1);
+ ctx = get_base_ctx(levels, scan[c], bwl, i);
if (level == i + 1) {
++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][1];