Extend color restoration to support guided filters

Frame level guided filters are now used as an option for
chroma components in addition to the Wiener filter.

Change-Id: Ie33299d22c15b69741ede55686177b7b8ce8e2b3
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 7d482f0..c1a8dae 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2460,8 +2460,13 @@
         aom_rb_read_bit(rb) ? RESTORE_SWITCHABLE : RESTORE_NONE;
   }
   for (p = 1; p < MAX_MB_PLANE; ++p) {
-    cm->rst_info[p].frame_restoration_type =
-        aom_rb_read_bit(rb) ? RESTORE_WIENER : RESTORE_NONE;
+    rsi = &cm->rst_info[p];
+    if (aom_rb_read_bit(rb)) {
+      rsi->frame_restoration_type =
+          aom_rb_read_bit(rb) ? RESTORE_SGRPROJ : RESTORE_WIENER;
+    } else {
+      rsi->frame_restoration_type = RESTORE_NONE;
+    }
   }
 
   cm->rst_info[0].restoration_tilesize = RESTORATION_TILESIZE_MAX;
@@ -2594,6 +2599,7 @@
   }
   for (p = 1; p < MAX_MB_PLANE; ++p) {
     set_default_wiener(&ref_wiener_info);
+    set_default_sgrproj(&ref_sgrproj_info);
     rsi = &cm->rst_info[p];
     if (rsi->frame_restoration_type == RESTORE_WIENER) {
       for (i = 0; i < ntiles_uv; ++i) {
@@ -2607,6 +2613,21 @@
           read_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, rb);
         }
       }
+    } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
+      for (i = 0; i < ntiles_uv; ++i) {
+        if (ntiles_uv > 1)
+          rsi->restoration_type[i] =
+              aom_read(rb, RESTORE_NONE_SGRPROJ_PROB, ACCT_STR)
+                  ? RESTORE_SGRPROJ
+                  : RESTORE_NONE;
+        else
+          rsi->restoration_type[i] = RESTORE_SGRPROJ;
+        if (rsi->restoration_type[i] == RESTORE_SGRPROJ) {
+          read_sgrproj_filter(&rsi->sgrproj_info[i], &ref_sgrproj_info, rb);
+        }
+      }
+    } else if (rsi->frame_restoration_type != RESTORE_NONE) {
+      assert(0);
     }
   }
 }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 5f92c41..a972573 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -3370,7 +3370,14 @@
     rsi = &cm->rst_info[p];
     switch (rsi->frame_restoration_type) {
       case RESTORE_NONE: aom_wb_write_bit(wb, 0); break;
-      case RESTORE_WIENER: aom_wb_write_bit(wb, 1); break;
+      case RESTORE_WIENER:
+        aom_wb_write_bit(wb, 1);
+        aom_wb_write_bit(wb, 0);
+        break;
+      case RESTORE_SGRPROJ:
+        aom_wb_write_bit(wb, 1);
+        aom_wb_write_bit(wb, 1);
+        break;
       default: assert(0);
     }
   }
@@ -3483,6 +3490,7 @@
   }
   for (p = 1; p < MAX_MB_PLANE; ++p) {
     set_default_wiener(&ref_wiener_info);
+    set_default_sgrproj(&ref_sgrproj_info);
     rsi = &cm->rst_info[p];
     if (rsi->frame_restoration_type == RESTORE_WIENER) {
       for (i = 0; i < ntiles_uv; ++i) {
@@ -3493,6 +3501,15 @@
           write_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, wb);
         }
       }
+    } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
+      for (i = 0; i < ntiles_uv; ++i) {
+        if (ntiles_uv > 1)
+          aom_write(wb, rsi->restoration_type[i] != RESTORE_NONE,
+                    RESTORE_NONE_SGRPROJ_PROB);
+        if (rsi->restoration_type[i] != RESTORE_NONE) {
+          write_sgrproj_filter(&rsi->sgrproj_info[i], &ref_sgrproj_info, wb);
+        }
+      }
     } else if (rsi->frame_restoration_type != RESTORE_NONE) {
       assert(0);
     }
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 80c6153..d73f32c 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -364,6 +364,121 @@
   return bits;
 }
 
