Modify the warped motion mode context

Modified the warped motion mode context based on neighbor's motion modes
and current block's mode.

Change-Id: I77ca35fab37ec640bb38661ff1799f643d5aafdc
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index e25fb23..afc92f4 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -319,6 +319,9 @@
   int wedge_sign;
   SEG_MASK_TYPE mask_type;
   MOTION_MODE motion_mode;
+#if CONFIG_EXT_WARPED_MOTION
+  int wm_ctx;
+#endif  // CONFIG_EXT_WARPED_MOTION
   int overlappable_neighbors[2];
   int_mv mv[2];
   int_mv pred_mv[2];
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index db6f282..59b255d 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -818,6 +818,7 @@
   -SIMPLE_TRANSLATION, 2, -OBMC_CAUSAL, -WARPED_CAUSAL,
 };
 
+#if !CONFIG_EXT_WARPED_MOTION
 static const aom_prob
     default_motion_mode_prob[BLOCK_SIZES_ALL][MOTION_MODES - 1] = {
       { 128, 128 }, { 128, 128 }, { 128, 128 }, { 128, 128 },
@@ -833,6 +834,56 @@
       { 252, 200 }, { 252, 200 }
 #endif  // CONFIG_EXT_PARTITION
     };
+#else
+static const aom_prob
+    default_motion_mode_prob[MOTION_MODE_CTX][BLOCK_SIZES_ALL][MOTION_MODES -
+                                                               1] = {
+      {
+          { 128, 128 }, { 128, 128 }, { 128, 128 }, { 128, 128 },
+          { 128, 128 }, { 128, 128 }, { 62, 115 },  { 39, 131 },
+          { 39, 132 },  { 118, 94 },  { 77, 125 },  { 100, 121 },
+          { 190, 66 },  { 207, 102 }, { 197, 100 }, { 239, 76 },
+#if CONFIG_EXT_PARTITION
+          { 252, 200 }, { 252, 200 }, { 252, 200 },
+#endif  // CONFIG_EXT_PARTITION
+          { 208, 200 }, { 208, 200 }, { 208, 200 }, { 208, 200 },
+          { 208, 200 }, { 208, 200 },
+#if CONFIG_EXT_PARTITION
+          { 252, 200 }, { 252, 200 },
+#endif  // CONFIG_EXT_PARTITION
+      },
+      {
+          { 128, 128 }, { 128, 128 }, { 128, 128 }, { 128, 128 },
+          { 128, 128 }, { 128, 128 }, { 62, 115 },  { 39, 131 },
+          { 39, 132 },  { 118, 94 },  { 77, 125 },  { 100, 121 },
+          { 190, 66 },  { 207, 102 }, { 197, 100 }, { 239, 76 },
+#if CONFIG_EXT_PARTITION
+          { 252, 200 }, { 252, 200 }, { 252, 200 },
+#endif  // CONFIG_EXT_PARTITION
+          { 208, 200 }, { 208, 200 }, { 208, 200 }, { 208, 200 },
+          { 208, 200 }, { 208, 200 },
+#if CONFIG_EXT_PARTITION
+          { 252, 200 }, { 252, 200 },
+#endif  // CONFIG_EXT_PARTITION
+      },
+      {
+          { 128, 128 }, { 128, 128 }, { 128, 128 }, { 128, 128 },
+          { 128, 128 }, { 128, 128 }, { 62, 115 },  { 39, 131 },
+          { 39, 132 },  { 118, 94 },  { 77, 125 },  { 100, 121 },
+          { 190, 66 },  { 207, 102 }, { 197, 100 }, { 239, 76 },
+#if CONFIG_EXT_PARTITION
+          { 252, 200 }, { 252, 200 }, { 252, 200 },
+#endif  // CONFIG_EXT_PARTITION
+          { 208, 200 }, { 208, 200 }, { 208, 200 }, { 208, 200 },
+          { 208, 200 }, { 208, 200 },
+#if CONFIG_EXT_PARTITION
+          { 252, 200 }, { 252, 200 },
+#endif  // CONFIG_EXT_PARTITION
+      },
+    };
+#endif  // CONFIG_EXT_WARPED_MOTION
+
+#if !CONFIG_EXT_WARPED_MOTION
 static const aom_cdf_prob
     default_motion_mode_cdf[BLOCK_SIZES_ALL][CDF_SIZE(MOTION_MODES)] = {
       { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
@@ -854,6 +905,72 @@
       { AOM_CDF3(32256, 32656) }, { AOM_CDF3(32256, 32656) },
 #endif
     };
+#else
+static const aom_cdf_prob
+    default_motion_mode_cdf[MOTION_MODE_CTX][BLOCK_SIZES_ALL][CDF_SIZE(
+        MOTION_MODES)] = {
+      {
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(7936, 19091) },  { AOM_CDF3(4991, 19205) },
+          { AOM_CDF3(4992, 19314) },  { AOM_CDF3(15104, 21590) },
+          { AOM_CDF3(9855, 21043) },  { AOM_CDF3(12800, 22238) },
+          { AOM_CDF3(24320, 26498) }, { AOM_CDF3(26496, 28995) },
+          { AOM_CDF3(25216, 28166) }, { AOM_CDF3(30592, 31238) },
+#if CONFIG_EXT_PARTITION
+          { AOM_CDF3(32256, 32656) }, { AOM_CDF3(32256, 32656) },
+          { AOM_CDF3(32256, 32656) },
+#endif
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+#if CONFIG_EXT_PARTITION
+          { AOM_CDF3(32256, 32656) }, { AOM_CDF3(32256, 32656) },
+#endif
+      },
+      {
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(7936, 19091) },  { AOM_CDF3(4991, 19205) },
+          { AOM_CDF3(4992, 19314) },  { AOM_CDF3(15104, 21590) },
+          { AOM_CDF3(9855, 21043) },  { AOM_CDF3(12800, 22238) },
+          { AOM_CDF3(24320, 26498) }, { AOM_CDF3(26496, 28995) },
+          { AOM_CDF3(25216, 28166) }, { AOM_CDF3(30592, 31238) },
+#if CONFIG_EXT_PARTITION
+          { AOM_CDF3(32256, 32656) }, { AOM_CDF3(32256, 32656) },
+          { AOM_CDF3(32256, 32656) },
+#endif
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+#if CONFIG_EXT_PARTITION
+          { AOM_CDF3(32256, 32656) }, { AOM_CDF3(32256, 32656) },
+#endif
+      },
+      {
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(16384, 24576) }, { AOM_CDF3(16384, 24576) },
+          { AOM_CDF3(7936, 19091) },  { AOM_CDF3(4991, 19205) },
+          { AOM_CDF3(4992, 19314) },  { AOM_CDF3(15104, 21590) },
+          { AOM_CDF3(9855, 21043) },  { AOM_CDF3(12800, 22238) },
+          { AOM_CDF3(24320, 26498) }, { AOM_CDF3(26496, 28995) },
+          { AOM_CDF3(25216, 28166) }, { AOM_CDF3(30592, 31238) },
+#if CONFIG_EXT_PARTITION
+          { AOM_CDF3(32256, 32656) }, { AOM_CDF3(32256, 32656) },
+          { AOM_CDF3(32256, 32656) },
+#endif
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+          { AOM_CDF3(32640, 32740) }, { AOM_CDF3(32640, 32740) },
+#if CONFIG_EXT_PARTITION
+          { AOM_CDF3(32256, 32656) }, { AOM_CDF3(32256, 32656) },
+#endif
+      },
+    };
+#endif  // CONFIG_EXT_WARPED_MOTION
 
 static const aom_cdf_prob default_obmc_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)] = {
   { AOM_CDF2(128 * 128) }, { AOM_CDF2(128 * 128) }, { AOM_CDF2(128 * 128) },
@@ -2898,9 +3015,18 @@
       fc->single_ref_prob[i][j] = av1_mode_mv_merge_probs(
           pre_fc->single_ref_prob[i][j], counts->single_ref[i][j]);
 
+#if !CONFIG_EXT_WARPED_MOTION
   for (i = BLOCK_8X8; i < BLOCK_SIZES_ALL; ++i)
     aom_tree_merge_probs(av1_motion_mode_tree, pre_fc->motion_mode_prob[i],
                          counts->motion_mode[i], fc->motion_mode_prob[i]);
+#else
+  for (i = 0; i < MOTION_MODE_CTX; ++i) {
+    for (j = BLOCK_8X8; j < BLOCK_SIZES_ALL; ++j)
+      aom_tree_merge_probs(av1_motion_mode_tree, pre_fc->motion_mode_prob[i][j],
+                           counts->motion_mode[i][j],
+                           fc->motion_mode_prob[i][j]);
+  }
+#endif  // CONFIG_EXT_WARPED_MOTION
 
 #if CONFIG_JNT_COMP
   for (i = 0; i < COMP_INDEX_CONTEXTS; ++i)
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 45d44ca..ec17d1c 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -206,8 +206,14 @@
   aom_cdf_prob wedge_interintra_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)];
   aom_cdf_prob interintra_mode_cdf[BLOCK_SIZE_GROUPS]
                                   [CDF_SIZE(INTERINTRA_MODES)];
