Speed-up weight calculation during temporal filtering

In parent version, the weights used during temporal
filtering were calculated using exp() function. In this CL,
exp() is replaced by approx_exp().

For 'good' encoding,
               Instruction Count      BD-Rate Loss(%)
cpu Resolution   Reduction(%)    avg.psnr  ovr.psnr    ssim
 5     LOWRES2      3.190         0.0171    -0.0000  -0.0147
 5     MIDRES2      4.240         0.0463     0.0181   0.0438
 5      HDRES2      5.180         0.0697     0.0537   0.0875
 6     LOWRES2      1.812         0.0216     0.0070  -0.0269
 6     MIDRES2      3.849         0.0256    -0.0472  -0.0362
 6      HDRES2      4.654        -0.0597    -0.0750  -0.0581

STATS_CHANGED for speed=5,6

Change-Id: I6b1b7086b834f230efa37002ef7c3e3ba3624aa9
diff --git a/aom_dsp/mathutils.h b/aom_dsp/mathutils.h
index 72572eb..035ca39 100644
--- a/aom_dsp/mathutils.h
+++ b/aom_dsp/mathutils.h
@@ -121,4 +121,19 @@
   }
 }
 
+static AOM_INLINE float approx_exp(float y) {
+#define A ((1 << 23) / 0.69314718056f)  // (1 << 23) / ln(2)
+#define B \
+  127  // Offset for the exponent according to IEEE floating point standard.
+#define C 60801  // Magic number controls the accuracy of approximation
+  union {
+    float as_float;
+    int32_t as_int32;
+  } container;
+  container.as_int32 = ((int32_t)(y * A)) + ((B << 23) - C);
+  return container.as_float;
+#undef A
+#undef B
+#undef C
+}
 #endif  // AOM_AOM_DSP_MATHUTILS_H_
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 4aa0d42..ac44a5c 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -402,7 +402,7 @@
   # Motion search
   #
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
-    add_proto qw/void av1_apply_temporal_filter/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int num_planes, const double *noise_levels, const MV *subblock_mvs, const int *subblock_mses, const int q_factor, const int filter_strength, const uint8_t *pred, uint32_t *accum, uint16_t *count";
+    add_proto qw/void av1_apply_temporal_filter/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int num_planes, const double *noise_levels, const MV *subblock_mvs, const int *subblock_mses, const int q_factor, const int filter_strength, int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum, uint16_t *count";
     specialize qw/av1_apply_temporal_filter sse2 avx2 neon/;
     if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
       add_proto qw/void av1_highbd_apply_temporal_filter/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int num_planes, const double *noise_levels, const MV *subblock_mvs, const int *subblock_mses, const int q_factor, const int filter_strength, const uint8_t *pred, uint32_t *accum, uint16_t *count";
diff --git a/av1/encoder/arm/neon/temporal_filter_neon.c b/av1/encoder/arm/neon/temporal_filter_neon.c
index cae44f9..f4ec20f 100644
--- a/av1/encoder/arm/neon/temporal_filter_neon.c
+++ b/av1/encoder/arm/neon/temporal_filter_neon.c
@@ -14,6 +14,7 @@
 #include "config/av1_rtcd.h"
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/temporal_filter.h"
+#include "aom_dsp/mathutils.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/sum_neon.h"
 
@@ -77,7 +78,7 @@
     unsigned int *accumulator, uint16_t *count, uint8_t *frame_abs_diff,
     uint32_t *luma_sse_sum, const double inv_num_ref_pixels,
     const double decay_factor, const double inv_factor,
-    const double weight_factor, double *d_factor) {
+    const double weight_factor, double *d_factor, int tf_wgt_calc_lvl) {
   assert(((block_width == 16) || (block_width == 32)) &&
          ((block_height == 16) || (block_height == 32)));
 
@@ -142,24 +143,48 @@
   }
 
   // Perform filtering.