+static double search_sgrproj_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
+                                int partial_frame, int plane,
+                                RestorationInfo *info, RestorationType *type,
+                                YV12_BUFFER_CONFIG *dst_frame) {
+  SgrprojInfo *sgrproj_info = info->sgrproj_info;
+  int64_t err;
+  double cost_norestore, cost_sgrproj, cost_sgrproj_frame;
+  int bits;
+  MACROBLOCK *x = &cpi->td.mb;
+  AV1_COMMON *const cm = &cpi->common;
+  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
+
+  const int width = src->uv_crop_width;
+  const int height = src->uv_crop_height;
+  const int src_stride = src->uv_stride;
+  const int dgd_stride = dgd->uv_stride;
+
+  RestorationInfo *rsi = cpi->rst_search;
+  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  // Allocate for the src buffer at high precision
+  const int ntiles = av1_get_rest_ntiles(
+      width, height, cm->rst_info[plane].restoration_tilesize, &tile_width,
+      &tile_height, &nhtiles, &nvtiles);
+  SgrprojInfo ref_sgrproj_info;
+  set_default_sgrproj(&ref_sgrproj_info);
+  assert(width == dgd->uv_crop_width);
+  assert(height == dgd->uv_crop_height);
+
+  rsi[plane].frame_restoration_type = RESTORE_SGRPROJ;
+
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
+  }
+  // Compute best Sgrproj filters for each tile
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, width, height, 0, 0, &h_start, &h_end,
+                             &v_start, &v_end);
+    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
+                               h_end - h_start, v_start, v_end - v_start,
+                               (1 << plane));
+    // #bits when a tile is not restored
+    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
+    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    if (plane == AOM_PLANE_U) {
+      search_selfguided_restoration(
+          dgd->u_buffer + v_start * dgd_stride + h_start, h_end - h_start,
+          v_end - v_start, dgd_stride,
+          src->u_buffer + v_start * src_stride + h_start, src_stride,
+#if CONFIG_HIGHBITDEPTH
+          cm->bit_depth,
+#else
+          8,
+#endif  // CONFIG_HIGHBITDEPTH
+          &rsi[plane].sgrproj_info[tile_idx].ep,
+          rsi[plane].sgrproj_info[tile_idx].xqd, cm->rst_internal.tmpbuf);
+    } else if (plane == AOM_PLANE_V) {
+      search_selfguided_restoration(
+          dgd->v_buffer + v_start * dgd_stride + h_start, h_end - h_start,
+          v_end - v_start, dgd_stride,
+          src->v_buffer + v_start * src_stride + h_start, src_stride,
+#if CONFIG_HIGHBITDEPTH
+          cm->bit_depth,
+#else
+          8,
+#endif  // CONFIG_HIGHBITDEPTH
+          &rsi[plane].sgrproj_info[tile_idx].ep,
+          rsi[plane].sgrproj_info[tile_idx].xqd, cm->rst_internal.tmpbuf);
+    } else {
+      assert(0);
+    }
+    rsi[plane].restoration_type[tile_idx] = RESTORE_SGRPROJ;
+    err = try_restoration_tile(src, cpi, rsi, (1 << plane), partial_frame,
+                               tile_idx, 0, 0, dst_frame);
+    bits = count_sgrproj_bits(&rsi[plane].sgrproj_info[tile_idx],
+                              &ref_sgrproj_info)
+           << AV1_PROB_COST_SHIFT;
+    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
+    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    if (cost_sgrproj >= cost_norestore) {
+      type[tile_idx] = RESTORE_NONE;
+    } else {
+      type[tile_idx] = RESTORE_SGRPROJ;
+      memcpy(&sgrproj_info[tile_idx], &rsi[plane].sgrproj_info[tile_idx],
+             sizeof(sgrproj_info[tile_idx]));
+      memcpy(&ref_sgrproj_info, &sgrproj_info[tile_idx],
+             sizeof(ref_sgrproj_info));
+    }
+    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
+  }
+  // Cost for Sgrproj filtering
+  set_default_sgrproj(&ref_sgrproj_info);
+  bits = frame_level_restore_bits[rsi[plane].frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    bits +=
+        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
+    memcpy(&rsi[plane].sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
+           sizeof(sgrproj_info[tile_idx]));
+    if (type[tile_idx] == RESTORE_SGRPROJ) {
+      bits += count_sgrproj_bits(&rsi[plane].sgrproj_info[tile_idx],
+                                 &ref_sgrproj_info)
+              << AV1_PROB_COST_SHIFT;
+      memcpy(&ref_sgrproj_info, &rsi[plane].sgrproj_info[tile_idx],
+             sizeof(ref_sgrproj_info));
+    }
+    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
+  }
+  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
+                              dst_frame);
+  cost_sgrproj_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+  return cost_sgrproj_frame;
+}
+
 static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                              int partial_frame, RestorationInfo *info,
                              RestorationType *type, double *best_tile_cost,
