MOTION_SEARCH: Small optimizations to diamond_search_sad

This commit consists of the following changes:
 * Divide diamond_search_sad into compound case and single case.
 * Move a couple variables out of the inner loop to avoid reloading.
 * Save num00 and second_best_mv at the end of diamond_search_sad to
   avoid duplicate checks for NULL.
 * Remove redundant checks for next_step_size.

Performance:
| SPD_SET | TESTSET | AVG_PSNR | OVR_PSNR |  SSIM   | ENC_T |
|---------|---------|----------|----------|---------|-------|
|    1    | hdres2  | +0.000%  | +0.000%  | +0.000% | -0.4% |
|    1    | lowres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|    1    | midres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|---------|---------|----------|----------|---------|-------|
|    2    | hdres2  | +0.000%  | +0.000%  | +0.000% | -0.3% |
|    2    | lowres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|    2    | midres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|---------|---------|----------|----------|---------|-------|
|    3    | hdres2  | +0.000%  | +0.000%  | +0.000% | -0.4% |
|    3    | lowres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|    3    | midres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|---------|---------|----------|----------|---------|-------|
|    4    | hdres2  | +0.000%  | +0.000%  | +0.000% | -0.4% |
|    4    | lowres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|    4    | midres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|---------|---------|----------|----------|---------|-------|
|    5    | hdres2  | +0.000%  | +0.000%  | +0.000% | -0.5% |
|    5    | lowres2 | +0.000%  | +0.000%  | +0.000% | -0.2% |
|    5    | midres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|---------|---------|----------|----------|---------|-------|
|    6    | hdres2  | +0.000%  | +0.000%  | +0.000% | -0.5% |
|    6    | lowres2 | +0.000%  | +0.000%  | +0.000% | -0.3% |
|    6    | midres2 | +0.000%  | +0.000%  | +0.000% | -0.4% |

