Modify pred_filter_search function

Change-Id: Ie6f246f622ddfd2c255f0cfa14ec18c49c3a2d44
diff --git a/av1/common/filter.h b/av1/common/filter.h
index 98c0846..d6a86a7 100644
--- a/av1/common/filter.h
+++ b/av1/common/filter.h
@@ -52,6 +52,12 @@
   INTERP_SKIP_LUMA_SKIP_CHROMA,
 } UENUM1BYTE(INTERP_EVAL_PLANE);
 
+enum {
+  INTERP_HORZ_NEQ_VERT_NEQ = 0,
+  INTERP_HORZ_EQ_VERT_NEQ,
+  INTERP_HORZ_NEQ_VERT_EQ,
+  INTERP_HORZ_EQ_VERT_EQ,
+} UENUM1BYTE(INTERP_PRED_TYPE);
 // Pack two InterpFilter's into a uint32_t: since there are at most 10 filters,
 // we can use 16 bits for each and have more than enough space. This reduces
 // argument passing and unifies the operation of setting a (pair of) filters.
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9184a46..9b8ce17 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8253,22 +8253,39 @@
   return 0;
 }
 
-static INLINE int is_pred_filter_search_allowed(
-    const AV1_COMP *const cpi, BLOCK_SIZE bsize, int mi_row, int mi_col,
-    InterpFilter af_horiz, InterpFilter af_vert, InterpFilter lf_horiz,
-    InterpFilter lf_vert) {
+static INLINE INTERP_PRED_TYPE is_pred_filter_search_allowed(
+    const AV1_COMP *const cpi, MACROBLOCKD *xd, BLOCK_SIZE bsize, int mi_row,
+    int mi_col, int_interpfilters *af, int_interpfilters *lf) {
   const AV1_COMMON *cm = &cpi->common;
+  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
+  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
   const int bsl = mi_size_wide_log2[bsize];
-  int pred_filter_search =
+  int is_horiz_eq = 0, is_vert_eq = 0;
+
+  if (above_mbmi && is_inter_block(above_mbmi))
+    *af = above_mbmi->interp_filters;
+
+  if (left_mbmi && is_inter_block(left_mbmi)) *lf = left_mbmi->interp_filters;
+
+  if (af->as_filters.x_filter != INTERP_INVALID)
+    is_horiz_eq = af->as_filters.x_filter == lf->as_filters.x_filter;
+  if (af->as_filters.y_filter != INTERP_INVALID)
+    is_vert_eq = af->as_filters.y_filter == lf->as_filters.y_filter;
+
+  INTERP_PRED_TYPE pred_filter_type = (is_vert_eq << 1) + is_horiz_eq;
+  int pred_filter_enable =
       cpi->sf.cb_pred_filter_search
           ? (((mi_row + mi_col) >> bsl) +
              get_chessboard_index(cm->current_frame.frame_number)) &
                 0x1
           : 0;
-  pred_filter_search &=
-      ((af_horiz == lf_horiz) && (af_horiz != INTERP_INVALID)) ||
-      ((af_vert == lf_vert) && (af_vert != INTERP_INVALID));
-  return pred_filter_search;
+  pred_filter_enable &= is_horiz_eq || is_vert_eq;
+  // pred_filter_search = 0: pred_filter is disabled
+  // pred_filter_search = 1: pred_filter is enabled and only horz pred matching
+  // pred_filter_search = 2: pred_filter is enabled and only vert pred matching
+  // pred_filter_search = 3: pred_filter is enabled and
+  //                         both vert, horz pred matching
+  return pred_filter_enable * pred_filter_type;
 }
 
 static INLINE void pred_dual_interp_filter_rd(
@@ -8276,30 +8293,30 @@
     const TileDataEnc *tile_data, BLOCK_SIZE bsize, int mi_row, int mi_col,
     const BUFFER_SET *const orig_dst, int64_t *const rd, RD_STATS *rd_stats_y,
     RD_STATS *rd_stats, int *const switchable_rate,
-    const BUFFER_SET *dst_bufs[2], int filter_idx, const int switchable_ctx[2],
-    const int skip_pred, InterpFilter af_horiz, InterpFilter af_vert,
-    InterpFilter lf_horiz, InterpFilter lf_vert) {
-  if ((af_horiz == lf_horiz) && (af_horiz != INTERP_INVALID)) {
-    if (((af_vert == lf_vert) && (af_vert != INTERP_INVALID))) {
-      filter_idx = af_horiz + (af_vert * SWITCHABLE_FILTERS);
+    const BUFFER_SET *dst_bufs[2], const int switchable_ctx[2],
+    const int skip_pred, INTERP_PRED_TYPE pred_filt_type, int_interpfilters *af,
+    int_interpfilters *lf) {
+  (void)lf;
+  int filter_idx = 0;
+  InterpFilter af_horiz = INTERP_INVALID, af_vert = INTERP_INVALID;
+  af_horiz = af->as_filters.x_filter;
+  af_vert = af->as_filters.y_filter;
+  assert(pred_filt_type != INTERP_HORZ_NEQ_VERT_NEQ);
+
+  // pred_filter_search = 1: pred_filter is enabled and only horz pred matching
+  if (pred_filt_type == INTERP_HORZ_EQ_VERT_NEQ) {
+    for (filter_idx = af_horiz; filter_idx < (DUAL_FILTER_SET_SIZE);
+         filter_idx += SWITCHABLE_FILTERS) {
       if (filter_idx) {
         interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
                                 orig_dst, rd, rd_stats_y, rd_stats,
                                 switchable_rate, dst_bufs, filter_idx,
                                 switchable_ctx, skip_pred);
       }
-    } else {
-      for (filter_idx = af_horiz; filter_idx < (DUAL_FILTER_SET_SIZE);
-           filter_idx += SWITCHABLE_FILTERS) {
-        if (filter_idx) {
-          interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
-                                  orig_dst, rd, rd_stats_y, rd_stats,
-                                  switchable_rate, dst_bufs, filter_idx,
-                                  switchable_ctx, skip_pred);
-        }
-      }
     }
-  } else if ((af_vert == lf_vert) && (af_vert != INTERP_INVALID)) {
+  } else if (pred_filt_type == INTERP_HORZ_NEQ_VERT_EQ) {
+    // pred_filter_search = 2: pred_filter is enabled and
+    //                         only vert pred matching
     for (filter_idx = (af_vert * SWITCHABLE_FILTERS);
          filter_idx <= ((af_vert * SWITCHABLE_FILTERS) + 2); filter_idx += 1) {
       if (filter_idx) {
@@ -8309,6 +8326,18 @@
                                 switchable_ctx, skip_pred);
       }
     }
+  } else if (pred_filt_type == INTERP_HORZ_EQ_VERT_EQ) {
+    // pred_filter_search = 3: pred_filter is enabled and
+    //                         both vert, horz pred matching
+    filter_idx = af_horiz + (af_vert * SWITCHABLE_FILTERS);
+    if (filter_idx) {
+      interpolation_filter_rd(x, cpi, tile_data, bsize, mi_row, mi_col,
+                              orig_dst, rd, rd_stats_y, rd_stats,
+                              switchable_rate, dst_bufs, filter_idx,
+                              switchable_ctx, skip_pred);
+    }
+  } else {
+    assert(0);
   }
 }
 // Evaluate dual filter type
