diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 5c5d91c..2a82927 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -617,6 +617,10 @@
   CANDIDATE_MV ref_mv_stack[MODE_CTX_REF_FRAMES][MAX_REF_MV_STACK_SIZE];
   uint8_t is_sec_rect;
 
+  // Counts of each reference frame in the above and left neighboring blocks.
+  // NOTE: Take into account both single and comp references.
+  uint8_t neighbors_ref_counts[TOTAL_REFS_PER_FRAME];
+
   FRAME_CONTEXT *tile_ctx;
   /* Bit depth: 8, 10, 12 */
   int bd;
diff --git a/av1/common/mvref_common.h b/av1/common/mvref_common.h
index 5fced59..0dc262a 100644
--- a/av1/common/mvref_common.h
+++ b/av1/common/mvref_common.h
@@ -439,6 +439,33 @@
 #endif  // CONFIG_MFMV
 #endif  // CONFIG_FRAME_MARKER
 
+static INLINE void av1_collect_neighbors_ref_counts(MACROBLOCKD *const xd) {
+  av1_zero(xd->neighbors_ref_counts);
+
+  uint8_t *const ref_counts = xd->neighbors_ref_counts;
+
+  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
+  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
+  const int above_in_image = xd->up_available;
+  const int left_in_image = xd->left_available;
+
+  // Above neighbor
+  if (above_in_image && is_inter_block(above_mbmi)) {
+    ref_counts[above_mbmi->ref_frame[0]]++;
+    if (has_second_ref(above_mbmi)) {
+      ref_counts[above_mbmi->ref_frame[1]]++;
+    }
+  }
+
+  // Left neighbor
+  if (left_in_image && is_inter_block(left_mbmi)) {
+    ref_counts[left_mbmi->ref_frame[0]]++;
+    if (has_second_ref(left_mbmi)) {
+      ref_counts[left_mbmi->ref_frame[1]]++;
+    }
+  }
+}
+
 void av1_copy_frame_mvs(const AV1_COMMON *const cm, MODE_INFO *mi, int mi_row,
                         int mi_col, int x_mis, int y_mis);
 
diff --git a/av1/common/pred_common.c b/av1/common/pred_common.c
index 92f7415..24cb6a4 100644
--- a/av1/common/pred_common.c
+++ b/av1/common/pred_common.c
@@ -228,8 +228,6 @@
 }
 
 #if CONFIG_EXT_COMP_REFS
