Add an ML based partition search breakout feature

Use a neural net model to predict if the encoder should breakout from
partition search after trying PARTITION_NONE. This patch only
implements models for small resolutions(less than 720P).

This feature is enabled for speed 0 and above. Coding loss:
speed 0:
lowres 0.088%   midres 0.073%   objective-1-fast(360P) -0.07%
speed 1:
lowres 0.051%   midres 0.080%   objective-1-fast(360P) 0.01%

Tested encoding speed over 20 lowres sequences, speed gains:
Speed 0:
             QP=35        QP=45       QP=55
average      11.39%      14.65%       20.56%
maximum      29.95%      53.47%       45.93%

Speed 1:
             QP=35        QP=45       QP=55
average       7.52%       7.38%       9.86%
maximum      18.20%      13.44%      25.85%

STATS_CHANGED

Change-Id: Iafff5da49db6ee7d304f69b0cd278160a78cb01f
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 6e7761a..fd87ac1 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -3070,6 +3070,69 @@
 #undef FEATURES
 #undef LABELS
 
+#define FEATURES 4
+// ML-based partition search breakout.
+static int ml_predict_breakout(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
+                               const MACROBLOCK *const x,
+                               const RD_STATS *const rd_stats,
+                               unsigned int pb_source_variance) {
+  const NN_CONFIG *nn_config = NULL;
+  int thresh = 0;
+  switch (bsize) {
+    case BLOCK_8X8:
+      nn_config = &av1_partition_breakout_nnconfig_8;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[0];
+      break;
+    case BLOCK_16X16:
+      nn_config = &av1_partition_breakout_nnconfig_16;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[1];
+      break;
+    case BLOCK_32X32:
+      nn_config = &av1_partition_breakout_nnconfig_32;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[2];
+      break;
+    case BLOCK_64X64:
+      nn_config = &av1_partition_breakout_nnconfig_64;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[3];
+      break;
+    case BLOCK_128X128:
+      nn_config = &av1_partition_breakout_nnconfig_128;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[4];
+      break;
+    default: assert(0 && "Unexpected bsize.");
+  }
+  if (!nn_config || thresh < 0) return 0;
+
+  // Generate feature values.
+  float features[FEATURES];
+  int feature_index = 0;
+  aom_clear_system_state();
+
+  const int num_pels_log2 = num_pels_log2_lookup[bsize];
+  float rate_f = (float)AOMMIN(rd_stats->rate, INT_MAX);
+  rate_f = ((float)x->rdmult / 128.0f / 512.0f / (float)(1 << num_pels_log2)) *
+           rate_f;
+  features[feature_index++] = rate_f;
+
+  const float dist_f =
+      (float)(AOMMIN(rd_stats->dist, INT_MAX) >> num_pels_log2);
+  features[feature_index++] = dist_f;
+
+  features[feature_index++] = (float)pb_source_variance;
+
+  const int dc_q = (int)x->plane[0].dequant_QTX[0];
+  features[feature_index++] = (float)(dc_q * dc_q) / 256.0f;
+  assert(feature_index == FEATURES);
+
+  // Calculate score using the NN model.
+  float score = 0.0f;
+  av1_nn_predict(features, nn_config, &score);
+
+  // Make decision.
+  return (int)(score * 100) >= thresh;
+}
+#undef FEATURES
+
 // TODO(jingning,jimbankoski,rbultje): properly skip partition types that are
 // unlikely to be selected depending on previous rate-distortion optimization
 // results, for encoding speed-up.
@@ -3368,16 +3431,29 @@
         best_rdc = this_rdc;
         if (bsize_at_least_8x8) pc_tree->partitioning = PARTITION_NONE;
 
