Refactor `temporal_filter.c` (Step-2).

Refactor Step-2: Function to apply temporal filter on YUV planes.

The refactoring is from following aspects:
  (1) Improve the interface by reducing number of arguments.
  (2) Align unit test with the new interface.
  (3) Handle different planes uniformly.
  (4) Handle low bit-depth and high bit-depth video uniformly.
  (5) Remove redundant functions.
  (6) Improve readablity.

The implementation with sse4 version is untouched.

Change-Id: Iba7fc490249faadfc20927f2b12a41948a54f324
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 819f480..1aa4508 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -277,8 +277,8 @@
   add_proto qw/int av1_full_range_search/, "const struct macroblock *x, const struct search_site_config *cfg, MV *ref_mv, MV *best_mv, int search_param, int sad_per_bit, int *num00, const struct aom_variance_vtable *fn_ptr, const MV *center_mv";
 
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
-    add_proto qw/void av1_apply_temporal_filter/, "const uint8_t *y_frame1, int y_stride, const uint8_t *y_pred, int y_buf_stride, const uint8_t *u_frame1, const uint8_t *v_frame1, int uv_stride, const uint8_t *u_pred, const uint8_t *v_pred, int uv_buf_stride, unsigned int block_width, unsigned int block_height, int ss_x, int ss_y, int strength, const int *blk_fw, int use_32x32, uint32_t *y_accumulator, uint16_t *y_count, uint32_t *u_accumulator, uint16_t *u_count, uint32_t *v_accumulator, uint16_t *v_count";
-    specialize qw/av1_apply_temporal_filter sse4_1/;
+    add_proto qw/void av1_apply_temporal_filter_yuv/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int strength, const int use_subblock, const int *subblock_filter_weights, const uint8_t *pred, uint32_t *accum, uint16_t *count";
+    specialize qw/av1_apply_temporal_filter_yuv sse4_1/;
   }
 
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
@@ -294,8 +294,8 @@
   }
 
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
-    add_proto qw/void av1_highbd_apply_temporal_filter/, "const uint8_t *yf, int y_stride, const uint8_t *yp, int y_buf_stride, const uint8_t *uf, const uint8_t *vf, int uv_stride, const uint8_t *up, const uint8_t *vp, int uv_buf_stride, unsigned int block_width, unsigned int block_height, int ss_x, int ss_y, int strength, const int *blk_fw, int use_32x32, uint32_t *y_accumulator, uint16_t *y_count, uint32_t *u_accumulator, uint16_t *u_count, uint32_t *v_accumulator, uint16_t *v_count";
-    specialize qw/av1_highbd_apply_temporal_filter sse4_1/;
+    add_proto qw/void av1_highbd_apply_temporal_filter_yuv/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int strength, const int use_subblock, const int *subblock_filter_weights, const uint8_t *pred, uint32_t *accum, uint16_t *count";
+    specialize qw/av1_highbd_apply_temporal_filter_yuv sse4_1/;
   }
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index a9d554e..f9953c7 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -38,55 +38,6 @@
 
 // NOTE: All `tf` in this file means `temporal filtering`.
 
-// Magic numbers used to compute pixel-wise averaging weights for filtering.
-// Only supports 3x3 window for filtering, hence, there are totally 9 non-zero
-// numbers within the 14-element multiplier lookup table. 5 zeros should never
-// be visited.
-static const unsigned int index_mult[14] = { 0,     0,     0,     0,     49152,
-                                             39322, 32768, 28087, 24576, 21846,
-                                             19661, 17874, 0,     15124 };
-// Magic numbers (for high bit-depth).
-static const int64_t highbd_index_mult[14] = {
-  0U,          0U,          0U,          0U,          3221225472U,
-  2576980378U, 2147483648U, 1840700270U, 1610612736U, 1431655766U,
-  1288490189U, 1171354718U, 0U,          991146300U
-};
-
-static INLINE int mod_index(int sum_dist, int index, int rounding, int strength,
-                            int filter_weight) {
-  assert(index >= 0 && index <= 13);
-  assert(index_mult[index] != 0);
-
-  int mod = (clamp(sum_dist, 0, UINT16_MAX) * index_mult[index]) >> 16;
-  mod += rounding;
-  mod >>= strength;
-
-  mod = AOMMIN(16, mod);
-
-  mod = 16 - mod;
-  mod *= filter_weight;
-
-  return mod;
-}
-
-static INLINE int highbd_mod_index(int64_t sum_dist, int index, int rounding,
-                                   int strength, int filter_weight) {
-  assert(index >= 0 && index <= 13);
-  assert(highbd_index_mult[index] != 0);
-
-  int mod =
-      (int)((AOMMIN(sum_dist, INT32_MAX) * highbd_index_mult[index]) >> 32);
-  mod += rounding;
-  mod >>= strength;
-
-  mod = AOMMIN(16, mod);
-
-  mod = 16 - mod;
-  mod *= filter_weight;
-
-  return mod;
-}
-
 // Builds predictors for blocks in temporal filtering.
 // Inputs:
 //   ref_frame: Pointer to the frame for filtering.
@@ -97,7 +48,7 @@
 //   scale: Scaling factor.
 //   num_planes: Number of planes in the frame.
 //   use_subblock: Whether to use four sub-blocks to replace the original block.
-//   subblock_mvs: The motion vectors for each sub-blocks. (row-major order).
+//   subblock_mvs: The motion vectors for each sub-blocks (row-major order).
 //   pred: Pointer to the predictors to build.
 // Returns:
 //   Nothing will be returned. But the content to which `pred` points will be
@@ -112,7 +63,7 @@
   // Information of the entire block.
   const int mb_height = block_size_high[block_size];  // Height.
   const int mb_width = block_size_wide[block_size];   // Width.
-  const int mb_size = mb_height * mb_width;           // Number of pixels.
+  const int mb_pels = mb_height * mb_width;           // Number of pixels.
   const int mb_y = mb_height * mb_row;                // Y-coord (Top-left).
   const int mb_x = mb_width * mb_col;                 // X-coord (Top-left).
   const int bit_depth = mbd->bd;                      // Bit depth.
@@ -174,7 +125,7 @@
         ++subblock_idx;
       }
     }
-    plane_offset += mb_size;
+    plane_offset += mb_pels;
   }
 }
 
@@ -215,318 +166,272 @@
   }
 }
 
