[jnt-comp, normative] Avoid double-rounding in prediction

As per the linked bug report, the distance-weighted compound
prediction has two separate round operations, first by 3
bits (inside the various convolve functions), then by 10 bits
(after the convolution functions).

We can improve on this by right shifting by 3 bits inside the
convolve functions - this is equivalent to doing a single round
by 13 bits at the end.

Note: In the encoder, when doing joint_motion_search(), we do
things a bit differently: So that we can try modifying the two
"sides" of the prediction independently, we predict each side as
if it were a single prediction (including rounding), then blend
these single predictions together.

This is already an approximation to the "real" prediction, even
in the non-jnt-comp case. So we leave that code path as-is.

BUG=aomedia:1289

Change-Id: I9ad1fbcb3e12db2b5fc3c82b407f0fd9e6b39750
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index 7de1074..4b8e39a 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -702,9 +702,7 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
-          dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
-                                                       DIST_PRECISION_BITS - 1);
+          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -742,9 +740,7 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
-          dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
-                                                       DIST_PRECISION_BITS - 1);
+          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -818,9 +814,7 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
-          dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
-                                                       DIST_PRECISION_BITS - 1);
+          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -985,9 +979,7 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
-          dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
-                                                       DIST_PRECISION_BITS - 1);
+          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -1060,9 +1052,7 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
-          dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
-                                                       DIST_PRECISION_BITS - 1);
+          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
diff --git a/av1/common/warped_motion.c b/av1/common/warped_motion.c
index f36b8be..63fbc27 100644
--- a/av1/common/warped_motion.c
+++ b/av1/common/warped_motion.c
@@ -516,7 +516,7 @@
             if (conv_params->use_jnt_comp_avg) {
               if (conv_params->do_average) {
                 *p += sum * conv_params->bck_offset;
-                *p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
+                *p >>= (DIST_PRECISION_BITS - 1);
               } else {
                 *p = sum * conv_params->fwd_offset;
               }
@@ -820,7 +820,7 @@
             if (conv_params->use_jnt_comp_avg) {
               if (conv_params->do_average) {
                 *p += sum * conv_params->bck_offset;
-                *p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
+                *p >>= (DIST_PRECISION_BITS - 1);
               } else {
                 *p = sum * conv_params->fwd_offset;
               }
diff --git a/av1/common/x86/av1_convolve_scale_sse4.c b/av1/common/x86/av1_convolve_scale_sse4.c
index 3959b06..1c27200 100644
--- a/av1/common/x86/av1_convolve_scale_sse4.c
+++ b/av1/common/x86/av1_convolve_scale_sse4.c
@@ -263,7 +263,6 @@
 #if CONFIG_JNT_COMP
   const __m128i fwd_offset = _mm_set1_epi32(conv_params->fwd_offset);
   const __m128i bck_offset = _mm_set1_epi32(conv_params->bck_offset);
-  const __m128i jnt_round = _mm_set1_epi32(1 << (DIST_PRECISION_BITS - 2));
 #endif  // CONFIG_JNT_COMP
 
   int y_qn = subpel_y_qn;
@@ -315,11 +314,10 @@
       __m128i result;
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          result = _mm_srai_epi32(
-              _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
-                                          _mm_mullo_epi32(subbed, bck_offset)),
-                            jnt_round),
-              DIST_PRECISION_BITS - 1);
+          result =
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
+                                           _mm_mullo_epi32(subbed, bck_offset)),
+                             DIST_PRECISION_BITS - 1);
         } else {
           result = _mm_mullo_epi32(subbed, fwd_offset);
         }
@@ -347,8 +345,7 @@
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
 
-          dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
-                                                       DIST_PRECISION_BITS - 1);
+          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -385,7 +382,6 @@
 #if CONFIG_JNT_COMP
   const __m128i fwd_offset = _mm_set1_epi32(conv_params->fwd_offset);
   const __m128i bck_offset = _mm_set1_epi32(conv_params->bck_offset);
