Encode loop restoration coefficients per tile

This is a baby-step towards encoding the coefficients at the start of
superblocks at the top-left of loop restoration tiles. Note that this
patch causes us to reset "wiener_info" and "sgrproj_info" at each tile
boundary, which will cause a performance drop.

This is necessary because, in order for tiles to be processed in
parallel, we cannot delta-encode coefficients across tile boundaries
if the coefficients are signalled within tiles. We could probably do
better than the current patch by, say, delta-encoding against previous
frames.

This patch also fixes up the costing in pickrst.c to match

Change-Id: I5b8b91d63aaf49627cde40219c31c0ac776dfd38
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index a198f9c..47e00aa 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -172,8 +172,8 @@
   return filt_err;
 }
 
-static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
-                                    int src_stride, uint8_t *dat8,
+static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
+                                    int src_stride, const uint8_t *dat8,
                                     int dat_stride, int bit_depth,
                                     int32_t *flt1, int flt1_stride,
                                     int32_t *flt2, int flt2_stride, int *xqd) {
@@ -219,9 +219,9 @@
 
 #define USE_SGRPROJ_REFINEMENT_SEARCH 1
 static int64_t finer_search_pixel_proj_error(
-    uint8_t *src8, int width, int height, int src_stride, uint8_t *dat8,
-    int dat_stride, int bit_depth, int32_t *flt1, int flt1_stride,
-    int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
+    const uint8_t *src8, int width, int height, int src_stride,
+    const uint8_t *dat8, int dat_stride, int bit_depth, int32_t *flt1,
+    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
   int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
                                      dat_stride, bit_depth, flt1, flt1_stride,
                                      flt2, flt2_stride, xqd);
@@ -273,7 +273,7 @@
   return err;
 }
 
-static void get_proj_subspace(uint8_t *src8, int width, int height,
+static void get_proj_subspace(const uint8_t *src8, int width, int height,
                               int src_stride, uint8_t *dat8, int dat_stride,
                               int bit_depth, int32_t *flt1, int flt1_stride,
                               int32_t *flt2, int flt2_stride, int *xq) {
@@ -346,7 +346,7 @@
 }
 
 static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
-                                          int dat_stride, uint8_t *src8,
+                                          int dat_stride, const uint8_t *src8,
                                           int src_stride, int bit_depth,
                                           int *eps, int *xqd, int32_t *rstbuf) {
   int32_t *flt1 = rstbuf;
@@ -420,124 +420,227 @@
   return bits;
 }
 
