Refactor iteration over neighbours for OBMC

There are six pieces of code in reconinter.c and two in rdopt.c which
iterate over the blocks along the top or left edge of the current
block for OBMC. Before this patch, each bit of code has its own
implementation of the iteration, which is reasonably finicky to get
right.

This patch factors out that logic into a pair of helpers
(foreach_overlappable_nb_above and foreach_overlappable_nb_left). The
functions take a "fun" parameter, which contains the loop body. Note
that the iteration is too complicated for us to be able to define a
macro that could be used like

  FOREACH_NB_ABOVE(rel_pos, nb_size, nb_mi) { ... }

While C's syntax doesn't seem to let you do that, once the compiler's
optimisation pass is done inlining everything, the results are
essentially the same.

The iteration logic is also slightly generalised: the old code checked
whether a block was shorter or narrower than 8 pixels by comparing a
block size with BLOCK_8X8. This doesn't work when you have a 4x16 or
16x4 block because e.g. BLOCK_16X4 is not less than BLOCK_8X8. This
generalisation is (unsurprisingly) needed in order to to support 16x4
or 4x16 blocks.

This patch doesn't address the CONFIG_NCOBMC functions in reconinter.c
that do prediction from right and bottom edges.

This patch shouldn't affect the generated bitstream in any way: the
code is supposed to be equivalent.

Change-Id: I9e5a116b012c18645604a7d98fb98be99697d363
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 272ac57..b63c24c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -27,6 +27,7 @@
 #include "av1/common/entropymode.h"
 #include "av1/common/idct.h"
 #include "av1/common/mvref_common.h"
+#include "av1/common/obmc.h"
 #include "av1/common/pred_common.h"
 #include "av1/common/quant_common.h"
 #include "av1/common/reconinter.h"
@@ -12286,6 +12287,124 @@
 }
 
 #if CONFIG_MOTION_VAR