-        // If all y, u, v transform blocks in this partition are skippable, and
-        // the dist & rate are within the thresholds, the partition search is
-        // terminated for current branch of the partition search tree.
-        // The dist & rate thresholds are set to 0 at speed 0 to disable the
-        // early termination at that speed.
-        if (!x->e_mbd.lossless[xd->mi[0]->segment_id] &&
-            (ctx_none->skippable && best_rdc.dist < dist_breakout_thr &&
-             best_rdc.rate < rate_breakout_thr)) {
-          do_square_split = 0;
-          do_rectangular_split = 0;
+        if ((do_square_split || do_rectangular_split) &&
+            !x->e_mbd.lossless[xd->mi[0]->segment_id] && ctx_none->skippable) {
+          const int use_ml_based_breakout =
+              bsize <= cpi->sf.use_square_partition_only_threshold &&
+              bsize > BLOCK_4X4 && xd->bd == 8;
+          if (use_ml_based_breakout) {
+            if (ml_predict_breakout(cpi, bsize, x, &this_rdc,
+                                    pb_source_variance)) {
+              do_square_split = 0;
+              do_rectangular_split = 0;
+            }
+          }
+
+          // If all y, u, v transform blocks in this partition are skippable,
+          // and the dist & rate are within the thresholds, the partition
+          // search is terminated for current branch of the partition search
+          // tree. The dist & rate thresholds are set to 0 at speed 0 to
+          // disable the early termination at that speed.
+          if (best_rdc.dist < dist_breakout_thr &&
+              best_rdc.rate < rate_breakout_thr) {
+            do_square_split = 0;
+            do_rectangular_split = 0;
+          }
         }
 
 #if CONFIG_FP_MB_STATS
diff --git a/av1/encoder/partition_model_weights.h b/av1/encoder/partition_model_weights.h
index 279d394..5f6b9d0 100644
--- a/av1/encoder/partition_model_weights.h
+++ b/av1/encoder/partition_model_weights.h
@@ -1786,6 +1786,265 @@
 #undef FEATURE_SIZE
 #undef LABEL_SIZE
 
+#define FEATURE_SIZE 4
+static const float
+    av1_partition_breakout_nn_weights_128_layer0[FEATURE_SIZE * 32] = {
+      -0.331785f,  0.068675f,  -0.323814f,  0.033714f,  -0.237835f, 0.166316f,
+      -0.498766f,  -0.545634f, -0.266173f,  -0.476957f, -0.120409f, -0.021042f,
+      0.124056f,   -0.278750f, -0.110120f,  -0.372812f, 4.547939f,  0.097618f,
+      -0.002710f,  -0.064169f, -1.841173f,  -0.403833f, 0.005536f,  0.067188f,
+      -0.434935f,  -0.227421f, -0.000011f,  -0.139961f, -0.174056f, -0.652384f,
+      -0.000015f,  -0.262847f, -3.319706f,  -0.947693f, 0.002981f,  0.016717f,
+      -10.408850f, -0.014568f, -0.000018f,  0.019084f,  1.523383f,  0.074525f,
+      -0.002076f,  -0.020734f, 4.881495f,   0.002799f,  0.000342f,  -0.019623f,
+      1.786154f,   0.037462f,  -0.019037f,  0.052833f,  11.408153f, -0.044602f,
+      0.026155f,   -0.518627f, -0.474499f,  -0.427430f, -0.442733f, -0.011116f,
+      -22.379410f, -0.000549f, -0.001418f,  0.008090f,  -0.295090f, -0.230268f,
+      -0.337278f,  -0.001127f, -0.644282f,  -0.598783f, -0.539417f, -0.003303f,
+      9.189824f,   0.038066f,  -0.004097f,  -0.460045f, -0.308858f, -0.242691f,
+      -0.230835f,  -0.273057f, 0.152226f,   0.179239f,  -0.146382f, -0.004655f,
+      -0.242940f,  -0.718862f, -0.001685f,  -0.214736f, 3.263186f,  0.079463f,
+      -0.003854f,  -0.187461f, -0.599144f,  -0.419808f, -0.000597f, -0.136980f,
+      0.184813f,   -0.319525f, -0.007246f,  0.079709f,  -0.883229f, -0.343748f,
+      -0.000077f,  -0.172214f, -0.548759f,  -0.194674f, -0.144786f, 0.043896f,
+      -0.176364f,  -0.248394f, -0.090215f,  -0.294743f, -0.280980f, -0.181436f,
+      -0.115681f,  -0.071915f, -13.035494f, -0.075623f, 0.017052f,  -0.171152f,
+      5.910803f,   0.128344f,  0.010256f,   -1.073301f, 2.387826f,  0.166183f,
+      -0.007193f,  -0.257836f,
+    };
+
+static const float av1_partition_breakout_nn_bias_128_layer0[32] = {
+  0.115591f,  -0.100178f, -0.165523f, -0.122997f, 11.045759f,  1.034761f,
+  -0.323672f, -0.189087f, 2.850950f,  7.010029f,  -21.447067f, 1.877031f,
+  0.437442f,  5.929414f,  -0.117274f, 4.462253f,  -0.135198f,  -0.145927f,
+  8.727211f,  0.000000f,  -3.532987f, -0.405898f, 11.364439f,  -0.141728f,
+  -5.994947f, -0.362574f, 1.857687f,  -0.100400f, -0.130312f,  0.006080f,
+  0.429660f,  -8.439470f,
+};
+
+static const float av1_partition_breakout_nn_weights_128_layer1[32] = {
+  -0.013738f, 0.022052f,  -0.074437f, -0.211377f, -0.080433f, 0.015543f,
+  0.002091f,  0.014252f,  0.134834f,  0.190263f,  0.244175f,  -0.031747f,
+  0.020068f,  -0.068326f, 0.185471f,  0.660268f,  -0.134898f, -0.010376f,
+  -0.276023f, -0.282921f, -0.022769f, 0.007070f,  -0.186235f, 0.024407f,
+  -0.024837f, 0.005764f,  0.016599f,  -0.040077f, 0.020990f,  0.095054f,
+  -0.039662f, 0.131499f,
+};
+
+static const float av1_partition_breakout_nn_bias_128_layer1[1] = {
+  0.86678213f,
+};
+
+static const NN_CONFIG av1_partition_breakout_nnconfig_128 = {
+  FEATURE_SIZE,  // num_inputs
+  1,             // num_outputs
+  1,             // num_hidden_layers
+  {
+      32,  // num_hidden_nodes
+  },
+  {
+      av1_partition_breakout_nn_weights_128_layer0,
+      av1_partition_breakout_nn_weights_128_layer1,
+  },
+  {
+      av1_partition_breakout_nn_bias_128_layer0,
+      av1_partition_breakout_nn_bias_128_layer1,
+  },
+};
+
+static const float
+    av1_partition_breakout_nn_weights_64_layer0[FEATURE_SIZE * 16] = {
+      0.872892f,  -0.235539f, -0.412159f, -0.142533f, -2.251479f, -0.057073f,
+      -0.001373f, 0.112147f,  5.281734f,  0.060704f,  0.000838f,  -0.961554f,
+      0.244995f,  0.154515f,  -0.292654f, -0.167177f, -3.759112f, -0.486347f,
+      0.003208f,  -0.418226f, 2.618152f,  0.026832f,  0.003988f,  -0.404406f,
+      -0.405434f, 0.102791f,  -0.033406f, -0.029820f, -4.492342f, -0.154291f,
+      0.012947f,  -0.195075f, 0.009311f,  -0.411410f, -0.010986f, -0.554822f,
+      0.160576f,  0.020796f,  -0.457230f, -0.191111f, -7.759542f, -0.065039f,
+      -0.001322f, 0.055691f,  0.291924f,  -0.053076f, -0.148379f, -0.298383f,
+      1.022023f,  -0.033668f, -0.000804f, -0.825778f, -3.902254f, -0.085812f,
+      -0.052520f, -0.035012f, -0.465468f, -0.319231f, -0.497529f, -0.183068f,
+      -2.407131f, -0.062304f, 0.000874f,  0.108786f,
+    };
+
+static const float av1_partition_breakout_nn_bias_64_layer0[16] = {
+  0.081425f,  -14.404084f, 11.511393f, -0.930053f, 1.841889f,  15.020920f,
+  -1.872288f, 5.392535f,   -0.329335f, -0.005358f, 12.600776f, 0.000000f,
+  -0.337413f, 4.492778f,   0.000000f,  17.043072f,
+};
+
+static const float av1_partition_breakout_nn_weights_64_layer1[16] = {
+  -0.465338f, -0.103023f, -0.174808f, -0.005156f, -0.016366f, -0.172494f,
+  0.014185f,  0.067030f,  -0.001939f, -0.175049f, 0.245992f,  -0.181660f,
+  -0.038572f, 0.307899f,  -0.294283f, 0.118323f,
+};
+
+static const float av1_partition_breakout_nn_bias_64_layer1[1] = {
+  -1.33438122f,
+};
+
+static const NN_CONFIG av1_partition_breakout_nnconfig_64 = {
+  FEATURE_SIZE,  // num_inputs
+  1,             // num_outputs
+  1,             // num_hidden_layers
+  {
+      16,  // num_hidden_nodes
+  },
+  {
+      av1_partition_breakout_nn_weights_64_layer0,
+      av1_partition_breakout_nn_weights_64_layer1,
+  },
+  {
+      av1_partition_breakout_nn_bias_64_layer0,
+      av1_partition_breakout_nn_bias_64_layer1,
+  },
+};
+
+static const float
+    av1_partition_breakout_nn_weights_32_layer0[FEATURE_SIZE * 16] = {
+      -4.825528f, -0.145737f, 0.001907f,  0.145415f,  -1.858153f, -0.080744f,
+      0.000601f,  0.211991f,  0.384265f,  -0.043945f, -0.521332f, -0.170622f,
+      -0.046866f, -0.600506f, -0.001216f, -0.332760f, -0.447677f, -0.605844f,
+      -0.121008f, -0.119936f, -0.215739f, -0.269665f, -0.668587f, 0.071318f,
+      -1.202551f, -0.729727f, -0.370084f, 0.088215f,  -1.926800f, -0.086519f,
+      0.000359f,  0.215120f,  0.718749f,  0.022942f,  0.003840f,  -0.176518f,
+      1.213451f,  0.080786f,  0.001557f,  -1.053430f, 0.202698f,  -0.583919f,
+      -0.535512f, -0.239927f, -0.110151f, -0.128832f, -0.441087f, -0.145575f,
+      -0.178518f, -0.585784f, 0.000029f,  -0.833014f, -0.331358f, -0.520297f,
+      -0.088676f, -0.178487f, -1.430755f, 0.022981f,  -0.106931f, 0.015573f,
+      -0.520814f, -0.045386f, -0.443123f, -0.484209f,
+    };
+
+static const float av1_partition_breakout_nn_bias_32_layer0[16] = {
+  11.747026f, -9.337718f, 0.341648f, -0.155847f, -0.104005f, 4.666283f,
+  6.669584f,  16.625504f, 9.885626f, 15.439183f, -0.346080f, 0.000000f,
+  -0.423808f, 0.000000f,  6.352258f, -0.155787f,
+};
+
+static const float av1_partition_breakout_nn_weights_32_layer1[16] = {
+  0.168561f,  -0.122519f, 0.524667f,  0.032474f,  0.059097f,  0.011900f,
+  0.166445f,  0.127256f,  -0.034838f, -0.212586f, -0.317973f, 0.348419f,
+  -0.004171f, 0.157694f,  0.117845f,  0.272115f,
+};
+
+static const float av1_partition_breakout_nn_bias_32_layer1[1] = {
+  0.09049262f,
+};
+
+static const NN_CONFIG av1_partition_breakout_nnconfig_32 = {
+  FEATURE_SIZE,  // num_inputs
+  1,             // num_outputs
+  1,             // num_hidden_layers
+  {
+      16,  // num_hidden_nodes
+  },
+  {
+      av1_partition_breakout_nn_weights_32_layer0,
+      av1_partition_breakout_nn_weights_32_layer1,
+  },
+  {
+      av1_partition_breakout_nn_bias_32_layer0,
+      av1_partition_breakout_nn_bias_32_layer1,
+  },
+};
+
+static const float
+    av1_partition_breakout_nn_weights_16_layer0[FEATURE_SIZE * 16] = {
+      0.209371f,  0.028758f,  0.005764f,  -0.384401f, -0.625777f, -0.005647f,
+      -0.316867f, 0.042985f,  0.127344f,  0.025461f,  0.011465f,  -0.071043f,
+      -0.295977f, -0.076093f, -0.209681f, -0.311653f, -0.147538f, 0.009910f,
+      -0.130997f, -0.012326f, 0.024124f,  -0.323578f, -0.005790f, -0.085664f,
+      -1.575066f, -0.119221f, 0.015018f,  0.187204f,  0.238117f,  0.084924f,
+      -0.004444f, -1.271538f, -0.709860f, -0.006226f, -0.903111f, 0.090573f,
+      -0.278642f, -0.011114f, 0.021162f,  0.081290f,  -0.467486f, -0.040771f,
+      -0.224069f, -0.714390f, -0.281905f, -0.001336f, -0.761212f, -0.060385f,
+      -0.814479f, -0.050450f, -0.003666f, 0.085668f,  -0.272589f, 0.057330f,
+      -0.206540f, -0.303418f, 0.075335f,  -0.180468f, -0.064872f, -0.755948f,
+      -0.509287f, -0.048877f, -0.001512f, 0.077086f,
+    };
+
+static const float av1_partition_breakout_nn_bias_16_layer0[16] = {
+  16.421495f, 4.012273f,  -1.828571f, 0.000000f,  -0.263564f, -0.201972f,
+  6.564987f,  14.651000f, -3.227779f, 2.241833f,  -0.137116f, 0.762876f,
+  5.625762f,  0.615822f,  0.040057f,  16.668884f,
+};
+
+static const float av1_partition_breakout_nn_weights_16_layer1[16] = {
+  -0.096440f, 0.184316f,  -0.021148f, 0.424974f, 0.003743f,  0.006310f,
+  0.046266f,  -0.219224f, -0.087004f, 0.024623f, -0.275798f, 0.120164f,
+  0.269773f,  -0.021105f, -0.146698f, 0.188764f,
+};
+
+static const float av1_partition_breakout_nn_bias_16_layer1[1] = {
+  1.60751927f,
+};
+
+static const NN_CONFIG av1_partition_breakout_nnconfig_16 = {
+  FEATURE_SIZE,  // num_inputs
+  1,             // num_outputs
+  1,             // num_hidden_layers
+  {
+      16,  // num_hidden_nodes
+  },
+  {
+      av1_partition_breakout_nn_weights_16_layer0,
+      av1_partition_breakout_nn_weights_16_layer1,
+  },
+  {
+      av1_partition_breakout_nn_bias_16_layer0,
+      av1_partition_breakout_nn_bias_16_layer1,
+  },
+};
+
+static const float
+    av1_partition_breakout_nn_weights_8_layer0[FEATURE_SIZE * 16] = {
+      -0.255885f, 0.109548f,  -0.111054f, -0.476119f, -1.083031f, -0.342003f,
+      0.048241f,  -0.356013f, -0.085054f, 0.124908f,  0.000084f,  -0.149906f,
+      -0.729829f, 0.133535f,  -0.002125f, 0.207516f,  -0.210163f, -0.567365f,
+      -0.590103f, 0.045308f,  -0.539406f, 0.130550f,  -0.663879f, -0.170549f,
+      0.017587f,  -0.054187f, 0.000550f,  0.038297f,  -0.112891f, -0.012751f,
+      -0.048067f, 0.095564f,  0.079892f,  0.077285f,  -0.749708f, -0.286312f,
+      -0.054334f, 0.132242f,  -0.004152f, -0.209758f, -0.073407f, 0.082306f,
+      -0.001034f, -0.090990f, 0.122823f,  -0.109794f, -0.230066f, -0.391155f,
+      -0.262245f, -0.004744f, -0.232246f, 0.099290f,  -0.637484f, 0.111937f,
+      -0.548556f, -0.598344f, 0.123265f,  -0.281395f, -0.399711f, -0.525671f,
+      -0.596269f, 0.098494f,  -0.005765f, 0.173652f,
+    };
+
+static const float av1_partition_breakout_nn_bias_8_layer0[16] = {
+  0.194141f, -0.111223f, 2.503733f, -7.155602f, -0.695068f, 0.114874f,
+  2.056990f, 5.284306f,  0.639643f, -2.792049f, -2.232339f, -0.232209f,
+  2.336705f, -0.278834f, 0.231905f, 7.954366f,
+};
+
+static const float av1_partition_breakout_nn_weights_8_layer1[16] = {
+  -0.014439f, 0.010171f, 0.048116f,  -0.090659f, -0.081235f, -0.021840f,
+  -0.017360f, 0.031063f, -0.031737f, -0.023439f, -0.037725f, 0.021954f,
+  0.055858f,  0.230970f, -0.056466f, 0.119780f,
+};
+
+static const float av1_partition_breakout_nn_bias_8_layer1[1] = {
+  1.27784479f,
+};
+
+static const NN_CONFIG av1_partition_breakout_nnconfig_8 = {
+  FEATURE_SIZE,  // num_inputs
+  1,             // num_outputs
+  1,             // num_hidden_layers
+  {
+      16,  // num_hidden_nodes
+  },
+  {
+      av1_partition_breakout_nn_weights_8_layer0,
+      av1_partition_breakout_nn_weights_8_layer1,
+  },
+  {
+      av1_partition_breakout_nn_bias_8_layer0,
+      av1_partition_breakout_nn_bias_8_layer1,
+  },
+};
+#undef FEATURE_SIZE
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 1e97bfb..0e5a9c9 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -98,6 +98,15 @@
     sf->use_square_partition_only_threshold = BLOCK_64X64;
   }
 
