blob: c2800528044cc95d12d1be5d4140cbd05bcd687b [file] [log] [blame] [edit]
/*
* Copyright (c) 2025, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include <immintrin.h> /* AVX2 */
#include <tmmintrin.h> /* SSSE3 */
#include "aom_dsp/aom_dsp_common.h"
#include "av1/common/intra_matrix.h"
// Multiply 11 element feature vector with matrix to generate 8x8 prediction.
// A - pointer to matrix
// B - pointer to feature vector
// C - 8x8 output prediction
// bd - bit depth
void av1_dip_matrix_multiplication_avx2(const uint16_t *A, const uint16_t *B,
uint16_t *C, int bd) {
static const uint16_t mask[16] = { -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, 0, 0, 0, 0, 0 };
__m256i in0 = _mm256_lddqu_si256((__m256i *)B);
__m256i in_mask = _mm256_lddqu_si256((__m256i *)mask);
in0 = _mm256_and_si256(in0, in_mask);
// in0 = { B0, B1, B2, B3, B4, B5, B6, B7 | B8, B9, B10, 0, 0, 0, 0, 0 }
__m256i negsum = _mm256_madd_epi16(in0, in_mask);
negsum = _mm256_hadd_epi32(negsum, negsum);
negsum = _mm256_hadd_epi32(negsum, negsum);
negsum = _mm256_slli_epi32(negsum, DIP_BITS - 2);
__m128i offset = _mm_set1_epi32(DIP_OFFSET >> 2);
__m128i maxval = _mm_set1_epi32((1 << bd) - 1);
__m128i zero = _mm_setzero_si128();
for (int i = 0; i < DIP_ROWS; i += 4) {
__m256i row0 = _mm256_lddqu_si256((__m256i *)&A[i * DIP_COLS]);
__m256i row1 = _mm256_lddqu_si256((__m256i *)&A[(i + 1) * DIP_COLS]);
__m256i row2 = _mm256_lddqu_si256((__m256i *)&A[(i + 2) * DIP_COLS]);
__m256i row3 = _mm256_lddqu_si256((__m256i *)&A[(i + 3) * DIP_COLS]);
__m256i m0 = _mm256_madd_epi16(row0, in0);
__m256i m1 = _mm256_madd_epi16(row1, in0);
__m256i m2 = _mm256_madd_epi16(row2, in0);
__m256i m3 = _mm256_madd_epi16(row3, in0);
__m256i m01 = _mm256_hadd_epi32(m0, m1);
__m256i m23 = _mm256_hadd_epi32(m2, m3);
__m256i m0123 = _mm256_hadd_epi32(m01, m23);
__m256i sum0 = _mm256_add_epi32(m0123, negsum);
__m128i sum0_lo = _mm256_castsi256_si128(sum0);
__m128i sum0_hi = _mm256_extracti128_si256(sum0, 1);
__m128i sum1 = _mm_add_epi32(sum0_lo, sum0_hi);
sum1 = _mm_add_epi32(sum1, offset);
sum1 = _mm_srai_epi32(sum1, DIP_BITS - 2);
sum1 = _mm_min_epi32(sum1, maxval);
sum1 = _mm_max_epi32(sum1, zero);
__m128i out0 = _mm_packus_epi32(sum1, sum1);
_mm_storeu_si64(&C[i], out0);
}
}
// Processes 8 pixels at a time using SSSE3.
static INLINE void resample_vert_w8_ssse3(uint16_t *dst, const uint16_t *p0_row,
const uint16_t *p1_row, const int w0,
const int w1, const int upy_log2) {
const __m128i p0 = _mm_loadu_si128((const __m128i *)p0_row);
const __m128i p1 = _mm_loadu_si128((const __m128i *)p1_row);
const __m128i p0p1_lo = _mm_unpacklo_epi16(p0, p1);
const __m128i p0p1_hi = _mm_unpackhi_epi16(p0, p1);
const __m128i weights = _mm_set_epi16(w1, w0, w1, w0, w1, w0, w1, w0);
const __m128i res_lo = _mm_madd_epi16(p0p1_lo, weights);
const __m128i res_hi = _mm_madd_epi16(p0p1_hi, weights);
const __m128i shift = _mm_cvtsi32_si128(upy_log2);
const __m128i shifted_lo = _mm_sra_epi32(res_lo, shift);
const __m128i shifted_hi = _mm_sra_epi32(res_hi, shift);
const __m128i result = _mm_packus_epi32(shifted_lo, shifted_hi);
_mm_storeu_si128((__m128i *)dst, result);
}
// Processes 16 pixels at a time using AVX2.
static INLINE void resample_vert_w16_avx2(uint16_t *dst, const uint16_t *p0_row,
const uint16_t *p1_row, const int w0,
const int w1, const int upy_log2) {
const __m256i p0 = _mm256_loadu_si256((const __m256i *)p0_row);
const __m256i p1 = _mm256_loadu_si256((const __m256i *)p1_row);
const __m256i p0p1_lo = _mm256_unpacklo_epi16(p0, p1);
const __m256i p0p1_hi = _mm256_unpackhi_epi16(p0, p1);
const __m256i weights = _mm256_set_epi16(w1, w0, w1, w0, w1, w0, w1, w0, w1,
w0, w1, w0, w1, w0, w1, w0);
const __m256i res_lo = _mm256_madd_epi16(p0p1_lo, weights);
const __m256i res_hi = _mm256_madd_epi16(p0p1_hi, weights);
const __m128i shift = _mm_cvtsi32_si128(upy_log2);
const __m256i shifted_lo = _mm256_sra_epi32(res_lo, shift);
const __m256i shifted_hi = _mm256_sra_epi32(res_hi, shift);
const __m256i result = _mm256_packus_epi32(shifted_lo, shifted_hi);
_mm256_storeu_si256((__m256i *)dst, result);
}
// Processes 32 pixels at a time using AVX2.
static INLINE void resample_vert_w32_avx2(uint16_t *dst, const uint16_t *p0_row,
const uint16_t *p1_row, const int w0,
const int w1, const int upy_log2) {
resample_vert_w16_avx2(dst, p0_row, p1_row, w0, w1, upy_log2);
resample_vert_w16_avx2(dst + 16, p0_row + 16, p1_row + 16, w0, w1, upy_log2);
}
// Processes 64 pixels at a time using AVX2.
static INLINE void resample_vert_w64_avx2(uint16_t *dst, const uint16_t *p0_row,
const uint16_t *p1_row, const int w0,
const int w1, const int upy_log2) {
resample_vert_w32_avx2(dst, p0_row, p1_row, w0, w1, upy_log2);
resample_vert_w32_avx2(dst + 32, p0_row + 32, p1_row + 32, w0, w1, upy_log2);
}
void resample_output_avx2(uint16_t *dst, int dst_stride,
const uint16_t *above_row, const uint16_t *left_col,
uint16_t *ml_output, int bw_log2, int bh_log2,
int transpose) {
// AOM_SIMD_CONV_FN_W_FN
typedef void (*resample_vert_fn)(uint16_t *dst, const uint16_t *p0_row,
const uint16_t *p1_row, const int w0,
const int w1, const int upy_log2);
// up/down sampling factors
int pred_x = 8;
int pred_y = 8;
int upx_log2 = bw_log2 - 3;
int upy_log2 = bh_log2 - 3;
int downx_log2 = 0;
int downy_log2 = 0;
if (upx_log2 < 0) {
downx_log2 = -upx_log2;
upx_log2 = 0;
}
if (upy_log2 < 0) {
downy_log2 = -upy_log2;
upy_log2 = 0;
}
int mx = 1 << upx_log2;
int my = 1 << upy_log2;
int downx = 1 << downx_log2;
int downy = 1 << downy_log2;
int bw = 1 << bw_log2;
// Copy ml_output[] into dst[]
for (int i = 0; i < pred_y >> downy_log2; i++) {
for (int j = 0; j < pred_x >> downx_log2; j++) {
int x = j * mx + (mx - 1);
int y = i * my + (my - 1);
int i1 = i * downy;
int j1 = j * downx;
int ii = transpose ? j1 : i1;
int jj = transpose ? i1 : j1;
dst[y * dst_stride + x] = ml_output[ii * pred_x + jj];
}
}
// Interpolate horizontally.
if (upx_log2 > 2) { // For width > 32
for (int i = 0; i < pred_y >> downy_log2; i++) {
int y = i * my + (my - 1);
int p0 = 0;
int p1 = left_col[y];
for (int j = 0; j < pred_x >> downx_log2; j++) {
int x = j * mx;
p0 = p1;
p1 = dst[y * dst_stride + x + mx - 1];
const __m256i p0_32 = _mm256_set1_epi32(p0);
const __m256i p1_32 = _mm256_set1_epi32(p1);
const __m256i mx_32 = _mm256_set1_epi32(mx);
const __m128i shift = _mm_cvtsi32_si128(upx_log2);
int k = 0;
for (; k <= mx - 1 - 8; k += 8) {
const __m256i k1_32 = _mm256_setr_epi32(k + 1, k + 2, k + 3, k + 4,
k + 5, k + 6, k + 7, k + 8);
const __m256i w1_32 = k1_32;
const __m256i w0_32 = _mm256_sub_epi32(mx_32, w1_32);
const __m256i p0w0 = _mm256_mullo_epi32(p0_32, w0_32);
const __m256i p1w1 = _mm256_mullo_epi32(p1_32, w1_32);
const __m256i sum = _mm256_add_epi32(p0w0, p1w1);
const __m256i shifted = _mm256_sra_epi32(sum, shift);
const __m128i lo = _mm256_castsi256_si128(shifted);
const __m128i hi = _mm256_extracti128_si256(shifted, 1);
const __m128i packed = _mm_packus_epi32(lo, hi);
_mm_storeu_si128((__m128i *)(dst + y * dst_stride + x + k), packed);
}
// Remainder loop
for (; k < mx - 1; ++k) {
const int k1 = k + 1;
dst[y * dst_stride + x + k] = (p0 * (mx - k1) + p1 * k1) >> upx_log2;
}
}
}
} else { // For width <= 32
for (int i = 0; i < pred_y >> downy_log2; i++) {
int y = i * my + (my - 1);
int p0 = 0;
int p1 = left_col[y];
for (int j = 0; j < pred_x >> downx_log2; j++) {
int x = j * mx;
p0 = p1;
p1 = dst[y * dst_stride + x + mx - 1];
for (int k = 0; k < mx - 1; k++) {
int k1 = k + 1;
dst[y * dst_stride + x + k] =
(p0 * (mx - k1) + (p1 * k1)) >> upx_log2;
}
}
}
}
// OPTIMIZED Interpolate vertically.
// A function pointer is used to select the SIMD level just once.
resample_vert_fn fn = NULL;
switch (bw) {
case 64: fn = resample_vert_w64_avx2; break;
case 32: fn = resample_vert_w32_avx2; break;
case 16: fn = resample_vert_w16_avx2; break;
case 8: fn = resample_vert_w8_ssse3; break;
}
if (fn) {
for (int i = 0; i < pred_y >> downy_log2; i++) {
const int y = i * my;
const uint16_t *p0_row =
(y == 0) ? above_row : &dst[(y - 1) * dst_stride];
const uint16_t *p1_row = &dst[(y + my - 1) * dst_stride];
for (int k = 0; k < my - 1; k++) {
const int k1 = k + 1;
const int w0 = my - k1;
const int w1 = k1;
uint16_t *dst_row = &dst[(y + k) * dst_stride];
fn(dst_row, p0_row, p1_row, w0, w1, upy_log2);
}
}
} else {
// C fallback for other widths.
// This loop is restructured to be more cache-friendly.
// The original C code loops column by column. Here we change it to
// row by row to make it friendly for SIMD.
for (int i = 0; i < pred_y >> downy_log2; i++) {
const int y = i * my;
const uint16_t *p0_row =
(y == 0) ? above_row : &dst[(y - 1) * dst_stride];
const uint16_t *p1_row = &dst[(y + my - 1) * dst_stride];
for (int k = 0; k < my - 1; k++) {
const int k1 = k + 1;
const int w0 = my - k1;
const int w1 = k1;
uint16_t *dst_row = &dst[(y + k) * dst_stride];
for (int x = 0; x < bw; x++) {
const int p0 = p0_row[x];
const int p1 = p1_row[x];
dst_row[x] = (uint16_t)((p0 * w0 + p1 * w1) >> upy_log2);
}
}
}
}
}