Misc cleanups and enhancements on loop restoration

Includes:
Some cleanups/refactoring
Better buffer management.
Some preps for future chrominance restoration.

Change-Id: Ia264b8989b5f4a53c0764ed3e8258ddc212723fc
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index 86051c8..31c1bf7 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -87,15 +87,17 @@
 }
 
 #if CONFIG_LOOP_RESTORATION
+void av1_alloc_restoration_buffers(AV1_COMMON *cm) {
+  av1_alloc_restoration_struct(&cm->rst_info, cm->width, cm->height);
+  cm->rst_internal.tmpbuf =
+      (uint8_t *)aom_realloc(cm->rst_internal.tmpbuf, RESTORATION_TMPBUF_SIZE);
+  assert(cm->rst_internal.tmpbuf != NULL);
+}
+
 void av1_free_restoration_buffers(AV1_COMMON *cm) {
-  aom_free(cm->rst_info.restoration_type);
-  cm->rst_info.restoration_type = NULL;
-  aom_free(cm->rst_info.wiener_info);
-  cm->rst_info.wiener_info = NULL;
-  aom_free(cm->rst_info.sgrproj_info);
-  cm->rst_info.sgrproj_info = NULL;
-  aom_free(cm->rst_info.domaintxfmrf_info);
-  cm->rst_info.domaintxfmrf_info = NULL;
+  av1_free_restoration_struct(&cm->rst_info);
+  aom_free(cm->rst_internal.tmpbuf);
+  cm->rst_internal.tmpbuf = NULL;
 }
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/common/alloccommon.h b/av1/common/alloccommon.h
index 0a0c38c..51863cd 100644
--- a/av1/common/alloccommon.h
+++ b/av1/common/alloccommon.h
@@ -29,6 +29,7 @@
 
 void av1_free_ref_frame_buffers(struct BufferPool *pool);
 #if CONFIG_LOOP_RESTORATION
+void av1_alloc_restoration_buffers(struct AV1Common *cm);
 void av1_free_restoration_buffers(struct AV1Common *cm);
 #endif  // CONFIG_LOOP_RESTORATION
 
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index c4f79c3..e905a8c 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -50,6 +50,35 @@
                                          int dst_stride);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
+int av1_alloc_restoration_struct(RestorationInfo *rst_info, int width,
+                                 int height) {
+  const int ntiles = av1_get_rest_ntiles(width, height, NULL, NULL, NULL, NULL);
+  rst_info->restoration_type = (RestorationType *)aom_realloc(
+      rst_info->restoration_type, sizeof(*rst_info->restoration_type) * ntiles);
+  rst_info->wiener_info = (WienerInfo *)aom_realloc(
+      rst_info->wiener_info, sizeof(*rst_info->wiener_info) * ntiles);
+  assert(rst_info->wiener_info != NULL);
+  rst_info->sgrproj_info = (SgrprojInfo *)aom_realloc(
+      rst_info->sgrproj_info, sizeof(*rst_info->sgrproj_info) * ntiles);
+  assert(rst_info->sgrproj_info != NULL);
+  rst_info->domaintxfmrf_info = (DomaintxfmrfInfo *)aom_realloc(
+      rst_info->domaintxfmrf_info,
+      sizeof(*rst_info->domaintxfmrf_info) * ntiles);
+  assert(rst_info->domaintxfmrf_info != NULL);
+  return ntiles;
+}
+
+void av1_free_restoration_struct(RestorationInfo *rst_info) {
+  aom_free(rst_info->restoration_type);
+  rst_info->restoration_type = NULL;
+  aom_free(rst_info->wiener_info);
+  rst_info->wiener_info = NULL;
+  aom_free(rst_info->sgrproj_info);
+  rst_info->sgrproj_info = NULL;
+  aom_free(rst_info->domaintxfmrf_info);
+  rst_info->domaintxfmrf_info = NULL;
+}
+
 static void GenDomainTxfmRFVtable() {
   int i, j;
   const double sigma_s = sqrt(2.0);
@@ -520,15 +549,16 @@
 
 static void loop_sgrproj_filter_tile(uint8_t *data, int tile_idx, int width,
                                      int height, int stride,
-                                     RestorationInternal *rst, void *tmpbuf,
-                                     uint8_t *dst, int dst_stride) {
+                                     RestorationInternal *rst, uint8_t *dst,
+                                     int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int i, j;
   int h_start, h_end, v_start, v_end;
   uint8_t *data_p, *dst_p;
-  int64_t *dat = (int64_t *)tmpbuf;
-  tmpbuf = (uint8_t *)tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
+  int64_t *dat = (int64_t *)rst->tmpbuf;
+  uint8_t *tmpbuf =
+      (uint8_t *)rst->tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
 
   if (rst->rsi->sgrproj_info[tile_idx].level == 0) {
     loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -561,12 +591,10 @@
                                 int stride, RestorationInternal *rst,
                                 uint8_t *dst, int dst_stride) {
   int tile_idx;
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, tmpbuf,
-                             dst, dst_stride);
+    loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, dst,
+                             dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void apply_domaintxfmrf_hor(int iter, int param, uint8_t *img, int width,
@@ -664,11 +692,11 @@
 static void loop_domaintxfmrf_filter_tile(uint8_t *data, int tile_idx,
                                           int width, int height, int stride,
                                           RestorationInternal *rst,
-                                          uint8_t *dst, int dst_stride,
-                                          int32_t *tmpbuf) {
+                                          uint8_t *dst, int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int h_start, h_end, v_start, v_end;
+  int32_t *tmpbuf = (int32_t *)rst->tmpbuf;
 
   if (rst->rsi->domaintxfmrf_info[tile_idx].level == 0) {
     loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -688,23 +716,16 @@
                                      int stride, RestorationInternal *rst,
                                      uint8_t *dst, int dst_stride) {
   int tile_idx;
-  int32_t *tmpbuf =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf));
-
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
     loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
-                                  dst, dst_stride, tmpbuf);
+                                  dst, dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void loop_switchable_filter(uint8_t *data, int width, int height,
                                    int stride, RestorationInternal *rst,
                                    uint8_t *dst, int dst_stride) {
   int tile_idx;
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
-  int32_t *tmpbuf32 =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf32));
   extend_frame(data, width, height, stride);
   copy_border(data, width, height, stride, dst, dst_stride);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
@@ -715,15 +736,13 @@
       loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst, dst,
                               dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ) {
-      loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst,
-                               tmpbuf, dst, dst_stride);
+      loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, dst,
+                               dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_DOMAINTXFMRF) {
       loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
-                                    dst, dst_stride, tmpbuf32);
+                                    dst, dst_stride);
     }
   }
