Simplify second_level_check in mcomp.c

Change-Id: I6911c4bbfd5f3da32f281a0bf42892e0aad636e4
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index 5397a3b..b6aa18d 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -31,8 +31,6 @@
 #include "av1/encoder/rdopt.h"
 #include "av1/encoder/reconinter_enc.h"
 
-// #define NEW_DIAMOND_SEARCH
-
 // TODO(any): Adaptively adjust the regularization strength based on image size
 // and motion activity instead of using hard-coded values. It seems like we
 // roughly half the lambda for each increase in resolution
@@ -2123,155 +2121,121 @@
   return cost;
 }
 
+static INLINE MV get_best_diag_step(int step_size, unsigned int left_cost,
+                                    unsigned int right_cost,
+                                    unsigned int up_cost,
+                                    unsigned int down_cost) {
+  const MV diag_step = { up_cost <= down_cost ? -step_size : step_size,
+                         left_cost <= right_cost ? -step_size : step_size };
+
+  return diag_step;
+}
+
 // Searches the four cardinal direction for a better mv, then follows up with a
 // search in the best quadrant. This uses bilinear filter to speed up the
 // calculation.
-static AOM_FORCE_INLINE int first_level_check_fast(
-    const MV *this_mv, MV *best_mv, int hstep, const SubpelMvLimits *mv_limits,
+static AOM_FORCE_INLINE MV first_level_check_fast(
+    const MV this_mv, MV *best_mv, int hstep, const SubpelMvLimits *mv_limits,
     const uint8_t *const src, const int src_stride, const uint8_t *const ref,
     int ref_stride, const SUBPEL_SEARCH_VAR_PARAMS *var_params,
     const MV_COST_PARAMS *mv_cost_params, unsigned int *besterr,
     unsigned int *sse1, int *distortion) {
   // Check the four cardinal directions
-  const MV left_mv = { this_mv->row, this_mv->col - hstep };
+  const MV left_mv = { this_mv.row, this_mv.col - hstep };
   int dummy = 0;
   const unsigned int left = check_better_fast(
       &left_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride,
       var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
 
-  const MV right_mv = { this_mv->row, this_mv->col + hstep };
+  const MV right_mv = { this_mv.row, this_mv.col + hstep };
   const unsigned int right = check_better_fast(
       &right_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride,
       var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
 
-  const MV top_mv = { this_mv->row - hstep, this_mv->col };
+  const MV top_mv = { this_mv.row - hstep, this_mv.col };
   const unsigned int up = check_better_fast(
       &top_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride, var_params,
       mv_cost_params, besterr, sse1, distortion, &dummy);
 
-  const MV bottom_mv = { this_mv->row + hstep, this_mv->col };
+  const MV bottom_mv = { this_mv.row + hstep, this_mv.col };
   const unsigned int down = check_better_fast(
       &bottom_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride,
       var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
 
+  const MV diag_step = get_best_diag_step(hstep, left, right, up, down);
+  const MV diag_mv = { this_mv.row + diag_step.row,
+                       this_mv.col + diag_step.col };
+
   // Check the diagonal direction with the best mv
-  const int whichdir = (left < right ? 0 : 1) + (up < down ? 0 : 2);
-  switch (whichdir) {
-    case 0: {
-      const MV top_left_mv = { this_mv->row - hstep, this_mv->col - hstep };
-      check_better_fast(&top_left_mv, best_mv, mv_limits, src, src_stride, ref,
-                        ref_stride, var_params, mv_cost_params, besterr, sse1,
-                        distortion, &dummy);
-      break;
-    }
-    case 1: {
-      const MV top_right_mv = { this_mv->row - hstep, this_mv->col + hstep };
-      check_better_fast(&top_right_mv, best_mv, mv_limits, src, src_stride, ref,
-                        ref_stride, var_params, mv_cost_params, besterr, sse1,
-                        distortion, &dummy);
-      break;
-    }
-    case 2: {
-      const MV bottom_left_mv = { this_mv->row + hstep, this_mv->col - hstep };
-      check_better_fast(&bottom_left_mv, best_mv, mv_limits, src, src_stride,
-                        ref, ref_stride, var_params, mv_cost_params, besterr,
-                        sse1, distortion, &dummy);
-      break;
-    }
-    case 3: {
-      const MV bottom_right_mv = { this_mv->row + hstep, this_mv->col + hstep };
-      check_better_fast(&bottom_right_mv, best_mv, mv_limits, src, src_stride,
-                        ref, ref_stride, var_params, mv_cost_params, besterr,
-                        sse1, distortion, &dummy);
-      break;
-    }
-  }
-  return whichdir;
+  check_better_fast(&diag_mv, best_mv, mv_limits, src, src_stride, ref,
+                    ref_stride, var_params, mv_cost_params, besterr, sse1,
+                    distortion, &dummy);
+
+  return diag_step;
 }
 
 // Performs a following up search after first_level_check_fast is called. This
 // performs two extra chess pattern searches in the best quadrant.
 static AOM_FORCE_INLINE void second_level_check_fast(
-    const MV *this_mv, MV *best_mv, int hstep, const SubpelMvLimits *mv_limits,
-    const uint8_t *const src, const int src_stride, const uint8_t *const ref,
-    int ref_stride, const SUBPEL_SEARCH_VAR_PARAMS *var_params,
+    const MV this_mv, const MV diag_step, MV *best_mv, int hstep,
+    const SubpelMvLimits *mv_limits, const uint8_t *const src,
+    const int src_stride, const uint8_t *const ref, int ref_stride,
+    const SUBPEL_SEARCH_VAR_PARAMS *var_params,
     const MV_COST_PARAMS *mv_cost_params, unsigned int *besterr,
-    unsigned int *sse1, int *distortion, int whichdir) {
-  const int tr = this_mv->row;
-  const int tc = this_mv->col;
+    unsigned int *sse1, int *distortion) {
+  assert(diag_step.row == hstep || diag_step.row == -hstep);
+  assert(diag_step.col == hstep || diag_step.col == -hstep);
+  const int tr = this_mv.row;
+  const int tc = this_mv.col;
   const int br = best_mv->row;
   const int bc = best_mv->col;
   int dummy = 0;
   if (tr != br && tc != bc) {
-    const int kr = br - tr;
-    const int kc = bc - tc;
-
-    const MV chess_mv_1 = { tr + kr, tc + 2 * kc };
+    assert(diag_step.col == bc - tc);
+    assert(diag_step.row == br - tr);
+    const MV chess_mv_1 = { br, bc + diag_step.col };
+    const MV chess_mv_2 = { br + diag_step.row, bc };
     check_better_fast(&chess_mv_1, best_mv, mv_limits, src, src_stride, ref,
                       ref_stride, var_params, mv_cost_params, besterr, sse1,
                       distortion, &dummy);
 
-    const MV chess_mv_2 = { tr + 2 * kr, tc + kc };
     check_better_fast(&chess_mv_2, best_mv, mv_limits, src, src_stride, ref,
                       ref_stride, var_params, mv_cost_params, besterr, sse1,
                       distortion, &dummy);
   } else if (tr == br && tc != bc) {
-    const int kc = bc - tc;
-    const MV bottom_long_mv = { tr + hstep, tc + 2 * kc };
+    assert(diag_step.col == bc - tc);
+    // Continue searching in the best direction
+    const MV bottom_long_mv = { br + hstep, bc + diag_step.col };
+    const MV top_long_mv = { br - hstep, bc + diag_step.col };
     check_better_fast(&bottom_long_mv, best_mv, mv_limits, src, src_stride, ref,
                       ref_stride, var_params, mv_cost_params, besterr, sse1,
                       distortion, &dummy);
-    const MV top_long_mv = { tr - hstep, tc + 2 * kc };
     check_better_fast(&top_long_mv, best_mv, mv_limits, src, src_stride, ref,
                       ref_stride, var_params, mv_cost_params, besterr, sse1,
                       distortion, &dummy);
 
-    switch (whichdir) {
-      case 0:
-      case 1: {
-        const MV bottom_mv = { tr + hstep, tc + kc };
-        check_better_fast(&bottom_mv, best_mv, mv_limits, src, src_stride, ref,
-                          ref_stride, var_params, mv_cost_params, besterr, sse1,
-                          distortion, &dummy);
-        break;
-      }
-      case 2:
-      case 3: {
-        const MV top_mv = { tr - hstep, tc + kc };
-        check_better_fast(&top_mv, best_mv, mv_limits, src, src_stride, ref,
-                          ref_stride, var_params, mv_cost_params, besterr, sse1,
-                          distortion, &dummy);
-        break;
-      }
-    }
+    // Search in the direction opposite of the best quadrant
+    const MV rev_mv = { br - diag_step.row, bc };
+    check_better_fast(&rev_mv, best_mv, mv_limits, src, src_stride, ref,
+                      ref_stride, var_params, mv_cost_params, besterr, sse1,
+                      distortion, &dummy);
   } else if (tr != br && tc == bc) {
-    const int kr = br - tr;
-    const MV right_long_mv = { tr + 2 * kr, tc + hstep };
+    assert(diag_step.row == br - tr);
+    // Continue searching in the best direction
+    const MV right_long_mv = { br + diag_step.row, bc + hstep };
+    const MV left_long_mv = { br + diag_step.row, bc - hstep };
     check_better_fast(&right_long_mv, best_mv, mv_limits, src, src_stride, ref,
                       ref_stride, var_params, mv_cost_params, besterr, sse1,
                       distortion, &dummy);
-    const MV left_long_mv = { tr + 2 * kr, tc - hstep };
     check_better_fast(&left_long_mv, best_mv, mv_limits, src, src_stride, ref,
                       ref_stride, var_params, mv_cost_params, besterr, sse1,
                       distortion, &dummy);
 
-    switch (whichdir) {
-      case 0:
-      case 2: {
-        const MV right_mv = { tr + kr, tc + hstep };
-        check_better_fast(&right_mv, best_mv, mv_limits, src, src_stride, ref,
-                          ref_stride, var_params, mv_cost_params, besterr, sse1,
-                          distortion, &dummy);
-        break;
-      }
-      case 1:
-      case 3: {
-        const MV left_mv = { tr + kr, tc - hstep };
-        check_better_fast(&left_mv, best_mv, mv_limits, src, src_stride, ref,
-                          ref_stride, var_params, mv_cost_params, besterr, sse1,
-                          distortion, &dummy);
-      }
-    }
+    // Search in the direction opposite of the best quadrant
+    const MV rev_mv = { br, bc - diag_step.col };
+    check_better_fast(&rev_mv, best_mv, mv_limits, src, src_stride, ref,
+                      ref_stride, var_params, mv_cost_params, besterr, sse1,
+                      distortion, &dummy);
   }
 }
 
@@ -2279,18 +2243,18 @@
 // searches the four cardinal directions, and perform several
 // diagonal/chess-pattern searches in the best quadrant.
 static AOM_FORCE_INLINE void two_level_checks_fast(
-    const MV *this_mv, MV *best_mv, int hstep, const SubpelMvLimits *mv_limits,
+    const MV this_mv, MV *best_mv, int hstep, const SubpelMvLimits *mv_limits,
     const uint8_t *const src, const int src_stride, const uint8_t *const ref,
     int ref_stride, const SUBPEL_SEARCH_VAR_PARAMS *var_params,
     const MV_COST_PARAMS *mv_cost_params, unsigned int *besterr,
     unsigned int *sse1, int *distortion, int iters) {
-  unsigned int whichdir = first_level_check_fast(
+  const MV diag_step = first_level_check_fast(
       this_mv, best_mv, hstep, mv_limits, src, src_stride, ref, ref_stride,
       var_params, mv_cost_params, besterr, sse1, distortion);
   if (iters > 1) {
-    second_level_check_fast(this_mv, best_mv, hstep, mv_limits, src, src_stride,
-                            ref, ref_stride, var_params, mv_cost_params,
-                            besterr, sse1, distortion, whichdir);
+    second_level_check_fast(this_mv, diag_step, best_mv, hstep, mv_limits, src,
+                            src_stride, ref, ref_stride, var_params,
+                            mv_cost_params, besterr, sse1, distortion);
   }
 }
 
@@ -2494,7 +2458,7 @@
                         &besterr, sse1, distortion, &dummy);
     }
   } else {
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);
@@ -2504,7 +2468,7 @@
     if (forced_stop != HALF_PEL) {
       hstep >>= 1;
       start_mv = *bestmv;
-      two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+      two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                             src_stride, ref_address, ref_stride, var_params,
                             mv_cost_params, &besterr, sse1, distortion,
                             iters_per_step);
@@ -2514,7 +2478,7 @@
   if (allow_hp && forced_stop == EIGHTH_PEL) {
     hstep >>= 1;
     start_mv = *bestmv;
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);
@@ -2570,7 +2534,7 @@
                         &besterr, sse1, distortion, &dummy);
     }
   } else {
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);
@@ -2581,7 +2545,7 @@
   if (forced_stop != HALF_PEL) {
     hstep >>= 1;
     start_mv = *bestmv;
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);
@@ -2590,7 +2554,7 @@
   if (allow_hp && forced_stop == EIGHTH_PEL) {
     hstep >>= 1;
     start_mv = *bestmv;
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);
@@ -2696,7 +2660,7 @@
         break;
     }
   } else {
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);
@@ -2707,7 +2671,7 @@
   if (forced_stop != HALF_PEL) {
     hstep >>= 1;
     start_mv = *bestmv;
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);
@@ -2716,7 +2680,7 @@
   if (allow_hp && forced_stop == EIGHTH_PEL) {
     hstep >>= 1;
     start_mv = *bestmv;
-    two_level_checks_fast(&start_mv, bestmv, hstep, &mv_limits, src_address,
+    two_level_checks_fast(start_mv, bestmv, hstep, &mv_limits, src_address,
                           src_stride, ref_address, ref_stride, var_params,
                           mv_cost_params, &besterr, sse1, distortion,
                           iters_per_step);