Make subpel masked motion work with upsampled refs

Change-Id: Id483354e73e983793370b55a1a6a1f2dcd137dc9
diff --git a/vp10/encoder/mcomp.c b/vp10/encoder/mcomp.c
index 0c8ec43..3b935b0 100644
--- a/vp10/encoder/mcomp.c
+++ b/vp10/encoder/mcomp.c
@@ -24,6 +24,7 @@
 
 #include "vp10/encoder/encoder.h"
 #include "vp10/encoder/mcomp.h"
+#include "vp10/encoder/rdopt.h"
 
 // #define NEW_DIAMOND_SEARCH
 
@@ -2655,6 +2656,29 @@
     v = INT_MAX;                                                       \
   }
 
+#undef CHECK_BETTER0
+#define CHECK_BETTER0(v, r, c) CHECK_BETTER(v, r, c)
+
+#undef CHECK_BETTER1
+#define CHECK_BETTER1(v, r, c) \
+  if (c >= minc && c <= maxc && r >= minr && r <= maxr) {              \
+    thismse = upsampled_masked_pref_error(xd,                          \
+                                          mask, mask_stride,           \
+                                          vfp, z, src_stride,          \
+                                          upre(y, y_stride, r, c),     \
+                                          y_stride,                    \
+                                          w, h, &sse);    \
+    if ((v = MVC(r, c) + thismse) < besterr) {                         \
+      besterr = v;                                                     \
+      br = r;                                                          \
+      bc = c;                                                          \
+      *distortion = thismse;                                           \
+      *sse1 = sse;                                                     \
+    }                                                                  \
+  } else {                                                             \
+    v = INT_MAX;                                                       \
+  }
+
 int vp10_find_best_masked_sub_pixel_tree(const MACROBLOCK *x,
                                          const uint8_t *mask, int mask_stride,
                                          MV *bestmv, const MV *ref_mv,
@@ -2671,8 +2695,8 @@
   const MACROBLOCKD *xd = &x->e_mbd;
   unsigned int besterr = INT_MAX;
   unsigned int sse;
-  unsigned int whichdir;
   int thismse;
+  unsigned int whichdir;
   unsigned int halfiters = iters_per_step;
   unsigned int quarteriters = iters_per_step;
   unsigned int eighthiters = iters_per_step;
@@ -2747,6 +2771,276 @@
   return besterr;
 }
 