-  aom_free(tmpbuf);
-  aom_free(tmpbuf32);
 }
 
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -830,7 +849,7 @@
                                       int dst_stride) {
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
-  int tile_idx, i;
+  int tile_idx;
   copy_border_highbd(data, width, height, stride, dst, dst_stride);
   extend_frame_highbd(data, width, height, stride);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
@@ -842,15 +861,16 @@
 static void loop_sgrproj_filter_tile_highbd(uint16_t *data, int tile_idx,
                                             int width, int height, int stride,
                                             RestorationInternal *rst,
-                                            int bit_depth, void *tmpbuf,
-                                            uint16_t *dst, int dst_stride) {
+                                            int bit_depth, uint16_t *dst,
+                                            int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int i, j;
   int h_start, h_end, v_start, v_end;
   uint16_t *data_p, *dst_p;
-  int64_t *dat = (int64_t *)tmpbuf;
-  tmpbuf = (uint8_t *)tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
+  int64_t *dat = (int64_t *)rst->tmpbuf;
+  uint8_t *tmpbuf =
+      (uint8_t *)rst->tmpbuf + RESTORATION_TILEPELS_MAX * sizeof(*dat);
 
   if (rst->rsi->sgrproj_info[tile_idx].level == 0) {
     loop_copy_tile_highbd(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -885,13 +905,11 @@
                                        int dst_stride) {
   int tile_idx;
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
     loop_sgrproj_filter_tile_highbd(data, tile_idx, width, height, stride, rst,
-                                    bit_depth, tmpbuf, dst, dst_stride);
+                                    bit_depth, dst, dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void apply_domaintxfmrf_hor_highbd(int iter, int param, uint16_t *img,
@@ -989,11 +1007,11 @@
 
 static void loop_domaintxfmrf_filter_tile_highbd(
     uint16_t *data, int tile_idx, int width, int height, int stride,
-    RestorationInternal *rst, int bit_depth, uint16_t *dst, int dst_stride,
-    int32_t *tmpbuf) {
+    RestorationInternal *rst, int bit_depth, uint16_t *dst, int dst_stride) {
   const int tile_width = rst->tile_width >> rst->subsampling_x;
   const int tile_height = rst->tile_height >> rst->subsampling_y;
   int h_start, h_end, v_start, v_end;
+  int32_t *tmpbuf = (int32_t *)rst->tmpbuf;
 
   if (rst->rsi->domaintxfmrf_info[tile_idx].level == 0) {
     loop_copy_tile_highbd(data, tile_idx, 0, 0, width, height, stride, rst, dst,
@@ -1015,16 +1033,12 @@
                                             int bit_depth, uint8_t *dst8,
                                             int dst_stride) {
   int tile_idx;
-  int32_t *tmpbuf =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf));
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
     loop_domaintxfmrf_filter_tile_highbd(data, tile_idx, width, height, stride,
-                                         rst, bit_depth, dst, dst_stride,
-                                         tmpbuf);
+                                         rst, bit_depth, dst, dst_stride);
   }
-  aom_free(tmpbuf);
 }
 
 static void loop_switchable_filter_highbd(uint8_t *data8, int width, int height,
@@ -1032,11 +1046,8 @@
                                           int bit_depth, uint8_t *dst8,
                                           int dst_stride) {
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
-  uint8_t *tmpbuf = aom_malloc(SGRPROJ_TMPBUF_SIZE);
-  int32_t *tmpbuf32 =
-      (int32_t *)aom_malloc(RESTORATION_TILEPELS_MAX * sizeof(*tmpbuf32));
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
-  int i, tile_idx;
+  int tile_idx;
   copy_border_highbd(data, width, height, stride, dst, dst_stride);
   extend_frame_highbd(data, width, height, stride);
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
@@ -1048,15 +1059,13 @@
                                      bit_depth, dst, dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ) {
       loop_sgrproj_filter_tile_highbd(data, tile_idx, width, height, stride,
-                                      rst, bit_depth, tmpbuf, dst, dst_stride);
+                                      rst, bit_depth, dst, dst_stride);
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_DOMAINTXFMRF) {
       loop_domaintxfmrf_filter_tile_highbd(data, tile_idx, width, height,
                                            stride, rst, bit_depth, dst,
-                                           dst_stride, tmpbuf32);
+                                           dst_stride);
     }
   }
