Round compound prediction into 16 bits

Apply shifts for compound prediction immediately when sum
or weighted sum (for jnt_comp) is computed. Such that all
intermediate results can be fit into 16 bits.

Note: now the buffer is still 32 bits. We need new simd functions
for 16 bits and finally reduce buffer to 16 bits.

Change-Id: Ia46a4736d69aa028623dfb9f036a6ce527e5cd9f
diff --git a/aom_dsp/x86/convolve_avx2.h b/aom_dsp/x86/convolve_avx2.h
index ec5868e..3ca424c 100644
--- a/aom_dsp/x86/convolve_avx2.h
+++ b/aom_dsp/x86/convolve_avx2.h
@@ -122,11 +122,12 @@
 
 static INLINE void add_store_aligned(CONV_BUF_TYPE *const dst,
                                      const __m256i *const res,
-                                     const __m256i *const avg_mask) {
+                                     const __m256i *const avg_mask, int shift) {
   __m256i d;
   d = _mm256_load_si256((__m256i *)dst);
   d = _mm256_and_si256(d, *avg_mask);
   d = _mm256_add_epi32(d, *res);
+  if (shift) d = _mm256_srai_epi32(d, 1);
   _mm256_store_si256((__m256i *)dst, d);
 }
 
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index fef179c..ba9b96f 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -424,10 +424,13 @@
       CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) -
                           ((1 << (offset_bits - conv_params->round_1)) +
                            (1 << (offset_bits - conv_params->round_1 - 1)));
-      if (conv_params->do_average)
-        dst[y * dst_stride + x] += res;
-      else
+      if (conv_params->do_average) {
+        int32_t tmp = dst[y * dst_stride + x];
+        tmp += res;
+        dst[y * dst_stride + x] = tmp >> 1;
+      } else {
         dst[y * dst_stride + x] = res;
+      }
     }
   }
 }
@@ -460,10 +463,13 @@
       }
       res *= (1 << bits);
       res = ROUND_POWER_OF_TWO(res, conv_params->round_1);
-      if (conv_params->do_average)
-        dst[y * dst_stride + x] += res;
-      else
+      if (conv_params->do_average) {
+        int32_t tmp = dst[y * dst_stride + x];
+        tmp += res;
+        dst[y * dst_stride + x] = tmp >> 1;
+      } else {
         dst[y * dst_stride + x] = res;
+      }
     }
   }
 }
@@ -495,10 +501,13 @@
         res += x_filter[k] * src[y * src_stride + x - fo_horiz + k];
       }
       res = (1 << bits) * ROUND_POWER_OF_TWO(res, conv_params->round_0);
-      if (conv_params->do_average)
-        dst[y * dst_stride + x] += res;
-      else
+      if (conv_params->do_average) {
+        int32_t tmp = dst[y * dst_stride + x];
+        tmp += res;
+        dst[y * dst_stride + x] = tmp >> 1;
+      } else {
         dst[y * dst_stride + x] = res;
+      }
     }
   }
 }
@@ -524,10 +533,13 @@
   for (int y = 0; y < h; ++y) {
     for (int x = 0; x < w; ++x) {
       CONV_BUF_TYPE res = src[y * src_stride + x] << bits;
-      if (conv_params->do_average)
-        dst[y * dst_stride + x] += res;
-      else
+      if (conv_params->do_average) {
+        int32_t tmp = dst[y * dst_stride + x];
+        tmp += res;
+        dst[y * dst_stride + x] = tmp >> 1;
+      } else {
         dst[y * dst_stride + x] = res;
+      }
     }
   }
 }
@@ -714,15 +726,20 @@
                            (1 << (offset_bits - conv_params->round_1 - 1)));
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
       }
     }
   }
@@ -756,15 +773,20 @@
       res = ROUND_POWER_OF_TWO(res, conv_params->round_1);
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
       }
     }
   }
@@ -797,15 +819,20 @@
       res = (1 << bits) * ROUND_POWER_OF_TWO(res, conv_params->round_0);
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
       }
     }
   }
@@ -834,15 +861,20 @@
       CONV_BUF_TYPE res = src[y * src_stride + x] << bits;
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
       }
     }
   }
@@ -907,21 +939,29 @@
 #if CONFIG_JNT_COMP
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
       }
 #else
-      if (conv_params->do_average)
-        dst[y * dst_stride + x] += res;
-      else
+      if (conv_params->do_average) {
+        int32_t tmp = dst[y * dst_stride + x];
+        tmp += res;
+        dst[y * dst_stride + x] = tmp >> 1;
+      } else {
         dst[y * dst_stride + x] = res;
+      }
 #endif  // CONFIG_JNT_COMP
     }
     src_vert++;
@@ -1044,10 +1084,13 @@
       CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) -
                           ((1 << (offset_bits - conv_params->round_1)) +
                            (1 << (offset_bits - conv_params->round_1 - 1)));
-      if (conv_params->do_average)
-        dst[y * dst_stride + x] += res;
-      else
+      if (conv_params->do_average) {
+        int32_t tmp = dst[y * dst_stride + x];
+        tmp += res;
+        dst[y * dst_stride + x] = tmp >> 1;
+      } else {
         dst[y * dst_stride + x] = res;
+      }
     }
   }
 }
@@ -1339,15 +1382,20 @@
                            (1 << (offset_bits - conv_params->round_1 - 1)));
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
       }
     }
   }
@@ -1533,21 +1581,29 @@
 #if CONFIG_JNT_COMP
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
       }
 #else