-static INLINE void calculate_squared_errors(const uint8_t *s, int s_stride,
-                                            const uint8_t *p, int p_stride,
-                                            uint16_t *diff_sse, unsigned int w,
-                                            unsigned int h) {
+// Function to compute pixel-wise squared difference between two buffers.
+// Inputs:
+//   ref: Pointer to reference buffer.
+//   ref_offset: Start position of reference buffer for computation.
+//   ref_stride: Stride for reference buffer.
+//   tgt: Pointer to target buffer.
+//   tgt_offset: Start position of target buffer for computation.
+//   tgt_stride: Stride for target buffer.
+//   height: Height of block for computation.
+//   width: Width of block for computation.
+//   is_high_bitdepth: Whether the two buffers point to high bit-depth frames.
+//   square_diff: Pointer to save the squared differces.
+// Returns:
+//   Nothing will be returned. But the content to which `square_diff` points
+//   will be modified.
+static INLINE void compute_square_diff(const uint8_t *ref, const int ref_offset,
+                                       const int ref_stride, const uint8_t *tgt,
+                                       const int tgt_offset,
+                                       const int tgt_stride, const int height,
+                                       const int width,
+                                       const int is_high_bitdepth,
+                                       uint32_t *square_diff) {
+  const uint16_t *ref16 = CONVERT_TO_SHORTPTR(ref);
+  const uint16_t *tgt16 = CONVERT_TO_SHORTPTR(tgt);
+
+  int ref_idx = 0;
+  int tgt_idx = 0;
   int idx = 0;
-  unsigned int i, j;
+  for (int i = 0; i < height; ++i) {
+    for (int j = 0; j < width; ++j) {
+      const uint16_t ref_value = is_high_bitdepth ? ref16[ref_offset + ref_idx]
+                                                  : ref[ref_offset + ref_idx];
+      const uint16_t tgt_value = is_high_bitdepth ? tgt16[tgt_offset + tgt_idx]
+                                                  : tgt[tgt_offset + tgt_idx];
+      const uint32_t diff = (ref_value > tgt_value) ? (ref_value - tgt_value)
+                                                    : (tgt_value - ref_value);
+      square_diff[idx] = diff * diff;
 
-  for (i = 0; i < h; i++) {
-    for (j = 0; j < w; j++) {
-      const int16_t diff = s[i * s_stride + j] - p[i * p_stride + j];
-      diff_sse[idx] = diff * diff;
-      idx++;
+      ++ref_idx;
+      ++tgt_idx;
+      ++idx;
     }
+    ref_idx += (ref_stride - width);
+    tgt_idx += (tgt_stride - width);
   }
 }
 
-static INLINE int get_filter_weight(unsigned int i, unsigned int j,
-                                    unsigned int block_height,
-                                    unsigned int block_width, const int *blk_fw,
-                                    int use_32x32) {
-  if (use_32x32)
-    // blk_fw[0] ~ blk_fw[3] are the same.
-    return blk_fw[0];
+// Magic numbers used to adjust the pixel-wise weight used in YUV filtering.
+// For now, it only supports 3x3 window for filtering.
+// The adjustment is performed with following steps:
+//   (1) For a particular pixel, compute the sum of squared difference between
+//       input frame and prediction in a small window (i.e., 3x3). There are
+//       three possible outcomes:
+//       (a) If the pixel locates in the middle of the plane, it has 9
+//           neighbours (self-included).
+//       (b) If the pixel locates on the edge of the plane, it has 6
+//           neighbours (self-included).
+//       (c) If the pixel locates on the corner of the plane, it has 4
+//           neighbours (self-included).
+//   (2) For Y-plane, it will also consider the squared difference from U-plane
+//       and V-plane at the corresponding position as reference. This leads to
+//       2 more neighbours.
+//   (3) For U-plane and V-plane, it will consider the squared difference from
+//       Y-plane at the corresponding position (after upsampling) as reference.
+//       This leads to 1 more (subsampling = 0) or 4 more (subsampling = 1)
+//       neighbours.
+//   (4) Find the modifier for adjustment from the lookup table according to
+//       number of reference pixels (neighbours) used. From above, the number
+//       of neighbours can be 9+2 (11), 6+2 (8), 4+2 (6), 9+1 (10), 6+1 (7),
+//       4+1 (5), 9+4 (13), 6+4 (10), 4+4 (8).
+// TODO(yjshen): Not sure what index 4 and index 9 are for.
+static const uint32_t filter_weight_adjustment_lookup_table_yuv[14] = {
+  0, 0, 0, 0, 49152, 39322, 32768, 28087, 24576, 21846, 19661, 17874, 0, 15124
+};
+// Lookup table for high bit-depth.
+static const uint64_t highbd_filter_weight_adjustment_lookup_table_yuv[14] = {
+  0U,          0U,          0U,          0U,          3221225472U,
+  2576980378U, 2147483648U, 1840700270U, 1610612736U, 1431655766U,
+  1288490189U, 1171354718U, 0U,          991146300U
+};
 
-  int filter_weight = 0;
-  if (i < block_height / 2) {
-    if (j < block_width / 2)
-      filter_weight = blk_fw[0];
-    else
-      filter_weight = blk_fw[1];
-  } else {
-    if (j < block_width / 2)
-      filter_weight = blk_fw[2];
-    else
-      filter_weight = blk_fw[3];
-  }
-  return filter_weight;
-}
+// Function to adjust the filter weight when applying YUV filter.
+// Inputs:
+//   filter_weight: Original filter weight.
+//   sum_square_diff: Sum of squared difference between input frame and
+//                    prediction. This field is computed pixel by pixel, and
+//                    is used as a reference for the filter weight adjustment.
+//   num_ref_pixels: Number of pixels used to compute the `sum_square_diff`.
+//                   This field should align with the above lookup tables
+//                   `filter_weight_adjustment_lookup_table_yuv` and
+//                   `highbd_filter_weight_adjustment_lookup_table_yuv`.
+//   strength: Strength for filter weight adjustment.
+//   is_high_bitdepth: Whether apply temporal filter to high bie-depth video.
+// Returns:
+//   Adjusted filter weight which will finally be used for filtering..
+static INLINE int adjust_filter_weight_yuv(const int filter_weight,
+                                           const uint64_t sum_square_diff,
+                                           const int num_ref_pixels,
+                                           const int strength,
+                                           const int is_high_bitdepth) {
+  assert(YUV_FILTER_WINDOW_LENGTH == 3);
+  assert(num_ref_pixels >= 0 && num_ref_pixels <= 13);
 
-void av1_apply_temporal_filter_c(
-    const uint8_t *y_frame1, int y_stride, const uint8_t *y_pred,
-    int y_buf_stride, const uint8_t *u_frame1, const uint8_t *v_frame1,
-    int uv_stride, const uint8_t *u_pred, const uint8_t *v_pred,
-    int uv_buf_stride, unsigned int block_width, unsigned int block_height,
-    int ss_x, int ss_y, int strength, const int *blk_fw, int use_32x32,
-    uint32_t *y_accumulator, uint16_t *y_count, uint32_t *u_accumulator,
-    uint16_t *u_count, uint32_t *v_accumulator, uint16_t *v_count) {
-  unsigned int i, j, k, m;
-  int modifier;
+  const uint64_t multiplier =
+      is_high_bitdepth
+          ? highbd_filter_weight_adjustment_lookup_table_yuv[num_ref_pixels]
+          : filter_weight_adjustment_lookup_table_yuv[num_ref_pixels];
+  assert(multiplier != 0);
+
+  const uint32_t max_value = is_high_bitdepth ? UINT32_MAX : UINT16_MAX;
+  const int shift = is_high_bitdepth ? 32 : 16;
+  int modifier =
+      (int)((AOMMIN(sum_square_diff, max_value) * multiplier) >> shift);
+
   const int rounding = (1 << strength) >> 1;
-  const unsigned int uv_block_width = block_width >> ss_x;
-  const unsigned int uv_block_height = block_height >> ss_y;
-  DECLARE_ALIGNED(16, uint16_t, y_diff_sse[BLK_PELS]);
-  DECLARE_ALIGNED(16, uint16_t, u_diff_sse[BLK_PELS]);
-  DECLARE_ALIGNED(16, uint16_t, v_diff_sse[BLK_PELS]);
+  modifier = (modifier + rounding) >> strength;
+  return (modifier >= 16) ? 0 : (16 - modifier) * filter_weight;
+}
 
-  int idx = 0, idy;
+// Applies temporal filter to YUV planes.
+// Inputs:
+//   ref_frame: Pointer to the frame for filtering.
+//   mbd: Pointer to the block for filtering.
+//   block_size: Size of the block.
+//   mb_row: Row index of the block in the entire frame.
+//   mb_col: Column index of the block in the entire frame.
+//   strength: Strength for filter weight adjustment.
+//   use_subblock: Whether to use four sub-blocks to replace the original block.
+//   subblock_filter_weights: The filter weights for each sub-block (row-major
+//                            order). If `use_subblock` is set as 0, the first
+//                            weight will be applied to the entire block.
+//   pred: Pointer to the well-built predictors.
+//   accum: Pointer to the pixel-wise accumulator for filtering.
+//   count: Pointer to the pixel-wise counter fot filtering.
+// Returns:
+//   Nothing will be returned. But the content to which `accum` and `pred`
+//   point will be modified.
+void av1_apply_temporal_filter_yuv_c(const YV12_BUFFER_CONFIG *ref_frame,
+                                     const MACROBLOCKD *mbd,
+                                     const BLOCK_SIZE block_size,
+                                     const int mb_row, const int mb_col,
+                                     const int strength, const int use_subblock,
+                                     const int *subblock_filter_weights,
+                                     const uint8_t *pred, uint32_t *accum,
+                                     uint16_t *count) {
+  // Block information.
+  const int mb_height = block_size_high[block_size];
+  const int mb_width = block_size_wide[block_size];
+  const int mb_pels = mb_height * mb_width;
+  // TODO(yjshen): Not sure if this is equivalent to is_cur_buf_hbd(mbd).
+  const int is_high_bitdepth = (mbd->bd > 8);
+  const uint16_t *pred16 = CONVERT_TO_SHORTPTR(pred);
 
-  memset(y_diff_sse, 0, BLK_PELS * sizeof(uint16_t));
-  memset(u_diff_sse, 0, BLK_PELS * sizeof(uint16_t));
-  memset(v_diff_sse, 0, BLK_PELS * sizeof(uint16_t));
+  // Allocate memory for pixel-wise squared differences for Y, U, V planes. All
+  // planes, regardless of the subsampling, are assigned with memory of size
+  // `mb_pels`.
+  uint32_t *square_diff =
+      aom_memalign(16, MAX_MB_PLANE * mb_pels * sizeof(uint32_t));
+  memset(square_diff, 0, MAX_MB_PLANE * mb_pels * sizeof(uint32_t));
 
-  // Calculate diff^2 for each pixel of the block.
-  // TODO(yunqing): the following code needs to be optimized.
-  calculate_squared_errors(y_frame1, y_stride, y_pred, y_buf_stride, y_diff_sse,
-                           block_width, block_height);
-  calculate_squared_errors(u_frame1, uv_stride, u_pred, uv_buf_stride,
-                           u_diff_sse, uv_block_width, uv_block_height);
-  calculate_squared_errors(v_frame1, uv_stride, v_pred, uv_buf_stride,
-                           v_diff_sse, uv_block_width, uv_block_height);
+  int plane_offset = 0;
+  for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    // Locate pixel on reference frame.
+    const int plane_h = mb_height >> mbd->plane[plane].subsampling_y;
+    const int plane_w = mb_width >> mbd->plane[plane].subsampling_x;
+    const int frame_stride = ref_frame->strides[plane == 0 ? 0 : 1];
+    const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
+    const uint8_t *ref = ref_frame->buffers[plane];
+    compute_square_diff(ref, frame_offset, frame_stride, pred, plane_offset,
+                        plane_w, plane_h, plane_w, is_high_bitdepth,
+                        square_diff + plane_offset);
+    plane_offset += mb_pels;
+  }
 
-  for (i = 0, k = 0, m = 0; i < block_height; i++) {
-    for (j = 0; j < block_width; j++) {
-      const int pixel_value = y_pred[i * y_buf_stride + j];
-      int filter_weight =
-          get_filter_weight(i, j, block_height, block_width, blk_fw, use_32x32);
+  // Get window size for pixel-wise filtering.
+  assert(YUV_FILTER_WINDOW_LENGTH % 2 == 1);
+  const int half_window = YUV_FILTER_WINDOW_LENGTH >> 1;
 
-      // non-local mean approach
-      int y_index = 0;
+  // Handle Y-plane, U-plane, V-plane in sequence.
+  plane_offset = 0;
+  for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+    const int subsampling_y = mbd->plane[plane].subsampling_y;
+    const int subsampling_x = mbd->plane[plane].subsampling_x;
+    // Only 0 and 1 are supported for filter weight adjustment.
+    assert(subsampling_y == 0 || subsampling_y == 1);
+    assert(subsampling_x == 0 || subsampling_x == 1);
+    const int h = mb_height >> subsampling_y;  // Plane height.
+    const int w = mb_width >> subsampling_x;   // Plane width.
 
-      const int uv_r = i >> ss_y;
-      const int uv_c = j >> ss_x;
-      modifier = 0;
-
-      for (idy = -1; idy <= 1; ++idy) {
-        for (idx = -1; idx <= 1; ++idx) {
-          const int row = (int)i + idy;
-          const int col = (int)j + idx;
-
-          if (row >= 0 && row < (int)block_height && col >= 0 &&
-              col < (int)block_width) {
-            modifier += y_diff_sse[row * (int)block_width + col];
-            ++y_index;
-          }
-        }
-      }
-
-      assert(y_index > 0);
-
-      modifier += u_diff_sse[uv_r * uv_block_width + uv_c];
-      modifier += v_diff_sse[uv_r * uv_block_width + uv_c];
-
-      y_index += 2;
-
-      modifier =
-          (int)mod_index(modifier, y_index, rounding, strength, filter_weight);
-
-      y_count[k] += modifier;
-      y_accumulator[k] += modifier * pixel_value;
-
-      ++k;
-
-      // Process chroma component
-      if (!(i & ss_y) && !(j & ss_x)) {
-        const int u_pixel_value = u_pred[uv_r * uv_buf_stride + uv_c];
-        const int v_pixel_value = v_pred[uv_r * uv_buf_stride + uv_c];
+    // Perform filtering.
+    int pred_idx = 0;
+    for (int i = 0; i < h; ++i) {
+      for (int j = 0; j < w; ++j) {
+        const int subblock_idx =
+            use_subblock ? (i >= h / 2) * 2 + (j >= w / 2) : 0;
+        const int filter_weight = subblock_filter_weights[subblock_idx];
 
         // non-local mean approach
-        int cr_index = 0;
-        int u_mod = 0, v_mod = 0;
-        int y_diff = 0;
+        uint64_t sum_square_diff = 0;
+        int num_ref_pixels = 0;
 
-        for (idy = -1; idy <= 1; ++idy) {
-          for (idx = -1; idx <= 1; ++idx) {
-            const int row = uv_r + idy;
-            const int col = uv_c + idx;
-
-            if (row >= 0 && row < (int)uv_block_height && col >= 0 &&
-                col < (int)uv_block_width) {
-              u_mod += u_diff_sse[row * uv_block_width + col];
-              v_mod += v_diff_sse[row * uv_block_width + col];
-              ++cr_index;
+        for (int wi = -half_window; wi <= half_window; ++wi) {
+          for (int wj = -half_window; wj <= half_window; ++wj) {
+            const int y = i + wi;  // Y-coord on the current plane.
+            const int x = j + wj;  // X-coord on the current plane.
+            if (y >= 0 && y < h && x >= 0 && x < w) {
+              sum_square_diff += square_diff[plane_offset + y * w + x];
+              ++num_ref_pixels;
             }
           }
         }
 
-        assert(cr_index > 0);
-
-        for (idy = 0; idy < 1 + ss_y; ++idy) {
-          for (idx = 0; idx < 1 + ss_x; ++idx) {
-            const int row = (uv_r << ss_y) + idy;
-            const int col = (uv_c << ss_x) + idx;
-            y_diff += y_diff_sse[row * (int)block_width + col];
-            ++cr_index;
+        if (plane == 0) {  // Filter Y-plane using both U-plane and V-plane.
+          for (int p = 1; p < MAX_MB_PLANE; ++p) {
+            const int ss_y_shift = mbd->plane[p].subsampling_y - subsampling_y;
+            const int ss_x_shift = mbd->plane[p].subsampling_x - subsampling_x;
+            const int yy = i >> ss_y_shift;  // Y-coord on UV-plane.
+            const int xx = j >> ss_x_shift;  // X-coord on UV-plane.
+            const int ww = w >> ss_x_shift;  // Width of UV-plane.
+            sum_square_diff += square_diff[p * mb_pels + yy * ww + xx];
+            ++num_ref_pixels;
           }
-        }
-
-        u_mod += y_diff;
-        v_mod += y_diff;
-
-        u_mod =
-            (int)mod_index(u_mod, cr_index, rounding, strength, filter_weight);
-        v_mod =
-            (int)mod_index(v_mod, cr_index, rounding, strength, filter_weight);
-
-        u_count[m] += u_mod;
-        u_accumulator[m] += u_mod * u_pixel_value;
-        v_count[m] += v_mod;
-        v_accumulator[m] += v_mod * v_pixel_value;
-
-        ++m;
-      }  // Complete YUV pixel
-    }
-  }
-}
-
-static INLINE void highbd_calculate_squared_errors(
-    const uint16_t *s, int s_stride, const uint16_t *p, int p_stride,
-    uint32_t *diff_sse, unsigned int w, unsigned int h) {
-  int idx = 0;
-  unsigned int i, j;
-
-  for (i = 0; i < h; i++) {
-    for (j = 0; j < w; j++) {
-      const int16_t diff = s[i * s_stride + j] - p[i * p_stride + j];
-      diff_sse[idx] = diff * diff;
-      idx++;
-    }
-  }
-}
-
-void av1_highbd_apply_temporal_filter_c(
-    const uint8_t *yf, int y_stride, const uint8_t *yp, int y_buf_stride,
-    const uint8_t *uf, const uint8_t *vf, int uv_stride, const uint8_t *up,
-    const uint8_t *vp, int uv_buf_stride, unsigned int block_width,
-    unsigned int block_height, int ss_x, int ss_y, int strength,
-    const int *blk_fw, int use_32x32, uint32_t *y_accumulator,
-    uint16_t *y_count, uint32_t *u_accumulator, uint16_t *u_count,
-    uint32_t *v_accumulator, uint16_t *v_count) {
-  unsigned int i, j, k, m;
-  int64_t modifier;
-  const int rounding = (1 << strength) >> 1;
-  const unsigned int uv_block_width = block_width >> ss_x;
-  const unsigned int uv_block_height = block_height >> ss_y;
-  DECLARE_ALIGNED(16, uint32_t, y_diff_sse[BLK_PELS]);
-  DECLARE_ALIGNED(16, uint32_t, u_diff_sse[BLK_PELS]);
-  DECLARE_ALIGNED(16, uint32_t, v_diff_sse[BLK_PELS]);
-
-  const uint16_t *y_frame1 = CONVERT_TO_SHORTPTR(yf);
-  const uint16_t *u_frame1 = CONVERT_TO_SHORTPTR(uf);
-  const uint16_t *v_frame1 = CONVERT_TO_SHORTPTR(vf);
-  const uint16_t *y_pred = CONVERT_TO_SHORTPTR(yp);
-  const uint16_t *u_pred = CONVERT_TO_SHORTPTR(up);
-  const uint16_t *v_pred = CONVERT_TO_SHORTPTR(vp);
-  int idx = 0, idy;
-
-  memset(y_diff_sse, 0, BLK_PELS * sizeof(uint32_t));
-  memset(u_diff_sse, 0, BLK_PELS * sizeof(uint32_t));
-  memset(v_diff_sse, 0, BLK_PELS * sizeof(uint32_t));
-
-  // Calculate diff^2 for each pixel of the block.
-  // TODO(yunqing): the following code needs to be optimized.
-  highbd_calculate_squared_errors(y_frame1, y_stride, y_pred, y_buf_stride,
-                                  y_diff_sse, block_width, block_height);
-  highbd_calculate_squared_errors(u_frame1, uv_stride, u_pred, uv_buf_stride,
-                                  u_diff_sse, uv_block_width, uv_block_height);
-  highbd_calculate_squared_errors(v_frame1, uv_stride, v_pred, uv_buf_stride,
-                                  v_diff_sse, uv_block_width, uv_block_height);
-
-  for (i = 0, k = 0, m = 0; i < block_height; i++) {
-    for (j = 0; j < block_width; j++) {
-      const int pixel_value = y_pred[i * y_buf_stride + j];
-      int filter_weight =
-          get_filter_weight(i, j, block_height, block_width, blk_fw, use_32x32);
-
-      // non-local mean approach
-      int y_index = 0;
-
-      const int uv_r = i >> ss_y;
-      const int uv_c = j >> ss_x;
-      modifier = 0;
-
-      for (idy = -1; idy <= 1; ++idy) {
-        for (idx = -1; idx <= 1; ++idx) {
-          const int row = (int)i + idy;
-          const int col = (int)j + idx;
-
-          if (row >= 0 && row < (int)block_height && col >= 0 &&
-              col < (int)block_width) {
-            modifier += y_diff_sse[row * (int)block_width + col];
-            ++y_index;
-          }
-        }
-      }
-
-      assert(y_index > 0);
-
-      modifier += u_diff_sse[uv_r * uv_block_width + uv_c];
-      modifier += v_diff_sse[uv_r * uv_block_width + uv_c];
-
-      y_index += 2;
-
-      const int final_y_mod = highbd_mod_index(modifier, y_index, rounding,
-                                               strength, filter_weight);
-
-      y_count[k] += final_y_mod;
-      y_accumulator[k] += final_y_mod * pixel_value;
-
-      ++k;
-
-      // Process chroma component
-      if (!(i & ss_y) && !(j & ss_x)) {
-        const int u_pixel_value = u_pred[uv_r * uv_buf_stride + uv_c];
-        const int v_pixel_value = v_pred[uv_r * uv_buf_stride + uv_c];
-
-        // non-local mean approach
-        int cr_index = 0;
-        int64_t u_mod = 0, v_mod = 0;
-        int y_diff = 0;
-
-        for (idy = -1; idy <= 1; ++idy) {
-          for (idx = -1; idx <= 1; ++idx) {
-            const int row = uv_r + idy;
-            const int col = uv_c + idx;
-
-            if (row >= 0 && row < (int)uv_block_height && col >= 0 &&
-                col < (int)uv_block_width) {
-              u_mod += u_diff_sse[row * uv_block_width + col];
-              v_mod += v_diff_sse[row * uv_block_width + col];
-              ++cr_index;
+        } else {  // Filter U-plane and V-plane using Y-plane.
+          const int ss_y_shift = subsampling_y - mbd->plane[0].subsampling_y;
+          const int ss_x_shift = subsampling_x - mbd->plane[0].subsampling_x;
+          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
+              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
+              const int ww = w << ss_x_shift;         // Width of Y-plane.
+              sum_square_diff += square_diff[yy * ww + xx];
+              ++num_ref_pixels;
             }
           }
         }
 
-        assert(cr_index > 0);
+        const int adjusted_weight = adjust_filter_weight_yuv(
+            filter_weight, sum_square_diff, num_ref_pixels, strength,
+            is_high_bitdepth);
 
-        for (idy = 0; idy < 1 + ss_y; ++idy) {
-          for (idx = 0; idx < 1 + ss_x; ++idx) {
-            const int row = (uv_r << ss_y) + idy;
-            const int col = (uv_c << ss_x) + idx;
-            y_diff += y_diff_sse[row * (int)block_width + col];
-            ++cr_index;
-          }
-        }
+        const int pred_value = is_high_bitdepth
+                                   ? pred16[plane_offset + pred_idx]
+                                   : pred[plane_offset + pred_idx];
+        accum[plane_offset + pred_idx] += adjusted_weight * pred_value;
+        count[plane_offset + pred_idx] += adjusted_weight;
 
-        u_mod += y_diff;
-        v_mod += y_diff;
-
-        const int final_u_mod = highbd_mod_index(u_mod, cr_index, rounding,
-                                                 strength, filter_weight);
-        const int final_v_mod = highbd_mod_index(v_mod, cr_index, rounding,
-                                                 strength, filter_weight);
-
-        u_count[m] += final_u_mod;
-        u_accumulator[m] += final_u_mod * u_pixel_value;
-        v_count[m] += final_v_mod;
-        v_accumulator[m] += final_v_mod * v_pixel_value;
-
-        ++m;
-      }  // Complete YUV pixel
+        ++pred_idx;
+      }
     }
+    plane_offset += mb_pels;
   }
+
+  aom_free(square_diff);
+}
+
+// Applies temporal filter to YUV planes (high bit-depth video).
+// NOTE: This function is now merged to `av1_apply_temporal_filter_yuv_c()`.
+void av1_highbd_apply_temporal_filter_yuv_c(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int strength, const int use_subblock,
+    const int *subblock_filter_weights, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
+  av1_apply_temporal_filter_yuv_c(ref_frame, mbd, block_size, mb_row, mb_col,
+                                  strength, use_subblock,
+                                  subblock_filter_weights, pred, accum, count);
 }
 
 // Only used in single plane case
@@ -543,8 +448,9 @@
   for (i = 0, k = 0; i < block_height; i++) {
     for (j = 0; j < block_width; j++, k++) {
       int pixel_value = *frame2;
-      int filter_weight =
-          get_filter_weight(i, j, block_height, block_width, blk_fw, use_32x32);
+      const int subblock_idx =
+          use_32x32 ? 0 : (i >= block_height / 2) * 2 + (j >= block_width / 2);
+      const int filter_weight = blk_fw[subblock_idx];
 
       // non-local mean approach
       int diff_sse[9] = { 0 };
@@ -608,8 +514,9 @@
   for (i = 0, k = 0; i < block_height; i++) {
     for (j = 0; j < block_width; j++, k++) {
       int pixel_value = *frame2;
-      int filter_weight =
-          get_filter_weight(i, j, block_height, block_width, blk_fw, use_32x32);
+      const int subblock_idx =
+          use_32x32 ? 0 : (i >= block_height / 2) * 2 + (j >= block_width / 2);
+      const int filter_weight = blk_fw[subblock_idx];
 
       // non-local mean approach
       int diff_sse[9] = { 0 };
@@ -748,14 +655,12 @@
   }
 }
 
-void apply_temporal_filter_block(YV12_BUFFER_CONFIG *frame, MACROBLOCKD *mbd,
-                                 int mb_y_src_offset, int mb_uv_src_offset,
-                                 int mb_uv_width, int mb_uv_height,
-                                 int num_planes, uint8_t *predictor,
-                                 int frame_height, int strength, double sigma,
-                                 int *blk_fw, int use_32x32,
-                                 unsigned int *accumulator, uint16_t *count,
-                                 int use_new_temporal_mode) {
+void apply_temporal_filter_block(
+    YV12_BUFFER_CONFIG *frame, MACROBLOCKD *mbd, int mb_y_src_offset,
+    int mb_uv_src_offset, int mb_uv_width, int mb_uv_height, int num_planes,
+    uint8_t *predictor, int frame_height, int strength, double sigma,
+    int *blk_fw, int use_32x32, unsigned int *accumulator, uint16_t *count,
+    int use_new_temporal_mode, int mb_row, int mb_col) {
   const int is_hbd = is_cur_buf_hbd(mbd);
   // High bitdepth
   if (is_hbd) {
@@ -796,15 +701,9 @@
             BH, adj_strength, blk_fw, use_32x32, accumulator, count);
       } else {
         // Process 3 planes together.
-        av1_highbd_apply_temporal_filter(
-            frame->y_buffer + mb_y_src_offset, frame->y_stride, predictor, BW,
-            frame->u_buffer + mb_uv_src_offset,
-            frame->v_buffer + mb_uv_src_offset, frame->uv_stride,
-            predictor + BLK_PELS, predictor + (BLK_PELS << 1), mb_uv_width, BW,
-            BH, mbd->plane[1].subsampling_x, mbd->plane[1].subsampling_y,
-            adj_strength, blk_fw, use_32x32, accumulator, count,
-            accumulator + BLK_PELS, count + BLK_PELS,
-            accumulator + (BLK_PELS << 1), count + (BLK_PELS << 1));
+        av1_highbd_apply_temporal_filter_yuv(
+            frame, mbd, TF_BLOCK, mb_row, mb_col, adj_strength, !(use_32x32),
+            blk_fw, predictor, accumulator, count);
       }
     }
     return;
@@ -847,15 +746,9 @@
                                   blk_fw, use_32x32, accumulator, count);
     } else {
       // Process 3 planes together.
-      av1_apply_temporal_filter(
-          frame->y_buffer + mb_y_src_offset, frame->y_stride, predictor, BW,
-          frame->u_buffer + mb_uv_src_offset,
-          frame->v_buffer + mb_uv_src_offset, frame->uv_stride,
-          predictor + BLK_PELS, predictor + (BLK_PELS << 1), mb_uv_width, BW,
-          BH, mbd->plane[1].subsampling_x, mbd->plane[1].subsampling_y,
-          strength, blk_fw, use_32x32, accumulator, count,
-          accumulator + BLK_PELS, count + BLK_PELS,
-          accumulator + (BLK_PELS << 1), count + (BLK_PELS << 1));
+      av1_apply_temporal_filter_yuv(frame, mbd, TF_BLOCK, mb_row, mb_col,
+                                    strength, !(use_32x32), blk_fw, predictor,
+                                    accumulator, count);
     }
   }
 }
@@ -1186,7 +1079,7 @@
                   f, mbd, mb_y_src_offset, mb_uv_src_offset, mb_uv_width,
                   mb_uv_height, num_planes, predictor, cm->height, strength,
                   sigma, blk_fw, use_32x32, accumulator, count,
-                  use_new_temporal_mode);
+                  use_new_temporal_mode, mb_row, mb_col);
 #else
               const int adj_strength = strength + 2 * (mbd->bd - 8);
               if (num_planes <= 1) {
@@ -1196,16 +1089,9 @@
                     BH, adj_strength, blk_fw, use_32x32, accumulator, count);
               } else {
                 // Process 3 planes together.
-                av1_highbd_apply_temporal_filter(
-                    f->y_buffer + mb_y_src_offset, f->y_stride, predictor, BW,
-                    f->u_buffer + mb_uv_src_offset,
-                    f->v_buffer + mb_uv_src_offset, f->uv_stride,
-                    predictor + BLK_PELS, predictor + (BLK_PELS << 1),
-                    mb_uv_width, BW, BH, mbd->plane[1].subsampling_x,
-                    mbd->plane[1].subsampling_y, adj_strength, blk_fw,
-                    use_32x32, accumulator, count, accumulator + BLK_PELS,
-                    count + BLK_PELS, accumulator + (BLK_PELS << 1),
-                    count + (BLK_PELS << 1));
+                av1_highbd_apply_temporal_filter_yuv(
+                    frame, mbd, TF_BLOCK, mb_row, mb_col, adj_strength,
+                    !(use_32x32), blk_fw, predictor, accumulator, count);
               }
 #endif  // EXPERIMENT_TEMPORAL_FILTER
             } else {
@@ -1214,7 +1100,7 @@
                   f, mbd, mb_y_src_offset, mb_uv_src_offset, mb_uv_width,
                   mb_uv_height, num_planes, predictor, cm->height, strength,
                   sigma, blk_fw, use_32x32, accumulator, count,
-                  use_new_temporal_mode);
+                  use_new_temporal_mode, mb_row, mb_col);
 #else
               if (num_planes <= 1) {
                 // Single plane case
@@ -1223,16 +1109,9 @@
                     BH, strength, blk_fw, use_32x32, accumulator, count);
               } else {
                 // Process 3 planes together.
-                av1_apply_temporal_filter(
-                    f->y_buffer + mb_y_src_offset, f->y_stride, predictor, BW,
-                    f->u_buffer + mb_uv_src_offset,
-                    f->v_buffer + mb_uv_src_offset, f->uv_stride,
-                    predictor + BLK_PELS, predictor + (BLK_PELS << 1),
-                    mb_uv_width, BW, BH, mbd->plane[1].subsampling_x,
-                    mbd->plane[1].subsampling_y, strength, blk_fw, use_32x32,
-                    accumulator, count, accumulator + BLK_PELS,
-                    count + BLK_PELS, accumulator + (BLK_PELS << 1),
-                    count + (BLK_PELS << 1));
+                av1_apply_temporal_filter_yuv(
+                    frame, mbd, TF_BLOCK, mb_row, mb_col, strength,
+                    !(use_32x32), blk_fw, predictor, accumulator, count);
               }
 #endif  // EXPERIMENT_TEMPORAL_FILTER
             }