-  aom_free(tmpbuf);
-  aom_free(tmpbuf32);
 }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 5773c77..2274fa8 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -38,10 +38,15 @@
 #define DOMAINTXFMRF_VTABLE_PREC (1 << DOMAINTXFMRF_VTABLE_PRECBITS)
 #define DOMAINTXFMRF_MULT \
   sqrt(((1 << (DOMAINTXFMRF_ITERS * 2)) - 1) * 2.0 / 3.0)
-#define DOMAINTXFMRF_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX)
+// A single 32 bit buffer needed for the filter
+#define DOMAINTXFMRF_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX * sizeof(int32_t))
 #define DOMAINTXFMRF_BITS (DOMAINTXFMRF_PARAMS_BITS)
 
-#define SGRPROJ_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX * 6 * 8)
+// 6 highprecision 64-bit buffers needed for the filter:
+// 1 for the degraded frame, 2 for the restored versions and
+// 3 for each restoration operation
+// TODO(debargha): Explore if we can use 32-bit buffers
+#define SGRPROJ_TMPBUF_SIZE (RESTORATION_TILEPELS_MAX * 6 * sizeof(int64_t))
 #define SGRPROJ_PARAMS_BITS 3
 #define SGRPROJ_PARAMS (1 << SGRPROJ_PARAMS_BITS)
 
@@ -60,6 +65,9 @@
 
 #define SGRPROJ_BITS (SGRPROJ_PRJ_BITS * 2 + SGRPROJ_PARAMS_BITS)
 
+// Max of SGRPROJ_TMPBUF_SIZE and DOMAINTXFMRF_TMPBUF_SIZE
+#define RESTORATION_TMPBUF_SIZE (SGRPROJ_TMPBUF_SIZE)
+
 #define RESTORATION_HALFWIN 3
 #define RESTORATION_HALFWIN1 (RESTORATION_HALFWIN + 1)
 #define RESTORATION_WIN (2 * RESTORATION_HALFWIN + 1)
@@ -128,6 +136,7 @@
   int ntiles;
   int tile_width, tile_height;
   int nhtiles, nvtiles;
+  uint8_t *tmpbuf;
 } RestorationInternal;
 
 static INLINE int get_rest_tilesize(int width, int height) {
@@ -189,6 +198,10 @@
 
 extern const sgr_params_type sgr_params[SGRPROJ_PARAMS];
 
+int av1_alloc_restoration_struct(RestorationInfo *rst_info, int width,
+                                 int height);
+void av1_free_restoration_struct(RestorationInfo *rst_info);
+
 void av1_selfguided_restoration(int64_t *dgd, int width, int height, int stride,
                                 int bit_depth, int r, int eps, void *tmpbuf);
 void av1_domaintxfmrf_restoration(uint8_t *dgd, int width, int height,