@@ -8324,30 +8353,19 @@
     const int skip_hor, const int skip_ver) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
-  int pred_filter_search = 0;
-  InterpFilter af_horiz = INTERP_INVALID, af_vert = INTERP_INVALID,
-               lf_horiz = INTERP_INVALID, lf_vert = INTERP_INVALID;
-  if (!have_newmv_in_inter_mode(mbmi->mode)) {
-    const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-    const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-    if (above_mbmi && is_inter_block(above_mbmi)) {
-      af_horiz = above_mbmi->interp_filters.as_filters.x_filter;
-      af_vert = above_mbmi->interp_filters.as_filters.y_filter;
-    }
-    if (left_mbmi && is_inter_block(left_mbmi)) {
-      lf_horiz = left_mbmi->interp_filters.as_filters.x_filter;
-      lf_vert = left_mbmi->interp_filters.as_filters.y_filter;
-    }
-    pred_filter_search = is_pred_filter_search_allowed(
-        cpi, bsize, mi_row, mi_col, af_horiz, af_vert, lf_horiz, lf_vert);
-  }
+  INTERP_PRED_TYPE pred_filter_type = INTERP_HORZ_NEQ_VERT_NEQ;
+  int_interpfilters af = av1_broadcast_interp_filter(INTERP_INVALID);
+  int_interpfilters lf = af;
 
