Expand skip mode support for all comp frames

All compound predicted frames now have the choice to either turn on
or off the use of the skip mode. The current encoder design only
turns on the use of skip mode when the temporal distances of the two
reference frames to the current frame are only different by 1.

This patch also includes a fix on the calculating of the temporal
distance between the two reference frames to the current frame for
skip mode.

Performance wise, this patch does not have noticeable impact as the
encoder stays with the same choice with the frame-level skip mode
flag.

Change-Id: I34f370940b3b25d2ab429b8721344133ae6288ad
diff --git a/av1/common/mvref_common.c b/av1/common/mvref_common.c
index ab2d141..505a6a1 100644
--- a/av1/common/mvref_common.c
+++ b/av1/common/mvref_common.c
@@ -2022,7 +2022,7 @@
   cm->is_skip_mode_allowed = 0;
   cm->ref_frame_idx_0 = cm->ref_frame_idx_1 = INVALID_IDX;
 
-  if (cm->frame_type == KEY_FRAME || cm->intra_only) return;
+  if (frame_is_intra_only(cm) || cm->reference_mode == SINGLE_REFERENCE) return;
 
   RefCntBuffer *const frame_bufs = cm->buffer_pool->frame_bufs;
   const int cur_frame_offset = cm->frame_offset;
@@ -2030,7 +2030,7 @@
   int ref_idx[2] = { INVALID_IDX, INVALID_IDX };
   int ref_buf_idx[2] = { INVALID_IDX, INVALID_IDX };
 
-  // Identify the nearest forward and backward references
+  // Identify the nearest forward and backward references.
   for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
     const int buf_idx = cm->frame_refs[i].idx;
     if (buf_idx == INVALID_IDX) continue;
@@ -2053,30 +2053,15 @@
     }
   }
 
-  // Flag is set when and only when both forward and backward references
-  // are available and the difference between their temporal distance to
-  // the current frame is no greater than 1, i.e. they may have the same
-  // temporal distance to the current frame, or the distances are off by 1.
   if (ref_idx[0] != INVALID_IDX && ref_idx[1] != INVALID_IDX) {
-    const int cur_to_fwd = cm->frame_offset - ref_frame_offset[0];
-    const int cur_to_bwd = ref_frame_offset[1] - cm->frame_offset;
-#if 0
-    if ((ref_frame_offset[1] - ref_frame_offset[0]) <= 3)
-#endif  // 0
-    if (abs(cur_to_fwd - cur_to_bwd) <= 1) {
-      cm->is_skip_mode_allowed = 1;
-      cm->ref_frame_idx_0 = ref_idx[0];
-      cm->ref_frame_idx_1 = ref_idx[1];
-    }
-  }
-
-  // NOTE: Low delay scenario
-  if (av1_refs_are_one_sided(cm)) {
-    assert(ref_idx[1] == INVALID_IDX);
-    assert(ref_idx[0] != INVALID_IDX);
-    // Identify the second closest forward reference frame.
+    // == Bi-directional prediction ==
+    cm->is_skip_mode_allowed = 1;
+    cm->ref_frame_idx_0 = ref_idx[0];
+    cm->ref_frame_idx_1 = ref_idx[1];
+  } else if (ref_idx[0] != INVALID_IDX && ref_idx[1] == INVALID_IDX) {
+    // == Forward prediction only ==
+    // Identify the second nearest forward reference.
     ref_frame_offset[1] = -1;
-    cm->is_skip_mode_allowed = 0;
     for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
       const int buf_idx = cm->frame_refs[i].idx;
       if (buf_idx == INVALID_IDX) continue;
@@ -2091,14 +2076,9 @@
       }
     }
     if (ref_frame_offset[1] >= 0) {
-      const int cur_to_fwd0 = cur_frame_offset - ref_frame_offset[0];
-      const int cur_to_fwd1 = cur_frame_offset - ref_frame_offset[1];
-      assert(cur_to_fwd1 > cur_to_fwd0);
-      if ((cur_to_fwd1 - cur_to_fwd0) <= 1) {
-        cm->is_skip_mode_allowed = 1;
-        cm->ref_frame_idx_0 = AOMMIN(ref_idx[0], ref_idx[1]);
-        cm->ref_frame_idx_1 = AOMMAX(ref_idx[0], ref_idx[1]);
-      }
+      cm->is_skip_mode_allowed = 1;
+      cm->ref_frame_idx_0 = AOMMIN(ref_idx[0], ref_idx[1]);
+      cm->ref_frame_idx_1 = AOMMAX(ref_idx[0], ref_idx[1]);
     }
   }
 
