blob: 23b651d73222a60ffa5f1f5c71ac6397ddc377f4 [file] [log] [blame]
Imdad Sardharwallac6acc532018-01-03 15:18:24 +00001/*
2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12#include <immintrin.h>
13
14#include "./aom_config.h"
15#include "./av1_rtcd.h"
16#include "av1/common/restoration.h"
17#include "aom_dsp/x86/synonyms.h"
18#include "aom_dsp/x86/synonyms_avx2.h"
19
20// Load 8 bytes from the possibly-misaligned pointer p, extend each byte to
21// 32-bit precision and return them in an AVX2 register.
22static __m256i yy256_load_extend_8_32(const void *p) {
23 return _mm256_cvtepu8_epi32(xx_loadl_64(p));
24}
25
26// Load 8 halfwords from the possibly-misaligned pointer p, extend each
27// halfword to 32-bit precision and return them in an AVX2 register.
28static __m256i yy256_load_extend_16_32(const void *p) {
29 return _mm256_cvtepu16_epi32(xx_loadu_128(p));
30}
31
32// Compute the scan of an AVX2 register holding 8 32-bit integers. If the
33// register holds x0..x7 then the scan will hold x0, x0+x1, x0+x1+x2, ...,
34// x0+x1+...+x7
35//
36// Let [...] represent a 128-bit block, and let a, ..., h be 32-bit integers
37// (assumed small enough to be able to add them without overflow).
38//
39// Use -> as shorthand for summing, i.e. h->a = h + g + f + e + d + c + b + a.
40//
41// x = [h g f e][d c b a]
42// x01 = [g f e 0][c b a 0]
43// x02 = [g+h f+g e+f e][c+d b+c a+b a]
44// x03 = [e+f e 0 0][a+b a 0 0]
45// x04 = [e->h e->g e->f e][a->d a->c a->b a]
46// s = a->d
47// s01 = [a->d a->d a->d a->d]
48// s02 = [a->d a->d a->d a->d][0 0 0 0]
49// ret = [a->h a->g a->f a->e][a->d a->c a->b a]
50static __m256i scan_32(__m256i x) {
51 const __m256i x01 = _mm256_slli_si256(x, 4);
52 const __m256i x02 = _mm256_add_epi32(x, x01);
53 const __m256i x03 = _mm256_slli_si256(x02, 8);
54 const __m256i x04 = _mm256_add_epi32(x02, x03);
55 const int32_t s = _mm256_extract_epi32(x04, 3);
56 const __m128i s01 = _mm_set1_epi32(s);
57 const __m256i s02 = _mm256_insertf128_si256(_mm256_setzero_si256(), s01, 1);
58 return _mm256_add_epi32(x04, s02);
59}
60
61// Compute two integral images from src. B sums elements; A sums their
62// squares. The images are offset by one pixel, so will have width and height
63// equal to width + 1, height + 1 and the first row and column will be zero.
64//
65// A+1 and B+1 should be aligned to 32 bytes. buf_stride should be a multiple
66// of 8.
67static void integral_images(const uint8_t *src, int src_stride, int width,
68 int height, int32_t *A, int32_t *B,
69 int buf_stride) {
70 // Write out the zero top row
71 memset(A, 0, sizeof(*A) * (width + 1));
72 memset(B, 0, sizeof(*B) * (width + 1));
73
74 const __m256i zero = _mm256_setzero_si256();
75 for (int i = 0; i < height; ++i) {
76 // Zero the left column.
77 A[(i + 1) * buf_stride] = B[(i + 1) * buf_stride] = 0;
78
79 // ldiff is the difference H - D where H is the output sample immediately
80 // to the left and D is the output sample above it. These are scalars,
81 // replicated across the eight lanes.
82 __m256i ldiff1 = zero, ldiff2 = zero;
83 for (int j = 0; j < width; j += 8) {
84 const int ABj = 1 + j;
85
86 const __m256i above1 = yy_load_256(B + ABj + i * buf_stride);
87 const __m256i above2 = yy_load_256(A + ABj + i * buf_stride);
88
89 const __m256i x1 = yy256_load_extend_8_32(src + j + i * src_stride);
90 const __m256i x2 = _mm256_madd_epi16(x1, x1);
91
92 const __m256i sc1 = scan_32(x1);
93 const __m256i sc2 = scan_32(x2);
94
95 const __m256i row1 =
96 _mm256_add_epi32(_mm256_add_epi32(sc1, above1), ldiff1);
97 const __m256i row2 =
98 _mm256_add_epi32(_mm256_add_epi32(sc2, above2), ldiff2);
99
100 yy_store_256(B + ABj + (i + 1) * buf_stride, row1);
101 yy_store_256(A + ABj + (i + 1) * buf_stride, row2);
102
103 // Calculate the new H - D.
104 ldiff1 = _mm256_set1_epi32(
105 _mm256_extract_epi32(_mm256_sub_epi32(row1, above1), 7));
106 ldiff2 = _mm256_set1_epi32(
107 _mm256_extract_epi32(_mm256_sub_epi32(row2, above2), 7));
108 }
109 }
110}
111
112// Compute two integral images from src. B sums elements; A sums their squares
113//
114// A and B should be aligned to 32 bytes. buf_stride should be a multiple of 8.
115static void integral_images_highbd(const uint16_t *src, int src_stride,
116 int width, int height, int32_t *A,
117 int32_t *B, int buf_stride) {
118 // Write out the zero top row
119 memset(A, 0, sizeof(*A) * (width + 1));
120 memset(B, 0, sizeof(*B) * (width + 1));
121
122 const __m256i zero = _mm256_setzero_si256();
123 for (int i = 0; i < height; ++i) {
124 // Zero the left column.
125 A[(i + 1) * buf_stride] = B[(i + 1) * buf_stride] = 0;
126
127 // ldiff is the difference H - D where H is the output sample immediately
128 // to the left and D is the output sample above it. These are scalars,
129 // replicated across the eight lanes.
130 __m256i ldiff1 = zero, ldiff2 = zero;
131 for (int j = 0; j < width; j += 8) {
132 const int ABj = 1 + j;
133
134 const __m256i above1 = yy_load_256(B + ABj + i * buf_stride);
135 const __m256i above2 = yy_load_256(A + ABj + i * buf_stride);
136
137 const __m256i x1 = yy256_load_extend_16_32(src + j + i * src_stride);
138 const __m256i x2 = _mm256_madd_epi16(x1, x1);
139
140 const __m256i sc1 = scan_32(x1);
141 const __m256i sc2 = scan_32(x2);
142
143 const __m256i row1 =
144 _mm256_add_epi32(_mm256_add_epi32(sc1, above1), ldiff1);
145 const __m256i row2 =
146 _mm256_add_epi32(_mm256_add_epi32(sc2, above2), ldiff2);
147
148 yy_store_256(B + ABj + (i + 1) * buf_stride, row1);
149 yy_store_256(A + ABj + (i + 1) * buf_stride, row2);
150
151 // Calculate the new H - D.
152 ldiff1 = _mm256_set1_epi32(
153 _mm256_extract_epi32(_mm256_sub_epi32(row1, above1), 7));
154 ldiff2 = _mm256_set1_epi32(
155 _mm256_extract_epi32(_mm256_sub_epi32(row2, above2), 7));
156 }
157 }
158}
159
160// Compute four values of boxsum from the given integral image. ii should point
161// at the middle of the box (for the first value). r is the box radius
162static __m256i boxsum_from_ii(const int32_t *ii, int stride, int r) {
163 const __m256i tl = yy_loadu_256(ii - (r + 1) - (r + 1) * stride);
164 const __m256i tr = yy_loadu_256(ii + (r + 0) - (r + 1) * stride);
165 const __m256i bl = yy_loadu_256(ii - (r + 1) + r * stride);
166 const __m256i br = yy_loadu_256(ii + (r + 0) + r * stride);
167 const __m256i u = _mm256_sub_epi32(tr, tl);
168 const __m256i v = _mm256_sub_epi32(br, bl);
169 return _mm256_sub_epi32(v, u);
170}
171
172static __m256i round_for_shift(unsigned shift) {
173 return _mm256_set1_epi32((1 << shift) >> 1);
174}
175
176static __m256i compute_p(__m256i sum1, __m256i sum2, int bit_depth, int n) {
177 __m256i an, bb;
178 if (bit_depth > 8) {
179 const __m256i rounding_a = round_for_shift(2 * (bit_depth - 8));
180 const __m256i rounding_b = round_for_shift(bit_depth - 8);
181 const __m128i shift_a = _mm_cvtsi32_si128(2 * (bit_depth - 8));
182 const __m128i shift_b = _mm_cvtsi32_si128(bit_depth - 8);
183 const __m256i a =
184 _mm256_srl_epi32(_mm256_add_epi32(sum2, rounding_a), shift_a);
185 const __m256i b =
186 _mm256_srl_epi32(_mm256_add_epi32(sum1, rounding_b), shift_b);
187 // b < 2^14, so we can use a 16-bit madd rather than a 32-bit
188 // mullo to square it
189 bb = _mm256_madd_epi16(b, b);
190 an = _mm256_max_epi32(_mm256_mullo_epi32(a, _mm256_set1_epi32(n)), bb);
191 } else {
192 bb = _mm256_madd_epi16(sum1, sum1);
193 an = _mm256_mullo_epi32(sum2, _mm256_set1_epi32(n));
194 }
195 return _mm256_sub_epi32(an, bb);
196}
197
198// Assumes that C, D are integral images for the original buffer which has been
199// extended to have a padding of SGRPROJ_BORDER_VERT/SGRPROJ_BORDER_HORZ pixels
200// on the sides. A, B, C, D point at logical position (0, 0).
201static void calc_ab(int32_t *A, int32_t *B, const int32_t *C, const int32_t *D,
202 int width, int height, int buf_stride, int eps,
203 int bit_depth, int r) {
204 const int n = (2 * r + 1) * (2 * r + 1);
205 const __m256i s = _mm256_set1_epi32(sgrproj_mtable[eps - 1][n - 1]);
206 // one_over_n[n-1] is 2^12/n, so easily fits in an int16
207 const __m256i one_over_n = _mm256_set1_epi32(one_by_x[n - 1]);
208
209 const __m256i rnd_z = round_for_shift(SGRPROJ_MTABLE_BITS);
210 const __m256i rnd_res = round_for_shift(SGRPROJ_RECIP_BITS);
211
Imdad Sardharwallaf32dabd2018-01-17 13:55:37 +0000212 // Set up masks
213 const __m128i ones32 = _mm_set_epi64x(0, 0xffffffffffffffffULL);
214 __m256i mask[8];
215 for (int idx = 0; idx < 8; idx++) {
216 const __m128i shift = _mm_set_epi64x(0, 8 * (8 - idx));
217 mask[idx] = _mm256_cvtepi8_epi32(_mm_srl_epi64(ones32, shift));
218 }
219
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000220 for (int i = -1; i < height + 1; ++i) {
221 for (int j = -1; j < width + 1; j += 8) {
222 const int32_t *Cij = C + i * buf_stride + j;
223 const int32_t *Dij = D + i * buf_stride + j;
224
Imdad Sardharwallaf32dabd2018-01-17 13:55:37 +0000225 __m256i sum1 = boxsum_from_ii(Dij, buf_stride, r);
226 __m256i sum2 = boxsum_from_ii(Cij, buf_stride, r);
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000227
Imdad Sardharwallaf32dabd2018-01-17 13:55:37 +0000228 // When width + 2 isn't a multiple of 8, sum1 and sum2 will contain
229 // some uninitialised data in their upper words. We use a mask to
230 // ensure that these bits are set to 0.
231 int idx = AOMMIN(8, width + 1 - j);
232 assert(idx >= 1);
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000233
Imdad Sardharwallaf32dabd2018-01-17 13:55:37 +0000234 if (idx < 8) {
235 sum1 = _mm256_and_si256(mask[idx], sum1);
236 sum2 = _mm256_and_si256(mask[idx], sum2);
237 }
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000238
239 const __m256i p = compute_p(sum1, sum2, bit_depth, n);
240
241 const __m256i z = _mm256_min_epi32(
242 _mm256_srli_epi32(_mm256_add_epi32(_mm256_mullo_epi32(p, s), rnd_z),
243 SGRPROJ_MTABLE_BITS),
244 _mm256_set1_epi32(255));
245
246 const __m256i a_res = _mm256_i32gather_epi32(x_by_xplus1, z, 4);
247
248 yy_storeu_256(A + i * buf_stride + j, a_res);
249
250 const __m256i a_complement =
251 _mm256_sub_epi32(_mm256_set1_epi32(SGRPROJ_SGR), a_res);
252
253 // sum1 might have lanes greater than 2^15, so we can't use madd to do
254 // multiplication involving sum1. However, a_complement and one_over_n
255 // are both less than 256, so we can multiply them first.
256 const __m256i a_comp_over_n = _mm256_madd_epi16(a_complement, one_over_n);
257 const __m256i b_int = _mm256_mullo_epi32(a_comp_over_n, sum1);
258 const __m256i b_res = _mm256_srli_epi32(_mm256_add_epi32(b_int, rnd_res),
259 SGRPROJ_RECIP_BITS);
260
261 yy_storeu_256(B + i * buf_stride + j, b_res);
262 }
263 }
264}
265
266// Calculate 4 values of the "cross sum" starting at buf. This is a 3x3 filter
267// where the outer four corners have weight 3 and all other pixels have weight
268// 4.
269//
270// Pixels are indexed as follows:
271// xtl xt xtr
272// xl x xr
273// xbl xb xbr
274//
275// buf points to x
276//
277// fours = xl + xt + xr + xb + x
278// threes = xtl + xtr + xbr + xbl
279// cross_sum = 4 * fours + 3 * threes
280// = 4 * (fours + threes) - threes
281// = (fours + threes) << 2 - threes
282static __m256i cross_sum(const int32_t *buf, int stride) {
283 const __m256i xtl = yy_loadu_256(buf - 1 - stride);
284 const __m256i xt = yy_loadu_256(buf - stride);
285 const __m256i xtr = yy_loadu_256(buf + 1 - stride);
286 const __m256i xl = yy_loadu_256(buf - 1);
287 const __m256i x = yy_loadu_256(buf);
288 const __m256i xr = yy_loadu_256(buf + 1);
289 const __m256i xbl = yy_loadu_256(buf - 1 + stride);
290 const __m256i xb = yy_loadu_256(buf + stride);
291 const __m256i xbr = yy_loadu_256(buf + 1 + stride);
292
293 const __m256i fours = _mm256_add_epi32(
294 xl, _mm256_add_epi32(xt, _mm256_add_epi32(xr, _mm256_add_epi32(xb, x))));
295 const __m256i threes =
296 _mm256_add_epi32(xtl, _mm256_add_epi32(xtr, _mm256_add_epi32(xbr, xbl)));
297
298 return _mm256_sub_epi32(_mm256_slli_epi32(_mm256_add_epi32(fours, threes), 2),
299 threes);
300}
301
302// The final filter for self-guided restoration. Computes a weighted average
303// across A, B with "cross sums" (see cross_sum implementation above)
304static void final_filter(int32_t *dst, int dst_stride, const int32_t *A,
305 const int32_t *B, int buf_stride, const void *dgd8,
306 int dgd_stride, int width, int height, int highbd) {
307 const int nb = 5;
308 const __m256i rounding =
309 round_for_shift(SGRPROJ_SGR_BITS + nb - SGRPROJ_RST_BITS);
310 const uint8_t *dgd_real =
311 highbd ? (const uint8_t *)CONVERT_TO_SHORTPTR(dgd8) : dgd8;
312
313 for (int i = 0; i < height; ++i) {
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000314 for (int j = 0; j < width; j += 8) {
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000315 const __m256i a = cross_sum(A + i * buf_stride + j, buf_stride);
316 const __m256i b = cross_sum(B + i * buf_stride + j, buf_stride);
317
318 const __m128i raw =
319 xx_loadu_128(dgd_real + ((i * dgd_stride + j) << highbd));
320 const __m256i src =
321 highbd ? _mm256_cvtepu16_epi32(raw) : _mm256_cvtepu8_epi32(raw);
322
323 __m256i v = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
324 __m256i w = _mm256_srai_epi32(_mm256_add_epi32(v, rounding),
325 SGRPROJ_SGR_BITS + nb - SGRPROJ_RST_BITS);
326
327 yy_storeu_256(dst + i * dst_stride + j, w);
328 }
329 }
330}
331
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000332#if CONFIG_FAST_SGR
333// Assumes that C, D are integral images for the original buffer which has been
334// extended to have a padding of SGRPROJ_BORDER_VERT/SGRPROJ_BORDER_HORZ pixels
335// on the sides. A, B, C, D point at logical position (0, 0).
336static void calc_ab_fast(int32_t *A, int32_t *B, const int32_t *C,
337 const int32_t *D, int width, int height,
338 int buf_stride, int eps, int bit_depth, int r) {
339 const int n = (2 * r + 1) * (2 * r + 1);
340 const __m256i s = _mm256_set1_epi32(sgrproj_mtable[eps - 1][n - 1]);
341 // one_over_n[n-1] is 2^12/n, so easily fits in an int16
342 const __m256i one_over_n = _mm256_set1_epi32(one_by_x[n - 1]);
343
344 const __m256i rnd_z = round_for_shift(SGRPROJ_MTABLE_BITS);
345 const __m256i rnd_res = round_for_shift(SGRPROJ_RECIP_BITS);
346
347 // Set up masks
348 const __m128i ones32 = _mm_set_epi64x(0, 0xffffffffffffffffULL);
349 __m256i mask[8];
350 for (int idx = 0; idx < 8; idx++) {
351 const __m128i shift = _mm_set_epi64x(0, 8 * (8 - idx));
352 mask[idx] = _mm256_cvtepi8_epi32(_mm_srl_epi64(ones32, shift));
353 }
354
355 for (int i = -1; i < height + 1; i += 2) {
356 for (int j = -1; j < width + 1; j += 8) {
357 const int32_t *Cij = C + i * buf_stride + j;
358 const int32_t *Dij = D + i * buf_stride + j;
359
360 __m256i sum1 = boxsum_from_ii(Dij, buf_stride, r);
361 __m256i sum2 = boxsum_from_ii(Cij, buf_stride, r);
362
363 // When width + 2 isn't a multiple of 8, sum1 and sum2 will contain
364 // some uninitialised data in their upper words. We use a mask to
365 // ensure that these bits are set to 0.
366 int idx = AOMMIN(8, width + 1 - j);
367 assert(idx >= 1);
368
369 if (idx < 8) {
370 sum1 = _mm256_and_si256(mask[idx], sum1);
371 sum2 = _mm256_and_si256(mask[idx], sum2);
372 }
373
374 const __m256i p = compute_p(sum1, sum2, bit_depth, n);
375
376 const __m256i z = _mm256_min_epi32(
377 _mm256_srli_epi32(_mm256_add_epi32(_mm256_mullo_epi32(p, s), rnd_z),
378 SGRPROJ_MTABLE_BITS),
379 _mm256_set1_epi32(255));
380
381 const __m256i a_res = _mm256_i32gather_epi32(x_by_xplus1, z, 4);
382
383 yy_storeu_256(A + i * buf_stride + j, a_res);
384
385 const __m256i a_complement =
386 _mm256_sub_epi32(_mm256_set1_epi32(SGRPROJ_SGR), a_res);
387
388 // sum1 might have lanes greater than 2^15, so we can't use madd to do
389 // multiplication involving sum1. However, a_complement and one_over_n
390 // are both less than 256, so we can multiply them first.
391 const __m256i a_comp_over_n = _mm256_madd_epi16(a_complement, one_over_n);
392 const __m256i b_int = _mm256_mullo_epi32(a_comp_over_n, sum1);
393 const __m256i b_res = _mm256_srli_epi32(_mm256_add_epi32(b_int, rnd_res),
394 SGRPROJ_RECIP_BITS);
395
396 yy_storeu_256(B + i * buf_stride + j, b_res);
397 }
398 }
399}
400
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000401// Calculate 8 values of the "cross sum" starting at buf.
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000402//
403// Pixels are indexed like this:
404// xtl xt xtr
405// - buf -
406// xbl xb xbr
407//
408// Pixels are weighted like this:
409// 5 6 5
410// 0 0 0
411// 5 6 5
412//
413// fives = xtl + xtr + xbl + xbr
414// sixes = xt + xb
415// cross_sum = 6 * sixes + 5 * fives
416// = 5 * (fives + sixes) - sixes
417// = (fives + sixes) << 2 + (fives + sixes) + sixes
418static __m256i cross_sum_fast_even(const int32_t *buf, int stride) {
419 const __m256i xtl = yy_loadu_256(buf - 1 - stride);
420 const __m256i xt = yy_loadu_256(buf - stride);
421 const __m256i xtr = yy_loadu_256(buf + 1 - stride);
422 const __m256i xbl = yy_loadu_256(buf - 1 + stride);
423 const __m256i xb = yy_loadu_256(buf + stride);
424 const __m256i xbr = yy_loadu_256(buf + 1 + stride);
425
426 const __m256i fives =
427 _mm256_add_epi32(xtl, _mm256_add_epi32(xtr, _mm256_add_epi32(xbr, xbl)));
428 const __m256i sixes = _mm256_add_epi32(xt, xb);
429 const __m256i fives_plus_sixes = _mm256_add_epi32(fives, sixes);
430
431 return _mm256_add_epi32(
432 _mm256_add_epi32(_mm256_slli_epi32(fives_plus_sixes, 2),
433 fives_plus_sixes),
434 sixes);
435}
436
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000437// Calculate 8 values of the "cross sum" starting at buf.
438//
439// Pixels are indexed like this:
440// xl x xr
441//
442// Pixels are weighted like this:
443// 5 6 5
444//
445// buf points to x
446//
447// fives = xl + xr
448// sixes = x
449// cross_sum = 5 * fives + 6 * sixes
450// = 4 * (fives + sixes) + (fives + sixes) + sixes
451// = (fives + sixes) << 2 + (fives + sixes) + sixes
452static __m256i cross_sum_fast_odd(const int32_t *buf) {
453 const __m256i xl = yy_loadu_256(buf - 1);
454 const __m256i x = yy_loadu_256(buf);
455 const __m256i xr = yy_loadu_256(buf + 1);
456
457 const __m256i fives = _mm256_add_epi32(xl, xr);
458 const __m256i sixes = x;
459
460 const __m256i fives_plus_sixes = _mm256_add_epi32(fives, sixes);
461
462 return _mm256_add_epi32(
463 _mm256_add_epi32(_mm256_slli_epi32(fives_plus_sixes, 2),
464 fives_plus_sixes),
465 sixes);
466}
467
468// Calculate 8 values of the "cross sum" starting at buf.
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000469//
470// Pixels are indexed like this:
471// xtl xt xtr
472// - - -
473// xl x xr
474// - - -
475// xbl xb xbr
476//
477// Pixels are weighted like this:
478// 3 4 3
479// 0 0 0
480// 14 16 14
481// 0 0 0
482// 3 4 3
483//
484// buf points to x
485//
486// threes = xtl + xtr + xbr + xbl
487// fours = xt + xb
488// fourteens = xl + xr
489// sixteens = x
490// cross_sum = 4 * fours + 3 * threes + 14 * fourteens + 16 * sixteens
491// = 4 * (fours + threes) + 16 * (sixteens + fourteens)
492// - (threes + fourteens) - fourteens
493// = (fours + threes) << 2 + (sixteens + fourteens) << 4
494// - (threes + fourteens) - fourteens
495static __m256i cross_sum_fast_odd_not_last(const int32_t *buf, int stride) {
496 const int two_stride = 2 * stride;
497 const __m256i xtl = yy_loadu_256(buf - 1 - two_stride);
498 const __m256i xt = yy_loadu_256(buf - two_stride);
499 const __m256i xtr = yy_loadu_256(buf + 1 - two_stride);
500 const __m256i xl = yy_loadu_256(buf - 1);
501 const __m256i x = yy_loadu_256(buf);
502 const __m256i xr = yy_loadu_256(buf + 1);
503 const __m256i xbl = yy_loadu_256(buf - 1 + two_stride);
504 const __m256i xb = yy_loadu_256(buf + two_stride);
505 const __m256i xbr = yy_loadu_256(buf + 1 + two_stride);
506
507 const __m256i threes =
508 _mm256_add_epi32(xtl, _mm256_add_epi32(xtr, _mm256_add_epi32(xbr, xbl)));
509 const __m256i fours = _mm256_add_epi32(xt, xb);
510 const __m256i fourteens = _mm256_add_epi32(xl, xr);
511 const __m256i sixteens = x;
512
513 const __m256i fours_plus_threes = _mm256_add_epi32(fours, threes);
514 const __m256i sixteens_plus_fourteens = _mm256_add_epi32(sixteens, fourteens);
515 const __m256i threes_plus_fourteens = _mm256_add_epi32(threes, fourteens);
516
517 return _mm256_sub_epi32(
518 _mm256_sub_epi32(
519 _mm256_add_epi32(_mm256_slli_epi32(fours_plus_threes, 2),
520 _mm256_slli_epi32(sixteens_plus_fourteens, 4)),
521 threes_plus_fourteens),
522 fourteens);
523}
524
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000525// Calculate 8 values of the "cross sum" starting at buf.
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000526//
527// Pixels are indexed like this:
528// xtl xt xtr
529// - - -
530// xl x xr
531//
532// Pixels are weighted like this:
Debargha Mukherjee127b5622018-01-25 13:35:44 -0800533// 4 6 4
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000534// 0 0 0
Debargha Mukherjee127b5622018-01-25 13:35:44 -0800535// 16 18 16
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000536//
537// buf points to x
538//
Debargha Mukherjee127b5622018-01-25 13:35:44 -0800539// fours = xtl + xtr
540// sixes = xt
541// sixteens = xl + xr
542// eighteens = x
543// cross_sum = 4 * fours + 6 * sixes + 16 * sixteens + 18 * eighteens
544// = 4 * (fours + sixes) + 16 * (sixteens + eighteens)
545// + 2 * (sixes + eighteens)
546// = (fours + sixes) << 2 + (sixteens + eighteens) << 4
547// + (sixes + eighteens) << 1
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000548static __m256i cross_sum_fast_odd_last(const int32_t *buf, int stride) {
549 const int two_stride = 2 * stride;
550 const __m256i xtl = yy_loadu_256(buf - 1 - two_stride);
551 const __m256i xt = yy_loadu_256(buf - two_stride);
552 const __m256i xtr = yy_loadu_256(buf + 1 - two_stride);
553 const __m256i xl = yy_loadu_256(buf - 1);
554 const __m256i x = yy_loadu_256(buf);
555 const __m256i xr = yy_loadu_256(buf + 1);
556
Debargha Mukherjee127b5622018-01-25 13:35:44 -0800557 const __m256i fours = _mm256_add_epi32(xtl, xtr);
558 const __m256i sixes = xt;
559 const __m256i sixteens = _mm256_add_epi32(xl, xr);
560 const __m256i eighteens = x;
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000561
Debargha Mukherjee127b5622018-01-25 13:35:44 -0800562 const __m256i fours_plus_sixes = _mm256_add_epi32(fours, sixes);
563 const __m256i sixteens_plus_eighteens = _mm256_add_epi32(sixteens, eighteens);
564 const __m256i sixes_plus_eighteens = _mm256_add_epi32(sixes, eighteens);
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000565
Debargha Mukherjee127b5622018-01-25 13:35:44 -0800566 return _mm256_add_epi32(
567 _mm256_add_epi32(_mm256_slli_epi32(fours_plus_sixes, 2),
568 _mm256_slli_epi32(sixteens_plus_eighteens, 4)),
569 _mm256_slli_epi32(sixes_plus_eighteens, 1));
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000570}
571
572// The final filter for selfguided restoration. Computes a weighted average
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000573// across A, B with "cross sums" (see cross_sum_... implementations above).
574// Designed for the first vertical sub-sampling version of FAST_SGR.
575static void final_filter_fast1(int32_t *dst, int dst_stride, const int32_t *A,
576 const int32_t *B, int buf_stride,
577 const void *dgd8, int dgd_stride, int width,
578 int height, int highbd) {
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000579 const int nb0 = 5;
580 const int nb1 = 6;
581
582 const __m256i rounding0 =
583 round_for_shift(SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS);
584 const __m256i rounding1 =
585 round_for_shift(SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
586
587 const uint8_t *dgd_real =
588 highbd ? (const uint8_t *)CONVERT_TO_SHORTPTR(dgd8) : dgd8;
589
590 for (int i = 0; i < height; ++i) {
591 if (!(i & 1)) { // even row
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000592 for (int j = 0; j < width; j += 8) {
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000593 const __m256i a =
594 cross_sum_fast_even(A + i * buf_stride + j, buf_stride);
595 const __m256i b =
596 cross_sum_fast_even(B + i * buf_stride + j, buf_stride);
597
598 const __m128i raw =
599 xx_loadu_128(dgd_real + ((i * dgd_stride + j) << highbd));
600 const __m256i src =
601 highbd ? _mm256_cvtepu16_epi32(raw) : _mm256_cvtepu8_epi32(raw);
602
603 __m256i v = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
604 __m256i w =
605 _mm256_srai_epi32(_mm256_add_epi32(v, rounding0),
606 SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS);
607
608 yy_storeu_256(dst + i * dst_stride + j, w);
609 }
610 } else if (i != height - 1) { // odd row and not last
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000611 for (int j = 0; j < width; j += 8) {
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000612 const __m256i a =
613 cross_sum_fast_odd_not_last(A + i * buf_stride + j, buf_stride);
614 const __m256i b =
615 cross_sum_fast_odd_not_last(B + i * buf_stride + j, buf_stride);
616
617 const __m128i raw =
618 xx_loadu_128(dgd_real + ((i * dgd_stride + j) << highbd));
619 const __m256i src =
620 highbd ? _mm256_cvtepu16_epi32(raw) : _mm256_cvtepu8_epi32(raw);
621
622 __m256i v = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
623 __m256i w =
624 _mm256_srai_epi32(_mm256_add_epi32(v, rounding1),
625 SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
626
627 yy_storeu_256(dst + i * dst_stride + j, w);
628 }
629 } else { // odd row and last
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000630 for (int j = 0; j < width; j += 8) {
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000631 const __m256i a =
632 cross_sum_fast_odd_last(A + i * buf_stride + j, buf_stride);
633 const __m256i b =
634 cross_sum_fast_odd_last(B + i * buf_stride + j, buf_stride);
635
636 const __m128i raw =
637 xx_loadu_128(dgd_real + ((i * dgd_stride + j) << highbd));
638 const __m256i src =
639 highbd ? _mm256_cvtepu16_epi32(raw) : _mm256_cvtepu8_epi32(raw);
640
641 __m256i v = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
642 __m256i w =
643 _mm256_srai_epi32(_mm256_add_epi32(v, rounding1),
644 SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
645
646 yy_storeu_256(dst + i * dst_stride + j, w);
647 }
648 }
649 }
650}
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000651
652// The final filter for selfguided restoration. Computes a weighted average
653// across A, B with "cross sums" (see cross_sum_... implementations above).
654// Designed for the second vertical sub-sampling version of FAST_SGR.
655static void final_filter_fast2(int32_t *dst, int dst_stride, const int32_t *A,
656 const int32_t *B, int buf_stride,
657 const void *dgd8, int dgd_stride, int width,
658 int height, int highbd) {
659 const int nb0 = 5;
660 const int nb1 = 4;
661
662 const __m256i rounding0 =
663 round_for_shift(SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS);
664 const __m256i rounding1 =
665 round_for_shift(SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
666
667 const uint8_t *dgd_real =
668 highbd ? (const uint8_t *)CONVERT_TO_SHORTPTR(dgd8) : dgd8;
669
670 for (int i = 0; i < height; ++i) {
671 if (!(i & 1)) { // even row
672 for (int j = 0; j < width; j += 8) {
673 const __m256i a =
674 cross_sum_fast_even(A + i * buf_stride + j, buf_stride);
675 const __m256i b =
676 cross_sum_fast_even(B + i * buf_stride + j, buf_stride);
677
678 const __m128i raw =
679 xx_loadu_128(dgd_real + ((i * dgd_stride + j) << highbd));
680 const __m256i src =
681 highbd ? _mm256_cvtepu16_epi32(raw) : _mm256_cvtepu8_epi32(raw);
682
683 __m256i v = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
684 __m256i w =
685 _mm256_srai_epi32(_mm256_add_epi32(v, rounding0),
686 SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS);
687
688 yy_storeu_256(dst + i * dst_stride + j, w);
689 }
690 } else { // odd row
691 for (int j = 0; j < width; j += 8) {
692 const __m256i a = cross_sum_fast_odd(A + i * buf_stride + j);
693 const __m256i b = cross_sum_fast_odd(B + i * buf_stride + j);
694
695 const __m128i raw =
696 xx_loadu_128(dgd_real + ((i * dgd_stride + j) << highbd));
697 const __m256i src =
698 highbd ? _mm256_cvtepu16_epi32(raw) : _mm256_cvtepu8_epi32(raw);
699
700 __m256i v = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
701 __m256i w =
702 _mm256_srai_epi32(_mm256_add_epi32(v, rounding1),
703 SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
704
705 yy_storeu_256(dst + i * dst_stride + j, w);
706 }
707 }
708 }
709}
Imdad Sardharwalla9d234572018-01-24 13:39:00 +0000710#endif
711
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000712void av1_selfguided_restoration_avx2(const uint8_t *dgd8, int width, int height,
713 int dgd_stride, int32_t *flt1,
714 int32_t *flt2, int flt_stride,
715 const sgr_params_type *params,
716 int bit_depth, int highbd) {
717 // The ALIGN_POWER_OF_TWO macro here ensures that column 1 of Atl, Btl,
718 // Ctl and Dtl is 32-byte aligned.
719 const int buf_elts = ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, 3);
720
721 DECLARE_ALIGNED(32, int32_t,
722 buf[4 * ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, 3)]);
723 memset(buf, 0, sizeof(buf));
724
725 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
726 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
727
728 // Adjusting the stride of A and B here appears to avoid bad cache effects,
729 // leading to a significant speed improvement.
730 // We also align the stride to a multiple of 32 bytes for efficiency.
731 int buf_stride = ALIGN_POWER_OF_TWO(width_ext + 16, 3);
732
733 // The "tl" pointers point at the top-left of the initialised data for the
734 // array.
735 int32_t *Atl = buf + 0 * buf_elts + 7;
736 int32_t *Btl = buf + 1 * buf_elts + 7;
737 int32_t *Ctl = buf + 2 * buf_elts + 7;
738 int32_t *Dtl = buf + 3 * buf_elts + 7;
739
740 // The "0" pointers are (- SGRPROJ_BORDER_VERT, -SGRPROJ_BORDER_HORZ). Note
741 // there's a zero row and column in A, B (integral images), so we move down
742 // and right one for them.
743 const int buf_diag_border =
744 SGRPROJ_BORDER_HORZ + buf_stride * SGRPROJ_BORDER_VERT;
745
746 int32_t *A0 = Atl + 1 + buf_stride;
747 int32_t *B0 = Btl + 1 + buf_stride;
748 int32_t *C0 = Ctl + 1 + buf_stride;
749 int32_t *D0 = Dtl + 1 + buf_stride;
750
751 // Finally, A, B, C, D point at position (0, 0).
752 int32_t *A = A0 + buf_diag_border;
753 int32_t *B = B0 + buf_diag_border;
754 int32_t *C = C0 + buf_diag_border;
755 int32_t *D = D0 + buf_diag_border;
756
757 const int dgd_diag_border =
758 SGRPROJ_BORDER_HORZ + dgd_stride * SGRPROJ_BORDER_VERT;
759 const uint8_t *dgd0 = dgd8 - dgd_diag_border;
760
761 // Generate integral images from the input. C will contain sums of squares; D
762 // will contain just sums
763 if (highbd)
764 integral_images_highbd(CONVERT_TO_SHORTPTR(dgd0), dgd_stride, width_ext,
765 height_ext, Ctl, Dtl, buf_stride);
766 else
767 integral_images(dgd0, dgd_stride, width_ext, height_ext, Ctl, Dtl,
768 buf_stride);
769
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000770// Write to flt1 and flt2
771#if CONFIG_FAST_SGR
772 assert(params->r1 < AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ));
773
774 // r == 2 filter
775 assert(params->r1 == 2);
776 calc_ab_fast(A, B, C, D, width, height, buf_stride, params->e1, bit_depth,
777 params->r1);
778 final_filter_fast2(flt1, flt_stride, A, B, buf_stride, dgd8, dgd_stride,
779 width, height, highbd);
780
781 // r == 1 filter
782 assert(params->r2 == 1);
783 calc_ab(A, B, C, D, width, height, buf_stride, params->e2, bit_depth,
784 params->r2);
785 final_filter(flt2, flt_stride, A, B, buf_stride, dgd8, dgd_stride, width,
786 height, highbd);
787#else
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000788 for (int i = 0; i < 2; ++i) {
789 int r = i ? params->r2 : params->r1;
790 int e = i ? params->e2 : params->e1;
791 int32_t *flt = i ? flt2 : flt1;
792
793 assert(r + 1 <= AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ));
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000794
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000795 calc_ab(A, B, C, D, width, height, buf_stride, e, bit_depth, r);
796 final_filter(flt, flt_stride, A, B, buf_stride, dgd8, dgd_stride, width,
797 height, highbd);
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000798 }
Imdad Sardharwallad051e562018-02-02 09:42:07 +0000799#endif
Imdad Sardharwallac6acc532018-01-03 15:18:24 +0000800}
801
802void apply_selfguided_restoration_avx2(const uint8_t *dat8, int width,
803 int height, int stride, int eps,
804 const int *xqd, uint8_t *dst8,
805 int dst_stride, int32_t *tmpbuf,
806 int bit_depth, int highbd) {
807 int32_t *flt1 = tmpbuf;
808 int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
809 assert(width * height <= RESTORATION_TILEPELS_MAX);
810 av1_selfguided_restoration_avx2(dat8, width, height, stride, flt1, flt2,
811 width, &sgr_params[eps], bit_depth, highbd);
812
813 int xq[2];
814 decode_xq(xqd, xq);
815
816 __m256i xq0 = _mm256_set1_epi32(xq[0]);
817 __m256i xq1 = _mm256_set1_epi32(xq[1]);
818
819 for (int i = 0; i < height; ++i) {
820 // Calculate output in batches of 16 pixels
821 for (int j = 0; j < width; j += 16) {
822 const int k = i * width + j;
823 const int m = i * dst_stride + j;
824
825 const uint8_t *dat8ij = dat8 + i * stride + j;
826 __m256i ep_0, ep_1;
827 __m128i src_0, src_1;
828 if (highbd) {
829 src_0 = xx_loadu_128(CONVERT_TO_SHORTPTR(dat8ij));
830 src_1 = xx_loadu_128(CONVERT_TO_SHORTPTR(dat8ij + 8));
831 ep_0 = _mm256_cvtepu16_epi32(src_0);
832 ep_1 = _mm256_cvtepu16_epi32(src_1);
833 } else {
834 src_0 = xx_loadu_128(dat8ij);
835 ep_0 = _mm256_cvtepu8_epi32(src_0);
836 ep_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(src_0, 8));
837 }
838
839 const __m256i u_0 = _mm256_slli_epi32(ep_0, SGRPROJ_RST_BITS);
840 const __m256i u_1 = _mm256_slli_epi32(ep_1, SGRPROJ_RST_BITS);
841
842 const __m256i f1_0 = _mm256_sub_epi32(yy_loadu_256(&flt1[k]), u_0);
843 const __m256i f1_1 = _mm256_sub_epi32(yy_loadu_256(&flt1[k + 8]), u_1);
844
845 const __m256i f2_0 = _mm256_sub_epi32(yy_loadu_256(&flt2[k]), u_0);
846 const __m256i f2_1 = _mm256_sub_epi32(yy_loadu_256(&flt2[k + 8]), u_1);
847
848 const __m256i v_0 =
849 _mm256_add_epi32(_mm256_add_epi32(_mm256_mullo_epi32(xq0, f1_0),
850 _mm256_mullo_epi32(xq1, f2_0)),
851 _mm256_slli_epi32(u_0, SGRPROJ_PRJ_BITS));
852 const __m256i v_1 =
853 _mm256_add_epi32(_mm256_add_epi32(_mm256_mullo_epi32(xq0, f1_1),
854 _mm256_mullo_epi32(xq1, f2_1)),
855 _mm256_slli_epi32(u_1, SGRPROJ_PRJ_BITS));
856
857 const __m256i rounding =
858 round_for_shift(SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
859 const __m256i w_0 = _mm256_srai_epi32(
860 _mm256_add_epi32(v_0, rounding), SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
861 const __m256i w_1 = _mm256_srai_epi32(
862 _mm256_add_epi32(v_1, rounding), SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
863
864 if (highbd) {
865 // Pack into 16 bits and clamp to [0, 2^bit_depth)
866 // Note that packing into 16 bits messes up the order of the bits,
867 // so we use a permute function to correct this
868 const __m256i tmp = _mm256_packus_epi32(w_0, w_1);
869 const __m256i tmp2 = _mm256_permute4x64_epi64(tmp, 0xd8);
870 const __m256i max = _mm256_set1_epi16((1 << bit_depth) - 1);
871 const __m256i res = _mm256_min_epi16(tmp2, max);
872 yy_store_256(CONVERT_TO_SHORTPTR(dst8 + m), res);
873 } else {
874 // Pack into 8 bits and clamp to [0, 256)
875 // Note that each pack messes up the order of the bits,
876 // so we use a permute function to correct this
877 const __m256i tmp = _mm256_packs_epi32(w_0, w_1);
878 const __m256i tmp2 = _mm256_permute4x64_epi64(tmp, 0xd8);
879 const __m256i res =
880 _mm256_packus_epi16(tmp2, tmp2 /* "don't care" value */);
881 const __m128i res2 =
882 _mm256_castsi256_si128(_mm256_permute4x64_epi64(res, 0xd8));
883 xx_store_128(dst8 + m, res2);
884 }
885 }
886 }
887}