diff --git a/av1/encoder/temporal_filter.h b/av1/encoder/temporal_filter.h
index a8982f5..d7bbd66 100644
--- a/av1/encoder/temporal_filter.h
+++ b/av1/encoder/temporal_filter.h
@@ -39,6 +39,10 @@
 #define WINDOW_SIZE 25
 #define SCALE 1000
 
+// Window size for temporal filtering on YUV planes.
+// This is particually used for function `av1_apply_temporal_filter_yuv_c()`.
+static const int YUV_FILTER_WINDOW_LENGTH = 3;
+
 static INLINE BLOCK_SIZE dims_to_size(int w, int h) {
   if (w != h) return -1;
   switch (w) {
diff --git a/av1/encoder/x86/highbd_temporal_filter_sse4.c b/av1/encoder/x86/highbd_temporal_filter_sse4.c
index 768e193..18ff882 100644
--- a/av1/encoder/x86/highbd_temporal_filter_sse4.c
+++ b/av1/encoder/x86/highbd_temporal_filter_sse4.c
@@ -858,14 +858,47 @@
       top_weight, bottom_weight, NULL);
 }
 
-void av1_highbd_apply_temporal_filter_sse4_1(
-    const uint8_t *y_src, int y_src_stride, const uint8_t *y_pre,
-    int y_pre_stride, const uint8_t *u_src, const uint8_t *v_src,
-    int uv_src_stride, const uint8_t *u_pre, const uint8_t *v_pre,
-    int uv_pre_stride, unsigned int block_width, unsigned int block_height,
-    int ss_x, int ss_y, int strength, const int *blk_fw, int use_whole_blk,
-    uint32_t *y_accum, uint16_t *y_count, uint32_t *u_accum, uint16_t *u_count,
-    uint32_t *v_accum, uint16_t *v_count) {
+void av1_highbd_apply_temporal_filter_yuv_sse4_1(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int strength, const int use_subblock,
+    const int *subblock_filter_weights, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
+  const int use_whole_blk = !use_subblock;
+  const int *blk_fw = subblock_filter_weights;
+
+  // Block information (Y-plane).
+  const unsigned int block_height = block_size_high[block_size];
+  const unsigned int block_width = block_size_wide[block_size];
+  const int mb_pels = block_height * block_width;
+  const int y_src_stride = ref_frame->y_stride;
+  const int y_pre_stride = block_width;
+  const int mb_y_src_offset =
+      mb_row * block_height * ref_frame->y_stride + mb_col * block_width;
+
+  // Block information (UV-plane).
+  const int ss_y = mbd->plane[1].subsampling_y;
+  const int ss_x = mbd->plane[1].subsampling_x;
+  const unsigned int uv_height = block_height >> ss_y;
+  const unsigned int uv_width = block_width >> ss_x;
+  const int uv_src_stride = ref_frame->uv_stride;
+  const int uv_pre_stride = block_width >> ss_x;
+  const int mb_uv_src_offset =
+      mb_row * uv_height * ref_frame->uv_stride + mb_col * uv_width;
+
+  const uint8_t *y_src = ref_frame->y_buffer + mb_y_src_offset;
+  const uint8_t *u_src = ref_frame->u_buffer + mb_uv_src_offset;
+  const uint8_t *v_src = ref_frame->v_buffer + mb_uv_src_offset;
+  const uint8_t *y_pre = pred;
+  const uint8_t *u_pre = pred + mb_pels;
+  const uint8_t *v_pre = pred + mb_pels * 2;
+  uint32_t *y_accum = accum;
+  uint32_t *u_accum = accum + mb_pels;
+  uint32_t *v_accum = accum + mb_pels * 2;
+  uint16_t *y_count = count;
+  uint16_t *u_count = count + mb_pels;
+  uint16_t *v_count = count + mb_pels * 2;
+
   const unsigned int chroma_height = block_height >> ss_y,
                      chroma_width = block_width >> ss_x;
 
diff --git a/av1/encoder/x86/temporal_filter_sse4.c b/av1/encoder/x86/temporal_filter_sse4.c
index 6151e87..a6d0508 100644
--- a/av1/encoder/x86/temporal_filter_sse4.c
+++ b/av1/encoder/x86/temporal_filter_sse4.c
@@ -921,14 +921,47 @@
       bottom_weight, NULL);
 }
 
