Guided CNN restoration: use macros for constants
diff --git a/av1/common/cnn_tflite.cc b/av1/common/cnn_tflite.cc
index e99b8ed..b94792b 100644
--- a/av1/common/cnn_tflite.cc
+++ b/av1/common/cnn_tflite.cc
@@ -776,8 +776,8 @@
         // finer search
         double flrA0 = (floor(A0));
         double flrA1 = (floor(A1));
-        flrA0 = AOMMIN(AOMMAX(flrA0, A0_min), A0_min + 15);
-        flrA1 = AOMMIN(AOMMAX(flrA1, A1_min), A1_min + 15);
+        flrA0 = AOMMIN(AOMMAX(flrA0, A0_min), A0_min + GUIDED_A_RANGE);
+        flrA1 = AOMMIN(AOMMAX(flrA1, A1_min), A1_min + GUIDED_A_RANGE);
         {
           A0 = flrA0;
           A1 = flrA1;
@@ -792,17 +792,20 @@
               err += diff * diff;
             }
           }
-          // approx RD cost assuming 7 bits per a0, a1 pair
+          // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
           cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
-              rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
-              bit_depth);
+              rdmult,
+              (norestorecost[0] +
+               (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+                  4,
+              err, bit_depth);
           if (cost < bestcost) {
             bestA0 = A0;
             bestA1 = A1;
             bestcost = cost;
           }
         }
-        if (flrA0 < A0_min + 15) {
+        if (flrA0 < A0_min + GUIDED_A_RANGE) {
           A0 = flrA0 + 1;
           A1 = flrA1;
           err = 0;
@@ -816,17 +819,20 @@
               err += diff * diff;
             }
           }
-          // approx RD cost assuming 7 bits per a0, a1 pair
+          // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
           cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
-              rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
-              bit_depth);
+              rdmult,
+              (norestorecost[0] +
+               (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+                  4,
+              err, bit_depth);
           if (cost < bestcost) {
             bestA0 = A0;
             bestA1 = A1;
             bestcost = cost;
           }
         }
-        if (flrA1 < A1_min + 15) {
+        if (flrA1 < A1_min + GUIDED_A_RANGE) {
           A0 = flrA0;
           A1 = flrA1 + 1;
           err = 0;
@@ -840,17 +846,21 @@
               err += diff * diff;
             }
           }
-          // approx RD cost assuming 7 bits per a0, a1 pair
+          // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
           cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
-              rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
-              bit_depth);
+              rdmult,
+              (norestorecost[0] +
+               (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+                  4,
+              err, bit_depth);
           if (cost < bestcost) {
             bestA0 = A0;
             bestA1 = A1;
             bestcost = cost;
           }
         }
-        if (flrA0 < A0_min + 15 && flrA1 < A1_min + 15) {
+        if (flrA0 < A0_min + GUIDED_A_RANGE &&
+            flrA1 < A1_min + GUIDED_A_RANGE) {
           A0 = flrA0 + 1;
           A1 = flrA1 + 1;
           err = 0;
@@ -864,10 +874,13 @@
               err += diff * diff;
             }
           }
-          // approx RD cost assuming 7 bits per a0, a1 pair
+          // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
           cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
-              rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
-              bit_depth);
+              rdmult,
+              (norestorecost[0] +
+               (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+                  4,
+              err, bit_depth);
           if (cost < bestcost) {
             bestA0 = A0;
             bestA1 = A1;
@@ -879,12 +892,12 @@
       } else {
         A0 = (round(A0));
         A1 = (round(A1));
-        A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + 15);
-        A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + 15);
+        A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + GUIDED_A_RANGE);
+        A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + GUIDED_A_RANGE);
       }
 
