Misc cleanups and enhancements on loop restoration

Includes:
Some cleanups/refactoring
Better buffer management.
Some preps for future chrominance restoration.

Change-Id: Ia264b8989b5f4a53c0764ed3e8258ddc212723fc
diff --git a/aom_dsp/psnr.c b/aom_dsp/psnr.c
index 93899ba..4414393 100644
--- a/aom_dsp/psnr.c
+++ b/aom_dsp/psnr.c
@@ -194,6 +194,14 @@
                  a->y_crop_width, a->y_crop_height);
 }
 
+int64_t aom_get_u_sse_part(const YV12_BUFFER_CONFIG *a,
+                           const YV12_BUFFER_CONFIG *b, int hstart, int width,
+                           int vstart, int height) {
+  return get_sse(a->u_buffer + vstart * a->uv_stride + hstart, a->uv_stride,
+                 b->u_buffer + vstart * b->uv_stride + hstart, b->uv_stride,
+                 width, height);
+}
+
 int64_t aom_get_u_sse(const YV12_BUFFER_CONFIG *a,
                       const YV12_BUFFER_CONFIG *b) {
   assert(a->uv_crop_width == b->uv_crop_width);
@@ -203,6 +211,14 @@
                  a->uv_crop_width, a->uv_crop_height);
 }
 
+int64_t aom_get_v_sse_part(const YV12_BUFFER_CONFIG *a,
+                           const YV12_BUFFER_CONFIG *b, int hstart, int width,
+                           int vstart, int height) {
+  return get_sse(a->v_buffer + vstart * a->uv_stride + hstart, a->uv_stride,
+                 b->v_buffer + vstart * b->uv_stride + hstart, b->uv_stride,
+                 width, height);
+}
+
 int64_t aom_get_v_sse(const YV12_BUFFER_CONFIG *a,
                       const YV12_BUFFER_CONFIG *b) {
   assert(a->uv_crop_width == b->uv_crop_width);
@@ -232,6 +248,15 @@
                         a->y_crop_width, a->y_crop_height);
 }
 
+int64_t aom_highbd_get_u_sse_part(const YV12_BUFFER_CONFIG *a,
+                                  const YV12_BUFFER_CONFIG *b, int hstart,
+                                  int width, int vstart, int height) {
+  return highbd_get_sse(a->u_buffer + vstart * a->uv_stride + hstart,
+                        a->uv_stride,
+                        b->u_buffer + vstart * b->uv_stride + hstart,
+                        b->uv_stride, width, height);
+}
+
 int64_t aom_highbd_get_u_sse(const YV12_BUFFER_CONFIG *a,
                              const YV12_BUFFER_CONFIG *b) {
   assert(a->uv_crop_width == b->uv_crop_width);
@@ -243,6 +268,15 @@
                         a->uv_crop_width, a->uv_crop_height);
 }
 
+int64_t aom_highbd_get_v_sse_part(const YV12_BUFFER_CONFIG *a,
+                                  const YV12_BUFFER_CONFIG *b, int hstart,
+                                  int width, int vstart, int height) {
+  return highbd_get_sse(a->v_buffer + vstart * a->uv_stride + hstart,
+                        a->uv_stride,
+                        b->v_buffer + vstart * b->uv_stride + hstart,
+                        b->uv_stride, width, height);
+}
+
 int64_t aom_highbd_get_v_sse(const YV12_BUFFER_CONFIG *a,
                              const YV12_BUFFER_CONFIG *b) {
   assert(a->uv_crop_width == b->uv_crop_width);
diff --git a/aom_dsp/psnr.h b/aom_dsp/psnr.h
index 1cd6b19..71432fe 100644
--- a/aom_dsp/psnr.h
+++ b/aom_dsp/psnr.h
@@ -39,7 +39,13 @@
                            const YV12_BUFFER_CONFIG *b, int hstart, int width,
                            int vstart, int height);
 int64_t aom_get_y_sse(const YV12_BUFFER_CONFIG *a, const YV12_BUFFER_CONFIG *b);
+int64_t aom_get_u_sse_part(const YV12_BUFFER_CONFIG *a,
+                           const YV12_BUFFER_CONFIG *b, int hstart, int width,
+                           int vstart, int height);
 int64_t aom_get_u_sse(const YV12_BUFFER_CONFIG *a, const YV12_BUFFER_CONFIG *b);
