JNT_COMP: 3. rd select the best weight

Select the best compound_idx in rd.
The rate/cost for compound_idx and their ctx will be in patch 4.

But there's a bug for now if we don't encode one more time using the
selected compound_idx. It remains a issue to be solved in the future.

Change-Id: I5e1ba51da2b6ab5bacd8aba752dda43bd2257014
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 042448e..ce4c03d 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5795,43 +5795,48 @@
 static void jnt_comp_weight_assign(const AV1_COMMON *cm,
                                    const MB_MODE_INFO *mbmi, int order_idx,
                                    uint8_t *second_pred) {
-  int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
-  int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
-  int bck_frame_index = 0, fwd_frame_index = 0;
-  int cur_frame_index = cm->cur_frame->cur_frame_offset;
-
-  if (bck_idx >= 0) {
-    bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
-  }
-
-  if (fwd_idx >= 0) {
-    fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
-  }
-
-  const double fwd = abs(fwd_frame_index - cur_frame_index);
-  const double bck = abs(cur_frame_index - bck_frame_index);
-  int order;
-  double ratio;
-
-  if (COMPOUND_WEIGHT_MODE == DIST) {
-    if (fwd > bck) {
-      ratio = (bck != 0) ? fwd / bck : 5.0;
-      order = 0;
-    } else {
-      ratio = (fwd != 0) ? bck / fwd : 5.0;
-      order = 1;
-    }
-    int quant_dist_idx;
-    for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
-      if (ratio < quant_dist_category[quant_dist_idx]) break;
-    }
-    second_pred[4096] =
-        quant_dist_lookup_table[order_idx][quant_dist_idx][order];
-    second_pred[4097] =
-        quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
+  if (mbmi->compound_idx) {
+    second_pred[4096] = -1;
+    second_pred[4097] = -1;
   } else {
-    second_pred[4096] = (DIST_PRECISION >> 1);
-    second_pred[4097] = (DIST_PRECISION >> 1);
+    int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
+    int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
+    int bck_frame_index = 0, fwd_frame_index = 0;
+    int cur_frame_index = cm->cur_frame->cur_frame_offset;
+
+    if (bck_idx >= 0) {
+      bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
+    }
+
+    if (fwd_idx >= 0) {
+      fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
+    }
+
+    const double fwd = abs(fwd_frame_index - cur_frame_index);
+    const double bck = abs(cur_frame_index - bck_frame_index);
+    int order;
+    double ratio;
+
+    if (COMPOUND_WEIGHT_MODE == DIST) {
+      if (fwd > bck) {
+        ratio = (bck != 0) ? fwd / bck : 5.0;
+        order = 0;
+      } else {
+        ratio = (fwd != 0) ? bck / fwd : 5.0;
+        order = 1;
+      }
+      int quant_dist_idx;
+      for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
+        if (ratio < quant_dist_category[quant_dist_idx]) break;
+      }
+      second_pred[4096] =
+          quant_dist_lookup_table[order_idx][quant_dist_idx][order];
+      second_pred[4097] =
+          quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
+    } else {
+      second_pred[4096] = (DIST_PRECISION >> 1);
+      second_pred[4097] = (DIST_PRECISION >> 1);
+    }
   }
 }
 #endif  // CONFIG_JNT_COMP
@@ -10217,6 +10222,130 @@
           }
         }
       }
