more ref_mv changes from aom/master

Change-Id: I9152f898dfacdf3877ed719f193bb1e0dbee0a1a
diff --git a/av1/common/mv.h b/av1/common/mv.h
index a02d598..d49fc3f 100644
--- a/av1/common/mv.h
+++ b/av1/common/mv.h
@@ -152,7 +152,7 @@
 typedef struct candidate_mv {
   int_mv this_mv;
   int_mv comp_mv;
-  int_mv pred_mv;
+  int_mv pred_mv[2];
   int weight;
 } CANDIDATE_MV;
 #endif
diff --git a/av1/common/mvref_common.c b/av1/common/mvref_common.c
index ff05c71..0f3f949 100644
--- a/av1/common/mvref_common.c
+++ b/av1/common/mvref_common.c
@@ -35,7 +35,7 @@
         // Add a new item to the list.
         if (index == *refmv_count) {
           ref_mv_stack[index].this_mv = this_refmv;
-          ref_mv_stack[index].pred_mv =
+          ref_mv_stack[index].pred_mv[0] =
               get_sub_block_pred_mv(candidate_mi, ref, col, block);
           ref_mv_stack[index].weight = 2 * len;
           ++(*refmv_count);
@@ -61,7 +61,7 @@
           // Add a new item to the list.
           if (index == *refmv_count) {
             ref_mv_stack[index].this_mv = this_refmv;
-            ref_mv_stack[index].pred_mv =
+            ref_mv_stack[index].pred_mv[0] =
                 get_sub_block_pred_mv(candidate_mi, ref, col, alt_block);
             ref_mv_stack[index].weight = len;
             ++(*refmv_count);
@@ -97,6 +97,10 @@
       if (index == *refmv_count) {
         ref_mv_stack[index].this_mv = this_refmv[0];
         ref_mv_stack[index].comp_mv = this_refmv[1];
+        ref_mv_stack[index].pred_mv[0] =
+            get_sub_block_pred_mv(candidate_mi, 0, col, block);
+        ref_mv_stack[index].pred_mv[1] =
+            get_sub_block_pred_mv(candidate_mi, 1, col, block);
         ref_mv_stack[index].weight = 2 * len;
         ++(*refmv_count);
 
@@ -127,6 +131,10 @@
         if (index == *refmv_count) {
           ref_mv_stack[index].this_mv = this_refmv[0];
           ref_mv_stack[index].comp_mv = this_refmv[1];
+          ref_mv_stack[index].pred_mv[0] =
+              get_sub_block_pred_mv(candidate_mi, 0, col, block);
+          ref_mv_stack[index].pred_mv[1] =
+              get_sub_block_pred_mv(candidate_mi, 1, col, block);
           ref_mv_stack[index].weight = len;
           ++(*refmv_count);
 
@@ -230,6 +238,7 @@
         candidate_mi, candidate, rf, refmv_count, ref_mv_stack,
         cm->allow_high_precision_mv, len, block, mi_pos.col);
   }  // Analyze a single 8x8 block motion information.
+
   return newmv_count;
 }
 
@@ -333,6 +342,7 @@
 
       if (idx == *refmv_count && *refmv_count < MAX_REF_MV_STACK_SIZE) {
         ref_mv_stack[idx].this_mv.as_int = this_refmv.as_int;
+        ref_mv_stack[idx].pred_mv[0] = prev_frame_mvs->pred_mv[ref];
         ref_mv_stack[idx].weight = 2;
         ++(*refmv_count);
       }
@@ -389,8 +399,8 @@
     int blk_row, blk_col;
     int coll_blk_count = 0;
 
-    for (blk_row = 0; blk_row < xd->n8_h; blk_row += 2) {
-      for (blk_col = 0; blk_col < xd->n8_w; blk_col += 2) {
+    for (blk_row = 0; blk_row < xd->n8_h; ++blk_row) {
+      for (blk_col = 0; blk_col < xd->n8_w; ++blk_col) {
         coll_blk_count += add_col_ref_mv(
             cm, prev_frame_mvs_base, xd, mi_row, mi_col, ref_frame, blk_row,
             blk_col, refmv_count, ref_mv_stack, mode_context);
@@ -705,19 +715,30 @@
 #else
                         mode_context);
 #endif  // CONFIG_REF_MV
-  find_mv_refs_idx(cm, xd, mi, ref_frame, mv_ref_list, -1, mi_row, mi_col, sync,
-                   data, NULL);
+#endif  // CONFIG_EXT_INTER
+#if CONFIG_REF_MV
+  if (ref_frame <= ALTREF_FRAME)
+    find_mv_refs_idx(cm, xd, mi, ref_frame, mv_ref_list, -1, mi_row, mi_col,
+                     sync, data, mode_context);
 #else
   find_mv_refs_idx(cm, xd, mi, ref_frame, mv_ref_list, -1, mi_row, mi_col, sync,
                    data, mode_context);
-#endif  // CONFIG_EXT_INTER
+#endif  // CONFIG_REF_MV
 
 #if CONFIG_REF_MV
   setup_ref_mv_list(cm, xd, ref_frame, ref_mv_count, ref_mv_stack, mv_ref_list,
                     -1, mi_row, mi_col, mode_context);
 
-  for (idx = 0; idx < MAX_MV_REF_CANDIDATES; ++idx)
-    if (mv_ref_list[idx].as_int != 0) all_zero = 0;
+  if (*ref_mv_count >= 2) {
+    for (idx = 0; idx < AOMMIN(3, *ref_mv_count); ++idx) {
+      if (ref_mv_stack[idx].this_mv.as_int != 0) all_zero = 0;
+      if (ref_frame > ALTREF_FRAME)
+        if (ref_mv_stack[idx].comp_mv.as_int != 0) all_zero = 0;
+    }
+  } else if (ref_frame <= ALTREF_FRAME) {
+    for (idx = 0; idx < MAX_MV_REF_CANDIDATES; ++idx)
+      if (mv_ref_list[idx].as_int != 0) all_zero = 0;
+  }
 
   if (all_zero) mode_context[ref_frame] |= (1 << ALL_ZERO_FLAG_OFFSET);
 #endif
diff --git a/av1/common/mvref_common.h b/av1/common/mvref_common.h
index 96646d0..b6a5e22 100644
--- a/av1/common/mvref_common.h
+++ b/av1/common/mvref_common.h
@@ -350,19 +350,24 @@
 
 #if CONFIG_REF_MV
 static INLINE int av1_nmv_ctx(const uint8_t ref_mv_count,
-                              const CANDIDATE_MV *ref_mv_stack) {
+                              const CANDIDATE_MV *ref_mv_stack, int ref,
+                              int ref_mv_idx) {
+  int_mv this_mv = (ref == 0) ? ref_mv_stack[ref_mv_idx].this_mv
+                              : ref_mv_stack[ref_mv_idx].comp_mv;
 #if CONFIG_EXT_INTER
   return 0;
 #endif
-  if (ref_mv_stack[0].weight > REF_CAT_LEVEL && ref_mv_count > 0) {
-    if (abs(ref_mv_stack[0].this_mv.as_mv.row -
-            ref_mv_stack[0].pred_mv.as_mv.row) <= 4 &&
-        abs(ref_mv_stack[0].this_mv.as_mv.col -
-            ref_mv_stack[0].pred_mv.as_mv.col) <= 4)
+
+  if (ref_mv_stack[ref_mv_idx].weight >= REF_CAT_LEVEL && ref_mv_count > 0) {
+    if (abs(this_mv.as_mv.row -
+            ref_mv_stack[ref_mv_idx].pred_mv[ref].as_mv.row) <= 4 &&
+        abs(this_mv.as_mv.col -
+            ref_mv_stack[ref_mv_idx].pred_mv[ref].as_mv.col) <= 4)
       return 2;
     else
       return 1;
   }
+
   return 0;
 }
 
@@ -404,6 +409,8 @@
     const int16_t *const mode_context, const MV_REFERENCE_FRAME *const rf,
     BLOCK_SIZE bsize, int block) {
   int16_t mode_ctx = 0;
+  int8_t ref_frame_type = av1_ref_frame_type(rf);
+
   if (block >= 0) {
     mode_ctx = mode_context[rf[0]] & 0x00ff;
 
@@ -418,7 +425,7 @@
   else if (rf[0] != ALTREF_FRAME)
     return mode_context[rf[0]] & ~(mode_context[ALTREF_FRAME] & 0xfe00);
   else
-    return mode_context[rf[0]];
+    return mode_context[ref_frame_type];
 }
 
 static INLINE uint8_t av1_drl_ctx(const CANDIDATE_MV *ref_mv_stack,
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 6cd6cbe..62b2c7a 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -82,6 +82,9 @@
 
 typedef struct {
   int_mv mv[2];
+#if CONFIG_REF_MV
+  int_mv pred_mv[2];
+#endif
   MV_REFERENCE_FRAME ref_frame[2];
 } MV_REF;
 
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 1bc76c4..ceed8b3 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -966,8 +966,10 @@
 #endif
       for (i = 0; i < 1 + is_compound; ++i) {
 #if CONFIG_REF_MV
-        int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[mbmi->ref_frame[i]],
-                                  xd->ref_mv_stack[mbmi->ref_frame[i]]);
+        int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+        int nmv_ctx =
+            av1_nmv_ctx(xd->ref_mv_count[rf_type], xd->ref_mv_stack[rf_type], i,
+                        mbmi->ref_mv_idx);
         nmv_context_counts *const mv_counts =
             counts ? &counts->mv[nmv_ctx] : NULL;
         read_mv(r, &mv[i].as_mv, &ref_mv[i].as_mv, is_compound,
@@ -1031,8 +1033,10 @@
       assert(is_compound);
       for (i = 0; i < 2; ++i) {
 #if CONFIG_REF_MV
-        int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[mbmi->ref_frame[i]],
-                                  xd->ref_mv_stack[mbmi->ref_frame[i]]);
+        int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+        int nmv_ctx =
+            av1_nmv_ctx(xd->ref_mv_count[rf_type], xd->ref_mv_stack[rf_type], i,
+                        mbmi->ref_mv_idx);
         nmv_context_counts *const mv_counts =
             counts ? &counts->mv[nmv_ctx] : NULL;
         read_mv(r, &mv[i].as_mv, &ref_mv[i].as_mv, is_compound,
@@ -1072,8 +1076,9 @@
     case NEW_NEARESTMV: {
       FRAME_COUNTS *counts = xd->counts;
 #if CONFIG_REF_MV
-      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[mbmi->ref_frame[0]],
-                                xd->ref_mv_stack[mbmi->ref_frame[0]]);
+      int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[rf_type],
+                                xd->ref_mv_stack[rf_type], 0, mbmi->ref_mv_idx);
       nmv_context_counts *const mv_counts =
           counts ? &counts->mv[nmv_ctx] : NULL;
       read_mv(r, &mv[0].as_mv, &ref_mv[0].as_mv, is_compound,
@@ -1091,8 +1096,9 @@
     case NEAREST_NEWMV: {
       FRAME_COUNTS *counts = xd->counts;
 #if CONFIG_REF_MV
-      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[mbmi->ref_frame[1]],
-                                xd->ref_mv_stack[mbmi->ref_frame[1]]);
+      int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[rf_type],
+                                xd->ref_mv_stack[rf_type], 1, mbmi->ref_mv_idx);
       nmv_context_counts *const mv_counts =
           counts ? &counts->mv[nmv_ctx] : NULL;
       mv[0].as_int = nearest_mv[0].as_int;
@@ -1111,8 +1117,9 @@
     case NEAR_NEWMV: {
       FRAME_COUNTS *counts = xd->counts;
 #if CONFIG_REF_MV
-      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[mbmi->ref_frame[1]],
-                                xd->ref_mv_stack[mbmi->ref_frame[1]]);
+      int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[rf_type],
+                                xd->ref_mv_stack[rf_type], 1, mbmi->ref_mv_idx);
       nmv_context_counts *const mv_counts =
           counts ? &counts->mv[nmv_ctx] : NULL;
       mv[0].as_int = near_mv[0].as_int;
@@ -1132,8 +1139,9 @@
     case NEW_NEARMV: {
       FRAME_COUNTS *counts = xd->counts;
 #if CONFIG_REF_MV
-      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[mbmi->ref_frame[0]],
-                                xd->ref_mv_stack[mbmi->ref_frame[0]]);
+      int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+      int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[rf_type],
+                                xd->ref_mv_stack[rf_type], 0, mbmi->ref_mv_idx);
       nmv_context_counts *const mv_counts =
           counts ? &counts->mv[nmv_ctx] : NULL;
       read_mv(r, &mv[0].as_mv, &ref_mv[0].as_mv, is_compound,
@@ -1235,6 +1243,29 @@
   }
 
 #if CONFIG_REF_MV
+  for (; ref_frame < MODE_CTX_REF_FRAMES; ++ref_frame) {
+    av1_find_mv_refs(cm, xd, mi, ref_frame, &xd->ref_mv_count[ref_frame],
+                     xd->ref_mv_stack[ref_frame],
+#if CONFIG_EXT_INTER
+                     compound_inter_mode_ctx,
+#endif  // CONFIG_EXT_INTER
+                     ref_mvs[ref_frame], mi_row, mi_col, fpm_sync, (void *)pbi,
+                     inter_mode_ctx);
+
+    if (xd->ref_mv_count[ref_frame] < 2) {
+      MV_REFERENCE_FRAME rf[2];
+      av1_set_ref_frame(rf, ref_frame);
+      for (ref = 0; ref < 2; ++ref) {
+        lower_mv_precision(&ref_mvs[rf[ref]][0].as_mv, allow_hp);
+        lower_mv_precision(&ref_mvs[rf[ref]][1].as_mv, allow_hp);
+      }
+
+      if (ref_mvs[rf[0]][0].as_int != 0 || ref_mvs[rf[0]][1].as_int != 0 ||
+          ref_mvs[rf[1]][0].as_int != 0 || ref_mvs[rf[1]][1].as_int != 0)
+        inter_mode_ctx[ref_frame] &= ~(1 << ALL_ZERO_FLAG_OFFSET);
+    }
+  }
+
 #if CONFIG_EXT_INTER
   if (is_compound)
     mode_ctx = compound_inter_mode_ctx[mbmi->ref_frame[0]];
@@ -1802,6 +1833,10 @@
         mv->ref_frame[1] = mi->mbmi.ref_frame[1];
         mv->mv[0].as_int = mi->mbmi.mv[0].as_int;
         mv->mv[1].as_int = mi->mbmi.mv[1].as_int;
+#if CONFIG_REF_MV
+        mv->pred_mv[0].as_int = mi->mbmi.pred_mv[0].as_int;
+        mv->pred_mv[1].as_int = mi->mbmi.pred_mv[1].as_int;
+#endif
       }
     }
   }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 7636151..7ef24bb 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1256,9 +1256,10 @@
 #endif  // CONFIG_EXT_INTER
             for (ref = 0; ref < 1 + is_compound; ++ref) {
 #if CONFIG_REF_MV
-              int nmv_ctx =
-                  av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[ref]],
-                              mbmi_ext->ref_mv_stack[mbmi->ref_frame[ref]]);
+              int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+              int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                                        mbmi_ext->ref_mv_stack[rf_type], ref,
+                                        mbmi->ref_mv_idx);
               const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
 #endif
               av1_encode_mv(cpi, w, &mi->bmi[j].as_mv[ref].as_mv,
@@ -1280,9 +1281,10 @@
 #if CONFIG_EXT_INTER
           else if (b_mode == NEAREST_NEWMV || b_mode == NEAR_NEWMV) {
 #if CONFIG_REF_MV
-            int nmv_ctx =
-                av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[1]],
-                            mbmi_ext->ref_mv_stack[mbmi->ref_frame[1]]);
+            int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+            int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                                      mbmi_ext->ref_mv_stack[rf_type], 1,
+                                      mbmi->ref_mv_idx);
             const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
 #endif
             av1_encode_mv(cpi, w, &mi->bmi[j].as_mv[1].as_mv,
@@ -1293,9 +1295,10 @@
                           nmvc, allow_hp);
           } else if (b_mode == NEW_NEARESTMV || b_mode == NEW_NEARMV) {
 #if CONFIG_REF_MV
-            int nmv_ctx =
-                av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[0]],
-                            mbmi_ext->ref_mv_stack[mbmi->ref_frame[0]]);
+            int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+            int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                                      mbmi_ext->ref_mv_stack[rf_type], 0,
+                                      mbmi->ref_mv_idx);
             const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
 #endif
             av1_encode_mv(cpi, w, &mi->bmi[j].as_mv[0].as_mv,
@@ -1317,9 +1320,10 @@
         int_mv ref_mv;
         for (ref = 0; ref < 1 + is_compound; ++ref) {
 #if CONFIG_REF_MV
-          int nmv_ctx =
-              av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[ref]],
-                          mbmi_ext->ref_mv_stack[mbmi->ref_frame[ref]]);
+          int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+          int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                                    mbmi_ext->ref_mv_stack[rf_type], ref,
+                                    mbmi->ref_mv_idx);
           const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
 #endif
           ref_mv = mbmi_ext->ref_mvs[mbmi->ref_frame[ref]][0];
