Facilitate skip rd based gating in compound type rd

This CL facilitates skip rd based gating in compound type rd by
computing the luma skip rd for the best mode in the mode loop.

Change-Id: I7fada3127466454d794c56f966894f6834651183
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 57710d8..1bf7b81 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -320,7 +320,7 @@
 
 typedef struct InterModeSearchState {
   int64_t best_rd;
-  int64_t best_skip_rd;
+  int64_t best_skip_rd[2];
   MB_MODE_INFO best_mbmode;
   int best_rate_y;
   int best_rate_uv;
@@ -610,7 +610,8 @@
   }
 }
 
-static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
+static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x,
+                       int64_t *sse_y) {
   const AV1_COMMON *cm = &cpi->common;
   const int num_planes = av1_num_planes(cm);
   const MACROBLOCKD *xd = &x->e_mbd;
@@ -627,6 +628,7 @@
     cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
                        &sse);
     total_sse += sse;
+    if (!plane && sse_y) *sse_y = sse;
   }
   total_sse <<= 4;
   return total_sse;
@@ -1216,27 +1218,6 @@
   }
 }
 
-static INLINE int check_txfm_eval(MACROBLOCK *const x, BLOCK_SIZE bsize,
-                                  int64_t best_skip_rd, int64_t skip_rd,
-                                  int level) {
-  int eval_txfm = 1;
-  // Derive aggressiveness factor for gating the transform search
-  // Lower value indicates more aggresiveness. Be more conservative (high value)
-  // for (i) low quantizers (ii) regions where prediction is poor
-  const int scale[3] = { INT_MAX, 3, 2 };
-  int aggr_factor =
-      AOMMAX(1, ((MAXQ - x->qindex) * 2 + QINDEX_RANGE / 2) >> QINDEX_BITS);
-  if (best_skip_rd >
-      (x->source_variance << (num_pels_log2_lookup[bsize] + RDDIV_BITS)))
-    aggr_factor *= scale[level];
-
-  int64_t rd_thresh = (best_skip_rd == INT64_MAX)
-                          ? best_skip_rd
-                          : (int64_t)(best_skip_rd * aggr_factor);
-  if (skip_rd > rd_thresh) eval_txfm = 0;
-  return eval_txfm;
-}
-
 // TODO(afergs): Refactor the MBMI references in here - there's four
 // TODO(afergs): Refactor optional args - add them to a struct or remove
 static int64_t motion_mode_rd(
@@ -1490,11 +1471,15 @@
 
     if (!do_tx_search) {
       int64_t curr_sse = -1;
+      int64_t sse_y = -1;
       int est_residue_cost = 0;
       int64_t est_dist = 0;
       int64_t est_rd = 0;
       if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
-        curr_sse = get_sse(cpi, x);
+        curr_sse = get_sse(cpi, x, &sse_y);
+        // Scale luma SSE as per bit depth so as to be consistent with
+        // model_rd_sb_fn and compound type rd
+        sse_y = ROUND_POWER_OF_TWO(sse_y, (xd->bd - 8) * 2);
         const int has_est_rd = get_est_rate_dist(tile_data, bsize, curr_sse,
                                                  &est_residue_cost, &est_dist);
         (void)has_est_rd;
@@ -1504,6 +1489,7 @@
         model_rd_sb_fn[MODELRD_TYPE_MOTION_MODE_RD](
             cpi, bsize, x, xd, 0, num_planes - 1, &est_residue_cost, &est_dist,
             NULL, &curr_sse, NULL, NULL, NULL);
+        sse_y = x->pred_sse[xd->mi[0]->ref_frame[0]];
       }
       est_rd = RDCOST(x->rdmult, rd_stats->rate + est_residue_cost, est_dist);
       if (est_rd * 0.80 > *best_est_rd) {
@@ -1514,7 +1500,14 @@
       rd_stats->rate += est_residue_cost;
       rd_stats->dist = est_dist;
       rd_stats->rdcost = est_rd;
-      *best_est_rd = AOMMIN(*best_est_rd, rd_stats->rdcost);
+      if (rd_stats->rdcost < *best_est_rd) {
+        *best_est_rd = rd_stats->rdcost;
+        int64_t skip_rdy = INT64_MAX;
+        if (cpi->sf.inter_sf.txfm_rd_gate_level) {
+          skip_rdy = RDCOST(x->rdmult, mode_rate, sse_y);
+        }
+        ref_skip_rd[1] = skip_rdy;
+      }
       if (cm->current_frame.reference_mode == SINGLE_REFERENCE) {
         if (!is_comp_pred) {
           assert(curr_sse >= 0);
@@ -1531,11 +1524,17 @@
       mbmi->skip = 0;
     } else {
       int64_t skip_rd = INT64_MAX;
+      int64_t skip_rdy = INT64_MAX;
       if (cpi->sf.inter_sf.txfm_rd_gate_level) {
         // Check if the mode is good enough based on skip RD
-        int64_t curr_sse = get_sse(cpi, x);
+        int64_t sse_y = INT64_MAX;
+        int64_t curr_sse = get_sse(cpi, x, &sse_y);
+        // Scale luma SSE as per bit depth so as to be consistent with
+        // model_rd_sb_fn and compound type rd
+        sse_y = ROUND_POWER_OF_TWO(sse_y, (xd->bd - 8) * 2);
         skip_rd = RDCOST(x->rdmult, rd_stats->rate, curr_sse);
-        int eval_txfm = check_txfm_eval(x, bsize, *ref_skip_rd, skip_rd,
+        skip_rdy = RDCOST(x->rdmult, rd_stats->rate, sse_y);
+        int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd[0], skip_rd,
                                         cpi->sf.inter_sf.txfm_rd_gate_level);
         if (!eval_txfm) continue;
       }
@@ -1555,7 +1554,8 @@
       const int64_t curr_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
       if (curr_rd < ref_best_rd) {
         ref_best_rd = curr_rd;
-        *ref_skip_rd = skip_rd;
+        ref_skip_rd[0] = skip_rd;
+        ref_skip_rd[1] = skip_rdy;
       }
       *disable_skip = 0;
       if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
@@ -3537,7 +3537,8 @@
   init_intra_mode_search_state(&search_state->intra_search_state);
 
   search_state->best_rd = best_rd_so_far;
-  search_state->best_skip_rd = INT64_MAX;
+  search_state->best_skip_rd[0] = INT64_MAX;
+  search_state->best_skip_rd[1] = INT64_MAX;
 
   av1_zero(search_state->best_mbmode);
 
@@ -4132,10 +4133,11 @@
       if (is_comp_pred) xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
     }
 
-    int64_t skip_rd = search_state->best_skip_rd;
+    int64_t skip_rd[2] = { search_state->best_skip_rd[0],
+                           search_state->best_skip_rd[1] };
     int64_t ret_value = motion_mode_rd(
         cpi, tile_data, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
-        &disable_skip, args, search_state->best_rd, &skip_rd, &rate_mv,
+        &disable_skip, args, search_state->best_rd, skip_rd, &rate_mv,
         &orig_dst, best_est_rd, do_tx_search, inter_modes_info, 1);
 
     if (ret_value != INT64_MAX) {
@@ -4150,7 +4152,7 @@
       if (rd_stats.rdcost < search_state->best_rd) {
         update_search_state(search_state, rd_cost, ctx, &rd_stats, &rd_stats_y,
                             &rd_stats_uv, mode_enum, x, do_tx_search);
-        if (do_tx_search) search_state->best_skip_rd = skip_rd;
+        if (do_tx_search) search_state->best_skip_rd[0] = skip_rd[0];
       }
     }
   }
@@ -4464,12 +4466,13 @@
       args.simple_rd_state = x->simple_rd_state[mode_enum];
     }
 
-    int64_t skip_rd = search_state.best_skip_rd;
+    int64_t skip_rd[2] = { search_state.best_skip_rd[0],
+                           search_state.best_skip_rd[1] };
     int64_t this_rd = handle_inter_mode(
         cpi, tile_data, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
         &disable_skip, &args, ref_best_rd, tmp_buf, &x->comp_rd_buffer,
         &best_est_rd, do_tx_search, inter_modes_info, &motion_mode_cand,
-        &skip_rd, &inter_cost_info_from_tpl);
+        skip_rd, &inter_cost_info_from_tpl);
 
     if (sf->inter_sf.prune_comp_search_by_single_result > 0 &&
         is_inter_singleref_mode(this_mode) && args.single_ref_first_pass) {
@@ -4495,7 +4498,8 @@
       search_state.best_pred_sse = x->pred_sse[ref_frame];
       update_search_state(&search_state, rd_cost, ctx, &rd_stats, &rd_stats_y,
                           &rd_stats_uv, mode_enum, x, do_tx_search);
-      if (do_tx_search) search_state.best_skip_rd = skip_rd;
+      if (do_tx_search) search_state.best_skip_rd[0] = skip_rd[0];
+      search_state.best_skip_rd[1] = skip_rd[1];
     }
     if (cpi->sf.winner_mode_sf.motion_mode_for_winner_cand) {
       const int num_motion_mode_cand =
@@ -4628,7 +4632,7 @@
         int64_t curr_sse = inter_modes_info->sse_arr[data_idx];
         skip_rd = RDCOST(x->rdmult, mode_rate, curr_sse);
         int eval_txfm =
-            check_txfm_eval(x, bsize, search_state.best_skip_rd, skip_rd,
+            check_txfm_eval(x, bsize, search_state.best_skip_rd[0], skip_rd,
                             cpi->sf.inter_sf.txfm_rd_gate_level);
         if (!eval_txfm) continue;
       }
@@ -4658,7 +4662,7 @@
       if (rd_stats.rdcost < search_state.best_rd) {
         update_search_state(&search_state, rd_cost, ctx, &rd_stats, &rd_stats_y,
                             &rd_stats_uv, mode_enum, x, txfm_search_done);
-        search_state.best_skip_rd = skip_rd;
+        search_state.best_skip_rd[0] = skip_rd;
       }
     }
   }