-  for (unsigned int i = 0, k = 0; i < block_height; i++) {
-    for (unsigned int j = 0; j < block_width; j++, k++) {
-      const int pixel_value = frame[i * stride + j];
-      uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
+  if (tf_wgt_calc_lvl == 0) {
+    for (unsigned int i = 0, k = 0; i < block_height; i++) {
+      for (unsigned int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame[i * stride + j];
+        uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
 
-      const double window_error = diff_sse * inv_num_ref_pixels;
-      const int subblock_idx =
-          (i >= block_height / 2) * 2 + (j >= block_width / 2);
-      const double block_error = (double)subblock_mses[subblock_idx];
-      const double combined_error =
-          weight_factor * window_error + block_error * inv_factor;
-      // Compute filter weight.
-      double scaled_error =
-          combined_error * d_factor[subblock_idx] * decay_factor;
-      scaled_error = AOMMIN(scaled_error, 7);
-      const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
-      accumulator[k] += weight * pixel_value;
-      count[k] += weight;
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx =
+            (i >= block_height / 2) * 2 + (j >= block_width / 2);
+        const double block_error = (double)subblock_mses[subblock_idx];
+        const double combined_error =
+            weight_factor * window_error + block_error * inv_factor;
+        // Compute filter weight.
+        double scaled_error =
+            combined_error * d_factor[subblock_idx] * decay_factor;
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+        accumulator[k] += weight * pixel_value;
+        count[k] += weight;
+      }
+    }
+  } else {
+    for (unsigned int i = 0, k = 0; i < block_height; i++) {
+      for (unsigned int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame[i * stride + j];
+        uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
+
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx =
+            (i >= block_height / 2) * 2 + (j >= block_width / 2);
+        const double block_error = (double)subblock_mses[subblock_idx];
+        const double combined_error =
+            weight_factor * window_error + block_error * inv_factor;
+        // Compute filter weight.
+        double scaled_error =
+            combined_error * d_factor[subblock_idx] * decay_factor;
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight =
+            (int)(approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE + 0.5f);
+        accumulator[k] += weight * pixel_value;
+        count[k] += weight;
+      }
     }
   }
 }
@@ -225,7 +250,7 @@
     unsigned int *accumulator, uint16_t *count, uint16_t *frame_sse,
     uint32_t *luma_sse_sum, const double inv_num_ref_pixels,
     const double decay_factor, const double inv_factor,
-    const double weight_factor, double *d_factor) {
+    const double weight_factor, double *d_factor, int tf_wgt_calc_lvl) {
   assert(((block_width == 16) || (block_width == 32)) &&
          ((block_height == 16) || (block_height == 32)));
 
@@ -273,24 +298,48 @@
   }
 
   // Perform filtering.
