Some fixes and clean-ups on convolve functions

Make the av1_convolve_x_sr_sse2/avx2 support various bit
shift options.

Addition of asserts in the convolve functions.

Change-Id: Ib6d1ada6c00a20e6e498af2672bd0bb76040d7d0
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index 9c44041..a2fb693 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -447,6 +447,8 @@
   (void)dst0;
   (void)dst_stride0;
 
+  assert(bits >= 0);
+
   // vertical filter
   const int16_t *y_filter = av1_get_interp_filter_subpel_kernel(
       *filter_params_y, subpel_y_q4 & SUBPEL_MASK);
@@ -481,6 +483,8 @@
   (void)dst0;
   (void)dst_stride0;
 
+  assert(bits >= 0);
+
   // horizontal filter
   const int16_t *x_filter = av1_get_interp_filter_subpel_kernel(
       *filter_params_x, subpel_x_q4 & SUBPEL_MASK);
@@ -590,6 +594,10 @@
   (void)subpel_x_q4;
   (void)conv_params;
 
+  assert(conv_params->round_0 <= FILTER_BITS);
+  assert(((conv_params->round_0 + conv_params->round_1) <= (FILTER_BITS + 1)) ||
+         ((conv_params->round_0 + conv_params->round_1) == (2 * FILTER_BITS)));
+
   // vertical filter
   const int16_t *y_filter = av1_get_interp_filter_subpel_kernel(
       *filter_params_y, subpel_y_q4 & SUBPEL_MASK);
@@ -617,6 +625,10 @@
   (void)subpel_y_q4;
   (void)conv_params;
 
+  assert(bits >= 0);
+  assert((FILTER_BITS - conv_params->round_1) >= 0 ||
+         ((conv_params->round_0 + conv_params->round_1) == 2 * FILTER_BITS));
+
   // horizontal filter
   const int16_t *x_filter = av1_get_interp_filter_subpel_kernel(
       *filter_params_x, subpel_x_q4 & SUBPEL_MASK);
diff --git a/av1/common/x86/convolve_2d_avx2.c b/av1/common/x86/convolve_2d_avx2.c
index 396e80f..9c1a32b 100644
--- a/av1/common/x86/convolve_2d_avx2.c
+++ b/av1/common/x86/convolve_2d_avx2.c
@@ -43,6 +43,8 @@
 
   __m256i filt[4], s[8], coeffs_x[4], coeffs_y[4];
 
+  assert(conv_params->round_0 > 0);
+
   filt[0] = _mm256_load_si256((__m256i const *)filt1_global_avx2);
   filt[1] = _mm256_load_si256((__m256i const *)filt2_global_avx2);
   filt[2] = _mm256_load_si256((__m256i const *)filt3_global_avx2);
@@ -176,6 +178,8 @@
 
   __m256i filt[4], coeffs_h[4], coeffs_v[4];
 
+  assert(conv_params->round_0 > 0);
+
   filt[0] = _mm256_load_si256((__m256i const *)filt1_global_avx2);
   filt[1] = _mm256_load_si256((__m256i const *)filt2_global_avx2);
   filt[2] = _mm256_load_si256((__m256i const *)filt3_global_avx2);
diff --git a/av1/common/x86/convolve_2d_sse2.c b/av1/common/x86/convolve_2d_sse2.c
index f2c3561..96a6042 100644
--- a/av1/common/x86/convolve_2d_sse2.c
+++ b/av1/common/x86/convolve_2d_sse2.c
@@ -41,6 +41,8 @@
 
   const __m128i zero = _mm_setzero_si128();
 
