Compute compound average in warp_plane only for COMPOUND_AVERAGE

This fixes a mismatch which occurs when global/warped motion and
a masked compound type are used together.

Change-Id: I08b2702cdb3b85f8d8817b9286a73951c97cf379
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index ec682cf..f6fa6be 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -608,11 +608,11 @@
 
 if ((aom_config("CONFIG_WARPED_MOTION") eq "yes") ||
     (aom_config("CONFIG_GLOBAL_MOTION") eq "yes")) {
-  add_proto qw/void av1_warp_affine/, "const int32_t *mat, const uint8_t *ref, int width, int height, int stride, uint8_t *pred, int p_col, int p_row, int p_width, int p_height, int p_stride, int subsampling_x, int subsampling_y, int ref_frm, int16_t alpha, int16_t beta, int16_t gamma, int16_t delta";
+  add_proto qw/void av1_warp_affine/, "const int32_t *mat, const uint8_t *ref, int width, int height, int stride, uint8_t *pred, int p_col, int p_row, int p_width, int p_height, int p_stride, int subsampling_x, int subsampling_y, int comp_avg, int16_t alpha, int16_t beta, int16_t gamma, int16_t delta";
   specialize qw/av1_warp_affine sse2 ssse3/;
 
   if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
-    add_proto qw/void av1_highbd_warp_affine/, "const int32_t *mat, const uint16_t *ref, int width, int height, int stride, uint16_t *pred, int p_col, int p_row, int p_width, int p_height, int p_stride, int subsampling_x, int subsampling_y, int bd, int ref_frm, int16_t alpha, int16_t beta, int16_t gamma, int16_t delta";
+    add_proto qw/void av1_highbd_warp_affine/, "const int32_t *mat, const uint16_t *ref, int width, int height, int stride, uint16_t *pred, int p_col, int p_row, int p_width, int p_height, int p_stride, int subsampling_x, int subsampling_y, int bd, int comp_avg, int16_t alpha, int16_t beta, int16_t gamma, int16_t delta";
     specialize qw/av1_highbd_warp_affine ssse3/;
   }
 }
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 10933a7..a70c13a 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -415,13 +415,19 @@
   if (do_warp) {
     const struct macroblockd_plane *const pd = &xd->plane[plane];
     const struct buf_2d *const pre_buf = &pd->pre[ref];
+#if CONFIG_EXT_INTER
+    int compute_avg =
+        ref && mi->mbmi.interinter_compound_type == COMPOUND_AVERAGE;
+#else
+    int compute_avg = ref;
+#endif  // CONFIG_EXT_INTER
     av1_warp_plane(&final_warp_params,
 #if CONFIG_HIGHBITDEPTH
                    xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH, xd->bd,
 #endif  // CONFIG_HIGHBITDEPTH
                    pre_buf->buf0, pre_buf->width, pre_buf->height,
                    pre_buf->stride, dst, p_col, p_row, w, h, dst_stride,
-                   pd->subsampling_x, pd->subsampling_y, xs, ys, ref);
+                   pd->subsampling_x, pd->subsampling_y, xs, ys, compute_avg);
     return;
   }
 #endif  // CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
diff --git a/av1/common/warped_motion.c b/av1/common/warped_motion.c
index 7e27985..1b95293 100644
--- a/av1/common/warped_motion.c
+++ b/av1/common/warped_motion.c
@@ -909,7 +909,7 @@
                                   int p_width, int p_height, int p_stride,
                                   int subsampling_x, int subsampling_y,
                                   int x_scale, int y_scale, int bd,