+#if CONFIG_JNT_COMP
+      {
+        int cum_rate = rate2;
+        MB_MODE_INFO backup_mbmi = *mbmi;
+
+        int_mv backup_frame_mv[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
+        int_mv backup_single_newmv[TOTAL_REFS_PER_FRAME];
+        int backup_single_newmv_rate[TOTAL_REFS_PER_FRAME];
+        int64_t backup_modelled_rd[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
+
+        memcpy(backup_frame_mv, frame_mv, sizeof(frame_mv));
+        memcpy(backup_single_newmv, single_newmv, sizeof(single_newmv));
+        memcpy(backup_single_newmv_rate, single_newmv_rate,
+               sizeof(single_newmv_rate));
+        memcpy(backup_modelled_rd, modelled_rd, sizeof(modelled_rd));
+
+        InterpFilters backup_interp_filters = mbmi->interp_filters;
+
+        for (int comp_idx = 0; comp_idx < 1 + has_second_ref(mbmi);
+             ++comp_idx) {
+          RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
+          av1_init_rd_stats(&rd_stats);
+          av1_init_rd_stats(&rd_stats_y);
+          av1_init_rd_stats(&rd_stats_uv);
+          rd_stats.rate = cum_rate;
+
+          memcpy(frame_mv, backup_frame_mv, sizeof(frame_mv));
+          memcpy(single_newmv, backup_single_newmv, sizeof(single_newmv));
+          memcpy(single_newmv_rate, backup_single_newmv_rate,
+                 sizeof(single_newmv_rate));
+          memcpy(modelled_rd, backup_modelled_rd, sizeof(modelled_rd));
+
+          mbmi->interp_filters = backup_interp_filters;
+
+          int dummy_disable_skip = 0;
+
+          // Point to variables that are maintained between loop iterations
+          args.single_newmv = single_newmv;
+          args.single_newmv_rate = single_newmv_rate;
+          args.modelled_rd = modelled_rd;
+          mbmi->compound_idx = comp_idx;
+
+          int64_t tmp_rd = handle_inter_mode(
+              cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
+              &dummy_disable_skip, frame_mv, mi_row, mi_col, &args, best_rd);
+
+          if (tmp_rd < INT64_MAX) {
+            if (RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist) <
+                RDCOST(x->rdmult, 0, rd_stats.sse))
+              tmp_rd =
+                  RDCOST(x->rdmult, rd_stats.rate + x->skip_cost[skip_ctx][0],
+                         rd_stats.dist);
+            else
+              tmp_rd = RDCOST(x->rdmult,
+                              rd_stats.rate + x->skip_cost[skip_ctx][1] -
+                                  rd_stats_y.rate - rd_stats_uv.rate,
+                              rd_stats.sse);
+          }
+
+          if (tmp_rd < this_rd) {
+            this_rd = tmp_rd;
+            rate2 = rd_stats.rate;
+            skippable = rd_stats.skip;
+            distortion2 = rd_stats.dist;
+            total_sse = rd_stats.sse;
+            rate_y = rd_stats_y.rate;
+            rate_uv = rd_stats_uv.rate;
+            disable_skip = dummy_disable_skip;
+            backup_mbmi = *mbmi;
+          }
+        }
+        *mbmi = backup_mbmi;
+
+        // TODO(chengchen): Redo encoding use the selected compound_idx
+        // But ideally, this is unnecessary
+        {
+          RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
+          av1_init_rd_stats(&rd_stats);
+          av1_init_rd_stats(&rd_stats_y);
+          av1_init_rd_stats(&rd_stats_uv);
+          rd_stats.rate = cum_rate;
+
+          memcpy(frame_mv, backup_frame_mv, sizeof(frame_mv));
+          memcpy(single_newmv, backup_single_newmv, sizeof(single_newmv));
+          memcpy(single_newmv_rate, backup_single_newmv_rate,
+                 sizeof(single_newmv_rate));
+          memcpy(modelled_rd, backup_modelled_rd, sizeof(modelled_rd));
+
+          mbmi->interp_filters = backup_interp_filters;
+
+          int dummy_disable_skip = 0;
+
+          args.single_newmv = single_newmv;
+          args.single_newmv_rate = single_newmv_rate;
+          args.modelled_rd = modelled_rd;
+
+          int64_t tmp_rd = handle_inter_mode(
+              cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
+              &dummy_disable_skip, frame_mv, mi_row, mi_col, &args, best_rd);
+
+          if (tmp_rd < INT64_MAX) {
+            if (RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist) <
+                RDCOST(x->rdmult, 0, rd_stats.sse))
+              tmp_rd =
+                  RDCOST(x->rdmult, rd_stats.rate + x->skip_cost[skip_ctx][0],
+                         rd_stats.dist);
+            else
+              tmp_rd = RDCOST(x->rdmult,
+                              rd_stats.rate + x->skip_cost[skip_ctx][1] -
+                                  rd_stats_y.rate - rd_stats_uv.rate,
+                              rd_stats.sse);
+          }
+
+          this_rd = tmp_rd;
+          rate2 = rd_stats.rate;
+          skippable = rd_stats.skip;
+          distortion2 = rd_stats.dist;
+          total_sse = rd_stats.sse;
+          rate_y = rd_stats_y.rate;
+          rate_uv = rd_stats_uv.rate;
+          disable_skip = dummy_disable_skip;
+        }
+      }
+#else  // CONFIG_JNT_COMP
       {
         RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
         av1_init_rd_stats(&rd_stats);
@@ -10240,6 +10369,7 @@
         rate_y = rd_stats_y.rate;
         rate_uv = rd_stats_uv.rate;
       }
+#endif  // CONFIG_JNT_COMP
 
 // TODO(jingning): This needs some refactoring to improve code quality
 // and reduce redundant steps.
@@ -10293,12 +10423,22 @@
           memcpy(x->blk_skip_drl[i], x->blk_skip[i],
                  sizeof(uint8_t) * ctx->num_4x4_blk);
 
-        for (ref_idx = 0; ref_idx < ref_set; ++ref_idx) {
+#if CONFIG_JNT_COMP
+        for (int sidx = 0; sidx < ref_set * (1 + has_second_ref(mbmi)); ++sidx)
+#else
+        for (ref_idx = 0; ref_idx < ref_set; ++ref_idx)
+#endif  // CONFIG_JNT_COMP
+        {
           int64_t tmp_alt_rd = INT64_MAX;
           int dummy_disable_skip = 0;
           int ref;
           int_mv cur_mv;
           RD_STATS tmp_rd_stats, tmp_rd_stats_y, tmp_rd_stats_uv;
+#if CONFIG_JNT_COMP
+          ref_idx = sidx;
+          if (has_second_ref(mbmi)) ref_idx /= 2;
+          mbmi->compound_idx = sidx % 2;
+#endif  // CONFIG_JNT_COMP
 
           av1_invalid_rd_stats(&tmp_rd_stats);
 
@@ -10480,6 +10620,9 @@
         for (i = 0; i < MAX_MB_PLANE; ++i)
           memcpy(x->blk_skip[i], x->blk_skip_drl[i],
                  sizeof(uint8_t) * ctx->num_4x4_blk);
+#if CONFIG_JNT_COMP
+        *mbmi = backup_mbmi;
+#endif  // CONFIG_JNT_COMP
       }
       mbmi_ext->ref_mvs[ref_frame][0] = backup_ref_mv[0];
       if (comp_pred) mbmi_ext->ref_mvs[second_ref_frame][0] = backup_ref_mv[1];