[NORMATIVE] Unify context design for single ref

The CL makes the context design for single reference frame coding the
same as that for the compound reference frame coding. There are 3
contexts designed for each of the binary symbols for the single
reference frame scenario, and the designed contexts simply rely on the
counts of the references used in the neighboring two blocks.

Once this CL is merged, the coding of the reference frames, regardless
of single prediction or compound prediction, will all follow the same
context design pattern for all the binary symbols. The design logic is
much simpler and the lines of code for each binary symbol context
identification are reduced by 80%.

Further, this CL has obtained a small coding gain for 30 frames with
the default coding tools:

lowres: avg_psnr -0.015%; ovr_psnr -0.021%; ssim -0.002%
midres: avg_psnr -0.108%; ovr_psnr -0.139%; ssim -0.135%

BUG=aomedia:1402
BUG=aomedia:973

Change-Id: Ia72a1d18e85ac3a05308675b60b95f80f2219c46
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 832dfb2..d7c547f 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -235,51 +235,40 @@
 #endif  // CONFIG_EXT_COMP_REFS
 
 static const aom_cdf_prob
-    default_comp_ref_cdf[COMP_REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)] = {
+    default_comp_ref_cdf[REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)] = {
       { { AOM_CDF2(4412) }, { AOM_CDF2(11499) }, { AOM_CDF2(478) } },
       { { AOM_CDF2(17926) }, { AOM_CDF2(26419) }, { AOM_CDF2(8615) } },
       { { AOM_CDF2(30449) }, { AOM_CDF2(31477) }, { AOM_CDF2(28035) } }
     };
 
 static const aom_cdf_prob
-    default_comp_bwdref_cdf[COMP_REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)] = {
+    default_comp_bwdref_cdf[REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)] = {
       { { AOM_CDF2(2762) }, { AOM_CDF2(1614) } },
       { { AOM_CDF2(17976) }, { AOM_CDF2(15912) } },
       { { AOM_CDF2(30894) }, { AOM_CDF2(30639) } },
     };
 
+// TODO(zoelu): To use aom_entropy_optimizer to update the following defaults.
 static const aom_cdf_prob default_single_ref_cdf[REF_CONTEXTS][SINGLE_REFS - 1]
                                                 [CDF_SIZE(2)] = {
-                                                  { { AOM_CDF2(4623) },
-                                                    { AOM_CDF2(2110) },
-                                                    { AOM_CDF2(4132) },
-                                                    { AOM_CDF2(7309) },
-                                                    { AOM_CDF2(1392) },
-                                                    { AOM_CDF2(1781) } },
-                                                  { { AOM_CDF2(8659) },
-                                                    { AOM_CDF2(16372) },
-                                                    { AOM_CDF2(9371) },
-                                                    { AOM_CDF2(16322) },
-                                                    { AOM_CDF2(6216) },
-                                                    { AOM_CDF2(15834) } },
-                                                  { { AOM_CDF2(17353) },
-                                                    { AOM_CDF2(30182) },
-                                                    { AOM_CDF2(16300) },
-                                                    { AOM_CDF2(21702) },
-                                                    { AOM_CDF2(10365) },
-                                                    { AOM_CDF2(30486) } },
-                                                  { { AOM_CDF2(16384) },
-                                                    { AOM_CDF2(16384) },
-                                                    { AOM_CDF2(24426) },
-                                                    { AOM_CDF2(26972) },
-                                                    { AOM_CDF2(14760) },
-                                                    { AOM_CDF2(16384) } },
-                                                  { { AOM_CDF2(28634) },
-                                                    { AOM_CDF2(16384) },
-                                                    { AOM_CDF2(29425) },
-                                                    { AOM_CDF2(30969) },
-                                                    { AOM_CDF2(26676) },
-                                                    { AOM_CDF2(16384) } }
+                                                  { { AOM_CDF2(6500) },
+                                                    { AOM_CDF2(3089) },
+                                                    { AOM_CDF2(4026) },
+                                                    { AOM_CDF2(8549) },
+                                                    { AOM_CDF2(184) },
+                                                    { AOM_CDF2(2264) } },
+                                                  { { AOM_CDF2(17037) },
+                                                    { AOM_CDF2(19408) },
+                                                    { AOM_CDF2(15521) },
+                                                    { AOM_CDF2(27640) },
+                                                    { AOM_CDF2(5047) },
+                                                    { AOM_CDF2(16251) } },
+                                                  { { AOM_CDF2(28292) },
+                                                    { AOM_CDF2(30427) },
+                                                    { AOM_CDF2(29003) },
+                                                    { AOM_CDF2(31436) },
+                                                    { AOM_CDF2(28466) },
+                                                    { AOM_CDF2(29371) } }
                                                 };
 
 // TODO(huisu): tune these cdfs
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index ac78466..f28596d 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -132,8 +132,8 @@
   aom_cdf_prob uni_comp_ref_cdf[UNI_COMP_REF_CONTEXTS][UNIDIR_COMP_REFS - 1]
                                [CDF_SIZE(2)];
 #endif  // CONFIG_EXT_COMP_REFS