-                                  int ref_frm) {
+                                  int comp_avg) {
   int i, j;
   ProjectPointsFunc projectpoints = get_project_points_type(wm->wmtype);
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
@@ -923,7 +923,7 @@
       projectpoints(wm->wmmat, in, out, 1, 2, 2, subsampling_x, subsampling_y);
       out[0] = ROUND_POWER_OF_TWO_SIGNED(out[0] * x_scale, 4);
       out[1] = ROUND_POWER_OF_TWO_SIGNED(out[1] * y_scale, 4);
-      if (ref_frm)
+      if (comp_avg)
         pred[(j - p_col) + (i - p_row) * p_stride] = ROUND_POWER_OF_TWO(
             pred[(j - p_col) + (i - p_row) * p_stride] +
                 highbd_warp_interpolate(ref, out[0], out[1], width, height,
@@ -953,7 +953,7 @@
                               int width, int height, int stride, uint16_t *pred,
                               int p_col, int p_row, int p_width, int p_height,
                               int p_stride, int subsampling_x,
-                              int subsampling_y, int bd, int ref_frm,
+                              int subsampling_y, int bd, int comp_avg,
                               int16_t alpha, int16_t beta, int16_t gamma,
                               int16_t delta) {
 #if HORSHEAR_REDUCE_PREC_BITS >= 5
@@ -1059,7 +1059,7 @@
           }
           sum = clip_pixel_highbd(
               ROUND_POWER_OF_TWO(sum, VERSHEAR_REDUCE_PREC_BITS), bd);
-          if (ref_frm)
+          if (comp_avg)
             *p = ROUND_POWER_OF_TWO(*p + sum, 1);
           else
             *p = sum;
@@ -1075,7 +1075,7 @@
                               int p_row, int p_width, int p_height,
                               int p_stride, int subsampling_x,
                               int subsampling_y, int x_scale, int y_scale,
-                              int bd, int ref_frm) {
+                              int bd, int comp_avg) {
   if (wm->wmtype == ROTZOOM) {
     wm->wmmat[5] = wm->wmmat[2];
     wm->wmmat[4] = -wm->wmmat[3];
@@ -1092,12 +1092,12 @@
     uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
     av1_highbd_warp_affine(mat, ref, width, height, stride, pred, p_col, p_row,
                            p_width, p_height, p_stride, subsampling_x,
-                           subsampling_y, bd, ref_frm, alpha, beta, gamma,
+                           subsampling_y, bd, comp_avg, alpha, beta, gamma,
                            delta);
   } else {
     highbd_warp_plane_old(wm, ref8, width, height, stride, pred8, p_col, p_row,
                           p_width, p_height, p_stride, subsampling_x,
-                          subsampling_y, x_scale, y_scale, bd, ref_frm);
+                          subsampling_y, x_scale, y_scale, bd, comp_avg);
   }
 }
 
@@ -1138,7 +1138,7 @@
                            int height, int stride, uint8_t *pred, int p_col,
                            int p_row, int p_width, int p_height, int p_stride,
                            int subsampling_x, int subsampling_y, int x_scale,
-                           int y_scale, int ref_frm) {
+                           int y_scale, int comp_avg) {
   int i, j;
   ProjectPointsFunc projectpoints = get_project_points_type(wm->wmtype);
   if (projectpoints == NULL) return;
@@ -1150,7 +1150,7 @@
       projectpoints(wm->wmmat, in, out, 1, 2, 2, subsampling_x, subsampling_y);
       out[0] = ROUND_POWER_OF_TWO_SIGNED(out[0] * x_scale, 4);
       out[1] = ROUND_POWER_OF_TWO_SIGNED(out[1] * y_scale, 4);
-      if (ref_frm)
+      if (comp_avg)
         pred[(j - p_col) + (i - p_row) * p_stride] = ROUND_POWER_OF_TWO(
             pred[(j - p_col) + (i - p_row) * p_stride] +
                 warp_interpolate(ref, out[0], out[1], width, height, stride),
@@ -1202,7 +1202,7 @@
 void av1_warp_affine_c(const int32_t *mat, const uint8_t *ref, int width,
                        int height, int stride, uint8_t *pred, int p_col,
                        int p_row, int p_width, int p_height, int p_stride,
-                       int subsampling_x, int subsampling_y, int ref_frm,
+                       int subsampling_x, int subsampling_y, int comp_avg,
                        int16_t alpha, int16_t beta, int16_t gamma,
                        int16_t delta) {
   int16_t tmp[15 * 8];
@@ -1316,7 +1316,7 @@
             sum += tmp[(k + m + 4) * 8 + (l + 4)] * coeffs[m];
           }
           sum = clip_pixel(ROUND_POWER_OF_TWO(sum, VERSHEAR_REDUCE_PREC_BITS));
-          if (ref_frm)
+          if (comp_avg)
             *p = ROUND_POWER_OF_TWO(*p + sum, 1);
           else
             *p = sum;
@@ -1331,7 +1331,7 @@
                        int height, int stride, uint8_t *pred, int p_col,
                        int p_row, int p_width, int p_height, int p_stride,
                        int subsampling_x, int subsampling_y, int x_scale,
-                       int y_scale, int ref_frm) {
+                       int y_scale, int comp_avg) {
   if (wm->wmtype == ROTZOOM) {
     wm->wmmat[5] = wm->wmmat[2];
     wm->wmmat[4] = -wm->wmmat[3];
@@ -1346,11 +1346,11 @@
 
     av1_warp_affine(mat, ref, width, height, stride, pred, p_col, p_row,
                     p_width, p_height, p_stride, subsampling_x, subsampling_y,
-                    ref_frm, alpha, beta, gamma, delta);
+                    comp_avg, alpha, beta, gamma, delta);
   } else {
     warp_plane_old(wm, ref, width, height, stride, pred, p_col, p_row, p_width,
                    p_height, p_stride, subsampling_x, subsampling_y, x_scale,
-                   y_scale, ref_frm);
+                   y_scale, comp_avg);
   }
 }
 