+  assert(conv_params->round_0 > 0);
+
   /* Horizontal filter */
   {
     const int16_t *x_filter = av1_get_interp_filter_subpel_kernel(
@@ -226,6 +228,8 @@
       FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
   const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
 
+  assert(conv_params->round_0 > 0);
+
   /* Horizontal filter */
   {
     const int16_t *x_filter = av1_get_interp_filter_subpel_kernel(
diff --git a/av1/common/x86/convolve_avx2.c b/av1/common/x86/convolve_avx2.c
index 2843a91..c4d7447 100644
--- a/av1/common/x86/convolve_avx2.c
+++ b/av1/common/x86/convolve_avx2.c
@@ -359,6 +359,8 @@
   const __m256i avg_mask = _mm256_set1_epi32(conv_params->do_average ? -1 : 0);
   __m256i coeffs[4], s[8];
 
+  assert((FILTER_BITS - conv_params->round_0) >= 0);
+
   prepare_coeffs(filter_params_y, subpel_y_q4, coeffs);
 
   (void)conv_params;
@@ -514,6 +516,10 @@
       _mm256_set1_epi16((1 << right_shift_bits) >> 1);
   __m256i coeffs[4], s[8];
 
+  assert(conv_params->round_0 <= FILTER_BITS);
+  assert(((conv_params->round_0 + conv_params->round_1) <= (FILTER_BITS + 1)) ||
+         ((conv_params->round_0 + conv_params->round_1) == (2 * FILTER_BITS)));
+
   prepare_coeffs(filter_params_y, subpel_y_q4, coeffs);
 
   (void)filter_params_x;
@@ -665,6 +671,9 @@
 
   __m256i filt[4], coeffs[4];
 
+  assert(bits >= 0);
+  assert(conv_params->round_0 > 0);
+
   filt[0] = _mm256_load_si256((__m256i const *)filt1_global_avx2);
   filt[1] = _mm256_load_si256((__m256i const *)filt2_global_avx2);
   filt[2] = _mm256_load_si256((__m256i const *)filt3_global_avx2);
@@ -720,6 +729,7 @@
   int i, j;
   const int fo_horiz = filter_params_x->taps / 2 - 1;
   const uint8_t *const src_ptr = src - fo_horiz;
+  const int bits = FILTER_BITS - conv_params->round_0;
 
   __m256i filt[4], coeffs[4];
 
@@ -730,14 +740,20 @@
 
   prepare_coeffs(filter_params_x, subpel_x_q4, coeffs);
 
-  const __m256i round_const =
-      _mm256_set1_epi16(((1 << (conv_params->round_0 - 1)) >> 1) +
-                        ((1 << (FILTER_BITS - 1)) >> 1));
-  const __m128i round_shift = _mm_cvtsi32_si128(FILTER_BITS - 1);
+  const __m256i round_0_const =
+      _mm256_set1_epi16((1 << (conv_params->round_0 - 1)) >> 1);
+  const __m128i round_0_shift = _mm_cvtsi32_si128(conv_params->round_0 - 1);
+  const __m256i round_const = _mm256_set1_epi16((1 << bits) >> 1);
+  const __m128i round_shift = _mm_cvtsi32_si128(bits);
 
   (void)filter_params_y;
   (void)subpel_y_q4;
 
+  assert(bits >= 0);
+  assert((FILTER_BITS - conv_params->round_1) >= 0 ||
+         ((conv_params->round_0 + conv_params->round_1) == 2 * FILTER_BITS));
+  assert(conv_params->round_0 > 0);
+
   for (i = 0; i < h; ++i) {
     for (j = 0; j < w; j += 16) {
       // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 8 9 10 11 12 13 14 15 16 17 18
@@ -748,7 +764,9 @@
 
       __m256i res_16b = convolve_x(data, coeffs, filt);
 
-      // Combine V round and 2F-H-V round into a single rounding
+      res_16b = _mm256_sra_epi16(_mm256_add_epi16(res_16b, round_0_const),
+                                 round_0_shift);
+
       res_16b =
           _mm256_sra_epi16(_mm256_add_epi16(res_16b, round_const), round_shift);
 
diff --git a/av1/common/x86/convolve_sse2.c b/av1/common/x86/convolve_sse2.c
index f8081d2..ab35226 100644
--- a/av1/common/x86/convolve_sse2.c
+++ b/av1/common/x86/convolve_sse2.c
@@ -105,6 +105,8 @@
   (void)dst0;
   (void)dst_stride0;
 
+  assert(bits >= 0);
+
   prepare_coeffs(filter_params_y, subpel_y_q4, coeffs);
 
   if (w == 4) {
@@ -252,6 +254,8 @@
   (void)dst0;
   (void)dst_stride0;
 
+  assert(bits >= 0);
+
   prepare_coeffs(filter_params_x, subpel_x_q4, coeffs);
 
   if (w == 4) {
@@ -335,6 +339,10 @@
   (void)subpel_x_q4;
   (void)conv_params;
 
+  assert(conv_params->round_0 <= FILTER_BITS);
+  assert(((conv_params->round_0 + conv_params->round_1) <= (FILTER_BITS + 1)) ||
+         ((conv_params->round_0 + conv_params->round_1) == (2 * FILTER_BITS)));
+
   prepare_coeffs(filter_params_y, subpel_y_q4, coeffs);
 
   if (w <= 4) {
@@ -484,14 +492,21 @@
                             ConvolveParams *conv_params) {
   const int fo_horiz = filter_params_x->taps / 2 - 1;
   const uint8_t *src_ptr = src - fo_horiz;
-  const __m128i round_const = _mm_set1_epi32(
-      ((1 << conv_params->round_0) >> 1) + (1 << (FILTER_BITS - 1)));
-  const __m128i round_shift = _mm_cvtsi32_si128(FILTER_BITS);
+  const int bits = FILTER_BITS - conv_params->round_0;
+  const __m128i round_0_const =
+      _mm_set1_epi32((1 << conv_params->round_0) >> 1);
+  const __m128i round_const = _mm_set1_epi32((1 << bits) >> 1);
+  const __m128i round_0_shift = _mm_cvtsi32_si128(conv_params->round_0);
+  const __m128i round_shift = _mm_cvtsi32_si128(bits);
   __m128i coeffs[4];
 
   (void)filter_params_y;
   (void)subpel_y_q4;
 
+  assert(bits >= 0);
+  assert((FILTER_BITS - conv_params->round_1) >= 0 ||
+         ((conv_params->round_0 + conv_params->round_1) == 2 * FILTER_BITS));
+
   prepare_coeffs(filter_params_x, subpel_x_q4, coeffs);
 
   if (w <= 4) {
@@ -507,8 +522,10 @@
       s[3] =
           _mm_unpacklo_epi8(_mm_srli_si128(data, 6), _mm_srli_si128(data, 7));
       const __m128i res_lo = convolve_lo_x(s, coeffs);
-      const __m128i res_lo_round =
-          _mm_sra_epi32(_mm_add_epi32(res_lo, round_const), round_shift);
+      __m128i res_lo_round =
+          _mm_sra_epi32(_mm_add_epi32(res_lo, round_0_const), round_0_shift);
+      res_lo_round =
+          _mm_sra_epi32(_mm_add_epi32(res_lo_round, round_const), round_shift);
 
       const __m128i res16 = _mm_packs_epi32(res_lo_round, res_lo_round);
       const __m128i res = _mm_packus_epi16(res16, res16);
@@ -549,10 +566,14 @@
         // Rearrange pixels back into the order 0 ... 7
         const __m128i res_lo = _mm_unpacklo_epi32(res_even, res_odd);
         const __m128i res_hi = _mm_unpackhi_epi32(res_even, res_odd);
-        const __m128i res_lo_round =
-            _mm_sra_epi32(_mm_add_epi32(res_lo, round_const), round_shift);
-        const __m128i res_hi_round =
-            _mm_sra_epi32(_mm_add_epi32(res_hi, round_const), round_shift);
+        __m128i res_lo_round =
+            _mm_sra_epi32(_mm_add_epi32(res_lo, round_0_const), round_0_shift);
+        res_lo_round = _mm_sra_epi32(_mm_add_epi32(res_lo_round, round_const),
+                                     round_shift);
+        __m128i res_hi_round =
+            _mm_sra_epi32(_mm_add_epi32(res_hi, round_0_const), round_0_shift);
+        res_hi_round = _mm_sra_epi32(_mm_add_epi32(res_hi_round, round_const),
+                                     round_shift);
 
         const __m128i res16 = _mm_packs_epi32(res_lo_round, res_hi_round);
         const __m128i res = _mm_packus_epi16(res16, res16);