Enable tile-adaptive restoration

Includes a major refactoring/enhancement to support
tile-adaptive switchable restoration. The framework can be
readily extended to add more restoration schemes in the
future. Also includes various cleanups and fixes.

Specifically the framework allows restoration to be conducted
on tiles such that each tile can be either left unrestored, or
use bilateral or wiener filtering.

There is a modest improvemnt in coding efficiency (0.1 - 0.2%).

Further enhancements will be added subsequently to improve coding
efficiency and complexity.

Change-Id: I5ebedb04785ce1ef6f324abe209e925c2d6cbe8a
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index dab0116..db4fbf7 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -83,6 +83,8 @@
 
 #if CONFIG_LOOP_RESTORATION
 void av1_free_restoration_buffers(AV1_COMMON *cm) {
+  aom_free(cm->rst_info.restoration_type);
+  cm->rst_info.restoration_type = NULL;
   aom_free(cm->rst_info.bilateral_level);
   cm->rst_info.bilateral_level = NULL;
   aom_free(cm->rst_info.vfilter);
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index e1593e3..856fa35 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -857,6 +857,17 @@
       },
     };
 
+#if CONFIG_LOOP_RESTORATION
+const aom_tree_index
+    av1_switchable_restore_tree[TREE_SIZE(RESTORE_SWITCHABLE_TYPES)] = {
+      -RESTORE_NONE, 2,
+      -RESTORE_BILATERAL, -RESTORE_WIENER,
+    };
+
+static const aom_prob
+    default_switchable_restore_prob[RESTORE_SWITCHABLE_TYPES - 1] = {32, 128};
+#endif  // CONFIG_LOOP_RESTORATION
+
 #if CONFIG_EXT_TX && CONFIG_RECT_TX && CONFIG_VAR_TX
 // the probability of (0) using recursive square tx partition vs.
 // (1) biggest rect tx for 4X8-8X4/8X16-16X8/16X32-32X16 blocks
@@ -1340,6 +1351,9 @@
 #endif  // CONFIG_EXT_INTRA
   av1_copy(fc->inter_ext_tx_prob, default_inter_ext_tx_prob);
   av1_copy(fc->intra_ext_tx_prob, default_intra_ext_tx_prob);
+#if CONFIG_LOOP_RESTORATION
+  av1_copy(fc->switchable_restore_prob, default_switchable_restore_prob);
+#endif  // CONFIG_LOOP_RESTORATION
 }
 
 #if CONFIG_EXT_INTERP
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 7968484..c389e18 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -128,6 +128,9 @@
 #if CONFIG_GLOBAL_MOTION
   aom_prob global_motion_types_prob[GLOBAL_MOTION_TYPES - 1];
 #endif  // CONFIG_GLOBAL_MOTION
+#if CONFIG_LOOP_RESTORATION
+  aom_prob switchable_restore_prob[RESTORE_SWITCHABLE_TYPES - 1];
+#endif  // CONFIG_LOOP_RESTORATION
 } FRAME_CONTEXT;
 
 typedef struct FRAME_COUNTS {
@@ -263,6 +266,13 @@
 extern const aom_tree_index av1_motvar_tree[TREE_SIZE(MOTION_VARIATIONS)];
 #endif  // CONFIG_OBMC || CONFIG_WARPED_MOTION
 
+#if CONFIG_LOOP_RESTORATION
+#define RESTORE_NONE_BILATERAL_PROB 16
+#define RESTORE_NONE_WIENER_PROB 64
+extern const aom_tree_index
+    av1_switchable_restore_tree[TREE_SIZE(RESTORE_SWITCHABLE_TYPES)];
+#endif  // CONFIG_LOOP_RESTORATION
+
 void av1_setup_past_independence(struct AV1Common *cm);
 
 void av1_adapt_intra_frame_probs(struct AV1Common *cm);
diff --git a/av1/common/enums.h b/av1/common/enums.h
index b1ac2a0..c9d3211 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -433,6 +433,16 @@
 #define MAX_SUPERTX_BLOCK_SIZE BLOCK_32X32
 #endif  // CONFIG_SUPERTX
 
+#if CONFIG_LOOP_RESTORATION
+typedef enum {
+  RESTORE_NONE,
+  RESTORE_BILATERAL,
+  RESTORE_WIENER,
+  RESTORE_SWITCHABLE,
+  RESTORE_SWITCHABLE_TYPES = RESTORE_SWITCHABLE,
+  RESTORE_TYPES,
+} RestorationType;
+#endif  // CONFIG_LOOP_RESTORATION
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/common/loopfilter.c b/av1/common/loopfilter.c
index 2147bb8..f45f3db 100644
--- a/av1/common/loopfilter.c
+++ b/av1/common/loopfilter.c
@@ -16,7 +16,6 @@
 #include "av1/common/loopfilter.h"
 #include "av1/common/onyxc_int.h"
 #include "av1/common/reconinter.h"
-#include "av1/common/restoration.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_mem/aom_mem.h"
 #include "aom_ports/mem.h"
diff --git a/av1/common/loopfilter.h b/av1/common/loopfilter.h
index ae0ef8a..975cbdf 100644
--- a/av1/common/loopfilter.h
+++ b/av1/common/loopfilter.h
@@ -16,7 +16,6 @@
 #include "./aom_config.h"
 
 #include "av1/common/blockd.h"
-#include "av1/common/restoration.h"
 #include "av1/common/seg_common.h"
 
 #ifdef __cplusplus
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index a14b34f..6cd6cbe 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -25,7 +25,9 @@
 #include "av1/common/frame_buffers.h"
 #include "av1/common/quant_common.h"
 #include "av1/common/tile_common.h"
+#if CONFIG_LOOP_RESTORATION
 #include "av1/common/restoration.h"
+#endif  // CONFIG_LOOP_RESTORATION
 
 #ifdef __cplusplus
 extern "C" {
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index d50181e..4f44e12 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -70,36 +70,6 @@
             : bilateral_level_to_params_arr[index];
 }
 
-typedef struct TileParams {
-  int width;
-  int height;
-} TileParams;
-
-static TileParams restoration_tile_sizes[RESTORATION_TILESIZES] = {
-  { 64, 64 }, { 128, 128 }, { 256, 256 }
-};
-
-void av1_get_restoration_tile_size(int tilesize, int width, int height,
-                                   int *tile_width, int *tile_height,
-                                   int *nhtiles, int *nvtiles) {
-  *tile_width = (tilesize < 0)
-                    ? width
-                    : AOMMIN(restoration_tile_sizes[tilesize].width, width);
-  *tile_height = (tilesize < 0)
-                     ? height
-                     : AOMMIN(restoration_tile_sizes[tilesize].height, height);
-  *nhtiles = (width + (*tile_width >> 1)) / *tile_width;
-  *nvtiles = (height + (*tile_height >> 1)) / *tile_height;
-}
-
-int av1_get_restoration_ntiles(int tilesize, int width, int height) {
-  int nhtiles, nvtiles;
-  int tile_width, tile_height;
-  av1_get_restoration_tile_size(tilesize, width, height, &tile_width,
-                                &tile_height, &nhtiles, &nvtiles);
-  return (nhtiles * nvtiles);
-}
-
 void av1_loop_restoration_precal() {
   int i;
   for (i = 0; i < BILATERAL_LEVELS_KF; i++) {
@@ -169,90 +139,75 @@
 void av1_loop_restoration_init(RestorationInternal *rst, RestorationInfo *rsi,
                                int kf, int width, int height) {
   int i, tile_idx;
-  rst->restoration_type = rsi->restoration_type;
+  rst->rsi = rsi;
+  rst->keyframe = kf;
   rst->subsampling_x = 0;
   rst->subsampling_y = 0;
-  if (rsi->restoration_type == RESTORE_BILATERAL) {
-    rst->tilesize_index = BILATERAL_TILESIZE;
-    rst->ntiles =
-        av1_get_restoration_ntiles(rst->tilesize_index, width, height);
-    av1_get_restoration_tile_size(rst->tilesize_index, width, height,
-                                  &rst->tile_width, &rst->tile_height,
-                                  &rst->nhtiles, &rst->nvtiles);
-    rst->bilateral_level = rsi->bilateral_level;
-    rst->wr_lut = (uint8_t **)malloc(sizeof(*rst->wr_lut) * rst->ntiles);
-    assert(rst->wr_lut != NULL);
-    rst->wx_lut = (uint8_t(**)[RESTORATION_WIN])malloc(sizeof(*rst->wx_lut) *
-                                                       rst->ntiles);
-    assert(rst->wx_lut != NULL);
+  rst->ntiles =
+      av1_get_rest_ntiles(width, height, &rst->tile_width,
+                          &rst->tile_height, &rst->nhtiles, &rst->nvtiles);
+  if (rsi->frame_restoration_type == RESTORE_WIENER) {
     for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-      const int level = rsi->bilateral_level[tile_idx];
-      if (level >= 0) {
-        rst->wr_lut[tile_idx] = kf ? bilateral_filter_coeffs_r_kf[level]
-                                   : bilateral_filter_coeffs_r[level];
-        rst->wx_lut[tile_idx] = kf ? bilateral_filter_coeffs_s_kf[level]
-                                   : bilateral_filter_coeffs_s[level];
+      rsi->vfilter[tile_idx][RESTORATION_HALFWIN] =
+          rsi->hfilter[tile_idx][RESTORATION_HALFWIN] = RESTORATION_FILT_STEP;
+      for (i = 0; i < RESTORATION_HALFWIN; ++i) {
+        rsi->vfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+            rsi->vfilter[tile_idx][i];
+        rsi->hfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+            rsi->hfilter[tile_idx][i];
+        rsi->vfilter[tile_idx][RESTORATION_HALFWIN] -=
+            2 * rsi->vfilter[tile_idx][i];
+        rsi->hfilter[tile_idx][RESTORATION_HALFWIN] -=
+            2 * rsi->hfilter[tile_idx][i];
       }
     }
-  } else if (rsi->restoration_type == RESTORE_WIENER) {
-    rst->tilesize_index = WIENER_TILESIZE;
-    rst->ntiles =
-        av1_get_restoration_ntiles(rst->tilesize_index, width, height);
-    av1_get_restoration_tile_size(rst->tilesize_index, width, height,
-                                  &rst->tile_width, &rst->tile_height,
-                                  &rst->nhtiles, &rst->nvtiles);
-    rst->wiener_level = rsi->wiener_level;
-    rst->vfilter =
-        (int(*)[RESTORATION_WIN])malloc(sizeof(*rst->vfilter) * rst->ntiles);
-    assert(rst->vfilter != NULL);
-    rst->hfilter =
-        (int(*)[RESTORATION_WIN])malloc(sizeof(*rst->hfilter) * rst->ntiles);
-    assert(rst->hfilter != NULL);
+  } else if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
     for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-      rst->vfilter[tile_idx][RESTORATION_HALFWIN] =
-          rst->hfilter[tile_idx][RESTORATION_HALFWIN] = RESTORATION_FILT_STEP;
-      for (i = 0; i < RESTORATION_HALFWIN; ++i) {
-        rst->vfilter[tile_idx][i] =
-            rst->vfilter[tile_idx][RESTORATION_WIN - 1 - i] =
-                rsi->vfilter[tile_idx][i];
-        rst->hfilter[tile_idx][i] =
-            rst->hfilter[tile_idx][RESTORATION_WIN - 1 - i] =
-                rsi->hfilter[tile_idx][i];
-        rst->vfilter[tile_idx][RESTORATION_HALFWIN] -=
-            2 * rsi->vfilter[tile_idx][i];
-        rst->hfilter[tile_idx][RESTORATION_HALFWIN] -=
-            2 * rsi->hfilter[tile_idx][i];
+      if (rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+        rsi->vfilter[tile_idx][RESTORATION_HALFWIN] =
+            rsi->hfilter[tile_idx][RESTORATION_HALFWIN] = RESTORATION_FILT_STEP;
+        for (i = 0; i < RESTORATION_HALFWIN; ++i) {
+          rsi->vfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+              rsi->vfilter[tile_idx][i];
+          rsi->hfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+              rsi->hfilter[tile_idx][i];
+          rsi->vfilter[tile_idx][RESTORATION_HALFWIN] -=
+              2 * rsi->vfilter[tile_idx][i];
+          rsi->hfilter[tile_idx][RESTORATION_HALFWIN] -=
+              2 * rsi->hfilter[tile_idx][i];
+        }
       }
     }
   }
 }
 
-static void loop_bilateral_filter(uint8_t *data, int width, int height,
-                                  int stride, RestorationInternal *rst,
-                                  uint8_t *tmpdata, int tmpstride) {
-  int i, j, tile_idx, htile_idx, vtile_idx;
+static void loop_bilateral_filter_tile(uint8_t *data, int tile_idx, int width,
+                                       int height, int stride,
+                                       RestorationInternal *rst,
+                                       uint8_t *tmpdata, int tmpstride) {
+  int i, j, subtile_idx;
   int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
 
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
-
-  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+  for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
     uint8_t *data_p, *tmpdata_p;
-    const uint8_t *wr_lut_ = rst->wr_lut[tile_idx] + BILATERAL_AMP_RANGE;
+    const int level =
+        rst->rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + subtile_idx];
+    uint8_t(*wx_lut)[RESTORATION_WIN];
+    uint8_t *wr_lut_;
 
-    if (rst->bilateral_level[tile_idx] < 0) continue;
+    if (level < 0) continue;
+    wr_lut_ = (rst->keyframe ? bilateral_filter_coeffs_r_kf[level]
+                             : bilateral_filter_coeffs_r[level]) +
+              BILATERAL_AMP_RANGE;
+    wx_lut = rst->keyframe ? bilateral_filter_coeffs_s_kf[level]
+                           : bilateral_filter_coeffs_s[level];
 
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
+    av1_get_rest_tile_limits(tile_idx, subtile_idx, BILATERAL_SUBTILE_BITS,
+                             rst->nhtiles, rst->nvtiles, tile_width,
+                             tile_height, width, height, 1, 1, &h_start, &h_end,
+                             &v_start, &v_end);
 
     data_p = data + h_start + v_start * stride;
     tmpdata_p = tmpdata + h_start + v_start * tmpstride;
@@ -264,8 +219,7 @@
         uint8_t *data_p2 = data_p + j - RESTORATION_HALFWIN * stride;
         for (y = -RESTORATION_HALFWIN; y <= RESTORATION_HALFWIN; ++y) {
           for (x = -RESTORATION_HALFWIN; x <= RESTORATION_HALFWIN; ++x) {
-            wt = (int)rst->wx_lut[tile_idx][y + RESTORATION_HALFWIN]
-                                 [x + RESTORATION_HALFWIN] *
+            wt = (int)wx_lut[y + RESTORATION_HALFWIN][x + RESTORATION_HALFWIN] *
                  (int)wr_lut_[data_p2[x] - data_p[j]];
             wtsum += wt;
             flsum += wt * data_p2[x];
@@ -287,6 +241,16 @@
   }
 }
 
+static void loop_bilateral_filter(uint8_t *data, int width, int height,
+                                  int stride, RestorationInternal *rst,
+                                  uint8_t *tmpdata, int tmpstride) {
+  int tile_idx;
+  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+    loop_bilateral_filter_tile(data, tile_idx, width, height, stride, rst,
+                               tmpdata, tmpstride);
+  }
+}
+
 uint8_t hor_sym_filter(uint8_t *d, int *hfilter) {
   int32_t s =
       (1 << (RESTORATION_FILT_BITS - 1)) + d[0] * hfilter[RESTORATION_HALFWIN];
@@ -305,17 +269,52 @@
   return clip_pixel(s >> RESTORATION_FILT_BITS);
 }
 
+static void loop_wiener_filter_tile(uint8_t *data, int tile_idx, int width,
+                                    int height, int stride,
+                                    RestorationInternal *rst, uint8_t *tmpdata,
+                                    int tmpstride) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int i, j;
+  int h_start, h_end, v_start, v_end;
+  uint8_t *data_p, *tmpdata_p;
+
+  if (rst->rsi->wiener_level[tile_idx] == 0) return;
+  // Filter row-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 1, 0,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *tmpdata_p++ = hor_sym_filter(data_p++, rst->rsi->hfilter[tile_idx]);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+  // Filter col-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 0, 1,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *data_p++ =
+          ver_sym_filter(tmpdata_p++, tmpstride, rst->rsi->vfilter[tile_idx]);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+}
+
 static void loop_wiener_filter(uint8_t *data, int width, int height, int stride,
                                RestorationInternal *rst, uint8_t *tmpdata,
                                int tmpstride) {
-  int i, j, tile_idx, htile_idx, vtile_idx;
-  int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
+  int i, tile_idx;
   uint8_t *data_p, *tmpdata_p;
 
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
-
   // Initialize tmp buffer
   data_p = data;
   tmpdata_p = tmpdata;
@@ -324,88 +323,65 @@
     data_p += stride;
     tmpdata_p += tmpstride;
   }
-
-  // Filter row-wise tile-by-tile
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start = vtile_idx * tile_height;
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : height;
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *tmpdata_p++ = hor_sym_filter(data_p++, rst->hfilter[tile_idx]);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
-    }
+    loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst, tmpdata,
+                            tmpstride);
   }
