Prune interpolation filter search

Prune interpolation filter search based on previous frames'
filter selection stats. Calculate probabilities of each filter,
and prune the search at prediction block level. Currently, this
feature is turned on at speed 5.

Borg test(150 frames) results:
       avg_psnr:  ovr_psnr:  ssim:   avg speedups over whole set:
lowres: 0.153     0.163      0.149         1.6%
midres: 0.145     0.158      0.070         1.8%

STATS_CHANGED

Change-Id: I21edcc76ac30b2fd2eec8f5efd7edeb2c7119cbf
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a685152..fc9b4f6 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5507,6 +5507,33 @@
     cpi->warped_probs[update_type] =
         (cpi->warped_probs[update_type] + new_prob) >> 1;
   }
+
+  if (cm->current_frame.frame_type != KEY_FRAME &&
+      cpi->sf.interp_sf.adaptive_interp_filter_search == 2 &&
+      cm->interp_filter == SWITCHABLE) {
+    const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
+
+    for (i = 0; i < SWITCHABLE_FILTER_CONTEXTS; i++) {
+      int sum = 0;
+      int j;
+      int left = 1536;
+
+      for (j = 0; j < SWITCHABLE_FILTERS; j++) {
+        sum += cpi->td.counts->switchable_interp[i][j];
+      }
+
+      for (j = SWITCHABLE_FILTERS - 1; j >= 0; j--) {
+        const int new_prob =
+            sum ? 1536 * cpi->td.counts->switchable_interp[i][j] / sum
+                : (j ? 0 : 1536);
+        int prob =
+            (cpi->switchable_interp_probs[update_type][i][j] + new_prob) >> 1;
+        left -= prob;
+        if (j == 0) prob += left;
+        cpi->switchable_interp_probs[update_type][i][j] = prob;
+      }
+    }
+  }
 }
 
 #define CHECK_PRECOMPUTED_REF_FRAME_MAP 0
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index a5373eb..9e85e2b 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -251,6 +251,125 @@
 const int default_warped_probs[FRAME_UPDATE_TYPES] = { 64, 64, 64, 64,
                                                        64, 64, 64 };
 