-  if (pred_filter_search) {
-    int filter_idx = 0;
+  if (!have_newmv_in_inter_mode(mbmi->mode))
+    pred_filter_type =
+        is_pred_filter_search_allowed(cpi, xd, bsize, mi_row, mi_col, &af, &lf);
+
+  if (pred_filter_type) {
     pred_dual_interp_filter_rd(
         x, cpi, tile_data, bsize, mi_row, mi_col, orig_dst, rd, rd_stats_y,
-        rd_stats, switchable_rate, dst_bufs, filter_idx, switchable_ctx,
-        (skip_hor & skip_ver), af_horiz, af_vert, lf_horiz, lf_vert);
+        rd_stats, switchable_rate, dst_bufs, switchable_ctx,
+        (skip_hor & skip_ver), pred_filter_type, &af, &lf);
   } else {
     const int bw = block_size_wide[bsize];
     const int bh = block_size_high[bsize];
@@ -8391,22 +8409,15 @@
   assert(x->e_mbd.mi[0]->interp_filters.as_int == filter_sets[0].as_int);
   assert(filter_set_size == DUAL_FILTER_SET_SIZE);
   if ((skip_hor & skip_ver) != cpi->default_interp_skip_flags) {
-    int pred_filter_search;
-    InterpFilter af_horiz = INTERP_INVALID, lf_horiz = INTERP_INVALID;
-    int filter_idx;
-    const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-    const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-    if (above_mbmi && is_inter_block(above_mbmi)) {
-      af_horiz = above_mbmi->interp_filters.as_filters.x_filter;
-    }
-    if (left_mbmi && is_inter_block(left_mbmi)) {
-      lf_horiz = left_mbmi->interp_filters.as_filters.x_filter;
-    }
-    pred_filter_search = is_pred_filter_search_allowed(
-        cpi, bsize, mi_row, mi_col, af_horiz, af_horiz, lf_horiz, lf_horiz);
-    if (pred_filter_search) {
-      assert(af_horiz != INTERP_INVALID);
-      filter_idx = SWITCHABLE * af_horiz;
+    INTERP_PRED_TYPE pred_filter_type = INTERP_HORZ_NEQ_VERT_NEQ;
+    int_interpfilters af = av1_broadcast_interp_filter(INTERP_INVALID);
+    int_interpfilters lf = af;
+
+    pred_filter_type =
+        is_pred_filter_search_allowed(cpi, xd, bsize, mi_row, mi_col, &af, &lf);
+    if (pred_filter_type) {
+      assert(af.as_filters.x_filter != INTERP_INVALID);
+      int filter_idx = SWITCHABLE * af.as_filters.x_filter;
       // This assert tells that (filter_x == filter_y) for non-dual filter case
       assert(filter_sets[filter_idx].as_filters.x_filter ==
              filter_sets[filter_idx].as_filters.y_filter);