Stop loop rest units from straddling tile boundaries

With this patch, restoration units are allocated within each tile as
if it were its own image. Arrays of information that need one entry
per restoration unit are laid out in tiles, with rsi->units_per_tile
units for each tile.

Change-Id: I485c17166f33e24d281079b3138b76f98f0fe081
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index 67ad9f1..b135bfb 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -132,9 +132,15 @@
 }
 
 #if CONFIG_LOOP_RESTORATION
-// Assumes cm->rst_info[p].restoration_tilesize is already initialized
+// Assumes cm->rst_info[p].restoration_unit_size is already initialized
 void av1_alloc_restoration_buffers(AV1_COMMON *cm) {
-  int p;
+  for (int p = 0; p < MAX_MB_PLANE; ++p)
+    av1_alloc_restoration_struct(cm, &cm->rst_info[p], p > 0);
+  aom_free(cm->rst_tmpbuf);
+  CHECK_MEM_ERROR(cm, cm->rst_tmpbuf,
+                  (int32_t *)aom_memalign(16, RESTORATION_TMPBUF_SIZE));
+
+#if CONFIG_STRIPED_LOOP_RESTORATION
 #if CONFIG_FRAME_SUPERRES
   int width = cm->superres_upscaled_width;
   int height = cm->superres_upscaled_height;
@@ -142,18 +148,9 @@
   int width = cm->width;
   int height = cm->height;
 #endif  // CONFIG_FRAME_SUPERRES
-  av1_alloc_restoration_struct(cm, &cm->rst_info[0], width, height);
-  for (p = 1; p < MAX_MB_PLANE; ++p)
-    av1_alloc_restoration_struct(cm, &cm->rst_info[p],
-                                 ROUND_POWER_OF_TWO(width, cm->subsampling_x),
-                                 ROUND_POWER_OF_TWO(height, cm->subsampling_y));
-  aom_free(cm->rst_tmpbuf);
-  CHECK_MEM_ERROR(cm, cm->rst_tmpbuf,
-                  (int32_t *)aom_memalign(16, RESTORATION_TMPBUF_SIZE));
 
-#if CONFIG_STRIPED_LOOP_RESTORATION
   // Allocate internal storage for the loop restoration stripe boundary lines
-  for (p = 0; p < MAX_MB_PLANE; ++p) {
+  for (int p = 0; p < MAX_MB_PLANE; ++p) {
     int w = p == 0 ? width : ROUND_POWER_OF_TWO(width, cm->subsampling_x);
     int align_bits = 5;  // align for efficiency
     int stride = ALIGN_POWER_OF_TWO(w + 2 * RESTORATION_EXTRA_HORZ, align_bits);
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index eb103f3..9c09b5d 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -16,6 +16,9 @@
 #include "./aom_dsp_rtcd.h"
 #include "./aom_scale_rtcd.h"
 #include "av1/common/onyxc_int.h"
+#if CONFIG_FRAME_SUPERRES
+#include "av1/common/resize.h"
+#endif
 #include "av1/common/restoration.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_mem/aom_mem.h"
@@ -45,15 +48,91 @@
 #endif
 };
 
-int av1_alloc_restoration_struct(AV1_COMMON *cm, RestorationInfo *rst_info,
-                                 int width, int height) {
-  const int ntiles = av1_get_rest_ntiles(
-      width, height, rst_info->restoration_tilesize, NULL, NULL);
-  aom_free(rst_info->unit_info);
-  CHECK_MEM_ERROR(
-      cm, rst_info->unit_info,
-      (RestorationUnitInfo *)aom_malloc(sizeof(*rst_info->unit_info) * ntiles));
-  return ntiles;
+#if CONFIG_MAX_TILE
+static void tile_width_and_height(const AV1_COMMON *cm, int is_uv, int sb_w,
+                                  int sb_h, int *px_w, int *px_h) {
+  const int scaled_sb_w = sb_w << MAX_MIB_SIZE_LOG2;
+  const int scaled_sb_h = sb_h << MAX_MIB_SIZE_LOG2;
+
+  const int ss_x = is_uv && cm->subsampling_x;
+  const int ss_y = is_uv && cm->subsampling_y;
+
+  *px_w = (scaled_sb_w + ss_x) >> ss_x;
+  *px_h = (scaled_sb_h + ss_y) >> ss_y;
+#if CONFIG_FRAME_SUPERRES
+  if (!av1_superres_unscaled(cm)) {
+    av1_calculate_unscaled_superres_size(px_w, px_h,
+                                         cm->superres_scale_denominator);
+  }
+#endif  // CONFIG_FRAME_SUPERRES
+}
+#endif  // CONFIG_MAX_TILE
+
+// Count horizontal or vertical units per tile (use a width or height for
+// tile_size, respectively). We basically want to divide the tile size by the
+// size of a restoration unit. Rather than rounding up unconditionally as you
+// might expect, we round to nearest, which models the way a right or bottom
+// restoration unit can extend to up to 150% its normal width or height. The
+// max with 1 is to deal with tiles that are smaller than half of a restoration
+// unit.
+static int count_units_in_tile(int unit_size, int tile_size) {
+  return AOMMAX((tile_size + (unit_size >> 1)) / unit_size, 1);
+}
+
+void av1_alloc_restoration_struct(AV1_COMMON *cm, RestorationInfo *rsi,
+                                  int is_uv) {
+#if CONFIG_MAX_TILE
+  // We need to allocate enough space for restoration units to cover the
+  // largest tile. Without CONFIG_MAX_TILE, this is always the tile at the
+  // top-left and we can use av1_get_tile_rect. With CONFIG_MAX_TILE, we have
+  // to do the computation ourselves, iterating over the tiles and keeping
+  // track of the largest width and height, then upscaling.
+  int max_sb_w = 0;
+  int max_sb_h = 0;
+  for (int i = 0; i < cm->tile_cols; ++i) {
+    const int sb_w = cm->tile_col_start_sb[i + 1] - cm->tile_col_start_sb[i];
+    max_sb_w = AOMMAX(max_sb_w, sb_w);
+  }
+  for (int i = 0; i < cm->tile_rows; ++i) {
+    const int sb_h = cm->tile_row_start_sb[i + 1] - cm->tile_row_start_sb[i];
+    max_sb_h = AOMMAX(max_sb_h, sb_h);
+  }
+
+  int max_tile_w, max_tile_h;
+  tile_width_and_height(cm, is_uv, max_sb_w, max_sb_h, &max_tile_w,
+                        &max_tile_h);
+#else
+  TileInfo tile_info;
+  av1_tile_init(&tile_info, cm, 0, 0);
+
+  const AV1PixelRect tile_rect = av1_get_tile_rect(&tile_info, cm, is_uv);
+  assert(tile_rect.left == 0 && tile_rect.top == 0);
+
+  const int max_tile_w = tile_rect.right;
+  const int max_tile_h = tile_rect.bottom;
+#endif  // CONFIG_MAX_TILE
+
+  // To calculate hpertile and vpertile (horizontal and vertical units per
+  // tile), we basically want to divide the largest tile width or height by the
+  // size of a restoration unit. Rather than rounding up unconditionally as you
+  // might expect, we round to nearest, which models the way a right or bottom
+  // restoration unit can extend to up to 150% its normal width or height. The
+  // max with 1 is to deal with tiles that are smaller than half of a
+  // restoration unit.
+  const int unit_size = rsi->restoration_unit_size;
+  const int hpertile = count_units_in_tile(unit_size, max_tile_w);
+  const int vpertile = count_units_in_tile(unit_size, max_tile_h);
+
+  rsi->units_per_tile = hpertile * vpertile;
+  rsi->horz_units_per_tile = hpertile;
+  rsi->vert_units_per_tile = vpertile;
+
+  const int ntiles = cm->tile_rows * cm->tile_cols;
+  const int nunits = ntiles * rsi->units_per_tile;
+
+  aom_free(rsi->unit_info);
+  CHECK_MEM_ERROR(cm, rsi->unit_info, (RestorationUnitInfo *)aom_malloc(
+                                          sizeof(*rsi->unit_info) * nunits));
 }
 
 void av1_free_restoration_struct(RestorationInfo *rst_info) {
@@ -1405,28 +1484,31 @@
   { RESTORATION_BORDER_HORZ, RESTORATION_BORDER_VERT }
 };
 
-static RestorationTileLimits get_rest_tile_limits(int tile_idx, int nhtiles,
-                                                  int nvtiles, int rtile_size,
-                                                  int im_width, int im_height,
-                                                  int subsampling_y) {
-  const int htile_idx = tile_idx % nhtiles;
-  const int vtile_idx = tile_idx / nhtiles;
-  RestorationTileLimits limits;
-  limits.h_start = htile_idx * rtile_size;
-  limits.v_start = vtile_idx * rtile_size;
-  limits.h_end =
-      (htile_idx < nhtiles - 1) ? limits.h_start + rtile_size : im_width;
-  limits.v_end =
-      (vtile_idx < nvtiles - 1) ? limits.v_start + rtile_size : im_height;
+typedef struct {
+  const RestorationInfo *rsi;
 #if CONFIG_STRIPED_LOOP_RESTORATION
-  // Offset the tile upwards to align with the restoration processing stripe
-  const int voffset = RESTORATION_TILE_OFFSET >> subsampling_y;
-  limits.v_start = AOMMAX(0, limits.v_start - voffset);
-  if (limits.v_end < im_height) limits.v_end -= voffset;
-#else
-  (void)subsampling_y;
+  RestorationLineBuffers *rlbs;
+  int ss_y;
 #endif
-  return limits;
+  int highbd, bit_depth;
+  uint8_t *data8, *dst8;
+  int data_stride, dst_stride;
+  int32_t *tmpbuf;
+} FilterFrameCtxt;
+
+static void filter_frame_on_unit(const RestorationTileLimits *limits,
+                                 int rest_unit_idx, void *priv) {
+  FilterFrameCtxt *ctxt = (FilterFrameCtxt *)priv;
+  const RestorationInfo *rsi = ctxt->rsi;
+
+  av1_loop_restoration_filter_unit(limits, &rsi->unit_info[rest_unit_idx],
+#if CONFIG_STRIPED_LOOP_RESTORATION
+                                   &rsi->boundaries, ctxt->rlbs, ctxt->ss_y,
+#endif
+                                   rsi->procunit_width, rsi->procunit_height,
+                                   ctxt->highbd, ctxt->bit_depth, ctxt->data8,
+                                   ctxt->data_stride, ctxt->dst8,
+                                   ctxt->dst_stride, ctxt->tmpbuf);
 }
 
 void av1_loop_restoration_filter_frame(YV12_BUFFER_CONFIG *frame,
@@ -1493,36 +1575,32 @@
     }
 
     const int is_uv = plane > 0;
-    const int ss_y = is_uv && cm->subsampling_y;
-
     const int plane_width = frame->crop_widths[is_uv];
     const int plane_height = frame->crop_heights[is_uv];
 
-    int nhtiles, nvtiles;
-    const int ntiles =
-        av1_get_rest_ntiles(plane_width, plane_height,
-                            prsi->restoration_tilesize, &nhtiles, &nvtiles);
-
     const struct restore_borders *borders =
         &restore_borders[prsi->frame_restoration_type];
     extend_frame(frame->buffers[plane], plane_width, plane_height,
                  frame->strides[is_uv], borders->hborder, borders->vborder,
                  highbd);
 
-    for (int tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-      RestorationTileLimits limits = get_rest_tile_limits(
-          tile_idx, nhtiles, nvtiles, prsi->restoration_tilesize, plane_width,
-          plane_height, ss_y);
-
-      av1_loop_restoration_filter_unit(
-          &limits, &prsi->unit_info[tile_idx],
+    FilterFrameCtxt ctxt;
+    ctxt.rsi = prsi;
 #if CONFIG_STRIPED_LOOP_RESTORATION
-          &prsi->boundaries, &rlbs, ss_y,
+    const int ss_y = is_uv && cm->subsampling_y;
+    ctxt.rlbs = &rlbs;
+    ctxt.ss_y = ss_y;
 #endif
-          prsi->procunit_width, prsi->procunit_height, highbd, bit_depth,
-          frame->buffers[plane], frame->strides[is_uv], dst->buffers[plane],
-          dst->strides[is_uv], cm->rst_tmpbuf);
-    }
+    ctxt.highbd = highbd;
+    ctxt.bit_depth = bit_depth;
+    ctxt.data8 = frame->buffers[plane];
+    ctxt.dst8 = dst->buffers[plane];
+    ctxt.data_stride = frame->strides[is_uv];
+    ctxt.dst_stride = dst->strides[is_uv];
+    ctxt.tmpbuf = cm->rst_tmpbuf;
+
+    av1_foreach_rest_unit_in_frame(cm, plane, NULL, filter_frame_on_unit,
+                                   &ctxt);
   }
 
   if (dst == &dst_) {
@@ -1536,25 +1614,54 @@
 }
 
 static void foreach_rest_unit_in_tile(const AV1PixelRect *tile_rect,
-                                      int nunits_x, int nunits_y, int unit_size,
-                                      int plane_w, int plane_h, int ss_y,
+                                      int tile_row, int tile_col, int tile_cols,
+                                      int hunits_per_tile, int units_per_tile,
+                                      int unit_size, int ss_y,
                                       rest_unit_visitor_t on_rest_unit,
                                       void *priv) {
-  const int col0 = (tile_rect->left + unit_size - 1) / unit_size;
-  const int col1 =
-      AOMMIN((tile_rect->right + unit_size - 1) / unit_size, nunits_x);
-  const int row0 = (tile_rect->top + unit_size - 1) / unit_size;
-  const int row1 =
-      AOMMIN((tile_rect->bottom + unit_size - 1) / unit_size, nunits_y);
+  const int tile_w = tile_rect->right - tile_rect->left;
+  const int tile_h = tile_rect->bottom - tile_rect->top;
+  const int ext_size = unit_size * 3 / 2;
 
-  for (int i = row0; i < row1; ++i) {
-    for (int j = col0; j < col1; ++j) {
-      const int rtile_idx = i * nunits_x + j;
-      RestorationTileLimits limits = get_rest_tile_limits(
-          rtile_idx, nunits_x, nunits_y, unit_size, plane_w, plane_h, ss_y);
+  const int tile_idx = tile_col + tile_row * tile_cols;
+  const int unit_idx0 = tile_idx * units_per_tile;
 
-      on_rest_unit(&limits, rtile_idx, priv);
+  int y0 = 0, i = 0;
+  while (y0 < tile_h) {
+    int remaining_h = tile_h - y0;
+    int h = (remaining_h < ext_size) ? remaining_h : unit_size;
+
+    RestorationTileLimits limits;
+    limits.v_start = tile_rect->top + y0;
+    limits.v_end = tile_rect->top + y0 + h;
+    assert(limits.v_end <= tile_rect->bottom);
+#if CONFIG_STRIPED_LOOP_RESTORATION
+    // Offset the tile upwards to align with the restoration processing stripe
+    const int voffset = RESTORATION_TILE_OFFSET >> ss_y;
+    limits.v_start = AOMMAX(0, limits.v_start - voffset);
+    if (limits.v_end < tile_rect->bottom) limits.v_end -= voffset;
+#else
+    (void)ss_y;
+#endif
+
+    int x0 = 0, j = 0;
+    while (x0 < tile_w) {
+      int remaining_w = tile_w - x0;
+      int w = (remaining_w < ext_size) ? remaining_w : unit_size;
+
+      limits.h_start = tile_rect->left + x0;
+      limits.h_end = tile_rect->left + x0 + w;
+      assert(limits.h_end <= tile_rect->right);
+
+      const int unit_idx = unit_idx0 + i * hunits_per_tile + j;
+      on_rest_unit(&limits, unit_idx, priv);
+
+      x0 += w;
+      ++j;
     }
+
+    y0 += h;
+    ++i;
   }
 }
 
@@ -1563,24 +1670,9 @@
                                     rest_unit_visitor_t on_rest_unit,
                                     void *priv) {
   const int is_uv = plane > 0;
-  const int ss_x = is_uv && cm->subsampling_x;
   const int ss_y = is_uv && cm->subsampling_y;
 
-#if CONFIG_FRAME_SUPERRES
-  const int frame_w = cm->superres_upscaled_width;
-  const int frame_h = cm->superres_upscaled_height;
-#else
-  const int frame_w = cm->width;
-  const int frame_h = cm->height;
-#endif
-
-  const int plane_w = (frame_w + ss_x) >> ss_x;
-  const int plane_h = (frame_h + ss_y) >> ss_y;
-
-  const int unit_size = cm->rst_info[plane].restoration_tilesize;
-
-  int nunits_x, nunits_y;
-  av1_get_rest_ntiles(plane_w, plane_h, unit_size, &nunits_x, &nunits_y);
+  const RestorationInfo *rsi = &cm->rst_info[plane];
 
   TileInfo tile_info;
   for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
@@ -1588,78 +1680,121 @@
     for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
       av1_tile_set_col(&tile_info, cm, tile_col);
 
-      on_tile(tile_row, tile_col, priv);
+      if (on_tile) on_tile(tile_row, tile_col, priv);
 
       AV1PixelRect tile_rect = av1_get_tile_rect(&tile_info, cm, is_uv);
-      foreach_rest_unit_in_tile(&tile_rect, nunits_x, nunits_y, unit_size,
-                                plane_w, plane_h, ss_y, on_rest_unit, priv);
+      foreach_rest_unit_in_tile(&tile_rect, tile_row, tile_col, cm->tile_cols,
+                                rsi->horz_units_per_tile, rsi->units_per_tile,
+                                rsi->restoration_unit_size, ss_y, on_rest_unit,
+                                priv);
     }
   }
 }
 
+#if CONFIG_MAX_TILE
+// Get the horizontal or vertical index of the tile containing mi_x. For a
+// horizontal index, mi_x should be the left-most column for some block in mi
+// units and tile_x_start_sb should be cm->tile_col_start_sb. The return value
+// will be "tile_col" for the tile containing that block.
+//
+// For a vertical index, mi_x should be the block's top row and tile_x_start_sb
+// should be cm->tile_row_start_sb. The return value will be "tile_row" for the
+// tile containing the block.
+static int get_tile_idx(const int *tile_x_start_sb, int mi_x) {
+  int sb_x = mi_x << MAX_MIB_SIZE_LOG2;
+
+  for (int i = 0; i < MAX_TILE_COLS; ++i) {
+    if (tile_x_start_sb[i + 1] > sb_x) return i;
+  }
+
+  // This shouldn't happen if tile_x_start_sb has been filled in
+  // correctly.
+  assert(0);
+  return 0;
+}
+#endif
+
 int av1_loop_restoration_corners_in_sb(const struct AV1Common *cm, int plane,
                                        int mi_row, int mi_col, BLOCK_SIZE bsize,
                                        int *rcol0, int *rcol1, int *rrow0,
-                                       int *rrow1, int *nhtiles) {
-  assert(rcol0 && rcol1 && rrow0 && rrow1 && nhtiles);
+                                       int *rrow1, int *tile_tl_idx) {
+  assert(rcol0 && rcol1 && rrow0 && rrow1);
 
   if (bsize != cm->sb_size) return 0;
 
+  const int is_uv = plane > 0;
+
+// Which tile contains the superblock? Find that tile's top-left in mi-units,
+// together with the tile's size in pixels.
+#if CONFIG_MAX_TILE
+  const int tile_row = get_tile_idx(cm->tile_row_start_sb, mi_row);
+  const int tile_col = get_tile_idx(cm->tile_col_start_sb, mi_col);
+
+  const int sb_t = cm->tile_row_start_sb[tile_row];
+  const int sb_l = cm->tile_col_start_sb[tile_col];
+  const int sb_b = cm->tile_row_start_sb[tile_row + 1];
+  const int sb_r = cm->tile_col_start_sb[tile_col + 1];
+
+  int tile_w, tile_h;
+  tile_width_and_height(cm, is_uv, sb_r - sb_l, sb_t - sb_b, &tile_w, &tile_h);
+
+  const int mi_top = sb_t << MAX_MIB_SIZE_LOG2;
+  const int mi_left = sb_l << MAX_MIB_SIZE_LOG2;
+#else
+  const int tile_row = mi_row / cm->tile_height;
+  const int tile_col = mi_col / cm->tile_width;
+
+  TileInfo tile_info;
+  av1_tile_init(&tile_info, cm, tile_row, tile_col);
+
+  const AV1PixelRect tile_rect = av1_get_tile_rect(&tile_info, cm, is_uv);
+  const int tile_w = tile_rect.right - tile_rect.left;
+  const int tile_h = tile_rect.bottom - tile_rect.top;
+
+  const int mi_top = tile_info.mi_row_start;
+  const int mi_left = tile_info.mi_col_start;
+#endif  // CONFIG_MAX_TILE
+
+  // Compute the mi-unit corners of the superblock relative to the top-left of
+  // the tile
+  const int mi_rel_row0 = mi_row - mi_top;
+  const int mi_rel_col0 = mi_col - mi_left;
+  const int mi_rel_row1 = mi_rel_row0 + mi_size_high[bsize];
+  const int mi_rel_col1 = mi_rel_col0 + mi_size_wide[bsize];
+
 #if CONFIG_FRAME_SUPERRES
-  const int frame_w = cm->superres_upscaled_width;
-  const int frame_h = cm->superres_upscaled_height;
-  const int mi_to_px = MI_SIZE * SCALE_NUMERATOR;
+  const int mi_to_num = MI_SIZE * SCALE_NUMERATOR;
   const int denom = cm->superres_scale_denominator;
 #else
-  const int frame_w = cm->width;
-  const int frame_h = cm->height;
-  const int mi_to_px = MI_SIZE;
+  const int mi_to_num = MI_SIZE;
   const int denom = 1;
 #endif  // CONFIG_FRAME_SUPERRES
 
-  const int ss_x = plane > 0 && cm->subsampling_x != 0;
-  const int ss_y = plane > 0 && cm->subsampling_y != 0;
+  const RestorationInfo *rsi = &cm->rst_info[plane];
+  const int size = rsi->restoration_unit_size;
+  const int rnd = size * denom - 1;
 
-  const int ss_frame_w = (frame_w + ss_x) >> ss_x;
-  const int ss_frame_h = (frame_h + ss_y) >> ss_y;
+  // Calculate the number of restoration units in this tile (which might be
+  // strictly less than rsi->horz_units_per_tile and rsi->vert_units_per_tile)
+  const int horz_units = count_units_in_tile(size, tile_w);
+  const int vert_units = count_units_in_tile(size, tile_h);
 
-  const int rtile_size = cm->rst_info[plane].restoration_tilesize;
+  // rcol0/rrow0 should be the first column/row of restoration units (relative
+  // to the top-left of the tile) that doesn't start left/below of
+  // mi_col/mi_row. For this calculation, we need to round up the division (if
+  // the sb starts at rtile column 10.1, the first matching rtile has column
+  // index 11)
+  *rcol0 = (mi_rel_col0 * mi_to_num + rnd) / (size * denom);
+  *rrow0 = (mi_rel_row0 * mi_to_num + rnd) / (size * denom);
 
-  int nvtiles;
-  av1_get_rest_ntiles(ss_frame_w, ss_frame_h, rtile_size, nhtiles, &nvtiles);
+  // rel_col1/rel_row1 is the equivalent calculation, but for the superblock
+  // below-right. If we're at the bottom or right of the tile, this restoration
+  // unit might not exist, in which case we'll clamp accordingly.
+  *rcol1 = AOMMIN((mi_rel_col1 * mi_to_num + rnd) / (size * denom), horz_units);
+  *rrow1 = AOMMIN((mi_rel_row1 * mi_to_num + rnd) / (size * denom), vert_units);
 
-  const int rnd = rtile_size * denom - 1;
-
-  // rcol0/rrow0 should be the first column/row of rtiles that doesn't start
-  // left/below of mi_col/mi_row. For this calculation, we need to round up the
-  // division (if the sb starts at rtile column 10.1, the first matching rtile
-  // has column index 11)
-  *rcol0 = (mi_col * mi_to_px + rnd) / (rtile_size * denom);
-  *rrow0 = (mi_row * mi_to_px + rnd) / (rtile_size * denom);
-
-  // rcol1/rrow1 is the equivalent calculation, but for the superblock
-  // below-right. There are some slightly strange boundary effects. First, we
-  // need to clamp to nhtiles/nvtiles for the case where it appears there are,
-  // say, 2.4 restoration tiles horizontally. There we need a maximum mi_row1
-  // of 2 because tile 1 gets extended.
-  //
-  // Second, if mi_col1 >= cm->mi_cols then we must manually set *rcol1 to
-  // nhtiles. This is needed whenever the frame's width rounded up to the next
-  // toplevel superblock is smaller than nhtiles * rtile_w. The same logic is
-  // needed for rows.
-  const int mi_row1 = mi_row + mi_size_high[bsize];
-  const int mi_col1 = mi_col + mi_size_wide[bsize];
-
-  if (mi_col1 >= cm->mi_cols)
-    *rcol1 = *nhtiles;
-  else
-    *rcol1 =
-        AOMMIN(*nhtiles, (mi_col1 * mi_to_px + rnd) / (rtile_size * denom));
-
-  if (mi_row1 >= cm->mi_rows)
-    *rrow1 = nvtiles;
-  else
-    *rrow1 = AOMMIN(nvtiles, (mi_row1 * mi_to_px + rnd) / (rtile_size * denom));
+  const int tile_idx = tile_col + tile_row * cm->tile_cols;
+  *tile_tl_idx = tile_idx * rsi->units_per_tile;
 
   return *rcol0 < *rcol1 && *rrow0 < *rrow1;
 }
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 6a77f2b..51e2d81 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -234,8 +234,19 @@
 
 typedef struct {
   RestorationType frame_restoration_type;
-  int restoration_tilesize;
+  int restoration_unit_size;
   int procunit_width, procunit_height;
+
+  // Fields below here are allocated and initialised by
+  // av1_alloc_restoration_struct. (horz_)units_per_tile give the number of
+  // restoration units in (one row of) the largest tile in the frame. The data
+  // in unit_info is laid out with units_per_tile entries for each tile, which
+  // have stride horz_units_per_tile.
+  //
+  // Even if there are tiles of different sizes, the data in unit_info is laid
+  // out as if all tiles are of full size.
+  int units_per_tile;
+  int vert_units_per_tile, horz_units_per_tile;
   RestorationUnitInfo *unit_info;
 #if CONFIG_STRIPED_LOOP_RESTORATION
   RestorationStripeBoundaries boundaries;
@@ -259,20 +270,6 @@
   wiener_info->vfilter[6] = wiener_info->hfilter[6] = WIENER_FILT_TAP0_MIDV;
 }
 
-static INLINE int av1_get_rest_ntiles(int width, int height, int tilesize,
-                                      int *nhtiles, int *nvtiles) {
-  int nhtiles_, nvtiles_;
-  const int tile_width = (tilesize < 0) ? width : AOMMIN(tilesize, width);
-  const int tile_height = (tilesize < 0) ? height : AOMMIN(tilesize, height);
-  assert(tile_width > 0 && tile_height > 0);
-
-  nhtiles_ = (width + (tile_width >> 1)) / tile_width;
-  nvtiles_ = (height + (tile_height >> 1)) / tile_height;
-  if (nhtiles) *nhtiles = nhtiles_;
-  if (nvtiles) *nvtiles = nvtiles_;
-  return (nhtiles_ * nvtiles_);
-}
-
 typedef struct { int h_start, h_end, v_start, v_end; } RestorationTileLimits;
 
 extern const sgr_params_type sgr_params[SGRPROJ_PARAMS];
@@ -280,9 +277,8 @@
 extern const int32_t x_by_xplus1[256];
 extern const int32_t one_by_x[MAX_NELEM];
 
-int av1_alloc_restoration_struct(struct AV1Common *cm,
-                                 RestorationInfo *rst_info, int width,
-                                 int height);
+void av1_alloc_restoration_struct(struct AV1Common *cm, RestorationInfo *rsi,
+                                  int is_uv);
 void av1_free_restoration_struct(RestorationInfo *rst_info);
 
 void extend_frame(uint8_t *data, int width, int height, int stride,
@@ -344,14 +340,14 @@
 // loop restoration tile.
 //
 // If the block is a top-level superblock, the function writes to
-// *rcol0, *rcol1, *rrow0, *rrow1. The rectangle of indices given by
-// [*rcol0, *rcol1) x [*rrow0, *rrow1) will point at the set of rtiles
-// whose top left corners lie in the superblock. Note that the set is
-// only nonempty if *rcol0 < *rcol1 and *rrow0 < *rrow1.
+// *rcol0, *rcol1, *rrow0, *rrow1. The rectangle of restoration unit
+// indices given by [*rcol0, *rcol1) x [*rrow0, *rrow1) are relative
+// to the current tile, whose starting index is returned as
+// *tile_tl_idx.
 int av1_loop_restoration_corners_in_sb(const struct AV1Common *cm, int plane,
                                        int mi_row, int mi_col, BLOCK_SIZE bsize,
                                        int *rcol0, int *rcol1, int *rrow0,
-                                       int *rrow1, int *nhtiles);
+                                       int *rrow1, int *tile_tl_idx);
 
 void av1_loop_restoration_save_boundary_lines(const YV12_BUFFER_CONFIG *frame,
                                               struct AV1Common *cm);
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 7d86a53..cd14201 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -1287,13 +1287,14 @@
 #endif  // CONFIG_CDEF
 #if CONFIG_LOOP_RESTORATION
   for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
-    int rcol0, rcol1, rrow0, rrow1, nhtiles;
+    int rcol0, rcol1, rrow0, rrow1, tile_tl_idx;
     if (av1_loop_restoration_corners_in_sb(cm, plane, mi_row, mi_col, bsize,
                                            &rcol0, &rcol1, &rrow0, &rrow1,
-                                           &nhtiles)) {
+                                           &tile_tl_idx)) {
+      const int rstride = cm->rst_info[plane].horz_units_per_tile;
       for (int rrow = rrow0; rrow < rrow1; ++rrow) {
         for (int rcol = rcol0; rcol < rcol1; ++rcol) {
-          int rtile_idx = rcol + rrow * nhtiles;
+          const int rtile_idx = tile_tl_idx + rcol + rrow * rstride;
           loop_restoration_read_sb_coeffs(cm, xd, r, plane, rtile_idx);
         }
       }
@@ -1389,30 +1390,31 @@
           aom_rb_read_bit(rb) ? RESTORE_SWITCHABLE : RESTORE_NONE;
     }
   }