-  aom_cdf_prob comp_ref_cdf[COMP_REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)];
-  aom_cdf_prob comp_bwdref_cdf[COMP_REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)];
+  aom_cdf_prob comp_ref_cdf[REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)];
+  aom_cdf_prob comp_bwdref_cdf[REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)];
   aom_cdf_prob txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)];
 #if CONFIG_JNT_COMP
   aom_cdf_prob compound_index_cdf[COMP_INDEX_CONTEXTS][CDF_SIZE(2)];
@@ -274,8 +274,8 @@
   unsigned int uni_comp_ref[UNI_COMP_REF_CONTEXTS][UNIDIR_COMP_REFS - 1][2];
 #endif  // CONFIG_EXT_COMP_REFS
   unsigned int single_ref[REF_CONTEXTS][SINGLE_REFS - 1][2];
-  unsigned int comp_ref[COMP_REF_CONTEXTS][FWD_REFS - 1][2];
-  unsigned int comp_bwdref[COMP_REF_CONTEXTS][BWD_REFS - 1][2];
+  unsigned int comp_ref[REF_CONTEXTS][FWD_REFS - 1][2];
+  unsigned int comp_bwdref[REF_CONTEXTS][BWD_REFS - 1][2];
 #if CONFIG_INTRABC
   unsigned int intrabc[2];
 #endif  // CONFIG_INTRABC
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 02335df..e44d97b 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -579,8 +579,7 @@
 
 #define INTRA_INTER_CONTEXTS 4
 #define COMP_INTER_CONTEXTS 5
-#define COMP_REF_CONTEXTS 3
-#define REF_CONTEXTS 5
+#define REF_CONTEXTS 3
 
 #if CONFIG_EXT_COMP_REFS
 #define COMP_REF_TYPE_CONTEXTS 5
diff --git a/av1/common/pred_common.c b/av1/common/pred_common.c
index ee36f01..d1e28b6 100644
--- a/av1/common/pred_common.c
+++ b/av1/common/pred_common.c
@@ -372,10 +372,11 @@
 }
 #endif  // CONFIG_EXT_COMP_REFS
 
