blob: 5e33996d258f6ddf3bf842c66359e7e6ebbb2116 [file] [log] [blame]
Yaowu Xuc27fc142016-08-22 16:08:15 -07001/*
Yaowu Xu9c01aa12016-09-01 14:32:49 -07002 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
Yaowu Xuc27fc142016-08-22 16:08:15 -07003 *
Yaowu Xu9c01aa12016-09-01 14:32:49 -07004 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
Yaowu Xuc27fc142016-08-22 16:08:15 -070010 */
11
12#include <arm_neon.h>
13
Tom Finegan44702c82018-05-22 13:00:39 -070014#include "config/aom_dsp_rtcd.h"
Tom Finegan60e653d2018-05-22 11:34:58 -070015#include "config/aom_config.h"
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +010016#include "aom_dsp/arm/mem_neon.h"
Jerome Jiang410e6d12019-06-27 10:45:38 -070017#include "aom_dsp/arm/sum_neon.h"
Yaowu Xuf883b422016-08-30 14:01:10 -070018#include "aom/aom_integer.h"
Yaowu Xuc27fc142016-08-22 16:08:15 -070019#include "aom_ports/mem.h"
20
Jonathan Wright38ef0d82022-07-07 09:12:11 +010021#if defined(__ARM_FEATURE_DOTPROD)
22
23static INLINE void variance_4xh_neon(const uint8_t *src, int src_stride,
24 const uint8_t *ref, int ref_stride, int h,
25 uint32_t *sse, int *sum) {
26 uint32x4_t src_sum = vdupq_n_u32(0);
27 uint32x4_t ref_sum = vdupq_n_u32(0);
28 uint32x4_t sse_u32 = vdupq_n_u32(0);
29
Gerda Zsejke More680652b2022-11-24 14:47:42 +010030 int i = h;
Jonathan Wright38ef0d82022-07-07 09:12:11 +010031 do {
32 uint8x16_t s = load_unaligned_u8q(src, src_stride);
33 uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
34
35 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
36 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
37
38 uint8x16_t abs_diff = vabdq_u8(s, r);
39 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
40
41 src += 4 * src_stride;
42 ref += 4 * ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +010043 i -= 4;
44 } while (i != 0);
Jonathan Wright38ef0d82022-07-07 09:12:11 +010045
46 int32x4_t sum_diff =
47 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
48 *sum = horizontal_add_s32x4(sum_diff);
49 *sse = horizontal_add_u32x4(sse_u32);
50}
51
52static INLINE void variance_8xh_neon(const uint8_t *src, int src_stride,
53 const uint8_t *ref, int ref_stride, int h,
54 uint32_t *sse, int *sum) {
55 uint32x4_t src_sum = vdupq_n_u32(0);
56 uint32x4_t ref_sum = vdupq_n_u32(0);
57 uint32x4_t sse_u32 = vdupq_n_u32(0);
58
Gerda Zsejke More680652b2022-11-24 14:47:42 +010059 int i = h;
Jonathan Wright38ef0d82022-07-07 09:12:11 +010060 do {
61 uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
62 uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
63
64 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
65 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
66
67 uint8x16_t abs_diff = vabdq_u8(s, r);
68 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
69
70 src += 2 * src_stride;
71 ref += 2 * ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +010072 i -= 2;
73 } while (i != 0);
Jonathan Wright38ef0d82022-07-07 09:12:11 +010074
75 int32x4_t sum_diff =
76 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
77 *sum = horizontal_add_s32x4(sum_diff);
78 *sse = horizontal_add_u32x4(sse_u32);
79}
80
81static INLINE void variance_16xh_neon(const uint8_t *src, int src_stride,
82 const uint8_t *ref, int ref_stride, int h,
83 uint32_t *sse, int *sum) {
84 uint32x4_t src_sum = vdupq_n_u32(0);
85 uint32x4_t ref_sum = vdupq_n_u32(0);
86 uint32x4_t sse_u32 = vdupq_n_u32(0);
87
Gerda Zsejke More680652b2022-11-24 14:47:42 +010088 int i = h;
Jonathan Wright38ef0d82022-07-07 09:12:11 +010089 do {
90 uint8x16_t s = vld1q_u8(src);
91 uint8x16_t r = vld1q_u8(ref);
92
93 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
94 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
95
96 uint8x16_t abs_diff = vabdq_u8(s, r);
97 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
98
99 src += src_stride;
100 ref += ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100101 } while (--i != 0);
Jonathan Wright38ef0d82022-07-07 09:12:11 +0100102
103 int32x4_t sum_diff =
104 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
105 *sum = horizontal_add_s32x4(sum_diff);
106 *sse = horizontal_add_u32x4(sse_u32);
107}
108
109static INLINE void variance_large_neon(const uint8_t *src, int src_stride,
110 const uint8_t *ref, int ref_stride,
111 int w, int h, uint32_t *sse, int *sum) {
112 uint32x4_t src_sum = vdupq_n_u32(0);
113 uint32x4_t ref_sum = vdupq_n_u32(0);
114 uint32x4_t sse_u32 = vdupq_n_u32(0);
115
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100116 int i = h;
Jonathan Wright38ef0d82022-07-07 09:12:11 +0100117 do {
118 int j = 0;
119 do {
120 uint8x16_t s = vld1q_u8(src + j);
121 uint8x16_t r = vld1q_u8(ref + j);
122
123 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
124 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
125
126 uint8x16_t abs_diff = vabdq_u8(s, r);
127 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
128
129 j += 16;
130 } while (j < w);
131
132 src += src_stride;
133 ref += ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100134 } while (--i != 0);
Jonathan Wright38ef0d82022-07-07 09:12:11 +0100135
136 int32x4_t sum_diff =
137 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
138 *sum = horizontal_add_s32x4(sum_diff);
139 *sse = horizontal_add_u32x4(sse_u32);
140}
141
142static INLINE void variance_32xh_neon(const uint8_t *src, int src_stride,
143 const uint8_t *ref, int ref_stride, int h,
144 uint32_t *sse, int *sum) {
145 variance_large_neon(src, src_stride, ref, ref_stride, 32, h, sse, sum);
146}
147
148static INLINE void variance_64xh_neon(const uint8_t *src, int src_stride,
149 const uint8_t *ref, int ref_stride, int h,
150 uint32_t *sse, int *sum) {
151 variance_large_neon(src, src_stride, ref, ref_stride, 64, h, sse, sum);
152}
153
154static INLINE void variance_128xh_neon(const uint8_t *src, int src_stride,
155 const uint8_t *ref, int ref_stride,
156 int h, uint32_t *sse, int *sum) {
157 variance_large_neon(src, src_stride, ref, ref_stride, 128, h, sse, sum);
158}
159
160#else // !defined(__ARM_FEATURE_DOTPROD)
161
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100162static INLINE void variance_4xh_neon(const uint8_t *src, int src_stride,
163 const uint8_t *ref, int ref_stride, int h,
164 uint32_t *sse, int *sum) {
165 int16x8_t sum_s16 = vdupq_n_s16(0);
166 int32x4_t sse_s32 = vdupq_n_s32(0);
Yaowu Xuc27fc142016-08-22 16:08:15 -0700167
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100168 // Number of rows we can process before 'sum_s16' overflows:
169 // 32767 / 255 ~= 128, but we use an 8-wide accumulator; so 256 4-wide rows.
170 assert(h <= 256);
171
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100172 int i = h;
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100173 do {
174 uint8x8_t s = load_unaligned_u8(src, src_stride);
175 uint8x8_t r = load_unaligned_u8(ref, ref_stride);
176 int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s, r));
177
178 sum_s16 = vaddq_s16(sum_s16, diff);
179
180 sse_s32 = vmlal_s16(sse_s32, vget_low_s16(diff), vget_low_s16(diff));
181 sse_s32 = vmlal_s16(sse_s32, vget_high_s16(diff), vget_high_s16(diff));
182
183 src += 2 * src_stride;
184 ref += 2 * ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100185 i -= 2;
186 } while (i != 0);
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100187
188 *sum = horizontal_add_s16x8(sum_s16);
189 *sse = (uint32_t)horizontal_add_s32x4(sse_s32);
190}
191
192static INLINE void variance_8xh_neon(const uint8_t *src, int src_stride,
193 const uint8_t *ref, int ref_stride, int h,
194 uint32_t *sse, int *sum) {
195 int16x8_t sum_s16 = vdupq_n_s16(0);
196 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
197
198 // Number of rows we can process before 'sum_s16' overflows:
199 // 32767 / 255 ~= 128
200 assert(h <= 128);
201
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100202 int i = h;
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100203 do {
204 uint8x8_t s = vld1_u8(src);
205 uint8x8_t r = vld1_u8(ref);
206 int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s, r));
207
208 sum_s16 = vaddq_s16(sum_s16, diff);
209
210 sse_s32[0] = vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff));
211 sse_s32[1] =
212 vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff));
213
214 src += src_stride;
215 ref += ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100216 } while (--i != 0);
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100217
218 *sum = horizontal_add_s16x8(sum_s16);
219 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1]));
220}
221
222static INLINE void variance_16xh_neon(const uint8_t *src, int src_stride,
223 const uint8_t *ref, int ref_stride, int h,
224 uint32_t *sse, int *sum) {
225 int16x8_t sum_s16[2] = { vdupq_n_s16(0), vdupq_n_s16(0) };
226 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
227
228 // Number of rows we can process before 'sum_s16' accumulators overflow:
229 // 32767 / 255 ~= 128, so 128 16-wide rows.
230 assert(h <= 128);
231
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100232 int i = h;
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100233 do {
234 uint8x16_t s = vld1q_u8(src);
235 uint8x16_t r = vld1q_u8(ref);
236
237 int16x8_t diff_l =
238 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(s), vget_low_u8(r)));
239 int16x8_t diff_h =
240 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(s), vget_high_u8(r)));
241
242 sum_s16[0] = vaddq_s16(sum_s16[0], diff_l);
243 sum_s16[1] = vaddq_s16(sum_s16[1], diff_h);
244
245 sse_s32[0] =
246 vmlal_s16(sse_s32[0], vget_low_s16(diff_l), vget_low_s16(diff_l));
247 sse_s32[1] =
248 vmlal_s16(sse_s32[1], vget_high_s16(diff_l), vget_high_s16(diff_l));
249 sse_s32[0] =
250 vmlal_s16(sse_s32[0], vget_low_s16(diff_h), vget_low_s16(diff_h));
251 sse_s32[1] =
252 vmlal_s16(sse_s32[1], vget_high_s16(diff_h), vget_high_s16(diff_h));
253
254 src += src_stride;
255 ref += ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100256 } while (--i != 0);
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100257
258 *sum = horizontal_add_s16x8(vaddq_s16(sum_s16[0], sum_s16[1]));
259 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1]));
260}
261
262static INLINE void variance_large_neon(const uint8_t *src, int src_stride,
263 const uint8_t *ref, int ref_stride,
264 int w, int h, int h_limit, uint32_t *sse,
265 int *sum) {
266 int32x4_t sum_s32 = vdupq_n_s32(0);
267 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
268
269 // 'h_limit' is the number of 'w'-width rows we can process before our 16-bit
270 // accumulator overflows. After hitting this limit we accumulate into 32-bit
271 // elements.
272 int h_tmp = h > h_limit ? h_limit : h;
273
274 int i = 0;
275 do {
276 int16x8_t sum_s16[2] = { vdupq_n_s16(0), vdupq_n_s16(0) };
277 do {
278 int j = 0;
279 do {
280 uint8x16_t s = vld1q_u8(src + j);
281 uint8x16_t r = vld1q_u8(ref + j);
282
283 int16x8_t diff_l =
284 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(s), vget_low_u8(r)));
285 int16x8_t diff_h =
286 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(s), vget_high_u8(r)));
287
288 sum_s16[0] = vaddq_s16(sum_s16[0], diff_l);
289 sum_s16[1] = vaddq_s16(sum_s16[1], diff_h);
290
291 sse_s32[0] =
292 vmlal_s16(sse_s32[0], vget_low_s16(diff_l), vget_low_s16(diff_l));
293 sse_s32[1] =
294 vmlal_s16(sse_s32[1], vget_high_s16(diff_l), vget_high_s16(diff_l));
295 sse_s32[0] =
296 vmlal_s16(sse_s32[0], vget_low_s16(diff_h), vget_low_s16(diff_h));
297 sse_s32[1] =
298 vmlal_s16(sse_s32[1], vget_high_s16(diff_h), vget_high_s16(diff_h));
299
300 j += 16;
301 } while (j < w);
302
303 src += src_stride;
304 ref += ref_stride;
305 i++;
306 } while (i < h_tmp);
307
308 sum_s32 = vpadalq_s16(sum_s32, sum_s16[0]);
309 sum_s32 = vpadalq_s16(sum_s32, sum_s16[1]);
310
311 h_tmp += h_limit;
312 } while (i < h);
313
314 *sum = horizontal_add_s32x4(sum_s32);
315 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1]));
316}
317
318static INLINE void variance_32xh_neon(const uint8_t *src, int src_stride,
319 const uint8_t *ref, int ref_stride, int h,
320 uint32_t *sse, int *sum) {
321 variance_large_neon(src, src_stride, ref, ref_stride, 32, h, 64, sse, sum);
322}
323
324static INLINE void variance_64xh_neon(const uint8_t *src, int src_stride,
325 const uint8_t *ref, int ref_stride, int h,
326 uint32_t *sse, int *sum) {
327 variance_large_neon(src, src_stride, ref, ref_stride, 64, h, 32, sse, sum);
328}
329
330static INLINE void variance_128xh_neon(const uint8_t *src, int src_stride,
331 const uint8_t *ref, int ref_stride,
332 int h, uint32_t *sse, int *sum) {
333 variance_large_neon(src, src_stride, ref, ref_stride, 128, h, 16, sse, sum);
334}
335
Jonathan Wright38ef0d82022-07-07 09:12:11 +0100336#endif // defined(__ARM_FEATURE_DOTPROD)
337
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100338#define VARIANCE_WXH_NEON(w, h, shift) \
339 unsigned int aom_variance##w##x##h##_neon( \
340 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
341 unsigned int *sse) { \
342 int sum; \
343 variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, sse, &sum); \
344 return *sse - (uint32_t)(((int64_t)sum * sum) >> shift); \
Yaowu Xuc27fc142016-08-22 16:08:15 -0700345 }
346
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100347VARIANCE_WXH_NEON(4, 4, 4)
348VARIANCE_WXH_NEON(4, 8, 5)
Jonathan Wright9491a132022-07-08 15:04:36 +0100349VARIANCE_WXH_NEON(4, 16, 6)
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100350
351VARIANCE_WXH_NEON(8, 4, 5)
352VARIANCE_WXH_NEON(8, 8, 6)
353VARIANCE_WXH_NEON(8, 16, 7)
Jonathan Wright9491a132022-07-08 15:04:36 +0100354VARIANCE_WXH_NEON(8, 32, 8)
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100355
Jonathan Wright9491a132022-07-08 15:04:36 +0100356VARIANCE_WXH_NEON(16, 4, 6)
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100357VARIANCE_WXH_NEON(16, 8, 7)
358VARIANCE_WXH_NEON(16, 16, 8)
359VARIANCE_WXH_NEON(16, 32, 9)
Jonathan Wright9491a132022-07-08 15:04:36 +0100360VARIANCE_WXH_NEON(16, 64, 10)
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100361
Jonathan Wright9491a132022-07-08 15:04:36 +0100362VARIANCE_WXH_NEON(32, 8, 8)
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100363VARIANCE_WXH_NEON(32, 16, 9)
364VARIANCE_WXH_NEON(32, 32, 10)
365VARIANCE_WXH_NEON(32, 64, 11)
366
Jonathan Wright9491a132022-07-08 15:04:36 +0100367VARIANCE_WXH_NEON(64, 16, 10)
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100368VARIANCE_WXH_NEON(64, 32, 11)
369VARIANCE_WXH_NEON(64, 64, 12)
370VARIANCE_WXH_NEON(64, 128, 13)
371
372VARIANCE_WXH_NEON(128, 64, 13)
373VARIANCE_WXH_NEON(128, 128, 14)
374
375#undef VARIANCE_WXH_NEON
376
Wan-Teh Chang50f93772022-03-30 18:23:50 -0700377// TODO(yunqingwang): Perform variance of two/four 8x8 blocks similar to that of
Diksha Singh369372e2022-10-21 07:58:27 +0530378// AVX2. Also, implement the NEON for variance computation present in this
379// function.
380void aom_get_var_sse_sum_8x8_quad_neon(const uint8_t *src, int src_stride,
381 const uint8_t *ref, int ref_stride,
382 uint32_t *sse8x8, int *sum8x8,
383 unsigned int *tot_sse, int *tot_sum,
384 uint32_t *var8x8) {
venkat sanampudicc538222022-03-07 12:57:37 +0530385 // Loop over 4 8x8 blocks. Process one 8x32 block.
386 for (int k = 0; k < 4; k++) {
Jonathan Wrightb96bb3a2022-07-06 14:45:37 +0100387 variance_8xh_neon(src + (k * 8), src_stride, ref + (k * 8), ref_stride, 8,
Diksha Singh369372e2022-10-21 07:58:27 +0530388 &sse8x8[k], &sum8x8[k]);
venkat sanampudicc538222022-03-07 12:57:37 +0530389 }
Diksha Singh369372e2022-10-21 07:58:27 +0530390
391 *tot_sse += sse8x8[0] + sse8x8[1] + sse8x8[2] + sse8x8[3];
392 *tot_sum += sum8x8[0] + sum8x8[1] + sum8x8[2] + sum8x8[3];
393 for (int i = 0; i < 4; i++)
394 var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6);
venkat sanampudicc538222022-03-07 12:57:37 +0530395}
396
Anupam Pandey1458e672022-12-26 10:39:31 +0530397void aom_get_var_sse_sum_16x16_dual_neon(const uint8_t *src, int src_stride,
398 const uint8_t *ref, int ref_stride,
399 uint32_t *sse16x16,
400 unsigned int *tot_sse, int *tot_sum,
401 uint32_t *var16x16) {
402 int sum16x16[2] = { 0 };
403 // Loop over 2 16x16 blocks. Process one 16x32 block.
404 for (int k = 0; k < 2; k++) {
405 variance_16xh_neon(src + (k * 16), src_stride, ref + (k * 16), ref_stride,
406 16, &sse16x16[k], &sum16x16[k]);
407 }
408
409 *tot_sse += sse16x16[0] + sse16x16[1];
410 *tot_sum += sum16x16[0] + sum16x16[1];
411 for (int i = 0; i < 2; i++)
412 var16x16[i] =
413 sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8);
414}
415
Jonathan Wrightd799e182022-07-11 09:54:28 +0100416#if defined(__ARM_FEATURE_DOTPROD)
417
418static INLINE unsigned int mse8xh_neon(const uint8_t *src, int src_stride,
419 const uint8_t *ref, int ref_stride,
420 unsigned int *sse, int h) {
421 uint32x4_t sse_u32 = vdupq_n_u32(0);
422
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100423 int i = h;
Jonathan Wrightd799e182022-07-11 09:54:28 +0100424 do {
425 uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
426 uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
427
428 uint8x16_t abs_diff = vabdq_u8(s, r);
429
430 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
431
432 src += 2 * src_stride;
433 ref += 2 * ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100434 i -= 2;
435 } while (i != 0);
Jonathan Wrightd799e182022-07-11 09:54:28 +0100436
437 *sse = horizontal_add_u32x4(sse_u32);
438 return horizontal_add_u32x4(sse_u32);
439}
440
441static INLINE unsigned int mse16xh_neon(const uint8_t *src, int src_stride,
442 const uint8_t *ref, int ref_stride,
443 unsigned int *sse, int h) {
444 uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
445
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100446 int i = h;
Jonathan Wrightd799e182022-07-11 09:54:28 +0100447 do {
448 uint8x16_t s0 = vld1q_u8(src);
449 uint8x16_t s1 = vld1q_u8(src + src_stride);
450 uint8x16_t r0 = vld1q_u8(ref);
451 uint8x16_t r1 = vld1q_u8(ref + ref_stride);
452
453 uint8x16_t abs_diff0 = vabdq_u8(s0, r0);
454 uint8x16_t abs_diff1 = vabdq_u8(s1, r1);
455
456 sse_u32[0] = vdotq_u32(sse_u32[0], abs_diff0, abs_diff0);
457 sse_u32[1] = vdotq_u32(sse_u32[1], abs_diff1, abs_diff1);
458
459 src += 2 * src_stride;
460 ref += 2 * ref_stride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100461 i -= 2;
462 } while (i != 0);
Jonathan Wrightd799e182022-07-11 09:54:28 +0100463
464 *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
465 return horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
466}
467
Jonathan Wrightd799e182022-07-11 09:54:28 +0100468#else // !defined(__ARM_FEATURE_DOTPROD)
469
Jonathan Wright699fa472022-07-11 09:03:02 +0100470static INLINE unsigned int mse8xh_neon(const uint8_t *src, int src_stride,
471 const uint8_t *ref, int ref_stride,
472 unsigned int *sse, int h) {
473 uint8x8_t s[2], r[2];
474 int16x4_t diff_lo[2], diff_hi[2];
475 uint16x8_t diff[2];
476 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
477
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100478 int i = h;
Jonathan Wright699fa472022-07-11 09:03:02 +0100479 do {
480 s[0] = vld1_u8(src);
481 src += src_stride;
482 s[1] = vld1_u8(src);
483 src += src_stride;
484 r[0] = vld1_u8(ref);
485 ref += ref_stride;
486 r[1] = vld1_u8(ref);
487 ref += ref_stride;
488
489 diff[0] = vsubl_u8(s[0], r[0]);
490 diff[1] = vsubl_u8(s[1], r[1]);
491
492 diff_lo[0] = vreinterpret_s16_u16(vget_low_u16(diff[0]));
493 diff_lo[1] = vreinterpret_s16_u16(vget_low_u16(diff[1]));
494 sse_s32[0] = vmlal_s16(sse_s32[0], diff_lo[0], diff_lo[0]);
495 sse_s32[1] = vmlal_s16(sse_s32[1], diff_lo[1], diff_lo[1]);
496
497 diff_hi[0] = vreinterpret_s16_u16(vget_high_u16(diff[0]));
498 diff_hi[1] = vreinterpret_s16_u16(vget_high_u16(diff[1]));
499 sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]);
500 sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]);
501
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100502 i -= 2;
503 } while (i != 0);
Jonathan Wright699fa472022-07-11 09:03:02 +0100504
505 sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]);
506
507 *sse = horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0]));
508 return horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0]));
509}
510
511static INLINE unsigned int mse16xh_neon(const uint8_t *src, int src_stride,
512 const uint8_t *ref, int ref_stride,
513 unsigned int *sse, int h) {
Jonathan Wright848be472022-07-10 22:20:23 +0100514 uint8x16_t s[2], r[2];
515 int16x4_t diff_lo[4], diff_hi[4];
516 uint16x8_t diff[4];
517 int32x4_t sse_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
518 vdupq_n_s32(0) };
Yaowu Xuc27fc142016-08-22 16:08:15 -0700519
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100520 int i = h;
Jonathan Wright848be472022-07-10 22:20:23 +0100521 do {
522 s[0] = vld1q_u8(src);
523 src += src_stride;
524 s[1] = vld1q_u8(src);
525 src += src_stride;
526 r[0] = vld1q_u8(ref);
527 ref += ref_stride;
528 r[1] = vld1q_u8(ref);
529 ref += ref_stride;
Yaowu Xuc27fc142016-08-22 16:08:15 -0700530
Jonathan Wright848be472022-07-10 22:20:23 +0100531 diff[0] = vsubl_u8(vget_low_u8(s[0]), vget_low_u8(r[0]));
532 diff[1] = vsubl_u8(vget_high_u8(s[0]), vget_high_u8(r[0]));
533 diff[2] = vsubl_u8(vget_low_u8(s[1]), vget_low_u8(r[1]));
534 diff[3] = vsubl_u8(vget_high_u8(s[1]), vget_high_u8(r[1]));
Yaowu Xuc27fc142016-08-22 16:08:15 -0700535
Jonathan Wright848be472022-07-10 22:20:23 +0100536 diff_lo[0] = vreinterpret_s16_u16(vget_low_u16(diff[0]));
537 diff_lo[1] = vreinterpret_s16_u16(vget_low_u16(diff[1]));
538 sse_s32[0] = vmlal_s16(sse_s32[0], diff_lo[0], diff_lo[0]);
539 sse_s32[1] = vmlal_s16(sse_s32[1], diff_lo[1], diff_lo[1]);
Yaowu Xuc27fc142016-08-22 16:08:15 -0700540
Jonathan Wright848be472022-07-10 22:20:23 +0100541 diff_lo[2] = vreinterpret_s16_u16(vget_low_u16(diff[2]));
542 diff_lo[3] = vreinterpret_s16_u16(vget_low_u16(diff[3]));
543 sse_s32[2] = vmlal_s16(sse_s32[2], diff_lo[2], diff_lo[2]);
544 sse_s32[3] = vmlal_s16(sse_s32[3], diff_lo[3], diff_lo[3]);
Yaowu Xuc27fc142016-08-22 16:08:15 -0700545
Jonathan Wright848be472022-07-10 22:20:23 +0100546 diff_hi[0] = vreinterpret_s16_u16(vget_high_u16(diff[0]));
547 diff_hi[1] = vreinterpret_s16_u16(vget_high_u16(diff[1]));
548 sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]);
549 sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]);
Yaowu Xuc27fc142016-08-22 16:08:15 -0700550
Jonathan Wright848be472022-07-10 22:20:23 +0100551 diff_hi[2] = vreinterpret_s16_u16(vget_high_u16(diff[2]));
552 diff_hi[3] = vreinterpret_s16_u16(vget_high_u16(diff[3]));
553 sse_s32[2] = vmlal_s16(sse_s32[2], diff_hi[2], diff_hi[2]);
554 sse_s32[3] = vmlal_s16(sse_s32[3], diff_hi[3], diff_hi[3]);
Yaowu Xuc27fc142016-08-22 16:08:15 -0700555
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100556 i -= 2;
557 } while (i != 0);
Yaowu Xuc27fc142016-08-22 16:08:15 -0700558
Jonathan Wright848be472022-07-10 22:20:23 +0100559 sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]);
560 sse_s32[2] = vaddq_s32(sse_s32[2], sse_s32[3]);
561 sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[2]);
Yaowu Xuc27fc142016-08-22 16:08:15 -0700562
Jonathan Wright848be472022-07-10 22:20:23 +0100563 *sse = horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0]));
564 return horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0]));
Yaowu Xuc27fc142016-08-22 16:08:15 -0700565}
566
Jonathan Wrightd799e182022-07-11 09:54:28 +0100567#endif // defined(__ARM_FEATURE_DOTPROD)
568
569#define MSE_WXH_NEON(w, h) \
570 unsigned int aom_mse##w##x##h##_neon(const uint8_t *src, int src_stride, \
571 const uint8_t *ref, int ref_stride, \
572 unsigned int *sse) { \
573 return mse##w##xh_neon(src, src_stride, ref, ref_stride, sse, h); \
574 }
575
576MSE_WXH_NEON(8, 8)
577MSE_WXH_NEON(8, 16)
578
579MSE_WXH_NEON(16, 8)
580MSE_WXH_NEON(16, 16)
581
582#undef MSE_WXH_NEON
Diksha Singh19a6c212022-10-12 14:02:33 +0530583
584#define COMPUTE_MSE_16BIT(src_16x8, dst_16x8) \
585 /* r7 r6 r5 r4 r3 r2 r1 r0 - 16 bit */ \
586 const uint16x8_t diff = vabdq_u16(src_16x8, dst_16x8); \
587 /*r3 r2 r1 r0 - 16 bit */ \
588 const uint16x4_t res0_low_16x4 = vget_low_u16(diff); \
589 /*r7 r6 r5 r4 - 16 bit */ \
590 const uint16x4_t res0_high_16x4 = vget_high_u16(diff); \
591 /* (r3*r3)= b3 (r2*r2)= b2 (r1*r1)= b1 (r0*r0)= b0 - 32 bit */ \
592 const uint32x4_t res0_32x4 = vmull_u16(res0_low_16x4, res0_low_16x4); \
593 /* (r7*r7)= b7 (r6*r6)= b6 (r5*r5)= b5 (r4*r4)= b4 - 32 bit*/ \
594 /* b3+b7 b2+b6 b1+b5 b0+b4 - 32 bit*/ \
595 const uint32x4_t res_32x4 = \
596 vmlal_u16(res0_32x4, res0_high_16x4, res0_high_16x4); \
597 \
598 /*a1 a0 - 64 bit*/ \
599 const uint64x2_t vl = vpaddlq_u32(res_32x4); \
600 /*a1+a2= f1 a3+a0= f0*/ \
601 square_result = vaddq_u64(square_result, vl);
602
603static AOM_INLINE uint64_t mse_4xh_16bit_neon(uint8_t *dst, int dstride,
604 uint16_t *src, int sstride,
605 int h) {
606 uint64x2_t square_result = vdupq_n_u64(0);
607 uint32_t d0, d1;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100608 int i = h;
Diksha Singh19a6c212022-10-12 14:02:33 +0530609 uint8_t *dst_ptr = dst;
610 uint16_t *src_ptr = src;
611 do {
612 // d03 d02 d01 d00 - 8 bit
613 memcpy(&d0, dst_ptr, 4);
614 dst_ptr += dstride;
615 // d13 d12 d11 d10 - 8 bit
616 memcpy(&d1, dst_ptr, 4);
617 dst_ptr += dstride;
618 // duplication
619 uint8x8_t tmp0_8x8 = vreinterpret_u8_u32(vdup_n_u32(d0));
620 // d03 d02 d01 d00 - 16 bit
621 const uint16x4_t dst0_16x4 = vget_low_u16(vmovl_u8(tmp0_8x8));
622 // duplication
623 tmp0_8x8 = vreinterpret_u8_u32(vdup_n_u32(d1));
624 // d13 d12 d11 d10 - 16 bit
625 const uint16x4_t dst1_16x4 = vget_low_u16(vmovl_u8(tmp0_8x8));
626 // d13 d12 d11 d10 d03 d02 d01 d00 - 16 bit
627 const uint16x8_t dst_16x8 = vcombine_u16(dst0_16x4, dst1_16x4);
628
629 // b1r0 - s03 s02 s01 s00 - 16 bit
630 const uint16x4_t src0_16x4 = vld1_u16(src_ptr);
631 src_ptr += sstride;
632 // b1r1 - s13 s12 s11 s10 - 16 bit
633 const uint16x4_t src1_16x4 = vld1_u16(src_ptr);
634 src_ptr += sstride;
635 // s13 s12 s11 s10 s03 s02 s01 s00 - 16 bit
636 const uint16x8_t src_16x8 = vcombine_u16(src0_16x4, src1_16x4);
637
638 COMPUTE_MSE_16BIT(src_16x8, dst_16x8)
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100639 i -= 2;
640 } while (i != 0);
Diksha Singh19a6c212022-10-12 14:02:33 +0530641 uint64x1_t sum =
642 vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result));
643 return vget_lane_u64(sum, 0);
644}
645
646static AOM_INLINE uint64_t mse_8xh_16bit_neon(uint8_t *dst, int dstride,
647 uint16_t *src, int sstride,
648 int h) {
649 uint64x2_t square_result = vdupq_n_u64(0);
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100650 int i = h;
Diksha Singh19a6c212022-10-12 14:02:33 +0530651 do {
652 // d7 d6 d5 d4 d3 d2 d1 d0 - 8 bit
James Zernaf89b292022-12-16 11:21:54 -0800653 const uint16x8_t dst_16x8 = vmovl_u8(vld1_u8(dst));
Diksha Singh19a6c212022-10-12 14:02:33 +0530654 // s7 s6 s5 s4 s3 s2 s1 s0 - 16 bit
James Zernaf89b292022-12-16 11:21:54 -0800655 const uint16x8_t src_16x8 = vld1q_u16(src);
Diksha Singh19a6c212022-10-12 14:02:33 +0530656
657 COMPUTE_MSE_16BIT(src_16x8, dst_16x8)
James Zernaf89b292022-12-16 11:21:54 -0800658
659 dst += dstride;
660 src += sstride;
Gerda Zsejke More680652b2022-11-24 14:47:42 +0100661 } while (--i != 0);
Diksha Singh19a6c212022-10-12 14:02:33 +0530662 uint64x1_t sum =
663 vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result));
664 return vget_lane_u64(sum, 0);
665}
666
667// Computes mse for a given block size. This function gets called for specific
668// block sizes, which are 8x8, 8x4, 4x8 and 4x4.
669uint64_t aom_mse_wxh_16bit_neon(uint8_t *dst, int dstride, uint16_t *src,
670 int sstride, int w, int h) {
671 assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
672 "w=8/4 and h=8/4 must satisfy");
673 switch (w) {
674 case 4: return mse_4xh_16bit_neon(dst, dstride, src, sstride, h);
675 case 8: return mse_8xh_16bit_neon(dst, dstride, src, sstride, h);
676 default: assert(0 && "unsupported width"); return -1;
677 }
678}