Move loop restoration coefficients to within the frame

Rather than encoding the loop restoration coefficients at the start of
the frame header, this patch moves them to occur just after certain
top-level superblocks.

You might hope that we could just encode coefficients on top-level
superblocks where the top-left corner of the superblock was also the
top-left corner of the loop restoration tile. Unfortunately, this
can't work with the superres experiment, where the loop restoration
tiles don't necessarily line up with the superblocks. Indeed, in
general there can be multiple different loop restoration coefficients
that apply in a given top-level superblock. This patch defines a
function, av1_loop_restoration_corners_in_sb, which yields the
rectangle [rrow0, rrow1) x [rcol0, rcol1) of loop restoration tiles
whose top left corners lie in this top-level superblock.

The total file size should be unchanged by this patch: the bits have
just been moved from the frame header and spread out among the rest of
the frame.

Change-Id: Icf43b0560964a63dea0d2cd801313f04139188d7
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 57bf41a..a5aa4d8 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -670,6 +670,18 @@
 typedef int16_t EobThresholdMD[TX_TYPES][EOB_THRESHOLD_NUM];
 #endif
 
+#if CONFIG_LOOP_RESTORATION
+typedef struct {
+  DECLARE_ALIGNED(16, InterpKernel, vfilter);
+  DECLARE_ALIGNED(16, InterpKernel, hfilter);
+} WienerInfo;
+
+typedef struct {
+  int ep;
+  int xqd[2];
+} SgrprojInfo;
+#endif  // CONFIG_LOOP_RESTORATION
+
 typedef struct macroblockd {
   struct macroblockd_plane plane[MAX_MB_PLANE];
   uint8_t bmode_blocks_wl;
@@ -731,6 +743,11 @@
 #endif
 #endif
 
+#if CONFIG_LOOP_RESTORATION
+  WienerInfo wiener_info[MAX_MB_PLANE];
+  SgrprojInfo sgrproj_info[MAX_MB_PLANE];
+#endif  // CONFIG_LOOP_RESTORATION
+
   // block dimension in the unit of mode_info.
   uint8_t n8_w, n8_h;
 
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index 910ecc3..11ee1d3 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -1426,3 +1426,67 @@
   loop_restoration_rows(frame, cm, start_mi_row, end_mi_row, components_pattern,
                         rsi, dst);
 }
+
+int av1_loop_restoration_corners_in_sb(const struct AV1Common *cm, int plane,
+                                       int mi_row, int mi_col, BLOCK_SIZE bsize,
+                                       int *rcol0, int *rcol1, int *rrow0,
+                                       int *rrow1, int *nhtiles) {
+  assert(rcol0 && rcol1 && rrow0 && rrow1 && nhtiles);
+
+  if (bsize != cm->sb_size) return 0;
+
+#if CONFIG_FRAME_SUPERRES
+  const int frame_w = cm->superres_upscaled_width;
+  const int frame_h = cm->superres_upscaled_height;
+  const int mi_to_px = MI_SIZE * cm->superres_scale_numerator;
+  const int denom = SCALE_DENOMINATOR;
+#else
+  const int frame_w = cm->width;
+  const int frame_h = cm->height;
+  const int mi_to_px = MI_SIZE;
+  const int denom = 1;
+#endif  // CONFIG_FRAME_SUPERRES
+
+  const int ss_frame_w = frame_w >> (plane > 0 && cm->subsampling_x != 0);
+  const int ss_frame_h = frame_h >> (plane > 0 && cm->subsampling_y != 0);
+
+  int rtile_w, rtile_h, nvtiles;
+  av1_get_rest_ntiles(ss_frame_w, ss_frame_h,
+                      cm->rst_info[0].restoration_tilesize, &rtile_w, &rtile_h,
+                      nhtiles, &nvtiles);
+
+  const int rnd_w = rtile_w * denom - 1;
+  const int rnd_h = rtile_h * denom - 1;
+
+  // rcol0/rrow0 should be the first column/row of rtiles that doesn't start
+  // left/below of mi_col/mi_row. For this calculation, we need to round up the
+  // division (if the sb starts at rtile column 10.1, the first matching rtile
+  // has column index 11)
+  *rcol0 = (mi_col * mi_to_px + rnd_w) / (rtile_w * denom);
+  *rrow0 = (mi_row * mi_to_px + rnd_h) / (rtile_h * denom);
+
+  // rcol1/rrow1 is the equivalent calculation, but for the superblock
+  // below-right. There are some slightly strange boundary effects. First, we
+  // need to clamp to nhtiles/nvtiles for the case where it appears there are,
+  // say, 2.4 restoration tiles horizontally. There we need a maximum mi_row1
+  // of 2 because tile 1 gets extended.
+  //
+  // Second, if mi_col1 >= cm->mi_cols then we must manually set *rcol1 to
+  // nhtiles. This is needed whenever the frame's width rounded up to the next
+  // toplevel superblock is smaller than nhtiles * rtile_w. The same logic is
+  // needed for rows.
+  const int mi_row1 = mi_row + mi_size_high[bsize];
+  const int mi_col1 = mi_col + mi_size_wide[bsize];
+
+  if (mi_col1 >= cm->mi_cols)
+    *rcol1 = *nhtiles;
+  else
+    *rcol1 = AOMMIN(*nhtiles, (mi_col1 * mi_to_px + rnd_w) / (rtile_w * denom));
+
+  if (mi_row1 >= cm->mi_rows)
+    *rrow1 = nvtiles;
+  else
+    *rrow1 = AOMMIN(nvtiles, (mi_row1 * mi_to_px + rnd_h) / (rtile_h * denom));
+
+  return *rcol0 < *rcol1 && *rrow0 < *rrow1;
+}
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index b718539..4a9ade9 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -135,10 +135,6 @@
 #if WIENER_FILT_PREC_BITS != 7
 #error "Wiener filter currently only works if WIENER_FILT_PREC_BITS == 7"
 #endif