-  const __m128i jnt_round = _mm_set1_epi32(1 << (DIST_PRECISION_BITS - 2));
 #endif  // CONFIG_JNT_COMP
 
   int y_qn = subpel_y_qn;
@@ -434,11 +430,10 @@
       __m128i result;
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          result = _mm_srai_epi32(
-              _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
-                                          _mm_mullo_epi32(subbed, bck_offset)),
-                            jnt_round),
-              DIST_PRECISION_BITS - 1);
+          result =
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
+                                           _mm_mullo_epi32(subbed, bck_offset)),
+                             DIST_PRECISION_BITS - 1);
         } else {
           result = _mm_mullo_epi32(subbed, fwd_offset);
         }
@@ -466,8 +461,7 @@
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
 
-          dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
-                                                       DIST_PRECISION_BITS - 1);
+          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
diff --git a/av1/common/x86/convolve_2d_sse2.c b/av1/common/x86/convolve_2d_sse2.c
index c327735..111bc85 100644
--- a/av1/common/x86/convolve_2d_sse2.c
+++ b/av1/common/x86/convolve_2d_sse2.c
@@ -676,8 +676,6 @@
   const int w1 = conv_params->bck_offset;
   const __m128i wt0 = _mm_set1_epi32(w0);
   const __m128i wt1 = _mm_set1_epi32(w1);
-  const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
-  const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
 
   if (!(w % 16)) {
     for (i = 0; i < h; ++i) {
@@ -697,26 +695,22 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
 
             mul = _mm_mullo_epi16(d32_1, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
-            d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
 
             mul = _mm_mullo_epi16(d32_2, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
-            d32_2 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_2 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
 
             mul = _mm_mullo_epi16(d32_3, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
-            d32_3 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_3 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
             d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
@@ -763,14 +757,12 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
 
             mul = _mm_mullo_epi16(d32_1, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
-            d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
             d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
@@ -806,8 +798,7 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
           }
@@ -838,8 +829,7 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
-            d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
-                                   DIST_PRECISION_BITS - 1);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
           }
diff --git a/av1/common/x86/convolve_2d_sse4.c b/av1/common/x86/convolve_2d_sse4.c
index 8e5e286..123350a 100644
--- a/av1/common/x86/convolve_2d_sse4.c
+++ b/av1/common/x86/convolve_2d_sse4.c
@@ -47,9 +47,6 @@
   const int w1 = conv_params->bck_offset;
   const __m128i wt0 = _mm_set_epi32(w0, w0, w0, w0);
   const __m128i wt1 = _mm_set_epi32(w1, w1, w1, w1);
-  const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
-  const __m128i jnt_r = _mm_set_epi32(jnt_round_const, jnt_round_const,
-                                      jnt_round_const, jnt_round_const);
 
   /* Horizontal filter */
   {
@@ -207,18 +204,14 @@
           if (do_average) {
             _mm_storeu_si128(
                 p + 0, _mm_srai_epi32(
-                           _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
-                                                       _mm_mullo_epi32(
-                                                           res_lo_round, wt1)),
-                                         jnt_r),
+                           _mm_add_epi32(_mm_loadu_si128(p + 0),
+                                         _mm_mullo_epi32(res_lo_round, wt1)),
                            DIST_PRECISION_BITS - 1));
 
             _mm_storeu_si128(
                 p + 1, _mm_srai_epi32(
-                           _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
-                                                       _mm_mullo_epi32(
-                                                           res_hi_round, wt1)),
-                                         jnt_r),
+                           _mm_add_epi32(_mm_loadu_si128(p + 1),
+                                         _mm_mullo_epi32(res_hi_round, wt1)),
                            DIST_PRECISION_BITS - 1));
           } else {
             _mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
diff --git a/av1/common/x86/highbd_convolve_2d_sse4.c b/av1/common/x86/highbd_convolve_2d_sse4.c
index e23e0d0..853ecfd 100644
--- a/av1/common/x86/highbd_convolve_2d_sse4.c
+++ b/av1/common/x86/highbd_convolve_2d_sse4.c
@@ -39,8 +39,6 @@
   const int w1 = conv_params->bck_offset;
   const __m128i wt0 = _mm_set1_epi32(w0);
   const __m128i wt1 = _mm_set1_epi32(w1);
-  const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
-  const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
 
   // Check that, even with 12-bit input, the intermediate values will fit
   // into an unsigned 15-bit intermediate array.
@@ -202,12 +200,10 @@
                 _mm_loadu_si128(p + 0), _mm_mullo_epi32(res_lo_round, wt1));
             const __m128i jnt_sum_hi = _mm_add_epi32(
                 _mm_loadu_si128(p + 1), _mm_mullo_epi32(res_hi_round, wt1));
