Add UV wiener loop restoration

Enables Wiener based loop restoration only for the UV
frames. The selfguided and domaintranform filters do not
work very well for UV components, hence they are disabled.
For each UV frame a single set of wiener parameters are
sent. They are applied tile-wise, but all tiles use the
same parameters.

BDRATE (Global PSNR) results:
-----------------------------
lowres: -1.266% (up from -0.666%, good improvement)
midres: -1.815% (up from -1.792%, tiny improvement)

Tiling on UV components will be explored subsequently.

Change-Id: Ib5be93121c4e88e05edf3c36c46488df3cfcd1e2
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 40c3486..a07cea2 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -23,6 +23,7 @@
 
 #include "av1/common/onyxc_int.h"
 #include "av1/common/quant_common.h"
+#include "av1/common/restoration.h"
 
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/picklpf.h"
@@ -41,16 +42,20 @@
                                     const YV12_BUFFER_CONFIG *dst,
                                     const AV1_COMMON *cm, int h_start,
                                     int width, int v_start, int height,
-                                    int y_only) {
-  int64_t filt_err;
+                                    int components_pattern) {
+  int64_t filt_err = 0;
 #if CONFIG_AOM_HIGHBITDEPTH
   if (cm->use_highbitdepth) {
-    filt_err =
-        aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
-    if (!y_only) {
+    if ((components_pattern >> AOM_PLANE_Y) & 1) {
+      filt_err +=
+          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
+    }
+    if ((components_pattern >> AOM_PLANE_U) & 1) {
       filt_err += aom_highbd_get_u_sse_part(
           src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
           v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+    }
+    if ((components_pattern >> AOM_PLANE_V) & 1) {
       filt_err += aom_highbd_get_v_sse_part(
           src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
           v_start >> cm->subsampling_y, height >> cm->subsampling_y);
@@ -58,11 +63,15 @@
     return filt_err;
   }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-  filt_err = aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
-  if (!y_only) {
+  if ((components_pattern >> AOM_PLANE_Y) & 1) {
+    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
+  }
+  if ((components_pattern >> AOM_PLANE_U) & 1) {
     filt_err += aom_get_u_sse_part(
         src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
         v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+  }
+  if ((components_pattern >> AOM_PLANE_V) & 1) {
     filt_err += aom_get_u_sse_part(
         src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
         v_start >> cm->subsampling_y, height >> cm->subsampling_y);
@@ -70,10 +79,41 @@
   return filt_err;
 }
 
+static int64_t sse_restoration_frame(const YV12_BUFFER_CONFIG *src,
+                                     const YV12_BUFFER_CONFIG *dst,
+                                     int components_pattern) {
+  int64_t filt_err = 0;
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (cm->use_highbitdepth) {
+    if ((components_pattern >> AOM_PLANE_Y) & 1) {
+      filt_err += aom_highbd_get_y_sse(src, dst);
+    }
+    if ((components_pattern >> AOM_PLANE_U) & 1) {
+      filt_err += aom_highbd_get_u_sse(src, dst);
+    }
+    if ((components_pattern >> AOM_PLANE_V) & 1) {
+      filt_err += aom_highbd_get_v_sse(src, dst);
+    }
+    return filt_err;
+  }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+  if ((components_pattern >> AOM_PLANE_Y) & 1) {
+    filt_err = aom_get_y_sse(src, dst);
+  }
+  if ((components_pattern >> AOM_PLANE_U) & 1) {
+    filt_err += aom_get_u_sse(src, dst);
+  }
+  if ((components_pattern >> AOM_PLANE_V) & 1) {
+    filt_err += aom_get_v_sse(src, dst);
+  }
+  return filt_err;
+}
+
 static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
-                                    int y_only, int partial_frame, int tile_idx,
-                                    int subtile_idx, int subtile_bits,
+                                    int components_pattern, int partial_frame,
+                                    int tile_idx, int subtile_idx,
+                                    int subtile_bits,
                                     YV12_BUFFER_CONFIG *dst_frame) {
   AV1_COMMON *const cm = &cpi->common;
   int64_t filt_err;
@@ -83,41 +123,27 @@
                                          &tile_height, &nhtiles, &nvtiles);
   (void)ntiles;
 
-  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
-                             dst_frame);
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
+                             partial_frame, dst_frame);
   av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
                            nvtiles, tile_width, tile_height, cm->width,
                            cm->height, 0, 0, &h_start, &h_end, &v_start,
                            &v_end);
   filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
-                                  v_start, v_end - v_start, y_only);
+                                  v_start, v_end - v_start, components_pattern);
 
   return filt_err;
 }
 
 static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *const cpi, RestorationInfo *rsi,