-      if (conv_params->do_average)
-        dst[y * dst_stride + x] += res;
-      else
+      if (conv_params->do_average) {
+        int32_t tmp = dst[y * dst_stride + x];
+        tmp += res;
+        dst[y * dst_stride + x] = tmp >> 1;
+      } else {
         dst[y * dst_stride + x] = res;
+      }
 #endif  // CONFIG_JNT_COMP
     }
     src_vert++;
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 18e316a..d975aa8 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -30,17 +30,9 @@
 // prediction.
 
 static INLINE int get_compound_post_rounding_bits(
-    const MB_MODE_INFO *const mbmi, const ConvolveParams *conv_params) {
+    const ConvolveParams *conv_params) {
   assert(conv_params->is_compound);
-  int round_bits =
-      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
-  if (is_masked_compound_type(mbmi->interinter_compound_type))
-    return round_bits;
-  round_bits += conv_params->is_compound;
-#if CONFIG_JNT_COMP
-  if (conv_params->use_jnt_comp_avg) round_bits += DIST_PRECISION_BITS - 1;
-#endif  // CONFIG_JNT_COMP
-  return round_bits;
+  return 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
 }
 
 static INLINE int allow_warp(const MODE_INFO *const mi,
@@ -1105,8 +1097,7 @@
 
         if (conv_params.is_compound) {
           assert(conv_params.dst != NULL);
-          int round_bits =
-              get_compound_post_rounding_bits(&mi->mbmi, &conv_params);
+          int round_bits = get_compound_post_rounding_bits(&conv_params);
           if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
             av1_highbd_convolve_rounding(tmp_dst, tmp_dst_stride, dst,
                                          dst_buf->stride, b4_w, b4_h,
@@ -1242,7 +1233,7 @@
     // TODO(angiebird): This part needs optimization
     if (conv_params.is_compound) {
       assert(conv_params.dst != NULL);
-      int round_bits = get_compound_post_rounding_bits(&mi->mbmi, &conv_params);
+      int round_bits = get_compound_post_rounding_bits(&conv_params);
       if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
         av1_highbd_convolve_rounding(tmp_dst, MAX_SB_SIZE, dst, dst_buf->stride,
                                      w, h, round_bits, xd->bd);
diff --git a/av1/common/warped_motion.c b/av1/common/warped_motion.c
index fc4d8da..fa6f7ec 100644
--- a/av1/common/warped_motion.c
+++ b/av1/common/warped_motion.c
@@ -508,21 +508,30 @@
 #if CONFIG_JNT_COMP
             if (conv_params->use_jnt_comp_avg) {
               if (conv_params->do_average) {
-                *p += sum * conv_params->bck_offset;
+                int32_t tmp32 = *p;
+                tmp32 = tmp32 * conv_params->fwd_offset +
+                        sum * conv_params->bck_offset;
+                *p = tmp32 >> DIST_PRECISION_BITS;
               } else {
-                *p = sum * conv_params->fwd_offset;
+                *p = sum;
               }
             } else {
-              if (conv_params->do_average)
-                *p += sum;
-              else
+              if (conv_params->do_average) {
+                int32_t tmp32 = *p;
+                tmp32 += sum;
+                *p = tmp32 >> 1;
+              } else {
                 *p = sum;
+              }
             }
 #else
-            if (conv_params->do_average)
-              *p += sum;
-            else
+            if (conv_params->do_average) {
+              int32_t tmp32 = *p;
+              tmp32 += sum;
+              *p = tmp32 >> 1;
+            } else {
               *p = sum;
+            }
 #endif  // CONFIG_JNT_COMP
           } else {
             uint16_t *p =
@@ -802,21 +811,30 @@
 #if CONFIG_JNT_COMP
             if (conv_params->use_jnt_comp_avg) {
               if (conv_params->do_average) {
-                *p += sum * conv_params->bck_offset;
+                int32_t tmp32 = *p;
+                tmp32 = tmp32 * conv_params->fwd_offset +
+                        sum * conv_params->bck_offset;
+                *p = tmp32 >> DIST_PRECISION_BITS;
               } else {
-                *p = sum * conv_params->fwd_offset;
+                *p = sum;
               }
             } else {
-              if (conv_params->do_average)
-                *p += sum;
-              else
+              if (conv_params->do_average) {
+                int32_t tmp32 = *p;
+                tmp32 += sum;
+                *p = tmp32 >> 1;
+              } else {
                 *p = sum;
+              }
             }
 #else
-            if (conv_params->do_average)
-              *p += sum;
-            else
+            if (conv_params->do_average) {
+              int32_t tmp32 = *p;
+              tmp32 += sum;
+              *p = tmp32 >> 1;
+            } else {
               *p = sum;
+            }
 #endif  // CONFIG_JNT_COMP
           } else {
             uint8_t *p =
diff --git a/av1/common/x86/av1_convolve_scale_sse4.c b/av1/common/x86/av1_convolve_scale_sse4.c
index e53717d..973dbe3 100644
--- a/av1/common/x86/av1_convolve_scale_sse4.c
+++ b/av1/common/x86/av1_convolve_scale_sse4.c
@@ -314,20 +314,26 @@
       __m128i result;
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          result = _mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
-                                 _mm_mullo_epi32(subbed, bck_offset));
+          __m128i tmp = _mm_loadu_si128((__m128i *)dst_x);
+          tmp = _mm_add_epi32(_mm_mullo_epi32(tmp, fwd_offset),
+                              _mm_mullo_epi32(subbed, bck_offset));
+          result = _mm_srai_epi32(tmp, DIST_PRECISION_BITS);
         } else {
-          result = _mm_mullo_epi32(subbed, fwd_offset);
+          result = subbed;
         }
       } else {
-        result = (conv_params->do_average)
-                     ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
-                     : subbed;
+        result =
+            (conv_params->do_average)
+                ? _mm_srai_epi32(
+                      _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)),
+                      1)
+                : subbed;
       }
 #else
       const __m128i result =
           (conv_params->do_average)
-              ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
+              ? _mm_srai_epi32(
+                    _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)), 1)
               : subbed;
 #endif  // CONFIG_JNT_COMP
 
