Modify the warped motion mode context

Modified the warped motion mode context based on neighbor's motion modes
and current block's mode.

Change-Id: I77ca35fab37ec640bb38661ff1799f643d5aafdc
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index c83ea98..677dfae 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -283,8 +283,13 @@
   }
 }
 
+#if CONFIG_EXT_WARPED_MOTION
+static MOTION_MODE read_motion_mode(AV1_COMMON *cm, MACROBLOCKD *xd,
+                                    MODE_INFO *mi, aom_reader *r, int best) {
+#else
 static MOTION_MODE read_motion_mode(AV1_COMMON *cm, MACROBLOCKD *xd,
                                     MODE_INFO *mi, aom_reader *r) {
+#endif  // CONFIG_EXT_WARPED_MOTION
   MB_MODE_INFO *mbmi = &mi->mbmi;
   (void)cm;
 
@@ -305,10 +310,24 @@
     if (counts) ++counts->obmc[mbmi->sb_type][motion_mode];
     return (MOTION_MODE)(SIMPLE_TRANSLATION + motion_mode);
   } else {
+#if CONFIG_EXT_WARPED_MOTION
+    int wm_ctx = 0;
+    if (best != -1) {
+      wm_ctx = 1;
+      if (mbmi->mode == NEARESTMV) wm_ctx = 2;
+    }
+
+    motion_mode =
+        aom_read_symbol(r, xd->tile_ctx->motion_mode_cdf[wm_ctx][mbmi->sb_type],
+                        MOTION_MODES, ACCT_STR);
+    if (counts) ++counts->motion_mode[wm_ctx][mbmi->sb_type][motion_mode];
+#else
     motion_mode =
         aom_read_symbol(r, xd->tile_ctx->motion_mode_cdf[mbmi->sb_type],
                         MOTION_MODES, ACCT_STR);
     if (counts) ++counts->motion_mode[mbmi->sb_type][motion_mode];
+#endif  // CONFIG_EXT_WARPED_MOTION
+
     return (MOTION_MODE)(SIMPLE_TRANSLATION + motion_mode);
   }
 }
@@ -2138,6 +2157,9 @@
     xd->block_refs[ref] = ref_buf;
   }
 
+#if CONFIG_EXT_WARPED_MOTION
+  int best_cand = -1;
+#endif  // CONFIG_EXT_WARPED_MOTION
   mbmi->motion_mode = SIMPLE_TRANSLATION;
   if (mbmi->sb_type >= BLOCK_8X8 &&
 #if CONFIG_EXT_SKIP
@@ -2145,15 +2167,33 @@
 #endif  // CONFIG_EXT_SKIP
       !has_second_ref(mbmi))
 #if CONFIG_EXT_WARPED_MOTION
+  {
     mbmi->num_proj_ref[0] =
         findSamples(cm, xd, mi_row, mi_col, pts, pts_inref, pts_mv, pts_wm);
+
+    // Find a warped neighbor.
+    int cand;
+    int best_weight = 0;
+
+    // if (mbmi->mode == NEARESTMV)
+    for (cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
+      if (pts_wm[cand * 2 + 1] > best_weight) {
+        best_weight = pts_wm[cand * 2 + 1];
+        best_cand = cand;
+      }
+    }
+  }
 #else
     mbmi->num_proj_ref[0] = findSamples(cm, xd, mi_row, mi_col, pts, pts_inref);
 #endif  // CONFIG_EXT_WARPED_MOTION
   av1_count_overlappable_neighbors(cm, xd, mi_row, mi_col);
 
   if (mbmi->ref_frame[1] != INTRA_FRAME)
+#if CONFIG_EXT_WARPED_MOTION
+    mbmi->motion_mode = read_motion_mode(cm, xd, mi, r, best_cand);
+#else
     mbmi->motion_mode = read_motion_mode(cm, xd, mi, r);
+#endif  // CONFIG_EXT_WARPED_MOTION
 
   mbmi->interinter_compound_type = COMPOUND_AVERAGE;
   if (cm->reference_mode != SINGLE_REFERENCE &&
@@ -2199,21 +2239,7 @@
     mbmi->wm_params[0].wmtype = DEFAULT_WMTYPE;
 
 #if CONFIG_EXT_WARPED_MOTION
-    // Find a warped neighbor.
-    int best_cand = -1;
-    int best_weight = 0;
-
-    assert(mbmi->mode >= NEARESTMV && mbmi->mode <= NEWMV);
-    if (mbmi->mode == NEARESTMV) {
-      for (int cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
-        if (pts_wm[cand * 2 + 1] > best_weight) {
-          best_weight = pts_wm[cand * 2 + 1];
-          best_cand = cand;
-        }
-      }
-    }
-
-    if (best_cand != -1) {
+    if (mbmi->mode == NEARESTMV && best_cand != -1) {
       MODE_INFO *best_mi = xd->mi[pts_wm[2 * best_cand]];
       assert(best_mi->mbmi.motion_mode == WARPED_CAUSAL);
       mbmi->wm_params[0] = best_mi->mbmi.wm_params[0];