-                                     int y_only, int partial_frame,
+                                     int components_pattern, int partial_frame,
                                      YV12_BUFFER_CONFIG *dst_frame) {
   AV1_COMMON *const cm = &cpi->common;
   int64_t filt_err;
-  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
-                             dst_frame);
-#if CONFIG_AOM_HIGHBITDEPTH
-  if (cm->use_highbitdepth) {
-    filt_err = aom_highbd_get_y_sse(src, dst_frame);
-    if (!y_only) {
-      filt_err += aom_highbd_get_u_sse(src, dst_frame);
-      filt_err += aom_highbd_get_v_sse(src, dst_frame);
-    }
-    return filt_err;
-  }
-#endif  // CONFIG_AOM_HIGHBITDEPTH
-  filt_err = aom_get_y_sse(src, dst_frame);
-  if (!y_only) {
-    filt_err += aom_get_u_sse(src, dst_frame);
-    filt_err += aom_get_v_sse(src, dst_frame);
-  }
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
+                             partial_frame, dst_frame);
+  filt_err = sse_restoration_frame(src, dst_frame, components_pattern);
   return filt_err;
 }
 
@@ -299,7 +325,7 @@
   MACROBLOCK *x = &cpi->td.mb;
   AV1_COMMON *const cm = &cpi->common;
   const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
-  RestorationInfo *rsi = &cpi->rst_search;
+  RestorationInfo *rsi = &cpi->rst_search[0];
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   // Allocate for the src buffer at high precision
@@ -516,7 +542,7 @@
   MACROBLOCK *x = &cpi->td.mb;
   AV1_COMMON *const cm = &cpi->common;
   const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
-  RestorationInfo *rsi = &cpi->rst_search;
+  RestorationInfo *rsi = &cpi->rst_search[0];
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
@@ -528,7 +554,6 @@
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
 
   rsi->frame_restoration_type = RESTORE_DOMAINTXFMRF;
-  rsi->domaintxfmrf_info = cpi->rst_search.domaintxfmrf_info;
 
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
     rsi->domaintxfmrf_info[tile_idx].level = 0;
@@ -843,7 +868,7 @@
   return 1;
 }
 
