Simplify computations in av1_apply_temporal_filter

This CL simplifies the computations in av1_apply_temporal_filter()
by replacing floating-point divisions in double precision with
multiplication operations. The C, SSE2 (low-bd, high-bd) and
AVX2 (low-bd) variants have been modified accordingly.

This CL is expected to be near bit-exact with minimal quality impact.

STATS_CHANGED expected

BUG=aomedia:2761

Change-Id: I4e373758ebbe363834a68c62f57c6670921541be
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index 5449a33..4d954b5 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -514,6 +514,24 @@
   const int frame_height = frame_to_filter->y_crop_height;
   const int frame_width = frame_to_filter->y_crop_width;
   const int min_frame_size = AOMMIN(frame_height, frame_width);
+  // Variables to simplify combined error calculation.
+  const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
+                                   TF_SEARCH_ERROR_NORM_WEIGHT);
+  const double weight_factor =
+      (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
+  // Decay factors for non-local mean approach.
+  double decay_factor[MAX_MB_PLANE] = { 0 };
+  // Smaller q -> smaller filtering weight.
+  double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
+  q_decay = CLIP(q_decay, 1e-5, 1);
+  // Smaller strength -> smaller filtering weight.
+  double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
+  s_decay = CLIP(s_decay, 1e-5, 1);
+  for (int plane = 0; plane < num_planes; plane++) {
+    // Larger noise -> larger filtering weight.
+    const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
+    decay_factor[plane] = 1 / (n_decay * q_decay * s_decay);
+  }
 
   // Allocate memory for pixel-wise squared differences for all planes. They,
   // regardless of the subsampling, are assigned with memory of size `mb_pels`.
@@ -546,6 +564,13 @@
     const int subsampling_x = mbd->plane[plane].subsampling_x;
     const int h = mb_height >> subsampling_y;  // Plane height.
     const int w = mb_width >> subsampling_x;   // Plane width.
+    const int ss_y_shift =
+        subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
+    const int ss_x_shift =
+        subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x;
+    const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
+                               ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
+    const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
 
     // Perform filtering.
     int pred_idx = 0;
@@ -553,14 +578,12 @@
       for (int j = 0; j < w; ++j) {
         // non-local mean approach
         uint64_t sum_square_diff = 0;
-        int num_ref_pixels = 0;
 
         for (int wi = -half_window; wi <= half_window; ++wi) {
           for (int wj = -half_window; wj <= half_window; ++wj) {
             const int y = CLIP(i + wi, 0, h - 1);  // Y-coord on current plane.
             const int x = CLIP(j + wj, 0, w - 1);  // X-coord on current plane.
             sum_square_diff += square_diff[plane_offset + y * w + x];
-            ++num_ref_pixels;
           }
         }
 
@@ -568,15 +591,12 @@
         // search is only done on Y-plane, so the information from Y-plane will
         // be more accurate.
         if (plane != 0) {
-          const int ss_y_shift = subsampling_y - mbd->plane[0].subsampling_y;
-          const int ss_x_shift = subsampling_x - mbd->plane[0].subsampling_x;
           for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
             for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
               const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
               const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
               const int ww = w << ss_x_shift;         // Width of Y-plane.
               sum_square_diff += square_diff[yy * ww + xx];
-              ++num_ref_pixels;
             }
           }
         }
@@ -585,22 +605,12 @@
         if (mbd->bd > 8) sum_square_diff >>= ((mbd->bd - 8) * 2);
 
         // Combine window error and block error, and normalize it.
-        const double window_error = (double)sum_square_diff / num_ref_pixels;
+        const double window_error = sum_square_diff * inv_num_ref_pixels;
         const int subblock_idx = (i >= h / 2) * 2 + (j >= w / 2);
         const double block_error = (double)subblock_mses[subblock_idx];
         const double combined_error =
-            (TF_WINDOW_BLOCK_BALANCE_WEIGHT * window_error + block_error) /
-            (TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) / TF_SEARCH_ERROR_NORM_WEIGHT;
+            weight_factor * window_error + block_error * inv_factor;
 
-        // Decay factors for non-local mean approach.
-        // Larger noise -> larger filtering weight.
-        const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
-        // Smaller q -> smaller filtering weight.
-        const double q_decay =
-            CLIP(pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2), 1e-5, 1);
-        // Smaller strength -> smaller filtering weight.
-        const double s_decay = CLIP(
-            pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2), 1e-5, 1);
         // Larger motion vector -> smaller filtering weight.
         const MV mv = subblock_mvs[subblock_idx];
         const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
