[NORMATIVE jnt_comp] remove double rounding

Change-Id: Ib325e33bee8aa3a8445a7f61c55adfd3fb210792
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index 7645653..95458c1 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -703,7 +703,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -746,7 +745,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -788,7 +786,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -826,7 +823,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -900,7 +896,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -1065,7 +1060,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -1138,7 +1132,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index c8c58ce..bd3493f 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -1106,17 +1106,19 @@
                 xs, ys, xd);
         }  // for (ref = 0; ref < 1 + is_compound; ++ref)
         if (conv_params.do_post_rounding) {
+          int round_bits = FILTER_BITS * 2 + is_compound - conv_params.round_0 -
+                           conv_params.round_1;
+#if CONFIG_JNT_COMP
+          if (conv_params.use_jnt_comp_avg)
+            round_bits += DIST_PRECISION_BITS - 1;
+#endif  // CONFIG_JNT_COMP
           if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-            av1_highbd_convolve_rounding(
-                tmp_dst, tmp_dst_stride, dst, dst_buf->stride, b4_w, b4_h,
-                FILTER_BITS * 2 + is_compound - conv_params.round_0 -
-                    conv_params.round_1,
-                xd->bd);
+            av1_highbd_convolve_rounding(tmp_dst, tmp_dst_stride, dst,
+                                         dst_buf->stride, b4_w, b4_h,
+                                         round_bits, xd->bd);
           else
-            av1_convolve_rounding(
-                tmp_dst, tmp_dst_stride, dst, dst_buf->stride, b4_w, b4_h,
-                FILTER_BITS * 2 + is_compound - conv_params.round_0 -
-                    conv_params.round_1);
+            av1_convolve_rounding(tmp_dst, tmp_dst_stride, dst, dst_buf->stride,
+                                  b4_w, b4_h, round_bits);
         }
         ++col;
       }
@@ -1250,16 +1252,17 @@
 
     // TODO(angiebird): This part needs optimization
     if (conv_params.do_post_rounding) {
+      int round_bits = FILTER_BITS * 2 + is_compound - conv_params.round_0 -
+                       conv_params.round_1;
+#if CONFIG_JNT_COMP
+      if (conv_params.use_jnt_comp_avg) round_bits += DIST_PRECISION_BITS - 1;
+#endif  // CONFIG_JNT_COMP
       if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-        av1_highbd_convolve_rounding(
-            tmp_dst, MAX_SB_SIZE, dst, dst_buf->stride, w, h,
-            FILTER_BITS * 2 + is_compound - conv_params.round_0 -
-                conv_params.round_1,
-            xd->bd);
+        av1_highbd_convolve_rounding(tmp_dst, MAX_SB_SIZE, dst, dst_buf->stride,
+                                     w, h, round_bits, xd->bd);
       else
         av1_convolve_rounding(tmp_dst, MAX_SB_SIZE, dst, dst_buf->stride, w, h,
-                              FILTER_BITS * 2 + is_compound -
-                                  conv_params.round_0 - conv_params.round_1);
+                              round_bits);
     }
   }
 }
diff --git a/av1/common/warped_motion.c b/av1/common/warped_motion.c
index cf41301..3b3eb8c 100644
--- a/av1/common/warped_motion.c
+++ b/av1/common/warped_motion.c
@@ -516,7 +516,6 @@
             if (conv_params->use_jnt_comp_avg) {
               if (conv_params->do_average) {
                 *p += sum * conv_params->bck_offset;
-                *p >>= (DIST_PRECISION_BITS - 1);
               } else {
                 *p = sum * conv_params->fwd_offset;
               }
@@ -820,7 +819,6 @@
             if (conv_params->use_jnt_comp_avg) {
               if (conv_params->do_average) {
                 *p += sum * conv_params->bck_offset;
-                *p >>= (DIST_PRECISION_BITS - 1);
               } else {
                 *p = sum * conv_params->fwd_offset;
               }
diff --git a/av1/common/x86/av1_convolve_scale_sse4.c b/av1/common/x86/av1_convolve_scale_sse4.c
index 7323403..e53717d 100644
--- a/av1/common/x86/av1_convolve_scale_sse4.c
+++ b/av1/common/x86/av1_convolve_scale_sse4.c
@@ -314,10 +314,8 @@
       __m128i result;
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          result =
-              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
-                                           _mm_mullo_epi32(subbed, bck_offset)),
-                             DIST_PRECISION_BITS - 1);
+          result = _mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
+                                 _mm_mullo_epi32(subbed, bck_offset));
         } else {
           result = _mm_mullo_epi32(subbed, fwd_offset);
         }