+}
 
-  // Filter column-wise tile-by-tile (bands of thickness RESTORATION_HALFWIN
-  // at top and bottom of tiles allow filtering overlap, and are not optimally
-  // filtered)
+static void loop_switchable_filter(uint8_t *data, int width, int height,
+                                   int stride, RestorationInternal *rst,
+                                   uint8_t *tmpdata, int tmpstride) {
+  int i, tile_idx;
+  uint8_t *data_p, *tmpdata_p;
+
+  // Initialize tmp buffer
+  data_p = data;
+  tmpdata_p = tmpdata;
+  for (i = 0; i < height; ++i) {
+    memcpy(tmpdata_p, data_p, sizeof(*data_p) * width);
+    data_p += stride;
+    tmpdata_p += tmpstride;
+  }
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start = htile_idx * tile_width;
-    h_end =
-        (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width) : width;
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *data_p++ =
-            ver_sym_filter(tmpdata_p++, tmpstride, rst->vfilter[tile_idx]);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
+    if (rst->rsi->restoration_type[tile_idx] == RESTORE_BILATERAL) {
+      loop_bilateral_filter_tile(data, tile_idx, width, height, stride, rst,
+                                 tmpdata, tmpstride);
+    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+      loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst,
+                              tmpdata, tmpstride);
     }
   }
 }
 
 #if CONFIG_AOM_HIGHBITDEPTH
-static void loop_bilateral_filter_highbd(uint8_t *data8, int width, int height,
-                                         int stride, RestorationInternal *rst,
-                                         uint8_t *tmpdata8, int tmpstride,
-                                         int bit_depth) {
-  int i, j, tile_idx, htile_idx, vtile_idx;
+static void loop_bilateral_filter_tile_highbd(uint16_t *data, int tile_idx,
+                                              int width, int height, int stride,
+                                              RestorationInternal *rst,
+                                              uint16_t *tmpdata, int tmpstride,
+                                              int bit_depth) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int i, j, subtile_idx;
   int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
 
-  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
-  uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
-
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
-
-  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+  for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
     uint16_t *data_p, *tmpdata_p;
-    const uint8_t *wr_lut_ = rst->wr_lut[tile_idx] + BILATERAL_AMP_RANGE;
+    const int level =
+        rst->rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + subtile_idx];
+    uint8_t(*wx_lut)[RESTORATION_WIN];
+    uint8_t *wr_lut_;
 
-    if (rst->bilateral_level[tile_idx] < 0) continue;
-
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
+    if (level < 0) continue;
+    wr_lut_ = (rst->keyframe ? bilateral_filter_coeffs_r_kf[level]
+                             : bilateral_filter_coeffs_r[level]) +
+              BILATERAL_AMP_RANGE;
+    wx_lut = rst->keyframe ? bilateral_filter_coeffs_s_kf[level]
+                           : bilateral_filter_coeffs_s[level];
+    av1_get_rest_tile_limits(tile_idx, subtile_idx, BILATERAL_SUBTILE_BITS,
+                             rst->nhtiles, rst->nvtiles, tile_width,
+                             tile_height, width, height, 1, 1, &h_start, &h_end,
+                             &v_start, &v_end);
 
     data_p = data + h_start + v_start * stride;
     tmpdata_p = tmpdata + h_start + v_start * tmpstride;
