Merge "Merge tag 'v3.1.2' into HEAD"
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index af61fcd..7da0ae0 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -431,8 +431,11 @@
   specialize qw/av1_get_horver_correlation_full sse4_1 avx2 neon/;
 
   add_proto qw/void av1_nn_predict/, " const float *input_nodes, const NN_CONFIG *const nn_config, int reduce_prec, float *const output";
+
+  add_proto qw/void av1_nn_fast_softmax_16/, " const float *input_nodes, float *output";
   if (aom_config("CONFIG_EXCLUDE_SIMD_MISMATCH") ne "yes") {
     specialize qw/av1_nn_predict sse3 neon/;
+    specialize qw/av1_nn_fast_softmax_16 sse3/;
   }
 
   # CNN functions
diff --git a/av1/encoder/ml.c b/av1/encoder/ml.c
index 57228ec..7eb643a 100644
--- a/av1/encoder/ml.c
+++ b/av1/encoder/ml.c
@@ -143,14 +143,44 @@
   // Softmax function is invariant to adding the same constant
   // to all input values, so we subtract the maximum input to avoid
   // possible overflow.
-  float max_inp = input[0];
-  for (int i = 1; i < n; i++) max_inp = AOMMAX(max_inp, input[i]);
+  float max_input = input[0];
+  for (int i = 1; i < n; i++) max_input = AOMMAX(max_input, input[i]);
   float sum_out = 0.0f;
   for (int i = 0; i < n; i++) {
     // Clamp to range [-10.0, 0.0] to prevent FE_UNDERFLOW errors.
-    const float normalized_input = AOMMAX(input[i] - max_inp, -10.0f);
-    output[i] = (float)exp(normalized_input);
+    const float normalized_input = AOMMAX(input[i] - max_input, -10.0f);
+    output[i] = expf(normalized_input);
     sum_out += output[i];
   }
   for (int i = 0; i < n; i++) output[i] /= sum_out;
 }
+
+static AOM_INLINE float approx_exp(float y) {
+#define A ((1 << 23) / 0.69314718056f)  // (1 << 23) / ln(2)
+#define B \
+  127  // Offset for the exponent according to IEEE floating point standard.
+#define C 60801  // Magic number controls the accuracy of approximation
+  union {
+    float as_float;
+    int32_t as_int32;
+  } container;
+  container.as_int32 = ((int32_t)(y * A)) + ((B << 23) - C);
+  return container.as_float;
+#undef A
+#undef B
+#undef C
+}
+
+void av1_nn_fast_softmax_16_c(const float *input, float *output) {
+  const int kNumClasses = 16;
+  float max_input = input[0];
+  for (int i = 1; i < kNumClasses; i++) max_input = AOMMAX(max_input, input[i]);
+  float sum_out = 0.0f;
+  for (int i = 0; i < kNumClasses; i++) {
+    // Clamp to range [-10.0, 0.0] to prevent FE_UNDERFLOW errors.
+    const float normalized_input = AOMMAX(input[i] - max_input, -10.0f);
+    output[i] = approx_exp(normalized_input);
+    sum_out += output[i];
+  }
+  for (int i = 0; i < kNumClasses; i++) output[i] /= sum_out;
+}
diff --git a/av1/encoder/ml.h b/av1/encoder/ml.h
index 62d543d..566f927 100644
--- a/av1/encoder/ml.h
+++ b/av1/encoder/ml.h
@@ -71,6 +71,9 @@
 // output[i] = exp(input[i]) / sum_{k \in [0,n)}(exp(input[k]))
 void av1_nn_softmax(const float *input, float *output, int n);
 
+// A faster but less accurate version of av1_nn_softmax(input, output, 16)
+void av1_nn_fast_softmax_16_c(const float *input, float *output);
+
 // Applies a precision reduction to output of av1_nn_predict to prevent
 // mismatches between C and SIMD implementations.
 void av1_nn_output_prec_reduce(float *const output, int num_output);
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index df724b6..2970cf6 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -1753,6 +1753,18 @@
   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
 }
 
+static AOM_INLINE bool check_bit_mask(uint16_t mask, int val) {
+  return mask & (1 << val);
+}
+
+static AOM_INLINE void set_bit_mask(uint16_t *mask, int val) {
+  *mask |= (1 << val);
+}
+
+static AOM_INLINE void unset_bit_mask(uint16_t *mask, int val) {
+  *mask &= ~(1 << val);
+}
+
 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
                         int blk_row, int blk_col, TxSetType tx_set_type,
                         TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
@@ -1791,9 +1803,11 @@
   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
                                 vfeatures);
+
   av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
                                   &hfeatures[hfeatures_num - 1],
                                   &vfeatures[vfeatures_num - 1]);