@@ -344,8 +342,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
@@ -430,10 +426,8 @@
       __m128i result;
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          result =
-              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
-                                           _mm_mullo_epi32(subbed, bck_offset)),
-                             DIST_PRECISION_BITS - 1);
+          result = _mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
+                                 _mm_mullo_epi32(subbed, bck_offset));
         } else {
           result = _mm_mullo_epi32(subbed, fwd_offset);
         }
@@ -460,8 +454,6 @@
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
           dst[y * dst_stride + x] += res * conv_params->bck_offset;
-
-          dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
         } else {
           dst[y * dst_stride + x] = res * conv_params->fwd_offset;
         }
diff --git a/av1/common/x86/convolve_2d_sse2.c b/av1/common/x86/convolve_2d_sse2.c
index 111bc85..e1b731a 100644
--- a/av1/common/x86/convolve_2d_sse2.c
+++ b/av1/common/x86/convolve_2d_sse2.c
@@ -695,22 +695,22 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_0 = sum;
 
             mul = _mm_mullo_epi16(d32_1, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
-            d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_1 = sum;
 
             mul = _mm_mullo_epi16(d32_2, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
-            d32_2 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_2 = sum;
 
             mul = _mm_mullo_epi16(d32_3, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
-            d32_3 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_3 = sum;
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
             d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
@@ -757,12 +757,12 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_0 = sum;
 
             mul = _mm_mullo_epi16(d32_1, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
             sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
-            d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_1 = sum;
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
             d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
@@ -798,7 +798,7 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_0 = sum;
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
           }
@@ -829,7 +829,7 @@
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
             __m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
-            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+            d32_0 = sum;
           } else {
             d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
           }
diff --git a/av1/common/x86/highbd_convolve_2d_sse4.c b/av1/common/x86/highbd_convolve_2d_sse4.c
index 853ecfd..80adce9 100644
--- a/av1/common/x86/highbd_convolve_2d_sse4.c
+++ b/av1/common/x86/highbd_convolve_2d_sse4.c
@@ -200,13 +200,9 @@
                 _mm_loadu_si128(p + 0), _mm_mullo_epi32(res_lo_round, wt1));
             const __m128i jnt_sum_hi = _mm_add_epi32(
                 _mm_loadu_si128(p + 1), _mm_mullo_epi32(res_hi_round, wt1));
-            const __m128i final_lo =
-                _mm_srai_epi32(jnt_sum_lo, DIST_PRECISION_BITS - 1);
-            const __m128i final_hi =
-                _mm_srai_epi32(jnt_sum_hi, DIST_PRECISION_BITS - 1);
 
-            _mm_storeu_si128(p + 0, final_lo);
-            _mm_storeu_si128(p + 1, final_hi);
+            _mm_storeu_si128(p + 0, jnt_sum_lo);
+            _mm_storeu_si128(p + 1, jnt_sum_hi);
           } else {
             _mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
             _mm_storeu_si128(p + 1, _mm_mullo_epi32(res_hi_round, wt0));
diff --git a/av1/common/x86/highbd_warp_plane_sse4.c b/av1/common/x86/highbd_warp_plane_sse4.c
index 0d4d17f..0cd438a 100644
--- a/av1/common/x86/highbd_warp_plane_sse4.c
+++ b/av1/common/x86/highbd_warp_plane_sse4.c
@@ -318,7 +318,7 @@
             if (comp_avg) {
               const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
                                                 _mm_mullo_epi32(res_lo, wt1));
-              res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+              res_lo = sum;
             } else {
               res_lo = _mm_mullo_epi32(res_lo, wt0);
             }
@@ -342,7 +342,7 @@
               if (comp_avg) {
                 const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
                                                   _mm_mullo_epi32(res_hi, wt1));
-                res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+                res_hi = sum;
               } else {
                 res_hi = _mm_mullo_epi32(res_hi, wt0);
               }
diff --git a/av1/common/x86/jnt_convolve_sse4.c b/av1/common/x86/jnt_convolve_sse4.c
index 9de2744..bc23365 100644
--- a/av1/common/x86/jnt_convolve_sse4.c
+++ b/av1/common/x86/jnt_convolve_sse4.c
@@ -88,12 +88,11 @@
 static INLINE void mult_add_store(CONV_BUF_TYPE *const dst,
                                   const __m128i *const res,
                                   const __m128i *const avg_mask,
-                                  const __m128i *const wt, int shift) {
+                                  const __m128i *const wt) {
   __m128i d;
   d = _mm_load_si128((__m128i *)dst);
   d = _mm_and_si128(d, *avg_mask);
   d = _mm_add_epi32(d, _mm_mullo_epi32(*res, *wt));
-  if (shift) d = _mm_srai_epi32(d, DIST_PRECISION_BITS - 1);
   _mm_store_si128((__m128i *)dst, d);
 }
 
