Add Neon implementation for av1_highbd_calc_frame_error

Add Neon implementation for av1_highbd_calc_frame_error and the
corresponding tests as well.

Change-Id: If452abbbede2ac1093c6b0977cc20a513f9b6d4a
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 4f7cdda..f5a8a9f 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -487,6 +487,7 @@
               "${AOM_ROOT}/av1/encoder/x86/highbd_temporal_filter_avx2.c")
 
   list(APPEND AOM_AV1_ENCODER_INTRIN_NEON
+              "${AOM_ROOT}/av1/encoder/arm/neon/highbd_frame_error_neon.c"
               "${AOM_ROOT}/av1/encoder/arm/neon/highbd_pickrst_neon.c"
               "${AOM_ROOT}/av1/encoder/arm/neon/highbd_rdopt_neon.c"
               "${AOM_ROOT}/av1/encoder/arm/neon/highbd_temporal_filter_neon.c")
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 997777f..9941947 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -502,7 +502,7 @@
 
     if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
       add_proto qw/int64_t av1_calc_highbd_frame_error/, "const uint16_t *const ref, int ref_stride, const uint16_t *const dst, int dst_stride, int p_width, int p_height, int bd";
-      specialize qw/av1_calc_highbd_frame_error sse2 avx2/;
+      specialize qw/av1_calc_highbd_frame_error sse2 avx2 neon/;
     }
   }
 }
