Enable compound motion search in tpl model

STATS_CHANGED

Change-Id: I71bca71ba41ded81f546d05befcdabe1e27f1449
diff --git a/av1/common/convolve.h b/av1/common/convolve.h
index 490d778..5f3e596 100644
--- a/av1/common/convolve.h
+++ b/av1/common/convolve.h
@@ -68,6 +68,7 @@
   assert(IMPLIES(cmp_index, is_compound));
 
   conv_params.is_compound = is_compound;
+  conv_params.use_dist_wtd_comp_avg = 0;
   conv_params.round_0 = ROUND0_BITS;
   conv_params.round_1 = is_compound ? COMPOUND_ROUND1_BITS
                                     : 2 * FILTER_BITS - conv_params.round_0;
diff --git a/av1/encoder/motion_search_facade.c b/av1/encoder/motion_search_facade.c
index 6294895..50cbdc7 100644
--- a/av1/encoder/motion_search_facade.c
+++ b/av1/encoder/motion_search_facade.c
@@ -535,53 +535,15 @@
           xd, cm, &ms_params, start_mv, &best_mv.as_mv, &dis, &sse, NULL);
 
       if (try_second) {
-        struct macroblockd_plane *p = xd->plane;
-        const BUFFER_SET orig_dst = {
-          { p[0].dst.buf, p[1].dst.buf, p[2].dst.buf },
-          { p[0].dst.stride, p[1].dst.stride, p[2].dst.stride },
-        };
-        mbmi->mv[id].as_mv = best_mv.as_mv;
-        mbmi->mv[!id].as_mv = cur_mv[!id].as_mv;
-
-        xd->plane[plane].pre[0] = ref_yv12[0];
-        xd->plane[plane].pre[1] = ref_yv12[1];
-
-        av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, &orig_dst, bsize,
-                                      0, 0);
-        av1_subtract_plane(x, bsize, 0);
-        RD_STATS this_rd_stats;
-        av1_init_rd_stats(&this_rd_stats);
-        av1_estimate_txfm_yrd(cpi, x, &this_rd_stats, INT64_MAX, bsize,
-                              max_txsize_rect_lookup[bsize]);
-        int this_mv_rate = av1_mv_bit_cost(
-            &best_mv.as_mv, &ref_mv[id].as_mv, mv_costs->nmv_joint_cost,
-            mv_costs->mv_cost_stack, MV_COST_WEIGHT);
-        int64_t rd = RDCOST(x->rdmult, this_mv_rate + this_rd_stats.rate,
-                            this_rd_stats.dist);
-
         MV this_best_mv;
         MV subpel_start_mv = get_mv_from_fullmv(&second_best_mv.as_fullmv);
         if (av1_is_subpelmv_in_range(&ms_params.mv_limits, subpel_start_mv)) {
-          const int this_var = cpi->mv_search_params.find_fractional_mv_step(
+          const int thissme = cpi->mv_search_params.find_fractional_mv_step(
               xd, cm, &ms_params, subpel_start_mv, &this_best_mv, &dis, &sse,
               NULL);
-          mbmi->mv[id].as_mv = this_best_mv;
-          mbmi->mv[!id].as_mv = cur_mv[!id].as_mv;
-          av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, &orig_dst,
-                                        bsize, 0, 0);
-          av1_subtract_plane(x, bsize, 0);
-          RD_STATS tmp_rd_stats;
-          av1_init_rd_stats(&tmp_rd_stats);
-          av1_estimate_txfm_yrd(cpi, x, &tmp_rd_stats, INT64_MAX, bsize,
-                                max_txsize_rect_lookup[bsize]);
-          int tmp_mv_rate = av1_mv_bit_cost(
-              &this_best_mv, &ref_mv[id].as_mv, mv_costs->nmv_joint_cost,
-              mv_costs->mv_cost_stack, MV_COST_WEIGHT);
-          int64_t tmp_rd = RDCOST(x->rdmult, tmp_rd_stats.rate + tmp_mv_rate,
-                                  tmp_rd_stats.dist);
-          if (tmp_rd < rd) {
+          if (thissme < bestsme) {
             best_mv.as_mv = this_best_mv;
-            bestsme = AOMMIN(bestsme, this_var);
+            bestsme = thissme;
           }
         }
       }