@@ -341,16 +347,21 @@
 #if CONFIG_JNT_COMP
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
 #endif  // CONFIG_JNT_COMP
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
 #if CONFIG_JNT_COMP
       }
 #endif  // CONFIG_JNT_COMP
@@ -426,20 +437,26 @@
       __m128i result;
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          result = _mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
-                                 _mm_mullo_epi32(subbed, bck_offset));
+          __m128i tmp = _mm_loadu_si128((__m128i *)dst_x);
+          tmp = _mm_add_epi32(_mm_mullo_epi32(tmp, fwd_offset),
+                              _mm_mullo_epi32(subbed, bck_offset));
+          result = _mm_srai_epi32(tmp, DIST_PRECISION_BITS);
         } else {
-          result = _mm_mullo_epi32(subbed, fwd_offset);
+          result = subbed;
         }
       } else {
-        result = (conv_params->do_average)
-                     ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
-                     : subbed;
+        result =
+            (conv_params->do_average)
+                ? _mm_srai_epi32(
+                      _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)),
+                      1)
+                : subbed;
       }
 #else
       const __m128i result =
           (conv_params->do_average)
-              ? _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x))
+              ? _mm_srai_epi32(
+                    _mm_add_epi32(subbed, _mm_loadu_si128((__m128i *)dst_x)), 1)
               : subbed;
 #endif  // CONFIG_JNT_COMP
 
@@ -453,16 +470,21 @@
 #if CONFIG_JNT_COMP
       if (conv_params->use_jnt_comp_avg) {
         if (conv_params->do_average) {
-          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
+          dst[y * dst_stride + x] = tmp >> DIST_PRECISION_BITS;
         } else {
-          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+          dst[y * dst_stride + x] = res;
         }
       } else {
 #endif  // CONFIG_JNT_COMP
-        if (conv_params->do_average)
-          dst[y * dst_stride + x] += res;
-        else
+        if (conv_params->do_average) {
+          int32_t tmp = dst[y * dst_stride + x];
+          tmp += res;
+          dst[y * dst_stride + x] = tmp >> 1;
+        } else {
           dst[y * dst_stride + x] = res;
+        }
 #if CONFIG_JNT_COMP
       }
 #endif  // CONFIG_JNT_COMP
diff --git a/av1/common/x86/convolve_2d_avx2.c b/av1/common/x86/convolve_2d_avx2.c
index fafe344..6407c3a 100644
--- a/av1/common/x86/convolve_2d_avx2.c
+++ b/av1/common/x86/convolve_2d_avx2.c
@@ -126,9 +126,10 @@
           const __m256i res_bx =
               _mm256_permute2x128_si256(res_a_round, res_b_round, 0x31);
 
-          add_store_aligned(&dst[i * dst_stride + j], &res_ax, &avg_mask);
+          add_store_aligned(&dst[i * dst_stride + j], &res_ax, &avg_mask,
+                            conv_params->do_average);
           add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_bx,
-                            &avg_mask);
+                            &avg_mask, conv_params->do_average);
         } else {
           const __m128i res_ax = _mm256_extracti128_si256(res_a_round, 0);
           const __m128i res_bx = _mm256_extracti128_si256(res_a_round, 1);
@@ -140,6 +141,10 @@
           r1 = _mm_and_si128(r1, _mm256_extracti128_si256(avg_mask, 0));
           r0 = _mm_add_epi32(r0, res_ax);
           r1 = _mm_add_epi32(r1, res_bx);
+          if (conv_params->do_average) {
+            r0 = _mm_srai_epi32(r0, 1);
+            r1 = _mm_srai_epi32(r1, 1);
+          }
           _mm_store_si128((__m128i *)&dst[i * dst_stride + j], r0);
           _mm_store_si128((__m128i *)&dst[i * dst_stride + j + dst_stride], r1);
         }
diff --git a/av1/common/x86/convolve_2d_sse2.c b/av1/common/x86/convolve_2d_sse2.c
index 96a6042..941d195 100644
--- a/av1/common/x86/convolve_2d_sse2.c
+++ b/av1/common/x86/convolve_2d_sse2.c
@@ -193,10 +193,14 @@
         // Accumulate values into the destination buffer
         __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
         if (do_average) {
-          _mm_storeu_si128(p + 0,
-                           _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
-          _mm_storeu_si128(p + 1,
-                           _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+          _mm_storeu_si128(
+              p + 0,
+              _mm_srai_epi32(
+                  _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
+          _mm_storeu_si128(
+              p + 1,
+              _mm_srai_epi32(
+                  _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
         } else {
           _mm_storeu_si128(p + 0, res_lo_round);
           _mm_storeu_si128(p + 1, res_hi_round);
@@ -444,10 +448,18 @@
 
         __m128i *const p = (__m128i *)&dst[j];
         if (do_average) {
-          _mm_storeu_si128(p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), d32_0));
-          _mm_storeu_si128(p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), d32_1));
-          _mm_storeu_si128(p + 2, _mm_add_epi32(_mm_loadu_si128(p + 2), d32_2));
-          _mm_storeu_si128(p + 3, _mm_add_epi32(_mm_loadu_si128(p + 3), d32_3));
+          _mm_storeu_si128(
+              p + 0,
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0), d32_0), 1));
+          _mm_storeu_si128(
+              p + 1,
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1), d32_1), 1));
+          _mm_storeu_si128(
+              p + 2,
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 2), d32_2), 1));
+          _mm_storeu_si128(
+              p + 3,
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 3), d32_3), 1));
         } else {
           _mm_storeu_si128(p + 0, d32_0);
           _mm_storeu_si128(p + 1, d32_1);
@@ -471,8 +483,12 @@
 
         __m128i *const p = (__m128i *)&dst[j];
         if (do_average) {
-          _mm_storeu_si128(p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), d32_0));
-          _mm_storeu_si128(p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), d32_1));
+          _mm_storeu_si128(
+              p + 0,
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0), d32_0), 1));
+          _mm_storeu_si128(
+              p + 1,
+              _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1), d32_1), 1));
         } else {
           _mm_storeu_si128(p + 0, d32_0);
           _mm_storeu_si128(p + 1, d32_1);
@@ -491,7 +507,8 @@
         d32_0 = _mm_sll_epi32(d32_0, left_shift);
         __m128i *const p = (__m128i *)&dst[j];
         if (do_average) {
-          _mm_storeu_si128(p, _mm_add_epi32(_mm_loadu_si128(p), d32_0));
+          _mm_storeu_si128(
+              p, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), d32_0), 1));
         } else {
           _mm_storeu_si128(p, d32_0);
         }