@@ -609,8 +619,8 @@
         const double d_factor = AOMMAX(distance / distance_threshold, 1);
 
         // Compute filter weight.
-        const double scaled_error =
-            AOMMIN(combined_error * d_factor / n_decay / q_decay / s_decay, 7);
+        double scaled_error = combined_error * d_factor * decay_factor[plane];
+        scaled_error = AOMMIN(scaled_error, 7);
         const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
         const int idx = plane_offset + pred_idx;  // Index with plane shift.
diff --git a/av1/encoder/x86/highbd_temporal_filter_sse2.c b/av1/encoder/x86/highbd_temporal_filter_sse2.c
index af623b8..cf08b66 100644
--- a/av1/encoder/x86/highbd_temporal_filter_sse2.c
+++ b/av1/encoder/x86/highbd_temporal_filter_sse2.c
@@ -92,11 +92,11 @@
 static void highbd_apply_temporal_filter(
     const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
     const unsigned int stride2, const int block_width, const int block_height,
-    const int min_frame_size, const double sigma, const MV *subblock_mvs,
-    const int *subblock_mses, const int q_factor, const int filter_strength,
+    const int min_frame_size, const MV *subblock_mvs, const int *subblock_mses,
     unsigned int *accumulator, uint16_t *count, uint32_t *luma_sq_error,
     uint32_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift,
-    int bd) {
+    int bd, const double inv_num_ref_pixels, const double decay_factor,
+    const double inv_factor, const double weight_factor) {
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
   if (plane > PLANE_TYPE_Y) assert(chroma_sq_error != NULL);
@@ -110,12 +110,6 @@
 
   __m128i vsrc[5][2];
 
-  const double n_decay = 0.5 + log(2 * sigma + 5.0);
-  const double q_decay =
-      CLIP(pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2), 1e-5, 1);
-  const double s_decay =
-      CLIP(pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2), 1e-5, 1);
-
   // Traverse 4 columns at a time
   // First and last columns will require padding
   for (int col = 0; col < block_width; col += 4) {
@@ -214,13 +208,12 @@
       // Scale down the difference for high bit depth input.
       diff_sse >>= ((bd - 8) * 2);
 
-      const double window_error = (double)(diff_sse) / num_ref_pixels;
+      const double window_error = diff_sse * inv_num_ref_pixels;
       const int subblock_idx =
           (i >= block_height / 2) * 2 + (j >= block_width / 2);
       const double block_error = (double)subblock_mses[subblock_idx];
       const double combined_error =
-          (TF_WINDOW_BLOCK_BALANCE_WEIGHT * window_error + block_error) /
-          (TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) / TF_SEARCH_ERROR_NORM_WEIGHT;
+          weight_factor * window_error + block_error * inv_factor;
 
       const MV mv = subblock_mvs[subblock_idx];
       const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
@@ -228,8 +221,8 @@
           (double)AOMMAX(min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD, 1);
       const double d_factor = AOMMAX(distance / distance_threshold, 1);
 
-      const double scaled_error =
-          AOMMIN(combined_error * d_factor / n_decay / q_decay / s_decay, 7);
+      double scaled_error = combined_error * d_factor * decay_factor;
+      scaled_error = AOMMIN(scaled_error, 7);
       const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
       count[k] += weight;
@@ -256,6 +249,18 @@
   const int frame_height = frame_to_filter->y_crop_height;
   const int frame_width = frame_to_filter->y_crop_width;
   const int min_frame_size = AOMMIN(frame_height, frame_width);
+  // Variables to simplify combined error calculation.
+  const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
+                                   TF_SEARCH_ERROR_NORM_WEIGHT);
+  const double weight_factor =
+      (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
+  // Decay factors for non-local mean approach.
+  // Smaller q -> smaller filtering weight.
+  double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
+  q_decay = CLIP(q_decay, 1e-5, 1);
+  // Smaller strength -> smaller filtering weight.
+  double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
+  s_decay = CLIP(s_decay, 1e-5, 1);
   uint32_t luma_sq_error[SSE_STRIDE * BH];
   uint32_t *chroma_sq_error =
       (num_planes > 0)
@@ -274,13 +279,19 @@
         mbd->plane[plane].subsampling_x - mbd->plane[0].subsampling_x;
     const int ss_y_shift =
         mbd->plane[plane].subsampling_y - mbd->plane[0].subsampling_y;
+    const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
+                               ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
+    const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
+    // Larger noise -> larger filtering weight.
+    const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
+    const double decay_factor = 1 / (n_decay * q_decay * s_decay);
 
     highbd_apply_temporal_filter(
         ref, frame_stride, pred1 + mb_pels * plane, plane_w, plane_w, plane_h,
-        min_frame_size, noise_levels[plane], subblock_mvs, subblock_mses,
-        q_factor, filter_strength, accum + mb_pels * plane,
+        min_frame_size, subblock_mvs, subblock_mses, accum + mb_pels * plane,
         count + mb_pels * plane, luma_sq_error, chroma_sq_error, plane,
-        ss_x_shift, ss_y_shift, mbd->bd);
+        ss_x_shift, ss_y_shift, mbd->bd, inv_num_ref_pixels, decay_factor,
+        inv_factor, weight_factor);
   }
   if (chroma_sq_error != NULL) aom_free(chroma_sq_error);
 }
