Add NEON version of av1_nn_predict function

SpeedUp

size            gain

10x64x16        (2.61)
12x12x1         (2.52)
12x24x1         (2.96)
12x32x1         (3.05)
18x24x4         (1.99)
18x32x4         (2.02)
4x16x1          (2.36)
8x16x1          (2.58)
8x16x4          (2.97)
8x24x1          (2.75)
8x32x1          (2.86)
8x64x1          (2.98)
9x32x3          (1.77)
4x8x4           (2.65)

via NEON/NnPredictTest.DISABLED_Speed

Change-Id: Id93a998538475fd808ac20c80218b05b911fec62
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 36e3816..4bef55e 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -378,6 +378,7 @@
 
 list(APPEND AOM_AV1_ENCODER_INTRIN_NEON
             "${AOM_ROOT}/av1/encoder/arm/neon/quantize_neon.c"
+            "${AOM_ROOT}/av1/encoder/arm/neon/ml_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/picksrt_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/rdopt_neon.c"
             "${AOM_ROOT}/av1/encoder/arm/neon/av1_error_neon.c"
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index b47e1c0..2a94ef3 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -406,7 +406,7 @@
   specialize qw/av1_get_horver_correlation_full sse4_1 avx2 neon/;
 
   add_proto qw/void av1_nn_predict/, " const float *input_nodes, const NN_CONFIG *const nn_config, int reduce_prec, float *const output";
-  specialize qw/av1_nn_predict sse3/;
+  specialize qw/av1_nn_predict sse3 neon/;
 }
 # end encoder functions
 