-void av1_apply_temporal_filter_sse4_1(
-    const uint8_t *y_src, int y_src_stride, const uint8_t *y_pre,
-    int y_pre_stride, const uint8_t *u_src, const uint8_t *v_src,
-    int uv_src_stride, const uint8_t *u_pre, const uint8_t *v_pre,
-    int uv_pre_stride, unsigned int block_width, unsigned int block_height,
-    int ss_x, int ss_y, int strength, const int *blk_fw, int use_whole_blk,
-    uint32_t *y_accum, uint16_t *y_count, uint32_t *u_accum, uint16_t *u_count,
-    uint32_t *v_accum, uint16_t *v_count) {
+void av1_apply_temporal_filter_yuv_sse4_1(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int strength, const int use_subblock,
+    const int *subblock_filter_weights, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
+  const int use_whole_blk = !use_subblock;
+  const int *blk_fw = subblock_filter_weights;
+
+  // Block information (Y-plane).
+  const unsigned int block_height = block_size_high[block_size];
+  const unsigned int block_width = block_size_wide[block_size];
+  const int mb_pels = block_height * block_width;
+  const int y_src_stride = ref_frame->y_stride;
+  const int y_pre_stride = block_width;
+  const int mb_y_src_offset =
+      mb_row * block_height * ref_frame->y_stride + mb_col * block_width;
+
+  // Block information (UV-plane).
+  const int ss_y = mbd->plane[1].subsampling_y;
+  const int ss_x = mbd->plane[1].subsampling_x;
+  const unsigned int uv_height = block_height >> ss_y;
+  const unsigned int uv_width = block_width >> ss_x;
+  const int uv_src_stride = ref_frame->uv_stride;
+  const int uv_pre_stride = block_width >> ss_x;
+  const int mb_uv_src_offset =
+      mb_row * uv_height * ref_frame->uv_stride + mb_col * uv_width;
+
+  const uint8_t *y_src = ref_frame->y_buffer + mb_y_src_offset;
+  const uint8_t *u_src = ref_frame->u_buffer + mb_uv_src_offset;
+  const uint8_t *v_src = ref_frame->v_buffer + mb_uv_src_offset;
+  const uint8_t *y_pre = pred;
+  const uint8_t *u_pre = pred + mb_pels;
+  const uint8_t *v_pre = pred + mb_pels * 2;
+  uint32_t *y_accum = accum;
+  uint32_t *u_accum = accum + mb_pels;
+  uint32_t *v_accum = accum + mb_pels * 2;
+  uint16_t *y_count = count;
+  uint16_t *u_count = count + mb_pels;
+  uint16_t *v_count = count + mb_pels * 2;
+
   const unsigned int chroma_height = block_height >> ss_y,
                      chroma_width = block_width >> ss_x;
 
diff --git a/test/yuv_temporal_filter_test.cc b/test/yuv_temporal_filter_test.cc
index f504794..e1dc177 100644
--- a/test/yuv_temporal_filter_test.cc
+++ b/test/yuv_temporal_filter_test.cc
@@ -25,13 +25,10 @@
 const int MAX_HEIGHT = 32;
 
 typedef void (*YUVTemporalFilterFunc)(
-    const uint8_t *y_src, int y_src_stride, const uint8_t *y_pre,
-    int y_pre_stride, const uint8_t *u_src, const uint8_t *v_src,
-    int uv_src_stride, const uint8_t *u_pre, const uint8_t *v_pre,
-    int uv_pre_stride, unsigned int block_width, unsigned int block_height,
-    int ss_x, int ss_y, int strength, const int *blk_fw, int use_32x32,
-    uint32_t *y_accumulator, uint16_t *y_count, uint32_t *u_accumulator,
-    uint16_t *u_count, uint32_t *v_accumulator, uint16_t *v_count);
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int strength, const int use_subblock, const int *blk_fw,
+    const uint8_t *pred, uint32_t *accum, uint16_t *count);
 
 struct TemporalFilterWithBd {
   TemporalFilterWithBd(YUVTemporalFilterFunc func, int bitdepth)
@@ -407,11 +404,68 @@
     int ss_x, int ss_y, int strength, const int *blk_fw, int use_32x32,
     uint32_t *y_accum, uint16_t *y_count, uint32_t *u_accum, uint16_t *u_count,
     uint32_t *v_accum, uint16_t *v_count) {
-  ASM_REGISTER_STATE_CHECK(
-      filter_func_(y_src, y_src_stride, y_pre, y_pre_stride, u_src, v_src,
-                   uv_src_stride, u_pre, v_pre, uv_pre_stride, block_width,
-                   block_height, ss_x, ss_y, strength, blk_fw, use_32x32,
-                   y_accum, y_count, u_accum, u_count, v_accum, v_count));
+  (void)block_width;
+  (void)block_height;
+  (void)y_src_stride;
+  (void)uv_src_stride;
+
+  assert(block_width == MAX_WIDTH && MAX_WIDTH == 32);
+  assert(block_height == MAX_HEIGHT && MAX_HEIGHT == 32);
+  const BLOCK_SIZE block_size = BLOCK_32X32;
+  const int mb_pels = MAX_WIDTH * MAX_HEIGHT;
+  const int mb_row = 0;
+  const int mb_col = 0;
+  const int use_subblock = !(use_32x32);
+
+  YV12_BUFFER_CONFIG *ref_frame =
+      (YV12_BUFFER_CONFIG *)malloc(sizeof(YV12_BUFFER_CONFIG));
+  ref_frame->strides[0] = y_pre_stride;
+  ref_frame->strides[1] = uv_pre_stride;
+  const int alloc_size = MAX_MB_PLANE * mb_pels;
+  DECLARE_ALIGNED(16, uint8_t, src[alloc_size]);
+  ref_frame->buffer_alloc = src;
+  ref_frame->buffers[0] = ref_frame->buffer_alloc + 0 * mb_pels;
+  ref_frame->buffers[1] = ref_frame->buffer_alloc + 1 * mb_pels;
+  ref_frame->buffers[2] = ref_frame->buffer_alloc + 2 * mb_pels;
+
+  MACROBLOCKD *mbd = (MACROBLOCKD *)malloc(sizeof(MACROBLOCKD));
+  mbd->bd = bd_;
+  mbd->plane[0].subsampling_y = 0;
+  mbd->plane[0].subsampling_x = 0;
+  mbd->plane[1].subsampling_y = ss_y;
+  mbd->plane[1].subsampling_x = ss_x;
+  mbd->plane[2].subsampling_y = ss_y;
+  mbd->plane[2].subsampling_x = ss_x;
+
+  DECLARE_ALIGNED(16, uint8_t, pred[alloc_size]);
+  DECLARE_ALIGNED(16, uint32_t, accum[alloc_size]);
+  DECLARE_ALIGNED(16, uint16_t, count[alloc_size]);
+  memcpy(src + 0 * mb_pels, y_src, mb_pels * sizeof(uint8_t));
+  memcpy(src + 1 * mb_pels, u_src, mb_pels * sizeof(uint8_t));
+  memcpy(src + 2 * mb_pels, v_src, mb_pels * sizeof(uint8_t));
+  memcpy(pred + 0 * mb_pels, y_pre, mb_pels * sizeof(uint8_t));
+  memcpy(pred + 1 * mb_pels, u_pre, mb_pels * sizeof(uint8_t));
+  memcpy(pred + 2 * mb_pels, v_pre, mb_pels * sizeof(uint8_t));
+  memcpy(accum + 0 * mb_pels, y_accum, mb_pels * sizeof(uint32_t));
+  memcpy(accum + 1 * mb_pels, u_accum, mb_pels * sizeof(uint32_t));
+  memcpy(accum + 2 * mb_pels, v_accum, mb_pels * sizeof(uint32_t));
+  memcpy(count + 0 * mb_pels, y_count, mb_pels * sizeof(uint16_t));
+  memcpy(count + 1 * mb_pels, u_count, mb_pels * sizeof(uint16_t));
+  memcpy(count + 2 * mb_pels, v_count, mb_pels * sizeof(uint16_t));
+
+  ASM_REGISTER_STATE_CHECK(filter_func_(ref_frame, mbd, block_size, mb_row,
+                                        mb_col, strength, use_subblock, blk_fw,
+                                        pred, accum, count));
+
+  memcpy(y_accum, accum + 0 * mb_pels, mb_pels * sizeof(uint32_t));
+  memcpy(u_accum, accum + 1 * mb_pels, mb_pels * sizeof(uint32_t));
+  memcpy(v_accum, accum + 2 * mb_pels, mb_pels * sizeof(uint32_t));
+  memcpy(y_count, count + 0 * mb_pels, mb_pels * sizeof(uint16_t));
+  memcpy(u_count, count + 1 * mb_pels, mb_pels * sizeof(uint16_t));
+  memcpy(v_count, count + 2 * mb_pels, mb_pels * sizeof(uint16_t));
+
+  free(ref_frame);
+  free(mbd);
 }
 
 template <>
@@ -423,12 +477,69 @@
     int ss_x, int ss_y, int strength, const int *blk_fw, int use_32x32,
     uint32_t *y_accum, uint16_t *y_count, uint32_t *u_accum, uint16_t *u_count,
     uint32_t *v_accum, uint16_t *v_count) {
-  ASM_REGISTER_STATE_CHECK(filter_func_(
-      CONVERT_TO_BYTEPTR(y_src), y_src_stride, CONVERT_TO_BYTEPTR(y_pre),
-      y_pre_stride, CONVERT_TO_BYTEPTR(u_src), CONVERT_TO_BYTEPTR(v_src),
-      uv_src_stride, CONVERT_TO_BYTEPTR(u_pre), CONVERT_TO_BYTEPTR(v_pre),
-      uv_pre_stride, block_width, block_height, ss_x, ss_y, strength, blk_fw,
-      use_32x32, y_accum, y_count, u_accum, u_count, v_accum, v_count));
+  (void)block_width;
+  (void)block_height;
+  (void)y_src_stride;
+  (void)uv_src_stride;
+
+  assert(block_width == MAX_WIDTH && MAX_WIDTH == 32);
+  assert(block_height == MAX_HEIGHT && MAX_HEIGHT == 32);
+  const BLOCK_SIZE block_size = BLOCK_32X32;
+  const int mb_pels = MAX_WIDTH * MAX_HEIGHT;
+  const int mb_row = 0;
+  const int mb_col = 0;
+  const int use_subblock = !(use_32x32);
+
+  YV12_BUFFER_CONFIG *ref_frame =
+      (YV12_BUFFER_CONFIG *)malloc(sizeof(YV12_BUFFER_CONFIG));
+  ref_frame->strides[0] = y_pre_stride;
+  ref_frame->strides[1] = uv_pre_stride;
+  const int alloc_size = MAX_MB_PLANE * mb_pels;
+  DECLARE_ALIGNED(16, uint16_t, src16[alloc_size]);
+  ref_frame->buffer_alloc = CONVERT_TO_BYTEPTR(src16);
+  ref_frame->buffers[0] = ref_frame->buffer_alloc + 0 * mb_pels;
+  ref_frame->buffers[1] = ref_frame->buffer_alloc + 1 * mb_pels;
+  ref_frame->buffers[2] = ref_frame->buffer_alloc + 2 * mb_pels;
+
+  MACROBLOCKD *mbd = (MACROBLOCKD *)malloc(sizeof(MACROBLOCKD));
+  mbd->bd = bd_;
+  mbd->plane[0].subsampling_y = 0;
+  mbd->plane[0].subsampling_x = 0;
+  mbd->plane[1].subsampling_y = ss_y;
+  mbd->plane[1].subsampling_x = ss_x;
+  mbd->plane[2].subsampling_y = ss_y;
+  mbd->plane[2].subsampling_x = ss_x;
+
+  DECLARE_ALIGNED(16, uint16_t, pred16[alloc_size]);
+  DECLARE_ALIGNED(16, uint32_t, accum[alloc_size]);
+  DECLARE_ALIGNED(16, uint16_t, count[alloc_size]);
+  memcpy(src16 + 0 * mb_pels, y_src, mb_pels * sizeof(uint16_t));
+  memcpy(src16 + 1 * mb_pels, u_src, mb_pels * sizeof(uint16_t));
+  memcpy(src16 + 2 * mb_pels, v_src, mb_pels * sizeof(uint16_t));
+  memcpy(pred16 + 0 * mb_pels, y_pre, mb_pels * sizeof(uint16_t));
+  memcpy(pred16 + 1 * mb_pels, u_pre, mb_pels * sizeof(uint16_t));
+  memcpy(pred16 + 2 * mb_pels, v_pre, mb_pels * sizeof(uint16_t));
+  memcpy(accum + 0 * mb_pels, y_accum, mb_pels * sizeof(uint32_t));
+  memcpy(accum + 1 * mb_pels, u_accum, mb_pels * sizeof(uint32_t));
+  memcpy(accum + 2 * mb_pels, v_accum, mb_pels * sizeof(uint32_t));
+  memcpy(count + 0 * mb_pels, y_count, mb_pels * sizeof(uint16_t));
+  memcpy(count + 1 * mb_pels, u_count, mb_pels * sizeof(uint16_t));
+  memcpy(count + 2 * mb_pels, v_count, mb_pels * sizeof(uint16_t));
+  const uint8_t *pred = CONVERT_TO_BYTEPTR(pred16);
+
+  ASM_REGISTER_STATE_CHECK(filter_func_(ref_frame, mbd, block_size, mb_row,
+                                        mb_col, strength, use_subblock, blk_fw,
+                                        pred, accum, count));
+
+  memcpy(y_accum, accum + 0 * mb_pels, mb_pels * sizeof(uint32_t));
+  memcpy(u_accum, accum + 1 * mb_pels, mb_pels * sizeof(uint32_t));
+  memcpy(v_accum, accum + 2 * mb_pels, mb_pels * sizeof(uint32_t));
+  memcpy(y_count, count + 0 * mb_pels, mb_pels * sizeof(uint16_t));
+  memcpy(u_count, count + 1 * mb_pels, mb_pels * sizeof(uint16_t));
+  memcpy(v_count, count + 2 * mb_pels, mb_pels * sizeof(uint16_t));
+
+  free(ref_frame);
+  free(mbd);
 }
 
 template <typename PixelType>
@@ -710,17 +821,18 @@
 INSTANTIATE_TEST_CASE_P(
     C, YUVTemporalFilterTest,
     ::testing::Values(
-        TemporalFilterWithBd(&av1_apply_temporal_filter_c, 8),
-        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_c, 10),
-        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_c, 12)));
+        TemporalFilterWithBd(&av1_apply_temporal_filter_yuv_c, 8),
+        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_yuv_c, 10),
+        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_yuv_c, 12)));
 
 #if HAVE_SSE4_1
 INSTANTIATE_TEST_CASE_P(
     SSE4_1, YUVTemporalFilterTest,
     ::testing::Values(
-        TemporalFilterWithBd(&av1_apply_temporal_filter_sse4_1, 8),
-        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_sse4_1, 10),
-        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_sse4_1, 12)));
+        TemporalFilterWithBd(&av1_apply_temporal_filter_yuv_sse4_1, 8),
+        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_yuv_sse4_1, 10),
+        TemporalFilterWithBd(&av1_highbd_apply_temporal_filter_yuv_sse4_1,
+                             12)));
 #endif  // HAVE_SSE4_1
 
 }  // namespace