@@ -509,7 +526,8 @@
         d32_0 = _mm_sll_epi32(d32_0, left_shift);
         __m128i *const p = (__m128i *)&dst[j];
         if (do_average) {
-          _mm_storel_epi64(p, _mm_add_epi32(_mm_loadl_epi64(p), d32_0));
+          _mm_storel_epi64(
+              p, _mm_srai_epi32(_mm_add_epi32(_mm_loadl_epi64(p), d32_0), 1));
         } else {
           _mm_storel_epi64(p, d32_0);
         }
@@ -707,39 +725,52 @@
           if (do_average) {
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
-            __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = sum;
+            __m128i tmp = _mm_loadu_si128(p + 0);
+            __m128i sum =
+                _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
 
             mul = _mm_mullo_epi16(d32_1, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
-            sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
-            d32_1 = sum;
+            tmp = _mm_loadu_si128(p + 1);
+            sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
 
             mul = _mm_mullo_epi16(d32_2, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
-            sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
-            d32_2 = sum;
+            tmp = _mm_loadu_si128(p + 2);
+            sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_2 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
 
             mul = _mm_mullo_epi16(d32_3, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
-            sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
-            d32_3 = sum;
+            tmp = _mm_loadu_si128(p + 3);
+            sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_3 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
           } else {
-            d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
-            d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
-            d32_2 = _mm_sll_epi32(_mm_mullo_epi16(d32_2, wt0), left_shift);
-            d32_3 = _mm_sll_epi32(_mm_mullo_epi16(d32_3, wt0), left_shift);
+            d32_0 = _mm_sll_epi32(d32_0, left_shift);
+            d32_1 = _mm_sll_epi32(d32_1, left_shift);
+            d32_2 = _mm_sll_epi32(d32_2, left_shift);
+            d32_3 = _mm_sll_epi32(d32_3, left_shift);
           }
         } else {
           if (do_average) {
-            d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
-                                  _mm_sll_epi32(d32_0, left_shift));
-            d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1),
-                                  _mm_sll_epi32(d32_1, left_shift));
-            d32_2 = _mm_add_epi32(_mm_loadu_si128(p + 2),
-                                  _mm_sll_epi32(d32_2, left_shift));
-            d32_3 = _mm_add_epi32(_mm_loadu_si128(p + 3),
-                                  _mm_sll_epi32(d32_3, left_shift));
+            d32_0 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+                                             _mm_sll_epi32(d32_0, left_shift)),
+                               1);
+            d32_1 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
+                                             _mm_sll_epi32(d32_1, left_shift)),
+                               1);
+            d32_2 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 2),
+                                             _mm_sll_epi32(d32_2, left_shift)),
+                               1);
+            d32_3 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 3),
+                                             _mm_sll_epi32(d32_3, left_shift)),
+                               1);
           } else {
             d32_0 = _mm_sll_epi32(d32_0, left_shift);
             d32_1 = _mm_sll_epi32(d32_1, left_shift);
@@ -769,23 +800,30 @@
           if (do_average) {
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
-            __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = sum;
+            __m128i tmp = _mm_loadu_si128(p + 0);
+            __m128i sum =
+                _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
 
             mul = _mm_mullo_epi16(d32_1, wt1);
             weighted_res = _mm_sll_epi32(mul, left_shift);
-            sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
-            d32_1 = sum;
+            tmp = _mm_loadu_si128(p + 1);
+            sum = _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
           } else {
-            d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
-            d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
+            d32_0 = _mm_sll_epi32(d32_0, left_shift);
+            d32_1 = _mm_sll_epi32(d32_1, left_shift);
           }
         } else {
           if (do_average) {
-            d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
-                                  _mm_sll_epi32(d32_0, left_shift));
-            d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1),
-                                  _mm_sll_epi32(d32_1, left_shift));
+            d32_0 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+                                             _mm_sll_epi32(d32_0, left_shift)),
+                               1);
+            d32_1 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
+                                             _mm_sll_epi32(d32_1, left_shift)),
+                               1);
           } else {
             d32_0 = _mm_sll_epi32(d32_0, left_shift);
             d32_1 = _mm_sll_epi32(d32_1, left_shift);
@@ -810,15 +848,19 @@
           if (do_average) {
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
-            __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
-            d32_0 = sum;
+            __m128i tmp = _mm_loadu_si128(p + 0);
+            __m128i sum =
+                _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
           } else {
-            d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
+            d32_0 = _mm_sll_epi32(d32_0, left_shift);
           }
         } else {
           if (do_average) {
-            d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
-                                  _mm_sll_epi32(d32_0, left_shift));
+            d32_0 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+                                             _mm_sll_epi32(d32_0, left_shift)),
+                               1);
           } else {
             d32_0 = _mm_sll_epi32(d32_0, left_shift);
           }