-  cm->rst_info[0].restoration_tilesize = RESTORATION_TILESIZE_MAX;
-  cm->rst_info[1].restoration_tilesize = RESTORATION_TILESIZE_MAX;
-  cm->rst_info[2].restoration_tilesize = RESTORATION_TILESIZE_MAX;
+  cm->rst_info[0].restoration_unit_size = RESTORATION_TILESIZE_MAX;
+  cm->rst_info[1].restoration_unit_size = RESTORATION_TILESIZE_MAX;
+  cm->rst_info[2].restoration_unit_size = RESTORATION_TILESIZE_MAX;
   if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
       cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
       cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
-    cm->rst_info[0].restoration_tilesize = RESTORATION_TILESIZE_MAX >> 2;
-    cm->rst_info[1].restoration_tilesize = RESTORATION_TILESIZE_MAX >> 2;
-    cm->rst_info[2].restoration_tilesize = RESTORATION_TILESIZE_MAX >> 2;
+    cm->rst_info[0].restoration_unit_size = RESTORATION_TILESIZE_MAX >> 2;
+    cm->rst_info[1].restoration_unit_size = RESTORATION_TILESIZE_MAX >> 2;
+    cm->rst_info[2].restoration_unit_size = RESTORATION_TILESIZE_MAX >> 2;
     rsi = &cm->rst_info[0];