diff --git a/av1/encoder/x86/temporal_filter_avx2.c b/av1/encoder/x86/temporal_filter_avx2.c
index 5f36737..52a3964 100644
--- a/av1/encoder/x86/temporal_filter_avx2.c
+++ b/av1/encoder/x86/temporal_filter_avx2.c
@@ -130,10 +130,11 @@
 static void apply_temporal_filter(
     const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
     const unsigned int stride2, const int block_width, const int block_height,
-    const int min_frame_size, const double sigma, const MV *subblock_mvs,
-    const int *subblock_mses, const int q_factor, const int filter_strength,
+    const int min_frame_size, const MV *subblock_mvs, const int *subblock_mses,
     unsigned int *accumulator, uint16_t *count, uint16_t *luma_sq_error,
-    uint16_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift) {
+    uint16_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift,
+    const double inv_num_ref_pixels, const double decay_factor,
+    const double inv_factor, const double weight_factor) {
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
   if (plane > PLANE_TYPE_Y) assert(chroma_sq_error != NULL);
@@ -152,12 +153,6 @@
 
   __m256i vsrc[5];
 
-  const double n_decay = 0.5 + log(2 * sigma + 5.0);
-  const double q_decay =
-      CLIP(pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2), 1e-5, 1);
-  const double s_decay =
-      CLIP(pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2), 1e-5, 1);
-
   // Traverse 4 columns at a time
   // First and last columns will require padding
   for (int col = 0; col < block_width; col += 4) {
@@ -187,7 +182,7 @@
       }
 
       // Load next row to the last element
-      if (row <= block_width - 4) {
+      if (row <= block_height - 4) {
         vsrc[4] = xx_load_and_pad(src, col, block_width);
         src += SSE_STRIDE;
       } else {
@@ -205,8 +200,7 @@
     for (int j = 0; j < block_width; j++, k++) {
       const int pixel_value = frame2[i * stride2 + j];
 
-      int diff_sse = acc_5x5_sse[i][j];
-      int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH;
+      uint32_t diff_sse = acc_5x5_sse[i][j];
 
       // Filter U-plane and V-plane using Y-plane. This is because motion
       // search is only done on Y-plane, so the information from Y-plane will
@@ -217,18 +211,16 @@
             const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
             const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
             diff_sse += luma_sq_error[yy * SSE_STRIDE + xx];
-            ++num_ref_pixels;
           }
         }
       }
 
-      const double window_error = (double)(diff_sse) / num_ref_pixels;
+      const double window_error = diff_sse * inv_num_ref_pixels;
       const int subblock_idx =
           (i >= block_height / 2) * 2 + (j >= block_width / 2);
       const double block_error = (double)subblock_mses[subblock_idx];
       const double combined_error =
-          (TF_WINDOW_BLOCK_BALANCE_WEIGHT * window_error + block_error) /
-          (TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) / TF_SEARCH_ERROR_NORM_WEIGHT;
+          weight_factor * window_error + block_error * inv_factor;
 
       const MV mv = subblock_mvs[subblock_idx];
       const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
@@ -236,8 +228,8 @@
           (double)AOMMAX(min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD, 1);
       const double d_factor = AOMMAX(distance / distance_threshold, 1);
 
