Add specific path for horizontal dirs in cdef_find_dir_neon
compute_directions was used for both vertical and horizontal directions,
with a rotation of the input block between the two calls. Remove the
rotation and add a specific function to compute horizontal directions.
This gives around 9% uplift for cdef_find_dir for both Clang and GCC.
Change-Id: Id68f4a6a43ecf0cf830e73e2366a16bc2a381bda
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 4261592..30a108e 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -196,6 +196,23 @@
#endif
}
+static INLINE int32x4_t horizontal_add_4d_s16x8(const int16x8_t sum[4]) {
+#if AOM_ARCH_AARCH64
+ const int16x8_t a0 = vpaddq_s16(sum[0], sum[1]);
+ const int16x8_t a1 = vpaddq_s16(sum[2], sum[3]);
+ const int16x8_t b0 = vpaddq_s16(a0, a1);
+ return vpaddlq_s16(b0);
+#else
+ const int16x4_t a0 = vadd_s16(vget_low_s16(sum[0]), vget_high_s16(sum[0]));
+ const int16x4_t a1 = vadd_s16(vget_low_s16(sum[1]), vget_high_s16(sum[1]));
+ const int16x4_t a2 = vadd_s16(vget_low_s16(sum[2]), vget_high_s16(sum[2]));
+ const int16x4_t a3 = vadd_s16(vget_low_s16(sum[3]), vget_high_s16(sum[3]));
+ const int16x4_t b0 = vpadd_s16(a0, a1);
+ const int16x4_t b1 = vpadd_s16(a2, a3);
+ return vpaddlq_s16(vcombine_s16(b0, b1));
+#endif
+}
+
static INLINE uint32_t horizontal_add_u32x2(const uint32x2_t a) {
#if AOM_ARCH_AARCH64
return vaddv_u32(a);
diff --git a/av1/common/arm/cdef_block_neon.c b/av1/common/arm/cdef_block_neon.c
index 45d58aa..68a292b 100644
--- a/av1/common/arm/cdef_block_neon.c
+++ b/av1/common/arm/cdef_block_neon.c
@@ -115,9 +115,8 @@
return cost;
}
-// This function is called a first time to compute the cost along directions 4,
-// 5, 6, 7, and then a second time on a rotated block to compute directions
-// 0, 1, 2, 3. (0 means 45-degree up-right, 2 is horizontal, and so on.)
+// This function computes the cost along directions 4, 5, 6, 7. (4 is diagonal
+// down-right, 6 is vertical).
//
// For each direction the lines are shifted so that we can perform a
// basic sum on each vector element. For example, direction 5 is "south by
@@ -147,8 +146,8 @@
// two of them to compute each half of the new configuration, and pad the empty
// spaces with zeros. Similar shifting is done for other directions, except
// direction 6 which is straightforward as it's the vertical direction.
-static INLINE uint32x4_t compute_directions_neon(int16x8_t lines[8],
- uint32_t cost[4]) {
+static INLINE uint32x4_t compute_vert_directions_neon(int16x8_t lines[8],
+ uint32_t cost[4]) {
const int16x8_t zero = vdupq_n_s16(0);
// Partial sums for lines 0 and 1.
@@ -227,46 +226,157 @@
return costs[0];
}
-static INLINE int64x2_t ziplo_s64(int32x4_t a, int32x4_t b) {
- return vcombine_s64(vget_low_s64(vreinterpretq_s64_s32(a)),
- vget_low_s64(vreinterpretq_s64_s32(b)));
+static INLINE uint32x4_t fold_mul_and_sum_pairwise_neon(int16x8_t partiala,
+ int16x8_t partialb,
+ int16x8_t partialc,
+ uint32x4_t const0) {
+ // Reverse partial c.
+ // pattern = { 10 11 8 9 6 7 4 5 2 3 0 1 12 13 14 15 }.
+ uint8x16_t pattern = vreinterpretq_u8_u64(
+ vcombine_u64(vcreate_u64((uint64_t)0x05040706 << 32 | 0x09080b0a),
+ vcreate_u64((uint64_t)0x0f0e0d0c << 32 | 0x01000302)));
+
+#if AOM_ARCH_AARCH64
+ partialc =
+ vreinterpretq_s16_s8(vqtbl1q_s8(vreinterpretq_s8_s16(partialc), pattern));
+#else
+ int8x8x2_t p = { { vget_low_s8(vreinterpretq_s8_s16(partialc)),
+ vget_high_s8(vreinterpretq_s8_s16(partialc)) } };
+ int8x8_t shuffle_hi = vtbl2_s8(p, vget_high_s8(vreinterpretq_s8_u8(pattern)));
+ int8x8_t shuffle_lo = vtbl2_s8(p, vget_low_s8(vreinterpretq_s8_u8(pattern)));
+ partialc = vreinterpretq_s16_s8(vcombine_s8(shuffle_lo, shuffle_hi));
+#endif
+
+ int32x4_t partiala_s32 = vpaddlq_s16(partiala);
+ int32x4_t partialb_s32 = vpaddlq_s16(partialb);
+ int32x4_t partialc_s32 = vpaddlq_s16(partialc);
+
+ partiala_s32 = vmulq_s32(partiala_s32, partiala_s32);
+ partialb_s32 = vmulq_s32(partialb_s32, partialb_s32);
+ partialc_s32 = vmulq_s32(partialc_s32, partialc_s32);
+
+ partiala_s32 = vaddq_s32(partiala_s32, partialc_s32);
+
+ uint32x4_t cost = vmulq_n_u32(vreinterpretq_u32_s32(partialb_s32), 105);
+ cost = vmlaq_u32(cost, vreinterpretq_u32_s32(partiala_s32), const0);
+ return cost;
}
-static INLINE int64x2_t ziphi_s64(int32x4_t a, int32x4_t b) {
- return vcombine_s64(vget_high_s64(vreinterpretq_s64_s32(a)),
- vget_high_s64(vreinterpretq_s64_s32(b)));
-}
+// This function computes the cost along directions 0, 1, 2, 3. (0 means
+// 45-degree up-right, 2 is horizontal).
+//
+// For direction 1 and 3 ("east northeast" and "east southeast") the shifted
+// lines need three vectors instead of two. For direction 1 for example, we need
+// to compute the sums along the line i below:
+// 0 0 1 1 2 2 3 3
+// 1 1 2 2 3 3 4 4
+// 2 2 3 3 4 4 5 5
+// 3 3 4 4 5 5 6 6
+// 4 4 5 5 6 6 7 7
+// 5 5 6 6 7 7 8 8
+// 6 6 7 7 8 8 9 9
+// 7 7 8 8 9 9 10 10
+//
+// Which means we need the following configuration:
+// 0 0 1 1 2 2 3 3
+// 1 1 2 2 3 3 4 4
+// 2 2 3 3 4 4 5 5
+// 3 3 4 4 5 5 6 6
+// 4 4 5 5 6 6 7 7
+// 5 5 6 6 7 7 8 8
+// 6 6 7 7 8 8 9 9
+// 7 7 8 8 9 9 10 10
+//
+// Three vectors are needed to compute this, as well as some extra pairwise
+// additions.
+static uint32x4_t compute_horiz_directions_neon(int16x8_t lines[8],
+ uint32_t cost[4]) {
+ const int16x8_t zero = vdupq_n_s16(0);
-// Transpose and reverse the order of the lines -- equivalent to a 90-degree
-// counter-clockwise rotation of the pixels.
-static INLINE void array_reverse_transpose_8x8_neon(int16x8_t *in,
- int16x8_t *res) {
- const int32x4_t tr0_0 = vreinterpretq_s32_s16(vzipq_s16(in[0], in[1]).val[0]);
- const int32x4_t tr0_1 = vreinterpretq_s32_s16(vzipq_s16(in[2], in[3]).val[0]);
- const int32x4_t tr0_2 = vreinterpretq_s32_s16(vzipq_s16(in[0], in[1]).val[1]);
- const int32x4_t tr0_3 = vreinterpretq_s32_s16(vzipq_s16(in[2], in[3]).val[1]);
- const int32x4_t tr0_4 = vreinterpretq_s32_s16(vzipq_s16(in[4], in[5]).val[0]);
- const int32x4_t tr0_5 = vreinterpretq_s32_s16(vzipq_s16(in[6], in[7]).val[0]);
- const int32x4_t tr0_6 = vreinterpretq_s32_s16(vzipq_s16(in[4], in[5]).val[1]);
- const int32x4_t tr0_7 = vreinterpretq_s32_s16(vzipq_s16(in[6], in[7]).val[1]);
+ // Compute diagonal directions (1, 2, 3).
+ // Partial sums for lines 0 and 1.
+ int16x8_t partial0a = lines[0];
+ partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[1], 7));
+ int16x8_t partial0b = vextq_s16(lines[1], zero, 7);
+ int16x8_t partial1a = vaddq_s16(lines[0], vextq_s16(zero, lines[1], 6));
+ int16x8_t partial1b = vextq_s16(lines[1], zero, 6);
+ int16x8_t partial3a = vextq_s16(lines[0], zero, 2);
+ partial3a = vaddq_s16(partial3a, vextq_s16(lines[1], zero, 4));
+ int16x8_t partial3b = vextq_s16(zero, lines[0], 2);
+ partial3b = vaddq_s16(partial3b, vextq_s16(zero, lines[1], 4));
- const int32x4_t tr1_0 = vzipq_s32(tr0_0, tr0_1).val[0];
- const int32x4_t tr1_1 = vzipq_s32(tr0_4, tr0_5).val[0];
- const int32x4_t tr1_2 = vzipq_s32(tr0_0, tr0_1).val[1];
- const int32x4_t tr1_3 = vzipq_s32(tr0_4, tr0_5).val[1];
- const int32x4_t tr1_4 = vzipq_s32(tr0_2, tr0_3).val[0];
- const int32x4_t tr1_5 = vzipq_s32(tr0_6, tr0_7).val[0];
- const int32x4_t tr1_6 = vzipq_s32(tr0_2, tr0_3).val[1];
- const int32x4_t tr1_7 = vzipq_s32(tr0_6, tr0_7).val[1];
+ // Partial sums for lines 2 and 3.
+ partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[2], 6));
+ partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[3], 5));
+ partial0b = vaddq_s16(partial0b, vextq_s16(lines[2], zero, 6));
+ partial0b = vaddq_s16(partial0b, vextq_s16(lines[3], zero, 5));
+ partial1a = vaddq_s16(partial1a, vextq_s16(zero, lines[2], 4));
+ partial1a = vaddq_s16(partial1a, vextq_s16(zero, lines[3], 2));
+ partial1b = vaddq_s16(partial1b, vextq_s16(lines[2], zero, 4));
+ partial1b = vaddq_s16(partial1b, vextq_s16(lines[3], zero, 2));
+ partial3a = vaddq_s16(partial3a, vextq_s16(lines[2], zero, 6));
+ partial3b = vaddq_s16(partial3b, vextq_s16(zero, lines[2], 6));
+ partial3b = vaddq_s16(partial3b, lines[3]);
- res[7] = vreinterpretq_s16_s64(ziplo_s64(tr1_0, tr1_1));
- res[6] = vreinterpretq_s16_s64(ziphi_s64(tr1_0, tr1_1));
- res[5] = vreinterpretq_s16_s64(ziplo_s64(tr1_2, tr1_3));
- res[4] = vreinterpretq_s16_s64(ziphi_s64(tr1_2, tr1_3));
- res[3] = vreinterpretq_s16_s64(ziplo_s64(tr1_4, tr1_5));
- res[2] = vreinterpretq_s16_s64(ziphi_s64(tr1_4, tr1_5));
- res[1] = vreinterpretq_s16_s64(ziplo_s64(tr1_6, tr1_7));
- res[0] = vreinterpretq_s16_s64(ziphi_s64(tr1_6, tr1_7));
+ // Partial sums for lines 4 and 5.
+ partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[4], 4));
+ partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[5], 3));
+ partial0b = vaddq_s16(partial0b, vextq_s16(lines[4], zero, 4));
+ partial0b = vaddq_s16(partial0b, vextq_s16(lines[5], zero, 3));
+ partial1b = vaddq_s16(partial1b, lines[4]);
+ partial1b = vaddq_s16(partial1b, vextq_s16(zero, lines[5], 6));
+ int16x8_t partial1c = vextq_s16(lines[5], zero, 6);
+ partial3b = vaddq_s16(partial3b, vextq_s16(lines[4], zero, 2));
+ partial3b = vaddq_s16(partial3b, vextq_s16(lines[5], zero, 4));
+ int16x8_t partial3c = vextq_s16(zero, lines[4], 2);
+ partial3c = vaddq_s16(partial3c, vextq_s16(zero, lines[5], 4));
+
+ // Partial sums for lines 6 and 7.
+ partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[6], 2));
+ partial0a = vaddq_s16(partial0a, vextq_s16(zero, lines[7], 1));
+ partial0b = vaddq_s16(partial0b, vextq_s16(lines[6], zero, 2));
+ partial0b = vaddq_s16(partial0b, vextq_s16(lines[7], zero, 1));
+ partial1b = vaddq_s16(partial1b, vextq_s16(zero, lines[6], 4));
+ partial1b = vaddq_s16(partial1b, vextq_s16(zero, lines[7], 2));
+ partial1c = vaddq_s16(partial1c, vextq_s16(lines[6], zero, 4));
+ partial1c = vaddq_s16(partial1c, vextq_s16(lines[7], zero, 2));
+ partial3b = vaddq_s16(partial3b, vextq_s16(lines[6], zero, 6));
+ partial3c = vaddq_s16(partial3c, vextq_s16(zero, lines[6], 6));
+ partial3c = vaddq_s16(partial3c, lines[7]);
+
+ // Special case for direction 2 as it's just a sum along each line.
+ int16x8_t lines03[4] = { lines[0], lines[1], lines[2], lines[3] };
+ int16x8_t lines47[4] = { lines[4], lines[5], lines[6], lines[7] };
+ int32x4_t partial2a = horizontal_add_4d_s16x8(lines03);
+ int32x4_t partial2b = horizontal_add_4d_s16x8(lines47);
+
+ uint32x4_t partial2a_u32 =
+ vreinterpretq_u32_s32(vmulq_s32(partial2a, partial2a));
+ uint32x4_t partial2b_u32 =
+ vreinterpretq_u32_s32(vmulq_s32(partial2b, partial2b));
+
+ uint32x4_t const0 = vreinterpretq_u32_u64(
+ vcombine_u64(vcreate_u64((uint64_t)420 << 32 | 840),
+ vcreate_u64((uint64_t)210 << 32 | 280)));
+ uint32x4_t const1 = vreinterpretq_u32_u64(
+ vcombine_u64(vcreate_u64((uint64_t)140 << 32 | 168),
+ vcreate_u64((uint64_t)105 << 32 | 120)));
+ uint32x4_t const2 = vreinterpretq_u32_u64(
+ vcombine_u64(vcreate_u64((uint64_t)210 << 32 | 420),
+ vcreate_u64((uint64_t)105 << 32 | 140)));
+
+ uint32x4_t costs[4];
+ costs[0] = fold_mul_and_sum_neon(partial0a, partial0b, const0, const1);
+ costs[1] =
+ fold_mul_and_sum_pairwise_neon(partial1a, partial1b, partial1c, const2);
+ costs[2] = vaddq_u32(partial2a_u32, partial2b_u32);
+ costs[2] = vmulq_n_u32(costs[2], 105);
+ costs[3] =
+ fold_mul_and_sum_pairwise_neon(partial3c, partial3b, partial3a, const2);
+
+ costs[0] = horizontal_add_4d_u32x4(costs);
+ vst1q_u32(cost, costs[0]);
+ return costs[0];
}
int cdef_find_dir_neon(const uint16_t *img, int stride, int32_t *var,
@@ -282,12 +392,10 @@
}
// Compute "mostly vertical" directions.
- uint32x4_t cost47 = compute_directions_neon(lines, cost + 4);
-
- array_reverse_transpose_8x8_neon(lines, lines);
+ uint32x4_t cost47 = compute_vert_directions_neon(lines, cost + 4);
// Compute "mostly horizontal" directions.
- uint32x4_t cost03 = compute_directions_neon(lines, cost);
+ uint32x4_t cost03 = compute_horiz_directions_neon(lines, cost);
// Find max cost as well as its index to get best_dir.
// The max cost needs to be propagated in the whole vector to find its