-    rsi->restoration_tilesize <<= aom_rb_read_bit(rb);
-    if (rsi->restoration_tilesize != (RESTORATION_TILESIZE_MAX >> 2)) {
-      rsi->restoration_tilesize <<= aom_rb_read_bit(rb);
+    rsi->restoration_unit_size <<= aom_rb_read_bit(rb);
+    if (rsi->restoration_unit_size != (RESTORATION_TILESIZE_MAX >> 2)) {
+      rsi->restoration_unit_size <<= aom_rb_read_bit(rb);
     }
   }
   int s = AOMMIN(cm->subsampling_x, cm->subsampling_y);
   if (s && (cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
             cm->rst_info[2].frame_restoration_type != RESTORE_NONE)) {
-    cm->rst_info[1].restoration_tilesize =
-        cm->rst_info[0].restoration_tilesize >> (aom_rb_read_bit(rb) * s);
+    cm->rst_info[1].restoration_unit_size =
+        cm->rst_info[0].restoration_unit_size >> (aom_rb_read_bit(rb) * s);
   } else {
-    cm->rst_info[1].restoration_tilesize = cm->rst_info[0].restoration_tilesize;
+    cm->rst_info[1].restoration_unit_size =
+        cm->rst_info[0].restoration_unit_size;
   }
