Added SIMD code for intra_matrix (DIP) prediction.
Add SIMD code for av1_dip_matrix_multiplication().
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 6859013..5ff0ed9 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -113,6 +113,8 @@
APPEND AOM_AV1_COMMON_SOURCES "${AOM_ROOT}/av1/common/intra_matrix.c"
"${AOM_ROOT}/av1/common/intra_matrix.h"
"${AOM_ROOT}/av1/common/intra_dip.cc" "${AOM_ROOT}/av1/common/intra_dip.h")
+ list(APPEND AOM_AV1_COMMON_INTRIN_AVX2
+ "${AOM_ROOT}/av1/common/x86/intra_matrix_avx2.c")
endif()
if(CONFIG_AV1_ENCODER)
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 6597d9b..1c9a2b4 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -323,6 +323,12 @@
"uint8_t* weights, uint16_t *dst, ptrdiff_t stride, uint16_t* second_pred, ptrdiff_t second_stride, int bw, int bh";
}
+# Data-driven intra prediction (DIP)
+if (aom_config("CONFIG_DIP") eq "yes") {
+ add_proto qw/void av1_dip_matrix_multiplication/, "const uint16_t *A, const uint16_t *B, uint16_t *C, int bd";
+ specialize qw/av1_dip_matrix_multiplication avx2/
+}
+
# build compound seg mask functions
add_proto qw/void av1_build_compound_diffwtd_mask_highbd/, "uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint16_t *src0, int src0_stride, const uint16_t *src1, int src1_stride, int h, int w, int bd";
specialize qw/av1_build_compound_diffwtd_mask_highbd ssse3 avx2/;
diff --git a/av1/common/intra_matrix.c b/av1/common/intra_matrix.c
index fbbfc2a..84ce187 100644
--- a/av1/common/intra_matrix.c
+++ b/av1/common/intra_matrix.c
@@ -15,14 +15,8 @@
#include "intra_dip.h"
#include "intra_matrix.h"
-#define ROWS 64
-#define COLS 11
-#define BITS 12
-#define OFFSET (1 << (12 - 1))
-#define SCALE 4
-
-static const uint16_t
- av1_intra_matrix_weights[INTRA_DIP_MODE_CNT][ROWS][COLS] = {
+const uint16_t
+ av1_intra_matrix_weights[INTRA_DIP_MODE_CNT][DIP_ROWS][DIP_COLS] = {
{
{ 776, 1714, 1077, 998, 1043, 1437, 1157, 1027, 1027, 1011, 1023 },
{ 845, 1519, 1598, 892, 1088, 988, 1219, 1079, 1039, 1000, 1022 },
@@ -426,17 +420,17 @@
// B - pointer to feature vector
// C - 8x8 output prediction
// bd - bit depth
-static void matrix_multiplication(const uint16_t *A, const uint16_t *B,
- uint16_t *C, int bd) {
+void av1_dip_matrix_multiplication_c(const uint16_t *A, const uint16_t *B,
+ uint16_t *C, int bd) {
int sum = 0;
- for (int j = 0; j < COLS; j++) sum += B[j];
+ for (int j = 0; j < DIP_FEATURES; j++) sum += B[j];
- for (int i = 0; i < ROWS; i++) {
+ for (int i = 0; i < DIP_ROWS; i++) {
int c = 0;
- for (int j = 0; j < COLS; j++) {
- c += SCALE * A[i * COLS + j] * B[j];
+ for (int j = 0; j < DIP_FEATURES; j++) {
+ c += DIP_SCALE * A[i * DIP_COLS + j] * B[j];
}
- c = ((c + OFFSET) >> BITS) - sum;
+ c = ((c + DIP_OFFSET) >> DIP_BITS) - sum;
C[i] = clip_pixel_highbd(c, bd);
}
}
@@ -447,5 +441,5 @@
assert(mode >= 0 && mode < INTRA_DIP_MODE_CNT);
const uint16_t *A = &av1_intra_matrix_weights[mode][0][0];
- matrix_multiplication(A, input, output, bd);
+ av1_dip_matrix_multiplication(A, input, output, bd);
}
diff --git a/av1/common/intra_matrix.h b/av1/common/intra_matrix.h
index ab83d92..c428186 100644
--- a/av1/common/intra_matrix.h
+++ b/av1/common/intra_matrix.h
@@ -9,5 +9,14 @@
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
+#define DIP_ROWS 64
+#define DIP_COLS 16
+#define DIP_BITS 12
+#define DIP_OFFSET (1 << (12 - 1))
+#define DIP_SCALE 4
+#define DIP_FEATURES 11
+
+extern const uint16_t av1_intra_matrix_weights[][DIP_ROWS][DIP_COLS];
+
void av1_intra_matrix_pred(const uint16_t *input, int mode, uint16_t *output,
int bd);
diff --git a/av1/common/x86/intra_matrix_avx2.c b/av1/common/x86/intra_matrix_avx2.c
new file mode 100644
index 0000000..bdcfdc3
--- /dev/null
+++ b/av1/common/x86/intra_matrix_avx2.c
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) 2025, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <immintrin.h> /* AVX2 */
+
+#include "aom_dsp/aom_dsp_common.h"
+#include "av1/common/intra_matrix.h"
+
+// Multiply 11 element feature vector with matrix to generate 8x8 prediction.
+// A - pointer to matrix
+// B - pointer to feature vector
+// C - 8x8 output prediction
+// bd - bit depth
+void av1_dip_matrix_multiplication_avx2(const uint16_t *A, const uint16_t *B,
+ uint16_t *C, int bd) {
+ static const uint16_t mask[16] = { -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, -1, 0, 0, 0, 0, 0 };
+
+ __m256i in0 = _mm256_lddqu_si256((__m256i *)B);
+ __m256i in_mask = _mm256_lddqu_si256((__m256i *)mask);
+ in0 = _mm256_and_si256(in0, in_mask);
+ // in0 = { B0, B1, B2, B3, B4, B5, B6, B7 | B8, B9, B10, 0, 0, 0, 0, 0 }
+ __m256i negsum = _mm256_madd_epi16(in0, in_mask);
+ negsum = _mm256_hadd_epi32(negsum, negsum);
+ negsum = _mm256_hadd_epi32(negsum, negsum);
+ negsum = _mm256_slli_epi32(negsum, DIP_BITS - 2);
+ __m128i offset = _mm_set1_epi32(DIP_OFFSET >> 2);
+ __m128i maxval = _mm_set1_epi32((1 << bd) - 1);
+ __m128i zero = _mm_setzero_si128();
+
+ for (int i = 0; i < DIP_ROWS; i += 4) {
+ __m256i row0 = _mm256_lddqu_si256((__m256i *)&A[i * DIP_COLS]);
+ __m256i row1 = _mm256_lddqu_si256((__m256i *)&A[(i + 1) * DIP_COLS]);
+ __m256i row2 = _mm256_lddqu_si256((__m256i *)&A[(i + 2) * DIP_COLS]);
+ __m256i row3 = _mm256_lddqu_si256((__m256i *)&A[(i + 3) * DIP_COLS]);
+ __m256i m0 = _mm256_madd_epi16(row0, in0);
+ __m256i m1 = _mm256_madd_epi16(row1, in0);
+ __m256i m2 = _mm256_madd_epi16(row2, in0);
+ __m256i m3 = _mm256_madd_epi16(row3, in0);
+ __m256i m01 = _mm256_hadd_epi32(m0, m1);
+ __m256i m23 = _mm256_hadd_epi32(m2, m3);
+ __m256i m0123 = _mm256_hadd_epi32(m01, m23);
+ __m256i sum0 = _mm256_add_epi32(m0123, negsum);
+ __m128i sum0_lo = _mm256_castsi256_si128(sum0);
+ __m128i sum0_hi = _mm256_extracti128_si256(sum0, 1);
+ __m128i sum1 = _mm_add_epi32(sum0_lo, sum0_hi);
+ sum1 = _mm_add_epi32(sum1, offset);
+ sum1 = _mm_srai_epi32(sum1, DIP_BITS - 2);
+ sum1 = _mm_min_epi32(sum1, maxval);
+ sum1 = _mm_max_epi32(sum1, zero);
+ __m128i out0 = _mm_packus_epi32(sum1, sum1);
+ _mm_storeu_si64(&C[i], out0);
+ }
+}
diff --git a/test/intra_matrix_test.cc b/test/intra_matrix_test.cc
new file mode 100644
index 0000000..bdb3193
--- /dev/null
+++ b/test/intra_matrix_test.cc
@@ -0,0 +1,133 @@
+/*
+ * Copyright (c) 2025, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 3-Clause Clear License
+ * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+ * License was not distributed with this source code in the LICENSE file, you
+ * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+ * Alliance for Open Media Patent License 1.0 was not distributed with this
+ * source code in the PATENTS file, you can obtain it at
+ * aomedia.org/license/patent-license/.
+ */
+
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+#include "test/register_state_check.h"
+#include "test/function_equivalence_test.h"
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+#include "config/av1_rtcd.h"
+
+#include "aom/aom_integer.h"
+#include "av1/common/enums.h"
+#include "av1/common/intra_dip.h"
+#include "av1/common/intra_matrix.h"
+
+using libaom_test::FunctionEquivalenceTest;
+
+namespace {
+
+template <typename F, typename T>
+class IntraMatrixTest : public FunctionEquivalenceTest<F> {
+ protected:
+ static const int kIterations = 1000000;
+ static const int kBufSize = 8 * 8;
+
+ virtual ~IntraMatrixTest() {}
+
+ virtual void Execute(T *dip_tst) = 0;
+
+ void Common() {
+ dip_ref_ = &dip_ref_data_[0];
+ dip_tst_ = &dip_tst_data_[0];
+
+ Execute(dip_tst_);
+
+ for (int r = 0; r < kBufSize; ++r) {
+ ASSERT_EQ(dip_ref_[r], dip_tst_[r]);
+ }
+ }
+
+ T dip_arr_[DIP_ROWS * DIP_COLS];
+ T dip_feat_[DIP_COLS];
+
+ T dip_ref_data_[kBufSize];
+ T dip_tst_data_[kBufSize];
+
+ T *dip_ref_;
+ T *dip_tst_;
+};
+
+//////////////////////////////////////////////////////////////////////////////
+// High bit-depth version
+//////////////////////////////////////////////////////////////////////////////
+
+typedef void (*IMHB)(const uint16_t *A, const uint16_t *B, uint16_t *C, int bd);
+typedef libaom_test::FuncParam<IMHB> IntraMatrixTestFuncsHBD;
+
+class IntraMatrixTestHB : public IntraMatrixTest<IMHB, uint16_t> {
+ protected:
+ void Execute(uint16_t *dip_tst) {
+ params_.ref_func(dip_arr_, dip_feat_, dip_ref_, bit_depth_);
+ ASM_REGISTER_STATE_CHECK(
+ params_.tst_func(dip_arr_, dip_feat_, dip_tst, bit_depth_));
+ }
+ int bit_depth_;
+};
+
+TEST_P(IntraMatrixTestHB, RandomValues) {
+ for (int iter = 0; iter < kIterations && !HasFatalFailure(); ++iter) {
+ switch (rng_(3)) {
+ case 0: bit_depth_ = 8; break;
+ case 1: bit_depth_ = 10; break;
+ default: bit_depth_ = 12; break;
+ }
+ const int hi = 1 << bit_depth_;
+
+ for (int i = 0; i < 16; ++i) {
+ dip_feat_[i] = rng_(hi);
+ }
+ int mode = iter % INTRA_DIP_MODE_CNT;
+ for (int r = 0; r < DIP_ROWS; ++r) {
+ for (int c = 0; c < DIP_FEATURES; ++c) {
+ dip_arr_[r * DIP_COLS + c] = av1_intra_matrix_weights[mode][r][c];
+ }
+ }
+
+ Common();
+ }
+}
+
+#if HAVE_AVX2
+INSTANTIATE_TEST_SUITE_P(AVX2, IntraMatrixTestHB,
+ ::testing::Values(IntraMatrixTestFuncsHBD(
+ av1_dip_matrix_multiplication_c,
+ av1_dip_matrix_multiplication_avx2)));
+#endif // HAVE_AVX2
+
+// Speed tests
+
+TEST_P(IntraMatrixTestHB, DISABLED_Speed) {
+ const int test_count = 10000000;
+ bit_depth_ = 12;
+ const int hi = 1 << bit_depth_;
+ for (int i = 0; i < 16; ++i) {
+ dip_feat_[i] = rng_(hi);
+ }
+ for (int r = 0; r < 64; ++r) {
+ for (int c = 0; c < 11; ++c) {
+ dip_arr_[r * 16 + c] = av1_intra_matrix_weights[0][r][c];
+ }
+ }
+ dip_tst_ = &dip_tst_data_[0];
+ for (int iter = 0; iter < test_count; ++iter) {
+ ASM_REGISTER_STATE_CHECK(
+ params_.tst_func(dip_arr_, dip_feat_, dip_tst_, bit_depth_));
+ }
+}
+
+} // namespace
diff --git a/test/test.cmake b/test/test.cmake
index be7ff2b..33fb16d 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -104,6 +104,11 @@
"${AOM_ROOT}/test/simd_cmp_impl.h"
"${AOM_ROOT}/test/simd_impl.h")
+ if(CONFIG_DIP)
+ list(APPEND AOM_UNIT_TEST_COMMON_SOURCES
+ "${AOM_ROOT}/test/intra_matrix_test.cc")
+ endif()
+
if(CONFIG_ACCOUNTING)
list(APPEND AOM_UNIT_TEST_COMMON_SOURCES
"${AOM_ROOT}/test/accounting_test.cc")