+static unsigned int setup_masked_center_error(const uint8_t *mask,
+                                              int mask_stride,
+                                              const MV *bestmv,
+                                              const MV *ref_mv,
+                                              int error_per_bit,
+                                              const vp10_variance_fn_ptr_t *vfp,
+                                              const uint8_t *const src,
+                                              const int src_stride,
+                                              const uint8_t *const y,
+                                              int y_stride,
+                                              int offset,
+                                              int *mvjcost, int *mvcost[2],
+                                              unsigned int *sse1,
+                                              int *distortion) {
+  unsigned int besterr;
+  besterr = vfp->mvf(y + offset, y_stride, src, src_stride,
+                     mask, mask_stride, sse1);
+  *distortion = besterr;
+  besterr += mv_err_cost(bestmv, ref_mv, mvjcost, mvcost, error_per_bit);
+  return besterr;
+}
+
+static int upsampled_masked_pref_error(const MACROBLOCKD *xd,
+                                       const uint8_t *mask,
+                                       int mask_stride,
+                                       const vp10_variance_fn_ptr_t *vfp,
+                                       const uint8_t *const src,
+                                       const int src_stride,
+                                       const uint8_t *const y, int y_stride,
+                                       int w, int h, unsigned int *sse) {
+  unsigned int besterr;
+#if CONFIG_VP9_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    DECLARE_ALIGNED(16, uint16_t, pred16[MAX_SB_SQUARE]);
+    vpx_highbd_upsampled_pred(pred16, w, h, y, y_stride);
+
+    besterr = vfp->mvf(CONVERT_TO_BYTEPTR(pred16), w, src, src_stride,
+                       mask, mask_stride, sse);
+  } else {
+    DECLARE_ALIGNED(16, uint8_t, pred[MAX_SB_SQUARE]);
+#else
+    DECLARE_ALIGNED(16, uint8_t, pred[MAX_SB_SQUARE]);
+    (void) xd;
+#endif  // CONFIG_VP9_HIGHBITDEPTH
+    vpx_upsampled_pred(pred, w, h, y, y_stride);
+
+    besterr = vfp->mvf(pred, w, src, src_stride,
+                       mask, mask_stride, sse);
+#if CONFIG_VP9_HIGHBITDEPTH
+  }
+#endif
+  return besterr;
+}
+
+static unsigned int upsampled_setup_masked_center_error(
+    const MACROBLOCKD *xd,
+    const uint8_t *mask, int mask_stride,
+    const MV *bestmv, const MV *ref_mv,
+    int error_per_bit, const vp10_variance_fn_ptr_t *vfp,
+    const uint8_t *const src, const int src_stride,
+    const uint8_t *const y, int y_stride,
+    int w, int h, int offset, int *mvjcost, int *mvcost[2],
+    unsigned int *sse1, int *distortion) {
+  unsigned int besterr = upsampled_masked_pref_error(
+      xd, mask, mask_stride, vfp, src, src_stride,
+      y + offset, y_stride, w, h, sse1);
+  *distortion = besterr;
+  besterr += mv_err_cost(bestmv, ref_mv, mvjcost, mvcost, error_per_bit);
+  return besterr;
+}
+
+int vp10_find_best_masked_sub_pixel_tree_up(VP10_COMP *cpi,
+                                            MACROBLOCK *x,
+                                            const uint8_t *mask,
+                                            int mask_stride,
+                                            int mi_row, int mi_col,
+                                            MV *bestmv, const MV *ref_mv,
+                                            int allow_hp,
+                                            int error_per_bit,
+                                            const vp10_variance_fn_ptr_t *vfp,
+                                            int forced_stop,
+                                            int iters_per_step,
+                                            int *mvjcost, int *mvcost[2],
+                                            int *distortion,
+                                            unsigned int *sse1,
+                                            int is_second,
+                                            int use_upsampled_ref) {
+  const uint8_t *const z = x->plane[0].src.buf;
+  const uint8_t *const src_address = z;
+  const int src_stride = x->plane[0].src.stride;
+  MACROBLOCKD *xd = &x->e_mbd;
+  struct macroblockd_plane *const pd = &xd->plane[0];
+  MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+  unsigned int besterr = INT_MAX;
+  unsigned int sse;
+  unsigned int thismse;
+
+  int rr = ref_mv->row;
+  int rc = ref_mv->col;
+  int br = bestmv->row * 8;
+  int bc = bestmv->col * 8;
+  int hstep = 4;
+  int iter;
+  int round = 3 - forced_stop;
+  const int minc = VPXMAX(x->mv_col_min * 8, ref_mv->col - MV_MAX);
+  const int maxc = VPXMIN(x->mv_col_max * 8, ref_mv->col + MV_MAX);
+  const int minr = VPXMAX(x->mv_row_min * 8, ref_mv->row - MV_MAX);
+  const int maxr = VPXMIN(x->mv_row_max * 8, ref_mv->row + MV_MAX);
+  int tr = br;
+  int tc = bc;
+  const MV *search_step = search_step_table;
+  int idx, best_idx = -1;
+  unsigned int cost_array[5];
+  int kr, kc;
+  const int w = 4 * num_4x4_blocks_wide_lookup[mbmi->sb_type];
+  const int h = 4 * num_4x4_blocks_high_lookup[mbmi->sb_type];
+  int offset;
+  int y_stride;
+  const uint8_t *y;
+
+  const struct buf_2d backup_pred = pd->pre[is_second];
+  if (use_upsampled_ref) {
+    int ref = xd->mi[0]->mbmi.ref_frame[is_second];
+    const YV12_BUFFER_CONFIG *upsampled_ref = get_upsampled_ref(cpi, ref);
+    setup_pred_plane(&pd->pre[is_second], upsampled_ref->y_buffer,
+                     upsampled_ref->y_stride, (mi_row << 3), (mi_col << 3),
+                     NULL, pd->subsampling_x, pd->subsampling_y);
+  }
+  y = pd->pre[is_second].buf;
+  y_stride = pd->pre[is_second].stride;
+  offset = bestmv->row * y_stride + bestmv->col;
+
+  if (!(allow_hp && vp10_use_mv_hp(ref_mv)))
+    if (round == 3)
+      round = 2;
+
+  bestmv->row *= 8;
+  bestmv->col *= 8;
+
+  // use_upsampled_ref can be 0 or 1
+  if (use_upsampled_ref)
+    besterr = upsampled_setup_masked_center_error(
+        xd, mask, mask_stride, bestmv, ref_mv, error_per_bit,
+        vfp, z, src_stride, y, y_stride,
+        w, h, (offset << 3),
+        mvjcost, mvcost, sse1, distortion);
+  else
+    besterr = setup_masked_center_error(
+        mask, mask_stride, bestmv, ref_mv, error_per_bit,
+        vfp, z, src_stride, y, y_stride,
+        offset, mvjcost, mvcost, sse1, distortion);
+
+  for (iter = 0; iter < round; ++iter) {
+    // Check vertical and horizontal sub-pixel positions.
+    for (idx = 0; idx < 4; ++idx) {
+      tr = br + search_step[idx].row;
+      tc = bc + search_step[idx].col;
+      if (tc >= minc && tc <= maxc && tr >= minr && tr <= maxr) {
+        MV this_mv = {tr, tc};
+
+        if (use_upsampled_ref) {
+          const uint8_t *const pre_address = y + tr * y_stride + tc;
+
+          thismse = upsampled_masked_pref_error(xd,
+                                                mask, mask_stride,
+                                                vfp, src_address, src_stride,
+                                                pre_address, y_stride,
+                                                w, h, &sse);
+        } else {
+          const uint8_t *const pre_address = y + (tr >> 3) * y_stride +
+              (tc >> 3);
+          thismse = vfp->msvf(pre_address, y_stride, sp(tc), sp(tr),
+                              src_address, src_stride,
+                              mask, mask_stride, &sse);
+        }
+
+        cost_array[idx] = thismse +
+            mv_err_cost(&this_mv, ref_mv, mvjcost, mvcost, error_per_bit);
+
+        if (cost_array[idx] < besterr) {
+          best_idx = idx;
+          besterr = cost_array[idx];
+          *distortion = thismse;
+          *sse1 = sse;
+        }
+      } else {
+        cost_array[idx] = INT_MAX;
+      }
+    }
+
+    // Check diagonal sub-pixel position
+    kc = (cost_array[0] <= cost_array[1] ? -hstep : hstep);
+    kr = (cost_array[2] <= cost_array[3] ? -hstep : hstep);
+
+    tc = bc + kc;
+    tr = br + kr;
+    if (tc >= minc && tc <= maxc && tr >= minr && tr <= maxr) {
+      MV this_mv = {tr, tc};
+
+      if (use_upsampled_ref) {
+        const uint8_t *const pre_address = y + tr * y_stride + tc;
+
+        thismse = upsampled_masked_pref_error(xd,
+                                              mask, mask_stride,
+                                              vfp, src_address, src_stride,
+                                              pre_address, y_stride,
+                                              w, h, &sse);
+      } else {
+        const uint8_t *const pre_address = y + (tr >> 3) * y_stride + (tc >> 3);
+
+        thismse = vfp->msvf(pre_address, y_stride, sp(tc), sp(tr),
+                            src_address, src_stride, mask, mask_stride, &sse);
+      }
+
+      cost_array[4] = thismse +
+          mv_err_cost(&this_mv, ref_mv, mvjcost, mvcost, error_per_bit);
+
+      if (cost_array[4] < besterr) {
+        best_idx = 4;
+        besterr = cost_array[4];
+        *distortion = thismse;
+        *sse1 = sse;
+      }
+    } else {
+      cost_array[idx] = INT_MAX;
+    }
+
+    if (best_idx < 4 && best_idx >= 0) {
+      br += search_step[best_idx].row;
+      bc += search_step[best_idx].col;
+    } else if (best_idx == 4) {
+      br = tr;
+      bc = tc;
+    }
+
+    if (iters_per_step > 1 && best_idx != -1) {
+      if (use_upsampled_ref) {
+        SECOND_LEVEL_CHECKS_BEST(1);
+      } else {
+        SECOND_LEVEL_CHECKS_BEST(0);
+      }
+    }
+
+    tr = br;
+    tc = bc;
+
+    search_step += 4;
+    hstep >>= 1;
+    best_idx = -1;
+  }
+
+  // These lines insure static analysis doesn't warn that
+  // tr and tc aren't used after the above point.
+  (void) tr;
+  (void) tc;
+
+  bestmv->row = br;
+  bestmv->col = bc;
+
+  if (use_upsampled_ref) {
+    pd->pre[is_second] = backup_pred;
+  }
+
+  if ((abs(bestmv->col - ref_mv->col) > (MAX_FULL_PEL_VAL << 3)) ||
+      (abs(bestmv->row - ref_mv->row) > (MAX_FULL_PEL_VAL << 3)))
+    return INT_MAX;
+
+  return besterr;
+}
+
 #undef DIST
 #undef MVC
 #undef CHECK_BETTER