diff --git a/av1/encoder/reconinter_enc.c b/av1/encoder/reconinter_enc.c
index 20da822..6020b94 100644
--- a/av1/encoder/reconinter_enc.c
+++ b/av1/encoder/reconinter_enc.c
@@ -77,7 +77,8 @@
                                        InterPredParams *inter_pred_params) {
   av1_build_one_inter_predictor(
       dst, dst_stride, src_mv, inter_pred_params, NULL /* xd */, 0 /* mi_x */,
-      0 /* mi_y */, 0 /* ref */, NULL /* mc_buf */, enc_calc_subpel_params);
+      0 /* mi_y */, inter_pred_params->conv_params.do_average /* ref */,
+      NULL /* mc_buf */, enc_calc_subpel_params);
 }
 
 static void enc_build_inter_predictors(const AV1_COMMON *cm, MACROBLOCKD *xd,
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 5344dc4..807cd6c 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -571,6 +571,7 @@
     sf->tpl_sf.reduce_first_step_size = 6;
     sf->tpl_sf.subpel_force_stop = QUARTER_PEL;
     sf->tpl_sf.search_method = DIAMOND;
+    sf->tpl_sf.allow_compound_pred = 0;
 
     sf->tx_sf.adaptive_txb_search_level = boosted ? 2 : 3;
     sf->tx_sf.tx_type_search.use_skip_flag_prediction = 2;
@@ -1064,6 +1065,7 @@
   tpl_sf->search_method = NSTEP;
   tpl_sf->disable_filtered_key_tpl = 0;
   tpl_sf->prune_ref_frames_in_tpl = 0;
+  tpl_sf->allow_compound_pred = 1;
 }
 
 static AOM_INLINE void init_gm_sf(GLOBAL_MOTION_SPEED_FEATURES *gm_sf) {
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 4d11f14..21cee75 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -369,6 +369,9 @@
 
   // Prune reference frames in TPL.
   int prune_ref_frames_in_tpl;
+
+  // Support compound predictions.
+  int allow_compound_pred;
 } TPL_SPEED_FEATURES;
 
 typedef struct GLOBAL_MOTION_SPEED_FEATURES {
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 821cbea..5949128 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -29,6 +29,7 @@
 #include "av1/encoder/encodeframe_utils.h"
 #include "av1/encoder/encode_strategy.h"
 #include "av1/encoder/hybrid_fwd_txfm.h"
+#include "av1/encoder/motion_search_facade.h"
 #include "av1/encoder/rd.h"
 #include "av1/encoder/rdopt.h"
 #include "av1/encoder/reconinter_enc.h"
@@ -434,16 +435,19 @@
   // Motion compensated prediction
   xd->mi[0]->ref_frame[0] = INTRA_FRAME;
   xd->mi[0]->ref_frame[1] = NONE_FRAME;
+  xd->mi[0]->compound_idx = 1;
 
   int best_rf_idx = -1;
   int_mv best_mv;
   int64_t inter_cost;
   int64_t best_inter_cost = INT64_MAX;
   int rf_idx;
+  int_mv single_mv[INTER_REFS_PER_FRAME];
 
   best_mv.as_int = INVALID_MV;
 
   for (rf_idx = 0; rf_idx < INTER_REFS_PER_FRAME; ++rf_idx) {
+    single_mv[rf_idx].as_int = INVALID_MV;
     if (tpl_data->ref_frame[rf_idx] == NULL ||
         tpl_data->src_ref_frame[rf_idx] == NULL) {
       tpl_stats->mv[rf_idx].as_int = INVALID_MV;
@@ -535,6 +539,7 @@
     }
 
     tpl_stats->mv[rf_idx].as_int = best_rfidx_mv.as_int;
+    single_mv[rf_idx] = best_rfidx_mv;
 
     struct buf_2d ref_buf = { NULL, ref_frame_ptr->y_buffer,
                               ref_frame_ptr->y_width, ref_frame_ptr->y_height,
@@ -567,6 +572,68 @@
     }
   }
 
+  int comp_ref_frames[3][2] = {
+    { 0, 4 },
+    { 0, 6 },
+    { 3, 6 },
+  };
+
+  xd->mi_row = mi_row;
+  xd->mi_col = mi_col;
+  for (int cmp_rf_idx = 0; cmp_rf_idx < 3; ++cmp_rf_idx) {
+    int rf_idx0 = comp_ref_frames[cmp_rf_idx][0];
+    int rf_idx1 = comp_ref_frames[cmp_rf_idx][1];
+
+    if (tpl_data->ref_frame[rf_idx0] == NULL ||
+        tpl_data->src_ref_frame[rf_idx0] == NULL ||
+        tpl_data->ref_frame[rf_idx1] == NULL ||
+        tpl_data->src_ref_frame[rf_idx1] == NULL) {
+      continue;
+    }
+
+    const YV12_BUFFER_CONFIG *ref_frame_ptr[2] = {
+      tpl_data->src_ref_frame[rf_idx0],
+      tpl_data->src_ref_frame[rf_idx1],
+    };
+
+    xd->mi[0]->ref_frame[0] = LAST_FRAME;
+    xd->mi[0]->ref_frame[1] = ALTREF_FRAME;
+
+    struct buf_2d yv12_mb[2][MAX_MB_PLANE];
+    for (int i = 0; i < 2; ++i) {
+      av1_setup_pred_block(xd, yv12_mb[i], ref_frame_ptr[i],
+                           xd->block_ref_scale_factors[i],
+                           xd->block_ref_scale_factors[i], MAX_MB_PLANE);
+      for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+        xd->plane[plane].pre[i] = yv12_mb[i][plane];
+      }
+    }
+
+    int_mv tmp_mv[2] = { single_mv[rf_idx0], single_mv[rf_idx1] };
+    int rate_mv;
+    av1_joint_motion_search(cpi, x, bsize, tmp_mv, NULL, 0, &rate_mv);
+
+    for (int ref = 0; ref < 2; ++ref) {
+      struct buf_2d ref_buf = { NULL, ref_frame_ptr[ref]->y_buffer,
+                                ref_frame_ptr[ref]->y_width,
+                                ref_frame_ptr[ref]->y_height,
+                                ref_frame_ptr[ref]->y_stride };
+      InterPredParams inter_pred_params;
+      av1_init_inter_params(&inter_pred_params, bw, bh, mi_row * MI_SIZE,
+                            mi_col * MI_SIZE, 0, 0, xd->bd, is_cur_buf_hbd(xd),
+                            0, &tpl_data->sf, &ref_buf, kernel);
+      av1_init_comp_mode(&inter_pred_params);
+
+      inter_pred_params.conv_params = get_conv_params_no_round(
+          ref, 0, xd->tmp_conv_dst, MAX_SB_SIZE, 1, xd->bd);
+
+      av1_enc_build_one_inter_predictor(predictor, bw, &tmp_mv[ref].as_mv,
+                                        &inter_pred_params);
+    }
+    tpl_get_satd_cost(x, src_diff, bw, src_mb_buffer, src_stride, predictor, bw,
+                      coeff, bw, bh, tx_size);
+  }
+
   if (best_inter_cost < INT64_MAX) {
     const YV12_BUFFER_CONFIG *ref_frame_ptr =
         tpl_data->src_ref_frame[best_rf_idx];
@@ -1257,6 +1324,9 @@
   tpl_row_mt->sync_read_ptr = av1_tpl_row_mt_sync_read_dummy;
   tpl_row_mt->sync_write_ptr = av1_tpl_row_mt_sync_write_dummy;
 
+  av1_setup_scale_factors_for_frame(&cm->sf_identity, cm->width, cm->height,
+                                    cm->width, cm->height);
+
   // Backward propagation from tpl_group_frames to 1.
   for (int frame_idx = gf_group->index; frame_idx < tpl_gf_group_frames;
        ++frame_idx) {
@@ -1265,7 +1335,7 @@
       continue;
 
     init_mc_flow_dispenser(cpi, frame_idx, pframe_qindex);
-    if (mt_info->num_workers > 1) {
+    if (mt_info->num_workers > 1 && !cpi->sf.tpl_sf.allow_compound_pred) {
       tpl_row_mt->sync_read_ptr = av1_tpl_row_mt_sync_read;
       tpl_row_mt->sync_write_ptr = av1_tpl_row_mt_sync_write;
       av1_mc_flow_dispenser_mt(cpi);