+struct rest_search_ctxt {
+  const YV12_BUFFER_CONFIG *src;
+  AV1_COMP *cpi;
+  uint8_t *dgd_buffer;
+  const uint8_t *src_buffer;
+  int dgd_stride;
+  int src_stride;
+  int partial_frame;
+  RestorationInfo *info;
+  RestorationType *type;
+  double *best_tile_cost;
+  int plane;
+  int plane_width;
+  int plane_height;
+  int nrtiles_x;
+  int nrtiles_y;
+  YV12_BUFFER_CONFIG *dst_frame;
+};
+
+// Fill in ctxt. Returns the number of restoration tiles for this plane
+static INLINE int init_rest_search_ctxt(
+    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int partial_frame, int plane,
+    RestorationInfo *info, RestorationType *type, double *best_tile_cost,
+    YV12_BUFFER_CONFIG *dst_frame, struct rest_search_ctxt *ctxt) {
+  AV1_COMMON *const cm = &cpi->common;
+  ctxt->src = src;
+  ctxt->cpi = cpi;
+  ctxt->partial_frame = partial_frame;
+  ctxt->info = info;
+  ctxt->type = type;
+  ctxt->best_tile_cost = best_tile_cost;
+  ctxt->plane = plane;
+  ctxt->dst_frame = dst_frame;
+
+  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
+  if (plane == AOM_PLANE_Y) {
+    ctxt->plane_width = src->y_crop_width;
+    ctxt->plane_height = src->y_crop_height;
+    ctxt->src_buffer = src->y_buffer;
+    ctxt->src_stride = src->y_stride;
+    ctxt->dgd_buffer = dgd->y_buffer;
+    ctxt->dgd_stride = dgd->y_stride;
+    assert(ctxt->plane_width == dgd->y_crop_width);
+    assert(ctxt->plane_height == dgd->y_crop_height);
+    assert(ctxt->plane_width == src->y_crop_width);
+    assert(ctxt->plane_height == src->y_crop_height);
+  } else {
+    ctxt->plane_width = src->uv_crop_width;
+    ctxt->plane_height = src->uv_crop_height;
+    ctxt->src_stride = src->uv_stride;
+    ctxt->dgd_stride = dgd->uv_stride;
+    ctxt->src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
+    ctxt->dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
+    assert(ctxt->plane_width == dgd->uv_crop_width);
+    assert(ctxt->plane_height == dgd->uv_crop_height);
+  }
+
+  return av1_get_rest_ntiles(ctxt->plane_width, ctxt->plane_height,
+                             cm->rst_info[plane].restoration_tilesize, NULL,
+                             NULL, &ctxt->nrtiles_x, &ctxt->nrtiles_y);
+}
+
+typedef void (*rtile_visitor_t)(const struct rest_search_ctxt *search_ctxt,
+                                int rtile_idx, int h_start, int h_end,
+                                int v_start, int v_end, void *arg);
+
+static void foreach_rtile_in_tile(const struct rest_search_ctxt *ctxt,
+                                  int tile_row, int tile_col,
+                                  rtile_visitor_t fun, void *arg) {
+  const AV1_COMMON *const cm = &ctxt->cpi->common;
+  const RestorationInfo *rsi = ctxt->cpi->rst_search;
+
+  const int tile_width_y = cm->tile_width * MI_SIZE;
+  const int tile_height_y = cm->tile_height * MI_SIZE;
+
+  const int tile_width =
+      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_width_y, cm->subsampling_x)
+                        : tile_width_y;
+  const int tile_height =
+      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_height_y, cm->subsampling_y)
+                        : tile_height_y;
+
+  const int rtile_size = rsi->restoration_tilesize;
+  const int rtiles_per_tile_x = tile_width * MI_SIZE / rtile_size;
+  const int rtiles_per_tile_y = tile_height * MI_SIZE / rtile_size;
+
+  const int rtile_row0 = rtiles_per_tile_y * tile_row;
+  const int rtile_row1 =
+      AOMMIN(rtile_row0 + rtiles_per_tile_y, ctxt->nrtiles_y);
+
+  const int rtile_col0 = rtiles_per_tile_x * tile_col;
+  const int rtile_col1 =
+      AOMMIN(rtile_col0 + rtiles_per_tile_x, ctxt->nrtiles_x);
+
+  const int rtile_width = AOMMIN(tile_width, rtile_size);
+  const int rtile_height = AOMMIN(tile_height, rtile_size);
+
+  for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
+    for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
+      const int rtile_idx = rtile_row * ctxt->nrtiles_x + rtile_col;
+      int h_start, h_end, v_start, v_end;
+      av1_get_rest_tile_limits(rtile_idx, 0, 0, ctxt->nrtiles_x,
+                               ctxt->nrtiles_y, rtile_width, rtile_height,
+                               ctxt->plane_width, ctxt->plane_height, 0, 0,
+                               &h_start, &h_end, &v_start, &v_end);
+
+      fun(ctxt, rtile_idx, h_start, h_end, v_start, v_end, arg);
+    }
+  }
+}
+
+static void search_sgrproj_for_rtile(const struct rest_search_ctxt *ctxt,
+                                     int rtile_idx, int h_start, int h_end,
+                                     int v_start, int v_end, void *arg) {
+  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
+  const AV1_COMMON *const cm = &ctxt->cpi->common;
+  RestorationInfo *rsi = ctxt->cpi->rst_search;
+  SgrprojInfo *sgrproj_info = ctxt->info->sgrproj_info;
+
+  SgrprojInfo *ref_sgrproj_info = (SgrprojInfo *)arg;
+
+  int64_t err = sse_restoration_tile(ctxt->src, cm->frame_to_show, cm, h_start,
+                                     h_end - h_start, v_start, v_end - v_start,
+                                     (1 << ctxt->plane));
+  // #bits when a tile is not restored
+  int bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
+  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
+  ctxt->best_tile_cost[rtile_idx] = DBL_MAX;
+
+  RestorationInfo *plane_rsi = &rsi[ctxt->plane];
+  SgrprojInfo *rtile_sgrproj_info = &plane_rsi->sgrproj_info[rtile_idx];
+  uint8_t *dgd_start = ctxt->dgd_buffer + v_start * ctxt->dgd_stride + h_start;
+  const uint8_t *src_start =
+      ctxt->src_buffer + v_start * ctxt->src_stride + h_start;
+
+  search_selfguided_restoration(dgd_start, h_end - h_start, v_end - v_start,
+                                ctxt->dgd_stride, src_start, ctxt->src_stride,
+#if CONFIG_HIGHBITDEPTH
+                                cm->bit_depth,
+#else
+                                8,
+#endif  // CONFIG_HIGHBITDEPTH
+                                &rtile_sgrproj_info->ep,
+                                rtile_sgrproj_info->xqd,
+                                cm->rst_internal.tmpbuf);
+  plane_rsi->restoration_type[rtile_idx] = RESTORE_SGRPROJ;
+  err = try_restoration_tile(ctxt->src, ctxt->cpi, rsi, (1 << ctxt->plane),
+                             ctxt->partial_frame, rtile_idx, 0, 0,
+                             ctxt->dst_frame);
+  bits =
+      count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx], ref_sgrproj_info)
+      << AV1_PROB_COST_SHIFT;
+  bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
+  double cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
+  if (cost_sgrproj >= cost_norestore) {
+    ctxt->type[rtile_idx] = RESTORE_NONE;
+  } else {
+    ctxt->type[rtile_idx] = RESTORE_SGRPROJ;
+    *ref_sgrproj_info = sgrproj_info[rtile_idx] =
+        plane_rsi->sgrproj_info[rtile_idx];
+    ctxt->best_tile_cost[rtile_idx] = err;
+  }
+  plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
+}
+
 static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                              int partial_frame, int plane,
                              RestorationInfo *info, RestorationType *type,
                              double *best_tile_cost,
                              YV12_BUFFER_CONFIG *dst_frame) {
-  SgrprojInfo *sgrproj_info = info->sgrproj_info;
-  double err, cost_norestore, cost_sgrproj;
-  int bits;
-  MACROBLOCK *x = &cpi->td.mb;
-  AV1_COMMON *const cm = &cpi->common;
-  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
-  RestorationInfo *rsi = &cpi->rst_search[0];
-  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
-  int h_start, h_end, v_start, v_end;
-  int width, height, src_stride, dgd_stride;
-  uint8_t *dgd_buffer, *src_buffer;
-  if (plane == AOM_PLANE_Y) {
-    width = src->y_crop_width;
-    height = src->y_crop_height;
-    src_buffer = src->y_buffer;
-    src_stride = src->y_stride;
-    dgd_buffer = dgd->y_buffer;
-    dgd_stride = dgd->y_stride;
-    assert(width == dgd->y_crop_width);
-    assert(height == dgd->y_crop_height);
-    assert(width == src->y_crop_width);
-    assert(height == src->y_crop_height);
-  } else {
-    width = src->uv_crop_width;
-    height = src->uv_crop_height;
-    src_stride = src->uv_stride;
-    dgd_stride = dgd->uv_stride;
-    src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
-    dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
-    assert(width == dgd->uv_crop_width);
-    assert(height == dgd->uv_crop_height);
+  struct rest_search_ctxt ctxt;
+  const int nrtiles =
+      init_rest_search_ctxt(src, cpi, partial_frame, plane, info, type,
+                            best_tile_cost, dst_frame, &ctxt);
+
+  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
+  plane_rsi->frame_restoration_type = RESTORE_SGRPROJ;
+  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
+    plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
   }
-  const int ntiles =
-      av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize,
-                          &tile_width, &tile_height, &nhtiles, &nvtiles);
+
+  // Compute best Sgrproj filters for each rtile, one (encoder/decoder)
+  // tile at a time.
+  const AV1_COMMON *const cm = &cpi->common;
+  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
+    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
+      SgrprojInfo ref_sgrproj_info;
+      set_default_sgrproj(&ref_sgrproj_info);
+
+      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_sgrproj_for_rtile,
+                            &ref_sgrproj_info);
+    }
+  }
+
+  // Cost for Sgrproj filtering
   SgrprojInfo ref_sgrproj_info;
   set_default_sgrproj(&ref_sgrproj_info);