@@ -1342,8 +1346,10 @@
 #if CONFIG_EXT_INTER
       } else if (mode == NEAREST_NEWMV || mode == NEAR_NEWMV) {
 #if CONFIG_REF_MV
-        int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[1]],
-                                  mbmi_ext->ref_mv_stack[mbmi->ref_frame[1]]);
+        int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+        int nmv_ctx =
+            av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                        mbmi_ext->ref_mv_stack[rf_type], 1, mbmi->ref_mv_idx);
         const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
 #endif
         av1_encode_mv(cpi, w, &mbmi->mv[1].as_mv,
@@ -1354,8 +1360,10 @@
                       nmvc, allow_hp);
       } else if (mode == NEW_NEARESTMV || mode == NEW_NEARMV) {
 #if CONFIG_REF_MV
-        int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[0]],
-                                  mbmi_ext->ref_mv_stack[mbmi->ref_frame[0]]);
+        int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+        int nmv_ctx =
+            av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                        mbmi_ext->ref_mv_stack[rf_type], 0, mbmi->ref_mv_idx);
         const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
 #endif
         av1_encode_mv(cpi, w, &mbmi->mv[0].as_mv,
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index cfc4718..ab896f4 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1075,6 +1075,7 @@
       clamp_mv_ref(&this_mv.as_mv, xd->n8_w << 3, xd->n8_h << 3, xd);
       x->mbmi_ext->ref_mvs[mbmi->ref_frame[i]][0] = this_mv;
       mbmi->pred_mv[i] = this_mv;
+      mi->mbmi.pred_mv[i] = this_mv;
     }
   }
 #endif
