ncobmc-adapt-weight: refactoring the mode selection function

Change-Id: I7393596d98f11aa53ba4b9e329386b5168b3e086
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 313de49..d1dde09 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -12591,102 +12591,55 @@
 }
 #endif  // CONFIG_NCOBMC
 
-#if CONFIG_NCOBMC_ADAPT_WEIGHT
-void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
-                                      struct macroblock *x, int mi_row,
-                                      int mi_col) {
+int64_t get_prediction_rd_cost(const struct AV1_COMP *cpi, struct macroblock *x,
+                               int mi_row, int mi_col, int *skip_blk,
+                               MB_MODE_INFO *backup_mbmi) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   BLOCK_SIZE bsize = mbmi->sb_type;
-#if CONFIG_VAR_TX
-  const int n4 = bsize_to_num_blk(bsize);
-  uint8_t backup_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
-#endif
-  MB_MODE_INFO backup_mbmi;
-  int plane, ref, skip_blk, backup_skip;
-  RD_STATS rd_stats_y, rd_stats_uv, rd_stats_y2, rd_stats_uv2;
+  RD_STATS rd_stats_y, rd_stats_uv;
   int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
   int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
-  int64_t prev_rd, naw_rd;  // ncobmc_adapt_weight_rd
+  int64_t this_rd;
+  int ref;
 
-  // Recompute the rd for the motion mode decided in rd loop
-  if (mbmi->motion_mode == SIMPLE_TRANSLATION ||
-      mbmi->motion_mode == OBMC_CAUSAL) {
-    set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
-    for (ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
-      YV12_BUFFER_CONFIG *cfg = get_ref_frame_buffer(cpi, mbmi->ref_frame[ref]);
-      assert(cfg != NULL);
-      av1_setup_pre_planes(xd, ref, cfg, mi_row, mi_col,
-                           &xd->block_refs[ref]->sf);
-    }
-    av1_setup_dst_planes(xd->plane, bsize, get_frame_new_buffer(cm), mi_row,
-                         mi_col);
+#if CONFIG_CB4X4
+  x->skip_chroma_rd =
+      !is_chroma_reference(mi_row, mi_col, bsize, xd->plane[1].subsampling_x,
+                           xd->plane[1].subsampling_y);
+#endif
 
+  set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+  for (ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
+    YV12_BUFFER_CONFIG *cfg = get_ref_frame_buffer(cpi, mbmi->ref_frame[ref]);
+    assert(cfg != NULL);
+    av1_setup_pre_planes(xd, ref, cfg, mi_row, mi_col,
+                         &xd->block_refs[ref]->sf);
+  }
+  av1_setup_dst_planes(x->e_mbd.plane, bsize,
+                       get_frame_new_buffer(&cpi->common), mi_row, mi_col);
+
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+  if (mbmi->motion_mode != NCOBMC_ADAPT_WEIGHT)
+#endif
     av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
-    if (mbmi->motion_mode == OBMC_CAUSAL) {
-      av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
-    }
 
-    av1_subtract_plane(x, bsize, 0);
-
-#if CONFIG_VAR_TX
-    if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
-      select_tx_type_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
-    } else {
-      int idx, idy;
-      super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
-      for (idy = 0; idy < xd->n8_h; ++idy)
-        for (idx = 0; idx < xd->n8_w; ++idx)
-          mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
-      memset(x->blk_skip[0], rd_stats_y2.skip,
-             sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
-    }
-    inter_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+#if CONFIG_MOTION_VAR
+  if (mbmi->motion_mode == OBMC_CAUSAL) {
+#if CONFIG_NCOBMC
+    av1_build_ncobmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
 #else
-    super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
-    super_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+    av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
 #endif
   }
+#endif  // CONFIG_MOTION_VAR
 
-  if (rd_stats_y2.skip && rd_stats_uv2.skip) {
-    rd_stats_y2.rate = rate_skip1;
-    rd_stats_uv2.rate = 0;
-    rd_stats_y2.dist = rd_stats_y2.sse;
-    rd_stats_uv2.dist = rd_stats_uv2.sse;
-    skip_blk = 1;
-  } else if (RDCOST(x->rdmult,
-                    (rd_stats_y2.rate + rd_stats_uv2.rate + rate_skip0),
-                    (rd_stats_y2.dist + rd_stats_uv2.dist)) >
-             RDCOST(x->rdmult, rate_skip1,
-                    (rd_stats_y2.sse + rd_stats_uv2.sse))) {
-    rd_stats_y2.rate = rate_skip1;
-    rd_stats_uv2.rate = 0;
-    rd_stats_y2.dist = rd_stats_y2.sse;
-    rd_stats_uv2.dist = rd_stats_uv2.sse;
-    skip_blk = 1;
-  } else {
-    rd_stats_y2.rate += rate_skip0;
-    skip_blk = 0;
-  }
-
-  backup_mbmi = *mbmi;
-  backup_skip = skip_blk;
-#if CONFIG_VAR_TX
-  memcpy(backup_blk_skip, x->blk_skip[0], sizeof(backup_blk_skip[0]) * n4);
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+  if (mbmi->motion_mode == NCOBMC_ADAPT_WEIGHT)
+    for (int plane = 0; plane < MAX_MB_PLANE; ++plane)
+      get_pred_from_intrpl_buf(xd, mi_row, mi_col, bsize, plane);
 #endif
-  prev_rd = RDCOST(x->rdmult, (rd_stats_y2.rate + rd_stats_uv2.rate),
-                   (rd_stats_y2.dist + rd_stats_uv2.dist));
-  prev_rd +=
-      RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
-
-  // Compute the rd cost for ncobmc adaptive weight
-  mbmi->motion_mode = NCOBMC_ADAPT_WEIGHT;
-
-  for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
-    get_pred_from_intrpl_buf(xd, mi_row, mi_col, bsize, plane);
-  }
-
   av1_subtract_plane(x, bsize, 0);
 
 #if CONFIG_VAR_TX
@@ -12713,7 +12666,7 @@
     rd_stats_uv.rate = 0;
     rd_stats_y.dist = rd_stats_y.sse;
     rd_stats_uv.dist = rd_stats_uv.sse;
-    skip_blk = 1;
+    *skip_blk = 1;
   } else if (RDCOST(x->rdmult,
                     (rd_stats_y.rate + rd_stats_uv.rate + rate_skip0),
                     (rd_stats_y.dist + rd_stats_uv.dist)) >
@@ -12723,32 +12676,78 @@
     rd_stats_uv.rate = 0;
     rd_stats_y.dist = rd_stats_y.sse;
     rd_stats_uv.dist = rd_stats_uv.sse;
-    skip_blk = 1;
+    *skip_blk = 1;
   } else {
     rd_stats_y.rate += rate_skip0;
-    skip_blk = 0;
+    *skip_blk = 0;
   }