-// Returns a context number for the given MB prediction signal
-// Signal the first reference frame for a compound mode be either
-// GOLDEN/LAST3, or LAST/LAST2.
-int av1_get_pred_context_comp_ref_p(const MACROBLOCKD *xd) {
+// == Common context functions for both comp and single ref ==
+//
+// Obtain contexts to signal a reference frame to be either LAST/LAST2 or
+// LAST3/GOLDEN.
+static int get_pred_context_ll2_or_l3gld(const MACROBLOCKD *xd) {
   const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Count of LAST + LAST2
@@ -388,14 +389,12 @@
                                ? 1
                                : ((last_last2_count < last3_gld_count) ? 0 : 2);
 
-  assert(pred_context >= 0 && pred_context < COMP_REF_CONTEXTS);
+  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
   return pred_context;
 }
 
-// Returns a context number for the given MB prediction signal
-// Signal the first reference frame for a compound mode be LAST,
-// conditioning on that it is known either LAST/LAST2.
-int av1_get_pred_context_comp_ref_p1(const MACROBLOCKD *xd) {
+// Obtain contexts to signal a reference frame to be either LAST or LAST2.
+static int get_pred_context_last_or_last2(const MACROBLOCKD *xd) {
   const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Count of LAST
@@ -406,14 +405,12 @@
   const int pred_context =
       (last_count == last2_count) ? 1 : ((last_count < last2_count) ? 0 : 2);
 
-  assert(pred_context >= 0 && pred_context < COMP_REF_CONTEXTS);
+  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
   return pred_context;
 }
 
-// Returns a context number for the given MB prediction signal
-// Signal the first reference frame for a compound mode be GOLDEN,
-// conditioning on that it is known either GOLDEN or LAST3.
-int av1_get_pred_context_comp_ref_p2(const MACROBLOCKD *xd) {
+// Obtain contexts to signal a reference frame to be either LAST3 or GOLDEN.
+static int get_pred_context_last3_or_gld(const MACROBLOCKD *xd) {
   const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Count of LAST3
@@ -424,7 +421,7 @@
   const int pred_context =
       (last3_count == gld_count) ? 1 : ((last3_count < gld_count) ? 0 : 2);
 
-  assert(pred_context >= 0 && pred_context < COMP_REF_CONTEXTS);
+  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
   return pred_context;
 }
 
@@ -441,7 +438,7 @@
   const int pred_context =
       (brfarf2_count == arf_count) ? 1 : ((brfarf2_count < arf_count) ? 0 : 2);
 
-  assert(pred_context >= 0 && pred_context < COMP_REF_CONTEXTS);
+  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
   return pred_context;
 }
 
@@ -457,10 +454,33 @@
   const int pred_context =
       (brf_count == arf2_count) ? 1 : ((brf_count < arf2_count) ? 0 : 2);
 
-  assert(pred_context >= 0 && pred_context < COMP_REF_CONTEXTS);
+  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
   return pred_context;
 }
 
+// == Context functions for comp ref ==
+//
+// Returns a context number for the given MB prediction signal
+// Signal the first reference frame for a compound mode be either
+// GOLDEN/LAST3, or LAST/LAST2.
+int av1_get_pred_context_comp_ref_p(const MACROBLOCKD *xd) {
+  return get_pred_context_ll2_or_l3gld(xd);
+}
+
+// Returns a context number for the given MB prediction signal
+// Signal the first reference frame for a compound mode be LAST,
+// conditioning on that it is known either LAST/LAST2.
+int av1_get_pred_context_comp_ref_p1(const MACROBLOCKD *xd) {
+  return get_pred_context_last_or_last2(xd);
+}
+
+// Returns a context number for the given MB prediction signal
+// Signal the first reference frame for a compound mode be GOLDEN,
+// conditioning on that it is known either GOLDEN or LAST3.
+int av1_get_pred_context_comp_ref_p2(const MACROBLOCKD *xd) {
+  return get_pred_context_last3_or_gld(xd);
+}
+
 // Signal the 2nd reference frame for a compound mode be either
 // ALTREF, or ALTREF2/BWDREF.
 int av1_get_pred_context_comp_bwdref_p(const MACROBLOCKD *xd) {
@@ -473,63 +493,22 @@
   return get_pred_context_brf_or_arf2(xd);
 }
 
+// == Context functions for single ref ==
+//
 // For the bit to signal whether the single reference is a forward reference
 // frame or a backward reference frame.
 int av1_get_pred_context_single_ref_p1(const MACROBLOCKD *xd) {
-  int pred_context;
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int has_above = xd->up_available;
-  const int has_left = xd->left_available;
+  const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
-  // Note:
-  // The mode info data structure has a one element border above and to the
-  // left of the entries correpsonding to real macroblocks.
-  // The prediction flags in these dummy entries are initialised to 0.
-  if (has_above && has_left) {  // both edges available
-    const int above_intra = !is_inter_block(above_mbmi);
-    const int left_intra = !is_inter_block(left_mbmi);
+  // Count of forward reference frames
+  const int fwd_count = ref_counts[LAST_FRAME] + ref_counts[LAST2_FRAME] +
+                        ref_counts[LAST3_FRAME] + ref_counts[GOLDEN_FRAME];
+  // Count of backward reference frames
+  const int bwd_count = ref_counts[BWDREF_FRAME] + ref_counts[ALTREF2_FRAME] +
+                        ref_counts[ALTREF_FRAME];
 
-    if (above_intra && left_intra) {  // intra/intra
-      pred_context = 2;
-    } else if (above_intra || left_intra) {  // intra/inter or inter/intra
-      const MB_MODE_INFO *edge_mbmi = above_intra ? left_mbmi : above_mbmi;
-
-      if (!has_second_ref(edge_mbmi))  // single
-        pred_context = 4 * (!CHECK_BACKWARD_REFS(edge_mbmi->ref_frame[0]));
-      else  // comp
-        pred_context = 2;
-    } else {  // inter/inter
-      const int above_has_second = has_second_ref(above_mbmi);
-      const int left_has_second = has_second_ref(left_mbmi);
-
-      const MV_REFERENCE_FRAME above0 = above_mbmi->ref_frame[0];
-      const MV_REFERENCE_FRAME left0 = left_mbmi->ref_frame[0];
-
-      if (above_has_second && left_has_second) {  // comp/comp
-        pred_context = 2;
-      } else if (above_has_second || left_has_second) {  // single/comp
-        const MV_REFERENCE_FRAME rfs = !above_has_second ? above0 : left0;
-
-        pred_context = (!CHECK_BACKWARD_REFS(rfs)) ? 4 : 1;
-      } else {  // single/single
-        pred_context = 2 * (!CHECK_BACKWARD_REFS(above0)) +
-                       2 * (!CHECK_BACKWARD_REFS(left0));
-      }
-    }
-  } else if (has_above || has_left) {  // one edge available
-    const MB_MODE_INFO *edge_mbmi = has_above ? above_mbmi : left_mbmi;
-    if (!is_inter_block(edge_mbmi)) {  // intra
-      pred_context = 2;
-    } else {                           // inter
-      if (!has_second_ref(edge_mbmi))  // single
-        pred_context = 4 * (!CHECK_BACKWARD_REFS(edge_mbmi->ref_frame[0]));
-      else  // comp
-        pred_context = 2;
-    }
-  } else {  // no edges available
-    pred_context = 2;
-  }
+  const int pred_context =
+      (fwd_count == bwd_count) ? 1 : ((fwd_count < bwd_count) ? 0 : 2);
 
   assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
   return pred_context;
@@ -542,282 +521,22 @@
   return get_pred_context_brfarf2_or_arf(xd);
 }
 
-#define CHECK_LAST_OR_LAST2(ref_frame) \
-  ((ref_frame == LAST_FRAME) || (ref_frame == LAST2_FRAME))
-
 // For the bit to signal whether the single reference is LAST3/GOLDEN or
 // LAST2/LAST, knowing that it shall be either of these 2 choices.
 int av1_get_pred_context_single_ref_p3(const MACROBLOCKD *xd) {
-  int pred_context;
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int has_above = xd->up_available;
-  const int has_left = xd->left_available;
-
-  // Note:
-  // The mode info data structure has a one element border above and to the
-  // left of the entries correpsonding to real macroblocks.
-  // The prediction flags in these dummy entries are initialised to 0.
-  if (has_above && has_left) {  // both edges available
-    const int above_intra = !is_inter_block(above_mbmi);
-    const int left_intra = !is_inter_block(left_mbmi);
-
-    if (above_intra && left_intra) {  // intra/intra
-      pred_context = 2;
-    } else if (above_intra || left_intra) {  // intra/inter or inter/intra
-      const MB_MODE_INFO *edge_mbmi = above_intra ? left_mbmi : above_mbmi;
-      if (!has_second_ref(edge_mbmi)) {  // single
-        if (CHECK_BACKWARD_REFS(edge_mbmi->ref_frame[0]))
-          pred_context = 3;
-        else
-          pred_context = 4 * CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[0]);
-      } else {  // comp
-        pred_context = 1 + 2 * (CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[0]) ||
-                                CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[1]));
-      }
-    } else {  // inter/inter
-      const int above_has_second = has_second_ref(above_mbmi);
-      const int left_has_second = has_second_ref(left_mbmi);
-      const MV_REFERENCE_FRAME above0 = above_mbmi->ref_frame[0];
-      const MV_REFERENCE_FRAME above1 = above_mbmi->ref_frame[1];
-      const MV_REFERENCE_FRAME left0 = left_mbmi->ref_frame[0];
-      const MV_REFERENCE_FRAME left1 = left_mbmi->ref_frame[1];
-
-      if (above_has_second && left_has_second) {  // comp/comp
-        if (above0 == left0 && above1 == left1)
-          pred_context =
-              3 * (CHECK_LAST_OR_LAST2(above0) || CHECK_LAST_OR_LAST2(above1) ||
-                   CHECK_LAST_OR_LAST2(left0) || CHECK_LAST_OR_LAST2(left1));
-        else
-          pred_context = 2;
-      } else if (above_has_second || left_has_second) {  // single/comp
-        const MV_REFERENCE_FRAME rfs = !above_has_second ? above0 : left0;
-        const MV_REFERENCE_FRAME crf1 = above_has_second ? above0 : left0;
-        const MV_REFERENCE_FRAME crf2 = above_has_second ? above1 : left1;
-
-        if (CHECK_LAST_OR_LAST2(rfs))
-          pred_context =
-              3 + (CHECK_LAST_OR_LAST2(crf1) || CHECK_LAST_OR_LAST2(crf2));
-        else if (CHECK_GOLDEN_OR_LAST3(rfs))
-          pred_context =
-              (CHECK_LAST_OR_LAST2(crf1) || CHECK_LAST_OR_LAST2(crf2));
-        else
-          pred_context =
-              1 + 2 * (CHECK_LAST_OR_LAST2(crf1) || CHECK_LAST_OR_LAST2(crf2));
-      } else {  // single/single
-        if (CHECK_BACKWARD_REFS(above0) && CHECK_BACKWARD_REFS(left0)) {
-          pred_context = 2 + (above0 == left0);
-        } else if (CHECK_BACKWARD_REFS(above0) || CHECK_BACKWARD_REFS(left0)) {
-          const MV_REFERENCE_FRAME edge0 =
-              CHECK_BACKWARD_REFS(above0) ? left0 : above0;
-          pred_context = 4 * CHECK_LAST_OR_LAST2(edge0);
-        } else {
-          pred_context =
-              2 * CHECK_LAST_OR_LAST2(above0) + 2 * CHECK_LAST_OR_LAST2(left0);
-        }
-      }
-    }
-  } else if (has_above || has_left) {  // one edge available
-    const MB_MODE_INFO *edge_mbmi = has_above ? above_mbmi : left_mbmi;
-
-    if (!is_inter_block(edge_mbmi) ||
-        (CHECK_BACKWARD_REFS(edge_mbmi->ref_frame[0]) &&
-         !has_second_ref(edge_mbmi)))
-      pred_context = 2;
-    else if (!has_second_ref(edge_mbmi))  // single
-      pred_context = 4 * (CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[0]));
-    else  // comp
-      pred_context = 3 * (CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[0]) ||
-                          CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[1]));
-  } else {  // no edges available (2)
-    pred_context = 2;
-  }
-
-  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
-  return pred_context;
+  return get_pred_context_ll2_or_l3gld(xd);
 }
 
 // For the bit to signal whether the single reference is LAST2_FRAME or
 // LAST_FRAME, knowing that it shall be either of these 2 choices.