+int64_t aom_get_v_sse_part(const YV12_BUFFER_CONFIG *a,
+                           const YV12_BUFFER_CONFIG *b, int hstart, int width,
+                           int vstart, int height);
 int64_t aom_get_v_sse(const YV12_BUFFER_CONFIG *a, const YV12_BUFFER_CONFIG *b);
 #if CONFIG_AOM_HIGHBITDEPTH
 int64_t aom_highbd_get_y_sse_part(const YV12_BUFFER_CONFIG *a,
@@ -47,8 +53,14 @@
                                   int width, int vstart, int height);
 int64_t aom_highbd_get_y_sse(const YV12_BUFFER_CONFIG *a,
                              const YV12_BUFFER_CONFIG *b);
+int64_t aom_highbd_get_u_sse_part(const YV12_BUFFER_CONFIG *a,
+                                  const YV12_BUFFER_CONFIG *b, int hstart,
+                                  int width, int vstart, int height);
 int64_t aom_highbd_get_u_sse(const YV12_BUFFER_CONFIG *a,
                              const YV12_BUFFER_CONFIG *b);
+int64_t aom_highbd_get_v_sse_part(const YV12_BUFFER_CONFIG *a,
+                                  const YV12_BUFFER_CONFIG *b, int hstart,
+                                  int width, int vstart, int height);
 int64_t aom_highbd_get_v_sse(const YV12_BUFFER_CONFIG *a,
                              const YV12_BUFFER_CONFIG *b);
 void aom_calc_highbd_psnr(const YV12_BUFFER_CONFIG *a,
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index 86051c8..31c1bf7 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -87,15 +87,17 @@
 }
 
 #if CONFIG_LOOP_RESTORATION
+void av1_alloc_restoration_buffers(AV1_COMMON *cm) {
+  av1_alloc_restoration_struct(&cm->rst_info, cm->width, cm->height);
+  cm->rst_internal.tmpbuf =
+      (uint8_t *)aom_realloc(cm->rst_internal.tmpbuf, RESTORATION_TMPBUF_SIZE);
+  assert(cm->rst_internal.tmpbuf != NULL);
+}
+
 void av1_free_restoration_buffers(AV1_COMMON *cm) {
-  aom_free(cm->rst_info.restoration_type);
-  cm->rst_info.restoration_type = NULL;
-  aom_free(cm->rst_info.wiener_info);
-  cm->rst_info.wiener_info = NULL;
-  aom_free(cm->rst_info.sgrproj_info);
-  cm->rst_info.sgrproj_info = NULL;
-  aom_free(cm->rst_info.domaintxfmrf_info);
-  cm->rst_info.domaintxfmrf_info = NULL;
+  av1_free_restoration_struct(&cm->rst_info);
+  aom_free(cm->rst_internal.tmpbuf);
+  cm->rst_internal.tmpbuf = NULL;
 }
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/common/alloccommon.h b/av1/common/alloccommon.h
index 0a0c38c..51863cd 100644
--- a/av1/common/alloccommon.h
+++ b/av1/common/alloccommon.h
@@ -29,6 +29,7 @@
 
 void av1_free_ref_frame_buffers(struct BufferPool *pool);
 #if CONFIG_LOOP_RESTORATION
+void av1_alloc_restoration_buffers(struct AV1Common *cm);
 void av1_free_restoration_buffers(struct AV1Common *cm);
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index c4f79c3..e905a8c 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -50,6 +50,35 @@
                                          int dst_stride);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
