ncobmc-adapt-weight: add interpolation mode search functions

Change-Id: I5370e38f6fe00f467e1945bc46866adea9422b22
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 96a496d..6220cfd 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8439,6 +8439,9 @@
 #if CONFIG_GLOBAL_MOTION
                                                          0, xd->global_motion,
 #endif  // CONFIG_GLOBAL_MOTION
+#if CONFIG_WARPED_MOTION
+                                                         xd,
+#endif
                                                          mi);
 #else
   last_motion_mode_allowed = motion_mode_allowed(
@@ -12587,4 +12590,261 @@
   }
 }
 #endif  // CONFIG_NCOBMC
+
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
+                                      struct macroblock *x, int mi_row,
+                                      int mi_col) {
+  const AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  BLOCK_SIZE bsize = mbmi->sb_type;
+#if CONFIG_VAR_TX
+  const int n4 = bsize_to_num_blk(bsize);
+  uint8_t backup_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+#endif
+  MB_MODE_INFO backup_mbmi;
+  int plane, ref, skip_blk, backup_skip;
+  RD_STATS rd_stats_y, rd_stats_uv, rd_stats_y2, rd_stats_uv2;
+  int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+  int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+  int64_t prev_rd, naw_rd;  // ncobmc_adapt_weight_rd
+
+  // Recompute the rd for the motion mode decided in rd loop
+  if (mbmi->motion_mode == SIMPLE_TRANSLATION ||
+      mbmi->motion_mode == OBMC_CAUSAL) {
+    set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+    for (ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
+      YV12_BUFFER_CONFIG *cfg = get_ref_frame_buffer(cpi, mbmi->ref_frame[ref]);
+      assert(cfg != NULL);
+      av1_setup_pre_planes(xd, ref, cfg, mi_row, mi_col,
+                           &xd->block_refs[ref]->sf);
+    }
+    av1_setup_dst_planes(xd->plane, bsize, get_frame_new_buffer(cm), mi_row,
+                         mi_col);
+
+    av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
+    if (mbmi->motion_mode == OBMC_CAUSAL) {
+      av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
+    }
+
+    av1_subtract_plane(x, bsize, 0);
+
+#if CONFIG_VAR_TX
+    if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+      select_tx_type_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
+    } else {
+      int idx, idy;
+      super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
+      for (idy = 0; idy < xd->n8_h; ++idy)
+        for (idx = 0; idx < xd->n8_w; ++idx)
+          mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
+      memset(x->blk_skip[0], rd_stats_y2.skip,
+             sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
+    }
+    inter_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+#else
+    super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
+    super_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+#endif
+  }
+
+  if (rd_stats_y2.skip && rd_stats_uv2.skip) {
+    rd_stats_y2.rate = rate_skip1;
+    rd_stats_uv2.rate = 0;
+    rd_stats_y2.dist = rd_stats_y2.sse;
+    rd_stats_uv2.dist = rd_stats_uv2.sse;
+    skip_blk = 1;
+  } else if (RDCOST(x->rdmult,
+                    (rd_stats_y2.rate + rd_stats_uv2.rate + rate_skip0),
+                    (rd_stats_y2.dist + rd_stats_uv2.dist)) >
+             RDCOST(x->rdmult, rate_skip1,
+                    (rd_stats_y2.sse + rd_stats_uv2.sse))) {
+    rd_stats_y2.rate = rate_skip1;
+    rd_stats_uv2.rate = 0;
+    rd_stats_y2.dist = rd_stats_y2.sse;
+    rd_stats_uv2.dist = rd_stats_uv2.sse;
+    skip_blk = 1;
+  } else {
+    rd_stats_y2.rate += rate_skip0;
+    skip_blk = 0;
+  }
+
+  backup_mbmi = *mbmi;
+  backup_skip = skip_blk;
+#if CONFIG_VAR_TX
+  memcpy(backup_blk_skip, x->blk_skip[0], sizeof(backup_blk_skip[0]) * n4);
+#endif
+  prev_rd = RDCOST(x->rdmult, (rd_stats_y2.rate + rd_stats_uv2.rate),
+                   (rd_stats_y2.dist + rd_stats_uv2.dist));
+  prev_rd +=
+      RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+  // Compute the rd cost for ncobmc adaptive weight
+  mbmi->motion_mode = NCOBMC_ADAPT_WEIGHT;
+
+  for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    get_pred_from_intrpl_buf(xd, mi_row, mi_col, bsize, plane);
+  }
+
+  av1_subtract_plane(x, bsize, 0);
+
+#if CONFIG_VAR_TX
+  if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+    select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
+  } else {
+    int idx, idy;
+    super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
+    for (idy = 0; idy < xd->n8_h; ++idy)
+      for (idx = 0; idx < xd->n8_w; ++idx)
+        mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
+    memset(x->blk_skip[0], rd_stats_y.skip,
+           sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
+  }
+  inter_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX);
+#else
+  super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
+  super_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX);
+#endif
+  assert(rd_stats_y.rate != INT_MAX && rd_stats_uv.rate != INT_MAX);
+
+  if (rd_stats_y.skip && rd_stats_uv.skip) {
+    rd_stats_y.rate = rate_skip1;
+    rd_stats_uv.rate = 0;
+    rd_stats_y.dist = rd_stats_y.sse;
+    rd_stats_uv.dist = rd_stats_uv.sse;
+    skip_blk = 1;
+  } else if (RDCOST(x->rdmult,
+                    (rd_stats_y.rate + rd_stats_uv.rate + rate_skip0),
+                    (rd_stats_y.dist + rd_stats_uv.dist)) >
+             RDCOST(x->rdmult, rate_skip1,
+                    (rd_stats_y.sse + rd_stats_uv.sse))) {
+    rd_stats_y.rate = rate_skip1;
+    rd_stats_uv.rate = 0;
+    rd_stats_y.dist = rd_stats_y.sse;
+    rd_stats_uv.dist = rd_stats_uv.sse;
+    skip_blk = 1;
+  } else {
+    rd_stats_y.rate += rate_skip0;
+    skip_blk = 0;
+  }
+  naw_rd = RDCOST(x->rdmult, (rd_stats_y.rate + rd_stats_uv.rate),
+                  (rd_stats_y.dist + rd_stats_uv.dist));
+  naw_rd += RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+  // Calculate the ncobmc mode costs
+  {
+    ADAPT_OVERLAP_BLOCK aob = adapt_overlap_block_lookup[bsize];
+    naw_rd +=
+        RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[0]], 0);
+    if (mi_size_wide[bsize] != mi_size_high[bsize])
+      naw_rd +=
+          RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[1]], 0);
+  }
+
+  if (prev_rd > naw_rd) {
+    x->skip = skip_blk;
+  } else {
+    *mbmi = backup_mbmi;
+    x->skip = backup_skip;
+#if CONFIG_VAR_TX
+    memcpy(x->blk_skip[0], backup_blk_skip, sizeof(backup_blk_skip[0]) * n4);
+#endif
+  }
+}
+
+int64_t get_ncobmc_error(MACROBLOCKD *xd, int pxl_row, int pxl_col,
+                         BLOCK_SIZE bsize, int plane, struct buf_2d *src) {
+  const int wide = AOMMIN(mi_size_wide[bsize] * MI_SIZE,
+                          (xd->sb_mi_bd.mi_col_end + 1) * MI_SIZE - pxl_col);
+  const int high = AOMMIN(mi_size_high[bsize] * MI_SIZE,
+                          (xd->sb_mi_bd.mi_row_end + 1) * MI_SIZE - pxl_row);
+  const int ss_x = xd->plane[plane].subsampling_x;
+  const int ss_y = xd->plane[plane].subsampling_y;
+  int row_offset = (pxl_row - xd->sb_mi_bd.mi_row_begin * MI_SIZE) >> ss_y;
+  int col_offset = (pxl_col - xd->sb_mi_bd.mi_col_begin * MI_SIZE) >> ss_x;
+  int dst_stride = xd->ncobmc_pred_buf_stride[plane];
+  int dst_offset = row_offset * dst_stride + col_offset;
+  int src_stride = src->stride;
+
+  int r, c;
+  int64_t tmp, error = 0;
+
+  for (r = 0; r < (high >> ss_y); ++r) {
+    for (c = 0; c < (wide >> ss_x); ++c) {
+      tmp = xd->ncobmc_pred_buf[plane][r * dst_stride + c + dst_offset] -
+            src->buf[r * src_stride + c];
+      error += tmp * tmp;
+    }
+  }
+  return error;
+}
+
+int get_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+                    MACROBLOCKD *xd, int mi_row, int mi_col, int bsize) {
+  const AV1_COMMON *const cm = &cpi->common;
+#if CONFIG_HIGHBITDEPTH
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_0[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_1[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_2[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_3[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
+#else
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_0[MAX_MB_PLANE * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_1[MAX_MB_PLANE * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_2[MAX_MB_PLANE * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, tmp_buf_3[MAX_MB_PLANE * MAX_SB_SQUARE]);
+#endif
+  uint8_t *pred_buf[4][MAX_MB_PLANE];
+
+  // TODO(weitinglin): stride size needs to be fixed for high-bit depth
+  int pred_stride[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
+
+  // target block in pxl
+  int pxl_row = mi_row << MI_SIZE_LOG2;
+  int pxl_col = mi_col << MI_SIZE_LOG2;
+  int64_t error, best_error = INT64_MAX;
+  int plane, tmp_mode, best_mode = 0;
+#if CONFIG_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    int len = sizeof(uint16_t);
+    ASSIGN_ALIGNED_PTRS_HBD(pred_buf[0], tmp_buf_0, MAX_SB_SQUARE, len);
+    ASSIGN_ALIGNED_PTRS_HBD(pred_buf[1], tmp_buf_0, MAX_SB_SQUARE, len);
+    ASSIGN_ALIGNED_PTRS_HBD(pred_buf[2], tmp_buf_0, MAX_SB_SQUARE, len);
+    ASSIGN_ALIGNED_PTRS_HBD(pred_buf[3], tmp_buf_0, MAX_SB_SQUARE, len);
+  } else {
+#endif  // CONFIG_HIGHBITDEPTH
+    ASSIGN_ALIGNED_PTRS(pred_buf[0], tmp_buf_0, MAX_SB_SQUARE);
+    ASSIGN_ALIGNED_PTRS(pred_buf[1], tmp_buf_1, MAX_SB_SQUARE);
+    ASSIGN_ALIGNED_PTRS(pred_buf[2], tmp_buf_2, MAX_SB_SQUARE);
+    ASSIGN_ALIGNED_PTRS(pred_buf[3], tmp_buf_3, MAX_SB_SQUARE);
+#if CONFIG_HIGHBITDEPTH
+  }
+#endif
+
+  av1_get_ext_blk_preds(cm, xd, bsize, mi_row, mi_col, pred_buf, pred_stride);
+  av1_get_ori_blk_pred(cm, xd, bsize, mi_row, mi_col, pred_buf[3], pred_stride);
+
+  for (tmp_mode = 0; tmp_mode < MAX_NCOBMC_MODES; ++tmp_mode) {
+    error = 0;
+    for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
+      build_ncobmc_intrpl_pred(cm, xd, plane, pxl_row, pxl_col, bsize, pred_buf,
+                               pred_stride, tmp_mode);
+      error += get_ncobmc_error(xd, pxl_row, pxl_col, bsize, plane,
+                                &x->plane[plane].src);
+    }
+    if (error < best_error) {
+      best_mode = tmp_mode;
+      best_error = error;
+    }
+  }
+
+  for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    build_ncobmc_intrpl_pred(cm, xd, plane, pxl_row, pxl_col, bsize, pred_buf,
+                             pred_stride, best_mode);
+  }
+
+  return best_mode;
+}
+
+#endif  // CONFIG_NCOBMC_ADAPT_WEIGHT
 #endif  // CONFIG_MOTION_VAR
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 4923952..bcad8f8 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -144,4 +144,20 @@
                      const MACROBLOCKD *xd, BLOCK_SIZE bsize, int plane,
                      TX_SIZE tx_size, TX_TYPE tx_type);
 
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
+                                      struct macroblock *x, int mi_row,
+                                      int mi_col);
+
+int get_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+                    MACROBLOCKD *xd, int mi_row, int mi_col, int bsize);
+
+void av1_setup_src_planes_pxl(MACROBLOCK *x, const YV12_BUFFER_CONFIG *src,
+                              int pxl_row, int pxl_col);
+
+void rebuild_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+                         MACROBLOCKD *xd, int mi_row, int mi_col, int bsize,
+                         int xd_mi_offset, NCOBMC_MODE best_mode, int rebuild);
+#endif
+
 #endif  // AV1_ENCODER_RDOPT_H_