Guided CNN Restoration: further improvements to unit sizes

Includes following changes:

* Similar to traditional loop restoration framework: If remaining rows/cols

  near boundary are half-unit-height/width, extend the previous unit to cover those rows/cols.
* For narrow units near boundary, allow splitting either bigger dimension only (to HORZ/VERT type).
* Signal one of 4 unit sizes at frame level:  largest allowed size is 

  min(2048, next_power_of_2 of largest dimension). And next 3 sizes are max_size/2, max_size/4 and max_size/8.

Also some code cleanups related to this.
diff --git a/av1/common/cnn_tflite.cc b/av1/common/cnn_tflite.cc
index b94792b..6ee12cc 100644
--- a/av1/common/cnn_tflite.cc
+++ b/av1/common/cnn_tflite.cc
@@ -618,30 +618,19 @@
   return 1;
 }
 
-typedef enum {
-  GUIDED_QT_NONE,
-  GUIDED_QT_SPLIT,
-  GUIDED_QT_HORZ,
-  GUIDED_QT_VERT,
-  GUIDED_QT_TYPES,
-  GUIDED_QT_INVALID = -1
-} GuidedQuadTreePartitionType;
-
 // Get unit width and height based on max size and partition type.
-static void get_unit_size(int quadtree_max_size,
+static void get_unit_size(int max_unit_width, int max_unit_height,
                           GuidedQuadTreePartitionType partition_type,
                           int *unit_width, int *unit_height) {
   assert(partition_type >= 0 && partition_type < GUIDED_QT_TYPES);
-  const int full_size = quadtree_max_size;
-  const int half_size = quadtree_max_size >> 1;
   *unit_width =
       (partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_HORZ)
-          ? full_size
-          : half_size;
+          ? max_unit_width
+          : max_unit_width >> 1;
   *unit_height =
       (partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_VERT)
-          ? full_size
-          : half_size;
+          ? max_unit_height
+          : max_unit_height >> 1;
 }
 
 // ------------------- Guided Quadtree: Encoder ------------------------------//
@@ -968,17 +957,24 @@
     const std::vector<std::vector<std::vector<double>>> &interm,
     GuidedQuadTreePartitionType partition_type, const uint16_t *src,
     int src_stride, const uint16_t *dgd, int dgd_stride, int start_row,
-    int end_row, int start_col, int end_col, int quadtree_max_size,
-    const int *quadtset, int rdmult, const std::pair<int, int> &prev_A,
-    const int *splitcosts, const int *norestorecosts, int bit_depth,
-    bool is_partial_unit, double *this_rdcost,
-    std::vector<std::vector<uint16_t>> &out,
+    int end_row, int start_col, int end_col, int max_unit_width,
+    int max_unit_height, const int *quadtset, int rdmult,
+    const std::pair<int, int> &prev_A, const int *quad_split_costs,
+    const int *binary_split_costs, const int *norestorecosts, int bit_depth,
+    bool is_horz_partitioning_allowed, int is_vert_partitioning_allowed,
+    double *this_rdcost, std::vector<std::vector<uint16_t>> &out,
     std::vector<std::pair<int, int>> &A) {
-  assert(IMPLIES(is_partial_unit, partition_type == GUIDED_QT_NONE));
+  assert(IMPLIES(
+      !is_horz_partitioning_allowed,
+      partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_VERT));
+  assert(IMPLIES(
+      !is_vert_partitioning_allowed,
+      partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_HORZ));
   // Get unit width and height based on partition type.
   int unit_width;
   int unit_height;
-  get_unit_size(quadtree_max_size, partition_type, &unit_width, &unit_height);
+  get_unit_size(max_unit_width, max_unit_height, partition_type, &unit_width,
+                &unit_height);
 
   // Compute restored unit, a0 and a1.
   generate_linear_combination(interm, src, src_stride, dgd, dgd_stride,
@@ -995,11 +991,16 @@
       compute_sse(out, src, src_stride, start_row, end_row, start_col, end_col);
 
   // Compute Rate.
-  const int num_bits_for_a = compute_rate(A, prev_A, quadtset, norestorecosts);
-  // Partition is implied to be NONE in case of partial unit.
+  const int a_signaling_cost =
+      compute_rate(A, prev_A, quadtset, norestorecosts);
+  // Partition signaling cost depending on 1, 2 or 4 possible partition types.
   const int partition_signaling_cost =
-      is_partial_unit ? 0 : splitcosts[partition_type];
-  const int bitrate = num_bits_for_a + partition_signaling_cost;
+      is_horz_partitioning_allowed && is_vert_partitioning_allowed
+          ? quad_split_costs[partition_type]
+      : (is_horz_partitioning_allowed || is_vert_partitioning_allowed)
+          ? binary_split_costs[partition_type]
+          : 0;
+  const int bitrate = a_signaling_cost + partition_signaling_cost;
 
   // Compute RDCost.
   *this_rdcost =
@@ -1013,39 +1014,54 @@
 static void select_quadtree_partitioning(
     const std::vector<std::vector<std::vector<double>>> &interm,
     const uint16_t *src, int src_stride, int start_row, int start_col,
-    int width, int height, int quadtree_max_size, const int *quadtset,
-    int rdmult, const std::pair<int, int> &prev_A, const int *splitcosts,
-    const int norestorecosts[2], int bit_depth, const uint16_t *dgd,
-    int dgd_stride, std::vector<int> &split,
+    int width, int height, int quadtree_max_size, int max_unit_width,
+    int max_unit_height, const int *quadtset, int rdmult,
+    const std::pair<int, int> &prev_A, const int *quad_split_costs,
+    const int *binary_split_costs, const int norestorecosts[2], int bit_depth,
+    const uint16_t *dgd, int dgd_stride, std::vector<int> &split,
     std::vector<std::pair<int, int>> &A, double *rdcost) {
-  const int end_row = AOMMIN(start_row + quadtree_max_size, height);
-  const int end_col = AOMMIN(start_col + quadtree_max_size, width);
-  const bool is_partial_unit = (start_row + quadtree_max_size > height) ||
-                               (start_col + quadtree_max_size > width);
+  const int end_row = AOMMIN(start_row + max_unit_height, height);
+  const int end_col = AOMMIN(start_col + max_unit_width, width);
+  // Check for special cases near boundary.
+  const bool is_horz_partitioning_allowed =
+      (max_unit_height >= quadtree_max_size);
+  const bool is_vert_partitioning_allowed =
+      (max_unit_width >= quadtree_max_size);
+  const bool is_split_partitioning_allowed =
+      is_horz_partitioning_allowed && is_vert_partitioning_allowed;
 
   auto best_rdcost = DBL_MAX;
   std::vector<std::pair<int, int>> best_A;
   std::vector<std::vector<uint16_t>> best_out(
-      quadtree_max_size, std::vector<uint16_t>(quadtree_max_size));
+      max_unit_height, std::vector<uint16_t>(max_unit_width));
   GuidedQuadTreePartitionType best_partition_type = GUIDED_QT_INVALID;
 
   for (int type = 0; type < GUIDED_QT_TYPES; ++type) {
     const auto this_partition_type = (GuidedQuadTreePartitionType)type;
-    // Special case: if only partial unit is within boundary, we implicitly
-    // use NONE partitioning and do not try the splitting options.
-    if (is_partial_unit && (this_partition_type != GUIDED_QT_NONE)) {
+    // Check for special cases near boundary.
+    if (!is_horz_partitioning_allowed &&
+        (this_partition_type == GUIDED_QT_HORZ)) {
       continue;
     }
-
+    if (!is_vert_partitioning_allowed &&
+        (this_partition_type == GUIDED_QT_VERT)) {
+      continue;
+    }
+    if (!is_split_partitioning_allowed &&
+        (this_partition_type == GUIDED_QT_SPLIT)) {
+      continue;
+    }
+    // Try this partition type.
     double this_rdcost;
     std::vector<std::pair<int, int>> this_A;
     std::vector<std::vector<uint16_t>> this_out(
-        quadtree_max_size, std::vector<uint16_t>(quadtree_max_size));
-    try_one_partition(interm, this_partition_type, src, src_stride, dgd,
-                      dgd_stride, start_row, end_row, start_col, end_col,
-                      quadtree_max_size, quadtset, rdmult, prev_A, splitcosts,
-                      norestorecosts, bit_depth, is_partial_unit, &this_rdcost,
-                      this_out, this_A);
+        max_unit_height, std::vector<uint16_t>(max_unit_width));
+    try_one_partition(
+        interm, this_partition_type, src, src_stride, dgd, dgd_stride,
+        start_row, end_row, start_col, end_col, max_unit_width, max_unit_height,
+        quadtset, rdmult, prev_A, quad_split_costs, binary_split_costs,
+        norestorecosts, bit_depth, is_horz_partitioning_allowed,
+        is_vert_partitioning_allowed, &this_rdcost, this_out, this_A);
     if (this_rdcost < best_rdcost) {
       best_rdcost = this_rdcost;
       best_A = this_A;
@@ -1063,45 +1079,29 @@
   }
 
   // Save split decision.
-  if (is_partial_unit) {
+  if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
     // Nothing should be added to 'split' array.
     assert(best_partition_type == GUIDED_QT_NONE);
     return;
   }
-  switch (best_partition_type) {
-    case GUIDED_QT_NONE:
-      split.push_back(0);
-      split.push_back(0);
-      break;
-    case GUIDED_QT_SPLIT:
-      split.push_back(0);
-      split.push_back(1);
-      break;
-    case GUIDED_QT_HORZ:
-      split.push_back(1);
-      split.push_back(1);
-      break;
-    case GUIDED_QT_VERT:
-      split.push_back(1);
-      split.push_back(0);
-      break;
-    default: assert(0 && "Wrong partition type"); break;
-  }
+  assert(best_partition_type >= 0 && best_partition_type < GUIDED_QT_TYPES);
+  split.push_back(best_partition_type);
 }
 
 static void apply_quadtree_partitioning(
     const std::vector<std::vector<std::vector<double>>> &interm, int start_row,
     int start_col, int width, int height, int quadtree_max_size,
-    const int *quadtset, int bit_depth, const std::vector<int> &split,
-    size_t &split_index, const std::vector<std::pair<int, int>> &A,
-    size_t &A_index, uint16_t *dgd, int dgd_stride);
+    int max_unit_width, int max_unit_height, const int *quadtset, int bit_depth,
+    const std::vector<int> &split, size_t &split_index,
+    const std::vector<std::pair<int, int>> &A, size_t &A_index, uint16_t *dgd,
+    int dgd_stride);
 
 // Top-level function to apply guided restoration on encoder side.
 static int restore_cnn_quadtree_encode_img_tflite_highbd(
     YV12_BUFFER_CONFIG *source_frame, AV1_COMMON *cm, int superres_denom,
-    int rdmult, const int *splitcosts, int (*norestorecosts)[2],
-    int num_threads, int bit_depth, int is_intra_only, int is_luma,
-    int cnn_index, QUADInfo *quad_info, double *rdcost) {
+    int rdmult, const int *quad_split_costs, const int *binary_split_costs,
+    int (*norestorecosts)[2], int num_threads, int bit_depth, int is_intra_only,
+    int is_luma, int cnn_index, QUADInfo *quad_info, double *rdcost) {
   YV12_BUFFER_CONFIG *dgd_buf = &cm->cur_frame->buf;
   uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd_buf->y_buffer);
   const int dgd_stride = dgd_buf->y_stride;
@@ -1136,7 +1136,8 @@
   std::vector<int> best_split;              // selected partitioning options.
   std::vector<std::pair<int, int>> best_A;  // selected a0, a1 weight pairs.
   double best_rdcost_total = DBL_MAX;
-  for (int this_unit_index = 0; this_unit_index <= 1; ++this_unit_index) {
+  for (int this_unit_index = 0; this_unit_index < GUIDED_QT_UNIT_SIZES;
+       ++this_unit_index) {
     const int quadtree_max_size =
         quad_tree_get_unit_size(width, height, this_unit_index);
     // For each quadtree unit, compute the best partitioning out of
@@ -1147,20 +1148,27 @@
     // Previous a0, a1 pair is mid-point of the range by default.
     std::pair<int, int> prev_A =
         std::make_pair(GUIDED_A_MID + A0_min, GUIDED_A_MID + A1_min);
-    // TODO(urvang): Include padded area in a unit if it's < unit size / 2?
-    // If so, need to modify / replace quad_tree_get_unit_info_length().
-    // Also double check: quad_tree_get_split_info_length().
-    for (int row = 0; row < height; row += quadtree_max_size) {
-      for (int col = 0; col < width; col += quadtree_max_size) {
+    const int ext_size = quadtree_max_size * 3 / 2;
+    for (int row = 0; row < height;) {
+      const int remaining_height = height - row;
+      const int this_unit_height =
+          (remaining_height < ext_size) ? remaining_height : quadtree_max_size;
+      for (int col = 0; col < width;) {
+        const int remaining_width = width - col;
+        const int this_unit_width =
+            (remaining_width < ext_size) ? remaining_width : quadtree_max_size;
         double this_rdcost;
         select_quadtree_partitioning(
             interm, src, src_stride, row, col, width, height, quadtree_max_size,
-            quadtset, rdmult, prev_A, splitcosts, this_norestorecosts,
+            this_unit_width, this_unit_height, quadtset, rdmult, prev_A,
+            quad_split_costs, binary_split_costs, this_norestorecosts,
             bit_depth, dgd, dgd_stride, this_split, this_A, &this_rdcost);
         // updates.
         this_rdcost_total += this_rdcost;
         prev_A = this_A.back();
+        col += this_unit_width;
       }
+      row += this_unit_height;
     }
     // Update best options.
     if (this_rdcost_total < best_rdcost_total) {
@@ -1188,22 +1196,33 @@
   // Apply guided restoration to 'dgd' using best options above.
   size_t split_index = 0;
   size_t A_index = 0;
-  for (int row = 0; row < height; row += quad_info->unit_size) {
-    for (int col = 0; col < width; col += quad_info->unit_size) {
+  const int quadtree_max_size = quad_info->unit_size;
+  const int ext_size = quadtree_max_size * 3 / 2;
+  for (int row = 0; row < height;) {
+    const int remaining_height = height - row;
+    const int this_unit_height =
+        (remaining_height < ext_size) ? remaining_height : quadtree_max_size;
+    for (int col = 0; col < width;) {
+      const int remaining_width = width - col;
+      const int this_unit_width =
+          (remaining_width < ext_size) ? remaining_width : quadtree_max_size;
       apply_quadtree_partitioning(
-          interm, row, col, width, height, quad_info->unit_size, quadtset,
-          bit_depth, best_split, split_index, best_A, A_index, dgd, dgd_stride);
+          interm, row, col, width, height, quadtree_max_size, this_unit_width,
+          this_unit_height, quadtset, bit_depth, best_split, split_index,
+          best_A, A_index, dgd, dgd_stride);
+      col += this_unit_width;
     }
+    row += this_unit_height;
   }
 
   return 1;
 }
 
 extern "C" int av1_restore_cnn_quadtree_encode_tflite(
-    AV1_COMMON *cm, YV12_BUFFER_CONFIG *source_frame, int RDMULT,
-    int *splitcosts, int (*norestorecosts)[2], int num_threads,
-    const int apply_cnn[MAX_MB_PLANE], const int cnn_indices[MAX_MB_PLANE],
-    QUADInfo *quad_info, double *rdcost) {
+    struct AV1Common *cm, YV12_BUFFER_CONFIG *source_frame, int RDMULT,
+    int *quad_split_costs, int *binary_split_costs, int (*norestorecosts)[2],
+    int num_threads, const int apply_cnn[MAX_MB_PLANE],
+    const int cnn_indices[MAX_MB_PLANE], QUADInfo *quad_info, double *rdcost) {
   YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf;
   const int is_intra_only = frame_is_intra_only(cm);
   for (int plane = 0; plane < av1_num_planes(cm); ++plane) {
@@ -1217,8 +1236,9 @@
       case AOM_PLANE_Y:
         ret = restore_cnn_quadtree_encode_img_tflite_highbd(
             source_frame, cm, cm->superres_scale_denominator, RDMULT,
-            splitcosts, norestorecosts, num_threads, cm->seq_params.bit_depth,
-            is_intra_only, is_luma, cnn_index, quad_info, rdcost);
+            quad_split_costs, binary_split_costs, norestorecosts, num_threads,
+            cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index,
+            quad_info, rdcost);
         if (ret == 0) return ret;
         break;
       case AOM_PLANE_U:
@@ -1290,42 +1310,30 @@
 static void apply_quadtree_partitioning(
     const std::vector<std::vector<std::vector<double>>> &interm, int start_row,
     int start_col, int width, int height, int quadtree_max_size,
-    const int *quadtset, int bit_depth, const std::vector<int> &split,
-    size_t &split_index, const std::vector<std::pair<int, int>> &A,
-    size_t &A_index, uint16_t *dgd, int dgd_stride) {
-  const int end_row = AOMMIN(start_row + quadtree_max_size, height);
-  const int end_col = AOMMIN(start_col + quadtree_max_size, width);
-  const bool is_partial_unit = (start_row + quadtree_max_size > height) ||
-                               (start_col + quadtree_max_size > width);
+    int max_unit_width, int max_unit_height, const int *quadtset, int bit_depth,
+    const std::vector<int> &split, size_t &split_index,
+    const std::vector<std::pair<int, int>> &A, size_t &A_index, uint16_t *dgd,
+    int dgd_stride) {
+  const int end_row = AOMMIN(start_row + max_unit_height, height);
+  const int end_col = AOMMIN(start_col + max_unit_width, width);
+  // Check for special cases near boundary.
+  const bool is_horz_partitioning_allowed =
+      (max_unit_height >= quadtree_max_size);
+  const bool is_vert_partitioning_allowed =
+      (max_unit_width >= quadtree_max_size);
 
   // Get partition type.
   GuidedQuadTreePartitionType partition_type = GUIDED_QT_NONE;
-  if (!is_partial_unit) {
-    const int spl1 = split[split_index++];
-    const int spl2 = split[split_index++];
-    if (spl1 == 0) {
-      if (spl2 == 0) {
-        partition_type = GUIDED_QT_NONE;  // (0, 0)
-      } else {
-        assert(spl2 == 1);
-        partition_type = GUIDED_QT_SPLIT;  // (0, 1)
-      }
-    } else {
-      assert(spl1 == 1);
-      if (spl2 == 1) {
-        partition_type = GUIDED_QT_HORZ;  // (1, 1)
-      } else {
-        assert(spl2 == 0);
-        partition_type = GUIDED_QT_VERT;  // (1, 0)
-      }
-    }
+  if (is_horz_partitioning_allowed || is_vert_partitioning_allowed) {
+    partition_type = (GuidedQuadTreePartitionType)split[split_index++];
   }
   assert(partition_type >= 0 && partition_type < GUIDED_QT_TYPES);
 
   // Get unit width and height based on partition type.
   int unit_width;
   int unit_height;
-  get_unit_size(quadtree_max_size, partition_type, &unit_width, &unit_height);
+  get_unit_size(max_unit_width, max_unit_height, partition_type, &unit_width,
+                &unit_height);
 
   // Compute restored unit, a0 and a1 with given A parameters.
   apply_linear_combination(interm, start_row, end_row, start_col, end_col,
@@ -1377,12 +1385,22 @@
   // For each quadtree unit, apply given quadtree partitioning.
   size_t split_index = 0;
   size_t A_index = 0;
-  for (int row = 0; row < height; row += quadtree_max_size) {
-    for (int col = 0; col < width; col += quadtree_max_size) {
+  const int ext_size = quadtree_max_size * 3 / 2;
+  for (int row = 0; row < height;) {
+    const int remaining_height = height - row;
+    const int this_unit_height =
+        (remaining_height < ext_size) ? remaining_height : quadtree_max_size;
+    for (int col = 0; col < width;) {
+      const int remaining_width = width - col;
+      const int this_unit_width =
+          (remaining_width < ext_size) ? remaining_width : quadtree_max_size;
       apply_quadtree_partitioning(interm, row, col, width, height,
-                                  quadtree_max_size, quadtset, bit_depth, split,
+                                  quadtree_max_size, this_unit_width,
+                                  this_unit_height, quadtset, bit_depth, split,
                                   split_index, A, A_index, dgd, dgd_stride);
+      col += this_unit_width;
     }
+    row += this_unit_height;
   }
   assert(split_index == split.size());
   assert(A_index == A.size());
diff --git a/av1/common/cnn_tflite.h b/av1/common/cnn_tflite.h
index f74165c..3ed3682 100644
--- a/av1/common/cnn_tflite.h
+++ b/av1/common/cnn_tflite.h
@@ -114,9 +114,9 @@
 // Apply Guided CNN restoration on encoder side.
 int av1_restore_cnn_quadtree_encode_tflite(
     struct AV1Common *cm, YV12_BUFFER_CONFIG *source_frame, int RDMULT,
-    int *splitcosts, int (*norestorecosts)[2], int num_threads,
-    const int apply_cnn[MAX_MB_PLANE], const int cnn_indices[MAX_MB_PLANE],
-    QUADInfo *quad_info, double *rdcost);
+    int *quad_split_costs, int *binary_split_costs, int (*norestorecosts)[2],
+    int num_threads, const int apply_cnn[MAX_MB_PLANE],
+    const int cnn_indices[MAX_MB_PLANE], QUADInfo *quad_info, double *rdcost);
 
 // Apply Guided CNN restoration on decoder side.
 int av1_restore_cnn_quadtree_decode_tflite(struct AV1Common *cm,
diff --git a/av1/common/entropy.c b/av1/common/entropy.c
index 89a47ff..1e63d77 100644
--- a/av1/common/entropy.c
+++ b/av1/common/entropy.c
@@ -241,6 +241,7 @@
 #endif  // CONFIG_RST_MERGECOEFFS
 #if CONFIG_CNN_GUIDED_QUADTREE
   RESET_CDF_COUNTER(fc->cnn_guided_quad_cdf, 4);
+  RESET_CDF_COUNTER(fc->cnn_guided_binary_cdf, 2);
   RESET_CDF_COUNTER(fc->cnn_guided_norestore_cdf, 2);
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
 #if CONFIG_AIMC
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 0d15acb..def65a8 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1560,6 +1560,9 @@
 static const aom_cdf_prob default_cnn_guided_quad_cdf[CDF_SIZE(4)] = {
   AOM_CDF4(23552, 24576, 28672),
 };
+static const aom_cdf_prob default_cnn_guided_binary_cdf[CDF_SIZE(2)] = {
+  AOM_CDF2(23552),
+};
 static const aom_cdf_prob
     default_cnn_guided_norestore_cdf[GUIDED_NORESTORE_CONTEXTS][CDF_SIZE(2)] = {
       { AOM_CDF2(16384) }, { AOM_CDF2(24576) }
@@ -1967,6 +1970,7 @@
 #endif  // CONFIG_PC_WIENER
 #if CONFIG_CNN_GUIDED_QUADTREE
   av1_copy(fc->cnn_guided_quad_cdf, default_cnn_guided_quad_cdf);
+  av1_copy(fc->cnn_guided_binary_cdf, default_cnn_guided_binary_cdf);
   av1_copy(fc->cnn_guided_norestore_cdf, default_cnn_guided_norestore_cdf);
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
 #if CONFIG_AIMC
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 3b4e443..e277f52 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -240,7 +240,12 @@
   aom_cdf_prob merged_param_cdf[CDF_SIZE(2)];
 #endif  // CONFIG_RST_MERGECOEFFS
 #if CONFIG_CNN_GUIDED_QUADTREE
+  // CDF to signal one of GUIDED_QT_TYPES partition types for each unit, when
+  // all types are allowed.
   aom_cdf_prob cnn_guided_quad_cdf[CDF_SIZE(4)];
+  // CDF to signal one of 2 partition types for each unit: GUIDED_QT_NONE or
+  // GUIDED_QT_HORZ/VERT, when only 2 types are allowed.
+  aom_cdf_prob cnn_guided_binary_cdf[CDF_SIZE(2)];
   aom_cdf_prob cnn_guided_norestore_cdf[GUIDED_NORESTORE_CONTEXTS][CDF_SIZE(2)];
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
 #if !CONFIG_AIMC
diff --git a/av1/common/guided_quadtree.c b/av1/common/guided_quadtree.c
index e296c0b..d9c4400 100644
--- a/av1/common/guided_quadtree.c
+++ b/av1/common/guided_quadtree.c
@@ -275,52 +275,63 @@
   dst->signaled = src->signaled;
 }
 
-// Returns (int)floor(x / y),
-#define DIVIDE_WITH_FLOOR(x, y) ((x) / (y))
-// Returns (int)ceil(x / y),
-#define DIVIDE_WITH_CEILING(x, y) (((x) + (y)-1) / (y))
-
-int quad_tree_get_unit_info_length(int width, int height, int unit_length,
-                                   const QUADSplitInfo *split_info,
-                                   int split_info_length) {
-  // We can compute total units as follows:
-  // (1) regular units: they may / may not be split. So, compute length of
-  // regular unit info by going through the split_info array. (2) unregular
-  // units (blocks near boundary that are NOT unit_length in size): they are
-  // never split. So, length of unregular unit info is same as number of
-  // unregular units.
-  const int regular_units = DIVIDE_WITH_FLOOR(width, unit_length) *
-                            DIVIDE_WITH_FLOOR(height, unit_length);
-  assert(regular_units * 2 == split_info_length);
-  const int total_units = DIVIDE_WITH_CEILING(width, unit_length) *
-                          DIVIDE_WITH_CEILING(height, unit_length);
-  const int unregular_unit_info_len = total_units - regular_units;
-
-  int regular_unit_info_len = 0;
-  for (int i = 0; i < split_info_length; i += 2) {
-    if (split_info == NULL ||
-        (split_info[i].split == 0 && split_info[i + 1].split == 1)) {
-      regular_unit_info_len += 4;  // Split
-    } else if (split_info[i].split == 1 && split_info[i + 1].split == 1) {
-      regular_unit_info_len += 2;  // Horz
-    } else if (split_info[i].split == 1 && split_info[i + 1].split == 0) {
-      regular_unit_info_len += 2;  // Vert
-    } else {
-      assert(split_info[i].split == 0 && split_info[i + 1].split == 0);
-      regular_unit_info_len += 1;  // No split
+int quad_tree_get_max_unit_info_length(int width, int height, int unit_length) {
+  int unit_info_length = 0;
+  const int ext_size = unit_length * 3 / 2;
+  for (int row = 0; row < height;) {
+    const int remaining_height = height - row;
+    const int this_unit_height =
+        (remaining_height < ext_size) ? remaining_height : unit_length;
+    for (int col = 0; col < width;) {
+      const int remaining_width = width - col;
+      const int this_unit_width =
+          (remaining_width < ext_size) ? remaining_width : unit_length;
+      // Check for special cases near boundary.
+      const bool is_horz_partitioning_allowed =
+          (this_unit_height >= unit_length);
+      const bool is_vert_partitioning_allowed =
+          (this_unit_width >= unit_length);
+      if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
+        // Implicitly no split, so single unit info will be signaled.
+        ++unit_info_length;
+      } else {
+        // Assume maximum possible sub-units.
+        const int max_sub_units =
+            is_horz_partitioning_allowed && is_vert_partitioning_allowed ? 4
+                                                                         : 2;
+        unit_info_length += max_sub_units;
+      }
+      col += this_unit_width;
     }
+    row += this_unit_height;
   }
-
-  return regular_unit_info_len + unregular_unit_info_len;
+  return unit_info_length;
 }
 
 int quad_tree_get_split_info_length(int width, int height, int unit_length) {
-  // Split info only signaled for units of full size. Blocks near boundaries are
-  // never split, so no info is signaled for those.
-  const int num_split_info_wide = DIVIDE_WITH_FLOOR(width, unit_length);
-  const int num_split_info_high = DIVIDE_WITH_FLOOR(height, unit_length);
-  // 2 bits signaled for each split info.
-  return num_split_info_wide * num_split_info_high * 2;
+  int split_info_len = 0;
+  const int ext_size = unit_length * 3 / 2;
+  for (int row = 0; row < height;) {
+    const int remaining_height = height - row;
+    const int this_unit_height =
+        (remaining_height < ext_size) ? remaining_height : unit_length;
+    for (int col = 0; col < width;) {
+      const int remaining_width = width - col;
+      const int this_unit_width =
+          (remaining_width < ext_size) ? remaining_width : unit_length;
+      // Check for special cases near boundary.
+      const bool is_horz_partitioning_allowed =
+          (this_unit_height >= unit_length);
+      const bool is_vert_partitioning_allowed =
+          (this_unit_width >= unit_length);
+      if (is_horz_partitioning_allowed || is_vert_partitioning_allowed) {
+        ++split_info_len;
+      }
+      col += this_unit_width;
+    }
+    row += this_unit_height;
+  }
+  return split_info_len;
 }
 
 void av1_alloc_quadtree_struct(struct AV1Common *cm, QUADInfo *quad_info) {
diff --git a/av1/common/guided_quadtree.h b/av1/common/guided_quadtree.h
index 4fef1e6..f51d9fe 100644
--- a/av1/common/guided_quadtree.h
+++ b/av1/common/guided_quadtree.h
@@ -30,17 +30,27 @@
 #define GUIDED_A_MID (GUIDED_A_NUM_VALUES >> 1)
 #define GUIDED_A_RANGE (GUIDED_A_NUM_VALUES - 1)
 #define GUIDED_A_PAIR_BITS (GUIDED_A_BITS * 2 - 1)
+#define GUIDED_QT_UNIT_SIZES_LOG2 2
+#define GUIDED_QT_UNIT_SIZES (1 << GUIDED_QT_UNIT_SIZES_LOG2)
+
+typedef enum {
+  GUIDED_QT_NONE,
+  GUIDED_QT_SPLIT,
+  GUIDED_QT_HORZ,
+  GUIDED_QT_VERT,
+  GUIDED_QT_TYPES,
+  GUIDED_QT_INVALID = -1
+} GuidedQuadTreePartitionType;
 
 int *get_quadparm_from_qindex(int qindex, int superres_denom, int is_intra_only,
                               int is_luma, int cnn_index);
 
 void quad_copy(const QUADInfo *src, QUADInfo *dst, struct AV1Common *cm);
-// Get the length of unit info array based on dimensions and split info.
-// If split_info == NULL, assumes each block uses split, thereby returning
-// longest possible unit info length.
-int quad_tree_get_unit_info_length(int width, int height, int unit_length,
-                                   const QUADSplitInfo *split_info,
-                                   int split_info_length);
+
+// Get the maximum possible length of unit info array based on dimensions,
+// assuming each block uses split.
+int quad_tree_get_max_unit_info_length(int width, int height, int unit_length);
+
 // Get the length of split info array based on dimensions.
 int quad_tree_get_split_info_length(int width, int height, int unit_length);
 
@@ -56,10 +66,12 @@
 // Get quad tree unit size.
 static INLINE int quad_tree_get_unit_size(int width, int height,
                                           int unit_index) {
-  const bool is_720p_or_smaller = (width * height <= 1280 * 720);
-  const int min_unit_size = is_720p_or_smaller ? 256 : 512;
-  assert(unit_index >= 0 && unit_index <= 1);
-  return min_unit_size << unit_index;
+  const int max_dim = AOMMAX(width, height);
+  const int max_dim_pow_2_bits = 1 + get_msb(max_dim);
+  const int max_dim_pow_2 = 1 << max_dim_pow_2_bits;
+  const int max_unit_size = AOMMAX(AOMMIN(max_dim_pow_2, 2048), 256);
+  assert(unit_index >= 0 && unit_index < GUIDED_QT_UNIT_SIZES);
+  return max_unit_size >> unit_index;
 }
 
 // Allocates buffers in 'quad_info' assuming 'quad_info->unit_index',
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index a71b6a6..a736131 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -137,7 +137,8 @@
 
 #if CONFIG_CNN_GUIDED_QUADTREE
 static AOM_INLINE void read_filter_quadtree(FRAME_CONTEXT *ctx, int QP,
-                                            int cnn_index, int superres_denom,
+                                            int cnn_index, int width,
+                                            int height, int superres_denom,
                                             int is_intra_only, QUADInfo *qi,
                                             aom_reader *rb);
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
@@ -1822,19 +1823,81 @@
   }
 }
 #if CONFIG_CNN_GUIDED_QUADTREE
+static void read_quadtree_split_info(FRAME_CONTEXT *ctx, int width, int height,
+                                     QUADInfo *qi, aom_reader *rb) {
+  int unit_info_length = 0;
+  int split_info_index = 0;
+  const int ext_size = qi->unit_size * 3 / 2;
+  for (int row = 0; row < height;) {
+    const int remaining_height = height - row;
+    const int this_unit_height =
+        (remaining_height < ext_size) ? remaining_height : qi->unit_size;
+    for (int col = 0; col < width;) {
+      const int remaining_width = width - col;
+      const int this_unit_width =
+          (remaining_width < ext_size) ? remaining_width : qi->unit_size;
+      // Check for special cases near boundary.
+      const bool is_horz_partitioning_allowed =
+          (this_unit_height >= qi->unit_size);
+      const bool is_vert_partitioning_allowed =
+          (this_unit_width >= qi->unit_size);
+      if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
+        // Implicitly no split, and single unit info will be signaled.
+        ++unit_info_length;
+      } else {
+        assert(split_info_index < qi->split_info_length);
+        // Read split info, depending on how many partition types are allowed.
+        const bool all_partition_types_allowed =
+            (is_horz_partitioning_allowed && is_vert_partitioning_allowed);
+        if (all_partition_types_allowed) {
+          qi->split_info[split_info_index].split =
+              aom_read_symbol(rb, ctx->cnn_guided_quad_cdf, 4, ACCT_STR);
+        } else {
+          const int do_split =
+              aom_read_symbol(rb, ctx->cnn_guided_binary_cdf, 2, ACCT_STR);
+          if (do_split) {
+            qi->split_info[split_info_index].split =
+                is_horz_partitioning_allowed ? GUIDED_QT_HORZ : GUIDED_QT_VERT;
+          } else {
+            qi->split_info[split_info_index].split = GUIDED_QT_NONE;
+          }
+        }
+        // Look at the split info to determine number of (sub)units.
+        const GuidedQuadTreePartitionType partition_type =
+            qi->split_info[split_info_index].split;
+        switch (partition_type) {
+          case GUIDED_QT_NONE: unit_info_length += 1; break;
+          case GUIDED_QT_HORZ:
+          case GUIDED_QT_VERT: unit_info_length += 2; break;
+          case GUIDED_QT_SPLIT: unit_info_length += 4; break;
+          default: assert(0 && "Wrong guided quadtree split type."); break;
+        }
+        ++split_info_index;
+      }
+      col += this_unit_width;
+    }
+    row += this_unit_height;
+  }
+  assert(qi->split_info_length == split_info_index);
+  qi->unit_info_length = unit_info_length;
+}
+
 static AOM_INLINE void read_filter_quadtree(FRAME_CONTEXT *ctx, int QP,
-                                            int cnn_index, int superres_denom,
+                                            int cnn_index, int width,
+                                            int height, int superres_denom,
                                             int is_intra_only, QUADInfo *qi,
                                             aom_reader *rb) {
-  int A0_min, A1_min;
-  int *quadtset;
-  quadtset =
+  // Read partitioning info.
+  read_quadtree_split_info(ctx, width, height, qi, rb);
+
+  // Read weight parameters 'a'.
+  const int *quadtset =
       get_quadparm_from_qindex(QP, superres_denom, is_intra_only, 1, cnn_index);
   const int norestore_ctx =
       get_guided_norestore_ctx(QP, superres_denom, is_intra_only);
 
-  A0_min = quadtset[2];
-  A1_min = quadtset[3];
+  const int A0_min = quadtset[2];
+  const int A1_min = quadtset[3];
 
   int ref_0 = GUIDED_A_MID;
   int ref_1 = GUIDED_A_MID;
@@ -1861,8 +1924,6 @@
       ref_0 = qi->unit_info[i].xqd[0] - A0_min;
       ref_1 = qi->unit_info[i].xqd[1] - A1_min;
     }
-    // printf("a0:%d a1:%d\n", qi->unit_info[i].xqd[0],
-    // qi->unit_info[i].xqd[1]);
   }
 }
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
@@ -1922,18 +1983,10 @@
 #if CONFIG_CNN_GUIDED_QUADTREE
     if (cm->use_cnn[0] && !cm->cnn_quad_info.signaled) {
       QUADInfo *qi = (QUADInfo *)&cm->cnn_quad_info;
-      for (int s = 0; s < qi->split_info_length; s += 2) {
-        const int split_index = aom_read_symbol(
-            reader, xd->tile_ctx->cnn_guided_quad_cdf, 4, ACCT_STR);
-        qi->split_info[s].split = split_index >> 1;
-        qi->split_info[s + 1].split = split_index & 1;
-      }
-      qi->unit_info_length = quad_tree_get_unit_info_length(
+      read_filter_quadtree(
+          xd->tile_ctx, cm->quant_params.base_qindex, cm->cnn_indices[0],
           cm->superres_upscaled_width, cm->superres_upscaled_height,
-          qi->unit_size, qi->split_info, qi->split_info_length);
-      read_filter_quadtree(xd->tile_ctx, cm->quant_params.base_qindex,
-                           cm->cnn_indices[0], cm->superres_scale_denominator,
-                           frame_is_intra_only(cm), qi, reader);
+          cm->superres_scale_denominator, frame_is_intra_only(cm), qi, reader);
       cm->cnn_quad_info.signaled = 1;
     }
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
@@ -2154,7 +2207,7 @@
 #if CONFIG_CNN_GUIDED_QUADTREE
 // Read quad tree unit index.
 static INLINE int quad_tree_read_unit_index(struct aom_read_bit_buffer *rb) {
-  const int unit_index = aom_rb_read_bit(rb);
+  const int unit_index = aom_rb_read_literal(rb, GUIDED_QT_UNIT_SIZES_LOG2);
   return unit_index;
 }
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
@@ -2191,9 +2244,9 @@
         qi->unit_size);
     // We allocate unit info assuming maximum number of possible units for now.
     // Actual length will be set later after actually reading split info.
-    qi->unit_info_length = quad_tree_get_unit_info_length(
+    qi->unit_info_length = quad_tree_get_max_unit_info_length(
         cm->superres_upscaled_width, cm->superres_upscaled_height,
-        qi->unit_size, NULL, qi->split_info_length);
+        qi->unit_size);
     av1_alloc_quadtree_struct(cm, qi);
   }
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 0c0c827..88d4b7e 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -85,8 +85,9 @@
 #if CONFIG_CNN_GUIDED_QUADTREE
 // Write crlc coeffs for one frame
 static void write_filter_quadtree(FRAME_CONTEXT *ctx, int qp, int cnn_index,
-                                  int superres_denom, int is_intra_only,
-                                  const QUADInfo *ci, aom_writer *wb);
+                                  int width, int height, int superres_denom,
+                                  int is_intra_only, const QUADInfo *qi,
+                                  aom_writer *wb);
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
 
 #if CONFIG_IBC_SR_EXT
@@ -2334,22 +2335,10 @@
 #if CONFIG_CNN_GUIDED_QUADTREE
   if (cm->use_cnn[0] && !cm->cnn_quad_info.signaled) {
     QUADInfo *qi = (QUADInfo *)&cm->cnn_quad_info;
-    assert(qi->split_info_length ==
-           quad_tree_get_split_info_length(cm->superres_upscaled_width,
-                                           cm->superres_upscaled_height,
-                                           qi->unit_size));
-    for (int s = 0; s < qi->split_info_length; s += 2) {
-      const int split_index =
-          qi->split_info[s].split * 2 + qi->split_info[s + 1].split;
-      aom_write_symbol(w, split_index, xd->tile_ctx->cnn_guided_quad_cdf, 4);
-    }
-    assert(qi->unit_info_length ==
-           quad_tree_get_unit_info_length(
-               cm->superres_upscaled_width, cm->superres_upscaled_height,
-               qi->unit_size, qi->split_info, qi->split_info_length));
-    write_filter_quadtree(xd->tile_ctx, cm->quant_params.base_qindex,
-                          cm->cnn_indices[0], cm->superres_scale_denominator,
-                          frame_is_intra_only(cm), qi, w);
+    write_filter_quadtree(
+        xd->tile_ctx, cm->quant_params.base_qindex, cm->cnn_indices[0],
+        cm->superres_upscaled_width, cm->superres_upscaled_height,
+        cm->superres_scale_denominator, frame_is_intra_only(cm), qi, w);
     qi->signaled = 1;
   }
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
@@ -3062,9 +3051,75 @@
 #endif  // !CONFIG_NEW_DF
 
 #if CONFIG_CNN_GUIDED_QUADTREE
+static void write_quadtree_split_info(FRAME_CONTEXT *ctx, int width, int height,
+                                      const QUADInfo *qi, aom_writer *wb) {
+#ifndef NDEBUG
+  int unit_info_length = 0;
+#endif  // NDEBUG
+  int split_info_index = 0;
+  const int ext_size = qi->unit_size * 3 / 2;
+  for (int row = 0; row < height;) {
+    const int remaining_height = height - row;
+    const int this_unit_height =
+        (remaining_height < ext_size) ? remaining_height : qi->unit_size;
+    for (int col = 0; col < width;) {
+      const int remaining_width = width - col;
+      const int this_unit_width =
+          (remaining_width < ext_size) ? remaining_width : qi->unit_size;
+      // Check for special cases near boundary.
+      const bool is_horz_partitioning_allowed =
+          (this_unit_height >= qi->unit_size);
+      const bool is_vert_partitioning_allowed =
+          (this_unit_width >= qi->unit_size);
+      if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
+        // Implicitly no split, and single unit info will be signaled.
+#ifndef NDEBUG
+        ++unit_info_length;
+#endif  // NDEBUG
+      } else {
+        assert(split_info_index < qi->split_info_length);
+        // Write split info, depending on how many partition types are allowed.
+        const bool all_partition_types_allowed =
+            (is_horz_partitioning_allowed && is_vert_partitioning_allowed);
+        const GuidedQuadTreePartitionType partition_type =
+            qi->split_info[split_info_index].split;
+        if (all_partition_types_allowed) {
+          aom_write_symbol(wb, partition_type, ctx->cnn_guided_quad_cdf,
+                           GUIDED_QT_TYPES);
+        } else {
+          aom_write_symbol(wb, partition_type != GUIDED_QT_NONE,
+                           ctx->cnn_guided_binary_cdf, 2);
+        }
+#ifndef NDEBUG
+        // Look at the split info to determine number of (sub)units.
+        switch (partition_type) {
+          case GUIDED_QT_NONE: unit_info_length += 1; break;
+          case GUIDED_QT_HORZ:
+          case GUIDED_QT_VERT: unit_info_length += 2; break;
+          case GUIDED_QT_SPLIT: unit_info_length += 4; break;
+          default: assert(0 && "Wrong guided quadtree split type."); break;
+        }
+#endif  // NDEBUG
+        ++split_info_index;
+      }
+      col += this_unit_width;
+    }
+    row += this_unit_height;
+  }
+  assert(qi->split_info_length == split_info_index);
+#ifndef NDEBUG
+  assert(qi->unit_info_length == unit_info_length);
+#endif  // NDEBUG
+}
+
 static void write_filter_quadtree(FRAME_CONTEXT *ctx, int QP, int cnn_index,
-                                  int superres_denom, int is_intra_only,
-                                  const QUADInfo *ci, aom_writer *wb) {
+                                  int width, int height, int superres_denom,
+                                  int is_intra_only, const QUADInfo *qi,
+                                  aom_writer *wb) {
+  // Write partitioning info.
+  write_quadtree_split_info(ctx, width, height, qi, wb);
+
+  // Write weight parameters 'a'.
   const int *const quadtset =
       get_quadparm_from_qindex(QP, superres_denom, is_intra_only, 1, cnn_index);
   const int norestore_ctx =
@@ -3073,9 +3128,9 @@
   const int A1_min = quadtset[3];
   int ref_0 = GUIDED_A_MID;
   int ref_1 = GUIDED_A_MID;
-  for (int i = 0; i < ci->unit_info_length; i++) {
-    const int a0 = ci->unit_info[i].xqd[0];
-    const int a1 = ci->unit_info[i].xqd[1];
+  for (int i = 0; i < qi->unit_info_length; i++) {
+    const int a0 = qi->unit_info[i].xqd[0];
+    const int a1 = qi->unit_info[i].xqd[1];
     int norestore;
     if (norestore_ctx != -1) {
       norestore = (a0 == 0 && a1 == 0);
@@ -3107,8 +3162,8 @@
 // Write quad tree unit index.
 static INLINE void quad_tree_write_unit_index(struct aom_write_bit_buffer *wb,
                                               const QUADInfo *const qi) {
-  assert(qi->unit_index >= 0 && qi->unit_index <= 1);
-  aom_wb_write_bit(wb, qi->unit_index);
+  assert(qi->unit_index >= 0 && qi->unit_index < GUIDED_QT_UNIT_SIZES);
+  aom_wb_write_literal(wb, qi->unit_index, GUIDED_QT_UNIT_SIZES_LOG2);
 }
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
 
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 02cc6a3..d740110 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -884,6 +884,10 @@
    */
   int cnn_guided_quad_cost[4];
   /*!
+   * cnn_guided binary split cost
+   */
+  int cnn_guided_binary_cost[2];
+  /*!
    * cnn_guided quad norestore cost
    */
   int cnn_guided_norestore_cost[2][2];
diff --git a/av1/encoder/encodeframe_utils.c b/av1/encoder/encodeframe_utils.c
index 8f5780b..20e1191 100644
--- a/av1/encoder/encodeframe_utils.c
+++ b/av1/encoder/encodeframe_utils.c
@@ -1271,6 +1271,8 @@
 #endif  // CONFIG_RST_MERGECEOFFS
 #if CONFIG_CNN_GUIDED_QUADTREE
   AVERAGE_CDF(ctx_left->cnn_guided_quad_cdf, ctx_tr->cnn_guided_quad_cdf, 4);
+  AVERAGE_CDF(ctx_left->cnn_guided_binary_cdf, ctx_tr->cnn_guided_binary_cdf,
+              2);
   AVERAGE_CDF(ctx_left->cnn_guided_norestore_cdf,
               ctx_tr->cnn_guided_norestore_cdf, 2);
 #endif  // CONFIG_CNN_GUIDED_QUADTREE
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 0b7e69a..e47994d 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -2339,6 +2339,7 @@
         if (!av1_restore_cnn_quadtree_encode_tflite(
                 cm, cpi->source, cpi->rd.RDMULT,
                 x->mode_costs.cnn_guided_quad_cost,
+                x->mode_costs.cnn_guided_binary_cost,
                 x->mode_costs.cnn_guided_norestore_cost,
                 cpi->mt_info.num_workers, quadtree_cnn, curr_cnn_indices,
                 &quad_info, &curr_cnn_rdcosts)) {
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index ebb64df..d8c2073 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -455,6 +455,8 @@
 #if CONFIG_CNN_GUIDED_QUADTREE
   av1_cost_tokens_from_cdf(mode_costs->cnn_guided_quad_cost,
                            fc->cnn_guided_quad_cdf, NULL);
+  av1_cost_tokens_from_cdf(mode_costs->cnn_guided_binary_cost,
+                           fc->cnn_guided_binary_cdf, NULL);
   for (int i = 0; i < GUIDED_NORESTORE_CONTEXTS; ++i)
     av1_cost_tokens_from_cdf(mode_costs->cnn_guided_norestore_cost[i],
                              fc->cnn_guided_norestore_cdf[i], NULL);
@@ -502,6 +504,8 @@
 #if CONFIG_CNN_GUIDED_QUADTREE
   av1_cost_tokens_from_cdf(mode_costs->cnn_guided_quad_cost,
                            fc->cnn_guided_quad_cdf, NULL);
+  av1_cost_tokens_from_cdf(mode_costs->cnn_guided_binary_cost,
+                           fc->cnn_guided_binary_cdf, NULL);
   for (int c = 0; c < GUIDED_NORESTORE_CONTEXTS; ++c)
     av1_cost_tokens_from_cdf(mode_costs->cnn_guided_norestore_cost[c],
                              fc->cnn_guided_norestore_cdf[c], NULL);