+  SgrprojInfo *sgrproj_info = info->sgrproj_info;
 
-  rsi[plane].frame_restoration_type = RESTORE_SGRPROJ;
-
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
-  }
-  // Compute best Sgrproj filters for each tile
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
-                             tile_height, width, height, 0, 0, &h_start, &h_end,
-                             &v_start, &v_end);
-    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
-                               h_end - h_start, v_start, v_end - v_start,
-                               (1 << plane));
-    // #bits when a tile is not restored
-    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
-    cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-    best_tile_cost[tile_idx] = DBL_MAX;
-    search_selfguided_restoration(
-        dgd_buffer + v_start * dgd_stride + h_start, h_end - h_start,
-        v_end - v_start, dgd_stride,
-        src_buffer + v_start * src_stride + h_start, src_stride,
-#if CONFIG_HIGHBITDEPTH
-        cm->bit_depth,
-#else
-        8,
-#endif  // CONFIG_HIGHBITDEPTH
-        &rsi[plane].sgrproj_info[tile_idx].ep,
-        rsi[plane].sgrproj_info[tile_idx].xqd, cm->rst_internal.tmpbuf);
-    rsi[plane].restoration_type[tile_idx] = RESTORE_SGRPROJ;
-    err = try_restoration_tile(src, cpi, rsi, (1 << plane), partial_frame,
-                               tile_idx, 0, 0, dst_frame);
-    bits = count_sgrproj_bits(&rsi[plane].sgrproj_info[tile_idx],
-                              &ref_sgrproj_info)
-           << AV1_PROB_COST_SHIFT;
-    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
-    cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-    if (cost_sgrproj >= cost_norestore) {
-      type[tile_idx] = RESTORE_NONE;
-    } else {
-      type[tile_idx] = RESTORE_SGRPROJ;
-      memcpy(&sgrproj_info[tile_idx], &rsi[plane].sgrproj_info[tile_idx],
-             sizeof(sgrproj_info[tile_idx]));
-      memcpy(&ref_sgrproj_info, &sgrproj_info[tile_idx],
-             sizeof(ref_sgrproj_info));
-      best_tile_cost[tile_idx] = err;
-    }
-    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
-  }
-  // Cost for Sgrproj filtering
-  set_default_sgrproj(&ref_sgrproj_info);
-  bits = frame_level_restore_bits[rsi[plane].frame_restoration_type]
-         << AV1_PROB_COST_SHIFT;
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    bits +=
-        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
-    memcpy(&rsi[plane].sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
-           sizeof(sgrproj_info[tile_idx]));
-    if (type[tile_idx] == RESTORE_SGRPROJ) {
-      bits += count_sgrproj_bits(&rsi[plane].sgrproj_info[tile_idx],
+  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
+             << AV1_PROB_COST_SHIFT;
+  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
+    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB,
+                         type[rtile_idx] != RESTORE_NONE);
+    plane_rsi->sgrproj_info[rtile_idx] = sgrproj_info[rtile_idx];
+    if (type[rtile_idx] == RESTORE_SGRPROJ) {
+      bits += count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx],
                                  &ref_sgrproj_info)
               << AV1_PROB_COST_SHIFT;