-// TODO(zoeliu): To try on the design of 3 contexts, instead of 5:
-//               COMP_REF_TYPE_CONTEXTS = 3
 int av1_get_comp_reference_type_context(const MACROBLOCKD *xd) {
   int pred_context;
   const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
@@ -310,44 +308,16 @@
 // 3 contexts: Voting is used to compare the count of forward references with
 //             that of backward references from the spatial neighbors.
 int av1_get_pred_context_uni_comp_ref_p(const MACROBLOCKD *xd) {
-  int pred_context;
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int above_in_image = xd->up_available;
-  const int left_in_image = xd->left_available;
+  const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Count of forward references (L, L2, L3, or G)
-  int frf_count = 0;
+  const int frf_count = ref_counts[LAST_FRAME] + ref_counts[LAST2_FRAME] +
+                        ref_counts[LAST3_FRAME] + ref_counts[GOLDEN_FRAME];
   // Count of backward references (B or A)
-  int brf_count = 0;
+  const int brf_count = ref_counts[BWDREF_FRAME] + ref_counts[ALTREF2_FRAME] +
+                        ref_counts[ALTREF_FRAME];
 
-  if (above_in_image && is_inter_block(above_mbmi)) {
-    if (above_mbmi->ref_frame[0] <= GOLDEN_FRAME)
-      ++frf_count;
-    else
-      ++brf_count;
-    if (has_second_ref(above_mbmi)) {
-      if (above_mbmi->ref_frame[1] <= GOLDEN_FRAME)
-        ++frf_count;
-      else
-        ++brf_count;
-    }
-  }
-
-  if (left_in_image && is_inter_block(left_mbmi)) {
-    if (left_mbmi->ref_frame[0] <= GOLDEN_FRAME)
-      ++frf_count;
-    else
-      ++brf_count;
-    if (has_second_ref(left_mbmi)) {
-      if (left_mbmi->ref_frame[1] <= GOLDEN_FRAME)
-        ++frf_count;
-      else
-        ++brf_count;
-    }
-  }
-
-  pred_context =
+  const int pred_context =
       (frf_count == brf_count) ? 1 : ((frf_count < brf_count) ? 0 : 2);
 
   assert(pred_context >= 0 && pred_context < UNI_COMP_REF_CONTEXTS);
@@ -363,50 +333,17 @@
 // 3 contexts: Voting is used to compare the count of LAST2_FRAME with the
 //             total count of LAST3/GOLDEN from the spatial neighbors.
 int av1_get_pred_context_uni_comp_ref_p1(const MACROBLOCKD *xd) {
-  int pred_context;
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int above_in_image = xd->up_available;
-  const int left_in_image = xd->left_available;
+  const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Count of LAST2
-  int last2_count = 0;
+  const int last2_count = ref_counts[LAST2_FRAME];
   // Count of LAST3 or GOLDEN
-  int last3_or_gld_count = 0;
+  const int last3_or_gld_count =
+      ref_counts[LAST3_FRAME] + ref_counts[GOLDEN_FRAME];
 
-  if (above_in_image && is_inter_block(above_mbmi)) {
-    last2_count = (above_mbmi->ref_frame[0] == LAST2_FRAME) ? last2_count + 1
-                                                            : last2_count;
-    last3_or_gld_count = CHECK_GOLDEN_OR_LAST3(above_mbmi->ref_frame[0])
-                             ? last3_or_gld_count + 1
-                             : last3_or_gld_count;
-    if (has_second_ref(above_mbmi)) {
-      last2_count = (above_mbmi->ref_frame[1] == LAST2_FRAME) ? last2_count + 1
-                                                              : last2_count;
-      last3_or_gld_count = CHECK_GOLDEN_OR_LAST3(above_mbmi->ref_frame[1])
-                               ? last3_or_gld_count + 1
-                               : last3_or_gld_count;
-    }
-  }
-
-  if (left_in_image && is_inter_block(left_mbmi)) {
-    last2_count = (left_mbmi->ref_frame[0] == LAST2_FRAME) ? last2_count + 1
-                                                           : last2_count;
-    last3_or_gld_count = CHECK_GOLDEN_OR_LAST3(left_mbmi->ref_frame[0])
-                             ? last3_or_gld_count + 1
-                             : last3_or_gld_count;
-    if (has_second_ref(left_mbmi)) {
-      last2_count = (left_mbmi->ref_frame[1] == LAST2_FRAME) ? last2_count + 1
-                                                             : last2_count;
-      last3_or_gld_count = CHECK_GOLDEN_OR_LAST3(left_mbmi->ref_frame[1])
-                               ? last3_or_gld_count + 1
-                               : last3_or_gld_count;
-    }
-  }
-
-  pred_context = (last2_count == last3_or_gld_count)
-                     ? 1
-                     : ((last2_count < last3_or_gld_count) ? 0 : 2);
+  const int pred_context = (last2_count == last3_or_gld_count)
+                               ? 1
+                               : ((last2_count < last3_or_gld_count) ? 0 : 2);
 
   assert(pred_context >= 0 && pred_context < UNI_COMP_REF_CONTEXTS);
   return pred_context;
@@ -421,44 +358,14 @@
 // 3 contexts: Voting is used to compare the count of LAST3_FRAME with the
 //             total count of GOLDEN_FRAME from the spatial neighbors.
 int av1_get_pred_context_uni_comp_ref_p2(const MACROBLOCKD *xd) {
-  int pred_context;
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int above_in_image = xd->up_available;
-  const int left_in_image = xd->left_available;
+  const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Count of LAST3
-  int last3_count = 0;
+  const int last3_count = ref_counts[LAST3_FRAME];
   // Count of GOLDEN
-  int gld_count = 0;
+  const int gld_count = ref_counts[GOLDEN_FRAME];
 
-  if (above_in_image && is_inter_block(above_mbmi)) {
-    last3_count = (above_mbmi->ref_frame[0] == LAST3_FRAME) ? last3_count + 1
-                                                            : last3_count;
-    gld_count =
-        (above_mbmi->ref_frame[0] == GOLDEN_FRAME) ? gld_count + 1 : gld_count;
-    if (has_second_ref(above_mbmi)) {
-      last3_count = (above_mbmi->ref_frame[1] == LAST3_FRAME) ? last3_count + 1
-                                                              : last3_count;
-      gld_count = (above_mbmi->ref_frame[1] == GOLDEN_FRAME) ? gld_count + 1
-                                                             : gld_count;
-    }
-  }
-
-  if (left_in_image && is_inter_block(left_mbmi)) {
-    last3_count = (left_mbmi->ref_frame[0] == LAST3_FRAME) ? last3_count + 1
-                                                           : last3_count;
-    gld_count =
-        (left_mbmi->ref_frame[0] == GOLDEN_FRAME) ? gld_count + 1 : gld_count;
-    if (has_second_ref(left_mbmi)) {
-      last3_count = (left_mbmi->ref_frame[1] == LAST3_FRAME) ? last3_count + 1
-                                                             : last3_count;
-      gld_count =
-          (left_mbmi->ref_frame[1] == GOLDEN_FRAME) ? gld_count + 1 : gld_count;
-    }
-  }
-
-  pred_context =
+  const int pred_context =
       (last3_count == gld_count) ? 1 : ((last3_count < gld_count) ? 0 : 2);
 
   assert(pred_context >= 0 && pred_context < UNI_COMP_REF_CONTEXTS);
@@ -466,9 +373,6 @@
 }
 #endif  // CONFIG_EXT_COMP_REFS
 
-// TODO(zoeliu): Future work will be conducted to optimize the context design
-//               for the coding of the reference frames.
-
 #define CHECK_LAST_OR_LAST2(ref_frame) \
   ((ref_frame == LAST_FRAME) || (ref_frame == LAST2_FRAME))
 
@@ -779,35 +683,13 @@
 // Obtain contexts to signal a reference frame be either BWDREF/ALTREF2, or
 // ALTREF.
 int av1_get_pred_context_brfarf2_or_arf(const MACROBLOCKD *xd) {
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int above_in_image = xd->up_available;
-  const int left_in_image = xd->left_available;
+  const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Counts of BWDREF, ALTREF2, or ALTREF frames (B, A2, or A)
-  int bwdref_counts[ALTREF_FRAME - BWDREF_FRAME + 1] = { 0 };
+  const int brfarf2_count =
+      ref_counts[BWDREF_FRAME] + ref_counts[ALTREF2_FRAME];
+  const int arf_count = ref_counts[ALTREF_FRAME];
 
-  if (above_in_image && is_inter_block(above_mbmi)) {
-    if (above_mbmi->ref_frame[0] >= BWDREF_FRAME)
-      ++bwdref_counts[above_mbmi->ref_frame[0] - BWDREF_FRAME];
-    if (has_second_ref(above_mbmi)) {
-      if (above_mbmi->ref_frame[1] >= BWDREF_FRAME)
-        ++bwdref_counts[above_mbmi->ref_frame[1] - BWDREF_FRAME];
-    }
-  }
-
-  if (left_in_image && is_inter_block(left_mbmi)) {
-    if (left_mbmi->ref_frame[0] >= BWDREF_FRAME)
-      ++bwdref_counts[left_mbmi->ref_frame[0] - BWDREF_FRAME];
-    if (has_second_ref(left_mbmi)) {
-      if (left_mbmi->ref_frame[1] >= BWDREF_FRAME)
-        ++bwdref_counts[left_mbmi->ref_frame[1] - BWDREF_FRAME];
-    }
-  }
-
-  const int brfarf2_count = bwdref_counts[BWDREF_FRAME - BWDREF_FRAME] +
-                            bwdref_counts[ALTREF2_FRAME - BWDREF_FRAME];
-  const int arf_count = bwdref_counts[ALTREF_FRAME - BWDREF_FRAME];
   const int pred_context =
       (brfarf2_count == arf_count) ? 1 : ((brfarf2_count < arf_count) ? 0 : 2);
 
@@ -817,41 +699,12 @@
 
 // Obtain contexts to signal a reference frame be either BWDREF or ALTREF2.
 int av1_get_pred_context_brf_or_arf2(const MACROBLOCKD *xd) {
-  const MB_MODE_INFO *const above_mbmi = xd->above_mbmi;
-  const MB_MODE_INFO *const left_mbmi = xd->left_mbmi;
-  const int above_in_image = xd->up_available;
-  const int left_in_image = xd->left_available;
+  const uint8_t *const ref_counts = &xd->neighbors_ref_counts[0];
 
   // Count of BWDREF frames (B)
-  int brf_count = 0;
+  const int brf_count = ref_counts[BWDREF_FRAME];
   // Count of ALTREF2 frames (A2)
-  int arf2_count = 0;
-
-  if (above_in_image && is_inter_block(above_mbmi)) {
-    if (above_mbmi->ref_frame[0] == BWDREF_FRAME)
-      ++brf_count;
-    else if (above_mbmi->ref_frame[0] == ALTREF2_FRAME)
-      ++arf2_count;
-    if (has_second_ref(above_mbmi)) {
-      if (above_mbmi->ref_frame[1] == BWDREF_FRAME)
-        ++brf_count;
-      else if (above_mbmi->ref_frame[1] == ALTREF2_FRAME)
-        ++arf2_count;
-    }
-  }
-
-  if (left_in_image && is_inter_block(left_mbmi)) {
-    if (left_mbmi->ref_frame[0] == BWDREF_FRAME)
-      ++brf_count;
-    else if (left_mbmi->ref_frame[0] == ALTREF2_FRAME)
-      ++arf2_count;
-    if (has_second_ref(left_mbmi)) {
-      if (left_mbmi->ref_frame[1] == BWDREF_FRAME)
-        ++brf_count;
-      else if (left_mbmi->ref_frame[1] == ALTREF2_FRAME)
-        ++arf2_count;
-    }
-  }
+  const int arf2_count = ref_counts[ALTREF2_FRAME];
 
   const int pred_context =
       (brf_count == arf2_count) ? 1 : ((brf_count < arf2_count) ? 0 : 2);
