blob: 4419c65b29271cd0dbae339449d26df7f3d933fa [file] [log] [blame]
Yi Luo1f496242016-11-09 13:39:51 -08001/*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12#include <immintrin.h>
13#include "./aom_dsp_rtcd.h"
14
15static unsigned int sad32x32(const uint8_t *src_ptr, int src_stride,
16 const uint8_t *ref_ptr, int ref_stride) {
17 __m256i s1, s2, r1, r2;
18 __m256i sum = _mm256_setzero_si256();
19 __m128i sum_i128;
20 int i;
21
22 for (i = 0; i < 16; ++i) {
23 r1 = _mm256_loadu_si256((__m256i const *)ref_ptr);
24 r2 = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride));
25 s1 = _mm256_sad_epu8(r1, _mm256_loadu_si256((__m256i const *)src_ptr));
26 s2 = _mm256_sad_epu8(
27 r2, _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride)));
28 sum = _mm256_add_epi32(sum, _mm256_add_epi32(s1, s2));
29 ref_ptr += ref_stride << 1;
30 src_ptr += src_stride << 1;
31 }
32
33 sum = _mm256_add_epi32(sum, _mm256_srli_si256(sum, 8));
34 sum_i128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 1),
35 _mm256_castsi256_si128(sum));
36 return _mm_cvtsi128_si32(sum_i128);
37}
38
39static unsigned int sad64x32(const uint8_t *src_ptr, int src_stride,
40 const uint8_t *ref_ptr, int ref_stride) {
41 unsigned int half_width = 32;
42 uint32_t sum = sad32x32(src_ptr, src_stride, ref_ptr, ref_stride);
43 src_ptr += half_width;
44 ref_ptr += half_width;
45 sum += sad32x32(src_ptr, src_stride, ref_ptr, ref_stride);
46 return sum;
47}
48
49static unsigned int sad64x64(const uint8_t *src_ptr, int src_stride,
50 const uint8_t *ref_ptr, int ref_stride) {
51 uint32_t sum = sad64x32(src_ptr, src_stride, ref_ptr, ref_stride);
52 src_ptr += src_stride << 5;
53 ref_ptr += ref_stride << 5;
54 sum += sad64x32(src_ptr, src_stride, ref_ptr, ref_stride);
55 return sum;
56}
57
58unsigned int aom_sad128x64_avx2(const uint8_t *src_ptr, int src_stride,
59 const uint8_t *ref_ptr, int ref_stride) {
60 unsigned int half_width = 64;
61 uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
62 src_ptr += half_width;
63 ref_ptr += half_width;
64 sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
65 return sum;
66}
67
68unsigned int aom_sad64x128_avx2(const uint8_t *src_ptr, int src_stride,
69 const uint8_t *ref_ptr, int ref_stride) {
70 uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
71 src_ptr += src_stride << 6;
72 ref_ptr += ref_stride << 6;
73 sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
74 return sum;
75}
76
77unsigned int aom_sad128x128_avx2(const uint8_t *src_ptr, int src_stride,
78 const uint8_t *ref_ptr, int ref_stride) {
79 uint32_t sum = aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride);
80 src_ptr += src_stride << 6;
81 ref_ptr += ref_stride << 6;
82 sum += aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride);
83 return sum;
84}
Yi Luo9e218742016-11-22 11:50:12 -080085
86static void sad64x64x4d(const uint8_t *src, int src_stride,
87 const uint8_t *const ref[4], int ref_stride,
88 __m128i *res) {
89 uint32_t sum[4];
90 aom_sad64x64x4d_avx2(src, src_stride, ref, ref_stride, sum);
91 *res = _mm_loadu_si128((const __m128i *)sum);
92}
93
94void aom_sad64x128x4d_avx2(const uint8_t *src, int src_stride,
95 const uint8_t *const ref[4], int ref_stride,
96 uint32_t res[4]) {
97 __m128i sum0, sum1;
98 const uint8_t *rf[4];
99
100 rf[0] = ref[0];
101 rf[1] = ref[1];
102 rf[2] = ref[2];
103 rf[3] = ref[3];
104 sad64x64x4d(src, src_stride, rf, ref_stride, &sum0);
105 src += src_stride << 6;
106 rf[0] += ref_stride << 6;
107 rf[1] += ref_stride << 6;
108 rf[2] += ref_stride << 6;
109 rf[3] += ref_stride << 6;
110 sad64x64x4d(src, src_stride, rf, ref_stride, &sum1);
111 sum0 = _mm_add_epi32(sum0, sum1);
112 _mm_storeu_si128((__m128i *)res, sum0);
113}
114
115void aom_sad128x64x4d_avx2(const uint8_t *src, int src_stride,
116 const uint8_t *const ref[4], int ref_stride,
117 uint32_t res[4]) {
118 __m128i sum0, sum1;
119 unsigned int half_width = 64;
120 const uint8_t *rf[4];
121
122 rf[0] = ref[0];
123 rf[1] = ref[1];
124 rf[2] = ref[2];
125 rf[3] = ref[3];
126 sad64x64x4d(src, src_stride, rf, ref_stride, &sum0);
127 src += half_width;
128 rf[0] += half_width;
129 rf[1] += half_width;
130 rf[2] += half_width;
131 rf[3] += half_width;
132 sad64x64x4d(src, src_stride, rf, ref_stride, &sum1);
133 sum0 = _mm_add_epi32(sum0, sum1);
134 _mm_storeu_si128((__m128i *)res, sum0);
135}
136
137void aom_sad128x128x4d_avx2(const uint8_t *src, int src_stride,
138 const uint8_t *const ref[4], int ref_stride,
139 uint32_t res[4]) {
140 const uint8_t *rf[4];
141 uint32_t sum0[4];
142 uint32_t sum1[4];
143
144 rf[0] = ref[0];
145 rf[1] = ref[1];
146 rf[2] = ref[2];
147 rf[3] = ref[3];
148 aom_sad128x64x4d_avx2(src, src_stride, rf, ref_stride, sum0);
149 src += src_stride << 6;
150 rf[0] += ref_stride << 6;
151 rf[1] += ref_stride << 6;
152 rf[2] += ref_stride << 6;
153 rf[3] += ref_stride << 6;
154 aom_sad128x64x4d_avx2(src, src_stride, rf, ref_stride, sum1);
155 res[0] = sum0[0] + sum1[0];
156 res[1] = sum0[1] + sum1[1];
157 res[2] = sum0[2] + sum1[2];
158 res[3] = sum0[3] + sum1[3];
159}
160
161static unsigned int sad_w64_avg_avx2(const uint8_t *src_ptr, int src_stride,
162 const uint8_t *ref_ptr, int ref_stride,
163 const int h, const uint8_t *second_pred,
164 const int second_pred_stride) {
165 int i, res;
166 __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;
167 __m256i sum_sad = _mm256_setzero_si256();
168 __m256i sum_sad_h;
169 __m128i sum_sad128;
170 for (i = 0; i < h; i++) {
171 ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);
172 ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + 32));
173 ref1_reg = _mm256_avg_epu8(
174 ref1_reg, _mm256_loadu_si256((__m256i const *)second_pred));
175 ref2_reg = _mm256_avg_epu8(
176 ref2_reg, _mm256_loadu_si256((__m256i const *)(second_pred + 32)));
177 sad1_reg =
178 _mm256_sad_epu8(ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));
179 sad2_reg = _mm256_sad_epu8(
180 ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + 32)));
181 sum_sad = _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));
182 ref_ptr += ref_stride;
183 src_ptr += src_stride;
184 second_pred += second_pred_stride;
185 }
186 sum_sad_h = _mm256_srli_si256(sum_sad, 8);
187 sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);
188 sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);
189 sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);
190 res = _mm_cvtsi128_si32(sum_sad128);
191
192 return res;
193}
194
195unsigned int aom_sad64x128_avg_avx2(const uint8_t *src_ptr, int src_stride,
196 const uint8_t *ref_ptr, int ref_stride,
197 const uint8_t *second_pred) {
198 uint32_t sum = sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
199 second_pred, 64);
200 src_ptr += src_stride << 6;
201 ref_ptr += ref_stride << 6;
202 second_pred += 64 << 6;
203 sum += sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
204 second_pred, 64);
205 return sum;
206}
207
208unsigned int aom_sad128x64_avg_avx2(const uint8_t *src_ptr, int src_stride,
209 const uint8_t *ref_ptr, int ref_stride,
210 const uint8_t *second_pred) {
211 unsigned int half_width = 64;
212 uint32_t sum = sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
213 second_pred, 128);
214 src_ptr += half_width;
215 ref_ptr += half_width;
216 second_pred += half_width;
217 sum += sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
218 second_pred, 128);
219 return sum;
220}
221
222unsigned int aom_sad128x128_avg_avx2(const uint8_t *src_ptr, int src_stride,
223 const uint8_t *ref_ptr, int ref_stride,
224 const uint8_t *second_pred) {
225 uint32_t sum = aom_sad128x64_avg_avx2(src_ptr, src_stride, ref_ptr,
226 ref_stride, second_pred);
227 src_ptr += src_stride << 6;
228 ref_ptr += ref_stride << 6;
229 second_pred += 128 << 6;
230 sum += aom_sad128x64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride,
231 second_pred);
232 return sum;
233}