Extract compound type search as a function

1. Extract compound_type_rd for search of
inter compound type
2. Merge branches, remove identical code.
3. Skip some memory copy when possible.
4. Move av1_subtract_plane into estimate_yrd_for_sb

Change-Id: Ie93d114586d0cf31a3d0c082fa2489512339db9a
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 3e8d1d6..d1c8263 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -605,6 +605,12 @@
   return xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH ? 1 : 0;
 }
 
+static INLINE uint8_t *get_buf_by_bd(const MACROBLOCKD *xd, uint8_t *buf16) {
+  return (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+             ? CONVERT_TO_BYTEPTR(buf16)
+             : buf16;
+}
+
 static INLINE int get_sqr_bsize_idx(BLOCK_SIZE bsize) {
   switch (bsize) {
     case BLOCK_4X4: return 0;
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index b6ac436..3cb4408 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -627,9 +627,7 @@
                   tmp_buf[INTER_PRED_BYTES_PER_PIXEL * MAX_SB_SQUARE]);
 #undef INTER_PRED_BYTES_PER_PIXEL
 
-  uint8_t *tmp_dst = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-                         ? CONVERT_TO_BYTEPTR(tmp_buf)
-                         : tmp_buf;
+  uint8_t *tmp_dst = get_buf_by_bd(xd, tmp_buf);
 
   const int tmp_buf_stride = MAX_SB_SIZE;
   CONV_BUF_TYPE *org_dst = conv_params->dst;
@@ -1713,9 +1711,7 @@
 
   const struct scale_factors *const sf = &xd->block_refs[ref]->sf;
   struct buf_2d *const pre_buf = &pd->pre[ref];
-  const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH;
-  uint8_t *const dst =
-      (hbd ? CONVERT_TO_BYTEPTR(ext_dst) : ext_dst) + ext_dst_stride * y + x;
+  uint8_t *const dst = get_buf_by_bd(xd, ext_dst) + ext_dst_stride * y + x;
   const MV mv = mi->mv[ref].as_mv;
 
   ConvolveParams conv_params = get_conv_params(ref, 0, plane, xd->bd);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 727c67b..ac2646d 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -60,6 +60,7 @@
 
 // Set this macro as 1 to collect data about tx size selection.
 #define COLLECT_TX_SIZE_DATA 0
+
 #if COLLECT_TX_SIZE_DATA
 static const char av1_tx_size_data_output_file[] = "tx_size_data.txt";
 #endif
@@ -3097,6 +3098,7 @@
                                    MACROBLOCK *x, int *r, int64_t *d, int *s,
                                    int64_t *sse, int64_t ref_best_rd) {
   RD_STATS rd_stats;
+  av1_subtract_plane(x, bs, 0);
   x->rd_model = LOW_TXFM_RD;
   int64_t rd = txfm_yrd(cpi, x, &rd_stats, ref_best_rd, bs,
                         max_txsize_rect_lookup[bs], FTXS_NONE);
@@ -7375,7 +7377,6 @@
       av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
                                                preds1, strides);
     }
-    av1_subtract_plane(x, bsize, 0);
     rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
                              &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
     if (rd != INT64_MAX)
@@ -7385,7 +7386,6 @@
   } else {
     av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
                                              preds1, strides);
-    av1_subtract_plane(x, bsize, 0);
     rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
                              &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
     if (rd != INT64_MAX)
@@ -7869,7 +7869,6 @@
       av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
                                                 intrapred, bw);
       av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
-      av1_subtract_plane(x, bsize, 0);
       rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
                                &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
       if (rd != INT64_MAX)
@@ -7933,7 +7932,6 @@
             av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
           }
           // Evaluate closer to true rd
-          av1_subtract_plane(x, bsize, 0);
           rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
                                    &tmp_skip_txfm_sb, &tmp_skip_sse_sb,
                                    INT64_MAX);
@@ -8331,6 +8329,123 @@
   return cost;
 }
 
