DNN RD Model simplification and improvement

Number of input features are reduced to 8 from 11.
Coding efficiency when used for interpolation filter search
is now a little better than the baseline, but not turned on
yet for complexity reasons.

Change-Id: Ib792db7b2a1e5dff0a6c861f7010f7e6473a19f7
diff --git a/av1/encoder/pustats.h b/av1/encoder/pustats.h
index 42a4c59..689f858 100644
--- a/av1/encoder/pustats.h
+++ b/av1/encoder/pustats.h
@@ -18,83 +18,78 @@
 
 #include "av1/encoder/ml.h"
 
-#define NUM_FEATURES 11
+#define NUM_FEATURES_PUSTATS 8
 #define NUM_HIDDEN_LAYERS 2
 #define HIDDEN_LAYERS_0_NODES 12
 #define HIDDEN_LAYERS_1_NODES 10
 #define LOGITS_NODES 1
 
 static const float
-    av1_pustats_rate_hiddenlayer_0_kernel[NUM_FEATURES *
+    av1_pustats_rate_hiddenlayer_0_kernel[NUM_FEATURES_PUSTATS *
                                           HIDDEN_LAYERS_0_NODES] = {
-      21.5067f,  22.6709f,  0.0049f,   0.9288f,  -0.0100f,  0.0060f,   -0.0071f,
-      -0.0085f,  0.0348f,   -0.1273f,  10.1154f, 6.3405f,   7.8589f,   -0.0652f,
-      -4.6352f,  0.0445f,   -3.2748f,  0.1025f,  -0.0385f,  -0.4505f,  1.1320f,
-      3.2634f,   23.2420f,  -7.9056f,  0.0522f,  -18.1555f, 0.0977f,   0.1155f,
-      -0.0138f,  0.0267f,   -0.3992f,  0.2735f,  22.8063f,  35.1043f,  3.8140f,
-      -0.0295f,  0.0771f,   -0.6938f,  0.0302f,  -0.0266f,  0.0989f,   -0.0794f,
-      0.2981f,   33.3333f,  -24.1150f, 1.4986f,  -0.0975f,  -15.3938f, -0.0858f,
-      -0.0845f,  -0.0869f,  -0.0858f,  0.3542f,  0.0155f,   -18.2629f, 9.6688f,
-      -11.9643f, -0.2904f,  -5.3026f,  -0.1011f, -0.1202f,  0.0127f,   -0.0269f,
-      0.3434f,   0.0595f,   16.6800f,  41.4730f, 6.9269f,   -0.0512f,  -1.4540f,
-      0.0468f,   0.0077f,   0.0983f,   0.1265f,  -0.5234f,  0.9477f,   36.6470f,
-      -0.4838f,  -0.2269f,  -0.1143f,  -0.3907f, -0.5005f,  -0.0179f,  -0.1057f,
-      0.1233f,   -0.4412f,  -0.0474f,  0.1140f,  -21.6813f, -0.9077f,  -0.0078f,
-      -3.3306f,  0.0417f,   0.0412f,   0.0427f,  0.0418f,   -0.1699f,  0.0072f,
-      -22.3335f, 16.1203f,  -10.1220f, -0.0019f, 0.0005f,   -0.0054f,  -0.0155f,
-      -0.0302f,  -0.0379f,  0.1276f,   0.1568f,  21.6175f,  12.2919f,  11.0327f,
-      -0.2000f,  -8.6691f,  -0.5593f,  -0.5952f, -0.4203f,  -0.4857f,  -1.1239f,
-      3.1404f,   -13.1098f, -5.9165f,  22.2060f, -0.0312f,  -3.9642f,  -0.0344f,
-      -0.0656f,  -0.0273f,  -0.0465f,  0.1412f,  -6.1974f,  9.3661f,
+      -0.1758f, -0.0499f, -10.0069f, -2.2838f,  -0.3359f,  0.3459f,  -0.3285f,
+      -0.0515f, -0.5417f, 0.2357f,   -0.0575f,  -69.0782f, 0.5348f,  1.4068f,
+      0.2213f,  -1.0490f, -0.0636f,  0.1654f,   1.1002f,   33.4924f, 0.4358f,
+      1.2499f,  0.1143f,  0.0592f,   -1.6335f,  -0.0092f,  1.2207f,  -28.4543f,
+      -0.4973f, 0.4368f,  0.2341f,   -0.1623f,  -3.8986f,  0.1311f,  -1.8789f,
+      -3.9079f, -0.8158f, -0.8420f,  1.4295f,   -2.3629f,  -1.4825f, 0.6498f,
+      -5.3669f, 6.4434f,  1.8393f,   -35.0678f, 3.7459f,   -2.8504f, 2.0502f,
+      -0.1812f, -3.9011f, -1.0155f,  1.8375f,   -1.4517f,  1.3917f,  3.8664f,
+      0.8345f,  -0.3472f, 5.7740f,   -1.1196f,  -0.3264f,  -1.2481f, -0.9284f,
+      -4.9657f, 2.2831f,  0.7337f,   2.3176f,   0.6416f,   0.8804f,  1.9988f,
+      -1.3426f, 1.2728f,  1.2249f,   -0.1551f,  5.6045f,   0.2046f,  -2.1464f,
+      -2.4922f, -0.5334f, 12.1055f,  7.2467f,   -0.0070f,  0.0234f,  0.0021f,
+      0.0215f,  -0.0098f, -0.0682f,  -6.1494f,  -0.3176f,  -1.6069f, -0.2119f,
+      -1.0533f, -0.3566f, 0.5294f,   -0.4335f,  0.1626f,
     };
 
 static const float av1_pustats_rate_hiddenlayer_0_bias[HIDDEN_LAYERS_0_NODES] =
     {
-      -14.3065f, 2.059f,   -62.9916f, -50.1209f, 57.643f,  -59.3737f,
-      -30.4737f, -0.1112f, 72.5427f,  55.402f,   24.9523f, 18.5834f,
+      10.5266f, 5.3268f, -1.0678f, 7.7411f,  8.7164f,  -0.3235f,
+      7.3028f,  9.0874f, -6.4594f, -1.0102f, -1.1146f, 10.8419f,
     };
 
 static const float
     av1_pustats_rate_hiddenlayer_1_kernel[HIDDEN_LAYERS_0_NODES *
                                           HIDDEN_LAYERS_1_NODES] = {
-      0.3883f,  -0.2784f, -0.2850f, 0.4894f,  -2.2450f, 0.4511f,  -0.1969f,
-      -0.0077f, -1.4924f, 0.1138f,  -2.9848f, 1.0211f,  -0.1712f, -0.1952f,
-      -0.4774f, 0.0761f,  -0.3186f, -0.1002f, 0.8663f,  0.5026f,  1.1920f,
-      0.9337f,  0.3911f,  -0.3841f, -0.0037f, 0.7295f,  -0.3183f, 0.1829f,
-      -1.3670f, -0.1046f, 0.6629f,  0.0619f,  -0.1551f, 0.8174f,  2.1521f,
-      -1.3323f, -0.0527f, -0.5772f, 0.2001f,  -0.6270f, -1.0625f, 0.3342f,
-      0.6676f,  0.4605f,  -2.0049f, 0.7781f,  0.0713f,  -0.0824f, -0.4529f,
-      0.1757f,  -0.1338f, -0.2319f, -0.2864f, 0.1248f,  0.3887f,  -0.1676f,
-      1.8422f,  0.6435f,  1.2123f,  -0.5667f, -0.2423f, -0.0314f, 0.2411f,
-      -0.5013f, 0.0422f,  0.2559f,  0.4435f,  -0.1223f, 1.5167f,  0.3939f,
-      1.0898f,  0.0795f,  -0.9251f, -0.0813f, -0.5929f, -0.0741f, 4.0687f,
-      -0.4368f, -0.0984f, 0.0837f,  3.6169f,  0.0662f,  -0.1679f, -0.8090f,
-      -0.2610f, -0.5791f, 0.0642f,  -0.2979f, -0.9036f, 0.2898f,  0.3265f,
-      0.4660f,  -1.6358f, -0.0347f, 0.1087f,  0.0353f,  0.5687f,  -0.5242f,
-      -0.4895f, 0.7693f,  -1.3829f, -0.2244f, -0.2880f, 0.0575f,  2.0563f,
-      -0.2322f, -1.1597f, 1.6125f,  -0.0925f, 1.3540f,  0.1432f,  0.3993f,
-      -0.0303f, -1.1438f, -1.7323f, -0.4329f, 2.9443f,  -0.5724f, 0.0122f,
-      -1.0829f,
+      10.5932f,  2.5192f,  -0.0015f, 5.9479f,   5.2426f,   -0.4091f, 5.3220f,
+      6.0469f,   0.7200f,  3.3241f,  5.5006f,   12.8290f,  -1.6396f, 0.5743f,
+      -0.8370f,  1.9956f,  -4.9270f, -1.5295f,  2.1350f,   -9.4415f, -0.7094f,
+      5.1822f,   19.7287f, -3.0444f, -0.3320f,  0.0031f,   -0.2709f, -0.5249f,
+      0.3281f,   -0.2240f, 0.2225f,  -0.2386f,  -0.4370f,  -0.2438f, -0.4928f,
+      -0.2842f,  -2.1772f, 9.2570f,  -17.6655f, 3.5448f,   -2.8394f, -1.0167f,
+      -0.5115f,  -1.9260f, -0.2111f, -0.7528f,  -1.2387f,  -0.0401f, 5.0716f,
+      -3.3763f,  -0.2898f, -0.4956f, -7.9993f,  0.1526f,   -0.0242f, 0.7354f,
+      6.0432f,   4.8043f,  7.4790f,  -0.6295f,  1.7565f,   3.7197f,  -2.3963f,
+      6.8945f,   2.9717f,  -3.1623f, 3.4241f,   4.4676f,   -1.8154f, -2.9401f,
+      -8.5657f,  -3.0240f, -1.4661f, 8.1145f,   -12.7858f, 3.3624f,  -1.0819f,
+      -4.2856f,  1.1801f,  -0.5587f, -1.6062f,  -1.1813f,  -3.5882f, -0.2490f,
+      -24.9566f, -0.4140f, -0.1113f, 3.5537f,   4.4112f,   0.1367f,  -1.5876f,
+      1.6605f,   1.3903f,  -0.0253f, -2.1419f,  -2.2197f,  -0.7659f, -0.4249f,
+      -0.0424f,  0.1486f,  0.4643f,  -0.9068f,  -0.3619f,  -0.7624f, -0.9132f,
+      -0.4947f,  -0.3527f, -0.5445f, -0.4768f,  -1.7761f,  -1.0686f, 0.5462f,
+      1.3371f,   4.3116f,  0.0777f,  -2.7216f,  -1.8908f,  3.4989f,  7.7269f,
+      -2.7566f,
     };
 
 static const float av1_pustats_rate_hiddenlayer_1_bias[HIDDEN_LAYERS_1_NODES] =
     {
-      -10.3717f, 37.304f,  -36.7221f, -52.7572f, 44.0877f,
-      41.1631f,  36.3299f, -48.6087f, -4.5189f,  13.0611f,
+      13.2435f, -8.5477f, -0.0998f, -1.5131f, -12.0187f,
+      6.1715f,  0.5094f,  7.6433f,  -0.3992f, -1.3555f,
     };
 
 static const float
     av1_pustats_rate_logits_kernel[HIDDEN_LAYERS_1_NODES * LOGITS_NODES] = {
-      0.8362f, 1.0615f, -1.5178f, -1.2959f, 1.3233f,
-      1.4909f, 1.3554f, -0.8626f, -0.618f,  -0.9458f,
+      4.3078f, -17.3497f, 0.0195f,  34.6032f, -5.0127f,
+      5.3079f, 10.0077f,  -13.129f, 0.0087f,  -8.4009f,
     };
 
 static const float av1_pustats_rate_logits_bias[LOGITS_NODES] = {
-  30.6878f,
+  4.5103f,
 };
 
 static const NN_CONFIG av1_pustats_rate_nnconfig = {
-  NUM_FEATURES,                                      // num_inputs
+  NUM_FEATURES_PUSTATS,                              // num_inputs
   LOGITS_NODES,                                      // num_outputs
   NUM_HIDDEN_LAYERS,                                 // num_hidden_layers
   { HIDDEN_LAYERS_0_NODES, HIDDEN_LAYERS_1_NODES },  // num_hidden_nodes
@@ -111,76 +106,71 @@
 };
 
 static const float
-    av1_pustats_dist_hiddenlayer_0_kernel[NUM_FEATURES *
+    av1_pustats_dist_hiddenlayer_0_kernel[NUM_FEATURES_PUSTATS *
                                           HIDDEN_LAYERS_0_NODES] = {
-      0.7770f,   1.0881f,  0.0177f,  0.4939f,  -0.2541f, -0.2672f, -0.1705f,
-      -0.1940f,  -0.6395f, 1.2928f,  3.6240f,  2.4445f,  1.6790f,  0.0265f,
-      0.1897f,   0.1776f,  0.0422f,  0.0197f,  -0.0466f, 0.0462f,  -1.0827f,
-      2.0231f,   1.8044f,  2.7022f,  0.0064f,  0.2255f,  -0.0552f, -0.1010f,
-      -0.0581f,  -0.0781f, 0.2614f,  -3.4085f, 1.7478f,  0.1155f,  -0.1458f,
-      -0.0031f,  -0.1797f, -0.4378f, -0.0539f, 0.0607f,  -0.1347f, -0.3142f,
-      -0.2014f,  -0.4484f, -0.2808f, 1.5913f,  0.0046f,  -0.0610f, -0.6479f,
-      -0.7278f,  -0.5592f, -0.6695f, -0.8120f, 2.9056f,  -1.1501f, 9.3618f,
-      4.2486f,   0.0011f,  -0.1499f, -0.0834f, 0.1282f,  0.0409f,  0.1670f,
-      -0.1398f,  -0.4661f, 13.7700f, 8.2061f,  -0.0685f, 0.0061f,  -0.2951f,
-      0.0169f,   0.0520f,  0.0040f,  0.0374f,  0.0467f,  -0.0107f, 14.2664f,
-      -2.2489f,  -0.2516f, -0.0061f, -0.9921f, 0.1223f,  0.1212f,  0.1199f,
-      0.1185f,   -0.4867f, 0.0325f,  -5.0757f, -8.7853f, 1.0450f,  0.0169f,
-      0.5462f,   0.0051f,  0.1330f,  0.0143f,  0.1429f,  -0.0258f, 0.2769f,
-      -12.8839f, 22.3093f, 1.2761f,  0.0037f,  -1.2459f, -0.0466f, 0.0003f,
-      -0.0464f,  -0.0067f, 0.2361f,  0.0355f,  23.3833f, 10.9218f, 2.6811f,
-      0.0222f,   -1.1055f, 0.1825f,  0.0575f,  0.0114f,  -0.1259f, 0.3148f,
-      -2.0047f,  11.9559f, 5.7375f,  0.8802f,  0.0042f,  -0.2469f, -0.1040f,
-      -1.5679f,  0.1969f,  -0.0184f, 0.0157f,  0.6688f,  3.4492f,
+      -0.2560f, 0.1105f,  -0.8434f, -0.0132f, -8.9371f, -1.1176f, -0.3655f,
+      0.4885f,  1.7518f,  0.4985f,  0.5582f,  -0.3739f, 0.9403f,  0.3874f,
+      0.3265f,  1.7383f,  3.1747f,  0.0285f,  3.3942f,  -0.0123f, 0.5057f,
+      0.1584f,  0.2697f,  4.6151f,  3.6251f,  -0.0121f, -1.0047f, -0.0037f,
+      0.0127f,  0.1935f,  -0.5277f, -2.7144f, 0.0729f,  -0.1457f, -0.0816f,
+      -0.5462f, 0.4738f,  0.3599f,  -0.0564f, 0.0910f,  0.0126f,  -0.0310f,
+      -2.1311f, -0.4666f, -0.0074f, -0.0765f, 0.0287f,  -0.2662f, -0.0999f,
+      -0.2983f, -0.4899f, -0.2314f, 0.2873f,  -0.3614f, 0.1783f,  -0.1210f,
+      0.3569f,  0.5436f,  -8.0536f, -0.0044f, -1.5255f, -0.8247f, -0.4556f,
+      1.9045f,  0.5463f,  0.1102f,  -0.9293f, -0.0185f, -0.8302f, -0.4378f,
+      -0.3531f, -1.3095f, 0.6099f,  0.7977f,  4.1950f,  -0.0067f, -0.2762f,
+      -0.1574f, -0.2149f, 0.6104f,  -1.7053f, 0.1904f,  4.2402f,  -0.2671f,
+      0.8940f,  0.6820f,  0.2241f,  -0.9459f, 1.4571f,  0.5255f,  2.3352f,
+      -0.0806f, 0.5231f,  0.3928f,  0.4146f,  2.0956f,
     };
 
 static const float av1_pustats_dist_hiddenlayer_0_bias[HIDDEN_LAYERS_0_NODES] =
     {
-      4.5051f,  -4.5858f, 1.4693f, 0.f,      3.7968f, -3.6292f,
-      -7.3112f, 10.9743f, 8.027f,  -2.2692f, -8.748f, -1.3689f,
+      1.1597f, 0.0836f, -0.7471f, -0.2439f, -0.0438f, 2.4626f,
+      0.f,     1.1485f, 2.7085f,  -4.7897f, 1.4093f,  -1.657f,
     };
 
 static const float
     av1_pustats_dist_hiddenlayer_1_kernel[HIDDEN_LAYERS_0_NODES *
                                           HIDDEN_LAYERS_1_NODES] = {
-      -0.0182f, -0.0925f, -0.0311f, -0.2962f, 0.1177f,  -0.0027f, -0.2136f,
-      -1.2094f, 0.0935f,  -0.1403f, -0.1477f, -0.0752f, 0.1519f,  -0.4726f,
-      -0.3521f, 0.4199f,  -0.0168f, -0.2927f, -0.2510f, 0.0706f,  -0.2920f,
-      0.2046f,  -0.0400f, -0.2114f, 0.4240f,  -0.7070f, 0.4964f,  0.4471f,
-      0.3841f,  -0.0918f, -0.6140f, 0.6056f,  -0.1123f, 0.3944f,  -0.0178f,
-      -1.7702f, -0.4434f, 0.0560f,  0.1565f,  -0.0793f, -0.0041f, 0.0052f,
-      -0.1843f, 0.2400f,  -0.0605f, 0.3196f,  -0.0286f, -0.0002f, -0.0595f,
-      -0.0493f, -0.2636f, -0.3994f, -0.1871f, -0.3298f, -0.0788f, -1.0685f,
-      0.1900f,  -0.5549f, -0.1350f, -0.0153f, -0.1195f, -0.5874f, 1.0468f,
-      0.0212f,  -0.2306f, -0.2677f, -0.3000f, -1.0702f, -0.1725f, -0.0656f,
-      -0.0226f, 0.0616f,  -0.3453f, 0.0810f,  0.4838f,  -0.3780f, -1.4486f,
-      0.7777f,  -0.0459f, -0.6568f, 0.0589f,  -1.0286f, -0.6001f, 0.0826f,
-      0.4794f,  -0.0586f, -0.1759f, 0.3811f,  -0.1313f, 0.3829f,  -0.0968f,
-      -2.0445f, -0.3566f, -0.1491f, -0.0745f, -0.0202f, 0.0839f,  0.0470f,
-      -0.2432f, 0.3013f,  -0.0743f, -0.3479f, 0.0749f,  -5.2490f, 0.0209f,
-      -0.1653f, -0.0826f, -0.0535f, 0.3225f,  -0.3786f, -0.0104f, 0.3091f,
-      0.3652f,  0.1757f,  -0.3252f, -1.1022f, -0.0574f, -0.4473f, 0.3469f,
-      -0.5539f,
+      -0.5203f, -1.3468f, 0.3865f,  -0.6859f, 0.0058f,  4.0682f,  0.4807f,
+      -0.1380f, 0.6050f,  0.8958f,  0.7748f,  -0.1311f, 1.7317f,  1.1265f,
+      0.0827f,  0.1407f,  -0.3605f, 0.5429f,  0.1880f,  -0.1439f, 0.2837f,
+      1.6477f,  0.0832f,  0.0593f,  -1.8464f, -0.7241f, -1.0672f, -0.3546f,
+      -0.3842f, -2.3637f, 0.2514f,  0.8263f,  -0.1872f, 0.5774f,  -0.3610f,
+      -0.0205f, 1.3977f,  -0.1083f, 0.6923f,  1.3039f,  -0.2870f, 1.0622f,
+      -0.0566f, 0.2697f,  -0.5429f, -0.6193f, 1.7559f,  0.3246f,  1.9159f,
+      0.3744f,  0.0686f,  1.0191f,  -0.4212f, 1.9591f,  -0.0691f, -0.1085f,
+      -1.2034f, 0.0606f,  1.0116f,  0.5565f,  -0.1874f, -0.7898f, 0.4796f,
+      0.2290f,  0.4334f,  -0.5817f, -0.2949f, 0.1367f,  -0.2932f, -1.1265f,
+      0.0133f,  -0.5309f, -3.3191f, 0.0939f,  0.3895f,  -2.5812f, -0.0066f,
+      -3.0063f, -0.2982f, 0.7309f,  -0.2422f, -0.2770f, -0.7152f, 0.1700f,
+      1.9630f,  0.1988f,  0.4194f,  0.8762f,  0.3402f,  0.1051f,  -0.1598f,
+      0.2405f,  0.0392f,  1.1256f,  1.5245f,  0.0950f,  0.2160f,  -0.5023f,
+      0.2584f,  0.2074f,  0.2218f,  0.3966f,  -0.0921f, -0.2435f, -0.4560f,
+      -1.1923f, -0.3716f, -0.3286f, -1.3225f, 0.1896f,  -0.3342f, -0.7888f,
+      -0.4488f, -1.7168f, 0.3341f,  0.1146f,  0.5226f,  0.2610f,  -0.4574f,
+      -0.4164f,
     };
 
 static const float av1_pustats_dist_hiddenlayer_1_bias[HIDDEN_LAYERS_1_NODES] =
     {
-      11.9337f, -0.3681f, -6.1324f,  12.674f,  9.0956f,
-      4.6069f,  -4.4158f, -12.4848f, 10.8473f, 5.7633f,
+      -2.3014f, -2.4292f, 1.3317f, -3.2361f, -1.918f,
+      2.7149f,  -2.5649f, 2.7765f, 2.9617f,  2.7684f,
     };
 
 static const float
     av1_pustats_dist_logits_kernel[HIDDEN_LAYERS_1_NODES * LOGITS_NODES] = {
-      0.3245f,  0.2979f,  -0.157f,  -0.1441f, 0.1413f,
-      -0.7496f, -0.1737f, -0.5322f, 0.0748f,  0.2518f,
+      -0.6868f, -0.6715f, 0.449f,  -1.293f, 0.6214f,
+      0.9894f,  -0.4342f, 0.7002f, 1.4363f, 0.6951f,
     };
 
 static const float av1_pustats_dist_logits_bias[LOGITS_NODES] = {
-  4.6065f,
+  2.3371f,
 };
 
 static const NN_CONFIG av1_pustats_dist_nnconfig = {
-  NUM_FEATURES,                                      // num_inputs
+  NUM_FEATURES_PUSTATS,                              // num_inputs
   LOGITS_NODES,                                      // num_outputs
   NUM_HIDDEN_LAYERS,                                 // num_hidden_layers
   { HIDDEN_LAYERS_0_NODES, HIDDEN_LAYERS_1_NODES },  // num_hidden_nodes
@@ -196,7 +186,6 @@
   },
 };
 
-#undef NUM_FEATURES
 #undef NUM_HIDDEN_LAYERS
 #undef HIDDEN_LAYERS_0_NODES
 #undef HIDDEN_LAYERS_1_NODES
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index fef6d28..aac3bff 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2506,25 +2506,26 @@
     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
     mean /= (1 << shift);
   }