diff --git a/av1/encoder/rdopt_utils.h b/av1/encoder/rdopt_utils.h
index b7fcb87..4a4344e 100644
--- a/av1/encoder/rdopt_utils.h
+++ b/av1/encoder/rdopt_utils.h
@@ -327,6 +327,27 @@
   return num_blk;
 }
 
+static INLINE int check_txfm_eval(MACROBLOCK *const x, BLOCK_SIZE bsize,
+                                  int64_t best_skip_rd, int64_t skip_rd,
+                                  int level) {
+  int eval_txfm = 1;
+  // Derive aggressiveness factor for gating the transform search
+  // Lower value indicates more aggresiveness. Be more conservative (high value)
+  // for (i) low quantizers (ii) regions where prediction is poor
+  const int scale[3] = { INT_MAX, 3, 2 };
+  int aggr_factor =
+      AOMMAX(1, ((MAXQ - x->qindex) * 2 + QINDEX_RANGE / 2) >> QINDEX_BITS);
+  if (best_skip_rd >
+      (x->source_variance << (num_pels_log2_lookup[bsize] + RDDIV_BITS)))
+    aggr_factor *= scale[level];
+
+  int64_t rd_thresh = (best_skip_rd == INT64_MAX)
+                          ? best_skip_rd
+                          : (int64_t)(best_skip_rd * aggr_factor);
+  if (skip_rd > rd_thresh) eval_txfm = 0;
+  return eval_txfm;
+}
+
 static TX_MODE select_tx_mode(
     const AV1_COMP *cpi, const TX_SIZE_SEARCH_METHOD tx_size_search_method) {
   if (cpi->common.coded_lossless) return ONLY_4X4;