Properly count the rate cost in base range coding

Properly count the base range coefficient coding in the rate
distortion optimization and soft quantization process.

Change-Id: I860001f51c4a9d0021d08b85b8ccdb097121b287
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 451f3be..318d849 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -74,7 +74,7 @@
   int eob_cost[EOB_COEF_CONTEXTS][2];
   int dc_sign_cost[DC_SIGN_CONTEXTS][2];
   int base_cost[NUM_BASE_LEVELS][COEFF_BASE_CONTEXTS][2];
-  int lps_cost[LEVEL_CONTEXTS][2];
+  int lps_cost[LEVEL_CONTEXTS][COEFF_BASE_RANGE + 1];
 #if BR_NODE
   int br_cost[BASE_RANGE_SETS][LEVEL_CONTEXTS][2];
 #endif
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 1bd7016..21c4cd6 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -364,17 +364,15 @@
 }
 
 static INLINE int get_br_cost(tran_low_t abs_qc, int ctx,
-                              const int coeff_lps[2]) {
+                              const int coeff_lps[COEFF_BASE_RANGE + 1]) {
   const tran_low_t min_level = 1 + NUM_BASE_LEVELS;
   const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE;
   (void)ctx;
   if (abs_qc >= min_level) {
-    const int cost0 = coeff_lps[0];
-    const int cost1 = coeff_lps[1];
     if (abs_qc >= max_level)
-      return COEFF_BASE_RANGE * cost0;
+      return coeff_lps[COEFF_BASE_RANGE];  // COEFF_BASE_RANGE * cost0;
     else
-      return (abs_qc - min_level) * cost0 + cost1;
+      return coeff_lps[(abs_qc - min_level)];  //  * cost0 + cost1;
   } else {
     return 0;
   }
@@ -467,29 +465,12 @@
         ctx = get_br_ctx(qcoeff, scan[c], bwl, height);
 #if BR_NODE
         int base_range = level - 1 - NUM_BASE_LEVELS;
-        int br_set_idx = base_range < COEFF_BASE_RANGE
-                             ? coeff_to_br_index[base_range]
-                             : BASE_RANGE_SETS;
-
-        for (idx = 0; idx < BASE_RANGE_SETS; ++idx) {
-          if (br_set_idx == idx) {
-            int br_base = br_index_to_coeff[br_set_idx];
-            int br_offset = base_range - br_base;
-            int extra_bits = (1 << br_extra_bits[idx]) - 1;
-            cost += coeff_costs->br_cost[idx][ctx][1];
-            for (int tok = 0; tok < extra_bits; ++tok) {
-              if (tok == br_offset) {
-                cost += coeff_costs->lps_cost[ctx][1];
-                break;
-              }
-              cost += coeff_costs->lps_cost[ctx][0];
-            }
-            //            cost += extra_bits * av1_cost_bit(128, 1);
-            break;
-          }
-          cost += coeff_costs->br_cost[idx][ctx][0];
+        if (base_range < COEFF_BASE_RANGE) {
+          cost += coeff_costs->lps_cost[ctx][base_range];
+          continue;
+        } else {
+          cost += coeff_costs->lps_cost[ctx][COEFF_BASE_RANGE];
         }
-        if (idx < BASE_RANGE_SETS) continue;
 
         idx = COEFF_BASE_RANGE;
 #else
@@ -841,18 +822,31 @@
     const int sign_cost = get_sign_bit_cost(
         qc, coeff_idx, txb_costs->dc_sign_cost, txb_info->txb_ctx->dc_sign_ctx);
     cost_diff -= sign_cost;
+  } else if (abs_qc <= NUM_BASE_LEVELS) {
+    const int *level_cost =
+        get_level_prob(abs_qc, coeff_idx, txb_cache, txb_costs);
+    const int *low_level_cost =
+        get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
+    cost_diff = -level_cost[1] + low_level_cost[1] - low_level_cost[0];
+  } else if (abs_qc == NUM_BASE_LEVELS + 1) {
+    const int *level_cost =
+        get_level_prob(abs_qc, coeff_idx, txb_cache, txb_costs);
+    const int *low_level_cost =
+        get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
+    cost_diff = -level_cost[0] + low_level_cost[1] - low_level_cost[0];
   } else if (abs_qc < 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
     const int *level_cost =
         get_level_prob(abs_qc, coeff_idx, txb_cache, txb_costs);
     const int *low_level_cost =
         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
 
-    cost_diff = -level_cost[1] + low_level_cost[1] - low_level_cost[0];
+    cost_diff = -level_cost[abs_qc - 1 - NUM_BASE_LEVELS] +
+                low_level_cost[abs(*low_coeff) - 1 - NUM_BASE_LEVELS];
   } else if (abs_qc == 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
     const int *low_level_cost =
         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
-    cost_diff =
-        -get_golomb_cost(abs_qc) + low_level_cost[1] - low_level_cost[0];
+    cost_diff = -get_golomb_cost(abs_qc) - low_level_cost[COEFF_BASE_RANGE] +
+                low_level_cost[COEFF_BASE_RANGE - 1];
   } else {
     assert(abs_qc > 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE);
     const tran_low_t abs_low_coeff = abs(*low_coeff);
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index a37eb4d..7894e4f 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -36,6 +36,9 @@
 #include "av1/encoder/encodemb.h"
 #include "av1/encoder/encodemv.h"
 #include "av1/encoder/encoder.h"
+#if CONFIG_LV_MAP
+#include "av1/encoder/encodetxb.h"
+#endif
 #include "av1/encoder/mcomp.h"
 #include "av1/encoder/ratectrl.h"
 #include "av1/encoder/rd.h"
@@ -508,9 +511,42 @@
                                    NULL);
 #endif  // BR_NODE
 
-      for (int ctx = 0; ctx < LEVEL_CONTEXTS; ++ctx)
-        av1_cost_tokens_from_cdf(pcost->lps_cost[ctx],
+      for (int ctx = 0; ctx < LEVEL_CONTEXTS; ++ctx) {
+        int lps_rate[2];
+        av1_cost_tokens_from_cdf(lps_rate,
                                  fc->coeff_lps_cdf[tx_size][plane][ctx], NULL);
+
+        for (int base_range = 0; base_range < COEFF_BASE_RANGE + 1;
+             ++base_range) {
+          int br_set_idx = base_range < COEFF_BASE_RANGE
+                               ? coeff_to_br_index[base_range]
+                               : BASE_RANGE_SETS;
+
+          pcost->lps_cost[ctx][base_range] = 0;
+
+          for (int idx = 0; idx < BASE_RANGE_SETS; ++idx) {
+            if (idx == br_set_idx) {
+              pcost->lps_cost[ctx][base_range] += pcost->br_cost[idx][ctx][1];
+
+              int br_base = br_index_to_coeff[br_set_idx];
+              int br_offset = base_range - br_base;
+              int extra_bits = (1 << br_extra_bits[idx]) - 1;
+              for (int tok = 0; tok < extra_bits; ++tok) {
+                if (tok == br_offset) {
+                  pcost->lps_cost[ctx][base_range] += lps_rate[1];
+                  break;
+                } else {
+                  pcost->lps_cost[ctx][base_range] += lps_rate[0];
+                }
+              }
+              break;
+            } else {
+              pcost->lps_cost[ctx][base_range] += pcost->br_cost[idx][ctx][0];
+            }
+          }
+          // load the base range cost
+        }
+      }
 #else   // LV_MAP_PROB
       for (int ctx = 0; ctx < TXB_SKIP_CONTEXTS; ++ctx)
         get_rate_cost(fc->txb_skip[tx_size][ctx], pcost->txb_skip_cost[ctx]);