Optimize Neon temporal filter using UDOT instruction

Add an alternative AArch64 implementation of
av1_apply_temporal_filter_neon that uses the Armv8.4-A UDOT (unsigned
dot-product) instruction to implement the 5x5 SSE sum calculation.

The existing implementation is retained for use on target CPUs that do
not implement the UDOT instruction (or CPUs executing in AArch32
mode.) The availability of the UDOT instruction is indicated by the
feature macro __ARM_FEATURE_DOTPROD.

Change-Id: I4e5eb8edeb0a3b629a9aabde21058ba87afb501b
diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index a49ede4..1a93f35 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -22,6 +22,11 @@
       (defined(__clang__) && defined(__arm__) &&                        \
        (__clang_major__ <= 6 ||                                         \
         (defined(__ANDROID__) && __clang_major__ <= 7)))))
+static INLINE uint8x16x2_t vld1q_u8_x2(uint8_t const *ptr) {
+  uint8x16x2_t res = { { vld1q_u8(ptr + 0 * 16), vld1q_u8(ptr + 1 * 16) } };
+  return res;
+}
+
 static INLINE uint16x8x4_t vld1q_u16_x4(uint16_t const *ptr) {
   uint16x8x4_t res = { { vld1q_u16(ptr + 0 * 8), vld1q_u16(ptr + 1 * 8),
                          vld1q_u16(ptr + 2 * 8), vld1q_u16(ptr + 3 * 8) } };
diff --git a/av1/encoder/arm/neon/temporal_filter_neon.c b/av1/encoder/arm/neon/temporal_filter_neon.c
index 4765e1a..65fb332b 100644
--- a/av1/encoder/arm/neon/temporal_filter_neon.c
+++ b/av1/encoder/arm/neon/temporal_filter_neon.c
@@ -20,6 +20,152 @@
 // For the squared error buffer, add padding for 4 samples.
 #define SSE_STRIDE (BW + 4)
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
+// clang-format off
+
+DECLARE_ALIGNED(16, static const uint8_t, kSlidingWindowMask[]) = {
+  0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00,
+  0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
+  0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00,
+  0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
+};
+
+// clang-format on
+
+static INLINE void get_abs_diff(const uint8_t *frame1, const uint32_t stride1,
+                                const uint8_t *frame2, const uint32_t stride2,
+                                const uint32_t block_width,
+                                const uint32_t block_height,
+                                uint8_t *frame_abs_diff,
+                                const unsigned int dst_stride) {
+  uint8_t *dst = frame_abs_diff;
+
+  uint32_t i = 0;
+  do {
+    uint32_t j = 0;
+    do {
+      uint8x16_t s = vld1q_u8(frame1 + i * stride1 + j);
+      uint8x16_t r = vld1q_u8(frame2 + i * stride2 + j);
+      uint8x16_t abs_diff = vabdq_u8(s, r);
+      vst1q_u8(dst + j + 2, abs_diff);
+      j += 16;
+    } while (j < block_width);
+
+    dst += dst_stride;
+    i++;
+  } while (i < block_height);
+}
+
+static INLINE uint8x16_t load_and_pad(uint8_t *src, const uint32_t col,
+                                      const uint32_t block_width) {
+  uint8x8_t s = vld1_u8(src);
+
+  if (col == 0) {
+    s[0] = s[2];
+    s[1] = s[2];
+  } else if (col >= block_width - 4) {
+    s[6] = s[5];
+    s[7] = s[5];
+  }
+  return vcombine_u8(s, s);
+}
+
+static void apply_temporal_filter(
+    const uint8_t *frame, const unsigned int stride, const uint32_t block_width,
+    const uint32_t block_height, const int *subblock_mses,
+    unsigned int *accumulator, uint16_t *count, uint8_t *frame_abs_diff,
+    uint32_t *luma_sse_sum, const double inv_num_ref_pixels,
+    const double decay_factor, const double inv_factor,
+    const double weight_factor, double *d_factor) {
+  assert(((block_width == 16) || (block_width == 32)) &&
+         ((block_height == 16) || (block_height == 32)));
+
+  uint32_t acc_5x5_neon[BH][BW];
+  const uint8x16x2_t vmask = vld1q_u8_x2(kSlidingWindowMask);
+
+  // Traverse 4 columns at a time - first and last two columns need padding.
+  for (uint32_t col = 0; col < block_width; col += 4) {
+    uint8x16_t vsrc[5][2];
+    uint8_t *src = frame_abs_diff + col;
+
+    // Load, pad (for first and last two columns) and mask 3 rows from the top.
+    for (int i = 2; i < 5; i++) {
+      uint8x16_t s = load_and_pad(src, col, block_width);
+      vsrc[i][0] = vandq_u8(s, vmask.val[0]);
+      vsrc[i][1] = vandq_u8(s, vmask.val[1]);
+      src += SSE_STRIDE;
+    }
+
+    // Pad the top 2 rows.
+    vsrc[0][0] = vsrc[2][0];
+    vsrc[0][1] = vsrc[2][1];
+    vsrc[1][0] = vsrc[2][0];
+    vsrc[1][1] = vsrc[2][1];
+
+    for (unsigned int row = 0; row < block_height; row++) {
+      uint32x4_t sum_01 = vdupq_n_u32(0);
+      uint32x4_t sum_23 = vdupq_n_u32(0);
+
+      sum_01 = vdotq_u32(sum_01, vsrc[0][0], vsrc[0][0]);
+      sum_01 = vdotq_u32(sum_01, vsrc[1][0], vsrc[1][0]);
+      sum_01 = vdotq_u32(sum_01, vsrc[2][0], vsrc[2][0]);
+      sum_01 = vdotq_u32(sum_01, vsrc[3][0], vsrc[3][0]);
+      sum_01 = vdotq_u32(sum_01, vsrc[4][0], vsrc[4][0]);
+
+      sum_23 = vdotq_u32(sum_23, vsrc[0][1], vsrc[0][1]);
+      sum_23 = vdotq_u32(sum_23, vsrc[1][1], vsrc[1][1]);
+      sum_23 = vdotq_u32(sum_23, vsrc[2][1], vsrc[2][1]);
+      sum_23 = vdotq_u32(sum_23, vsrc[3][1], vsrc[3][1]);
+      sum_23 = vdotq_u32(sum_23, vsrc[4][1], vsrc[4][1]);
+
+      vst1q_u32(&acc_5x5_neon[row][col], vpaddq_u32(sum_01, sum_23));
+
+      // Push all rows in the sliding window up one.
+      for (int i = 0; i < 4; i++) {
+        vsrc[i][0] = vsrc[i + 1][0];
+        vsrc[i][1] = vsrc[i + 1][1];
+      }
+
+      if (row <= block_height - 4) {
+        // Load next row into the bottom of the sliding window.
+        uint8x16_t s = load_and_pad(src, col, block_width);
+        vsrc[4][0] = vandq_u8(s, vmask.val[0]);
+        vsrc[4][1] = vandq_u8(s, vmask.val[1]);
+        src += SSE_STRIDE;
+      } else {
+        // Pad the bottom 2 rows.
+        vsrc[4][0] = vsrc[3][0];
+        vsrc[4][1] = vsrc[3][1];
+      }
+    }
+  }
+
+  // Perform filtering.
+  for (unsigned int i = 0, k = 0; i < block_height; i++) {
+    for (unsigned int j = 0; j < block_width; j++, k++) {
+      const int pixel_value = frame[i * stride + j];
+      uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j];
+
+      const double window_error = diff_sse * inv_num_ref_pixels;
+      const int subblock_idx =
+          (i >= block_height / 2) * 2 + (j >= block_width / 2);
+      const double block_error = (double)subblock_mses[subblock_idx];
+      const double combined_error =
+          weight_factor * window_error + block_error * inv_factor;
+      // Compute filter weight.
+      double scaled_error =
+          combined_error * d_factor[subblock_idx] * decay_factor;
+      scaled_error = AOMMIN(scaled_error, 7);
+      const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
+      accumulator[k] += weight * pixel_value;
+      count[k] += weight;
+    }
+  }
+}
+
+#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+
 DECLARE_ALIGNED(16, static const uint16_t, kSlidingWindowMask[]) = {
   0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000,
   0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000,
@@ -148,6 +294,8 @@
   }
 }
 
