[CFL] Compute and Subtract Average on the Fly

Instead of storing the transform block average, it is immediately
subtracted from the subsampled pixels. This change does not alter the
bitstream and it reduces CfL complexity.

Change-Id: Ia5038b336abf1ec01e295b235734318906d3bae6
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index e7cb1f8..02ecba0 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -709,16 +709,6 @@
   // this context
   int uv_height, uv_width;
 
-  // Transform level averages of the luma reconstructed values over the entire
-  // prediction unit
-  // Fixed point y_averages is Q12.3:
-  //   * Worst case division is 1/1024
-  //   * Max error will be 1/16th.
-  // Note: 3 is chosen so that y_averages fits in 15 bits when 12 bit input is
-  // used
-  int y_averages_q3[MAX_NUM_TXB_SQUARE];
-  int y_averages_stride;
-
   int are_parameters_computed;
 
   // Chroma subsampling
diff --git a/av1/common/cfl.c b/av1/common/cfl.c
index f387c7e..6ca7df0 100644
--- a/av1/common/cfl.c
+++ b/av1/common/cfl.c
@@ -57,24 +57,18 @@
 }
 
 // Load from the CfL pixel buffer into output
-static void cfl_load(CFL_CTX *cfl, int row, int col, int width, int height) {
+static void cfl_load(CFL_CTX *cfl, int width, int height) {
   const int sub_x = cfl->subsampling_x;
   const int sub_y = cfl->subsampling_y;
-  const int off_log2 = tx_size_wide_log2[0];
 
   int *output_q3 = cfl->y_down_pix_q3;
 
   // TODO(ltrudeau) should be faster to downsample when we store the values
   // TODO(ltrudeau) add support for 4:2:2
   if (sub_y == 0 && sub_x == 0) {
-    // TODO(ltrudeau) convert to uint16 to add HBD support
-    const uint8_t *y_pix = cfl->y_pix + ((row * MAX_SB_SIZE + col) << off_log2);
-    cfl_luma_subsampling_444(y_pix, output_q3, width, height);
+    cfl_luma_subsampling_444(cfl->y_pix, output_q3, width, height);
   } else if (sub_y == 1 && sub_x == 1) {
-    // TODO(ltrudeau) convert to uint16 to add HBD support
-    const uint8_t *y_pix =
-        cfl->y_pix + ((row * MAX_SB_SIZE + col) << (off_log2 + sub_y));
-    cfl_luma_subsampling_420(y_pix, output_q3, width, height);
+    cfl_luma_subsampling_420(cfl->y_pix, output_q3, width, height);
   } else {
     assert(0);  // Unsupported chroma subsampling
   }
@@ -86,11 +80,8 @@
   // overrun,
   // we apply rows first. This way, when the rows overrun the bottom of the
   // frame, the columns will be copied over them.
-  const int uv_width = (col << off_log2) + width;
-  const int uv_height = (row << off_log2) + height;
-
-  const int diff_width = uv_width - (cfl->y_width >> sub_x);
-  const int diff_height = uv_height - (cfl->y_height >> sub_y);
+  const int diff_width = width - (cfl->y_width >> sub_x);
+  const int diff_height = height - (cfl->y_height >> sub_y);
 
   if (diff_width > 0) {
     int last_pixel;
@@ -103,7 +94,7 @@
       }
       output_row_offset += MAX_SB_SIZE;
     }
-    cfl->y_width = uv_width << sub_x;
+    cfl->y_width = width << sub_x;
   }
 
   if (diff_height > 0) {
@@ -116,7 +107,7 @@
       }
       output_row_offset += MAX_SB_SIZE;
     }
-    cfl->y_height = uv_height << sub_y;
+    cfl->y_height = height << sub_y;
   }
 }
 
@@ -191,45 +182,44 @@
   cfl->dc_pred[CFL_PRED_V] = (sum_v + (num_pel >> 1)) / num_pel;
 }
 
