Facilitate multiwinner mode processing framework for inter frames

This patch will facilitate multi-winner mode processing support
for Inter frames.

Change-Id: Icfa93bd5a607bcfe1c70a2c936094f18a3da7c21
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index da80a30..34df73b 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -35,11 +35,16 @@
 #define MC_FLOW_BSIZE_1D 16
 #define MC_FLOW_NUM_PELS (MC_FLOW_BSIZE_1D * MC_FLOW_BSIZE_1D)
 #define MAX_MC_FLOW_BLK_IN_SB (MAX_SB_SIZE / MC_FLOW_BSIZE_1D)
-#define MAX_WINNER_MODE_COUNT 3
+#define MAX_WINNER_MODE_COUNT_INTRA 3
+#define MAX_WINNER_MODE_COUNT_INTER 1
 typedef struct {
   MB_MODE_INFO mbmi;
+  RD_STATS rd_cost;
   int64_t rd;
+  int rate_y;
+  int rate_uv;
   uint8_t color_index_map[64 * 64];
+  THR_MODES mode_index;
 } WinnerModeStats;
 
 typedef struct {
@@ -235,7 +240,8 @@
   MB_MODE_INFO_EXT *mbmi_ext;
   MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame;
   // Array of mode stats for winner mode processing
-  WinnerModeStats winner_mode_stats[MAX_WINNER_MODE_COUNT];
+  WinnerModeStats winner_mode_stats[AOMMAX(MAX_WINNER_MODE_COUNT_INTRA,
+                                           MAX_WINNER_MODE_COUNT_INTER)];
   int winner_mode_count;
   int skip_block;
   int qindex;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 5005047..42e72ff 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4558,36 +4558,63 @@
 }
 
 // Store best mode stats for winner mode processing
-static void store_winner_mode_stats(MACROBLOCK *x, MB_MODE_INFO *mbmi,
+static void store_winner_mode_stats(const AV1_COMMON *const cm, MACROBLOCK *x,
+                                    MB_MODE_INFO *mbmi, RD_STATS *rd_cost,
+                                    RD_STATS *rd_cost_y, RD_STATS *rd_cost_uv,
+                                    THR_MODES mode_index, uint8_t *color_map,
+                                    BLOCK_SIZE bsize, int64_t this_rd,
                                     int enable_multiwinner_mode_process,
-                                    uint8_t *color_map, BLOCK_SIZE bsize,
-                                    int64_t this_rd) {
+                                    int txfm_search_done) {
   WinnerModeStats *winner_mode_stats = x->winner_mode_stats;
   int mode_idx = 0;
+  int is_palette_mode = mbmi->palette_mode_info.palette_size[PLANE_TYPE_Y] > 0;
   // Mode stat is not required when multiwinner mode processing is disabled
   if (!enable_multiwinner_mode_process) return;
+  // TODO(any): Winner mode processing is currently not applicable for palette
+  // mode in Inter frames. Clean-up the following code, once support is added
+  if (!frame_is_intra_only(cm) && is_palette_mode) return;
 
+  const int max_winner_mode_count = frame_is_intra_only(cm)
+                                        ? MAX_WINNER_MODE_COUNT_INTRA
+                                        : MAX_WINNER_MODE_COUNT_INTER;
   assert(x->winner_mode_count >= 0 &&
-         x->winner_mode_count <= MAX_WINNER_MODE_COUNT);
+         x->winner_mode_count <= max_winner_mode_count);
 
   if (x->winner_mode_count) {
     // Find the mode which has higher rd cost than this_rd
     for (mode_idx = 0; mode_idx < x->winner_mode_count; mode_idx++)
       if (winner_mode_stats[mode_idx].rd > this_rd) break;
 
-    if (mode_idx == MAX_WINNER_MODE_COUNT) {
+    if (mode_idx == max_winner_mode_count) {
       // No mode has higher rd cost than this_rd
       return;
-    } else if (mode_idx < MAX_WINNER_MODE_COUNT - 1) {
+    } else if (mode_idx < max_winner_mode_count - 1) {
       // Create a slot for current mode and move others to the next slot
       memmove(
           &winner_mode_stats[mode_idx + 1], &winner_mode_stats[mode_idx],
-          (MAX_WINNER_MODE_COUNT - mode_idx - 1) * sizeof(*winner_mode_stats));
+          (max_winner_mode_count - mode_idx - 1) * sizeof(*winner_mode_stats));
     }
   }
   // Add a mode stat for winner mode processing
   winner_mode_stats[mode_idx].mbmi = *mbmi;
   winner_mode_stats[mode_idx].rd = this_rd;
+  winner_mode_stats[mode_idx].mode_index = mode_index;
+
+  // Update rd stats required for inter frame
+  if (!frame_is_intra_only(cm) && rd_cost && rd_cost_y && rd_cost_uv) {
+    const MACROBLOCKD *xd = &x->e_mbd;
+    const int skip_ctx = av1_get_skip_context(xd);
+    const int is_intra_mode = av1_mode_defs[mode_index].mode < INTRA_MODE_END;
+    const int skip = mbmi->skip && !is_intra_mode;
+
+    winner_mode_stats[mode_idx].rd_cost = *rd_cost;
+    if (txfm_search_done) {
+      winner_mode_stats[mode_idx].rate_y =
+          rd_cost_y->rate + x->skip_cost[skip_ctx][rd_cost->skip || skip];
+      winner_mode_stats[mode_idx].rate_uv = rd_cost_uv->rate;
+    }
+  }
+
   if (color_map) {
     // Store color_index_map for palette mode
     const MACROBLOCKD *const xd = &x->e_mbd;
@@ -4599,7 +4626,7 @@
   }
 
   x->winner_mode_count =
-      AOMMIN(x->winner_mode_count + 1, MAX_WINNER_MODE_COUNT);
+      AOMMIN(x->winner_mode_count + 1, max_winner_mode_count);
 }
 
 // Given the base colors as specified in centroids[], calculate the RD cost