-  for (unsigned int i = 0, k = 0; i < block_height; i++) {
-    for (unsigned int j = 0; j < block_width; j++, k++) {
-      const int pixel_value = frame[i * stride + j];
-      uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
+  if (tf_wgt_calc_lvl == 0) {
+    for (unsigned int i = 0, k = 0; i < block_height; i++) {
+      for (unsigned int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame[i * stride + j];
+        uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
 
-      const double window_error = diff_sse * inv_num_ref_pixels;
-      const int subblock_idx =
-          (i >= block_height / 2) * 2 + (j >= block_width / 2);
-      const double block_error = (double)subblock_mses[subblock_idx];
-      const double combined_error =
-          weight_factor * window_error + block_error * inv_factor;
-      // Compute filter weight.
-      double scaled_error =
-          combined_error * d_factor[subblock_idx] * decay_factor;
-      scaled_error = AOMMIN(scaled_error, 7);
-      const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
-      accumulator[k] += weight * pixel_value;
-      count[k] += weight;
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx =
+            (i >= block_height / 2) * 2 + (j >= block_width / 2);
+        const double block_error = (double)subblock_mses[subblock_idx];
+        const double combined_error =
+            weight_factor * window_error + block_error * inv_factor;
+        // Compute filter weight.
+        double scaled_error =
+            combined_error * d_factor[subblock_idx] * decay_factor;
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+        accumulator[k] += weight * pixel_value;
+        count[k] += weight;
+      }
+    }
+  } else {
+    for (unsigned int i = 0, k = 0; i < block_height; i++) {
+      for (unsigned int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame[i * stride + j];
+        uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
+
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx =
+            (i >= block_height / 2) * 2 + (j >= block_width / 2);
+        const double block_error = (double)subblock_mses[subblock_idx];
+        const double combined_error =
+            weight_factor * window_error + block_error * inv_factor;
+        // Compute filter weight.
+        double scaled_error =
+            combined_error * d_factor[subblock_idx] * decay_factor;
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight =
+            (int)(approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE + 0.5f);
+        accumulator[k] += weight * pixel_value;
+        count[k] += weight;
+      }
     }
   }
 }
@@ -302,7 +351,8 @@
     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
     const int num_planes, const double *noise_levels, const MV *subblock_mvs,
     const int *subblock_mses, const int q_factor, const int filter_strength,
-    const uint8_t *pred, uint32_t *accum, uint16_t *count) {
+    int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
   assert(block_size == BLOCK_32X32 && "Only support 32x32 block with Neon!");
   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with Neon!");
@@ -403,7 +453,7 @@
                           subblock_mses, accum + plane_offset,
                           count + plane_offset, frame_abs_diff, luma_sse_sum,
                           inv_num_ref_pixels, decay_factor, inv_factor,
-                          weight_factor, d_factor);
+                          weight_factor, d_factor, tf_wgt_calc_lvl);
 #else   // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
     if (plane == AOM_PLANE_U) {
       for (unsigned int i = 0; i < plane_h; i++) {
@@ -422,10 +472,11 @@
     get_squared_error(ref, frame_stride, pred + plane_offset, plane_w, plane_w,
                       plane_h, frame_sse, SSE_STRIDE);
 
-    apply_temporal_filter(
-        pred + plane_offset, plane_w, plane_w, plane_h, subblock_mses,
-        accum + plane_offset, count + plane_offset, frame_sse, luma_sse_sum,
-        inv_num_ref_pixels, decay_factor, inv_factor, weight_factor, d_factor);
+    apply_temporal_filter(pred + plane_offset, plane_w, plane_w, plane_h,
+                          subblock_mses, accum + plane_offset,
+                          count + plane_offset, frame_sse, luma_sse_sum,
+                          inv_num_ref_pixels, decay_factor, inv_factor,
+                          weight_factor, d_factor, tf_wgt_calc_lvl);
 #endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
     plane_offset += plane_h * plane_w;
diff --git a/av1/encoder/ml.c b/av1/encoder/ml.c
index 5078fb1..94cd56c 100644
--- a/av1/encoder/ml.c
+++ b/av1/encoder/ml.c
@@ -13,6 +13,7 @@
 #include <math.h>
 
 #include "aom_dsp/aom_dsp_common.h"
+#include "aom_dsp/mathutils.h"
 #include "av1/encoder/ml.h"
 
 void av1_nn_output_prec_reduce(float *const output, int num_output) {
@@ -155,22 +156,6 @@
   for (int i = 0; i < n; i++) output[i] /= sum_out;
 }
 
-static AOM_INLINE float approx_exp(float y) {
-#define A ((1 << 23) / 0.69314718056f)  // (1 << 23) / ln(2)
-#define B \
-  127  // Offset for the exponent according to IEEE floating point standard.
-#define C 60801  // Magic number controls the accuracy of approximation
-  union {
-    float as_float;
-    int32_t as_int32;
-  } container;
-  container.as_int32 = ((int32_t)(y * A)) + ((B << 23) - C);
-  return container.as_float;
-#undef A
-#undef B
-#undef C
-}
-
 void av1_nn_fast_softmax_16_c(const float *input, float *output) {
   const int kNumClasses = 16;
   float max_input = input[0];
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 9e42391..a0cc5a6 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -1183,6 +1183,9 @@
   }
 
   if (speed >= 5) {
+    // TODO(Ranjit): Enable the optimization for highbd encoding mode
+    sf->hl_sf.weight_calc_level_in_tf = use_hbd ? 0 : 1;
+
     sf->fp_sf.reduce_mv_step_param = 4;
 
     sf->part_sf.simple_motion_search_prune_agg =
@@ -1817,6 +1820,7 @@
   hl_sf->second_alt_ref_filtering = 1;
   hl_sf->num_frames_used_in_tf = INT_MAX;
   hl_sf->accurate_bit_estimate = 0;
+  hl_sf->weight_calc_level_in_tf = 0;
 }
 
 static AOM_INLINE void init_fp_sf(FIRST_PASS_SPEED_FEATURES *fp_sf) {
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 4ff85f9..bef24c1 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -442,6 +442,13 @@
    * 1: estimate bits more accurately based on the frame complexity.
    */
   int accurate_bit_estimate;
+
+  /*!
+   * Decide the approach for weight calculation during temporal filtering.
+   * 0: Calculate weight using exp()
+   * 1: Calculate weight using a lookup table that approximates exp().
+   */
+  int weight_calc_level_in_tf;
 } HIGH_LEVEL_SPEED_FEATURES;
 
 /*!
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index 7540bd8..4cfb4e5 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -16,6 +16,7 @@
 #include "config/aom_scale_rtcd.h"
 
 #include "aom_dsp/aom_dsp_common.h"
+#include "aom_dsp/mathutils.h"
 #include "aom_dsp/odintrin.h"
 #include "aom_mem/aom_mem.h"
 #include "aom_ports/aom_timer.h"
@@ -545,6 +546,8 @@
  *                              defined in libaom, converted from `qindex`
  * \param[in]   filter_strength Filtering strength. This value lies in range
  *                              [0, 6] where 6 is the maximum strength.
+ * \param[in]   tf_wgt_calc_lvl Controls the weight calculation method during
+ *                              temporal filtering
  * \param[out]  pred            Pointer to the well-built predictors
  * \param[out]  accum           Pointer to the pixel-wise accumulator for
  *                              filtering
@@ -559,7 +562,8 @@
     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
     const int num_planes, const double *noise_levels, const MV *subblock_mvs,
     const int *subblock_mses, const int q_factor, const int filter_strength,
-    const uint8_t *pred, uint32_t *accum, uint16_t *count) {
+    int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
   // Block information.
   const int mb_height = block_size_high[block_size];
   const int mb_width = block_size_wide[block_size];
@@ -689,7 +693,13 @@
         double scaled_error =
             combined_error * d_factor[subblock_idx] * decay_factor[plane];
         scaled_error = AOMMIN(scaled_error, 7);
-        const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+        int weight;
+        if (tf_wgt_calc_lvl == 0) {
+          weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+        } else {
+          weight =
+              (int)(approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE + 0.5f);
+        }
 
         const int idx = plane_offset + pred_idx;  // Index with plane shift.
         const int pred_value = is_high_bitdepth ? pred16[idx] : pred[idx];
@@ -715,7 +725,7 @@
     const uint8_t *pred, uint32_t *accum, uint16_t *count) {
   av1_apply_temporal_filter_c(frame_to_filter, mbd, block_size, mb_row, mb_col,
                               num_planes, noise_levels, subblock_mvs,
-                              subblock_mses, q_factor, filter_strength, pred,
+                              subblock_mses, q_factor, filter_strength, 0, pred,
                               accum, count);
 }
 #endif  // CONFIG_AV1_HIGHBITDEPTH
@@ -867,21 +877,23 @@
             av1_apply_temporal_filter_c(
                 frame_to_filter, mbd, block_size, mb_row, mb_col, num_planes,
                 noise_levels, subblock_mvs, subblock_mses, q_factor,
-                filter_strength, pred, accum, count);
+                filter_strength, 0, pred, accum, count);
 #if CONFIG_AV1_HIGHBITDEPTH
           }
 #endif            // CONFIG_AV1_HIGHBITDEPTH
         } else {  // for 8-bit
           if (TF_BLOCK_SIZE == BLOCK_32X32 && TF_WINDOW_LENGTH == 5) {
-            av1_apply_temporal_filter(frame_to_filter, mbd, block_size, mb_row,
-                                      mb_col, num_planes, noise_levels,
-                                      subblock_mvs, subblock_mses, q_factor,
-                                      filter_strength, pred, accum, count);
+            av1_apply_temporal_filter(
+                frame_to_filter, mbd, block_size, mb_row, mb_col, num_planes,
+                noise_levels, subblock_mvs, subblock_mses, q_factor,
+                filter_strength, cpi->sf.hl_sf.weight_calc_level_in_tf, pred,
+                accum, count);
           } else {
             av1_apply_temporal_filter_c(
                 frame_to_filter, mbd, block_size, mb_row, mb_col, num_planes,
                 noise_levels, subblock_mvs, subblock_mses, q_factor,
-                filter_strength, pred, accum, count);
+                filter_strength, cpi->sf.hl_sf.weight_calc_level_in_tf, pred,
+                accum, count);
           }
         }
       }
diff --git a/av1/encoder/x86/temporal_filter_avx2.c b/av1/encoder/x86/temporal_filter_avx2.c
index 32b9d4d..ff1b49f 100644
--- a/av1/encoder/x86/temporal_filter_avx2.c
+++ b/av1/encoder/x86/temporal_filter_avx2.c
@@ -13,6 +13,7 @@
 #include <immintrin.h>
 
 #include "config/av1_rtcd.h"
+#include "aom_dsp/mathutils.h"
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/temporal_filter.h"
 
@@ -133,7 +134,8 @@
     const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
     uint16_t *frame_sse, uint32_t *luma_sse_sum,
     const double inv_num_ref_pixels, const double decay_factor,
-    const double inv_factor, const double weight_factor, double *d_factor) {
+    const double inv_factor, const double weight_factor, double *d_factor,
+    int tf_wgt_calc_lvl) {
   assert(((block_width == 16) || (block_width == 32)) &&
          ((block_height == 16) || (block_height == 32)));
 
@@ -198,23 +200,46 @@
     subblock_mses_scaled[idx] = subblock_mses[idx] * inv_factor;
     d_factor_decayed[idx] = d_factor[idx] * decay_factor;
   }
-  for (int i = 0, k = 0; i < block_height; i++) {
-    const int y_blk_raster_offset = (i >= block_height / 2) * 2;
-    for (int j = 0; j < block_width; j++, k++) {
-      const int pixel_value = frame2[i * stride2 + j];
-      uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
+  if (tf_wgt_calc_lvl == 0) {
+    for (int i = 0, k = 0; i < block_height; i++) {
+      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      for (int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame2[i * stride2 + j];
+        uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
 
-      const double window_error = diff_sse * inv_num_ref_pixels;
-      const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
-      const double combined_error =
-          weight_factor * window_error + subblock_mses_scaled[subblock_idx];
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        const double combined_error =
+            weight_factor * window_error + subblock_mses_scaled[subblock_idx];
 
-      double scaled_error = combined_error * d_factor_decayed[subblock_idx];
-      scaled_error = AOMMIN(scaled_error, 7);
-      const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+        double scaled_error = combined_error * d_factor_decayed[subblock_idx];
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
-      count[k] += weight;
-      accumulator[k] += weight * pixel_value;
+        count[k] += weight;
+        accumulator[k] += weight * pixel_value;
+      }
+    }
+  } else {
+    for (int i = 0, k = 0; i < block_height; i++) {
+      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      for (int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame2[i * stride2 + j];
+        uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
+
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        const double combined_error =
+            weight_factor * window_error + subblock_mses_scaled[subblock_idx];
+
+        double scaled_error = combined_error * d_factor_decayed[subblock_idx];
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight =
+            (int)(approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE + 0.5f);
+
+        count[k] += weight;
+        accumulator[k] += weight * pixel_value;
+      }
     }
   }
 }
@@ -224,7 +249,8 @@
     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
     const int num_planes, const double *noise_levels, const MV *subblock_mvs,
     const int *subblock_mses, const int q_factor, const int filter_strength,
-    const uint8_t *pred, uint32_t *accum, uint16_t *count) {
+    int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
   assert(block_size == BLOCK_32X32 && "Only support 32x32 block with avx2!");
   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with avx2!");
@@ -312,7 +338,7 @@
                           plane_w, plane_h, subblock_mses, accum + plane_offset,
                           count + plane_offset, frame_sse, luma_sse_sum,
                           inv_num_ref_pixels, decay_factor, inv_factor,
-                          weight_factor, d_factor);
+                          weight_factor, d_factor, tf_wgt_calc_lvl);
     plane_offset += plane_h * plane_w;
   }
 }
diff --git a/av1/encoder/x86/temporal_filter_sse2.c b/av1/encoder/x86/temporal_filter_sse2.c
index 9bb7148..b0eb2f1 100644
--- a/av1/encoder/x86/temporal_filter_sse2.c
+++ b/av1/encoder/x86/temporal_filter_sse2.c
@@ -13,6 +13,7 @@
 #include <emmintrin.h>
 
 #include "config/av1_rtcd.h"
+#include "aom_dsp/mathutils.h"
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/temporal_filter.h"
 
@@ -107,7 +108,8 @@
     const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
     uint16_t *frame_sse, uint32_t *luma_sse_sum,
     const double inv_num_ref_pixels, const double decay_factor,
-    const double inv_factor, const double weight_factor, double *d_factor) {
+    const double inv_factor, const double weight_factor, double *d_factor,
+    int tf_wgt_calc_lvl) {
   assert(((block_width == 16) || (block_width == 32)) &&
          ((block_height == 16) || (block_height == 32)));
 
@@ -174,23 +176,46 @@
     subblock_mses_scaled[idx] = subblock_mses[idx] * inv_factor;
     d_factor_decayed[idx] = d_factor[idx] * decay_factor;
   }
-  for (int i = 0, k = 0; i < block_height; i++) {
-    const int y_blk_raster_offset = (i >= block_height / 2) * 2;
-    for (int j = 0; j < block_width; j++, k++) {
-      const int pixel_value = frame2[i * stride2 + j];
-      uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
+  if (tf_wgt_calc_lvl == 0) {
+    for (int i = 0, k = 0; i < block_height; i++) {
+      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      for (int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame2[i * stride2 + j];
+        uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
 
-      const double window_error = diff_sse * inv_num_ref_pixels;
-      const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
-      const double combined_error =
-          weight_factor * window_error + subblock_mses_scaled[subblock_idx];
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        const double combined_error =
+            weight_factor * window_error + subblock_mses_scaled[subblock_idx];
 
-      double scaled_error = combined_error * d_factor_decayed[subblock_idx];
-      scaled_error = AOMMIN(scaled_error, 7);
-      const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+        double scaled_error = combined_error * d_factor_decayed[subblock_idx];
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
 
-      count[k] += weight;
-      accumulator[k] += weight * pixel_value;
+        count[k] += weight;
+        accumulator[k] += weight * pixel_value;
+      }
+    }
+  } else {
+    for (int i = 0, k = 0; i < block_height; i++) {
+      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      for (int j = 0; j < block_width; j++, k++) {
+        const int pixel_value = frame2[i * stride2 + j];
+        uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
+
+        const double window_error = diff_sse * inv_num_ref_pixels;
+        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        const double combined_error =
+            weight_factor * window_error + subblock_mses_scaled[subblock_idx];
+
+        double scaled_error = combined_error * d_factor_decayed[subblock_idx];
+        scaled_error = AOMMIN(scaled_error, 7);
+        const int weight =
+            (int)(approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE + 0.5f);
+
+        count[k] += weight;
+        accumulator[k] += weight * pixel_value;
+      }
     }
   }
 }
@@ -200,7 +225,8 @@
     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
     const int num_planes, const double *noise_levels, const MV *subblock_mvs,
     const int *subblock_mses, const int q_factor, const int filter_strength,
-    const uint8_t *pred, uint32_t *accum, uint16_t *count) {
+    int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
+    uint16_t *count) {
   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
   assert(block_size == BLOCK_32X32 && "Only support 32x32 block with sse2!");
   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with sse2!");
@@ -288,7 +314,7 @@
                           plane_w, plane_h, subblock_mses, accum + plane_offset,
                           count + plane_offset, frame_sse, luma_sse_sum,
                           inv_num_ref_pixels, decay_factor, inv_factor,
-                          weight_factor, d_factor);
+                          weight_factor, d_factor, tf_wgt_calc_lvl);
     plane_offset += plane_h * plane_w;
   }
 }
diff --git a/test/temporal_filter_test.cc b/test/temporal_filter_test.cc
index 154fd5d..79d7e52 100644
--- a/test/temporal_filter_test.cc
+++ b/test/temporal_filter_test.cc
@@ -51,7 +51,7 @@
     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
     const int num_planes, const double *noise_level, const MV *subblock_mvs,
     const int *subblock_mses, const int q_factor, const int filter_strenght,
-    const uint8_t *pred, uint32_t *accum, uint16_t *count);
+    int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum, uint16_t *count);
 typedef libaom_test::FuncParam<TemporalFilterFunc> TemporalFilterFuncParam;
 
 typedef std::tuple<TemporalFilterFuncParam, int> TemporalFilterWithParam;
@@ -62,6 +62,7 @@
   virtual ~TemporalFilterTest() {}
   virtual void SetUp() {
     params_ = GET_PARAM(0);
+    tf_wgt_calc_lvl_ = GET_PARAM(1);
     rnd_.Reset(ACMRandom::DeterministicSeed());
     src1_ = reinterpret_cast<uint8_t *>(
         aom_memalign(8, sizeof(uint8_t) * MAX_MB_PLANE * BH * BW));
@@ -121,6 +122,7 @@
 
  protected:
   TemporalFilterFuncParam params_;
+  int32_t tf_wgt_calc_lvl_;
   uint8_t *src1_;
   uint8_t *src2_;
   ACMRandom rnd_;
@@ -208,18 +210,20 @@
 
     params_.ref_func(ref_frame.get(), mbd.get(), block_size, mb_row, mb_col,
                      num_planes, sigma, subblock_mvs, subblock_mses, q_factor,
-                     filter_strength, src2_, accumulator_ref, count_ref);
+                     filter_strength, tf_wgt_calc_lvl_, src2_, accumulator_ref,
+                     count_ref);
     params_.tst_func(ref_frame.get(), mbd.get(), block_size, mb_row, mb_col,
                      num_planes, sigma, subblock_mvs, subblock_mses, q_factor,
-                     filter_strength, src2_, accumulator_mod, count_mod);
+                     filter_strength, tf_wgt_calc_lvl_, src2_, accumulator_mod,
+                     count_mod);
 
     if (run_times > 1) {
       aom_usec_timer_start(&ref_timer);
       for (int j = 0; j < run_times; j++) {
         params_.ref_func(ref_frame.get(), mbd.get(), block_size, mb_row, mb_col,
                          num_planes, sigma, subblock_mvs, subblock_mses,
-                         q_factor, filter_strength, src2_, accumulator_ref,
-                         count_ref);
+                         q_factor, filter_strength, tf_wgt_calc_lvl_, src2_,
+                         accumulator_ref, count_ref);
       }
       aom_usec_timer_mark(&ref_timer);
       const int elapsed_time_c =
@@ -229,8 +233,8 @@
       for (int j = 0; j < run_times; j++) {
         params_.tst_func(ref_frame.get(), mbd.get(), block_size, mb_row, mb_col,
                          num_planes, sigma, subblock_mvs, subblock_mses,
-                         q_factor, filter_strength, src2_, accumulator_mod,
-                         count_mod);
+                         q_factor, filter_strength, tf_wgt_calc_lvl_, src2_,
+                         accumulator_mod, count_mod);
       }
       aom_usec_timer_mark(&test_timer);
       const int elapsed_time_simd =
@@ -286,7 +290,7 @@
     &av1_apply_temporal_filter_c, &av1_apply_temporal_filter_avx2) };
 INSTANTIATE_TEST_SUITE_P(AVX2, TemporalFilterTest,
                          Combine(ValuesIn(temporal_filter_test_avx2),
-                                 Range(64, 65, 4)));
+                                 Values(0, 1)));
 #endif  // HAVE_AVX2
 
 #if HAVE_SSE2
@@ -294,7 +298,7 @@
     &av1_apply_temporal_filter_c, &av1_apply_temporal_filter_sse2) };
 INSTANTIATE_TEST_SUITE_P(SSE2, TemporalFilterTest,
                          Combine(ValuesIn(temporal_filter_test_sse2),
-                                 Range(64, 65, 4)));
+                                 Values(0, 1)));
 #endif  // HAVE_SSE2
 
 #if HAVE_NEON
@@ -302,7 +306,7 @@
     &av1_apply_temporal_filter_c, &av1_apply_temporal_filter_neon) };
 INSTANTIATE_TEST_SUITE_P(NEON, TemporalFilterTest,
                          Combine(ValuesIn(temporal_filter_test_neon),
-                                 Range(64, 65, 4)));
+                                 Values(0, 1)));
 #endif  // HAVE_NEON
 
 #if CONFIG_AV1_HIGHBITDEPTH