Guided CNN restoration: signal unit size.

An index `unit_index` is signaled at frame level to select one of two
unit sizes:
- unit_index = 0 means unit_size = 256 for low-res and 512 for high-res.
- unit_index = 1 means unit_size = 512 for low-res and 1024 for high-res.

For RDO, a loop is added in encoder side function
`restore_cnn_quadtree_encode_img_tflite_highbd` to pick the best unit
size.
diff --git a/av1/common/cnn_tflite.cc b/av1/common/cnn_tflite.cc
index 61d813f..e99b8ed 100644
--- a/av1/common/cnn_tflite.cc
+++ b/av1/common/cnn_tflite.cc
@@ -993,18 +993,18 @@
       RDCOST_DBL_WITH_NATIVE_BD_DIST(rdmult, bitrate >> 4, sse, bit_depth);
 }
 
-// Given intermediate restoration 'interm' and 'src', computes the best
-// partitioning out of NONE, SPLIT, HORZ and VERT based on RD cost, and uses it
-// to restore the unit starting at 'row' and 'col' inside 'dgd'.
-// Also, stores the split decisions in 'split' and a0,a1 pairs in 'A'.
+// Given intermediate restoration 'interm', source 'src' and degradade frame
+// 'dgd', computes the best partitioning out of NONE, SPLIT, HORZ and VERT based
+// on RD cost for the widthxheight unit starting at 'row' and 'col'.
+// The split decisions are stored in 'split' and a0,a1 pairs are stored in 'A'.
 static void select_quadtree_partitioning(
     const std::vector<std::vector<std::vector<double>>> &interm,
     const uint16_t *src, int src_stride, int start_row, int start_col,
     int width, int height, int quadtree_max_size, const int *quadtset,
     int rdmult, const std::pair<int, int> &prev_A, const int *splitcosts,
-    const int norestorecosts[2], int bit_depth, uint16_t *dgd, int dgd_stride,
-    std::vector<int> &split, std::vector<std::pair<int, int>> &A,
-    double *rdcost) {
+    const int norestorecosts[2], int bit_depth, const uint16_t *dgd,
+    int dgd_stride, std::vector<int> &split,
+    std::vector<std::pair<int, int>> &A, double *rdcost) {
   const int end_row = AOMMIN(start_row + quadtree_max_size, height);
   const int end_col = AOMMIN(start_col + quadtree_max_size, width);
   const bool is_partial_unit = (start_row + quadtree_max_size > height) ||
@@ -1044,13 +1044,6 @@
   // Save RDCost.
   *rdcost = best_rdcost;
 
-  // Restore this unit in 'dgd'.
-  for (int row = start_row; row < end_row; ++row) {
-    for (int col = start_col; col < end_col; ++col) {
-      dgd[row * dgd_stride + col] = best_out[row - start_row][col - start_col];
-    }
-  }
-
   // Save a0, a1 pairs.
   for (auto &a0a1 : best_A) {
     A.push_back(a0a1);
@@ -1083,6 +1076,13 @@
   }
 }
 