-static void cfl_compute_averages(CFL_CTX *cfl, TX_SIZE tx_size) {
+static void cfl_subtract_averages(CFL_CTX *cfl, TX_SIZE tx_size) {
   const int width = cfl->uv_width;
   const int height = cfl->uv_height;
   const int tx_height = tx_size_high[tx_size];
   const int tx_width = tx_size_wide[tx_size];
-  const int stride = width >> tx_size_wide_log2[tx_size];
   const int block_row_stride = MAX_SB_SIZE << tx_size_high_log2[tx_size];
   const int num_pel_log2 =
       (tx_size_high_log2[tx_size] + tx_size_wide_log2[tx_size]);
 
-  const int *y_pix_q3 = cfl->y_down_pix_q3;
-  const int *t_y_pix_q3;
-  int *averages_q3 = cfl->y_averages_q3;
+  int *y_pix_q3 = cfl->y_down_pix_q3;
 
-  cfl_load(cfl, 0, 0, width, height);
-
-  int a = 0;
+  cfl_load(cfl, width, height);
   for (int b_j = 0; b_j < height; b_j += tx_height) {
     for (int b_i = 0; b_i < width; b_i += tx_width) {
       int sum_q3 = 0;
-      t_y_pix_q3 = y_pix_q3;
+      int *t_y_pix_q3 = y_pix_q3;
       for (int t_j = 0; t_j < tx_height; t_j++) {
         for (int t_i = b_i; t_i < b_i + tx_width; t_i++) {
           sum_q3 += t_y_pix_q3[t_i];
         }
         t_y_pix_q3 += MAX_SB_SIZE;
       }
-      assert(a < MAX_NUM_TXB_SQUARE);
-      averages_q3[a++] = (sum_q3 + (1 << (num_pel_log2 - 1))) >> num_pel_log2;
-
+      int avg_q3 = (sum_q3 + (1 << (num_pel_log2 - 1))) >> num_pel_log2;
       // Loss is never more than 1/2 (in Q3)
-      assert(fabs((double)averages_q3[a - 1] -
-                  (sum_q3 / ((double)(1 << num_pel_log2)))) <= 0.5);
+      assert(fabs((double)avg_q3 - (sum_q3 / ((double)(1 << num_pel_log2)))) <=
+             0.5);
+
+      t_y_pix_q3 = y_pix_q3;
+      for (int t_j = 0; t_j < tx_height; t_j++) {
+        for (int t_i = b_i; t_i < b_i + tx_width; t_i++) {
+          t_y_pix_q3[t_i] -= avg_q3;
+        }
+
+        t_y_pix_q3 += MAX_SB_SIZE;
+      }
     }
-    assert(a % stride == 0);
     y_pix_q3 += block_row_stride;
   }
-
-  cfl->y_averages_stride = stride;
 }
 
 static INLINE int cfl_idx_to_alpha(int alpha_idx, int joint_sign,
@@ -253,28 +243,21 @@
 
   const int width = tx_size_wide[tx_size];
   const int height = tx_size_high[tx_size];
-  const int *y_pix_q3 = cfl->y_down_pix_q3;
+  const int *y_down_pix_q3 =
+      cfl->y_down_pix_q3 + ((row * MAX_SB_SIZE + col) << tx_size_wide_log2[0]);
 
   const int dc_pred = cfl->dc_pred[plane - 1];
   const int alpha_q3 =
       cfl_idx_to_alpha(mbmi->cfl_alpha_idx, mbmi->cfl_alpha_signs, plane - 1);
 
-  const int avg_row =
-      (row << tx_size_wide_log2[0]) >> tx_size_wide_log2[tx_size];
-  const int avg_col =
-      (col << tx_size_high_log2[0]) >> tx_size_high_log2[tx_size];
-  const int avg_q3 =
-      cfl->y_averages_q3[cfl->y_averages_stride * avg_row + avg_col];
-
-  cfl_load(cfl, row, col, width, height);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       // TODO(ltrudeau) add support for HBD.
-      dst[i] = clip_pixel(get_scaled_luma_q0(alpha_q3, y_pix_q3[i], avg_q3) +
-                          dc_pred);
+      dst[i] =
+          clip_pixel(get_scaled_luma_q0(alpha_q3, y_down_pix_q3[i]) + dc_pred);
     }
     dst += dst_stride;
-    y_pix_q3 += MAX_SB_SIZE;
+    y_down_pix_q3 += MAX_SB_SIZE;
   }
 }
 
