Simplify the search in pickrst.c

The end results should be essentially the same, except hopefully a
little better because the previous code got the costing slightly
wrong if there were multiple tiles.

The code was doing something quite complicated to search for the best
restoration types; this patch makes it much simpler. Basically,
av1_pick_filter_restoration loops over the planes in the image. For
each plane, it loops over the possible restoration types, calling
search_rest_type for each one.

search_rest_type iterates over the restoration units in the image,
resetting the current context on tile boundaries and calling
search_<rest_type> for each restoration unit.

The search_norestore function just computes the SSE error with no
restoration. The search_wiener and search_sgrproj functions compute
the best set of coefficients and then the resulting SSE error with
those coefficients (ignoring the bit cost, so the result can be
re-used for switchable restoration).

In all cases but search_norestore, the search function has to decide
what restoration type is best for each restoration unit. For example,
search_wiener could choose to enable or disable Wiener filtering on
this unit.

Eventually, search_rest_type calculates the RDCOST after summing bit
rates and SSE errors over the restoration units. This cost gets
returned to av1_pick_filter_restoration which can then choose the best
frame-level restoration type.

Change-Id: I9bc17eb47cc46413adae749a43a440825c41bba6
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 105eb92..1a9ee17 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -545,12 +545,6 @@
 #if CONFIG_LOOP_RESTORATION
   av1_free_restoration_buffers(cm);
   aom_free_frame_buffer(&cpi->trial_frame_rst);
-  aom_free(cpi->extra_rstbuf);
-  {
-    int i;
-    for (i = 0; i < MAX_MB_PLANE; ++i)
-      av1_free_restoration_struct(&cpi->rst_search[i]);
-  }
 #endif  // CONFIG_LOOP_RESTORATION
   aom_free_frame_buffer(&cpi->scaled_source);
   aom_free_frame_buffer(&cpi->scaled_last_source);
@@ -827,14 +821,6 @@
           AOM_BORDER_IN_PIXELS, cm->byte_alignment, NULL, NULL, NULL))
     aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
                        "Failed to allocate trial restored frame buffer");
-  int extra_rstbuf_sz = RESTORATION_EXTBUF_SIZE;
-  if (extra_rstbuf_sz > 0) {
-    aom_free(cpi->extra_rstbuf);
-    CHECK_MEM_ERROR(cm, cpi->extra_rstbuf,
-                    (uint8_t *)aom_malloc(extra_rstbuf_sz));
-  } else {
-    cpi->extra_rstbuf = NULL;
-  }
 #endif  // CONFIG_LOOP_RESTORATION
 
   if (aom_realloc_frame_buffer(&cpi->scaled_source, cm->width, cm->height,
@@ -4260,16 +4246,6 @@
                        "Failed to allocate frame buffer");
 
 #if CONFIG_LOOP_RESTORATION
-  set_restoration_tilesize(
-#if CONFIG_FRAME_SUPERRES
-      cm->superres_upscaled_width, cm->superres_upscaled_height,
-#else
-      cm->width, cm->height,
-#endif  // CONFIG_FRAME_SUPERRES
-      cm->subsampling_x, cm->subsampling_y, cm->rst_info);
-  for (int i = 0; i < MAX_MB_PLANE; ++i)
-    cm->rst_info[i].frame_restoration_type = RESTORE_NONE;
-
 #if CONFIG_FRAME_SUPERRES
   const int frame_width = cm->superres_upscaled_width;
   const int frame_height = cm->superres_upscaled_height;
@@ -4277,25 +4253,12 @@
   const int frame_width = cm->width;
   const int frame_height = cm->height;
 #endif
+  set_restoration_tilesize(frame_width, frame_height, cm->subsampling_x,
+                           cm->subsampling_y, cm->rst_info);
+  for (int i = 0; i < MAX_MB_PLANE; ++i)
+    cm->rst_info[i].frame_restoration_type = RESTORE_NONE;
 
   av1_alloc_restoration_buffers(cm);
-
-  // Set up the rst_search RestorationInfo structures. These are the same as
-  // the rst_info ones except need their own arrays of types and coefficients,
-  // allocated in av1_alloc_restoration_struct.
-  for (int i = 0; i < MAX_MB_PLANE; ++i) {
-    RestorationInfo *search = &cpi->rst_search[i];
-    RestorationInfo *rsi = &cm->rst_info[i];
-
-    search->restoration_tilesize = rsi->restoration_tilesize;
-    search->procunit_width = rsi->procunit_width;
-    search->procunit_height = rsi->procunit_height;
-    av1_alloc_restoration_struct(cm, search, frame_width, frame_height);
-#if CONFIG_STRIPED_LOOP_RESTORATION
-    // We can share boundary buffers between the search info and the main one
-    search->boundaries = rsi->boundaries;
-#endif
-  }
 #endif                            // CONFIG_LOOP_RESTORATION
   alloc_util_frame_buffers(cpi);  // TODO(afergs): Remove? Gets called anyways.
   init_motion_estimation(cpi);
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index b326d5c..d01a31a 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -435,9 +435,7 @@
   YV12_BUFFER_CONFIG last_frame_uf;
 #if CONFIG_LOOP_RESTORATION
   YV12_BUFFER_CONFIG trial_frame_rst;
-  uint8_t *extra_rstbuf;  // Extra buffers used in restoration search
-  RestorationInfo rst_search[MAX_MB_PLANE];  // Used for encoder side search
-#endif                                       // CONFIG_LOOP_RESTORATION
+#endif
 
   // Ambient reconstruction err target for force key frames
   int64_t ambient_err;
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 1299982..646c08c 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -40,13 +40,6 @@
 // Number of Wiener iterations
 #define NUM_WIENER_ITERS 5
 
-typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
-                                      AV1_COMP *cpi, int plane,
-                                      RestorationInfo *info,
-                                      RestorationType *rest_level,
-                                      int64_t *best_tile_cost,
-                                      YV12_BUFFER_CONFIG *dst_frame);
-
 const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
 
 typedef int64_t (*sse_extractor_type)(const YV12_BUFFER_CONFIG *a,
@@ -67,13 +60,6 @@
 #endif  // CONFIG_HIGHBITDEPTH
 };
 
-static const sse_extractor_type sse_extractors[NUM_EXTRACTORS] = {
-  aom_get_y_sse,        aom_get_u_sse,        aom_get_v_sse,
-#if CONFIG_HIGHBITDEPTH
-  aom_highbd_get_y_sse, aom_highbd_get_u_sse, aom_highbd_get_v_sse,
-#endif  // CONFIG_HIGHBITDEPTH
-};
-
 static int64_t sse_restoration_tile(const RestorationTileLimits *limits,
                                     const YV12_BUFFER_CONFIG *src,
                                     const YV12_BUFFER_CONFIG *dst, int plane,
@@ -84,19 +70,12 @@
       limits->v_start, limits->v_end - limits->v_start);
 }
 