-typedef struct {
-  DECLARE_ALIGNED(16, InterpKernel, vfilter);
-  DECLARE_ALIGNED(16, InterpKernel, hfilter);
-} WienerInfo;
 
 typedef struct {
 #if USE_HIGHPASS_IN_SGRPROJ
@@ -153,11 +149,6 @@
 } sgr_params_type;
 
 typedef struct {
-  int ep;
-  int xqd[2];
-} SgrprojInfo;
-
-typedef struct {
   int restoration_tilesize;
   RestorationType frame_restoration_type;
   RestorationType *restoration_type;
@@ -261,6 +252,20 @@
                                 RestorationInfo *rsi, int components_pattern,
                                 int partial_frame, YV12_BUFFER_CONFIG *dst);
 void av1_loop_restoration_precal();
+
+// Return 1 iff the block at mi_row, mi_col with size bsize is a
+// top-level superblock containing the top-left corner of at least one
+// loop restoration tile.
+//
+// If the block is a top-level superblock, the function writes to
+// *rcol0, *rcol1, *rrow0, *rrow1. The rectangle of indices given by
+// [*rcol0, *rcol1) x [*rrow0, *rrow1) will point at the set of rtiles
+// whose top left corners lie in the superblock. Note that the set is
+// only nonempty if *rcol0 < *rcol1 and *rrow0 < *rrow1.
+int av1_loop_restoration_corners_in_sb(const struct AV1Common *cm, int plane,
+                                       int mi_row, int mi_col, BLOCK_SIZE bsize,
+                                       int *rcol0, int *rcol1, int *rrow0,
+                                       int *rrow1, int *nhtiles);
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index fb0dbd7..9978a9f 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -86,6 +86,13 @@
 #include "av1/common/cfl.h"
 #endif
 
+#if CONFIG_LOOP_RESTORATION
+static void loop_restoration_read_sb_coeffs(const AV1_COMMON *const cm,
+                                            MACROBLOCKD *xd,
+                                            aom_reader *const r, int plane,
+                                            int rtile_idx);
+#endif
+
 static struct aom_read_bit_buffer *init_read_bit_buffer(
     AV1Decoder *pbi, struct aom_read_bit_buffer *rb, const uint8_t *data,
     const uint8_t *data_end, uint8_t clear_data[MAX_AV1_HEADER_SIZE]);
@@ -2549,6 +2556,21 @@
     }
   }
 #endif  // CONFIG_CDEF
