Add UV wiener loop restoration

Enables Wiener based loop restoration only for the UV
frames. The selfguided and domaintranform filters do not
work very well for UV components, hence they are disabled.
For each UV frame a single set of wiener parameters are
sent. They are applied tile-wise, but all tiles use the
same parameters.

BDRATE (Global PSNR) results:
-----------------------------
lowres: -1.266% (up from -0.666%, good improvement)
midres: -1.815% (up from -1.792%, tiny improvement)

Tiling on UV components will be explored subsequently.

Change-Id: Ib5be93121c4e88e05edf3c36c46488df3cfcd1e2
diff --git a/aom_scale/aom_scale_rtcd.pl b/aom_scale/aom_scale_rtcd.pl
index 925530a..c91e60d 100644
--- a/aom_scale/aom_scale_rtcd.pl
+++ b/aom_scale/aom_scale_rtcd.pl
@@ -18,10 +18,14 @@
 
 add_proto qw/void aom_yv12_extend_frame_borders/, "struct yv12_buffer_config *ybf";
 
-add_proto qw/void aom_yv12_copy_frame/, "const struct yv12_buffer_config *src_ybc, struct yv12_buffer_config *dst_ybc";
+add_proto qw/void aom_yv12_copy_frame/, "const struct yv12_buffer_config *src_bc, struct yv12_buffer_config *dst_bc";
 
 add_proto qw/void aom_yv12_copy_y/, "const struct yv12_buffer_config *src_ybc, struct yv12_buffer_config *dst_ybc";
 