-static int64_t sse_restoration_frame(const YV12_BUFFER_CONFIG *src,
-                                     const YV12_BUFFER_CONFIG *dst, int plane,
-                                     int highbd) {
-  assert(CONFIG_HIGHBITDEPTH || !highbd);
-  return sse_extractors[3 * highbd + plane](src, dst);
-}
-
 static int64_t try_restoration_tile(const AV1_COMMON *cm,
                                     const YV12_BUFFER_CONFIG *src,
                                     const RestorationTileLimits *limits,
                                     const RestorationUnitInfo *rui,
                                     YV12_BUFFER_CONFIG *dst, int plane) {
-  const RestorationInfo *prsi = &cm->rst_info[plane];
+  const RestorationInfo *rsi = &cm->rst_info[plane];
   const int is_uv = plane > 0;
 #if CONFIG_STRIPED_LOOP_RESTORATION
   RestorationLineBuffers rlbs;
@@ -114,9 +93,9 @@
 
   av1_loop_restoration_filter_unit(limits, rui,
 #if CONFIG_STRIPED_LOOP_RESTORATION
-                                   &prsi->boundaries, &rlbs, ss_y,
+                                   &rsi->boundaries, &rlbs, ss_y,
 #endif
-                                   prsi->procunit_width, prsi->procunit_height,
+                                   rsi->procunit_width, rsi->procunit_height,
                                    highbd, bit_depth, fts->buffers[plane],
                                    fts->strides[is_uv], dst->buffers[plane],
                                    dst->strides[is_uv], cm->rst_tmpbuf);
@@ -124,20 +103,6 @@
   return sse_restoration_tile(limits, src, dst, plane, highbd);
 }
 
-static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
-                                     AV1_COMMON *cm, RestorationInfo *rsi,
-                                     YV12_BUFFER_CONFIG *dst, int plane) {
-#if CONFIG_HIGHBITDEPTH
-  const int highbd = cm->use_highbitdepth;
-#else
-  const int highbd = 0;
-#endif  // CONFIG_HIGHBITDEPTH
-
-  av1_loop_restoration_filter_frame(cm->frame_to_show, cm, rsi, 1 << plane,
-                                    dst);
-  return sse_restoration_frame(src, dst, plane, highbd);
-}
-
 static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
                                     int src_stride, const uint8_t *dat8,
                                     int dat_stride, int use_highbitdepth,
@@ -312,12 +277,10 @@
   xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
 }
 
-static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
-                                          int dat_stride, const uint8_t *src8,
-                                          int src_stride, int use_highbitdepth,
-                                          int bit_depth, int pu_width,
-                                          int pu_height, int *eps, int *xqd,
-                                          int32_t *rstbuf) {
+static SgrprojInfo search_selfguided_restoration(
+    uint8_t *dat8, int width, int height, int dat_stride, const uint8_t *src8,
+    int src_stride, int use_highbitdepth, int bit_depth, int pu_width,
+    int pu_height, int32_t *rstbuf) {
   int32_t *flt1 = rstbuf;
   int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
   int ep, bestep = 0;
@@ -397,9 +360,12 @@
       bestxqd[1] = exqd[1];
     }
   }
-  *eps = bestep;
-  xqd[0] = bestxqd[0];
-  xqd[1] = bestxqd[1];
+
+  SgrprojInfo ret;
+  ret.ep = bestep;
+  ret.xqd[0] = bestxqd[0];
+  ret.xqd[1] = bestxqd[1];
+  return ret;
 }
 
 static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
@@ -416,63 +382,96 @@
   return bits;
 }
 
-struct rest_search_ctxt {
+typedef struct {
+  // The best coefficients for Wiener or Sgrproj restoration
+  WienerInfo wiener;
+  SgrprojInfo sgrproj;
+
+  // The sum of squared errors for this rtype.
+  int64_t sse[RESTORE_SWITCHABLE_TYPES];
+
+  // The rtype to use for this unit given a frame rtype as
+  // index. Indices: WIENER, SGRPROJ, SWITCHABLE.
+  RestorationType best_rtype[RESTORE_TYPES - 1];
+} RestUnitSearchInfo;
+
+typedef struct {
   const YV12_BUFFER_CONFIG *src;
-  AV1_COMP *cpi;
-  uint8_t *dgd_buffer;
-  const uint8_t *src_buffer;
-  int dgd_stride;
-  int src_stride;
-  RestorationInfo *info;
-  RestorationType *type;
-  int64_t *best_tile_cost;
+  const AV1_COMMON *cm;
+  const MACROBLOCK *x;
   int plane;
   int plane_width;
   int plane_height;
+  RestUnitSearchInfo *rusi;
+  YV12_BUFFER_CONFIG *dst_frame;
+
+  uint8_t *dgd_buffer;
+  int dgd_stride;
+  const uint8_t *src_buffer;
+  int src_stride;
+
   int nrtiles_x;
   int nrtiles_y;
-  YV12_BUFFER_CONFIG *dst_frame;
-};
+} RestSearchInfo;
 