+
 #if CONFIG_NN_V2
   av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
   av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
@@ -1810,7 +1824,11 @@
     cur_scores_2D[3] = vscores[i] * hscores[3];
   }
 
-  av1_nn_softmax(scores_2D_raw, scores_2D, 16);
+  assert(TX_TYPES == 16);
+  // This version of the function only works when there are at most 16 classes.
+  // So we will need to change the optimization or use av1_nn_softmax instead if
+  // this ever gets changed.
+  av1_nn_fast_softmax_16(scores_2D_raw, scores_2D);
 
   const float score_thresh =
       get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
@@ -1824,24 +1842,30 @@
   // Calculate sum of allowed tx type score and Populate allow bit mask based
   // on score_thresh and allowed_tx_mask
   for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
-    int allow_tx_type = *allowed_tx_mask & (1 << tx_type_table_2D[tx_idx]);
+    int allow_tx_type =
+        check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
     if (scores_2D[tx_idx] > max_score && allow_tx_type) {
       max_score = scores_2D[tx_idx];
       max_score_i = tx_idx;
     }
     if (scores_2D[tx_idx] >= score_thresh && allow_tx_type) {
       // Set allow mask based on score_thresh
-      allow_bitmask |= (1 << tx_type_table_2D[tx_idx]);
+      set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
 
       // Accumulate score of allowed tx type
       sum_score += scores_2D[tx_idx];
     }
   }
-  if (!((allow_bitmask >> max_score_i) & 0x01)) {
-    // Set allow mask based on tx type with max score
-    allow_bitmask |= (1 << tx_type_table_2D[max_score_i]);
-    sum_score += scores_2D[max_score_i];
+  if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
+    // If even the tx_type with max score is pruned, this means that no other
+    // tx_type is feasible. When this happens, we force enable max_score_i and
+    // end the search.
+    set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
+    memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
+    *allowed_tx_mask = allow_bitmask;
+    return;
   }
+
   // Sort tx type probability of all types
   sort_probability(scores_2D, tx_type_table_2D, TX_TYPES);
 
@@ -1859,7 +1883,7 @@
       if (score_ratio > 30.0 && tx_count >= 2) break;
 
       // Calculate cumulative probability of allowed tx types
-      if (allow_bitmask & (1 << tx_type_table_2D[tx_idx])) {
+      if (check_bit_mask(allow_bitmask, tx_type_table_2D[tx_idx])) {
         // Calculate cumulative probability
         temp_score += scores_2D[tx_idx];
 
@@ -1870,8 +1894,9 @@
     }
     // Set remaining tx types as pruned
     for (; tx_idx < TX_TYPES; tx_idx++)
-      allow_bitmask &= ~(1 << tx_type_table_2D[tx_idx]);
+      unset_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
   }
+
   memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
   *allowed_tx_mask = allow_bitmask;
 }
diff --git a/av1/encoder/x86/ml_sse3.c b/av1/encoder/x86/ml_sse3.c
index 89b1e6a..ab69088 100644
--- a/av1/encoder/x86/ml_sse3.c
+++ b/av1/encoder/x86/ml_sse3.c
@@ -242,3 +242,95 @@
   }
   if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
 }
