Move tpl-based MV generation outside of av1_single_motion_search

Change-Id: I3b6e23c766b494f23e18cc34850983c22fab85ee
diff --git a/av1/encoder/motion_search_facade.c b/av1/encoder/motion_search_facade.c
index 059c8f2..9834b19 100644
--- a/av1/encoder/motion_search_facade.c
+++ b/av1/encoder/motion_search_facade.c
@@ -44,6 +44,77 @@
          cpi->oxcf.speed <= 2;
 }
 
+// Iterate through the tpl and collect the mvs to be used as candidates
+static INLINE void get_mv_candidate_from_tpl(const AV1_COMP *const cpi,
+                                             const MACROBLOCK *x,
+                                             BLOCK_SIZE bsize, int ref,
+                                             cand_mv_t *cand, int *cand_count,
+                                             int *total_cand_weight) {
+  const SuperBlockEnc *sb_enc = &x->sb_enc;
+  if (!sb_enc->tpl_data_count) {
+    return;
+  }
+
+  const AV1_COMMON *cm = &cpi->common;
+  const MACROBLOCKD *xd = &x->e_mbd;
+  const int mi_row = xd->mi_row;
+  const int mi_col = xd->mi_col;
+
+  const BLOCK_SIZE tpl_bsize =
+      convert_length_to_bsize(cpi->tpl_data.tpl_bsize_1d);
+  const int tplw = mi_size_wide[tpl_bsize];
+  const int tplh = mi_size_high[tpl_bsize];
+  const int nw = mi_size_wide[bsize] / tplw;
+  const int nh = mi_size_high[bsize] / tplh;
+
+  if (nw >= 1 && nh >= 1) {
+    const int of_h = mi_row % mi_size_high[cm->seq_params.sb_size];
+    const int of_w = mi_col % mi_size_wide[cm->seq_params.sb_size];
+    const int start = of_h / tplh * sb_enc->tpl_stride + of_w / tplw;
+    int valid = 1;
+
+    // Assign large weight to start_mv, so it is always tested.
+    cand[0].weight = nw * nh;
+
+    for (int k = 0; k < nh; k++) {
+      for (int l = 0; l < nw; l++) {
+        const int_mv mv =
+            sb_enc
+                ->tpl_mv[start + k * sb_enc->tpl_stride + l][ref - LAST_FRAME];
+        if (mv.as_int == INVALID_MV) {
+          valid = 0;
+          break;
+        }
+
+        const FULLPEL_MV fmv = { GET_MV_RAWPEL(mv.as_mv.row),
+                                 GET_MV_RAWPEL(mv.as_mv.col) };
+        int unique = 1;
+        for (int m = 0; m < *cand_count; m++) {
+          if (RIGHT_SHIFT_MV(fmv.row) == RIGHT_SHIFT_MV(cand[m].fmv.row) &&
+              RIGHT_SHIFT_MV(fmv.col) == RIGHT_SHIFT_MV(cand[m].fmv.col)) {
+            unique = 0;
+            cand[m].weight++;
+            break;
+          }
+        }
+
+        if (unique) {
+          cand[*cand_count].fmv = fmv;
+          cand[*cand_count].weight = 1;
+          (*cand_count)++;
+        }
+      }
+      if (!valid) break;
+    }
+
+    if (valid) {
+      *total_cand_weight = 2 * nh * nw;
+      if (*cand_count > 2)
+        qsort(cand, *cand_count, sizeof(cand[0]), &compare_weight);
+    }
+  }
+}
+
 void av1_single_motion_search(const AV1_COMP *const cpi, MACROBLOCK *x,
                               BLOCK_SIZE bsize, int ref_idx, int *rate_mv,
                               int search_range, inter_mode_info *mode_info,
@@ -104,60 +175,7 @@
 
   if (!cpi->sf.mv_sf.full_pixel_search_level &&
       mbmi->motion_mode == SIMPLE_TRANSLATION) {
-    SuperBlockEnc *sb_enc = &x->sb_enc;
-    if (sb_enc->tpl_data_count) {
-      const BLOCK_SIZE tpl_bsize =
-          convert_length_to_bsize(cpi->tpl_data.tpl_bsize_1d);
-      const int tplw = mi_size_wide[tpl_bsize];
-      const int tplh = mi_size_high[tpl_bsize];
-      const int nw = mi_size_wide[bsize] / tplw;
-      const int nh = mi_size_high[bsize] / tplh;
-
-      if (nw >= 1 && nh >= 1) {
-        const int of_h = mi_row % mi_size_high[cm->seq_params.sb_size];
-        const int of_w = mi_col % mi_size_wide[cm->seq_params.sb_size];
-        const int start = of_h / tplh * sb_enc->tpl_stride + of_w / tplw;
-        int valid = 1;
-
-        // Assign large weight to start_mv, so it is always tested.
-        cand[0].weight = nw * nh;
-
-        for (int k = 0; k < nh; k++) {
-          for (int l = 0; l < nw; l++) {
-            const int_mv mv = sb_enc->tpl_mv[start + k * sb_enc->tpl_stride + l]
-                                            [ref - LAST_FRAME];
-            if (mv.as_int == INVALID_MV) {
-              valid = 0;
-              break;
-            }
-
-            const FULLPEL_MV fmv = { GET_MV_RAWPEL(mv.as_mv.row),
-                                     GET_MV_RAWPEL(mv.as_mv.col) };
-            int unique = 1;
-            for (int m = 0; m < cnt; m++) {
-              if (RIGHT_SHIFT_MV(fmv.row) == RIGHT_SHIFT_MV(cand[m].fmv.row) &&
-                  RIGHT_SHIFT_MV(fmv.col) == RIGHT_SHIFT_MV(cand[m].fmv.col)) {
-                unique = 0;
-                cand[m].weight++;
-                break;
-              }
-            }
-
-            if (unique) {
-              cand[cnt].fmv = fmv;
-              cand[cnt].weight = 1;
-              cnt++;
-            }
-          }
-          if (!valid) break;
-        }
-
-        if (valid) {
-          total_weight = 2 * nh * nw;
-          if (cnt > 2) qsort(cand, cnt, sizeof(cand[0]), &compare_weight);
-        }
-      }
-    }
+    get_mv_candidate_from_tpl(cpi, x, bsize, ref, cand, &cnt, &total_weight);
   }
 
   // Further reduce the search range.
@@ -190,9 +208,9 @@
 
   switch (mbmi->motion_mode) {
     case SIMPLE_TRANSLATION: {
+      // Perform a search with the top 2 candidates
       int sum_weight = 0;
-
-      for (int m = 0; m < cnt; m++) {
+      for (int m = 0; m < AOMMIN(2, cnt); m++) {
         FULLPEL_MV smv = cand[m].fmv;
         FULLPEL_MV this_best_mv, this_second_best_mv;
 
@@ -207,7 +225,7 @@
         }
 
         sum_weight += cand[m].weight;
-        if (m >= 2 || 4 * sum_weight > 3 * total_weight) break;
+        if (4 * sum_weight > 3 * total_weight) break;
       }
     } break;
     case OBMC_CAUSAL: