Add try_level_down()

This function computes the overall (i.e. self and neighbors') cost
difference caused by downgrading a coefficient by one level at
a specific location

Change-Id: I1b7b6acfe06ed06b9a2ff48b5bb11527646d1aa8
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index bac23e9..3748798 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -589,10 +589,161 @@
   return qc > 0 ? qc - 1 : qc + 1;
 }
 
-// TODO(angiebird): add static to this function once it's called
-int try_self_level_down(tran_low_t *low_coeff, int coeff_idx,
-                        TxbCache *txb_cache, TxbProbs *txb_probs,
-                        TxbInfo *txb_info) {
+static INLINE int get_mag_from_mag_arr(const int *mag_arr) {
+  int mag;
+  if (mag_arr[1] > 0) {
+    mag = mag_arr[0];
+  } else if (mag_arr[0] > 0) {
+    mag = mag_arr[0] - 1;
+  } else {
+    // no neighbor
+    assert(mag_arr[0] == 0 && mag_arr[1] == 0);
+    mag = 0;
+  }
+  return mag;
+}
+
+static int neighbor_level_down_update(int *new_count, int *new_mag, int count,
+                                      const int *mag, int coeff_idx,
+                                      tran_low_t abs_nb_coeff, int nb_coeff_idx,
+                                      int level, const TxbInfo *txb_info) {
+  *new_count = count;
+  *new_mag = get_mag_from_mag_arr(mag);
+
+  int update = 0;
+  // check if br_count changes
+  if (abs_nb_coeff == level) {
+    update = 1;
+    *new_count -= 1;
+    assert(*new_count >= 0);
+  }
+  const int row = coeff_idx >> txb_info->bwl;
+  const int col = coeff_idx - (row << txb_info->bwl);
+  const int nb_row = nb_coeff_idx >> txb_info->bwl;
+  const int nb_col = nb_coeff_idx - (nb_row << txb_info->bwl);
+
+  // check if mag changes
+  if (nb_row >= row && nb_col >= col) {
+    if (abs_nb_coeff == mag[0]) {
+      assert(mag[1] > 0);
+      if (mag[1] == 1) {
+        // the nb is the only qc with max mag
+        *new_mag -= 1;
+        assert(*new_mag >= 0);
+        update = 1;
+      }
+    }
+  }
+  return update;
+}
+
+static int try_neighbor_level_down_br(int coeff_idx, int nb_coeff_idx,
+                                      const TxbCache *txb_cache,
+                                      const TxbProbs *txb_probs,
+                                      const TxbInfo *txb_info) {
+  const tran_low_t qc = txb_info->qcoeff[coeff_idx];
+  const tran_low_t abs_qc = abs(qc);
+  const int level = NUM_BASE_LEVELS + 1;
+  if (abs_qc < level) return 0;
+
+  const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
+  const tran_low_t abs_nb_coeff = abs(nb_coeff);
+  const int count = txb_cache->br_count_arr[coeff_idx];
+  const int *mag = txb_cache->br_mag_arr[coeff_idx];
+  int new_count;
+  int new_mag;
+  const int update =
+      neighbor_level_down_update(&new_count, &new_mag, count, mag, coeff_idx,
+                                 abs_nb_coeff, nb_coeff_idx, level, txb_info);
+  if (update) {
+    const int row = coeff_idx >> txb_info->bwl;
+    const int col = coeff_idx - (row << txb_info->bwl);
+    const int ctx = txb_cache->br_ctx_arr[coeff_idx][0];
+    const int org_cost = get_br_cost(abs_qc, ctx, txb_probs->coeff_lps);
+
+    const int new_ctx = get_br_ctx_from_count_mag(row, col, new_count, new_mag);
+    const int new_cost = get_br_cost(abs_qc, new_ctx, txb_probs->coeff_lps);
+    const int cost_diff = -org_cost + new_cost;
+    return cost_diff;
+  } else {
+    return 0;
+  }
+}
+
+static int try_neighbor_level_down_base(int coeff_idx, int nb_coeff_idx,
+                                        const TxbCache *txb_cache,
+                                        const TxbProbs *txb_probs,
+                                        const TxbInfo *txb_info) {
+  const tran_low_t qc = txb_info->qcoeff[coeff_idx];
+  const tran_low_t abs_qc = abs(qc);
+
+  int cost_diff = 0;
+  for (int base_idx = 0; base_idx < NUM_BASE_LEVELS; ++base_idx) {
+    const int level = base_idx + 1;
+    if (abs_qc < level) continue;
+
+    const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
+    const tran_low_t abs_nb_coeff = abs(nb_coeff);
+
+    const int count = txb_cache->base_count_arr[base_idx][coeff_idx];
+    const int *mag = txb_cache->base_mag_arr[coeff_idx];
+    int new_count;
+    int new_mag;
+    const int update =
+        neighbor_level_down_update(&new_count, &new_mag, count, mag, coeff_idx,
+                                   abs_nb_coeff, nb_coeff_idx, level, txb_info);
+    if (update) {
+      const int row = coeff_idx >> txb_info->bwl;
+      const int col = coeff_idx - (row << txb_info->bwl);
+      const int ctx = txb_cache->base_ctx_arr[base_idx][coeff_idx][0];
+      const int org_cost =
+          get_base_cost(abs_qc, ctx, txb_probs->coeff_base, base_idx);
+
+      const int new_ctx =
+          get_base_ctx_from_count_mag(row, col, new_count, new_mag, level);
+      const int new_cost =
+          get_base_cost(abs_qc, new_ctx, txb_probs->coeff_base, base_idx);
+      cost_diff += -org_cost + new_cost;
+    }
+  }
+  return cost_diff;
+}
+
+static int try_neighbor_level_down_nz(int coeff_idx, int nb_coeff_idx,
+                                      const TxbCache *txb_cache,
+                                      const TxbProbs *txb_probs,
+                                      TxbInfo *txb_info) {
+  // assume eob doesn't change
+  const tran_low_t qc = txb_info->qcoeff[coeff_idx];
+  const tran_low_t abs_qc = abs(qc);
+  const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
+  const tran_low_t abs_nb_coeff = abs(nb_coeff);
+  if (abs_nb_coeff != 1) return 0;
+  const int16_t *iscan = txb_info->scan_order->iscan;
+  const int scan_idx = iscan[coeff_idx];
+  if (scan_idx == txb_info->seg_eob) return 0;
+  const int nb_scan_idx = iscan[nb_coeff_idx];
+  if (nb_scan_idx < scan_idx) {
+    const int count = txb_cache->nz_count_arr[coeff_idx];
+    assert(count > 0);
+    txb_info->qcoeff[nb_coeff_idx] = get_lower_coeff(nb_coeff);
+    const int new_ctx = get_nz_map_ctx_from_count(
+        count - 1, txb_info->qcoeff, coeff_idx, txb_info->bwl, iscan);
+    txb_info->qcoeff[nb_coeff_idx] = nb_coeff;
+    const int ctx = txb_cache->nz_ctx_arr[coeff_idx][0];
+    const int is_nz = abs_qc > 0;
+    const int org_cost = av1_cost_bit(txb_probs->nz_map[ctx], is_nz);
+    const int new_cost = av1_cost_bit(txb_probs->nz_map[new_ctx], is_nz);
+    const int cost_diff = new_cost - org_cost;
+    return cost_diff;
+  } else {
+    return 0;
+  }
+}
+
+static int try_self_level_down(tran_low_t *low_coeff, int coeff_idx,
+                               const TxbCache *txb_cache,
+                               const TxbProbs *txb_probs, TxbInfo *txb_info) {
   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
   if (qc == 0) {
     *low_coeff = 0;
@@ -603,9 +754,9 @@
   int cost_diff;
   if (*low_coeff == 0) {
     const int scan_idx = txb_info->scan_order->iscan[coeff_idx];
-    aom_prob level_prob =
+    const aom_prob level_prob =
         get_level_prob(abs_qc, coeff_idx, txb_cache, txb_probs);
-    aom_prob low_level_prob =
+    const aom_prob low_level_prob =
         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_probs);
     if (scan_idx < txb_info->seg_eob) {
       // When level-0, we code the binary of abs_qc > level
@@ -649,6 +800,102 @@
   return cost_diff;
 }
 
+#define COST_MAP_SIZE 5
+#define COST_MAP_OFFSET 2
+
+static INLINE int check_nz_neighbor(tran_low_t qc) { return abs(qc) == 1; }
+
+static INLINE int check_base_neighbor(tran_low_t qc) {
+  return abs(qc) <= 1 + NUM_BASE_LEVELS;
+}
+
+static INLINE int check_br_neighbor(tran_low_t qc) {
+  return abs(qc) > BR_MAG_OFFSET;
+}
+
+// TODO(angiebird): add static to this function once it's called
+int try_level_down(int coeff_idx, const TxbCache *txb_cache,
+                   const TxbProbs *txb_probs, TxbInfo *txb_info,
+                   int (*cost_map)[COST_MAP_SIZE]) {
+  if (cost_map) {
+    for (int i = 0; i < COST_MAP_SIZE; ++i) av1_zero(cost_map[i]);
+  }
+
+  tran_low_t qc = txb_info->qcoeff[coeff_idx];
+  tran_low_t low_coeff;
+  if (qc == 0) return 0;
+  int accu_cost_diff = 0;
+
+  const int16_t *iscan = txb_info->scan_order->iscan;
+  const int eob = txb_info->eob;
+  const int scan_idx = iscan[coeff_idx];
+  if (scan_idx < eob) {
+    const int cost_diff = try_self_level_down(&low_coeff, coeff_idx, txb_cache,
+                                              txb_probs, txb_info);
+    if (cost_map)
+      cost_map[0 + COST_MAP_OFFSET][0 + COST_MAP_OFFSET] = cost_diff;
+    accu_cost_diff += cost_diff;
+  }
+
+  const int row = coeff_idx >> txb_info->bwl;
+  const int col = coeff_idx - (row << txb_info->bwl);
+  if (check_nz_neighbor(qc)) {
+    for (int i = 0; i < SIG_REF_OFFSET_NUM; ++i) {
+      const int nb_row = row - sig_ref_offset[i][0];
+      const int nb_col = col - sig_ref_offset[i][1];
+      const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
+      const int nb_scan_idx = iscan[nb_coeff_idx];
+      if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 &&
+          nb_row < txb_info->stride && nb_col < txb_info->stride) {
+        const int cost_diff = try_neighbor_level_down_nz(
+            nb_coeff_idx, coeff_idx, txb_cache, txb_probs, txb_info);
+        if (cost_map)
+          cost_map[nb_row - row + COST_MAP_OFFSET]
+                  [nb_col - col + COST_MAP_OFFSET] += cost_diff;
+        accu_cost_diff += cost_diff;
+      }
+    }
+  }
+
+  if (check_base_neighbor(qc)) {
+    for (int i = 0; i < BASE_CONTEXT_POSITION_NUM; ++i) {
+      const int nb_row = row - base_ref_offset[i][0];
+      const int nb_col = col - base_ref_offset[i][1];
+      const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
+      const int nb_scan_idx = iscan[nb_coeff_idx];
+      if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 &&
+          nb_row < txb_info->stride && nb_col < txb_info->stride) {
+        const int cost_diff = try_neighbor_level_down_base(
+            nb_coeff_idx, coeff_idx, txb_cache, txb_probs, txb_info);
+        if (cost_map)
+          cost_map[nb_row - row + COST_MAP_OFFSET]
+                  [nb_col - col + COST_MAP_OFFSET] += cost_diff;
+        accu_cost_diff += cost_diff;
+      }
+    }
+  }
+
+  if (check_br_neighbor(qc)) {
+    for (int i = 0; i < BR_CONTEXT_POSITION_NUM; ++i) {
+      const int nb_row = row - br_ref_offset[i][0];
+      const int nb_col = col - br_ref_offset[i][1];
+      const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
+      const int nb_scan_idx = iscan[nb_coeff_idx];
+      if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 &&
+          nb_row < txb_info->stride && nb_col < txb_info->stride) {
+        const int cost_diff = try_neighbor_level_down_br(
+            nb_coeff_idx, coeff_idx, txb_cache, txb_probs, txb_info);
+        if (cost_map)
+          cost_map[nb_row - row + COST_MAP_OFFSET]
+                  [nb_col - col + COST_MAP_OFFSET] += cost_diff;
+        accu_cost_diff += cost_diff;
+      }
+    }
+  }
+
+  return accu_cost_diff;
+}
+
 static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info,
                           TxbProbs *txb_probs) {
   const TXB_CTX *txb_ctx = txb_info->txb_ctx;