diff --git a/av1/encoder/arm/neon/highbd_frame_error_neon.c b/av1/encoder/arm/neon/highbd_frame_error_neon.c
new file mode 100644
index 0000000..58875f6
--- /dev/null
+++ b/av1/encoder/arm/neon/highbd_frame_error_neon.c
@@ -0,0 +1,218 @@
+/*
+ * Copyright (c) 2023, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <stdlib.h>
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+static const uint16_t error_measure_lut_diff[257] = {
+  339, 211, 181, 163, 151, 142, 136, 129, 125, 121, 117, 114, 111, 108, 107,
+  104, 102, 101, 99,  97,  96,  94,  93,  92,  91,  90,  88,  88,  87,  86,
+  85,  84,  84,  82,  82,  82,  80,  80,  79,  79,  78,  78,  77,  76,  76,
+  76,  74,  75,  74,  73,  74,  72,  73,  71,  72,  71,  71,  70,  70,  69,
+  70,  69,  68,  68,  68,  68,  67,  67,  67,  66,  67,  65,  66,  65,  65,
+  65,  65,  64,  64,  64,  63,  64,  63,  63,  62,  63,  62,  62,  62,  61,
+  62,  61,  61,  61,  60,  61,  60,  60,  60,  59,  60,  59,  59,  59,  59,
+  58,  59,  58,  58,  58,  58,  58,  57,  58,  57,  57,  57,  56,  57,  56,
+  57,  56,  56,  56,  56,  55,  56,  55,  55,  56,  55,  54,  55,  55,  54,
+  55,  54,  54,  54,  54,  54,  53,  54,  53,  54,  53,  53,  53,  53,  53,
+  52,  53,  52,  53,  52,  52,  52,  52,  52,  52,  52,  51,  52,  51,  51,
+  52,  51,  51,  51,  50,  51,  51,  50,  51,  50,  51,  50,  50,  50,  50,
+  50,  50,  49,  50,  50,  49,  50,  49,  49,  49,  49,  49,  49,  49,  49,
+  49,  48,  49,  48,  49,  48,  48,  49,  48,  48,  48,  48,  47,  48,  48,
+  48,  47,  48,  47,  47,  48,  47,  47,  47,  47,  47,  47,  47,  47,  47,
+  46,  47,  46,  47,  46,  47,  46,  46,  46,  47,  46,  46,  46,  45,  46,
+  46,  46,  45,  46,  46,  45,  45,  46,  45,  45,  46,  45,  45,  45,  45,
+  0
+};
+
+static const int hbd_12_error_measure_lut[257] = {
+  0,      5424,   8800,   11696,  14304,  16720,  18992,  21168,  23232,
+  25232,  27168,  29040,  30864,  32640,  34368,  36080,  37744,  39376,
+  40992,  42576,  44128,  45664,  47168,  48656,  50128,  51584,  53024,
+  54432,  55840,  57232,  58608,  59968,  61312,  62656,  63968,  65280,
+  66592,  67872,  69152,  70416,  71680,  72928,  74176,  75408,  76624,
+  77840,  79056,  80240,  81440,  82624,  83792,  84976,  86128,  87296,
+  88432,  89584,  90720,  91856,  92976,  94096,  95200,  96320,  97424,
+  98512,  99600,  100688, 101776, 102848, 103920, 104992, 106048, 107120,
+  108160, 109216, 110256, 111296, 112336, 113376, 114400, 115424, 116448,
+  117456, 118480, 119488, 120496, 121488, 122496, 123488, 124480, 125472,
+  126448, 127440, 128416, 129392, 130368, 131328, 132304, 133264, 134224,
+  135184, 136128, 137088, 138032, 138976, 139920, 140864, 141792, 142736,
+  143664, 144592, 145520, 146448, 147376, 148288, 149216, 150128, 151040,
+  151952, 152848, 153760, 154656, 155568, 156464, 157360, 158256, 159152,
+  160032, 160928, 161808, 162688, 163584, 164464, 165328, 166208, 167088,
+  167952, 168832, 169696, 170560, 171424, 172288, 173152, 174000, 174864,
+  175712, 176576, 177424, 178272, 179120, 179968, 180816, 181648, 182496,
+  183328, 184176, 185008, 185840, 186672, 187504, 188336, 189168, 190000,
+  190816, 191648, 192464, 193280, 194112, 194928, 195744, 196560, 197360,
+  198176, 198992, 199792, 200608, 201408, 202224, 203024, 203824, 204624,
+  205424, 206224, 207024, 207808, 208608, 209408, 210192, 210992, 211776,
+  212560, 213344, 214128, 214912, 215696, 216480, 217264, 218048, 218816,
+  219600, 220368, 221152, 221920, 222688, 223472, 224240, 225008, 225776,
+  226544, 227296, 228064, 228832, 229600, 230352, 231120, 231872, 232624,
+  233392, 234144, 234896, 235648, 236400, 237152, 237904, 238656, 239408,
+  240160, 240896, 241648, 242384, 243136, 243872, 244624, 245360, 246096,
+  246832, 247584, 248320, 249056, 249792, 250512, 251248, 251984, 252720,
+  253440, 254176, 254912, 255632, 256352, 257088, 257808, 258528, 259264,
+  259984, 260704, 261424, 262144, 262144,
+};
+
+static const int hbd_10_error_measure_lut[257] = {
+  0,     1356,  2200,  2924,  3576,  4180,  4748,  5292,  5808,  6308,  6792,
+  7260,  7716,  8160,  8592,  9020,  9436,  9844,  10248, 10644, 11032, 11416,
+  11792, 12164, 12532, 12896, 13256, 13608, 13960, 14308, 14652, 14992, 15328,
+  15664, 15992, 16320, 16648, 16968, 17288, 17604, 17920, 18232, 18544, 18852,
+  19156, 19460, 19764, 20060, 20360, 20656, 20948, 21244, 21532, 21824, 22108,
+  22396, 22680, 22964, 23244, 23524, 23800, 24080, 24356, 24628, 24900, 25172,
+  25444, 25712, 25980, 26248, 26512, 26780, 27040, 27304, 27564, 27824, 28084,
+  28344, 28600, 28856, 29112, 29364, 29620, 29872, 30124, 30372, 30624, 30872,
+  31120, 31368, 31612, 31860, 32104, 32348, 32592, 32832, 33076, 33316, 33556,
+  33796, 34032, 34272, 34508, 34744, 34980, 35216, 35448, 35684, 35916, 36148,
+  36380, 36612, 36844, 37072, 37304, 37532, 37760, 37988, 38212, 38440, 38664,
+  38892, 39116, 39340, 39564, 39788, 40008, 40232, 40452, 40672, 40896, 41116,
+  41332, 41552, 41772, 41988, 42208, 42424, 42640, 42856, 43072, 43288, 43500,
+  43716, 43928, 44144, 44356, 44568, 44780, 44992, 45204, 45412, 45624, 45832,
+  46044, 46252, 46460, 46668, 46876, 47084, 47292, 47500, 47704, 47912, 48116,
+  48320, 48528, 48732, 48936, 49140, 49340, 49544, 49748, 49948, 50152, 50352,
+  50556, 50756, 50956, 51156, 51356, 51556, 51756, 51952, 52152, 52352, 52548,
+  52748, 52944, 53140, 53336, 53532, 53728, 53924, 54120, 54316, 54512, 54704,
+  54900, 55092, 55288, 55480, 55672, 55868, 56060, 56252, 56444, 56636, 56824,
+  57016, 57208, 57400, 57588, 57780, 57968, 58156, 58348, 58536, 58724, 58912,
+  59100, 59288, 59476, 59664, 59852, 60040, 60224, 60412, 60596, 60784, 60968,
+  61156, 61340, 61524, 61708, 61896, 62080, 62264, 62448, 62628, 62812, 62996,
+  63180, 63360, 63544, 63728, 63908, 64088, 64272, 64452, 64632, 64816, 64996,
+  65176, 65356, 65536, 65536,
+};
+
+static const int hbd_8_error_measure_lut[257] = {
+  0,     339,   550,   731,   894,   1045,  1187,  1323,  1452,  1577,  1698,
+  1815,  1929,  2040,  2148,  2255,  2359,  2461,  2562,  2661,  2758,  2854,
+  2948,  3041,  3133,  3224,  3314,  3402,  3490,  3577,  3663,  3748,  3832,
+  3916,  3998,  4080,  4162,  4242,  4322,  4401,  4480,  4558,  4636,  4713,
+  4789,  4865,  4941,  5015,  5090,  5164,  5237,  5311,  5383,  5456,  5527,
+  5599,  5670,  5741,  5811,  5881,  5950,  6020,  6089,  6157,  6225,  6293,
+  6361,  6428,  6495,  6562,  6628,  6695,  6760,  6826,  6891,  6956,  7021,
+  7086,  7150,  7214,  7278,  7341,  7405,  7468,  7531,  7593,  7656,  7718,
+  7780,  7842,  7903,  7965,  8026,  8087,  8148,  8208,  8269,  8329,  8389,
+  8449,  8508,  8568,  8627,  8686,  8745,  8804,  8862,  8921,  8979,  9037,
+  9095,  9153,  9211,  9268,  9326,  9383,  9440,  9497,  9553,  9610,  9666,
+  9723,  9779,  9835,  9891,  9947,  10002, 10058, 10113, 10168, 10224, 10279,
+  10333, 10388, 10443, 10497, 10552, 10606, 10660, 10714, 10768, 10822, 10875,
+  10929, 10982, 11036, 11089, 11142, 11195, 11248, 11301, 11353, 11406, 11458,
+  11511, 11563, 11615, 11667, 11719, 11771, 11823, 11875, 11926, 11978, 12029,
+  12080, 12132, 12183, 12234, 12285, 12335, 12386, 12437, 12487, 12538, 12588,
+  12639, 12689, 12739, 12789, 12839, 12889, 12939, 12988, 13038, 13088, 13137,
+  13187, 13236, 13285, 13334, 13383, 13432, 13481, 13530, 13579, 13628, 13676,
+  13725, 13773, 13822, 13870, 13918, 13967, 14015, 14063, 14111, 14159, 14206,
+  14254, 14302, 14350, 14397, 14445, 14492, 14539, 14587, 14634, 14681, 14728,
+  14775, 14822, 14869, 14916, 14963, 15010, 15056, 15103, 15149, 15196, 15242,
+  15289, 15335, 15381, 15427, 15474, 15520, 15566, 15612, 15657, 15703, 15749,
+  15795, 15840, 15886, 15932, 15977, 16022, 16068, 16113, 16158, 16204, 16249,
+  16294, 16339, 16384, 16384,
+};
+
+// Split error into two parts and do an interpolated table lookup.
+// To compute the table index and interpolation value, we want to calculate
+// the quotient and remainder of (dst - ref) / 2^(bd - 8).
+#define HBD_CALC_FRAME_ERROR(bd, offset, mask)                              \
+  static INLINE int highbd_##bd##_error_measure(int q, int r) {             \
+    return (hbd_##bd##_error_measure_lut[q]) +                              \
+           (error_measure_lut_diff[q]) * r;                                 \
+  }                                                                         \
+                                                                            \
+  int64_t av1_calc_highbd_##bd##_frame_error_neon(                          \
+      const uint16_t *const ref, int ref_stride, const uint16_t *const dst, \
+      int dst_stride, int width, int height) {                              \
+    int64_t sum_error[4] = { 0, 0, 0, 0 };                                  \
+    int r = 0;                                                              \
+    int d = 0;                                                              \
+                                                                            \
+    do {                                                                    \
+      int w = width;                                                        \
+      int rr = r;                                                           \
+      int dd = d;                                                           \
+                                                                            \
+      do {                                                                  \
+        uint16x8_t dst_v = vld1q_u16(&dst[dd]);                             \
+        uint16x8_t ref_v = vld1q_u16(&ref[rr]);                             \
+                                                                            \
+        uint64x2_t abs_v = vreinterpretq_u64_u16(vabdq_u16(dst_v, ref_v));  \
+                                                                            \
+        uint64_t abs0 = vgetq_lane_u64(abs_v, 0);                           \
+        uint64_t abs1 = vgetq_lane_u64(abs_v, 1);                           \
+                                                                            \
+        sum_error[0] += highbd_##bd##_error_measure(                        \
+            (abs0 >> (0 + offset)) & 0xFF, (abs0 >> 0) & mask);             \
+        sum_error[1] += highbd_##bd##_error_measure(                        \
+            (abs0 >> (16 + offset)) & 0xFF, (abs0 >> 16) & mask);           \
+        sum_error[2] += highbd_##bd##_error_measure(                        \
+            (abs0 >> (32 + offset)) & 0xFF, (abs0 >> 32) & mask);           \
+        sum_error[3] += highbd_##bd##_error_measure(                        \
+            (abs0 >> (48 + offset)) & 0xFF, (abs0 >> 48) & mask);           \
+                                                                            \
+        sum_error[0] += highbd_##bd##_error_measure(                        \
+            (abs1 >> (0 + offset)) & 0xFF, (abs1 >> 0) & mask);             \
+        sum_error[1] += highbd_##bd##_error_measure(                        \
+            (abs1 >> (16 + offset)) & 0xFF, (abs1 >> 16) & mask);           \
+        sum_error[2] += highbd_##bd##_error_measure(                        \
+            (abs1 >> (32 + offset)) & 0xFF, (abs1 >> 32) & mask);           \
+        sum_error[3] += highbd_##bd##_error_measure(                        \
+            (abs1 >> (48 + offset)) & 0xFF, (abs1 >> 48) & mask);           \
+                                                                            \
+        dd += 8;                                                            \
+        rr += 8;                                                            \
+        w -= 8;                                                             \
+      } while (w >= 8);                                                     \
+                                                                            \
+      while (w-- != 0) {                                                    \
+        uint16_t abs_u16 = abs(dst[dd] - ref[rr]);                          \
+        sum_error[0] +=                                                     \
+            highbd_##bd##_error_measure(abs_u16 >> offset, abs_u16 & mask); \
+        dd++;                                                               \
+        rr++;                                                               \
+      }                                                                     \
+                                                                            \
+      r += ref_stride;                                                      \
+      d += dst_stride;                                                      \
+    } while (--height != 0);                                                \
+                                                                            \
+    return sum_error[0] + sum_error[1] + sum_error[2] + sum_error[3];       \
+  }
+
+// 12 bitdepth
+HBD_CALC_FRAME_ERROR(12, 4, 0xF)
+// 10 bitdepth
+HBD_CALC_FRAME_ERROR(10, 2, 0x3)
+// 8 bitdepth
+HBD_CALC_FRAME_ERROR(8, 0, 0x0)
+
+int64_t av1_calc_highbd_frame_error_neon(const uint16_t *const ref,
+                                         int ref_stride,
+                                         const uint16_t *const dst,
+                                         int dst_stride, int width, int height,
+                                         int bd) {
+  switch (bd) {
+    case 8:
+    default:
+      return av1_calc_highbd_8_frame_error_neon(ref, ref_stride, dst,
+                                                dst_stride, width, height);
+    case 10:
+      return av1_calc_highbd_10_frame_error_neon(ref, ref_stride, dst,
+                                                 dst_stride, width, height);
+    case 12:
+      return av1_calc_highbd_12_frame_error_neon(ref, ref_stride, dst,
+                                                 dst_stride, width, height);
+  }
+}
diff --git a/test/frame_error_test.cc b/test/frame_error_test.cc
index 265fba6..35733d9 100644
--- a/test/frame_error_test.cc
+++ b/test/frame_error_test.cc
@@ -183,7 +183,7 @@
 const int kBlockHeight[] = {
   480, 482, 360, 720, 1080,
 };
-#if HAVE_AVX2 || HAVE_SSE2
+#if HAVE_AVX2 || HAVE_SSE2 || HAVE_NEON
 const int kBitDepths[] = { 8, 10, 12 };
 #endif
 typedef std::tuple<highbd_frame_error_func, int, int, int>
@@ -327,6 +327,15 @@
                        ::testing::ValuesIn(kBitDepths)));
 #endif
 
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+    NEON, AV1HighbdFrameErrorTest,
+    ::testing::Combine(::testing::Values(&av1_calc_highbd_frame_error_neon),
+                       ::testing::ValuesIn(kBlockWidth),
+                       ::testing::ValuesIn(kBlockHeight),
+                       ::testing::ValuesIn(kBitDepths)));
+#endif
+
 // Check that 8-bit and 16-bit code paths give the same results for
 // 8-bit content
 typedef std::tuple<int, int> HighbdFrameErrorConsistencyParam;