-//
-// NOTE(zoeliu): The probability of ref_frame[0] is LAST2_FRAME, conditioning
-// on it is either LAST2_FRAME/LAST_FRAME.
 int av1_get_pred_context_single_ref_p4(const MACROBLOCKD *xd) {
-  int pred_context;
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int has_above = xd->up_available;
-  const int has_left = xd->left_available;
-
-  // Note:
-  // The mode info data structure has a one element border above and to the
-  // left of the entries correpsonding to real macroblocks.
-  // The prediction flags in these dummy entries are initialised to 0.
-  if (has_above && has_left) {  // both edges available
-    const int above_intra = !is_inter_block(above_mbmi);
-    const int left_intra = !is_inter_block(left_mbmi);
-
-    if (above_intra && left_intra) {  // intra/intra
-      pred_context = 2;
-    } else if (above_intra || left_intra) {  // intra/inter or inter/intra
-      const MB_MODE_INFO *edge_mbmi = above_intra ? left_mbmi : above_mbmi;
-      if (!has_second_ref(edge_mbmi)) {  // single
-        if (!CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[0]))
-          pred_context = 3;
-        else
-          pred_context = 4 * (edge_mbmi->ref_frame[0] == LAST_FRAME);
-      } else {  // comp
-        pred_context = 1 + 2 * (edge_mbmi->ref_frame[0] == LAST_FRAME ||
-                                edge_mbmi->ref_frame[1] == LAST_FRAME);
-      }
-    } else {  // inter/inter
-      const int above_has_second = has_second_ref(above_mbmi);
-      const int left_has_second = has_second_ref(left_mbmi);
-      const MV_REFERENCE_FRAME above0 = above_mbmi->ref_frame[0];
-      const MV_REFERENCE_FRAME above1 = above_mbmi->ref_frame[1];
-      const MV_REFERENCE_FRAME left0 = left_mbmi->ref_frame[0];
-      const MV_REFERENCE_FRAME left1 = left_mbmi->ref_frame[1];
-
-      if (above_has_second && left_has_second) {  // comp/comp
-        if (above0 == left0 && above1 == left1)
-          pred_context = 3 * (above0 == LAST_FRAME || above1 == LAST_FRAME ||
-                              left0 == LAST_FRAME || left1 == LAST_FRAME);
-        else
-          pred_context = 2;
-      } else if (above_has_second || left_has_second) {  // single/comp
-        const MV_REFERENCE_FRAME rfs = !above_has_second ? above0 : left0;
-        const MV_REFERENCE_FRAME crf1 = above_has_second ? above0 : left0;
-        const MV_REFERENCE_FRAME crf2 = above_has_second ? above1 : left1;
-
-        if (rfs == LAST_FRAME)
-          pred_context = 3 + (crf1 == LAST_FRAME || crf2 == LAST_FRAME);
-        else if (rfs == LAST2_FRAME)
-          pred_context = (crf1 == LAST_FRAME || crf2 == LAST_FRAME);
-        else
-          pred_context = 1 + 2 * (crf1 == LAST_FRAME || crf2 == LAST_FRAME);
-      } else {  // single/single
-        if (!CHECK_LAST_OR_LAST2(above0) && !CHECK_LAST_OR_LAST2(left0)) {
-          pred_context = 2 + (above0 == left0);
-        } else if (!CHECK_LAST_OR_LAST2(above0) ||
-                   !CHECK_LAST_OR_LAST2(left0)) {
-          const MV_REFERENCE_FRAME edge0 =
-              !CHECK_LAST_OR_LAST2(above0) ? left0 : above0;
-          pred_context = 4 * (edge0 == LAST_FRAME);
-        } else {
-          pred_context = 2 * (above0 == LAST_FRAME) + 2 * (left0 == LAST_FRAME);
-        }
-      }
-    }
-  } else if (has_above || has_left) {  // one edge available
-    const MB_MODE_INFO *edge_mbmi = has_above ? above_mbmi : left_mbmi;
-
-    if (!is_inter_block(edge_mbmi) ||
-        (!CHECK_LAST_OR_LAST2(edge_mbmi->ref_frame[0]) &&
-         !has_second_ref(edge_mbmi)))
-      pred_context = 2;
-    else if (!has_second_ref(edge_mbmi))  // single
-      pred_context = 4 * (edge_mbmi->ref_frame[0] == LAST_FRAME);
-    else  // comp
-      pred_context = 3 * (edge_mbmi->ref_frame[0] == LAST_FRAME ||
-                          edge_mbmi->ref_frame[1] == LAST_FRAME);
-  } else {  // no edges available (2)
-    pred_context = 2;
-  }
-
-  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
-  return pred_context;
+  return get_pred_context_last_or_last2(xd);
 }
 
 // For the bit to signal whether the single reference is GOLDEN_FRAME or
 // LAST3_FRAME, knowing that it shall be either of these 2 choices.