-      const double scaled_error =
-          AOMMIN(combined_error * d_factor / n_decay / q_decay / s_decay, 7);
+      double scaled_error = combined_error * d_factor * decay_factor;
+      scaled_error = AOMMIN(scaled_error, 7);
       const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
       count[k] += weight;
@@ -265,6 +257,18 @@
   const int frame_height = frame_to_filter->y_crop_height;
   const int frame_width = frame_to_filter->y_crop_width;
   const int min_frame_size = AOMMIN(frame_height, frame_width);
+  // Variables to simplify combined error calculation.
+  const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
+                                   TF_SEARCH_ERROR_NORM_WEIGHT);
+  const double weight_factor =
+      (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
+  // Decay factors for non-local mean approach.
+  // Smaller q -> smaller filtering weight.
+  double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
+  q_decay = CLIP(q_decay, 1e-5, 1);
+  // Smaller strength -> smaller filtering weight.
+  double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
+  s_decay = CLIP(s_decay, 1e-5, 1);
   uint16_t luma_sq_error[SSE_STRIDE * BH];
   uint16_t *chroma_sq_error =
       (num_planes > 0)
@@ -279,16 +283,22 @@
 
     const uint8_t *ref = frame_to_filter->buffers[plane] + frame_offset;
     const int ss_x_shift =
-        mbd->plane[plane].subsampling_x - mbd->plane[0].subsampling_x;
+        mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x;
     const int ss_y_shift =
-        mbd->plane[plane].subsampling_y - mbd->plane[0].subsampling_y;
+        mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
+    const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
+                               ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
+    const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
+    // Larger noise -> larger filtering weight.
+    const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
+    const double decay_factor = 1 / (n_decay * q_decay * s_decay);
 
-    apply_temporal_filter(ref, frame_stride, pred + mb_pels * plane, plane_w,
-                          plane_w, plane_h, min_frame_size, noise_levels[plane],
-                          subblock_mvs, subblock_mses, q_factor,
-                          filter_strength, accum + mb_pels * plane,
-                          count + mb_pels * plane, luma_sq_error,
-                          chroma_sq_error, plane, ss_x_shift, ss_y_shift);
+    apply_temporal_filter(
+        ref, frame_stride, pred + mb_pels * plane, plane_w, plane_w, plane_h,
+        min_frame_size, subblock_mvs, subblock_mses, accum + mb_pels * plane,
+        count + mb_pels * plane, luma_sq_error, chroma_sq_error, plane,
+        ss_x_shift, ss_y_shift, inv_num_ref_pixels, decay_factor, inv_factor,
+        weight_factor);
   }
   if (chroma_sq_error != NULL) aom_free(chroma_sq_error);
 }
diff --git a/av1/encoder/x86/temporal_filter_sse2.c b/av1/encoder/x86/temporal_filter_sse2.c
index 9fc92a6..a3a0981 100644
--- a/av1/encoder/x86/temporal_filter_sse2.c
+++ b/av1/encoder/x86/temporal_filter_sse2.c
@@ -105,10 +105,11 @@
 static void apply_temporal_filter(
     const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
     const unsigned int stride2, const int block_width, const int block_height,
-    const int min_frame_size, const double sigma, const MV *subblock_mvs,
-    const int *subblock_mses, const int q_factor, const int filter_strength,
+    const int min_frame_size, const MV *subblock_mvs, const int *subblock_mses,
     unsigned int *accumulator, uint16_t *count, uint16_t *luma_sq_error,
-    uint16_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift) {
+    uint16_t *chroma_sq_error, int plane, int ss_x_shift, int ss_y_shift,
+    const double inv_num_ref_pixels, const double decay_factor,
+    const double inv_factor, const double weight_factor) {
   assert(((block_width == 32) && (block_height == 32)) ||
          ((block_width == 16) && (block_height == 16)));
   if (plane > PLANE_TYPE_Y) assert(chroma_sq_error != NULL);
@@ -122,12 +123,6 @@
 
   __m128i vsrc[5][2];
 
-  const double n_decay = 0.5 + log(2 * sigma + 5.0);
-  const double q_decay =
-      CLIP(pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2), 1e-5, 1);
-  const double s_decay =
-      CLIP(pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2), 1e-5, 1);
-
   // Traverse 4 columns at a time
   // First and last columns will require padding
   for (int col = 0; col < block_width; col += 4) {
@@ -182,8 +177,7 @@
     for (int j = 0; j < block_width; j++, k++) {
       const int pixel_value = frame2[i * stride2 + j];
 
-      int diff_sse = acc_5x5_sse[i][j];
-      int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH;
+      uint32_t diff_sse = acc_5x5_sse[i][j];
 
       // Filter U-plane and V-plane using Y-plane. This is because motion
       // search is only done on Y-plane, so the information from Y-plane will
@@ -195,18 +189,16 @@
             const int xx = (j << ss_x_shift) + jj + 2;  // X-coord on Y-plane.
             const int ww = SSE_STRIDE;                  // Stride of Y-plane.
             diff_sse += luma_sq_error[yy * ww + xx];
-            ++num_ref_pixels;
           }
         }
       }
 