+#if CONFIG_EXT_WARPED_MOTION
+  aom_prob motion_mode_prob[MOTION_MODE_CTX][BLOCK_SIZES_ALL][MOTION_MODES - 1];
+  aom_cdf_prob motion_mode_cdf[MOTION_MODE_CTX][BLOCK_SIZES_ALL]
+                              [CDF_SIZE(MOTION_MODES)];
+#else
   aom_prob motion_mode_prob[BLOCK_SIZES_ALL][MOTION_MODES - 1];
   aom_cdf_prob motion_mode_cdf[BLOCK_SIZES_ALL][CDF_SIZE(MOTION_MODES)];
+#endif  // CONFIG_EXT_WARPED_MOTION
   aom_cdf_prob obmc_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)];
   aom_prob comp_inter_prob[COMP_INTER_CONTEXTS];
   aom_cdf_prob palette_y_size_cdf[PALETTE_BLOCK_SIZES][CDF_SIZE(PALETTE_SIZES)];
@@ -366,7 +372,12 @@
   unsigned int interintra_mode[BLOCK_SIZE_GROUPS][INTERINTRA_MODES];
   unsigned int wedge_interintra[BLOCK_SIZES_ALL][2];
   unsigned int compound_interinter[BLOCK_SIZES_ALL][COMPOUND_TYPES];
