Refactor `temporal_filter.c` (Step-3).

Refactor Step-3: Unify the interface for applying temporal filtering.

The refactoring is from following aspects:
  (1) Unify the interface for all three different temporal filtering
      strategies, i.e., YUV filtering, Y-plane only filtering,
      Plane-wise filtering.
  (2) Unify the interface for low bit-depth video and high bit-depth
      video.
  (3) Update unit test for Plane-wise temporal filtering.

Change-Id: I50c00618c806e068e470da01a1c5dc2b8fa73f32
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 620d500..0b6e96c 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -282,8 +282,8 @@
   }
 
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
-    add_proto qw/void av1_temporal_filter_plane/, "uint8_t *frame1, unsigned int stride, uint8_t *frame2, unsigned int stride2, int block_width, int block_height, int strength, double sigma, int decay_control, const int *blk_fw, int use_32x32, unsigned int *accumulator, uint16_t *count";
-    specialize qw/av1_temporal_filter_plane sse2 avx2/;
+    add_proto qw/void av1_apply_temporal_filter_planewise/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int num_planes, const double noise_level, const uint8_t *pred, uint32_t *accum, uint16_t *count";
+    specialize qw/av1_apply_temporal_filter_planewise sse2 avx2/;
   }
   add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, const qm_val_t * qm_ptr, const qm_val_t * iqm_ptr, int log_scale";
 
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index f0519cf..59f218d 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -129,40 +129,50 @@
   }
 }
 
-static void apply_temporal_filter_self(const uint8_t *pred, int buf_stride,
-                                       unsigned int block_width,
-                                       unsigned int block_height,
-                                       int filter_weight, uint32_t *accumulator,
-                                       uint16_t *count,
-                                       int use_new_temporal_mode) {
-  const int modifier = use_new_temporal_mode ? SCALE : filter_weight * 16;
-  unsigned int i, j, k = 0;
+// Computes temporal filter weights and accumulators for the reference frame.
+// More concretely, the filter weights for all pixels are the same.
+// Inputs:
+//   mbd: Pointer to the block for filtering.
+//   block_size: Size of the block.
+//   num_planes: Number of planes in the frame.
+//   filter_weight: Weight used for filtering.
+//   pred: Pointer to the well-built predictors.
+//   accum: Pointer to the pixel-wise accumulator for filtering.
+//   count: Pointer to the pixel-wise counter fot filtering.
+// Returns:
+//   Nothing will be returned. But the content to which `accum` and `pred`
+//   point will be modified.
+void av1_apply_temporal_filter_self(const MACROBLOCKD *mbd,
+                                    const BLOCK_SIZE block_size,
+                                    const int num_planes,
+                                    const int filter_weight,
+                                    const uint8_t *pred, uint32_t *accum,
+                                    uint16_t *count) {
+  // Block information.
+  const int mb_height = block_size_high[block_size];
+  const int mb_width = block_size_wide[block_size];
+  const int mb_pels = mb_height * mb_width;
+  const int is_high_bitdepth = is_cur_buf_hbd(mbd);
+  const uint16_t *pred16 = CONVERT_TO_SHORTPTR(pred);
 
-  for (i = 0; i < block_height; i++) {
-    for (j = 0; j < block_width; j++) {
-      const int pixel_value = pred[i * buf_stride + j];
-      count[k] += modifier;
-      accumulator[k] += modifier * pixel_value;
-      ++k;
+  int plane_offset = 0;
+  for (int plane = 0; plane < num_planes; ++plane) {
+    const int subsampling_y = mbd->plane[plane].subsampling_y;
+    const int subsampling_x = mbd->plane[plane].subsampling_x;
+    const int h = mb_height >> subsampling_y;  // Plane height.
+    const int w = mb_width >> subsampling_x;   // Plane width.
+
+    int pred_idx = 0;
+    for (int i = 0; i < h; ++i) {
+      for (int j = 0; j < w; ++j) {
+        const int idx = plane_offset + pred_idx;  // Index with plane shift.
+        const int pred_value = is_high_bitdepth ? pred16[idx] : pred[idx];
+        accum[idx] += filter_weight * pred_value;
+        count[idx] += filter_weight;
+        ++pred_idx;
+      }
     }
-  }
-}
-
-static void highbd_apply_temporal_filter_self(
-    const uint8_t *pred8, int buf_stride, unsigned int block_width,
-    unsigned int block_height, int filter_weight, uint32_t *accumulator,
-    uint16_t *count, int use_new_temporal_mode) {
-  const int modifier = use_new_temporal_mode ? SCALE : filter_weight * 16;
-  const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
-  unsigned int i, j, k = 0;
-
-  for (i = 0; i < block_height; i++) {
-    for (j = 0; j < block_width; j++) {
-      const int pixel_value = pred[i * buf_stride + j];
-      count[k] += modifier;
-      accumulator[k] += modifier * pixel_value;
-      ++k;
-    }
+    plane_offset += mb_pels;
   }
 }
 
@@ -401,15 +411,13 @@
           }
         }
 
+        const int idx = plane_offset + pred_idx;  // Index with plane shift.
+        const int pred_value = is_high_bitdepth ? pred16[idx] : pred[idx];
         const int adjusted_weight = adjust_filter_weight_yuv(
             filter_weight, sum_square_diff, num_ref_pixels, strength,
             is_high_bitdepth);
-
-        const int pred_value = is_high_bitdepth
-                                   ? pred16[plane_offset + pred_idx]
-                                   : pred[plane_offset + pred_idx];
-        accum[plane_offset + pred_idx] += adjusted_weight * pred_value;
-        count[plane_offset + pred_idx] += adjusted_weight;
+        accum[idx] += adjusted_weight * pred_value;
+        count[idx] += adjusted_weight;
 
         ++pred_idx;
       }
@@ -420,325 +428,287 @@
   aom_free(square_diff);
 }
 