+#if CONFIG_LOOP_RESTORATION
+  for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    int rcol0, rcol1, rrow0, rrow1, nhtiles;
+    if (av1_loop_restoration_corners_in_sb(cm, plane, mi_row, mi_col, bsize,
+                                           &rcol0, &rcol1, &rrow0, &rrow1,
+                                           &nhtiles)) {
+      for (int rrow = rrow0; rrow < rrow1; ++rrow) {
+        for (int rcol = rcol0; rcol < rcol1; ++rcol) {
+          int rtile_idx = rcol + rrow * nhtiles;
+          loop_restoration_read_sb_coeffs(cm, xd, r, plane, rtile_idx);
+        }
+      }
+    }
+  }
+#endif
 }
 
 static void setup_bool_decoder(const uint8_t *data, const uint8_t *data_end,
@@ -2739,97 +2761,43 @@
   memcpy(ref_sgrproj_info, sgrproj_info, sizeof(*sgrproj_info));
 }
 
-static void decode_restoration_for_tile(AV1_COMMON *cm, aom_reader *rb,
-                                        int tile_row, int tile_col,
-                                        const int nrtiles_x[2],
-                                        const int nrtiles_y[2]) {
-  for (int p = 0; p < MAX_MB_PLANE; ++p) {
-    RestorationInfo *rsi = &cm->rst_info[p];
-    if (rsi->frame_restoration_type == RESTORE_NONE) continue;
+static void loop_restoration_read_sb_coeffs(const AV1_COMMON *const cm,
+                                            MACROBLOCKD *xd,
+                                            aom_reader *const r, int plane,
+                                            int rtile_idx) {
+  const RestorationInfo *rsi = cm->rst_info + plane;
+  if (rsi->frame_restoration_type == RESTORE_NONE) return;
 
-    const int tile_width =
-        (p > 0) ? ROUND_POWER_OF_TWO(cm->tile_width, cm->subsampling_x)
-                : cm->tile_width;
-    const int tile_height =
-        (p > 0) ? ROUND_POWER_OF_TWO(cm->tile_height, cm->subsampling_y)
-                : cm->tile_height;
-    const int rtile_size = rsi->restoration_tilesize;
-    const int rtiles_per_tile_x = tile_width * MI_SIZE / rtile_size;
-    const int rtiles_per_tile_y = tile_height * MI_SIZE / rtile_size;
+  const int wiener_win = (plane > 0) ? WIENER_WIN_CHROMA : WIENER_WIN;
+  WienerInfo *wiener_info = xd->wiener_info + plane;
+  SgrprojInfo *sgrproj_info = xd->sgrproj_info + plane;
 
-    const int rtile_row0 = rtiles_per_tile_y * tile_row;
-    const int rtile_row1 =
-        AOMMIN(rtile_row0 + rtiles_per_tile_y, nrtiles_y[p > 0]);
+  if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
+    assert(plane == 0);
+    rsi->restoration_type[rtile_idx] =
+        aom_read_tree(r, av1_switchable_restore_tree,
+                      cm->fc->switchable_restore_prob, ACCT_STR);
 
-    const int rtile_col0 = rtiles_per_tile_x * tile_col;
-    const int rtile_col1 =
-        AOMMIN(rtile_col0 + rtiles_per_tile_x, nrtiles_x[p > 0]);
-
-    WienerInfo wiener_info;
-    SgrprojInfo sgrproj_info;
-    set_default_wiener(&wiener_info);
-    set_default_sgrproj(&sgrproj_info);
-
-    const int wiener_win = (p > 0) ? WIENER_WIN_CHROMA : WIENER_WIN;
-
-    for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
-      for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
-        const int rtile_idx = rtile_row * nrtiles_x[p > 0] + rtile_col;
-        if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
-          assert(p == 0);
-          rsi->restoration_type[rtile_idx] =
-              aom_read_tree(rb, av1_switchable_restore_tree,
-                            cm->fc->switchable_restore_prob, ACCT_STR);
-          if (rsi->restoration_type[rtile_idx] == RESTORE_WIENER) {
-            read_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx],
-                               &wiener_info, rb);
-          } else if (rsi->restoration_type[rtile_idx] == RESTORE_SGRPROJ) {
-            read_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], &sgrproj_info,
-                                rb);
-          }
-        } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
-          if (aom_read(rb, RESTORE_NONE_WIENER_PROB, ACCT_STR)) {
-            rsi->restoration_type[rtile_idx] = RESTORE_WIENER;
-            read_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx],
-                               &wiener_info, rb);
-          } else {
-            rsi->restoration_type[rtile_idx] = RESTORE_NONE;
-          }
-        } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
-          if (aom_read(rb, RESTORE_NONE_SGRPROJ_PROB, ACCT_STR)) {
-            rsi->restoration_type[rtile_idx] = RESTORE_SGRPROJ;
-            read_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], &sgrproj_info,
-                                rb);
-          } else {
-            rsi->restoration_type[rtile_idx] = RESTORE_NONE;
-          }
-        }
-      }
+    if (rsi->restoration_type[rtile_idx] == RESTORE_WIENER) {
+      read_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx], wiener_info,
+                         r);
+    } else if (rsi->restoration_type[rtile_idx] == RESTORE_SGRPROJ) {
+      read_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], sgrproj_info, r);
     }