+int av1_alloc_restoration_struct(RestorationInfo *rst_info, int width,
+                                 int height) {
+  const int ntiles = av1_get_rest_ntiles(width, height, NULL, NULL, NULL, NULL);
+  rst_info->restoration_type = (RestorationType *)aom_realloc(
+      rst_info->restoration_type, sizeof(*rst_info->restoration_type) * ntiles);
+  rst_info->wiener_info = (WienerInfo *)aom_realloc(
+      rst_info->wiener_info, sizeof(*rst_info->wiener_info) * ntiles);
+  assert(rst_info->wiener_info != NULL);
+  rst_info->sgrproj_info = (SgrprojInfo *)aom_realloc(
+      rst_info->sgrproj_info, sizeof(*rst_info->sgrproj_info) * ntiles);
+  assert(rst_info->sgrproj_info != NULL);
+  rst_info->domaintxfmrf_info = (DomaintxfmrfInfo *)aom_realloc(
+      rst_info->domaintxfmrf_info,
+      sizeof(*rst_info->domaintxfmrf_info) * ntiles);
+  assert(rst_info->domaintxfmrf_info != NULL);
+  return ntiles;
+}
+
+void av1_free_restoration_struct(RestorationInfo *rst_info) {
+  aom_free(rst_info->restoration_type);
+  rst_info->restoration_type = NULL;
+  aom_free(rst_info->wiener_info);
+  rst_info->wiener_info = NULL;
+  aom_free(rst_info->sgrproj_info);
+  rst_info->sgrproj_info = NULL;
+  aom_free(rst_info->domaintxfmrf_info);
+  rst_info->domaintxfmrf_info = NULL;
+}
+
 static void GenDomainTxfmRFVtable() {
   int i, j;
   const double sigma_s = sqrt(2.0);
@@ -520,15 +549,16 @@
 
 static void loop_sgrproj_filter_tile(uint8_t *data, int tile_idx, int width,
                                      int height, int stride,
-                                     RestorationInternal *rst, void *tmpbuf,
-                                     uint8_t *dst, int dst_stride) {
+                                     RestorationInternal *rst, uint8_t *dst,
+                                     int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int i, j;
   int h_start, h_end, v_start, v_end;
   uint8_t *data_p, *dst_p;
-  int64_t *dat = (int64_t *)tmpbuf;
-  tmpbuf = (uint8_t *)tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
+  int64_t *dat = (int64_t *)rst->tmpbuf;
+  uint8_t *tmpbuf =
+      (uint8_t *)rst->tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
 
   if (rst->rsi->sgrproj_info[tile_idx].level == 0) {
     loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -561,12 +591,10 @@
                                 int stride, RestorationInternal *rst,
                                 uint8_t *dst, int dst_stride) {
   int tile_idx;
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, tmpbuf,
-                             dst, dst_stride);
+    loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, dst,
+                             dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void apply_domaintxfmrf_hor(int iter, int param, uint8_t *img, int width,
@@ -664,11 +692,11 @@
 static void loop_domaintxfmrf_filter_tile(uint8_t *data, int tile_idx,
                                           int width, int height, int stride,
                                           RestorationInternal *rst,
-                                          uint8_t *dst, int dst_stride,
-                                          int32_t *tmpbuf) {
+                                          uint8_t *dst, int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int h_start, h_end, v_start, v_end;
+  int32_t *tmpbuf = (int32_t *)rst->tmpbuf;
 
   if (rst->rsi->domaintxfmrf_info[tile_idx].level == 0) {
     loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -688,23 +716,16 @@
                                      int stride, RestorationInternal *rst,
                                      uint8_t *dst, int dst_stride) {
   int tile_idx;
-  int32_t *tmpbuf =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf));
-
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
     loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
-                                  dst, dst_stride, tmpbuf);
+                                  dst, dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void loop_switchable_filter(uint8_t *data, int width, int height,
                                    int stride, RestorationInternal *rst,
                                    uint8_t *dst, int dst_stride) {
   int tile_idx;
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
-  int32_t *tmpbuf32 =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf32));
   extend_frame(data, width, height, stride);
   copy_border(data, width, height, stride, dst, dst_stride);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
@@ -715,15 +736,13 @@
       loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst, dst,
                               dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ) {
-      loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst,
-                               tmpbuf, dst, dst_stride);
+      loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, dst,
+                               dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_DOMAINTXFMRF) {
       loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
-                                    dst, dst_stride, tmpbuf32);
+                                    dst, dst_stride);
     }
   }