@@ -417,8 +393,7 @@
         uint16_t *data_p2 = data_p + j - RESTORATION_HALFWIN * stride;
         for (y = -RESTORATION_HALFWIN; y <= RESTORATION_HALFWIN; ++y) {
           for (x = -RESTORATION_HALFWIN; x <= RESTORATION_HALFWIN; ++x) {
-            wt = (int)rst->wx_lut[tile_idx][y + RESTORATION_HALFWIN]
-                                 [x + RESTORATION_HALFWIN] *
+            wt = (int)wx_lut[y + RESTORATION_HALFWIN][x + RESTORATION_HALFWIN] *
                  (int)wr_lut_[data_p2[x] - data_p[j]];
             wtsum += wt;
             flsum += wt * data_p2[x];
@@ -441,6 +416,20 @@
   }
 }
 
+static void loop_bilateral_filter_highbd(uint8_t *data8, int width, int height,
+                                         int stride, RestorationInternal *rst,
+                                         uint8_t *tmpdata8, int tmpstride,
+                                         int bit_depth) {
+  int tile_idx;
+  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
+  uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
+
+  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+    loop_bilateral_filter_tile_highbd(data, tile_idx, width, height, stride,
+                                      rst, tmpdata, tmpstride, bit_depth);
+  }
+}
+
 uint16_t hor_sym_filter_highbd(uint16_t *d, int *hfilter, int bd) {
   int32_t s =
       (1 << (RESTORATION_FILT_BITS - 1)) + d[0] * hfilter[RESTORATION_HALFWIN];
@@ -459,20 +448,57 @@
   return clip_pixel_highbd(s >> RESTORATION_FILT_BITS, bd);
 }
 
+static void loop_wiener_filter_tile_highbd(uint16_t *data, int tile_idx,
+                                           int width, int height, int stride,
+                                           RestorationInternal *rst,
+                                           uint16_t *tmpdata, int tmpstride,
+                                           int bit_depth) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int h_start, h_end, v_start, v_end;
+  int i, j;
+  uint16_t *data_p, *tmpdata_p;
+
+  if (rst->rsi->wiener_level[tile_idx] == 0) return;
+  // Filter row-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 1, 0,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *tmpdata_p++ = hor_sym_filter_highbd(
+          data_p++, rst->rsi->hfilter[tile_idx], bit_depth);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+  // Filter col-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 0, 1,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *data_p++ = ver_sym_filter_highbd(tmpdata_p++, tmpstride,
+                                        rst->rsi->vfilter[tile_idx], bit_depth);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+}
+
 static void loop_wiener_filter_highbd(uint8_t *data8, int width, int height,
                                       int stride, RestorationInternal *rst,
                                       uint8_t *tmpdata8, int tmpstride,
                                       int bit_depth) {
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
   uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
-  int i, j, tile_idx, htile_idx, vtile_idx;
-  int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
+  int i, tile_idx;
   uint16_t *data_p, *tmpdata_p;
 
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
-
   // Initialize tmp buffer
   data_p = data;
   tmpdata_p = tmpdata;
@@ -481,54 +507,36 @@
     data_p += stride;
     tmpdata_p += tmpstride;
   }
-
-  // Filter row-wise tile-by-tile
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start = vtile_idx * tile_height;
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : height;
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *tmpdata_p++ =
-            hor_sym_filter_highbd(data_p++, rst->hfilter[tile_idx], bit_depth);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
-    }
+    loop_wiener_filter_tile_highbd(data, tile_idx, width, height, stride, rst,
+                                   tmpdata, tmpstride, bit_depth);
   }
+}
 
-  // Filter column-wise tile-by-tile (bands of thickness RESTORATION_HALFWIN
-  // at top and bottom of tiles allow filtering overlap, and are not optimally
-  // filtered)
+static void loop_switchable_filter_highbd(uint8_t *data8, int width, int height,
+                                          int stride, RestorationInternal *rst,
+                                          uint8_t *tmpdata8, int tmpstride,
+                                          int bit_depth) {
+  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
+  uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
+  int i, tile_idx;
+  uint16_t *data_p, *tmpdata_p;
+
+  // Initialize tmp buffer
+  data_p = data;
+  tmpdata_p = tmpdata;
+  for (i = 0; i < height; ++i) {
+    memcpy(tmpdata_p, data_p, sizeof(*data_p) * width);
+    data_p += stride;
+    tmpdata_p += tmpstride;
+  }
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start = htile_idx * tile_width;
-    h_end =
-        (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width) : width;
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *data_p++ = ver_sym_filter_highbd(tmpdata_p++, tmpstride,
-                                          rst->vfilter[tile_idx], bit_depth);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
+    if (rst->rsi->restoration_type[tile_idx] == RESTORE_BILATERAL) {
+      loop_bilateral_filter_tile_highbd(data, tile_idx, width, height, stride,
+                                        rst, tmpdata, tmpstride, bit_depth);
+    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+      loop_wiener_filter_tile_highbd(data, tile_idx, width, height, stride, rst,
+                                     tmpdata, tmpstride, bit_depth);
     }
   }
 }
@@ -545,16 +553,23 @@
   int yend = end_mi_row << MI_SIZE_LOG2;
   int uvend = yend >> cm->subsampling_y;
   restore_func_type restore_func =
-      cm->rst_internal.restoration_type == RESTORE_BILATERAL
+      cm->rst_internal.rsi->frame_restoration_type == RESTORE_BILATERAL
           ? loop_bilateral_filter
-          : loop_wiener_filter;
+          : (cm->rst_internal.rsi->frame_restoration_type == RESTORE_WIENER
+                 ? loop_wiener_filter
+                 : loop_switchable_filter);
 #if CONFIG_AOM_HIGHBITDEPTH
   restore_func_highbd_type restore_func_highbd =
-      cm->rst_internal.restoration_type == RESTORE_BILATERAL
+      cm->rst_internal.rsi->frame_restoration_type == RESTORE_BILATERAL
           ? loop_bilateral_filter_highbd
-          : loop_wiener_filter_highbd;
+          : (cm->rst_internal.rsi->frame_restoration_type == RESTORE_WIENER
+                 ? loop_wiener_filter_highbd
+                 : loop_switchable_filter_highbd);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
   YV12_BUFFER_CONFIG tmp_buf;
+
+  if (cm->rst_internal.rsi->frame_restoration_type == RESTORE_NONE) return;
+
   memset(&tmp_buf, 0, sizeof(YV12_BUFFER_CONFIG));
 
   yend = AOMMIN(yend, cm->height);
@@ -609,25 +624,13 @@
 #endif  // CONFIG_AOM_HIGHBITDEPTH
   }
   aom_free_frame_buffer(&tmp_buf);