-  }
-}
-
-static void decode_restoration(AV1_COMMON *cm, aom_reader *rb) {
-#if CONFIG_FRAME_SUPERRES
-  const int width = cm->superres_upscaled_width;
-  const int height = cm->superres_upscaled_height;
-#else
-  const int width = cm->width;
-  const int height = cm->height;
-#endif  // CONFIG_FRAME_SUPERRES
-
-  int nrtiles_x[2], nrtiles_y[2];
-  av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize, NULL,
-                      NULL, &nrtiles_x[0], &nrtiles_y[0]);
-  av1_get_rest_ntiles(ROUND_POWER_OF_TWO(width, cm->subsampling_x),
-                      ROUND_POWER_OF_TWO(height, cm->subsampling_y),
-                      cm->rst_info[1].restoration_tilesize, NULL, NULL,
-                      &nrtiles_x[1], &nrtiles_y[1]);
-
-  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
-    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
-      decode_restoration_for_tile(cm, rb, tile_row, tile_col, nrtiles_x,
-                                  nrtiles_y);
+  } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
+    if (aom_read(r, RESTORE_NONE_WIENER_PROB, ACCT_STR)) {
+      rsi->restoration_type[rtile_idx] = RESTORE_WIENER;
+      read_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx], wiener_info,
+                         r);
+    } else {
+      rsi->restoration_type[rtile_idx] = RESTORE_NONE;
+    }
+  } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
+    if (aom_read(r, RESTORE_NONE_SGRPROJ_PROB, ACCT_STR)) {
+      rsi->restoration_type[rtile_idx] = RESTORE_SGRPROJ;
+      read_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], sgrproj_info, r);
+    } else {
+      rsi->restoration_type[rtile_idx] = RESTORE_NONE;
     }
   }
 }
@@ -3740,6 +3708,12 @@
 #else
       av1_zero_above_context(cm, tile_info.mi_col_start, tile_info.mi_col_end);
 #endif
+#if CONFIG_LOOP_RESTORATION
+      for (int p = 0; p < MAX_MB_PLANE; ++p) {
+        set_default_wiener(td->xd.wiener_info + p);
+        set_default_sgrproj(td->xd.sgrproj_info + p);
+      }
+#endif  // CONFIG_LOOP_RESTORATION
 
 #if CONFIG_LOOPFILTERING_ACROSS_TILES
       dec_setup_across_tile_boundary_info(cm, &tile_info);
@@ -4902,15 +4876,6 @@
     aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
                        "Failed to allocate bool decoder 0");
 
-#if CONFIG_LOOP_RESTORATION
-  if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
-      cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
-      cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
-    av1_alloc_restoration_buffers(cm);
-    decode_restoration(cm, &r);
-  }
-#endif
-
 #if CONFIG_RECT_TX_EXT && (CONFIG_EXT_TX || CONFIG_VAR_TX)
   if (cm->tx_mode == TX_MODE_SELECT)
     av1_diff_update_prob(&r, &fc->quarter_tx_size_prob, ACCT_STR);
@@ -5303,6 +5268,14 @@
     aom_internal_error(&cm->error, AOM_CODEC_CORRUPT_FRAME,
                        "Decode failed. Frame data header is corrupted.");
 
+#if CONFIG_LOOP_RESTORATION
+  if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
+      cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
+      cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
+    av1_alloc_restoration_buffers(cm);
+  }
+#endif
+
 #if CONFIG_LOOPFILTER_LEVEL
   if ((cm->lf.filter_level[0] || cm->lf.filter_level[1]) &&
       !cm->skip_loop_filter) {
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index be8073c..da7951d 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -102,6 +102,10 @@
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
 #if CONFIG_LOOP_RESTORATION
 static struct av1_token switchable_restore_encodings[RESTORE_SWITCHABLE_TYPES];
+static void loop_restoration_write_sb_coeffs(const AV1_COMMON *const cm,
+                                             MACROBLOCKD *xd,
+                                             aom_writer *const w, int plane,
+                                             int rtile_idx);
 #endif  // CONFIG_LOOP_RESTORATION
 static void write_uncompressed_header(AV1_COMP *cpi,
                                       struct aom_write_bit_buffer *wb);
@@ -3126,6 +3130,21 @@
     }
   }
 #endif