+
+// Based on N. N. Schraudolph. A Fast, Compact Approximation of the Exponential
+// Function. Neural Computation, 11(4):853–862, 1999.
+static AOM_INLINE __m128 approx_exp(__m128 y) {
+#define A ((1 << 23) / 0.69314718056f)  // (1 << 23) / ln(2)
+#define B \
+  127  // Offset for the exponent according to IEEE floating point standard.
+#define C 60801  // Magic number controls the accuracy of approximation
+  const __m128 multiplier = _mm_set1_ps(A);
+  const __m128i offset = _mm_set1_epi32(B * (1 << 23) - C);
+
+  y = _mm_mul_ps(y, multiplier);
+  y = _mm_castsi128_ps(_mm_add_epi32(_mm_cvtps_epi32(y), offset));
+  return y;
+#undef A
+#undef B
+#undef C
+}
+
+static AOM_INLINE __m128 reduce_max(__m128 reg) {
+  __m128 tmp_reg;
+
+  tmp_reg = _mm_shuffle_ps(reg, reg, 0x4e);  // 01 00 11 10
+  reg = _mm_max_ps(reg, tmp_reg);
+
+  tmp_reg = _mm_shuffle_ps(reg, reg, 0xb1);  // 10 11 00 01
+  reg = _mm_max_ps(reg, tmp_reg);
+
+  return reg;
+}
+
+static AOM_INLINE __m128 reduce_sum(__m128 reg) {
+  __m128 tmp_reg;
+
+  tmp_reg = _mm_shuffle_ps(reg, reg, 0x4e);  // 01 00 11 10
+  reg = _mm_add_ps(reg, tmp_reg);
+
+  tmp_reg = _mm_shuffle_ps(reg, reg, 0xb1);  // 10 11 00 01
+  reg = _mm_add_ps(reg, tmp_reg);
+
+  return reg;
+}
+
+void av1_nn_fast_softmax_16_sse3(const float *input, float *output) {
+  // Clips at -10 to avoid underflowing
+  const __m128 clipper = _mm_set1_ps(-10.0f);
+
+  // Load in 16 values
+  __m128 in_0 = _mm_loadu_ps(&input[0]);
+  __m128 in_1 = _mm_loadu_ps(&input[4]);
+  __m128 in_2 = _mm_loadu_ps(&input[8]);
+  __m128 in_3 = _mm_loadu_ps(&input[12]);
+
+  // Get the max
+  __m128 max_0 = _mm_max_ps(in_0, in_1);
+  __m128 max_1 = _mm_max_ps(in_2, in_3);
+
+  max_0 = _mm_max_ps(max_0, max_1);
+  max_0 = reduce_max(max_0);
+
+  // Subtract the max off and clip
+  in_0 = _mm_sub_ps(in_0, max_0);
+  in_1 = _mm_sub_ps(in_1, max_0);
+  in_2 = _mm_sub_ps(in_2, max_0);
+  in_3 = _mm_sub_ps(in_3, max_0);
+
+  in_0 = _mm_max_ps(in_0, clipper);
+  in_1 = _mm_max_ps(in_1, clipper);
+  in_2 = _mm_max_ps(in_2, clipper);
+  in_3 = _mm_max_ps(in_3, clipper);
+
+  // Exponentiate and compute the denominator
+  __m128 sum = in_0 = approx_exp(in_0);
+  in_1 = approx_exp(in_1);
+  sum = _mm_add_ps(sum, in_1);
+  in_2 = approx_exp(in_2);
+  sum = _mm_add_ps(sum, in_2);
+  in_3 = approx_exp(in_3);
+  sum = _mm_add_ps(sum, in_3);
+  sum = reduce_sum(sum);
+
+  // Divide to get the probability
+  in_0 = _mm_div_ps(in_0, sum);
+  in_1 = _mm_div_ps(in_1, sum);
+  in_2 = _mm_div_ps(in_2, sum);
+  in_3 = _mm_div_ps(in_3, sum);
+
+  _mm_storeu_ps(&output[0], in_0);
+  _mm_storeu_ps(&output[4], in_1);
+  _mm_storeu_ps(&output[8], in_2);
+  _mm_storeu_ps(&output[12], in_3);
+}
diff --git a/test/av1_softmax_test.cc b/test/av1_softmax_test.cc
new file mode 100644
index 0000000..c2ab07b
--- /dev/null
+++ b/test/av1_softmax_test.cc
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <tuple>
+
+#include "aom/aom_integer.h"
+#include "aom_ports/aom_timer.h"
+#include "av1/encoder/ml.h"
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+#include "config/av1_rtcd.h"
+#include "test/acm_random.h"
+#include "test/register_state_check.h"
+#include "test/util.h"
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+namespace {
+using FastSoftmaxFn = void (*)(const float *const input, float *output);
+using FastSoftmaxTestParams = std::tuple<const FastSoftmaxFn, int>;
+
+// Error thresholds for functional equivalence
+constexpr float kRelEpsilon = 5e-2f;
+constexpr float kAbsEpsilon = 5e-3f;
+
+class FastSoftmaxTest : public ::testing::TestWithParam<FastSoftmaxTestParams> {
+ public:
+  FastSoftmaxTest()
+      : target_fn_{ GET_PARAM(0) }, num_classes_(GET_PARAM(1)),
+        ref_buf_(new float[num_classes_]()),
+        dst_buf_(new float[num_classes_]()), input_(new float[num_classes_]()) {
+  }
+  void RunSoftmaxTest();
+  void RunSoftmaxSpeedTest(const int run_times);
+  void FillInputBuf();
+
+ private:
+  const FastSoftmaxFn target_fn_;
+  const int num_classes_;
+  std::unique_ptr<float[]> ref_buf_, dst_buf_, input_;
+  libaom_test::ACMRandom rng_;
+};
+
+void FastSoftmaxTest::FillInputBuf() {
+  for (int idx = 0; idx < num_classes_; idx++) {
+    input_[idx] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 30);
+  }
+}
+
+void FastSoftmaxTest::RunSoftmaxTest() {
+  av1_nn_softmax(input_.get(), ref_buf_.get(), num_classes_);
+  target_fn_(input_.get(), dst_buf_.get());
+
+  for (int idx = 0; idx < num_classes_; idx++) {
+    if (ref_buf_[idx] < kAbsEpsilon) {
+      ASSERT_LE(dst_buf_[idx], kAbsEpsilon)
+          << "Reference output was near-zero, test output was not" << std::endl;
+    } else {
+      const float error = dst_buf_[idx] - ref_buf_[idx];
+      const float relative_error = fabsf(error / ref_buf_[idx]);
+      ASSERT_LE(relative_error, kRelEpsilon)
+          << "Excessive relative error between reference and test output"
+          << std::endl;
+      ASSERT_LE(error, kAbsEpsilon)
+          << "Excessive absolute error between reference and test output"
+          << std::endl;
+    }
+  }
+}
+
+void FastSoftmaxTest::RunSoftmaxSpeedTest(const int run_times) {
+  aom_usec_timer timer;
+  aom_usec_timer_start(&timer);
+  for (int idx = 0; idx < run_times; idx++) {
+    target_fn_(input_.get(), dst_buf_.get());
+  }
+  aom_usec_timer_mark(&timer);
+  const int64_t time = aom_usec_timer_elapsed(&timer);
+  std::cout << "Test with " << num_classes_ << " classes took " << time
+            << " us." << std::endl;
+}
+
+TEST_P(FastSoftmaxTest, RandomValues) {
+  FillInputBuf();
+  RunSoftmaxTest();
+}
+
+TEST_P(FastSoftmaxTest, DISABLED_Speed) {
+  constexpr int kNumTimes = 1000000;
+  RunSoftmaxSpeedTest(kNumTimes);
+}
+
+void AnchorSoftmax16Fn(const float *input, float *output) {
+  av1_nn_softmax(input, output, 16);
+}
+
+const FastSoftmaxTestParams kArrayParams_c[] = {
+  { AnchorSoftmax16Fn, 16 }, { av1_nn_fast_softmax_16_c, 16 }
+};
+INSTANTIATE_TEST_SUITE_P(C, FastSoftmaxTest,
+                         ::testing::ValuesIn(kArrayParams_c));
+
+#if HAVE_SSE3 && !CONFIG_EXCLUDE_SIMD_MISMATCH
+INSTANTIATE_TEST_SUITE_P(
+    SSE3, FastSoftmaxTest,
+    ::testing::Values(FastSoftmaxTestParams(av1_nn_fast_softmax_16_sse3, 16)));
+#endif
+}  // namespace
diff --git a/test/best_encode.sh b/test/best_encode.sh
index fe31a01..d29fdae 100755
--- a/test/best_encode.sh
+++ b/test/best_encode.sh
@@ -29,7 +29,7 @@
     -p 2 \
     --pass=2 \
     --fpf=$f.fpf \