-      memcpy(&ref_sgrproj_info, &rsi[plane].sgrproj_info[tile_idx],
-             sizeof(ref_sgrproj_info));
+      ref_sgrproj_info = plane_rsi->sgrproj_info[rtile_idx];
     }
-    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
+    plane_rsi->restoration_type[rtile_idx] = type[rtile_idx];
   }
-  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
-                              dst_frame);
-  cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-
+  double err = try_restoration_frame(src, cpi, cpi->rst_search, (1 << plane),
+                                     partial_frame, dst_frame);
+  double cost_sgrproj = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
   return cost_sgrproj;
 }
 
-static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
-                           int v_end, int stride) {
+static double find_average(const uint8_t *src, int h_start, int h_end,
+                           int v_start, int v_end, int stride) {
   uint64_t sum = 0;
   double avg = 0;
   int i, j;
@@ -548,10 +651,10 @@
   return avg;
 }
 
-static void compute_stats(int wiener_win, uint8_t *dgd, uint8_t *src,
-                          int h_start, int h_end, int v_start, int v_end,
-                          int dgd_stride, int src_stride, double *M,
-                          double *H) {
+static void compute_stats(int wiener_win, const uint8_t *dgd,
+                          const uint8_t *src, int h_start, int h_end,
+                          int v_start, int v_end, int dgd_stride,
+                          int src_stride, double *M, double *H) {
   int i, j, k, l;
   double Y[WIENER_WIN2];
   const int wiener_win2 = wiener_win * wiener_win;
@@ -591,7 +694,7 @@
 }
 
 #if CONFIG_HIGHBITDEPTH
-static double find_average_highbd(uint16_t *src, int h_start, int h_end,
+static double find_average_highbd(const uint16_t *src, int h_start, int h_end,
                                   int v_start, int v_end, int stride) {
   uint64_t sum = 0;
   double avg = 0;
@@ -603,16 +706,16 @@
   return avg;
 }
 
-static void compute_stats_highbd(int wiener_win, uint8_t *dgd8, uint8_t *src8,
-                                 int h_start, int h_end, int v_start, int v_end,
-                                 int dgd_stride, int src_stride, double *M,
-                                 double *H) {
+static void compute_stats_highbd(int wiener_win, const uint8_t *dgd8,
+                                 const uint8_t *src8, int h_start, int h_end,
+                                 int v_start, int v_end, int dgd_stride,
+                                 int src_stride, double *M, double *H) {
   int i, j, k, l;
   double Y[WIENER_WIN2];
   const int wiener_win2 = wiener_win * wiener_win;
   const int wiener_halfwin = (wiener_win >> 1);
-  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
-  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
+  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
   const double avg =
       find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
 
@@ -999,165 +1102,151 @@
   return err;
 }
 
+static void search_wiener_for_rtile(const struct rest_search_ctxt *ctxt,
+                                    int rtile_idx, int h_start, int h_end,
+                                    int v_start, int v_end, void *arg) {
+  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
+  const AV1_COMMON *const cm = &ctxt->cpi->common;
+  RestorationInfo *rsi = ctxt->cpi->rst_search;
+
+  const int wiener_win =
+      (ctxt->plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
+
+  double M[WIENER_WIN2];
+  double H[WIENER_WIN2 * WIENER_WIN2];
+  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
+
+  WienerInfo *ref_wiener_info = (WienerInfo *)arg;
+
+  int64_t err = sse_restoration_tile(ctxt->src, cm->frame_to_show, cm, h_start,
+                                     h_end - h_start, v_start, v_end - v_start,
+                                     (1 << ctxt->plane));
+  // #bits when a tile is not restored
+  int bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
+  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
+  ctxt->best_tile_cost[rtile_idx] = DBL_MAX;
+
+#if CONFIG_HIGHBITDEPTH
+  if (cm->use_highbitdepth)
+    compute_stats_highbd(wiener_win, ctxt->dgd_buffer, ctxt->src_buffer,
+                         h_start, h_end, v_start, v_end, ctxt->dgd_stride,
+                         ctxt->src_stride, M, H);
+  else
+#endif  // CONFIG_HIGHBITDEPTH
+    compute_stats(wiener_win, ctxt->dgd_buffer, ctxt->src_buffer, h_start,
+                  h_end, v_start, v_end, ctxt->dgd_stride, ctxt->src_stride, M,
+                  H);
+
+  ctxt->type[rtile_idx] = RESTORE_WIENER;
+
+  if (!wiener_decompose_sep_sym(wiener_win, M, H, vfilterd, hfilterd)) {
+    ctxt->type[rtile_idx] = RESTORE_NONE;
+    return;
+  }
+
+  RestorationInfo *plane_rsi = &rsi[ctxt->plane];
+  WienerInfo *rtile_wiener_info = &plane_rsi->wiener_info[rtile_idx];
+  quantize_sym_filter(wiener_win, vfilterd, rtile_wiener_info->vfilter);
+  quantize_sym_filter(wiener_win, hfilterd, rtile_wiener_info->hfilter);
+
+  // Filter score computes the value of the function x'*A*x - x'*b for the
+  // learned filter and compares it against identity filer. If there is no
+  // reduction in the function, the filter is reverted back to identity
+  double score = compute_score(wiener_win, M, H, rtile_wiener_info->vfilter,
+                               rtile_wiener_info->hfilter);
+  if (score > 0.0) {
+    ctxt->type[rtile_idx] = RESTORE_NONE;
+    return;
+  }
+  aom_clear_system_state();
+
+  plane_rsi->restoration_type[rtile_idx] = RESTORE_WIENER;
+  err = finer_tile_search_wiener(ctxt->src, ctxt->cpi, rsi, 4, ctxt->plane,
+                                 wiener_win, rtile_idx, ctxt->partial_frame,
+                                 ctxt->dst_frame);
+  if (wiener_win != WIENER_WIN) {
+    assert(rtile_wiener_info->vfilter[0] == 0 &&
+           rtile_wiener_info->vfilter[WIENER_WIN - 1] == 0);
+    assert(rtile_wiener_info->hfilter[0] == 0 &&
+           rtile_wiener_info->hfilter[WIENER_WIN - 1] == 0);
+  }
+  bits = count_wiener_bits(wiener_win, rtile_wiener_info, ref_wiener_info)
+         << AV1_PROB_COST_SHIFT;
+  bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
+  double cost_wiener = RDCOST_DBL(x->rdmult, (bits >> 4), err);
+  if (cost_wiener >= cost_norestore) {
+    ctxt->type[rtile_idx] = RESTORE_NONE;
+  } else {
+    ctxt->type[rtile_idx] = RESTORE_WIENER;
+    *ref_wiener_info = ctxt->info->wiener_info[rtile_idx] = *rtile_wiener_info;
+    ctxt->best_tile_cost[rtile_idx] = err;
+  }
+  plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
+}
+
 static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int partial_frame, int plane, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
-  WienerInfo *wiener_info = info->wiener_info;
+  struct rest_search_ctxt ctxt;
+  const int nrtiles =
+      init_rest_search_ctxt(src, cpi, partial_frame, plane, info, type,
+                            best_tile_cost, dst_frame, &ctxt);
+
+  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
+  plane_rsi->frame_restoration_type = RESTORE_WIENER;
+  for (int tile_idx = 0; tile_idx < nrtiles; ++tile_idx) {
+    plane_rsi->restoration_type[tile_idx] = RESTORE_NONE;
+  }
+
   AV1_COMMON *const cm = &cpi->common;
-  RestorationInfo *rsi = cpi->rst_search;
-  int64_t err;
-  int bits;
-  double cost_wiener, cost_norestore;
-  MACROBLOCK *x = &cpi->td.mb;
-  double M[WIENER_WIN2];
-  double H[WIENER_WIN2 * WIENER_WIN2];
-  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
-  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
-  int width, height, src_stride, dgd_stride, wiener_win;
-  uint8_t *dgd_buffer, *src_buffer;
-  if (plane == AOM_PLANE_Y) {
-    width = src->y_crop_width;
-    height = src->y_crop_height;
-    src_buffer = src->y_buffer;
-    src_stride = src->y_stride;
-    dgd_buffer = dgd->y_buffer;
-    dgd_stride = dgd->y_stride;
-    assert(width == dgd->y_crop_width);
-    assert(height == dgd->y_crop_height);
-    assert(width == src->y_crop_width);
-    assert(height == src->y_crop_height);
-    wiener_win = WIENER_WIN;
-  } else {
-    width = src->uv_crop_width;
-    height = src->uv_crop_height;
-    src_stride = src->uv_stride;
-    dgd_stride = dgd->uv_stride;
-    src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
-    dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
-    assert(width == dgd->uv_crop_width);
-    assert(height == dgd->uv_crop_height);
-    wiener_win = WIENER_WIN_CHROMA;
-  }
-  double score;
-  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
-  int h_start, h_end, v_start, v_end;
-  const int ntiles = av1_get_rest_ntiles(
-      width, height, cm->rst_info[plane].restoration_tilesize, &tile_width,
-      &tile_height, &nhtiles, &nvtiles);
-  WienerInfo ref_wiener_info;
-  set_default_wiener(&ref_wiener_info);
-
-  rsi[plane].frame_restoration_type = RESTORE_WIENER;
-
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
-  }
-
 // Construct a (WIENER_HALFWIN)-pixel border around the frame
 #if CONFIG_HIGHBITDEPTH
   if (cm->use_highbitdepth)
-    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd_buffer), width, height,
-                        dgd_stride);
+    extend_frame_highbd(CONVERT_TO_SHORTPTR(ctxt.dgd_buffer), ctxt.plane_width,
+                        ctxt.plane_height, ctxt.dgd_stride);
   else
 #endif
-    extend_frame(dgd_buffer, width, height, dgd_stride);
+    extend_frame(ctxt.dgd_buffer, ctxt.plane_width, ctxt.plane_height,
+                 ctxt.dgd_stride);
 
-  // Compute best Wiener filters for each tile
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
-                             tile_height, width, height, 0, 0, &h_start, &h_end,
-                             &v_start, &v_end);
-    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
-                               h_end - h_start, v_start, v_end - v_start,
-                               (1 << plane));
-    // #bits when a tile is not restored
-    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
-    cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-    best_tile_cost[tile_idx] = DBL_MAX;
+  // Compute best Wiener filters for each rtile, one (encoder/decoder)
+  // tile at a time.
+  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
+    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
+      WienerInfo ref_wiener_info;
+      set_default_wiener(&ref_wiener_info);
 