@@ -4652,8 +4679,10 @@
     tokenonly_rd_stats.rate -= tx_size_cost(x, bsize, mbmi->tx_size);
   }
   // Collect mode stats for multiwinner mode processing
-  store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
-                          color_map, bsize, this_rd);
+  const int txfm_search_done = 1;
+  store_winner_mode_stats(
+      &cpi->common, x, mbmi, NULL, NULL, NULL, THR_DC, color_map, bsize,
+      this_rd, cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
   if (this_rd < *best_rd) {
     *best_rd = this_rd;
     // Setting beat_best_rd flag because current mode rd is better than best_rd.
@@ -4917,8 +4946,10 @@
     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
 
     // Collect mode stats for multiwinner mode processing
-    store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
-                            NULL, bsize, this_rd);
+    const int txfm_search_done = 1;
+    store_winner_mode_stats(
+        &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, this_rd,
+        cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
     if (this_rd < *best_rd) {
       *best_rd = this_rd;
       best_tx_size = mbmi->tx_size;
@@ -5298,10 +5329,13 @@
   set_mode_eval_params(cpi, x, MODE_EVAL);
 
   MB_MODE_INFO best_mbmi = *mbmi;
+  av1_zero(x->winner_mode_stats);
   x->winner_mode_count = 0;
   // Initialize best mode stats for winner mode processing
-  store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
-                          NULL, bsize, best_rd);
+  const int txfm_search_done = 1;
+  store_winner_mode_stats(
+      &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, best_rd,
+      cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
   /* Y Search for intra prediction mode */
   for (int mode_idx = INTRA_MODE_START; mode_idx < INTRA_MODE_END; ++mode_idx) {
     RD_STATS this_rd_stats;
@@ -5350,8 +5384,9 @@
         intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
     this_rd = RDCOST(x->rdmult, this_rate, this_distortion);
     // Collect mode stats for multiwinner mode processing
-    store_winner_mode_stats(x, mbmi, cpi->sf.enable_multiwinner_mode_process,
-                            NULL, bsize, this_rd);
+    store_winner_mode_stats(
+        &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, this_rd,
+        cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
     if (this_rd < best_rd) {
       best_mbmi = *mbmi;
       best_rd = this_rd;
@@ -11118,6 +11153,13 @@
             }
           }
           if (skip) {
+            const THR_MODES mode_enum = get_prediction_mode_idx(
+                best_mbmi.mode, best_mbmi.ref_frame[0], best_mbmi.ref_frame[1]);
+            // Collect mode stats for multiwinner mode processing
+            store_winner_mode_stats(
+                &cpi->common, x, &best_mbmi, &best_rd_stats, &best_rd_stats_y,
+                &best_rd_stats_uv, mode_enum, NULL, bsize, best_rd,
+                cpi->sf.enable_multiwinner_mode_process, do_tx_search);
             args->modelled_rd[this_mode][ref_mv_idx][refs[0]] =
                 args->modelled_rd[this_mode][i][refs[0]];
             args->simple_rd[this_mode][ref_mv_idx][refs[0]] =
@@ -11278,6 +11320,13 @@
     if (ret_val != INT64_MAX) {
       int64_t tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
       mode_info[ref_mv_idx].rd = tmp_rd;
+      const THR_MODES mode_enum = get_prediction_mode_idx(
+          mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+      // Collect mode stats for multiwinner mode processing
+      store_winner_mode_stats(&cpi->common, x, mbmi, rd_stats, rd_stats_y,
+                              rd_stats_uv, mode_enum, NULL, bsize, tmp_rd,
+                              cpi->sf.enable_multiwinner_mode_process,
+                              do_tx_search);
       if (tmp_rd < best_rd) {
         best_rd_stats = *rd_stats;
         best_rd_stats_y = *rd_stats_y;
@@ -11776,6 +11825,33 @@
   }
 }
 
+// Get winner mode stats of given mode index
+static AOM_INLINE MB_MODE_INFO *get_winner_mode_stats(
+    MACROBLOCK *x, MB_MODE_INFO *best_mbmode, RD_STATS *best_rd_cost,
+    int best_rate_y, int best_rate_uv, THR_MODES *best_mode_index,
+    RD_STATS **winner_rd_cost, int *winner_rate_y, int *winner_rate_uv,
+    THR_MODES *winner_mode_index, int enable_multiwinner_mode_process,
+    int mode_idx) {
+  MB_MODE_INFO *winner_mbmi;
+  if (enable_multiwinner_mode_process) {
+    assert(mode_idx >= 0 && mode_idx < x->winner_mode_count);
+    WinnerModeStats *winner_mode_stat = &x->winner_mode_stats[mode_idx];
+    winner_mbmi = &winner_mode_stat->mbmi;
+
+    *winner_rd_cost = &winner_mode_stat->rd_cost;
+    *winner_rate_y = winner_mode_stat->rate_y;
+    *winner_rate_uv = winner_mode_stat->rate_uv;
+    *winner_mode_index = winner_mode_stat->mode_index;
+  } else {
+    winner_mbmi = best_mbmode;
+    *winner_rd_cost = best_rd_cost;
+    *winner_rate_y = best_rate_y;
+    *winner_rate_uv = best_rate_uv;
+    *winner_mode_index = *best_mode_index;
+  }
+  return winner_mbmi;
+}
+
 // speed feature: fast intra/inter transform type search
 // Used for speed >= 2
 // When this speed feature is on, in rd mode search, only DCT is used.
@@ -11783,24 +11859,47 @@
 // transform types and get accurate rdcost.
 static AOM_INLINE void refine_winner_mode_tx(
     const AV1_COMP *cpi, MACROBLOCK *x, RD_STATS *rd_cost, BLOCK_SIZE bsize,
-    PICK_MODE_CONTEXT *ctx, THR_MODES best_mode_index,
+    PICK_MODE_CONTEXT *ctx, THR_MODES *best_mode_index,
     MB_MODE_INFO *best_mbmode, struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE],
-    int best_rate_y, int best_rate_uv, int *best_skip2) {
+    int best_rate_y, int best_rate_uv, int *best_skip2, int winner_mode_count) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
+  int64_t best_rd;
   const int num_planes = av1_num_planes(cm);
 
-  if (is_winner_mode_processing_enabled(cpi, mbmi, best_mbmode->mode)) {
-    // Set params for winner mode evaluation
-    set_mode_eval_params(cpi, x, WINNER_MODE_EVAL);
+  if (!is_winner_mode_processing_enabled(cpi, best_mbmode, best_mbmode->mode))
+    return;
 
-    if (xd->lossless[mbmi->segment_id] == 0 && best_mode_index != THR_INVALID) {
+  // Set params for winner mode evaluation
+  set_mode_eval_params(cpi, x, WINNER_MODE_EVAL);
+
+  // No best mode identified so far
+  if (*best_mode_index == THR_INVALID) return;
+
+  best_rd = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
+  for (int mode_idx = 0; mode_idx < winner_mode_count; mode_idx++) {
+    RD_STATS *winner_rd_stats = NULL;
+    int winner_rate_y = 0, winner_rate_uv = 0;
+    THR_MODES winner_mode_index = 0;
+
+    // TODO(any): Combine best mode and multi-winner mode processing paths
+    // Get winner mode stats for current mode index
+    MB_MODE_INFO *winner_mbmi = get_winner_mode_stats(
+        x, best_mbmode, rd_cost, best_rate_y, best_rate_uv, best_mode_index,
+        &winner_rd_stats, &winner_rate_y, &winner_rate_uv, &winner_mode_index,
+        cpi->sf.enable_multiwinner_mode_process, mode_idx);
+
+    if (xd->lossless[winner_mbmi->segment_id] == 0 &&
+        winner_mode_index != THR_INVALID &&
+        is_winner_mode_processing_enabled(cpi, winner_mbmi,
+                                          winner_mbmi->mode)) {
+      RD_STATS rd_stats = *winner_rd_stats;
       int skip_blk = 0;
       RD_STATS rd_stats_y, rd_stats_uv;
       const int skip_ctx = av1_get_skip_context(xd);
 
-      *mbmi = *best_mbmode;
+      *mbmi = *winner_mbmi;
 
       set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
 
@@ -11855,18 +11954,20 @@
         skip_blk = 0;
         rd_stats_y.rate += x->skip_cost[skip_ctx][0];
       }
-
-      if (RDCOST(x->rdmult, best_rate_y + best_rate_uv, rd_cost->dist) >
-          RDCOST(x->rdmult, rd_stats_y.rate + rd_stats_uv.rate,
-                 (rd_stats_y.dist + rd_stats_uv.dist))) {
-        best_mbmode->tx_size = mbmi->tx_size;
-        av1_copy(best_mbmode->inter_tx_size, mbmi->inter_tx_size);
+      int this_rate = rd_stats.rate + rd_stats_y.rate + rd_stats_uv.rate -
+                      winner_rate_y - winner_rate_uv;
+      int64_t this_rd =
+          RDCOST(x->rdmult, this_rate, (rd_stats_y.dist + rd_stats_uv.dist));
+      if (best_rd > this_rd) {
+        *best_mbmode = *mbmi;
+        *best_mode_index = winner_mode_index;
         av1_copy_array(ctx->blk_skip, x->blk_skip, ctx->num_4x4_blk);
         av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
-        rd_cost->rate +=
-            (rd_stats_y.rate + rd_stats_uv.rate - best_rate_y - best_rate_uv);
+        rd_cost->rate = this_rate;
         rd_cost->dist = rd_stats_y.dist + rd_stats_uv.dist;
-        rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
+        rd_cost->sse = rd_stats_y.sse + rd_stats_uv.sse;
+        rd_cost->rdcost = this_rd;
+        best_rd = this_rd;
         *best_skip2 = skip_blk;
       }
     }
@@ -13124,9 +13225,14 @@
 
     if (ret_value != INT64_MAX) {
       rd_stats.rdcost = RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist);
+      const THR_MODES mode_enum = get_prediction_mode_idx(
+          mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+      // Collect mode stats for multiwinner mode processing
+      store_winner_mode_stats(
+          &cpi->common, x, mbmi, &rd_stats, &rd_stats_y, &rd_stats_uv,
+          mode_enum, NULL, bsize, rd_stats.rdcost,
+          cpi->sf.enable_multiwinner_mode_process, do_tx_search);
       if (rd_stats.rdcost < search_state->best_rd) {
-        const THR_MODES mode_enum = get_prediction_mode_idx(
-            mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
         update_search_state(search_state, rd_cost, ctx, &rd_stats, &rd_stats_y,
                             &rd_stats_uv, mode_enum, x, do_tx_search);
       }
@@ -13275,6 +13381,13 @@
       find_last_single_ref_mode_idx(av1_default_mode_order);
   int prune_cpd_using_sr_stats_ready = 0;
 
+  // Initialize best mode stats for winner mode processing
+  av1_zero(x->winner_mode_stats);
+  x->winner_mode_count = 0;
+  store_winner_mode_stats(&cpi->common, x, mbmi, NULL, NULL, NULL, THR_INVALID,
+                          NULL, bsize, best_rd_so_far,
+                          cpi->sf.enable_multiwinner_mode_process, 0);
+
   // Here midx is just an interator index that should not be used by itself
   // except to keep track of the number of modes searched. It should be used
   // with av1_default_mode_order to get the enum that defines the mode, which
@@ -13499,6 +13612,11 @@
     inter_modes_info_sort(inter_modes_info, inter_modes_info->rd_idx_pair_arr);
     search_state.best_rd = best_rd_so_far;
     search_state.best_mode_index = THR_INVALID;
+    // Initialize best mode stats for winner mode processing
+    x->winner_mode_count = 0;
+    store_winner_mode_stats(
+        &cpi->common, x, mbmi, NULL, NULL, NULL, THR_INVALID, NULL, bsize,
+        best_rd_so_far, cpi->sf.enable_multiwinner_mode_process, do_tx_search);
     inter_modes_info->num =
         inter_modes_info->num < cpi->sf.num_inter_modes_for_tx_search
             ? inter_modes_info->num
@@ -13544,13 +13662,19 @@
                                  x->skip_cost[skip_ctx][mbmi->skip]);
       }
       rd_stats.rdcost = RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist);
+      // TODO(chiyotsai@google.com): get_prediction_mode_idx gives incorrect
+      // output once we change the mode order. Fix this!
+      const THR_MODES mode_enum = get_prediction_mode_idx(
+          mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+
+      // Collect mode stats for multiwinner mode processing
+      const int txfm_search_done = 1;
+      store_winner_mode_stats(
+          &cpi->common, x, mbmi, &rd_stats, &rd_stats_y, &rd_stats_uv,
+          mode_enum, NULL, bsize, rd_stats.rdcost,
+          cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
 
       if (rd_stats.rdcost < search_state.best_rd) {
-        // TODO(chiyotsai@google.com): get_prediction_mode_idx gives incorrect
-        // output once we change the mode order. Fix this!
-        const THR_MODES mode_enum = get_prediction_mode_idx(
-            mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
-        const int txfm_search_done = 1;
         update_search_state(&search_state, rd_cost, ctx, &rd_stats, &rd_stats_y,
                             &rd_stats_uv, mode_enum, x, txfm_search_done);
       }
@@ -13628,8 +13752,15 @@
     intra_rd_stats.rdcost = handle_intra_mode(
         &search_state, cpi, x, bsize, intra_ref_frame_cost, ctx, 0,
         &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
+    // Collect mode stats for multiwinner mode processing
+    const int txfm_search_done = 1;
+    if (intra_rd_stats.rdcost != INT64_MAX) {
+      store_winner_mode_stats(
+          &cpi->common, x, mbmi, &intra_rd_stats, &intra_rd_stats_y,
+          &intra_rd_stats_uv, mode_enum, NULL, bsize, intra_rd_stats.rdcost,
+          cpi->sf.enable_multiwinner_mode_process, txfm_search_done);
+    }
     if (intra_rd_stats.rdcost < search_state.best_rd) {
-      const int txfm_search_done = 1;
       update_search_state(&search_state, rd_cost, ctx, &intra_rd_stats,
                           &intra_rd_stats_y, &intra_rd_stats_uv, mode_enum, x,
                           txfm_search_done);
@@ -13639,11 +13770,13 @@
   end_timing(cpi, handle_intra_mode_time);
 #endif
 
+  int winner_mode_count =
+      cpi->sf.enable_multiwinner_mode_process ? x->winner_mode_count : 1;
   // In effect only when fast tx search speed features are enabled.
-  refine_winner_mode_tx(cpi, x, rd_cost, bsize, ctx,
-                        search_state.best_mode_index, &search_state.best_mbmode,
-                        yv12_mb, search_state.best_rate_y,
-                        search_state.best_rate_uv, &search_state.best_skip2);
+  refine_winner_mode_tx(
+      cpi, x, rd_cost, bsize, ctx, &search_state.best_mode_index,
+      &search_state.best_mbmode, yv12_mb, search_state.best_rate_y,
+      search_state.best_rate_uv, &search_state.best_skip2, winner_mode_count);
 
   // Initialize default mode evaluation params
   set_mode_eval_params(cpi, x, DEFAULT_EVAL);