-// Only used in single plane case
-void av1_temporal_filter_apply_c(uint8_t *frame1, unsigned int stride,
-                                 uint8_t *frame2, unsigned int block_width,
-                                 unsigned int block_height, int strength,
-                                 const int *blk_fw, int use_32x32,
-                                 unsigned int *accumulator, uint16_t *count) {
-  unsigned int i, j, k;
-  int modifier;
-  int byte = 0;
-  const int rounding = strength > 0 ? 1 << (strength - 1) : 0;
+// Function to adjust the filter weight when applying filter to Y-plane only.
+// Inputs:
+//   filter_weight: Original filter weight.
+//   sum_square_diff: Sum of squared difference between input frame and
+//                    prediction. This field is computed pixel by pixel, and
+//                    is used as a reference for the filter weight adjustment.
+//   num_ref_pixels: Number of pixels used to compute the `sum_square_diff`.
+//   strength: Strength for filter weight adjustment.
+// Returns:
+//   Adjusted filter weight which will finally be used for filtering.
+static INLINE int adjust_filter_weight_yonly(const int filter_weight,
+                                             const uint64_t sum_square_diff,
+                                             const int num_ref_pixels,
+                                             const int strength) {
+  assert(YONLY_FILTER_WINDOW_LENGTH == 3);
 
-  for (i = 0, k = 0; i < block_height; i++) {
-    for (j = 0; j < block_width; j++, k++) {
-      int pixel_value = *frame2;
+  int modifier = (int)(AOMMIN(sum_square_diff * 3, INT32_MAX));
+  modifier /= num_ref_pixels;
+
+  const int rounding = (1 << strength) >> 1;
+  modifier = (modifier + rounding) >> strength;
+  return (modifier >= 16) ? 0 : (16 - modifier) * filter_weight;
+}
+
+// Applies temporal filter to Y-plane ONLY.
+// Different from the function `av1_apply_temporal_filter_yuv_c()`, this
+// function only applies temporal filter to Y-plane. This should be used when
+// the input video frame only has one plane.
+// Inputs:
+//   ref_frame: Pointer to the frame for filtering.
+//   mbd: Pointer to the block for filtering.
+//   block_size: Size of the block.
+//   mb_row: Row index of the block in the entire frame.
+//   mb_col: Column index of the block in the entire frame.
+//   strength: Strength for filter weight adjustment.
+//   use_subblock: Whether to use four sub-blocks to replace the original block.
+//   subblock_filter_weights: The filter weights for each sub-block (row-major
+//                            order). If `use_subblock` is set as 0, the first
+//                            weight will be applied to the entire block.
+//   pred: Pointer to the well-built predictors.
+//   accum: Pointer to the pixel-wise accumulator for filtering.
+//   count: Pointer to the pixel-wise counter fot filtering.
+// Returns:
+//   Nothing will be returned. But the content to which `accum` and `pred`
+//   point will be modified.
+void av1_apply_temporal_filter_yonly(const YV12_BUFFER_CONFIG *ref_frame,
+                                     const MACROBLOCKD *mbd,
+                                     const BLOCK_SIZE block_size,
+                                     const int mb_row, const int mb_col,
+                                     const int strength, const int use_subblock,
+                                     const int *subblock_filter_weights,
+                                     const uint8_t *pred, uint32_t *accum,
+                                     uint16_t *count) {
+  // Block information.
+  const int mb_height = block_size_high[block_size];
+  const int mb_width = block_size_wide[block_size];
+  const int mb_pels = mb_height * mb_width;
+  const int is_high_bitdepth = ref_frame->flags & YV12_FLAG_HIGHBITDEPTH;
+  const uint16_t *pred16 = CONVERT_TO_SHORTPTR(pred);
+
+  // Y-plane information.
+  const int subsampling_y = mbd->plane[0].subsampling_y;
+  const int subsampling_x = mbd->plane[0].subsampling_x;
+  const int h = mb_height >> subsampling_y;
+  const int w = mb_width >> subsampling_x;
+
+  // Pre-compute squared difference before filtering.
+  const int frame_stride = ref_frame->strides[0];
+  const int frame_offset = mb_row * h * frame_stride + mb_col * w;
+  const uint8_t *ref = ref_frame->buffers[0];
+  uint32_t *square_diff = aom_memalign(16, mb_pels * sizeof(uint32_t));
+  memset(square_diff, 0, mb_pels * sizeof(uint32_t));
+  compute_square_diff(ref, frame_offset, frame_stride, pred, 0, w, h, w,
+                      is_high_bitdepth, square_diff);
+
+  // Get window size for pixel-wise filtering.
+  assert(YONLY_FILTER_WINDOW_LENGTH % 2 == 1);
+  const int half_window = YONLY_FILTER_WINDOW_LENGTH >> 1;
+
+  // Perform filtering.
+  int idx = 0;
+  for (int i = 0; i < h; ++i) {
+    for (int j = 0; j < w; ++j) {
       const int subblock_idx =
-          use_32x32 ? 0 : (i >= block_height / 2) * 2 + (j >= block_width / 2);
-      const int filter_weight = blk_fw[subblock_idx];
+          use_subblock ? (i >= h / 2) * 2 + (j >= w / 2) : 0;
+      const int filter_weight = subblock_filter_weights[subblock_idx];
 
       // non-local mean approach
-      int diff_sse[9] = { 0 };
-      int idx, idy, index = 0;
+      uint64_t sum_square_diff = 0;
+      int num_ref_pixels = 0;
 
-      for (idy = -1; idy <= 1; ++idy) {
-        for (idx = -1; idx <= 1; ++idx) {
-          int row = (int)i + idy;
-          int col = (int)j + idx;
-
-          if (row >= 0 && row < (int)block_height && col >= 0 &&
-              col < (int)block_width) {
-            int diff = frame1[byte + idy * (int)stride + idx] -
-                       frame2[idy * (int)block_width + idx];
-            diff_sse[index] = diff * diff;
-            ++index;
+      for (int wi = -half_window; wi <= half_window; ++wi) {
+        for (int wj = -half_window; wj <= half_window; ++wj) {
+          const int y = i + wi;  // Y-coord on the current plane.
+          const int x = j + wj;  // X-coord on the current plane.
+          if (y >= 0 && y < h && x >= 0 && x < w) {
+            sum_square_diff += square_diff[y * w + x];
+            ++num_ref_pixels;
           }
         }
       }
 
-      assert(index > 0);
+      const int pred_value = is_high_bitdepth ? pred16[idx] : pred[idx];
+      const int adjusted_weight = adjust_filter_weight_yonly(
+          filter_weight, sum_square_diff, num_ref_pixels, strength);
+      accum[idx] += adjusted_weight * pred_value;
+      count[idx] += adjusted_weight;
 
-      modifier = 0;
-      for (idx = 0; idx < 9; ++idx) modifier += diff_sse[idx];
-
-      modifier *= 3;
-      modifier /= index;
-
-      ++frame2;
-
-      modifier += rounding;
-      modifier >>= strength;
-
-      if (modifier > 16) modifier = 16;
-
-      modifier = 16 - modifier;
-      modifier *= filter_weight;
-
-      count[k] += modifier;
-      accumulator[k] += modifier * pixel_value;
-
-      byte++;
+      ++idx;
     }
-
-    byte += stride - block_width;
   }
+
+  aom_free(square_diff);
 }
 
-// Only used in single plane case
-void av1_highbd_temporal_filter_apply_c(
-    uint8_t *frame1_8, unsigned int stride, uint8_t *frame2_8,
-    unsigned int block_width, unsigned int block_height, int strength,
-    int *blk_fw, int use_32x32, unsigned int *accumulator, uint16_t *count) {
-  uint16_t *frame1 = CONVERT_TO_SHORTPTR(frame1_8);
-  uint16_t *frame2 = CONVERT_TO_SHORTPTR(frame2_8);
-  unsigned int i, j, k;
-  int modifier;
-  int byte = 0;
-  const int rounding = strength > 0 ? 1 << (strength - 1) : 0;
+// Applies temporal filter plane by plane.
+// Different from the function `av1_apply_temporal_filter_yuv_c()` and the
+// function `av1_apply_temporal_filter_yonly()`, this function applies temporal
+// filter to each plane independently. Besides, the strategy of filter weight
+// adjustment is different from the other two functions.
+// Inputs:
+//   ref_frame: Pointer to the frame for filtering.
+//   mbd: Pointer to the block for filtering.
+//   block_size: Size of the block.
+//   mb_row: Row index of the block in the entire frame.
+//   mb_col: Column index of the block in the entire frame.
+//   num_planes: Number of planes in the frame.
+//   noise_level: Estimated noise level for the current block.
+//   pred: Pointer to the well-built predictors.
+//   accum: Pointer to the pixel-wise accumulator for filtering.
+//   count: Pointer to the pixel-wise counter fot filtering.
+// Returns:
+//   Nothing will be returned. But the content to which `accum` and `pred`
+//   point will be modified.
+void av1_apply_temporal_filter_planewise_c(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int num_planes, const double noise_level, const uint8_t *pred,
+    uint32_t *accum, uint16_t *count) {
+  // Hyper-parameter for filter weight adjustment.
+  const int frame_height = ref_frame->heights[0] << mbd->plane[0].subsampling_y;
+  const int decay_control = frame_height >= 480 ? 4 : 3;
+  // Control factor for non-local mean approach.
+  const double r = (double)decay_control * (0.7 + log(noise_level + 0.5));
 
-  for (i = 0, k = 0; i < block_height; i++) {
-    for (j = 0; j < block_width; j++, k++) {
-      int pixel_value = *frame2;
-      const int subblock_idx =
-          use_32x32 ? 0 : (i >= block_height / 2) * 2 + (j >= block_width / 2);
-      const int filter_weight = blk_fw[subblock_idx];
+  // Block information.
+  const int mb_height = block_size_high[block_size];
+  const int mb_width = block_size_wide[block_size];
+  const int mb_pels = mb_height * mb_width;
+  const int is_high_bitdepth = ref_frame->flags & YV12_FLAG_HIGHBITDEPTH;
+  const uint16_t *pred16 = CONVERT_TO_SHORTPTR(pred);
 
-      // non-local mean approach
-      int diff_sse[9] = { 0 };
-      int idx, idy, index = 0;
+  // Allocate memory for pixel-wise squared differences for all planes. They,
+  // regardless of the subsampling, are assigned with memory of size `mb_pels`.
+  uint32_t *square_diff =
+      aom_memalign(16, num_planes * mb_pels * sizeof(uint32_t));
+  memset(square_diff, 0, num_planes * mb_pels * sizeof(uint32_t));
 
-      for (idy = -1; idy <= 1; ++idy) {
-        for (idx = -1; idx <= 1; ++idx) {
-          int row = (int)i + idy;
-          int col = (int)j + idx;
+  int plane_offset = 0;
+  for (int plane = 0; plane < num_planes; ++plane) {
+    // Locate pixel on reference frame.
+    const int plane_h = mb_height >> mbd->plane[plane].subsampling_y;
+    const int plane_w = mb_width >> mbd->plane[plane].subsampling_x;
+    const int frame_stride = ref_frame->strides[plane == 0 ? 0 : 1];
+    const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
+    const uint8_t *ref = ref_frame->buffers[plane];
+    compute_square_diff(ref, frame_offset, frame_stride, pred, plane_offset,
+                        plane_w, plane_h, plane_w, is_high_bitdepth,
+                        square_diff + plane_offset);
+    plane_offset += mb_pels;
+  }
 
-          if (row >= 0 && row < (int)block_height && col >= 0 &&
-              col < (int)block_width) {
-            int diff = frame1[byte + idy * (int)stride + idx] -
-                       frame2[idy * (int)block_width + idx];
-            diff_sse[index] = diff * diff;
-            ++index;
+  // Get window size for pixel-wise filtering.
+  assert(PLANEWISE_FILTER_WINDOW_LENGTH % 2 == 1);
+  const int half_window = PLANEWISE_FILTER_WINDOW_LENGTH >> 1;
+
+  // Handle planes in sequence.
+  plane_offset = 0;
+  for (int plane = 0; plane < num_planes; ++plane) {
+    const int subsampling_y = mbd->plane[plane].subsampling_y;
+    const int subsampling_x = mbd->plane[plane].subsampling_x;
+    const int h = mb_height >> subsampling_y;  // Plane height.
+    const int w = mb_width >> subsampling_x;   // Plane width.
+
+    // Perform filtering.
+    int pred_idx = 0;
+    for (int i = 0; i < h; ++i) {
+      for (int j = 0; j < w; ++j) {
+        // non-local mean approach
+        uint64_t sum_square_diff = 0;
+        int num_ref_pixels = 0;
+
+        for (int wi = -half_window; wi <= half_window; ++wi) {
+          for (int wj = -half_window; wj <= half_window; ++wj) {
+            const int y = CLIP(i + wi, 0, h - 1);  // Y-coord on current plane.
+            const int x = CLIP(j + wj, 0, w - 1);  // X-coord on current plane.
+            sum_square_diff += square_diff[plane_offset + y * w + x];
+            ++num_ref_pixels;
           }
         }
+
+        const int idx = plane_offset + pred_idx;  // Index with plane shift.
+        const int pred_value = is_high_bitdepth ? pred16[idx] : pred[idx];
+        const double scaled_diff = AOMMAX(
+            -(double)(sum_square_diff / num_ref_pixels) / (2 * r * r), -15.0);
+        const int adjusted_weight =
+            (int)(exp(scaled_diff) * PLANEWISE_FILTER_WEIGHT_SCALE);
+        accum[idx] += adjusted_weight * pred_value;
+        count[idx] += adjusted_weight;
+
+        ++pred_idx;
       }
-
-      assert(index > 0);
-
-      modifier = 0;
-      for (idx = 0; idx < 9; ++idx) modifier += diff_sse[idx];
-
-      modifier *= 3;
-      modifier /= index;
-
-      ++frame2;
-
-      modifier += rounding;
-      modifier >>= strength;
-
-      if (modifier > 16) modifier = 16;
-
-      modifier = 16 - modifier;
-      modifier *= filter_weight;
-
-      count[k] += modifier;
-      accumulator[k] += modifier * pixel_value;
-
-      byte++;
     }
-
-    byte += stride - block_width;
+    plane_offset += mb_pels;
   }
+
+  aom_free(square_diff);
 }
 
-#if EXPERIMENT_TEMPORAL_FILTER
-void av1_temporal_filter_plane_c(uint8_t *frame1, unsigned int stride,
-                                 uint8_t *frame2, unsigned int stride2,
-                                 int block_width, int block_height,
-                                 int strength, double sigma, int decay_control,
-                                 const int *blk_fw, int use_32x32,
-                                 unsigned int *accumulator, uint16_t *count) {
-  (void)strength;
-  (void)blk_fw;
-  (void)use_32x32;
-  const double h = decay_control * (0.7 + log(sigma + 0.5));
-  const double beta = 1.0;
-  for (int i = 0, k = 0; i < block_height; i++) {
-    for (int j = 0; j < block_width; j++, k++) {
-      const int pixel_value = frame2[i * stride2 + j];
-
-      int diff_sse = 0;
-      for (int idy = -WINDOW_LENGTH; idy <= WINDOW_LENGTH; ++idy) {
-        for (int idx = -WINDOW_LENGTH; idx <= WINDOW_LENGTH; ++idx) {
-          int row = i + idy;
-          int col = j + idx;
-          if (row < 0) row = 0;
-          if (row >= block_height) row = block_height - 1;
-          if (col < 0) col = 0;
-          if (col >= block_width) col = block_width - 1;
-
-          int diff = frame1[row * (int)stride + col] -
-                     frame2[row * (int)stride2 + col];
-          diff_sse += diff * diff;
-        }
-      }
-      diff_sse /= WINDOW_SIZE;
-
-      double scaled_diff = -diff_sse / (2 * beta * h * h);
-      // clamp the value to avoid underflow in exp()
-      if (scaled_diff < -15) scaled_diff = -15;
-      double w = exp(scaled_diff);
-      const int weight = (int)(w * SCALE);
-
-      count[k] += weight;
-      accumulator[k] += weight * pixel_value;
-    }
-  }
-}
-
-void av1_highbd_temporal_filter_plane_c(
-    uint8_t *frame1_8bit, unsigned int stride, uint8_t *frame2_8bit,
-    unsigned int stride2, int block_width, int block_height, int strength,
-    double sigma, int decay_control, const int *blk_fw, int use_32x32,
-    unsigned int *accumulator, uint16_t *count) {
-  (void)strength;
-  (void)blk_fw;
-  (void)use_32x32;
-  uint16_t *frame1 = CONVERT_TO_SHORTPTR(frame1_8bit);
-  uint16_t *frame2 = CONVERT_TO_SHORTPTR(frame2_8bit);
-  const double h = decay_control * (0.7 + log(sigma + 0.5));
-  const double beta = 1.0;
-  for (int i = 0, k = 0; i < block_height; i++) {
-    for (int j = 0; j < block_width; j++, k++) {
-      const int pixel_value = frame2[i * stride2 + j];
-
-      int diff_sse = 0;
-      for (int idy = -WINDOW_LENGTH; idy <= WINDOW_LENGTH; ++idy) {
-        for (int idx = -WINDOW_LENGTH; idx <= WINDOW_LENGTH; ++idx) {
-          int row = i + idy;
-          int col = j + idx;
-          if (row < 0) row = 0;
-          if (row >= block_height) row = block_height - 1;
-          if (col < 0) col = 0;
-          if (col >= block_width) col = block_width - 1;
-
-          int diff = frame1[row * (int)stride + col] -
-                     frame2[row * (int)stride2 + col];
-          diff_sse += diff * diff;
-        }
-      }
-      diff_sse /= WINDOW_SIZE;
-
-      double scaled_diff = -diff_sse / (2 * beta * h * h);
-      // clamp the value to avoid underflow in exp()
-      if (scaled_diff < -20) scaled_diff = -20;
-      double w = exp(scaled_diff);
-      const int weight = (int)(w * SCALE);
-
-      count[k] += weight;
-      accumulator[k] += weight * pixel_value;
-    }
-  }
-}
-
-void apply_temporal_filter_block(
-    YV12_BUFFER_CONFIG *frame, MACROBLOCKD *mbd, int mb_y_src_offset,
-    int mb_uv_src_offset, int mb_uv_width, int mb_uv_height, int num_planes,
-    uint8_t *predictor, int frame_height, int strength, double sigma,
-    int *blk_fw, int use_32x32, unsigned int *accumulator, uint16_t *count,
-    int use_new_temporal_mode, int mb_row, int mb_col) {
-  const int is_hbd = is_cur_buf_hbd(mbd);
-  // High bitdepth
-  if (is_hbd) {
-    if (use_new_temporal_mode) {
-      // Apply frame size dependent non-local means filtering.
-      int decay_control;
-      // The decay is obtained empirically, subject to better tuning.
-      if (frame_height >= 720) {
-        decay_control = 4;
-      } else if (frame_height >= 480) {
-        decay_control = 4;
-      } else {
-        decay_control = 3;
-      }
-      av1_highbd_temporal_filter_plane_c(frame->y_buffer + mb_y_src_offset,
-                                         frame->y_stride, predictor, BW, BW, BH,
-                                         strength, sigma, decay_control, blk_fw,
-                                         use_32x32, accumulator, count);
-      if (num_planes > 1) {
-        av1_highbd_temporal_filter_plane_c(
-            frame->u_buffer + mb_uv_src_offset, frame->uv_stride,
-            predictor + BLK_PELS, mb_uv_width, mb_uv_width, mb_uv_height,
-            strength, sigma, decay_control, blk_fw, use_32x32,
-            accumulator + BLK_PELS, count + BLK_PELS);
-        av1_highbd_temporal_filter_plane_c(
-            frame->v_buffer + mb_uv_src_offset, frame->uv_stride,
-            predictor + (BLK_PELS << 1), mb_uv_width, mb_uv_width, mb_uv_height,
-            strength, sigma, decay_control, blk_fw, use_32x32,
-            accumulator + (BLK_PELS << 1), count + (BLK_PELS << 1));
-      }
+// Computes temporal filter weights and accumulators from other frames excluding
+// the reference frame.
+// Inputs:
+//   ref_frame: Pointer to the frame for filtering.
+//   mbd: Pointer to the block for filtering.
+//   block_size: Size of the block.
+//   mb_row: Row index of the block in the entire frame.
+//   mb_col: Column index of the block in the entire frame.
+//   num_planes: Number of planes in the frame.
+//   use_new_strategy: Whether to use new temporal filtering strategy. If set as
+//                     1, Plane-wise filtering will be used, otherwise, YUV or
+//                     YONLY filtering will be used (depending on number of
+//                     planes).
+//   strength: Strength for filter weight adjustment. (Used in YUV filtering and
+//             YONLY filtering.)
+//   use_subblock: Whether to use four sub-blocks to replace the original block.
+//                 (Used in YUV filtering and YONLY filtering.)
+//   subblock_filter_weights: The filter weights for each sub-block (row-major
+//                            order). If `use_subblock` is set as 0, the first
+//                            weight will be applied to the entire block. (Used
+//                            in YUV filtering and YONLY filtering.)
+//   noise_level: Estimated noise level for the current block. (Used in
+//                Plane-wise filtering.)
+//   pred: Pointer to the well-built predictors.
+//   accum: Pointer to the pixel-wise accumulator for filtering.
+//   count: Pointer to the pixel-wise counter fot filtering.
+// Returns:
+//   Nothing will be returned. But the content to which `accum` and `pred`
+//   point will be modified.
+void av1_apply_temporal_filter_others(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int num_planes, const int use_planewise_strategy, const int strength,
+    const int use_subblock, const int *subblock_filter_weights,
+    const double noise_level, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
+  if (use_planewise_strategy) {  // Commonly used for high-resolution video.
+    const int is_high_bitdepth = ref_frame->flags & YV12_FLAG_HIGHBITDEPTH;
+    if (is_high_bitdepth) {
+      av1_apply_temporal_filter_planewise_c(ref_frame, mbd, block_size, mb_row,
+                                            mb_col, num_planes, noise_level,
+                                            pred, accum, count);
     } else {
-      // Apply original non-local means filtering for small resolution
-      const int adj_strength = strength + 2 * (mbd->bd - 8);
-      if (num_planes <= 1) {
-        // Single plane case
-        av1_highbd_temporal_filter_apply_c(
-            frame->y_buffer + mb_y_src_offset, frame->y_stride, predictor, BW,
-            BH, adj_strength, blk_fw, use_32x32, accumulator, count);
-      } else {
-        // Process 3 planes together.
-        av1_apply_temporal_filter_yuv(frame, mbd, TF_BLOCK, mb_row, mb_col,
-                                      adj_strength, !(use_32x32), blk_fw,
-                                      predictor, accumulator, count);
-      }
+      av1_apply_temporal_filter_planewise(ref_frame, mbd, block_size, mb_row,
+                                          mb_col, num_planes, noise_level, pred,
+                                          accum, count);
     }
-    return;
-  }
-
-  // Low bitdepth
-  if (use_new_temporal_mode) {
-    // Apply frame size dependent non-local means filtering.
-    int decay_control;
-    // The decay is obtained empirically, subject to better tuning.
-    if (frame_height >= 720) {
-      decay_control = 4;
-    } else if (frame_height >= 480) {
-      decay_control = 4;
+  } else {  // Commonly used for low-resolution video.
+    const int adj_strength = strength + 2 * (mbd->bd - 8);
+    if (num_planes == 1) {
+      av1_apply_temporal_filter_yonly(
+          ref_frame, mbd, block_size, mb_row, mb_col, adj_strength,
+          use_subblock, subblock_filter_weights, pred, accum, count);
+    } else if (num_planes == 3) {
+      av1_apply_temporal_filter_yuv(
+          ref_frame, mbd, block_size, mb_row, mb_col, adj_strength,
+          use_subblock, subblock_filter_weights, pred, accum, count);
     } else {
-      decay_control = 3;
-    }
-    av1_temporal_filter_plane(frame->y_buffer + mb_y_src_offset,
-                              frame->y_stride, predictor, BW, BW, BH, strength,
-                              sigma, decay_control, blk_fw, use_32x32,
-                              accumulator, count);
-    if (num_planes > 1) {
-      av1_temporal_filter_plane(
-          frame->u_buffer + mb_uv_src_offset, frame->uv_stride,
-          predictor + BLK_PELS, mb_uv_width, mb_uv_width, mb_uv_height,
-          strength, sigma, decay_control, blk_fw, use_32x32,
-          accumulator + BLK_PELS, count + BLK_PELS);
-      av1_temporal_filter_plane(
-          frame->v_buffer + mb_uv_src_offset, frame->uv_stride,
-          predictor + (BLK_PELS << 1), mb_uv_width, mb_uv_width, mb_uv_height,
-          strength, sigma, decay_control, blk_fw, use_32x32,
-          accumulator + (BLK_PELS << 1), count + (BLK_PELS << 1));
-    }
-  } else {
-    // Apply original non-local means filtering for small resolution
-    if (num_planes <= 1) {
-      // Single plane case
-      av1_temporal_filter_apply_c(frame->y_buffer + mb_y_src_offset,
-                                  frame->y_stride, predictor, BW, BH, strength,
-                                  blk_fw, use_32x32, accumulator, count);
-    } else {
-      // Process 3 planes together.
-      av1_apply_temporal_filter_yuv(frame, mbd, TF_BLOCK, mb_row, mb_col,
-                                    strength, !(use_32x32), blk_fw, predictor,
-                                    accumulator, count);
+      assert(0 && "Only support Y-plane and YUV-plane modes.");
     }
   }
 }
-#endif  // EXPERIMENT_TEMPORAL_FILTER
 
 static int temporal_filter_find_matching_mb_c(
     AV1_COMP *cpi, uint8_t *arf_frame_buf, uint8_t *frame_ptr_buf, int stride,
@@ -1034,73 +1004,17 @@
 
           // Apply the filter (YUV)
           if (frame == alt_ref_index) {
-            uint8_t *pred = predictor;
-            uint32_t *accum = accumulator;
-            uint16_t *cnt = count;
-            int plane;
-
-            // All 4 blk_fws are equal to 2.
-            for (plane = 0; plane < num_planes; ++plane) {
-              const int pred_stride = plane ? mb_uv_width : BW;
-              const unsigned int w = plane ? mb_uv_width : BW;
-              const unsigned int h = plane ? mb_uv_height : BH;
-
-              if (is_hbd) {
-                highbd_apply_temporal_filter_self(pred, pred_stride, w, h,
-                                                  blk_fw[0], accum, cnt,
-                                                  use_new_temporal_mode);
-              } else {
-                apply_temporal_filter_self(pred, pred_stride, w, h, blk_fw[0],
-                                           accum, cnt, use_new_temporal_mode);
-              }
-
-              pred += BLK_PELS;
-              accum += BLK_PELS;
-              cnt += BLK_PELS;
-            }
+            const int filter_weight = use_new_temporal_mode
+                                          ? PLANEWISE_FILTER_WEIGHT_SCALE
+                                          : blk_fw[0] * 16;
+            av1_apply_temporal_filter_self(mbd, TF_BLOCK, num_planes,
+                                           filter_weight, predictor,
+                                           accumulator, count);
           } else {
-            if (is_hbd) {
-#if EXPERIMENT_TEMPORAL_FILTER
-              apply_temporal_filter_block(
-                  f, mbd, mb_y_src_offset, mb_uv_src_offset, mb_uv_width,
-                  mb_uv_height, num_planes, predictor, cm->height, strength,
-                  sigma, blk_fw, use_32x32, accumulator, count,
-                  use_new_temporal_mode, mb_row, mb_col);
-#else
-              const int adj_strength = strength + 2 * (mbd->bd - 8);
-              if (num_planes <= 1) {
-                // Single plane case
-                av1_highbd_temporal_filter_apply_c(
-                    f->y_buffer + mb_y_src_offset, f->y_stride, predictor, BW,
-                    BH, adj_strength, blk_fw, use_32x32, accumulator, count);
-              } else {
-                // Process 3 planes together.
-                av1_apply_temporal_filter_yuv(
-                    frame, mbd, TF_BLOCK, mb_row, mb_col, adj_strength,
-                    !(use_32x32), blk_fw, predictor, accumulator, count);
-              }
-#endif  // EXPERIMENT_TEMPORAL_FILTER
-            } else {
-#if EXPERIMENT_TEMPORAL_FILTER
-              apply_temporal_filter_block(
-                  f, mbd, mb_y_src_offset, mb_uv_src_offset, mb_uv_width,
-                  mb_uv_height, num_planes, predictor, cm->height, strength,
-                  sigma, blk_fw, use_32x32, accumulator, count,
-                  use_new_temporal_mode, mb_row, mb_col);
-#else
-              if (num_planes <= 1) {
-                // Single plane case
-                av1_temporal_filter_apply_c(
-                    f->y_buffer + mb_y_src_offset, f->y_stride, predictor, BW,
-                    BH, strength, blk_fw, use_32x32, accumulator, count);
-              } else {
-                // Process 3 planes together.
-                av1_apply_temporal_filter_yuv(
-                    frame, mbd, TF_BLOCK, mb_row, mb_col, strength,
-                    !(use_32x32), blk_fw, predictor, accumulator, count);
-              }
-#endif  // EXPERIMENT_TEMPORAL_FILTER
-            }
+            av1_apply_temporal_filter_others(
+                f, mbd, TF_BLOCK, mb_row, mb_col, num_planes,
+                use_new_temporal_mode, strength, !(use_32x32), blk_fw, sigma,
+                predictor, accumulator, count);
           }
         }
       }
diff --git a/av1/encoder/temporal_filter.h b/av1/encoder/temporal_filter.h
index d7bbd66..b64e947 100644
--- a/av1/encoder/temporal_filter.h
+++ b/av1/encoder/temporal_filter.h
@@ -35,14 +35,22 @@
 #define SQRT_PI_BY_2 1.25331413732
 
 #define EXPERIMENT_TEMPORAL_FILTER 1
-#define WINDOW_LENGTH 2
-#define WINDOW_SIZE 25
-#define SCALE 1000
 
 // Window size for temporal filtering on YUV planes.
-// This is particually used for function `av1_apply_temporal_filter_yuv_c()`.
+// This is particually used for function `av1_apply_temporal_filter_yuv()`.
 static const int YUV_FILTER_WINDOW_LENGTH = 3;
 
+// Window size for temporal filtering on Y planes.
+// This is particually used for function `av1_apply_temporal_filter_yonly()`.
+static const int YONLY_FILTER_WINDOW_LENGTH = 3;
+
+// Window size for plane-wise temporal filtering.
+// This is particually used for function `av1_apply_temporal_filter_planewise()`
+static const int PLANEWISE_FILTER_WINDOW_LENGTH = 5;
+// A scale factor used in plane-wise temporal filtering to raise the filter
+// weight from `double` with range [0, 1] to `int` with range [0, 1000].
+static const int PLANEWISE_FILTER_WEIGHT_SCALE = 1000;
+
 static INLINE BLOCK_SIZE dims_to_size(int w, int h) {
   if (w != h) return -1;
   switch (w) {
diff --git a/av1/encoder/x86/temporal_filter_avx2.c b/av1/encoder/x86/temporal_filter_avx2.c
index 09325aa..af39913 100644
--- a/av1/encoder/x86/temporal_filter_avx2.c
+++ b/av1/encoder/x86/temporal_filter_avx2.c
@@ -18,7 +18,6 @@
 
 #define SSE_STRIDE (BW + 2)
 
-#if EXPERIMENT_TEMPORAL_FILTER
 DECLARE_ALIGNED(32, static const uint32_t, sse_bytemask[4][8]) = {
   { 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000 },
   { 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000 },
@@ -32,12 +31,12 @@
 };
 
 static AOM_FORCE_INLINE void get_squared_error_16x16_avx2(
-    uint8_t *frame1, unsigned int stride, uint8_t *frame2, unsigned int stride2,
-    int block_width, int block_height, uint16_t *frame_sse,
-    unsigned int sse_stride) {
+    const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
+    const unsigned int stride2, const int block_width, const int block_height,
+    uint16_t *frame_sse, const unsigned int sse_stride) {
   (void)block_width;
-  uint8_t *src1 = frame1;
-  uint8_t *src2 = frame2;
+  const uint8_t *src1 = frame1;
+  const uint8_t *src2 = frame2;
   uint16_t *dst = frame_sse;
   for (int i = 0; i < block_height; i++) {
     __m128i vf1_128, vf2_128;
@@ -60,12 +59,12 @@
 }
 
 static AOM_FORCE_INLINE void get_squared_error_32x32_avx2(
-    uint8_t *frame1, unsigned int stride, uint8_t *frame2, unsigned int stride2,
-    int block_width, int block_height, uint16_t *frame_sse,
-    unsigned int sse_stride) {
+    const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
+    const unsigned int stride2, const int block_width, const int block_height,
+    uint16_t *frame_sse, const unsigned int sse_stride) {
   (void)block_width;
-  uint8_t *src1 = frame1;
-  uint8_t *src2 = frame2;
+  const uint8_t *src1 = frame1;
+  const uint8_t *src2 = frame2;
   uint16_t *dst = frame_sse;
   for (int i = 0; i < block_height; i++) {
     __m256i vsrc1, vsrc2, vmin, vmax, vdiff, vdiff1, vdiff2, vres1, vres2;
@@ -128,22 +127,18 @@
   return _mm_extract_epi32(v128a, 0);
 }
 
-void av1_temporal_filter_plane_avx2(uint8_t *frame1, unsigned int stride,
-                                    uint8_t *frame2, unsigned int stride2,
-                                    int block_width, int block_height,
-                                    int strength, double sigma,
-                                    int decay_control, const int *blk_fw,
-                                    int use_32x32, unsigned int *accumulator,
-                                    uint16_t *count) {
-  (void)strength;
-  (void)blk_fw;
-  (void)use_32x32;
+static void apply_temporal_filter_planewise(
+    const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
+    const unsigned int stride2, const int block_width, const int block_height,
+    const double sigma, const int decay_control, unsigned int *accumulator,
+    uint16_t *count) {
   const double h = decay_control * (0.7 + log(sigma + 0.5));
   const double beta = 1.0;
 
   uint16_t frame_sse[SSE_STRIDE * BH];
   uint32_t acc_5x5_sse[BH][BW];
 
+  assert(PLANEWISE_FILTER_WINDOW_LENGTH == 5);
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
 
@@ -205,17 +200,47 @@
       const int pixel_value = frame2[i * stride2 + j];
 
       int diff_sse = acc_5x5_sse[i][j];
-      diff_sse /= WINDOW_SIZE;
+      diff_sse /=
+          (PLANEWISE_FILTER_WINDOW_LENGTH * PLANEWISE_FILTER_WINDOW_LENGTH);
 
       double scaled_diff = -diff_sse / (2 * beta * h * h);
       // clamp the value to avoid underflow in exp()
       if (scaled_diff < -15) scaled_diff = -15;
       double w = exp(scaled_diff);
-      const int weight = (int)(w * SCALE);
+      const int weight = (int)(w * PLANEWISE_FILTER_WEIGHT_SCALE);
 
       count[k] += weight;
       accumulator[k] += weight * pixel_value;
     }
   }
 }
-#endif
+
+void av1_apply_temporal_filter_planewise_avx2(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int num_planes, const double noise_level, const uint8_t *pred,
+    uint32_t *accum, uint16_t *count) {
+  const int is_high_bitdepth = ref_frame->flags & YV12_FLAG_HIGHBITDEPTH;
+  if (is_high_bitdepth) {
+    assert(0 && "Only support low bit-depth with avx2!");
+  }
+
+  const int frame_height = ref_frame->heights[0] << mbd->plane[0].subsampling_y;
+  const int decay_control = frame_height >= 480 ? 4 : 3;
+
+  const int mb_height = block_size_high[block_size];
+  const int mb_width = block_size_wide[block_size];
+  const int mb_pels = mb_height * mb_width;
+  for (int plane = 0; plane < num_planes; ++plane) {
+    const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
+    const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
+    const uint32_t frame_stride = ref_frame->strides[plane == 0 ? 0 : 1];
+    const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
+
+    const uint8_t *ref = ref_frame->buffers[plane] + frame_offset;
+    apply_temporal_filter_planewise(ref, frame_stride, pred + mb_pels * plane,
+                                    plane_w, plane_w, plane_h, noise_level,
+                                    decay_control, accum + mb_pels * plane,
+                                    count + mb_pels * plane);
+  }
+}
diff --git a/av1/encoder/x86/temporal_filter_sse2.c b/av1/encoder/x86/temporal_filter_sse2.c
index bc0dd51..4baa724 100644
--- a/av1/encoder/x86/temporal_filter_sse2.c
+++ b/av1/encoder/x86/temporal_filter_sse2.c
@@ -19,8 +19,6 @@
 // For the squared error buffer, keep a padding for 4 samples
 #define SSE_STRIDE (BW + 4)
 
-#if EXPERIMENT_TEMPORAL_FILTER
-
 DECLARE_ALIGNED(32, static const uint32_t, sse_bytemask_2x4[4][2][4]) = {
   { { 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF }, { 0xFFFF, 0x0000, 0x0000, 0x0000 } },
   { { 0x0000, 0xFFFF, 0xFFFF, 0xFFFF }, { 0xFFFF, 0xFFFF, 0x0000, 0x0000 } },
@@ -28,12 +26,13 @@
   { { 0x0000, 0x0000, 0x0000, 0xFFFF }, { 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF } }
 };
 
-static void get_squared_error(uint8_t *frame1, unsigned int stride,
-                              uint8_t *frame2, unsigned int stride2,
-                              int block_width, int block_height,
-                              uint16_t *frame_sse, unsigned int dst_stride) {
-  uint8_t *src1 = frame1;
-  uint8_t *src2 = frame2;
+static void get_squared_error(const uint8_t *frame1, const unsigned int stride,
+                              const uint8_t *frame2, const unsigned int stride2,
+                              const int block_width, const int block_height,
+                              uint16_t *frame_sse,
+                              const unsigned int dst_stride) {
+  const uint8_t *src1 = frame1;
+  const uint8_t *src2 = frame2;
   uint16_t *dst = frame_sse;
 
   for (int i = 0; i < block_height; i++) {
@@ -99,23 +98,18 @@
   return _mm_cvtsi128_si32(veca);
 }
 
-void av1_temporal_filter_plane_sse2(uint8_t *frame1, unsigned int stride,
-                                    uint8_t *frame2, unsigned int stride2,
-                                    int block_width, int block_height,
-                                    int strength, double sigma,
-                                    int decay_control, const int *blk_fw,
-                                    int use_32x32, unsigned int *accumulator,
-                                    uint16_t *count) {
-  (void)strength;
-  (void)blk_fw;
-  (void)use_32x32;
+static void apply_temporal_filter_planewise(
+    const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
+    const unsigned int stride2, const int block_width, const int block_height,
+    const double sigma, const int decay_control, unsigned int *accumulator,
+    uint16_t *count) {
   const double h = decay_control * (0.7 + log(sigma + 0.5));
   const double beta = 1.0;
 
   uint16_t frame_sse[SSE_STRIDE * BH];
   uint32_t acc_5x5_sse[BH][BW];
 
-  assert(WINDOW_LENGTH == 2);
+  assert(PLANEWISE_FILTER_WINDOW_LENGTH == 5);
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
 
@@ -179,17 +173,47 @@
       const int pixel_value = frame2[i * stride2 + j];
 
       int diff_sse = acc_5x5_sse[i][j];
-      diff_sse /= WINDOW_SIZE;
+      diff_sse /=
+          (PLANEWISE_FILTER_WINDOW_LENGTH * PLANEWISE_FILTER_WINDOW_LENGTH);
 
       double scaled_diff = -diff_sse / (2 * beta * h * h);
       // clamp the value to avoid underflow in exp()
       if (scaled_diff < -15) scaled_diff = -15;
       double w = exp(scaled_diff);
-      const int weight = (int)(w * SCALE);
+      const int weight = (int)(w * PLANEWISE_FILTER_WEIGHT_SCALE);
 
       count[k] += weight;
       accumulator[k] += weight * pixel_value;
     }
   }
 }
-#endif
+
+void av1_apply_temporal_filter_planewise_sse2(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int num_planes, const double noise_level, const uint8_t *pred,
+    uint32_t *accum, uint16_t *count) {
+  const int is_high_bitdepth = ref_frame->flags & YV12_FLAG_HIGHBITDEPTH;
+  if (is_high_bitdepth) {
+    assert(0 && "Only support low bit-depth with sse2!");
+  }
+
+  const int frame_height = ref_frame->heights[0] << mbd->plane[0].subsampling_y;
+  const int decay_control = frame_height >= 480 ? 4 : 3;
+
+  const int mb_height = block_size_high[block_size];
+  const int mb_width = block_size_wide[block_size];
+  const int mb_pels = mb_height * mb_width;
+  for (int plane = 0; plane < num_planes; ++plane) {
+    const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
+    const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
+    const uint32_t frame_stride = ref_frame->strides[plane == 0 ? 0 : 1];
+    const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
+
+    const uint8_t *ref = ref_frame->buffers[plane] + frame_offset;
+    apply_temporal_filter_planewise(ref, frame_stride, pred + mb_pels * plane,
+                                    plane_w, plane_w, plane_h, noise_level,
+                                    decay_control, accum + mb_pels * plane,
+                                    count + mb_pels * plane);
+  }
+}
diff --git a/test/temporal_filter_plane_test.cc b/test/temporal_filter_plane_test.cc
deleted file mode 100644
index 513cfc5..0000000
--- a/test/temporal_filter_plane_test.cc
+++ /dev/null
@@ -1,221 +0,0 @@
-/*
- * Copyright (c) 2019, Alliance for Open Media. All rights reserved
- *
- * This source code is subject to the terms of the BSD 2 Clause License and
- * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
- * was not distributed with this source code in the LICENSE file, you can
- * obtain it at www.aomedia.org/license/software. If the Alliance for Open
- * Media Patent License 1.0 was not distributed with this source code in the
- * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
- */
-
-#include <cmath>
-#include <cstdlib>
-#include <string>
-
-#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
-
-#include "config/aom_config.h"
-#include "config/aom_dsp_rtcd.h"
-#include "config/av1_rtcd.h"
-
-#include "aom_ports/mem.h"
-#include "test/acm_random.h"
-#include "test/clear_system_state.h"
-#include "test/register_state_check.h"
-#include "test/util.h"
-#include "test/function_equivalence_test.h"
-
-using libaom_test::ACMRandom;
-using libaom_test::FunctionEquivalenceTest;
-using ::testing::Combine;
-using ::testing::Range;
-using ::testing::Values;
-using ::testing::ValuesIn;
-
-#if !CONFIG_REALTIME_ONLY
-namespace {
-
-typedef void (*temporal_filter_plane_func)(
-    uint8_t *frame1, unsigned int stride, uint8_t *frame2, unsigned int stride2,
-    int block_width, int block_height, int strength, double sigma,
-    int decay_control, const int *blk_fw, int use_32x32,
-    unsigned int *accumulator, uint16_t *count);
-typedef libaom_test::FuncParam<temporal_filter_plane_func>
-    TestTemporal_FilterPlane;
-
-typedef ::testing::tuple<TestTemporal_FilterPlane, int> TemporalFilter_Params;
-
-class TemporalFilterTest
-    : public ::testing::TestWithParam<TemporalFilter_Params> {
- public:
-  virtual ~TemporalFilterTest() {}
-  virtual void SetUp() {
-    params_ = GET_PARAM(0);
-    rnd_.Reset(ACMRandom::DeterministicSeed());
-    src1_ = reinterpret_cast<uint8_t *>(aom_memalign(8, 256 * 256));
-    src2_ = reinterpret_cast<uint8_t *>(aom_memalign(8, 256 * 256));
-
-    ASSERT_TRUE(src1_ != NULL);
-    ASSERT_TRUE(src2_ != NULL);
-  }
-
-  virtual void TearDown() {
-    libaom_test::ClearSystemState();
-    aom_free(src1_);
-    aom_free(src2_);
-  }
-  void RunTest(int isRandom, int width, int height, int run_times);
-
-  void GenRandomData(int width, int height, int stride, int stride2) {
-    for (int ii = 0; ii < height; ii++) {
-      for (int jj = 0; jj < width; jj++) {
-        src1_[ii * stride + jj] = rnd_.Rand8();
-        src2_[ii * stride2 + jj] = rnd_.Rand8();
-      }
-    }
-  }
-
-  void GenExtremeData(int width, int height, int stride, uint8_t *data,
-                      int stride2, uint8_t *data2, uint8_t val) {
-    for (int ii = 0; ii < height; ii++) {
-      for (int jj = 0; jj < width; jj++) {
-        data[ii * stride + jj] = val;
-        data2[ii * stride2 + jj] = (255 - val);
-      }
-    }
-  }
-
- protected:
-  TestTemporal_FilterPlane params_;
-  uint8_t *src1_;
-  uint8_t *src2_;
-  ACMRandom rnd_;
-};
-
-void TemporalFilterTest::RunTest(int isRandom, int width, int height,
-                                 int run_times) {
-  aom_usec_timer ref_timer, test_timer;
-  for (int k = 0; k < 3; k++) {
-    int stride = 5 << rnd_(6);  // Up to 256 stride
-    int stride2 = 5 << rnd_(6);
-
-    while (stride < width) {  // Make sure it's valid
-      stride = 5 << rnd_(6);
-      stride2 = 5 << rnd_(6);
-    }
-    if (isRandom) {
-      GenRandomData(width, height, stride, stride2);
-    } else {
-      const int msb = 8;  // Up to 12 bit input
-      const int limit = (1 << msb) - 1;
-      if (k == 0) {
-        GenExtremeData(width, height, stride, src1_, stride2, src2_, limit);
-      } else {
-        GenExtremeData(width, height, stride, src1_, stride2, src2_, 0);
-      }
-    }
-    int use32X32 = 1;
-    int strength = rnd_(16);
-    double sigma = 2.1002103677063437;
-    int decay_control = 5;
-    int blk_fw = rnd_(16);
-    DECLARE_ALIGNED(16, unsigned int, accumulator_ref[1024 * 3]);
-    DECLARE_ALIGNED(16, uint16_t, count_ref[1024 * 3]);
-    memset(accumulator_ref, 0, 1024 * 3 * sizeof(accumulator_ref[0]));
-    memset(count_ref, 0, 1024 * 3 * sizeof(count_ref[0]));
-    DECLARE_ALIGNED(16, unsigned int, accumulator_mod[1024 * 3]);
-    DECLARE_ALIGNED(16, uint16_t, count_mod[1024 * 3]);
-    memset(accumulator_mod, 0, 1024 * 3 * sizeof(accumulator_mod[0]));
-    memset(count_mod, 0, 1024 * 3 * sizeof(count_mod[0]));
-
-    params_.ref_func(src1_, stride, src2_, stride2, width, height, strength,
-                     sigma, decay_control, &blk_fw, use32X32, accumulator_ref,
-                     count_ref);
-    params_.tst_func(src1_, stride, src2_, stride2, width, height, strength,
-                     sigma, decay_control, &blk_fw, use32X32, accumulator_mod,
-                     count_mod);
-
-    if (run_times > 1) {
-      aom_usec_timer_start(&ref_timer);
-      for (int j = 0; j < run_times; j++) {
-        params_.ref_func(src1_, stride, src2_, stride2, width, height, strength,
-                         sigma, decay_control, &blk_fw, use32X32,
-                         accumulator_ref, count_ref);
-      }
-      aom_usec_timer_mark(&ref_timer);
-      const int elapsed_time_c =
-          static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
-
-      aom_usec_timer_start(&test_timer);
-      for (int j = 0; j < run_times; j++) {
-        params_.tst_func(src1_, stride, src2_, stride2, width, height, strength,
-                         sigma, decay_control, &blk_fw, use32X32,
-                         accumulator_mod, count_mod);
-      }
-      aom_usec_timer_mark(&test_timer);
-      const int elapsed_time_simd =
-          static_cast<int>(aom_usec_timer_elapsed(&test_timer));
-
-      printf(
-          "c_time=%d \t simd_time=%d \t "
-          "gain=%f\t width=%d\t height=%d \n",
-          elapsed_time_c, elapsed_time_simd,
-          (float)((float)elapsed_time_c / (float)elapsed_time_simd), width,
-          height);
-
-    } else {
-      for (int i = 0, l = 0; i < height; i++) {
-        for (int j = 0; j < width; j++, l++) {
-          EXPECT_EQ(accumulator_ref[l], accumulator_mod[l])
-              << "Error:" << k << " SSE Sum Test [" << width << "x" << height
-              << "] C accumulator does not match optimized accumulator.";
-          EXPECT_EQ(count_ref[l], count_mod[l])
-              << "Error:" << k << " SSE Sum Test [" << width << "x" << height
-              << "] C count does not match optimized count.";
-        }
-      }
-    }
-  }
-}
-
-TEST_P(TemporalFilterTest, OperationCheck) {
-  for (int height = 16; height <= 32; height = height * 2) {
-    RunTest(1, height, height, 1);  // GenRandomData
-  }
-}
-
-TEST_P(TemporalFilterTest, ExtremeValues) {
-  for (int height = 16; height <= 32; height = height * 2) {
-    RunTest(0, height, height, 1);
-  }
-}
-
-TEST_P(TemporalFilterTest, DISABLED_Speed) {
-  for (int height = 16; height <= 32; height = height * 2) {
-    RunTest(1, height, height, 100000);
-  }
-}
-
-#if HAVE_AVX2
-TestTemporal_FilterPlane Temporal_filter_test_avx2[] = {
-  TestTemporal_FilterPlane(&av1_temporal_filter_plane_c,
-                           &av1_temporal_filter_plane_avx2)
-};
-INSTANTIATE_TEST_CASE_P(AVX2, TemporalFilterTest,
-                        Combine(ValuesIn(Temporal_filter_test_avx2),
-                                Range(64, 65, 4)));
-#endif  // HAVE_AVX2
-
-#if HAVE_SSE2
-TestTemporal_FilterPlane Temporal_filter_test_sse2[] = {
-  TestTemporal_FilterPlane(&av1_temporal_filter_plane_c,
-                           &av1_temporal_filter_plane_sse2)
-};
-INSTANTIATE_TEST_CASE_P(SSE2, TemporalFilterTest,
-                        Combine(ValuesIn(Temporal_filter_test_sse2),
-                                Range(64, 65, 4)));
-#endif  // HAVE_SSE2
-
-}  // namespace
-#endif
diff --git a/test/temporal_filter_planewise_test.cc b/test/temporal_filter_planewise_test.cc
new file mode 100644
index 0000000..947ffde
--- /dev/null
+++ b/test/temporal_filter_planewise_test.cc
@@ -0,0 +1,231 @@
+/*
+ * Copyright (c) 2019, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <cmath>
+#include <cstdlib>
+#include <string>
+
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+#include "config/av1_rtcd.h"
+
+#include "aom_ports/mem.h"
+#include "test/acm_random.h"
+#include "test/clear_system_state.h"
+#include "test/register_state_check.h"
+#include "test/util.h"
+#include "test/function_equivalence_test.h"
+
+using libaom_test::ACMRandom;
+using libaom_test::FunctionEquivalenceTest;
+using ::testing::Combine;
+using ::testing::Range;
+using ::testing::Values;
+using ::testing::ValuesIn;
+
+#if !CONFIG_REALTIME_ONLY
+namespace {
+
+typedef void (*TemporalFilterPlanewiseFunc)(
+    const YV12_BUFFER_CONFIG *ref_frame, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int num_planes, const double noise_level, const uint8_t *pred,
+    uint32_t *accum, uint16_t *count);
+typedef libaom_test::FuncParam<TemporalFilterPlanewiseFunc>
+    TemporalFilterPlanewiseFuncParam;
+
+typedef ::testing::tuple<TemporalFilterPlanewiseFuncParam, int>
+    TemporalFilterPlanewiseWithParam;
+
+class TemporalFilterPlanewiseTest
+    : public ::testing::TestWithParam<TemporalFilterPlanewiseWithParam> {
+ public:
+  virtual ~TemporalFilterPlanewiseTest() {}
+  virtual void SetUp() {
+    params_ = GET_PARAM(0);
+    rnd_.Reset(ACMRandom::DeterministicSeed());
+    src1_ = reinterpret_cast<uint8_t *>(aom_memalign(8, 256 * 256));
+    src2_ = reinterpret_cast<uint8_t *>(aom_memalign(8, 256 * 256));
+
+    ASSERT_TRUE(src1_ != NULL);
+    ASSERT_TRUE(src2_ != NULL);
+  }
+
+  virtual void TearDown() {
+    libaom_test::ClearSystemState();
+    aom_free(src1_);
+    aom_free(src2_);
+  }
+  void RunTest(int isRandom, int width, int height, int run_times);
+
+  void GenRandomData(int width, int height, int stride, int stride2) {
+    for (int ii = 0; ii < height; ii++) {
+      for (int jj = 0; jj < width; jj++) {
+        src1_[ii * stride + jj] = rnd_.Rand8();
+        src2_[ii * stride2 + jj] = rnd_.Rand8();
+      }
+    }
+  }
+
+  void GenExtremeData(int width, int height, int stride, uint8_t *data,
+                      int stride2, uint8_t *data2, uint8_t val) {
+    for (int ii = 0; ii < height; ii++) {
+      for (int jj = 0; jj < width; jj++) {
+        data[ii * stride + jj] = val;
+        data2[ii * stride2 + jj] = (255 - val);
+      }
+    }
+  }
+
+ protected:
+  TemporalFilterPlanewiseFuncParam params_;
+  uint8_t *src1_;
+  uint8_t *src2_;
+  ACMRandom rnd_;
+};
+
+void TemporalFilterPlanewiseTest::RunTest(int isRandom, int width, int height,
+                                          int run_times) {
+  aom_usec_timer ref_timer, test_timer;
+  for (int k = 0; k < 3; k++) {
+    const int stride = width;
+    const int stride2 = width;
+    if (isRandom) {
+      GenRandomData(width, height, stride, stride2);
+    } else {
+      const int msb = 8;  // Up to 8 bit input
+      const int limit = (1 << msb) - 1;
+      if (k == 0) {
+        GenExtremeData(width, height, stride, src1_, stride2, src2_, limit);
+      } else {
+        GenExtremeData(width, height, stride, src1_, stride2, src2_, 0);
+      }
+    }
+    double sigma = 2.1002103677063437;
+    DECLARE_ALIGNED(16, unsigned int, accumulator_ref[1024 * 3]);
+    DECLARE_ALIGNED(16, uint16_t, count_ref[1024 * 3]);
+    memset(accumulator_ref, 0, 1024 * 3 * sizeof(accumulator_ref[0]));
+    memset(count_ref, 0, 1024 * 3 * sizeof(count_ref[0]));
+    DECLARE_ALIGNED(16, unsigned int, accumulator_mod[1024 * 3]);
+    DECLARE_ALIGNED(16, uint16_t, count_mod[1024 * 3]);
+    memset(accumulator_mod, 0, 1024 * 3 * sizeof(accumulator_mod[0]));
+    memset(count_mod, 0, 1024 * 3 * sizeof(count_mod[0]));
+
+    assert(width == 32 && height == 32);
+    const BLOCK_SIZE block_size = BLOCK_32X32;
+    const int mb_row = 0;
+    const int mb_col = 0;
+    const int num_planes = 1;
+    YV12_BUFFER_CONFIG *ref_frame =
+        (YV12_BUFFER_CONFIG *)malloc(sizeof(YV12_BUFFER_CONFIG));
+    ref_frame->heights[0] = height;
+    ref_frame->strides[0] = stride;
+    DECLARE_ALIGNED(16, uint8_t, src[1024 * 3]);
+    ref_frame->buffer_alloc = src;
+    ref_frame->buffers[0] = ref_frame->buffer_alloc;
+    ref_frame->flags = 0;  // Only support low bit-depth test.
+    memcpy(src, src1_, 1024 * 3 * sizeof(uint8_t));
+
+    MACROBLOCKD *mbd = (MACROBLOCKD *)malloc(sizeof(MACROBLOCKD));
+    mbd->plane[0].subsampling_y = 0;
+    mbd->plane[0].subsampling_x = 0;
+
+    params_.ref_func(ref_frame, mbd, block_size, mb_row, mb_col, num_planes,
+                     sigma, src2_, accumulator_ref, count_ref);
+    params_.tst_func(ref_frame, mbd, block_size, mb_row, mb_col, num_planes,
+                     sigma, src2_, accumulator_mod, count_mod);
+
+    if (run_times > 1) {
+      aom_usec_timer_start(&ref_timer);
+      for (int j = 0; j < run_times; j++) {
+        params_.ref_func(ref_frame, mbd, block_size, mb_row, mb_col, num_planes,
+                         sigma, src2_, accumulator_ref, count_ref);
+      }
+      aom_usec_timer_mark(&ref_timer);
+      const int elapsed_time_c =
+          static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
+
+      aom_usec_timer_start(&test_timer);
+      for (int j = 0; j < run_times; j++) {
+        params_.tst_func(ref_frame, mbd, block_size, mb_row, mb_col, num_planes,
+                         sigma, src2_, accumulator_mod, count_mod);
+      }
+      aom_usec_timer_mark(&test_timer);
+      const int elapsed_time_simd =
+          static_cast<int>(aom_usec_timer_elapsed(&test_timer));
+
+      printf(
+          "c_time=%d \t simd_time=%d \t "
+          "gain=%f\t width=%d\t height=%d \n",
+          elapsed_time_c, elapsed_time_simd,
+          (float)((float)elapsed_time_c / (float)elapsed_time_simd), width,
+          height);
+
+    } else {
+      for (int i = 0, l = 0; i < height; i++) {
+        for (int j = 0; j < width; j++, l++) {
+          EXPECT_EQ(accumulator_ref[l], accumulator_mod[l])
+              << "Error:" << k << " SSE Sum Test [" << width << "x" << height
+              << "] C accumulator does not match optimized accumulator.";
+          EXPECT_EQ(count_ref[l], count_mod[l])
+              << "Error:" << k << " SSE Sum Test [" << width << "x" << height
+              << "] C count does not match optimized count.";
+        }
+      }
+    }
+
+    free(ref_frame);
+    free(mbd);
+  }
+}
+
+TEST_P(TemporalFilterPlanewiseTest, OperationCheck) {
+  for (int height = 32; height <= 32; height = height * 2) {
+    RunTest(1, height, height, 1);  // GenRandomData
+  }
+}
+
+TEST_P(TemporalFilterPlanewiseTest, ExtremeValues) {
+  for (int height = 32; height <= 32; height = height * 2) {
+    RunTest(0, height, height, 1);
+  }
+}
+
+TEST_P(TemporalFilterPlanewiseTest, DISABLED_Speed) {
+  for (int height = 32; height <= 32; height = height * 2) {
+    RunTest(1, height, height, 100000);
+  }
+}
+
+#if HAVE_AVX2
+TemporalFilterPlanewiseFuncParam temporal_filter_planewise_test_avx2[] = {
+  TemporalFilterPlanewiseFuncParam(&av1_apply_temporal_filter_planewise_c,
+                                   &av1_apply_temporal_filter_planewise_avx2)
+};
+INSTANTIATE_TEST_CASE_P(AVX2, TemporalFilterPlanewiseTest,
+                        Combine(ValuesIn(temporal_filter_planewise_test_avx2),
+                                Range(64, 65, 4)));
+#endif  // HAVE_AVX2
+
+#if HAVE_SSE2
+TemporalFilterPlanewiseFuncParam temporal_filter_planewise_test_sse2[] = {
+  TemporalFilterPlanewiseFuncParam(&av1_apply_temporal_filter_planewise_c,
+                                   &av1_apply_temporal_filter_planewise_sse2)
+};
+INSTANTIATE_TEST_CASE_P(SSE2, TemporalFilterPlanewiseTest,
+                        Combine(ValuesIn(temporal_filter_planewise_test_sse2),
+                                Range(64, 65, 4)));
+#endif  // HAVE_SSE2
+
+}  // namespace
+#endif
diff --git a/test/test.cmake b/test/test.cmake
index 05f0864..b66159b 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -131,7 +131,7 @@
                 "${AOM_ROOT}/test/segment_binarization_sync.cc"
                 "${AOM_ROOT}/test/superframe_test.cc"
                 "${AOM_ROOT}/test/tile_independence_test.cc"
-                "${AOM_ROOT}/test/temporal_filter_plane_test.cc"
+                "${AOM_ROOT}/test/temporal_filter_planewise_test.cc"
                 "${AOM_ROOT}/test/temporal_filter_yuv_test.cc")
     if(CONFIG_REALTIME_ONLY)
       list(REMOVE_ITEM AOM_UNIT_TEST_COMMON_SOURCES