-  cm->rst_info[2].restoration_tilesize = cm->rst_info[1].restoration_tilesize;
+  cm->rst_info[2].restoration_unit_size = cm->rst_info[1].restoration_unit_size;
 
   cm->rst_info[0].procunit_width = cm->rst_info[0].procunit_height =
       RESTORATION_PROC_UNIT_SIZE;
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 4a41eb7..da53814 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -2675,13 +2675,14 @@
 #endif
 #if CONFIG_LOOP_RESTORATION
   for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
-    int rcol0, rcol1, rrow0, rrow1, nhtiles;
+    int rcol0, rcol1, rrow0, rrow1, tile_tl_idx;
     if (av1_loop_restoration_corners_in_sb(cm, plane, mi_row, mi_col, bsize,
                                            &rcol0, &rcol1, &rrow0, &rrow1,
-                                           &nhtiles)) {
+                                           &tile_tl_idx)) {
+      const int rstride = cm->rst_info[plane].horz_units_per_tile;
       for (int rrow = rrow0; rrow < rrow1; ++rrow) {
         for (int rcol = rcol0; rcol < rcol1; ++rcol) {
-          int rtile_idx = rcol + rrow * nhtiles;
+          const int rtile_idx = tile_tl_idx + rcol + rrow * rstride;
           const RestorationUnitInfo *rui =
               &cm->rst_info[plane].unit_info[rtile_idx];
           loop_restoration_write_sb_coeffs(cm, xd, rui, w, plane);
@@ -2767,29 +2768,29 @@
       cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
       cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
     aom_wb_write_bit(
-        wb, rsi->restoration_tilesize != (RESTORATION_TILESIZE_MAX >> 2));
-    if (rsi->restoration_tilesize != (RESTORATION_TILESIZE_MAX >> 2)) {
+        wb, rsi->restoration_unit_size != (RESTORATION_TILESIZE_MAX >> 2));
+    if (rsi->restoration_unit_size != (RESTORATION_TILESIZE_MAX >> 2)) {
       aom_wb_write_bit(
-          wb, rsi->restoration_tilesize != (RESTORATION_TILESIZE_MAX >> 1));
+          wb, rsi->restoration_unit_size != (RESTORATION_TILESIZE_MAX >> 1));
     }
   }
   int s = AOMMIN(cm->subsampling_x, cm->subsampling_y);
   if (s && (cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
             cm->rst_info[2].frame_restoration_type != RESTORE_NONE)) {
     aom_wb_write_bit(wb,
-                     cm->rst_info[1].restoration_tilesize !=
-                         cm->rst_info[0].restoration_tilesize);
-    assert(cm->rst_info[1].restoration_tilesize ==
-               cm->rst_info[0].restoration_tilesize ||
-           cm->rst_info[1].restoration_tilesize ==
-               (cm->rst_info[0].restoration_tilesize >> s));
-    assert(cm->rst_info[2].restoration_tilesize ==
-           cm->rst_info[1].restoration_tilesize);
+                     cm->rst_info[1].restoration_unit_size !=
+                         cm->rst_info[0].restoration_unit_size);
+    assert(cm->rst_info[1].restoration_unit_size ==
+               cm->rst_info[0].restoration_unit_size ||
+           cm->rst_info[1].restoration_unit_size ==
+               (cm->rst_info[0].restoration_unit_size >> s));
+    assert(cm->rst_info[2].restoration_unit_size ==
+           cm->rst_info[1].restoration_unit_size);
   } else if (!s) {
-    assert(cm->rst_info[1].restoration_tilesize ==
-           cm->rst_info[0].restoration_tilesize);
-    assert(cm->rst_info[2].restoration_tilesize ==
-           cm->rst_info[1].restoration_tilesize);
+    assert(cm->rst_info[1].restoration_unit_size ==
+           cm->rst_info[0].restoration_unit_size);
+    assert(cm->rst_info[2].restoration_unit_size ==
+           cm->rst_info[1].restoration_unit_size);
   }
 }
 
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index d883f96..d75642b 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -4171,8 +4171,8 @@
 
 #if CONFIG_LOOP_RESTORATION
 #define COUPLED_CHROMA_FROM_LUMA_RESTORATION 0
