Add Neon implementation of av1_apply_temporal_filter

Add an Arm Neon implementation of av1_apply_temporal_filter and use
it instead of the scalar C implementation for 32x32 blocks.

Also add test coverage for the new Neon implementation.

Change-Id: I965f489be562d13b3bc74a1ed57c2689c916fbd0
diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index e7872d9..a49ede4 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -15,6 +15,25 @@
 #include <string.h>
 #include "aom_dsp/aom_dsp_common.h"
 
+// Support for these xN intrinsics is lacking in older compilers.
+#if (defined(_MSC_VER) && !defined(__clang__) && !defined(_M_ARM64)) || \
+    (defined(__GNUC__) &&                                               \
+     ((!defined(__clang__) && (__GNUC__ < 8 || defined(__arm__))) ||    \
+      (defined(__clang__) && defined(__arm__) &&                        \
+       (__clang_major__ <= 6 ||                                         \
+        (defined(__ANDROID__) && __clang_major__ <= 7)))))
+static INLINE uint16x8x4_t vld1q_u16_x4(uint16_t const *ptr) {
+  uint16x8x4_t res = { { vld1q_u16(ptr + 0 * 8), vld1q_u16(ptr + 1 * 8),
+                         vld1q_u16(ptr + 2 * 8), vld1q_u16(ptr + 3 * 8) } };
+  return res;
+}
+#endif  // (defined(_MSC_VER) && !defined(__clang__) && !defined(_M_ARM64)) ||
+        // (defined(__GNUC__) &&
+        //  ((!defined(__clang__) && (__GNUC__ < 8 || defined(__arm__))) ||
+        //   (defined(__clang__) && defined(__arm__) &&
+        //    (__clang_major__ <= 6 ||
+        //     (defined(__ANDROID__) && __clang_major__ <= 7)))))
+
 static INLINE void store_row2_u8_8x8(uint8_t *s, int p, const uint8x8_t s0,
                                      const uint8x8_t s1) {
   vst1_u8(s, s0);
diff --git a/av1/av1.cmake b/av1/av1.cmake
index fc99b25..ace30c5 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -365,7 +365,8 @@
             "${AOM_ROOT}/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/highbd_fwd_txfm_neon.c"
-            "${AOM_ROOT}/av1/encoder/arm/neon/wedge_utils_neon.c")
+            "${AOM_ROOT}/av1/encoder/arm/neon/wedge_utils_neon.c"
+            "${AOM_ROOT}/av1/encoder/arm/neon/temporal_filter_neon.c")
 
 list(APPEND AOM_AV1_ENCODER_INTRIN_CRC32
             "${AOM_ROOT}/av1/encoder/arm/crc32/hash_crc32.c")
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index fb650a8..333a72d 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -403,7 +403,7 @@
   #
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
     add_proto qw/void av1_apply_temporal_filter/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int num_planes, const double *noise_levels, const MV *subblock_mvs, const int *subblock_mses, const int q_factor, const int filter_strength, const uint8_t *pred, uint32_t *accum, uint16_t *count";
-    specialize qw/av1_apply_temporal_filter sse2 avx2/;
+    specialize qw/av1_apply_temporal_filter sse2 avx2 neon/;
     if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
       add_proto qw/void av1_highbd_apply_temporal_filter/, "const struct yv12_buffer_config *ref_frame, const struct macroblockd *mbd, const BLOCK_SIZE block_size, const int mb_row, const int mb_col, const int num_planes, const double *noise_levels, const MV *subblock_mvs, const int *subblock_mses, const int q_factor, const int filter_strength, const uint8_t *pred, uint32_t *accum, uint16_t *count";
       specialize qw/av1_highbd_apply_temporal_filter sse2 avx2/;