diff --git a/av1/common/mvref_common.h b/av1/common/mvref_common.h
index c1a91d7..c161c86 100644
--- a/av1/common/mvref_common.h
+++ b/av1/common/mvref_common.h
@@ -420,6 +420,21 @@
   return one_sided_refs;
 }
 
+#if CONFIG_EXT_SKIP
+static INLINE void get_skip_mode_ref_offsets(const AV1_COMMON *cm,
+                                             int ref_offset[2]) {
+  ref_offset[0] = ref_offset[1] = 0;
+  if (!cm->is_skip_mode_allowed) return;
+
+  const int buf_idx_0 = cm->frame_refs[cm->ref_frame_idx_0].idx;
+  const int buf_idx_1 = cm->frame_refs[cm->ref_frame_idx_1].idx;
+  assert(buf_idx_0 != INVALID_IDX && buf_idx_1 != INVALID_IDX);
+
+  ref_offset[0] = cm->buffer_pool->frame_bufs[buf_idx_0].cur_frame_offset;
+  ref_offset[1] = cm->buffer_pool->frame_bufs[buf_idx_1].cur_frame_offset;
+}
+#endif  // CONFIG_EXT_SKIP
+
 void av1_setup_frame_buf_refs(AV1_COMMON *cm);
 #if CONFIG_FRAME_SIGN_BIAS
 void av1_setup_frame_sign_bias(AV1_COMMON *cm);
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 84dd492..7394e2c 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2850,23 +2850,6 @@
   }
   av1_setup_frame_buf_refs(cm);
 
-#if CONFIG_EXT_SKIP
-  av1_setup_skip_mode_allowed(cm);
-  cm->skip_mode_flag = cm->is_skip_mode_allowed ? aom_rb_read_bit(rb) : 0;
-  xd->all_one_sided_refs = cm->skip_mode_flag ? av1_refs_are_one_sided(cm) : 0;
-#if 0
-  printf(
-      "DECODER: Frame=%d, frame_offset=%d, show_frame=%d, "
-      "show_existing_frame=%d, is_skip_mode_allowed=%d, "
-      "ref_frame_idx=(%d,%d), frame_reference_mode=%d, "
-      "tpl_frame_ref0_idx=%d, skip_mode_flag=%d\n\n",
-      cm->current_video_frame, cm->frame_offset, cm->show_frame,
-      cm->show_existing_frame, cm->is_skip_mode_allowed, cm->ref_frame_idx_0,
-      cm->ref_frame_idx_1, cm->reference_mode, cm->tpl_frame_ref0_idx,
-      cm->skip_mode_flag);
-#endif  // 0
-#endif  // CONFIG_EXT_SKIP
-
 #if CONFIG_FRAME_SIGN_BIAS
 #if CONFIG_OBU
   if (cm->frame_type != S_FRAME)
@@ -3060,6 +3043,24 @@
   cm->tx_mode = read_tx_mode(cm, rb);
   cm->reference_mode = read_frame_reference_mode(cm, rb);
   if (cm->reference_mode != SINGLE_REFERENCE) setup_compound_reference_mode(cm);
