Add guided projection filter to loop restoration

BDRATE:
lowres: -1.01% (up from -0.7%)
midres: -1.90% (up from -1.5%)
hdres:  -2.11% (up from ~1.7%)

Change-Id: I1fe04ec9ef90ccc4cc990e09cd45eea82c752e0c
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 62303b7..408c76d 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -33,7 +33,7 @@
                                       int partial_frame, RestorationInfo *info,
                                       double *best_tile_cost);
 
-const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
+const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 3, 3, 2 };
 
 static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                     AV1_COMMON *const cm, int h_start,
@@ -100,6 +100,228 @@
   return filt_err;
 }
 
+static int64_t get_pixel_proj_error(int64_t *src, int width, int height,
+                                    int src_stride, int64_t *dgd,
+                                    int dgd_stride, int64_t *flt1,
+                                    int flt1_stride, int64_t *flt2,
+                                    int flt2_stride, int *xqd) {
+  int i, j;
+  int64_t err = 0;
+  int xq[2];
+  decode_xq(xqd, xq);
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      const int64_t s = (int64_t)src[i * src_stride + j];
+      const int64_t u = (int64_t)dgd[i * dgd_stride + j];
+      const int64_t f1 = (int64_t)flt1[i * flt1_stride + j] - u;
+      const int64_t f2 = (int64_t)flt2[i * flt2_stride + j] - u;
+      const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
+      const int64_t e =
+          ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
+          ROUND_POWER_OF_TWO(s, SGRPROJ_RST_BITS);
+      err += e * e;
+    }
+  }
+  return err;
+}
+
+static void get_proj_subspace(int64_t *src, int width, int height,
+                              int src_stride, int64_t *dgd, int dgd_stride,
+                              int64_t *flt1, int flt1_stride, int64_t *flt2,
+                              int flt2_stride, int *xq) {
+  int i, j;
+  double H[2][2] = { { 0, 0 }, { 0, 0 } };
+  double C[2] = { 0, 0 };
+  double Det;
+  double x[2];
+  const int size = width * height;
+
+  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
+  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      const double u = (double)dgd[i * dgd_stride + j];
+      const double s = (double)src[i * src_stride + j] - u;
+      const double f1 = (double)flt1[i * flt1_stride + j] - u;
+      const double f2 = (double)flt2[i * flt2_stride + j] - u;
+      H[0][0] += f1 * f1;
+      H[1][1] += f2 * f2;
+      H[0][1] += f1 * f2;
+      C[0] += f1 * s;
+      C[1] += f2 * s;
+    }
+  }
+  H[0][0] /= size;
+  H[0][1] /= size;
+  H[1][1] /= size;
+  H[1][0] = H[0][1];
+  C[0] /= size;
+  C[1] /= size;
+  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
+  if (Det < 1e-8) return;  // ill-posed, return default values
+  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
+  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
+  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
+  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
+}
+
+void encode_xq(int *xq, int *xqd) {
+  xqd[0] = -xq[0];
+  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
+  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
+  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
+}
+
+static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
+                                          int dat_stride, uint8_t *src8,
+                                          int src_stride, int bit_depth,
+                                          int *eps, int *xqd, void *tmpbuf) {
+  int64_t *flt1 = (int64_t *)tmpbuf;
+  int64_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
+  uint8_t *tmpbuf2 = (uint8_t *)(flt2 + RESTORATION_TILEPELS_MAX);
+  int64_t srd[RESTORATION_TILEPELS_MAX];
+  int64_t dgd[RESTORATION_TILEPELS_MAX];
+  int i, j, ep, bestep = 0;
+  int64_t err, besterr = -1;
+  int exqd[2], bestxqd[2] = { 0, 0 };
+  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
+    int exq[2];
+    if (bit_depth > 8) {
+      uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
+      for (i = 0; i < height; ++i) {
+        for (j = 0; j < width; ++j) {
+          flt1[i * width + j] = (int64_t)dat[i * dat_stride + j];
+          flt2[i * width + j] = (int64_t)dat[i * dat_stride + j];
+          dgd[i * width + j] = (int64_t)dat[i * dat_stride + j]
+                               << SGRPROJ_RST_BITS;
+          srd[i * width + j] = (int64_t)src[i * src_stride + j]
+                               << SGRPROJ_RST_BITS;
+        }
+      }
+    } else {
+      uint8_t *src = src8;
+      uint8_t *dat = dat8;
+      for (i = 0; i < height; ++i) {
+        for (j = 0; j < width; ++j) {
+          const int k = i * width + j;
+          const int l = i * dat_stride + j;
+          flt1[k] = (int64_t)dat[l];
+          flt2[k] = (int64_t)dat[l];
+          dgd[k] = (int64_t)dat[l] << SGRPROJ_RST_BITS;
+          srd[k] = (int64_t)src[i * src_stride + j] << SGRPROJ_RST_BITS;
+        }
+      }
+    }
+    av1_selfguided_restoration(flt1, width, height, width, bit_depth,
+                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
+    av1_selfguided_restoration(flt2, width, height, width, bit_depth,
+                               sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
+    get_proj_subspace(srd, width, height, width, dgd, width, flt1, width, flt2,
+                      width, exq);
+    encode_xq(exq, exqd);
+    err = get_pixel_proj_error(srd, width, height, width, dgd, width, flt1,
+                               width, flt2, width, exqd);
+    if (besterr == -1 || err < besterr) {
+      bestep = ep;
+      besterr = err;
+      bestxqd[0] = exqd[0];
+      bestxqd[1] = exqd[1];
+    }
+  }
+  *eps = bestep;
+  xqd[0] = bestxqd[0];
+  xqd[1] = bestxqd[1];
+}
+
+static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
+                             int filter_level, int partial_frame,
+                             RestorationInfo *info, double *best_tile_cost) {
+  SgrprojInfo *sgrproj_info = info->sgrproj_info;
+  double err, cost_norestore, cost_sgrproj;
+  int bits;
+  MACROBLOCK *x = &cpi->td.mb;
+  AV1_COMMON *const cm = &cpi->common;
+  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
+  RestorationInfo rsi;
+  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
+  //  Make a copy of the unfiltered / processed recon buffer
+  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
+  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
+                        1, partial_frame);
+  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
+
+  rsi.frame_restoration_type = RESTORE_SGRPROJ;
+  rsi.sgrproj_info =
+      (SgrprojInfo *)aom_malloc(sizeof(*rsi.sgrproj_info) * ntiles);
+  assert(rsi.sgrproj_info != NULL);
+
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
+    rsi.sgrproj_info[tile_idx].level = 0;
+  // Compute best Sgrproj filters for each tile
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, cm->width, cm->height, 0, 0, &h_start,
+                             &h_end, &v_start, &v_end);
+    err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                               v_end - v_start);
+    // #bits when a tile is not restored
+    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
+    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    best_tile_cost[tile_idx] = DBL_MAX;
+    search_selfguided_restoration(
+        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
+        v_end - v_start, dgd->y_stride,
+        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
+#if CONFIG_AOM_HIGHBITDEPTH
+        cm->bit_depth,
+#else
+        8,
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+        &rsi.sgrproj_info[tile_idx].ep, rsi.sgrproj_info[tile_idx].xqd, tmpbuf);
+    rsi.sgrproj_info[tile_idx].level = 1;
+    err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0);
+    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
+    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
+    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    if (cost_sgrproj >= cost_norestore) {
+      sgrproj_info[tile_idx].level = 0;
+    } else {
+      memcpy(&sgrproj_info[tile_idx], &rsi.sgrproj_info[tile_idx],
+             sizeof(sgrproj_info[tile_idx]));
+      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
+      best_tile_cost[tile_idx] = RDCOST_DBL(
+          x->rdmult, x->rddiv,
+          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
+    }
+    rsi.sgrproj_info[tile_idx].level = 0;
+  }
+  // Cost for Sgrproj filtering
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    bits +=
+        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, sgrproj_info[tile_idx].level);
+    memcpy(&rsi.sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
+           sizeof(sgrproj_info[tile_idx]));
+    if (sgrproj_info[tile_idx].level) {
+      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
+    }
+  }
+  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
+  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+
+  aom_free(rsi.sgrproj_info);
+  aom_free(tmpbuf);
+
+  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
+  return cost_sgrproj;
+}
+
 static double search_bilateral(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                int filter_level, int partial_frame,
                                RestorationInfo *info, double *best_tile_cost) {
@@ -520,7 +742,7 @@
   RestorationInfo rsi;
   int64_t err;
   int bits;
-  double cost_wiener, cost_norestore_tile;
+  double cost_wiener, cost_norestore;
   MACROBLOCK *x = &cpi->td.mb;
   double M[RESTORATION_WIN2];
   double H[RESTORATION_WIN2 * RESTORATION_WIN2];
@@ -533,7 +755,7 @@
   double score;
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
-  int i, j;
+  int i;
 
   const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
                                          &tile_height, &nhtiles, &nvtiles);
