Refactor av1_apply_temporal_filter()

This CL simplifies av1_apply_temporal_filter() by refactoring some
sections of code so as to reduce the complexity and to avoid redundancy.
The C, SSE2 (low-bd, high-bd) and AVX2 (low-bd) variants have been
modified accordingly.

This CL is bit-exact with speed improvement across all presets.

BUG=aomedia:2761

Change-Id: I6591ade03c98cc1fc69beb968a51049edd798a9f
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index 4d954b5..ac8064f 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -465,6 +465,35 @@
   }
 }
 
+// Function to accumulate pixel-wise squared difference between two luma buffers
+// to be consumed while filtering the chroma planes.
+// Inputs:
+//   square_diff: Pointer to squared differences from luma plane.
+//   luma_sse_sum: Pointer to save the sum of luma squared differences.
+//   block_height: Height of block for computation.
+//   block_width: Width of block for computation.
+//   ss_x_shift: Chroma subsampling shift in 'X' direction
+//   ss_y_shift: Chroma subsampling shift in 'Y' direction
+// Returns:
+//   Nothing will be returned. But the content to which `luma_sse_sum` points
+//   will be modified.
+void compute_luma_sq_error_sum(uint32_t *square_diff, uint32_t *luma_sse_sum,
+                               int block_height, int block_width,
+                               int ss_x_shift, int ss_y_shift) {
+  for (int i = 0; i < block_height; ++i) {
+    for (int j = 0; j < block_width; ++j) {
+      for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+        for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+          const int yy = (i << ss_y_shift) + ii;     // Y-coord on Y-plane.
+          const int xx = (j << ss_x_shift) + jj;     // X-coord on Y-plane.
+          const int ww = block_width << ss_x_shift;  // Width of Y-plane.
+          luma_sse_sum[i * block_width + j] += square_diff[yy * ww + xx];
+        }
+      }
+    }
+  }
+}
+
 /*!\endcond */
 /*!\brief Applies temporal filtering. NOTE that there are various optimised
  * versions of this function called where the appropriate instruction set is
@@ -532,38 +561,43 @@
     const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
     decay_factor[plane] = 1 / (n_decay * q_decay * s_decay);
   }
-
-  // Allocate memory for pixel-wise squared differences for all planes. They,
-  // regardless of the subsampling, are assigned with memory of size `mb_pels`.
-  uint32_t *square_diff =
-      aom_memalign(16, num_planes * mb_pels * sizeof(uint32_t));
-  memset(square_diff, 0, num_planes * mb_pels * sizeof(square_diff[0]));
-
-  int plane_offset = 0;
-  for (int plane = 0; plane < num_planes; ++plane) {
-    // Locate pixel on reference frame.
-    const int plane_h = mb_height >> mbd->plane[plane].subsampling_y;
-    const int plane_w = mb_width >> mbd->plane[plane].subsampling_x;
-    const int frame_stride = frame_to_filter->strides[plane == 0 ? 0 : 1];
-    const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
-    const uint8_t *ref = frame_to_filter->buffers[plane];
-    compute_square_diff(ref, frame_offset, frame_stride, pred, plane_offset,
-                        plane_w, plane_h, plane_w, is_high_bitdepth,
-                        square_diff + plane_offset);
-    plane_offset += mb_pels;
+  double d_factor[4] = { 0 };
+  for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
+    // Larger motion vector -> smaller filtering weight.
+    const MV mv = subblock_mvs[subblock_idx];
+    const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
+    double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
+    distance_threshold = AOMMAX(distance_threshold, 1);
+    d_factor[subblock_idx] = distance / distance_threshold;
+    d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
   }
 
+  // Allocate memory for pixel-wise squared differences. They,
+  // regardless of the subsampling, are assigned with memory of size `mb_pels`.
+  uint32_t *square_diff = aom_memalign(16, mb_pels * sizeof(uint32_t));
+  memset(square_diff, 0, mb_pels * sizeof(square_diff[0]));
+
+  // Allocate memory for accumulated luma squared error. This value will be
+  // consumed while filtering the chroma planes.
+  uint32_t *luma_sse_sum = aom_memalign(32, mb_pels * sizeof(uint32_t));
+  memset(luma_sse_sum, 0, mb_pels * sizeof(luma_sse_sum[0]));
+
   // Get window size for pixel-wise filtering.
   assert(TF_WINDOW_LENGTH % 2 == 1);
   const int half_window = TF_WINDOW_LENGTH >> 1;
 
   // Handle planes in sequence.
-  plane_offset = 0;
+  int plane_offset = 0;
   for (int plane = 0; plane < num_planes; ++plane) {
+    // Locate pixel on reference frame.
     const int subsampling_y = mbd->plane[plane].subsampling_y;
     const int subsampling_x = mbd->plane[plane].subsampling_x;
     const int h = mb_height >> subsampling_y;  // Plane height.
     const int w = mb_width >> subsampling_x;   // Plane width.
+    const int frame_stride =
+        frame_to_filter->strides[plane == AOM_PLANE_Y ? 0 : 1];
+    const int frame_offset = mb_row * h * frame_stride + mb_col * w;
+    const uint8_t *ref = frame_to_filter->buffers[plane];
     const int ss_y_shift =
         subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
     const int ss_x_shift =
@@ -572,6 +606,15 @@
                                ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
     const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
 
+    // Filter U-plane and V-plane using Y-plane. This is because motion
+    // search is only done on Y-plane, so the information from Y-plane will
+    // be more accurate. The luma sse sum is reused in both chroma planes.
+    if (plane == AOM_PLANE_U)
+      compute_luma_sq_error_sum(square_diff, luma_sse_sum, h, w, ss_x_shift,
+                                ss_y_shift);
+    compute_square_diff(ref, frame_offset, frame_stride, pred, plane_offset, w,
+                        h, w, is_high_bitdepth, square_diff);
+
     // Perform filtering.
     int pred_idx = 0;
     for (int i = 0; i < h; ++i) {
@@ -583,23 +626,11 @@
           for (int wj = -half_window; wj <= half_window; ++wj) {
             const int y = CLIP(i + wi, 0, h - 1);  // Y-coord on current plane.
             const int x = CLIP(j + wj, 0, w - 1);  // X-coord on current plane.
-            sum_square_diff += square_diff[plane_offset + y * w + x];
+            sum_square_diff += square_diff[y * w + x];
           }
         }
 
-        // Filter U-plane and V-plane using Y-plane. This is because motion
-        // search is only done on Y-plane, so the information from Y-plane will
-        // be more accurate.
-        if (plane != 0) {
-          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
-            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
-              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
-              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
-              const int ww = w << ss_x_shift;         // Width of Y-plane.
-              sum_square_diff += square_diff[yy * ww + xx];
-            }
-          }
-        }
+        sum_square_diff += luma_sse_sum[i * w + j];
 
         // Scale down the difference for high bit depth input.
         if (mbd->bd > 8) sum_square_diff >>= ((mbd->bd - 8) * 2);
@@ -611,15 +642,9 @@
         const double combined_error =
             weight_factor * window_error + block_error * inv_factor;
 
-        // Larger motion vector -> smaller filtering weight.
-        const MV mv = subblock_mvs[subblock_idx];
-        const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
-        const double distance_threshold =
-            (double)AOMMAX(min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD, 1);
-        const double d_factor = AOMMAX(distance / distance_threshold, 1);
-
         // Compute filter weight.
-        double scaled_error = combined_error * d_factor * decay_factor[plane];
+        double scaled_error =
+            combined_error * d_factor[subblock_idx] * decay_factor[plane];
         scaled_error = AOMMIN(scaled_error, 7);
         const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
@@ -635,6 +660,7 @@
   }
 
   aom_free(square_diff);
+  aom_free(luma_sse_sum);
 }
 #if CONFIG_AV1_HIGHBITDEPTH
 // Calls High bit-depth temporal filter
diff --git a/av1/encoder/x86/highbd_temporal_filter_sse2.c b/av1/encoder/x86/highbd_temporal_filter_sse2.c
index cf08b66..c0d214f 100644
--- a/av1/encoder/x86/highbd_temporal_filter_sse2.c
+++ b/av1/encoder/x86/highbd_temporal_filter_sse2.c
@@ -92,18 +92,14 @@
 static void highbd_apply_temporal_filter(
     const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
     const unsigned int stride2, const int block_width, const int block_height,
-    const int min_frame_size, const MV *subblock_mvs, const int *subblock_mses,
-    unsigned int *accumulator, uint16_t *count, uint32_t *luma_sq_error,
-    uint32_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift,
-    int bd, const double inv_num_ref_pixels, const double decay_factor,
-    const double inv_factor, const double weight_factor) {
+    const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
+    uint32_t *frame_sse, uint32_t *luma_sse_sum, int bd,
+    const double inv_num_ref_pixels, const double decay_factor,
+    const double inv_factor, const double weight_factor, double *d_factor) {
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
-  if (plane > PLANE_TYPE_Y) assert(chroma_sq_error != NULL);
 
   uint32_t acc_5x5_sse[BH][BW];
-  uint32_t *frame_sse =
-      (plane == PLANE_TYPE_Y) ? luma_sq_error : chroma_sq_error;
 
   get_squared_error(frame1, stride, frame2, stride2, block_width, block_height,
                     frame_sse, SSE_STRIDE);
@@ -186,24 +182,7 @@
   for (int i = 0, k = 0; i < block_height; i++) {
     for (int j = 0; j < block_width; j++, k++) {
       const int pixel_value = frame2[i * stride2 + j];
-
-      int diff_sse = acc_5x5_sse[i][j];
-      int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH;
-
-      // Filter U-plane and V-plane using Y-plane. This is because motion
-      // search is only done on Y-plane, so the information from Y-plane will
-      // be more accurate.
-      if (plane != PLANE_TYPE_Y) {
-        for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
-          for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
-            const int yy = (i << ss_y_shift) + ii;      // Y-coord on Y-plane.
-            const int xx = (j << ss_x_shift) + jj + 2;  // X-coord on Y-plane.
-            const int ww = SSE_STRIDE;                  // Stride of Y-plane.
-            diff_sse += luma_sq_error[yy * ww + xx];
-            ++num_ref_pixels;
-          }
-        }
-      }
+      uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
 
       // Scale down the difference for high bit depth input.
       diff_sse >>= ((bd - 8) * 2);
@@ -215,13 +194,8 @@
       const double combined_error =
           weight_factor * window_error + block_error * inv_factor;
 
-      const MV mv = subblock_mvs[subblock_idx];
-      const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
-      const double distance_threshold =
-          (double)AOMMAX(min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD, 1);
-      const double d_factor = AOMMAX(distance / distance_threshold, 1);
-
-      double scaled_error = combined_error * d_factor * decay_factor;
+      double scaled_error =
+          combined_error * d_factor[subblock_idx] * decay_factor;
       scaled_error = AOMMIN(scaled_error, 7);
       const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
@@ -261,12 +235,21 @@
   // Smaller strength -> smaller filtering weight.
   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
   s_decay = CLIP(s_decay, 1e-5, 1);
-  uint32_t luma_sq_error[SSE_STRIDE * BH];
-  uint32_t *chroma_sq_error =
-      (num_planes > 0)
-          ? (uint32_t *)aom_malloc(SSE_STRIDE * BH * sizeof(uint32_t))
-          : NULL;
+  double d_factor[4] = { 0 };
+  uint32_t frame_sse[SSE_STRIDE * BH] = { 0 };
+  uint32_t luma_sse_sum[BW * BH] = { 0 };
   uint16_t *pred1 = CONVERT_TO_SHORTPTR(pred);
+
+  for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
+    // Larger motion vector -> smaller filtering weight.
+    const MV mv = subblock_mvs[subblock_idx];
+    const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
+    double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
+    distance_threshold = AOMMAX(distance_threshold, 1);
+    d_factor[subblock_idx] = distance / distance_threshold;
+    d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
+  }
+
   for (int plane = 0; plane < num_planes; ++plane) {
     const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
     const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
@@ -286,12 +269,28 @@
     const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
     const double decay_factor = 1 / (n_decay * q_decay * s_decay);
 
+    // Filter U-plane and V-plane using Y-plane. This is because motion
+    // search is only done on Y-plane, so the information from Y-plane
+    // will be more accurate. The luma sse sum is reused in both chroma
+    // planes.
+    if (plane == AOM_PLANE_U) {
+      for (unsigned int i = 0, k = 0; i < plane_h; i++) {
+        for (unsigned int j = 0; j < plane_w; j++, k++) {
+          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
+              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
+              luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
+            }
+          }
+        }
+      }
+    }
+
     highbd_apply_temporal_filter(
         ref, frame_stride, pred1 + mb_pels * plane, plane_w, plane_w, plane_h,
-        min_frame_size, subblock_mvs, subblock_mses, accum + mb_pels * plane,
-        count + mb_pels * plane, luma_sq_error, chroma_sq_error, plane,
-        ss_x_shift, ss_y_shift, mbd->bd, inv_num_ref_pixels, decay_factor,
-        inv_factor, weight_factor);
+        subblock_mses, accum + mb_pels * plane, count + mb_pels * plane,
+        frame_sse, luma_sse_sum, mbd->bd, inv_num_ref_pixels, decay_factor,
+        inv_factor, weight_factor, d_factor);
   }
-  if (chroma_sq_error != NULL) aom_free(chroma_sq_error);
 }
diff --git a/av1/encoder/x86/temporal_filter_avx2.c b/av1/encoder/x86/temporal_filter_avx2.c
index 52a3964..65f7fe4 100644
--- a/av1/encoder/x86/temporal_filter_avx2.c
+++ b/av1/encoder/x86/temporal_filter_avx2.c
@@ -130,18 +130,14 @@
 static void apply_temporal_filter(
     const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
     const unsigned int stride2, const int block_width, const int block_height,
-    const int min_frame_size, const MV *subblock_mvs, const int *subblock_mses,
-    unsigned int *accumulator, uint16_t *count, uint16_t *luma_sq_error,
-    uint16_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift,
+    const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
+    uint16_t *frame_sse, uint32_t *luma_sse_sum,
     const double inv_num_ref_pixels, const double decay_factor,
-    const double inv_factor, const double weight_factor) {
+    const double inv_factor, const double weight_factor, double *d_factor) {
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
-  if (plane > PLANE_TYPE_Y) assert(chroma_sq_error != NULL);
 
   uint32_t acc_5x5_sse[BH][BW];
-  uint16_t *frame_sse =
-      (plane == PLANE_TYPE_Y) ? luma_sq_error : chroma_sq_error;
 
   if (block_width == 32) {
     get_squared_error_32x32_avx2(frame1, stride, frame2, stride2, block_width,
@@ -199,21 +195,7 @@
   for (int i = 0, k = 0; i < block_height; i++) {
     for (int j = 0; j < block_width; j++, k++) {
       const int pixel_value = frame2[i * stride2 + j];
-
-      uint32_t diff_sse = acc_5x5_sse[i][j];
-
-      // Filter U-plane and V-plane using Y-plane. This is because motion
-      // search is only done on Y-plane, so the information from Y-plane will
-      // be more accurate.
-      if (plane != PLANE_TYPE_Y) {
-        for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
-          for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
-            const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
-            const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
-            diff_sse += luma_sq_error[yy * SSE_STRIDE + xx];
-          }
-        }
-      }
+      uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
 
       const double window_error = diff_sse * inv_num_ref_pixels;
       const int subblock_idx =
@@ -222,13 +204,8 @@
       const double combined_error =
           weight_factor * window_error + block_error * inv_factor;
 
-      const MV mv = subblock_mvs[subblock_idx];
-      const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
-      const double distance_threshold =
-          (double)AOMMAX(min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD, 1);
-      const double d_factor = AOMMAX(distance / distance_threshold, 1);
-
-      double scaled_error = combined_error * d_factor * decay_factor;
+      double scaled_error =
+          combined_error * d_factor[subblock_idx] * decay_factor;
       scaled_error = AOMMIN(scaled_error, 7);
       const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
@@ -269,11 +246,19 @@
   // Smaller strength -> smaller filtering weight.
   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
   s_decay = CLIP(s_decay, 1e-5, 1);
-  uint16_t luma_sq_error[SSE_STRIDE * BH];
-  uint16_t *chroma_sq_error =
-      (num_planes > 0)
-          ? (uint16_t *)aom_malloc(SSE_STRIDE * BH * sizeof(uint16_t))
-          : NULL;
+  double d_factor[4] = { 0 };
+  uint16_t frame_sse[SSE_STRIDE * BH] = { 0 };
+  uint32_t luma_sse_sum[BW * BH] = { 0 };
+
+  for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
+    // Larger motion vector -> smaller filtering weight.
+    const MV mv = subblock_mvs[subblock_idx];
+    const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
+    double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
+    distance_threshold = AOMMAX(distance_threshold, 1);
+    d_factor[subblock_idx] = distance / distance_threshold;
+    d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
+  }
 
   for (int plane = 0; plane < num_planes; ++plane) {
     const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
@@ -293,12 +278,28 @@
     const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
     const double decay_factor = 1 / (n_decay * q_decay * s_decay);
 
-    apply_temporal_filter(
-        ref, frame_stride, pred + mb_pels * plane, plane_w, plane_w, plane_h,
-        min_frame_size, subblock_mvs, subblock_mses, accum + mb_pels * plane,
-        count + mb_pels * plane, luma_sq_error, chroma_sq_error, plane,
-        ss_x_shift, ss_y_shift, inv_num_ref_pixels, decay_factor, inv_factor,
-        weight_factor);
+    // Filter U-plane and V-plane using Y-plane. This is because motion
+    // search is only done on Y-plane, so the information from Y-plane
+    // will be more accurate. The luma sse sum is reused in both chroma
+    // planes.
+    if (plane == AOM_PLANE_U) {
+      for (unsigned int i = 0, k = 0; i < plane_h; i++) {
+        for (unsigned int j = 0; j < plane_w; j++, k++) {
+          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
+              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
+              luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx];
+            }
+          }
+        }
+      }
+    }
+
+    apply_temporal_filter(ref, frame_stride, pred + mb_pels * plane, plane_w,
+                          plane_w, plane_h, subblock_mses,
+                          accum + mb_pels * plane, count + mb_pels * plane,
+                          frame_sse, luma_sse_sum, inv_num_ref_pixels,
+                          decay_factor, inv_factor, weight_factor, d_factor);
   }
-  if (chroma_sq_error != NULL) aom_free(chroma_sq_error);
 }
diff --git a/av1/encoder/x86/temporal_filter_sse2.c b/av1/encoder/x86/temporal_filter_sse2.c
index a3a0981..b366d0f 100644
--- a/av1/encoder/x86/temporal_filter_sse2.c
+++ b/av1/encoder/x86/temporal_filter_sse2.c
@@ -105,18 +105,14 @@
 static void apply_temporal_filter(
     const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
     const unsigned int stride2, const int block_width, const int block_height,
-    const int min_frame_size, const MV *subblock_mvs, const int *subblock_mses,
-    unsigned int *accumulator, uint16_t *count, uint16_t *luma_sq_error,
-    uint16_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift,
+    const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
+    uint16_t *frame_sse, uint32_t *luma_sse_sum,
     const double inv_num_ref_pixels, const double decay_factor,
-    const double inv_factor, const double weight_factor) {
+    const double inv_factor, const double weight_factor, double *d_factor) {
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
-  if (plane > PLANE_TYPE_Y) assert(chroma_sq_error != NULL);
 
   uint32_t acc_5x5_sse[BH][BW];
-  uint16_t *frame_sse =
-      (plane == PLANE_TYPE_Y) ? luma_sq_error : chroma_sq_error;
 
   get_squared_error(frame1, stride, frame2, stride2, block_width, block_height,
                     frame_sse, SSE_STRIDE);
@@ -176,22 +172,7 @@
   for (int i = 0, k = 0; i < block_height; i++) {
     for (int j = 0; j < block_width; j++, k++) {
       const int pixel_value = frame2[i * stride2 + j];
-
-      uint32_t diff_sse = acc_5x5_sse[i][j];
-
-      // Filter U-plane and V-plane using Y-plane. This is because motion
-      // search is only done on Y-plane, so the information from Y-plane will
-      // be more accurate.
-      if (plane != PLANE_TYPE_Y) {
-        for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
-          for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
-            const int yy = (i << ss_y_shift) + ii;      // Y-coord on Y-plane.
-            const int xx = (j << ss_x_shift) + jj + 2;  // X-coord on Y-plane.
-            const int ww = SSE_STRIDE;                  // Stride of Y-plane.
-            diff_sse += luma_sq_error[yy * ww + xx];
-          }
-        }
-      }
+      uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
 
       const double window_error = diff_sse * inv_num_ref_pixels;
       const int subblock_idx =
@@ -200,13 +181,8 @@
       const double combined_error =
           weight_factor * window_error + block_error * inv_factor;
 
-      const MV mv = subblock_mvs[subblock_idx];
-      const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
-      const double distance_threshold =
-          (double)AOMMAX(min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD, 1);
-      const double d_factor = AOMMAX(distance / distance_threshold, 1);
-
-      double scaled_error = combined_error * d_factor * decay_factor;
+      double scaled_error =
+          combined_error * d_factor[subblock_idx] * decay_factor;
       scaled_error = AOMMIN(scaled_error, 7);
       const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
@@ -247,11 +223,19 @@
   // Smaller strength -> smaller filtering weight.
   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
   s_decay = CLIP(s_decay, 1e-5, 1);
-  uint16_t luma_sq_error[SSE_STRIDE * BH];
-  uint16_t *chroma_sq_error =
-      (num_planes > 0)
-          ? (uint16_t *)aom_malloc(SSE_STRIDE * BH * sizeof(uint16_t))
-          : NULL;
+  double d_factor[4] = { 0 };
+  uint16_t frame_sse[SSE_STRIDE * BH] = { 0 };
+  uint32_t luma_sse_sum[BW * BH] = { 0 };
+
+  for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
+    // Larger motion vector -> smaller filtering weight.
+    const MV mv = subblock_mvs[subblock_idx];
+    const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
+    double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
+    distance_threshold = AOMMAX(distance_threshold, 1);
+    d_factor[subblock_idx] = distance / distance_threshold;
+    d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
+  }
 
   for (int plane = 0; plane < num_planes; ++plane) {
     const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
@@ -271,12 +255,28 @@
     const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
     const double decay_factor = 1 / (n_decay * q_decay * s_decay);
 
-    apply_temporal_filter(
-        ref, frame_stride, pred + mb_pels * plane, plane_w, plane_w, plane_h,
-        min_frame_size, subblock_mvs, subblock_mses, accum + mb_pels * plane,
-        count + mb_pels * plane, luma_sq_error, chroma_sq_error, plane,
-        ss_x_shift, ss_y_shift, inv_num_ref_pixels, decay_factor, inv_factor,
-        weight_factor);
+    // Filter U-plane and V-plane using Y-plane. This is because motion
+    // search is only done on Y-plane, so the information from Y-plane
+    // will be more accurate. The luma sse sum is reused in both chroma
+    // planes.
+    if (plane == AOM_PLANE_U) {
+      for (unsigned int i = 0, k = 0; i < plane_h; i++) {
+        for (unsigned int j = 0; j < plane_w; j++, k++) {
+          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
+              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
+              luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
+            }
+          }
+        }
+      }
+    }
+
+    apply_temporal_filter(ref, frame_stride, pred + mb_pels * plane, plane_w,
+                          plane_w, plane_h, subblock_mses,
+                          accum + mb_pels * plane, count + mb_pels * plane,
+                          frame_sse, luma_sse_sum, inv_num_ref_pixels,
+                          decay_factor, inv_factor, weight_factor, d_factor);
   }
-  if (chroma_sq_error != NULL) aom_free(chroma_sq_error);
 }