Misc fixes for RD stats collection
Change-Id: I6fe5018075aced0c7dd824438edff598cd50f414
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index a88b0f4..71394e5 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2347,11 +2347,29 @@
blk_row, blk_col, plane_bsize, tx_bsize);
}
-static double get_mean(const int16_t *diff, int stride, int w, int h) {
+static double get_diff_mean(const uint8_t *src, int src_stride,
+ const uint8_t *dst, int dst_stride, int w, int h) {
double sum = 0.0;
for (int j = 0; j < h; ++j) {
for (int i = 0; i < w; ++i) {
- sum += diff[j * stride + i];
+ const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
+ sum += diff;
+ }
+ }
+ assert(w > 0 && h > 0);
+ return sum / (w * h);
+}
+
+static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
+ const uint8_t *dst8, int dst_stride, int w,
+ int h) {
+ const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+ const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
+ double sum = 0.0;
+ for (int j = 0; j < h; ++j) {
+ for (int i = 0; i < w; ++i) {
+ const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
+ sum += diff;
}
}
assert(w > 0 && h > 0);
@@ -2440,6 +2458,17 @@
#if CONFIG_COLLECT_RD_STATS
#if CONFIG_COLLECT_RD_STATS == 1
+static double get_mean(const int16_t *diff, int stride, int w, int h) {
+ double sum = 0.0;
+ for (int j = 0; j < h; ++j) {
+ for (int i = 0; i < w; ++i) {
+ sum += diff[j * stride + i];
+ }
+ }
+ assert(w > 0 && h > 0);
+ return sum / (w * h);
+}
+
static void PrintTransformUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x,
const RD_STATS *const rd_stats, int blk_row,
int blk_col, BLOCK_SIZE plane_bsize,
@@ -2588,7 +2617,14 @@
const int16_t *const src_diff = p->src_diff;
const int shift = (xd->bd - 8);
- int64_t sse = aom_sum_squares_2d_i16(src_diff, diff_stride, bw, bh);
+ int64_t sse;
+ if (is_cur_buf_hbd(xd)) {
+ sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
+ bw, bh);
+ } else {
+ sse =
+ aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw, bh);
+ }
sse = ROUND_POWER_OF_TWO(sse, shift * 2);
const double sse_norm = (double)sse / num_samples;
@@ -2627,6 +2663,26 @@
fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
model_rdcost_norm);
+ double mean;
+ if (is_cur_buf_hbd(xd)) {
+ mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
+ pd->dst.stride, bw, bh);
+ } else {
+ mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
+ bw, bh);
+ }
+ mean /= (1 << shift);
+ float hor_corr, vert_corr;
+ av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
+ &vert_corr);
+ fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
+
+ double hdist[4] = { 0 }, vdist[4] = { 0 };
+ get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
+ dst_stride, 1, hdist, vdist);
+ fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
+ hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
+
#if CONFIG_COLLECT_INTER_MODE_RD_STATS
if (cpi->sf.inter_mode_rd_model_estimation == 1) {
assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
@@ -2644,19 +2700,6 @@
}
#endif
- double mean = get_mean(src_diff, diff_stride, bw, bh);
- mean /= (1 << shift);
- float hor_corr, vert_corr;
- av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
- &vert_corr);
- fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
-
- double hdist[4] = { 0 }, vdist[4] = { 0 };
- get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
- dst_stride, 1, hdist, vdist);
- fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
- hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
-
fprintf(fout, "\n");
fclose(fout);
}
@@ -2708,7 +2751,12 @@
get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
dst_stride, src_diff, diff_stride,
sse_norm_arr, NULL);
- double mean = get_mean(src_diff, bw, bw, bh);
+ double mean;
+ if (is_cur_buf_hbd(xd)) {
+ mean = get_highbd_diff_mean(src, src_stride, dst, dst_stride, bw, bh);
+ } else {
+ mean = get_diff_mean(src, src_stride, dst, dst_stride, bw, bh);
+ }
if (shift) {
for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
mean /= (1 << shift);
@@ -8160,7 +8208,8 @@
AOM_PLANE_Y, AOM_PLANE_Y);
#if CONFIG_COLLECT_RD_STATS == 3
RD_STATS rd_stats_y;
- select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
+ pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col,
+ INT64_MAX);
PrintPredictionUnitStats(cpi, tile_data, x, &rd_stats_y, bsize);
#endif // CONFIG_COLLECT_RD_STATS == 3
model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
@@ -8515,7 +8564,7 @@
#if CONFIG_COLLECT_RD_STATS == 3
RD_STATS rd_stats_y;
- select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
+ pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
PrintPredictionUnitStats(cpi, tile_data, x, &rd_stats_y, bsize);
#endif // CONFIG_COLLECT_RD_STATS == 3
model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
@@ -8722,7 +8771,7 @@
* This function combines y and uv planes' transform search processes
* together, when the prediction is generated. It first does subtraction to
* obtain the prediction error. Then it calls
- * select_tx_type_yrd/super_block_yrd and super_block_uvrd sequentially and
+ * pick_tx_size_type_yrd/super_block_yrd and super_block_uvrd sequentially and
* handles the early terminations happening in those functions. At the end, it
* computes the rd_stats/_y/_uv accordingly.
*/