@@ -841,15 +883,19 @@
           if (do_average) {
             __m128i mul = _mm_mullo_epi16(d32_0, wt1);
             __m128i weighted_res = _mm_sll_epi32(mul, left_shift);
-            __m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
-            d32_0 = sum;
+            __m128i tmp = _mm_loadl_epi64(p);
+            __m128i sum =
+                _mm_add_epi32(_mm_mullo_epi16(tmp, wt0), weighted_res);
+            d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
           } else {
-            d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
+            d32_0 = _mm_sll_epi32(d32_0, left_shift);
           }
         } else {
           if (do_average) {
-            d32_0 = _mm_add_epi32(_mm_loadl_epi64(p),
-                                  _mm_sll_epi32(d32_0, left_shift));
+            d32_0 =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadl_epi64(p),
+                                             _mm_sll_epi32(d32_0, left_shift)),
+                               1);
           } else {
             d32_0 = _mm_sll_epi32(d32_0, left_shift);
           }
diff --git a/av1/common/x86/convolve_avx2.c b/av1/common/x86/convolve_avx2.c
index 63f6138..ffc870b 100644
--- a/av1/common/x86/convolve_avx2.c
+++ b/av1/common/x86/convolve_avx2.c
@@ -452,7 +452,8 @@
           _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
 
       // Accumulate values into the destination buffer
-      add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_round, &avg_mask);
+      add_store_aligned(&dst[i * dst_stride + j], &res_lo_0_round, &avg_mask,
+                        conv_params->do_average);
 
       const __m256i res_lo_1_32b =
           _mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_lo, 1));
@@ -462,7 +463,7 @@
           _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
 
       add_store_aligned(&dst[i * dst_stride + j + dst_stride], &res_lo_1_round,
-                        &avg_mask);
+                        &avg_mask, conv_params->do_average);
 
       if (w - j > 8) {
         const __m256i res_hi = convolve_lowbd(s + 4, coeffs);
@@ -475,7 +476,7 @@
             _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
 
         add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_0_round,
-                          &avg_mask);
+                          &avg_mask, conv_params->do_average);
 
         const __m256i res_hi_1_32b =
             _mm256_cvtepi16_epi32(_mm256_extracti128_si256(res_hi, 1));
@@ -485,7 +486,7 @@
             _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
 
         add_store_aligned(&dst[i * dst_stride + j + 8 + dst_stride],
-                          &res_hi_1_round, &avg_mask);
+                          &res_hi_1_round, &avg_mask, conv_params->do_average);
       }
       s[0] = s[1];
       s[1] = s[2];
@@ -711,10 +712,11 @@
       const __m256i res_hi_shift = _mm256_slli_epi32(res_hi_round, bits);
 
       // Accumulate values into the destination buffer
-      add_store_aligned(&dst[i * dst_stride + j], &res_lo_shift, &avg_mask);
+      add_store_aligned(&dst[i * dst_stride + j], &res_lo_shift, &avg_mask,
+                        conv_params->do_average);
       if (w - j > 8) {
         add_store_aligned(&dst[i * dst_stride + j + 8], &res_hi_shift,
-                          &avg_mask);
+                          &avg_mask, conv_params->do_average);
       }
     }
   }
diff --git a/av1/common/x86/convolve_sse2.c b/av1/common/x86/convolve_sse2.c
index ab35226..a03f0ef 100644
--- a/av1/common/x86/convolve_sse2.c
+++ b/av1/common/x86/convolve_sse2.c
@@ -75,11 +75,12 @@
 }
 
 static INLINE void add_store(CONV_BUF_TYPE *const dst, const __m128i *const res,
-                             const __m128i *const avg_mask) {
+                             const __m128i *const avg_mask, int shift) {
   __m128i d;
   d = _mm_load_si128((__m128i *)dst);
   d = _mm_and_si128(d, *avg_mask);
   d = _mm_add_epi32(d, *res);
+  if (shift) d = _mm_srai_epi32(d, 1);
   _mm_store_si128((__m128i *)dst, d);
 }
 
@@ -141,7 +142,7 @@
       res_shift = _mm_sll_epi32(res, left_shift);
       res_shift =
           _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
-      add_store(dst, &res_shift, &avg_mask);
+      add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
       src_ptr += src_stride;
       dst += dst_stride;
 
@@ -149,7 +150,7 @@
       res_shift = _mm_sll_epi32(res, left_shift);
       res_shift =
           _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
-      add_store(dst, &res_shift, &avg_mask);
+      add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
       src_ptr += src_stride;
       dst += dst_stride;
 
@@ -204,8 +205,10 @@
                                      round_shift);
         res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
                                      round_shift);
-        add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
-        add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+        add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+                  conv_params->do_average);
+        add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+                  conv_params->do_average);
         i++;
 
         res_lo = convolve_lo_y(s + 1, coeffs);  // Filter low index pixels
@@ -216,8 +219,10 @@
                                      round_shift);
         res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
                                      round_shift);
-        add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
-        add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+        add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+                  conv_params->do_average);
+        add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+                  conv_params->do_average);
         i++;
 
         s[0] = s[2];
@@ -276,7 +281,7 @@
       const __m128i res_lo_shift = _mm_sll_epi32(res_lo_round, left_shift);
 
       // Accumulate values into the destination buffer
-      add_store(dst, &res_lo_shift, &avg_mask);
+      add_store(dst, &res_lo_shift, &avg_mask, conv_params->do_average);
       src_ptr += src_stride;
       dst += dst_stride;
     } while (--h);
@@ -315,8 +320,10 @@
         const __m128i res_hi_shift = _mm_sll_epi32(res_hi_round, left_shift);
 
         // Accumulate values into the destination buffer
-        add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
-        add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+        add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+                  conv_params->do_average);
+        add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+                  conv_params->do_average);
         j += 8;
       } while (j < w);
     } while (++i < h);