Change-Id: Id978d6fd17671dbdcc02a24a94b1218a89eef730
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index 7d239c7..8337592 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -1300,84 +1300,77 @@
                               const FULLPEL_MOTION_SEARCH_PARAMS *ms_params,
                               const int search_step, int *num00,
                               FULLPEL_MV *best_mv, FULLPEL_MV *second_best_mv) {
+#define UPDATE_SEARCH_STEP                                      \
+  do {                                                          \
+    if (best_site != 0) {                                       \
+      tmp_second_best_mv = *best_mv;                            \
+      best_mv->row += site[best_site].mv.row;                   \
+      best_mv->col += site[best_site].mv.col;                   \
+      best_address += site[best_site].offset;                   \
+      is_off_center = 1;                                        \
+    }                                                           \
+                                                                \
+    if (is_off_center == 0) num_center_steps++;                 \
+                                                                \
+    if (best_site == 0 && step > 2) {                           \
+      int next_step_size = cfg->radius[step - 1];               \
+      while (next_step_size == cfg->radius[step] && step > 2) { \
+        num_center_steps++;                                     \
+        --step;                                                 \
+        next_step_size = cfg->radius[step - 1];                 \
+      }                                                         \
+    }                                                           \
+  } while (0)
+
   const struct buf_2d *const src = ms_params->ms_buffers.src;
   const struct buf_2d *const ref = ms_params->ms_buffers.ref;
 
+  const uint8_t *src_buf = src->buf;
+  const int src_stride = src->stride;
   const int ref_stride = ref->stride;
-  const uint8_t *best_address;
 
-  const uint8_t *mask = ms_params->ms_buffers.mask;
-  const uint8_t *second_pred = ms_params->ms_buffers.second_pred;
   const MV_COST_PARAMS *mv_cost_params = &ms_params->mv_cost_params;
 
   const search_site_config *cfg = ms_params->search_sites;
 
-  unsigned int bestsad = INT_MAX;
-  int best_site = 0;
   int is_off_center = 0;
+  // Number of times that we have stayed in the middle. This is used to skip
+  // search steps in the future if diamond_search_sad is called again.
+  int num_center_steps = 0;
 
   clamp_fullmv(&start_mv, &ms_params->mv_limits);
 
   // search_step determines the length of the initial step and hence the number
   // of iterations.
   const int tot_steps = cfg->num_search_steps - search_step;
+  FULLPEL_MV tmp_second_best_mv;
+  if (second_best_mv) {
+    tmp_second_best_mv = *second_best_mv;
+  }
 
-  *num00 = 0;
   *best_mv = start_mv;
 
   // Check the starting position
-  best_address = get_buf_from_fullmv(ref, &start_mv);
-  bestsad = get_mvpred_compound_sad(ms_params, src, best_address, ref_stride);
-  bestsad += mvsad_err_cost_(best_mv, &ms_params->mv_cost_params);
+  const uint8_t *best_address = get_buf_from_fullmv(ref, &start_mv);
+  unsigned int bestsad = mvsad_err_cost_(best_mv, &ms_params->mv_cost_params);
 
-  int next_step_size = tot_steps > 2 ? cfg->radius[tot_steps - 2] : 1;
-  for (int step = tot_steps - 1; step >= 0; --step) {
-    const search_site *site = cfg->site[step];
-    best_site = 0;
-    if (step > 0) next_step_size = cfg->radius[step - 1];
+  // TODO(chiyotsai@google.com): Implement 4 points search for msdf&sdaf
+  if (ms_params->ms_buffers.second_pred) {
+    bestsad +=
+        get_mvpred_compound_sad(ms_params, src, best_address, ref_stride);
 
-    int all_in = 1, j;
-    // Trap illegal vectors
-    all_in &= best_mv->row + site[1].mv.row >= ms_params->mv_limits.row_min;
-    all_in &= best_mv->row + site[2].mv.row <= ms_params->mv_limits.row_max;
-    all_in &= best_mv->col + site[3].mv.col >= ms_params->mv_limits.col_min;
-    all_in &= best_mv->col + site[4].mv.col <= ms_params->mv_limits.col_max;
+    for (int step = tot_steps - 1; step >= 0; --step) {
+      const search_site *site = cfg->site[step];
+      const int num_searches = cfg->searches_per_step[step];
+      int best_site = 0;
 
-    // TODO(anyone): Implement 4 points search for msdf&sdaf
-    if (all_in && !mask && !second_pred) {
-      const uint8_t *src_buf = src->buf;
-      const int src_stride = src->stride;
-      for (int idx = 1; idx <= cfg->searches_per_step[step]; idx += 4) {
-        unsigned char const *block_offset[4];
-        unsigned int sads[4];
-
-        for (j = 0; j < 4; j++)
-          block_offset[j] = site[idx + j].offset + best_address;
-
-        ms_params->sdx4df(src_buf, src_stride, block_offset, ref_stride, sads);
-        for (j = 0; j < 4; j++) {
-          if (sads[j] < bestsad) {
-            const FULLPEL_MV this_mv = { best_mv->row + site[idx + j].mv.row,
-                                         best_mv->col + site[idx + j].mv.col };
-            unsigned int thissad =
-                sads[j] + mvsad_err_cost_(&this_mv, mv_cost_params);
-            if (thissad < bestsad) {
-              bestsad = thissad;
-              best_site = idx + j;
-            }
-          }
-        }
-      }
-    } else {
-      for (int idx = 1; idx <= cfg->searches_per_step[step]; idx++) {
+      for (int idx = 1; idx <= num_searches; idx++) {
         const FULLPEL_MV this_mv = { best_mv->row + site[idx].mv.row,
                                      best_mv->col + site[idx].mv.col };
 
         if (av1_is_fullmv_in_range(&ms_params->mv_limits, this_mv)) {
           const uint8_t *const check_here = site[idx].offset + best_address;
-          unsigned int thissad;
-
-          thissad =
+          unsigned int thissad =
               get_mvpred_compound_sad(ms_params, src, check_here, ref_stride);
 
           if (thissad < bestsad) {
@@ -1389,47 +1382,91 @@
           }
         }
       }
+      UPDATE_SEARCH_STEP;
     }
+  } else {
+    bestsad += get_mvpred_sad(ms_params, src, best_address, ref_stride);
 
-    if (best_site != 0) {
-      if (second_best_mv) {
-        *second_best_mv = *best_mv;
+    for (int step = tot_steps - 1; step >= 0; --step) {
+      const search_site *site = cfg->site[step];
+      const int num_searches = cfg->searches_per_step[step];
+      int best_site = 0;
+
+      int all_in = 1;
+      // Trap illegal vectors
+      all_in &= best_mv->row + site[1].mv.row >= ms_params->mv_limits.row_min;
+      all_in &= best_mv->row + site[2].mv.row <= ms_params->mv_limits.row_max;
+      all_in &= best_mv->col + site[3].mv.col >= ms_params->mv_limits.col_min;
+      all_in &= best_mv->col + site[4].mv.col <= ms_params->mv_limits.col_max;
+
+      if (all_in) {
+        for (int idx = 1; idx <= num_searches; idx += 4) {
+          unsigned char const *block_offset[4];
+          unsigned int sads[4];
+
+          for (int j = 0; j < 4; j++)
+            block_offset[j] = site[idx + j].offset + best_address;
+
+          ms_params->sdx4df(src_buf, src_stride, block_offset, ref_stride,
+                            sads);
+          for (int j = 0; j < 4; j++) {
+            if (sads[j] < bestsad) {
+              const FULLPEL_MV this_mv = { best_mv->row + site[idx + j].mv.row,
+                                           best_mv->col +
+                                               site[idx + j].mv.col };
+              unsigned int thissad =
+                  sads[j] + mvsad_err_cost_(&this_mv, mv_cost_params);
+              if (thissad < bestsad) {
+                bestsad = thissad;
+                best_site = idx + j;
+              }
+            }
+          }
+        }
+      } else {
+        for (int idx = 1; idx <= num_searches; idx++) {
+          const FULLPEL_MV this_mv = { best_mv->row + site[idx].mv.row,
+                                       best_mv->col + site[idx].mv.col };
+
+          if (av1_is_fullmv_in_range(&ms_params->mv_limits, this_mv)) {
+            const uint8_t *const check_here = site[idx].offset + best_address;
+            unsigned int thissad =
+                get_mvpred_sad(ms_params, src, check_here, ref_stride);
+
+            if (thissad < bestsad) {
+              thissad += mvsad_err_cost_(&this_mv, mv_cost_params);
+              if (thissad < bestsad) {
+                bestsad = thissad;
+                best_site = idx;
+              }
+            }
+          }
+        }
       }
-      best_mv->row += site[best_site].mv.row;
-      best_mv->col += site[best_site].mv.col;
-      best_address += site[best_site].offset;
-      is_off_center = 1;
-    }
-
-    if (is_off_center == 0) (*num00)++;
-
-    if (best_site == 0) {
-      while (next_step_size == cfg->radius[step] && step > 2) {
-        ++(*num00);
-        --step;
-        next_step_size = cfg->radius[step - 1];
-      }
+      UPDATE_SEARCH_STEP;
     }
   }
 
+  *num00 = num_center_steps;
+  if (second_best_mv) {
+    *second_best_mv = tmp_second_best_mv;
+  }
+
   return bestsad;
+
+#undef UPDATE_SEARCH_STEP
 }
 
-/* do_refine: If last step (1-away) of n-step search doesn't pick the center
-              point as the best match, we will do a final 1-away diamond
-              refining search  */
 static int full_pixel_diamond(const FULLPEL_MV start_mv,
                               const FULLPEL_MOTION_SEARCH_PARAMS *ms_params,
                               const int step_param, int *cost_list,
                               FULLPEL_MV *best_mv, FULLPEL_MV *second_best_mv) {
   const search_site_config *cfg = ms_params->search_sites;
   int thissme, n, num00 = 0;
-  int bestsme = diamond_search_sad(start_mv, ms_params, step_param, &n, best_mv,
-                                   second_best_mv);
+  diamond_search_sad(start_mv, ms_params, step_param, &n, best_mv,
+                     second_best_mv);
 
-  if (bestsme < INT_MAX) {
-    bestsme = get_mvpred_compound_var_cost(ms_params, best_mv);
-  }
+  int bestsme = get_mvpred_compound_var_cost(ms_params, best_mv);
 
   // If there won't be more n-step search, check to see if refining search is
   // needed.
@@ -1437,23 +1474,23 @@
   while (n < further_steps) {
     ++n;
 
+    // TODO(chiyotsai@google.com): There is another bug here where the second
+    // best mv gets incorrectly overwritten. Fix it later.
+    FULLPEL_MV tmp_best_mv;
+    diamond_search_sad(start_mv, ms_params, step_param + n, &num00,
+                       &tmp_best_mv, second_best_mv);
+
+    thissme = get_mvpred_compound_var_cost(ms_params, &tmp_best_mv);
+
+    if (thissme < bestsme) {
+      bestsme = thissme;
+      *best_mv = tmp_best_mv;
+    }
+
     if (num00) {
-      num00--;
-    } else {
-      // TODO(chiyotsai@google.com): There is another bug here where the second
-      // best mv gets incorrectly overwritten. Fix it later.
-      FULLPEL_MV tmp_best_mv;
-      thissme = diamond_search_sad(start_mv, ms_params, step_param + n, &num00,
-                                   &tmp_best_mv, second_best_mv);
-
-      if (thissme < INT_MAX) {
-        thissme = get_mvpred_compound_var_cost(ms_params, &tmp_best_mv);
-      }
-
-      if (thissme < bestsme) {
-        bestsme = thissme;
-        *best_mv = tmp_best_mv;
-      }
+      // Advance the loop by num00 steps
+      n += num00;
+      num00 = 0;
     }
   }