+#if CONFIG_EXT_WARPED_MOTION
+  unsigned int motion_mode[MOTION_MODE_CTX][BLOCK_SIZES_ALL][MOTION_MODES];
+#else
   unsigned int motion_mode[BLOCK_SIZES_ALL][MOTION_MODES];
+#endif  // CONFIG_EXT_WARPED_MOTION
+
   unsigned int obmc[BLOCK_SIZES_ALL][2];
   unsigned int intra_inter[INTRA_INTER_CONTEXTS][2];
   unsigned int comp_inter[COMP_INTER_CONTEXTS][2];
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 44e815c..4684eb9 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -537,6 +537,10 @@
   MOTION_MODES
 } MOTION_MODE;
 
+#if CONFIG_EXT_WARPED_MOTION
+#define MOTION_MODE_CTX 3
+#endif  // CONFIG_EXT_WARPED_MOTION
+
 typedef enum ATTRIBUTE_PACKED {
   II_DC_PRED,
   II_V_PRED,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index c83ea98..677dfae 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -283,8 +283,13 @@
   }
 }
 
+#if CONFIG_EXT_WARPED_MOTION
+static MOTION_MODE read_motion_mode(AV1_COMMON *cm, MACROBLOCKD *xd,
+                                    MODE_INFO *mi, aom_reader *r, int best) {
+#else
 static MOTION_MODE read_motion_mode(AV1_COMMON *cm, MACROBLOCKD *xd,
                                     MODE_INFO *mi, aom_reader *r) {
+#endif  // CONFIG_EXT_WARPED_MOTION
   MB_MODE_INFO *mbmi = &mi->mbmi;
   (void)cm;
 
@@ -305,10 +310,24 @@
     if (counts) ++counts->obmc[mbmi->sb_type][motion_mode];
     return (MOTION_MODE)(SIMPLE_TRANSLATION + motion_mode);
   } else {
+#if CONFIG_EXT_WARPED_MOTION
+    int wm_ctx = 0;
+    if (best != -1) {
+      wm_ctx = 1;
+      if (mbmi->mode == NEARESTMV) wm_ctx = 2;
+    }
+
+    motion_mode =
+        aom_read_symbol(r, xd->tile_ctx->motion_mode_cdf[wm_ctx][mbmi->sb_type],
+                        MOTION_MODES, ACCT_STR);
+    if (counts) ++counts->motion_mode[wm_ctx][mbmi->sb_type][motion_mode];
+#else
     motion_mode =
         aom_read_symbol(r, xd->tile_ctx->motion_mode_cdf[mbmi->sb_type],
                         MOTION_MODES, ACCT_STR);
     if (counts) ++counts->motion_mode[mbmi->sb_type][motion_mode];
+#endif  // CONFIG_EXT_WARPED_MOTION
+
     return (MOTION_MODE)(SIMPLE_TRANSLATION + motion_mode);
   }
 }