-// Fill in ctxt. Returns the number of restoration tiles for this plane
-static INLINE int init_rest_search_ctxt(
-    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int plane,
-    RestorationInfo *info, RestorationType *type, int64_t *best_tile_cost,
-    YV12_BUFFER_CONFIG *dst_frame, struct rest_search_ctxt *ctxt) {
-  AV1_COMMON *const cm = &cpi->common;
-  ctxt->src = src;
-  ctxt->cpi = cpi;
-  ctxt->info = info;
-  ctxt->type = type;
-  ctxt->best_tile_cost = best_tile_cost;
-  ctxt->plane = plane;
-  ctxt->dst_frame = dst_frame;
+static INLINE int init_rest_search_info(const YV12_BUFFER_CONFIG *src,
+                                        const AV1_COMMON *cm,
+                                        const MACROBLOCK *x, int plane,
+                                        RestUnitSearchInfo *rusi,
+                                        YV12_BUFFER_CONFIG *dst_frame,
+                                        RestSearchInfo *info) {
+  info->src = src;
+  info->cm = cm;
+  info->x = x;
+  info->plane = plane;
+  info->rusi = rusi;
+  info->dst_frame = dst_frame;
 
   const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
   const int is_uv = plane != AOM_PLANE_Y;
-  ctxt->plane_width = src->crop_widths[is_uv];
-  ctxt->plane_height = src->crop_heights[is_uv];
-  ctxt->src_buffer = src->buffers[plane];
-  ctxt->src_stride = src->strides[is_uv];
-  ctxt->dgd_buffer = dgd->buffers[plane];
-  ctxt->dgd_stride = dgd->strides[is_uv];
+  info->plane_width = src->crop_widths[is_uv];
+  info->plane_height = src->crop_heights[is_uv];
+  info->src_buffer = src->buffers[plane];
+  info->src_stride = src->strides[is_uv];
+  info->dgd_buffer = dgd->buffers[plane];
+  info->dgd_stride = dgd->strides[is_uv];
   assert(src->crop_widths[is_uv] == dgd->crop_widths[is_uv]);
   assert(src->crop_heights[is_uv] == dgd->crop_heights[is_uv]);
 
-  return av1_get_rest_ntiles(ctxt->plane_width, ctxt->plane_height,
+  return av1_get_rest_ntiles(info->plane_width, info->plane_height,
                              cm->rst_info[plane].restoration_tilesize,
-                             &ctxt->nrtiles_x, &ctxt->nrtiles_y);
+                             &info->nrtiles_x, &info->nrtiles_y);
 }
 
-typedef void (*rtile_visitor_t)(const struct rest_search_ctxt *search_ctxt,
-                                int rtile_idx,
-                                const RestorationTileLimits *limits, void *arg);
+typedef struct {
+  SgrprojInfo sgrproj;
+  WienerInfo wiener;
+  int64_t sse;
+  int64_t bits;
+} RestSearchCtxt;
 
-static void foreach_rtile_in_tile(const struct rest_search_ctxt *ctxt,
-                                  int tile_row, int tile_col,
-                                  rtile_visitor_t fun, void *arg) {
-  const AV1_COMMON *const cm = &ctxt->cpi->common;
-  const RestorationInfo *rsi = ctxt->cpi->rst_search;
+static void rsc_on_tile(RestSearchCtxt *rsc) {
+  set_default_sgrproj(&rsc->sgrproj);
+  set_default_wiener(&rsc->wiener);
+}
+
+static void rsc_init(RestSearchCtxt *rsc) {
+  rsc->sse = rsc->bits = 0;
+  rsc_on_tile(rsc);
+}
+
+static double rsc_cost(const RestSearchCtxt *rsc, int rdmult) {
+  return RDCOST_DBL(rdmult, rsc->bits >> 4, rsc->sse);
+}
+
+typedef void (*on_rest_unit_fun)(const RestSearchInfo *info,
+                                 const RestorationTileLimits *limits,
+                                 RestUnitSearchInfo *rusi, RestSearchCtxt *rsc);
+
+static void foreach_rtile_in_tile(const RestSearchInfo *info, int tile_row,
+                                  int tile_col, on_rest_unit_fun fun,
+                                  RestSearchCtxt *rsc) {
+  const AV1_COMMON *const cm = info->cm;
   TileInfo tile_info;
 
   av1_tile_set_row(&tile_info, cm, tile_row);
@@ -482,7 +481,7 @@
   int tile_col_end = tile_info.mi_col_end * MI_SIZE;
   int tile_row_start = tile_info.mi_row_start * MI_SIZE;
   int tile_row_end = tile_info.mi_row_end * MI_SIZE;
-  if (ctxt->plane > 0) {
+  if (info->plane > 0) {
     tile_col_start = ROUND_POWER_OF_TWO(tile_col_start, cm->subsampling_x);
     tile_col_end = ROUND_POWER_OF_TWO(tile_col_end, cm->subsampling_x);
     tile_row_start = ROUND_POWER_OF_TWO(tile_row_start, cm->subsampling_y);
@@ -499,39 +498,37 @@
     av1_calculate_unscaled_superres_size(&tile_col_end, &tile_row_end,
                                          cm->superres_scale_denominator);
     // Make sure we don't fall off the bottom-right of the frame.
-    tile_col_end = AOMMIN(tile_col_end, ctxt->plane_width);
-    tile_row_end = AOMMIN(tile_row_end, ctxt->plane_height);
+    tile_col_end = AOMMIN(tile_col_end, info->plane_width);
+    tile_row_end = AOMMIN(tile_row_end, info->plane_height);
   }
 #endif  // CONFIG_FRAME_SUPERRES
 
-  const int rtile_size = rsi->restoration_tilesize;
+  const int rtile_size = cm->rst_info[info->plane].restoration_tilesize;
   const int rtile_col0 = (tile_col_start + rtile_size - 1) / rtile_size;
   const int rtile_col1 =
-      AOMMIN((tile_col_end + rtile_size - 1) / rtile_size, ctxt->nrtiles_x);
+      AOMMIN((tile_col_end + rtile_size - 1) / rtile_size, info->nrtiles_x);
   const int rtile_row0 = (tile_row_start + rtile_size - 1) / rtile_size;
   const int rtile_row1 =
-      AOMMIN((tile_row_end + rtile_size - 1) / rtile_size, ctxt->nrtiles_y);
-  const int ss_y = ctxt->plane > 0 && cm->subsampling_y;
+      AOMMIN((tile_row_end + rtile_size - 1) / rtile_size, info->nrtiles_y);
+  const int ss_y = info->plane > 0 && cm->subsampling_y;
 
   for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
     for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
-      const int rtile_idx = rtile_row * ctxt->nrtiles_x + rtile_col;
+      const int rtile_idx = rtile_row * info->nrtiles_x + rtile_col;
       RestorationTileLimits limits = av1_get_rest_tile_limits(
-          rtile_idx, ctxt->nrtiles_x, ctxt->nrtiles_y, rtile_size,
-          ctxt->plane_width, ctxt->plane_height, ss_y);
-      fun(ctxt, rtile_idx, &limits, arg);
+          rtile_idx, info->nrtiles_x, info->nrtiles_y, rtile_size,
+          info->plane_width, info->plane_height, ss_y);
+      fun(info, &limits, &info->rusi[rtile_idx], rsc);
     }
   }
 }
 