-  if (cm->rst_internal.restoration_type == RESTORE_BILATERAL) {
-    free(cm->rst_internal.wr_lut);
-    cm->rst_internal.wr_lut = NULL;
-    free(cm->rst_internal.wx_lut);
-    cm->rst_internal.wx_lut = NULL;
-  }
-  if (cm->rst_internal.restoration_type == RESTORE_WIENER) {
-    free(cm->rst_internal.vfilter);
-    cm->rst_internal.vfilter = NULL;
-    free(cm->rst_internal.hfilter);
-    cm->rst_internal.hfilter = NULL;
-  }
 }
 
 void av1_loop_restoration_frame(YV12_BUFFER_CONFIG *frame, AV1_COMMON *cm,
                                 RestorationInfo *rsi, int y_only,
                                 int partial_frame) {
   int start_mi_row, end_mi_row, mi_rows_to_filter;
-  if (rsi->restoration_type != RESTORE_NONE) {
+  if (rsi->frame_restoration_type != RESTORE_NONE) {
     start_mi_row = 0;
     mi_rows_to_filter = cm->mi_rows;
     if (partial_frame && cm->mi_rows > 8) {
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index d8a312d..3d4802f 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -26,9 +26,10 @@
 #define BILATERAL_LEVELS (1 << BILATERAL_LEVEL_BITS)
 // #define DEF_BILATERAL_LEVEL     2
 
-#define RESTORATION_TILESIZES 3
-#define BILATERAL_TILESIZE 1
-#define WIENER_TILESIZE 2
+#define RESTORATION_TILESIZE_SML 128
+#define RESTORATION_TILESIZE_BIG 256
+#define BILATERAL_SUBTILE_BITS 1
+#define BILATERAL_SUBTILES (1 << (2 * BILATERAL_SUBTILE_BITS))
 
 #define RESTORATION_HALFWIN 3
 #define RESTORATION_HALFWIN1 (RESTORATION_HALFWIN + 1)
@@ -56,43 +57,84 @@
 #define WIENER_FILT_TAP2_MAXV \
   (WIENER_FILT_TAP2_MINV - 1 + (1 << WIENER_FILT_TAP2_BITS))
 
-typedef enum {
-  RESTORE_NONE,
-  RESTORE_BILATERAL,
-  RESTORE_WIENER,
-} RestorationType;
-
 typedef struct {
-  RestorationType restoration_type;
+  RestorationType frame_restoration_type;
+  RestorationType *restoration_type;
   // Bilateral filter
   int *bilateral_level;
   // Wiener filter
   int *wiener_level;
-  int (*vfilter)[RESTORATION_HALFWIN], (*hfilter)[RESTORATION_HALFWIN];
-} RestorationInfo;
-
-typedef struct {
-  RestorationType restoration_type;
-  int subsampling_x;
-  int subsampling_y;
-  int tilesize_index;
-  int ntiles;
-  int tile_width, tile_height;
-  int nhtiles, nvtiles;
-  // Bilateral filter
-  int *bilateral_level;
-  uint8_t (**wx_lut)[RESTORATION_WIN];
-  uint8_t **wr_lut;
-  // Wiener filter
-  int *wiener_level;
   int (*vfilter)[RESTORATION_WIN], (*hfilter)[RESTORATION_WIN];
+} RestorationInfo;
+
+typedef struct {
+  RestorationInfo *rsi;
+  int keyframe;
+  int subsampling_x;
+  int subsampling_y;
+  int ntiles;
+  int tile_width, tile_height;
+  int nhtiles, nvtiles;
 } RestorationInternal;
 
+static INLINE int get_rest_tilesize(int width, int height) {
+  if (width * height <= 352 * 288)
+    return RESTORATION_TILESIZE_SML;
+  else
+    return RESTORATION_TILESIZE_BIG;
+}
+
+static INLINE int av1_get_rest_ntiles(int width, int height,
+                                      int *tile_width, int *tile_height,
+                                      int *nhtiles, int *nvtiles) {
+  int nhtiles_, nvtiles_;
+  int tile_width_, tile_height_;
+  int tilesize = get_rest_tilesize(width, height);
+  tile_width_ = (tilesize < 0) ? width : AOMMIN(tilesize, width);
+  tile_height_ = (tilesize < 0) ? height : AOMMIN(tilesize, height);
+  nhtiles_ = (width + (tile_width_ >> 1)) / tile_width_;
+  nvtiles_ = (height + (tile_height_ >> 1)) / tile_height_;
+  if (tile_width) *tile_width = tile_width_;
+  if (tile_height) *tile_height = tile_height_;
+  if (nhtiles) *nhtiles = nhtiles_;
+  if (nvtiles) *nvtiles = nvtiles_;
+  return (nhtiles_ * nvtiles_);
+}
+
+static INLINE void av1_get_rest_tile_limits(
+    int tile_idx, int subtile_idx, int subtile_bits, int nhtiles, int nvtiles,
+    int tile_width, int tile_height, int im_width, int im_height, int clamp_h,
+    int clamp_v, int *h_start, int *h_end, int *v_start, int *v_end) {
+  const int htile_idx = tile_idx % nhtiles;
+  const int vtile_idx = tile_idx / nhtiles;
+  *h_start = htile_idx * tile_width;
+  *v_start = vtile_idx * tile_height;
+  *h_end = (htile_idx < nhtiles - 1) ? *h_start + tile_width : im_width;
+  *v_end = (vtile_idx < nvtiles - 1) ? *v_start + tile_height : im_height;
+  if (subtile_bits) {
+    const int num_subtiles_1d = (1 << subtile_bits);
+    const int subtile_width = (*h_end - *h_start) >> subtile_bits;
+    const int subtile_height = (*v_end - *v_start) >> subtile_bits;
+    const int subtile_idx_h = subtile_idx & (num_subtiles_1d - 1);
+    const int subtile_idx_v = subtile_idx >> subtile_bits;
+    *h_start += subtile_idx_h * subtile_width;
+    *v_start += subtile_idx_v * subtile_height;
+    *h_end = subtile_idx_h == num_subtiles_1d - 1 ? *h_end
+                                                  : *h_start + subtile_width;
+    *v_end = subtile_idx_v == num_subtiles_1d - 1 ? *v_end
+                                                  : *v_start + subtile_height;
+  }
+  if (clamp_h) {
+    *h_start = AOMMAX(*h_start, RESTORATION_HALFWIN);
+    *h_end = AOMMIN(*h_end, im_width - RESTORATION_HALFWIN);
+  }
+  if (clamp_v) {
+    *v_start = AOMMAX(*v_start, RESTORATION_HALFWIN);
+    *v_end = AOMMIN(*v_end, im_height - RESTORATION_HALFWIN);
+  }
+}
+
 int av1_bilateral_level_bits(const struct AV1Common *const cm);
-int av1_get_restoration_ntiles(int tilesize, int width, int height);
-void av1_get_restoration_tile_size(int tilesize, int width, int height,
-                                   int *tile_width, int *tile_height,
-                                   int *nhtiles, int *nvtiles);
 void av1_loop_restoration_init(RestorationInternal *rst, RestorationInfo *rsi,
                                int kf, int width, int height);
 void av1_loop_restoration_frame(YV12_BUFFER_CONFIG *frame, struct AV1Common *cm,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index de0b502..2f32e94 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -1899,62 +1899,134 @@
 }
 
 #if CONFIG_LOOP_RESTORATION
-static void setup_restoration(AV1_COMMON *cm, struct aom_read_bit_buffer *rb) {
+static void decode_restoration_mode(AV1_COMMON *cm,
+                                    struct aom_read_bit_buffer *rb) {
+  RestorationInfo *rsi = &cm->rst_info;
+  if (aom_rb_read_bit(rb)) {
+    rsi->frame_restoration_type =
+        aom_rb_read_bit(rb) ? RESTORE_WIENER : RESTORE_BILATERAL;
+  } else {
+    rsi->frame_restoration_type =
+        aom_rb_read_bit(rb) ? RESTORE_SWITCHABLE : RESTORE_NONE;
+  }
+}
+
+static void decode_restoration(AV1_COMMON *cm, aom_reader *rb) {
   int i;
   RestorationInfo *rsi = &cm->rst_info;
-  int ntiles;
-  if (aom_rb_read_bit(rb)) {
-    if (aom_rb_read_bit(rb)) {
-      rsi->restoration_type = RESTORE_BILATERAL;
-      ntiles =
-          av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height,
+                                         NULL, NULL, NULL, NULL);
+  if (rsi->frame_restoration_type != RESTORE_NONE) {
+    rsi->restoration_type = (RestorationType *)aom_realloc(
+        rsi->restoration_type, sizeof(*rsi->restoration_type) * ntiles);
+    if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
       rsi->bilateral_level = (int *)aom_realloc(
-          rsi->bilateral_level, sizeof(*rsi->bilateral_level) * ntiles);
+          rsi->bilateral_level,
+          sizeof(*rsi->bilateral_level) * ntiles * BILATERAL_SUBTILES);
       assert(rsi->bilateral_level != NULL);
-      for (i = 0; i < ntiles; ++i) {
-        if (aom_rb_read_bit(rb)) {
-          rsi->bilateral_level[i] =
-              aom_rb_read_literal(rb, av1_bilateral_level_bits(cm));
-        } else {
-          rsi->bilateral_level[i] = -1;
-        }
-      }
-    } else {
-      rsi->restoration_type = RESTORE_WIENER;
-      ntiles =
-          av1_get_restoration_ntiles(WIENER_TILESIZE, cm->width, cm->height);
       rsi->wiener_level = (int *)aom_realloc(
           rsi->wiener_level, sizeof(*rsi->wiener_level) * ntiles);
       assert(rsi->wiener_level != NULL);
-      rsi->vfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+      rsi->vfilter = (int(*)[RESTORATION_WIN])aom_realloc(
           rsi->vfilter, sizeof(*rsi->vfilter) * ntiles);
       assert(rsi->vfilter != NULL);
-      rsi->hfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+      rsi->hfilter = (int(*)[RESTORATION_WIN])aom_realloc(
           rsi->hfilter, sizeof(*rsi->hfilter) * ntiles);
       assert(rsi->hfilter != NULL);
       for (i = 0; i < ntiles; ++i) {
-        rsi->wiener_level[i] = aom_rb_read_bit(rb);
-        if (rsi->wiener_level[i]) {
-          rsi->vfilter[i][0] = aom_rb_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+        rsi->restoration_type[i] = aom_read_tree(
+            rb, av1_switchable_restore_tree, cm->fc->switchable_restore_prob);
+        if (rsi->restoration_type[i] == RESTORE_WIENER) {
+          rsi->wiener_level[i] = 1;
+          rsi->vfilter[i][0] =
+              aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+              WIENER_FILT_TAP0_MINV;
+          rsi->vfilter[i][1] =
+              aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+              WIENER_FILT_TAP1_MINV;
+          rsi->vfilter[i][2] =
+              aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+              WIENER_FILT_TAP2_MINV;
+          rsi->hfilter[i][0] =
+              aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+              WIENER_FILT_TAP0_MINV;
+          rsi->hfilter[i][1] =
+              aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+              WIENER_FILT_TAP1_MINV;
+          rsi->hfilter[i][2] =
+              aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+              WIENER_FILT_TAP2_MINV;
+        } else if (rsi->restoration_type[i] == RESTORE_BILATERAL) {
+          int s;
+          for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+            const int j = i * BILATERAL_SUBTILES + s;
+#if BILATERAL_SUBTILES == 0
+            rsi->bilateral_level[j] =
+                aom_read_literal(rb, av1_bilateral_level_bits(cm));
+#else
+            if (aom_read(rb, RESTORE_NONE_BILATERAL_PROB)) {
+              rsi->bilateral_level[j] =
+                  aom_read_literal(rb, av1_bilateral_level_bits(cm));
+            } else {
+              rsi->bilateral_level[j] = -1;
+            }
+#endif
+          }
+        }
+      }
+    } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
+      rsi->wiener_level = (int *)aom_realloc(
+          rsi->wiener_level, sizeof(*rsi->wiener_level) * ntiles);
+      assert(rsi->wiener_level != NULL);
+      rsi->vfilter = (int(*)[RESTORATION_WIN])aom_realloc(
+          rsi->vfilter, sizeof(*rsi->vfilter) * ntiles);
+      assert(rsi->vfilter != NULL);
+      rsi->hfilter = (int(*)[RESTORATION_WIN])aom_realloc(
+          rsi->hfilter, sizeof(*rsi->hfilter) * ntiles);
+      assert(rsi->hfilter != NULL);
+      for (i = 0; i < ntiles; ++i) {
+        if (aom_read(rb, RESTORE_NONE_WIENER_PROB)) {
+          rsi->wiener_level[i] = 1;
+          rsi->restoration_type[i] = RESTORE_WIENER;
+          rsi->vfilter[i][0] = aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
                                WIENER_FILT_TAP0_MINV;
-          rsi->vfilter[i][1] = aom_rb_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+          rsi->vfilter[i][1] = aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
                                WIENER_FILT_TAP1_MINV;
-          rsi->vfilter[i][2] = aom_rb_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+          rsi->vfilter[i][2] = aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
                                WIENER_FILT_TAP2_MINV;
-          rsi->hfilter[i][0] = aom_rb_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+          rsi->hfilter[i][0] = aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
                                WIENER_FILT_TAP0_MINV;
-          rsi->hfilter[i][1] = aom_rb_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+          rsi->hfilter[i][1] = aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
                                WIENER_FILT_TAP1_MINV;
-          rsi->hfilter[i][2] = aom_rb_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+          rsi->hfilter[i][2] = aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
                                WIENER_FILT_TAP2_MINV;
         } else {
-          rsi->vfilter[i][0] = rsi->vfilter[i][1] = rsi->vfilter[i][2] = 0;
-          rsi->hfilter[i][0] = rsi->hfilter[i][1] = rsi->hfilter[i][2] = 0;
+          rsi->wiener_level[i] = 0;
+          rsi->restoration_type[i] = RESTORE_NONE;
+        }
+      }
+    } else {
+      rsi->frame_restoration_type = RESTORE_BILATERAL;
+      rsi->bilateral_level = (int *)aom_realloc(
+          rsi->bilateral_level,
+          sizeof(*rsi->bilateral_level) * ntiles * BILATERAL_SUBTILES);
+      assert(rsi->bilateral_level != NULL);
+      for (i = 0; i < ntiles; ++i) {
+        int s;
+        rsi->restoration_type[i] = RESTORE_BILATERAL;
+        for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+          const int j = i * BILATERAL_SUBTILES + s;
+          if (aom_read(rb, RESTORE_NONE_BILATERAL_PROB)) {
+            rsi->bilateral_level[j] =
+                aom_read_literal(rb, av1_bilateral_level_bits(cm));
+          } else {
+            rsi->bilateral_level[j] = -1;
+          }
         }
       }
     }
   } else {
-    rsi->restoration_type = RESTORE_NONE;
+    rsi->frame_restoration_type = RESTORE_NONE;
   }
 }
 #endif  // CONFIG_LOOP_RESTORATION
@@ -3286,7 +3358,7 @@
   setup_dering(cm, rb);
 #endif
 #if CONFIG_LOOP_RESTORATION
-  setup_restoration(cm, rb);
+  decode_restoration_mode(cm, rb);
 #endif  // CONFIG_LOOP_RESTORATION
   setup_quantization(cm, rb);
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -3468,6 +3540,10 @@
                        "Failed to allocate compressed header ANS decoder");
 #endif  // !CONFIG_ANS
 
+#if CONFIG_LOOP_RESTORATION
+  decode_restoration(cm, &r);
+#endif
+
   if (cm->tx_mode == TX_MODE_SELECT) {
     for (i = 0; i < TX_SIZES - 1; ++i)
       for (j = 0; j < TX_SIZE_CONTEXTS; ++j)
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 6578c0c..f09a5cd 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -150,6 +150,9 @@
 #if CONFIG_OBMC || CONFIG_WARPED_MOTION
 static struct av1_token motvar_encodings[MOTION_VARIATIONS];
 #endif  // CONFIG_OBMC || CONFIG_WARPED_MOTION
+#if CONFIG_LOOP_RESTORATION
+static struct av1_token switchable_restore_encodings[RESTORE_SWITCHABLE_TYPES];
+#endif  // CONFIG_LOOP_RESTORATION
 
 void av1_encode_token_init(void) {
 #if CONFIG_EXT_TX
@@ -176,6 +179,10 @@
   av1_tokens_from_tree(global_motion_types_encodings,
                        av1_global_motion_types_tree);
 #endif  // CONFIG_GLOBAL_MOTION
+#if CONFIG_LOOP_RESTORATION
+  av1_tokens_from_tree(switchable_restore_encodings,
+                       av1_switchable_restore_tree);
+#endif  // CONFIG_LOOP_RESTORATION
 }
 
 static void write_intra_mode(aom_writer *w, PREDICTION_MODE mode,
@@ -2420,42 +2427,102 @@
 }
 
 #if CONFIG_LOOP_RESTORATION
-static void encode_restoration(AV1_COMMON *cm,
-                               struct aom_write_bit_buffer *wb) {
+static void encode_restoration_mode(AV1_COMMON *cm,
+                                    struct aom_write_bit_buffer *wb) {
+  RestorationInfo *rst = &cm->rst_info;
+  switch (rst->frame_restoration_type) {
+    case RESTORE_NONE:
+      aom_wb_write_bit(wb, 0);
+      aom_wb_write_bit(wb, 0);
+      break;
+    case RESTORE_SWITCHABLE:
+      aom_wb_write_bit(wb, 0);
+      aom_wb_write_bit(wb, 1);
+      break;
+    case RESTORE_BILATERAL:
+      aom_wb_write_bit(wb, 1);
+      aom_wb_write_bit(wb, 0);
+      break;
+    case RESTORE_WIENER:
+      aom_wb_write_bit(wb, 1);
+      aom_wb_write_bit(wb, 1);
+      break;
+    default: assert(0);
+  }
+}
+
+static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
   int i;
   RestorationInfo *rst = &cm->rst_info;
-  aom_wb_write_bit(wb, rst->restoration_type != RESTORE_NONE);
-  if (rst->restoration_type != RESTORE_NONE) {
-    if (rst->restoration_type == RESTORE_BILATERAL) {
-      aom_wb_write_bit(wb, 1);
+  if (rst->frame_restoration_type != RESTORE_NONE) {
+    if (rst->frame_restoration_type == RESTORE_SWITCHABLE) {
+      // RESTORE_SWITCHABLE
       for (i = 0; i < cm->rst_internal.ntiles; ++i) {
-        if (rst->bilateral_level[i] >= 0) {
-          aom_wb_write_bit(wb, 1);
-          aom_wb_write_literal(wb, rst->bilateral_level[i],
-                               av1_bilateral_level_bits(cm));
+        av1_write_token(
+            wb, av1_switchable_restore_tree,
+            cm->fc->switchable_restore_prob,
+            &switchable_restore_encodings[rst->restoration_type[i]]);
+        if (rst->restoration_type[i] == RESTORE_NONE) {
+        } else if (rst->restoration_type[i] == RESTORE_BILATERAL) {
+          int s;
+          for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+            const int j = i * BILATERAL_SUBTILES + s;
+#if BILATERAL_SUBTILES == 0
+            aom_write_literal(wb, rst->bilateral_level[j],
+                              av1_bilateral_level_bits(cm));
+#else
+            aom_write(wb, rst->bilateral_level[j] >= 0,
+                      RESTORE_NONE_BILATERAL_PROB);
+            if (rst->bilateral_level[j] >= 0) {
+              aom_write_literal(wb, rst->bilateral_level[j],
+                                av1_bilateral_level_bits(cm));
+            }
+#endif
+          }
         } else {
-          aom_wb_write_bit(wb, 0);
+          aom_write_literal(wb, rst->vfilter[i][0] - WIENER_FILT_TAP0_MINV,
+                            WIENER_FILT_TAP0_BITS);
+          aom_write_literal(wb, rst->vfilter[i][1] - WIENER_FILT_TAP1_MINV,
+                            WIENER_FILT_TAP1_BITS);
+          aom_write_literal(wb, rst->vfilter[i][2] - WIENER_FILT_TAP2_MINV,
+                               WIENER_FILT_TAP2_BITS);
+          aom_write_literal(wb, rst->hfilter[i][0] - WIENER_FILT_TAP0_MINV,
+                               WIENER_FILT_TAP0_BITS);
+          aom_write_literal(wb, rst->hfilter[i][1] - WIENER_FILT_TAP1_MINV,
+                               WIENER_FILT_TAP1_BITS);
+          aom_write_literal(wb, rst->hfilter[i][2] - WIENER_FILT_TAP2_MINV,
+                               WIENER_FILT_TAP2_BITS);
         }
       }
-    } else {
-      aom_wb_write_bit(wb, 0);
+    } else if (rst->frame_restoration_type == RESTORE_BILATERAL) {
       for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+        int s;
+        for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+          const int j = i * BILATERAL_SUBTILES + s;
+          aom_write(wb, rst->bilateral_level[j] >= 0,
+                    RESTORE_NONE_BILATERAL_PROB);
+          if (rst->bilateral_level[j] >= 0) {
+            aom_write_literal(wb, rst->bilateral_level[j],
+                                 av1_bilateral_level_bits(cm));
+          }
+        }
+      }
+    } else if (rst->frame_restoration_type == RESTORE_WIENER) {
+      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+        aom_write(wb, rst->wiener_level[i] != 0, RESTORE_NONE_WIENER_PROB);
         if (rst->wiener_level[i]) {
-          aom_wb_write_bit(wb, 1);
-          aom_wb_write_literal(wb, rst->vfilter[i][0] - WIENER_FILT_TAP0_MINV,
-                               WIENER_FILT_TAP0_BITS);
-          aom_wb_write_literal(wb, rst->vfilter[i][1] - WIENER_FILT_TAP1_MINV,
-                               WIENER_FILT_TAP1_BITS);
-          aom_wb_write_literal(wb, rst->vfilter[i][2] - WIENER_FILT_TAP2_MINV,
-                               WIENER_FILT_TAP2_BITS);
-          aom_wb_write_literal(wb, rst->hfilter[i][0] - WIENER_FILT_TAP0_MINV,
-                               WIENER_FILT_TAP0_BITS);
-          aom_wb_write_literal(wb, rst->hfilter[i][1] - WIENER_FILT_TAP1_MINV,
-                               WIENER_FILT_TAP1_BITS);
-          aom_wb_write_literal(wb, rst->hfilter[i][2] - WIENER_FILT_TAP2_MINV,
-                               WIENER_FILT_TAP2_BITS);
-        } else {
-          aom_wb_write_bit(wb, 0);
+          aom_write_literal(wb, rst->vfilter[i][0] - WIENER_FILT_TAP0_MINV,
+                            WIENER_FILT_TAP0_BITS);
+          aom_write_literal(wb, rst->vfilter[i][1] - WIENER_FILT_TAP1_MINV,
+                            WIENER_FILT_TAP1_BITS);
+          aom_write_literal(wb, rst->vfilter[i][2] - WIENER_FILT_TAP2_MINV,
+                            WIENER_FILT_TAP2_BITS);
+          aom_write_literal(wb, rst->hfilter[i][0] - WIENER_FILT_TAP0_MINV,
+                            WIENER_FILT_TAP0_BITS);
+          aom_write_literal(wb, rst->hfilter[i][1] - WIENER_FILT_TAP1_MINV,
+                            WIENER_FILT_TAP1_BITS);
+          aom_write_literal(wb, rst->hfilter[i][2] - WIENER_FILT_TAP2_MINV,
+                            WIENER_FILT_TAP2_BITS);
         }
       }
     }
@@ -3183,7 +3250,7 @@
   encode_dering(cm->dering_level, wb);
 #endif  // CONFIG_DERING
 #if CONFIG_LOOP_RESTORATION
-  encode_restoration(cm, wb);
+  encode_restoration_mode(cm, wb);
 #endif  // CONFIG_LOOP_RESTORATION
   encode_quantization(cm, wb);
   encode_segmentation(cm, xd, wb);
@@ -3282,6 +3349,11 @@
   header_bc = &real_header_bc;
   aom_start_encode(header_bc, data);
 #endif
+
+#if CONFIG_LOOP_RESTORATION
+  encode_restoration(cm, header_bc);
+#endif  // CONFIG_LOOP_RESTORATION
+
   update_txfm_probs(cm, header_bc, counts);
   update_coef_probs(cpi, header_bc);
 
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 821d2f1..9902517 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -572,6 +572,9 @@
 #if CONFIG_EXT_INTRA
   int intra_filter_cost[INTRA_FILTERS + 1][INTRA_FILTERS];
 #endif  // CONFIG_EXT_INTRA
+#if CONFIG_LOOP_RESTORATION
+  int switchable_restore_cost[RESTORE_SWITCHABLE_TYPES];
+#endif  // CONFIG_LOOP_RESTORATION
 
   int multi_arf_allowed;
   int multi_arf_enabled;
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 22bd019..00e46a68 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -28,7 +28,53 @@
 #include "av1/encoder/pickrst.h"
 #include "av1/encoder/quantize.h"
 
-static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *sd,
+const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
+
+static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
+                                    AV1_COMMON *const cm, int h_start,
+                                    int width, int v_start, int height) {
+  int64_t filt_err;
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (cm->use_highbitdepth) {
+    filt_err = aom_highbd_get_y_sse_part(src, cm->frame_to_show, h_start, width,
+                                         v_start, height);
+  } else {
+    filt_err = aom_get_y_sse_part(src, cm->frame_to_show, h_start, width,
+                                  v_start, height);
+  }
+#else
+  filt_err = aom_get_y_sse_part(src, cm->frame_to_show, h_start, width, v_start,
+                                height);
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+  return filt_err;
+}
+
+static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
+                                    AV1_COMP *const cpi, RestorationInfo *rsi,
+                                    int partial_frame, int tile_idx,
+                                    int subtile_idx, int subtile_bits) {
+  AV1_COMMON *const cm = &cpi->common;
+  int64_t filt_err;
+  int tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
+  (void)ntiles;
+
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, 1, partial_frame);
+  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
+                           nvtiles, tile_width, tile_height, cm->width,
+                           cm->height, 0, 0, &h_start, &h_end, &v_start,
+                           &v_end);
+  filt_err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                                  v_end - v_start);
+
+  // Re-instate the unfiltered frame
+  aom_yv12_copy_y(&cpi->last_frame_db, cm->frame_to_show);
+  return filt_err;
+}
+
+static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *const cpi, RestorationInfo *rsi,
                                      int partial_frame) {
   AV1_COMMON *const cm = &cpi->common;
@@ -36,12 +82,12 @@
   av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, 1, partial_frame);
 #if CONFIG_AOM_HIGHBITDEPTH
   if (cm->use_highbitdepth) {
-    filt_err = aom_highbd_get_y_sse(sd, cm->frame_to_show);
+    filt_err = aom_highbd_get_y_sse(src, cm->frame_to_show);
   } else {
-    filt_err = aom_get_y_sse(sd, cm->frame_to_show);
+    filt_err = aom_get_y_sse(src, cm->frame_to_show);
   }
 #else