+
+#if CONFIG_EXT_SKIP
+  av1_setup_skip_mode_allowed(cm);
+  cm->skip_mode_flag = cm->is_skip_mode_allowed ? aom_rb_read_bit(rb) : 0;
+  xd->all_one_sided_refs =
+      frame_is_intra_only(cm) ? 0 : av1_refs_are_one_sided(cm);
+#if 0
+  printf(
+      "DECODER: Frame=%d, frame_offset=%d, show_frame=%d, "
+      "show_existing_frame=%d, is_skip_mode_allowed=%d, "
+      "ref_frame_idx=(%d,%d), reference_mode=%d, "
+      "skip_mode_flag=%d\n\n",
+      cm->current_video_frame, cm->frame_offset, cm->show_frame,
+      cm->show_existing_frame, cm->is_skip_mode_allowed, cm->ref_frame_idx_0,
+      cm->ref_frame_idx_1, cm->reference_mode, cm->skip_mode_flag);
+#endif  // 0
+#endif  // CONFIG_EXT_SKIP
+
   read_compound_tools(cm, rb);
 
   cm->reduced_tx_set_used = aom_rb_read_bit(rb);
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index a734894..db489dc 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2198,9 +2198,11 @@
 #if CONFIG_JNT_COMP && SKIP_MODE_WITH_JNT_COMP
   if (mbmi->skip_mode) {
     const int cur_offset = (int)cm->frame_offset;
-    const int cur_to_fwd = cur_offset - cm->ref_frame_idx_0;
-    const int cur_to_bwd = abs(cm->ref_frame_idx_1 - cur_offset);
-    if (cur_to_fwd != cur_to_bwd && xd->all_one_sided_refs) {
+    int ref_offset[2];
+    get_skip_mode_ref_offsets(cm, ref_offset);
+    const int cur_to_ref0 = cur_offset - ref_offset[0];
+    const int cur_to_ref1 = abs(cur_offset - ref_offset[1]);
+    if (cur_to_ref0 != cur_to_ref1 && xd->all_one_sided_refs) {
       const int comp_index_ctx = get_comp_index_context(cm, xd);
       mbmi->compound_idx = aom_read_symbol(
           r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
@@ -2349,12 +2351,6 @@
 
 #if CONFIG_EXT_SKIP
   mbmi->skip_mode = read_skip_mode(cm, xd, mbmi->segment_id, r);
-#if 0
-  if (mbmi->skip_mode)
-    printf("Frame=%d, frame_offset=%d, (mi_row,mi_col)=(%d,%d), skip_mode=%d\n",
-           cm->current_video_frame, cm->frame_offset, mi_row, mi_col,
-           mbmi->skip_mode);
-#endif  // 0
 
   if (mbmi->skip_mode)
     mbmi->skip = 1;
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index f98ce30..7386ff3 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1389,9 +1389,11 @@
   if (mbmi->skip_mode) {
 #if CONFIG_JNT_COMP && SKIP_MODE_WITH_JNT_COMP
     const int cur_offset = (int)cm->frame_offset;
-    const int cur_to_fwd = cur_offset - cm->ref_frame_idx_0;
-    const int cur_to_bwd = abs(cm->ref_frame_idx_1 - cur_offset);
-    if (cur_to_fwd != cur_to_bwd && xd->all_one_sided_refs) {
+    int ref_offset[2];
+    get_skip_mode_ref_offsets(cm, ref_offset);
+    const int cur_to_ref0 = cur_offset - ref_offset[0];
+    const int cur_to_ref1 = abs(cur_offset - ref_offset[1]);
+    if (cur_to_ref0 != cur_to_ref1 && xd->all_one_sided_refs) {
       const int comp_index_ctx = get_comp_index_context(cm, xd);
       aom_write_symbol(w, mbmi->compound_idx,
                        ec_ctx->compound_index_cdf[comp_index_ctx], 2);
@@ -3805,10 +3807,6 @@
     arf_offset = AOMMIN((MAX_GF_INTERVAL - 1), arf_offset + brf_offset);
     aom_wb_write_literal(wb, arf_offset, FRAME_OFFSET_BITS);
   }
-
-#if CONFIG_EXT_SKIP
-  if (cm->is_skip_mode_allowed) aom_wb_write_bit(wb, cm->skip_mode_flag);
-#endif  // CONFIG_EXT_SKIP
 #endif  // CONFIG_FRAME_MARKER
 
 #if CONFIG_REFERENCE_BUFFER
@@ -3889,6 +3887,11 @@
     if (!use_hybrid_pred) aom_wb_write_bit(wb, use_compound_pred);
 #endif  // !CONFIG_REF_ADAPT
   }
+
+#if CONFIG_EXT_SKIP
+  if (cm->is_skip_mode_allowed) aom_wb_write_bit(wb, cm->skip_mode_flag);
+#endif  // CONFIG_EXT_SKIP
+
   write_compound_tools(cm, wb);
 
   aom_wb_write_bit(wb, cm->reduced_tx_set_used);
@@ -4167,10 +4170,6 @@
     arf_offset = AOMMIN((MAX_GF_INTERVAL - 1), arf_offset + brf_offset);
     aom_wb_write_literal(wb, arf_offset, FRAME_OFFSET_BITS);
   }
-
-#if CONFIG_EXT_SKIP
-  if (cm->is_skip_mode_allowed) aom_wb_write_bit(wb, cm->skip_mode_flag);
-#endif  // CONFIG_EXT_SKIP
 #endif  // CONFIG_FRAME_MARKER
 
 #if CONFIG_REFERENCE_BUFFER
@@ -4245,6 +4244,15 @@
     if (!use_hybrid_pred) aom_wb_write_bit(wb, use_compound_pred);
 #endif  // !CONFIG_REF_ADAPT
   }
+
+#if CONFIG_EXT_SKIP
+#if 0
+  printf("\n[ENCODER] Frame=%d, is_skip_mode_allowed=%d, skip_mode_flag=%d\n\n",
+         (int)cm->frame_offset, cm->is_skip_mode_allowed, cm->skip_mode_flag);
+#endif  // 0
+  if (cm->is_skip_mode_allowed) aom_wb_write_bit(wb, cm->skip_mode_flag);
+#endif  // CONFIG_EXT_SKIP
+
   write_compound_tools(cm, wb);
 
   aom_wb_write_bit(wb, cm->reduced_tx_set_used);
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index ed8b7d8..f52b8a8 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -871,9 +871,11 @@
       }
 #if CONFIG_JNT_COMP && SKIP_MODE_WITH_JNT_COMP
       const int cur_offset = (int)cm->frame_offset;
-      const int cur_to_fwd = cur_offset - cm->ref_frame_idx_0;
-      const int cur_to_bwd = abs(cm->ref_frame_idx_1 - cur_offset);
-      if (cur_to_fwd != cur_to_bwd && xd->all_one_sided_refs) {
+      int ref_offset[2];
+      get_skip_mode_ref_offsets(cm, ref_offset);
+      const int cur_to_ref0 = cur_offset - ref_offset[0];
+      const int cur_to_ref1 = abs(cur_offset - ref_offset[1]);
+      if (cur_to_ref0 != cur_to_ref1 && xd->all_one_sided_refs) {
         const int comp_index_ctx = get_comp_index_context(cm, xd);
         ++counts->compound_index[comp_index_ctx][mbmi->compound_idx];
         if (allow_update_cdf)
@@ -3649,6 +3651,43 @@
 }
 #endif  // CONFIG_FRAME_MARKER
 
+#if CONFIG_EXT_SKIP
+static int check_skip_mode_enabled(AV1_COMP *const cpi) {
+  AV1_COMMON *const cm = &cpi->common;
+
+  av1_setup_skip_mode_allowed(cm);
+  if (!cm->is_skip_mode_allowed) return 0;
+
+  // Turn off skip mode if the temporal distances of the reference pair to the
+  // current frame are different by more than 1 frame.
+  const int cur_offset = (int)cm->frame_offset;
+  int ref_offset[2];
+  get_skip_mode_ref_offsets(cm, ref_offset);
+  const int cur_to_ref0 = cur_offset - ref_offset[0];
+  const int cur_to_ref1 = abs(cur_offset - ref_offset[1]);
+  if (abs(cur_to_ref0 - cur_to_ref1) > 1) return 0;
+
+  // High Latency: Turn off skip mode if all refs are fwd.
+  if (cpi->all_one_sided_refs && cpi->oxcf.lag_in_frames > 0) return 0;
+
+  static const int flag_list[TOTAL_REFS_PER_FRAME] = { 0,
+                                                       AOM_LAST_FLAG,
+                                                       AOM_LAST2_FLAG,
+                                                       AOM_LAST3_FLAG,
+                                                       AOM_GOLD_FLAG,
+                                                       AOM_BWD_FLAG,
+                                                       AOM_ALT2_FLAG,
+                                                       AOM_ALT_FLAG };
+  const int ref_frame[2] = { cm->ref_frame_idx_0 + LAST_FRAME,
+                             cm->ref_frame_idx_1 + LAST_FRAME };
+  if (!(cpi->ref_frame_flags & flag_list[ref_frame[0]]) ||
+      !(cpi->ref_frame_flags & flag_list[ref_frame[1]]))
+    return 0;
+
+  return 1;
+}
+#endif  // CONFIG_EXT_SKIP
+
 static void encode_frame_internal(AV1_COMP *cpi) {
   ThreadData *const td = &cpi->td;
   MACROBLOCK *const x = &td->mb;
@@ -3991,43 +4030,18 @@
 #endif  // CONFIG_FRAME_MARKER
 
 #if CONFIG_EXT_SKIP
-  av1_setup_skip_mode_allowed(cm);
-  cm->skip_mode_flag = cm->is_skip_mode_allowed;
-  if (cm->skip_mode_flag && cpi->all_one_sided_refs &&
-      cpi->oxcf.lag_in_frames > 0) {
-    // High latency: Turn off skip mode if all refs are fwd.
-    cm->skip_mode_flag = 0;
-  }
-  if (cm->skip_mode_flag) {
-    if (cm->reference_mode == SINGLE_REFERENCE) {
-      cm->skip_mode_flag = 0;
-    } else {
-      static const int flag_list[TOTAL_REFS_PER_FRAME] = { 0,
-                                                           AOM_LAST_FLAG,
-                                                           AOM_LAST2_FLAG,
-                                                           AOM_LAST3_FLAG,
-                                                           AOM_GOLD_FLAG,
-                                                           AOM_BWD_FLAG,
-                                                           AOM_ALT2_FLAG,
-                                                           AOM_ALT_FLAG };
-      const int ref_frame[2] = { cm->ref_frame_idx_0 + LAST_FRAME,
-                                 cm->ref_frame_idx_1 + LAST_FRAME };
-      if (!(cpi->ref_frame_flags & flag_list[ref_frame[0]]) ||
-          !(cpi->ref_frame_flags & flag_list[ref_frame[1]]))
-        cm->skip_mode_flag = 0;
-    }
-  }
-  xd->all_one_sided_refs = cm->skip_mode_flag ? cpi->all_one_sided_refs : 0;
+  cm->skip_mode_flag = check_skip_mode_enabled(cpi);
+  xd->all_one_sided_refs = cpi->all_one_sided_refs;
 #if 0
   printf(
       "\nENCODER: Frame=%d, frame_offset=%d, show_frame=%d, "
       "show_existing_frame=%d, is_skip_mode_allowed=%d, "
-      "ref_frame_idx=(%d,%d), frame_reference_mode=%d, "
-      "tpl_frame_ref0_idx=%d, skip_mode_flag=%d, lag_in_frames=%d\n",
+      "ref_frame_idx=(%d,%d), reference_mode=%d, "
+      "skip_mode_flag=%d, lag_in_frames=%d\n",
       cm->current_video_frame, cm->frame_offset, cm->show_frame,
       cm->show_existing_frame, cm->is_skip_mode_allowed, cm->ref_frame_idx_0,
-      cm->ref_frame_idx_1, cm->reference_mode, cm->tpl_frame_ref0_idx,
-      cm->skip_mode_flag, cpi->oxcf.lag_in_frames);
+      cm->ref_frame_idx_1, cm->reference_mode, cm->skip_mode_flag,
+      cpi->oxcf.lag_in_frames);
 #endif  // 0
 #endif  // CONFIG_EXT_SKIP
 
@@ -4203,8 +4217,12 @@
     }
     make_consistent_compound_tools(cm);
 #if CONFIG_EXT_SKIP
-    if (frame_is_intra_only(cm) || cm->reference_mode == SINGLE_REFERENCE ||
-        rdc->skip_mode_used_flag == 0)
+    // Re-check on the skip mode status as reference mode may have been changed.
+    if (frame_is_intra_only(cm) || cm->reference_mode == SINGLE_REFERENCE) {
+      cm->is_skip_mode_allowed = 0;
+      cm->skip_mode_flag = 0;
+    }
+    if (cm->skip_mode_flag && rdc->skip_mode_used_flag == 0)
       cm->skip_mode_flag = 0;
 #endif  // CONFIG_EXT_SKIP
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9498087..d55ded9 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -10647,9 +10647,11 @@
     x->compound_idx = 1;  // COMPOUND_AVERAGE
 #if SKIP_MODE_WITH_JNT_COMP
     const int cur_offset = (int)cm->frame_offset;
-    const int cur_to_fwd = cur_offset - cm->ref_frame_idx_0;
-    const int cur_to_bwd = abs(cm->ref_frame_idx_1 - cur_offset);
-    if (cur_to_fwd != cur_to_bwd && xd->all_one_sided_refs) {
+    int ref_offset[2];
+    get_skip_mode_ref_offsets(cm, ref_offset);
+    const int cur_to_ref0 = cur_offset - ref_offset[0];
+    const int cur_to_ref1 = abs(cur_offset - ref_offset[0]);
+    if (cur_to_ref0 != cur_to_ref1 && xd->all_one_sided_refs) {
       // Decide on the JNT_COMP mode.
       int64_t best_skip_mode_rd = INT64_MAX;
       int best_compound_idx = 0;