Modify the warped motion mode context

Modified the warped motion mode context based on neighbor's motion modes
and current block's mode.

Change-Id: I77ca35fab37ec640bb38661ff1799f643d5aafdc
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 6573e0d..098f11a 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7655,6 +7655,7 @@
   int pts0[SAMPLES_ARRAY_SIZE], pts_inref0[SAMPLES_ARRAY_SIZE];
   int pts_mv0[SAMPLES_ARRAY_SIZE], pts_wm[SAMPLES_ARRAY_SIZE];
   int total_samples;
+  int best_cand = -1;
 #else
   int pts[SAMPLES_ARRAY_SIZE], pts_inref[SAMPLES_ARRAY_SIZE];
 #endif  // CONFIG_EXT_WARPED_MOTION
@@ -7667,6 +7668,20 @@
   mbmi->num_proj_ref[0] =
       findSamples(cm, xd, mi_row, mi_col, pts0, pts_inref0, pts_mv0, pts_wm);
   total_samples = mbmi->num_proj_ref[0];
+
+  // Find a warped neighbor.
+  int cand;
+  int best_weight = 0;
+
+  // if (this_mode == NEARESTMV)
+  for (cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
+    if (pts_wm[cand * 2 + 1] > best_weight) {
+      best_weight = pts_wm[cand * 2 + 1];
+      best_cand = cand;
+    }
+  }
+  mbmi->wm_ctx = best_cand;
+  best_bmc_mbmi->wm_ctx = mbmi->wm_ctx;
 #else
   mbmi->num_proj_ref[0] = findSamples(cm, xd, mi_row, mi_col, pts, pts_inref);
 #endif  // CONFIG_EXT_WARPED_MOTION
@@ -7725,22 +7740,7 @@
           av1_unswitchable_filter(cm->interp_filter));
 
 #if CONFIG_EXT_WARPED_MOTION
-      // Find a warped neighbor.
-      int cand;
-      int best_cand = -1;
-      int best_weight = 0;
-
-      assert(this_mode >= NEARESTMV && this_mode <= NEWMV);
-      if (this_mode == NEARESTMV) {
-        for (cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
-          if (pts_wm[cand * 2 + 1] > best_weight) {
-            best_weight = pts_wm[cand * 2 + 1];
-            best_cand = cand;
-          }
-        }
-      }
-
-      if (best_cand != -1) {
+      if (this_mode == NEARESTMV && best_cand != -1) {
         MODE_INFO *best_mi = xd->mi[pts_wm[2 * best_cand]];
         assert(best_mi->mbmi.motion_mode == WARPED_CAUSAL);
         mbmi->wm_params[0] = best_mi->mbmi.wm_params[0];
@@ -7839,10 +7839,21 @@
     rd_stats->skip = 1;
     rd_stats->rate = tmp_rate2;
     if (last_motion_mode_allowed > SIMPLE_TRANSLATION) {
-      if (last_motion_mode_allowed == WARPED_CAUSAL)
+      if (last_motion_mode_allowed == WARPED_CAUSAL) {
+#if CONFIG_EXT_WARPED_MOTION
+        int wm_ctx = 0;
+        if (mbmi->wm_ctx != -1) {
+          wm_ctx = 1;
+          if (mbmi->mode == NEARESTMV) wm_ctx = 2;
+        }
+
+        rd_stats->rate += x->motion_mode_cost[wm_ctx][bsize][mbmi->motion_mode];
+#else
         rd_stats->rate += x->motion_mode_cost[bsize][mbmi->motion_mode];
-      else
+#endif  // CONFIG_EXT_WARPED_MOTION
+      } else {
         rd_stats->rate += x->motion_mode_cost1[bsize][mbmi->motion_mode];
+      }
     }
     if (mbmi->motion_mode == WARPED_CAUSAL) {
       rd_stats->rate -= rs;