-// Computes the function x'*A*x - x'*b for the learned filters, and compares
+// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
 // against identity filters; Final score is defined as the difference between
 // the function values
 static double compute_score(double *M, double *H, int *vfilt, int *hfilt) {
@@ -852,9 +877,7 @@
   double P = 0, Q = 0;
   double iP = 0, iQ = 0;
   double Score, iScore;
-  int w;
   double a[WIENER_WIN], b[WIENER_WIN];
-  w = WIENER_WIN;
   a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
   for (i = 0; i < WIENER_HALFWIN; ++i) {
     a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
@@ -862,17 +885,18 @@
     a[WIENER_HALFWIN] -= 2 * a[i];
     b[WIENER_HALFWIN] -= 2 * b[i];
   }
-  for (k = 0; k < w; ++k) {
-    for (l = 0; l < w; ++l) ab[k * w + l] = a[l] * b[k];
+  for (k = 0; k < WIENER_WIN; ++k) {
+    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
   }
-  for (k = 0; k < w * w; ++k) {
+  for (k = 0; k < WIENER_WIN2; ++k) {
     P += ab[k] * M[k];
-    for (l = 0; l < w * w; ++l) Q += ab[k] * H[k * w * w + l] * ab[l];
+    for (l = 0; l < WIENER_WIN2; ++l)
+      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
   }
   Score = Q - 2 * P;
 
-  iP = M[(w * w) >> 1];
-  iQ = H[((w * w) >> 1) * w * w + ((w * w) >> 1)];
+  iP = M[WIENER_WIN2 >> 1];
+  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
   iScore = iQ - 2 * iP;
 
   return Score - iScore;
@@ -887,6 +911,121 @@
   fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
   fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
   fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
+  // Satisfy filter constraints
+  fi[WIENER_WIN - 1] = fi[0];
+  fi[WIENER_WIN - 2] = fi[1];
+  fi[WIENER_WIN - 3] = fi[2];
+  fi[3] = WIENER_FILT_STEP - 2 * (fi[0] + fi[1] + fi[2]);
+}
+
+static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
+                               int filter_level, int partial_frame, int plane,
+                               RestorationInfo *info,
+                               YV12_BUFFER_CONFIG *dst_frame) {
+  WienerInfo *wiener_info = info->wiener_info;
+  AV1_COMMON *const cm = &cpi->common;
+  RestorationInfo *rsi = cpi->rst_search;
+  int64_t err;
+  int bits;
+  double cost_wiener = 0, cost_norestore = 0;
+  MACROBLOCK *x = &cpi->td.mb;
+  double M[WIENER_WIN2];
+  double H[WIENER_WIN2 * WIENER_WIN2];
+  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
+  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
+  const int width = src->uv_crop_width;
+  const int height = src->uv_crop_height;
+  const int src_stride = src->uv_stride;
+  const int dgd_stride = dgd->uv_stride;
+  double score;
+  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
+
+  assert(width == dgd->uv_crop_width);
+  assert(height == dgd->uv_crop_height);
+
+  //  Make a copy of the unfiltered / processed recon buffer
+  aom_yv12_copy_frame(cm->frame_to_show, &cpi->last_frame_uf);
+  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
+                        0, partial_frame);
+  aom_yv12_copy_frame(cm->frame_to_show, &cpi->last_frame_db);
+
+  rsi[plane].frame_restoration_type = RESTORE_NONE;
+
+  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
+                              dst_frame);
+  bits = 0;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+
+  rsi[plane].frame_restoration_type = RESTORE_WIENER;
+  h_start = v_start = WIENER_HALFWIN;
+  h_end = width - WIENER_HALFWIN;
+  v_end = height - WIENER_HALFWIN;
+  if (plane == AOM_PLANE_U) {
+#if CONFIG_AOM_HIGHBITDEPTH
+    if (cm->use_highbitdepth)
+      compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
+                           v_start, v_end, dgd_stride, src_stride, M, H);
+    else
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+      compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
+                    v_end, dgd_stride, src_stride, M, H);
+  } else if (plane == AOM_PLANE_V) {
+#if CONFIG_AOM_HIGHBITDEPTH
+    if (cm->use_highbitdepth)
+      compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
+                           v_start, v_end, dgd_stride, src_stride, M, H);
+    else
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+      compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
+                    v_end, dgd_stride, src_stride, M, H);
+  } else {
+    assert(0);
+  }
+  if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
+    info->frame_restoration_type = RESTORE_NONE;
+    aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+    return cost_norestore;
+  }
+  quantize_sym_filter(vfilterd, rsi[plane].wiener_info[0].vfilter);
+  quantize_sym_filter(hfilterd, rsi[plane].wiener_info[0].hfilter);
+
+  // Filter score computes the value of the function x'*A*x - x'*b for the
+  // learned filter and compares it against identity filer. If there is no
+  // reduction in the function, the filter is reverted back to identity
+  score = compute_score(M, H, rsi[plane].wiener_info[0].vfilter,
+                        rsi[plane].wiener_info[0].hfilter);
+  if (score > 0.0) {
+    info->frame_restoration_type = RESTORE_NONE;
+    aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+    return cost_norestore;
+  }
+
+  info->frame_restoration_type = RESTORE_WIENER;
+  rsi[plane].restoration_type[0] = info->restoration_type[0] = RESTORE_WIENER;
+  rsi[plane].wiener_info[0].level = 1;
+  memcpy(&wiener_info[0], &rsi[plane].wiener_info[0], sizeof(wiener_info[0]));
+  for (tile_idx = 1; tile_idx < ntiles; ++tile_idx) {
+    info->restoration_type[tile_idx] = RESTORE_WIENER;
+    memcpy(&rsi[plane].wiener_info[tile_idx], &rsi[plane].wiener_info[0],
+           sizeof(rsi[plane].wiener_info[0]));
+    memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[0],
+           sizeof(rsi[plane].wiener_info[0]));
+  }
+  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
+                              dst_frame);
+  bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
+  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+  if (cost_wiener > cost_norestore) {
+    info->frame_restoration_type = RESTORE_NONE;
+    aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+    return cost_norestore;
+  }
+
+  aom_yv12_copy_frame(&cpi->last_frame_uf, cm->frame_to_show);
+  return cost_wiener;
 }
 
 static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
