Enable high bit-depth support for tune=vmaf_without_preprocessing
BD-rate gains
basline: tune=psnr
PSNR SSIM VMAF
Lowres_bd10 3.39 3.34 -5.06
Midres_bd10 2.94 3.19 -4.24
Change-Id: Idda56b4bda462ea0b365dc47c6e37d1e50baf1aa
diff --git a/aom_dsp/vmaf.c b/aom_dsp/vmaf.c
index 801fee0..3a012e7 100644
--- a/aom_dsp/vmaf.c
+++ b/aom_dsp/vmaf.c
@@ -15,6 +15,7 @@
#include <stdlib.h>
#include <string.h>
+#include "aom_dsp/blend.h"
#include "aom_dsp/vmaf.h"
#include "aom_ports/system_state.h"
@@ -22,6 +23,7 @@
const YV12_BUFFER_CONFIG *source;
const YV12_BUFFER_CONFIG *distorted;
int frame_set;
+ int bit_depth;
} FrameData;
static void vmaf_fatal_error(const char *message) {
@@ -32,8 +34,8 @@
// A callback function used to pass data to VMAF.
// Returns 0 after reading a frame.
// Returns 2 when there is no more frame to read.
-static int read_frame_8bd(float *ref_data, float *main_data, float *temp_data,
- int stride, void *user_data) {
+static int read_frame(float *ref_data, float *main_data, float *temp_data,
+ int stride, void *user_data) {
FrameData *frames = (FrameData *)user_data;
if (!frames->frame_set) {
@@ -41,23 +43,46 @@
const int height = frames->source->y_height;
assert(width == frames->distorted->y_width);
assert(height == frames->distorted->y_height);
- uint8_t *ref_ptr = frames->source->y_buffer;
- uint8_t *main_ptr = frames->distorted->y_buffer;
- for (int row = 0; row < height; ++row) {
- for (int col = 0; col < width; ++col) {
- ref_data[col] = (float)ref_ptr[col];
- }
- ref_ptr += frames->source->y_stride;
- ref_data += stride / sizeof(*ref_data);
- }
+ if (frames->bit_depth > 8) {
+ const float scale_factor = 1.0f / (float)(1 << (frames->bit_depth - 8));
+ uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(frames->source->y_buffer);
+ uint16_t *main_ptr = CONVERT_TO_SHORTPTR(frames->distorted->y_buffer);
- for (int row = 0; row < height; ++row) {
- for (int col = 0; col < width; ++col) {
- main_data[col] = (float)main_ptr[col];
+ for (int row = 0; row < height; ++row) {
+ for (int col = 0; col < width; ++col) {
+ ref_data[col] = scale_factor * (float)ref_ptr[col];
+ }
+ ref_ptr += frames->source->y_stride;
+ ref_data += stride / sizeof(*ref_data);
}
- main_ptr += frames->distorted->y_stride;
- main_data += stride / sizeof(*main_data);
+
+ for (int row = 0; row < height; ++row) {
+ for (int col = 0; col < width; ++col) {
+ main_data[col] = scale_factor * (float)main_ptr[col];
+ }
+ main_ptr += frames->distorted->y_stride;
+ main_data += stride / sizeof(*main_data);
+ }
+ } else {
+ uint8_t *ref_ptr = frames->source->y_buffer;
+ uint8_t *main_ptr = frames->distorted->y_buffer;
+
+ for (int row = 0; row < height; ++row) {
+ for (int col = 0; col < width; ++col) {
+ ref_data[col] = (float)ref_ptr[col];
+ }
+ ref_ptr += frames->source->y_stride;
+ ref_data += stride / sizeof(*ref_data);
+ }
+
+ for (int row = 0; row < height; ++row) {
+ for (int col = 0; col < width; ++col) {
+ main_data[col] = (float)main_ptr[col];
+ }
+ main_ptr += frames->distorted->y_stride;
+ main_data += stride / sizeof(*main_data);
+ }
}
frames->frame_set = 1;
return 0;
@@ -68,39 +93,39 @@
}
void aom_calc_vmaf(const char *model_path, const YV12_BUFFER_CONFIG *source,
- const YV12_BUFFER_CONFIG *distorted, double *vmaf) {
+ const YV12_BUFFER_CONFIG *distorted, const int bit_depth,
+ double *const vmaf) {
aom_clear_system_state();
const int width = source->y_width;
const int height = source->y_height;
- FrameData frames = { source, distorted, 0 };
+ FrameData frames = { source, distorted, 0, bit_depth };
+ char *fmt = bit_depth == 10 ? "yuv420p10le" : "yuv420p";
double vmaf_score;
- int (*read_frame)(float *reference_data, float *distorted_data,
- float *temp_data, int stride, void *s);
- read_frame = read_frame_8bd;
const int ret =
- compute_vmaf(&vmaf_score, (char *)"yuv420p", width, height, read_frame,
+ compute_vmaf(&vmaf_score, fmt, width, height, read_frame,
/*user_data=*/&frames, (char *)model_path,
/*log_path=*/NULL, /*log_fmt=*/NULL, /*disable_clip=*/1,
/*disable_avx=*/0, /*enable_transform=*/0,
/*phone_model=*/0, /*do_psnr=*/0, /*do_ssim=*/0,
/*do_ms_ssim=*/0, /*pool_method=*/NULL, /*n_thread=*/0,
/*n_subsample=*/1, /*enable_conf_interval=*/0);
+ if (ret) vmaf_fatal_error("Failed to compute VMAF scores.");
aom_clear_system_state();
*vmaf = vmaf_score;
- if (ret) vmaf_fatal_error("Failed to compute VMAF scores.");
}
void aom_calc_vmaf_multi_frame(
void *user_data, const char *model_path,
int (*read_frame)(float *ref_data, float *main_data, float *temp_data,
int stride_byte, void *user_data),
- int frame_width, int frame_height, double *vmaf) {
+ int frame_width, int frame_height, int bit_depth, double *vmaf) {
aom_clear_system_state();
+ char *fmt = bit_depth == 10 ? "yuv420p10le" : "yuv420p";
double vmaf_score;
const int ret = compute_vmaf(
- &vmaf_score, (char *)"yuv420p", frame_width, frame_height, read_frame,
+ &vmaf_score, fmt, frame_width, frame_height, read_frame,
/*user_data=*/user_data, (char *)model_path,
/*log_path=*/"vmaf_scores.xml", /*log_fmt=*/NULL, /*disable_clip=*/0,
/*disable_avx=*/0, /*enable_transform=*/0,
diff --git a/aom_dsp/vmaf.h b/aom_dsp/vmaf.h
index 186edeb..fb8bf46 100644
--- a/aom_dsp/vmaf.h
+++ b/aom_dsp/vmaf.h
@@ -15,12 +15,13 @@
#include "aom_scale/yv12config.h"
void aom_calc_vmaf(const char *model_path, const YV12_BUFFER_CONFIG *source,
- const YV12_BUFFER_CONFIG *distorted, double *vmaf);
+ const YV12_BUFFER_CONFIG *distorted, int bit_depth,
+ double *vmaf);
void aom_calc_vmaf_multi_frame(
void *user_data, const char *model_path,
int (*read_frame)(float *ref_data, float *main_data, float *temp_data,
int stride_byte, void *user_data),
- int frame_width, int frame_height, double *vmaf);
+ int frame_width, int frame_height, int bit_depth, double *vmaf);
#endif // AOM_AOM_DSP_VMAF_H_
diff --git a/av1/encoder/tune_vmaf.c b/av1/encoder/tune_vmaf.c
index 2f64606..24f03d4 100644
--- a/av1/encoder/tune_vmaf.c
+++ b/av1/encoder/tune_vmaf.c
@@ -11,6 +11,7 @@
#include "av1/encoder/tune_vmaf.h"
+#include "aom_dsp/psnr.h"
#include "aom_dsp/vmaf.h"
#include "aom_ports/system_state.h"
#include "av1/encoder/extend.h"
@@ -33,24 +34,6 @@
}
}
-static AOM_INLINE void unsharp_rect_float(const float *source,
- int source_stride,
- const uint8_t *blurred,
- int blurred_stride, float *dst,
- int dst_stride, int w, int h,
- float amount) {
- for (int i = 0; i < h; ++i) {
- for (int j = 0; j < w; ++j) {
- dst[j] = source[j] + amount * (source[j] - (float)blurred[j]);
- if (dst[j] < 0.0f) dst[j] = 0.0f;
- if (dst[j] > 255.0) dst[j] = 255.0f;
- }
- source += source_stride;
- blurred += blurred_stride;
- dst += dst_stride;
- }
-}
-
static AOM_INLINE void unsharp(const YV12_BUFFER_CONFIG *source,
const YV12_BUFFER_CONFIG *blurred,
const YV12_BUFFER_CONFIG *dst, double amount) {
@@ -67,20 +50,15 @@
const YV12_BUFFER_CONFIG *source,
const YV12_BUFFER_CONFIG *dst) {
const AV1_COMMON *cm = &cpi->common;
- const ThreadData *td = &cpi->td;
- const MACROBLOCK *x = &td->mb;
- const MACROBLOCKD *xd = &x->e_mbd;
-
+ const int bit_depth = cpi->td.mb.e_mbd.bd;
const int block_size = BLOCK_128X128;
-
const int num_mi_w = mi_size_wide[block_size];
const int num_mi_h = mi_size_high[block_size];
const int num_cols = (cm->mi_cols + num_mi_w - 1) / num_mi_w;
const int num_rows = (cm->mi_rows + num_mi_h - 1) / num_mi_h;
int row, col;
- const int use_hbd = source->flags & YV12_FLAG_HIGHBITDEPTH;
- ConvolveParams conv_params = get_conv_params(0, 0, xd->bd);
+ ConvolveParams conv_params = get_conv_params(0, 0, bit_depth);
InterpFilterParams filter = { .filter_ptr = gauss_filter,
.taps = 8,
.subpel_shifts = 0,
@@ -99,11 +77,11 @@
uint8_t *dst_buf =
dst->y_buffer + row_offset_y * dst->y_stride + col_offset_y;
- if (use_hbd) {
+ if (bit_depth > 8) {
av1_highbd_convolve_2d_sr(
CONVERT_TO_SHORTPTR(src_buf), source->y_stride,
CONVERT_TO_SHORTPTR(dst_buf), dst->y_stride, num_mi_w << 2,
- num_mi_h << 2, &filter, &filter, 0, 0, &conv_params, xd->bd);
+ num_mi_h << 2, &filter, &filter, 0, 0, &conv_params, bit_depth);
} else {
av1_convolve_2d_sr(src_buf, source->y_stride, dst_buf, dst->y_stride,
num_mi_w << 2, num_mi_h << 2, &filter, &filter, 0, 0,
@@ -173,7 +151,7 @@
best_vmaf = approx_vmaf;
unsharp_amount += step_size;
unsharp(source, blurred, &sharpened, unsharp_amount);
- aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, &sharpened, &new_vmaf);
+ aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, &sharpened, 8, &new_vmaf);
const double sharpened_var = frame_average_variance(cpi, &sharpened);
approx_vmaf =
baseline_variance / sharpened_var * (new_vmaf - baseline_vmaf);
@@ -340,28 +318,9 @@
aom_clear_system_state();
}
-// TODO(sdeng): replace it with the SIMD version.
-static AOM_INLINE double image_sse_c(const uint8_t *src, int src_stride,
- const uint8_t *ref, int ref_stride, int w,
- int h) {
- double accum = 0.0;
- int i, j;
-
- for (i = 0; i < h; ++i) {
- for (j = 0; j < w; ++j) {
- double img1px = src[i * src_stride + j];
- double img2px = ref[i * ref_stride + j];
-
- accum += (img1px - img2px) * (img1px - img2px);
- }
- }
-
- return accum;
-}
-
typedef struct FrameData {
const YV12_BUFFER_CONFIG *source, *blurred;
- int block_w, block_h, num_rows, num_cols, row, col;
+ int block_w, block_h, num_rows, num_cols, row, col, bit_depth;
} FrameData;
// A callback function used to pass data to VMAF.
@@ -380,17 +339,27 @@
const int block_h = frames->block_h;
const YV12_BUFFER_CONFIG *source = frames->source;
const YV12_BUFFER_CONFIG *blurred = frames->blurred;
+ const int bit_depth = frames->bit_depth;
+ const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
(void)temp_data;
stride /= (int)sizeof(*ref_data);
for (int i = 0; i < height; ++i) {
float *ref, *main;
- uint8_t *src;
ref = ref_data + i * stride;
main = main_data + i * stride;
- src = source->y_buffer + i * source->y_stride;
- for (int j = 0; j < width; ++j) {
- ref[j] = main[j] = (float)src[j];
+ if (bit_depth == 8) {
+ uint8_t *src;
+ src = source->y_buffer + i * source->y_stride;
+ for (int j = 0; j < width; ++j) {
+ ref[j] = main[j] = (float)src[j];
+ }
+ } else {
+ uint16_t *src;
+ src = CONVERT_TO_SHORTPTR(source->y_buffer) + i * source->y_stride;
+ for (int j = 0; j < width; ++j) {
+ ref[j] = main[j] = scale_factor * (float)src[j];
+ }
}
}
if (row < 0 && col < 0) {
@@ -405,12 +374,27 @@
const int block_height = AOMMIN(height - row_offset, block_h);
float *main_buf = main_data + col_offset + row_offset * stride;
- float *ref_buf = ref_data + col_offset + row_offset * stride;
- uint8_t *blurred_buf =
- blurred->y_buffer + row_offset * blurred->y_stride + col_offset;
-
- unsharp_rect_float(ref_buf, stride, blurred_buf, blurred->y_stride,
- main_buf, stride, block_width, block_height, -1.0f);
+ if (bit_depth == 8) {
+ uint8_t *blurred_buf =
+ blurred->y_buffer + row_offset * blurred->y_stride + col_offset;
+ for (int i = 0; i < block_height; ++i) {
+ for (int j = 0; j < block_width; ++j) {
+ main_buf[j] = (float)blurred_buf[j];
+ }
+ main_buf += stride;
+ blurred_buf += blurred->y_stride;
+ }
+ } else {
+ uint16_t *blurred_buf = CONVERT_TO_SHORTPTR(blurred->y_buffer) +
+ row_offset * blurred->y_stride + col_offset;
+ for (int i = 0; i < block_height; ++i) {
+ for (int j = 0; j < block_width; ++j) {
+ main_buf[j] = scale_factor * (float)blurred_buf[j];
+ }
+ main_buf += stride;
+ blurred_buf += blurred->y_stride;
+ }
+ }
frames->col++;
if (frames->col >= num_cols) {
@@ -425,9 +409,6 @@
void av1_set_mb_vmaf_rdmult_scaling(AV1_COMP *cpi) {
AV1_COMMON *cm = &cpi->common;
- ThreadData *td = &cpi->td;
- MACROBLOCK *x = &td->mb;
- MACROBLOCKD *xd = &x->e_mbd;
uint8_t *const y_buffer = cpi->source->y_buffer;
const int y_stride = cpi->source->y_stride;
const int y_width = cpi->source->y_width;
@@ -440,12 +421,7 @@
const int num_rows = (cm->mi_rows + num_mi_h - 1) / num_mi_h;
const int block_w = num_mi_w << 2;
const int block_h = num_mi_h << 2;
- const int use_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
- // TODO(sdeng): Add high bit depth support.
- if (use_hbd) {
- printf("VMAF RDO for high bit depth videos is unsupported yet.\n");
- exit(0);
- }
+ const int bit_depth = cpi->td.mb.e_mbd.bd;
aom_clear_system_state();
YV12_BUFFER_CONFIG blurred;
@@ -466,9 +442,10 @@
frame_data.num_cols = num_cols;
frame_data.row = -1;
frame_data.col = -1;
+ frame_data.bit_depth = bit_depth;
aom_calc_vmaf_multi_frame(&frame_data, cpi->oxcf.vmaf_model_path,
- update_frame, y_width, y_height, scores);
- const double baseline_mse = 0.0, baseline_vmaf = scores[0];
+ update_frame, y_width, y_height, bit_depth, scores);
+ const double baseline_vmaf = scores[0];
// Loop through each 'block_size' block.
for (int row = 0; row < num_rows; ++row) {
@@ -484,24 +461,20 @@
uint8_t *const blurred_buf =
blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
- const int block_width = AOMMIN(y_width - col_offset_y, block_w);
- const int block_height = AOMMIN(y_height - row_offset_y, block_h);
-
const double vmaf = scores[index + 1];
const double dvmaf = baseline_vmaf - vmaf;
+ unsigned int sse;
+ cpi->fn_ptr[block_size].vf(orig_buf, y_stride, blurred_buf,
+ blurred.y_stride, &sse);
- const double mse =
- image_sse_c(orig_buf, y_stride, blurred_buf, blurred.y_stride,
- block_width, block_height) /
- (double)(y_width * y_height);
- const double dmse = mse - baseline_mse;
+ const double mse = (double)sse / (double)(y_width * y_height);
- double weight = 0.0;
+ double weight;
const double eps = 0.01 / (num_rows * num_cols);
- if (dvmaf < eps || dmse < eps) {
+ if (dvmaf < eps || mse < eps) {
weight = 1.0;
} else {
- weight = dmse / dvmaf;
+ weight = mse / dvmaf;
}
// Normalize it with a data fitted model.
@@ -513,7 +486,6 @@
aom_free_frame_buffer(&blurred);
aom_free(scores);
aom_clear_system_state();
- (void)xd;
}
void av1_set_vmaf_rdmult(const AV1_COMP *const cpi, MACROBLOCK *const x,
@@ -550,7 +522,25 @@
aom_clear_system_state();
}
-// TODO(sdeng): replace it with the SIMD version.
+// TODO(sdeng): replace them with the SIMD versions.
+static AOM_INLINE double highbd_image_sad_c(const uint16_t *src, int src_stride,
+ const uint16_t *ref, int ref_stride,
+ int w, int h) {
+ double accum = 0.0;
+ int i, j;
+
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) {
+ double img1px = src[i * src_stride + j];
+ double img2px = ref[i * ref_stride + j];
+
+ accum += fabs(img1px - img2px);
+ }
+ }
+
+ return accum / (double)(h * w);
+}
+
static AOM_INLINE double image_sad_c(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride, int w,
int h) {
@@ -576,6 +566,7 @@
const int y_width = cur->y_width;
const int y_height = cur->y_height;
YV12_BUFFER_CONFIG blurred_cur, blurred_last, blurred_next;
+ const int bit_depth = cpi->td.mb.e_mbd.bd;
memset(&blurred_cur, 0, sizeof(blurred_cur));
memset(&blurred_last, 0, sizeof(blurred_last));
@@ -596,13 +587,30 @@
if (next) gaussian_blur(cpi, next, &blurred_next);
double motion1, motion2 = 65536.0;
- motion1 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
- blurred_last.y_buffer, blurred_last.y_stride, y_width,
- y_height);
- if (next) {
- motion2 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
- blurred_next.y_buffer, blurred_next.y_stride, y_width,
+
+ if (bit_depth > 8) {
+ const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
+ motion1 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
+ blurred_cur.y_stride,
+ CONVERT_TO_SHORTPTR(blurred_last.y_buffer),
+ blurred_last.y_stride, y_width, y_height) *
+ scale_factor;
+ if (next) {
+ motion2 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
+ blurred_cur.y_stride,
+ CONVERT_TO_SHORTPTR(blurred_next.y_buffer),
+ blurred_next.y_stride, y_width, y_height) *
+ scale_factor;
+ }
+ } else {
+ motion1 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
+ blurred_last.y_buffer, blurred_last.y_stride, y_width,
y_height);
+ if (next) {
+ motion2 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
+ blurred_next.y_buffer, blurred_next.y_stride,
+ y_width, y_height);
+ }
}
aom_free_frame_buffer(&blurred_cur);
@@ -616,16 +624,14 @@
// observation: when the motion score becomes higher, the VMAF score of the
// same source and distorted frames would become higher.
int av1_get_vmaf_base_qindex(const AV1_COMP *const cpi, int current_qindex) {
- const int use_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
- if (use_hbd) {
- printf("Tune for VMAF for high bit depth videos is unsupported yet.\n");
- exit(0);
- }
const AV1_COMMON *const cm = &cpi->common;
if (cm->current_frame.frame_number == 0 || cpi->oxcf.pass == 1) {
return current_qindex;
}
- const double approx_sse = cpi->last_frame_ysse;
+ const int bit_depth = cpi->td.mb.e_mbd.bd;
+ const double approx_sse =
+ cpi->last_frame_ysse /
+ (double)((1 << (bit_depth - 8)) * (1 << (bit_depth - 8)));
const double approx_dvmaf = cpi->last_frame_bvmaf - cpi->last_frame_vmaf;
const double sse_threshold =
0.01 * cpi->source->y_width * cpi->source->y_height;
@@ -643,7 +649,6 @@
av1_lookahead_peek(cpi->lookahead, src_index, cpi->compressor_stage);
cur_buf = &cur_entry->img;
}
-
assert(cur_buf);
const struct lookahead_entry *last_entry =
@@ -676,11 +681,14 @@
void av1_update_vmaf_curve(AV1_COMP *cpi, YV12_BUFFER_CONFIG *source,
YV12_BUFFER_CONFIG *recon) {
- aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, source,
+ const int bit_depth = cpi->td.mb.e_mbd.bd;
+ aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, source, bit_depth,
&cpi->last_frame_bvmaf);
- aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, recon,
+ aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, recon, bit_depth,
&cpi->last_frame_vmaf);
- cpi->last_frame_ysse =
- image_sse_c(source->y_buffer, source->y_stride, recon->y_buffer,
- recon->y_stride, source->y_width, source->y_height);
+ if (bit_depth > 8) {
+ cpi->last_frame_ysse = (double)aom_highbd_get_y_sse(source, recon);
+ } else {
+ cpi->last_frame_ysse = (double)aom_get_y_sse(source, recon);
+ }
}