@@ -2138,6 +2157,9 @@
     xd->block_refs[ref] = ref_buf;
   }
 
+#if CONFIG_EXT_WARPED_MOTION
+  int best_cand = -1;
+#endif  // CONFIG_EXT_WARPED_MOTION
   mbmi->motion_mode = SIMPLE_TRANSLATION;
   if (mbmi->sb_type >= BLOCK_8X8 &&
 #if CONFIG_EXT_SKIP
@@ -2145,15 +2167,33 @@
 #endif  // CONFIG_EXT_SKIP
       !has_second_ref(mbmi))
 #if CONFIG_EXT_WARPED_MOTION
+  {
     mbmi->num_proj_ref[0] =
         findSamples(cm, xd, mi_row, mi_col, pts, pts_inref, pts_mv, pts_wm);
+
+    // Find a warped neighbor.
+    int cand;
+    int best_weight = 0;
+
+    // if (mbmi->mode == NEARESTMV)
+    for (cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
+      if (pts_wm[cand * 2 + 1] > best_weight) {
+        best_weight = pts_wm[cand * 2 + 1];
+        best_cand = cand;
+      }
+    }
+  }
 #else
     mbmi->num_proj_ref[0] = findSamples(cm, xd, mi_row, mi_col, pts, pts_inref);
 #endif  // CONFIG_EXT_WARPED_MOTION
   av1_count_overlappable_neighbors(cm, xd, mi_row, mi_col);
 
   if (mbmi->ref_frame[1] != INTRA_FRAME)
+#if CONFIG_EXT_WARPED_MOTION
+    mbmi->motion_mode = read_motion_mode(cm, xd, mi, r, best_cand);
+#else
     mbmi->motion_mode = read_motion_mode(cm, xd, mi, r);
