Allow NEAR_NEARMV and NEW_NEWMV modes to use ref_mv_idx

When ext-inter and ref-mv are both enabled, this patch
allows the NEAR_NEARMV and NEW_NEWMV modes to pick from
the extended reference mv list, just like the NEARMV and
NEWMV modes can.

Change-Id: Ibcc9e19dba7779422c1c9589d5498159e83bf61e
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index da0af3b..a376948 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -252,7 +252,11 @@
   uint8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
   mbmi->ref_mv_idx = 0;
 
+#if CONFIG_EXT_INTER
+  if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
+#else
   if (mbmi->mode == NEWMV) {
+#endif
     int idx;
     for (idx = 0; idx < 2; ++idx) {
       if (xd->ref_mv_count[ref_frame_type] > idx + 1) {
@@ -269,7 +273,11 @@
     }
   }
 
+#if CONFIG_EXT_INTER
+  if (mbmi->mode == NEARMV || mbmi->mode == NEAR_NEARMV) {
+#else
   if (mbmi->mode == NEARMV) {
+#endif
     int idx;
     // Offset the NEARESTMV mode.
     // TODO(jingning): Unify the two syntax decoding loops after the NEARESTMV
@@ -1732,7 +1740,12 @@
 #endif  // CONFIG_REF_MV && CONFIG_EXT_INTER
                                      r, mode_ctx);
 #if CONFIG_REF_MV
+#if CONFIG_EXT_INTER
+      if (mbmi->mode == NEARMV || mbmi->mode == NEAR_NEARMV ||
+          mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV)
+#else
       if (mbmi->mode == NEARMV || mbmi->mode == NEWMV)
+#endif
         read_drl_idx(cm, xd, mbmi, r);
 #endif
     }
@@ -1791,15 +1804,16 @@
 
 #if CONFIG_EXT_INTER
     if (xd->ref_mv_count[ref_frame_type] > 1) {
+      int ref_mv_idx = 1 + mbmi->ref_mv_idx;
       if (mbmi->mode == NEAR_NEWMV || mbmi->mode == NEAR_NEARESTMV ||
           mbmi->mode == NEAR_NEARMV) {
-        nearmv[0] = xd->ref_mv_stack[ref_frame_type][1].this_mv;
+        nearmv[0] = xd->ref_mv_stack[ref_frame_type][ref_mv_idx].this_mv;
         lower_mv_precision(&nearmv[0].as_mv, allow_hp);
       }
 
       if (mbmi->mode == NEW_NEARMV || mbmi->mode == NEAREST_NEARMV ||
           mbmi->mode == NEAR_NEARMV) {
-        nearmv[1] = xd->ref_mv_stack[ref_frame_type][1].comp_mv;
+        nearmv[1] = xd->ref_mv_stack[ref_frame_type][ref_mv_idx].comp_mv;
         lower_mv_precision(&nearmv[1].as_mv, allow_hp);
       }
     }
@@ -1935,19 +1949,25 @@
     ref_mv[0] = nearestmv[0];
     ref_mv[1] = nearestmv[1];
 
-    for (ref = 0; ref < 1 + is_compound && mbmi->mode == NEWMV; ++ref) {
-#if CONFIG_REF_MV
-      uint8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
-      if (xd->ref_mv_count[ref_frame_type] > 1) {
-        ref_mv[ref] =
-            (ref == 0)
-                ? xd->ref_mv_stack[ref_frame_type][mbmi->ref_mv_idx].this_mv
-                : xd->ref_mv_stack[ref_frame_type][mbmi->ref_mv_idx].comp_mv;
-        clamp_mv_ref(&ref_mv[ref].as_mv, xd->n8_w << MI_SIZE_LOG2,
-                     xd->n8_h << MI_SIZE_LOG2, xd);
-      }
+#if CONFIG_EXT_INTER
+    if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
+#else
+    if (mbmi->mode == NEWMV) {
 #endif
-      nearestmv[ref] = ref_mv[ref];
+      for (ref = 0; ref < 1 + is_compound; ++ref) {
+#if CONFIG_REF_MV
+        uint8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
+        if (xd->ref_mv_count[ref_frame_type] > 1) {
+          ref_mv[ref] =
+              (ref == 0)
+                  ? xd->ref_mv_stack[ref_frame_type][mbmi->ref_mv_idx].this_mv
+                  : xd->ref_mv_stack[ref_frame_type][mbmi->ref_mv_idx].comp_mv;
+          clamp_mv_ref(&ref_mv[ref].as_mv, xd->n8_w << MI_SIZE_LOG2,
+                       xd->n8_h << MI_SIZE_LOG2, xd);
+        }
+#endif
+        nearestmv[ref] = ref_mv[ref];
+      }
     }
 
     int mv_corrupted_flag =
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 70064a1..7999525 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -304,7 +304,11 @@
 
   assert(mbmi->ref_mv_idx < 3);
 
+#if CONFIG_EXT_INTER
+  if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
+#else
   if (mbmi->mode == NEWMV) {
+#endif
     int idx;
     for (idx = 0; idx < 2; ++idx) {
       if (mbmi_ext->ref_mv_count[ref_frame_type] > idx + 1) {
@@ -319,7 +323,11 @@
     return;
   }
 
+#if CONFIG_EXT_INTER
+  if (mbmi->mode == NEARMV || mbmi->mode == NEAR_NEARMV) {
+#else
   if (mbmi->mode == NEARMV) {
+#endif
     int idx;
     // TODO(jingning): Temporary solution to compensate the NEARESTMV offset.
     for (idx = 1; idx < 3; ++idx) {
@@ -1732,13 +1740,16 @@
                            mode_ctx);
 
 #if CONFIG_REF_MV
-        if (mode == NEARMV || mode == NEWMV)
-          write_drl_idx(cm, mbmi, mbmi_ext, w);
 #if CONFIG_EXT_INTER
+        if (mode == NEARMV || mode == NEAR_NEARMV || mode == NEWMV ||
+            mode == NEW_NEWMV)
+#else
+        if (mode == NEARMV || mode == NEWMV)
+#endif
+          write_drl_idx(cm, mbmi, mbmi_ext, w);
         else
           assert(mbmi->ref_mv_idx == 0);
 #endif
-#endif
       }
     }
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index d8f6fcd..49ab325 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1093,7 +1093,12 @@
 #if CONFIG_REF_MV
   rf_type = av1_ref_frame_type(mbmi->ref_frame);
   if (x->mbmi_ext->ref_mv_count[rf_type] > 1 &&
-      (mbmi->sb_type >= BLOCK_8X8 || unify_bsize) && mbmi->mode == NEWMV) {
+      (mbmi->sb_type >= BLOCK_8X8 || unify_bsize) &&
+#if CONFIG_EXT_INTER
+      (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV)) {
+#else
+      mbmi->mode == NEWMV) {
+#endif
     for (i = 0; i < 1 + has_second_ref(mbmi); ++i) {
       int_mv this_mv =
           (i == 0)
@@ -1299,7 +1304,11 @@
 #if !CONFIG_CB4X4
       mbmi->sb_type >= BLOCK_8X8 &&
 #endif  // !CONFIG_CB4X4
+#if CONFIG_EXT_INTER
+      (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV)) {
+#else
       mbmi->mode == NEWMV) {
+#endif
     for (i = 0; i < 1 + has_second_ref(mbmi); ++i) {
       int_mv this_mv =
           (i == 0)
@@ -2211,7 +2220,11 @@
         }
 #endif  // CONFIG_EXT_INTER
 
-        if (mode == NEWMV) {
+#if CONFIG_EXT_INTER
+        if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
+#else
+        if (mbmi->mode == NEWMV) {
+#endif
           uint8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
           int idx;
 
@@ -2226,7 +2239,11 @@
           }
         }
 
-        if (mode == NEARMV) {
+#if CONFIG_EXT_INTER
+        if (mbmi->mode == NEARMV || mbmi->mode == NEAR_NEARMV) {
+#else
+        if (mbmi->mode == NEARMV) {
+#endif
           uint8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
           int idx;
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 5461612..c24820e 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5260,8 +5260,10 @@
         lower_mv_precision(&this_mv[0].as_mv, 0);
       if (!cpi->common.allow_high_precision_mv)
         lower_mv_precision(&this_mv[1].as_mv, 0);
+      av1_set_mvcost(x, mbmi->ref_frame[0], 0, mbmi->ref_mv_idx);
       thismvcost += av1_mv_bit_cost(&this_mv[0].as_mv, &best_ref_mv[0]->as_mv,
                                     mvjcost, mvcost, MV_COST_WEIGHT_SUB);
+      av1_set_mvcost(x, mbmi->ref_frame[1], 1, mbmi->ref_mv_idx);
       thismvcost += av1_mv_bit_cost(&this_mv[1].as_mv, &best_ref_mv[1]->as_mv,
                                     mvjcost, mvcost, MV_COST_WEIGHT_SUB);
       break;
@@ -5270,6 +5272,7 @@
       this_mv[0].as_int = seg_mvs[mbmi->ref_frame[0]].as_int;
       if (!cpi->common.allow_high_precision_mv)
         lower_mv_precision(&this_mv[0].as_mv, 0);
+      av1_set_mvcost(x, mbmi->ref_frame[0], 0, mbmi->ref_mv_idx);
       thismvcost += av1_mv_bit_cost(&this_mv[0].as_mv, &best_ref_mv[0]->as_mv,
                                     mvjcost, mvcost, MV_COST_WEIGHT_SUB);
       this_mv[1].as_int = frame_mv[mode][mbmi->ref_frame[1]].as_int;
@@ -5280,6 +5283,7 @@
       this_mv[1].as_int = seg_mvs[mbmi->ref_frame[1]].as_int;
       if (!cpi->common.allow_high_precision_mv)
         lower_mv_precision(&this_mv[1].as_mv, 0);
+      av1_set_mvcost(x, mbmi->ref_frame[1], 1, mbmi->ref_mv_idx);
       thismvcost += av1_mv_bit_cost(&this_mv[1].as_mv, &best_ref_mv[1]->as_mv,
                                     mvjcost, mvcost, MV_COST_WEIGHT_SUB);
       break;
@@ -8902,9 +8906,10 @@
   }
 
   if (mbmi_ext->ref_mv_count[ref_frame_type] > 1) {
+    int ref_mv_idx = mbmi->ref_mv_idx + 1;
     if (this_mode == NEAR_NEWMV || this_mode == NEAR_NEARESTMV ||
         this_mode == NEAR_NEARMV) {
-      cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][1].this_mv;
+      cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx].this_mv;
 
       lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv);
       clamp_mv2(&cur_mv[0].as_mv, xd);
@@ -8914,7 +8919,7 @@
 
     if (this_mode == NEW_NEARMV || this_mode == NEAREST_NEARMV ||
         this_mode == NEAR_NEARMV) {
-      cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][1].comp_mv;
+      cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx].comp_mv;
 
       lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv);
       clamp_mv2(&cur_mv[1].as_mv, xd);
@@ -10505,17 +10510,40 @@
       mbmi->ref_mv_idx = 0;
       ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
 
-      if (this_mode == NEWMV && mbmi_ext->ref_mv_count[ref_frame_type] > 1) {
-        int ref;
-        for (ref = 0; ref < 1 + comp_pred; ++ref) {
-          int_mv this_mv =
-              (ref == 0) ? mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv
-                         : mbmi_ext->ref_mv_stack[ref_frame_type][0].comp_mv;
-          clamp_mv_ref(&this_mv.as_mv, xd->n8_w << MI_SIZE_LOG2,
-                       xd->n8_h << MI_SIZE_LOG2, xd);
-          mbmi_ext->ref_mvs[mbmi->ref_frame[ref]][0] = this_mv;
+#if CONFIG_EXT_INTER
+      if (comp_pred) {
+        if (mbmi_ext->ref_mv_count[ref_frame_type] > 1) {
+          if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV ||
+              this_mode == NEW_NEWMV) {
+            int_mv this_mv = mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv;
+            clamp_mv_ref(&this_mv.as_mv, xd->n8_w << MI_SIZE_LOG2,
+                         xd->n8_h << MI_SIZE_LOG2, xd);
+            mbmi_ext->ref_mvs[mbmi->ref_frame[0]][0] = this_mv;
+          }
+          if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV ||
+              this_mode == NEW_NEWMV) {
+            int_mv this_mv = mbmi_ext->ref_mv_stack[ref_frame_type][0].comp_mv;
+            clamp_mv_ref(&this_mv.as_mv, xd->n8_w << MI_SIZE_LOG2,
+                         xd->n8_h << MI_SIZE_LOG2, xd);
+            mbmi_ext->ref_mvs[mbmi->ref_frame[1]][0] = this_mv;
+          }
         }
+      } else {
+#endif  // CONFIG_EXT_INTER
+        if (this_mode == NEWMV && mbmi_ext->ref_mv_count[ref_frame_type] > 1) {
+          int ref;
+          for (ref = 0; ref < 1 + comp_pred; ++ref) {
+            int_mv this_mv =
+                (ref == 0) ? mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv
+                           : mbmi_ext->ref_mv_stack[ref_frame_type][0].comp_mv;
+            clamp_mv_ref(&this_mv.as_mv, xd->n8_w << MI_SIZE_LOG2,
+                         xd->n8_h << MI_SIZE_LOG2, xd);
+            mbmi_ext->ref_mvs[mbmi->ref_frame[ref]][0] = this_mv;
+          }
+        }
+#if CONFIG_EXT_INTER
       }
+#endif  // CONFIG_EXT_INTER
 #endif  // CONFIG_REF_MV
       {
         RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
@@ -10550,11 +10578,18 @@
       }
 
 #if CONFIG_REF_MV
-      // TODO(jingning): This needs some refactoring to improve code quality
-      // and reduce redundant steps.
+// TODO(jingning): This needs some refactoring to improve code quality
+// and reduce redundant steps.
+#if CONFIG_EXT_INTER
+      if (((mbmi->mode == NEARMV || mbmi->mode == NEAR_NEARMV) &&
+           mbmi_ext->ref_mv_count[ref_frame_type] > 2) ||
+          ((mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) &&
+           mbmi_ext->ref_mv_count[ref_frame_type] > 1)) {
+#else
       if ((mbmi->mode == NEARMV &&
            mbmi_ext->ref_mv_count[ref_frame_type] > 2) ||
           (mbmi->mode == NEWMV && mbmi_ext->ref_mv_count[ref_frame_type] > 1)) {
+#endif
         int_mv backup_mv = frame_mv[NEARMV][ref_frame];
         MB_MODE_INFO backup_mbmi = *mbmi;
         int backup_skip = x->skip;
@@ -11287,26 +11322,42 @@
       }
 
       if (nearestmv[0].as_int == best_mbmode.mv[0].as_int &&
-          nearestmv[1].as_int == best_mbmode.mv[1].as_int)
+          nearestmv[1].as_int == best_mbmode.mv[1].as_int) {
 #if CONFIG_EXT_INTER
         best_mbmode.mode = NEAREST_NEARESTMV;
-      else if (nearestmv[0].as_int == best_mbmode.mv[0].as_int &&
-               nearmv[1].as_int == best_mbmode.mv[1].as_int)
+      } else if (nearestmv[0].as_int == best_mbmode.mv[0].as_int &&
+                 nearmv[1].as_int == best_mbmode.mv[1].as_int) {
         best_mbmode.mode = NEAREST_NEARMV;
-      else if (nearmv[0].as_int == best_mbmode.mv[0].as_int &&
-               nearestmv[1].as_int == best_mbmode.mv[1].as_int)
+      } else if (nearmv[0].as_int == best_mbmode.mv[0].as_int &&
+                 nearestmv[1].as_int == best_mbmode.mv[1].as_int) {
         best_mbmode.mode = NEAR_NEARESTMV;
-      else if (nearmv[0].as_int == best_mbmode.mv[0].as_int &&
-               nearmv[1].as_int == best_mbmode.mv[1].as_int)
-        best_mbmode.mode = NEAR_NEARMV;
-      else if (best_mbmode.mv[0].as_int == zeromv[0].as_int &&
-               best_mbmode.mv[1].as_int == zeromv[1].as_int)
-        best_mbmode.mode = ZERO_ZEROMV;
+      } else {
+        int ref_set = (mbmi_ext->ref_mv_count[rf_type] >= 2)
+                          ? AOMMIN(2, mbmi_ext->ref_mv_count[rf_type] - 2)
+                          : INT_MAX;
+
+        for (i = 0; i <= ref_set && ref_set != INT_MAX; ++i) {
+          nearmv[0] = mbmi_ext->ref_mv_stack[rf_type][i + 1].this_mv;
+          nearmv[1] = mbmi_ext->ref_mv_stack[rf_type][i + 1].comp_mv;
+
+          if (nearmv[0].as_int == best_mbmode.mv[0].as_int &&
+              nearmv[1].as_int == best_mbmode.mv[1].as_int) {
+            best_mbmode.mode = NEAR_NEARMV;
+            best_mbmode.ref_mv_idx = i;
+          }
+        }
+
+        if (best_mbmode.mode != NEAR_NEARMV &&
+            best_mbmode.mv[0].as_int == zeromv[0].as_int &&
+            best_mbmode.mv[1].as_int == zeromv[1].as_int)
+          best_mbmode.mode = ZERO_ZEROMV;
+      }
 #else
         best_mbmode.mode = NEARESTMV;
-      else if (best_mbmode.mv[0].as_int == zeromv[0].as_int &&
-               best_mbmode.mv[1].as_int == zeromv[1].as_int)
+      } else if (best_mbmode.mv[0].as_int == zeromv[0].as_int &&
+                 best_mbmode.mv[1].as_int == zeromv[1].as_int) {
         best_mbmode.mode = ZEROMV;
+      }
 #endif  // CONFIG_EXT_INTER
     }
 #else
@@ -11381,7 +11432,12 @@
   // Make sure that the ref_mv_idx is only nonzero when we're
   // using a mode which can support ref_mv_idx
   if (best_mbmode.ref_mv_idx != 0 &&
+#if CONFIG_EXT_INTER
+      !(best_mbmode.mode == NEARMV || best_mbmode.mode == NEAR_NEARMV ||
+        best_mbmode.mode == NEWMV || best_mbmode.mode == NEW_NEWMV)) {
+#else
       !(best_mbmode.mode == NEARMV || best_mbmode.mode == NEWMV)) {
+#endif
     best_mbmode.ref_mv_idx = 0;
   }