+
+struct calc_target_weighted_pred_ctxt {
+  const MACROBLOCK *x;
+  const uint8_t *tmp;
+  int tmp_stride;
+  int overlap;
+};
+
+static INLINE void calc_target_weighted_pred_above(MACROBLOCKD *xd,
+                                                   int rel_mi_col,
+                                                   uint8_t nb_mi_width,
+                                                   MODE_INFO *nb_mi,
+                                                   void *fun_ctxt) {
+  (void)nb_mi;
+
+  struct calc_target_weighted_pred_ctxt *ctxt =
+      (struct calc_target_weighted_pred_ctxt *)fun_ctxt;
+
+#if CONFIG_HIGHBITDEPTH
+  const int is_hbd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? 1 : 0;
+#else
+  const int is_hbd = 0;
+#endif  // CONFIG_HIGHBITDEPTH
+
+  const int bw = xd->n8_w << MI_SIZE_LOG2;
+  const uint8_t *const mask1d = av1_get_obmc_mask(ctxt->overlap);
+
+  int32_t *wsrc = ctxt->x->wsrc_buf + (rel_mi_col * MI_SIZE);
+  int32_t *mask = ctxt->x->mask_buf + (rel_mi_col * MI_SIZE);
+  const uint8_t *tmp = ctxt->tmp + rel_mi_col * MI_SIZE;
+
+  if (!is_hbd) {
+    for (int row = 0; row < ctxt->overlap; ++row) {
+      const uint8_t m0 = mask1d[row];
+      const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
+      for (int col = 0; col < nb_mi_width * MI_SIZE; ++col) {
+        wsrc[col] = m1 * tmp[col];
+        mask[col] = m0;
+      }
+      wsrc += bw;
+      mask += bw;
+      tmp += ctxt->tmp_stride;
+    }
+#if CONFIG_HIGHBITDEPTH
+  } else {
+    const uint16_t *tmp16 = CONVERT_TO_SHORTPTR(tmp);
+
+    for (int row = 0; row < ctxt->overlap; ++row) {
+      const uint8_t m0 = mask1d[row];
+      const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
+      for (int col = 0; col < nb_mi_width * MI_SIZE; ++col) {
+        wsrc[col] = m1 * tmp16[col];
+        mask[col] = m0;
+      }
+      wsrc += bw;
+      mask += bw;
+      tmp16 += ctxt->tmp_stride;
+    }
+#endif  // CONFIG_HIGHBITDEPTH
+  }
+}
+
+static INLINE void calc_target_weighted_pred_left(MACROBLOCKD *xd,
+                                                  int rel_mi_row,
+                                                  uint8_t nb_mi_height,
+                                                  MODE_INFO *nb_mi,
+                                                  void *fun_ctxt) {
+  (void)nb_mi;
+
+  struct calc_target_weighted_pred_ctxt *ctxt =
+      (struct calc_target_weighted_pred_ctxt *)fun_ctxt;
+
+#if CONFIG_HIGHBITDEPTH
+  const int is_hbd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? 1 : 0;
+#else
+  const int is_hbd = 0;
+#endif  // CONFIG_HIGHBITDEPTH
+
+  const int bw = xd->n8_w << MI_SIZE_LOG2;
+  const uint8_t *const mask1d = av1_get_obmc_mask(ctxt->overlap);
+
+  int32_t *wsrc = ctxt->x->wsrc_buf + (rel_mi_row * MI_SIZE * bw);
+  int32_t *mask = ctxt->x->mask_buf + (rel_mi_row * MI_SIZE * bw);
+  const uint8_t *tmp = ctxt->tmp + (rel_mi_row * MI_SIZE * ctxt->tmp_stride);
+
+  if (!is_hbd) {
+    for (int row = 0; row < nb_mi_height * MI_SIZE; ++row) {
+      for (int col = 0; col < ctxt->overlap; ++col) {
+        const uint8_t m0 = mask1d[col];
+        const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
+        wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
+                    (tmp[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
+        mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
+      }
+      wsrc += bw;
+      mask += bw;
+      tmp += ctxt->tmp_stride;
+    }
+#if CONFIG_HIGHBITDEPTH
+  } else {
+    const uint16_t *tmp16 = CONVERT_TO_SHORTPTR(tmp);
+
+    for (int row = 0; row < nb_mi_height * MI_SIZE; ++row) {
+      for (int col = 0; col < ctxt->overlap; ++col) {
+        const uint8_t m0 = mask1d[col];
+        const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
+        wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
+                    (tmp16[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
+        mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
+      }
+      wsrc += bw;
+      mask += bw;
+      tmp16 += ctxt->tmp_stride;
+    }
+#endif  // CONFIG_HIGHBITDEPTH
+  }
+}
+
 // This function has a structure similar to av1_build_obmc_inter_prediction
 //
 // The OBMC predictor is computed as:
@@ -12330,13 +12449,11 @@
                                       int above_stride, const uint8_t *left,
                                       int left_stride) {
   const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
-  int row, col, i;
   const int bw = xd->n8_w << MI_SIZE_LOG2;
   const int bh = xd->n8_h << MI_SIZE_LOG2;
   int32_t *mask_buf = x->mask_buf;
   int32_t *wsrc_buf = x->wsrc_buf;
-  const int wsrc_stride = bw;
-  const int mask_stride = bw;
+
   const int src_scale = AOM_BLEND_A64_MAX_ALPHA * AOM_BLEND_A64_MAX_ALPHA;
 #if CONFIG_HIGHBITDEPTH
   const int is_hbd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? 1 : 0;
@@ -12349,86 +12466,20 @@
   assert(xd->plane[0].subsampling_y == 0);
 
   av1_zero_array(wsrc_buf, bw * bh);
-  for (i = 0; i < bw * bh; ++i) mask_buf[i] = AOM_BLEND_A64_MAX_ALPHA;
+  for (int i = 0; i < bw * bh; ++i) mask_buf[i] = AOM_BLEND_A64_MAX_ALPHA;
 
   // handle above row
   if (xd->up_available) {
     const int overlap =
-        AOMMIN(block_size_high[bsize] >> 1, block_size_high[BLOCK_64X64] >> 1);
-    const int miw = AOMMIN(xd->n8_w, cm->mi_cols - mi_col);
-    const int mi_row_offset = -1;
-    const uint8_t *const mask1d = av1_get_obmc_mask(overlap);
-    const int neighbor_limit = max_neighbor_obmc[b_width_log2_lookup[bsize]];
-    int neighbor_count = 0;
-
-    assert(miw > 0);
-
-    i = 0;
-    do {  // for each mi in the above row
-      const int mi_col_offset = i;
-      const MB_MODE_INFO *above_mbmi =
-          &xd->mi[mi_col_offset + mi_row_offset * xd->mi_stride]->mbmi;
-#if CONFIG_CHROMA_SUB8X8
-      if (above_mbmi->sb_type < BLOCK_8X8)
-        above_mbmi =
-            &xd->mi[mi_col_offset + 1 + mi_row_offset * xd->mi_stride]->mbmi;
-#endif
-      const BLOCK_SIZE a_bsize = AOMMAX(above_mbmi->sb_type, BLOCK_8X8);
-      const int above_step =
-          AOMMIN(mi_size_wide[a_bsize], mi_size_wide[BLOCK_64X64]);
-      const int mi_step = AOMMIN(xd->n8_w, above_step);
-      const int neighbor_bw = mi_step * MI_SIZE;
-
-      if (is_neighbor_overlappable(above_mbmi)) {
-        if (!CONFIG_CB4X4 && (a_bsize == BLOCK_4X4 || a_bsize == BLOCK_4X8))
-          neighbor_count += 2;
-        else
-          neighbor_count++;
-        if (neighbor_count > neighbor_limit) break;
-
-        const int tmp_stride = above_stride;
-        int32_t *wsrc = wsrc_buf + (i * MI_SIZE);
-        int32_t *mask = mask_buf + (i * MI_SIZE);
-
-        if (!is_hbd) {
-          const uint8_t *tmp = above;
-
-          for (row = 0; row < overlap; ++row) {
-            const uint8_t m0 = mask1d[row];
-            const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
-            for (col = 0; col < neighbor_bw; ++col) {
-              wsrc[col] = m1 * tmp[col];
-              mask[col] = m0;
-            }
-            wsrc += wsrc_stride;
-            mask += mask_stride;
-            tmp += tmp_stride;
-          }
-#if CONFIG_HIGHBITDEPTH
-        } else {
-          const uint16_t *tmp = CONVERT_TO_SHORTPTR(above);
-
-          for (row = 0; row < overlap; ++row) {
-            const uint8_t m0 = mask1d[row];
-            const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
-            for (col = 0; col < neighbor_bw; ++col) {
-              wsrc[col] = m1 * tmp[col];
-              mask[col] = m0;
-            }
-            wsrc += wsrc_stride;
-            mask += mask_stride;
-            tmp += tmp_stride;
-          }
-#endif  // CONFIG_HIGHBITDEPTH
-        }
-      }
-
-      above += neighbor_bw;
-      i += mi_step;
-    } while (i < miw);
+        AOMMIN(block_size_high[bsize], block_size_high[BLOCK_64X64]) >> 1;
+    struct calc_target_weighted_pred_ctxt ctxt = { x, above, above_stride,
+                                                   overlap };
+    foreach_overlappable_nb_above(cm, (MACROBLOCKD *)xd, mi_col,
+                                  max_neighbor_obmc[b_width_log2_lookup[bsize]],
+                                  calc_target_weighted_pred_above, &ctxt);
   }
 
-  for (i = 0; i < bw * bh; ++i) {
+  for (int i = 0; i < bw * bh; ++i) {
     wsrc_buf[i] *= AOM_BLEND_A64_MAX_ALPHA;
     mask_buf[i] *= AOM_BLEND_A64_MAX_ALPHA;
   }
@@ -12436,102 +12487,33 @@
   // handle left column
   if (xd->left_available) {
     const int overlap =
-        AOMMIN(block_size_wide[bsize] >> 1, block_size_wide[BLOCK_64X64] >> 1);
-    const int mih = AOMMIN(xd->n8_h, cm->mi_rows - mi_row);
-    const int mi_col_offset = -1;
-    const uint8_t *const mask1d = av1_get_obmc_mask(overlap);
-    const int neighbor_limit = max_neighbor_obmc[b_height_log2_lookup[bsize]];
-    int neighbor_count = 0;
-
-    assert(mih > 0);
-
-    i = 0;
-    do {  // for each mi in the left column
-      const int mi_row_offset = i;
-      MB_MODE_INFO *left_mbmi =
-          &xd->mi[mi_col_offset + mi_row_offset * xd->mi_stride]->mbmi;
-
-#if CONFIG_CHROMA_SUB8X8
-      if (left_mbmi->sb_type < BLOCK_8X8)
-        left_mbmi =
-            &xd->mi[mi_col_offset + (mi_row_offset + 1) * xd->mi_stride]->mbmi;
-#endif
-      const BLOCK_SIZE l_bsize = AOMMAX(left_mbmi->sb_type, BLOCK_8X8);
-      const int left_step =
-          AOMMIN(mi_size_high[l_bsize], mi_size_high[BLOCK_64X64]);
-      const int mi_step = AOMMIN(xd->n8_h, left_step);
-      const int neighbor_bh = mi_step * MI_SIZE;
-
-      if (is_neighbor_overlappable(left_mbmi)) {
-        if (!CONFIG_CB4X4 && (l_bsize == BLOCK_4X4 || l_bsize == BLOCK_8X4))
-          neighbor_count += 2;
-        else
-          neighbor_count++;
-        if (neighbor_count > neighbor_limit) break;
-
-        const int tmp_stride = left_stride;
-        int32_t *wsrc = wsrc_buf + (i * MI_SIZE * wsrc_stride);
-        int32_t *mask = mask_buf + (i * MI_SIZE * mask_stride);
-
-        if (!is_hbd) {
-          const uint8_t *tmp = left;
-
-          for (row = 0; row < neighbor_bh; ++row) {
-            for (col = 0; col < overlap; ++col) {
-              const uint8_t m0 = mask1d[col];
-              const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
-              wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
-                          (tmp[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
-              mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
-            }
-            wsrc += wsrc_stride;
-            mask += mask_stride;
-            tmp += tmp_stride;
-          }
-#if CONFIG_HIGHBITDEPTH
-        } else {
-          const uint16_t *tmp = CONVERT_TO_SHORTPTR(left);
-
-          for (row = 0; row < neighbor_bh; ++row) {
-            for (col = 0; col < overlap; ++col) {
-              const uint8_t m0 = mask1d[col];
-              const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
-              wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
-                          (tmp[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
-              mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
-            }
-            wsrc += wsrc_stride;
-            mask += mask_stride;
-            tmp += tmp_stride;
-          }
-#endif  // CONFIG_HIGHBITDEPTH
-        }
-      }
-
-      left += neighbor_bh * left_stride;
-      i += mi_step;
-    } while (i < mih);
+        AOMMIN(block_size_wide[bsize], block_size_wide[BLOCK_64X64]) >> 1;
+    struct calc_target_weighted_pred_ctxt ctxt = { x, left, left_stride,
+                                                   overlap };
+    foreach_overlappable_nb_left(cm, (MACROBLOCKD *)xd, mi_row,
+                                 max_neighbor_obmc[b_height_log2_lookup[bsize]],
+                                 calc_target_weighted_pred_left, &ctxt);
   }
 
   if (!is_hbd) {
     const uint8_t *src = x->plane[0].src.buf;
 
-    for (row = 0; row < bh; ++row) {
-      for (col = 0; col < bw; ++col) {
+    for (int row = 0; row < bh; ++row) {
+      for (int col = 0; col < bw; ++col) {
         wsrc_buf[col] = src[col] * src_scale - wsrc_buf[col];
       }
-      wsrc_buf += wsrc_stride;
+      wsrc_buf += bw;
       src += x->plane[0].src.stride;
     }
 #if CONFIG_HIGHBITDEPTH
   } else {
     const uint16_t *src = CONVERT_TO_SHORTPTR(x->plane[0].src.buf);
 
-    for (row = 0; row < bh; ++row) {
-      for (col = 0; col < bw; ++col) {
+    for (int row = 0; row < bh; ++row) {
+      for (int col = 0; col < bw; ++col) {
         wsrc_buf[col] = src[col] * src_scale - wsrc_buf[col];
       }
-      wsrc_buf += wsrc_stride;
+      wsrc_buf += bw;
       src += x->plane[0].src.stride;
     }
 #endif  // CONFIG_HIGHBITDEPTH