+#endif  // CONFIG_EXT_WARPED_MOTION
 
   mbmi->interinter_compound_type = COMPOUND_AVERAGE;
   if (cm->reference_mode != SINGLE_REFERENCE &&
@@ -2199,21 +2239,7 @@
     mbmi->wm_params[0].wmtype = DEFAULT_WMTYPE;
 
 #if CONFIG_EXT_WARPED_MOTION
-    // Find a warped neighbor.
-    int best_cand = -1;
-    int best_weight = 0;
-
-    assert(mbmi->mode >= NEARESTMV && mbmi->mode <= NEWMV);
-    if (mbmi->mode == NEARESTMV) {
-      for (int cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
-        if (pts_wm[cand * 2 + 1] > best_weight) {
-          best_weight = pts_wm[cand * 2 + 1];
-          best_cand = cand;
-        }
-      }
-    }
-
-    if (best_cand != -1) {
+    if (mbmi->mode == NEARESTMV && best_cand != -1) {
       MODE_INFO *best_mi = xd->mi[pts_wm[2 * best_cand]];
       assert(best_mi->mbmi.motion_mode == WARPED_CAUSAL);
       mbmi->wm_params[0] = best_mi->mbmi.wm_params[0];
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 4b18e9a..3b50a6b 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -325,9 +325,22 @@
                        xd->tile_ctx->obmc_cdf[mbmi->sb_type], 2);
       break;
     default:
+#if CONFIG_EXT_WARPED_MOTION
+    {
+      int wm_ctx = 0;
+      if (mbmi->wm_ctx != -1) {
+        wm_ctx = 1;
+        if (mbmi->mode == NEARESTMV) wm_ctx = 2;
+      }
+      aom_write_symbol(w, mbmi->motion_mode,
+                       xd->tile_ctx->motion_mode_cdf[wm_ctx][mbmi->sb_type],
+                       MOTION_MODES);
+    }
+#else
       aom_write_symbol(w, mbmi->motion_mode,
                        xd->tile_ctx->motion_mode_cdf[mbmi->sb_type],
                        MOTION_MODES);
+#endif  // CONFIG_EXT_WARPED_MOTION
   }
 }
 
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 139f332..8de0a78 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -275,7 +275,11 @@
   int interintra_cost[BLOCK_SIZE_GROUPS][2];
   int wedge_interintra_cost[BLOCK_SIZES_ALL][2];
   int interintra_mode_cost[BLOCK_SIZE_GROUPS][INTERINTRA_MODES];
+#if CONFIG_EXT_WARPED_MOTION
+  int motion_mode_cost[MOTION_MODE_CTX][BLOCK_SIZES_ALL][MOTION_MODES];
+#else
   int motion_mode_cost[BLOCK_SIZES_ALL][MOTION_MODES];
+#endif  // CONFIG_EXT_WARPED_MOTION
   int motion_mode_cost1[BLOCK_SIZES_ALL][2];
   int intra_uv_mode_cost[INTRA_MODES][UV_INTRA_MODES];
   int y_mode_costs[INTRA_MODES][INTRA_MODES][INTRA_MODES];
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 0f4af1d..c476a92 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1233,10 +1233,23 @@
             motion_mode_allowed(0, xd->global_motion, xd, mi);
         if (mbmi->ref_frame[1] != INTRA_FRAME) {
           if (motion_allowed == WARPED_CAUSAL) {
+#if CONFIG_EXT_WARPED_MOTION
+            int wm_ctx = 0;
+            if (mbmi->wm_ctx != -1) {
+              wm_ctx = 1;
+              if (mbmi->mode == NEARESTMV) wm_ctx = 2;
+            }
+
+            counts->motion_mode[wm_ctx][mbmi->sb_type][mbmi->motion_mode]++;
+            if (allow_update_cdf)
+              update_cdf(fc->motion_mode_cdf[wm_ctx][mbmi->sb_type],
+                         mbmi->motion_mode, MOTION_MODES);
+#else
             counts->motion_mode[mbmi->sb_type][mbmi->motion_mode]++;
             if (allow_update_cdf)
               update_cdf(fc->motion_mode_cdf[mbmi->sb_type], mbmi->motion_mode,
                          MOTION_MODES);
+#endif  // CONFIG_EXT_WARPED_MOTION
           } else if (motion_allowed == OBMC_CAUSAL) {
             counts->obmc[mbmi->sb_type][mbmi->motion_mode == OBMC_CAUSAL]++;
             if (allow_update_cdf)
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 4e74799..f1b109b 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -262,10 +262,20 @@
       av1_cost_tokens_from_cdf(x->wedge_interintra_cost[i],
                                fc->wedge_interintra_cdf[i], NULL);
     }
+#if CONFIG_EXT_WARPED_MOTION
+    for (i = 0; i < MOTION_MODE_CTX; i++) {
+      for (j = BLOCK_8X8; j < BLOCK_SIZES_ALL; j++) {
+        av1_cost_tokens_from_cdf(x->motion_mode_cost[i][j],
+                                 fc->motion_mode_cdf[i][j], NULL);
+      }
+    }
+#else
     for (i = BLOCK_8X8; i < BLOCK_SIZES_ALL; i++) {
       av1_cost_tokens_from_cdf(x->motion_mode_cost[i], fc->motion_mode_cdf[i],
                                NULL);
     }
+#endif  // CONFIG_EXT_WARPED_MOTION
+
     for (i = BLOCK_8X8; i < BLOCK_SIZES_ALL; i++) {
       av1_cost_tokens_from_cdf(x->motion_mode_cost1[i], fc->obmc_cdf[i], NULL);
     }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 6573e0d..098f11a 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7655,6 +7655,7 @@
   int pts0[SAMPLES_ARRAY_SIZE], pts_inref0[SAMPLES_ARRAY_SIZE];
   int pts_mv0[SAMPLES_ARRAY_SIZE], pts_wm[SAMPLES_ARRAY_SIZE];
   int total_samples;
+  int best_cand = -1;
 #else
   int pts[SAMPLES_ARRAY_SIZE], pts_inref[SAMPLES_ARRAY_SIZE];
 #endif  // CONFIG_EXT_WARPED_MOTION
@@ -7667,6 +7668,20 @@
   mbmi->num_proj_ref[0] =
       findSamples(cm, xd, mi_row, mi_col, pts0, pts_inref0, pts_mv0, pts_wm);
   total_samples = mbmi->num_proj_ref[0];
+
+  // Find a warped neighbor.
+  int cand;
+  int best_weight = 0;
+
+  // if (this_mode == NEARESTMV)
+  for (cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
+    if (pts_wm[cand * 2 + 1] > best_weight) {
+      best_weight = pts_wm[cand * 2 + 1];
+      best_cand = cand;
+    }
+  }
+  mbmi->wm_ctx = best_cand;
+  best_bmc_mbmi->wm_ctx = mbmi->wm_ctx;
 #else
   mbmi->num_proj_ref[0] = findSamples(cm, xd, mi_row, mi_col, pts, pts_inref);
 #endif  // CONFIG_EXT_WARPED_MOTION
@@ -7725,22 +7740,7 @@
           av1_unswitchable_filter(cm->interp_filter));
 
 #if CONFIG_EXT_WARPED_MOTION