diff --git a/av1/encoder/encodemv.c b/av1/encoder/encodemv.c
index ee627bd..00a95ee 100644
--- a/av1/encoder/encodemv.c
+++ b/av1/encoder/encodemv.c
@@ -298,8 +298,10 @@
       const MV diff = { mvs[i].as_mv.row - ref->row,
                         mvs[i].as_mv.col - ref->col };
 #if CONFIG_REF_MV
-      int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[i]],
-                                mbmi_ext->ref_mv_stack[mbmi->ref_frame[i]]);
+      int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+      int nmv_ctx =
+          av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                      mbmi_ext->ref_mv_stack[rf_type], i, mbmi->ref_mv_idx);
       nmv_context_counts *counts = &nmv_counts[nmv_ctx];
       (void)pred_mvs;
 #endif
@@ -310,8 +312,10 @@
     const MV diff = { mvs[1].as_mv.row - ref->row,
                       mvs[1].as_mv.col - ref->col };
 #if CONFIG_REF_MV
-    int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[1]],
-                              mbmi_ext->ref_mv_stack[mbmi->ref_frame[1]]);
+    int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+    int nmv_ctx =
+        av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                    mbmi_ext->ref_mv_stack[rf_type], 1, mbmi->ref_mv_idx);
     nmv_context_counts *counts = &nmv_counts[nmv_ctx];
 #endif
     av1_inc_mv(&diff, counts, av1_use_mv_hp(ref));