diff --git a/av1/common/x86/highbd_convolve_2d_avx2.c b/av1/common/x86/highbd_convolve_2d_avx2.c
index cb13e9b..73c2bea 100644
--- a/av1/common/x86/highbd_convolve_2d_avx2.c
+++ b/av1/common/x86/highbd_convolve_2d_avx2.c
@@ -374,18 +374,26 @@
         __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
         if (do_average) {
           _mm_storeu_si128(
-              p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0),
-                                   _mm256_extractf128_si256(res_lo_round, 0)));
+              p + 0, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
+                                                  _mm256_extractf128_si256(
+                                                      res_lo_round, 0)),
+                                    1));
           _mm_storeu_si128(
-              p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1),
-                                   _mm256_extractf128_si256(res_hi_round, 0)));
+              p + 1, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
+                                                  _mm256_extractf128_si256(
+                                                      res_hi_round, 0)),
+                                    1));
           if (w - j > 8) {
-            _mm_storeu_si128(p + 2, _mm_add_epi32(_mm_loadu_si128(p + 2),
-                                                  _mm256_extractf128_si256(
-                                                      res_lo_round, 1)));
-            _mm_storeu_si128(p + 3, _mm_add_epi32(_mm_loadu_si128(p + 3),
-                                                  _mm256_extractf128_si256(
-                                                      res_hi_round, 1)));
+            _mm_storeu_si128(
+                p + 2, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 2),
+                                                    _mm256_extractf128_si256(
+                                                        res_lo_round, 1)),
+                                      1));
+            _mm_storeu_si128(
+                p + 3, _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p + 3),
+                                                    _mm256_extractf128_si256(
+                                                        res_hi_round, 1)),
+                                      1));
           }
         } else {
           _mm_storeu_si128(p + 0, _mm256_extractf128_si256(res_lo_round, 0));
diff --git a/av1/common/x86/highbd_convolve_2d_sse4.c b/av1/common/x86/highbd_convolve_2d_sse4.c
index 979d1dd..6980960 100644
--- a/av1/common/x86/highbd_convolve_2d_sse4.c
+++ b/av1/common/x86/highbd_convolve_2d_sse4.c
@@ -201,23 +201,35 @@
         __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
         if (conv_params->use_jnt_comp_avg) {
           if (do_average) {
-            const __m128i jnt_sum_lo = _mm_add_epi32(
-                _mm_loadu_si128(p + 0), _mm_mullo_epi32(res_lo_round, wt1));
-            const __m128i jnt_sum_hi = _mm_add_epi32(
-                _mm_loadu_si128(p + 1), _mm_mullo_epi32(res_hi_round, wt1));
+            const __m128i tmp_lo = _mm_loadu_si128(p + 0);
+            const __m128i tmp_hi = _mm_loadu_si128(p + 1);
+            const __m128i jnt_sum_lo =
+                _mm_add_epi32(_mm_mullo_epi32(tmp_lo, wt0),
+                              _mm_mullo_epi32(res_lo_round, wt1));
+            const __m128i jnt_sum_hi =
+                _mm_add_epi32(_mm_mullo_epi32(tmp_hi, wt0),
+                              _mm_mullo_epi32(res_hi_round, wt1));
+            const __m128i final_lo =
+                _mm_srai_epi32(jnt_sum_lo, DIST_PRECISION_BITS);
+            const __m128i final_hi =
+                _mm_srai_epi32(jnt_sum_hi, DIST_PRECISION_BITS);
 
-            _mm_storeu_si128(p + 0, jnt_sum_lo);
-            _mm_storeu_si128(p + 1, jnt_sum_hi);
+            _mm_storeu_si128(p + 0, final_lo);
+            _mm_storeu_si128(p + 1, final_hi);
           } else {
-            _mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
-            _mm_storeu_si128(p + 1, _mm_mullo_epi32(res_hi_round, wt0));
+            _mm_storeu_si128(p + 0, res_lo_round);
+            _mm_storeu_si128(p + 1, res_hi_round);
           }
         } else {
           if (do_average) {
             _mm_storeu_si128(
-                p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
+                p + 0,
+                _mm_srai_epi32(
+                    _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
             _mm_storeu_si128(
-                p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+                p + 1,
+                _mm_srai_epi32(
+                    _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
           } else {
             _mm_storeu_si128(p + 0, res_lo_round);
             _mm_storeu_si128(p + 1, res_hi_round);
diff --git a/av1/common/x86/highbd_convolve_2d_ssse3.c b/av1/common/x86/highbd_convolve_2d_ssse3.c
index ee948f8..ce348ac 100644
--- a/av1/common/x86/highbd_convolve_2d_ssse3.c
+++ b/av1/common/x86/highbd_convolve_2d_ssse3.c
@@ -192,10 +192,14 @@
         // Accumulate values into the destination buffer
         __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
         if (do_average) {
-          _mm_storeu_si128(p + 0,
-                           _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
-          _mm_storeu_si128(p + 1,
-                           _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+          _mm_storeu_si128(
+              p + 0,
+              _mm_srai_epi32(
+                  _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
+          _mm_storeu_si128(
+              p + 1,
+              _mm_srai_epi32(
+                  _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
         } else {
           _mm_storeu_si128(p + 0, res_lo_round);
           _mm_storeu_si128(p + 1, res_hi_round);
diff --git a/av1/common/x86/highbd_warp_plane_sse4.c b/av1/common/x86/highbd_warp_plane_sse4.c
index 5647eb3..4ebd8a6 100644
--- a/av1/common/x86/highbd_warp_plane_sse4.c
+++ b/av1/common/x86/highbd_warp_plane_sse4.c
@@ -309,19 +309,22 @@
 #if CONFIG_JNT_COMP
           if (conv_params->use_jnt_comp_avg) {
             if (comp_avg) {
-              const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
-                                                _mm_mullo_epi32(res_lo, wt1));
-              res_lo = sum;
-            } else {
-              res_lo = _mm_mullo_epi32(res_lo, wt0);
+              const __m128i sum =
+                  _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p), wt0),
+                                _mm_mullo_epi32(res_lo, wt1));
+              res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
             }
           } else {
-            if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+            if (comp_avg)
+              res_lo =
+                  _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
           }
 
           _mm_storeu_si128(p, res_lo);
 #else
-          if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+          if (comp_avg)
+            res_lo =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
           _mm_storeu_si128(p, res_lo);
 #endif
 
@@ -332,21 +335,22 @@
 #if CONFIG_JNT_COMP
             if (conv_params->use_jnt_comp_avg) {
               if (comp_avg) {
-                const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
-                                                  _mm_mullo_epi32(res_hi, wt1));
-                res_hi = sum;
-              } else {
-                res_hi = _mm_mullo_epi32(res_hi, wt0);
+                const __m128i sum =
+                    _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 1), wt0),
+                                  _mm_mullo_epi32(res_hi, wt1));
+                res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS);
               }
             } else {
               if (comp_avg)
-                res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+                res_hi = _mm_srai_epi32(
+                    _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
             }
 
             _mm_storeu_si128(p + 1, res_hi);
 #else
             if (comp_avg)
-              res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+              res_hi = _mm_srai_epi32(
+                  _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
             _mm_storeu_si128(p + 1, res_hi);
 #endif
           }
diff --git a/av1/common/x86/jnt_convolve_sse4.c b/av1/common/x86/jnt_convolve_sse4.c
index bc23365..54bef5a 100644
--- a/av1/common/x86/jnt_convolve_sse4.c
+++ b/av1/common/x86/jnt_convolve_sse4.c
@@ -76,23 +76,28 @@
 }
 
 static INLINE void add_store(CONV_BUF_TYPE *const dst, const __m128i *const res,
-                             const __m128i *const avg_mask) {
+                             const __m128i *const avg_mask, int shift) {
   __m128i d;
   d = _mm_load_si128((__m128i *)dst);
   d = _mm_and_si128(d, *avg_mask);
   d = _mm_add_epi32(d, *res);
+  if (shift) d = _mm_srai_epi32(d, 1);
   _mm_store_si128((__m128i *)dst, d);
 }
 
 #if CONFIG_JNT_COMP
 static INLINE void mult_add_store(CONV_BUF_TYPE *const dst,
                                   const __m128i *const res,
-                                  const __m128i *const avg_mask,
-                                  const __m128i *const wt) {
+                                  const __m128i *const wt0,
+                                  const __m128i *const wt1, int do_average) {
   __m128i d;
-  d = _mm_load_si128((__m128i *)dst);
-  d = _mm_and_si128(d, *avg_mask);
-  d = _mm_add_epi32(d, _mm_mullo_epi32(*res, *wt));
+  if (do_average) {
+    d = _mm_load_si128((__m128i *)dst);
+    d = _mm_add_epi32(_mm_mullo_epi32(d, *wt0), _mm_mullo_epi32(*res, *wt1));
+    d = _mm_srai_epi32(d, DIST_PRECISION_BITS);
+  } else {
+    d = *res;
+  }
   _mm_store_si128((__m128i *)dst, d);
 }
 
@@ -111,7 +116,6 @@
   const __m128i avg_mask = _mm_set1_epi32(conv_params->do_average ? -1 : 0);
   const __m128i wt0 = _mm_set1_epi32(conv_params->fwd_offset);
   const __m128i wt1 = _mm_set1_epi32(conv_params->bck_offset);
-  const __m128i wt = conv_params->do_average ? wt1 : wt0;
   const __m128i round_const = _mm_set1_epi32((1 << conv_params->round_1) >> 1);
   const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
   __m128i coeffs[4];
@@ -156,9 +160,9 @@
       res_shift =
           _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
       if (conv_params->use_jnt_comp_avg)
-        mult_add_store(dst, &res_shift, &avg_mask, &wt);
+        mult_add_store(dst, &res_shift, &wt0, &wt1, conv_params->do_average);
       else
-        add_store(dst, &res_shift, &avg_mask);
+        add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
       src_ptr += src_stride;
       dst += dst_stride;
 
@@ -167,9 +171,9 @@
       res_shift =
           _mm_sra_epi32(_mm_add_epi32(res_shift, round_const), round_shift);
       if (conv_params->use_jnt_comp_avg)
-        mult_add_store(dst, &res_shift, &avg_mask, &wt);
+        mult_add_store(dst, &res_shift, &wt0, &wt1, conv_params->do_average);
       else
-        add_store(dst, &res_shift, &avg_mask);
+        add_store(dst, &res_shift, &avg_mask, conv_params->do_average);
       src_ptr += src_stride;
       dst += dst_stride;
 
@@ -225,13 +229,15 @@
         res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
                                      round_shift);
         if (conv_params->use_jnt_comp_avg) {
-          mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
-                         &wt);
-          mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
-                         &wt);
+          mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &wt0,
+                         &wt1, conv_params->do_average);
+          mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &wt0,
+                         &wt1, conv_params->do_average);
         } else {
-          add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
-          add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+          add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+                    conv_params->do_average);
+          add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+                    conv_params->do_average);
         }
         i++;
 