-            const __m128i jnt_round_res_lo = _mm_add_epi32(jnt_sum_lo, jnt_r);
-            const __m128i jnt_round_res_hi = _mm_add_epi32(jnt_sum_hi, jnt_r);
             const __m128i final_lo =
-                _mm_srai_epi32(jnt_round_res_lo, DIST_PRECISION_BITS - 1);
+                _mm_srai_epi32(jnt_sum_lo, DIST_PRECISION_BITS - 1);
             const __m128i final_hi =
-                _mm_srai_epi32(jnt_round_res_hi, DIST_PRECISION_BITS - 1);
+                _mm_srai_epi32(jnt_sum_hi, DIST_PRECISION_BITS - 1);
 
             _mm_storeu_si128(p + 0, final_lo);
             _mm_storeu_si128(p + 1, final_hi);
diff --git a/av1/common/x86/highbd_warp_plane_sse4.c b/av1/common/x86/highbd_warp_plane_sse4.c
index d40a969..0d4d17f 100644
--- a/av1/common/x86/highbd_warp_plane_sse4.c
+++ b/av1/common/x86/highbd_warp_plane_sse4.c
@@ -42,8 +42,6 @@
   const int w1 = conv_params->bck_offset;
   const __m128i wt0 = _mm_set1_epi32(w0);
   const __m128i wt1 = _mm_set1_epi32(w1);
-  const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
-  const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
 #endif  // CONFIG_JNT_COMP
 
   /* Note: For this code to work, the left/right frame borders need to be
@@ -320,8 +318,7 @@
             if (comp_avg) {
               const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
                                                 _mm_mullo_epi32(res_lo, wt1));
-              const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
-              res_lo = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+              res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
             } else {
               res_lo = _mm_mullo_epi32(res_lo, wt0);
             }
@@ -345,8 +342,7 @@
               if (comp_avg) {
                 const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
                                                   _mm_mullo_epi32(res_hi, wt1));
-                const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
-                res_hi = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+                res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
               } else {
                 res_hi = _mm_mullo_epi32(res_hi, wt0);
               }
diff --git a/av1/common/x86/warp_plane_sse4.c b/av1/common/x86/warp_plane_sse4.c
index e0d6206..94b5f09 100644
--- a/av1/common/x86/warp_plane_sse4.c
+++ b/av1/common/x86/warp_plane_sse4.c
@@ -39,8 +39,6 @@
   const int w1 = conv_params->bck_offset;
   const __m128i wt0 = _mm_set1_epi32(w0);
   const __m128i wt1 = _mm_set1_epi32(w1);
-  const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
-  const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
 #endif  // CONFIG_JNT_COMP
 
   /* Note: For this code to work, the left/right frame borders need to be
@@ -317,8 +315,7 @@
             if (comp_avg) {
               const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
                                                 _mm_mullo_epi32(res_lo, wt1));
-              const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
-              res_lo = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+              res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
             } else {
               res_lo = _mm_mullo_epi32(res_lo, wt0);
             }
@@ -340,8 +337,7 @@
               if (comp_avg) {
                 const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
                                                   _mm_mullo_epi32(res_hi, wt1));
-                const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
-                res_hi = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
+                res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
               } else {
                 res_hi = _mm_mullo_epi32(res_hi, wt0);
               }