Adds the option to use 5x5 Wiener for chroma

Change-Id: I1b789acc18f1e69fb5db069ccd8bd17815938e9d
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 8863942..a198f9c 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -548,41 +548,44 @@
   return avg;
 }
 
-static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
-                          int v_start, int v_end, int dgd_stride,
-                          int src_stride, double *M, double *H) {
+static void compute_stats(int wiener_win, uint8_t *dgd, uint8_t *src,
+                          int h_start, int h_end, int v_start, int v_end,
+                          int dgd_stride, int src_stride, double *M,
+                          double *H) {
   int i, j, k, l;
   double Y[WIENER_WIN2];
+  const int wiener_win2 = wiener_win * wiener_win;
+  const int wiener_halfwin = (wiener_win >> 1);
   const double avg =
       find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
 
-  memset(M, 0, sizeof(*M) * WIENER_WIN2);
-  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
+  memset(M, 0, sizeof(*M) * wiener_win2);
+  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
   for (i = v_start; i < v_end; i++) {
     for (j = h_start; j < h_end; j++) {
       const double X = (double)src[i * src_stride + j] - avg;
       int idx = 0;
-      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
-        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
+      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
+        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
           Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
           idx++;
         }
       }
-      for (k = 0; k < WIENER_WIN2; ++k) {
+      for (k = 0; k < wiener_win2; ++k) {
         M[k] += Y[k] * X;
-        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
-        for (l = k + 1; l < WIENER_WIN2; ++l) {
+        H[k * wiener_win2 + k] += Y[k] * Y[k];
+        for (l = k + 1; l < wiener_win2; ++l) {
           // H is a symmetric matrix, so we only need to fill out the upper
           // triangle here. We can copy it down to the lower triangle outside
           // the (i, j) loops.
-          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
+          H[k * wiener_win2 + l] += Y[k] * Y[l];
         }
       }
     }
   }
-  for (k = 0; k < WIENER_WIN2; ++k) {
-    for (l = k + 1; l < WIENER_WIN2; ++l) {
-      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
+  for (k = 0; k < wiener_win2; ++k) {
+    for (l = k + 1; l < wiener_win2; ++l) {
+      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
     }
   }
 }
@@ -600,168 +603,183 @@
   return avg;
 }
 
-static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
-                                 int h_end, int v_start, int v_end,
+static void compute_stats_highbd(int wiener_win, uint8_t *dgd8, uint8_t *src8,
+                                 int h_start, int h_end, int v_start, int v_end,
                                  int dgd_stride, int src_stride, double *M,
                                  double *H) {
   int i, j, k, l;
   double Y[WIENER_WIN2];
+  const int wiener_win2 = wiener_win * wiener_win;
+  const int wiener_halfwin = (wiener_win >> 1);
   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
   uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
   const double avg =
       find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
 
-  memset(M, 0, sizeof(*M) * WIENER_WIN2);
-  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
+  memset(M, 0, sizeof(*M) * wiener_win2);
+  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
   for (i = v_start; i < v_end; i++) {
     for (j = h_start; j < h_end; j++) {
       const double X = (double)src[i * src_stride + j] - avg;
       int idx = 0;
-      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
-        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
+      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
+        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
           Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
           idx++;
         }
       }
-      for (k = 0; k < WIENER_WIN2; ++k) {
+      for (k = 0; k < wiener_win2; ++k) {
         M[k] += Y[k] * X;
-        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
-        for (l = k + 1; l < WIENER_WIN2; ++l) {
+        H[k * wiener_win2 + k] += Y[k] * Y[k];
+        for (l = k + 1; l < wiener_win2; ++l) {
           // H is a symmetric matrix, so we only need to fill out the upper
           // triangle here. We can copy it down to the lower triangle outside
           // the (i, j) loops.
-          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
+          H[k * wiener_win2 + l] += Y[k] * Y[l];
         }
       }
     }
   }
-  for (k = 0; k < WIENER_WIN2; ++k) {
-    for (l = k + 1; l < WIENER_WIN2; ++l) {
-      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
+  for (k = 0; k < wiener_win2; ++k) {
+    for (l = k + 1; l < wiener_win2; ++l) {
+      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
     }
   }
 }
 #endif  // CONFIG_HIGHBITDEPTH
 
-static INLINE int wrap_index(int i) {
-  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
+static INLINE int wrap_index(int i, int wiener_win) {
+  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
+  return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
 }
 
 // Fix vector b, update vector a
-static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
+static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
+                             double *a, double *b) {
   int i, j;
   double S[WIENER_WIN];
   double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
-  int w, w2;
+  const int wiener_win2 = wiener_win * wiener_win;
+  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
   memset(A, 0, sizeof(A));
   memset(B, 0, sizeof(B));
-  for (i = 0; i < WIENER_WIN; i++) {
-    for (j = 0; j < WIENER_WIN; ++j) {
-      const int jj = wrap_index(j);
+  for (i = 0; i < wiener_win; i++) {
+    for (j = 0; j < wiener_win; ++j) {
+      const int jj = wrap_index(j, wiener_win);
       A[jj] += Mc[i][j] * b[i];
     }
   }
-  for (i = 0; i < WIENER_WIN; i++) {
-    for (j = 0; j < WIENER_WIN; j++) {
+  for (i = 0; i < wiener_win; i++) {
+    for (j = 0; j < wiener_win; j++) {
       int k, l;
-      for (k = 0; k < WIENER_WIN; ++k)
-        for (l = 0; l < WIENER_WIN; ++l) {
-          const int kk = wrap_index(k);
-          const int ll = wrap_index(l);
-          B[ll * WIENER_HALFWIN1 + kk] +=
-              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
+      for (k = 0; k < wiener_win; ++k)
+        for (l = 0; l < wiener_win; ++l) {
+          const int kk = wrap_index(k, wiener_win);
+          const int ll = wrap_index(l, wiener_win);
+          B[ll * wiener_halfwin1 + kk] +=
+              Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
         }
     }
   }
   // Normalization enforcement in the system of equations itself
-  w = WIENER_WIN;
-  w2 = (w >> 1) + 1;
-  for (i = 0; i < w2 - 1; ++i)
+  for (i = 0; i < wiener_halfwin1 - 1; ++i)
     A[i] -=
-        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
-  for (i = 0; i < w2 - 1; ++i)
-    for (j = 0; j < w2 - 1; ++j)
-      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
-                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
-  if (linsolve(w2 - 1, B, w2, A, S)) {
-    S[w2 - 1] = 1.0;
-    for (i = w2; i < w; ++i) {
-      S[i] = S[w - 1 - i];
-      S[w2 - 1] -= 2 * S[i];
+        A[wiener_halfwin1 - 1] * 2 +
+        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
+        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
+  for (i = 0; i < wiener_halfwin1 - 1; ++i)
+    for (j = 0; j < wiener_halfwin1 - 1; ++j)
+      B[i * wiener_halfwin1 + j] -=
+          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
+               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
+               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
+                     (wiener_halfwin1 - 1)]);
+  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
+    S[wiener_halfwin1 - 1] = 1.0;
+    for (i = wiener_halfwin1; i < wiener_win; ++i) {
+      S[i] = S[wiener_win - 1 - i];
+      S[wiener_halfwin1 - 1] -= 2 * S[i];
     }
-    memcpy(a, S, w * sizeof(*a));
+    memcpy(a, S, wiener_win * sizeof(*a));
   }
 }
 
 // Fix vector a, update vector b
-static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
+static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
+                             double *a, double *b) {
   int i, j;
   double S[WIENER_WIN];
   double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
-  int w, w2;
+  const int wiener_win2 = wiener_win * wiener_win;
+  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
   memset(A, 0, sizeof(A));
   memset(B, 0, sizeof(B));
-  for (i = 0; i < WIENER_WIN; i++) {
-    const int ii = wrap_index(i);
-    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
+  for (i = 0; i < wiener_win; i++) {
+    const int ii = wrap_index(i, wiener_win);
+    for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
   }
 
-  for (i = 0; i < WIENER_WIN; i++) {
-    for (j = 0; j < WIENER_WIN; j++) {
-      const int ii = wrap_index(i);
-      const int jj = wrap_index(j);
+  for (i = 0; i < wiener_win; i++) {
+    for (j = 0; j < wiener_win; j++) {
+      const int ii = wrap_index(i, wiener_win);
+      const int jj = wrap_index(j, wiener_win);
       int k, l;
-      for (k = 0; k < WIENER_WIN; ++k)
-        for (l = 0; l < WIENER_WIN; ++l)
-          B[jj * WIENER_HALFWIN1 + ii] +=
-              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
+      for (k = 0; k < wiener_win; ++k)
+        for (l = 0; l < wiener_win; ++l)
+          B[jj * wiener_halfwin1 + ii] +=
+              Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
     }
   }
   // Normalization enforcement in the system of equations itself
-  w = WIENER_WIN;
-  w2 = WIENER_HALFWIN1;
-  for (i = 0; i < w2 - 1; ++i)
+  for (i = 0; i < wiener_halfwin1 - 1; ++i)
     A[i] -=
-        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
-  for (i = 0; i < w2 - 1; ++i)
-    for (j = 0; j < w2 - 1; ++j)
-      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
-                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
-  if (linsolve(w2 - 1, B, w2, A, S)) {
-    S[w2 - 1] = 1.0;
-    for (i = w2; i < w; ++i) {
-      S[i] = S[w - 1 - i];
-      S[w2 - 1] -= 2 * S[i];
+        A[wiener_halfwin1 - 1] * 2 +
+        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
+        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
+  for (i = 0; i < wiener_halfwin1 - 1; ++i)
+    for (j = 0; j < wiener_halfwin1 - 1; ++j)
+      B[i * wiener_halfwin1 + j] -=
+          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
+               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
+               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
+                     (wiener_halfwin1 - 1)]);
+  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
+    S[wiener_halfwin1 - 1] = 1.0;
+    for (i = wiener_halfwin1; i < wiener_win; ++i) {
+      S[i] = S[wiener_win - 1 - i];
+      S[wiener_halfwin1 - 1] -= 2 * S[i];
     }
-    memcpy(b, S, w * sizeof(*b));
+    memcpy(b, S, wiener_win * sizeof(*b));
   }
 }
 
-static int wiener_decompose_sep_sym(double *M, double *H, double *a,
-                                    double *b) {
+static int wiener_decompose_sep_sym(int wiener_win, double *M, double *H,
+                                    double *a, double *b) {
   static const int init_filt[WIENER_WIN] = {
     WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
     WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
     WIENER_FILT_TAP0_MIDV,
   };
-  int i, j, iter;
   double *Hc[WIENER_WIN2];
   double *Mc[WIENER_WIN];
-  for (i = 0; i < WIENER_WIN; i++) {
-    Mc[i] = M + i * WIENER_WIN;
-    for (j = 0; j < WIENER_WIN; j++) {
-      Hc[i * WIENER_WIN + j] =
-          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
-    }
+  int i, j, iter;
+  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
+  const int wiener_win2 = wiener_win * wiener_win;
+  for (i = 0; i < wiener_win; i++) {
+    a[i] = b[i] = (double)init_filt[i + plane_off] / WIENER_FILT_STEP;
   }
-  for (i = 0; i < WIENER_WIN; i++) {
-    a[i] = b[i] = (double)init_filt[i] / WIENER_FILT_STEP;
+  for (i = 0; i < wiener_win; i++) {
+    Mc[i] = M + i * wiener_win;
+    for (j = 0; j < wiener_win; j++) {
+      Hc[i * wiener_win + j] =
+          H + i * wiener_win * wiener_win2 + j * wiener_win;
+    }
   }
 
   iter = 1;
   while (iter < NUM_WIENER_ITERS) {
-    update_a_sep_sym(Mc, Hc, a, b);
-    update_b_sep_sym(Mc, Hc, a, b);
+    update_a_sep_sym(wiener_win, Mc, Hc, a, b);
+    update_b_sep_sym(wiener_win, Mc, Hc, a, b);
     iter++;
   }
   return 1;
@@ -770,14 +788,16 @@
 // Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
 // against identity filters; Final score is defined as the difference between
 // the function values
-static double compute_score(double *M, double *H, InterpKernel vfilt,
-                            InterpKernel hfilt) {
+static double compute_score(int wiener_win, double *M, double *H,
+                            InterpKernel vfilt, InterpKernel hfilt) {
   double ab[WIENER_WIN * WIENER_WIN];
   int i, k, l;
   double P = 0, Q = 0;
   double iP = 0, iQ = 0;
   double Score, iScore;
   double a[WIENER_WIN], b[WIENER_WIN];
+  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
+  const int wiener_win2 = wiener_win * wiener_win;
 
   aom_clear_system_state();
 
@@ -788,32 +808,40 @@
     a[WIENER_HALFWIN] -= 2 * a[i];
     b[WIENER_HALFWIN] -= 2 * b[i];
   }
-  for (k = 0; k < WIENER_WIN; ++k) {
-    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
+  for (k = 0; k < wiener_win; ++k) {
+    for (l = 0; l < wiener_win; ++l)
+      ab[k * wiener_win + l] = a[l + plane_off] * b[k + plane_off];
   }
-  for (k = 0; k < WIENER_WIN2; ++k) {
+  for (k = 0; k < wiener_win2; ++k) {
     P += ab[k] * M[k];
-    for (l = 0; l < WIENER_WIN2; ++l)
-      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
+    for (l = 0; l < wiener_win2; ++l)
+      Q += ab[k] * H[k * wiener_win2 + l] * ab[l];
   }
   Score = Q - 2 * P;
 
-  iP = M[WIENER_WIN2 >> 1];
-  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
+  iP = M[wiener_win2 >> 1];
+  iQ = H[(wiener_win2 >> 1) * wiener_win2 + (wiener_win2 >> 1)];
   iScore = iQ - 2 * iP;
 
   return Score - iScore;
 }
 
-static void quantize_sym_filter(double *f, InterpKernel fi) {
+static void quantize_sym_filter(int wiener_win, double *f, InterpKernel fi) {
   int i;
-  for (i = 0; i < WIENER_HALFWIN; ++i) {
+  const int wiener_halfwin = (wiener_win >> 1);
+  for (i = 0; i < wiener_halfwin; ++i) {
     fi[i] = RINT(f[i] * WIENER_FILT_STEP);
   }
   // Specialize for 7-tap filter
-  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
-  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
-  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
+  if (wiener_win == WIENER_WIN) {
+    fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
+    fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
+    fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
+  } else {
+    fi[2] = CLIP(fi[1], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
+    fi[1] = CLIP(fi[0], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
+    fi[0] = 0;
+  }
   // Satisfy filter constraints
   fi[WIENER_WIN - 1] = fi[0];
   fi[WIENER_WIN - 2] = fi[1];
@@ -822,14 +850,15 @@
   fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
 }
 
-static int count_wiener_bits(WienerInfo *wiener_info,
+static int count_wiener_bits(int wiener_win, WienerInfo *wiener_info,
                              WienerInfo *ref_wiener_info) {
   int bits = 0;
-  bits += aom_count_primitive_refsubexpfin(
-      WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
-      WIENER_FILT_TAP0_SUBEXP_K,
-      ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
-      wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
+  if (wiener_win == WIENER_WIN)
+    bits += aom_count_primitive_refsubexpfin(
+        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
+        WIENER_FILT_TAP0_SUBEXP_K,
+        ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
+        wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
   bits += aom_count_primitive_refsubexpfin(
       WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
       WIENER_FILT_TAP1_SUBEXP_K,
@@ -840,11 +869,12 @@
       WIENER_FILT_TAP2_SUBEXP_K,
       ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
       wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
-  bits += aom_count_primitive_refsubexpfin(
-      WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
-      WIENER_FILT_TAP0_SUBEXP_K,
-      ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
-      wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
+  if (wiener_win == WIENER_WIN)
+    bits += aom_count_primitive_refsubexpfin(
+        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
+        WIENER_FILT_TAP0_SUBEXP_K,
+        ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
+        wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
   bits += aom_count_primitive_refsubexpfin(
       WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
       WIENER_FILT_TAP1_SUBEXP_K,
@@ -861,9 +891,11 @@
 #define USE_WIENER_REFINEMENT_SEARCH 1
 static int64_t finer_tile_search_wiener(const YV12_BUFFER_CONFIG *src,
                                         AV1_COMP *cpi, RestorationInfo *rsi,
-                                        int start_step, int plane, int tile_idx,
+                                        int start_step, int plane,
+                                        int wiener_win, int tile_idx,
                                         int partial_frame,
                                         YV12_BUFFER_CONFIG *dst_frame) {
+  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
   int64_t err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
   (void)start_step;
@@ -875,7 +907,7 @@
                     WIENER_FILT_TAP2_MAXV };
   // printf("err  pre = %"PRId64"\n", err);
   for (int s = start_step; s >= 1; s >>= 1) {
-    for (int p = 0; p < WIENER_HALFWIN; ++p) {
+    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
       int skip = 0;
       do {
         if (rsi[plane].wiener_info[tile_idx].hfilter[p] - s >= tap_min[p]) {
@@ -918,7 +950,7 @@
         break;
       } while (1);
     }
-    for (int p = 0; p < WIENER_HALFWIN; ++p) {
+    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
       int skip = 0;
       do {
         if (rsi[plane].wiener_info[tile_idx].vfilter[p] - s >= tap_min[p]) {
@@ -982,7 +1014,7 @@
   double H[WIENER_WIN2 * WIENER_WIN2];
   double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
   const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
-  int width, height, src_stride, dgd_stride;
+  int width, height, src_stride, dgd_stride, wiener_win;
   uint8_t *dgd_buffer, *src_buffer;
   if (plane == AOM_PLANE_Y) {
     width = src->y_crop_width;
@@ -995,6 +1027,7 @@
     assert(height == dgd->y_crop_height);
     assert(width == src->y_crop_width);
     assert(height == src->y_crop_height);
+    wiener_win = WIENER_WIN;
   } else {
     width = src->uv_crop_width;
     height = src->uv_crop_height;
@@ -1004,6 +1037,7 @@
     dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
     assert(width == dgd->uv_crop_width);
     assert(height == dgd->uv_crop_height);
+    wiener_win = WIENER_WIN_CHROMA;
   }
   double score;
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
@@ -1047,26 +1081,29 @@
                              &v_start, &v_end);
 #if CONFIG_HIGHBITDEPTH
     if (cm->use_highbitdepth)
-      compute_stats_highbd(dgd_buffer, src_buffer, h_start, h_end, v_start,
-                           v_end, dgd_stride, src_stride, M, H);
+      compute_stats_highbd(wiener_win, dgd_buffer, src_buffer, h_start, h_end,
+                           v_start, v_end, dgd_stride, src_stride, M, H);
     else
 #endif  // CONFIG_HIGHBITDEPTH
-      compute_stats(dgd_buffer, src_buffer, h_start, h_end, v_start, v_end,
-                    dgd_stride, src_stride, M, H);
+      compute_stats(wiener_win, dgd_buffer, src_buffer, h_start, h_end, v_start,
+                    v_end, dgd_stride, src_stride, M, H);
 
     type[tile_idx] = RESTORE_WIENER;
 
-    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
+    if (!wiener_decompose_sep_sym(wiener_win, M, H, vfilterd, hfilterd)) {
       type[tile_idx] = RESTORE_NONE;
       continue;
     }
-    quantize_sym_filter(vfilterd, rsi[plane].wiener_info[tile_idx].vfilter);
-    quantize_sym_filter(hfilterd, rsi[plane].wiener_info[tile_idx].hfilter);
+    quantize_sym_filter(wiener_win, vfilterd,
+                        rsi[plane].wiener_info[tile_idx].vfilter);
+    quantize_sym_filter(wiener_win, hfilterd,
+                        rsi[plane].wiener_info[tile_idx].hfilter);
 
     // Filter score computes the value of the function x'*A*x - x'*b for the
     // learned filter and compares it against identity filer. If there is no
     // reduction in the function, the filter is reverted back to identity
-    score = compute_score(M, H, rsi[plane].wiener_info[tile_idx].vfilter,
+    score = compute_score(wiener_win, M, H,
+                          rsi[plane].wiener_info[tile_idx].vfilter,
                           rsi[plane].wiener_info[tile_idx].hfilter);
     if (score > 0.0) {
       type[tile_idx] = RESTORE_NONE;
@@ -1075,11 +1112,17 @@
     aom_clear_system_state();
 
     rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
-    err = finer_tile_search_wiener(src, cpi, rsi, 4, plane, tile_idx,
-                                   partial_frame, dst_frame);
-    bits =
-        count_wiener_bits(&rsi[plane].wiener_info[tile_idx], &ref_wiener_info)
-        << AV1_PROB_COST_SHIFT;
+    err = finer_tile_search_wiener(src, cpi, rsi, 4, plane, wiener_win,
+                                   tile_idx, partial_frame, dst_frame);
+    if (wiener_win != WIENER_WIN) {
+      assert(rsi[plane].wiener_info[tile_idx].vfilter[0] == 0 &&
+             rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - 1] == 0);
+      assert(rsi[plane].wiener_info[tile_idx].hfilter[0] == 0 &&
+             rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - 1] == 0);
+    }
+    bits = count_wiener_bits(wiener_win, &rsi[plane].wiener_info[tile_idx],
+                             &ref_wiener_info)
+           << AV1_PROB_COST_SHIFT;
     bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
     cost_wiener = RDCOST_DBL(x->rdmult, (bits >> 4), err);
     if (cost_wiener >= cost_norestore) {
@@ -1104,9 +1147,9 @@
     memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
            sizeof(wiener_info[tile_idx]));
     if (type[tile_idx] == RESTORE_WIENER) {
-      bits +=
-          count_wiener_bits(&rsi[plane].wiener_info[tile_idx], &ref_wiener_info)
-          << AV1_PROB_COST_SHIFT;
+      bits += count_wiener_bits(wiener_win, &rsi[plane].wiener_info[tile_idx],
+                                &ref_wiener_info)
+              << AV1_PROB_COST_SHIFT;
       memcpy(&ref_wiener_info, &rsi[plane].wiener_info[tile_idx],
              sizeof(ref_wiener_info));
     }
@@ -1205,8 +1248,9 @@
       int tilebits = 0;
       if (restore_types[r][tile_idx] != r) continue;
       if (r == RESTORE_WIENER)
-        tilebits +=
-            count_wiener_bits(&rsi->wiener_info[tile_idx], &ref_wiener_info);
+        tilebits += count_wiener_bits(
+            (plane == AOM_PLANE_Y ? WIENER_WIN : WIENER_WIN - 2),
+            &rsi->wiener_info[tile_idx], &ref_wiener_info);
       else if (r == RESTORE_SGRPROJ)
         tilebits +=
             count_sgrproj_bits(&rsi->sgrproj_info[tile_idx], &ref_sgrproj_info);