[CFL] Faster AVX2 Average Subtract

Based on the observation that for small blocks AVX2 does not outperform
SSE2, we call the SSE2 code for block widths 4 and 8.

For widths 16 and 32, the AVX2 version is optimized by:
  * Summing over two rows in the summing loop;
  * Operating over the full 256bit registers in the summing loop;
  * Using more accumulators to reduce coupling operations;
  * Leveraging chained hadd calls in the fill function.

AVX2/CFLSubAvgTest
4x4: C time = 384 us, SIMD time = 153 us (~2.5x)
8x8: C time = 805 us, SIMD time = 229 us (~3.5x)
16x16: C time = 2757 us, SIMD time = 775 us (~3.6x)
32x32: C time = 10035 us, SIMD time = 2524 us (~4x)

Change-Id: I683994026c1f1626828e90949cd0bd911b46ed5e
diff --git a/av1/common/cfl.h b/av1/common/cfl.h
index 3c1494b..a76a27c 100644
--- a/av1/common/cfl.h
+++ b/av1/common/cfl.h
@@ -127,12 +127,17 @@
   assert(0);
 }
 
-#define CFL_SUB_AVG_X(arch, width, height, round_offset, num_pel_log2)        \
-  static void subtract_average_##width##x##height##_x(int16_t *pred_buf_q3) { \
-    subtract_average_##arch(pred_buf_q3, width, height, round_offset,         \
-                            num_pel_log2);                                    \
+// Declare a size-specific wrapper for the size-generic function. The compiler
+// will inline the size generic function in here, the advantage is that the size
+// will be constant allowing for loop unrolling and other constant propagated
+// goodness.
+#define CFL_SUB_AVG_X(arch, width, height, round_offset, num_pel_log2)      \
+  void subtract_average_##width##x##height##_##arch(int16_t *pred_buf_q3) { \
+    subtract_average_##arch(pred_buf_q3, width, height, round_offset,       \
+                            num_pel_log2);                                  \
   }
 
+// Declare size-specific wrappers for all valid CfL sizes.
 #define CFL_SUB_AVG_FN(arch)                                                \
   CFL_SUB_AVG_X(arch, 4, 4, 8, 4)                                           \
   CFL_SUB_AVG_X(arch, 4, 8, 16, 5)                                          \
@@ -150,25 +155,25 @@
   CFL_SUB_AVG_X(arch, 32, 32, 512, 10)                                      \
   cfl_subtract_average_fn get_subtract_average_fn_##arch(TX_SIZE tx_size) { \
     static const cfl_subtract_average_fn sub_avg[TX_SIZES_ALL] = {          \
-      subtract_average_4x4_x,    /* 4x4 */                                  \
-      subtract_average_8x8_x,    /* 8x8 */                                  \
-      subtract_average_16x16_x,  /* 16x16 */                                \
-      subtract_average_32x32_x,  /* 32x32 */                                \
-      cfl_subtract_average_null, /* 64x64 (invalid CFL size) */             \
-      subtract_average_4x8_x,    /* 4x8 */                                  \
-      subtract_average_8x4_x,    /* 8x4 */                                  \
-      subtract_average_8x16_x,   /* 8x16 */                                 \
-      subtract_average_16x8_x,   /* 16x8 */                                 \
-      subtract_average_16x32_x,  /* 16x32 */                                \
-      subtract_average_32x16_x,  /* 32x16 */                                \
-      cfl_subtract_average_null, /* 32x64 (invalid CFL size) */             \
-      cfl_subtract_average_null, /* 64x32 (invalid CFL size) */             \
-      subtract_average_4x16_x,   /* 4x16 (invalid CFL size) */              \
-      subtract_average_16x4_x,   /* 16x4 (invalid CFL size) */              \
-      subtract_average_8x32_x,   /* 8x32 (invalid CFL size) */              \
-      subtract_average_32x8_x,   /* 32x8 (invalid CFL size) */              \
-      cfl_subtract_average_null, /* 16x64 (invalid CFL size) */             \
-      cfl_subtract_average_null, /* 64x16 (invalid CFL size) */             \
+      subtract_average_4x4_##arch,   /* 4x4 */                              \
+      subtract_average_8x8_##arch,   /* 8x8 */                              \
+      subtract_average_16x16_##arch, /* 16x16 */                            \
+      subtract_average_32x32_##arch, /* 32x32 */                            \
+      cfl_subtract_average_null,     /* 64x64 (invalid CFL size) */         \
+      subtract_average_4x8_##arch,   /* 4x8 */                              \
+      subtract_average_8x4_##arch,   /* 8x4 */                              \
+      subtract_average_8x16_##arch,  /* 8x16 */                             \
+      subtract_average_16x8_##arch,  /* 16x8 */                             \
+      subtract_average_16x32_##arch, /* 16x32 */                            \
+      subtract_average_32x16_##arch, /* 32x16 */                            \
+      cfl_subtract_average_null,     /* 32x64 (invalid CFL size) */         \
+      cfl_subtract_average_null,     /* 64x32 (invalid CFL size) */         \
+      subtract_average_4x16_##arch,  /* 4x16 (invalid CFL size) */          \
+      subtract_average_16x4_##arch,  /* 16x4 (invalid CFL size) */          \
+      subtract_average_8x32_##arch,  /* 8x32 (invalid CFL size) */          \
+      subtract_average_32x8_##arch,  /* 32x8 (invalid CFL size) */          \
+      cfl_subtract_average_null,     /* 16x64 (invalid CFL size) */         \
+      cfl_subtract_average_null,     /* 64x16 (invalid CFL size) */         \
     };                                                                      \
     /* Modulo TX_SIZES_ALL to ensure that an attacker won't be able to */   \
     /* index the function pointer array out of bounds. */                   \