@@ -157,8 +156,7 @@
       res_shift =
           _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
       if (conv_params->use_jnt_comp_avg)
-        mult_add_store(dst, &res_shift, &avg_mask, &wt,
-                       conv_params->do_average);
+        mult_add_store(dst, &res_shift, &avg_mask, &wt);
       else
         add_store(dst, &res_shift, &avg_mask);
       src_ptr += src_stride;
@@ -169,8 +167,7 @@
       res_shift =
           _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
       if (conv_params->use_jnt_comp_avg)
-        mult_add_store(dst, &res_shift, &avg_mask, &wt,
-                       conv_params->do_average);
+        mult_add_store(dst, &res_shift, &avg_mask, &wt);
       else
         add_store(dst, &res_shift, &avg_mask);
       src_ptr += src_stride;
@@ -229,9 +226,9 @@
                                      round_shift);
         if (conv_params->use_jnt_comp_avg) {
           mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
-                         &wt, conv_params->do_average);
+                         &wt);
           mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
-                         &wt, conv_params->do_average);
+                         &wt);
         } else {
           add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
           add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
@@ -248,9 +245,9 @@
                                      round_shift);
         if (conv_params->use_jnt_comp_avg) {
           mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
-                         &wt, conv_params->do_average);
+                         &wt);
           mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
-                         &wt, conv_params->do_average);
+                         &wt);
         } else {
           add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
           add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
@@ -317,8 +314,7 @@
 
       // Accumulate values into the destination buffer
       if (conv_params->use_jnt_comp_avg)
-        mult_add_store(dst, &res_lo_shift, &avg_mask, &wt,
-                       conv_params->do_average);
+        mult_add_store(dst, &res_lo_shift, &avg_mask, &wt);
       else
         add_store(dst, &res_lo_shift, &avg_mask);
       src_ptr += src_stride;
@@ -361,9 +357,9 @@
         // Accumulate values into the destination buffer
         if (conv_params->use_jnt_comp_avg) {
           mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
-                         &wt, conv_params->do_average);
+                         &wt);
           mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
-                         &wt, conv_params->do_average);
+                         &wt);
         } else {
           add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
           add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
@@ -557,17 +553,12 @@
           // original c function at: av1/common/convolve.c: av1_convolve_2d_c
           __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
           if (do_average) {
-            _mm_storeu_si128(
-                p + 0, _mm_srai_epi32(
-                           _mm_add_epi32(_mm_loadu_si128(p + 0),
-                                         _mm_mullo_epi32(res_lo_round, wt1)),
-                           DIST_PRECISION_BITS - 1));
-
-            _mm_storeu_si128(
-                p + 1, _mm_srai_epi32(
-                           _mm_add_epi32(_mm_loadu_si128(p + 1),
-                                         _mm_mullo_epi32(res_hi_round, wt1)),
-                           DIST_PRECISION_BITS - 1));
+            _mm_storeu_si128(p + 0,
+                             _mm_add_epi32(_mm_loadu_si128(p + 0),
+                                           _mm_mullo_epi32(res_lo_round, wt1)));
+            _mm_storeu_si128(p + 1,
+                             _mm_add_epi32(_mm_loadu_si128(p + 1),
+                                           _mm_mullo_epi32(res_hi_round, wt1)));
           } else {
             _mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
             _mm_storeu_si128(p + 1, _mm_mullo_epi32(res_hi_round, wt0));
diff --git a/av1/common/x86/warp_plane_sse4.c b/av1/common/x86/warp_plane_sse4.c
index 94b5f09..043e6a7 100644
--- a/av1/common/x86/warp_plane_sse4.c
+++ b/av1/common/x86/warp_plane_sse4.c
@@ -315,7 +315,7 @@
             if (comp_avg) {
               const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
                                                 _mm_mullo_epi32(res_lo, wt1));
-              res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+              res_lo = sum;
             } else {
               res_lo = _mm_mullo_epi32(res_lo, wt0);
             }
@@ -337,7 +337,7 @@
               if (comp_avg) {
                 const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
                                                   _mm_mullo_epi32(res_hi, wt1));
-                res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
+                res_hi = sum;
               } else {
                 res_hi = _mm_mullo_epi32(res_hi, wt0);
               }