@@ -1409,17 +1409,17 @@
                     uint8_t *ref, int width, int height, int stride,
                     uint8_t *pred, int p_col, int p_row, int p_width,
                     int p_height, int p_stride, int subsampling_x,
-                    int subsampling_y, int x_scale, int y_scale, int ref_frm) {
+                    int subsampling_y, int x_scale, int y_scale, int comp_avg) {
 #if CONFIG_HIGHBITDEPTH
   if (use_hbd)
     highbd_warp_plane(wm, ref, width, height, stride, pred, p_col, p_row,
                       p_width, p_height, p_stride, subsampling_x, subsampling_y,
-                      x_scale, y_scale, bd, ref_frm);
+                      x_scale, y_scale, bd, comp_avg);
   else
 #endif  // CONFIG_HIGHBITDEPTH
     warp_plane(wm, ref, width, height, stride, pred, p_col, p_row, p_width,
                p_height, p_stride, subsampling_x, subsampling_y, x_scale,
-               y_scale, ref_frm);
+               y_scale, comp_avg);
 }
 
 #if CONFIG_WARPED_MOTION
diff --git a/av1/common/warped_motion.h b/av1/common/warped_motion.h
index dfd8dae..7c011ad 100644
--- a/av1/common/warped_motion.h
+++ b/av1/common/warped_motion.h
@@ -87,7 +87,7 @@
                     uint8_t *ref, int width, int height, int stride,
                     uint8_t *pred, int p_col, int p_row, int p_width,
                     int p_height, int p_stride, int subsampling_x,
-                    int subsampling_y, int x_scale, int y_scale, int ref_frm);
+                    int subsampling_y, int x_scale, int y_scale, int comp_avg);
 
 int find_projection(int np, int *pts1, int *pts2, BLOCK_SIZE bsize, int mvy,
                     int mvx, WarpedMotionParams *wm_params, int mi_row,
diff --git a/av1/common/x86/highbd_warp_plane_ssse3.c b/av1/common/x86/highbd_warp_plane_ssse3.c
index 4762340..18f120f 100644
--- a/av1/common/x86/highbd_warp_plane_ssse3.c
+++ b/av1/common/x86/highbd_warp_plane_ssse3.c
@@ -20,7 +20,7 @@
                                   uint16_t *pred, int p_col, int p_row,
                                   int p_width, int p_height, int p_stride,
                                   int subsampling_x, int subsampling_y, int bd,
-                                  int ref_frm, int16_t alpha, int16_t beta,
+                                  int comp_avg, int16_t alpha, int16_t beta,
                                   int16_t gamma, int16_t delta) {
 #if HORSHEAR_REDUCE_PREC_BITS >= 5
   __m128i tmp[15];
@@ -304,10 +304,12 @@
         // to only output 4 pixels at this point, to avoid encode/decode
         // mismatches when encoding with multiple threads.
         if (p_width == 4) {
-          if (ref_frm) res_16bit = _mm_avg_epu16(res_16bit, _mm_loadl_epi64(p));
+          if (comp_avg)
+            res_16bit = _mm_avg_epu16(res_16bit, _mm_loadl_epi64(p));
           _mm_storel_epi64(p, res_16bit);
         } else {
-          if (ref_frm) res_16bit = _mm_avg_epu16(res_16bit, _mm_loadu_si128(p));
+          if (comp_avg)
+            res_16bit = _mm_avg_epu16(res_16bit, _mm_loadu_si128(p));
           _mm_storeu_si128(p, res_16bit);
         }
       }