-  naw_rd = RDCOST(x->rdmult, (rd_stats_y.rate + rd_stats_uv.rate),
-                  (rd_stats_y.dist + rd_stats_uv.dist));
-  naw_rd += RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+  if (backup_mbmi) *backup_mbmi = *mbmi;
+
+  this_rd = RDCOST(x->rdmult, (rd_stats_y.rate + rd_stats_uv.rate),
+                   (rd_stats_y.dist + rd_stats_uv.dist));
+  this_rd +=
+      RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+  return this_rd;
+}
+
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
+                                      struct macroblock *x, int mi_row,
+                                      int mi_col) {
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  BLOCK_SIZE bsize = mbmi->sb_type;
+#if CONFIG_VAR_TX
+  const int n4 = bsize_to_num_blk(bsize);
+  uint8_t st_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+  uint8_t obmc_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+#endif
+  MB_MODE_INFO st_mbmi, obmc_mbmi;
+  int skip_blk, st_skip, obmc_skip;
+  int64_t st_rd, obmc_rd, ncobmc_rd;
+
+  // Recompute the rd for the motion mode decided in rd loop
+  mbmi->motion_mode = SIMPLE_TRANSLATION;
+  st_rd = get_prediction_rd_cost(cpi, x, mi_row, mi_col, &st_skip, &st_mbmi);
+#if CONFIG_VAR_TX
+  memcpy(st_blk_skip, x->blk_skip[0], sizeof(st_blk_skip[0]) * n4);
+#endif
+
+  mbmi->motion_mode = OBMC_CAUSAL;
+  obmc_rd =
+      get_prediction_rd_cost(cpi, x, mi_row, mi_col, &obmc_skip, &obmc_mbmi);
+#if CONFIG_VAR_TX
+  memcpy(obmc_blk_skip, x->blk_skip[0], sizeof(obmc_blk_skip[0]) * n4);
+#endif
+  // Compute the rd cost for ncobmc adaptive weight
+  mbmi->motion_mode = NCOBMC_ADAPT_WEIGHT;
+  ncobmc_rd = get_prediction_rd_cost(cpi, x, mi_row, mi_col, &skip_blk, NULL);
 
   // Calculate the ncobmc mode costs
   {
     ADAPT_OVERLAP_BLOCK aob = adapt_overlap_block_lookup[bsize];
-    naw_rd +=
+    ncobmc_rd +=
         RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[0]], 0);
     if (mi_size_wide[bsize] != mi_size_high[bsize])
-      naw_rd +=
+      ncobmc_rd +=
           RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[1]], 0);
   }
 
-  if (prev_rd > naw_rd) {
+  if (ncobmc_rd < AOMMIN(st_rd, obmc_rd)) {
     x->skip = skip_blk;
-  } else {
-    *mbmi = backup_mbmi;
-    x->skip = backup_skip;
+  } else if (obmc_rd < st_rd) {
+    *mbmi = obmc_mbmi;
+    x->skip = obmc_skip;
 #if CONFIG_VAR_TX
-    memcpy(x->blk_skip[0], backup_blk_skip, sizeof(backup_blk_skip[0]) * n4);
+    memcpy(x->blk_skip[0], obmc_blk_skip, sizeof(obmc_blk_skip[0]) * n4);
+#endif
+  } else {
+    *mbmi = st_mbmi;
+    x->skip = st_skip;
+#if CONFIG_VAR_TX
+    memcpy(x->blk_skip[0], st_blk_skip, sizeof(st_blk_skip[0]) * n4);
 #endif
   }
 }