+add_proto qw/void aom_yv12_copy_u/, "const struct yv12_buffer_config *src_bc, struct yv12_buffer_config *dst_bc";
+
+add_proto qw/void aom_yv12_copy_v/, "const struct yv12_buffer_config *src_bc, struct yv12_buffer_config *dst_bc";
+
 if (aom_config("CONFIG_AV1") eq "yes") {
     add_proto qw/void aom_extend_frame_borders/, "struct yv12_buffer_config *ybf";
     specialize qw/aom_extend_frame_borders dspr2/;
diff --git a/aom_scale/generic/yv12extend.c b/aom_scale/generic/yv12extend.c
index 5c7a052..28431a11 100644
--- a/aom_scale/generic/yv12extend.c
+++ b/aom_scale/generic/yv12extend.c
@@ -228,79 +228,79 @@
 // Copies the source image into the destination image and updates the
 // destination's UMV borders.
 // Note: The frames are assumed to be identical in size.
-void aom_yv12_copy_frame_c(const YV12_BUFFER_CONFIG *src_ybc,
-                           YV12_BUFFER_CONFIG *dst_ybc) {
+void aom_yv12_copy_frame_c(const YV12_BUFFER_CONFIG *src_bc,
+                           YV12_BUFFER_CONFIG *dst_bc) {
   int row;
-  const uint8_t *src = src_ybc->y_buffer;
-  uint8_t *dst = dst_ybc->y_buffer;
+  const uint8_t *src = src_bc->y_buffer;
+  uint8_t *dst = dst_bc->y_buffer;
 
 #if 0
   /* These assertions are valid in the codec, but the libaom-tester uses
    * this code slightly differently.
    */
-  assert(src_ybc->y_width == dst_ybc->y_width);
-  assert(src_ybc->y_height == dst_ybc->y_height);
+  assert(src_bc->y_width == dst_bc->y_width);
+  assert(src_bc->y_height == dst_bc->y_height);
 #endif
 
 #if CONFIG_AOM_HIGHBITDEPTH
-  if (src_ybc->flags & YV12_FLAG_HIGHBITDEPTH) {
-    assert(dst_ybc->flags & YV12_FLAG_HIGHBITDEPTH);
-    for (row = 0; row < src_ybc->y_height; ++row) {
-      memcpy_short_addr(dst, src, src_ybc->y_width);
-      src += src_ybc->y_stride;
-      dst += dst_ybc->y_stride;
+  if (src_bc->flags & YV12_FLAG_HIGHBITDEPTH) {
+    assert(dst_bc->flags & YV12_FLAG_HIGHBITDEPTH);
+    for (row = 0; row < src_bc->y_height; ++row) {
+      memcpy_short_addr(dst, src, src_bc->y_width);
+      src += src_bc->y_stride;
+      dst += dst_bc->y_stride;
     }
 
-    src = src_ybc->u_buffer;
-    dst = dst_ybc->u_buffer;
+    src = src_bc->u_buffer;
+    dst = dst_bc->u_buffer;
 
-    for (row = 0; row < src_ybc->uv_height; ++row) {
-      memcpy_short_addr(dst, src, src_ybc->uv_width);
-      src += src_ybc->uv_stride;
-      dst += dst_ybc->uv_stride;
+    for (row = 0; row < src_bc->uv_height; ++row) {
+      memcpy_short_addr(dst, src, src_bc->uv_width);
+      src += src_bc->uv_stride;
+      dst += dst_bc->uv_stride;
     }
 
-    src = src_ybc->v_buffer;
-    dst = dst_ybc->v_buffer;
+    src = src_bc->v_buffer;
+    dst = dst_bc->v_buffer;
 
-    for (row = 0; row < src_ybc->uv_height; ++row) {
-      memcpy_short_addr(dst, src, src_ybc->uv_width);
-      src += src_ybc->uv_stride;
-      dst += dst_ybc->uv_stride;
+    for (row = 0; row < src_bc->uv_height; ++row) {
+      memcpy_short_addr(dst, src, src_bc->uv_width);
+      src += src_bc->uv_stride;
+      dst += dst_bc->uv_stride;
     }
 
-    aom_yv12_extend_frame_borders_c(dst_ybc);
+    aom_yv12_extend_frame_borders_c(dst_bc);
     return;
   } else {
-    assert(!(dst_ybc->flags & YV12_FLAG_HIGHBITDEPTH));
+    assert(!(dst_bc->flags & YV12_FLAG_HIGHBITDEPTH));
   }
 #endif
 
-  for (row = 0; row < src_ybc->y_height; ++row) {
-    memcpy(dst, src, src_ybc->y_width);
-    src += src_ybc->y_stride;
-    dst += dst_ybc->y_stride;
+  for (row = 0; row < src_bc->y_height; ++row) {
+    memcpy(dst, src, src_bc->y_width);
+    src += src_bc->y_stride;
+    dst += dst_bc->y_stride;
   }
 
-  src = src_ybc->u_buffer;
-  dst = dst_ybc->u_buffer;
+  src = src_bc->u_buffer;
+  dst = dst_bc->u_buffer;
 
-  for (row = 0; row < src_ybc->uv_height; ++row) {
-    memcpy(dst, src, src_ybc->uv_width);
-    src += src_ybc->uv_stride;
-    dst += dst_ybc->uv_stride;
+  for (row = 0; row < src_bc->uv_height; ++row) {
+    memcpy(dst, src, src_bc->uv_width);
+    src += src_bc->uv_stride;
+    dst += dst_bc->uv_stride;
   }
 
-  src = src_ybc->v_buffer;
-  dst = dst_ybc->v_buffer;
+  src = src_bc->v_buffer;
+  dst = dst_bc->v_buffer;
 
-  for (row = 0; row < src_ybc->uv_height; ++row) {
-    memcpy(dst, src, src_ybc->uv_width);
-    src += src_ybc->uv_stride;
-    dst += dst_ybc->uv_stride;
+  for (row = 0; row < src_bc->uv_height; ++row) {
+    memcpy(dst, src, src_bc->uv_width);
+    src += src_bc->uv_stride;
+    dst += dst_bc->uv_stride;
   }
 
-  aom_yv12_extend_frame_borders_c(dst_ybc);
+  aom_yv12_extend_frame_borders_c(dst_bc);
 }
 
 void aom_yv12_copy_y_c(const YV12_BUFFER_CONFIG *src_ybc,
@@ -320,7 +320,7 @@
     }
     return;
   }
-#endif
+#endif  // CONFIG_AOM_HIGHBITDEPTH
 
   for (row = 0; row < src_ybc->y_height; ++row) {
     memcpy(dst, src, src_ybc->y_width);
@@ -328,3 +328,55 @@
     dst += dst_ybc->y_stride;
   }
 }
+
+void aom_yv12_copy_u_c(const YV12_BUFFER_CONFIG *src_bc,
+                       YV12_BUFFER_CONFIG *dst_bc) {
+  int row;
+  const uint8_t *src = src_bc->u_buffer;
+  uint8_t *dst = dst_bc->u_buffer;
+
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (src_bc->flags & YV12_FLAG_HIGHBITDEPTH) {
+    const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
+    uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
+    for (row = 0; row < src_bc->uv_height; ++row) {
+      memcpy(dst16, src16, src_bc->uv_width * sizeof(uint16_t));
+      src16 += src_bc->uv_stride;
+      dst16 += dst_bc->uv_stride;
+    }
+    return;
+  }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+
+  for (row = 0; row < src_bc->uv_height; ++row) {
+    memcpy(dst, src, src_bc->uv_width);
+    src += src_bc->uv_stride;
+    dst += dst_bc->uv_stride;
+  }
+}
+
+void aom_yv12_copy_v_c(const YV12_BUFFER_CONFIG *src_bc,
+                       YV12_BUFFER_CONFIG *dst_bc) {
+  int row;
+  const uint8_t *src = src_bc->v_buffer;
+  uint8_t *dst = dst_bc->v_buffer;
+
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (src_bc->flags & YV12_FLAG_HIGHBITDEPTH) {
+    const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
+    uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
+    for (row = 0; row < src_bc->uv_height; ++row) {
+      memcpy(dst16, src16, src_bc->uv_width * sizeof(uint16_t));
+      src16 += src_bc->uv_stride;
+      dst16 += dst_bc->uv_stride;
+    }
+    return;
+  }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+
+  for (row = 0; row < src_bc->uv_height; ++row) {
+    memcpy(dst, src, src_bc->uv_width);
+    src += src_bc->uv_stride;
+    dst += dst_bc->uv_stride;
+  }
+}
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index 471ae6c..50446b2 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -89,7 +89,9 @@
 
 #if CONFIG_LOOP_RESTORATION
 void av1_alloc_restoration_buffers(AV1_COMMON *cm) {
-  av1_alloc_restoration_struct(&cm->rst_info, cm->width, cm->height);
+  int p;
+  for (p = 0; p < MAX_MB_PLANE; ++p)
+    av1_alloc_restoration_struct(&cm->rst_info[p], cm->width, cm->height);
   cm->rst_internal.tmpbuf =
       (int32_t *)aom_realloc(cm->rst_internal.tmpbuf, RESTORATION_TMPBUF_SIZE);
   if (cm->rst_internal.tmpbuf == NULL)
@@ -98,7 +100,9 @@
 }
 
 void av1_free_restoration_buffers(AV1_COMMON *cm) {
-  av1_free_restoration_struct(&cm->rst_info);
+  int p;
+  for (p = 0; p < MAX_MB_PLANE; ++p)
+    av1_free_restoration_struct(&cm->rst_info[p]);
   aom_free(cm->rst_internal.tmpbuf);
   cm->rst_internal.tmpbuf = NULL;
 }
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 4d88c2c..7c9416d 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -306,7 +306,7 @@
 
   loop_filter_info_n lf_info;
 #if CONFIG_LOOP_RESTORATION
-  RestorationInfo rst_info;
+  RestorationInfo rst_info[MAX_MB_PLANE];
   RestorationInternal rst_internal;
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index 1c00dd2..d725302 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -23,9 +23,6 @@
 
 static int domaintxfmrf_vtable[DOMAINTXFMRF_ITERS][DOMAINTXFMRF_PARAMS][256];
 
-// Whether to filter only y or not
-static const int override_y_only[RESTORE_TYPES] = { 1, 1, 1, 1, 1 };
-
 static const int domaintxfmrf_params[DOMAINTXFMRF_PARAMS] = {
   32,  40,  48,  56,  64,  68,  72,  76,  80,  82,  84,  86,  88,
   90,  92,  94,  96,  97,  98,  99,  100, 101, 102, 103, 104, 105,
@@ -100,53 +97,12 @@
 
 void av1_loop_restoration_precal() { GenDomainTxfmRFVtable(); }
 
-void av1_loop_restoration_init(RestorationInternal *rst, RestorationInfo *rsi,
-                               int kf, int width, int height) {
-  int i, tile_idx;
-  rst->rsi = rsi;
+static void loop_restoration_init(RestorationInternal *rst, int kf, int width,
+                                  int height) {
   rst->keyframe = kf;
-  rst->subsampling_x = 0;
-  rst->subsampling_y = 0;
   rst->ntiles =
       av1_get_rest_ntiles(width, height, &rst->tile_width, &rst->tile_height,
                           &rst->nhtiles, &rst->nvtiles);
-  if (rsi->frame_restoration_type == RESTORE_WIENER) {
-    for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-      if (rsi->wiener_info[tile_idx].level) {
-        rsi->wiener_info[tile_idx].vfilter[WIENER_HALFWIN] =
-            rsi->wiener_info[tile_idx].hfilter[WIENER_HALFWIN] =
-                WIENER_FILT_STEP;
-        for (i = 0; i < WIENER_HALFWIN; ++i) {
-          rsi->wiener_info[tile_idx].vfilter[WIENER_WIN - 1 - i] =
-              rsi->wiener_info[tile_idx].vfilter[i];
-          rsi->wiener_info[tile_idx].hfilter[WIENER_WIN - 1 - i] =
-              rsi->wiener_info[tile_idx].hfilter[i];
-          rsi->wiener_info[tile_idx].vfilter[WIENER_HALFWIN] -=
-              2 * rsi->wiener_info[tile_idx].vfilter[i];
-          rsi->wiener_info[tile_idx].hfilter[WIENER_HALFWIN] -=
-              2 * rsi->wiener_info[tile_idx].hfilter[i];
-        }
-      }
-    }
-  } else if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
-    for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-      if (rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
-        rsi->wiener_info[tile_idx].vfilter[WIENER_HALFWIN] =
-            rsi->wiener_info[tile_idx].hfilter[WIENER_HALFWIN] =
-                WIENER_FILT_STEP;
-        for (i = 0; i < WIENER_HALFWIN; ++i) {
-          rsi->wiener_info[tile_idx].vfilter[WIENER_WIN - 1 - i] =
-              rsi->wiener_info[tile_idx].vfilter[i];
-          rsi->wiener_info[tile_idx].hfilter[WIENER_WIN - 1 - i] =
-              rsi->wiener_info[tile_idx].hfilter[i];
-          rsi->wiener_info[tile_idx].vfilter[WIENER_HALFWIN] -=
-              2 * rsi->wiener_info[tile_idx].vfilter[i];
-          rsi->wiener_info[tile_idx].hfilter[WIENER_HALFWIN] -=
-              2 * rsi->wiener_info[tile_idx].hfilter[i];
-        }
-      }
-    }
-  }
 }
 
 static void extend_frame(uint8_t *data, int width, int height, int stride) {
@@ -1043,9 +999,10 @@
 }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
-void av1_loop_restoration_rows(YV12_BUFFER_CONFIG *frame, AV1_COMMON *cm,
-                               int start_mi_row, int end_mi_row, int y_only,
-                               YV12_BUFFER_CONFIG *dst) {
+static void loop_restoration_rows(YV12_BUFFER_CONFIG *frame, AV1_COMMON *cm,
+                                  int start_mi_row, int end_mi_row,
+                                  int components_pattern, RestorationInfo *rsi,
+                                  YV12_BUFFER_CONFIG *dst) {
   const int ywidth = frame->y_crop_width;
   const int ystride = frame->y_stride;
   const int uvwidth = frame->uv_crop_width;
@@ -1064,29 +1021,44 @@
     loop_domaintxfmrf_filter_highbd, loop_switchable_filter_highbd
   };
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-  restore_func_type restore_func =
-      restore_funcs[cm->rst_internal.rsi->frame_restoration_type];
+  restore_func_type restore_func;
 #if CONFIG_AOM_HIGHBITDEPTH
-  restore_func_highbd_type restore_func_highbd =
-      restore_funcs_highbd[cm->rst_internal.rsi->frame_restoration_type];
+  restore_func_highbd_type restore_func_highbd;
 #endif  // CONFIG_AOM_HIGHBITDEPTH
   YV12_BUFFER_CONFIG dst_;
 
   yend = AOMMIN(yend, cm->height);
   uvend = AOMMIN(uvend, cm->subsampling_y ? (cm->height + 1) >> 1 : cm->height);
 
-  if (cm->rst_internal.rsi->frame_restoration_type == RESTORE_NONE) {
-    if (dst) {
-      if (y_only)
-        aom_yv12_copy_y(frame, dst);
-      else
-        aom_yv12_copy_frame(frame, dst);
+  if (components_pattern == (1 << AOM_PLANE_Y)) {
+    // Only y
+    if (rsi[0].frame_restoration_type == RESTORE_NONE) {
+      if (dst) aom_yv12_copy_y(frame, dst);
+      return;
     }
-    return;
+  } else if (components_pattern == (1 << AOM_PLANE_U)) {
+    // Only U
+    if (rsi[1].frame_restoration_type == RESTORE_NONE) {
+      if (dst) aom_yv12_copy_u(frame, dst);
+      return;
+    }
+  } else if (components_pattern == (1 << AOM_PLANE_V)) {
+    // Only V
+    if (rsi[2].frame_restoration_type == RESTORE_NONE) {
+      if (dst) aom_yv12_copy_v(frame, dst);
+      return;
+    }
+  } else if (components_pattern ==
+             ((1 << AOM_PLANE_Y) | (1 << AOM_PLANE_U) | (1 << AOM_PLANE_V))) {
+    // All components
+    if (rsi[0].frame_restoration_type == RESTORE_NONE &&
+        rsi[1].frame_restoration_type == RESTORE_NONE &&
+        rsi[2].frame_restoration_type == RESTORE_NONE) {
+      if (dst) aom_yv12_copy_frame(frame, dst);
+      return;
+    }
   }
 
-  if (y_only == 0)
-    y_only = override_y_only[cm->rst_internal.rsi->frame_restoration_type];
   if (!dst) {
     dst = &dst_;
     memset(dst, 0, sizeof(YV12_BUFFER_CONFIG));
@@ -1100,68 +1072,103 @@
                          "Failed to allocate restoration dst buffer");
   }
 
+  if ((components_pattern >> AOM_PLANE_Y) & 1) {
+    if (rsi[0].frame_restoration_type != RESTORE_NONE) {
+      cm->rst_internal.subsampling_x = 0;
+      cm->rst_internal.subsampling_y = 0;
+      cm->rst_internal.rsi = &rsi[0];
+      restore_func =
+          restore_funcs[cm->rst_internal.rsi->frame_restoration_type];
 #if CONFIG_AOM_HIGHBITDEPTH
-  if (cm->use_highbitdepth)
-    restore_func_highbd(frame->y_buffer + ystart * ystride, ywidth,
-                        yend - ystart, ystride, &cm->rst_internal,
-                        cm->bit_depth, dst->y_buffer + ystart * dst->y_stride,
-                        dst->y_stride);
-  else
+      restore_func_highbd =
+          restore_funcs_highbd[cm->rst_internal.rsi->frame_restoration_type];
+      if (cm->use_highbitdepth)
+        restore_func_highbd(
+            frame->y_buffer + ystart * ystride, ywidth, yend - ystart, ystride,
+            &cm->rst_internal, cm->bit_depth,
+            dst->y_buffer + ystart * dst->y_stride, dst->y_stride);
+      else
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-    restore_func(frame->y_buffer + ystart * ystride, ywidth, yend - ystart,
-                 ystride, &cm->rst_internal,
-                 dst->y_buffer + ystart * dst->y_stride, dst->y_stride);
-  if (!y_only) {
+        restore_func(frame->y_buffer + ystart * ystride, ywidth, yend - ystart,
+                     ystride, &cm->rst_internal,
+                     dst->y_buffer + ystart * dst->y_stride, dst->y_stride);
+    } else {
+      aom_yv12_copy_y(frame, dst);
+    }
+  }
+
+  if ((components_pattern >> AOM_PLANE_U) & 1) {
     cm->rst_internal.subsampling_x = cm->subsampling_x;
     cm->rst_internal.subsampling_y = cm->subsampling_y;
+    cm->rst_internal.rsi = &rsi[1];
+    if (rsi[1].frame_restoration_type != RESTORE_NONE) {
+      restore_func =
+          restore_funcs[cm->rst_internal.rsi->frame_restoration_type];
 #if CONFIG_AOM_HIGHBITDEPTH
-    if (cm->use_highbitdepth) {
-      restore_func_highbd(
-          frame->u_buffer + uvstart * uvstride, uvwidth, uvend - uvstart,
-          uvstride, &cm->rst_internal, cm->bit_depth,
-          dst->u_buffer + uvstart * dst->uv_stride, dst->uv_stride);
-      restore_func_highbd(
-          frame->v_buffer + uvstart * uvstride, uvwidth, uvend - uvstart,
-          uvstride, &cm->rst_internal, cm->bit_depth,
-          dst->v_buffer + uvstart * dst->uv_stride, dst->uv_stride);
+      restore_func_highbd =
+          restore_funcs_highbd[cm->rst_internal.rsi->frame_restoration_type];
+      if (cm->use_highbitdepth)
+        restore_func_highbd(
+            frame->u_buffer + uvstart * uvstride, uvwidth, uvend - uvstart,
+            uvstride, &cm->rst_internal, cm->bit_depth,
+            dst->u_buffer + uvstart * dst->uv_stride, dst->uv_stride);
+      else
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+        restore_func(frame->u_buffer + uvstart * uvstride, uvwidth,
+                     uvend - uvstart, uvstride, &cm->rst_internal,
+                     dst->u_buffer + uvstart * dst->uv_stride, dst->uv_stride);
     } else {
-#endif  // CONFIG_AOM_HIGHBITDEPTH
-      restore_func(frame->u_buffer + uvstart * uvstride, uvwidth,
-                   uvend - uvstart, uvstride, &cm->rst_internal,
-                   dst->u_buffer + uvstart * dst->uv_stride, dst->uv_stride);
-      restore_func(frame->v_buffer + uvstart * uvstride, uvwidth,
-                   uvend - uvstart, uvstride, &cm->rst_internal,
-                   dst->v_buffer + uvstart * dst->uv_stride, dst->uv_stride);
-#if CONFIG_AOM_HIGHBITDEPTH
+      aom_yv12_copy_u(frame, dst);
     }
+  }
+
+  if ((components_pattern >> AOM_PLANE_V) & 1) {
+    cm->rst_internal.subsampling_x = cm->subsampling_x;
+    cm->rst_internal.subsampling_y = cm->subsampling_y;
+    cm->rst_internal.rsi = &rsi[2];
+    if (rsi[2].frame_restoration_type != RESTORE_NONE) {
+      restore_func =
+          restore_funcs[cm->rst_internal.rsi->frame_restoration_type];
+#if CONFIG_AOM_HIGHBITDEPTH
+      restore_func_highbd =
+          restore_funcs_highbd[cm->rst_internal.rsi->frame_restoration_type];
+      if (cm->use_highbitdepth)
+        restore_func_highbd(
+            frame->v_buffer + uvstart * uvstride, uvwidth, uvend - uvstart,
+            uvstride, &cm->rst_internal, cm->bit_depth,
+            dst->v_buffer + uvstart * dst->uv_stride, dst->uv_stride);
+      else
 #endif  // CONFIG_AOM_HIGHBITDEPTH
+        restore_func(frame->v_buffer + uvstart * uvstride, uvwidth,
+                     uvend - uvstart, uvstride, &cm->rst_internal,
+                     dst->v_buffer + uvstart * dst->uv_stride, dst->uv_stride);
+    } else {
+      aom_yv12_copy_v(frame, dst);
+    }
   }
 
   if (dst == &dst_) {
-    if (y_only)
-      aom_yv12_copy_y(dst, frame);
-    else
-      aom_yv12_copy_frame(dst, frame);
+    if ((components_pattern >> AOM_PLANE_Y) & 1) aom_yv12_copy_y(dst, frame);
+    if ((components_pattern >> AOM_PLANE_U) & 1) aom_yv12_copy_u(dst, frame);
+    if ((components_pattern >> AOM_PLANE_V) & 1) aom_yv12_copy_v(dst, frame);
     aom_free_frame_buffer(dst);
   }
 }
 
 void av1_loop_restoration_frame(YV12_BUFFER_CONFIG *frame, AV1_COMMON *cm,
-                                RestorationInfo *rsi, int y_only,
+                                RestorationInfo *rsi, int components_pattern,
                                 int partial_frame, YV12_BUFFER_CONFIG *dst) {
   int start_mi_row, end_mi_row, mi_rows_to_filter;
-  if (rsi->frame_restoration_type != RESTORE_NONE) {
-    start_mi_row = 0;
-    mi_rows_to_filter = cm->mi_rows;
-    if (partial_frame && cm->mi_rows > 8) {
-      start_mi_row = cm->mi_rows >> 1;
-      start_mi_row &= 0xfffffff8;
-      mi_rows_to_filter = AOMMAX(cm->mi_rows / 8, 8);
-    }
-    end_mi_row = start_mi_row + mi_rows_to_filter;
-    av1_loop_restoration_init(&cm->rst_internal, rsi,
-                              cm->frame_type == KEY_FRAME, cm->width,
-                              cm->height);
-    av1_loop_restoration_rows(frame, cm, start_mi_row, end_mi_row, y_only, dst);
+  start_mi_row = 0;
+  mi_rows_to_filter = cm->mi_rows;
+  if (partial_frame && cm->mi_rows > 8) {
+    start_mi_row = cm->mi_rows >> 1;
+    start_mi_row &= 0xfffffff8;
+    mi_rows_to_filter = AOMMAX(cm->mi_rows / 8, 8);
   }
+  end_mi_row = start_mi_row + mi_rows_to_filter;
+  loop_restoration_init(&cm->rst_internal, cm->frame_type == KEY_FRAME,
+                        cm->width, cm->height);
+  loop_restoration_rows(frame, cm, start_mi_row, end_mi_row, components_pattern,
+                        rsi, dst);
 }
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 8bcbb91..cf44962 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -233,14 +233,9 @@
                                          int32_t *tmpbuf);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 void decode_xq(int *xqd, int *xq);
-void av1_loop_restoration_init(RestorationInternal *rst, RestorationInfo *rsi,
-                               int kf, int width, int height);
 void av1_loop_restoration_frame(YV12_BUFFER_CONFIG *frame, struct AV1Common *cm,
-                                RestorationInfo *rsi, int y_only,
+                                RestorationInfo *rsi, int components_pattern,
                                 int partial_frame, YV12_BUFFER_CONFIG *dst);
-void av1_loop_restoration_rows(YV12_BUFFER_CONFIG *frame, struct AV1Common *cm,
-                               int start_mi_row, int end_mi_row, int y_only,
-                               YV12_BUFFER_CONFIG *dst);
 void av1_loop_restoration_precal();
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 37ec17b..499da3d 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2294,7 +2294,8 @@
 #if CONFIG_LOOP_RESTORATION
 static void decode_restoration_mode(AV1_COMMON *cm,
                                     struct aom_read_bit_buffer *rb) {
-  RestorationInfo *rsi = &cm->rst_info;
+  int p;
+  RestorationInfo *rsi = &cm->rst_info[0];
   if (aom_rb_read_bit(rb)) {
     if (aom_rb_read_bit(rb))
       rsi->frame_restoration_type =
@@ -2305,27 +2306,39 @@
     rsi->frame_restoration_type =
         aom_rb_read_bit(rb) ? RESTORE_SWITCHABLE : RESTORE_NONE;
   }
+  for (p = 1; p < MAX_MB_PLANE; ++p) {
+    cm->rst_info[p].frame_restoration_type =
+        aom_rb_read_bit(rb) ? RESTORE_WIENER : RESTORE_NONE;
+  }
 }
 
 static void read_wiener_filter(WienerInfo *wiener_info, aom_reader *rb) {
-  wiener_info->vfilter[0] =
+  wiener_info->vfilter[0] = wiener_info->vfilter[WIENER_WIN - 1] =
       aom_read_literal(rb, WIENER_FILT_TAP0_BITS, ACCT_STR) +
       WIENER_FILT_TAP0_MINV;
-  wiener_info->vfilter[1] =
+  wiener_info->vfilter[1] = wiener_info->vfilter[WIENER_WIN - 2] =
       aom_read_literal(rb, WIENER_FILT_TAP1_BITS, ACCT_STR) +
       WIENER_FILT_TAP1_MINV;
-  wiener_info->vfilter[2] =
+  wiener_info->vfilter[2] = wiener_info->vfilter[WIENER_WIN - 3] =
       aom_read_literal(rb, WIENER_FILT_TAP2_BITS, ACCT_STR) +
       WIENER_FILT_TAP2_MINV;
-  wiener_info->hfilter[0] =
+  wiener_info->vfilter[WIENER_HALFWIN] =
+      WIENER_FILT_STEP -
+      2 * (wiener_info->vfilter[0] + wiener_info->vfilter[1] +
+           wiener_info->vfilter[2]);
+  wiener_info->hfilter[0] = wiener_info->hfilter[WIENER_WIN - 1] =
       aom_read_literal(rb, WIENER_FILT_TAP0_BITS, ACCT_STR) +
       WIENER_FILT_TAP0_MINV;
-  wiener_info->hfilter[1] =
+  wiener_info->hfilter[1] = wiener_info->hfilter[WIENER_WIN - 2] =
       aom_read_literal(rb, WIENER_FILT_TAP1_BITS, ACCT_STR) +
       WIENER_FILT_TAP1_MINV;
-  wiener_info->hfilter[2] =
+  wiener_info->hfilter[2] = wiener_info->hfilter[WIENER_WIN - 3] =
       aom_read_literal(rb, WIENER_FILT_TAP2_BITS, ACCT_STR) +
       WIENER_FILT_TAP2_MINV;
+  wiener_info->hfilter[WIENER_HALFWIN] =
+      WIENER_FILT_STEP -
+      2 * (wiener_info->hfilter[0] + wiener_info->hfilter[1] +
+           wiener_info->hfilter[2]);
 }
 
 static void read_sgrproj_filter(SgrprojInfo *sgrproj_info, aom_reader *rb) {
@@ -2343,10 +2356,10 @@
 }
 
 static void decode_restoration(AV1_COMMON *cm, aom_reader *rb) {
-  int i;
-  RestorationInfo *rsi = &cm->rst_info;
+  int i, p;
   const int ntiles =
       av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
+  RestorationInfo *rsi = &cm->rst_info[0];
   if (rsi->frame_restoration_type != RESTORE_NONE) {
     if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
       for (i = 0; i < ntiles; ++i) {
@@ -2399,6 +2412,19 @@
       }
     }
   }
+  for (p = 1; p < MAX_MB_PLANE; ++p) {
+    rsi = &cm->rst_info[p];
+    if (rsi->frame_restoration_type == RESTORE_WIENER) {
+      rsi->restoration_type[0] = RESTORE_WIENER;
+      rsi->wiener_info[0].level = 1;
+      read_wiener_filter(&rsi->wiener_info[0], rb);
+      for (i = 1; i < ntiles; ++i) {
+        rsi->restoration_type[i] = RESTORE_WIENER;
+        memcpy(&rsi->wiener_info[i], &rsi->wiener_info[0],
+               sizeof(rsi->wiener_info[0]));
+      }
+    }
+  }
 }
 #endif  // CONFIG_LOOP_RESTORATION
 
@@ -4596,8 +4622,10 @@
     *p_data_end = decode_tiles(pbi, data + first_partition_size, data_end);
   }
 #if CONFIG_LOOP_RESTORATION
-  if (cm->rst_info.frame_restoration_type != RESTORE_NONE) {
-    av1_loop_restoration_frame(new_fb, cm, &cm->rst_info, 0, 0, NULL);
+  if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
+      cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
+      cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
+    av1_loop_restoration_frame(new_fb, cm, cm->rst_info, 7, 0, NULL);
   }
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 6152652..9e153c5 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -3022,8 +3022,9 @@
 #if CONFIG_LOOP_RESTORATION
 static void encode_restoration_mode(AV1_COMMON *cm,
                                     struct aom_write_bit_buffer *wb) {
-  RestorationInfo *rst = &cm->rst_info;
-  switch (rst->frame_restoration_type) {
+  int p;
+  RestorationInfo *rsi = &cm->rst_info[0];
+  switch (rsi->frame_restoration_type) {
     case RESTORE_NONE:
       aom_wb_write_bit(wb, 0);
       aom_wb_write_bit(wb, 0);
@@ -3048,6 +3049,14 @@
       break;
     default: assert(0);
   }
+  for (p = 1; p < MAX_MB_PLANE; ++p) {
+    rsi = &cm->rst_info[p];
+    switch (rsi->frame_restoration_type) {
+      case RESTORE_NONE: aom_wb_write_bit(wb, 0); break;
+      case RESTORE_WIENER: aom_wb_write_bit(wb, 1); break;
+      default: assert(0);
+    }
+  }
 }
 
 static void write_wiener_filter(WienerInfo *wiener_info, aom_writer *wb) {
@@ -3079,8 +3088,8 @@
 }
 
 static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
-  int i;
-  RestorationInfo *rsi = &cm->rst_info;
+  int i, p;
+  RestorationInfo *rsi = &cm->rst_info[0];
   if (rsi->frame_restoration_type != RESTORE_NONE) {
     if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
       // RESTORE_SWITCHABLE
@@ -3121,6 +3130,14 @@
       }
     }
   }
+  for (p = 1; p < MAX_MB_PLANE; ++p) {
+    rsi = &cm->rst_info[p];
+    if (rsi->frame_restoration_type == RESTORE_WIENER) {
+      write_wiener_filter(&rsi->wiener_info[0], wb);
+    } else if (rsi->frame_restoration_type != RESTORE_NONE) {
+      assert(0);
+    }
+  }
 }
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 1aff23f..1713dd4 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -459,7 +459,8 @@
   aom_free_frame_buffer(&cpi->last_frame_db);
   aom_free_frame_buffer(&cpi->trial_frame_rst);
   aom_free(cpi->extra_rstbuf);
-  av1_free_restoration_struct(&cpi->rst_search);
+  for (i = 0; i < MAX_MB_PLANE; ++i)
+    av1_free_restoration_struct(&cpi->rst_search[i]);
 #endif  // CONFIG_LOOP_RESTORATION
   aom_free_frame_buffer(&cpi->scaled_source);
   aom_free_frame_buffer(&cpi->scaled_last_source);
@@ -712,6 +713,9 @@
 }
 
 static void alloc_util_frame_buffers(AV1_COMP *cpi) {
+#if CONFIG_LOOP_RESTORATION
+  int i;
+#endif  // CONFIG_LOOP_RESTORATION
   AV1_COMMON *const cm = &cpi->common;
   if (aom_realloc_frame_buffer(&cpi->last_frame_uf, cm->width, cm->height,
                                cm->subsampling_x, cm->subsampling_y,
@@ -747,7 +751,8 @@
   if (!cpi->extra_rstbuf)
     aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
                        "Failed to allocate extra rstbuf for restoration");
-  av1_alloc_restoration_struct(&cpi->rst_search, cm->width, cm->height);
+  for (i = 0; i < MAX_MB_PLANE; ++i)
+    av1_alloc_restoration_struct(&cpi->rst_search[i], cm->width, cm->height);
 #endif  // CONFIG_LOOP_RESTORATION
 
   if (aom_realloc_frame_buffer(&cpi->scaled_source, cm->width, cm->height,
@@ -3496,9 +3501,10 @@
   }
 #endif
 #if CONFIG_LOOP_RESTORATION
-  if (cm->rst_info.frame_restoration_type != RESTORE_NONE) {
-    av1_loop_restoration_frame(cm->frame_to_show, cm, &cm->rst_info, 0, 0,
-                               NULL);
+  if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
+      cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
+      cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
+    av1_loop_restoration_frame(cm->frame_to_show, cm, cm->rst_info, 7, 0, NULL);
   }
 #endif  // CONFIG_LOOP_RESTORATION
   aom_extend_frame_inner_borders(cm->frame_to_show);
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 59ebb23..f1d06b98 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -404,9 +404,9 @@
 #if CONFIG_LOOP_RESTORATION
   YV12_BUFFER_CONFIG last_frame_db;
   YV12_BUFFER_CONFIG trial_frame_rst;
-  uint8_t *extra_rstbuf;       // Extra buffers used in restoration search
-  RestorationInfo rst_search;  // Used for encoder side search
-#endif                         // CONFIG_LOOP_RESTORATION
+  uint8_t *extra_rstbuf;  // Extra buffers used in restoration search
+  RestorationInfo rst_search[MAX_MB_PLANE];  // Used for encoder side search
+#endif                                       // CONFIG_LOOP_RESTORATION
 
   // Ambient reconstruction err target for force key frames
   int64_t ambient_err;
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 40c3486..a07cea2 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -23,6 +23,7 @@
 
 #include "av1/common/onyxc_int.h"
 #include "av1/common/quant_common.h"
+#include "av1/common/restoration.h"
 
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/picklpf.h"
@@ -41,16 +42,20 @@
                                     const YV12_BUFFER_CONFIG *dst,
                                     const AV1_COMMON *cm, int h_start,
                                     int width, int v_start, int height,
-                                    int y_only) {
-  int64_t filt_err;
+                                    int components_pattern) {
+  int64_t filt_err = 0;
 #if CONFIG_AOM_HIGHBITDEPTH
   if (cm->use_highbitdepth) {
-    filt_err =
-        aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
-    if (!y_only) {
+    if ((components_pattern >> AOM_PLANE_Y) & 1) {
+      filt_err +=
+          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
+    }
+    if ((components_pattern >> AOM_PLANE_U) & 1) {
       filt_err += aom_highbd_get_u_sse_part(
           src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
           v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+    }
+    if ((components_pattern >> AOM_PLANE_V) & 1) {
       filt_err += aom_highbd_get_v_sse_part(
           src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
           v_start >> cm->subsampling_y, height >> cm->subsampling_y);
@@ -58,11 +63,15 @@
     return filt_err;
   }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-  filt_err = aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
-  if (!y_only) {
+  if ((components_pattern >> AOM_PLANE_Y) & 1) {
+    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
+  }
+  if ((components_pattern >> AOM_PLANE_U) & 1) {
     filt_err += aom_get_u_sse_part(
         src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
         v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+  }
+  if ((components_pattern >> AOM_PLANE_V) & 1) {
     filt_err += aom_get_u_sse_part(
         src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
         v_start >> cm->subsampling_y, height >> cm->subsampling_y);
@@ -70,10 +79,41 @@
   return filt_err;
 }
 
+static int64_t sse_restoration_frame(const YV12_BUFFER_CONFIG *src,
+                                     const YV12_BUFFER_CONFIG *dst,
+                                     int components_pattern) {
+  int64_t filt_err = 0;
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (cm->use_highbitdepth) {
+    if ((components_pattern >> AOM_PLANE_Y) & 1) {
+      filt_err += aom_highbd_get_y_sse(src, dst);
+    }
+    if ((components_pattern >> AOM_PLANE_U) & 1) {
+      filt_err += aom_highbd_get_u_sse(src, dst);
+    }
+    if ((components_pattern >> AOM_PLANE_V) & 1) {
+      filt_err += aom_highbd_get_v_sse(src, dst);
+    }
+    return filt_err;
+  }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+  if ((components_pattern >> AOM_PLANE_Y) & 1) {
+    filt_err = aom_get_y_sse(src, dst);
+  }
+  if ((components_pattern >> AOM_PLANE_U) & 1) {
+    filt_err += aom_get_u_sse(src, dst);
+  }
+  if ((components_pattern >> AOM_PLANE_V) & 1) {
+    filt_err += aom_get_v_sse(src, dst);
+  }
+  return filt_err;
+}
+
 static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
-                                    int y_only, int partial_frame, int tile_idx,
-                                    int subtile_idx, int subtile_bits,
+                                    int components_pattern, int partial_frame,
+                                    int tile_idx, int subtile_idx,
+                                    int subtile_bits,
                                     YV12_BUFFER_CONFIG *dst_frame) {
   AV1_COMMON *const cm = &cpi->common;
   int64_t filt_err;
@@ -83,41 +123,27 @@
                                          &tile_height, &nhtiles, &nvtiles);
   (void)ntiles;
 
-  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
-                             dst_frame);
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
+                             partial_frame, dst_frame);
   av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
                            nvtiles, tile_width, tile_height, cm->width,
                            cm->height, 0, 0, &h_start, &h_end, &v_start,
                            &v_end);
   filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
-                                  v_start, v_end - v_start, y_only);
+                                  v_start, v_end - v_start, components_pattern);
 
   return filt_err;
 }
 
 static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *const cpi, RestorationInfo *rsi,
-                                     int y_only, int partial_frame,
+                                     int components_pattern, int partial_frame,
                                      YV12_BUFFER_CONFIG *dst_frame) {
   AV1_COMMON *const cm = &cpi->common;
   int64_t filt_err;
-  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
-                             dst_frame);
-#if CONFIG_AOM_HIGHBITDEPTH
-  if (cm->use_highbitdepth) {
-    filt_err = aom_highbd_get_y_sse(src, dst_frame);
-    if (!y_only) {
-      filt_err += aom_highbd_get_u_sse(src, dst_frame);
-      filt_err += aom_highbd_get_v_sse(src, dst_frame);
-    }
-    return filt_err;
-  }
-#endif  // CONFIG_AOM_HIGHBITDEPTH
-  filt_err = aom_get_y_sse(src, dst_frame);
-  if (!y_only) {
-    filt_err += aom_get_u_sse(src, dst_frame);
-    filt_err += aom_get_v_sse(src, dst_frame);
-  }
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
+                             partial_frame, dst_frame);
+  filt_err = sse_restoration_frame(src, dst_frame, components_pattern);
   return filt_err;
 }
 
@@ -299,7 +325,7 @@
   MACROBLOCK *x = &cpi->td.mb;
   AV1_COMMON *const cm = &cpi->common;
   const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
-  RestorationInfo *rsi = &cpi->rst_search;
+  RestorationInfo *rsi = &cpi->rst_search[0];
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   // Allocate for the src buffer at high precision
@@ -516,7 +542,7 @@
   MACROBLOCK *x = &cpi->td.mb;
   AV1_COMMON *const cm = &cpi->common;
   const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
-  RestorationInfo *rsi = &cpi->rst_search;
+  RestorationInfo *rsi = &cpi->rst_search[0];
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
@@ -528,7 +554,6 @@
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
 
   rsi->frame_restoration_type = RESTORE_DOMAINTXFMRF;
-  rsi->domaintxfmrf_info = cpi->rst_search.domaintxfmrf_info;
 
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
     rsi->domaintxfmrf_info[tile_idx].level = 0;
@@ -843,7 +868,7 @@
   return 1;
 }
 
-// Computes the function x'*A*x - x'*b for the learned filters, and compares
+// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
 // against identity filters; Final score is defined as the difference between
 // the function values
 static double compute_score(double *M, double *H, int *vfilt, int *hfilt) {
@@ -852,9 +877,7 @@
   double P = 0, Q = 0;
   double iP = 0, iQ = 0;
   double Score, iScore;
-  int w;
   double a[WIENER_WIN], b[WIENER_WIN];
-  w = WIENER_WIN;
   a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
   for (i = 0; i < WIENER_HALFWIN; ++i) {
     a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
@@ -862,17 +885,18 @@
     a[WIENER_HALFWIN] -= 2 * a[i];
     b[WIENER_HALFWIN] -= 2 * b[i];
   }
-  for (k = 0; k < w; ++k) {
-    for (l = 0; l < w; ++l) ab[k * w + l] = a[l] * b[k];
+  for (k = 0; k < WIENER_WIN; ++k) {
+    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
   }
-  for (k = 0; k < w * w; ++k) {
+  for (k = 0; k < WIENER_WIN2; ++k) {
     P += ab[k] * M[k];
-    for (l = 0; l < w * w; ++l) Q += ab[k] * H[k * w * w + l] * ab[l];
+    for (l = 0; l < WIENER_WIN2; ++l)
+      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
   }
   Score = Q - 2 * P;
 
-  iP = M[(w * w) >> 1];
-  iQ = H[((w * w) >> 1) * w * w + ((w * w) >> 1)];
+  iP = M[WIENER_WIN2 >> 1];
+  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
   iScore = iQ - 2 * iP;
 
   return Score - iScore;
@@ -887,6 +911,121 @@
   fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
   fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
   fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
+  // Satisfy filter constraints
+  fi[WIENER_WIN - 1] = fi[0];
+  fi[WIENER_WIN - 2] = fi[1];
+  fi[WIENER_WIN - 3] = fi[2];
+  fi[3] = WIENER_FILT_STEP - 2 * (fi[0] + fi[1] + fi[2]);
+}
+
+static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
+                               int filter_level, int partial_frame, int plane,
+                               RestorationInfo *info,
+                               YV12_BUFFER_CONFIG *dst_frame) {
+  WienerInfo *wiener_info = info->wiener_info;
+  AV1_COMMON *const cm = &cpi->common;
+  RestorationInfo *rsi = cpi->rst_search;
+  int64_t err;
+  int bits;
+  double cost_wiener = 0, cost_norestore = 0;
+  MACROBLOCK *x = &cpi->td.mb;
+  double M[WIENER_WIN2];
+  double H[WIENER_WIN2 * WIENER_WIN2];
+  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
+  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
+  const int width = src->uv_crop_width;
+  const int height = src->uv_crop_height;
+  const int src_stride = src->uv_stride;
+  const int dgd_stride = dgd->uv_stride;
+  double score;
+  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
+
+  assert(width == dgd->uv_crop_width);
+  assert(height == dgd->uv_crop_height);
+
+  //  Make a copy of the unfiltered / processed recon buffer
+  aom_yv12_copy_frame(cm->frame_to_show, &cpi->last_frame_uf);
+  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
+                        0, partial_frame);
+  aom_yv12_copy_frame(cm->frame_to_show, &cpi->last_frame_db);
+
+  rsi[plane].frame_restoration_type = RESTORE_NONE;
+
+  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
+                              dst_frame);
+  bits = 0;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+
+  rsi[plane].frame_restoration_type = RESTORE_WIENER;
+  h_start = v_start = WIENER_HALFWIN;
+  h_end = width - WIENER_HALFWIN;
+  v_end = height - WIENER_HALFWIN;
+  if (plane == AOM_PLANE_U) {
+#if CONFIG_AOM_HIGHBITDEPTH
+    if (cm->use_highbitdepth)
+      compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
+                           v_start, v_end, dgd_stride, src_stride, M, H);
+    else
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+      compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
+                    v_end, dgd_stride, src_stride, M, H);
+  } else if (plane == AOM_PLANE_V) {
+#if CONFIG_AOM_HIGHBITDEPTH
+    if (cm->use_highbitdepth)
+      compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
+                           v_start, v_end, dgd_stride, src_stride, M, H);
+    else
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+      compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
+                    v_end, dgd_stride, src_stride, M, H);
+  } else {
+    assert(0);
+  }
+  if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
+    info->frame_restoration_type = RESTORE_NONE;
+    aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+    return cost_norestore;
+  }
+  quantize_sym_filter(vfilterd, rsi[plane].wiener_info[0].vfilter);
+  quantize_sym_filter(hfilterd, rsi[plane].wiener_info[0].hfilter);
+
+  // Filter score computes the value of the function x'*A*x - x'*b for the
+  // learned filter and compares it against identity filer. If there is no
+  // reduction in the function, the filter is reverted back to identity
+  score = compute_score(M, H, rsi[plane].wiener_info[0].vfilter,
+                        rsi[plane].wiener_info[0].hfilter);
+  if (score > 0.0) {
+    info->frame_restoration_type = RESTORE_NONE;
+    aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+    return cost_norestore;
+  }
+
+  info->frame_restoration_type = RESTORE_WIENER;
+  rsi[plane].restoration_type[0] = info->restoration_type[0] = RESTORE_WIENER;
+  rsi[plane].wiener_info[0].level = 1;
+  memcpy(&wiener_info[0], &rsi[plane].wiener_info[0], sizeof(wiener_info[0]));
+  for (tile_idx = 1; tile_idx < ntiles; ++tile_idx) {
+    info->restoration_type[tile_idx] = RESTORE_WIENER;
+    memcpy(&rsi[plane].wiener_info[tile_idx], &rsi[plane].wiener_info[0],
+           sizeof(rsi[plane].wiener_info[0]));
+    memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[0],
+           sizeof(rsi[plane].wiener_info[0]));
+  }
+  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
+                              dst_frame);
+  bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
+  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+  if (cost_wiener > cost_norestore) {
+    info->frame_restoration_type = RESTORE_NONE;
+    aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+    return cost_norestore;
+  }
+
+  aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+  return cost_wiener;
 }
 
 static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