+#if CONFIG_LOOP_RESTORATION
+  for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    int rcol0, rcol1, rrow0, rrow1, nhtiles;
+    if (av1_loop_restoration_corners_in_sb(cm, plane, mi_row, mi_col, bsize,
+                                           &rcol0, &rcol1, &rrow0, &rrow1,
+                                           &nhtiles)) {
+      for (int rrow = rrow0; rrow < rrow1; ++rrow) {
+        for (int rcol = rcol0; rcol < rcol1; ++rcol) {
+          int rtile_idx = rcol + rrow * nhtiles;
+          loop_restoration_write_sb_coeffs(cm, xd, w, plane, rtile_idx);
+        }
+      }
+    }
+  }
+#endif
 }
 
 static void write_modes(AV1_COMP *const cpi, const TileInfo *const tile,
@@ -3307,98 +3326,44 @@
   memcpy(ref_sgrproj_info, sgrproj_info, sizeof(*sgrproj_info));
 }
 
-static void encode_restoration_for_tile(AV1_COMMON *cm, aom_writer *wb,
-                                        int tile_row, int tile_col,
-                                        const int nrtiles_x[2],
-                                        const int nrtiles_y[2]) {
-  for (int p = 0; p < MAX_MB_PLANE; ++p) {
-    RestorationInfo *rsi = &cm->rst_info[p];
-    if (rsi->frame_restoration_type == RESTORE_NONE) continue;
+static void loop_restoration_write_sb_coeffs(const AV1_COMMON *const cm,
+                                             MACROBLOCKD *xd,
+                                             aom_writer *const w, int plane,
+                                             int rtile_idx) {
+  const RestorationInfo *rsi = cm->rst_info + plane;
+  if (rsi->frame_restoration_type == RESTORE_NONE) return;
 
-    const int tile_width =
-        (p > 0) ? ROUND_POWER_OF_TWO(cm->tile_width, cm->subsampling_x)
-                : cm->tile_width;
-    const int tile_height =
-        (p > 0) ? ROUND_POWER_OF_TWO(cm->tile_height, cm->subsampling_y)
-                : cm->tile_height;
-    const int rtile_size = rsi->restoration_tilesize;
-    const int rtiles_per_tile_x = tile_width * MI_SIZE / rtile_size;
-    const int rtiles_per_tile_y = tile_height * MI_SIZE / rtile_size;
+  const int wiener_win = (plane > 0) ? WIENER_WIN_CHROMA : WIENER_WIN;
+  WienerInfo *wiener_info = xd->wiener_info + plane;
+  SgrprojInfo *sgrproj_info = xd->sgrproj_info + plane;
 
-    const int rtile_row0 = rtiles_per_tile_y * tile_row;
-    const int rtile_row1 =
-        AOMMIN(rtile_row0 + rtiles_per_tile_y, nrtiles_y[p > 0]);
-
-    const int rtile_col0 = rtiles_per_tile_x * tile_col;
-    const int rtile_col1 =
-        AOMMIN(rtile_col0 + rtiles_per_tile_x, nrtiles_x[p > 0]);
-
-    WienerInfo wiener_info;
-    SgrprojInfo sgrproj_info;
-    set_default_wiener(&wiener_info);
-    set_default_sgrproj(&sgrproj_info);
-
-    const int wiener_win = (p > 0) ? WIENER_WIN_CHROMA : WIENER_WIN;
-
-    for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
-      for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
-        const int rtile_idx = rtile_row * nrtiles_x[p > 0] + rtile_col;
-        if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
-          assert(p == 0);
-          av1_write_token(
-              wb, av1_switchable_restore_tree, cm->fc->switchable_restore_prob,
-              &switchable_restore_encodings[rsi->restoration_type[rtile_idx]]);
-          if (rsi->restoration_type[rtile_idx] == RESTORE_WIENER) {
-            write_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx],
-                                &wiener_info, wb);
-          } else if (rsi->restoration_type[rtile_idx] == RESTORE_SGRPROJ) {
-            write_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], &sgrproj_info,
-                                 wb);
-          }
-        } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
-          aom_write(wb, rsi->restoration_type[rtile_idx] != RESTORE_NONE,
-                    RESTORE_NONE_WIENER_PROB);
-          if (rsi->restoration_type[rtile_idx] != RESTORE_NONE) {
-            write_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx],
-                                &wiener_info, wb);
-          }
-        } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
-          aom_write(wb, rsi->restoration_type[rtile_idx] != RESTORE_NONE,
-                    RESTORE_NONE_SGRPROJ_PROB);
-          if (rsi->restoration_type[rtile_idx] != RESTORE_NONE) {
-            write_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], &sgrproj_info,
-                                 wb);
-          }
-        }
-      }
+  if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
+    assert(plane == 0);
+    av1_write_token(
+        w, av1_switchable_restore_tree, cm->fc->switchable_restore_prob,
+        &switchable_restore_encodings[rsi->restoration_type[rtile_idx]]);
+    if (rsi->restoration_type[rtile_idx] == RESTORE_WIENER) {
+      write_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx], wiener_info,
+                          w);
+    } else if (rsi->restoration_type[rtile_idx] == RESTORE_SGRPROJ) {
+      write_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], sgrproj_info, w);
+    }
+  } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
+    aom_write(w, rsi->restoration_type[rtile_idx] != RESTORE_NONE,
+              RESTORE_NONE_WIENER_PROB);
+    if (rsi->restoration_type[rtile_idx] != RESTORE_NONE) {
+      write_wiener_filter(wiener_win, &rsi->wiener_info[rtile_idx], wiener_info,
+                          w);
+    }
+  } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
+    aom_write(w, rsi->restoration_type[rtile_idx] != RESTORE_NONE,
+              RESTORE_NONE_SGRPROJ_PROB);
+    if (rsi->restoration_type[rtile_idx] != RESTORE_NONE) {
+      write_sgrproj_filter(&rsi->sgrproj_info[rtile_idx], sgrproj_info, w);
     }
   }
 }
 
