Refactor prune_ref_by_selective_ref_frame

Refactored the pruning logic for reference frames
in prune_ref_by_selective_ref_frame

Change-Id: I43fedc7bd0da9cd7a7aa58c07a87329790ff5005
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index bfeeac2..020d39f 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -383,6 +383,26 @@
   }
 }
 
+// This function prunes the mode if either of the reference frame falls in the
+// pruning list
+static INLINE int prune_ref(const MV_REFERENCE_FRAME *const ref_frame,
+                            const OrderHintInfo *const order_hint_info,
+                            const unsigned int *const ref_display_order_hint,
+                            const unsigned int frame_display_order_hint,
+                            const int *ref_frame_list) {
+  for (int i = 0; i < 2; i++) {
+    if (ref_frame[0] == ref_frame_list[i] ||
+        ref_frame[1] == ref_frame_list[i]) {
+      if (av1_encoder_get_relative_dist(
+              order_hint_info,
+              ref_display_order_hint[ref_frame_list[i] - LAST_FRAME],
+              frame_display_order_hint) < 0)
+        return 1;
+    }
+  }
+  return 0;
+}
+
 static INLINE int prune_ref_by_selective_ref_frame(
     const AV1_COMP *const cpi, const MV_REFERENCE_FRAME *const ref_frame,
     const unsigned int *const ref_display_order_hint,
@@ -395,54 +415,37 @@
     const int comp_pred = ref_frame[1] > INTRA_FRAME;
     if (sf->inter_sf.selective_ref_frame >= 2 ||
         (sf->inter_sf.selective_ref_frame == 1 && comp_pred)) {
-      if (ref_frame[0] == LAST3_FRAME || ref_frame[1] == LAST3_FRAME) {
-        if (av1_encoder_get_relative_dist(
-                order_hint_info,
-                ref_display_order_hint[LAST3_FRAME - LAST_FRAME],
-                ref_display_order_hint[GOLDEN_FRAME - LAST_FRAME]) < 0)
-          return 1;
-      }
-      if (ref_frame[0] == LAST2_FRAME || ref_frame[1] == LAST2_FRAME) {
-        if (av1_encoder_get_relative_dist(
-                order_hint_info,
-                ref_display_order_hint[LAST2_FRAME - LAST_FRAME],
-                ref_display_order_hint[GOLDEN_FRAME - LAST_FRAME]) < 0)
-          return 1;
-      }
+      const int ref_frame_list[2] = { LAST3_FRAME, LAST2_FRAME };
+      if (prune_ref(ref_frame, order_hint_info, ref_display_order_hint,
+                    ref_display_order_hint[GOLDEN_FRAME - LAST_FRAME],
+                    ref_frame_list))
+        return 1;
     }
 
     // One-sided compound is used only when all reference frames are one-sided.
     if (sf->inter_sf.selective_ref_frame >= 2 && comp_pred &&
         !cpi->all_one_sided_refs) {
       unsigned int ref_offsets[2];
+      int ref_dist[2];
       for (int i = 0; i < 2; ++i) {
         const RefCntBuffer *const buf = get_ref_frame_buf(cm, ref_frame[i]);
         assert(buf != NULL);
         ref_offsets[i] = buf->display_order_hint;
+        ref_dist[i] = av1_encoder_get_relative_dist(
+            order_hint_info, ref_offsets[i], cur_frame_display_order_hint);
       }
-      const int ref0_dist = av1_encoder_get_relative_dist(
-          order_hint_info, ref_offsets[0], cur_frame_display_order_hint);
-      const int ref1_dist = av1_encoder_get_relative_dist(
-          order_hint_info, ref_offsets[1], cur_frame_display_order_hint);
-      if ((ref0_dist <= 0 && ref1_dist <= 0) ||
-          (ref0_dist > 0 && ref1_dist > 0)) {
+
+      // If both references are in same direction
+      if ((ref_dist[0] > 0) == (ref_dist[1] > 0)) {
         return 1;
       }
     }
 
     if (sf->inter_sf.selective_ref_frame >= 3) {
-      if (ref_frame[0] == ALTREF2_FRAME || ref_frame[1] == ALTREF2_FRAME)
-        if (av1_encoder_get_relative_dist(
-                order_hint_info,
-                ref_display_order_hint[ALTREF2_FRAME - LAST_FRAME],
-                cur_frame_display_order_hint) < 0)
-          return 1;
-      if (ref_frame[0] == BWDREF_FRAME || ref_frame[1] == BWDREF_FRAME)
-        if (av1_encoder_get_relative_dist(
-                order_hint_info,
-                ref_display_order_hint[BWDREF_FRAME - LAST_FRAME],
-                cur_frame_display_order_hint) < 0)
-          return 1;
+      static const int ref_frame_list[2] = { ALTREF2_FRAME, BWDREF_FRAME };
+      if (prune_ref(ref_frame, order_hint_info, ref_display_order_hint,
+                    cur_frame_display_order_hint, ref_frame_list))
+        return 1;
     }
 
     if (sf->inter_sf.selective_ref_frame >= 4 && comp_pred) {