High precision Wiener filter rework

Implements the high precision Wiener filter with an offset
to reduce the error due to saturation without increasing
the number of bits needed for intermediate precision.

Also turns the high precision filter on.

Change-Id: I34037a5746a6a89c5fce67753c1b027749085edf
diff --git a/aom_dsp/aom_convolve.c b/aom_dsp/aom_convolve.c
index 1abd9a2..4dac6aa 100644
--- a/aom_dsp/aom_convolve.c
+++ b/aom_dsp/aom_convolve.c
@@ -337,14 +337,14 @@
                                    uint8_t *dst, ptrdiff_t dst_stride,
                                    const InterpKernel *x_filters, int x0_q4,
                                    int x_step_q4, int w, int h) {
-  int x, y;
+  int x, y, k;
   src -= SUBPEL_TAPS / 2 - 1;
   for (y = 0; y < h; ++y) {
     int x_q4 = x0_q4;
     for (x = 0; x < w; ++x) {
       const uint8_t *const src_x = &src[x_q4 >> SUBPEL_BITS];
       const int16_t *const x_filter = x_filters[x_q4 & SUBPEL_MASK];
-      int k, sum = 0;
+      int sum = 0;
       for (k = 0; k < SUBPEL_TAPS; ++k) sum += src_x[k] * x_filter[k];
       dst[x] = clip_pixel(ROUND_POWER_OF_TWO(sum, FILTER_BITS) +
                           src_x[SUBPEL_TAPS / 2 - 1]);
@@ -359,7 +359,7 @@
                                   uint8_t *dst, ptrdiff_t dst_stride,
                                   const InterpKernel *y_filters, int y0_q4,
                                   int y_step_q4, int w, int h) {
-  int x, y;
+  int x, y, k;
   src -= src_stride * (SUBPEL_TAPS / 2 - 1);
 
   for (x = 0; x < w; ++x) {
@@ -367,7 +367,7 @@
     for (y = 0; y < h; ++y) {
       const unsigned char *src_y = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
       const int16_t *const y_filter = y_filters[y_q4 & SUBPEL_MASK];
-      int k, sum = 0;
+      int sum = 0;
       for (k = 0; k < SUBPEL_TAPS; ++k)
         sum += src_y[k * src_stride] * y_filter[k];
       dst[y * dst_stride] =
@@ -451,18 +451,20 @@
                                        uint16_t *dst, ptrdiff_t dst_stride,
                                        const InterpKernel *x_filters, int x0_q4,
                                        int x_step_q4, int w, int h) {
-  int x, y;
+  const int bd = 8;
+  int x, y, k;
   src -= SUBPEL_TAPS / 2 - 1;
   for (y = 0; y < h; ++y) {
     int x_q4 = x0_q4;
     for (x = 0; x < w; ++x) {
       const uint8_t *const src_x = &src[x_q4 >> SUBPEL_BITS];
       const int16_t *const x_filter = x_filters[x_q4 & SUBPEL_MASK];
-      int k, sum = ((int)src_x[SUBPEL_TAPS / 2 - 1] << FILTER_BITS);
+      int sum = ((int)src_x[SUBPEL_TAPS / 2 - 1] << FILTER_BITS) +
+                (1 << (bd + FILTER_BITS - 1));
       for (k = 0; k < SUBPEL_TAPS; ++k) sum += src_x[k] * x_filter[k];
       dst[x] =
           (uint16_t)clamp(ROUND_POWER_OF_TWO(sum, FILTER_BITS - EXTRAPREC_BITS),
-                          0, EXTRAPREC_CLAMP_LIMIT - 1);
+                          0, EXTRAPREC_CLAMP_LIMIT(bd) - 1);
       x_q4 += x_step_q4;
     }
     src += src_stride;
@@ -474,7 +476,8 @@
                                       uint8_t *dst, ptrdiff_t dst_stride,
                                       const InterpKernel *y_filters, int y0_q4,
                                       int y_step_q4, int w, int h) {
-  int x, y;
+  const int bd = 8;
+  int x, y, k;
   src -= src_stride * (SUBPEL_TAPS / 2 - 1);
 
   for (x = 0; x < w; ++x) {
@@ -482,8 +485,9 @@
     for (y = 0; y < h; ++y) {
       const uint16_t *src_y = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
       const int16_t *const y_filter = y_filters[y_q4 & SUBPEL_MASK];
-      int k,
-          sum = ((int)src_y[(SUBPEL_TAPS / 2 - 1) * src_stride] << FILTER_BITS);
+      int sum =
+          ((int)src_y[(SUBPEL_TAPS / 2 - 1) * src_stride] << FILTER_BITS) -
+          (1 << (bd + FILTER_BITS + EXTRAPREC_BITS - 1));
       for (k = 0; k < SUBPEL_TAPS; ++k)
         sum += src_y[k * src_stride] * y_filter[k];
       dst[y * dst_stride] =
@@ -838,7 +842,7 @@
                                           const InterpKernel *x_filters,
                                           int x0_q4, int x_step_q4, int w,
                                           int h, int bd) {
-  int x, y;
+  int x, y, k;
   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
   src -= SUBPEL_TAPS / 2 - 1;
@@ -847,7 +851,7 @@
     for (x = 0; x < w; ++x) {
       const uint16_t *const src_x = &src[x_q4 >> SUBPEL_BITS];
       const int16_t *const x_filter = x_filters[x_q4 & SUBPEL_MASK];
-      int k, sum = 0;
+      int sum = 0;
       for (k = 0; k < SUBPEL_TAPS; ++k) sum += src_x[k] * x_filter[k];
       dst[x] = clip_pixel_highbd(
           ROUND_POWER_OF_TWO(sum, FILTER_BITS) + src_x[SUBPEL_TAPS / 2 - 1],
@@ -865,7 +869,7 @@
                                          const InterpKernel *y_filters,
                                          int y0_q4, int y_step_q4, int w, int h,
                                          int bd) {
-  int x, y;
+  int x, y, k;
   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
   src -= src_stride * (SUBPEL_TAPS / 2 - 1);
@@ -874,7 +878,7 @@
     for (y = 0; y < h; ++y) {
       const uint16_t *src_y = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
       const int16_t *const y_filter = y_filters[y_q4 & SUBPEL_MASK];
-      int k, sum = 0;
+      int sum = 0;
       for (k = 0; k < SUBPEL_TAPS; ++k)
         sum += src_y[k * src_stride] * y_filter[k];
       dst[y * dst_stride] =
@@ -972,8 +976,8 @@
     const uint8_t *src8, ptrdiff_t src_stride, uint16_t *dst,
     ptrdiff_t dst_stride, const InterpKernel *x_filters, int x0_q4,
     int x_step_q4, int w, int h, int bd) {
-  const int extraprec_clamp_limit = (EXTRAPREC_CLAMP_LIMIT << (bd - 8));
-  int x, y;
+  const int extraprec_clamp_limit = EXTRAPREC_CLAMP_LIMIT(bd);
+  int x, y, k;
   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
   src -= SUBPEL_TAPS / 2 - 1;
   for (y = 0; y < h; ++y) {
@@ -981,7 +985,8 @@
     for (x = 0; x < w; ++x) {
       const uint16_t *const src_x = &src[x_q4 >> SUBPEL_BITS];
       const int16_t *const x_filter = x_filters[x_q4 & SUBPEL_MASK];
-      int k, sum = ((int)src_x[SUBPEL_TAPS / 2 - 1] << FILTER_BITS);
+      int sum = ((int)src_x[SUBPEL_TAPS / 2 - 1] << FILTER_BITS) +
+                (1 << (bd + FILTER_BITS - 1));
       for (k = 0; k < SUBPEL_TAPS; ++k) sum += src_x[k] * x_filter[k];
       dst[x] =
           (uint16_t)clamp(ROUND_POWER_OF_TWO(sum, FILTER_BITS - EXTRAPREC_BITS),
@@ -997,7 +1002,7 @@
     const uint16_t *src, ptrdiff_t src_stride, uint8_t *dst8,
     ptrdiff_t dst_stride, const InterpKernel *y_filters, int y0_q4,
     int y_step_q4, int w, int h, int bd) {
-  int x, y;
+  int x, y, k;
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
   src -= src_stride * (SUBPEL_TAPS / 2 - 1);
   for (x = 0; x < w; ++x) {
@@ -1005,8 +1010,9 @@
     for (y = 0; y < h; ++y) {
       const uint16_t *src_y = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
       const int16_t *const y_filter = y_filters[y_q4 & SUBPEL_MASK];
-      int k,
-          sum = ((int)src_y[(SUBPEL_TAPS / 2 - 1) * src_stride] << FILTER_BITS);
+      int sum =
+          ((int)src_y[(SUBPEL_TAPS / 2 - 1) * src_stride] << FILTER_BITS) -
+          (1 << (bd + FILTER_BITS + EXTRAPREC_BITS - 1));
       for (k = 0; k < SUBPEL_TAPS; ++k)
         sum += src_y[k * src_stride] * y_filter[k];
       dst[y * dst_stride] = clip_pixel_highbd(
diff --git a/aom_dsp/aom_convolve.h b/aom_dsp/aom_convolve.h
index d922d01..c7943dc 100644
--- a/aom_dsp/aom_convolve.h
+++ b/aom_dsp/aom_convolve.h
@@ -38,7 +38,7 @@
 
 #if CONFIG_AV1 && CONFIG_LOOP_RESTORATION
 #define EXTRAPREC_BITS 2
-#define EXTRAPREC_CLAMP_LIMIT (512 << EXTRAPREC_BITS)
+#define EXTRAPREC_CLAMP_LIMIT(bd) (1 << ((bd) + 1 + EXTRAPREC_BITS))
 #endif
 
 typedef void (*convolve_fn_t)(const uint8_t *src, ptrdiff_t src_stride,
diff --git a/aom_dsp/x86/aom_convolve_hip_sse2.c b/aom_dsp/x86/aom_convolve_hip_sse2.c
index a27239b..1435289 100644
--- a/aom_dsp/x86/aom_convolve_hip_sse2.c
+++ b/aom_dsp/x86/aom_convolve_hip_sse2.c
@@ -22,6 +22,7 @@
                                     const int16_t *filter_x, int x_step_q4,
                                     const int16_t *filter_y, int y_step_q4,
                                     int w, int h) {
+  const int bd = 8;
   assert(x_step_q4 == 16 && y_step_q4 == 16);
   assert(!(w & 7));
   (void)x_step_q4;
@@ -57,7 +58,8 @@
     const __m128i coeff_67 = _mm_unpackhi_epi64(tmp_1, tmp_1);
 
     const __m128i round_const =
-        _mm_set1_epi32((1 << (FILTER_BITS - EXTRAPREC_BITS)) >> 1);
+        _mm_set1_epi32((1 << (FILTER_BITS - EXTRAPREC_BITS - 1)) +
+                       (1 << (bd + FILTER_BITS - 1)));
 
     for (i = 0; i < intermediate_height; ++i) {
       for (j = 0; j < w; j += 8) {
@@ -97,7 +99,7 @@
         // Pack in the column order 0, 2, 4, 6, 1, 3, 5, 7
         __m128i res = _mm_packs_epi32(res_even, res_odd);
         res = _mm_min_epi16(_mm_max_epi16(res, zero),
-                            _mm_set1_epi16(EXTRAPREC_CLAMP_LIMIT - 1));
+                            _mm_set1_epi16(EXTRAPREC_CLAMP_LIMIT(bd) - 1));
         _mm_storeu_si128((__m128i *)&temp[i * MAX_SB_SIZE + j], res);
       }
     }
@@ -123,7 +125,8 @@
     const __m128i coeff_67 = _mm_unpackhi_epi64(tmp_1, tmp_1);
 
     const __m128i round_const =
-        _mm_set1_epi32((1 << (FILTER_BITS + EXTRAPREC_BITS)) >> 1);
+        _mm_set1_epi32((1 << (FILTER_BITS + EXTRAPREC_BITS - 1)) -
+                       (1 << (bd + FILTER_BITS + EXTRAPREC_BITS - 1)));
 
     for (i = 0; i < h; ++i) {
       for (j = 0; j < w; j += 8) {
diff --git a/aom_dsp/x86/aom_highbd_convolve_hip_ssse3.c b/aom_dsp/x86/aom_highbd_convolve_hip_ssse3.c
index 7fad65f..74ce80e 100644
--- a/aom_dsp/x86/aom_highbd_convolve_hip_ssse3.c
+++ b/aom_dsp/x86/aom_highbd_convolve_hip_ssse3.c
@@ -64,7 +64,8 @@
     const __m128i coeff_67 = _mm_unpackhi_epi64(tmp_1, tmp_1);
 
     const __m128i round_const =
-        _mm_set1_epi32((1 << (FILTER_BITS - EXTRAPREC_BITS)) >> 1);
+        _mm_set1_epi32((1 << (FILTER_BITS - EXTRAPREC_BITS - 1)) +
+                       (1 << (bd + FILTER_BITS - 1)));
 
     for (i = 0; i < intermediate_height; ++i) {
       for (j = 0; j < w; j += 8) {
@@ -103,8 +104,7 @@
                                  FILTER_BITS - EXTRAPREC_BITS);
 
         // Pack in the column order 0, 2, 4, 6, 1, 3, 5, 7
-        const __m128i maxval =
-            _mm_set1_epi16((EXTRAPREC_CLAMP_LIMIT << (bd - 8)) - 1);
+        const __m128i maxval = _mm_set1_epi16((EXTRAPREC_CLAMP_LIMIT(bd)) - 1);
         __m128i res = _mm_packs_epi32(res_even, res_odd);
         res = _mm_min_epi16(_mm_max_epi16(res, zero), maxval);
         _mm_storeu_si128((__m128i *)&temp[i * MAX_SB_SIZE + j], res);
@@ -132,7 +132,8 @@
     const __m128i coeff_67 = _mm_unpackhi_epi64(tmp_1, tmp_1);
 
     const __m128i round_const =
-        _mm_set1_epi32((1 << (FILTER_BITS + EXTRAPREC_BITS)) >> 1);
+        _mm_set1_epi32((1 << (FILTER_BITS + EXTRAPREC_BITS - 1)) -
+                       (1 << (bd + FILTER_BITS + EXTRAPREC_BITS - 1)));
 
     for (i = 0; i < h; ++i) {
       for (j = 0; j < w; j += 8) {
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 8770abc..477f20a 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -82,7 +82,7 @@
 #define WIENER_FILT_STEP (1 << WIENER_FILT_PREC_BITS)
 
 // Whether to use high intermediate precision filtering
-#define USE_WIENER_HIGH_INTERMEDIATE_PRECISION 0
+#define USE_WIENER_HIGH_INTERMEDIATE_PRECISION 1
 
 // Central values for the taps
 #define WIENER_FILT_TAP0_MIDV (3)