+static int compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
+                            BLOCK_SIZE bsize, int mi_col, int mi_row,
+                            int_mv *cur_mv, int masked_compound_used,
+                            BUFFER_SET *orig_dst, int *rate_mv, int64_t *rd,
+                            RD_STATS *rd_stats, int64_t ref_best_rd) {
+  const AV1_COMMON *cm = &cpi->common;
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *mbmi = xd->mi[0];
+  const int this_mode = mbmi->mode;
+  const int bw = block_size_wide[bsize];
+  int rate_sum, rs2;
+  int64_t dist_sum;
+
+  int_mv best_mv[2];
+  int best_tmp_rate_mv = *rate_mv;
+  int tmp_skip_txfm_sb;
+  int64_t tmp_skip_sse_sb;
+  INTERINTER_COMPOUND_DATA best_compound_data;
+  best_compound_data.type = COMPOUND_AVERAGE;
+  DECLARE_ALIGNED(16, uint8_t, pred0[2 * MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(16, uint8_t, pred1[2 * MAX_SB_SQUARE]);
+  uint8_t tmp_best_mask_buf[2 * MAX_SB_SQUARE];
+  uint8_t *preds0[1] = { pred0 };
+  uint8_t *preds1[1] = { pred1 };
+  int strides[1] = { bw };
+  int tmp_rate_mv;
+  const int num_pix = 1 << num_pels_log2_lookup[bsize];
+  const int mask_len = 2 * num_pix * sizeof(uint8_t);
+  COMPOUND_TYPE cur_type;
+  int best_compmode_interinter_cost = 0;
+  int can_use_previous = cm->allow_warped_motion;
+
+  best_mv[0].as_int = cur_mv[0].as_int;
+  best_mv[1].as_int = cur_mv[1].as_int;
+  *rd = INT64_MAX;
+  if (masked_compound_used) {
+    // get inter predictors to use for masked compound modes
+    av1_build_inter_predictors_for_planes_single_buf(
+        xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides, can_use_previous);
+    av1_build_inter_predictors_for_planes_single_buf(
+        xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides, can_use_previous);
+  }
+  for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
+    if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break;
+    if (!is_interinter_compound_used(cur_type, bsize)) continue;
+    tmp_rate_mv = *rate_mv;
+    int64_t best_rd_cur = INT64_MAX;
+    mbmi->interinter_comp.type = cur_type;
+    int masked_type_cost = 0;
+
+    const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+    const int comp_index_ctx = get_comp_index_context(cm, xd);
+    mbmi->compound_idx = 1;
+    if (cur_type == COMPOUND_AVERAGE) {
+      mbmi->comp_group_idx = 0;
+      if (masked_compound_used) {
+        masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+      }
+      masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+      rs2 = masked_type_cost;
+      av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize);
+      int64_t est_rd =
+          estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                              &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
+      if (est_rd != INT64_MAX)
+        best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
+    } else {
+      mbmi->comp_group_idx = 1;
+      masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
+      masked_type_cost += x->compound_type_cost[bsize][cur_type - 1];
+      rs2 = masked_type_cost;
+      if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
+          *rd / 3 < ref_best_rd) {
+        best_rd_cur = build_and_cost_compound_type(
+            cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
+            &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
+      }
+    }
+    if (best_rd_cur < *rd) {
+      *rd = best_rd_cur;
+      best_compound_data = mbmi->interinter_comp;
+      if (masked_compound_used && cur_type != COMPOUND_TYPES - 1) {
+        memcpy(tmp_best_mask_buf, xd->seg_mask, mask_len);
+      }
+      best_compmode_interinter_cost = rs2;
+      if (have_newmv_in_inter_mode(this_mode)) {
+        if (use_masked_motion_search(cur_type)) {
+          best_tmp_rate_mv = tmp_rate_mv;
+          best_mv[0].as_int = mbmi->mv[0].as_int;
+          best_mv[1].as_int = mbmi->mv[1].as_int;
+        } else {
+          best_mv[0].as_int = cur_mv[0].as_int;
+          best_mv[1].as_int = cur_mv[1].as_int;
+        }
+      }
+    }
+    // reset to original mvs for next iteration
+    mbmi->mv[0].as_int = cur_mv[0].as_int;
+    mbmi->mv[1].as_int = cur_mv[1].as_int;
+  }
+  if (mbmi->interinter_comp.type != best_compound_data.type) {
+    mbmi->comp_group_idx =
+        (best_compound_data.type == COMPOUND_AVERAGE) ? 0 : 1;
+    mbmi->interinter_comp = best_compound_data;
+    memcpy(xd->seg_mask, tmp_best_mask_buf, mask_len);
+  }
+  if (have_newmv_in_inter_mode(this_mode)) {
+    mbmi->mv[0].as_int = best_mv[0].as_int;
+    mbmi->mv[1].as_int = best_mv[1].as_int;
+    if (use_masked_motion_search(mbmi->interinter_comp.type)) {
+      rd_stats->rate += best_tmp_rate_mv - *rate_mv;
+      *rate_mv = best_tmp_rate_mv;
+    }
+  }
+  return best_compmode_interinter_cost;
+}
+
 static int64_t handle_inter_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
                                  BLOCK_SIZE bsize, RD_STATS *rd_stats,
                                  RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv,