-  aom_free(tmpbuf);
-  aom_free(tmpbuf32);
 }
 
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -830,7 +849,7 @@
                                       int dst_stride) {
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
-  int tile_idx, i;
+  int tile_idx;
   copy_border_highbd(data, width, height, stride, dst, dst_stride);
   extend_frame_highbd(data, width, height, stride);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
@@ -842,15 +861,16 @@
 static void loop_sgrproj_filter_tile_highbd(uint16_t *data, int tile_idx,
                                             int width, int height, int stride,
                                             RestorationInternal *rst,
-                                            int bit_depth, void *tmpbuf,
-                                            uint16_t *dst, int dst_stride) {
+                                            int bit_depth, uint16_t *dst,
+                                            int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int i, j;
   int h_start, h_end, v_start, v_end;
   uint16_t *data_p, *dst_p;
-  int64_t *dat = (int64_t *)tmpbuf;
-  tmpbuf = (uint8_t *)tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
+  int64_t *dat = (int64_t *)rst->tmpbuf;
+  uint8_t *tmpbuf =
+      (uint8_t *)rst->tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
 
   if (rst->rsi->sgrproj_info[tile_idx].level == 0) {
     loop_copy_tile_highbd(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -885,13 +905,11 @@
                                        int dst_stride) {
   int tile_idx;
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
     loop_sgrproj_filter_tile_highbd(data, tile_idx, width, height, stride, rst,
-                                    bit_depth, tmpbuf, dst, dst_stride);
+                                    bit_depth, dst, dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void apply_domaintxfmrf_hor_highbd(int iter, int param, uint16_t *img,
@@ -989,11 +1007,11 @@
 
 static void loop_domaintxfmrf_filter_tile_highbd(
     uint16_t *data, int tile_idx, int width, int height, int stride,
-    RestorationInternal *rst, int bit_depth, uint16_t *dst, int dst_stride,
-    int32_t *tmpbuf) {
+    RestorationInternal *rst, int bit_depth, uint16_t *dst, int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int h_start, h_end, v_start, v_end;
+  int32_t *tmpbuf = (int32_t *)rst->tmpbuf;
 
   if (rst->rsi->domaintxfmrf_info[tile_idx].level == 0) {
     loop_copy_tile_highbd(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -1015,16 +1033,12 @@
                                             int bit_depth, uint8_t *dst8,
                                             int dst_stride) {
   int tile_idx;
-  int32_t *tmpbuf =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf));
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
     loop_domaintxfmrf_filter_tile_highbd(data, tile_idx, width, height, stride,
-                                         rst, bit_depth, dst, dst_stride,
-                                         tmpbuf);
+                                         rst, bit_depth, dst, dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void loop_switchable_filter_highbd(uint8_t *data8, int width, int height,
@@ -1032,11 +1046,8 @@
                                           int bit_depth, uint8_t *dst8,
                                           int dst_stride) {
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
-  int32_t *tmpbuf32 =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf32));
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
-  int i, tile_idx;
+  int tile_idx;
   copy_border_highbd(data, width, height, stride, dst, dst_stride);
   extend_frame_highbd(data, width, height, stride);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
@@ -1048,15 +1059,13 @@
                                      bit_depth, dst, dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ) {
       loop_sgrproj_filter_tile_highbd(data, tile_idx, width, height, stride,
-                                      rst, bit_depth, tmpbuf, dst, dst_stride);
+                                      rst, bit_depth, dst, dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_DOMAINTXFMRF) {
       loop_domaintxfmrf_filter_tile_highbd(data, tile_idx, width, height,
                                            stride, rst, bit_depth, dst,
-                                           dst_stride, tmpbuf32);
+                                           dst_stride);
     }
   }
-  aom_free(tmpbuf);
-  aom_free(tmpbuf32);
 }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 5773c77..2274fa8 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -38,10 +38,15 @@
 #define DOMAINTXFMRF_VTABLE_PREC (1 << DOMAINTXFMRF_VTABLE_PRECBITS)
 #define DOMAINTXFMRF_MULT \
   sqrt(((1 << (DOMAINTXFMRF_ITERS * 2)) - 1) * 2.0 / 3.0)
-#define DOMAINTXFMRF_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX)
+// A single 32 bit buffer needed for the filter
+#define DOMAINTXFMRF_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX * sizeof(int32_t))
 #define DOMAINTXFMRF_BITS (DOMAINTXFMRF_PARAMS_BITS)
 
-#define SGRPROJ_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX * 6 * 8)
+// 6 highprecision 64-bit buffers needed for the filter:
+// 1 for the degraded frame, 2 for the restored versions and
+// 3 for each restoration operation
+// TODO(debargha): Explore if we can use 32-bit buffers
+#define SGRPROJ_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX * 6 * sizeof(int64_t))
 #define SGRPROJ_PARAMS_BITS 3
 #define SGRPROJ_PARAMS (1 << SGRPROJ_PARAMS_BITS)
 
@@ -60,6 +65,9 @@
 
 #define SGRPROJ_BITS (SGRPROJ_PRJ_BITS * 2 + SGRPROJ_PARAMS_BITS)
 
+// Max of SGRPROJ_TMPBUF_SIZE and DOMAINTXFMRF_TMPBUF_SIZE
+#define RESTORATION_TMPBUF_SIZE (SGRPROJ_TMPBUF_SIZE)
+
 #define RESTORATION_HALFWIN 3
 #define RESTORATION_HALFWIN1 (RESTORATION_HALFWIN + 1)
 #define RESTORATION_WIN (2 * RESTORATION_HALFWIN + 1)
@@ -128,6 +136,7 @@
   int ntiles;
   int tile_width, tile_height;
   int nhtiles, nvtiles;
+  uint8_t *tmpbuf;
 } RestorationInternal;
 
 static INLINE int get_rest_tilesize(int width, int height) {
@@ -189,6 +198,10 @@
 
 extern const sgr_params_type sgr_params[SGRPROJ_PARAMS];
 
+int av1_alloc_restoration_struct(RestorationInfo *rst_info, int width,
+                                 int height);
+void av1_free_restoration_struct(RestorationInfo *rst_info);
+
 void av1_selfguided_restoration(int64_t *dgd, int width, int height, int stride,
                                 int bit_depth, int r, int eps, void *tmpbuf);
 void av1_domaintxfmrf_restoration(uint8_t *dgd, int width, int height,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index a64ba14..45e82ea 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2319,18 +2319,7 @@
   const int ntiles =
       av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
   if (rsi->frame_restoration_type != RESTORE_NONE) {
-    rsi->restoration_type = (RestorationType *)aom_realloc(
-        rsi->restoration_type, sizeof(*rsi->restoration_type) * ntiles);
     if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
-      rsi->wiener_info = (WienerInfo *)aom_realloc(
-          rsi->wiener_info, sizeof(*rsi->wiener_info) * ntiles);
-      assert(rsi->wiener_info != NULL);
-      rsi->sgrproj_info = (SgrprojInfo *)aom_realloc(
-          rsi->sgrproj_info, sizeof(*rsi->sgrproj_info) * ntiles);
-      assert(rsi->sgrproj_info != NULL);
-      rsi->domaintxfmrf_info = (DomaintxfmrfInfo *)aom_realloc(
-          rsi->domaintxfmrf_info, sizeof(*rsi->domaintxfmrf_info) * ntiles);
-      assert(rsi->domaintxfmrf_info != NULL);
       for (i = 0; i < ntiles; ++i) {
         rsi->restoration_type[i] =
             aom_read_tree(rb, av1_switchable_restore_tree,
@@ -2347,9 +2336,6 @@
         }
       }
     } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
-      rsi->wiener_info = (WienerInfo *)aom_realloc(
-          rsi->wiener_info, sizeof(*rsi->wiener_info) * ntiles);
-      assert(rsi->wiener_info != NULL);
       for (i = 0; i < ntiles; ++i) {
         if (aom_read(rb, RESTORE_NONE_WIENER_PROB, ACCT_STR)) {
           rsi->restoration_type[i] = RESTORE_WIENER;
@@ -2361,9 +2347,6 @@
         }
       }
     } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
-      rsi->sgrproj_info = (SgrprojInfo *)aom_realloc(
-          rsi->sgrproj_info, sizeof(*rsi->sgrproj_info) * ntiles);
-      assert(rsi->sgrproj_info != NULL);
       for (i = 0; i < ntiles; ++i) {
         if (aom_read(rb, RESTORE_NONE_SGRPROJ_PROB, ACCT_STR)) {
           rsi->restoration_type[i] = RESTORE_SGRPROJ;
@@ -2375,9 +2358,6 @@
         }
       }
     } else if (rsi->frame_restoration_type == RESTORE_DOMAINTXFMRF) {
-      rsi->domaintxfmrf_info = (DomaintxfmrfInfo *)aom_realloc(
-          rsi->domaintxfmrf_info, sizeof(*rsi->domaintxfmrf_info) * ntiles);
-      assert(rsi->domaintxfmrf_info != NULL);
       for (i = 0; i < ntiles; ++i) {
         if (aom_read(rb, RESTORE_NONE_DOMAINTXFMRF_PROB, ACCT_STR)) {
           rsi->restoration_type[i] = RESTORE_DOMAINTXFMRF;
@@ -3986,6 +3966,7 @@
   setup_clpf(pbi, rb);
 #endif
 #if CONFIG_LOOP_RESTORATION
+  av1_alloc_restoration_buffers(cm);
   decode_restoration_mode(cm, rb);
 #endif  // CONFIG_LOOP_RESTORATION
   setup_quantization(cm, rb);
@@ -4572,11 +4553,8 @@
     *p_data_end = decode_tiles(pbi, data + first_partition_size, data_end);
   }
 #if CONFIG_LOOP_RESTORATION
-  if (cm->rst_info.restoration_type != RESTORE_NONE) {
-    av1_loop_restoration_init(&cm->rst_internal, &cm->rst_info,
-                              cm->frame_type == KEY_FRAME, cm->width,
-                              cm->height);
-    av1_loop_restoration_rows(new_fb, cm, 0, cm->mi_rows, 0, NULL);
+  if (cm->rst_info.frame_restoration_type != RESTORE_NONE) {
+    av1_loop_restoration_frame(new_fb, cm, &cm->rst_info, 0, 0, NULL);
   }
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 3b95405..621f1f5 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -452,6 +452,7 @@
   aom_free_frame_buffer(&cpi->last_frame_uf);
 #if CONFIG_LOOP_RESTORATION
   aom_free_frame_buffer(&cpi->last_frame_db);
+  aom_free(cpi->highprec_srcbuf);
   av1_free_restoration_buffers(cm);
 #endif  // CONFIG_LOOP_RESTORATION
   aom_free_frame_buffer(&cpi->scaled_source);
@@ -726,6 +727,11 @@
                                NULL, NULL))
     aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
                        "Failed to allocate last frame deblocked buffer");