-    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
-                             tile_height, width, height, 0, 0, &h_start, &h_end,
-                             &v_start, &v_end);
-#if CONFIG_HIGHBITDEPTH
-    if (cm->use_highbitdepth)
-      compute_stats_highbd(wiener_win, dgd_buffer, src_buffer, h_start, h_end,
-                           v_start, v_end, dgd_stride, src_stride, M, H);
-    else
-#endif  // CONFIG_HIGHBITDEPTH
-      compute_stats(wiener_win, dgd_buffer, src_buffer, h_start, h_end, v_start,
-                    v_end, dgd_stride, src_stride, M, H);
-
-    type[tile_idx] = RESTORE_WIENER;
-
-    if (!wiener_decompose_sep_sym(wiener_win, M, H, vfilterd, hfilterd)) {
-      type[tile_idx] = RESTORE_NONE;
-      continue;
+      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_wiener_for_rtile,
+                            &ref_wiener_info);
     }
-    quantize_sym_filter(wiener_win, vfilterd,
-                        rsi[plane].wiener_info[tile_idx].vfilter);
-    quantize_sym_filter(wiener_win, hfilterd,
-                        rsi[plane].wiener_info[tile_idx].hfilter);
-
-    // Filter score computes the value of the function x'*A*x - x'*b for the
-    // learned filter and compares it against identity filer. If there is no
-    // reduction in the function, the filter is reverted back to identity
-    score = compute_score(wiener_win, M, H,
-                          rsi[plane].wiener_info[tile_idx].vfilter,
-                          rsi[plane].wiener_info[tile_idx].hfilter);
-    if (score > 0.0) {
-      type[tile_idx] = RESTORE_NONE;
-      continue;
-    }
-    aom_clear_system_state();
-
-    rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
-    err = finer_tile_search_wiener(src, cpi, rsi, 4, plane, wiener_win,
-                                   tile_idx, partial_frame, dst_frame);
-    if (wiener_win != WIENER_WIN) {
-      assert(rsi[plane].wiener_info[tile_idx].vfilter[0] == 0 &&
-             rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - 1] == 0);
-      assert(rsi[plane].wiener_info[tile_idx].hfilter[0] == 0 &&
-             rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - 1] == 0);
-    }
-    bits = count_wiener_bits(wiener_win, &rsi[plane].wiener_info[tile_idx],
-                             &ref_wiener_info)
-           << AV1_PROB_COST_SHIFT;
-    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
-    cost_wiener = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-    if (cost_wiener >= cost_norestore) {
-      type[tile_idx] = RESTORE_NONE;
-    } else {
-      type[tile_idx] = RESTORE_WIENER;
-      memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[tile_idx],
-             sizeof(wiener_info[tile_idx]));
-      memcpy(&ref_wiener_info, &rsi[plane].wiener_info[tile_idx],
-             sizeof(ref_wiener_info));
-      best_tile_cost[tile_idx] = err;
-    }
-    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
   }