-      A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + 15);
-      A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + 15);
+      A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + GUIDED_A_RANGE);
+      A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + GUIDED_A_RANGE);
       A.emplace_back((int)A0, (int)A1);
       for (int i = this_start_row; i < this_end_row; i++) {
         for (int j = this_start_col; j < this_end_col; j++) {
@@ -929,21 +942,21 @@
   const int A0_min = quadtset[2];
   const int A1_min = quadtset[3];
   int num_bits = 0;
-  int ref0 = AOMMIN(AOMMAX(prev_A.first - A0_min, 0), 15);
-  int ref1 = AOMMIN(AOMMAX(prev_A.second - A1_min, 0), 15);
+  int ref0 = AOMMIN(AOMMAX(prev_A.first - A0_min, 0), GUIDED_A_RANGE);
+  int ref1 = AOMMIN(AOMMAX(prev_A.second - A1_min, 0), GUIDED_A_RANGE);
   for (auto &this_A : A) {
     if (this_A.first == 0 && this_A.second == 0) {
       num_bits += norestorecosts[1];
     } else {
       num_bits += norestorecosts[0];
-      num_bits += (aom_count_primitive_refsubexpfin(16, 1, ref0,
-                                                    this_A.first - A0_min) +
-                   aom_count_primitive_refsubexpfin(16, 1, ref1,
-                                                    this_A.second - A1_min))
+      num_bits += (aom_count_primitive_refsubexpfin(
+                       GUIDED_A_NUM_VALUES, 1, ref0, this_A.first - A0_min) +
+                   aom_count_primitive_refsubexpfin(
+                       GUIDED_A_NUM_VALUES, 1, ref1, this_A.second - A1_min))
                   << AV1_PROB_COST_SHIFT;
     }
-    ref0 = AOMMIN(AOMMAX(this_A.first - A0_min, 0), 15);
-    ref1 = AOMMIN(AOMMAX(this_A.second - A1_min, 0), 15);
+    ref0 = AOMMIN(AOMMAX(this_A.first - A0_min, 0), GUIDED_A_RANGE);
+    ref1 = AOMMIN(AOMMAX(this_A.second - A1_min, 0), GUIDED_A_RANGE);
   }
   return num_bits;
 }
@@ -1132,7 +1145,8 @@
     std::vector<std::pair<int, int>> this_A;  // selected a0, a1 weight pairs.
     double this_rdcost_total = 0.0;
     // Previous a0, a1 pair is mid-point of the range by default.
-    std::pair<int, int> prev_A = std::make_pair(8 + A0_min, 8 + A1_min);
+    std::pair<int, int> prev_A =
+        std::make_pair(GUIDED_A_MID + A0_min, GUIDED_A_MID + A1_min);
     // TODO(urvang): Include padded area in a unit if it's < unit size / 2?
     // If so, need to modify / replace quad_tree_get_unit_info_length().
     // Also double check: quad_tree_get_split_info_length().
diff --git a/av1/common/guided_quadtree.h b/av1/common/guided_quadtree.h
index 31550eb..4fef1e6 100644
--- a/av1/common/guided_quadtree.h
+++ b/av1/common/guided_quadtree.h
@@ -24,6 +24,13 @@
 #endif
 
 #if CONFIG_CNN_GUIDED_QUADTREE
+
+#define GUIDED_A_BITS 4
+#define GUIDED_A_NUM_VALUES (1 << GUIDED_A_BITS)
+#define GUIDED_A_MID (GUIDED_A_NUM_VALUES >> 1)
+#define GUIDED_A_RANGE (GUIDED_A_NUM_VALUES - 1)
+#define GUIDED_A_PAIR_BITS (GUIDED_A_BITS * 2 - 1)
+
 int *get_quadparm_from_qindex(int qindex, int superres_denom, int is_intra_only,
                               int is_luma, int cnn_index);
 
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index a1088f0..a71b6a6 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -1836,8 +1836,8 @@
   A0_min = quadtset[2];
   A1_min = quadtset[3];
 
-  int ref_0 = 8;
-  int ref_1 = 8;
+  int ref_0 = GUIDED_A_MID;
+  int ref_1 = GUIDED_A_MID;
   for (int i = 0; i < qi->unit_info_length; i++) {
     const int norestore =
         norestore_ctx == -1
@@ -1847,13 +1847,17 @@
     if (norestore) {
       qi->unit_info[i].xqd[0] = 0;
       qi->unit_info[i].xqd[1] = 0;
-      ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + 15, 0)) - A0_min;
-      ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + 15, 0)) - A1_min;
+      ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + GUIDED_A_RANGE, 0)) - A0_min;
+      ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + GUIDED_A_RANGE, 0)) - A1_min;
     } else {
       qi->unit_info[i].xqd[0] =
-          aom_read_primitive_refsubexpfin(rb, 16, 1, ref_0, ACCT_STR) + A0_min;
+          aom_read_primitive_refsubexpfin(rb, GUIDED_A_NUM_VALUES, 1, ref_0,
+                                          ACCT_STR) +
+          A0_min;
       qi->unit_info[i].xqd[1] =
-          aom_read_primitive_refsubexpfin(rb, 16, 1, ref_1, ACCT_STR) + A1_min;
+          aom_read_primitive_refsubexpfin(rb, GUIDED_A_NUM_VALUES, 1, ref_1,
+                                          ACCT_STR) +
+          A1_min;
       ref_0 = qi->unit_info[i].xqd[0] - A0_min;
       ref_1 = qi->unit_info[i].xqd[1] - A1_min;
     }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 88817af..0c0c827 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -3065,26 +3065,18 @@
 static void write_filter_quadtree(FRAME_CONTEXT *ctx, int QP, int cnn_index,
                                   int superres_denom, int is_intra_only,
                                   const QUADInfo *ci, aom_writer *wb) {
-  int A0_min, A1_min;
-  int *quadtset;
-  quadtset =
+  const int *const quadtset =
       get_quadparm_from_qindex(QP, superres_denom, is_intra_only, 1, cnn_index);
   const int norestore_ctx =
       get_guided_norestore_ctx(QP, superres_denom, is_intra_only);
-  A0_min = quadtset[2];
-  A1_min = quadtset[3];
-  int a0;
-  int a1;
-  int b_a0;
-  int b_a1;
-  int ref_0 = 8;
-  int ref_1 = 8;
+  const int A0_min = quadtset[2];
+  const int A1_min = quadtset[3];
+  int ref_0 = GUIDED_A_MID;
+  int ref_1 = GUIDED_A_MID;
   for (int i = 0; i < ci->unit_info_length; i++) {
-    a0 = ci->unit_info[i].xqd[0];
-    a1 = ci->unit_info[i].xqd[1];
+    const int a0 = ci->unit_info[i].xqd[0];
+    const int a1 = ci->unit_info[i].xqd[1];
     int norestore;
-
-    // printf("a0:%d  a1:%d\n", a0, a1);
     if (norestore_ctx != -1) {
       norestore = (a0 == 0 && a1 == 0);
       aom_write_symbol(wb, norestore,
@@ -3093,27 +3085,17 @@
       norestore = 0;
     }
     if (norestore) {
-      ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + 15, 0)) - A0_min;
-      ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + 15, 0)) - A1_min;
+      ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + GUIDED_A_RANGE, 0)) - A0_min;
+      ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + GUIDED_A_RANGE, 0)) - A1_min;
     } else {
-      b_a0 = a0 - A0_min;
-      if (b_a0 < 0) {
-        b_a0 = 0;
-      }
-      if (b_a0 > 15) {
-        b_a0 = 15;
-      }
-      b_a1 = a1 - A1_min;
-      if (b_a1 < 0) {
-        b_a1 = 0;
-      }
-      if (b_a1 > 15) {
-        b_a1 = 15;
-      }
-      aom_write_primitive_refsubexpfin(wb, 16, 1, ref_0, b_a0);
-      aom_write_primitive_refsubexpfin(wb, 16, 1, ref_1, b_a1);
-      ref_0 = b_a0;
-      ref_1 = b_a1;
+      const int a0_offset = AOMMIN(AOMMAX(a0 - A0_min, 0), GUIDED_A_RANGE);
+      const int a1_offset = AOMMIN(AOMMAX(a1 - A1_min, 0), GUIDED_A_RANGE);
+      aom_write_primitive_refsubexpfin(wb, GUIDED_A_NUM_VALUES, 1, ref_0,
+                                       a0_offset);
+      aom_write_primitive_refsubexpfin(wb, GUIDED_A_NUM_VALUES, 1, ref_1,
+                                       a1_offset);
+      ref_0 = a0_offset;
+      ref_1 = a1_offset;
     }
   }
 }