-    --best \
+    --good \
     --cpu-used=0 \
     --target-bitrate=$b \
     --auto-alt-ref=1 \
@@ -48,8 +48,7 @@
     --maxsection-pct=800 \
     --psnr \
     --arnr-maxframes=7 \
-    --arnr-strength=3 \
-    --arnr-type=3
+    --arnr-strength=3
 else
   # No first-pass file found, do 2-pass encode
   aomenc \
@@ -58,7 +57,7 @@
     -p 2 \
     --pass=1 \
     --fpf=$f.fpf \
-    --best \
+    --good \
     --cpu-used=0 \
     --target-bitrate=$b \
     --auto-alt-ref=1 \
@@ -79,7 +78,7 @@
     -p 2 \
     --pass=2 \
     --fpf=$f.fpf \
-    --best \
+    --good \
     --cpu-used=0 \
     --target-bitrate=$b \
     --auto-alt-ref=1 \
@@ -98,6 +97,5 @@
     --maxsection-pct=800 \
     --psnr \
     --arnr-maxframes=7 \
-    --arnr-strength=3 \
-    --arnr-type=3
+    --arnr-strength=3
 fi
diff --git a/test/test.cmake b/test/test.cmake
index 94d1350..80af114 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -248,6 +248,7 @@
               "${AOM_ROOT}/test/av1_inv_txfm2d_test.cc"
               "${AOM_ROOT}/test/av1_nn_predict_test.cc"
               "${AOM_ROOT}/test/av1_round_shift_array_test.cc"
+              "${AOM_ROOT}/test/av1_softmax_test.cc"
               "${AOM_ROOT}/test/av1_txfm_test.cc"
               "${AOM_ROOT}/test/av1_txfm_test.h"
               "${AOM_ROOT}/test/av1_wedge_utils_test.cc"