+  cpi->highprec_srcbuf = (uint8_t *)aom_realloc(
+      cpi->highprec_srcbuf, RESTORATION_TILEPELS_MAX * sizeof(int64_t));
+  if (!cpi->highprec_srcbuf)
+    aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
+                       "Failed to allocate highprec srcbuf for restoration");
 #endif  // CONFIG_LOOP_RESTORATION
 
   if (aom_realloc_frame_buffer(&cpi->scaled_source, cm->width, cm->height,
@@ -3428,14 +3434,11 @@
   }
 #endif
 #if CONFIG_LOOP_RESTORATION
-  if (cm->rst_info.restoration_type != RESTORE_NONE) {
-    av1_loop_restoration_init(&cm->rst_internal, &cm->rst_info,
-                              cm->frame_type == KEY_FRAME, cm->width,
-                              cm->height);
-    av1_loop_restoration_rows(cm->frame_to_show, cm, 0, cm->mi_rows, 0, NULL);
+  if (cm->rst_info.frame_restoration_type != RESTORE_NONE) {
+    av1_loop_restoration_frame(cm->frame_to_show, cm, &cm->rst_info, 0, 0,
+                               NULL);
   }
 #endif  // CONFIG_LOOP_RESTORATION
-
   aom_extend_frame_inner_borders(cm->frame_to_show);
 }
 
@@ -3830,6 +3833,9 @@
 
   alloc_util_frame_buffers(cpi);
   init_motion_estimation(cpi);
+#if CONFIG_LOOP_RESTORATION
+  av1_alloc_restoration_buffers(cm);
+#endif  // CONFIG_LOOP_RESTORATION
 
   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
     RefBuffer *const ref_buf = &cm->frame_refs[ref_frame - LAST_FRAME];
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index ed4f694..ea591e2 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -400,6 +400,7 @@
   YV12_BUFFER_CONFIG last_frame_uf;
 #if CONFIG_LOOP_RESTORATION
   YV12_BUFFER_CONFIG last_frame_db;
+  uint8_t *highprec_srcbuf;
 #endif  // CONFIG_LOOP_RESTORATION
 
   // Ambient reconstruction err target for force key frames
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 3b25efa..9222906 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -40,25 +40,39 @@
 static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                     const YV12_BUFFER_CONFIG *dst,
                                     const AV1_COMMON *cm, int h_start,
-                                    int width, int v_start, int height) {
+                                    int width, int v_start, int height,
+                                    int y_only) {
   int64_t filt_err;
 #if CONFIG_AOM_HIGHBITDEPTH
   if (cm->use_highbitdepth) {
     filt_err =
         aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
-  } else {
-    filt_err = aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
+    if (!y_only) {
+      filt_err += aom_highbd_get_u_sse_part(
+          src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
+          v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+      filt_err += aom_highbd_get_v_sse_part(
+          src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
+          v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+    }
+    return filt_err;
   }
-#else
-  (void)cm;
-  filt_err = aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
+  filt_err = aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
+  if (!y_only) {
+    filt_err += aom_get_u_sse_part(
+        src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
+        v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+    filt_err += aom_get_u_sse_part(
+        src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
+        v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+  }
   return filt_err;
 }
 
 static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
-                                    int partial_frame, int tile_idx,
+                                    int y_only, int partial_frame, int tile_idx,
                                     int subtile_idx, int subtile_bits,
                                     YV12_BUFFER_CONFIG *dst_frame) {
   AV1_COMMON *const cm = &cpi->common;
@@ -69,36 +83,41 @@
                                          &tile_height, &nhtiles, &nvtiles);
   (void)ntiles;
 
-  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, 1, partial_frame,
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
                              dst_frame);
   av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
                            nvtiles, tile_width, tile_height, cm->width,
                            cm->height, 0, 0, &h_start, &h_end, &v_start,
                            &v_end);
   filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
-                                  v_start, v_end - v_start);
+                                  v_start, v_end - v_start, y_only);
 
   return filt_err;
 }
 
 static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *const cpi, RestorationInfo *rsi,