-  const double variance = sse_norm - mean * mean;
-  assert(variance >= 0.0);
+  double sse_norm_sum = 0.0, sse_frac_arr[3];
+  for (int k = 0; k < 4; ++k) sse_norm_sum += sse_norm_arr[k];
+  for (int k = 0; k < 3; ++k)
+    sse_frac_arr[k] =
+        sse_norm_sum > 0.0 ? sse_norm_arr[k] / sse_norm_sum : 0.25;
   const double q_sqr = (double)(q_step * q_step);
   const double q_sqr_by_sse_norm = q_sqr / (sse_norm + 1.0);
+  const double mean_sqr_by_sse_norm = mean * mean / (sse_norm + 1.0);
   double hor_corr, vert_corr;
   get_horver_correlation(src_diff, diff_stride, bw, bh, &hor_corr, &vert_corr);
 
-  float features[11];
+  float features[NUM_FEATURES_PUSTATS];
   features[0] = (float)hor_corr;
   features[1] = (float)log_numpels;
-  features[2] = (float)q_sqr;
+  features[2] = (float)mean_sqr_by_sse_norm;
   features[3] = (float)q_sqr_by_sse_norm;
-  features[4] = (float)sse_norm_arr[0];
-  features[5] = (float)sse_norm_arr[1];
-  features[6] = (float)sse_norm_arr[2];
-  features[7] = (float)sse_norm_arr[3];
-  features[8] = (float)sse_norm;
-  features[9] = (float)variance;
-  features[10] = (float)vert_corr;
+  features[4] = (float)sse_frac_arr[0];
+  features[5] = (float)sse_frac_arr[1];
+  features[6] = (float)sse_frac_arr[2];
+  features[7] = (float)vert_corr;
 
   float rate_f, dist_by_sse_norm_f;
   av1_nn_predict(features, &av1_pustats_dist_nnconfig, &dist_by_sse_norm_f);
@@ -2564,6 +2565,7 @@
 
   x->pred_sse[ref] = 0;
 
+  aom_clear_system_state();
   for (int plane = plane_from; plane <= plane_to; ++plane) {
     struct macroblockd_plane *const pd = &xd->plane[plane];
     const BLOCK_SIZE plane_bsize =
@@ -2586,6 +2588,7 @@
     if (plane_sse) plane_sse[plane] = sse;
     if (plane_dist) plane_dist[plane] = dist;
   }
+  aom_clear_system_state();
 
   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;