+static void apply_quadtree_partitioning(
+    const std::vector<std::vector<std::vector<double>>> &interm, int start_row,
+    int start_col, int width, int height, int quadtree_max_size,
+    const int *quadtset, int bit_depth, const std::vector<int> &split,
+    size_t &split_index, const std::vector<std::pair<int, int>> &A,
+    size_t &A_index, uint16_t *dgd, int dgd_stride);
+
 // Top-level function to apply guided restoration on encoder side.
 static int restore_cnn_quadtree_encode_img_tflite_highbd(
     YV12_BUFFER_CONFIG *source_frame, AV1_COMMON *cm, int superres_denom,
@@ -1108,9 +1108,6 @@
   // Initialization.
   const uint16_t *src = CONVERT_TO_SHORTPTR(source_frame->y_buffer);
   const int src_stride = source_frame->y_stride;
-  const int unit_index = quad_tree_get_unit_index(width, height);
-  const int quadtree_max_size =
-      quad_tree_get_unit_size(width, height, unit_index);
   const int *quadtset = get_quadparm_from_qindex(
       qindex, superres_denom, is_intra_only, is_luma, cnn_index);
   const int A0_min = quadtset[2];
@@ -1121,41 +1118,70 @@
   const int *this_norestorecosts =
       norestore_ctx == -1 ? null_norestorecosts : norestorecosts[norestore_ctx];
 
-  // For each quadtree unit, compute the best partitioning out of
-  // NONE, SPLIT, HORZ and VERT based on RD cost.
-  std::vector<int> split;              // selected partitioning options.
-  std::vector<std::pair<int, int>> A;  // selected a0, a1 weight pairs.
-  // Previous a0, a1 pair is mid-point of the range by default.
-  std::pair<int, int> prev_A = std::make_pair(8 + A0_min, 8 + A1_min);
-  *rdcost = 0;
-  // TODO(urvang): Include padded area in a unit if it's < unit size / 2?
-  // If so, need to modify / replace quad_tree_get_unit_info_length().
-  // Also double check: quad_tree_get_split_info_length().
-  for (int row = 0; row < height; row += quadtree_max_size) {
-    for (int col = 0; col < width; col += quadtree_max_size) {
-      double this_rdcost;
-      select_quadtree_partitioning(
-          interm, src, src_stride, row, col, width, height, quadtree_max_size,
-          quadtset, rdmult, prev_A, splitcosts, this_norestorecosts, bit_depth,
-          dgd, dgd_stride, split, A, &this_rdcost);
-      // updates.
-      *rdcost += this_rdcost;
-      prev_A = A.back();
+  // Try all possible quadtree unit sizes.
+  int best_unit_index = -1;
+  std::vector<int> best_split;              // selected partitioning options.
+  std::vector<std::pair<int, int>> best_A;  // selected a0, a1 weight pairs.
+  double best_rdcost_total = DBL_MAX;
+  for (int this_unit_index = 0; this_unit_index <= 1; ++this_unit_index) {
+    const int quadtree_max_size =
+        quad_tree_get_unit_size(width, height, this_unit_index);
+    // For each quadtree unit, compute the best partitioning out of
+    // NONE, SPLIT, HORZ and VERT based on RD cost.
+    std::vector<int> this_split;              // selected partitioning options.
+    std::vector<std::pair<int, int>> this_A;  // selected a0, a1 weight pairs.
+    double this_rdcost_total = 0.0;
+    // Previous a0, a1 pair is mid-point of the range by default.
+    std::pair<int, int> prev_A = std::make_pair(8 + A0_min, 8 + A1_min);
+    // TODO(urvang): Include padded area in a unit if it's < unit size / 2?
+    // If so, need to modify / replace quad_tree_get_unit_info_length().
+    // Also double check: quad_tree_get_split_info_length().
+    for (int row = 0; row < height; row += quadtree_max_size) {
+      for (int col = 0; col < width; col += quadtree_max_size) {
+        double this_rdcost;
+        select_quadtree_partitioning(
+            interm, src, src_stride, row, col, width, height, quadtree_max_size,
+            quadtset, rdmult, prev_A, splitcosts, this_norestorecosts,
+            bit_depth, dgd, dgd_stride, this_split, this_A, &this_rdcost);
+        // updates.
+        this_rdcost_total += this_rdcost;
+        prev_A = this_A.back();
+      }
+    }
+    // Update best options.
+    if (this_rdcost_total < best_rdcost_total) {
+      best_unit_index = this_unit_index;
+      best_split = this_split;
+      best_A = this_A;
+      best_rdcost_total = this_rdcost_total;
     }
   }
 
-  // Fill in the decisions.
-  quad_info->unit_index = unit_index;
-  quad_info->split_info_length = (int)split.size();
-  quad_info->unit_info_length = (int)A.size();
+  // Fill in the best options.
+  quad_info->unit_index = best_unit_index;
+  quad_info->split_info_length = (int)best_split.size();
+  quad_info->unit_info_length = (int)best_A.size();
   av1_alloc_quadtree_struct(cm, quad_info);
-  for (unsigned int i = 0; i < split.size(); ++i) {
-    quad_info->split_info[i].split = split[i];
+  for (unsigned int i = 0; i < best_split.size(); ++i) {
+    quad_info->split_info[i].split = best_split[i];
   }
-  for (unsigned int i = 0; i < A.size(); ++i) {
-    quad_info->unit_info[i].xqd[0] = A[i].first;
-    quad_info->unit_info[i].xqd[1] = A[i].second;
+  for (unsigned int i = 0; i < best_A.size(); ++i) {
+    quad_info->unit_info[i].xqd[0] = best_A[i].first;
+    quad_info->unit_info[i].xqd[1] = best_A[i].second;
   }
+  *rdcost = best_rdcost_total;
+
+  // Apply guided restoration to 'dgd' using best options above.
+  size_t split_index = 0;
+  size_t A_index = 0;
+  for (int row = 0; row < height; row += quad_info->unit_size) {
+    for (int col = 0; col < width; col += quad_info->unit_size) {
+      apply_quadtree_partitioning(
+          interm, row, col, width, height, quad_info->unit_size, quadtset,
+          bit_depth, best_split, split_index, best_A, A_index, dgd, dgd_stride);
+    }
+  }
+
   return 1;
 }
 
diff --git a/av1/common/guided_quadtree.h b/av1/common/guided_quadtree.h
index 2d953de..31550eb 100644
--- a/av1/common/guided_quadtree.h
+++ b/av1/common/guided_quadtree.h
@@ -46,17 +46,13 @@
   return 0;
 }
 
-// Get quad tree unit index based on dimensions.
-static INLINE int quad_tree_get_unit_index(int width, int height) {
-  return (width * height <= 1280 * 720);
-}
-
 // Get quad tree unit size.
 static INLINE int quad_tree_get_unit_size(int width, int height,
-                                          int quad_level) {
-  (void)width;
-  (void)height;
-  return 512 >> quad_level;
+                                          int unit_index) {
+  const bool is_720p_or_smaller = (width * height <= 1280 * 720);
+  const int min_unit_size = is_720p_or_smaller ? 256 : 512;
+  assert(unit_index >= 0 && unit_index <= 1);
+  return min_unit_size << unit_index;
 }
 
 // Allocates buffers in 'quad_info' assuming 'quad_info->unit_index',
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 4604ad6..0304328 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -450,8 +450,8 @@
 /*!\cond */
 #if CONFIG_CNN_GUIDED_QUADTREE
 typedef struct {
-  int unit_index;  // unit size index inferred from frame dimensions.
-  int unit_size;   // inferred from unit_index.
+  int unit_index;  // index signaled in bitstream.
+  int unit_size;   // inferred from unit_index and frame dimensions.
   int split_info_length;
   int unit_info_length;
   QUADUnitInfo *unit_info;
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index f7628d7..a1088f0 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2146,6 +2146,15 @@
 }
 
 #if CONFIG_CNN_RESTORATION
+
+#if CONFIG_CNN_GUIDED_QUADTREE
+// Read quad tree unit index.
+static INLINE int quad_tree_read_unit_index(struct aom_read_bit_buffer *rb) {
+  const int unit_index = aom_rb_read_bit(rb);
+  return unit_index;
+}
+#endif  // CONFIG_CNN_GUIDED_QUADTREE
+
 static void decode_cnn(AV1_COMMON *cm, struct aom_read_bit_buffer *rb) {
   for (int plane = 0; plane < av1_num_planes(cm); ++plane) {
     if (av1_allow_cnn_for_plane(cm, plane)) {
@@ -2169,8 +2178,7 @@
 #if CONFIG_CNN_GUIDED_QUADTREE
   if (cm->use_cnn[0]) {
     QUADInfo *qi = &cm->cnn_quad_info;
-    qi->unit_index = quad_tree_get_unit_index(cm->superres_upscaled_width,
-                                              cm->superres_upscaled_height);
+    qi->unit_index = quad_tree_read_unit_index(rb);
     qi->unit_size =
         quad_tree_get_unit_size(cm->superres_upscaled_width,
                                 cm->superres_upscaled_height, qi->unit_index);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index c9a6216..88817af 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -3120,6 +3120,16 @@
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
 
 #if CONFIG_CNN_RESTORATION
+
+#if CONFIG_CNN_GUIDED_QUADTREE
+// Write quad tree unit index.
+static INLINE void quad_tree_write_unit_index(struct aom_write_bit_buffer *wb,
+                                              const QUADInfo *const qi) {
+  assert(qi->unit_index >= 0 && qi->unit_index <= 1);
+  aom_wb_write_bit(wb, qi->unit_index);
+}
+#endif  // CONFIG_CNN_GUIDED_QUADTREE
+
 static void encode_cnn(AV1_COMMON *cm, struct aom_write_bit_buffer *wb) {
   for (int plane = 0; plane < av1_num_planes(cm); ++plane) {
     if (av1_allow_cnn_for_plane(cm, plane)) {
@@ -3141,9 +3151,11 @@
   }
 #if CONFIG_CNN_GUIDED_QUADTREE
   if (cm->use_cnn[0]) {
-    assert(cm->cnn_quad_info.unit_index ==
-           quad_tree_get_unit_index(cm->superres_upscaled_width,
-                                    cm->superres_upscaled_height));
+    quad_tree_write_unit_index(wb, &cm->cnn_quad_info);
+    assert(cm->cnn_quad_info.unit_size ==
+           quad_tree_get_unit_size(cm->superres_upscaled_width,
+                                   cm->superres_upscaled_height,
+                                   cm->cnn_quad_info.unit_index));
     assert(cm->cnn_quad_info.split_info_length ==
            quad_tree_get_split_info_length(cm->superres_upscaled_width,
                                            cm->superres_upscaled_height,