@@ -425,10 +408,7 @@
   }
 #endif  // CONFIG_DEBUG
 
-  // Compute block-level DC_PRED for both chromatic planes.
-  // DC_PRED replaces beta in the linear model.
   cfl_dc_pred(xd, plane_bsize);
-  // Compute transform-level average on reconstructed luma input.
-  cfl_compute_averages(cfl, tx_size);
+  cfl_subtract_averages(cfl, tx_size);
   cfl->are_parameters_computed = 1;
 }
diff --git a/av1/common/cfl.h b/av1/common/cfl.h
index e6de1b1..26e0606 100644
--- a/av1/common/cfl.h
+++ b/av1/common/cfl.h
@@ -14,8 +14,8 @@
 
 #include "av1/common/blockd.h"
 
-static INLINE int get_scaled_luma_q0(int alpha_q3, int y_pix_q3, int avg_q3) {
-  int scaled_luma_q6 = alpha_q3 * (y_pix_q3 - avg_q3);
+static INLINE int get_scaled_luma_q0(int alpha_q3, int y_down_pix_q3) {
+  int scaled_luma_q6 = alpha_q3 * y_down_pix_q3;
   return ROUND_POWER_OF_TWO_SIGNED(scaled_luma_q6, 6);
 }
 
diff --git a/av1/common/enums.h b/av1/common/enums.h
index aa87406..0ad9e91 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -220,9 +220,6 @@
 #define MAX_TX_BLOCKS_IN_MAX_SB_LOG2 ((MAX_SB_SIZE_LOG2 - MAX_TX_SIZE_LOG2) * 2)
 #define MAX_TX_BLOCKS_IN_MAX_SB (1 << MAX_TX_BLOCKS_IN_MAX_SB_LOG2)
 
-#define MAX_NUM_TXB (1 << (MAX_SB_SIZE_LOG2 - MIN_TX_SIZE_LOG2))
-#define MAX_NUM_TXB_SQUARE (MAX_NUM_TXB * MAX_NUM_TXB)
-
 #if CONFIG_NCOBMC_ADAPT_WEIGHT
 typedef enum ATTRIBUTE_PACKED {
   NCOBMC_MODE_0,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index b540607..b97797f 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5845,11 +5845,10 @@
 #endif  // CONFIG_EXT_INTRA
 
 #if CONFIG_CFL
-static int64_t cfl_alpha_dist(const int *y_pix_q3, int y_stride,
-                              const int y_averages_q3[MAX_NUM_TXB_SQUARE],
-                              const uint8_t *src, int src_stride, int width,
-                              int height, TX_SIZE tx_size, int dc_pred,
-                              int alpha_q3, int64_t *dist_neg_out) {
+static int64_t cfl_alpha_dist(const int *y_pix_q3, const uint8_t *src,
+                              int src_stride, int width, int height,
+                              int dc_pred, int alpha_q3,
+                              int64_t *dist_neg_out) {
   int64_t dist = 0;
   int diff;
 
@@ -5868,41 +5867,21 @@
   }
 
   int64_t dist_neg = 0;
-  const int tx_height = tx_size_high[tx_size];
-  const int tx_width = tx_size_wide[tx_size];
-  const int y_block_row_off = y_stride * tx_height;
-  const int src_block_row_off = src_stride * tx_height;
-  const int *t_y_pix_q3;
-  const uint8_t *t_src;
-  int a = 0;
-  for (int b_j = 0; b_j < height; b_j += tx_height) {
-    const int h = b_j + tx_height;
-    for (int b_i = 0; b_i < width; b_i += tx_width) {
-      const int w = b_i + tx_width;
-      const int tx_avg_q3 = y_averages_q3[a++];
-      t_y_pix_q3 = y_pix_q3;
-      t_src = src;
-      for (int t_j = b_j; t_j < h; t_j++) {
-        for (int t_i = b_i; t_i < w; t_i++) {
-          const int uv = t_src[t_i];
+  for (int j = 0; j < height; j++) {
+    for (int i = 0; i < width; i++) {
+      const int uv = src[i];
+      const int scaled_luma = get_scaled_luma_q0(alpha_q3, y_pix_q3[i]);
 
-          const int scaled_luma =
-              get_scaled_luma_q0(alpha_q3, t_y_pix_q3[t_i], tx_avg_q3);
+      // TODO(ltrudeau) add support for HBD.
+      diff = uv - clamp(scaled_luma + dc_pred, 0, 255);
+      dist += diff * diff;
 
-          // TODO(ltrudeau) add support for HBD.
-          diff = uv - clamp(scaled_luma + dc_pred, 0, 255);
-          dist += diff * diff;
-
-          // TODO(ltrudeau) add support for HBD.
-          diff = uv - clamp(-scaled_luma + dc_pred, 0, 255);
-          dist_neg += diff * diff;
-        }
-        t_y_pix_q3 += y_stride;
-        t_src += src_stride;
-      }
+      // TODO(ltrudeau) add support for HBD.
+      diff = uv - clamp(-scaled_luma + dc_pred, 0, 255);
+      dist_neg += diff * diff;
     }
-    y_pix_q3 += y_block_row_off;
-    src += src_block_row_off;
+    y_pix_q3 += MAX_SB_SIZE;
+    src += src_stride;
   }
 
   if (dist_neg_out) *dist_neg_out = dist_neg;
@@ -5927,26 +5906,23 @@
   const int height = cfl->uv_height;
   const int dc_pred_u = cfl->dc_pred[CFL_PRED_U];
   const int dc_pred_v = cfl->dc_pred[CFL_PRED_V];
-  const int *y_averages_q3 = cfl->y_averages_q3;
   const int *y_pix_q3 = cfl->y_down_pix_q3;
 
   int64_t sse[CFL_PRED_PLANES][CFL_MAGS_SIZE];
-  sse[CFL_PRED_U][0] =
-      cfl_alpha_dist(y_pix_q3, MAX_SB_SIZE, y_averages_q3, src_u, src_stride_u,
-                     width, height, tx_size, dc_pred_u, 0, NULL);
-  sse[CFL_PRED_V][0] =
-      cfl_alpha_dist(y_pix_q3, MAX_SB_SIZE, y_averages_q3, src_v, src_stride_v,
-                     width, height, tx_size, dc_pred_v, 0, NULL);
+  sse[CFL_PRED_U][0] = cfl_alpha_dist(y_pix_q3, src_u, src_stride_u, width,
+                                      height, dc_pred_u, 0, NULL);
+  sse[CFL_PRED_V][0] = cfl_alpha_dist(y_pix_q3, src_v, src_stride_v, width,
+                                      height, dc_pred_v, 0, NULL);
 
   for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
     const int m = c * 2 + 1;
     const int abs_alpha_q3 = c + 1;
-    sse[CFL_PRED_U][m] = cfl_alpha_dist(
-        y_pix_q3, MAX_SB_SIZE, y_averages_q3, src_u, src_stride_u, width,
-        height, tx_size, dc_pred_u, abs_alpha_q3, &sse[CFL_PRED_U][m + 1]);
-    sse[CFL_PRED_V][m] = cfl_alpha_dist(
-        y_pix_q3, MAX_SB_SIZE, y_averages_q3, src_v, src_stride_v, width,
-        height, tx_size, dc_pred_v, abs_alpha_q3, &sse[CFL_PRED_V][m + 1]);
+    sse[CFL_PRED_U][m] =
+        cfl_alpha_dist(y_pix_q3, src_u, src_stride_u, width, height, dc_pred_u,
+                       abs_alpha_q3, &sse[CFL_PRED_U][m + 1]);
+    sse[CFL_PRED_V][m] =
+        cfl_alpha_dist(y_pix_q3, src_v, src_stride_v, width, height, dc_pred_v,
+                       abs_alpha_q3, &sse[CFL_PRED_V][m + 1]);
   }
 
   int64_t dist;