+#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
 void av1_apply_temporal_filter_neon(
     const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd,
     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
@@ -187,7 +335,11 @@
   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
   s_decay = CLIP(s_decay, 1e-5, 1);
   double d_factor[4] = { 0 };
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+  uint8_t frame_abs_diff[SSE_STRIDE * BH] = { 0 };
+#else   // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
   uint16_t frame_sse[SSE_STRIDE * BH] = { 0 };
+#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
   uint32_t luma_sse_sum[BW * BH] = { 0 };
 
   for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
@@ -226,6 +378,32 @@
     // search is only done on Y-plane, so the information from Y-plane
     // will be more accurate. The luma sse sum is reused in both chroma
     // planes.
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+    if (plane == AOM_PLANE_U) {
+      for (unsigned int i = 0; i < plane_h; i++) {
+        for (unsigned int j = 0; j < plane_w; j++) {
+          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
+              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
+              luma_sse_sum[i * BW + j] +=
+                  (frame_abs_diff[yy * SSE_STRIDE + xx + 2] *
+                   frame_abs_diff[yy * SSE_STRIDE + xx + 2]);
+            }
+          }
+        }
+      }
+    }
+
+    get_abs_diff(ref, frame_stride, pred + plane_offset, plane_w, plane_w,
+                 plane_h, frame_abs_diff, SSE_STRIDE);
+
+    apply_temporal_filter(pred + plane_offset, plane_w, plane_w, plane_h,
+                          subblock_mses, accum + plane_offset,
+                          count + plane_offset, frame_abs_diff, luma_sse_sum,
+                          inv_num_ref_pixels, decay_factor, inv_factor,
+                          weight_factor, d_factor);
+#else   // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
     if (plane == AOM_PLANE_U) {
       for (unsigned int i = 0; i < plane_h; i++) {
         for (unsigned int j = 0; j < plane_w; j++) {
@@ -247,6 +425,7 @@
         pred + plane_offset, plane_w, plane_w, plane_h, subblock_mses,
         accum + plane_offset, count + plane_offset, frame_sse, luma_sse_sum,
         inv_num_ref_pixels, decay_factor, inv_factor, weight_factor, d_factor);
+#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
     plane_offset += plane_h * plane_w;
   }