Improve classified rate model
Fixes a bug in rate model generation and updates the model.
lowres (20 frames, end-usage q): Av PSNR -0.075%
midres (20 frames, end-usage q): Av PSNR -0.081%
Also changes the interface to the model to pass in sse for subsequent
use in classified distortion model.
STATS_CHANGED
Change-Id: I270b9728f9e34bc76b24ae3bd7af87fd6a0cc701
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 5152b66..3554389 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -1273,8 +1273,9 @@
0.007205, 0.007205, 0.007203, 0.004341, 0.004340, 0.004338,
};
-void av1_model_rd_surffit(BLOCK_SIZE bsize, double xm, double yl,
- double *rate_f, double *dist_f) {
+void av1_model_rd_surffit(BLOCK_SIZE bsize, double sse_norm, double xm,
+ double yl, double *rate_f, double *dist_f) {
+ (void)sse_norm;
const double x_start = -0.5;
const double x_end = 16.5;
const double x_step = 1.0;
@@ -1283,7 +1284,7 @@
const double y_step = 1.0;
const double epsilon = 1e-6;
const int stride = (int)rint((x_end - x_start) / x_step) + 1;
- const int cat = bsize_model_cat_lookup[bsize];
+ const int rcat = bsize_model_cat_lookup[bsize];
(void)y_end;
xm = AOMMAX(xm, x_start + x_step + epsilon);
@@ -1301,7 +1302,7 @@
const double yo = y - yi;
const double xo = x - xi;
- const double *prate = &interp_rgrid_surf[cat][(yi - 1) * stride + (xi - 1)];
+ const double *prate = &interp_rgrid_surf[rcat][(yi - 1) * stride + (xi - 1)];
const double *pdist = &interp_dgrid_surf[(yi - 1) * stride + (xi - 1)];
*rate_f = interp_bicubic(prate, stride, xo, yo);
*dist_f = interp_bicubic(pdist, stride, xo, yo);
@@ -1311,62 +1312,62 @@
{
0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
- 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
+ 0.000000, 23.801499, 28.387688, 33.388795, 42.298282,
41.525408, 51.597692, 49.566271, 54.632979, 60.321507,
67.730678, 75.766165, 85.324032, 96.600012, 120.839562,
173.917577, 255.974908, 354.107573, 458.063476, 562.345966,
668.568424, 772.072881, 878.598490, 982.202274, 1082.708946,
1188.037853, 1287.702240, 1395.588773, 1490.825830, 1584.231230,
1691.386090, 1766.822555, 1869.630904, 1926.743565, 2002.949495,
- 2047.431137, 2138.486068, 2154.743767, 2209.242472, 2278.252010,
- 2298.028834, 2302.326180, 2293.979995, 2275.826226, 2250.700821,
- 2221.439725, 2190.878887, 2161.854252, 2137.201768, 2119.757381,
- 2112.357039, 2117.836689, 2139.032277, 2178.779750, 2239.915056,
- },
- {
- 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
- 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
- 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
- 11.561347, 12.578139, 14.205101, 16.770584, 19.094853,
- 21.330863, 23.298907, 26.901921, 34.501017, 57.891733,
- 112.234763, 194.853189, 288.302032, 380.499422, 472.625309,
- 560.226809, 647.928463, 734.155122, 817.489721, 906.265783,
- 999.260562, 1094.489206, 1197.062998, 1293.296825, 1378.926484,
- 1472.760990, 1552.663779, 1635.196884, 1692.451951, 1759.741063,
- 1822.162720, 1916.515921, 1966.686071, 2031.647506, 2031.381029,
- 2067.971335, 2203.662704, 2500.257936, 3019.559830, 3823.371186,
- 4973.494802, 6531.733478, 8559.890013, 11119.767206, 14273.167855,
- 18081.894761, 22607.750723, 27912.538538, 34058.061008, 41106.120930,
+ 2047.431137, 2138.486068, 2154.743767, 2209.242472, 2277.593051,
+ 2290.996432, 2307.452938, 2343.567091, 2397.654644, 2469.425868,
+ 2558.591037, 2664.860422, 2787.944296, 2927.552932, 3083.396602,
+ 3255.185579, 3442.630134, 3645.440541, 3863.327072, 4096.000000,
},
{
0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
+ 0.000000, 8.998436, 9.439592, 9.731837, 10.865931,
+ 11.561347, 12.578139, 14.205101, 16.770584, 19.094853,
+ 21.330863, 23.298907, 26.901921, 34.501017, 57.891733,
+ 112.234763, 194.853189, 288.302032, 380.499422, 472.625309,
+ 560.226809, 647.928463, 734.155122, 817.489721, 906.265783,
+ 999.260562, 1094.489206, 1197.062998, 1293.296825, 1378.926484,
+ 1472.760990, 1552.663779, 1635.196884, 1692.451951, 1759.741063,
+ 1822.162720, 1916.515921, 1966.686071, 2031.647506, 2033.700134,
+ 2087.847688, 2161.688858, 2242.536028, 2334.023491, 2436.337802,
+ 2549.665519, 2674.193198, 2810.107395, 2957.594666, 3116.841567,
+ 3288.034655, 3471.360486, 3667.005616, 3875.156602, 4096.000000,
+ },
+ {
0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
+ 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
+ 0.000000, 2.377584, 2.557185, 2.732445, 2.851114,
3.281800, 3.765589, 4.342578, 5.145582, 5.611038,
6.642238, 7.945977, 11.800522, 17.346624, 37.501413,
87.216800, 165.860942, 253.865564, 332.039345, 408.518863,
478.120452, 547.268590, 616.067676, 680.022540, 753.863541,
834.529973, 919.489191, 1008.264989, 1092.230318, 1173.971886,
1249.514122, 1330.510941, 1399.523249, 1466.923387, 1530.533471,
- 1586.515722, 1695.197774, 1746.648696, 1837.136959, 1909.056910,
- 1974.948082, 2063.374132, 2178.496387, 2324.476176, 2505.474827,
- 2725.653666, 2989.174023, 3300.197225, 3662.884600, 4081.397476,
- 4559.897180, 5102.545042, 5713.502387, 6396.930546, 7156.990844,
+ 1586.515722, 1695.197774, 1746.648696, 1837.136959, 1909.075485,
+ 1975.074651, 2060.159200, 2155.335095, 2259.762505, 2373.710437,
+ 2497.447898, 2631.243895, 2775.367434, 2930.087523, 3095.673170,
+ 3272.393380, 3460.517161, 3660.313520, 3872.051464, 4096.000000,
},
{
- 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
- 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
- 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
- 0.614483, 0.842937, 1.050824, 1.326663, 1.717750,
- 2.530591, 3.582302, 6.995373, 9.973335, 24.042464,
- 56.598240, 113.680735, 180.018689, 231.050567, 266.101082,
- 294.957934, 323.326511, 349.434429, 380.443211, 408.171987,
- 441.214916, 475.716772, 512.900000, 551.186939, 592.364455,
- 624.527378, 661.940693, 679.185473, 724.800679, 764.781792,
- 873.050019, 950.299001, 939.292954, 1052.406153, 1030.816617,
- 1086.316710, 1275.467594, 1671.923018, 2349.336727, 3381.362469,
- 4841.653990, 6803.865037, 9341.649358, 12528.660698, 16438.552805,
- 21144.979426, 26721.594308, 33242.051197, 40780.003840, 49409.105984,
+ 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
+ 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
+ 0.000000, 0.296997, 0.342545, 0.403097, 0.472889,
+ 0.614483, 0.842937, 1.050824, 1.326663, 1.717750,
+ 2.530591, 3.582302, 6.995373, 9.973335, 24.042464,
+ 56.598240, 113.680735, 180.018689, 231.050567, 266.101082,
+ 294.957934, 323.326511, 349.434429, 380.443211, 408.171987,
+ 441.214916, 475.716772, 512.900000, 551.186939, 592.364455,
+ 624.527378, 661.940693, 679.185473, 724.800679, 764.781792,
+ 873.050019, 950.299001, 939.292954, 1052.406153, 1033.893184,
+ 1112.182406, 1219.174326, 1337.296681, 1471.648357, 1622.492809,
+ 1790.093491, 1974.713858, 2176.617364, 2396.067465, 2633.327614,
+ 2888.661266, 3162.331876, 3454.602899, 3765.737789, 4096.000000,
},
};
@@ -1383,13 +1384,14 @@
0.000000, 0.000000,
};
-void av1_model_rd_curvfit(BLOCK_SIZE bsize, double xqr, double *rate_f,
- double *distbysse_f) {
+void av1_model_rd_curvfit(BLOCK_SIZE bsize, double sse_norm, double xqr,
+ double *rate_f, double *distbysse_f) {
+ (void)sse_norm;
const double x_start = -15.5;
const double x_end = 16.5;
const double x_step = 0.5;
const double epsilon = 1e-6;
- const int cat = bsize_model_cat_lookup[bsize];
+ const int rcat = bsize_model_cat_lookup[bsize];
(void)x_end;
xqr = AOMMAX(xqr, x_start + x_step + epsilon);
@@ -1400,9 +1402,9 @@
assert(xi > 0);
- const double *prate = &interp_rgrid_curv[cat][(xi - 1)];
- const double *pdist = &interp_dgrid_curv[(xi - 1)];
+ const double *prate = &interp_rgrid_curv[rcat][(xi - 1)];
*rate_f = interp_cubic(prate, xo);
+ const double *pdist = &interp_dgrid_curv[(xi - 1)];
*distbysse_f = interp_cubic(pdist, xo);
}
diff --git a/av1/encoder/rd.h b/av1/encoder/rd.h
index e29a1d5..350eeb6 100644
--- a/av1/encoder/rd.h
+++ b/av1/encoder/rd.h
@@ -656,10 +656,10 @@
void av1_model_rd_from_var_lapndz(int64_t var, unsigned int n,
unsigned int qstep, int *rate, int64_t *dist);
-void av1_model_rd_curvfit(BLOCK_SIZE bsize, double xqr, double *rate_f,
- double *distbysse_f);
-void av1_model_rd_surffit(BLOCK_SIZE bsize, double xm, double yl,
+void av1_model_rd_curvfit(BLOCK_SIZE bsize, double sse_norm, double xqr,
double *rate_f, double *distbysse_f);
+void av1_model_rd_surffit(BLOCK_SIZE bsize, double sse_norm, double xm,
+ double yl, double *rate_f, double *distbysse_f);
int av1_get_switchable_rate(const AV1_COMMON *const cm, MACROBLOCK *x,
const MACROBLOCKD *xd);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index f8a4eb6..eba1827 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2833,7 +2833,8 @@
const double yl = log(sse_norm / qstepsqr) / log(2.0);
double rate_f, dist_by_sse_norm_f;
- av1_model_rd_surffit(plane_bsize, xm, yl, &rate_f, &dist_by_sse_norm_f);
+ av1_model_rd_surffit(plane_bsize, sse_norm, xm, yl, &rate_f,
+ &dist_by_sse_norm_f);
const double dist_f = dist_by_sse_norm_f * sse_norm;
int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
@@ -2937,7 +2938,8 @@
const double xqr = log(sse_norm / qstepsqr) / log(2.0);
double rate_f, dist_by_sse_norm_f;
- av1_model_rd_curvfit(plane_bsize, xqr, &rate_f, &dist_by_sse_norm_f);
+ av1_model_rd_curvfit(plane_bsize, sse_norm, xqr, &rate_f,
+ &dist_by_sse_norm_f);
const double dist_f = dist_by_sse_norm_f * sse_norm;
int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);