@@ -552,7 +774,8 @@
   rsi.wiener_info = (WienerInfo *)aom_malloc(sizeof(*rsi.wiener_info) * ntiles);
   assert(rsi.wiener_info != NULL);
 
-  for (j = 0; j < ntiles; ++j) rsi.wiener_info[j].level = 0;
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
+    rsi.wiener_info[tile_idx].level = 0;
 
   // Compute best Wiener filters for each tile
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
@@ -563,7 +786,7 @@
                                v_end - v_start);
     // #bits when a tile is not restored
     bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
-    cost_norestore_tile = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
     best_tile_cost[tile_idx] = DBL_MAX;
 
     av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
@@ -601,7 +824,7 @@
     bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
     bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
     cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
-    if (cost_wiener >= cost_norestore_tile) {
+    if (cost_wiener >= cost_norestore) {
       wiener_info[tile_idx].level = 0;
     } else {
       wiener_info[tile_idx].level = 1;
@@ -632,6 +855,7 @@
   }
   err = try_restoration_frame(src, cpi, &rsi, partial_frame);
   cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+
   aom_free(rsi.wiener_info);
 
   aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
@@ -713,7 +937,7 @@
 void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                  LPF_PICK_METHOD method) {
   static search_restore_type search_restore_fun[RESTORE_SWITCHABLE_TYPES] = {
-    search_norestore, search_bilateral, search_wiener,
+    search_norestore, search_sgrproj, search_bilateral, search_wiener,
   };
   AV1_COMMON *const cm = &cpi->common;
   struct loopfilter *const lf = &cm->lf;
@@ -734,6 +958,9 @@
   cm->rst_info.wiener_info = (WienerInfo *)aom_realloc(
       cm->rst_info.wiener_info, sizeof(*cm->rst_info.wiener_info) * ntiles);
   assert(cm->rst_info.wiener_info != NULL);
+  cm->rst_info.sgrproj_info = (SgrprojInfo *)aom_realloc(
+      cm->rst_info.sgrproj_info, sizeof(*cm->rst_info.sgrproj_info) * ntiles);
+  assert(cm->rst_info.sgrproj_info != NULL);
 
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++)
     tile_cost[r] = (double *)aom_malloc(sizeof(*tile_cost[0]) * ntiles);
@@ -796,10 +1023,10 @@
   }
   cm->rst_info.frame_restoration_type = best_restore;
   /*
-  printf("Frame %d/%d frame_restore_type %d : %f %f %f %f\n",
+  printf("Frame %d/%d frame_restore_type %d : %f %f %f %f %f\n",
          cm->current_video_frame, cm->show_frame,
-         cm->rst_info.frame_restoration_type,
-         cost_restore[0], cost_restore[1], cost_restore[2], cost_restore[3]);
+         cm->rst_info.frame_restoration_type, cost_restore[0], cost_restore[1],
+         cost_restore[2], cost_restore[3], cost_restore[4]);
          */
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++) aom_free(tile_cost[r]);
 }