-static void search_sgrproj_for_rtile(const struct rest_search_ctxt *ctxt,
-                                     int rtile_idx,
-                                     const RestorationTileLimits *limits,
-                                     void *arg) {
-  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
-  const AV1_COMMON *const cm = &ctxt->cpi->common;
-  RestorationInfo *rsi = ctxt->cpi->rst_search;
-  SgrprojInfo *ref_sgrproj_info = (SgrprojInfo *)arg;
+static void search_sgrproj(const RestSearchInfo *info,
+                           const RestorationTileLimits *limits,
+                           RestUnitSearchInfo *rusi, RestSearchCtxt *rsc) {
+  const MACROBLOCK *const x = info->x;
+  const AV1_COMMON *const cm = info->cm;
+  const RestorationInfo *rsi = &cm->rst_info[info->plane];
 
 #if CONFIG_HIGHBITDEPTH
   const int highbd = cm->use_highbitdepth;
@@ -541,105 +538,41 @@
   const int bit_depth = 8;
 #endif  // CONFIG_HIGHBITDEPTH
 
-  int64_t err = sse_restoration_tile(limits, ctxt->src, cm->frame_to_show,
-                                     ctxt->plane, highbd);
-  // #bits when a tile is not restored
-  int bits = x->sgrproj_restore_cost[0];
-  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-  ctxt->best_tile_cost[rtile_idx] = INT64_MAX;
-
-  RestorationUnitInfo *plane_rui = &rsi[ctxt->plane].unit_info[rtile_idx];
-  SgrprojInfo *rtile_sgrproj_info = &plane_rui->sgrproj_info;
   uint8_t *dgd_start =
-      ctxt->dgd_buffer + limits->v_start * ctxt->dgd_stride + limits->h_start;
+      info->dgd_buffer + limits->v_start * info->dgd_stride + limits->h_start;
   const uint8_t *src_start =
-      ctxt->src_buffer + limits->v_start * ctxt->src_stride + limits->h_start;
+      info->src_buffer + limits->v_start * info->src_stride + limits->h_start;
 
-  search_selfguided_restoration(
+  rusi->sgrproj = search_selfguided_restoration(
       dgd_start, limits->h_end - limits->h_start,
-      limits->v_end - limits->v_start, ctxt->dgd_stride, src_start,
-      ctxt->src_stride, highbd, bit_depth, rsi[ctxt->plane].procunit_width,
-      rsi[ctxt->plane].procunit_height, &rtile_sgrproj_info->ep,
-      rtile_sgrproj_info->xqd, cm->rst_tmpbuf);
+      limits->v_end - limits->v_start, info->dgd_stride, src_start,
+      info->src_stride, highbd, bit_depth, rsi->procunit_width,
+      rsi->procunit_height, cm->rst_tmpbuf);
 
-  plane_rui->restoration_type = RESTORE_SGRPROJ;
-  err = try_restoration_tile(cm, ctxt->src, limits, plane_rui, ctxt->dst_frame,
-                             ctxt->plane);
+  RestorationUnitInfo rui;
+  rui.restoration_type = RESTORE_SGRPROJ;
+  rui.sgrproj_info = rusi->sgrproj;
 
-  bits = count_sgrproj_bits(rtile_sgrproj_info, ref_sgrproj_info)
-         << AV1_PROB_COST_SHIFT;
-  bits += x->sgrproj_restore_cost[1];
-  double cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-  if (cost_sgrproj >= cost_norestore) {
-    ctxt->type[rtile_idx] = RESTORE_NONE;
-  } else {
-    ctxt->type[rtile_idx] = RESTORE_SGRPROJ;
-    SgrprojInfo *sgrproj_info = &ctxt->info->unit_info[rtile_idx].sgrproj_info;
-    *ref_sgrproj_info = *sgrproj_info = plane_rui->sgrproj_info;
-    ctxt->best_tile_cost[rtile_idx] = err;
-  }
-  plane_rui->restoration_type = RESTORE_NONE;
-}
+  rusi->sse[RESTORE_SGRPROJ] = try_restoration_tile(
+      cm, info->src, limits, &rui, info->dst_frame, info->plane);
 
-static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
-                             int plane, RestorationInfo *info,
-                             RestorationType *type, int64_t *best_tile_cost,
-                             YV12_BUFFER_CONFIG *dst_frame) {
-  const MACROBLOCK *const x = &cpi->td.mb;
-  struct rest_search_ctxt ctxt;
-  const int nrtiles = init_rest_search_ctxt(src, cpi, plane, info, type,
-                                            best_tile_cost, dst_frame, &ctxt);
+  const int64_t bits_none = x->sgrproj_restore_cost[0];
+  const int64_t bits_sgr = x->sgrproj_restore_cost[1] +
+                           (count_sgrproj_bits(&rusi->sgrproj, &rsc->sgrproj)
+                            << AV1_PROB_COST_SHIFT);
 
-  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
-  plane_rsi->frame_restoration_type = RESTORE_SGRPROJ;
-  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
-    plane_rsi->unit_info[rtile_idx].restoration_type = RESTORE_NONE;
-  }
+  double cost_none =
+      RDCOST_DBL(x->rdmult, bits_none >> 4, rusi->sse[RESTORE_NONE]);
+  double cost_sgr =
+      RDCOST_DBL(x->rdmult, bits_sgr >> 4, rusi->sse[RESTORE_SGRPROJ]);
 
-  // Compute best Sgrproj filters for each rtile, one (encoder/decoder)
-  // tile at a time.
-  AV1_COMMON *const cm = &cpi->common;
-#if CONFIG_HIGHBITDEPTH
-  const int highbd = cm->use_highbitdepth;
-#else
-  const int highbd = 0;
-#endif
-  extend_frame(ctxt.dgd_buffer, ctxt.plane_width, ctxt.plane_height,
-               ctxt.dgd_stride, SGRPROJ_BORDER_HORZ, SGRPROJ_BORDER_VERT,
-               highbd);
+  RestorationType rtype =
+      (cost_sgr < cost_none) ? RESTORE_SGRPROJ : RESTORE_NONE;
+  rusi->best_rtype[RESTORE_SGRPROJ - 1] = rtype;
 
-  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
-    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
-      SgrprojInfo ref_sgrproj_info;
-      set_default_sgrproj(&ref_sgrproj_info);
-      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_sgrproj_for_rtile,
-                            &ref_sgrproj_info);
-    }
-  }
-
-  // Cost for Sgrproj filtering
-  SgrprojInfo ref_sgrproj_info;
-  set_default_sgrproj(&ref_sgrproj_info);
-
-  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
-             << AV1_PROB_COST_SHIFT;
-  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
-    RestorationUnitInfo *plane_rui = &plane_rsi->unit_info[rtile_idx];
-    RestorationUnitInfo *search_rui = &info->unit_info[rtile_idx];
-
-    bits += x->sgrproj_restore_cost[type[rtile_idx] != RESTORE_NONE];
-    plane_rui->sgrproj_info = search_rui->sgrproj_info;
-    if (type[rtile_idx] == RESTORE_SGRPROJ) {
-      bits += count_sgrproj_bits(&plane_rui->sgrproj_info, &ref_sgrproj_info)
-              << AV1_PROB_COST_SHIFT;
-      ref_sgrproj_info = plane_rui->sgrproj_info;
-    }
-    plane_rui->restoration_type = type[rtile_idx];
-  }
-  int64_t err =
-      try_restoration_frame(src, cm, cpi->rst_search, dst_frame, plane);
-  double cost_sgrproj = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
-  return cost_sgrproj;
+  rsc->sse += rusi->sse[rtype];
+  rsc->bits += (cost_sgr < cost_none) ? bits_sgr : bits_none;
+  if (cost_sgr < cost_none) rsc->sgrproj = rusi->sgrproj;
 }
 
 static double find_average(const uint8_t *src, int h_start, int h_end,
@@ -1104,367 +1037,246 @@
   return err;
 }
 