@@ -8352,9 +8467,8 @@
   int refs[2] = { mbmi->ref_frame[0],
                   (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
   int rate_mv = 0;
-  const int bw = block_size_wide[bsize];
   DECLARE_ALIGNED(32, uint8_t, tmp_buf_[2 * MAX_MB_PLANE * MAX_SB_SQUARE]);
-  uint8_t *tmp_buf;
+  uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
   int64_t rd = INT64_MAX;
   BUFFER_SET orig_dst, tmp_dst;
 
@@ -8368,15 +8482,6 @@
   if (mbmi->ref_frame[1] == INTRA_FRAME) mbmi->ref_frame[1] = NONE_FRAME;
 
   mode_ctx = av1_mode_context_analyzer(mbmi_ext->mode_context, mbmi->ref_frame);
-
-  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-    tmp_buf = CONVERT_TO_BYTEPTR(tmp_buf_);
-  else
-    tmp_buf = tmp_buf_;
-  // Make sure that we didn't leave the plane destination buffers set
-  // to tmp_buf at the end of the last iteration
-  assert(xd->plane[0].dst.buf != tmp_buf);
-
   mbmi->num_proj_ref[0] = 0;
   mbmi->num_proj_ref[1] = 0;
 
@@ -8401,8 +8506,6 @@
   const RD_STATS backup_rd_stats_y = *rd_stats_y;
   const RD_STATS backup_rd_stats_uv = *rd_stats_uv;
   const MB_MODE_INFO backup_mbmi = *mbmi;
-  INTERINTER_COMPOUND_DATA best_compound_data;
-  uint8_t tmp_best_mask_buf[2 * MAX_SB_SQUARE];
   RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv;
   int64_t best_rd = INT64_MAX;
   int64_t best_ret_val = INT64_MAX;
@@ -8543,147 +8646,14 @@
     }
 
     if (is_comp_pred && comp_idx) {
-      int rate_sum, rs2;
-      int64_t dist_sum;
-      int64_t best_rd_compound = INT64_MAX, best_rd_cur = INT64_MAX;
-      int_mv best_mv[2];
-      int best_tmp_rate_mv = rate_mv;
-      int tmp_skip_txfm_sb;
-      int64_t tmp_skip_sse_sb;
-      DECLARE_ALIGNED(16, uint8_t, pred0[2 * MAX_SB_SQUARE]);
-      DECLARE_ALIGNED(16, uint8_t, pred1[2 * MAX_SB_SQUARE]);
-      uint8_t *preds0[1] = { pred0 };
-      uint8_t *preds1[1] = { pred1 };
-      int strides[1] = { bw };
-      int tmp_rate_mv;
-      const int num_pix = 1 << num_pels_log2_lookup[bsize];
-      COMPOUND_TYPE cur_type;
-      int best_compmode_interinter_cost = 0;
-      int can_use_previous = cm->allow_warped_motion;
-
-      best_mv[0].as_int = cur_mv[0].as_int;
-      best_mv[1].as_int = cur_mv[1].as_int;
-
-      if (masked_compound_used) {
-        // get inter predictors to use for masked compound modes
-        av1_build_inter_predictors_for_planes_single_buf(
-            xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides,
-            can_use_previous);
-        av1_build_inter_predictors_for_planes_single_buf(
-            xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides,
-            can_use_previous);
-      }
-
-      int best_comp_group_idx = 0;
-      int best_compound_idx = 1;
-      for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
-        if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break;
-        if (!is_interinter_compound_used(cur_type, bsize)) continue;
-        tmp_rate_mv = rate_mv;
-        best_rd_cur = INT64_MAX;
-        mbmi->interinter_comp.type = cur_type;
-        int masked_type_cost = 0;
-
-        const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
-        const int comp_index_ctx = get_comp_index_context(cm, xd);
-        if (masked_compound_used) {
-          if (cur_type == COMPOUND_AVERAGE) {
-            mbmi->comp_group_idx = 0;
-            mbmi->compound_idx = 1;
-
-            masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
-            masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
-          } else {
-            mbmi->comp_group_idx = 1;
-            mbmi->compound_idx = 1;
-
-            masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
-            masked_type_cost +=
-                x->compound_type_cost[bsize][mbmi->interinter_comp.type - 1];
-          }
-        } else {
-          mbmi->comp_group_idx = 0;
-          mbmi->compound_idx = 1;
-
-          masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
-        }
-        rs2 = masked_type_cost;
-
-        switch (cur_type) {
-          case COMPOUND_AVERAGE:
-            av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst,
-                                           bsize);
-            av1_subtract_plane(x, bsize, 0);
-            rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
-                                     &tmp_skip_txfm_sb, &tmp_skip_sse_sb,
-                                     INT64_MAX);
-            if (rd != INT64_MAX)
-              best_rd_cur =
-                  RDCOST(x->rdmult, rs2 + rate_mv + rate_sum, dist_sum);
-            break;
-          case COMPOUND_WEDGE:
-            if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
-                best_rd_compound / 3 < ref_best_rd) {
-              best_rd_cur = build_and_cost_compound_type(
-                  cpi, x, cur_mv, bsize, this_mode, &rs2, rate_mv, &orig_dst,
-                  &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
-            }
-            break;
-          case COMPOUND_DIFFWTD:
-            if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
-                best_rd_compound / 3 < ref_best_rd) {
-              best_rd_cur = build_and_cost_compound_type(
-                  cpi, x, cur_mv, bsize, this_mode, &rs2, rate_mv, &orig_dst,
-                  &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
-            }
-            break;
-          default: assert(0); return INT64_MAX;
-        }
-
-        if (best_rd_cur < best_rd_compound) {
-          best_comp_group_idx = mbmi->comp_group_idx;
-          best_compound_idx = mbmi->compound_idx;
-          best_rd_compound = best_rd_cur;
-          best_compound_data = mbmi->interinter_comp;
-          memcpy(tmp_best_mask_buf, xd->seg_mask,
-                 2 * num_pix * sizeof(uint8_t));
-          best_compmode_interinter_cost = rs2;
-          if (have_newmv_in_inter_mode(this_mode)) {
-            if (use_masked_motion_search(cur_type)) {
-              best_tmp_rate_mv = tmp_rate_mv;
-              best_mv[0].as_int = mbmi->mv[0].as_int;
-              best_mv[1].as_int = mbmi->mv[1].as_int;
-            } else {
-              best_mv[0].as_int = cur_mv[0].as_int;
-              best_mv[1].as_int = cur_mv[1].as_int;
-            }
-          }
-        }
-        // reset to original mvs for next iteration
-        mbmi->mv[0].as_int = cur_mv[0].as_int;
-        mbmi->mv[1].as_int = cur_mv[1].as_int;
-      }
-      mbmi->comp_group_idx = best_comp_group_idx;
-      mbmi->compound_idx = best_compound_idx;
-      mbmi->interinter_comp = best_compound_data;
-      assert(IMPLIES(mbmi->comp_group_idx == 1,
-                     mbmi->interinter_comp.type != COMPOUND_AVERAGE));
-      memcpy(xd->seg_mask, tmp_best_mask_buf, 2 * num_pix * sizeof(uint8_t));
-      if (have_newmv_in_inter_mode(this_mode)) {
-        mbmi->mv[0].as_int = best_mv[0].as_int;
-        mbmi->mv[1].as_int = best_mv[1].as_int;
-        if (use_masked_motion_search(mbmi->interinter_comp.type)) {
-          rd_stats->rate += best_tmp_rate_mv - rate_mv;
-          rate_mv = best_tmp_rate_mv;
-        }
-      }
-
-      if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) {
+      compmode_interinter_cost = compound_type_rd(
+          cpi, x, bsize, mi_col, mi_row, cur_mv, masked_compound_used,
+          &orig_dst, &rate_mv, &rd, rd_stats, ref_best_rd);
+      if (ref_best_rd < INT64_MAX && rd / 3 > ref_best_rd) {
         restore_dst_buf(xd, orig_dst, num_planes);
         early_terminate = INT64_MAX;
         continue;
       }
-      compmode_interinter_cost = best_compmode_interinter_cost;
     }
 
     if (is_comp_pred) {