-  filt_err = aom_get_y_sse(sd, cm->frame_to_show);
+  filt_err = aom_get_y_sse(src, cm->frame_to_show);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
   // Re-instate the unfiltered frame
@@ -49,20 +95,24 @@
   return filt_err;
 }
 
-static int search_bilateral_level(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi,
+static int search_bilateral_level(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                   int filter_level, int partial_frame,
-                                  int *bilateral_level, double *best_cost_ret) {
+                                  int *bilateral_level, double *best_cost_ret,
+                                  double *best_tile_cost) {
   AV1_COMMON *const cm = &cpi->common;
   int i, j, tile_idx;
   int64_t err;
   int bits;
-  double cost, best_cost, cost_norestore, cost_bilateral;
+  double cost, best_cost, cost_norestore, cost_bilateral,
+      cost_norestore_subtile;
   const int bilateral_level_bits = av1_bilateral_level_bits(&cpi->common);
   const int bilateral_levels = 1 << bilateral_level_bits;
   MACROBLOCK *x = &cpi->td.mb;
   RestorationInfo rsi;
-  const int ntiles =
-      av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
+  int tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
 
   //  Make a copy of the unfiltered / processed recon buffer
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
@@ -71,53 +121,94 @@
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
 
   // RD cost associated with no restoration
-  rsi.restoration_type = RESTORE_NONE;
-  err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
-  bits = 0;
-  cost_norestore =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
-  best_cost = cost_norestore;
+  rsi.frame_restoration_type = RESTORE_NONE;
+  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
+  // err = sse_restoration_tile(src, cm, 0, cm->width, 0, cm->height);
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   // RD cost associated with bilateral filtering
-  rsi.restoration_type = RESTORE_BILATERAL;
-  rsi.bilateral_level =
-      (int *)aom_malloc(sizeof(*rsi.bilateral_level) * ntiles);
+  rsi.frame_restoration_type = RESTORE_BILATERAL;
+  rsi.bilateral_level = (int *)aom_malloc(sizeof(*rsi.bilateral_level) *
+                                          ntiles * BILATERAL_SUBTILES);
   assert(rsi.bilateral_level != NULL);
 
-  for (j = 0; j < ntiles; ++j) bilateral_level[j] = -1;
+  for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) bilateral_level[j] = -1;
 
+  // TODO(debargha): This is a pretty inefficient way to find the best
+  // parameters per tile. Needs fixing.
   // Find best filter for each tile
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    for (j = 0; j < ntiles; ++j) rsi.bilateral_level[j] = -1;
-    best_cost = cost_norestore;
-    for (i = 0; i < bilateral_levels; ++i) {
-      rsi.bilateral_level[tile_idx] = i;
-      err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
-      bits = bilateral_level_bits + 1;
-      // Normally the rate is rate in bits * 256 and dist is sum sq err * 64
-      // when RDCOST is used.  However below we just scale both in the correct
-      // ratios appropriately but not exactly by these values.
-      cost = RDCOST_DBL(x->rdmult, x->rddiv,
-                        (bits << (AV1_PROB_COST_SHIFT - 4)), err);
-      if (cost < best_cost) {
-        bilateral_level[tile_idx] = i;
-        best_cost = cost;
+    int subtile_idx;
+    for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
+      const int fulltile_idx = tile_idx * BILATERAL_SUBTILES + subtile_idx;
+      av1_get_rest_tile_limits(tile_idx, subtile_idx, BILATERAL_SUBTILE_BITS,
+                               nhtiles, nvtiles, tile_width, tile_height,
+                               cm->width, cm->height, 0, 0, &h_start, &h_end,
+                               &v_start, &v_end);
+      err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                                 v_end - v_start);
+#if BILATERAL_SUBTILES
+      // #bits when a subtile is not restored
+      bits = av1_cost_bit(RESTORE_NONE_BILATERAL_PROB, 0);
+#else
+      bits = 0;
+#endif
+      cost_norestore_subtile =
+          RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+      best_cost = cost_norestore_subtile;
+      for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j)
+        rsi.bilateral_level[j] = -1;
+
+      for (i = 0; i < bilateral_levels; ++i) {
+        rsi.bilateral_level[fulltile_idx] = i;
+        err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx,
+                                   subtile_idx, BILATERAL_SUBTILE_BITS);
+        bits = bilateral_level_bits << AV1_PROB_COST_SHIFT;
+        bits += av1_cost_bit(RESTORE_NONE_BILATERAL_PROB, 1);
+        cost = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+        if (cost < best_cost) {
+          bilateral_level[fulltile_idx] = i;
+          best_cost = cost;
+        }
       }
     }