@@ -244,13 +250,15 @@
         res_hi_shift = _mm_sra_epi32(_mm_add_epi32(res_hi_shift, round_const),
                                      round_shift);
         if (conv_params->use_jnt_comp_avg) {
-          mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
-                         &wt);
-          mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
-                         &wt);
+          mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &wt0,
+                         &wt1, conv_params->do_average);
+          mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &wt0,
+                         &wt1, conv_params->do_average);
         } else {
-          add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
-          add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+          add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+                    conv_params->do_average);
+          add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+                    conv_params->do_average);
         }
         i++;
 
@@ -285,7 +293,6 @@
   const int w1 = conv_params->bck_offset;
   const __m128i wt0 = _mm_set1_epi32(w0);
   const __m128i wt1 = _mm_set1_epi32(w1);
-  const __m128i wt = conv_params->do_average ? wt1 : wt0;
   __m128i coeffs[4];
 
   (void)filter_params_y;
@@ -314,9 +321,9 @@
 
       // Accumulate values into the destination buffer
       if (conv_params->use_jnt_comp_avg)
-        mult_add_store(dst, &res_lo_shift, &avg_mask, &wt);
+        mult_add_store(dst, &res_lo_shift, &wt0, &wt1, conv_params->do_average);
       else
-        add_store(dst, &res_lo_shift, &avg_mask);
+        add_store(dst, &res_lo_shift, &avg_mask, conv_params->do_average);
       src_ptr += src_stride;
       dst += dst_stride;
     } while (--h);
