Add test code for optimize_txb()
Change-Id: Ieae4c1a1c932d375b4577c7e42a9764e5f9cd16a
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 590670d..ba0ae1c 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -957,11 +957,11 @@
int cost_diff = 0;
cost_diff -= get_low_coeff_cost(coeff_idx, txb_cache, txb_probs, txb_info);
// int coeff_cost =
- // av1_get_coeff_cost(qc, scan_idx, txb_info, txb_probs, txb_cache);
+ // get_coeff_cost(qc, scan_idx, txb_info, txb_probs);
// if (-cost_diff != coeff_cost) {
// printf("-cost_diff %d coeff_cost %d\n", -cost_diff, coeff_cost);
// get_low_coeff_cost(coeff_idx, txb_cache, txb_probs, txb_info);
- // av1_get_coeff_cost(qc, scan_idx, txb_info, txb_probs, txb_cache);
+ // get_coeff_cost(qc, scan_idx, txb_info, txb_probs);
// }
for (int si = scan_idx - 1; si >= 0; --si) {
const int ci = scan[si];
@@ -1103,7 +1103,7 @@
}
static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info,
- TxbProbs *txb_probs) {
+ const TxbProbs *txb_probs) {
const TXB_CTX *txb_ctx = txb_info->txb_ctx;
const int is_nz = (qc != 0);
const tran_low_t abs_qc = abs(qc);
@@ -1146,14 +1146,68 @@
}
#if TEST_OPTIMIZE_TXB
-static void test_level_down(int coeff_idx, TxbCache *txb_cache,
- TxbProbs *txb_probs, TxbInfo *txb_info) {
+#define ALL_REF_OFFSET_NUM 17
+static int all_ref_offset[ALL_REF_OFFSET_NUM][2] = {
+ { 0, 0 }, { -2, -1 }, { -2, 0 }, { -2, 1 }, { -1, -2 }, { -1, -1 },
+ { -1, 0 }, { -1, 1 }, { 0, -2 }, { 0, -1 }, { 1, -2 }, { 1, -1 },
+ { 1, 0 }, { 2, 0 }, { 0, 1 }, { 0, 2 }, { 1, 1 },
+};
+
+static int try_level_down_ref(int coeff_idx, const TxbProbs *txb_probs,
+ TxbInfo *txb_info,
+ int (*cost_map)[COST_MAP_SIZE]) {
+ if (cost_map) {
+ for (int i = 0; i < COST_MAP_SIZE; ++i) av1_zero(cost_map[i]);
+ }
+ tran_low_t qc = txb_info->qcoeff[coeff_idx];
+ if (qc == 0) return 0;
+ int row = coeff_idx >> txb_info->bwl;
+ int col = coeff_idx - (row << txb_info->bwl);
+ int org_cost = 0;
+ for (int i = 0; i < ALL_REF_OFFSET_NUM; ++i) {
+ int nb_row = row - all_ref_offset[i][0];
+ int nb_col = col - all_ref_offset[i][1];
+ int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
+ int nb_scan_idx = txb_info->scan_order->iscan[nb_coeff_idx];
+ if (nb_scan_idx < txb_info->eob && nb_row >= 0 && nb_col >= 0 &&
+ nb_row < txb_info->stride && nb_col < txb_info->stride) {
+ tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
+ int cost = get_coeff_cost(nb_coeff, nb_scan_idx, txb_info, txb_probs);
+ if (cost_map)
+ cost_map[nb_row - row + COST_MAP_OFFSET]
+ [nb_col - col + COST_MAP_OFFSET] -= cost;
+ org_cost += cost;
+ }
+ }
+ txb_info->qcoeff[coeff_idx] = get_lower_coeff(qc);
+ int new_cost = 0;
+ for (int i = 0; i < ALL_REF_OFFSET_NUM; ++i) {
+ int nb_row = row - all_ref_offset[i][0];
+ int nb_col = col - all_ref_offset[i][1];
+ int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
+ int nb_scan_idx = txb_info->scan_order->iscan[nb_coeff_idx];
+ if (nb_scan_idx < txb_info->eob && nb_row >= 0 && nb_col >= 0 &&
+ nb_row < txb_info->stride && nb_col < txb_info->stride) {
+ tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
+ int cost = get_coeff_cost(nb_coeff, nb_scan_idx, txb_info, txb_probs);
+ if (cost_map)
+ cost_map[nb_row - row + COST_MAP_OFFSET]
+ [nb_col - col + COST_MAP_OFFSET] += cost;
+ new_cost += cost;
+ }
+ }
+ txb_info->qcoeff[coeff_idx] = qc;
+ return new_cost - org_cost;
+}
+
+static void test_level_down(int coeff_idx, const TxbCache *txb_cache,
+ const TxbProbs *txb_probs, TxbInfo *txb_info) {
int cost_map[COST_MAP_SIZE][COST_MAP_SIZE];
int ref_cost_map[COST_MAP_SIZE][COST_MAP_SIZE];
const int cost_diff =
try_level_down(coeff_idx, txb_cache, txb_probs, txb_info, cost_map);
- const int cost_diff_ref = try_level_down_ref(coeff_idx, txb_cache, txb_probs,
- txb_info, ref_cost_map);
+ const int cost_diff_ref =
+ try_level_down_ref(coeff_idx, txb_probs, txb_info, ref_cost_map);
if (cost_diff != cost_diff_ref) {
printf("qc %d cost_diff %d cost_diff_ref %d\n", txb_info->qcoeff[coeff_idx],
cost_diff, cost_diff_ref);
@@ -1168,7 +1222,7 @@
#endif
// TODO(angiebird): make this static once it's called
-int get_txb_cost(TxbInfo *txb_info, TxbProbs *txb_probs) {
+int get_txb_cost(TxbInfo *txb_info, const TxbProbs *txb_probs) {
int cost = 0;
int txb_skip_ctx = txb_info->txb_ctx->txb_skip_ctx;
const int16_t *scan = txb_info->scan_order->scan;
@@ -1185,6 +1239,37 @@
return cost;
}
+#if TEST_OPTIMIZE_TXB
+void test_try_change_eob(TxbInfo *txb_info, TxbProbs *txb_probs,
+ TxbCache *txb_cache) {
+ int eob = txb_info->eob;
+ const int16_t *scan = txb_info->scan_order->scan;
+ if (eob > 0) {
+ int last_si = eob - 1;
+ int last_ci = scan[last_si];
+ int last_coeff = txb_info->qcoeff[last_ci];
+ if (abs(last_coeff) == 1) {
+ int new_eob;
+ int cost_diff =
+ try_change_eob(&new_eob, last_ci, txb_cache, txb_probs, txb_info);
+ int org_eob = txb_info->eob;
+ int cost = get_txb_cost(txb_info, txb_probs);
+
+ txb_info->qcoeff[last_ci] = get_lower_coeff(last_coeff);
+ set_eob(txb_info, new_eob);
+ int new_cost = get_txb_cost(txb_info, txb_probs);
+ set_eob(txb_info, org_eob);
+ txb_info->qcoeff[last_ci] = last_coeff;
+
+ int ref_cost_diff = -cost + new_cost;
+ if (cost_diff != ref_cost_diff)
+ printf("org_eob %d new_eob %d cost_diff %d ref_cost_diff %d\n", org_eob,
+ new_eob, cost_diff, ref_cost_diff);
+ }
+ }
+}
+#endif
+
static INLINE int64_t get_coeff_dist(tran_low_t tcoeff, tran_low_t dqcoeff,
int shift) {
const int64_t diff = (tcoeff - dqcoeff) * (1 << shift);
@@ -1258,7 +1343,7 @@
int64_t org_dist =
av1_block_error_c(txb_info->tcoeff, txb_info->dqcoeff, max_eob, &sse) *
(1 << (2 * txb_info->shift));
- int org_cost = get_txb_cost(txb_info, txb_probs, txb_cache);
+ int org_cost = get_txb_cost(txb_info, txb_probs);
#endif
tran_low_t *org_qcoeff = txb_info->qcoeff;
@@ -1319,7 +1404,7 @@
int64_t new_dist =
av1_block_error_c(txb_info->tcoeff, txb_info->dqcoeff, max_eob, &sse) *
(1 << (2 * txb_info->shift));
- int new_cost = get_txb_cost(txb_info, txb_probs, txb_cache);
+ int new_cost = get_txb_cost(txb_info, txb_probs);
int64_t ref_dist_diff = new_dist - org_dist;
int ref_cost_diff = new_cost - org_cost;
if (cost_diff != ref_cost_diff || dist_diff != ref_dist_diff)