av1_fwd_txfm2d_neon.c: Use switch for small square problem sizes

For the smallest problem sizes we have a significant overhead from
needing to load and store the transform intermediate vectors. Avoiding
the kernel lookup and calling the kernels directly significantly
improves performance in these cases.

Benchmarking on a Neoverse N2 machine with Clang 16 and GCC 12, the
speed tests report a geomean ~22% reduction in times reported for 4x4
and a ~4.6% reduction for 8x8.

Change-Id: I5907fab09e40b1cea57c446fd7e604ae911ceae8
diff --git a/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c b/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c
index a17a41a..d70f5a5 100644
--- a/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c
+++ b/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c
@@ -1598,44 +1598,6 @@
                                             int32_t *output, int stride,
                                             int cos_bit);
 
-static const col_transform_1d_lbd_4_neon col_txfm4x4_arr[TX_TYPES] = {
-  fdct4x4_col_neon,       // DCT_DCT
-  fadst4x4_col_neon,      // ADST_DCT
-  fdct4x4_col_neon,       // DCT_ADST
-  fadst4x4_col_neon,      // ADST_ADST
-  fadst4x4_col_neon,      // FLIPADST_DCT
-  fdct4x4_col_neon,       // DCT_FLIPADST
-  fadst4x4_col_neon,      // FLIPADST_FLIPADST
-  fadst4x4_col_neon,      // ADST_FLIPADST
-  fadst4x4_col_neon,      // FLIPADST_ADST
-  fidentity4x4_col_neon,  // IDTX
-  fdct4x4_col_neon,       // V_DCT
-  fidentity4x4_col_neon,  // H_DCT
-  fadst4x4_col_neon,      // V_ADST
-  fidentity4x4_col_neon,  // H_ADST
-  fadst4x4_col_neon,      // V_FLIPADST
-  fidentity4x4_col_neon   // H_FLIPADST
-};
-
-static const row_transform_1d_lbd_4_neon row_txfm4x4_arr[TX_TYPES] = {
-  fdct4x4_row_neon,       // DCT_DCT
-  fdct4x4_row_neon,       // ADST_DCT
-  fadst4x4_row_neon,      // DCT_ADST
-  fadst4x4_row_neon,      // ADST_ADST
-  fdct4x4_row_neon,       // FLIPADST_DCT
-  fadst4x4_row_neon,      // DCT_FLIPADST
-  fadst4x4_row_neon,      // FLIPADST_FLIPADST
-  fadst4x4_row_neon,      // ADST_FLIPADST
-  fadst4x4_row_neon,      // FLIPADST_ADST
-  fidentity4x4_row_neon,  // IDTX
-  fidentity4x4_row_neon,  // V_DCT
-  fdct4x4_row_neon,       // H_DCT
-  fidentity4x4_row_neon,  // V_ADST
-  fadst4x4_row_neon,      // H_ADST
-  fidentity4x4_row_neon,  // V_FLIPADST
-  fadst4x4_row_neon       // H_FLIPADST
-};
-
 static const col_transform_1d_lbd_4_neon col_txfm4x8_arr[TX_TYPES] = {
   fdct4x8_col_neon,       // DCT_DCT
   fadst4x8_col_neon,      // ADST_DCT
@@ -1943,21 +1905,96 @@
 static void lowbd_fwd_txfm2d_4x4_neon(const int16_t *input, int32_t *output,
                                       int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  int16x4_t buf0[4], buf1[4];
-  const col_transform_1d_lbd_4_neon col_txfm = col_txfm4x4_arr[tx_type];
-  const row_transform_1d_lbd_4_neon row_txfm = row_txfm4x4_arr[tx_type];
   int ud_flip, lr_flip;
-
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   ud_adjust_input_and_stride(ud_flip, &input, &stride, 4);
-  col_txfm(input, buf0, stride, 13);
-  transpose_arrays_s16_4x4(buf0, buf1);
 
-  if (lr_flip) {
-    flip_buf_4_neon(buf1, buf0, 4);
-    row_txfm(buf0, output, 4, 13);
-  } else {
-    row_txfm(buf1, output, 4, 13);
+  int16x4_t buf0[4], buf1[4];
+  switch (tx_type) {
+    case DCT_DCT:
+      fdct4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fdct4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case ADST_DCT:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fdct4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case DCT_ADST:
+      fdct4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fadst4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case ADST_ADST:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fadst4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case FLIPADST_DCT:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fdct4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case DCT_FLIPADST:
+      fdct4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      flip_buf_4_neon(buf1, buf0, 4);
+      fadst4x4_row_neon(buf0, output, 4, 13);
+      break;
+    case FLIPADST_FLIPADST:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      flip_buf_4_neon(buf1, buf0, 4);
+      fadst4x4_row_neon(buf0, output, 4, 13);
+      break;
+    case ADST_FLIPADST:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      flip_buf_4_neon(buf1, buf0, 4);
+      fadst4x4_row_neon(buf0, output, 4, 13);
+      break;
+    case FLIPADST_ADST:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fadst4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case IDTX:
+      fidentity4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fidentity4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case V_DCT:
+      fdct4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fidentity4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case H_DCT:
+      fidentity4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fdct4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case V_ADST:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fidentity4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case H_ADST:
+      fidentity4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fadst4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case V_FLIPADST:
+      fadst4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      fidentity4x4_row_neon(buf1, output, 4, 13);
+      break;
+    case H_FLIPADST:
+      fidentity4x4_col_neon(input, buf0, stride, 13);
+      transpose_arrays_s16_4x4(buf0, buf1);
+      flip_buf_4_neon(buf1, buf0, 4);
+      fadst4x4_row_neon(buf0, output, 4, 13);
+      break;
   }
 }
 
@@ -2040,22 +2077,113 @@
 static void lowbd_fwd_txfm2d_8x8_neon(const int16_t *input, int32_t *output,
                                       int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  int16x8_t buf0[8], buf1[8];
-  const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x8_arr[tx_type];
-  const row_transform_1d_lbd_8_neon row_txfm = row_txfm8x8_arr[tx_type];
   int ud_flip, lr_flip;
-
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   ud_adjust_input_and_stride(ud_flip, &input, &stride, 8);
-  col_txfm(input, buf0, stride, 13);
-  shift_right_1_round_s16_x8(buf0, buf0, 8);
-  transpose_arrays_s16_8x8(buf0, buf1);
 
-  if (lr_flip) {
-    flip_buf_8_neon(buf1, buf0, 8);
-    row_txfm(buf0, output, 8, 13);
-  } else {
-    row_txfm(buf1, output, 8, 13);
+  int16x8_t buf0[8], buf1[8];
+
+  switch (tx_type) {
+    case DCT_DCT:
+      fdct8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fdct8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case ADST_DCT:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fdct8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case DCT_ADST:
+      fdct8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fadst8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case ADST_ADST:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fadst8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case FLIPADST_DCT:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fdct8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case DCT_FLIPADST:
+      fdct8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      flip_buf_8_neon(buf1, buf0, 8);
+      fadst8x8_row_neon(buf0, output, 8, 13);
+      break;
+    case FLIPADST_FLIPADST:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      flip_buf_8_neon(buf1, buf0, 8);
+      fadst8x8_row_neon(buf0, output, 8, 13);
+      break;
+    case ADST_FLIPADST:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      flip_buf_8_neon(buf1, buf0, 8);
+      fadst8x8_row_neon(buf0, output, 8, 13);
+      break;
+    case FLIPADST_ADST:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fadst8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case IDTX:
+      fidentity8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fidentity8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case V_DCT:
+      fdct8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fidentity8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case H_DCT:
+      fidentity8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fdct8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case V_ADST:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fidentity8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case H_ADST:
+      fidentity8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fadst8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case V_FLIPADST:
+      fadst8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      fidentity8x8_row_neon(buf1, output, 8, 13);
+      break;
+    case H_FLIPADST:
+      fidentity8x8_col_neon(input, buf0, stride, 13);
+      shift_right_1_round_s16_x8(buf0, buf0, 8);
+      transpose_arrays_s16_8x8(buf0, buf1);
+      flip_buf_8_neon(buf1, buf0, 8);
+      fadst8x8_row_neon(buf0, output, 8, 13);
+      break;
   }
 }