diff --git a/av1/encoder/arm/neon/ml_neon.c b/av1/encoder/arm/neon/ml_neon.c
new file mode 100644
index 0000000..fcff3a9
--- /dev/null
+++ b/av1/encoder/arm/neon/ml_neon.c
@@ -0,0 +1,338 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <stdbool.h>
+#include <assert.h>
+#include <arm_neon.h>
+
+#include "config/av1_rtcd.h"
+#include "av1/encoder/ml.h"
+
+static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l,
+                         const float32x4_t *zero) {
+  *out_h = vmaxq_f32(*out_h, *zero);
+  *out_l = vmaxq_f32(*out_l, *zero);
+}
+
+static void nn_activate4(float32x4_t *x, const float32x4_t *zero) {
+  *x = vmaxq_f32(*x, *zero);
+}
+
+#define CLAMP_0(x) (x = x > 0 ? x : 0)
+
+static void nn_propagate_8to1(int num_inputs, const float *const inputs,
+                              const float *const weights,
+                              const float *layer_bias,
+                              float *const output_nodes, bool output_layer) {
+  const float32x4_t zero = vdupq_n_f32(0);
+  float32x4_t vadd = zero;
+  float total = *layer_bias;
+
+  for (int in = 0; in < num_inputs; in += 8) {
+    const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
+    const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
+
+    const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
+    const float32x4_t weights_l = vld1q_f32(&weights[in]);
+
+    vadd = vmlaq_f32(vadd, inputs_h, weights_h);
+    vadd = vmlaq_f32(vadd, inputs_l, weights_l);
+  }
+#if defined(__aarch64__)
+  total += vaddvq_f32(vadd);
+#else
+  float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
+  vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
+  total += vget_lane_f32(vadd_lo, 0);
+#endif
+
+  if (!output_layer) CLAMP_0(total);
+  *output_nodes = total;
+}
+
+static void nn_propagate_xto1(int num_inputs, const float *const inputs,
+                              const float *const weights,
+                              const float *layer_bias,
+                              float *const output_nodes) {
+  float32x4_t vadd = vdupq_n_f32(0);
+
+  float total = *layer_bias;
+  int j = num_inputs;
+  int in = 0;
+  while (j > 7) {
+    const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
+    const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
+
+    const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
+    const float32x4_t weights_l = vld1q_f32(&weights[in]);
+
+    vadd = vmlaq_f32(vadd, inputs_h, weights_h);
+    vadd = vmlaq_f32(vadd, inputs_l, weights_l);
+    in += 8;
+    j -= 8;
+  }
+
+#if defined(__aarch64__)
+  total += vaddvq_f32(vadd);
+
+#else
+  float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
+  vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
+  total += vget_lane_f32(vadd_lo, 0);
+#endif
+  for (; in < num_inputs; in++) total += weights[in] * inputs[in];
+
+  *output_nodes = CLAMP_0(total);
+}
+
+static void nn_propagate_xsto1(int num_inputs, const float *const inputs,
+                               const float *const weights,
+                               const float *layer_bias,
+                               float *const output_nodes) {
+  float total = *layer_bias;
+#if defined(__aarch64__)
+  const float32x4_t v_inputs = vld1q_f32(inputs);
+  const float32x4_t v_weights = vld1q_f32(weights);
+  const float32x4_t vadd = vmulq_f32(v_inputs, v_weights);
+  total += vaddvq_f32(vadd);
+  int in = 4;
+#else
+  int in = 0;
+#endif
+  for (; in < num_inputs; in++) total += weights[in] * inputs[in];
+
+  *output_nodes = CLAMP_0(total);
+}
+
+static void nn_propagate_4to1(int num_inputs, const float *const inputs,
+                              const float *const weights,
+                              const float *layer_bias,
+                              float *const output_nodes, bool output_layer) {
+  const float32x4_t zero = vdupq_n_f32(0);
+  float32x4_t vadd = zero;
+  float total = *layer_bias;
+
+  for (int in = 0; in < num_inputs; in += 4) {
+    const float32x4_t v_inputs = vld1q_f32(&inputs[in]);
+    const float32x4_t v_weights = vld1q_f32(&weights[in]);
+    vadd = vmlaq_f32(vadd, v_inputs, v_weights);
+  }
+
+#if defined(__aarch64__)
+  total += vaddvq_f32(vadd);
+#else
+  float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
+  vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
+  total += vget_lane_f32(vadd_lo, 0);
+#endif
+
+  if (!output_layer) CLAMP_0(total);
+  *output_nodes = total;
+}
+
+static void nn_propagate_4to4(int num_inputs, const float *const inputs,
+                              const float *const weights,
+                              const float *layer_bias,
+                              float *const output_nodes, bool output_layer) {
+  float32x4_t outputs = vld1q_f32(layer_bias);
+  const float32x4_t zero = vdupq_n_f32(0);
+
+  float32x4_t mul0[2] = { zero, zero };
+  float32x4_t mul1[2] = { zero, zero };
+  for (int in = 0; in < num_inputs; in += 4) {
+    const float32x4_t v_input = vld1q_f32(&inputs[in]);
+
+    for (int i = 0; i < 2; i++) {
+      const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
+      mul0[i] = vmlaq_f32(mul0[i], weight0, v_input);
+      const float32x4_t weight1 =
+          vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
+      mul1[i] = vmlaq_f32(mul1[i], weight1, v_input);
+    }
+  }
+  for (int i = 0; i < 2; i++)
+#if defined(__aarch64__)
+    mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
+  const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]);
+#else
+    mul0[i] =
+        vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
+                     vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
+  const float32x4_t hh =
+      vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
+                   vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
+#endif
+
+  outputs = vaddq_f32(outputs, hh);
+  if (!output_layer) nn_activate4(&outputs, &zero);
+  vst1q_f32(output_nodes, outputs);
+}
+
+static void nn_propagate_4to8(const int num_inputs, const float *const inputs,
+                              const float *const weights,
+                              const float *layer_bias,
+                              float *const output_nodes, bool output_layer) {
+  float32x4_t out_h = vld1q_f32(&layer_bias[4]);
+  float32x4_t out_l = vld1q_f32(layer_bias);
+  const float32x4_t zero = vdupq_n_f32(0);
+  float32x4_t mul0[4] = { zero, zero, zero, zero };
+  float32x4_t mul1[4] = { zero, zero, zero, zero };
+
+  for (int in = 0; in < num_inputs; in += 4) {
+    const float32x4_t v_input = vld1q_f32(&inputs[in]);
+    for (int i = 0; i < 4; i++) {
+      const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
+      const float32x4_t weight1 =
+          vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
+      mul0[i] = vmlaq_f32(mul0[i], v_input, weight0);
+      mul1[i] = vmlaq_f32(mul1[i], v_input, weight1);
+    }
+  }
+  for (int i = 0; i < 4; i++)
+#if defined(__aarch64__)
+    mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
+  const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]);
+  const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]);
+#else
+    mul0[i] =
+        vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
+                     vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
+  const float32x4_t hh0 =
+      vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
+                   vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
+  const float32x4_t hh1 =
+      vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])),
+                   vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3])));
+#endif
+
+  out_h = vaddq_f32(out_h, hh1);
+  out_l = vaddq_f32(out_l, hh0);
+
+  if (!output_layer) nn_activate8(&out_h, &out_l, &zero);
+  vst1q_f32(&output_nodes[4], out_h);
+  vst1q_f32(output_nodes, out_l);
+}
+
+static void nn_propagate_8to4(const int num_inputs, const float *const inputs,
+                              const float *const weights,
+                              const float *layer_bias,
+                              float *const output_nodes, bool output_layer) {
+  float32x4_t outputs = vld1q_f32(layer_bias);
+  const float32x4_t zero = vdupq_n_f32(0);
+  float32x4_t add[4] = { zero, zero, zero, zero };
+  for (int in = 0; in < num_inputs; in += 8) {
+    const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
+    const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
+
+    for (int i = 0; i < 4; i++) {
+      const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]);
+      const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]);
+      add[i] = vmlaq_f32(add[i], inputs_l, weight_l);
+      add[i] = vmlaq_f32(add[i], inputs_h, weight_h);
+    }
+  }
+#if defined(__aarch64__)
+  const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]);
+  const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]);
+  const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h);
+#else
+  const float32x4_t hadd_h =
+      vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])),
+                   vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3])));
+  const float32x4_t hadd_l =
+      vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])),
+                   vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1])));
+  const float32x4_t haddhadd =
+      vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)),
+                   vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h)));
+#endif
+
+  outputs = vaddq_f32(outputs, haddhadd);
+  if (!output_layer) nn_activate4(&outputs, &zero);
+  vst1q_f32(output_nodes, outputs);
+}
+
+// Calculate prediction based on the given input features and neural net config.
+// Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
+// layer.
+void av1_nn_predict_neon(const float *input_nodes,
+                         const NN_CONFIG *const nn_config, int reduce_prec,
+                         float *const output) {
+  float buf[2][NN_MAX_NODES_PER_LAYER];
+  int buf_index = 0;
+  int num_inputs = nn_config->num_inputs;
+  // Hidden layers, except the final iteration is the output layer.
+  for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
+    const float *layer_weights = nn_config->weights[layer];
+    const float *layer_bias = nn_config->bias[layer];
+    bool output_layer = (layer == nn_config->num_hidden_layers);
+    float *const output_nodes = output_layer ? output : buf[buf_index];
+    const int num_outputs = output_layer ? nn_config->num_outputs
+                                         : nn_config->num_hidden_nodes[layer];
+
+    if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
+      for (int out = 0; out < num_outputs; out += 8) {
+        nn_propagate_4to8(num_inputs, input_nodes,
+                          &layer_weights[out * num_inputs], &layer_bias[out],
+                          &output_nodes[out], output_layer);
+      }
+    } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
+      for (int out = 0; out < num_outputs; out += 4) {
+        nn_propagate_8to4(num_inputs, input_nodes,
+                          &layer_weights[out * num_inputs], &layer_bias[out],
+                          &output_nodes[out], output_layer);
+      }
+    } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
+      for (int out = 0; out < num_outputs; out += 4) {
+        nn_propagate_4to4(num_inputs, input_nodes,
+                          &layer_weights[out * num_inputs], &layer_bias[out],
+                          &output_nodes[out], output_layer);
+      }
+    } else if (num_inputs % 8 == 0) {
+      for (int out = 0; out < num_outputs; out++) {
+        nn_propagate_8to1(num_inputs, input_nodes,
+                          &layer_weights[out * num_inputs], &layer_bias[out],
+                          &output_nodes[out], output_layer);
+      }
+    } else if (num_inputs % 4 == 0) {
+      for (int out = 0; out < num_outputs; out++) {
+        nn_propagate_4to1(num_inputs, input_nodes,
+                          &layer_weights[out * num_inputs], &layer_bias[out],
+                          &output_nodes[out], output_layer);
+      }
+    } else if (num_inputs > 8) {
+      for (int out = 0; out < num_outputs; out++) {
+        nn_propagate_xto1(num_inputs, input_nodes,
+                          &layer_weights[out * num_inputs], &layer_bias[out],
+                          &output_nodes[out]);
+      }
+    } else if (num_inputs >= 4) {
+      for (int out = 0; out < num_outputs; out++) {
+        nn_propagate_xsto1(num_inputs, input_nodes,
+                           &layer_weights[out * num_inputs], &layer_bias[out],
+                           &output_nodes[out]);
+      }
+    } else {
+      for (int node = 0; node < num_outputs; ++node) {
+        float val = layer_bias[node];
+        for (int i = 0; i < num_inputs; ++i)
+          val += layer_weights[node * num_inputs + i] * input_nodes[i];
+        // ReLU as activation function.
+        val = val > 0.0f ? val : 0.0f;  // Could use AOMMAX().
+        output_nodes[node] = val;
+      }
+    }
+    input_nodes = output_nodes;
+    num_inputs = num_outputs;
+    buf_index = 1 - buf_index;
+  }
+  if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
+}
diff --git a/test/av1_nn_predict_test.cc b/test/av1_nn_predict_test.cc
index c03cba8..cdf0417 100644
--- a/test/av1_nn_predict_test.cc
+++ b/test/av1_nn_predict_test.cc
@@ -214,4 +214,9 @@
                          ::testing::Values(av1_nn_predict_sse3));
 #endif
 
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(NEON, NnPredictTest,
+                         ::testing::Values(av1_nn_predict_neon));
+#endif
+
 }  // namespace