@@ -356,13 +363,15 @@
 
         // Accumulate values into the destination buffer
         if (conv_params->use_jnt_comp_avg) {
-          mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
-                         &wt);
-          mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
-                         &wt);
+          mult_add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &wt0,
+                         &wt1, conv_params->do_average);
+          mult_add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &wt0,
+                         &wt1, conv_params->do_average);
         } else {
-          add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask);
-          add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask);
+          add_store(dst + i * dst_stride + j + 0, &res_lo_shift, &avg_mask,
+                    conv_params->do_average);
+          add_store(dst + i * dst_stride + j + 4, &res_hi_shift, &avg_mask,
+                    conv_params->do_average);
         }
         j += 8;
       } while (j < w);
@@ -553,24 +562,34 @@
           // original c function at: av1/common/convolve.c: av1_convolve_2d_c
           __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
           if (do_average) {
-            _mm_storeu_si128(p + 0,
-                             _mm_add_epi32(_mm_loadu_si128(p + 0),
-                                           _mm_mullo_epi32(res_lo_round, wt1)));
-            _mm_storeu_si128(p + 1,
-                             _mm_add_epi32(_mm_loadu_si128(p + 1),
-                                           _mm_mullo_epi32(res_hi_round, wt1)));
+            _mm_storeu_si128(
+                p + 0,
+                _mm_srai_epi32(
+                    _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 0), wt0),
+                                  _mm_mullo_epi32(res_lo_round, wt1)),
+                    DIST_PRECISION_BITS));
+            _mm_storeu_si128(
+                p + 1,
+                _mm_srai_epi32(
+                    _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 1), wt0),
+                                  _mm_mullo_epi32(res_hi_round, wt1)),
+                    DIST_PRECISION_BITS));
           } else {
-            _mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
-            _mm_storeu_si128(p + 1, _mm_mullo_epi32(res_hi_round, wt0));
+            _mm_storeu_si128(p + 0, res_lo_round);
+            _mm_storeu_si128(p + 1, res_hi_round);
           }
         } else {
           // Accumulate values into the destination buffer
           __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
           if (do_average) {
             _mm_storeu_si128(
-                p + 0, _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round));
+                p + 0,
+                _mm_srai_epi32(
+                    _mm_add_epi32(_mm_loadu_si128(p + 0), res_lo_round), 1));
             _mm_storeu_si128(
-                p + 1, _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round));
+                p + 1,
+                _mm_srai_epi32(
+                    _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi_round), 1));
           } else {
             _mm_storeu_si128(p + 0, res_lo_round);
             _mm_storeu_si128(p + 1, res_hi_round);
diff --git a/av1/common/x86/warp_plane_sse4.c b/av1/common/x86/warp_plane_sse4.c
index 1e8ad47..b05b3b8 100644
--- a/av1/common/x86/warp_plane_sse4.c
+++ b/av1/common/x86/warp_plane_sse4.c
@@ -484,18 +484,21 @@
 #if CONFIG_JNT_COMP
           if (conv_params->use_jnt_comp_avg) {
             if (comp_avg) {
-              res_lo = _mm_add_epi32(_mm_loadu_si128(p),
+              res_lo = _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p), wt0),
                                      _mm_mullo_epi32(res_lo, wt1));
-            } else {
-              res_lo = _mm_mullo_epi32(res_lo, wt0);
+              res_lo = _mm_srai_epi32(res_lo, DIST_PRECISION_BITS);
             }
           } else {
-            if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+            if (comp_avg)
+              res_lo =
+                  _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
           }
 
           _mm_storeu_si128(p, res_lo);
 #else
-          if (comp_avg) res_lo = _mm_add_epi32(_mm_loadu_si128(p), res_lo);
+          if (comp_avg)
+            res_lo =
+                _mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128(p), res_lo), 1);
           _mm_storeu_si128(p, res_lo);
 #endif
           if (p_width > 4) {
@@ -504,20 +507,22 @@
 #if CONFIG_JNT_COMP
             if (conv_params->use_jnt_comp_avg) {
               if (comp_avg) {
-                res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1),
-                                       _mm_mullo_epi32(res_hi, wt1));
-              } else {
-                res_hi = _mm_mullo_epi32(res_hi, wt0);
+                res_hi =
+                    _mm_add_epi32(_mm_mullo_epi32(_mm_loadu_si128(p + 1), wt0),
+                                  _mm_mullo_epi32(res_hi, wt1));
+                res_hi = _mm_srai_epi32(res_hi, DIST_PRECISION_BITS);
               }
             } else {
               if (comp_avg)
-                res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+                res_hi = _mm_srai_epi32(
+                    _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
             }
 
             _mm_storeu_si128(p + 1, res_hi);
 #else
             if (comp_avg)
-              res_hi = _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi);
+              res_hi = _mm_srai_epi32(
+                  _mm_add_epi32(_mm_loadu_si128(p + 1), res_hi), 1);
             _mm_storeu_si128(p + 1, res_hi);
 #endif
           }