@@ -895,7 +1010,7 @@
   RestorationInfo *rsi = cpi->rst_search;
   int64_t err;
   int bits;
-  double cost_wiener, cost_norestore, cost_wiener_frame, cost_norestore_frame;
+  double cost_wiener, cost_norestore, cost_wiener_frame;
   MACROBLOCK *x = &cpi->td.mb;
   double M[WIENER_WIN2];
   double H[WIENER_WIN2 * WIENER_WIN2];
@@ -908,19 +1023,14 @@
   double score;
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
-  const int ntiles =
-      av1_get_rest_ntiles(width, height, cm->rst_info[1].restoration_tilesize,
-                          &tile_width, &tile_height, &nhtiles, &nvtiles);
+  const int ntiles = av1_get_rest_ntiles(
+      width, height, cm->rst_info[plane].restoration_tilesize, &tile_width,
+      &tile_height, &nhtiles, &nvtiles);
   WienerInfo ref_wiener_info;
   set_default_wiener(&ref_wiener_info);
   assert(width == dgd->uv_crop_width);
   assert(height == dgd->uv_crop_height);
 
-  rsi[plane].frame_restoration_type = RESTORE_NONE;
-  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
-  bits = 0;
-  cost_norestore_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
-
   rsi[plane].frame_restoration_type = RESTORE_WIENER;
 
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
@@ -1036,15 +1146,7 @@
   err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
                               dst_frame);
   cost_wiener_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
-
-  if (cost_wiener_frame < cost_norestore_frame) {
-    info->frame_restoration_type = RESTORE_WIENER;
-  } else {
-    info->frame_restoration_type = RESTORE_NONE;
-  }
-
-  return info->frame_restoration_type == RESTORE_WIENER ? cost_wiener_frame
-                                                        : cost_norestore_frame;
+  return cost_wiener_frame;
 }
 
 static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
@@ -1183,11 +1285,41 @@
   return cost_wiener;
 }
 
+static double search_norestore_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
+                                  int partial_frame, int plane,
+                                  RestorationInfo *info, RestorationType *type,
+                                  YV12_BUFFER_CONFIG *dst_frame) {
+  double cost_norestore;
+  int64_t err;
+  int bits;
+  MACROBLOCK *x = &cpi->td.mb;
+  AV1_COMMON *const cm = &cpi->common;
+  const int width = src->uv_crop_width;
+  const int height = src->uv_crop_height;
+  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
+  const int ntiles = av1_get_rest_ntiles(
+      width, height, cm->rst_info[plane].restoration_tilesize, &tile_width,
+      &tile_height, &nhtiles, &nvtiles);
+  (void)dst_frame;
+  (void)partial_frame;
+
+  info->frame_restoration_type = RESTORE_NONE;
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    type[tile_idx] = RESTORE_NONE;
+  }
+  // RD cost associated with no restoration
+  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
+  bits = frame_level_restore_bits[RESTORE_NONE] << AV1_PROB_COST_SHIFT;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+  return cost_norestore;
+}
+
 static double search_norestore(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                int partial_frame, RestorationInfo *info,
                                RestorationType *type, double *best_tile_cost,
                                YV12_BUFFER_CONFIG *dst_frame) {
-  double err, cost_norestore;
+  int64_t err;
+  double cost_norestore;
   int bits;
   MACROBLOCK *x = &cpi->td.mb;
   AV1_COMMON *const cm = &cpi->common;
@@ -1328,14 +1460,39 @@
   }
 
   // Color components