@@ -895,7 +1034,7 @@
                             YV12_BUFFER_CONFIG *dst_frame) {
   WienerInfo *wiener_info = info->wiener_info;
   AV1_COMMON *const cm = &cpi->common;
-  RestorationInfo *rsi = &cpi->rst_search;
+  RestorationInfo *rsi = cpi->rst_search;
   int64_t err;
   int bits;
   double cost_wiener, cost_norestore;
@@ -912,7 +1051,6 @@
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   int i;
-
   const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
                                          &tile_height, &nhtiles, &nvtiles);
   assert(width == dgd->y_crop_width);
@@ -984,7 +1122,7 @@
       wiener_info[tile_idx].level = 0;
     } else {
       wiener_info[tile_idx].level = 1;
-      for (i = 0; i < WIENER_HALFWIN; ++i) {
+      for (i = 0; i < WIENER_WIN; ++i) {
         wiener_info[tile_idx].vfilter[i] =
             rsi->wiener_info[tile_idx].vfilter[i];
         wiener_info[tile_idx].hfilter[i] =
@@ -1005,7 +1143,7 @@
     rsi->wiener_info[tile_idx].level = wiener_info[tile_idx].level;
     if (wiener_info[tile_idx].level) {
       bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
-      for (i = 0; i < WIENER_HALFWIN; ++i) {
+      for (i = 0; i < WIENER_WIN; ++i) {
         rsi->wiener_info[tile_idx].vfilter[i] =
             wiener_info[tile_idx].vfilter[i];
         rsi->wiener_info[tile_idx].hfilter[i] =
@@ -1117,7 +1255,7 @@
 
   if (method == LPF_PICK_MINIMAL_LPF && lf->filter_level) {
     lf->filter_level = 0;
-    cm->rst_info.frame_restoration_type = RESTORE_NONE;
+    cm->rst_info[0].frame_restoration_type = RESTORE_NONE;
   } else if (method >= LPF_PICK_FROM_Q) {
     const int min_filter_level = 0;
     const int max_filter_level = av1_get_max_filter_level(cpi);
@@ -1155,10 +1293,10 @@
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; ++r) {
     cost_restore[r] = search_restore_fun[r](
         src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
-        &cm->rst_info, tile_cost[r], &cpi->trial_frame_rst);
+        &cm->rst_info[0], tile_cost[r], &cpi->trial_frame_rst);
   }
   cost_restore[RESTORE_SWITCHABLE] = search_switchable_restoration(
-      cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE, &cm->rst_info,
+      cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE, &cm->rst_info[0],
       tile_cost);
 
   best_cost_restore = DBL_MAX;
@@ -1169,12 +1307,26 @@
       best_cost_restore = cost_restore[r];
     }
   }
-  cm->rst_info.frame_restoration_type = best_restore;
+  cm->rst_info[0].frame_restoration_type = best_restore;
+
+  // Color components
+  search_wiener_uv(src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+                   AOM_PLANE_U, &cm->rst_info[AOM_PLANE_U],
+                   &cpi->trial_frame_rst);
+  search_wiener_uv(src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+                   AOM_PLANE_V, &cm->rst_info[AOM_PLANE_V],
+                   &cpi->trial_frame_rst);
   /*
+  printf("restore types: %d %d %d\n",
+         cm->rst_info[0].frame_restoration_type,
+         cm->rst_info[1].frame_restoration_type,
+         cm->rst_info[2].frame_restoration_type);
   printf("Frame %d/%d frame_restore_type %d : %f %f %f %f %f\n",
          cm->current_video_frame, cm->show_frame,
-         cm->rst_info.frame_restoration_type, cost_restore[0], cost_restore[1],
+         cm->rst_info[0].frame_restoration_type, cost_restore[0],
+  cost_restore[1],
          cost_restore[2], cost_restore[3], cost_restore[4]);
          */
+
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++) aom_free(tile_cost[r]);
 }