-  // Cost for Wiener filtering
+
+  // cost for Wiener filtering
+  WienerInfo ref_wiener_info;
   set_default_wiener(&ref_wiener_info);
-  bits = frame_level_restore_bits[rsi[plane].frame_restoration_type]
-         << AV1_PROB_COST_SHIFT;
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
+             << AV1_PROB_COST_SHIFT;
+  WienerInfo *wiener_info = info->wiener_info;
+  const int wiener_win =
+      (plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
+
+  for (int tile_idx = 0; tile_idx < nrtiles; ++tile_idx) {
     bits +=
         av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
-    memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
-           sizeof(wiener_info[tile_idx]));
+    plane_rsi->wiener_info[tile_idx] = wiener_info[tile_idx];
+
     if (type[tile_idx] == RESTORE_WIENER) {
-      bits += count_wiener_bits(wiener_win, &rsi[plane].wiener_info[tile_idx],
+      bits += count_wiener_bits(wiener_win, &plane_rsi->wiener_info[tile_idx],
                                 &ref_wiener_info)
               << AV1_PROB_COST_SHIFT;
-      memcpy(&ref_wiener_info, &rsi[plane].wiener_info[tile_idx],
-             sizeof(ref_wiener_info));
+      ref_wiener_info = plane_rsi->wiener_info[tile_idx];
     }
-    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
+    plane_rsi->restoration_type[tile_idx] = type[tile_idx];
   }
-  err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
-                              dst_frame);
-  cost_wiener = RDCOST_DBL(x->rdmult, (bits >> 4), err);
+  int64_t err = try_restoration_frame(src, cpi, cpi->rst_search, 1 << plane,
+                                      partial_frame, dst_frame);
+  double cost_wiener = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
 
   return cost_wiener;
 }