-//
-// NOTE(zoeliu): The probability of ref_frame[0] is GOLDEN_FRAME, conditioning
-// on it is either GOLDEN_FRAME/LAST3_FRAME.
 int av1_get_pred_context_single_ref_p5(const MACROBLOCKD *xd) {
-  int pred_context;
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int has_above = xd->up_available;
-  const int has_left = xd->left_available;
-
-  // Note:
-  // The mode info data structure has a one element border above and to the
-  // left of the entries correpsonding to real macroblocks.
-  // The prediction flags in these dummy entries are initialised to 0.
-  if (has_above && has_left) {  // both edges available
-    const int above_intra = !is_inter_block(above_mbmi);
-    const int left_intra = !is_inter_block(left_mbmi);
-
-    if (above_intra && left_intra) {  // intra/intra
-      pred_context = 2;
-    } else if (above_intra || left_intra) {  // intra/inter or inter/intra
-      const MB_MODE_INFO *edge_mbmi = above_intra ? left_mbmi : above_mbmi;
-      if (!has_second_ref(edge_mbmi)) {  // single
-        if (!CHECK_GOLDEN_OR_LAST3(edge_mbmi->ref_frame[0]))
-          pred_context = 3;
-        else
-          pred_context = 4 * (edge_mbmi->ref_frame[0] == LAST3_FRAME);
-      } else {  // comp
-        pred_context = 1 + 2 * (edge_mbmi->ref_frame[0] == LAST3_FRAME ||
-                                edge_mbmi->ref_frame[1] == LAST3_FRAME);
-      }
-    } else {  // inter/inter
-      const int above_has_second = has_second_ref(above_mbmi);
-      const int left_has_second = has_second_ref(left_mbmi);
-      const MV_REFERENCE_FRAME above0 = above_mbmi->ref_frame[0];
-      const MV_REFERENCE_FRAME above1 = above_mbmi->ref_frame[1];
-      const MV_REFERENCE_FRAME left0 = left_mbmi->ref_frame[0];
-      const MV_REFERENCE_FRAME left1 = left_mbmi->ref_frame[1];
-
-      if (above_has_second && left_has_second) {  // comp/comp
-        if (above0 == left0 && above1 == left1)
-          pred_context = 3 * (above0 == LAST3_FRAME || above1 == LAST3_FRAME ||
-                              left0 == LAST3_FRAME || left1 == LAST3_FRAME);
-        else
-          pred_context = 2;
-      } else if (above_has_second || left_has_second) {  // single/comp
-        const MV_REFERENCE_FRAME rfs = !above_has_second ? above0 : left0;
-        const MV_REFERENCE_FRAME crf1 = above_has_second ? above0 : left0;
-        const MV_REFERENCE_FRAME crf2 = above_has_second ? above1 : left1;
-
-        if (rfs == LAST3_FRAME)
-          pred_context = 3 + (crf1 == LAST3_FRAME || crf2 == LAST3_FRAME);
-        else if (rfs == GOLDEN_FRAME)
-          pred_context = (crf1 == LAST3_FRAME || crf2 == LAST3_FRAME);
-        else
-          pred_context = 1 + 2 * (crf1 == LAST3_FRAME || crf2 == LAST3_FRAME);
-      } else {  // single/single
-        if (!CHECK_GOLDEN_OR_LAST3(above0) && !CHECK_GOLDEN_OR_LAST3(left0)) {
-          pred_context = 2 + (above0 == left0);
-        } else if (!CHECK_GOLDEN_OR_LAST3(above0) ||
-                   !CHECK_GOLDEN_OR_LAST3(left0)) {
-          const MV_REFERENCE_FRAME edge0 =
-              !CHECK_GOLDEN_OR_LAST3(above0) ? left0 : above0;
-          pred_context = 4 * (edge0 == LAST3_FRAME);
-        } else {
-          pred_context =
-              2 * (above0 == LAST3_FRAME) + 2 * (left0 == LAST3_FRAME);
-        }
-      }
-    }
-  } else if (has_above || has_left) {  // one edge available
-    const MB_MODE_INFO *edge_mbmi = has_above ? above_mbmi : left_mbmi;
-
-    if (!is_inter_block(edge_mbmi) ||
-        (!CHECK_GOLDEN_OR_LAST3(edge_mbmi->ref_frame[0]) &&
-         !has_second_ref(edge_mbmi)))
-      pred_context = 2;
-    else if (!has_second_ref(edge_mbmi))  // single
-      pred_context = 4 * (edge_mbmi->ref_frame[0] == LAST3_FRAME);
-    else  // comp
-      pred_context = 3 * (edge_mbmi->ref_frame[0] == LAST3_FRAME ||
-                          edge_mbmi->ref_frame[1] == LAST3_FRAME);
-  } else {  // no edges available (2)
-    pred_context = 2;
-  }
-
-  assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
-  return pred_context;
+  return get_pred_context_last3_or_gld(xd);
 }
 
 // For the bit to signal whether the single reference is ALTREF2_FRAME or
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index d438c98..08e1da1 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -299,10 +299,10 @@
 #endif  // CONFIG_EXT_COMP_REFS
   // Cost for signaling ref_frame[0] (LAST_FRAME, LAST2_FRAME, LAST3_FRAME or
   // GOLDEN_FRAME) in bidir-comp mode.