+    if (best_tile_cost) {
+      bits = 0;
+      for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j)
+        rsi.bilateral_level[j] = -1;
+      for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
+        const int fulltile_idx = tile_idx * BILATERAL_SUBTILES + subtile_idx;
+        rsi.bilateral_level[fulltile_idx] = bilateral_level[fulltile_idx];
+        if (rsi.bilateral_level[fulltile_idx] >= 0)
+          bits += bilateral_level_bits << AV1_PROB_COST_SHIFT;
+#if BILATERAL_SUBTILES
+        bits += av1_cost_bit(RESTORE_NONE_BILATERAL_PROB,
+                             rsi.bilateral_level[fulltile_idx] >= 0);
+#endif
+      }
+      err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0);
+      best_tile_cost[tile_idx] = RDCOST_DBL(
+          x->rdmult, x->rddiv,
+          (bits + cpi->switchable_restore_cost[RESTORE_BILATERAL]) >> 4, err);
+    }
   }
   // Find cost for combined configuration
-  bits = 0;
-  for (j = 0; j < ntiles; ++j) {
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
     rsi.bilateral_level[j] = bilateral_level[j];
     if (rsi.bilateral_level[j] >= 0) {
-      bits += (bilateral_level_bits + 1);
-    } else {
-      bits += 1;
+      bits += bilateral_level_bits << AV1_PROB_COST_SHIFT;
     }
+#if BILATERAL_SUBTILES
+    bits +=
+        av1_cost_bit(RESTORE_NONE_BILATERAL_PROB, rsi.bilateral_level[j] >= 0);
+#endif
   }
-  err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
-  cost_bilateral =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
+  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
+  cost_bilateral = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   aom_free(rsi.bilateral_level);
 
@@ -131,10 +222,11 @@
   }
 }
 
-static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
+static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *src,
                                          AV1_COMP *cpi, int partial_frame,
                                          int *filter_best, int *bilateral_level,
-                                         double *best_cost_ret) {
+                                         double *best_cost_ret,
+                                         double *best_tile_cost) {
   const AV1_COMMON *const cm = &cpi->common;
   const struct loopfilter *const lf = &cm->lf;
   const int min_filter_level = 0;
@@ -147,7 +239,8 @@
   int bilateral_success[MAX_LOOP_FILTER + 1];
 
   const int ntiles =
-      av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
+      av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
+  double *tile_cost = (double *)aom_malloc(sizeof(*tile_cost) * ntiles);
 
   // Start the search at the previous frame filter level unless it is now out of
   // range.
@@ -157,13 +250,14 @@
   // Set each entry to -1
   for (i = 0; i <= MAX_LOOP_FILTER; ++i) ss_err[i] = -1.0;
 
-  tmp_level = (int *)aom_malloc(sizeof(*tmp_level) * ntiles);
+  tmp_level =
+      (int *)aom_malloc(sizeof(*tmp_level) * ntiles * BILATERAL_SUBTILES);
 
   bilateral_success[filt_mid] = search_bilateral_level(
-      sd, cpi, filt_mid, partial_frame, tmp_level, &best_err);
+      src, cpi, filt_mid, partial_frame, tmp_level, &best_err, best_tile_cost);
   filt_best = filt_mid;
   ss_err[filt_mid] = best_err;