-                                     int partial_frame,
+                                     int y_only, int partial_frame,
                                      YV12_BUFFER_CONFIG *dst_frame) {
   AV1_COMMON *const cm = &cpi->common;
   int64_t filt_err;
-  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, 1, partial_frame,
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
                              dst_frame);
 #if CONFIG_AOM_HIGHBITDEPTH
   if (cm->use_highbitdepth) {
     filt_err = aom_highbd_get_y_sse(src, dst_frame);
-  } else {
-    filt_err = aom_get_y_sse(src, dst_frame);
+    if (!y_only) {
+      filt_err += aom_highbd_get_u_sse(src, dst_frame);
+      filt_err += aom_highbd_get_v_sse(src, dst_frame);
+    }
+    return filt_err;
   }
-#else
-  filt_err = aom_get_y_sse(src, dst_frame);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-
+  filt_err = aom_get_y_sse(src, dst_frame);
+  if (!y_only) {
+    filt_err += aom_get_u_sse(src, dst_frame);
+    filt_err += aom_get_v_sse(src, dst_frame);
+  }
   return filt_err;
 }
 
@@ -177,9 +196,10 @@
 static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                           int dat_stride, uint8_t *src8,
                                           int src_stride, int bit_depth,
-                                          int *eps, int *xqd, void *tmpbuf) {
-  int64_t *srd = (int64_t *)tmpbuf;
-  int64_t *dgd = srd + RESTORATION_TILEPELS_MAX;
+                                          int *eps, int *xqd, void *srcbuf,
+                                          void *tmpbuf) {
+  int64_t *srd = (int64_t *)srcbuf;
+  int64_t *dgd = (int64_t *)tmpbuf;
   int64_t *flt1 = dgd + RESTORATION_TILEPELS_MAX;
   int64_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
   uint8_t *tmpbuf2 = (uint8_t *)(flt2 + RESTORATION_TILEPELS_MAX);
@@ -249,8 +269,7 @@
   RestorationInfo rsi;
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE +
-                               RESTORATION_TILEPELS_MAX * sizeof(int64_t) * 2);
+  // Allocate for the src buffer at high precision
   const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                          &tile_height, &nhtiles, &nvtiles);
   //  Make a copy of the unfiltered / processed recon buffer
