Refine the prediction mode reused in AB partitions

This commit refines how the prediction mode is reused by properly taking
reference frames into account.

This gives about 0.03% bitrate reduction on speed 3 and above with
neutral speed change.

STATS_CHANGED

Change-Id: If53da1ef86fa2339115af892674d929e3e4304cc
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 12a4598..ba9933c 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -15,6 +15,7 @@
 #ifndef AOM_AV1_ENCODER_BLOCK_H_
 #define AOM_AV1_ENCODER_BLOCK_H_
 
+#include "av1/common/blockd.h"
 #include "av1/common/entropymv.h"
 #include "av1/common/entropy.h"
 #include "av1/common/enums.h"
@@ -1084,7 +1085,7 @@
   /*! \brief Whether to reuse the mode stored in intermode_cache. */
   int use_intermode_cache;
   /*! \brief The mode to reuse during \ref av1_rd_pick_inter_mode_sb. */
-  PREDICTION_MODE intermode_cache;
+  const MB_MODE_INFO *intermode_cache;
   /**@}*/
 
   /*****************************************************************************
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index a213738..c5e8368 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -2251,7 +2251,7 @@
                                PARTITION_TYPE partition,
                                const BLOCK_SIZE ab_subsize[SUB_PARTITIONS_AB],
                                const int ab_mi_pos[SUB_PARTITIONS_AB][2],
-                               const PREDICTION_MODE *mode_cache) {
+                               const MB_MODE_INFO **mode_cache) {
   MACROBLOCK *const x = &td->mb;
   const MACROBLOCKD *const xd = &x->e_mbd;
   const int pl = partition_plane_context(xd, mi_row, mi_col, bsize);
@@ -2261,7 +2261,7 @@
   sum_rdc.rdcost = RDCOST(x->rdmult, sum_rdc.rate, 0);
   // Loop over sub-partitions in AB partition type.
   for (int i = 0; i < SUB_PARTITIONS_AB; i++) {
-    if (mode_cache && mode_cache[i] != PRED_MODE_INVALID) {
+    if (mode_cache && mode_cache[i]) {
       x->use_intermode_cache = 1;
       x->intermode_cache = mode_cache[i];
     }
@@ -2270,7 +2270,7 @@
                         ab_mi_pos[i][0], ab_mi_pos[i][1], ab_subsize[i],
                         *best_rdc, &sum_rdc, partition, ctxs[i]);
     x->use_intermode_cache = 0;
-    x->intermode_cache = PRED_MODE_INVALID;
+    x->intermode_cache = NULL;
     if (!mode_search_success) {
       return false;
     }
@@ -2628,7 +2628,7 @@
     PartitionSearchState *part_search_state, RD_STATS *best_rdc,
     const BLOCK_SIZE ab_subsize[SUB_PARTITIONS_AB],
     const int ab_mi_pos[SUB_PARTITIONS_AB][2], const PARTITION_TYPE part_type,
-    const PREDICTION_MODE *mode_cache) {
+    const MB_MODE_INFO **mode_cache) {
   const AV1_COMMON *const cm = &cpi->common;
   PartitionBlkParams blk_params = part_search_state->part_blk_params;
   const int mi_row = blk_params.mi_row;
@@ -2696,25 +2696,25 @@
 }
 
 static AOM_INLINE void copy_partition_mode_from_mode_context(
-    PREDICTION_MODE *dst_mode, const PICK_MODE_CONTEXT *ctx) {
+    const MB_MODE_INFO **dst_mode, const PICK_MODE_CONTEXT *ctx) {
   if (ctx && ctx->rd_stats.rate < INT_MAX) {
-    *dst_mode = ctx->mic.mode;
+    *dst_mode = &ctx->mic;
   } else {
-    *dst_mode = PRED_MODE_INVALID;
+    *dst_mode = NULL;
   }
 }
 
 static AOM_INLINE void copy_partition_mode_from_pc_tree(
-    PREDICTION_MODE *dst_mode, const PC_TREE *pc_tree) {
+    const MB_MODE_INFO **dst_mode, const PC_TREE *pc_tree) {
   if (pc_tree) {
     copy_partition_mode_from_mode_context(dst_mode, pc_tree->none);
   } else {
-    *dst_mode = PRED_MODE_INVALID;
+    *dst_mode = NULL;
   }
 }
 
 static AOM_INLINE void set_mode_cache_for_partition_ab(
-    PREDICTION_MODE *mode_cache, const PC_TREE *pc_tree,
+    const MB_MODE_INFO **mode_cache, const PC_TREE *pc_tree,
     AB_PART_TYPE ab_part_type) {
   switch (ab_part_type) {
     case HORZ_A:
@@ -2854,8 +2854,7 @@
 
     // Even if the contexts don't match, we can still speed up by reusing the
     // previous prediction mode.
-    PREDICTION_MODE mode_cache[3] = { PRED_MODE_INVALID, PRED_MODE_INVALID,
-                                      PRED_MODE_INVALID };
+    const MB_MODE_INFO *mode_cache[3] = { NULL, NULL, NULL };
     if (cpi->sf.inter_sf.reuse_best_prediction_for_part_ab) {
       set_mode_cache_for_partition_ab(mode_cache, pc_tree, ab_part_type);
     }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index b10d910..9e3718f 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3713,6 +3713,22 @@
   return 0;
 }
 
+static AOM_INLINE int is_ref_frame_used_in_cache(MV_REFERENCE_FRAME ref_frame,
+                                                 const MB_MODE_INFO *mi_cache) {
+  if (!mi_cache) {
+    return 0;
+  }
+
+  if (ref_frame < REF_FRAMES) {
+    return (ref_frame == mi_cache->ref_frame[0] ||
+            ref_frame == mi_cache->ref_frame[1]);
+  }
+
+  // if we are here, then the current mode is compound.
+  MV_REFERENCE_FRAME cached_ref_type = av1_ref_frame_type(mi_cache->ref_frame);
+  return ref_frame == cached_ref_type;
+}
+
 // Please add/modify parameter setting in this function, making it consistent
 // and easy to read and maintain.
 static AOM_INLINE void set_params_rd_pick_inter_mode(
@@ -3744,7 +3760,8 @@
       // Skip the ref frame if the mask says skip and the ref is not used by
       // compound ref.
       if (skip_ref_frame_mask & (1 << ref_frame) &&
-          !is_ref_frame_used_by_compound_ref(ref_frame, skip_ref_frame_mask)) {
+          !is_ref_frame_used_by_compound_ref(ref_frame, skip_ref_frame_mask) &&
+          !is_ref_frame_used_in_cache(ref_frame, x->intermode_cache)) {
         continue;
       }
       assert(get_ref_frame_yv12_buf(cm, ref_frame) != NULL);
@@ -3769,7 +3786,8 @@
         continue;
       }
 
-      if (skip_ref_frame_mask & (1 << ref_frame)) {
+      if (skip_ref_frame_mask & (1 << ref_frame) &&
+          !is_ref_frame_used_in_cache(ref_frame, x->intermode_cache)) {
         continue;
       }
       // Ref mv list population is not required, when compound references are
@@ -4012,6 +4030,51 @@
     return 1;
   }
 
+  // Reuse the prediction mode in cache
+  if (x->use_intermode_cache) {
+    const MB_MODE_INFO *cached_mi = x->intermode_cache;
+    const PREDICTION_MODE cached_mode = cached_mi->mode;
+    const MV_REFERENCE_FRAME *cached_frame = cached_mi->ref_frame;
+    const int cached_mode_is_single = cached_frame[1] <= INTRA_FRAME;
+
+    // If the cached mode is intra, then we just need to match the mode.
+    if (is_mode_intra(cached_mode) && mode != cached_mode) {
+      return 1;
+    }
+
+    // If the cached mode is single inter mode, then we match the mode and
+    // reference frame.
+    if (cached_mode_is_single) {
+      if (mode != cached_mode || ref_frame[0] != cached_frame[0]) {
+        return 1;
+      }
+    } else {
+      // If the cached mode is compound, then we need to consider several cases.
+      const int mode_is_single = ref_frame[1] <= INTRA_FRAME;
+      if (mode_is_single) {
+        // If the mode is single, we know the modes can't match. But we might
+        // still want to search it if compound mode depends on the current mode.
+        int skip_motion_mode_only = 0;
+        if (cached_mode == NEW_NEARMV || cached_mode == NEW_NEARESTMV) {
+          skip_motion_mode_only = (ref_frame[0] == cached_frame[0]);
+        } else if (cached_mode == NEAR_NEWMV || cached_mode == NEAREST_NEWMV) {
+          skip_motion_mode_only = (ref_frame[0] == cached_frame[1]);
+        } else if (cached_mode == NEW_NEWMV) {
+          skip_motion_mode_only = (ref_frame[0] == cached_frame[0] ||
+                                   ref_frame[0] == cached_frame[1]);
+        }
+
+        return 1 + skip_motion_mode_only;
+      } else {
+        // If both modes are compound, then everything must match.
+        if (mode != cached_mode || ref_frame[0] != cached_frame[0] ||
+            ref_frame[1] != cached_frame[1]) {
+          return 1;
+        }
+      }
+    }
+  }
+
   const MB_MODE_INFO *const mbmi = x->e_mbd.mi[0];
   // If no valid mode has been found so far in PARTITION_NONE when finding a
   // valid partition is required, do not skip mode.
@@ -4062,6 +4125,15 @@
         skip_ref = 0;
       }
     }
+    // If we are reusing the prediction from cache, and the current frame is
+    // required by the cache, then we cannot prune it.
+    if (is_ref_frame_used_in_cache(ref_type, x->intermode_cache)) {
+      skip_ref = 0;
+      // If the cache only needs the current reference type for compound
+      // prediction, then we can skip motion mode search.
+      skip_motion_mode = (ref_type <= ALTREF_FRAME &&
+                          x->intermode_cache->ref_frame[1] > INTRA_FRAME);
+    }
     if (skip_ref) return 1;
   }
 
@@ -4439,9 +4511,8 @@
   const MACROBLOCKD *xd = &x->e_mbd;
   const MB_MODE_INFO *mbmi = xd->mi[0];
   const int skip_ctx = av1_get_skip_txfm_context(xd);
-  const int mode_is_intra =
-      (av1_mode_defs[new_best_mode].mode < INTRA_MODE_END);
-  const int skip_txfm = mbmi->skip_txfm && !mode_is_intra;
+  const int skip_txfm =
+      mbmi->skip_txfm && !is_mode_intra(av1_mode_defs[new_best_mode].mode);
   const TxfmSearchInfo *txfm_info = &x->txfm_search_info;
 
   search_state->best_rd = new_best_rd_stats->rdcost;
@@ -5093,10 +5164,6 @@
     num_single_modes_processed += is_single_pred;
     set_ref_ptrs(cm, xd, ref_frame, second_ref_frame);
 
-    if (x->use_intermode_cache && this_mode != x->intermode_cache) {
-      continue;
-    }
-
     // Apply speed features to decide if this inter mode can be skipped
     if (skip_inter_mode(cpi, x, bsize, ref_frame_rd, midx, &sf_args)) continue;
 
diff --git a/av1/encoder/rdopt_utils.h b/av1/encoder/rdopt_utils.h
index c7c7d17..e636df8 100644
--- a/av1/encoder/rdopt_utils.h
+++ b/av1/encoder/rdopt_utils.h
@@ -683,6 +683,10 @@
                                                 const struct buf_2d *ref,
                                                 BLOCK_SIZE bs, int bd);
 
+static INLINE int is_mode_intra(PREDICTION_MODE mode) {
+  return mode < INTRA_MODE_END;
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif