Refactor full pixel motion search

Unify the seach scheme to support multiple search applications.

Change-Id: I927dabd5c318f4bbf5b3137b6a56c67b89947a8a
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 7a5a0e5..3dd0423 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -274,8 +274,6 @@
   #
   # Motion search
   #
-  add_proto qw/int av1_diamond_search_sad/, "struct macroblock *x, const struct search_site_config *cfg,  MV *ref_mv, MV *best_mv, int search_param, int sad_per_bit, int *num00, const struct aom_variance_vtable *fn_ptr, const MV *center_mv";
-
   add_proto qw/int av1_full_range_search/, "const struct macroblock *x, const struct search_site_config *cfg, MV *ref_mv, MV *best_mv, int search_param, int sad_per_bit, int *num00, const struct aom_variance_vtable *fn_ptr, const MV *center_mv";
 
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
diff --git a/av1/encoder/firstpass.c b/av1/encoder/firstpass.c
index 0247959..42c28fa3 100644
--- a/av1/encoder/firstpass.c
+++ b/av1/encoder/firstpass.c
@@ -239,9 +239,9 @@
   }
 #endif
   // Center the initial step/diamond search on best mv.
-  tmp_err = av1_diamond_search_sad_c(x, &cpi->ss_cfg[SS_CFG_SRC], &ref_mv_full,
-                                     &tmp_mv, step_param, x->sadperbit16,
-                                     &num00, &v_fn_ptr, ref_mv);
+  tmp_err = av1_diamond_search_sad_c(
+      x, &cpi->ss_cfg[SS_CFG_SRC], &ref_mv_full, &tmp_mv, step_param,
+      x->sadperbit16, &num00, &v_fn_ptr, ref_mv, NULL, NULL, 0, 0);
   if (tmp_err < INT_MAX)
     tmp_err = av1_get_mvpred_var(x, &tmp_mv, ref_mv, &v_fn_ptr, 1);
   if (tmp_err < INT_MAX - new_mv_mode_penalty) tmp_err += new_mv_mode_penalty;
@@ -263,7 +263,7 @@
     } else {
       tmp_err = av1_diamond_search_sad_c(
           x, &cpi->ss_cfg[SS_CFG_SRC], &ref_mv_full, &tmp_mv, step_param + n,
-          x->sadperbit16, &num00, &v_fn_ptr, ref_mv);
+          x->sadperbit16, &num00, &v_fn_ptr, ref_mv, NULL, NULL, 0, 0);
       if (tmp_err < INT_MAX)
         tmp_err = av1_get_mvpred_var(x, &tmp_mv, ref_mv, &v_fn_ptr, 1);
       if (tmp_err < INT_MAX - new_mv_mode_penalty)
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index f771410..1dff653 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -1698,7 +1698,8 @@
                              MV *ref_mv, MV *best_mv, int search_param,
                              int sad_per_bit, int *num00,
                              const aom_variance_fn_ptr_t *fn_ptr,
-                             const MV *center_mv) {
+                             const MV *center_mv, uint8_t *second_pred,
+                             uint8_t *mask, int mask_stride, int inv_mask) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   uint8_t *what = x->plane[0].src.buf;
   const int what_stride = x->plane[0].src.stride;
@@ -1731,8 +1732,18 @@
   best_address = in_what;
 
   // Check the starting position
-  bestsad = fn_ptr->sdf(what, what_stride, in_what, in_what_stride) +
-            mvsad_err_cost(x, best_mv, &fcenter_mv, sad_per_bit);
+  // TODO(jingning): unify the parameter interface for the following
+  // computation modes.
+  if (mask)
+    bestsad = fn_ptr->msdf(what, what_stride, in_what, in_what_stride,
+                           second_pred, mask, mask_stride, inv_mask);
+  else if (second_pred)
+    bestsad =
+        fn_ptr->sdaf(what, what_stride, in_what, in_what_stride, second_pred);
+  else
+    bestsad = fn_ptr->sdf(what, what_stride, in_what, in_what_stride);
+
+  bestsad += mvsad_err_cost(x, best_mv, &fcenter_mv, sad_per_bit);
 
   for (int step = tot_steps; step >= 0; --step) {
     const search_site *ss = cfg->ss[step];
@@ -1746,8 +1757,16 @@
 
       if (is_mv_in(&x->mv_limits, &this_mv)) {
         const uint8_t *const check_here = ss[idx].offset + best_address;
-        unsigned int thissad =
-            fn_ptr->sdf(what, what_stride, check_here, in_what_stride);
+        unsigned int thissad;
+
+        if (mask)
+          thissad = fn_ptr->msdf(what, what_stride, check_here, in_what_stride,
+                                 second_pred, mask, mask_stride, inv_mask);
+        else if (second_pred)
+          thissad = fn_ptr->sdaf(what, what_stride, check_here, in_what_stride,
+                                 second_pred);
+        else
+          thissad = fn_ptr->sdf(what, what_stride, check_here, in_what_stride);
 
         if (thissad < bestsad) {
           thissad += mvsad_err_cost(x, &this_mv, &fcenter_mv, sad_per_bit);
@@ -1780,11 +1799,14 @@
                               int sadpb, int further_steps, int do_refine,
                               int *cost_list,
                               const aom_variance_fn_ptr_t *fn_ptr,
-                              const MV *ref_mv, const search_site_config *cfg) {
+                              const MV *ref_mv, const search_site_config *cfg,
+                              uint8_t *second_pred, uint8_t *mask,
+                              int mask_stride, int inv_mask) {
   MV temp_mv;
   int thissme, n, num00 = 0;
   int bestsme = av1_diamond_search_sad_c(x, cfg, mvp_full, &temp_mv, step_param,
-                                         sadpb, &n, fn_ptr, ref_mv);
+                                         sadpb, &n, fn_ptr, ref_mv, second_pred,
+                                         mask, mask_stride, inv_mask);
   if (bestsme < INT_MAX)
     bestsme = av1_get_mvpred_var(x, &temp_mv, ref_mv, fn_ptr, 1);
   x->best_mv.as_mv = temp_mv;
@@ -1799,9 +1821,9 @@
     if (num00) {
       num00--;
     } else {
-      thissme =
-          av1_diamond_search_sad_c(x, cfg, mvp_full, &temp_mv, step_param + n,
-                                   sadpb, &num00, fn_ptr, ref_mv);
+      thissme = av1_diamond_search_sad_c(
+          x, cfg, mvp_full, &temp_mv, step_param + n, sadpb, &num00, fn_ptr,
+          ref_mv, second_pred, mask, mask_stride, inv_mask);
       if (thissme < INT_MAX)
         thissme = av1_get_mvpred_var(x, &temp_mv, ref_mv, fn_ptr, 1);
 
@@ -2331,9 +2353,10 @@
                           fn_ptr, 1, ref_mv);
       break;
     case NSTEP:
-      var = full_pixel_diamond(x, mvp_full, step_param, error_per_bit,
-                               MAX_MVSEARCH_STEPS - 1 - step_param, 1,
-                               cost_list, fn_ptr, ref_mv, cfg);
+      var =
+          full_pixel_diamond(x, mvp_full, step_param, error_per_bit,
+                             MAX_MVSEARCH_STEPS - 1 - step_param, 1, cost_list,
+                             fn_ptr, ref_mv, cfg, NULL, NULL, 0, 0);
       break;
     default: assert(0 && "Invalid search method.");
   }
diff --git a/av1/encoder/mcomp.h b/av1/encoder/mcomp.h
index fa8958a..7a6f227 100644
--- a/av1/encoder/mcomp.h
+++ b/av1/encoder/mcomp.h
@@ -130,6 +130,13 @@
                              const struct buf_2d *src,
                              const struct buf_2d *pre);
 
+int av1_diamond_search_sad_c(MACROBLOCK *x, const search_site_config *cfg,
+                             MV *ref_mv, MV *best_mv, int search_param,
+                             int sad_per_bit, int *num00,
+                             const aom_variance_fn_ptr_t *fn_ptr,
+                             const MV *center_mv, uint8_t *second_pred,
+                             uint8_t *mask, int mask_stride, int inv_mask);
+
 int av1_full_pixel_search(const struct AV1_COMP *cpi, MACROBLOCK *x,
                           BLOCK_SIZE bsize, MV *mvp_full, int step_param,
                           int method, int run_mesh_search, int error_per_bit,