ncobmc-adapt-weight: refactoring the mode selection function

Change-Id: I7393596d98f11aa53ba4b9e329386b5168b3e086
diff --git a/av1/common/enums.h b/av1/common/enums.h
index a55d3f7..6241b0e 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -24,7 +24,12 @@
 
 #if CONFIG_NCOBMC_ADAPT_WEIGHT
 #define TWO_MODE
-// #define FOUR_MODE
+#endif
+
+#if CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT
+#define NC_MODE_INFO 1
+#else
+#define NC_MODE_INFO 1
 #endif
 
 // Max superblock size
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index db28a30..dbdc126 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2114,7 +2114,7 @@
   aom_merge_corrupted_flag(&xd->corrupted, reader_corrupted_flag);
 }
 
-#if (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT) && CONFIG_MOTION_VAR
+#if NC_MODE_INFO && CONFIG_MOTION_VAR
 static void detoken_and_recon_sb(AV1Decoder *const pbi, MACROBLOCKD *const xd,
                                  int mi_row, int mi_col, aom_reader *r,
                                  BLOCK_SIZE bsize) {
@@ -2214,7 +2214,7 @@
 #endif
                     bsize);
 
-#if !(CONFIG_MOTION_VAR && (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT))
+#if !(CONFIG_MOTION_VAR && NC_MODE_INFO)
 #if CONFIG_SUPERTX
   if (!supertx_enabled)
 #endif  // CONFIG_SUPERTX
@@ -3846,7 +3846,7 @@
                            0,
 #endif  // CONFIG_SUPERTX
                            mi_row, mi_col, &td->bit_reader, cm->sb_size);
-#if (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT) && CONFIG_MOTION_VAR
+#if NC_MODE_INFO && CONFIG_MOTION_VAR
           detoken_and_recon_sb(pbi, &td->xd, mi_row, mi_col, &td->bit_reader,
                                cm->sb_size);
 #endif
@@ -4000,7 +4000,7 @@
                        0,
 #endif
                        mi_row, mi_col, &tile_data->bit_reader, cm->sb_size);
-#if (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT) && CONFIG_MOTION_VAR
+#if NC_MODE_INFO && CONFIG_MOTION_VAR
       detoken_and_recon_sb(pbi, &tile_data->xd, mi_row, mi_col,
                            &tile_data->bit_reader, cm->sb_size);
 #endif
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index efefdb3..860f751 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -2743,7 +2743,7 @@
 #endif  // CONFIG_COEF_INTERLEAVE
 }
 
-#if CONFIG_MOTION_VAR && (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT)
+#if CONFIG_MOTION_VAR && NC_MODE_INFO
 static void write_tokens_sb(AV1_COMP *cpi, const TileInfo *const tile,
                             aom_writer *w, const TOKENEXTRA **tok,
                             const TOKENEXTRA *const tok_end, int mi_row,
@@ -2830,7 +2830,7 @@
 #endif
                mi_row, mi_col);
 
-#if CONFIG_MOTION_VAR && (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT)
+#if CONFIG_MOTION_VAR && NC_MODE_INFO
   (void)tok;
   (void)tok_end;
 #else
@@ -3199,7 +3199,7 @@
     for (mi_col = mi_col_start; mi_col < mi_col_end; mi_col += cm->mib_size) {
       write_modes_sb_wrapper(cpi, tile, w, tok, tok_end, 0, mi_row, mi_col,
                              cm->sb_size);
-#if CONFIG_MOTION_VAR && (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT)
+#if CONFIG_MOTION_VAR && NC_MODE_INFO
       write_tokens_sb(cpi, tile, w, tok, tok_end, mi_row, mi_col, cm->sb_size);
 #endif
     }
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a631a25..bce9b62 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1165,7 +1165,7 @@
 }
 #endif  // CONFIG_SUPERTX
 