-      const double window_error = (double)(diff_sse) / num_ref_pixels;
+      const double window_error = diff_sse * inv_num_ref_pixels;
       const int subblock_idx =
           (i >= block_height / 2) * 2 + (j >= block_width / 2);
       const double block_error = (double)subblock_mses[subblock_idx];
       const double combined_error =
-          (TF_WINDOW_BLOCK_BALANCE_WEIGHT * window_error + block_error) /
-          (TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) / TF_SEARCH_ERROR_NORM_WEIGHT;
+          weight_factor * window_error + block_error * inv_factor;
 
       const MV mv = subblock_mvs[subblock_idx];
       const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
@@ -214,8 +206,8 @@
           (double)AOMMAX(min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD, 1);
       const double d_factor = AOMMAX(distance / distance_threshold, 1);
 
-      const double scaled_error =
-          AOMMIN(combined_error * d_factor / n_decay / q_decay / s_decay, 7);
+      double scaled_error = combined_error * d_factor * decay_factor;
+      scaled_error = AOMMIN(scaled_error, 7);
       const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
       count[k] += weight;
@@ -243,6 +235,18 @@
   const int frame_height = frame_to_filter->y_crop_height;
   const int frame_width = frame_to_filter->y_crop_width;
   const int min_frame_size = AOMMIN(frame_height, frame_width);
+  // Variables to simplify combined error calculation.
+  const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
+                                   TF_SEARCH_ERROR_NORM_WEIGHT);
+  const double weight_factor =
+      (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
+  // Decay factors for non-local mean approach.
+  // Smaller q -> smaller filtering weight.
+  double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
+  q_decay = CLIP(q_decay, 1e-5, 1);
+  // Smaller strength -> smaller filtering weight.
+  double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
+  s_decay = CLIP(s_decay, 1e-5, 1);
   uint16_t luma_sq_error[SSE_STRIDE * BH];
   uint16_t *chroma_sq_error =
       (num_planes > 0)
@@ -257,16 +261,22 @@
 
     const uint8_t *ref = frame_to_filter->buffers[plane] + frame_offset;
     const int ss_x_shift =
-        mbd->plane[plane].subsampling_x - mbd->plane[0].subsampling_x;
+        mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x;
     const int ss_y_shift =
-        mbd->plane[plane].subsampling_y - mbd->plane[0].subsampling_y;
+        mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
+    const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
+                               ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
+    const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
+    // Larger noise -> larger filtering weight.
+    const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
+    const double decay_factor = 1 / (n_decay * q_decay * s_decay);
 
-    apply_temporal_filter(ref, frame_stride, pred + mb_pels * plane, plane_w,
-                          plane_w, plane_h, min_frame_size, noise_levels[plane],
-                          subblock_mvs, subblock_mses, q_factor,
-                          filter_strength, accum + mb_pels * plane,
-                          count + mb_pels * plane, luma_sq_error,
-                          chroma_sq_error, plane, ss_x_shift, ss_y_shift);
+    apply_temporal_filter(
+        ref, frame_stride, pred + mb_pels * plane, plane_w, plane_w, plane_h,
+        min_frame_size, subblock_mvs, subblock_mses, accum + mb_pels * plane,
+        count + mb_pels * plane, luma_sq_error, chroma_sq_error, plane,
+        ss_x_shift, ss_y_shift, inv_num_ref_pixels, decay_factor, inv_factor,
+        weight_factor);
   }
   if (chroma_sq_error != NULL) aom_free(chroma_sq_error);
 }