@@ -1207,75 +1296,90 @@
   return cost_norestore;
 }
 
+struct switchable_rest_search_ctxt {
+  SgrprojInfo sgrproj_info;
+  WienerInfo wiener_info;
+  RestorationType *const *restore_types;
+  double *const *tile_cost;
+  double cost_switchable;
+};
+
+static void search_switchable_for_rtile(const struct rest_search_ctxt *ctxt,
+                                        int rtile_idx, int h_start, int h_end,
+                                        int v_start, int v_end, void *arg) {
+  const MACROBLOCK *x = &ctxt->cpi->td.mb;
+  RestorationInfo *rsi = &ctxt->cpi->common.rst_info[ctxt->plane];
+  struct switchable_rest_search_ctxt *swctxt =
+      (struct switchable_rest_search_ctxt *)arg;
+
+  (void)h_start;
+  (void)h_end;
+  (void)v_start;
+  (void)v_end;
+
+  double best_cost =
+      RDCOST_DBL(x->rdmult, (x->switchable_restore_cost[RESTORE_NONE] >> 4),
+                 swctxt->tile_cost[RESTORE_NONE][rtile_idx]);
+  rsi->restoration_type[rtile_idx] = RESTORE_NONE;
+  for (RestorationType r = 1; r < RESTORE_SWITCHABLE_TYPES; r++) {
+    if (force_restore_type != 0)
+      if (r != force_restore_type) continue;
+    int tilebits = 0;
+    if (swctxt->restore_types[r][rtile_idx] != r) continue;
+    if (r == RESTORE_WIENER)
+      tilebits += count_wiener_bits(
+          (ctxt->plane == AOM_PLANE_Y ? WIENER_WIN : WIENER_WIN - 2),
+          &rsi->wiener_info[rtile_idx], &swctxt->wiener_info);
+    else if (r == RESTORE_SGRPROJ)
+      tilebits += count_sgrproj_bits(&rsi->sgrproj_info[rtile_idx],
+                                     &swctxt->sgrproj_info);
+    tilebits <<= AV1_PROB_COST_SHIFT;
+    tilebits += x->switchable_restore_cost[r];
+    double cost =
+        RDCOST_DBL(x->rdmult, tilebits >> 4, swctxt->tile_cost[r][rtile_idx]);
+
+    if (cost < best_cost) {
+      rsi->restoration_type[rtile_idx] = r;
+      best_cost = cost;
+    }
+  }
+  if (rsi->restoration_type[rtile_idx] == RESTORE_WIENER)
+    swctxt->wiener_info = rsi->wiener_info[rtile_idx];
+  else if (rsi->restoration_type[rtile_idx] == RESTORE_SGRPROJ)
+    swctxt->sgrproj_info = rsi->sgrproj_info[rtile_idx];
+  if (force_restore_type != 0)
+    assert(rsi->restoration_type[rtile_idx] == force_restore_type ||
+           rsi->restoration_type[rtile_idx] == RESTORE_NONE);
+  swctxt->cost_switchable += best_cost;
+}
+
 static double search_switchable_restoration(
     const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int partial_frame, int plane,
     RestorationType *const restore_types[RESTORE_SWITCHABLE_TYPES],
     double *const tile_cost[RESTORE_SWITCHABLE_TYPES], RestorationInfo *rsi) {
-  AV1_COMMON *const cm = &cpi->common;
-  MACROBLOCK *x = &cpi->td.mb;
-  double cost_switchable = 0;
-  int bits, tile_idx;
-  RestorationType r;
-  int width, height;
-  if (plane == AOM_PLANE_Y) {
-    width = src->y_crop_width;
-    height = src->y_crop_height;
-  } else {
-    width = src->uv_crop_width;
-    height = src->uv_crop_height;
-  }
-  const int ntiles = av1_get_rest_ntiles(
-      width, height, cm->rst_info[plane].restoration_tilesize, NULL, NULL, NULL,
-      NULL);
-  SgrprojInfo ref_sgrproj_info;
-  set_default_sgrproj(&ref_sgrproj_info);
-  WienerInfo ref_wiener_info;
-  set_default_wiener(&ref_wiener_info);
-  (void)partial_frame;
+  const AV1_COMMON *const cm = &cpi->common;
+  struct rest_search_ctxt ctxt;
+  init_rest_search_ctxt(src, cpi, partial_frame, plane, NULL, NULL, NULL, NULL,
+                        &ctxt);
+  struct switchable_rest_search_ctxt swctxt;
+  swctxt.restore_types = restore_types;
+  swctxt.tile_cost = tile_cost;
 
   rsi->frame_restoration_type = RESTORE_SWITCHABLE;
-  bits = frame_level_restore_bits[rsi->frame_restoration_type]
-         << AV1_PROB_COST_SHIFT;
-  cost_switchable = RDCOST_DBL(x->rdmult, bits >> 4, 0);
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    double best_cost =
-        RDCOST_DBL(x->rdmult, (x->switchable_restore_cost[RESTORE_NONE] >> 4),
-                   tile_cost[RESTORE_NONE][tile_idx]);
-    rsi->restoration_type[tile_idx] = RESTORE_NONE;
-    for (r = 1; r < RESTORE_SWITCHABLE_TYPES; r++) {
-      if (force_restore_type != 0)
-        if (r != force_restore_type) continue;
-      int tilebits = 0;
-      if (restore_types[r][tile_idx] != r) continue;
-      if (r == RESTORE_WIENER)
-        tilebits += count_wiener_bits(
-            (plane == AOM_PLANE_Y ? WIENER_WIN : WIENER_WIN - 2),
-            &rsi->wiener_info[tile_idx], &ref_wiener_info);
-      else if (r == RESTORE_SGRPROJ)
-        tilebits +=
-            count_sgrproj_bits(&rsi->sgrproj_info[tile_idx], &ref_sgrproj_info);
-      tilebits <<= AV1_PROB_COST_SHIFT;
-      tilebits += x->switchable_restore_cost[r];
-      double cost =
-          RDCOST_DBL(x->rdmult, tilebits >> 4, tile_cost[r][tile_idx]);
+  int bits = frame_level_restore_bits[rsi->frame_restoration_type]
+             << AV1_PROB_COST_SHIFT;
+  swctxt.cost_switchable = RDCOST_DBL(cpi->td.mb.rdmult, bits >> 4, 0);
 
-      if (cost < best_cost) {
-        rsi->restoration_type[tile_idx] = r;
-        best_cost = cost;
-      }
+  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
+    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
+      set_default_sgrproj(&swctxt.sgrproj_info);
+      set_default_wiener(&swctxt.wiener_info);
+      foreach_rtile_in_tile(&ctxt, tile_row, tile_col,
+                            search_switchable_for_rtile, &swctxt);
     }
-    if (rsi->restoration_type[tile_idx] == RESTORE_WIENER)
-      memcpy(&ref_wiener_info, &rsi->wiener_info[tile_idx],
-             sizeof(ref_wiener_info));
-    else if (rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ)
-      memcpy(&ref_sgrproj_info, &rsi->sgrproj_info[tile_idx],
-             sizeof(ref_sgrproj_info));
-    if (force_restore_type != 0)
-      assert(rsi->restoration_type[tile_idx] == force_restore_type ||
-             rsi->restoration_type[tile_idx] == RESTORE_NONE);
-    cost_switchable += best_cost;
   }
-  return cost_switchable;
+
+  return swctxt.cost_switchable;
 }
 
 void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,