diff --git a/vp10/encoder/mcomp.h b/vp10/encoder/mcomp.h
index f99cd8b..c12e7af 100644
--- a/vp10/encoder/mcomp.h
+++ b/vp10/encoder/mcomp.h
@@ -169,7 +169,24 @@
                                          int iters_per_step,
                                          int *mvjcost, int *mvcost[2],
                                          int *distortion,
-                                         unsigned int *sse1, int is_second);
+                                         unsigned int *sse1,
+                                         int is_second);
+int vp10_find_best_masked_sub_pixel_tree_up(struct VP10_COMP *cpi,
+                                            MACROBLOCK *x,
+                                            const uint8_t *mask,
+                                            int mask_stride,
+                                            int mi_row, int mi_col,
+                                            MV *bestmv, const MV *ref_mv,
+                                            int allow_hp,
+                                            int error_per_bit,
+                                            const vp10_variance_fn_ptr_t *vfp,
+                                            int forced_stop,
+                                            int iters_per_step,
+                                            int *mvjcost, int *mvcost[2],
+                                            int *distortion,
+                                            unsigned int *sse1,
+                                            int is_second,
+                                            int use_upsampled_ref);
 int vp10_masked_full_pixel_diamond(const struct VP10_COMP *cpi, MACROBLOCK *x,
                                    const uint8_t *mask, int mask_stride,
                                    MV *mvp_full, int step_param,
diff --git a/vp10/encoder/rdopt.c b/vp10/encoder/rdopt.c
index 918ad3e..254529c 100644
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -5978,15 +5978,18 @@
 
   if (bestsme < INT_MAX) {
     int dis;  /* TODO: use dis in distortion calculation later. */
-    vp10_find_best_masked_sub_pixel_tree(x, mask, mask_stride,
-                                         &tmp_mv->as_mv, &ref_mv,
-                                         cm->allow_high_precision_mv,
-                                         x->errorperbit,
-                                         &cpi->fn_ptr[bsize],
-                                         cpi->sf.mv.subpel_force_stop,
-                                         cpi->sf.mv.subpel_iters_per_step,
-                                         x->nmvjointcost, x->mvcost,
-                                         &dis, &x->pred_sse[ref], ref_idx);
+    vp10_find_best_masked_sub_pixel_tree_up(cpi, x, mask, mask_stride,
+                                            mi_row, mi_col,
+                                            &tmp_mv->as_mv, &ref_mv,
+                                            cm->allow_high_precision_mv,
+                                            x->errorperbit,
+                                            &cpi->fn_ptr[bsize],
+                                            cpi->sf.mv.subpel_force_stop,
+                                            cpi->sf.mv.subpel_iters_per_step,
+                                            x->nmvjointcost, x->mvcost,
+                                            &dis, &x->pred_sse[ref],
+                                            ref_idx,
+                                            cpi->sf.use_upsampled_references);
   }
   *rate_mv = vp10_mv_bit_cost(&tmp_mv->as_mv, &ref_mv,
                               x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);