-#if CONFIG_MOTION_VAR && (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT)
+#if CONFIG_MOTION_VAR && NC_MODE_INFO
 static void set_mode_info_b(const AV1_COMP *const cpi,
                             const TileInfo *const tile, ThreadData *td,
                             int mi_row, int mi_col, BLOCK_SIZE bsize,
@@ -2117,9 +2117,12 @@
 #endif
 
 #if CONFIG_NCOBMC_ADAPT_WEIGHT
-  if (dry_run == OUTPUT_ENABLED && motion_allowed == NCOBMC_ADAPT_WEIGHT) {
-    get_ncobmc_intrpl_pred(cpi, td, mi_row, mi_col, bsize);
-    av1_check_ncobmc_adapt_weight_rd(cpi, x, mi_row, mi_col);
+  if (dry_run == OUTPUT_ENABLED && !frame_is_intra_only(&cpi->common)) {
+    // we also need to handle inter-intra
+    if (motion_allowed == NCOBMC_ADAPT_WEIGHT && is_inter_block(mbmi)) {
+      get_ncobmc_intrpl_pred(cpi, td, mi_row, mi_col, bsize);
+      av1_check_ncobmc_adapt_weight_rd(cpi, x, mi_row, mi_col);
+    }
     av1_setup_dst_planes(x->e_mbd.plane, bsize,
                          get_frame_new_buffer(&cpi->common), mi_row, mi_col);
   }
@@ -4500,7 +4503,7 @@
   if (best_rdc.rate < INT_MAX && best_rdc.dist < INT64_MAX &&
       pc_tree->index != 3) {
     if (bsize == cm->sb_size) {
-#if CONFIG_MOTION_VAR && (CONFIG_NCOBMC || CONFIG_NCOBMC_ADAPT_WEIGHT)
+#if CONFIG_MOTION_VAR && NC_MODE_INFO
       set_mode_info_sb(cpi, td, tile_info, tp, mi_row, mi_col, bsize, pc_tree);
 #endif
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 313de49..d1dde09 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -12591,102 +12591,55 @@
 }
 #endif  // CONFIG_NCOBMC
 
-#if CONFIG_NCOBMC_ADAPT_WEIGHT
-void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
-                                      struct macroblock *x, int mi_row,
-                                      int mi_col) {
+int64_t get_prediction_rd_cost(const struct AV1_COMP *cpi, struct macroblock *x,
+                               int mi_row, int mi_col, int *skip_blk,
+                               MB_MODE_INFO *backup_mbmi) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   BLOCK_SIZE bsize = mbmi->sb_type;
-#if CONFIG_VAR_TX
-  const int n4 = bsize_to_num_blk(bsize);
-  uint8_t backup_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
-#endif
-  MB_MODE_INFO backup_mbmi;
-  int plane, ref, skip_blk, backup_skip;
-  RD_STATS rd_stats_y, rd_stats_uv, rd_stats_y2, rd_stats_uv2;
+  RD_STATS rd_stats_y, rd_stats_uv;
   int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
   int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
-  int64_t prev_rd, naw_rd;  // ncobmc_adapt_weight_rd
+  int64_t this_rd;
+  int ref;
 
-  // Recompute the rd for the motion mode decided in rd loop
-  if (mbmi->motion_mode == SIMPLE_TRANSLATION ||
-      mbmi->motion_mode == OBMC_CAUSAL) {
-    set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
-    for (ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
-      YV12_BUFFER_CONFIG *cfg = get_ref_frame_buffer(cpi, mbmi->ref_frame[ref]);
-      assert(cfg != NULL);
-      av1_setup_pre_planes(xd, ref, cfg, mi_row, mi_col,
-                           &xd->block_refs[ref]->sf);
-    }
-    av1_setup_dst_planes(xd->plane, bsize, get_frame_new_buffer(cm), mi_row,
-                         mi_col);
+#if CONFIG_CB4X4
+  x->skip_chroma_rd =
+      !is_chroma_reference(mi_row, mi_col, bsize, xd->plane[1].subsampling_x,
+                           xd->plane[1].subsampling_y);
+#endif
 
+  set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+  for (ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
+    YV12_BUFFER_CONFIG *cfg = get_ref_frame_buffer(cpi, mbmi->ref_frame[ref]);
+    assert(cfg != NULL);
+    av1_setup_pre_planes(xd, ref, cfg, mi_row, mi_col,
+                         &xd->block_refs[ref]->sf);
+  }
+  av1_setup_dst_planes(x->e_mbd.plane, bsize,
+                       get_frame_new_buffer(&cpi->common), mi_row, mi_col);
+
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+  if (mbmi->motion_mode != NCOBMC_ADAPT_WEIGHT)
+#endif
     av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
-    if (mbmi->motion_mode == OBMC_CAUSAL) {
-      av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
-    }
 
-    av1_subtract_plane(x, bsize, 0);
-
-#if CONFIG_VAR_TX
-    if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
-      select_tx_type_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
-    } else {
-      int idx, idy;
-      super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
-      for (idy = 0; idy < xd->n8_h; ++idy)
-        for (idx = 0; idx < xd->n8_w; ++idx)
-          mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
-      memset(x->blk_skip[0], rd_stats_y2.skip,
-             sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
-    }
-    inter_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+#if CONFIG_MOTION_VAR
+  if (mbmi->motion_mode == OBMC_CAUSAL) {
+#if CONFIG_NCOBMC
+    av1_build_ncobmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
 #else
-    super_block_yrd(cpi, x, &rd_stats_y2, bsize, INT64_MAX);
-    super_block_uvrd(cpi, x, &rd_stats_uv2, bsize, INT64_MAX);
+    av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
 #endif
   }
+#endif  // CONFIG_MOTION_VAR
 
-  if (rd_stats_y2.skip && rd_stats_uv2.skip) {
-    rd_stats_y2.rate = rate_skip1;
-    rd_stats_uv2.rate = 0;
-    rd_stats_y2.dist = rd_stats_y2.sse;
-    rd_stats_uv2.dist = rd_stats_uv2.sse;
-    skip_blk = 1;
-  } else if (RDCOST(x->rdmult,
-                    (rd_stats_y2.rate + rd_stats_uv2.rate + rate_skip0),
-                    (rd_stats_y2.dist + rd_stats_uv2.dist)) >
-             RDCOST(x->rdmult, rate_skip1,
-                    (rd_stats_y2.sse + rd_stats_uv2.sse))) {
-    rd_stats_y2.rate = rate_skip1;
-    rd_stats_uv2.rate = 0;
-    rd_stats_y2.dist = rd_stats_y2.sse;
-    rd_stats_uv2.dist = rd_stats_uv2.sse;
-    skip_blk = 1;
-  } else {
-    rd_stats_y2.rate += rate_skip0;
-    skip_blk = 0;
-  }
-
-  backup_mbmi = *mbmi;
-  backup_skip = skip_blk;
-#if CONFIG_VAR_TX
-  memcpy(backup_blk_skip, x->blk_skip[0], sizeof(backup_blk_skip[0]) * n4);
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+  if (mbmi->motion_mode == NCOBMC_ADAPT_WEIGHT)
+    for (int plane = 0; plane < MAX_MB_PLANE; ++plane)
+      get_pred_from_intrpl_buf(xd, mi_row, mi_col, bsize, plane);
 #endif
-  prev_rd = RDCOST(x->rdmult, (rd_stats_y2.rate + rd_stats_uv2.rate),
-                   (rd_stats_y2.dist + rd_stats_uv2.dist));
-  prev_rd +=
-      RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
-
-  // Compute the rd cost for ncobmc adaptive weight
-  mbmi->motion_mode = NCOBMC_ADAPT_WEIGHT;
-
-  for (plane = 0; plane < MAX_MB_PLANE; ++plane) {
-    get_pred_from_intrpl_buf(xd, mi_row, mi_col, bsize, plane);
-  }
-
   av1_subtract_plane(x, bsize, 0);
 
 #if CONFIG_VAR_TX
@@ -12713,7 +12666,7 @@
     rd_stats_uv.rate = 0;
     rd_stats_y.dist = rd_stats_y.sse;
     rd_stats_uv.dist = rd_stats_uv.sse;
-    skip_blk = 1;
+    *skip_blk = 1;
   } else if (RDCOST(x->rdmult,
                     (rd_stats_y.rate + rd_stats_uv.rate + rate_skip0),
                     (rd_stats_y.dist + rd_stats_uv.dist)) >
@@ -12723,32 +12676,78 @@
     rd_stats_uv.rate = 0;
     rd_stats_y.dist = rd_stats_y.sse;
     rd_stats_uv.dist = rd_stats_uv.sse;
-    skip_blk = 1;
+    *skip_blk = 1;
   } else {
     rd_stats_y.rate += rate_skip0;
-    skip_blk = 0;
+    *skip_blk = 0;
   }
-  naw_rd = RDCOST(x->rdmult, (rd_stats_y.rate + rd_stats_uv.rate),
-                  (rd_stats_y.dist + rd_stats_uv.dist));
-  naw_rd += RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+  if (backup_mbmi) *backup_mbmi = *mbmi;
+
+  this_rd = RDCOST(x->rdmult, (rd_stats_y.rate + rd_stats_uv.rate),
+                   (rd_stats_y.dist + rd_stats_uv.dist));
+  this_rd +=
+      RDCOST(x->rdmult, x->motion_mode_cost[bsize][mbmi->motion_mode], 0);
+
+  return this_rd;
+}
+
+#if CONFIG_NCOBMC_ADAPT_WEIGHT
+void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
+                                      struct macroblock *x, int mi_row,
+                                      int mi_col) {
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  BLOCK_SIZE bsize = mbmi->sb_type;
+#if CONFIG_VAR_TX
+  const int n4 = bsize_to_num_blk(bsize);
+  uint8_t st_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+  uint8_t obmc_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+#endif
+  MB_MODE_INFO st_mbmi, obmc_mbmi;
+  int skip_blk, st_skip, obmc_skip;
+  int64_t st_rd, obmc_rd, ncobmc_rd;
+
+  // Recompute the rd for the motion mode decided in rd loop
+  mbmi->motion_mode = SIMPLE_TRANSLATION;
+  st_rd = get_prediction_rd_cost(cpi, x, mi_row, mi_col, &st_skip, &st_mbmi);
+#if CONFIG_VAR_TX
+  memcpy(st_blk_skip, x->blk_skip[0], sizeof(st_blk_skip[0]) * n4);
+#endif
+
+  mbmi->motion_mode = OBMC_CAUSAL;
+  obmc_rd =
+      get_prediction_rd_cost(cpi, x, mi_row, mi_col, &obmc_skip, &obmc_mbmi);
+#if CONFIG_VAR_TX
+  memcpy(obmc_blk_skip, x->blk_skip[0], sizeof(obmc_blk_skip[0]) * n4);
+#endif
+  // Compute the rd cost for ncobmc adaptive weight
+  mbmi->motion_mode = NCOBMC_ADAPT_WEIGHT;
+  ncobmc_rd = get_prediction_rd_cost(cpi, x, mi_row, mi_col, &skip_blk, NULL);
 
   // Calculate the ncobmc mode costs
   {
     ADAPT_OVERLAP_BLOCK aob = adapt_overlap_block_lookup[bsize];
-    naw_rd +=
+    ncobmc_rd +=
         RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[0]], 0);
     if (mi_size_wide[bsize] != mi_size_high[bsize])
-      naw_rd +=
+      ncobmc_rd +=
           RDCOST(x->rdmult, x->ncobmc_mode_cost[aob][mbmi->ncobmc_mode[1]], 0);
   }
 