@@ -320,8 +324,10 @@
     const MV diff = { mvs[0].as_mv.row - ref->row,
                       mvs[0].as_mv.col - ref->col };
 #if CONFIG_REF_MV
-    int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[0]],
-                              mbmi_ext->ref_mv_stack[mbmi->ref_frame[0]]);
+    int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+    int nmv_ctx =
+        av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                    mbmi_ext->ref_mv_stack[rf_type], 0, mbmi->ref_mv_idx);
     nmv_context_counts *counts = &nmv_counts[nmv_ctx];
 #endif
     av1_inc_mv(&diff, counts, av1_use_mv_hp(ref));
@@ -347,8 +353,10 @@
       const MV diff = { mvs[i].as_mv.row - ref->row,
                         mvs[i].as_mv.col - ref->col };
 #if CONFIG_REF_MV
-      int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[i]],
-                                mbmi_ext->ref_mv_stack[mbmi->ref_frame[i]]);
+      int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+      int nmv_ctx =
+          av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                      mbmi_ext->ref_mv_stack[rf_type], i, mbmi->ref_mv_idx);
       nmv_context_counts *counts = &nmv_counts[nmv_ctx];
 #endif
       av1_inc_mv(&diff, counts, av1_use_mv_hp(ref));
@@ -358,8 +366,10 @@
     const MV diff = { mvs[1].as_mv.row - ref->row,
                       mvs[1].as_mv.col - ref->col };
 #if CONFIG_REF_MV