-  search_wiener_uv(src, cpi, method == LPF_PICK_FROM_SUBIMAGE, AOM_PLANE_U,
-                   &cm->rst_info[AOM_PLANE_U],
-                   cm->rst_info[AOM_PLANE_U].restoration_type,
-                   &cpi->trial_frame_rst);
-  search_wiener_uv(src, cpi, method == LPF_PICK_FROM_SUBIMAGE, AOM_PLANE_V,
-                   &cm->rst_info[AOM_PLANE_V],
-                   cm->rst_info[AOM_PLANE_V].restoration_type,
-                   &cpi->trial_frame_rst);
+  const int ntiles_uv = av1_get_rest_ntiles(
+      ROUND_POWER_OF_TWO(cm->width, cm->subsampling_x),
+      ROUND_POWER_OF_TWO(cm->height, cm->subsampling_y),
+      cm->rst_info[1].restoration_tilesize, NULL, NULL, NULL, NULL);
+  for (int plane = AOM_PLANE_U; plane <= AOM_PLANE_V; ++plane) {
+    double cost_uv[RESTORE_SWITCHABLE_TYPES];
+    cost_uv[RESTORE_NONE] = search_norestore_uv(
+        src, cpi, method == LPF_PICK_FROM_SUBIMAGE, plane, &cm->rst_info[plane],
+        restore_types[0], &cpi->trial_frame_rst);
+    cost_uv[RESTORE_WIENER] = search_wiener_uv(
+        src, cpi, method == LPF_PICK_FROM_SUBIMAGE, plane, &cm->rst_info[plane],
+        restore_types[RESTORE_WIENER], &cpi->trial_frame_rst);
+    cost_uv[RESTORE_SGRPROJ] = search_sgrproj_uv(
+        src, cpi, method == LPF_PICK_FROM_SUBIMAGE, plane, &cm->rst_info[plane],
+        restore_types[RESTORE_SGRPROJ], &cpi->trial_frame_rst);
+    if (cost_uv[RESTORE_SGRPROJ] < cost_uv[RESTORE_WIENER] &&
+        cost_uv[RESTORE_SGRPROJ] < cost_uv[RESTORE_NONE]) {
+      cm->rst_info[plane].frame_restoration_type = RESTORE_SGRPROJ;
+      memcpy(cm->rst_info[plane].restoration_type,
+             restore_types[RESTORE_SGRPROJ],
+             ntiles_uv * sizeof(restore_types[RESTORE_SGRPROJ][0]));
+    } else if (cost_uv[RESTORE_WIENER] < cost_uv[RESTORE_NONE] &&
+               cost_uv[RESTORE_WIENER] < cost_uv[RESTORE_SGRPROJ]) {
+      cm->rst_info[plane].frame_restoration_type = RESTORE_WIENER;
+      memcpy(cm->rst_info[plane].restoration_type,
+             restore_types[RESTORE_WIENER],
+             ntiles_uv * sizeof(restore_types[RESTORE_WIENER][0]));
+    } else {
+      cm->rst_info[plane].frame_restoration_type = RESTORE_NONE;
+      memcpy(cm->rst_info[plane].restoration_type, restore_types[RESTORE_NONE],
+             ntiles_uv * sizeof(restore_types[RESTORE_NONE][0]));
+    }
+  }
   /*
   printf("Frame %d/%d restore types: %d %d %d\n", cm->current_video_frame,
          cm->show_frame, cm->rst_info[0].frame_restoration_type,