Refactor UV restoration to use same tilesize as Y

Change-Id: I56e741551f74624a84250d7565520db9c5127d1b
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index 50446b2..4860ae4 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -90,8 +90,11 @@
 #if CONFIG_LOOP_RESTORATION
 void av1_alloc_restoration_buffers(AV1_COMMON *cm) {
   int p;
-  for (p = 0; p < MAX_MB_PLANE; ++p)
-    av1_alloc_restoration_struct(&cm->rst_info[p], cm->width, cm->height);
+  av1_alloc_restoration_struct(&cm->rst_info[0], cm->width, cm->height);
+  for (p = 1; p < MAX_MB_PLANE; ++p)
+    av1_alloc_restoration_struct(&cm->rst_info[p],
+                                 cm->width >> cm->subsampling_x,
+                                 cm->height >> cm->subsampling_y);
   cm->rst_internal.tmpbuf =
       (int32_t *)aom_realloc(cm->rst_internal.tmpbuf, RESTORATION_TMPBUF_SIZE);
   if (cm->rst_internal.tmpbuf == NULL)
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index e718807..8228170 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -97,12 +97,8 @@
 
 void av1_loop_restoration_precal() { GenDomainTxfmRFVtable(); }
 
-static void loop_restoration_init(RestorationInternal *rst, int kf, int width,
-                                  int height) {
+static void loop_restoration_init(RestorationInternal *rst, int kf) {
   rst->keyframe = kf;
-  rst->ntiles =
-      av1_get_rest_ntiles(width, height, &rst->tile_width, &rst->tile_height,
-                          &rst->nhtiles, &rst->nvtiles);
 }
 
 void extend_frame(uint8_t *data, int width, int height, int stride) {
@@ -127,8 +123,8 @@
                            int subtile_bits, int width, int height, int stride,
                            RestorationInternal *rst, uint8_t *dst,
                            int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int i;
   int h_start, h_end, v_start, v_end;
   av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, rst->nhtiles,
@@ -143,8 +139,8 @@
                                     int height, int stride,
                                     RestorationInternal *rst, uint8_t *dst,
                                     int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int i, j;
   int h_start, h_end, v_start, v_end;
   DECLARE_ALIGNED(16, InterpKernel, hkernel);
@@ -631,8 +627,8 @@
                                      int height, int stride,
                                      RestorationInternal *rst, uint8_t *dst,
                                      int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int h_start, h_end, v_start, v_end;
   uint8_t *data_p, *dst_p;
 
@@ -790,8 +786,8 @@
                                           int width, int height, int stride,
                                           RestorationInternal *rst,
                                           uint8_t *dst, int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int h_start, h_end, v_start, v_end;
   int32_t *tmpbuf = (int32_t *)rst->tmpbuf;
 
@@ -866,8 +862,8 @@
                                   int subtile_bits, int width, int height,
                                   int stride, RestorationInternal *rst,
                                   uint16_t *dst, int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int i;
   int h_start, h_end, v_start, v_end;
   av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, rst->nhtiles,
@@ -883,8 +879,8 @@
                                            RestorationInternal *rst,
                                            int bit_depth, uint16_t *dst,
                                            int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int h_start, h_end, v_start, v_end;
   int i, j;
   DECLARE_ALIGNED(16, InterpKernel, hkernel);
@@ -978,8 +974,8 @@
                                             RestorationInternal *rst,
                                             int bit_depth, uint16_t *dst,
                                             int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int h_start, h_end, v_start, v_end;
   uint16_t *data_p, *dst_p;
 
@@ -1108,8 +1104,8 @@
 static void loop_domaintxfmrf_filter_tile_highbd(
     uint16_t *data, int tile_idx, int width, int height, int stride,
     RestorationInternal *rst, int bit_depth, uint16_t *dst, int dst_stride) {
-  const int tile_width = rst->tile_width >> rst->subsampling_x;
-  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width;
+  const int tile_height = rst->tile_height;
   int h_start, h_end, v_start, v_end;
   int32_t *tmpbuf = (int32_t *)rst->tmpbuf;
 
@@ -1243,8 +1239,10 @@
 
   if ((components_pattern >> AOM_PLANE_Y) & 1) {
     if (rsi[0].frame_restoration_type != RESTORE_NONE) {
-      cm->rst_internal.subsampling_x = 0;
-      cm->rst_internal.subsampling_y = 0;
+      cm->rst_internal.ntiles = av1_get_rest_ntiles(
+          cm->width, cm->height, &cm->rst_internal.tile_width,
+          &cm->rst_internal.tile_height, &cm->rst_internal.nhtiles,
+          &cm->rst_internal.nvtiles);
       cm->rst_internal.rsi = &rsi[0];
       restore_func =
           restore_funcs[cm->rst_internal.rsi->frame_restoration_type];
@@ -1267,10 +1265,12 @@
   }
 
   if ((components_pattern >> AOM_PLANE_U) & 1) {
-    cm->rst_internal.subsampling_x = cm->subsampling_x;
-    cm->rst_internal.subsampling_y = cm->subsampling_y;
-    cm->rst_internal.rsi = &rsi[1];
-    if (rsi[1].frame_restoration_type != RESTORE_NONE) {
+    if (rsi[AOM_PLANE_U].frame_restoration_type != RESTORE_NONE) {
+      cm->rst_internal.ntiles = av1_get_rest_ntiles(
+          cm->width >> cm->subsampling_x, cm->height >> cm->subsampling_y,
+          &cm->rst_internal.tile_width, &cm->rst_internal.tile_height,
+          &cm->rst_internal.nhtiles, &cm->rst_internal.nvtiles);
+      cm->rst_internal.rsi = &rsi[AOM_PLANE_U];
       restore_func =
           restore_funcs[cm->rst_internal.rsi->frame_restoration_type];
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -1292,10 +1292,12 @@
   }
 
   if ((components_pattern >> AOM_PLANE_V) & 1) {
-    cm->rst_internal.subsampling_x = cm->subsampling_x;
-    cm->rst_internal.subsampling_y = cm->subsampling_y;
-    cm->rst_internal.rsi = &rsi[2];
-    if (rsi[2].frame_restoration_type != RESTORE_NONE) {
+    if (rsi[AOM_PLANE_V].frame_restoration_type != RESTORE_NONE) {
+      cm->rst_internal.ntiles = av1_get_rest_ntiles(
+          cm->width >> cm->subsampling_x, cm->height >> cm->subsampling_y,
+          &cm->rst_internal.tile_width, &cm->rst_internal.tile_height,
+          &cm->rst_internal.nhtiles, &cm->rst_internal.nvtiles);
+      cm->rst_internal.rsi = &rsi[AOM_PLANE_V];
       restore_func =
           restore_funcs[cm->rst_internal.rsi->frame_restoration_type];
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -1336,8 +1338,7 @@
     mi_rows_to_filter = AOMMAX(cm->mi_rows / 8, 8);
   }
   end_mi_row = start_mi_row + mi_rows_to_filter;
-  loop_restoration_init(&cm->rst_internal, cm->frame_type == KEY_FRAME,
-                        cm->width, cm->height);
+  loop_restoration_init(&cm->rst_internal, cm->frame_type == KEY_FRAME);
   loop_restoration_rows(frame, cm, start_mi_row, end_mi_row, components_pattern,
                         rsi, dst);
 }
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 3a1d5ea..cb596c6 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -151,8 +151,6 @@
 typedef struct {
   RestorationInfo *rsi;
   int keyframe;
-  int subsampling_x;
-  int subsampling_y;
   int ntiles;
   int tile_width, tile_height;
   int nhtiles, nvtiles;
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index aea295d..b946ddb 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2358,6 +2358,9 @@
   int i, p;
   const int ntiles =
       av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
+  const int ntiles_uv = av1_get_rest_ntiles(cm->width >> cm->subsampling_x,
+                                            cm->height >> cm->subsampling_x,
+                                            NULL, NULL, NULL, NULL);
   RestorationInfo *rsi = &cm->rst_info[0];
   if (rsi->frame_restoration_type != RESTORE_NONE) {
     if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
@@ -2417,7 +2420,7 @@
       rsi->restoration_type[0] = RESTORE_WIENER;
       rsi->wiener_info[0].level = 1;
       read_wiener_filter(&rsi->wiener_info[0], rb);
-      for (i = 1; i < ntiles; ++i) {
+      for (i = 1; i < ntiles_uv; ++i) {
         rsi->restoration_type[i] = RESTORE_WIENER;
         memcpy(&rsi->wiener_info[i], &rsi->wiener_info[0],
                sizeof(rsi->wiener_info[0]));
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 5cf6244..ae2c36e 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -3086,11 +3086,13 @@
 
 static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
   int i, p;
+  const int ntiles =
+      av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
   RestorationInfo *rsi = &cm->rst_info[0];
   if (rsi->frame_restoration_type != RESTORE_NONE) {
     if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
       // RESTORE_SWITCHABLE
-      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+      for (i = 0; i < ntiles; ++i) {
         av1_write_token(
             wb, av1_switchable_restore_tree, cm->fc->switchable_restore_prob,
             &switchable_restore_encodings[rsi->restoration_type[i]]);
@@ -3103,14 +3105,14 @@
         }
       }
     } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
-      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+      for (i = 0; i < ntiles; ++i) {
         aom_write(wb, rsi->wiener_info[i].level != 0, RESTORE_NONE_WIENER_PROB);
         if (rsi->wiener_info[i].level) {
           write_wiener_filter(&rsi->wiener_info[i], wb);
         }
       }
     } else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
-      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+      for (i = 0; i < ntiles; ++i) {
         aom_write(wb, rsi->sgrproj_info[i].level != 0,
                   RESTORE_NONE_SGRPROJ_PROB);
         if (rsi->sgrproj_info[i].level) {
@@ -3118,7 +3120,7 @@
         }
       }
     } else if (rsi->frame_restoration_type == RESTORE_DOMAINTXFMRF) {
-      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+      for (i = 0; i < ntiles; ++i) {
         aom_write(wb, rsi->domaintxfmrf_info[i].level != 0,
                   RESTORE_NONE_DOMAINTXFMRF_PROB);
         if (rsi->domaintxfmrf_info[i].level) {
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 0fa07c8..f8a374e 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -44,6 +44,10 @@
                                     int width, int v_start, int height,
                                     int components_pattern) {
   int64_t filt_err = 0;
+  (void)cm;
+  // Y and UV components cannot be mixed
+  assert(components_pattern == 1 || components_pattern == 2 ||
+         components_pattern == 4 || components_pattern == 6);
 #if CONFIG_AOM_HIGHBITDEPTH
   if (cm->use_highbitdepth) {
     if ((components_pattern >> AOM_PLANE_Y) & 1) {
@@ -51,14 +55,12 @@
           aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
     }
     if ((components_pattern >> AOM_PLANE_U) & 1) {
-      filt_err += aom_highbd_get_u_sse_part(
-          src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
-          v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+      filt_err +=
+          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
     }
     if ((components_pattern >> AOM_PLANE_V) & 1) {
-      filt_err += aom_highbd_get_v_sse_part(
-          src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
-          v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+      filt_err +=
+          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
     }
     return filt_err;
   }
@@ -67,14 +69,10 @@
     filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
   }
   if ((components_pattern >> AOM_PLANE_U) & 1) {
-    filt_err += aom_get_u_sse_part(
-        src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
-        v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
   }
   if ((components_pattern >> AOM_PLANE_V) & 1) {
-    filt_err += aom_get_u_sse_part(
-        src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
-        v_start >> cm->subsampling_y, height >> cm->subsampling_y);
+    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
   }
   return filt_err;
 }
@@ -119,16 +117,28 @@
   int64_t filt_err;
   int tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
-  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
-                                         &tile_height, &nhtiles, &nvtiles);
+  int ntiles, width, height;
+
+  // Y and UV components cannot be mixed
+  assert(components_pattern == 1 || components_pattern == 2 ||
+         components_pattern == 4 || components_pattern == 6);
+
+  if (components_pattern == 1) {  // Y only
+    width = src->y_crop_width;
+    height = src->y_crop_height;
+  } else {  // Color
+    width = src->uv_crop_width;
+    height = src->uv_crop_height;
+  }
+  ntiles = av1_get_rest_ntiles(width, height, &tile_width, &tile_height,
+                               &nhtiles, &nvtiles);
   (void)ntiles;
 
   av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                              partial_frame, dst_frame);
   av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
-                           nvtiles, tile_width, tile_height, cm->width,
-                           cm->height, 0, 0, &h_start, &h_end, &v_start,
-                           &v_end);
+                           nvtiles, tile_width, tile_height, width, height, 0,
+                           0, &h_start, &h_end, &v_start, &v_end);
   filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
                                   v_start, v_end - v_start, components_pattern);
 
@@ -951,7 +961,7 @@
   const int dgd_stride = dgd->uv_stride;
   double score;
   int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
-  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+  const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
                                          &tile_height, &nhtiles, &nvtiles);
 
   assert(width == dgd->uv_crop_width);
@@ -1361,6 +1371,8 @@
          cm->rst_info[0].frame_restoration_type,
          cm->rst_info[1].frame_restoration_type,
          cm->rst_info[2].frame_restoration_type);
+         */
+  /*
   printf("Frame %d/%d frame_restore_type %d : %f %f %f %f %f\n",
          cm->current_video_frame, cm->show_frame,
          cm->rst_info[0].frame_restoration_type, cost_restore[0],