diff --git a/av1/common/x86/warp_plane_sse2.c b/av1/common/x86/warp_plane_sse2.c
index 81145b6..055a1f6 100644
--- a/av1/common/x86/warp_plane_sse2.c
+++ b/av1/common/x86/warp_plane_sse2.c
@@ -18,7 +18,7 @@
 void av1_warp_affine_sse2(const int32_t *mat, const uint8_t *ref, int width,
                           int height, int stride, uint8_t *pred, int p_col,
                           int p_row, int p_width, int p_height, int p_stride,
-                          int subsampling_x, int subsampling_y, int ref_frm,
+                          int subsampling_x, int subsampling_y, int comp_avg,
                           int16_t alpha, int16_t beta, int16_t gamma,
                           int16_t delta) {
   __m128i tmp[15];
@@ -296,13 +296,13 @@
         // to only output 4 pixels at this point, to avoid encode/decode
         // mismatches when encoding with multiple threads.
         if (p_width == 4) {
-          if (ref_frm) {
+          if (comp_avg) {
             const __m128i orig = _mm_cvtsi32_si128(*(uint32_t *)p);
             res_8bit = _mm_avg_epu8(res_8bit, orig);
           }
           *(uint32_t *)p = _mm_cvtsi128_si32(res_8bit);
         } else {
-          if (ref_frm) res_8bit = _mm_avg_epu8(res_8bit, _mm_loadl_epi64(p));
+          if (comp_avg) res_8bit = _mm_avg_epu8(res_8bit, _mm_loadl_epi64(p));
           _mm_storel_epi64(p, res_8bit);
         }
       }
diff --git a/av1/common/x86/warp_plane_ssse3.c b/av1/common/x86/warp_plane_ssse3.c
index f6cc2d6..1cca425 100644
--- a/av1/common/x86/warp_plane_ssse3.c
+++ b/av1/common/x86/warp_plane_ssse3.c
@@ -205,7 +205,7 @@
 void av1_warp_affine_ssse3(const int32_t *mat, const uint8_t *ref, int width,
                            int height, int stride, uint8_t *pred, int p_col,
                            int p_row, int p_width, int p_height, int p_stride,
-                           int subsampling_x, int subsampling_y, int ref_frm,
+                           int subsampling_x, int subsampling_y, int comp_avg,
                            int16_t alpha, int16_t beta, int16_t gamma,
                            int16_t delta) {
   __m128i tmp[15];
@@ -473,13 +473,13 @@
         // to only output 4 pixels at this point, to avoid encode/decode
         // mismatches when encoding with multiple threads.
         if (p_width == 4) {
-          if (ref_frm) {
+          if (comp_avg) {
             const __m128i orig = _mm_cvtsi32_si128(*(uint32_t *)p);
             res_8bit = _mm_avg_epu8(res_8bit, orig);
           }
           *(uint32_t *)p = _mm_cvtsi128_si32(res_8bit);
         } else {
-          if (ref_frm) res_8bit = _mm_avg_epu8(res_8bit, _mm_loadl_epi64(p));
+          if (comp_avg) res_8bit = _mm_avg_epu8(res_8bit, _mm_loadl_epi64(p));
           _mm_storel_epi64(p, res_8bit);
         }
       }