-static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
-#if CONFIG_FRAME_SUPERRES
-  const int width = cm->superres_upscaled_width;
-  const int height = cm->superres_upscaled_height;
-#else
-  const int width = cm->width;
-  const int height = cm->height;
-#endif  // CONFIG_FRAME_SUPERRES
-
-  int nrtiles_x[2], nrtiles_y[2];
-  av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize, NULL,
-                      NULL, &nrtiles_x[0], &nrtiles_y[0]);
-  av1_get_rest_ntiles(ROUND_POWER_OF_TWO(width, cm->subsampling_x),
-                      ROUND_POWER_OF_TWO(height, cm->subsampling_y),
-                      cm->rst_info[1].restoration_tilesize, NULL, NULL,
-                      &nrtiles_x[1], &nrtiles_y[1]);
-
-  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
-    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
-      encode_restoration_for_tile(cm, wb, tile_row, tile_col, nrtiles_x,
-                                  nrtiles_y);
-    }
-  }
-}
 #endif  // CONFIG_LOOP_RESTORATION
 
 static void encode_loopfilter(AV1_COMMON *cm, struct aom_write_bit_buffer *wb) {
@@ -4009,6 +3974,13 @@
 #if CONFIG_ANS
         mode_bc.size = 1 << cpi->common.ans_window_size_log2;
 #endif  // CONFIG_ANS
+#if CONFIG_LOOP_RESTORATION
+        for (int p = 0; p < MAX_MB_PLANE; ++p) {
+          set_default_wiener(cpi->td.mb.e_mbd.wiener_info + p);
+          set_default_sgrproj(cpi->td.mb.e_mbd.sgrproj_info + p);
+        }
+#endif  // CONFIG_LOOP_RESTORATION
+
         aom_start_encode(&mode_bc, dst + total_size);
         write_modes(cpi, &tile_info, &mode_bc, &tok, tok_end);
 #if !CONFIG_LV_MAP
@@ -4640,9 +4612,6 @@
 #endif
   aom_start_encode(header_bc, data);
 
-#if CONFIG_LOOP_RESTORATION
-  encode_restoration(cm, header_bc);
-#endif  // CONFIG_LOOP_RESTORATION
 #if CONFIG_RECT_TX_EXT && (CONFIG_EXT_TX || CONFIG_VAR_TX)
   if (cm->tx_mode == TX_MODE_SELECT)
     av1_cond_prob_diff_update(header_bc, &cm->fc->quarter_tx_size_prob,