diff --git a/av1/common/x86/cfl_avx2.c b/av1/common/x86/cfl_avx2.c
index d7ecad1..775d3ff 100644
--- a/av1/common/x86/cfl_avx2.c
+++ b/av1/common/x86/cfl_avx2.c
@@ -169,64 +169,70 @@
   return predict_hbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3];
 }
 
-static INLINE __m256i fill_sum_epi32(__m256i l0) {
-  l0 = _mm256_add_epi32(l0, _mm256_shuffle_epi32(l0, _MM_SHUFFLE(1, 0, 3, 2)));
-  return _mm256_add_epi32(l0,
-                          _mm256_shuffle_epi32(l0, _MM_SHUFFLE(2, 3, 0, 1)));
+// Returns a vector where all the (32-bits) elements are the sum of all the
+// lanes in a.
+static INLINE __m256i fill_sum_epi32(__m256i a) {
+  // Given that a == [A, B, C, D, E, F, G, H]
+  a = _mm256_hadd_epi32(a, a);
+  // Given that A' == A + B, C' == C + D, E' == E + F, G' == G + H
+  // a == [A', C', A', C', E', G', E', G']
+  a = _mm256_permute4x64_epi64(a, _MM_SHUFFLE(3, 1, 2, 0));
+  // a == [A', C', E', G', A', C', E', G']
+  a = _mm256_hadd_epi32(a, a);
+  // Given that A'' == A' + C' and E'' == E' + G'
+  // a == [A'', E'', A'', E'', A'', E'', A'', E'']
+  return _mm256_hadd_epi32(a, a);
+  // Given that A''' == A'' + E''
+  // a == [A''', A''', A''', A''', A''', A''', A''', A''']
+}
+
+static INLINE __m256i _mm256_addl_epi16(__m256i a) {
+  return _mm256_add_epi32(_mm256_unpacklo_epi16(a, _mm256_setzero_si256()),
+                          _mm256_unpackhi_epi16(a, _mm256_setzero_si256()));
 }
 
 static INLINE void subtract_average_avx2(int16_t *pred_buf, int width,
                                          int height, int round_offset,
                                          int num_pel_log2) {
-  const __m256i zeros = _mm256_setzero_si256();
+  // Use SSE2 version for smaller widths
+  assert(width == 16 || width == 32);
   __m256i *row = (__m256i *)pred_buf;
   const __m256i *const end = row + height * CFL_BUF_LINE_I256;
-  const int step = CFL_BUF_LINE_I256 * (1 + (width == 8) + 3 * (width == 4));
-  union {
-    __m256i v;
-    int32_t i32[8];
-  } sum;
-  sum.v = zeros;
+  // To maximize usage of the AVX2 registers, we sum two rows per loop
+  // iteration
+  const int step = 2 * CFL_BUF_LINE_I256;
+  __m256i sum = _mm256_setzero_si256();
+
+  // For width 32, we use a second sum accumulator to reduce accumulator
+  // dependencies in the loop.
+  __m256i sum2;
+  if (width == 32) sum2 = _mm256_setzero_si256();
   do {
-    if (width == 4) {
-      __m256i l0 = _mm256_loadu_si256(row);
-      __m256i l1 = _mm256_loadu_si256(row + CFL_BUF_LINE_I256);
-      __m256i l2 = _mm256_loadu_si256(row + 2 * CFL_BUF_LINE_I256);
-      __m256i l3 = _mm256_loadu_si256(row + 3 * CFL_BUF_LINE_I256);
-
-      __m256i t0 = _mm256_add_epi16(l0, l1);
-      __m256i t1 = _mm256_add_epi16(l2, l3);
-
-      sum.v = _mm256_add_epi32(
-          sum.v, _mm256_add_epi32(_mm256_unpacklo_epi16(t0, zeros),
-                                  _mm256_unpacklo_epi16(t1, zeros)));
-    } else {
-      __m256i l0;
-      if (width == 8) {
-        l0 = _mm256_add_epi16(_mm256_loadu_si256(row),
-                              _mm256_loadu_si256(row + CFL_BUF_LINE_I256));
-      } else {
-        l0 = _mm256_loadu_si256(row);
-        l0 = _mm256_add_epi16(l0, _mm256_permute2x128_si256(l0, l0, 1));
-      }
-      sum.v = _mm256_add_epi32(
-          sum.v, _mm256_add_epi32(_mm256_unpacklo_epi16(l0, zeros),
-                                  _mm256_unpackhi_epi16(l0, zeros)));
-      if (width == 32) {
-        l0 = _mm256_loadu_si256(row + 1);
-        l0 = _mm256_add_epi16(l0, _mm256_permute2x128_si256(l0, l0, 1));
-        sum.v = _mm256_add_epi32(
-            sum.v, _mm256_add_epi32(_mm256_unpacklo_epi16(l0, zeros),
-                                    _mm256_unpackhi_epi16(l0, zeros)));
-      }
+    // Add top row to the bottom row
+    __m256i l0 = _mm256_add_epi16(_mm256_loadu_si256(row),
+                                  _mm256_loadu_si256(row + CFL_BUF_LINE_I256));
+    sum = _mm256_add_epi32(sum, _mm256_addl_epi16(l0));
+    if (width == 32) { /* Don't worry, this if it gets optimized out. */
+      // Add the second part of the top row to the second part of the bottom row
+      __m256i l1 =
+          _mm256_add_epi16(_mm256_loadu_si256(row + 1),
+                           _mm256_loadu_si256(row + 1 + CFL_BUF_LINE_I256));
+      // Store the sum of the second part in the same accumulator as the first
+      // part
+      sum2 = _mm256_add_epi32(sum2, _mm256_addl_epi16(l1));
     }
   } while ((row += step) < end);
+  // Combine both sum accumulator
+  if (width == 32) sum = _mm256_add_epi32(sum, sum2);
 
-  sum.v = fill_sum_epi32(sum.v);
+  // The sum accumulator now contains the 8 lanes
+  __m256i fill = fill_sum_epi32(sum);
 
-  __m256i avg_epi16 =
-      _mm256_set1_epi16((sum.i32[0] + round_offset) >> num_pel_log2);
+  __m256i avg_epi16 = _mm256_srli_epi32(
+      _mm256_add_epi32(fill, _mm256_set1_epi32(round_offset)), num_pel_log2);
+  avg_epi16 = _mm256_packs_epi32(avg_epi16, avg_epi16);
 
+  // Store and subtract loop
   row = (__m256i *)pred_buf;
   do {
     _mm256_storeu_si256(row,
@@ -238,4 +244,40 @@
   } while ((row += CFL_BUF_LINE_I256) < end);
 }
 
-CFL_SUB_AVG_FN(avx2)
+// Declare wrappers for AVX2 sizes
+CFL_SUB_AVG_X(avx2, 16, 4, 32, 6)
+CFL_SUB_AVG_X(avx2, 16, 8, 64, 7)
+CFL_SUB_AVG_X(avx2, 16, 16, 128, 8)
+CFL_SUB_AVG_X(avx2, 16, 32, 256, 9)
+CFL_SUB_AVG_X(avx2, 32, 8, 128, 8)
+CFL_SUB_AVG_X(avx2, 32, 16, 256, 9)
+CFL_SUB_AVG_X(avx2, 32, 32, 512, 10)
+
+// Based on the observation that for small blocks AVX2 does not outperform
+// SSE2, we call the SSE2 code for block widths 4 and 8.
+cfl_subtract_average_fn get_subtract_average_fn_avx2(TX_SIZE tx_size) {
+  static const cfl_subtract_average_fn sub_avg[TX_SIZES_ALL] = {
+    subtract_average_4x4_sse2,   /* 4x4 */
+    subtract_average_8x8_sse2,   /* 8x8 */
+    subtract_average_16x16_avx2, /* 16x16 */
+    subtract_average_32x32_avx2, /* 32x32 */
+    cfl_subtract_average_null,   /* 64x64 (invalid CFL size) */
+    subtract_average_4x8_sse2,   /* 4x8 */
+    subtract_average_8x4_sse2,   /* 8x4 */
+    subtract_average_8x16_sse2,  /* 8x16 */
+    subtract_average_16x8_avx2,  /* 16x8 */
+    subtract_average_16x32_avx2, /* 16x32 */
+    subtract_average_32x16_avx2, /* 32x16 */
+    cfl_subtract_average_null,   /* 32x64 (invalid CFL size) */
+    cfl_subtract_average_null,   /* 64x32 (invalid CFL size) */
+    subtract_average_4x16_sse2,  /* 4x16 */
+    subtract_average_16x4_avx2,  /* 16x4 */
+    subtract_average_8x32_sse2,  /* 8x32 */
+    subtract_average_32x8_avx2,  /* 32x8 */
+    cfl_subtract_average_null,   /* 16x64 (invalid CFL size) */
+    cfl_subtract_average_null,   /* 64x16 (invalid CFL size) */
+  };
+  // Modulo TX_SIZES_ALL to ensure that an attacker won't be able to
+  // index the function pointer array out of bounds.
+  return sub_avg[tx_size % TX_SIZES_ALL];
+}
diff --git a/av1/common/x86/cfl_simd.h b/av1/common/x86/cfl_simd.h
index 17aaf15..3e75cb4 100644
--- a/av1/common/x86/cfl_simd.h
+++ b/av1/common/x86/cfl_simd.h
@@ -60,3 +60,14 @@
                                    int16_t *output_q3);
 void subsample_lbd_420_16x32_ssse3(const uint8_t *input, int input_stride,
                                    int16_t *output_q3);
+
+// SSE2 version is optimal for with == 4, we reuse them in AVX2
+void subtract_average_4x4_sse2(int16_t *pred_buf_q3);
+void subtract_average_4x8_sse2(int16_t *pred_buf_q3);
+void subtract_average_4x16_sse2(int16_t *pred_buf_q3);
+
+// SSE2 version is optimal for with == 8, we reuse them in AVX2
+void subtract_average_8x4_sse2(int16_t *pred_buf_q3);
+void subtract_average_8x8_sse2(int16_t *pred_buf_q3);
+void subtract_average_8x16_sse2(int16_t *pred_buf_q3);
+void subtract_average_8x32_sse2(int16_t *pred_buf_q3);