+  // TODO(huisu@google.com): train models for 720P and above.
+  if (!is_720p_or_larger) {
+    sf->ml_partition_search_breakout_thresh[0] = 200;  // BLOCK_8X8
+    sf->ml_partition_search_breakout_thresh[1] = 250;  // BLOCK_16X16
+    sf->ml_partition_search_breakout_thresh[2] = 300;  // BLOCK_32X32
+    sf->ml_partition_search_breakout_thresh[3] = 500;  // BLOCK_64X64
+    sf->ml_partition_search_breakout_thresh[4] = -1;   // BLOCK_128X128
+  }
+
   if (speed >= 1) {
     if (is_720p_or_larger) {
       sf->use_square_partition_only_threshold = BLOCK_128X128;
@@ -106,6 +115,14 @@
     } else {
       sf->use_square_partition_only_threshold = BLOCK_32X32;
     }
+
+    if (!is_720p_or_larger) {
+      sf->ml_partition_search_breakout_thresh[0] = 200;  // BLOCK_8X8
+      sf->ml_partition_search_breakout_thresh[1] = 250;  // BLOCK_16X16
+      sf->ml_partition_search_breakout_thresh[2] = 300;  // BLOCK_32X32
+      sf->ml_partition_search_breakout_thresh[3] = 300;  // BLOCK_64X64
+      sf->ml_partition_search_breakout_thresh[4] = -1;   // BLOCK_128X128
+    }
   }
 
   if (speed >= 2) {
@@ -476,6 +493,8 @@
   sf->ml_prune_ab_partition = 0;
   sf->ml_prune_4_partition = 0;
   sf->fast_cdef_search = 0;
+  for (i = 0; i < PARTITION_BLOCK_SIZES; ++i)
+    sf->ml_partition_search_breakout_thresh[i] = -1;  // -1 means not enabled.
 
   // Set this at the appropriate speed levels
   sf->use_transform_domain_distortion = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 2f3c574..5a5230d 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -561,6 +561,9 @@
   int64_t partition_search_breakout_dist_thr;
   int partition_search_breakout_rate_thr;
 
+  // Thresholds for ML based partition search breakout.
+  int ml_partition_search_breakout_thresh[PARTITION_BLOCK_SIZES];
+
   // Allow skipping partition search for still image frame
   int allow_partition_search_skip;