-static void search_wiener_for_rtile(const struct rest_search_ctxt *ctxt,
-                                    int rtile_idx,
-                                    const RestorationTileLimits *limits,
-                                    void *arg) {
-  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
-  const AV1_COMMON *const cm = &ctxt->cpi->common;
-  RestorationInfo *rsi = ctxt->cpi->rst_search;
-
+static void search_wiener(const RestSearchInfo *info,
+                          const RestorationTileLimits *limits,
+                          RestUnitSearchInfo *rusi, RestSearchCtxt *rsc) {
   const int wiener_win =
-      (ctxt->plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
+      (info->plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
 
   double M[WIENER_WIN2];
   double H[WIENER_WIN2 * WIENER_WIN2];
   double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
 
-  WienerInfo *ref_wiener_info = (WienerInfo *)arg;
-
 #if CONFIG_HIGHBITDEPTH
-  const int highbd = cm->use_highbitdepth;
-#else
-  const int highbd = 0;
-#endif  // CONFIG_HIGHBITDEPTH
-
-  int64_t err = sse_restoration_tile(limits, ctxt->src, cm->frame_to_show,
-                                     ctxt->plane, highbd);
-  // #bits when a tile is not restored
-  int bits = x->wiener_restore_cost[0];
-  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-  ctxt->best_tile_cost[rtile_idx] = INT64_MAX;
-
-#if CONFIG_HIGHBITDEPTH
+  const AV1_COMMON *const cm = info->cm;
   if (cm->use_highbitdepth)
-    compute_stats_highbd(wiener_win, ctxt->dgd_buffer, ctxt->src_buffer,
+    compute_stats_highbd(wiener_win, info->dgd_buffer, info->src_buffer,
                          limits->h_start, limits->h_end, limits->v_start,
-                         limits->v_end, ctxt->dgd_stride, ctxt->src_stride, M,
+                         limits->v_end, info->dgd_stride, info->src_stride, M,
                          H);
   else
 #endif  // CONFIG_HIGHBITDEPTH
-    compute_stats(wiener_win, ctxt->dgd_buffer, ctxt->src_buffer,
+    compute_stats(wiener_win, info->dgd_buffer, info->src_buffer,
                   limits->h_start, limits->h_end, limits->v_start,
-                  limits->v_end, ctxt->dgd_stride, ctxt->src_stride, M, H);
+                  limits->v_end, info->dgd_stride, info->src_stride, M, H);
 
-  ctxt->type[rtile_idx] = RESTORE_WIENER;
+  const MACROBLOCK *const x = info->x;
+  const int64_t bits_none = x->wiener_restore_cost[0];
 
   if (!wiener_decompose_sep_sym(wiener_win, M, H, vfilterd, hfilterd)) {
-    ctxt->type[rtile_idx] = RESTORE_NONE;
+    rsc->bits += bits_none;
+    rsc->sse += rusi->sse[RESTORE_NONE];
+    rusi->best_rtype[RESTORE_WIENER - 1] = RESTORE_NONE;
+    rusi->sse[RESTORE_WIENER] = INT64_MAX;
     return;
   }
 
-  RestorationUnitInfo *plane_rui = &rsi[ctxt->plane].unit_info[rtile_idx];
-  WienerInfo *rtile_wiener_info = &plane_rui->wiener_info;
-  quantize_sym_filter(wiener_win, vfilterd, rtile_wiener_info->vfilter);
-  quantize_sym_filter(wiener_win, hfilterd, rtile_wiener_info->hfilter);
+  RestorationUnitInfo rui;
+  memset(&rui, 0, sizeof(rui));
+  rui.restoration_type = RESTORE_WIENER;
+  quantize_sym_filter(wiener_win, vfilterd, rui.wiener_info.vfilter);
+  quantize_sym_filter(wiener_win, hfilterd, rui.wiener_info.hfilter);
 
   // Filter score computes the value of the function x'*A*x - x'*b for the
   // learned filter and compares it against identity filer. If there is no
   // reduction in the function, the filter is reverted back to identity
-  double score = compute_score(wiener_win, M, H, rtile_wiener_info->vfilter,
-                               rtile_wiener_info->hfilter);
-  if (score > 0.0) {
-    ctxt->type[rtile_idx] = RESTORE_NONE;
+  if (compute_score(wiener_win, M, H, rui.wiener_info.vfilter,
+                    rui.wiener_info.hfilter) > 0) {
+    rsc->bits += bits_none;
+    rsc->sse += rusi->sse[RESTORE_NONE];
+    rusi->best_rtype[RESTORE_WIENER - 1] = RESTORE_NONE;
+    rusi->sse[RESTORE_WIENER] = INT64_MAX;
     return;
   }
+
   aom_clear_system_state();
 
-  plane_rui->restoration_type = RESTORE_WIENER;
-  err =
-      finer_tile_search_wiener(&ctxt->cpi->common, ctxt->src, limits, plane_rui,
-                               4, ctxt->plane, wiener_win, ctxt->dst_frame);
+  rusi->sse[RESTORE_WIENER] =
+      finer_tile_search_wiener(info->cm, info->src, limits, &rui, 4,
+                               info->plane, wiener_win, info->dst_frame);
+  rusi->wiener = rui.wiener_info;
+
   if (wiener_win != WIENER_WIN) {
-    assert(rtile_wiener_info->vfilter[0] == 0 &&
-           rtile_wiener_info->vfilter[WIENER_WIN - 1] == 0);
-    assert(rtile_wiener_info->hfilter[0] == 0 &&
-           rtile_wiener_info->hfilter[WIENER_WIN - 1] == 0);
+    assert(rui.wiener_info.vfilter[0] == 0 &&
+           rui.wiener_info.vfilter[WIENER_WIN - 1] == 0);
+    assert(rui.wiener_info.hfilter[0] == 0 &&
+           rui.wiener_info.hfilter[WIENER_WIN - 1] == 0);
   }
-  bits = count_wiener_bits(wiener_win, rtile_wiener_info, ref_wiener_info)
-         << AV1_PROB_COST_SHIFT;
-  bits += x->wiener_restore_cost[1];
-  double cost_wiener = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-  if (cost_wiener >= cost_norestore) {
-    ctxt->type[rtile_idx] = RESTORE_NONE;
-  } else {
-    ctxt->type[rtile_idx] = RESTORE_WIENER;
-    *ref_wiener_info = ctxt->info->unit_info[rtile_idx].wiener_info =
-        *rtile_wiener_info;
-    ctxt->best_tile_cost[rtile_idx] = err;
-  }
-  plane_rui->restoration_type = RESTORE_NONE;
+
+  const int64_t bits_wiener =
+      x->wiener_restore_cost[1] +
+      (count_wiener_bits(wiener_win, &rusi->wiener, &rsc->wiener)
+       << AV1_PROB_COST_SHIFT);
+
+  double cost_none =
+      RDCOST_DBL(x->rdmult, bits_none >> 4, rusi->sse[RESTORE_NONE]);
+  double cost_wiener =
+      RDCOST_DBL(x->rdmult, bits_wiener >> 4, rusi->sse[RESTORE_WIENER]);
+
+  RestorationType rtype =
+      (cost_wiener < cost_none) ? RESTORE_WIENER : RESTORE_NONE;
+  rusi->best_rtype[RESTORE_WIENER - 1] = rtype;
+
+  rsc->sse += rusi->sse[rtype];
+  rsc->bits += (cost_wiener < cost_none) ? bits_wiener : bits_none;
+  if (cost_wiener < cost_none) rsc->wiener = rusi->wiener;
 }
 
-static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
-                            int plane, RestorationInfo *info,
-                            RestorationType *type, int64_t *best_tile_cost,
-                            YV12_BUFFER_CONFIG *dst_frame) {
-  const MACROBLOCK *const x = &cpi->td.mb;
-  struct rest_search_ctxt ctxt;
-  const int nrtiles = init_rest_search_ctxt(src, cpi, plane, info, type,
-                                            best_tile_cost, dst_frame, &ctxt);
-
-  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
-  plane_rsi->frame_restoration_type = RESTORE_WIENER;
-  for (int tile_idx = 0; tile_idx < nrtiles; ++tile_idx) {
-    plane_rsi->unit_info[tile_idx].restoration_type = RESTORE_NONE;
-  }
-
-  AV1_COMMON *const cm = &cpi->common;
-// Construct a (WIENER_HALFWIN)-pixel border around the frame
-// Note use this border to gather stats even though the actual filter
-// may use less border on the top/bottom of a processing unit.
+static void search_norestore(const RestSearchInfo *info,
+                             const RestorationTileLimits *limits,
+                             RestUnitSearchInfo *rusi, RestSearchCtxt *rsc) {
 #if CONFIG_HIGHBITDEPTH
-  const int highbd = cm->use_highbitdepth;
-#else
-  const int highbd = 0;
-#endif
-  extend_frame(ctxt.dgd_buffer, ctxt.plane_width, ctxt.plane_height,
-               ctxt.dgd_stride, WIENER_HALFWIN, WIENER_HALFWIN, highbd);
-
-  // Compute best Wiener filters for each rtile, one (encoder/decoder)
-  // tile at a time.
-  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
-    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
-      WienerInfo ref_wiener_info;
-      set_default_wiener(&ref_wiener_info);
-
-      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_wiener_for_rtile,
-                            &ref_wiener_info);
-    }
-  }
-
-  // cost for Wiener filtering
-  WienerInfo ref_wiener_info;
-  set_default_wiener(&ref_wiener_info);
-  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
-             << AV1_PROB_COST_SHIFT;
-  const int wiener_win =
-      (plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
-
-  for (int tile_idx = 0; tile_idx < nrtiles; ++tile_idx) {
-    bits += x->wiener_restore_cost[type[tile_idx] != RESTORE_NONE];
-    RestorationUnitInfo *plane_rui = &plane_rsi->unit_info[tile_idx];
-    plane_rui->wiener_info = info->unit_info[tile_idx].wiener_info;
-
-    if (type[tile_idx] == RESTORE_WIENER) {
-      bits += count_wiener_bits(wiener_win, &plane_rui->wiener_info,
-                                &ref_wiener_info)
-              << AV1_PROB_COST_SHIFT;
-      ref_wiener_info = plane_rui->wiener_info;
-    }
-    plane_rui->restoration_type = type[tile_idx];
-  }
-  int64_t err =
-      try_restoration_frame(src, cm, cpi->rst_search, dst_frame, plane);
-  double cost_wiener = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
-
-  return cost_wiener;
-}
-
-static double search_norestore(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
-                               int plane, RestorationInfo *info,
-                               RestorationType *type, int64_t *best_tile_cost,
-                               YV12_BUFFER_CONFIG *dst_frame) {
-  int64_t err;
-  double cost_norestore;
-  int bits;
-  MACROBLOCK *x = &cpi->td.mb;
-  AV1_COMMON *const cm = &cpi->common;
-  int tile_idx, nhtiles, nvtiles;
-
-#if CONFIG_HIGHBITDEPTH
-  const int highbd = cm->use_highbitdepth;
+  const int highbd = info->cm->use_highbitdepth;
 #else
   const int highbd = 0;
 #endif  // CONFIG_HIGHBITDEPTH
 
-  const int is_uv = plane > 0;
-  const int ss_y = plane > 0 && cm->subsampling_y;
-  const int width = src->crop_widths[is_uv];
-  const int height = src->crop_heights[is_uv];
+  rusi->sse[RESTORE_NONE] = sse_restoration_tile(
+      limits, info->src, info->cm->frame_to_show, info->plane, highbd);
 
-  const int rtile_size = cm->rst_info[plane].restoration_tilesize;
-
-  const int ntiles =
-      av1_get_rest_ntiles(width, height, rtile_size, &nhtiles, &nvtiles);
-  (void)dst_frame;
-
-  info->frame_restoration_type = RESTORE_NONE;
-  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    RestorationTileLimits limits = av1_get_rest_tile_limits(
-        tile_idx, nhtiles, nvtiles, rtile_size, width, height, ss_y);
-    err = sse_restoration_tile(&limits, src, cm->frame_to_show, plane, highbd);
-    type[tile_idx] = RESTORE_NONE;
-    best_tile_cost[tile_idx] = err;
-  }
-  // RD cost associated with no restoration
-  err = sse_restoration_frame(src, cm->frame_to_show, plane, highbd);
-  bits = frame_level_restore_bits[RESTORE_NONE] << AV1_PROB_COST_SHIFT;
-  cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
-  return cost_norestore;
+  rsc->sse += rusi->sse[RESTORE_NONE];
 }
 
-struct switchable_rest_search_ctxt {
-  SgrprojInfo sgrproj_info;
-  WienerInfo wiener_info;
-  RestorationType *const *restore_types;
-  int64_t *const *tile_cost;
-  double cost_switchable;
-};
-
-static void search_switchable_for_rtile(const struct rest_search_ctxt *ctxt,
-                                        int rtile_idx,
-                                        const RestorationTileLimits *limits,
-                                        void *arg) {
-  const MACROBLOCK *x = &ctxt->cpi->td.mb;
-  RestorationUnitInfo *rui =
-      &ctxt->cpi->common.rst_info[ctxt->plane].unit_info[rtile_idx];
-  struct switchable_rest_search_ctxt *swctxt =
-      (struct switchable_rest_search_ctxt *)arg;
-
+static void search_switchable(const RestSearchInfo *info,
+                              const RestorationTileLimits *limits,
+                              RestUnitSearchInfo *rusi, RestSearchCtxt *rsc) {
   (void)limits;
 
-  double best_cost =
-      RDCOST_DBL(x->rdmult, (x->switchable_restore_cost[RESTORE_NONE] >> 4),
-                 swctxt->tile_cost[RESTORE_NONE][rtile_idx]);
+  const MACROBLOCK *const x = info->x;
 
-  rui->restoration_type = RESTORE_NONE;
-  for (RestorationType r = 1; r < RESTORE_SWITCHABLE_TYPES; r++) {
-    if (force_restore_type != RESTORE_TYPES)
-      if (r != force_restore_type) continue;
-    int tilebits = 0;
-    if (swctxt->restore_types[r][rtile_idx] != r) continue;
-    if (r == RESTORE_WIENER)
-      tilebits += count_wiener_bits(
-          (ctxt->plane == AOM_PLANE_Y ? WIENER_WIN : WIENER_WIN - 2),
-          &rui->wiener_info, &swctxt->wiener_info);
-    else if (r == RESTORE_SGRPROJ)
-      tilebits += count_sgrproj_bits(&rui->sgrproj_info, &swctxt->sgrproj_info);
-    tilebits <<= AV1_PROB_COST_SHIFT;
-    tilebits += x->switchable_restore_cost[r];
-    double cost =
-        RDCOST_DBL(x->rdmult, tilebits >> 4, swctxt->tile_cost[r][rtile_idx]);
+  const int wiener_win =
+      (info->plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
 
-    if (cost < best_cost) {
-      rui->restoration_type = r;
+  double best_cost = 0;
+  int64_t best_bits = 0;
+  RestorationType best_rtype = RESTORE_NONE;
+
+  for (RestorationType r = 0; r < RESTORE_SWITCHABLE_TYPES; ++r) {
+    const int64_t sse = rusi->sse[r];
+    int64_t coeff_pcost = 0;
+    switch (r) {
+      case RESTORE_NONE: coeff_pcost = 0; break;
+      case RESTORE_WIENER:
+        coeff_pcost =
+            count_wiener_bits(wiener_win, &rusi->wiener, &rsc->wiener);
+        break;
+      default:
+        assert(r == RESTORE_SGRPROJ);
+        coeff_pcost = count_sgrproj_bits(&rusi->sgrproj, &rsc->sgrproj);
+        break;
+    }
+    const int64_t coeff_bits = coeff_pcost << AV1_PROB_COST_SHIFT;
+    const int64_t bits = x->switchable_restore_cost[r] + coeff_bits;
+    double cost = RDCOST_DBL(x->rdmult, bits >> 4, sse);
+    if (r == 0 || cost < best_cost) {
       best_cost = cost;
+      best_bits = bits;
+      best_rtype = r;
     }
   }
-  if (rui->restoration_type == RESTORE_WIENER)
-    swctxt->wiener_info = rui->wiener_info;
-  else if (rui->restoration_type == RESTORE_SGRPROJ)
-    swctxt->sgrproj_info = rui->sgrproj_info;
-  if (force_restore_type != RESTORE_TYPES)
-    assert(rui->restoration_type == force_restore_type ||
-           rui->restoration_type == RESTORE_NONE);
-  swctxt->cost_switchable += best_cost;
+
+  rusi->best_rtype[RESTORE_SWITCHABLE - 1] = best_rtype;
+
+  rsc->sse += rusi->sse[best_rtype];
+  rsc->bits += best_bits;
+  if (best_rtype == RESTORE_WIENER) rsc->wiener = rusi->wiener;
+  if (best_rtype == RESTORE_SGRPROJ) rsc->sgrproj = rusi->sgrproj;
 }
 
-static double search_switchable_restoration(
-    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int plane,
-    RestorationType *const restore_types[RESTORE_SWITCHABLE_TYPES],
-    int64_t *const tile_cost[RESTORE_SWITCHABLE_TYPES], RestorationInfo *rsi) {
-  const AV1_COMMON *const cm = &cpi->common;
-  struct rest_search_ctxt ctxt;
-  init_rest_search_ctxt(src, cpi, plane, NULL, NULL, NULL, NULL, &ctxt);
-  struct switchable_rest_search_ctxt swctxt;
-  swctxt.restore_types = restore_types;
-  swctxt.tile_cost = tile_cost;
+static void copy_unit_info(RestorationType frame_rtype,
+                           const RestUnitSearchInfo *rusi,
+                           RestorationUnitInfo *rui) {
+  assert(frame_rtype > 0);
+  rui->restoration_type = rusi->best_rtype[frame_rtype - 1];
+  if (rui->restoration_type == RESTORE_WIENER)
+    rui->wiener_info = rusi->wiener;
+  else
+    rui->sgrproj_info = rusi->sgrproj;
+}
 
-  rsi->frame_restoration_type = RESTORE_SWITCHABLE;
-  int bits = frame_level_restore_bits[rsi->frame_restoration_type]
-             << AV1_PROB_COST_SHIFT;
-  swctxt.cost_switchable = RDCOST_DBL(cpi->td.mb.rdmult, bits >> 4, 0);
+static double search_rest_type(const RestSearchInfo *info,
+                               RestorationType rtype) {
+  static const on_rest_unit_fun funs[RESTORE_TYPES] = {
+    search_norestore, search_wiener, search_sgrproj, search_switchable
+  };
+  static const int hborders[RESTORE_TYPES] = { 0, WIENER_HALFWIN,
+                                               SGRPROJ_BORDER_HORZ, 0 };
+  static const int vborders[RESTORE_TYPES] = { 0, WIENER_HALFWIN,
+                                               SGRPROJ_BORDER_VERT, 0 };
 
+  const AV1_COMMON *const cm = info->cm;
+
+  if (hborders[rtype] || vborders[rtype]) {
+#if CONFIG_HIGHBITDEPTH
+    const int highbd = cm->use_highbitdepth;
+#else
+    const int highbd = 0;
+#endif
+    extend_frame(info->dgd_buffer, info->plane_width, info->plane_height,
+                 info->dgd_stride, hborders[rtype], vborders[rtype], highbd);
+  }
+
+  RestSearchCtxt rsc;
+  rsc_init(&rsc);
+  rsc.bits = frame_level_restore_bits[rtype] << AV1_PROB_COST_SHIFT;
   for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
     for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
-      set_default_sgrproj(&swctxt.sgrproj_info);
-      set_default_wiener(&swctxt.wiener_info);
-      foreach_rtile_in_tile(&ctxt, tile_row, tile_col,
-                            search_switchable_for_rtile, &swctxt);
+      if (tile_row || tile_col) rsc_on_tile(&rsc);
+      foreach_rtile_in_tile(info, tile_row, tile_col, funs[rtype], &rsc);
     }
   }
 
-  return swctxt.cost_switchable;
+  return rsc_cost(&rsc, info->x->rdmult);
 }
 
 void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) {
-  static search_restore_type search_restore_fun[RESTORE_SWITCHABLE_TYPES] = {
-    search_norestore, search_wiener, search_sgrproj,
-  };
   AV1_COMMON *const cm = &cpi->common;
-  double cost_restore[RESTORE_TYPES];
-  int64_t *tile_cost[RESTORE_SWITCHABLE_TYPES];
-  RestorationType *restore_types[RESTORE_SWITCHABLE_TYPES];
-  double best_cost_restore;
-  RestorationType r, best_restore;
-  const int ywidth = src->y_crop_width;
-  const int yheight = src->y_crop_height;
-  const int uvwidth = src->uv_crop_width;
-  const int uvheight = src->uv_crop_height;
 
-  const int ntiles_y = av1_get_rest_ntiles(
-      ywidth, yheight, cm->rst_info[0].restoration_tilesize, NULL, NULL);
-  const int ntiles_uv = av1_get_rest_ntiles(
-      uvwidth, uvheight, cm->rst_info[1].restoration_tilesize, NULL, NULL);
+  int ntiles[2];
+  for (int is_uv = 0; is_uv < 2; ++is_uv)
+    ntiles[is_uv] = av1_get_rest_ntiles(
+        src->crop_widths[is_uv], src->crop_heights[is_uv],
+        cm->rst_info[is_uv].restoration_tilesize, NULL, NULL);
 
-  // Assume ntiles_uv is never larger that ntiles_y and so the same arrays work.
-  assert(ntiles_uv <= ntiles_y);
-  for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++) {
-    tile_cost[r] = (int64_t *)aom_malloc(sizeof(*tile_cost[0]) * ntiles_y);
-    restore_types[r] =
-        (RestorationType *)aom_malloc(sizeof(*restore_types[0]) * ntiles_y);
-  }
+  assert(ntiles[1] <= ntiles[0]);
+  RestUnitSearchInfo *rusi =
+      (RestUnitSearchInfo *)aom_malloc(sizeof(*rusi) * ntiles[0]);
 
   for (int plane = AOM_PLANE_Y; plane <= AOM_PLANE_V; ++plane) {
-    const int ntiles = (plane == AOM_PLANE_Y ? ntiles_y : ntiles_uv);
-    for (r = 0; r < RESTORE_SWITCHABLE_TYPES; ++r) {
-      cost_restore[r] = DBL_MAX;
-      if (force_restore_type != RESTORE_TYPES)
-        if (r != RESTORE_NONE && r != force_restore_type) continue;
-      cost_restore[r] = search_restore_fun[r](
-          src, cpi, plane, &cm->rst_info[plane], restore_types[r], tile_cost[r],
-          &cpi->trial_frame_rst);
-    }
-    if (ntiles > 1)
-      cost_restore[RESTORE_SWITCHABLE] = search_switchable_restoration(
-          src, cpi, plane, restore_types, tile_cost, &cm->rst_info[plane]);
-    else
-      cost_restore[RESTORE_SWITCHABLE] = DBL_MAX;
-    best_cost_restore = DBL_MAX;
-    best_restore = 0;
-    for (r = 0; r < RESTORE_TYPES; ++r) {
-      if (force_restore_type != RESTORE_TYPES)
-        if (r != RESTORE_NONE && r != force_restore_type) continue;
-      if (cost_restore[r] < best_cost_restore) {
-        best_restore = r;
-        best_cost_restore = cost_restore[r];
-      }
-    }
-    cm->rst_info[plane].frame_restoration_type = best_restore;
-    if (force_restore_type != RESTORE_TYPES)
-      assert(best_restore == force_restore_type ||
-             best_restore == RESTORE_NONE);
-    if (best_restore != RESTORE_SWITCHABLE) {
-      for (int u = 0; u < ntiles; ++u) {
-        cm->rst_info[plane].unit_info[u].restoration_type =
-            restore_types[best_restore][u];
-      }
-    }
-  }
-  /*
-  printf("Frame %d/%d restore types: %d %d %d\n", cm->current_video_frame,
-         cm->show_frame, cm->rst_info[0].frame_restoration_type,
-         cm->rst_info[1].frame_restoration_type,
-         cm->rst_info[2].frame_restoration_type);
-  printf("Frame %d/%d frame_restore_type %d : %f %f %f %f\n",
-         cm->current_video_frame, cm->show_frame,
-         cm->rst_info[0].frame_restoration_type, cost_restore[0],
-         cost_restore[1], cost_restore[2], cost_restore[3]);
-         */
+    RestSearchInfo info;
+    init_rest_search_info(src, &cpi->common, &cpi->td.mb, plane, rusi,
+                          &cpi->trial_frame_rst, &info);
 
-  for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++) {
-    aom_free(tile_cost[r]);
-    aom_free(restore_types[r]);
+    const int plane_ntiles = ntiles[plane > 0];
+    const RestorationType num_rtypes =
+        (plane_ntiles > 1) ? RESTORE_TYPES : RESTORE_SWITCHABLE_TYPES;
+
+    double best_cost = 0;
+    RestorationType best_rtype = RESTORE_NONE;
+
+    for (RestorationType r = 0; r < num_rtypes; ++r) {
+      if ((force_restore_type != RESTORE_TYPES) && (r != RESTORE_NONE) &&
+          (r != force_restore_type))
+        continue;
+
+      double cost = search_rest_type(&info, r);
+
+      if (r == 0 || cost < best_cost) {
+        best_cost = cost;
+        best_rtype = r;
+      }
+    }
+
+    cm->rst_info[plane].frame_restoration_type = best_rtype;
+    if (force_restore_type != RESTORE_TYPES)
+      assert(best_rtype == force_restore_type || best_rtype == RESTORE_NONE);
+
+    if (best_rtype != RESTORE_NONE) {
+      for (int u = 0; u < plane_ntiles; ++u) {
+        copy_unit_info(best_rtype, &rusi[u], &cm->rst_info[plane].unit_info[u]);
+      }
+    }
   }
+
+  aom_free(rusi);
 }