+// TODO(yunqing): the default probs can be trained later from better
+// performance.
+const int default_switchable_interp_probs[FRAME_UPDATE_TYPES]
+                                         [SWITCHABLE_FILTER_CONTEXTS]
+                                         [SWITCHABLE_FILTERS] = {
+                                           { { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 } },
+                                           { { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 } },
+                                           { { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 } },
+                                           { { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 } },
+                                           { { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 } },
+                                           { { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 } },
+                                           { { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 },
+                                             { 512, 512, 512 } }
+                                         };
+
 static INLINE void Scale2Ratio(AOM_SCALING mode, int *hr, int *hs) {
   switch (mode) {
     case NORMAL:
@@ -5030,7 +5149,7 @@
       cm->current_frame.frame_type == KEY_FRAME) {
     av1_copy(cpi->tx_type_probs, default_tx_type_probs);
 
-    int thr[2][2] = { { 15, 10 }, { 17, 10 } };
+    const int thr[2][2] = { { 15, 10 }, { 17, 10 } };
     for (int f = 0; f < FRAME_UPDATE_TYPES; f++) {
       int kf_arf_update = (f == KF_UPDATE || f == ARF_UPDATE);
       cpi->tx_type_probs_thresh[f] =
@@ -5049,6 +5168,16 @@
     av1_copy(cpi->warped_probs, default_warped_probs);
   }
 
+  if (cpi->sf.interp_sf.adaptive_interp_filter_search == 2 &&
+      cm->current_frame.frame_type == KEY_FRAME) {
+    av1_copy(cpi->switchable_interp_probs, default_switchable_interp_probs);
+
+    const int thr[7] = { 0, 8, 8, 8, 8, 0, 8 };
+    for (int f = 0; f < FRAME_UPDATE_TYPES; f++) {
+      cpi->switchable_interp_thresh[f] = thr[f];
+    }
+  }
+
   // Loop variables
   int loop_count = 0;
   int loop_at_this_size = 0;
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index b2b8638..997c673 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -990,6 +990,9 @@
   int warped_probs[FRAME_UPDATE_TYPES];
   int tx_type_probs[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES];
   int tx_type_probs_thresh[FRAME_UPDATE_TYPES];
+  int switchable_interp_probs[FRAME_UPDATE_TYPES][SWITCHABLE_FILTER_CONTEXTS]
+                             [SWITCHABLE_FILTERS];
+  int switchable_interp_thresh[FRAME_UPDATE_TYPES];
 
   // Multi-threading
   int num_workers;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 170776d..754e860 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8257,6 +8257,29 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
 
+  uint16_t interp_filter_search_mask = cpi->interp_filter_search_mask;
+
+  if (cpi->sf.interp_sf.adaptive_interp_filter_search == 2) {
+    const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
+    const int ctx0 = av1_get_pred_context_switchable_interp(xd, 0);
+    const int ctx1 = av1_get_pred_context_switchable_interp(xd, 1);
+    const int *switchable_interp_p0 =
+        cpi->switchable_interp_probs[update_type][ctx0];
+    const int *switchable_interp_p1 =
+        cpi->switchable_interp_probs[update_type][ctx1];
+
+    const int thresh = cpi->switchable_interp_thresh[update_type];
+    for (i = 0; i < SWITCHABLE_FILTERS; i++) {
+      // For non-dual case, the 2 dir's prob should be identical.
+      assert(switchable_interp_p0[i] == switchable_interp_p1[i]);
+      if (switchable_interp_p0[i] < thresh &&
+          switchable_interp_p1[i] < thresh) {
+        DUAL_FILTER_TYPE filt_type = i + SWITCHABLE_FILTERS * i;
+        reset_interp_filter_allowed_mask(&interp_filter_search_mask, filt_type);
+      }
+    }
+  }
+
   // Regular filter evaluation should have been done and hence the same should
   // be the winner
   assert(x->e_mbd.mi[0]->interp_filters.as_int == filter_sets[0].as_int);
@@ -8273,7 +8296,7 @@
       assert(filter_sets[filter_idx].as_filters.x_filter ==
              filter_sets[filter_idx].as_filters.y_filter);
       if (cpi->sf.interp_sf.adaptive_interp_filter_search &&
-          !(get_interp_filter_allowed_mask(cpi->interp_filter_search_mask,
+          !(get_interp_filter_allowed_mask(interp_filter_search_mask,
                                            filter_idx))) {
         return;
       }
@@ -8308,7 +8331,7 @@
     set_interp_filter_allowed_mask(&allowed_interp_mask, SHARP_SHARP);
     set_interp_filter_allowed_mask(&allowed_interp_mask, SMOOTH_SMOOTH);
     if (cpi->sf.interp_sf.adaptive_interp_filter_search)
-      allowed_interp_mask &= cpi->interp_filter_search_mask;
+      allowed_interp_mask &= interp_filter_search_mask;
 
     find_best_interp_rd_facade(x, cpi, tile_data, bsize, orig_dst, rd,
                                rd_stats_y, rd_stats, switchable_rate, dst_bufs,
@@ -8322,8 +8345,7 @@
       assert(filter_sets[i].as_filters.x_filter ==
              filter_sets[i].as_filters.y_filter);
       if (cpi->sf.interp_sf.adaptive_interp_filter_search &&
-          !(get_interp_filter_allowed_mask(cpi->interp_filter_search_mask,
-                                           i))) {
+          !(get_interp_filter_allowed_mask(interp_filter_search_mask, i))) {
         continue;
       }
       interpolation_filter_rd(x, cpi, tile_data, bsize, orig_dst, rd,
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index e54f42c..741ea4c 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -539,6 +539,8 @@
     sf->inter_sf.disable_smooth_interintra = 1;
 
     sf->lpf_sf.disable_lr_filter = 1;
+
+    sf->interp_sf.adaptive_interp_filter_search = 2;
   }
 }