-  int comp_ref_cost[COMP_REF_CONTEXTS][FWD_REFS - 1][2];
+  int comp_ref_cost[REF_CONTEXTS][FWD_REFS - 1][2];
   // Cost for signaling ref_frame[1] (ALTREF_FRAME, ALTREF2_FRAME, or
   // BWDREF_FRAME) in bidir-comp mode.
-  int comp_bwdref_cost[COMP_REF_CONTEXTS][BWD_REFS - 1][2];
+  int comp_bwdref_cost[REF_CONTEXTS][BWD_REFS - 1][2];
   int inter_compound_mode_cost[INTER_MODE_CONTEXTS][INTER_COMPOUND_MODES];
 #if CONFIG_JNT_COMP
   int compound_type_cost[BLOCK_SIZES_ALL][COMPOUND_TYPES - 1];
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 401fd03..4a13638 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -270,14 +270,14 @@
     }
 #endif  // CONFIG_EXT_COMP_REFS
 
-    for (i = 0; i < COMP_REF_CONTEXTS; ++i) {
+    for (i = 0; i < REF_CONTEXTS; ++i) {
       for (j = 0; j < FWD_REFS - 1; ++j) {
         av1_cost_tokens_from_cdf(x->comp_ref_cost[i][j], fc->comp_ref_cdf[i][j],
                                  NULL);
       }
     }
 
-    for (i = 0; i < COMP_REF_CONTEXTS; ++i) {
+    for (i = 0; i < REF_CONTEXTS; ++i) {
       for (j = 0; j < BWD_REFS - 1; ++j) {
         av1_cost_tokens_from_cdf(x->comp_bwdref_cost[i][j],
                                  fc->comp_bwdref_cdf[i][j], NULL);
diff --git a/tools/aom_entropy_optimizer.c b/tools/aom_entropy_optimizer.c
index 1fa0bf5..b5a1bf3 100644
--- a/tools/aom_entropy_optimizer.c
+++ b/tools/aom_entropy_optimizer.c
@@ -470,21 +470,21 @@
       "default_single_ref_cdf[REF_CONTEXTS][SINGLE_REFS - 1][CDF_SIZE(2)]");
 
   /* ext_refs experiment */
-  cts_each_dim[0] = COMP_REF_CONTEXTS;
+  cts_each_dim[0] = REF_CONTEXTS;
   cts_each_dim[1] = FWD_REFS - 1;
   cts_each_dim[2] = 2;
   optimize_cdf_table(
       &fc.comp_ref[0][0][0], probsfile, 3, cts_each_dim,
       "static const aom_cdf_prob\n"
-      "default_comp_ref_cdf[COMP_REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)]");
+      "default_comp_ref_cdf[REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)]");
 
-  cts_each_dim[0] = COMP_REF_CONTEXTS;
+  cts_each_dim[0] = REF_CONTEXTS;
   cts_each_dim[1] = BWD_REFS - 1;
   cts_each_dim[2] = 2;
   optimize_cdf_table(
       &fc.comp_bwdref[0][0][0], probsfile, 3, cts_each_dim,
       "static const aom_cdf_prob\n"
-      "default_comp_bwdref_cdf[COMP_REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)]");
+      "default_comp_bwdref_cdf[REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)]");
 
   /* Transform size */
   cts_each_dim[0] = TXFM_PARTITION_CONTEXTS;