blob: 9d263105e3bc292cfc8da35366ea3ce728228191 [file] [log] [blame]
George Steed25876b72023-01-04 16:59:50 +00001/*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12#include <arm_neon.h>
13
14#include "config/aom_config.h"
15#include "config/aom_dsp_rtcd.h"
16
17#include "aom/aom_integer.h"
Gerda Zsejke Moreb792e1c2023-07-22 12:17:47 +020018#include "aom_dsp/arm/blend_neon.h"
19#include "aom_dsp/arm/mem_neon.h"
20#include "aom_dsp/arm/sum_neon.h"
George Steed25876b72023-01-04 16:59:50 +000021#include "aom_dsp/blend.h"
George Steed25876b72023-01-04 16:59:50 +000022
23static INLINE uint16x8_t masked_sad_16x1_neon(uint16x8_t sad,
24 const uint8_t *src,
25 const uint8_t *a,
26 const uint8_t *b,
27 const uint8_t *m) {
28 uint8x16_t m0 = vld1q_u8(m);
29 uint8x16_t a0 = vld1q_u8(a);
30 uint8x16_t b0 = vld1q_u8(b);
31 uint8x16_t s0 = vld1q_u8(src);
32
Gerda Zsejke Moreb792e1c2023-07-22 12:17:47 +020033 uint8x16_t blend_u8 = alpha_blend_a64_u8x16(m0, a0, b0);
George Steed25876b72023-01-04 16:59:50 +000034
35 return vpadalq_u8(sad, vabdq_u8(blend_u8, s0));
36}
37
38static INLINE unsigned masked_sad_128xh_neon(const uint8_t *src, int src_stride,
39 const uint8_t *a, int a_stride,
40 const uint8_t *b, int b_stride,
41 const uint8_t *m, int m_stride,
42 int height) {
43 // Eight accumulator vectors are required to avoid overflow in the 128x128
44 // case.
45 assert(height <= 128);
46 uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
47 vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
48 vdupq_n_u16(0), vdupq_n_u16(0) };
49
50 do {
51 sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]);
52 sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]);
53 sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]);
54 sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]);
55 sad[4] = masked_sad_16x1_neon(sad[4], &src[64], &a[64], &b[64], &m[64]);
56 sad[5] = masked_sad_16x1_neon(sad[5], &src[80], &a[80], &b[80], &m[80]);
57 sad[6] = masked_sad_16x1_neon(sad[6], &src[96], &a[96], &b[96], &m[96]);
58 sad[7] = masked_sad_16x1_neon(sad[7], &src[112], &a[112], &b[112], &m[112]);
59
60 src += src_stride;
61 a += a_stride;
62 b += b_stride;
63 m += m_stride;
64 height--;
65 } while (height != 0);
66
67 return horizontal_long_add_u16x8(sad[0], sad[1]) +
68 horizontal_long_add_u16x8(sad[2], sad[3]) +
69 horizontal_long_add_u16x8(sad[4], sad[5]) +
70 horizontal_long_add_u16x8(sad[6], sad[7]);
71}
72
73static INLINE unsigned masked_sad_64xh_neon(const uint8_t *src, int src_stride,
74 const uint8_t *a, int a_stride,
75 const uint8_t *b, int b_stride,
76 const uint8_t *m, int m_stride,
77 int height) {
78 // Four accumulator vectors are required to avoid overflow in the 64x128 case.
79 assert(height <= 128);
80 uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
81 vdupq_n_u16(0) };
82
83 do {
84 sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]);
85 sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]);
86 sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]);
87 sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]);
88
89 src += src_stride;
90 a += a_stride;
91 b += b_stride;
92 m += m_stride;
93 height--;
94 } while (height != 0);
95
96 return horizontal_long_add_u16x8(sad[0], sad[1]) +
97 horizontal_long_add_u16x8(sad[2], sad[3]);
98}
99
100static INLINE unsigned masked_sad_32xh_neon(const uint8_t *src, int src_stride,
101 const uint8_t *a, int a_stride,
102 const uint8_t *b, int b_stride,
103 const uint8_t *m, int m_stride,
104 int height) {
105 // We could use a single accumulator up to height=64 without overflow.
106 assert(height <= 64);
107 uint16x8_t sad = vdupq_n_u16(0);
108
109 do {
110 sad = masked_sad_16x1_neon(sad, &src[0], &a[0], &b[0], &m[0]);
111 sad = masked_sad_16x1_neon(sad, &src[16], &a[16], &b[16], &m[16]);
112
113 src += src_stride;
114 a += a_stride;
115 b += b_stride;
116 m += m_stride;
117 height--;
118 } while (height != 0);
119
120 return horizontal_add_u16x8(sad);
121}
122
123static INLINE unsigned masked_sad_16xh_neon(const uint8_t *src, int src_stride,
124 const uint8_t *a, int a_stride,
125 const uint8_t *b, int b_stride,
126 const uint8_t *m, int m_stride,
127 int height) {
128 // We could use a single accumulator up to height=128 without overflow.
129 assert(height <= 128);
130 uint16x8_t sad = vdupq_n_u16(0);
131
132 do {
133 sad = masked_sad_16x1_neon(sad, src, a, b, m);
134
135 src += src_stride;
136 a += a_stride;
137 b += b_stride;
138 m += m_stride;
139 height--;
140 } while (height != 0);
141
142 return horizontal_add_u16x8(sad);
143}
144
145static INLINE unsigned masked_sad_8xh_neon(const uint8_t *src, int src_stride,
146 const uint8_t *a, int a_stride,
147 const uint8_t *b, int b_stride,
148 const uint8_t *m, int m_stride,
149 int height) {
150 // We could use a single accumulator up to height=128 without overflow.
151 assert(height <= 128);
152 uint16x4_t sad = vdup_n_u16(0);
153
154 do {
155 uint8x8_t m0 = vld1_u8(m);
156 uint8x8_t a0 = vld1_u8(a);
157 uint8x8_t b0 = vld1_u8(b);
158 uint8x8_t s0 = vld1_u8(src);
159
Gerda Zsejke Moreb792e1c2023-07-22 12:17:47 +0200160 uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, a0, b0);
George Steed25876b72023-01-04 16:59:50 +0000161
162 sad = vpadal_u8(sad, vabd_u8(blend_u8, s0));
163
164 src += src_stride;
165 a += a_stride;
166 b += b_stride;
167 m += m_stride;
168 height--;
169 } while (height != 0);
170
171 return horizontal_add_u16x4(sad);
172}
173
174static INLINE unsigned masked_sad_4xh_neon(const uint8_t *src, int src_stride,
175 const uint8_t *a, int a_stride,
176 const uint8_t *b, int b_stride,
177 const uint8_t *m, int m_stride,
178 int height) {
179 // Process two rows per loop iteration.
180 assert(height % 2 == 0);
181
182 // We could use a single accumulator up to height=256 without overflow.
183 assert(height <= 256);
184 uint16x4_t sad = vdup_n_u16(0);
185
186 do {
187 uint8x8_t m0 = load_unaligned_u8(m, m_stride);
188 uint8x8_t a0 = load_unaligned_u8(a, a_stride);
189 uint8x8_t b0 = load_unaligned_u8(b, b_stride);
190 uint8x8_t s0 = load_unaligned_u8(src, src_stride);
191
Gerda Zsejke Moreb792e1c2023-07-22 12:17:47 +0200192 uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, a0, b0);
George Steed25876b72023-01-04 16:59:50 +0000193
194 sad = vpadal_u8(sad, vabd_u8(blend_u8, s0));
195
196 src += 2 * src_stride;
197 a += 2 * a_stride;
198 b += 2 * b_stride;
199 m += 2 * m_stride;
200 height -= 2;
201 } while (height != 0);
202
203 return horizontal_add_u16x4(sad);
204}
205
206#define MASKED_SAD_WXH_NEON(width, height) \
207 unsigned aom_masked_sad##width##x##height##_neon( \
208 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
209 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
210 int invert_mask) { \
211 if (!invert_mask) \
212 return masked_sad_##width##xh_neon(src, src_stride, ref, ref_stride, \
213 second_pred, width, msk, msk_stride, \
214 height); \
215 else \
216 return masked_sad_##width##xh_neon(src, src_stride, second_pred, width, \
217 ref, ref_stride, msk, msk_stride, \
218 height); \
219 }
220
221MASKED_SAD_WXH_NEON(4, 4)
222MASKED_SAD_WXH_NEON(4, 8)
223MASKED_SAD_WXH_NEON(8, 4)
224MASKED_SAD_WXH_NEON(8, 8)
225MASKED_SAD_WXH_NEON(8, 16)
226MASKED_SAD_WXH_NEON(16, 8)
227MASKED_SAD_WXH_NEON(16, 16)
228MASKED_SAD_WXH_NEON(16, 32)
229MASKED_SAD_WXH_NEON(32, 16)
230MASKED_SAD_WXH_NEON(32, 32)
231MASKED_SAD_WXH_NEON(32, 64)
232MASKED_SAD_WXH_NEON(64, 32)
233MASKED_SAD_WXH_NEON(64, 64)
234MASKED_SAD_WXH_NEON(64, 128)
235MASKED_SAD_WXH_NEON(128, 64)
236MASKED_SAD_WXH_NEON(128, 128)
237#if !CONFIG_REALTIME_ONLY
238MASKED_SAD_WXH_NEON(4, 16)
239MASKED_SAD_WXH_NEON(16, 4)
240MASKED_SAD_WXH_NEON(8, 32)
241MASKED_SAD_WXH_NEON(32, 8)
242MASKED_SAD_WXH_NEON(16, 64)
243MASKED_SAD_WXH_NEON(64, 16)
244#endif