-  for (j = 0; j < ntiles; ++j) {
+  for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
     bilateral_level[j] = tmp_level[j];
   }
 
@@ -183,8 +277,9 @@
     if (filt_direction <= 0 && filt_low != filt_mid) {
       // Get Low filter error score
       if (ss_err[filt_low] < 0) {
-        bilateral_success[filt_low] = search_bilateral_level(
-            sd, cpi, filt_low, partial_frame, tmp_level, &ss_err[filt_low]);
+        bilateral_success[filt_low] =
+            search_bilateral_level(src, cpi, filt_low, partial_frame, tmp_level,
+                                   &ss_err[filt_low], tile_cost);
       }
       // If value is close to the best so far then bias towards a lower loop
       // filter value.
@@ -194,26 +289,29 @@
           best_err = ss_err[filt_low];
         }
         filt_best = filt_low;
-        for (j = 0; j < ntiles; ++j) {
+        for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
           bilateral_level[j] = tmp_level[j];
         }
+        memcpy(best_tile_cost, tile_cost, sizeof(*tile_cost) * ntiles);
       }
     }
 
     // Now look at filt_high
     if (filt_direction >= 0 && filt_high != filt_mid) {
       if (ss_err[filt_high] < 0) {
-        bilateral_success[filt_high] = search_bilateral_level(
-            sd, cpi, filt_high, partial_frame, tmp_level, &ss_err[filt_high]);
+        bilateral_success[filt_high] =
+            search_bilateral_level(src, cpi, filt_high, partial_frame,
+                                   tmp_level, &ss_err[filt_high], tile_cost);
       }
       // If value is significantly better than previous best, bias added against
       // raising filter value
       if (ss_err[filt_high] < (best_err - bias)) {
         best_err = ss_err[filt_high];
         filt_best = filt_high;
-        for (j = 0; j < ntiles; ++j) {
+        for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
           bilateral_level[j] = tmp_level[j];
         }
+        memcpy(best_tile_cost, tile_cost, sizeof(*tile_cost) * ntiles);
       }
     }
 
@@ -226,12 +324,11 @@
       filt_mid = filt_best;
     }
   }
-
   aom_free(tmp_level);
+  aom_free(tile_cost);
 
   // Update best error
   best_err = ss_err[filt_best];
-
   if (best_cost_ret) *best_cost_ret = best_err;
   if (filter_best) *filter_best = filt_best;
 
@@ -546,14 +643,15 @@
 
 static int search_wiener_filter(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                 int filter_level, int partial_frame,
-                                int (*vfilter)[RESTORATION_HALFWIN],
-                                int (*hfilter)[RESTORATION_HALFWIN],
-                                int *process_tile, double *best_cost_ret) {
+                                int (*vfilter)[RESTORATION_WIN],
+                                int (*hfilter)[RESTORATION_WIN],
+                                int *wiener_level, double *best_cost_ret,
+                                double *best_tile_cost) {
   AV1_COMMON *const cm = &cpi->common;
   RestorationInfo rsi;
   int64_t err;
   int bits;
-  double cost_wiener, cost_norestore;
+  double cost_wiener, cost_norestore, cost_norestore_tile;
   MACROBLOCK *x = &cpi->td.mb;
   double M[RESTORATION_WIN2];
   double H[RESTORATION_WIN2 * RESTORATION_WIN2];
@@ -564,56 +662,55 @@
   const int src_stride = src->y_stride;
   const int dgd_stride = dgd->y_stride;
   double score;
-  int tile_idx, htile_idx, vtile_idx, tile_width, tile_height, nhtiles, nvtiles;
+  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   int i, j;
 
-  const int tilesize = WIENER_TILESIZE;
-  const int ntiles = av1_get_restoration_ntiles(tilesize, width, height);
-
+  const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
   assert(width == dgd->y_crop_width);
   assert(height == dgd->y_crop_height);
   assert(width == src->y_crop_width);
   assert(height == src->y_crop_height);
 
-  av1_get_restoration_tile_size(tilesize, width, height, &tile_width,
-                                &tile_height, &nhtiles, &nvtiles);
-
   //  Make a copy of the unfiltered / processed recon buffer
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
   av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                         1, partial_frame);
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
 
-  rsi.restoration_type = RESTORE_NONE;
-  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
-  bits = 0;
-  cost_norestore =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
+  rsi.frame_restoration_type = RESTORE_NONE;
+  err = sse_restoration_tile(src, cm, 0, cm->width, 0, cm->height);
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
-  rsi.restoration_type = RESTORE_WIENER;
+  rsi.frame_restoration_type = RESTORE_WIENER;
   rsi.vfilter =
-      (int(*)[RESTORATION_HALFWIN])aom_malloc(sizeof(*rsi.vfilter) * ntiles);
+      (int(*)[RESTORATION_WIN])aom_malloc(sizeof(*rsi.vfilter) * ntiles);
   assert(rsi.vfilter != NULL);
   rsi.hfilter =
-      (int(*)[RESTORATION_HALFWIN])aom_malloc(sizeof(*rsi.hfilter) * ntiles);
+      (int(*)[RESTORATION_WIN])aom_malloc(sizeof(*rsi.hfilter) * ntiles);
   assert(rsi.hfilter != NULL);
   rsi.wiener_level = (int *)aom_malloc(sizeof(*rsi.wiener_level) * ntiles);
   assert(rsi.wiener_level != NULL);
 
   // Compute best Wiener filters for each tile
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    htile_idx = tile_idx % nhtiles;
-    vtile_idx = tile_idx / nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                      : (width - RESTORATION_HALFWIN);
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                      : (height - RESTORATION_HALFWIN);
+    wiener_level[tile_idx] = 0;
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, width, height, 0, 0, &h_start, &h_end,
+                             &v_start, &v_end);
+    err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                               v_end - v_start);
+    // #bits when a tile is not restored
+    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
+    cost_norestore_tile = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    if (best_tile_cost) best_tile_cost[tile_idx] = cost_norestore_tile;
 
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, width, height, 1, 1, &h_start, &h_end,
+                             &v_start, &v_end);
 #if CONFIG_AOM_HIGHBITDEPTH
     if (cm->use_highbitdepth)
       compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
@@ -626,12 +723,12 @@
     if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
       for (i = 0; i < RESTORATION_HALFWIN; ++i)
         rsi.vfilter[tile_idx][i] = rsi.hfilter[tile_idx][i] = 0;
-      process_tile[tile_idx] = 0;
+      wiener_level[tile_idx] = 0;
       continue;
     }
     quantize_sym_filter(vfilterd, rsi.vfilter[tile_idx]);
     quantize_sym_filter(hfilterd, rsi.hfilter[tile_idx]);
-    process_tile[tile_idx] = 1;
+    wiener_level[tile_idx] = 1;
 
     // Filter score computes the value of the function x'*A*x - x'*b for the
     // learned filter and compares it against identity filer. If there is no
@@ -640,31 +737,41 @@
     if (score > 0.0) {
       for (i = 0; i < RESTORATION_HALFWIN; ++i)
         rsi.vfilter[tile_idx][i] = rsi.hfilter[tile_idx][i] = 0;
-      process_tile[tile_idx] = 0;
+      wiener_level[tile_idx] = 0;
       continue;
     }
 
     for (j = 0; j < ntiles; ++j) rsi.wiener_level[j] = 0;
     rsi.wiener_level[tile_idx] = 1;
 
-    err = try_restoration_frame(src, cpi, &rsi, partial_frame);
-    bits = 1 + WIENER_FILT_BITS;
-    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv,
-                             (bits << (AV1_PROB_COST_SHIFT - 4)), err);
-    if (cost_wiener >= cost_norestore) process_tile[tile_idx] = 0;
+    err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0);
+    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
+    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
+    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    if (cost_wiener >= cost_norestore_tile) wiener_level[tile_idx] = 0;
+    if (best_tile_cost) {
+      bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
+      best_tile_cost[tile_idx] = RDCOST_DBL(
+          x->rdmult, x->rddiv,
+          (bits + cpi->switchable_restore_cost[RESTORE_WIENER]) >> 4, err);
+    }
   }
   // Cost for Wiener filtering
-  bits = 0;
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    bits += (process_tile[tile_idx] ? (WIENER_FILT_BITS + 1) : 1);
-    rsi.wiener_level[tile_idx] = process_tile[tile_idx];
+    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, wiener_level[tile_idx]);
+    if (wiener_level[tile_idx])
+      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
+    rsi.wiener_level[tile_idx] = wiener_level[tile_idx];
   }
+  // TODO(debargha): This is a pretty inefficient way to find the error
+  // for the whole frame. Specialize for a specific tile.
   err = try_restoration_frame(src, cpi, &rsi, partial_frame);
-  cost_wiener =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
+  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    if (process_tile[tile_idx] == 0) continue;
+    if (wiener_level[tile_idx] == 0) continue;
     for (i = 0; i < RESTORATION_HALFWIN; ++i) {
       vfilter[tile_idx][i] = rsi.vfilter[tile_idx][i];
       hfilter[tile_idx][i] = rsi.hfilter[tile_idx][i];
@@ -685,40 +792,125 @@
   }
 }
 