-  if (prev_rd > naw_rd) {
+  if (ncobmc_rd < AOMMIN(st_rd, obmc_rd)) {
     x->skip = skip_blk;
-  } else {
-    *mbmi = backup_mbmi;
-    x->skip = backup_skip;
+  } else if (obmc_rd < st_rd) {
+    *mbmi = obmc_mbmi;
+    x->skip = obmc_skip;
 #if CONFIG_VAR_TX
-    memcpy(x->blk_skip[0], backup_blk_skip, sizeof(backup_blk_skip[0]) * n4);
+    memcpy(x->blk_skip[0], obmc_blk_skip, sizeof(obmc_blk_skip[0]) * n4);
+#endif
+  } else {
+    *mbmi = st_mbmi;
+    x->skip = st_skip;
+#if CONFIG_VAR_TX
+    memcpy(x->blk_skip[0], st_blk_skip, sizeof(st_blk_skip[0]) * n4);
 #endif
   }
 }
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index bcad8f8..355fb52 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -144,6 +144,10 @@
                      const MACROBLOCKD *xd, BLOCK_SIZE bsize, int plane,
                      TX_SIZE tx_size, TX_TYPE tx_type);
 
+int64_t get_prediction_rd_cost(const struct AV1_COMP *cpi, struct macroblock *x,
+                               int mi_row, int mi_col, int *skip_blk,
+                               MB_MODE_INFO *backup_mbmi);
+
 #if CONFIG_NCOBMC_ADAPT_WEIGHT
 void av1_check_ncobmc_adapt_weight_rd(const struct AV1_COMP *cpi,
                                       struct macroblock *x, int mi_row,
@@ -152,12 +156,6 @@
 int get_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
                     MACROBLOCKD *xd, int mi_row, int mi_col, int bsize);
 
-void av1_setup_src_planes_pxl(MACROBLOCK *x, const YV12_BUFFER_CONFIG *src,
-                              int pxl_row, int pxl_col);
-
-void rebuild_ncobmc_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
-                         MACROBLOCKD *xd, int mi_row, int mi_col, int bsize,
-                         int xd_mi_offset, NCOBMC_MODE best_mode, int rebuild);
 #endif
 
 #endif  // AV1_ENCODER_RDOPT_H_