-    int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[1]],
-                              mbmi_ext->ref_mv_stack[mbmi->ref_frame[1]]);
+    int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+    int nmv_ctx =
+        av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                    mbmi_ext->ref_mv_stack[rf_type], 1, mbmi->ref_mv_idx);
     nmv_context_counts *counts = &nmv_counts[nmv_ctx];
 #endif
     av1_inc_mv(&diff, counts, av1_use_mv_hp(ref));
@@ -368,8 +378,10 @@
     const MV diff = { mvs[0].as_mv.row - ref->row,
                       mvs[0].as_mv.col - ref->col };
 #if CONFIG_REF_MV
-    int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[0]],
-                              mbmi_ext->ref_mv_stack[mbmi->ref_frame[0]]);
+    int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+    int nmv_ctx =
+        av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                    mbmi_ext->ref_mv_stack[rf_type], 0, mbmi->ref_mv_idx);
     nmv_context_counts *counts = &nmv_counts[nmv_ctx];
 #endif
     av1_inc_mv(&diff, counts, av1_use_mv_hp(ref));
@@ -389,8 +401,10 @@
 
   for (i = 0; i < 1 + has_second_ref(mbmi); ++i) {
 #if CONFIG_REF_MV
-    int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[i]],
-                              mbmi_ext->ref_mv_stack[mbmi->ref_frame[i]]);
+    int8_t rf_type = av1_ref_frame_type(mbmi->ref_frame);
+    int nmv_ctx =
+        av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                    mbmi_ext->ref_mv_stack[rf_type], i, mbmi->ref_mv_idx);
     nmv_context_counts *counts = &nmv_counts[nmv_ctx];
     const MV *ref = &pred_mvs[i].as_mv;
 #else
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 0c80174..4d157ff 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -335,10 +335,13 @@
 }
 
 #if CONFIG_REF_MV