-void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi,
+static int search_switchable_restoration(
+    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int filter_level,
+    int partial_frame, RestorationInfo *rsi, double *tile_cost_bilateral,
+    double *tile_cost_wiener, double *best_cost_ret) {
+  AV1_COMMON *const cm = &cpi->common;
+  const int bilateral_level_bits = av1_bilateral_level_bits(&cpi->common);
+  MACROBLOCK *x = &cpi->td.mb;
+  double err, cost_norestore, cost_norestore_tile, cost_switchable;
+  int bits, tile_idx;
+  int tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
+
+  //  Make a copy of the unfiltered / processed recon buffer
+  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
+  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
+                        1, partial_frame);
+  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
+
+  // RD cost associated with no restoration
+  rsi->frame_restoration_type = RESTORE_NONE;
+  err = sse_restoration_tile(src, cm, 0, cm->width, 0, cm->height);
+  bits = frame_level_restore_bits[rsi->frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+
+  rsi->frame_restoration_type = RESTORE_SWITCHABLE;
+  bits = frame_level_restore_bits[rsi->frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, cm->width, cm->height, 0, 0, &h_start,
+                             &h_end, &v_start, &v_end);
+    err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                               v_end - v_start);
+    cost_norestore_tile =
+        RDCOST_DBL(x->rdmult, x->rddiv,
+                   (cpi->switchable_restore_cost[RESTORE_NONE] >> 4), err);
+    if (tile_cost_wiener[tile_idx] > cost_norestore_tile &&
+        tile_cost_bilateral[tile_idx] > cost_norestore_tile) {
+      rsi->restoration_type[tile_idx] = RESTORE_NONE;
+    } else {
+      rsi->restoration_type[tile_idx] =
+          tile_cost_wiener[tile_idx] < tile_cost_bilateral[tile_idx]
+              ? RESTORE_WIENER
+              : RESTORE_BILATERAL;
+      if (rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+        if (rsi->wiener_level[tile_idx]) {
+          bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
+        } else {
+          rsi->restoration_type[tile_idx] = RESTORE_NONE;
+        }
+      } else {
+        int s;
+        for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+#if BILATERAL_SUBTILES
+          bits += av1_cost_bit(
+              RESTORE_NONE_BILATERAL_PROB,
+              rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + s] >= 0);
+#endif
+          if (rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + s] >= 0)
+            bits += bilateral_level_bits << AV1_PROB_COST_SHIFT;
+        }
+      }
+    }
+    bits += cpi->switchable_restore_cost[rsi->restoration_type[tile_idx]];
+  }
+  err = try_restoration_frame(src, cpi, rsi, partial_frame);
+  cost_switchable = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
+  if (cost_switchable < cost_norestore) {
+    *best_cost_ret = cost_switchable;
+    return 1;
+  } else {
+    *best_cost_ret = cost_norestore;
+    return 0;
+  }
+}
+
+void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                  LPF_PICK_METHOD method) {
   AV1_COMMON *const cm = &cpi->common;
   struct loopfilter *const lf = &cm->lf;
   int wiener_success = 0;
   int bilateral_success = 0;
+  int switchable_success = 0;
   double cost_bilateral = DBL_MAX;
   double cost_wiener = DBL_MAX;
-  double cost_norestore = DBL_MAX;
-  int ntiles;
-
-  ntiles =
-      av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
-  cm->rst_info.bilateral_level =
-      (int *)aom_realloc(cm->rst_info.bilateral_level,
-                         sizeof(*cm->rst_info.bilateral_level) * ntiles);
+  // double cost_norestore = DBL_MAX;
+  double cost_switchable = DBL_MAX;
+  double *tile_cost_bilateral, *tile_cost_wiener;
+  const int ntiles =
+      av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
+  cm->rst_info.restoration_type = (RestorationType *)aom_realloc(
+      cm->rst_info.restoration_type,
+      sizeof(*cm->rst_info.restoration_type) * ntiles);
+  cm->rst_info.bilateral_level = (int *)aom_realloc(
+      cm->rst_info.bilateral_level,
+      sizeof(*cm->rst_info.bilateral_level) * ntiles * BILATERAL_SUBTILES);
   assert(cm->rst_info.bilateral_level != NULL);
 
-  ntiles = av1_get_restoration_ntiles(WIENER_TILESIZE, cm->width, cm->height);
   cm->rst_info.wiener_level = (int *)aom_realloc(
       cm->rst_info.wiener_level, sizeof(*cm->rst_info.wiener_level) * ntiles);
   assert(cm->rst_info.wiener_level != NULL);
-  cm->rst_info.vfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+  cm->rst_info.vfilter = (int(*)[RESTORATION_WIN])aom_realloc(
       cm->rst_info.vfilter, sizeof(*cm->rst_info.vfilter) * ntiles);
   assert(cm->rst_info.vfilter != NULL);
-  cm->rst_info.hfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+  cm->rst_info.hfilter = (int(*)[RESTORATION_WIN])aom_realloc(
       cm->rst_info.hfilter, sizeof(*cm->rst_info.hfilter) * ntiles);
   assert(cm->rst_info.hfilter != NULL);
+  tile_cost_wiener = (double *)aom_malloc(sizeof(cost_wiener) * ntiles);
+  tile_cost_bilateral = (double *)aom_malloc(sizeof(cost_bilateral) * ntiles);
 
   lf->sharpness_level = cm->frame_type == KEY_FRAME ? 0 : cpi->oxcf.sharpness;
 
   if (method == LPF_PICK_MINIMAL_LPF && lf->filter_level) {
     lf->filter_level = 0;
-    cm->rst_info.restoration_type = RESTORE_NONE;
+    cm->rst_info.frame_restoration_type = RESTORE_NONE;
   } else if (method >= LPF_PICK_FROM_Q) {
     const int min_filter_level = 0;
     const int max_filter_level = av1_get_max_filter_level(cpi);
@@ -749,60 +941,51 @@
     if (cm->frame_type == KEY_FRAME) filt_guess -= 4;
     lf->filter_level = clamp(filt_guess, min_filter_level, max_filter_level);
     bilateral_success = search_bilateral_level(
-        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
-        cm->rst_info.bilateral_level, &cost_bilateral);
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        cm->rst_info.bilateral_level, &cost_bilateral, tile_cost_bilateral);
     wiener_success = search_wiener_filter(
-        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
         cm->rst_info.vfilter, cm->rst_info.hfilter, cm->rst_info.wiener_level,
-        &cost_wiener);
-    if (cost_bilateral < cost_wiener) {
-      if (bilateral_success)
-        cm->rst_info.restoration_type = RESTORE_BILATERAL;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    } else {
-      if (wiener_success)
-        cm->rst_info.restoration_type = RESTORE_WIENER;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    }
+        &cost_wiener, tile_cost_wiener);
+    switchable_success = search_switchable_restoration(
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        &cm->rst_info, tile_cost_bilateral, tile_cost_wiener, &cost_switchable);
   } else {
-    int blf_filter_level = -1;
+    // lf->filter_level = av1_search_filter_level(
+    //     src, cpi, method == LPF_PICK_FROM_SUBIMAGE, &cost_norestore);
+    // bilateral_success = search_bilateral_level(
+    //     src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+    //     cm->rst_info.bilateral_level, &cost_bilateral, tile_cost_bilateral);
     bilateral_success = search_filter_bilateral_level(
-        sd, cpi, method == LPF_PICK_FROM_SUBIMAGE, &blf_filter_level,
-        cm->rst_info.bilateral_level, &cost_bilateral);
-    lf->filter_level = av1_search_filter_level(
-        sd, cpi, method == LPF_PICK_FROM_SUBIMAGE, &cost_norestore);
+        src, cpi, method == LPF_PICK_FROM_SUBIMAGE, &lf->filter_level,
+        cm->rst_info.bilateral_level, &cost_bilateral, tile_cost_bilateral);
     wiener_success = search_wiener_filter(
-        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
         cm->rst_info.vfilter, cm->rst_info.hfilter, cm->rst_info.wiener_level,
-        &cost_wiener);
-    if (cost_bilateral < cost_wiener) {
-      lf->filter_level = blf_filter_level;
-      if (bilateral_success)
-        cm->rst_info.restoration_type = RESTORE_BILATERAL;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    } else {
-      if (wiener_success)
-        cm->rst_info.restoration_type = RESTORE_WIENER;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    }
-    // printf("[%d] Costs %g %g (%d) %g (%d)\n", cm->rst_info.restoration_type,
-    //        cost_norestore, cost_bilateral, lf->filter_level, cost_wiener,
-    //        wiener_success);
+        &cost_wiener, tile_cost_wiener);
+    switchable_success = search_switchable_restoration(
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        &cm->rst_info, tile_cost_bilateral, tile_cost_wiener, &cost_switchable);
   }
-  if (cm->rst_info.restoration_type != RESTORE_BILATERAL) {
-    aom_free(cm->rst_info.bilateral_level);
-    cm->rst_info.bilateral_level = NULL;
+  if (cost_bilateral < AOMMIN(cost_wiener, cost_switchable)) {
+    if (bilateral_success)
+      cm->rst_info.frame_restoration_type = RESTORE_BILATERAL;
+    else
+      cm->rst_info.frame_restoration_type = RESTORE_NONE;
+  } else if (cost_wiener < AOMMIN(cost_bilateral, cost_switchable)) {
+    if (wiener_success)
+      cm->rst_info.frame_restoration_type = RESTORE_WIENER;
+    else
+      cm->rst_info.frame_restoration_type = RESTORE_NONE;
+  } else {
+    if (switchable_success)
+      cm->rst_info.frame_restoration_type = RESTORE_SWITCHABLE;
+    else
+      cm->rst_info.frame_restoration_type = RESTORE_NONE;
   }
-  if (cm->rst_info.restoration_type != RESTORE_WIENER) {
-    aom_free(cm->rst_info.vfilter);
-    cm->rst_info.vfilter = NULL;
-    aom_free(cm->rst_info.hfilter);
-    cm->rst_info.hfilter = NULL;
-    aom_free(cm->rst_info.wiener_level);
-    cm->rst_info.wiener_level = NULL;
-  }
+  printf("Frame %d frame_restore_type %d [%d]: %f %f %f\n",
+         cm->current_video_frame, cm->rst_info.frame_restoration_type, ntiles,
+         cost_bilateral, cost_wiener, cost_switchable);
+  aom_free(tile_cost_bilateral);
+  aom_free(tile_cost_wiener);
 }
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 5660369..b56e3c1 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -146,6 +146,10 @@
     av1_cost_tokens(cpi->intra_filter_cost[i], fc->intra_filter_probs[i],
                     av1_intra_filter_tree);
 #endif  // CONFIG_EXT_INTRA
+#if CONFIG_LOOP_RESTORATION
+  av1_cost_tokens(cpi->switchable_restore_cost, fc->switchable_restore_prob,
+                  av1_switchable_restore_tree);
+#endif  // CONFIG_LOOP_RESTORATION
 }
 
 void av1_fill_token_costs(av1_coeff_cost *c,