@@ -895,7 +1034,7 @@
                             YV12_BUFFER_CONFIG *dst_frame) {
   WienerInfo *wiener_info = info->wiener_info;
   AV1_COMMON *const cm = &cpi->common;
-  RestorationInfo *rsi = &cpi->rst_search;
+  RestorationInfo *rsi = cpi->rst_search;
   int64_t err;
   int bits;
   double cost_wiener, cost_norestore;
@@ -912,7 +1051,6 @@
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   int i;
-
   const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
                                          &tile_height, &nhtiles, &nvtiles);
   assert(width == dgd->y_crop_width);
@@ -984,7 +1122,7 @@
       wiener_info[tile_idx].level = 0;
     } else {
       wiener_info[tile_idx].level = 1;
-      for (i = 0; i < WIENER_HALFWIN; ++i) {
+      for (i = 0; i < WIENER_WIN; ++i) {
         wiener_info[tile_idx].vfilter[i] =
             rsi->wiener_info[tile_idx].vfilter[i];
         wiener_info[tile_idx].hfilter[i] =
@@ -1005,7 +1143,7 @@
     rsi->wiener_info[tile_idx].level = wiener_info[tile_idx].level;
     if (wiener_info[tile_idx].level) {
       bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
-      for (i = 0; i < WIENER_HALFWIN; ++i) {
+      for (i = 0; i < WIENER_WIN; ++i) {
         rsi->wiener_info[tile_idx].vfilter[i] =
             wiener_info[tile_idx].vfilter[i];
         rsi->wiener_info[tile_idx].hfilter[i] =
@@ -1117,7 +1255,7 @@
 
   if (method == LPF_PICK_MINIMAL_LPF && lf->filter_level) {
     lf->filter_level = 0;
-    cm->rst_info.frame_restoration_type = RESTORE_NONE;
+    cm->rst_info[0].frame_restoration_type = RESTORE_NONE;
   } else if (method >= LPF_PICK_FROM_Q) {
     const int min_filter_level = 0;
     const int max_filter_level = av1_get_max_filter_level(cpi);
@@ -1155,10 +1293,10 @@
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; ++r) {
     cost_restore[r] = search_restore_fun[r](
         src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
-        &cm->rst_info, tile_cost[r], &cpi->trial_frame_rst);
+        &cm->rst_info[0], tile_cost[r], &cpi->trial_frame_rst);
   }
   cost_restore[RESTORE_SWITCHABLE] = search_switchable_restoration(
-      cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE, &cm->rst_info,
+      cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE, &cm->rst_info[0],
       tile_cost);
 
   best_cost_restore = DBL_MAX;
@@ -1169,12 +1307,26 @@
       best_cost_restore = cost_restore[r];
     }
   }
-  cm->rst_info.frame_restoration_type = best_restore;
+  cm->rst_info[0].frame_restoration_type = best_restore;
+
+  // Color components
+  search_wiener_uv(src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+                   AOM_PLANE_U, &cm->rst_info[AOM_PLANE_U],
+                   &cpi->trial_frame_rst);
+  search_wiener_uv(src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+                   AOM_PLANE_V, &cm->rst_info[AOM_PLANE_V],
+                   &cpi->trial_frame_rst);
   /*
+  printf("restore types: %d %d %d\n",
+         cm->rst_info[0].frame_restoration_type,
+         cm->rst_info[1].frame_restoration_type,
+         cm->rst_info[2].frame_restoration_type);
   printf("Frame %d/%d frame_restore_type %d : %f %f %f %f %f\n",
          cm->current_video_frame, cm->show_frame,
-         cm->rst_info.frame_restoration_type, cost_restore[0], cost_restore[1],
+         cm->rst_info[0].frame_restoration_type, cost_restore[0],
+  cost_restore[1],
          cost_restore[2], cost_restore[3], cost_restore[4]);
          */
+
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++) aom_free(tile_cost[r]);
 }