-void av1_set_mvcost(MACROBLOCK *x, MV_REFERENCE_FRAME ref_frame) {
+void av1_set_mvcost(MACROBLOCK *x, MV_REFERENCE_FRAME ref_frame, int ref,
+                    int ref_mv_idx) {
   MB_MODE_INFO_EXT *mbmi_ext = x->mbmi_ext;
-  int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[ref_frame],
-                            mbmi_ext->ref_mv_stack[ref_frame]);
+  int8_t rf_type = av1_ref_frame_type(x->e_mbd.mi[0]->mbmi.ref_frame);
+  int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[rf_type],
+                            mbmi_ext->ref_mv_stack[rf_type], ref, ref_mv_idx);
+  (void)ref_frame;
   x->mvcost = x->mv_cost_stack[nmv_ctx];
   x->nmvjointcost = x->nmv_vec_cost[nmv_ctx];
   x->mvsadcost = x->mvcost;
diff --git a/av1/encoder/rd.h b/av1/encoder/rd.h
index 54c10b2..933733b 100644
--- a/av1/encoder/rd.h
+++ b/av1/encoder/rd.h
@@ -413,7 +413,8 @@
 void av1_init_me_luts(void);
 
 #if CONFIG_REF_MV
-void av1_set_mvcost(MACROBLOCK *x, MV_REFERENCE_FRAME ref_frame);
+void av1_set_mvcost(MACROBLOCK *x, MV_REFERENCE_FRAME ref_frame, int ref,
+                    int ref_mv_idx);
 #endif
 
 void av1_get_entropy_contexts(BLOCK_SIZE bsize, TX_SIZE tx_size,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9721abe..3d00687 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1664,7 +1664,7 @@
 
   for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
 #if CONFIG_REF_MV
-    if (tx_type != DCT_DCT && is_inter && mbmi->ref_mv_idx > 0) continue;
+    if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
 #endif
     rd = choose_tx_size_fix_type(cpi, bs, x, &r, &d, &s, &sse, ref_best_rd,
                                  tx_type, prune);
@@ -4159,7 +4159,7 @@
 #if CONFIG_REF_MV
       for (idx = 0; idx < 1 + is_compound; ++idx) {
         this_mv[idx] = seg_mvs[mbmi->ref_frame[idx]];
-        av1_set_mvcost(x, mbmi->ref_frame[idx]);
+        av1_set_mvcost(x, mbmi->ref_frame[idx], idx, mbmi->ref_mv_idx);
         thismvcost +=
             av1_mv_bit_cost(&this_mv[idx].as_mv, &best_ref_mv[idx]->as_mv,
                             x->nmvjointcost, x->mvcost, MV_COST_WEIGHT_SUB);
@@ -4264,8 +4264,11 @@
 
 #if CONFIG_REF_MV
   if (mode == NEWMV) {
-    mic->bmi[i].pred_mv[0].as_int = best_ref_mv[0]->as_int;
-    if (is_compound) mic->bmi[i].pred_mv[1].as_int = best_ref_mv[1]->as_int;
+    mic->bmi[i].pred_mv[0].as_int =
+        mbmi_ext->ref_mvs[mbmi->ref_frame[0]][0].as_int;
+    if (is_compound)
+      mic->bmi[i].pred_mv[1].as_int =
+          mbmi_ext->ref_mvs[mbmi->ref_frame[1]][0].as_int;
   } else {
     mic->bmi[i].pred_mv[0].as_int = this_mv[0].as_int;
     if (is_compound) mic->bmi[i].pred_mv[1].as_int = this_mv[1].as_int;
@@ -4703,7 +4706,7 @@
     best_mv->row >>= 3;
 
 #if CONFIG_REF_MV
-    av1_set_mvcost(x, refs[id]);
+    av1_set_mvcost(x, refs[id], id, mbmi->ref_mv_idx);
 #endif
 
     // Small-range full-pixel motion search.
@@ -4782,7 +4785,7 @@
         xd->plane[i].pre[ref] = backup_yv12[ref][i];
     }
 #if CONFIG_REF_MV
-    av1_set_mvcost(x, refs[ref]);
+    av1_set_mvcost(x, refs[ref], ref, mbmi->ref_mv_idx);
 #endif
 #if CONFIG_EXT_INTER
     if (bsize >= BLOCK_8X8)
@@ -5141,7 +5144,7 @@
           x->best_mv.as_int = x->second_best_mv.as_int = INVALID_MV;
 
 #if CONFIG_REF_MV
-          av1_set_mvcost(x, mbmi->ref_frame[0]);
+          av1_set_mvcost(x, mbmi->ref_frame[0], 0, mbmi->ref_mv_idx);
 #endif
           bestsme = av1_full_pixel_search(
               cpi, x, bsize, &mvp_full, step_param, sadpb,
@@ -5820,10 +5823,6 @@
   pred_mv[1] = x->mbmi_ext->ref_mvs[ref][1].as_mv;
   pred_mv[2] = x->pred_mv[ref];
 
-#if CONFIG_REF_MV
-  av1_set_mvcost(x, ref);
-#endif
-
   if (scaled_ref_frame) {
     int i;
     // Swap out the reference frame for a version that's been scaled to
@@ -5835,6 +5834,12 @@
     av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL);
   }
 
+  av1_set_mv_search_range(x, &ref_mv);
+
+#if CONFIG_REF_MV
+  av1_set_mvcost(x, ref, ref_idx, mbmi->ref_mv_idx);
+#endif
+
   // Work out the size of the first step in the mv step search.
   // 0 here is maximum length first step. 1 is AOMMAX >> 1 etc.
   if (cpi->sf.mv.auto_mv_step_size && cm->show_frame) {
@@ -6017,7 +6022,7 @@
       av1_get_scaled_ref_frame(cpi, ref);
 
 #if CONFIG_REF_MV
-  av1_set_mvcost(x, ref);
+  av1_set_mvcost(x, ref, ref_idx, mbmi->ref_mv_idx);
 #endif
 
   if (scaled_ref_frame) {
@@ -6145,7 +6150,7 @@
   pred_mv[2] = x->pred_mv[ref];
 
 #if CONFIG_REF_MV
-  av1_set_mvcost(x, ref);
+  av1_set_mvcost(x, ref, ref_idx, mbmi->ref_mv_idx);
 #endif
 
   if (scaled_ref_frame) {
@@ -6791,13 +6796,13 @@
                               single_newmv, &rate_mv, 0);
         } else {
 #if CONFIG_REF_MV
-          av1_set_mvcost(x, mbmi->ref_frame[0]);
+          av1_set_mvcost(x, mbmi->ref_frame[0], 0, mbmi->ref_mv_idx);
 #endif  // CONFIG_REF_MV
           rate_mv = av1_mv_bit_cost(&frame_mv[refs[0]].as_mv,
                                     &x->mbmi_ext->ref_mvs[refs[0]][0].as_mv,
                                     x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
 #if CONFIG_REF_MV
-          av1_set_mvcost(x, mbmi->ref_frame[1]);
+          av1_set_mvcost(x, mbmi->ref_frame[1], 1, mbmi->ref_mv_idx);
 #endif  // CONFIG_REF_MV
           rate_mv += av1_mv_bit_cost(
               &frame_mv[refs[1]].as_mv, &x->mbmi_ext->ref_mvs[refs[1]][0].as_mv,
@@ -6824,13 +6829,13 @@
                             single_newmv, &rate_mv, 0);
       } else {
 #if CONFIG_REF_MV
-        av1_set_mvcost(x, mbmi->ref_frame[0]);
+        av1_set_mvcost(x, mbmi->ref_frame[0], 0, mbmi->ref_mv_idx);
 #endif  // CONFIG_REF_MV
         rate_mv = av1_mv_bit_cost(&frame_mv[refs[0]].as_mv,
                                   &x->mbmi_ext->ref_mvs[refs[0]][0].as_mv,
                                   x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
 #if CONFIG_REF_MV
-        av1_set_mvcost(x, mbmi->ref_frame[1]);
+        av1_set_mvcost(x, mbmi->ref_frame[1], 1, mbmi->ref_mv_idx);
 #endif  // CONFIG_REF_MV
         rate_mv += av1_mv_bit_cost(&frame_mv[refs[1]].as_mv,
                                    &x->mbmi_ext->ref_mvs[refs[1]][0].as_mv,
@@ -8228,6 +8233,10 @@
   int64_t best_pred_diff[REFERENCE_MODES];
   int64_t best_pred_rd[REFERENCE_MODES];
   MB_MODE_INFO best_mbmode;
+#if CONFIG_REF_MV
+  int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+  int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+#endif
   int best_mode_skippable = 0;
   int midx, best_mode_index = -1;
   unsigned int ref_costs_single[TOTAL_REFS_PER_FRAME];
@@ -8560,6 +8569,9 @@
     this_mode = av1_mode_order[mode_index].mode;
     ref_frame = av1_mode_order[mode_index].ref_frame[0];
     second_ref_frame = av1_mode_order[mode_index].ref_frame[1];
+#if CONFIG_REF_MV
+    mbmi->ref_mv_idx = 0;
+#endif
 
 #if CONFIG_EXT_INTER
     if (ref_frame > INTRA_FRAME && second_ref_frame == INTRA_FRAME) {
@@ -9145,8 +9157,14 @@
         // Cost the skip mb case
         rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
       } else if (ref_frame != INTRA_FRAME && !xd->lossless[mbmi->segment_id]) {
+#if CONFIG_REF_MV
+        if (RDCOST(x->rdmult, x->rddiv, rate_y + rate_uv + rate_skip0,
+                   distortion2) <
+            RDCOST(x->rdmult, x->rddiv, rate_skip1, total_sse)) {
+#else
         if (RDCOST(x->rdmult, x->rddiv, rate_y + rate_uv, distortion2) <
             RDCOST(x->rdmult, x->rddiv, 0, total_sse)) {
+#endif
           // Add in the cost of the no skip flag.
           rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
         } else {
@@ -10038,6 +10056,10 @@
     ref_frame = av1_ref_order[ref_index].ref_frame[0];
     second_ref_frame = av1_ref_order[ref_index].ref_frame[1];
 
+#if CONFIG_REF_MV
+    mbmi->ref_mv_idx = 0;
+#endif
+
     // Look at the reference frame of the best mode so far and set the
     // skip mask to look at a subset of the remaining modes.
     if (ref_index > 2 && sf->mode_skip_start < MAX_MODES) {
@@ -10679,12 +10701,12 @@
     for (i = 0; i < 4; ++i)
       memcpy(&xd->mi[0]->bmi[i], &best_bmodes[i], sizeof(b_mode_info));
 
-    mbmi->mv[0].as_int = xd->mi[0]->bmi[3].as_mv[0].as_int;
-    mbmi->mv[1].as_int = xd->mi[0]->bmi[3].as_mv[1].as_int;
 #if CONFIG_REF_MV
     mbmi->pred_mv[0].as_int = xd->mi[0]->bmi[3].pred_mv[0].as_int;
     mbmi->pred_mv[1].as_int = xd->mi[0]->bmi[3].pred_mv[1].as_int;
 #endif
+    mbmi->mv[0].as_int = xd->mi[0]->bmi[3].as_mv[0].as_int;
+    mbmi->mv[1].as_int = xd->mi[0]->bmi[3].as_mv[1].as_int;
   }
 
   for (i = 0; i < REFERENCE_MODES; ++i) {