@@ -272,7 +291,7 @@
                              tile_height, cm->width, cm->height, 0, 0, &h_start,
                              &h_end, &v_start, &v_end);
     err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
-                               h_end - h_start, v_start, v_end - v_start);
+                               h_end - h_start, v_start, v_end - v_start, 1);
     // #bits when a tile is not restored
     bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
     cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
@@ -286,9 +305,10 @@
 #else
         8,
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-        &rsi.sgrproj_info[tile_idx].ep, rsi.sgrproj_info[tile_idx].xqd, tmpbuf);
+        &rsi.sgrproj_info[tile_idx].ep, rsi.sgrproj_info[tile_idx].xqd,
+        cpi->highprec_srcbuf, cm->rst_internal.tmpbuf);
     rsi.sgrproj_info[tile_idx].level = 1;
-    err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0,
+    err = try_restoration_tile(src, cpi, &rsi, 1, partial_frame, tile_idx, 0, 0,
                                dst_frame);
     bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
     bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
@@ -317,11 +337,10 @@
       bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
     }
   }
-  err = try_restoration_frame(src, cpi, &rsi, partial_frame, dst_frame);
+  err = try_restoration_frame(src, cpi, &rsi, 1, partial_frame, dst_frame);
   cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   aom_free(rsi.sgrproj_info);
-  aom_free(tmpbuf);
 
   aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
   return cost_sgrproj;
@@ -371,8 +390,7 @@
   int64_t best_sse = INT64_MAX, sse;
   if (bit_depth == 8) {
     uint8_t *tmp = (uint8_t *)aom_malloc(width * height * sizeof(*tmp));
-    int32_t *tmpbuf =
-        (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf));
+    int32_t *tmpbuf = (int32_t *)aom_malloc(DOMAINTXFMRF_TMPBUF_SIZE);
     uint8_t *dgd = dgd8;
     uint8_t *src = src8;
     // First phase
@@ -412,11 +430,11 @@
       }
     }
     aom_free(tmp);
+    aom_free(tmpbuf);
   } else {
 #if CONFIG_AOM_HIGHBITDEPTH
     uint16_t *tmp = (uint16_t *)aom_malloc(width * height * sizeof(*tmp));
-    int32_t *tmpbuf =
-        (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf));
+    int32_t *tmpbuf = (int32_t *)aom_malloc(DOMAINTXFMRF_TMPBUF_SIZE);
     uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
     uint16_t *src = CONVERT_TO_SHORTPTR(src8);
     // First phase