-static void set_restoration_tilesize(int width, int height, int sx, int sy,
-                                     RestorationInfo *rst) {
+static void set_restoration_unit_size(int width, int height, int sx, int sy,
+                                      RestorationInfo *rst) {
   (void)width;
   (void)height;
   (void)sx;
@@ -4183,9 +4183,9 @@
   int s = 0;
 #endif  // !COUPLED_CHROMA_FROM_LUMA_RESTORATION
 
-  rst[0].restoration_tilesize = (RESTORATION_TILESIZE_MAX >> 1);
-  rst[1].restoration_tilesize = rst[0].restoration_tilesize >> s;
-  rst[2].restoration_tilesize = rst[1].restoration_tilesize;
+  rst[0].restoration_unit_size = (RESTORATION_TILESIZE_MAX >> 1);
+  rst[1].restoration_unit_size = rst[0].restoration_unit_size >> s;
+  rst[2].restoration_unit_size = rst[1].restoration_unit_size;
 
   rst[0].procunit_width = rst[0].procunit_height = RESTORATION_PROC_UNIT_SIZE;
   rst[1].procunit_width = rst[2].procunit_width =
@@ -4307,8 +4307,8 @@
   const int frame_width = cm->width;
   const int frame_height = cm->height;
 #endif
-  set_restoration_tilesize(frame_width, frame_height, cm->subsampling_x,
-                           cm->subsampling_y, cm->rst_info);
+  set_restoration_unit_size(frame_width, frame_height, cm->subsampling_x,
+                            cm->subsampling_y, cm->rst_info);
   for (int i = 0; i < MAX_MB_PLANE; ++i)
     cm->rst_info[i].frame_restoration_type = RESTORE_NONE;
 
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 6b187bd..26330d9 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -1158,14 +1158,17 @@
   return RDCOST_DBL(rsc->x->rdmult, rsc->bits >> 4, rsc->sse);
 }
 
+static int rest_tiles_in_plane(const AV1_COMMON *cm, int plane) {
+  const RestorationInfo *rsi = &cm->rst_info[plane];
+  return cm->tile_rows * cm->tile_cols * rsi->units_per_tile;
+}
+
 void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) {
   AV1_COMMON *const cm = &cpi->common;
 
   int ntiles[2];
   for (int is_uv = 0; is_uv < 2; ++is_uv)
-    ntiles[is_uv] = av1_get_rest_ntiles(
-        src->crop_widths[is_uv], src->crop_heights[is_uv],
-        cm->rst_info[is_uv].restoration_tilesize, NULL, NULL);
+    ntiles[is_uv] = rest_tiles_in_plane(cm, is_uv);
 
   assert(ntiles[1] <= ntiles[0]);
   RestUnitSearchInfo *rusi =