Simplify av1_find_best_(obmc_)sub_pixel_tree

Change-Id: I599c875e23eee96d9eb202a19fd5f1da0ca68eeb
diff --git a/av1/common/mv.h b/av1/common/mv.h
index 8aa364e..dac13d7 100644
--- a/av1/common/mv.h
+++ b/av1/common/mv.h
@@ -24,6 +24,8 @@
 #define GET_MV_RAWPEL(x) (((x) + 3 + ((x) >= 0)) >> 3)
 #define GET_MV_SUBPEL(x) ((x)*8)
 
+#define CHECK_MV_EQUAL(x, y) (((int_mv)(x)).as_int == ((int_mv)(y)).as_int)
+
 // The motion vector in units of full pixel
 typedef struct fullpel_mv {
   int16_t row;
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index b6aa18d..3d587b0 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -13,6 +13,7 @@
 #include <math.h>
 #include <stdio.h>
 
+#include "av1/common/filter.h"
 #include "config/aom_config.h"
 #include "config/aom_dsp_rtcd.h"
 
@@ -2258,28 +2259,70 @@
   }
 }
 
-// A newer version of second level check that gives better quality.
-// TODO(chiyotsai@google.com): evaluate this on subpel_search_types different
-// from av1_find_best_sub_pixel_tree
-static AOM_FORCE_INLINE void second_level_check_v2(
-    MACROBLOCKD *xd, const AV1_COMMON *const cm, const MV *diag_mv, MV *best_mv,
-    int kr, int kc, const SubpelMvLimits *mv_limits, const uint8_t *const src,
+static AOM_FORCE_INLINE MV first_level_check(
+    MACROBLOCKD *xd, const AV1_COMMON *const cm, const MV this_mv, MV *best_mv,
+    const int hstep, const SubpelMvLimits *mv_limits, const uint8_t *const src,
     const int src_stride, const uint8_t *const ref, int ref_stride,
     const SUBPEL_SEARCH_VAR_PARAMS *var_params,
     const MV_COST_PARAMS *mv_cost_params, unsigned int *besterr,
     unsigned int *sse1, int *distortion) {
-  const MV center_mv = *best_mv;
+  int dummy = 0;
+  const MV left_mv = { this_mv.row, this_mv.col - hstep };
+  const MV right_mv = { this_mv.row, this_mv.col + hstep };
+  const MV top_mv = { this_mv.row - hstep, this_mv.col };
+  const MV bottom_mv = { this_mv.row + hstep, this_mv.col };
 
-  assert(diag_mv->row == best_mv->row || diag_mv->col == best_mv->col);
-  if (best_mv->row == diag_mv->row && best_mv->col != diag_mv->col) {
-    kc = best_mv->col - diag_mv->col;
-  } else if (best_mv->row != diag_mv->row && best_mv->col == diag_mv->col) {
-    kr = best_mv->row - diag_mv->row;
+  const unsigned int left = check_better(
+      xd, cm, &left_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride,
+      var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+  const unsigned int right = check_better(
+      xd, cm, &right_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride,
+      var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+  const unsigned int up = check_better(
+      xd, cm, &top_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride,
+      var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+  const unsigned int down = check_better(
+      xd, cm, &bottom_mv, best_mv, mv_limits, src, src_stride, ref, ref_stride,
+      var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+
+  const MV diag_step = get_best_diag_step(hstep, left, right, up, down);
+  const MV diag_mv = { this_mv.row + diag_step.row,
+                       this_mv.col + diag_step.col };
+
+  // Check the diagonal direction with the best mv
+  check_better(xd, cm, &diag_mv, best_mv, mv_limits, src, src_stride, ref,
+               ref_stride, var_params, mv_cost_params, besterr, sse1,
+               distortion, &dummy);
+
+  return diag_step;
+}
+
+// A newer version of second level check that gives better quality.
+// TODO(chiyotsai@google.com): evaluate this on subpel_search_types different
+// from av1_find_best_sub_pixel_tree
+static AOM_FORCE_INLINE void second_level_check_v2(
+    MACROBLOCKD *xd, const AV1_COMMON *const cm, const MV this_mv, MV diag_step,
+    MV *best_mv, const SubpelMvLimits *mv_limits, const uint8_t *const src,
+    const int src_stride, const uint8_t *const ref, int ref_stride,
+    const SUBPEL_SEARCH_VAR_PARAMS *var_params,
+    const MV_COST_PARAMS *mv_cost_params, unsigned int *besterr,
+    unsigned int *sse1, int *distortion) {
+  assert(best_mv->row == this_mv.row + diag_step.row ||
+         best_mv->col == this_mv.col + diag_step.col);
+  if (CHECK_MV_EQUAL(this_mv, *best_mv)) {
+    return;
+  } else if (this_mv.row == best_mv->row) {
+    // Search away from diagonal step since diagonal search did not provide any
+    // improvement
+    diag_step.row *= -1;
+  } else if (this_mv.col == best_mv->col) {
+    diag_step.col *= -1;
   }
 
-  const MV row_bias_mv = { center_mv.row + kr, center_mv.col };
-  const MV col_bias_mv = { center_mv.row, center_mv.col + kc };
-  const MV diag_bias_mv = { center_mv.row + kr, center_mv.col + kc };
+  const MV row_bias_mv = { best_mv->row + diag_step.row, best_mv->col };
+  const MV col_bias_mv = { best_mv->row, best_mv->col + diag_step.col };
+  const MV diag_bias_mv = { best_mv->row + diag_step.row,
+                            best_mv->col + diag_step.col };
   int has_better_mv = 0;
 
   if (var_params->subpel_search_type != USE_2_TAPS_ORIG) {
@@ -2292,10 +2335,9 @@
 
     // Do an additional search if the second iteration gives a better mv
     if (has_better_mv) {
-      int dummy = 0;
       check_better(xd, cm, &diag_bias_mv, best_mv, mv_limits, src, src_stride,
                    ref, ref_stride, var_params, mv_cost_params, besterr, sse1,
-                   distortion, &dummy);
+                   distortion, &has_better_mv);
     }
   } else {
     check_better_fast(&row_bias_mv, best_mv, mv_limits, src, src_stride, ref,
@@ -2307,10 +2349,9 @@
 
     // Do an additional search if the second iteration gives a better mv
     if (has_better_mv) {
-      int dummy = 0;
       check_better_fast(&diag_bias_mv, best_mv, mv_limits, src, src_stride, ref,
                         ref_stride, var_params, mv_cost_params, besterr, sse1,
-                        distortion, &dummy);
+                        distortion, &has_better_mv);
     }
   }
 }
@@ -2689,18 +2730,6 @@
   return besterr;
 }
 
-/* clang-format off */
-static const MV search_step_table[12] = {
-  // left, right, up, down
-  { 0, -INIT_SUBPEL_STEP_SIZE },        { 0, INIT_SUBPEL_STEP_SIZE },
-  { -INIT_SUBPEL_STEP_SIZE, 0 },        { INIT_SUBPEL_STEP_SIZE, 0 },
-  { 0, -(INIT_SUBPEL_STEP_SIZE >> 1) }, { 0, (INIT_SUBPEL_STEP_SIZE >> 1) },
-  { -(INIT_SUBPEL_STEP_SIZE >> 1), 0 }, { (INIT_SUBPEL_STEP_SIZE >> 1), 0 },
-  { 0, -(INIT_SUBPEL_STEP_SIZE >> 2) }, { 0, (INIT_SUBPEL_STEP_SIZE >> 2) },
-  { -(INIT_SUBPEL_STEP_SIZE >> 2), 0 }, { (INIT_SUBPEL_STEP_SIZE >> 2), 0 }
-};
-/* clang-format on */
-
 int av1_find_best_sub_pixel_tree(MACROBLOCK *x, const AV1_COMMON *const cm,
                                  const SUBPEL_MOTION_SEARCH_PARAMS *ms_params,
                                  int *distortion, unsigned int *sse1) {
@@ -2716,9 +2745,9 @@
 
   MACROBLOCKD *xd = &x->e_mbd;
   const uint8_t *const src_address = x->plane[0].src.buf;
+  const uint8_t *const ref_address = xd->plane[0].pre[0].buf;
   const int src_stride = x->plane[0].src.stride;
   const int ref_stride = xd->plane[0].pre[0].stride;
-  const uint8_t *const ref_address = xd->plane[0].pre[0].buf;
 
   convert_fullmv_to_mv(&x->best_mv);
   MV *bestmv = &x->best_mv.as_mv;
@@ -2728,8 +2757,6 @@
 
   int hstep = INIT_SUBPEL_STEP_SIZE;
   int iter, round = FULL_PEL - forced_stop;
-  const MV *search_step = search_step_table;
-  unsigned int cost_array[5];
   unsigned int besterr = INT_MAX;
 
   if (!allow_hp)
@@ -2749,77 +2776,38 @@
     av1_set_fractional_mv(x->fractional_best_mv);
   }
 
-  MV iter_center_mv = *bestmv;
   for (iter = 0; iter < round; ++iter) {
-    if (x->fractional_best_mv[iter].as_mv.row == iter_center_mv.row &&
-        x->fractional_best_mv[iter].as_mv.col == iter_center_mv.col)
+    MV iter_center_mv = *bestmv;
+    if (CHECK_MV_EQUAL(x->fractional_best_mv[iter], iter_center_mv)) {
       return INT_MAX;
+    }
 
     x->fractional_best_mv[iter].as_mv = iter_center_mv;
 
-    MV best_iter_mv = iter_center_mv;
-    int iter_best_idx = -1;
-
-    // Check vertical and horizontal sub-pixel positions.
-    for (int idx = 0; idx < 4; ++idx) {
-      const MV this_mv = { iter_center_mv.row + search_step[idx].row,
-                           iter_center_mv.col + search_step[idx].col };
-
-      int has_better_mv = 0;
-      if (subpel_search_type != USE_2_TAPS_ORIG) {
-        cost_array[idx] = check_better(
-            xd, cm, &this_mv, &best_iter_mv, &mv_limits, src_address,
-            src_stride, ref_address, ref_stride, var_params, mv_cost_params,
-            &besterr, sse1, distortion, &has_better_mv);
-      } else {
-        cost_array[idx] = check_better_fast(
-            &this_mv, &best_iter_mv, &mv_limits, src_address, src_stride,
-            ref_address, ref_stride, var_params, mv_cost_params, &besterr, sse1,
-            distortion, &has_better_mv);
-      }
-      if (has_better_mv) {
-        iter_best_idx = idx;
-      }
+    MV diag_step;
+    if (subpel_search_type != USE_2_TAPS_ORIG) {
+      diag_step = first_level_check(xd, cm, iter_center_mv, bestmv, hstep,
+                                    &mv_limits, src_address, src_stride,
+                                    ref_address, ref_stride, var_params,
+                                    mv_cost_params, &besterr, sse1, distortion);
+    } else {
+      diag_step = first_level_check_fast(
+          iter_center_mv, bestmv, hstep, &mv_limits, src_address, src_stride,
+          ref_address, ref_stride, var_params, mv_cost_params, &besterr, sse1,
+          distortion);
     }
 
     // Check diagonal sub-pixel position
-    const MV diag_step = { (cost_array[2] <= cost_array[3] ? -hstep : hstep),
-                           (cost_array[0] <= cost_array[1] ? -hstep : hstep) };
-    const MV diag_mv = { iter_center_mv.row + diag_step.row,
-                         iter_center_mv.col + diag_step.col };
-    int has_better_mv = 0;
-    if (subpel_search_type != USE_2_TAPS_ORIG) {
-      cost_array[4] = check_better(xd, cm, &diag_mv, &best_iter_mv, &mv_limits,
-                                   src_address, src_stride, ref_address,
-                                   ref_stride, var_params, mv_cost_params,
-                                   &besterr, sse1, distortion, &has_better_mv);
-    } else {
-      cost_array[4] = check_better_fast(
-          &diag_mv, &best_iter_mv, &mv_limits, src_address, src_stride,
-          ref_address, ref_stride, var_params, mv_cost_params, &besterr, sse1,
-          distortion, &has_better_mv);
-    }
-    if (has_better_mv) {
-      iter_best_idx = 4;
+    if (!CHECK_MV_EQUAL(iter_center_mv, *bestmv) && iters_per_step > 1) {
+      second_level_check_v2(xd, cm, iter_center_mv, diag_step, bestmv,
+                            &mv_limits, src_address, src_stride, ref_address,
+                            ref_stride, var_params, mv_cost_params, &besterr,
+                            sse1, distortion);
     }
 
-    if (iter_best_idx != -1) {
-      iter_center_mv = best_iter_mv;
-
-      if (iters_per_step > 1) {
-        second_level_check_v2(xd, cm, &diag_mv, &iter_center_mv, diag_step.row,
-                              diag_step.col, &mv_limits, src_address,
-                              src_stride, ref_address, ref_stride, var_params,
-                              mv_cost_params, &besterr, sse1, distortion);
-      }
-    }
-
-    search_step += 4;
     hstep >>= 1;
   }
 
-  *bestmv = iter_center_mv;
-
   return besterr;
 }
 
@@ -3132,25 +3120,96 @@
   return cost;
 }
 
+static AOM_FORCE_INLINE MV obmc_first_level_check(
+    MACROBLOCKD *xd, const AV1_COMMON *const cm, const MV this_mv, MV *best_mv,
+    const int hstep, const SubpelMvLimits *mv_limits, const int32_t *const src,
+    const int32_t *const mask, const uint8_t *const ref, int ref_stride,
+    const SUBPEL_SEARCH_VAR_PARAMS *var_params,
+    const MV_COST_PARAMS *mv_cost_params, unsigned int *besterr,
+    unsigned int *sse1, int *distortion) {
+  int dummy = 0;
+  const MV left_mv = { this_mv.row, this_mv.col - hstep };
+  const MV right_mv = { this_mv.row, this_mv.col + hstep };
+  const MV top_mv = { this_mv.row - hstep, this_mv.col };
+  const MV bottom_mv = { this_mv.row + hstep, this_mv.col };
+
+  if (var_params->subpel_search_type != USE_2_TAPS_ORIG) {
+    const unsigned int left = obmc_check_better(
+        xd, cm, &left_mv, best_mv, mv_limits, src, mask, ref, ref_stride,
+        var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+    const unsigned int right = obmc_check_better(
+        xd, cm, &right_mv, best_mv, mv_limits, src, mask, ref, ref_stride,
+        var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+    const unsigned int up = obmc_check_better(
+        xd, cm, &top_mv, best_mv, mv_limits, src, mask, ref, ref_stride,
+        var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+    const unsigned int down = obmc_check_better(
+        xd, cm, &bottom_mv, best_mv, mv_limits, src, mask, ref, ref_stride,
+        var_params, mv_cost_params, besterr, sse1, distortion, &dummy);
+
+    const MV diag_step = get_best_diag_step(hstep, left, right, up, down);
+    const MV diag_mv = { this_mv.row + diag_step.row,
+                         this_mv.col + diag_step.col };
+
+    // Check the diagonal direction with the best mv
+    obmc_check_better(xd, cm, &diag_mv, best_mv, mv_limits, src, mask, ref,
+                      ref_stride, var_params, mv_cost_params, besterr, sse1,
+                      distortion, &dummy);
+
+    return diag_step;
+  } else {
+    const unsigned int left = obmc_check_better_fast(
+        &left_mv, best_mv, mv_limits, src, mask, ref, ref_stride, var_params,
+        mv_cost_params, besterr, sse1, distortion, &dummy);
+    const unsigned int right = obmc_check_better_fast(
+        &right_mv, best_mv, mv_limits, src, mask, ref, ref_stride, var_params,
+        mv_cost_params, besterr, sse1, distortion, &dummy);
+
+    const unsigned int up = obmc_check_better_fast(
+        &top_mv, best_mv, mv_limits, src, mask, ref, ref_stride, var_params,
+        mv_cost_params, besterr, sse1, distortion, &dummy);
+
+    const unsigned int down = obmc_check_better_fast(
+        &bottom_mv, best_mv, mv_limits, src, mask, ref, ref_stride, var_params,
+        mv_cost_params, besterr, sse1, distortion, &dummy);
+
+    const MV diag_step = get_best_diag_step(hstep, left, right, up, down);
+    const MV diag_mv = { this_mv.row + diag_step.row,
+                         this_mv.col + diag_step.col };
+
+    // Check the diagonal direction with the best mv
+    obmc_check_better_fast(&diag_mv, best_mv, mv_limits, src, mask, ref,
+                           ref_stride, var_params, mv_cost_params, besterr,
+                           sse1, distortion, &dummy);
+
+    return diag_step;
+  }
+}
+
 // A newer version of second level check for obmc that gives better quality.
 static AOM_FORCE_INLINE void obmc_second_level_check_v2(
-    MACROBLOCKD *xd, const AV1_COMMON *const cm, const MV *diag_mv, MV *best_mv,
-    int kr, int kc, const SubpelMvLimits *mv_limits, const int32_t *src,
+    MACROBLOCKD *xd, const AV1_COMMON *const cm, const MV this_mv, MV diag_step,
+    MV *best_mv, const SubpelMvLimits *mv_limits, const int32_t *src,
     const int32_t *mask, const uint8_t *const ref, int ref_stride,
     const SUBPEL_SEARCH_VAR_PARAMS *var_params,
     const MV_COST_PARAMS *mv_cost_params, unsigned int *besterr,
     unsigned int *sse1, int *distortion) {
-  assert(diag_mv->row == best_mv->row || diag_mv->col == best_mv->col);
-
-  const MV center_mv = *best_mv;
-  if (best_mv->row == diag_mv->row && best_mv->col != diag_mv->col) {
-    kc = best_mv->col - diag_mv->col;
-  } else if (best_mv->row != diag_mv->row && best_mv->col == diag_mv->col) {
-    kr = best_mv->row - diag_mv->row;
+  assert(best_mv->row == this_mv.row + diag_step.row ||
+         best_mv->col == this_mv.col + diag_step.col);
+  if (CHECK_MV_EQUAL(this_mv, *best_mv)) {
+    return;
+  } else if (this_mv.row == best_mv->row) {
+    // Search away from diagonal step since diagonal search did not provide any
+    // improvement
+    diag_step.row *= -1;
+  } else if (this_mv.col == best_mv->col) {
+    diag_step.col *= -1;
   }
 
-  const MV row_bias_mv = { center_mv.row + kr, center_mv.col };
-  const MV col_bias_mv = { center_mv.row, center_mv.col + kc };
+  const MV row_bias_mv = { best_mv->row + diag_step.row, best_mv->col };
+  const MV col_bias_mv = { best_mv->row, best_mv->col + diag_step.col };
+  const MV diag_bias_mv = { best_mv->row + diag_step.row,
+                            best_mv->col + diag_step.col };
   int has_better_mv = 0;
 
   if (var_params->subpel_search_type != USE_2_TAPS_ORIG) {
@@ -3163,11 +3222,9 @@
 
     // Do an additional search if the second iteration gives a better mv
     if (has_better_mv) {
-      int dummy = 0;
-      const MV diag_bias_mv = { center_mv.row + kr, center_mv.col + kc };
       obmc_check_better(xd, cm, &diag_bias_mv, best_mv, mv_limits, src, mask,
                         ref, ref_stride, var_params, mv_cost_params, besterr,
-                        sse1, distortion, &dummy);
+                        sse1, distortion, &has_better_mv);
     }
   } else {
     obmc_check_better_fast(&row_bias_mv, best_mv, mv_limits, src, mask, ref,
@@ -3179,11 +3236,9 @@
 
     // Do an additional search if the second iteration gives a better mv
     if (has_better_mv) {
-      int dummy = 0;
-      const MV diag_bias_mv = { center_mv.row + kr, center_mv.col + kc };
       obmc_check_better_fast(&diag_bias_mv, best_mv, mv_limits, src, mask, ref,
                              ref_stride, var_params, mv_cost_params, besterr,
-                             sse1, distortion, &dummy);
+                             sse1, distortion, &has_better_mv);
     }
   }
 }
@@ -3223,8 +3278,6 @@
   if (!allow_hp)
     if (round == 3) round = 2;
 
-  const MV *search_step = search_step_table;
-  unsigned int cost_array[5];
   unsigned int besterr = INT_MAX;
 
   if (subpel_search_type != USE_2_TAPS_ORIG)
@@ -3236,71 +3289,22 @@
                                       ref_stride, var_params, mv_cost_params,
                                       sse1, distortion);
 
-  MV iter_center_mv = *bestmv;
   for (int iter = 0; iter < round; ++iter) {
-    MV best_iter_mv = iter_center_mv;
-    int iter_best_idx = -1;
+    MV iter_center_mv = *bestmv;
+    MV diag_step = obmc_first_level_check(
+        xd, cm, iter_center_mv, bestmv, hstep, &mv_limits, src_address, mask,
+        ref_address, ref_stride, var_params, mv_cost_params, &besterr, sse1,
+        distortion);
 
-    // Check vertical and horizontal sub-pixel positions.
-    for (int idx = 0; idx < 4; ++idx) {
-      const MV this_mv = { iter_center_mv.row + search_step[idx].row,
-                           iter_center_mv.col + search_step[idx].col };
-
-      int has_better_mv = 0;
-      if (subpel_search_type != USE_2_TAPS_ORIG) {
-        cost_array[idx] = obmc_check_better(
-            xd, cm, &this_mv, &best_iter_mv, &mv_limits, src_address, mask,
-            ref_address, ref_stride, var_params, mv_cost_params, &besterr, sse1,
-            distortion, &has_better_mv);
-      } else {
-        cost_array[idx] = obmc_check_better_fast(
-            &this_mv, &best_iter_mv, &mv_limits, src_address, mask, ref_address,
-            ref_stride, var_params, mv_cost_params, &besterr, sse1, distortion,
-            &has_better_mv);
-      }
-      if (has_better_mv) {
-        iter_best_idx = idx;
-      }
+    if (!CHECK_MV_EQUAL(iter_center_mv, *bestmv) && iters_per_step > 1) {
+      obmc_second_level_check_v2(xd, cm, iter_center_mv, diag_step, bestmv,
+                                 &mv_limits, src_address, mask, ref_address,
+                                 ref_stride, var_params, mv_cost_params,
+                                 &besterr, sse1, distortion);
     }
-
-    // Check diagonal sub-pixel position
-    const MV diag_step = { (cost_array[2] <= cost_array[3] ? -hstep : hstep),
-                           (cost_array[0] <= cost_array[1] ? -hstep : hstep) };
-    const MV diag_mv = { iter_center_mv.row + diag_step.row,
-                         iter_center_mv.col + diag_step.col };
-    int has_better_mv = 0;
-    if (subpel_search_type != USE_2_TAPS_ORIG) {
-      cost_array[4] = obmc_check_better(
-          xd, cm, &diag_mv, &best_iter_mv, &mv_limits, src_address, mask,
-          ref_address, ref_stride, var_params, mv_cost_params, &besterr, sse1,
-          distortion, &has_better_mv);
-    } else {
-      cost_array[4] = obmc_check_better_fast(
-          &diag_mv, &best_iter_mv, &mv_limits, src_address, mask, ref_address,
-          ref_stride, var_params, mv_cost_params, &besterr, sse1, distortion,
-          &has_better_mv);
-    }
-    if (has_better_mv) {
-      iter_best_idx = 4;
-    }
-
-    if (iter_best_idx != -1) {
-      iter_center_mv = best_iter_mv;
-
-      if (iters_per_step > 1) {
-        obmc_second_level_check_v2(
-            xd, cm, &diag_mv, &iter_center_mv, diag_step.row, diag_step.col,
-            &mv_limits, src_address, mask, ref_address, ref_stride, var_params,
-            mv_cost_params, &besterr, sse1, distortion);
-      }
-    }
-
-    search_step += 4;
     hstep >>= 1;
   }
 
-  *bestmv = iter_center_mv;
-
   return besterr;
 }