@@ -456,6 +474,7 @@
       }
     }
     aom_free(tmp);
+    aom_free(tmpbuf);
 #else
     assert(0);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
@@ -498,7 +517,7 @@
                              tile_height, cm->width, cm->height, 0, 0, &h_start,
                              &h_end, &v_start, &v_end);
     err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
-                               h_end - h_start, v_start, v_end - v_start);
+                               h_end - h_start, v_start, v_end - v_start, 1);
     // #bits when a tile is not restored
     bits = av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 0);
     cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
@@ -516,7 +535,7 @@
         &rsi.domaintxfmrf_info[tile_idx].sigma_r);
 
     rsi.domaintxfmrf_info[tile_idx].level = 1;
-    err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0,
+    err = try_restoration_tile(src, cpi, &rsi, 1, partial_frame, tile_idx, 0, 0,
                                dst_frame);
     bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
     bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 1);
@@ -546,7 +565,7 @@
       bits += (DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT);
     }
   }
-  err = try_restoration_frame(src, cpi, &rsi, partial_frame, dst_frame);
+  err = try_restoration_frame(src, cpi, &rsi, 1, partial_frame, dst_frame);
   cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   aom_free(rsi.domaintxfmrf_info);
@@ -904,7 +923,7 @@
                              tile_height, width, height, 0, 0, &h_start, &h_end,
                              &v_start, &v_end);
     err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
-                               h_end - h_start, v_start, v_end - v_start);
+                               h_end - h_start, v_start, v_end - v_start, 1);
     // #bits when a tile is not restored
     bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
     cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
@@ -941,7 +960,7 @@
     }
 
     rsi.wiener_info[tile_idx].level = 1;
-    err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0,
+    err = try_restoration_tile(src, cpi, &rsi, 1, partial_frame, tile_idx, 0, 0,
                                dst_frame);
     bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
     bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
@@ -975,7 +994,7 @@
       }
     }
   }
-  err = try_restoration_frame(src, cpi, &rsi, partial_frame, dst_frame);
+  err = try_restoration_frame(src, cpi, &rsi, 1, partial_frame, dst_frame);
   cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   aom_free(rsi.wiener_info);
@@ -1010,14 +1029,14 @@
                              tile_height, cm->width, cm->height, 0, 0, &h_start,
                              &h_end, &v_start, &v_end);
     err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
-                               h_end - h_start, v_start, v_end - v_start);
+                               h_end - h_start, v_start, v_end - v_start, 1);
     best_tile_cost[tile_idx] =
         RDCOST_DBL(x->rdmult, x->rddiv,
                    (cpi->switchable_restore_cost[RESTORE_NONE] >> 4), err);
   }
   // RD cost associated with no restoration
   err = sse_restoration_tile(src, cm->frame_to_show, cm, 0, cm->width, 0,
-                             cm->height);
+                             cm->height, 1);
   bits = frame_level_restore_bits[RESTORE_NONE] << AV1_PROB_COST_SHIFT;
   cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
   aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
@@ -1074,19 +1093,6 @@
 
   const int ntiles =
       av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
-  cm->rst_info.restoration_type = (RestorationType *)aom_realloc(
-      cm->rst_info.restoration_type,
-      sizeof(*cm->rst_info.restoration_type) * ntiles);
-  cm->rst_info.wiener_info = (WienerInfo *)aom_realloc(
-      cm->rst_info.wiener_info, sizeof(*cm->rst_info.wiener_info) * ntiles);
-  assert(cm->rst_info.wiener_info != NULL);
-  cm->rst_info.sgrproj_info = (SgrprojInfo *)aom_realloc(
-      cm->rst_info.sgrproj_info, sizeof(*cm->rst_info.sgrproj_info) * ntiles);
-  assert(cm->rst_info.sgrproj_info != NULL);
-  cm->rst_info.domaintxfmrf_info = (DomaintxfmrfInfo *)aom_realloc(
-      cm->rst_info.domaintxfmrf_info,
-      sizeof(*cm->rst_info.domaintxfmrf_info) * ntiles);
-  assert(cm->rst_info.domaintxfmrf_info != NULL);
 
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++)
     tile_cost[r] = (double *)aom_malloc(sizeof(*tile_cost[0]) * ntiles);
@@ -1166,4 +1172,5 @@
          cost_restore[2], cost_restore[3], cost_restore[4]);
          */
   for (r = 0; r < RESTORE_SWITCHABLE_TYPES; r++) aom_free(tile_cost[r]);
+  aom_free_frame_buffer(&dst_frame);
 }