Facilitate ref frame pruning in gm_search

This patch facilitates the changes to introduce ref frame pruning
in global motion estimation.

Change-Id: Iceb70fe91bf1f283b17e790a4b030566590e98e0
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 89d2701..120d7d2 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4716,6 +4716,220 @@
     return INT64_MAX;
 }
 
+static void compute_global_motion_for_ref_frame(
+    AV1_COMP *cpi, YV12_BUFFER_CONFIG *ref_buf[REF_FRAMES], int frame,
+    int *num_frm_corners, int *frm_corners, unsigned char *frm_buffer,
+    MotionModel *params_by_motion, uint8_t *segment_map,
+    const int segment_map_w, const int segment_map_h,
+    const WarpedMotionParams *ref_params) {
+  ThreadData *const td = &cpi->td;
+  MACROBLOCK *const x = &td->mb;
+  AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  int i;
+  // clang-format off
+  static const double kIdentityParams[MAX_PARAMDIM - 1] = {
+     0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0
+  };
+  // clang-format on
+  WarpedMotionParams tmp_wm_params;
+  const double *params_this_motion;
+  int inliers_by_motion[RANSAC_NUM_MOTIONS];
+  assert(ref_buf[frame] != NULL);
+  if (*num_frm_corners < 0) {
+    // compute interest points using FAST features
+    *num_frm_corners = av1_fast_corner_detect(
+        frm_buffer, cpi->source->y_width, cpi->source->y_height,
+        cpi->source->y_stride, frm_corners, MAX_CORNERS);
+  }
+  TransformationType model;
+
+  aom_clear_system_state();
+
+  // TODO(sarahparker, debargha): Explore do_adaptive_gm_estimation = 1
+  const int do_adaptive_gm_estimation = 0;
+
+  const int ref_frame_dist = get_relative_dist(
+      &cm->seq_params.order_hint_info, cm->current_frame.order_hint,
+      cm->cur_frame->ref_order_hints[frame - LAST_FRAME]);
+  const GlobalMotionEstimationType gm_estimation_type =
+      cm->seq_params.order_hint_info.enable_order_hint &&
+              abs(ref_frame_dist) <= 2 && do_adaptive_gm_estimation
+          ? GLOBAL_MOTION_DISFLOW_BASED
+          : GLOBAL_MOTION_FEATURE_BASED;
+  for (model = ROTZOOM; model < GLOBAL_TRANS_TYPES_ENC; ++model) {
+    int64_t best_warp_error = INT64_MAX;
+    // Initially set all params to identity.
+    for (i = 0; i < RANSAC_NUM_MOTIONS; ++i) {
+      memcpy(params_by_motion[i].params, kIdentityParams,
+             (MAX_PARAMDIM - 1) * sizeof(*(params_by_motion[i].params)));
+      params_by_motion[i].num_inliers = 0;
+    }
+
+    av1_compute_global_motion(
+        model, frm_buffer, cpi->source->y_width, cpi->source->y_height,
+        cpi->source->y_stride, frm_corners, *num_frm_corners, ref_buf[frame],
+        cpi->common.seq_params.bit_depth, gm_estimation_type, inliers_by_motion,
+        params_by_motion, RANSAC_NUM_MOTIONS);
+    int64_t ref_frame_error = 0;
+    for (i = 0; i < RANSAC_NUM_MOTIONS; ++i) {
+      if (inliers_by_motion[i] == 0) continue;
+
+      params_this_motion = params_by_motion[i].params;
+      av1_convert_model_to_params(params_this_motion, &tmp_wm_params);
+
+      if (tmp_wm_params.wmtype != IDENTITY) {
+        av1_compute_feature_segmentation_map(
+            segment_map, segment_map_w, segment_map_h,
+            params_by_motion[i].inliers, params_by_motion[i].num_inliers);
+
+        ref_frame_error = av1_segmented_frame_error(
+            is_cur_buf_hbd(xd), xd->bd, ref_buf[frame]->y_buffer,
+            ref_buf[frame]->y_stride, cpi->source->y_buffer,
+            cpi->source->y_width, cpi->source->y_height, cpi->source->y_stride,
+            segment_map, segment_map_w);
+
+        int64_t erroradv_threshold =
+            calc_erroradv_threshold(cpi, ref_frame_error);
+
+        const int64_t warp_error = av1_refine_integerized_param(
+            &tmp_wm_params, tmp_wm_params.wmtype, is_cur_buf_hbd(xd), xd->bd,
+            ref_buf[frame]->y_buffer, ref_buf[frame]->y_width,
+            ref_buf[frame]->y_height, ref_buf[frame]->y_stride,
+            cpi->source->y_buffer, cpi->source->y_width, cpi->source->y_height,
+            cpi->source->y_stride, GM_REFINEMENT_COUNT, best_warp_error,
+            segment_map, segment_map_w, erroradv_threshold);
+
+        if (warp_error < best_warp_error) {
+          best_warp_error = warp_error;
+          // Save the wm_params modified by
+          // av1_refine_integerized_param() rather than motion index to
+          // avoid rerunning refine() below.
+          memcpy(&(cm->global_motion[frame]), &tmp_wm_params,
+                 sizeof(WarpedMotionParams));
+        }
+      }
+    }
+    if (cm->global_motion[frame].wmtype <= AFFINE)
+      if (!av1_get_shear_params(&cm->global_motion[frame]))
+        cm->global_motion[frame] = default_warp_params;
+
+    if (cm->global_motion[frame].wmtype == TRANSLATION) {
+      cm->global_motion[frame].wmmat[0] =
+          convert_to_trans_prec(cm->allow_high_precision_mv,
+                                cm->global_motion[frame].wmmat[0]) *
+          GM_TRANS_ONLY_DECODE_FACTOR;
+      cm->global_motion[frame].wmmat[1] =
+          convert_to_trans_prec(cm->allow_high_precision_mv,
+                                cm->global_motion[frame].wmmat[1]) *
+          GM_TRANS_ONLY_DECODE_FACTOR;
+    }
+
+    if (cm->global_motion[frame].wmtype == IDENTITY) continue;
+
+    if (ref_frame_error == 0) continue;
+
+    // If the best error advantage found doesn't meet the threshold for
+    // this motion type, revert to IDENTITY.
+    if (!av1_is_enough_erroradvantage(
+            (double)best_warp_error / ref_frame_error,
+            gm_get_params_cost(&cm->global_motion[frame], ref_params,
+                               cm->allow_high_precision_mv),
+            cpi->sf.gm_erroradv_type)) {
+      cm->global_motion[frame] = default_warp_params;
+    }
+
+    if (cm->global_motion[frame].wmtype != IDENTITY) break;
+  }
+
+  aom_clear_system_state();
+}
+
+typedef struct {
+  int distance;
+  MV_REFERENCE_FRAME frame;
+} FrameDistPair;
+
+static INLINE void update_valid_ref_frames_for_gm(
+    AV1_COMP *cpi, YV12_BUFFER_CONFIG *ref_buf[REF_FRAMES],
+    FrameDistPair *past_ref_frame, FrameDistPair *future_ref_frame,
+    int *num_past_ref_frames, int *num_future_ref_frames) {
+  AV1_COMMON *const cm = &cpi->common;
+  const OrderHintInfo *const order_hint_info = &cm->seq_params.order_hint_info;
+  for (int frame = ALTREF_FRAME; frame >= LAST_FRAME; --frame) {
+    const MV_REFERENCE_FRAME ref_frame[2] = { frame, NONE_FRAME };
+    RefCntBuffer *buf = get_ref_frame_buf(cm, frame);
+    const int ref_disabled =
+        !(cpi->ref_frame_flags & av1_ref_frame_flag_list[frame]);
+    ref_buf[frame] = NULL;
+    cm->global_motion[frame] = default_warp_params;
+    // Skip global motion estimation for invalid ref frames
+    if (buf == NULL ||
+        (ref_disabled && cpi->sf.recode_loop != DISALLOW_RECODE)) {
+      cpi->gmparams_cost[frame] = 0;
+      continue;
+    } else {
+      ref_buf[frame] = &buf->buf;
+    }
+
+    if (ref_buf[frame]->y_crop_width == cpi->source->y_crop_width &&
+        ref_buf[frame]->y_crop_height == cpi->source->y_crop_height &&
+        do_gm_search_logic(&cpi->sf, frame) &&
+        !prune_ref_by_selective_ref_frame(
+            cpi, ref_frame, cm->cur_frame->ref_display_order_hint,
+            cm->current_frame.display_order_hint) &&
+        !(cpi->sf.selective_ref_gm && skip_gm_frame(cm, frame))) {
+      assert(ref_buf[frame] != NULL);
+      int relative_frame_dist = av1_encoder_get_relative_dist(
+          order_hint_info, buf->display_order_hint,
+          cm->cur_frame->display_order_hint);
+      // Populate past and future ref frames
+      if (relative_frame_dist < 0) {
+        past_ref_frame[*num_past_ref_frames].distance =
+            abs(relative_frame_dist);
+        past_ref_frame[*num_past_ref_frames].frame = frame;
+        (*num_past_ref_frames)++;
+      } else {
+        future_ref_frame[*num_future_ref_frames].distance =
+            abs(relative_frame_dist);
+        future_ref_frame[*num_future_ref_frames].frame = frame;
+        (*num_future_ref_frames)++;
+      }
+    }
+  }
+}
+
+static INLINE void compute_gm_for_valid_ref_frames(
+    AV1_COMP *cpi, YV12_BUFFER_CONFIG *ref_buf[REF_FRAMES], int frame,
+    int *num_frm_corners, int *frm_corners, unsigned char *frm_buffer,
+    MotionModel *params_by_motion, uint8_t *segment_map,
+    const int segment_map_w, const int segment_map_h) {
+  AV1_COMMON *const cm = &cpi->common;
+  const WarpedMotionParams *ref_params =
+      cm->prev_frame ? &cm->prev_frame->global_motion[frame]
+                     : &default_warp_params;
+
+  compute_global_motion_for_ref_frame(
+      cpi, ref_buf, frame, num_frm_corners, frm_corners, frm_buffer,
+      params_by_motion, segment_map, segment_map_w, segment_map_h, ref_params);
+
+  cpi->gmparams_cost[frame] =
+      gm_get_params_cost(&cm->global_motion[frame], ref_params,
+                         cm->allow_high_precision_mv) +
+      cpi->gmtype_cost[cm->global_motion[frame].wmtype] -
+      cpi->gmtype_cost[IDENTITY];
+}
+
+static int compare_distance(const void *a, const void *b) {
+  const int diff =
+      ((FrameDistPair *)a)->distance - ((FrameDistPair *)b)->distance;
+  if (diff < 0)
+    return 1;
+  else if (diff > 0)
+    return -1;
+  return 0;
+}
+
 static AOM_INLINE void encode_frame_internal(AV1_COMP *cpi) {
   ThreadData *const td = &cpi->td;
   MACROBLOCK *const x = &td->mb;
@@ -4888,7 +5102,6 @@
   if (cpi->common.current_frame.frame_type == INTER_FRAME && cpi->source &&
       cpi->oxcf.enable_global_motion && !cpi->global_motion_search_done) {
     YV12_BUFFER_CONFIG *ref_buf[REF_FRAMES];
-    int frame;
     MotionModel params_by_motion[RANSAC_NUM_MOTIONS];
     for (int m = 0; m < RANSAC_NUM_MOTIONS; m++) {
       memset(&params_by_motion[m], 0, sizeof(params_by_motion[m]));
@@ -4896,14 +5109,6 @@
           aom_malloc(sizeof(*(params_by_motion[m].inliers)) * 2 * MAX_CORNERS);
     }
 
-    const double *params_this_motion;
-    int inliers_by_motion[RANSAC_NUM_MOTIONS];
-    WarpedMotionParams tmp_wm_params;
-    // clang-format off
-    static const double kIdentityParams[MAX_PARAMDIM - 1] = {
-      0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0
-    };
-    // clang-format on
     int num_frm_corners = -1;
     int frm_corners[2 * MAX_CORNERS];
     unsigned char *frm_buffer = cpi->source->y_buffer;
@@ -4923,146 +5128,44 @@
     memset(segment_map, 0,
            sizeof(*segment_map) * segment_map_w * segment_map_h);
 
-    for (frame = ALTREF_FRAME; frame >= LAST_FRAME; --frame) {
-      const WarpedMotionParams *ref_params;
-      const MV_REFERENCE_FRAME ref_frame[2] = { frame, NONE_FRAME };
-      RefCntBuffer *buf = get_ref_frame_buf(cm, frame);
-      const int ref_disabled =
-          !(cpi->ref_frame_flags & av1_ref_frame_flag_list[frame]);
-      ref_buf[frame] = NULL;
-      cm->global_motion[frame] = default_warp_params;
-      // Skip global motion estimation for invalid ref frames
-      if (buf == NULL ||
-          (ref_disabled && cpi->sf.recode_loop != DISALLOW_RECODE)) {
-        cpi->gmparams_cost[frame] = 0;
-        continue;
-      } else {
-        ref_buf[frame] = &buf->buf;
-        ref_params = cm->prev_frame ? &cm->prev_frame->global_motion[frame]
-                                    : &default_warp_params;
-      }
+    FrameDistPair future_ref_frame[REF_FRAMES - 1] = {
+      { -1, NONE_FRAME }, { -1, NONE_FRAME }, { -1, NONE_FRAME },
+      { -1, NONE_FRAME }, { -1, NONE_FRAME }, { -1, NONE_FRAME },
+      { -1, NONE_FRAME }
+    };
+    FrameDistPair past_ref_frame[REF_FRAMES - 1] = {
+      { -1, NONE_FRAME }, { -1, NONE_FRAME }, { -1, NONE_FRAME },
+      { -1, NONE_FRAME }, { -1, NONE_FRAME }, { -1, NONE_FRAME },
+      { -1, NONE_FRAME }
+    };
+    int num_past_ref_frames = 0;
+    int num_future_ref_frames = 0;
+    // Populate ref_buf for valid ref frames in global motion
+    update_valid_ref_frames_for_gm(cpi, ref_buf, past_ref_frame,
+                                   future_ref_frame, &num_past_ref_frames,
+                                   &num_future_ref_frames);
 
-      if (ref_buf[frame]->y_crop_width == cpi->source->y_crop_width &&
-          ref_buf[frame]->y_crop_height == cpi->source->y_crop_height &&
-          do_gm_search_logic(&cpi->sf, frame) &&
-          !prune_ref_by_selective_ref_frame(
-              cpi, ref_frame, cm->cur_frame->ref_display_order_hint,
-              cm->current_frame.display_order_hint) &&
-          !(cpi->sf.selective_ref_gm && skip_gm_frame(cm, frame))) {
-        assert(ref_buf[frame] != NULL);
-        if (num_frm_corners < 0) {
-          // compute interest points using FAST features
-          num_frm_corners = av1_fast_corner_detect(
-              frm_buffer, cpi->source->y_width, cpi->source->y_height,
-              cpi->source->y_stride, frm_corners, MAX_CORNERS);
-        }
-        TransformationType model;
+    // Sort the ref frames based on the distance from current frame
+    qsort(past_ref_frame, num_past_ref_frames, sizeof(past_ref_frame[0]),
+          compare_distance);
+    qsort(future_ref_frame, num_future_ref_frames, sizeof(future_ref_frame[0]),
+          compare_distance);
 
-        aom_clear_system_state();
+    // Compute global motion w.r.t. past reference frames
+    for (int past_frame = 0; past_frame < num_past_ref_frames; past_frame++) {
+      int frame = past_ref_frame[past_frame].frame;
+      compute_gm_for_valid_ref_frames(
+          cpi, ref_buf, frame, &num_frm_corners, frm_corners, frm_buffer,
+          params_by_motion, segment_map, segment_map_w, segment_map_h);
+    }
 
-        // TODO(sarahparker, debargha): Explore do_adaptive_gm_estimation = 1
-        const int do_adaptive_gm_estimation = 0;
-
-        const int ref_frame_dist = get_relative_dist(
-            &cm->seq_params.order_hint_info, cm->current_frame.order_hint,
-            cm->cur_frame->ref_order_hints[frame - LAST_FRAME]);
-        const GlobalMotionEstimationType gm_estimation_type =
-            cm->seq_params.order_hint_info.enable_order_hint &&
-                    abs(ref_frame_dist) <= 2 && do_adaptive_gm_estimation
-                ? GLOBAL_MOTION_DISFLOW_BASED
-                : GLOBAL_MOTION_FEATURE_BASED;
-        for (model = ROTZOOM; model < GLOBAL_TRANS_TYPES_ENC; ++model) {
-          int64_t best_warp_error = INT64_MAX;
-          // Initially set all params to identity.
-          for (i = 0; i < RANSAC_NUM_MOTIONS; ++i) {
-            memcpy(params_by_motion[i].params, kIdentityParams,
-                   (MAX_PARAMDIM - 1) * sizeof(*(params_by_motion[i].params)));
-            params_by_motion[i].num_inliers = 0;
-          }
-
-          av1_compute_global_motion(
-              model, frm_buffer, cpi->source->y_width, cpi->source->y_height,
-              cpi->source->y_stride, frm_corners, num_frm_corners,
-              ref_buf[frame], cpi->common.seq_params.bit_depth,
-              gm_estimation_type, inliers_by_motion, params_by_motion,
-              RANSAC_NUM_MOTIONS);
-          int64_t ref_frame_error = 0;
-          for (i = 0; i < RANSAC_NUM_MOTIONS; ++i) {
-            if (inliers_by_motion[i] == 0) continue;
-
-            params_this_motion = params_by_motion[i].params;
-            av1_convert_model_to_params(params_this_motion, &tmp_wm_params);
-
-            if (tmp_wm_params.wmtype != IDENTITY) {
-              av1_compute_feature_segmentation_map(
-                  segment_map, segment_map_w, segment_map_h,
-                  params_by_motion[i].inliers, params_by_motion[i].num_inliers);
-
-              ref_frame_error = av1_segmented_frame_error(
-                  is_cur_buf_hbd(xd), xd->bd, ref_buf[frame]->y_buffer,
-                  ref_buf[frame]->y_stride, cpi->source->y_buffer,
-                  cpi->source->y_width, cpi->source->y_height,
-                  cpi->source->y_stride, segment_map, segment_map_w);
-
-              int64_t erroradv_threshold =
-                  calc_erroradv_threshold(cpi, ref_frame_error);
-
-              const int64_t warp_error = av1_refine_integerized_param(
-                  &tmp_wm_params, tmp_wm_params.wmtype, is_cur_buf_hbd(xd),
-                  xd->bd, ref_buf[frame]->y_buffer, ref_buf[frame]->y_width,
-                  ref_buf[frame]->y_height, ref_buf[frame]->y_stride,
-                  cpi->source->y_buffer, cpi->source->y_width,
-                  cpi->source->y_height, cpi->source->y_stride,
-                  GM_REFINEMENT_COUNT, best_warp_error, segment_map,
-                  segment_map_w, erroradv_threshold);
-
-              if (warp_error < best_warp_error) {
-                best_warp_error = warp_error;
-                // Save the wm_params modified by
-                // av1_refine_integerized_param() rather than motion index to
-                // avoid rerunning refine() below.
-                memcpy(&(cm->global_motion[frame]), &tmp_wm_params,
-                       sizeof(WarpedMotionParams));
-              }
-            }
-          }
-          if (cm->global_motion[frame].wmtype <= AFFINE)
-            if (!av1_get_shear_params(&cm->global_motion[frame]))
-              cm->global_motion[frame] = default_warp_params;
-
-          if (cm->global_motion[frame].wmtype == TRANSLATION) {
-            cm->global_motion[frame].wmmat[0] =
-                convert_to_trans_prec(cm->allow_high_precision_mv,
-                                      cm->global_motion[frame].wmmat[0]) *
-                GM_TRANS_ONLY_DECODE_FACTOR;
-            cm->global_motion[frame].wmmat[1] =
-                convert_to_trans_prec(cm->allow_high_precision_mv,
-                                      cm->global_motion[frame].wmmat[1]) *
-                GM_TRANS_ONLY_DECODE_FACTOR;
-          }
-
-          if (cm->global_motion[frame].wmtype == IDENTITY) continue;
-
-          if (ref_frame_error == 0) continue;
-
-          // If the best error advantage found doesn't meet the threshold for
-          // this motion type, revert to IDENTITY.
-          if (!av1_is_enough_erroradvantage(
-                  (double)best_warp_error / ref_frame_error,
-                  gm_get_params_cost(&cm->global_motion[frame], ref_params,
-                                     cm->allow_high_precision_mv),
-                  cpi->sf.gm_erroradv_type)) {
-            cm->global_motion[frame] = default_warp_params;
-          }
-          if (cm->global_motion[frame].wmtype != IDENTITY) break;
-        }
-        aom_clear_system_state();
-      }
-      cpi->gmparams_cost[frame] =
-          gm_get_params_cost(&cm->global_motion[frame], ref_params,
-                             cm->allow_high_precision_mv) +
-          cpi->gmtype_cost[cm->global_motion[frame].wmtype] -
-          cpi->gmtype_cost[IDENTITY];
+    // Compute global motion w.r.t. future reference frames
+    for (int future_frame = 0; future_frame < num_future_ref_frames;
+         future_frame++) {
+      int frame = future_ref_frame[future_frame].frame;
+      compute_gm_for_valid_ref_frames(
+          cpi, ref_buf, frame, &num_frm_corners, frm_corners, frm_buffer,
+          params_by_motion, segment_map, segment_map_w, segment_map_h);
     }
     aom_free(segment_map);