-      // Find a warped neighbor.
-      int cand;
-      int best_cand = -1;
-      int best_weight = 0;
-
-      assert(this_mode >= NEARESTMV && this_mode <= NEWMV);
-      if (this_mode == NEARESTMV) {
-        for (cand = 0; cand < mbmi->num_proj_ref[0]; cand++) {
-          if (pts_wm[cand * 2 + 1] > best_weight) {
-            best_weight = pts_wm[cand * 2 + 1];
-            best_cand = cand;
-          }
-        }
-      }
-
-      if (best_cand != -1) {
+      if (this_mode == NEARESTMV && best_cand != -1) {
         MODE_INFO *best_mi = xd->mi[pts_wm[2 * best_cand]];
         assert(best_mi->mbmi.motion_mode == WARPED_CAUSAL);
         mbmi->wm_params[0] = best_mi->mbmi.wm_params[0];
@@ -7839,10 +7839,21 @@
     rd_stats->skip = 1;
     rd_stats->rate = tmp_rate2;
     if (last_motion_mode_allowed > SIMPLE_TRANSLATION) {
-      if (last_motion_mode_allowed == WARPED_CAUSAL)
+      if (last_motion_mode_allowed == WARPED_CAUSAL) {
+#if CONFIG_EXT_WARPED_MOTION
+        int wm_ctx = 0;
+        if (mbmi->wm_ctx != -1) {
+          wm_ctx = 1;
+          if (mbmi->mode == NEARESTMV) wm_ctx = 2;
+        }
+
+        rd_stats->rate += x->motion_mode_cost[wm_ctx][bsize][mbmi->motion_mode];
+#else
         rd_stats->rate += x->motion_mode_cost[bsize][mbmi->motion_mode];
-      else
+#endif  // CONFIG_EXT_WARPED_MOTION
+      } else {
         rd_stats->rate += x->motion_mode_cost1[bsize][mbmi->motion_mode];
+      }
     }
     if (mbmi->motion_mode == WARPED_CAUSAL) {
       rd_stats->rate -= rs;