diff --git a/av1/encoder/arm/neon/temporal_filter_neon.c b/av1/encoder/arm/neon/temporal_filter_neon.c
new file mode 100644
index 0000000..4765e1a
--- /dev/null
+++ b/av1/encoder/arm/neon/temporal_filter_neon.c
@@ -0,0 +1,253 @@
+/*
+ * Copyright (c) 2022, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <arm_neon.h>
+
+#include "config/av1_rtcd.h"
+#include "av1/encoder/encoder.h"
+#include "av1/encoder/temporal_filter.h"
+#include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/arm/sum_neon.h"
+
+// For the squared error buffer, add padding for 4 samples.
+#define SSE_STRIDE (BW + 4)
+
+DECLARE_ALIGNED(16, static const uint16_t, kSlidingWindowMask[]) = {
+  0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000,
+  0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000,
+  0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000,
+  0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF
+};
+
+static INLINE void get_squared_error(
+    const uint8_t *frame1, const uint32_t stride1, const uint8_t *frame2,
+    const uint32_t stride2, const uint32_t block_width,
+    const uint32_t block_height, uint16_t *frame_sse,
+    const unsigned int dst_stride) {
+  uint16_t *dst = frame_sse;
+
+  uint32_t i = 0;
+  do {
+    uint32_t j = 0;
+    do {
+      uint8x16_t s = vld1q_u8(frame1 + i * stride1 + j);
+      uint8x16_t r = vld1q_u8(frame2 + i * stride2 + j);
+
+      uint8x16_t abs_diff = vabdq_u8(s, r);
+      uint16x8_t sse_lo =
+          vmull_u8(vget_low_u8(abs_diff), vget_low_u8(abs_diff));
+      uint16x8_t sse_hi =
+          vmull_u8(vget_high_u8(abs_diff), vget_high_u8(abs_diff));
+
+      vst1q_u16(dst + j + 2, sse_lo);
+      vst1q_u16(dst + j + 10, sse_hi);
+
+      j += 16;
+    } while (j < block_width);
+
+    dst += dst_stride;
+    i++;
+  } while (i < block_height);
+}
+
+static INLINE uint16x8_t load_and_pad(uint16_t *src, const uint32_t col,
+                                      const uint32_t block_width) {
+  uint16x8_t s = vld1q_u16(src);
+
+  if (col == 0) {
+    s[0] = s[2];
+    s[1] = s[2];
+  } else if (col >= block_width - 4) {
+    s[6] = s[5];
+    s[7] = s[5];
+  }
+  return s;
+}
+
+static void apply_temporal_filter(
+    const uint8_t *frame, const unsigned int stride, const uint32_t block_width,
+    const uint32_t block_height, const int *subblock_mses,
+    unsigned int *accumulator, uint16_t *count, uint16_t *frame_sse,
+    uint32_t *luma_sse_sum, const double inv_num_ref_pixels,
+    const double decay_factor, const double inv_factor,
+    const double weight_factor, double *d_factor) {
+  assert(((block_width == 16) || (block_width == 32)) &&
+         ((block_height == 16) || (block_height == 32)));
+
+  uint32_t acc_5x5_neon[BH][BW];
+  const uint16x8x4_t vmask = vld1q_u16_x4(kSlidingWindowMask);
+
+  // Traverse 4 columns at a time - first and last two columns need padding.
+  for (uint32_t col = 0; col < block_width; col += 4) {
+    uint16x8_t vsrc[5];
+    uint16_t *src = frame_sse + col;
+
+    // Load and pad (for first and last two columns) 3 rows from the top.
+    for (int i = 2; i < 5; i++) {
+      vsrc[i] = load_and_pad(src, col, block_width);
+      src += SSE_STRIDE;
+    }
+
+    // Pad the top 2 rows.
+    vsrc[0] = vsrc[2];
+    vsrc[1] = vsrc[2];
+
+    for (unsigned int row = 0; row < block_height; row++) {
+      for (int i = 0; i < 4; i++) {
+        uint32x4_t vsum = vdupq_n_u32(0);
+        for (int j = 0; j < 5; j++) {
+          vsum = vpadalq_u16(vsum, vandq_u16(vsrc[j], vmask.val[i]));
+        }
+        acc_5x5_neon[row][col + i] = horizontal_add_u32x4(vsum);
+      }
+
+      // Push all rows in the sliding window up one.
+      for (int i = 0; i < 4; i++) {
+        vsrc[i] = vsrc[i + 1];
+      }
+
+      if (row <= block_height - 4) {
+        // Load next row into the bottom of the sliding window.
+        vsrc[4] = load_and_pad(src, col, block_width);
+        src += SSE_STRIDE;
+      } else {
+        // Pad the bottom 2 rows.
+        vsrc[4] = vsrc[3];
+      }
+    }
+  }
+
+  // Perform filtering.
+  for (unsigned int i = 0, k = 0; i < block_height; i++) {
+    for (unsigned int j = 0; j < block_width; j++, k++) {
+      const int pixel_value = frame[i * stride + j];
+      uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
+
+      const double window_error = diff_sse * inv_num_ref_pixels;
+      const int subblock_idx =
+          (i >= block_height / 2) * 2 + (j >= block_width / 2);
+      const double block_error = (double)subblock_mses[subblock_idx];
+      const double combined_error =
+          weight_factor * window_error + block_error * inv_factor;
+      // Compute filter weight.
+      double scaled_error =
+          combined_error * d_factor[subblock_idx] * decay_factor;
+      scaled_error = AOMMIN(scaled_error, 7);
+      const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+      accumulator[k] += weight * pixel_value;
+      count[k] += weight;
+    }
+  }
+}
+
+void av1_apply_temporal_filter_neon(
+    const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd,
+    const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
+    const int num_planes, const double *noise_levels, const MV *subblock_mvs,
+    const int *subblock_mses, const int q_factor, const int filter_strength,
+    const uint8_t *pred, uint32_t *accum, uint16_t *count) {
+  const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
+  assert(block_size == BLOCK_32X32 && "Only support 32x32 block with Neon!");
+  assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with Neon!");
+  assert(!is_high_bitdepth && "Only support low bit-depth with Neon!");
+  assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
+  (void)is_high_bitdepth;
+
+  // Block information.
+  const int mb_height = block_size_high[block_size];
+  const int mb_width = block_size_wide[block_size];
+  // Frame information.
+  const int frame_height = frame_to_filter->y_crop_height;
+  const int frame_width = frame_to_filter->y_crop_width;
+  const int min_frame_size = AOMMIN(frame_height, frame_width);
+  // Variables to simplify combined error calculation.
+  const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
+                                   TF_SEARCH_ERROR_NORM_WEIGHT);
+  const double weight_factor =
+      (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
+  // Adjust filtering based on q.
+  // Larger q -> stronger filtering -> larger weight.
+  // Smaller q -> weaker filtering -> smaller weight.
+  double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
+  q_decay = CLIP(q_decay, 1e-5, 1);
+  if (q_factor >= TF_QINDEX_CUTOFF) {
+    // Max q_factor is 255, therefore the upper bound of q_decay is 8.
+    // We do not need a clip here.
+    q_decay = 0.5 * pow((double)q_factor / 64, 2);
+  }
+  // Smaller strength -> smaller filtering weight.
+  double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
+  s_decay = CLIP(s_decay, 1e-5, 1);
+  double d_factor[4] = { 0 };
+  uint16_t frame_sse[SSE_STRIDE * BH] = { 0 };
+  uint32_t luma_sse_sum[BW * BH] = { 0 };
+
+  for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
+    // Larger motion vector -> smaller filtering weight.
+    const MV mv = subblock_mvs[subblock_idx];
+    const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
+    double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
+    distance_threshold = AOMMAX(distance_threshold, 1);
+    d_factor[subblock_idx] = distance / distance_threshold;
+    d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
+  }
+
+  // Handle planes in sequence.
+  int plane_offset = 0;
+  for (int plane = 0; plane < num_planes; ++plane) {
+    const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
+    const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
+    const uint32_t frame_stride =
+        frame_to_filter->strides[plane == AOM_PLANE_Y ? 0 : 1];
+    const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
+
+    const uint8_t *ref = frame_to_filter->buffers[plane] + frame_offset;
+    const int ss_x_shift =
+        mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x;
+    const int ss_y_shift =
+        mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
+    const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
+                               ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
+    const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
+    // Larger noise -> larger filtering weight.
+    const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
+    // Decay factors for non-local mean approach.
+    const double decay_factor = 1 / (n_decay * q_decay * s_decay);
+
+    // Filter U-plane and V-plane using Y-plane. This is because motion
+    // search is only done on Y-plane, so the information from Y-plane
+    // will be more accurate. The luma sse sum is reused in both chroma
+    // planes.
+    if (plane == AOM_PLANE_U) {
+      for (unsigned int i = 0; i < plane_h; i++) {
+        for (unsigned int j = 0; j < plane_w; j++) {
+          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
+              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
+              luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
+            }
+          }
+        }
+      }
+    }
+
+    get_squared_error(ref, frame_stride, pred + plane_offset, plane_w, plane_w,
+                      plane_h, frame_sse, SSE_STRIDE);
+
+    apply_temporal_filter(
+        pred + plane_offset, plane_w, plane_w, plane_h, subblock_mses,
+        accum + plane_offset, count + plane_offset, frame_sse, luma_sse_sum,
+        inv_num_ref_pixels, decay_factor, inv_factor, weight_factor, d_factor);
+
+    plane_offset += plane_h * plane_w;
+  }
+}
diff --git a/test/temporal_filter_test.cc b/test/temporal_filter_test.cc
index bf61f02..154fd5d 100644
--- a/test/temporal_filter_test.cc
+++ b/test/temporal_filter_test.cc
@@ -296,6 +296,15 @@
                          Combine(ValuesIn(temporal_filter_test_sse2),
                                  Range(64, 65, 4)));
 #endif  // HAVE_SSE2
+
+#if HAVE_NEON
+TemporalFilterFuncParam temporal_filter_test_neon[] = { TemporalFilterFuncParam(
+    &av1_apply_temporal_filter_c, &av1_apply_temporal_filter_neon) };
+INSTANTIATE_TEST_SUITE_P(NEON, TemporalFilterTest,
+                         Combine(ValuesIn(temporal_filter_test_neon),
+                                 Range(64, 65, 4)));
+#endif  // HAVE_NEON
+
 #if CONFIG_AV1_HIGHBITDEPTH
 
 typedef void (*HBDTemporalFilterFunc)(