Enable high bit-depth support for tune=vmaf_without_preprocessing

BD-rate gains
basline: tune=psnr
               PSNR      SSIM      VMAF
Lowres_bd10    3.39      3.34     -5.06
Midres_bd10    2.94      3.19     -4.24

Change-Id: Idda56b4bda462ea0b365dc47c6e37d1e50baf1aa
diff --git a/aom_dsp/vmaf.c b/aom_dsp/vmaf.c
index 801fee0..3a012e7 100644
--- a/aom_dsp/vmaf.c
+++ b/aom_dsp/vmaf.c
@@ -15,6 +15,7 @@
 #include <stdlib.h>
 #include <string.h>
 
+#include "aom_dsp/blend.h"
 #include "aom_dsp/vmaf.h"
 #include "aom_ports/system_state.h"
 
@@ -22,6 +23,7 @@
   const YV12_BUFFER_CONFIG *source;
   const YV12_BUFFER_CONFIG *distorted;
   int frame_set;
+  int bit_depth;
 } FrameData;
 
 static void vmaf_fatal_error(const char *message) {
@@ -32,8 +34,8 @@
 // A callback function used to pass data to VMAF.
 // Returns 0 after reading a frame.
 // Returns 2 when there is no more frame to read.
-static int read_frame_8bd(float *ref_data, float *main_data, float *temp_data,
-                          int stride, void *user_data) {
+static int read_frame(float *ref_data, float *main_data, float *temp_data,
+                      int stride, void *user_data) {
   FrameData *frames = (FrameData *)user_data;
 
   if (!frames->frame_set) {
@@ -41,23 +43,46 @@
     const int height = frames->source->y_height;
     assert(width == frames->distorted->y_width);
     assert(height == frames->distorted->y_height);
-    uint8_t *ref_ptr = frames->source->y_buffer;
-    uint8_t *main_ptr = frames->distorted->y_buffer;
 
-    for (int row = 0; row < height; ++row) {
-      for (int col = 0; col < width; ++col) {
-        ref_data[col] = (float)ref_ptr[col];
-      }
-      ref_ptr += frames->source->y_stride;
-      ref_data += stride / sizeof(*ref_data);
-    }
+    if (frames->bit_depth > 8) {
+      const float scale_factor = 1.0f / (float)(1 << (frames->bit_depth - 8));
+      uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(frames->source->y_buffer);
+      uint16_t *main_ptr = CONVERT_TO_SHORTPTR(frames->distorted->y_buffer);
 
-    for (int row = 0; row < height; ++row) {
-      for (int col = 0; col < width; ++col) {
-        main_data[col] = (float)main_ptr[col];
+      for (int row = 0; row < height; ++row) {
+        for (int col = 0; col < width; ++col) {
+          ref_data[col] = scale_factor * (float)ref_ptr[col];
+        }
+        ref_ptr += frames->source->y_stride;
+        ref_data += stride / sizeof(*ref_data);
       }
-      main_ptr += frames->distorted->y_stride;
-      main_data += stride / sizeof(*main_data);
+
+      for (int row = 0; row < height; ++row) {
+        for (int col = 0; col < width; ++col) {
+          main_data[col] = scale_factor * (float)main_ptr[col];
+        }
+        main_ptr += frames->distorted->y_stride;
+        main_data += stride / sizeof(*main_data);
+      }
+    } else {
+      uint8_t *ref_ptr = frames->source->y_buffer;
+      uint8_t *main_ptr = frames->distorted->y_buffer;
+
+      for (int row = 0; row < height; ++row) {
+        for (int col = 0; col < width; ++col) {
+          ref_data[col] = (float)ref_ptr[col];
+        }
+        ref_ptr += frames->source->y_stride;
+        ref_data += stride / sizeof(*ref_data);
+      }
+
+      for (int row = 0; row < height; ++row) {
+        for (int col = 0; col < width; ++col) {
+          main_data[col] = (float)main_ptr[col];
+        }
+        main_ptr += frames->distorted->y_stride;
+        main_data += stride / sizeof(*main_data);
+      }
     }
     frames->frame_set = 1;
     return 0;
@@ -68,39 +93,39 @@
 }
 
 void aom_calc_vmaf(const char *model_path, const YV12_BUFFER_CONFIG *source,
-                   const YV12_BUFFER_CONFIG *distorted, double *vmaf) {
+                   const YV12_BUFFER_CONFIG *distorted, const int bit_depth,
+                   double *const vmaf) {
   aom_clear_system_state();
   const int width = source->y_width;
   const int height = source->y_height;
-  FrameData frames = { source, distorted, 0 };
+  FrameData frames = { source, distorted, 0, bit_depth };
+  char *fmt = bit_depth == 10 ? "yuv420p10le" : "yuv420p";
   double vmaf_score;
-  int (*read_frame)(float *reference_data, float *distorted_data,
-                    float *temp_data, int stride, void *s);
-  read_frame = read_frame_8bd;
   const int ret =
-      compute_vmaf(&vmaf_score, (char *)"yuv420p", width, height, read_frame,
+      compute_vmaf(&vmaf_score, fmt, width, height, read_frame,
                    /*user_data=*/&frames, (char *)model_path,
                    /*log_path=*/NULL, /*log_fmt=*/NULL, /*disable_clip=*/1,
                    /*disable_avx=*/0, /*enable_transform=*/0,
                    /*phone_model=*/0, /*do_psnr=*/0, /*do_ssim=*/0,
                    /*do_ms_ssim=*/0, /*pool_method=*/NULL, /*n_thread=*/0,
                    /*n_subsample=*/1, /*enable_conf_interval=*/0);
+  if (ret) vmaf_fatal_error("Failed to compute VMAF scores.");
 
   aom_clear_system_state();
   *vmaf = vmaf_score;
-  if (ret) vmaf_fatal_error("Failed to compute VMAF scores.");
 }
 
 void aom_calc_vmaf_multi_frame(
     void *user_data, const char *model_path,
     int (*read_frame)(float *ref_data, float *main_data, float *temp_data,
                       int stride_byte, void *user_data),
-    int frame_width, int frame_height, double *vmaf) {
+    int frame_width, int frame_height, int bit_depth, double *vmaf) {
   aom_clear_system_state();
 
+  char *fmt = bit_depth == 10 ? "yuv420p10le" : "yuv420p";
   double vmaf_score;
   const int ret = compute_vmaf(
-      &vmaf_score, (char *)"yuv420p", frame_width, frame_height, read_frame,
+      &vmaf_score, fmt, frame_width, frame_height, read_frame,
       /*user_data=*/user_data, (char *)model_path,
       /*log_path=*/"vmaf_scores.xml", /*log_fmt=*/NULL, /*disable_clip=*/0,
       /*disable_avx=*/0, /*enable_transform=*/0,
diff --git a/aom_dsp/vmaf.h b/aom_dsp/vmaf.h
index 186edeb..fb8bf46 100644
--- a/aom_dsp/vmaf.h
+++ b/aom_dsp/vmaf.h
@@ -15,12 +15,13 @@
 #include "aom_scale/yv12config.h"
 
 void aom_calc_vmaf(const char *model_path, const YV12_BUFFER_CONFIG *source,
-                   const YV12_BUFFER_CONFIG *distorted, double *vmaf);
+                   const YV12_BUFFER_CONFIG *distorted, int bit_depth,
+                   double *vmaf);
 
 void aom_calc_vmaf_multi_frame(
     void *user_data, const char *model_path,
     int (*read_frame)(float *ref_data, float *main_data, float *temp_data,
                       int stride_byte, void *user_data),
-    int frame_width, int frame_height, double *vmaf);
+    int frame_width, int frame_height, int bit_depth, double *vmaf);
 
 #endif  // AOM_AOM_DSP_VMAF_H_
diff --git a/av1/encoder/tune_vmaf.c b/av1/encoder/tune_vmaf.c
index 2f64606..24f03d4 100644
--- a/av1/encoder/tune_vmaf.c
+++ b/av1/encoder/tune_vmaf.c
@@ -11,6 +11,7 @@
 
 #include "av1/encoder/tune_vmaf.h"
 
+#include "aom_dsp/psnr.h"
 #include "aom_dsp/vmaf.h"
 #include "aom_ports/system_state.h"
 #include "av1/encoder/extend.h"
@@ -33,24 +34,6 @@
   }
 }
 
-static AOM_INLINE void unsharp_rect_float(const float *source,
-                                          int source_stride,
-                                          const uint8_t *blurred,
-                                          int blurred_stride, float *dst,
-                                          int dst_stride, int w, int h,
-                                          float amount) {
-  for (int i = 0; i < h; ++i) {
-    for (int j = 0; j < w; ++j) {
-      dst[j] = source[j] + amount * (source[j] - (float)blurred[j]);
-      if (dst[j] < 0.0f) dst[j] = 0.0f;
-      if (dst[j] > 255.0) dst[j] = 255.0f;
-    }
-    source += source_stride;
-    blurred += blurred_stride;
-    dst += dst_stride;
-  }
-}
-
 static AOM_INLINE void unsharp(const YV12_BUFFER_CONFIG *source,
                                const YV12_BUFFER_CONFIG *blurred,
                                const YV12_BUFFER_CONFIG *dst, double amount) {
@@ -67,20 +50,15 @@
                                      const YV12_BUFFER_CONFIG *source,
                                      const YV12_BUFFER_CONFIG *dst) {
   const AV1_COMMON *cm = &cpi->common;
-  const ThreadData *td = &cpi->td;
-  const MACROBLOCK *x = &td->mb;
-  const MACROBLOCKD *xd = &x->e_mbd;
-
+  const int bit_depth = cpi->td.mb.e_mbd.bd;
   const int block_size = BLOCK_128X128;
-
   const int num_mi_w = mi_size_wide[block_size];
   const int num_mi_h = mi_size_high[block_size];
   const int num_cols = (cm->mi_cols + num_mi_w - 1) / num_mi_w;
   const int num_rows = (cm->mi_rows + num_mi_h - 1) / num_mi_h;
   int row, col;
-  const int use_hbd = source->flags & YV12_FLAG_HIGHBITDEPTH;
 
-  ConvolveParams conv_params = get_conv_params(0, 0, xd->bd);
+  ConvolveParams conv_params = get_conv_params(0, 0, bit_depth);
   InterpFilterParams filter = { .filter_ptr = gauss_filter,
                                 .taps = 8,
                                 .subpel_shifts = 0,
@@ -99,11 +77,11 @@
       uint8_t *dst_buf =
           dst->y_buffer + row_offset_y * dst->y_stride + col_offset_y;
 
-      if (use_hbd) {
+      if (bit_depth > 8) {
         av1_highbd_convolve_2d_sr(
             CONVERT_TO_SHORTPTR(src_buf), source->y_stride,
             CONVERT_TO_SHORTPTR(dst_buf), dst->y_stride, num_mi_w << 2,
-            num_mi_h << 2, &filter, &filter, 0, 0, &conv_params, xd->bd);
+            num_mi_h << 2, &filter, &filter, 0, 0, &conv_params, bit_depth);
       } else {
         av1_convolve_2d_sr(src_buf, source->y_stride, dst_buf, dst->y_stride,
                            num_mi_w << 2, num_mi_h << 2, &filter, &filter, 0, 0,
@@ -173,7 +151,7 @@
     best_vmaf = approx_vmaf;
     unsharp_amount += step_size;
     unsharp(source, blurred, &sharpened, unsharp_amount);
-    aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, &sharpened, &new_vmaf);
+    aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, &sharpened, 8, &new_vmaf);
     const double sharpened_var = frame_average_variance(cpi, &sharpened);
     approx_vmaf =
         baseline_variance / sharpened_var * (new_vmaf - baseline_vmaf);
@@ -340,28 +318,9 @@
   aom_clear_system_state();
 }
 
-// TODO(sdeng): replace it with the SIMD version.
-static AOM_INLINE double image_sse_c(const uint8_t *src, int src_stride,
-                                     const uint8_t *ref, int ref_stride, int w,
-                                     int h) {
-  double accum = 0.0;
-  int i, j;
-
-  for (i = 0; i < h; ++i) {
-    for (j = 0; j < w; ++j) {
-      double img1px = src[i * src_stride + j];
-      double img2px = ref[i * ref_stride + j];
-
-      accum += (img1px - img2px) * (img1px - img2px);
-    }
-  }
-
-  return accum;
-}
-
 typedef struct FrameData {
   const YV12_BUFFER_CONFIG *source, *blurred;
-  int block_w, block_h, num_rows, num_cols, row, col;
+  int block_w, block_h, num_rows, num_cols, row, col, bit_depth;
 } FrameData;
 
 // A callback function used to pass data to VMAF.
@@ -380,17 +339,27 @@
   const int block_h = frames->block_h;
   const YV12_BUFFER_CONFIG *source = frames->source;
   const YV12_BUFFER_CONFIG *blurred = frames->blurred;
+  const int bit_depth = frames->bit_depth;
+  const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
   (void)temp_data;
   stride /= (int)sizeof(*ref_data);
 
   for (int i = 0; i < height; ++i) {
     float *ref, *main;
-    uint8_t *src;
     ref = ref_data + i * stride;
     main = main_data + i * stride;
-    src = source->y_buffer + i * source->y_stride;
-    for (int j = 0; j < width; ++j) {
-      ref[j] = main[j] = (float)src[j];
+    if (bit_depth == 8) {
+      uint8_t *src;
+      src = source->y_buffer + i * source->y_stride;
+      for (int j = 0; j < width; ++j) {
+        ref[j] = main[j] = (float)src[j];
+      }
+    } else {
+      uint16_t *src;
+      src = CONVERT_TO_SHORTPTR(source->y_buffer) + i * source->y_stride;
+      for (int j = 0; j < width; ++j) {
+        ref[j] = main[j] = scale_factor * (float)src[j];
+      }
     }
   }
   if (row < 0 && col < 0) {
@@ -405,12 +374,27 @@
     const int block_height = AOMMIN(height - row_offset, block_h);
 
     float *main_buf = main_data + col_offset + row_offset * stride;
-    float *ref_buf = ref_data + col_offset + row_offset * stride;
-    uint8_t *blurred_buf =
-        blurred->y_buffer + row_offset * blurred->y_stride + col_offset;
-
-    unsharp_rect_float(ref_buf, stride, blurred_buf, blurred->y_stride,
-                       main_buf, stride, block_width, block_height, -1.0f);
+    if (bit_depth == 8) {
+      uint8_t *blurred_buf =
+          blurred->y_buffer + row_offset * blurred->y_stride + col_offset;
+      for (int i = 0; i < block_height; ++i) {
+        for (int j = 0; j < block_width; ++j) {
+          main_buf[j] = (float)blurred_buf[j];
+        }
+        main_buf += stride;
+        blurred_buf += blurred->y_stride;
+      }
+    } else {
+      uint16_t *blurred_buf = CONVERT_TO_SHORTPTR(blurred->y_buffer) +
+                              row_offset * blurred->y_stride + col_offset;
+      for (int i = 0; i < block_height; ++i) {
+        for (int j = 0; j < block_width; ++j) {
+          main_buf[j] = scale_factor * (float)blurred_buf[j];
+        }
+        main_buf += stride;
+        blurred_buf += blurred->y_stride;
+      }
+    }
 
     frames->col++;
     if (frames->col >= num_cols) {
@@ -425,9 +409,6 @@
 
 void av1_set_mb_vmaf_rdmult_scaling(AV1_COMP *cpi) {
   AV1_COMMON *cm = &cpi->common;
-  ThreadData *td = &cpi->td;
-  MACROBLOCK *x = &td->mb;
-  MACROBLOCKD *xd = &x->e_mbd;
   uint8_t *const y_buffer = cpi->source->y_buffer;
   const int y_stride = cpi->source->y_stride;
   const int y_width = cpi->source->y_width;
@@ -440,12 +421,7 @@
   const int num_rows = (cm->mi_rows + num_mi_h - 1) / num_mi_h;
   const int block_w = num_mi_w << 2;
   const int block_h = num_mi_h << 2;
-  const int use_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
-  // TODO(sdeng): Add high bit depth support.
-  if (use_hbd) {
-    printf("VMAF RDO for high bit depth videos is unsupported yet.\n");
-    exit(0);
-  }
+  const int bit_depth = cpi->td.mb.e_mbd.bd;
 
   aom_clear_system_state();
   YV12_BUFFER_CONFIG blurred;
@@ -466,9 +442,10 @@
   frame_data.num_cols = num_cols;
   frame_data.row = -1;
   frame_data.col = -1;
+  frame_data.bit_depth = bit_depth;
   aom_calc_vmaf_multi_frame(&frame_data, cpi->oxcf.vmaf_model_path,
-                            update_frame, y_width, y_height, scores);
-  const double baseline_mse = 0.0, baseline_vmaf = scores[0];
+                            update_frame, y_width, y_height, bit_depth, scores);
+  const double baseline_vmaf = scores[0];
 
   // Loop through each 'block_size' block.
   for (int row = 0; row < num_rows; ++row) {
@@ -484,24 +461,20 @@
       uint8_t *const blurred_buf =
           blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
 
-      const int block_width = AOMMIN(y_width - col_offset_y, block_w);
-      const int block_height = AOMMIN(y_height - row_offset_y, block_h);
-
       const double vmaf = scores[index + 1];
       const double dvmaf = baseline_vmaf - vmaf;
+      unsigned int sse;
+      cpi->fn_ptr[block_size].vf(orig_buf, y_stride, blurred_buf,
+                                 blurred.y_stride, &sse);
 
-      const double mse =
-          image_sse_c(orig_buf, y_stride, blurred_buf, blurred.y_stride,
-                      block_width, block_height) /
-          (double)(y_width * y_height);
-      const double dmse = mse - baseline_mse;
+      const double mse = (double)sse / (double)(y_width * y_height);
 
-      double weight = 0.0;
+      double weight;
       const double eps = 0.01 / (num_rows * num_cols);
-      if (dvmaf < eps || dmse < eps) {
+      if (dvmaf < eps || mse < eps) {
         weight = 1.0;
       } else {
-        weight = dmse / dvmaf;
+        weight = mse / dvmaf;
       }
 
       // Normalize it with a data fitted model.
@@ -513,7 +486,6 @@
   aom_free_frame_buffer(&blurred);
   aom_free(scores);
   aom_clear_system_state();
-  (void)xd;
 }
 
 void av1_set_vmaf_rdmult(const AV1_COMP *const cpi, MACROBLOCK *const x,
@@ -550,7 +522,25 @@
   aom_clear_system_state();
 }
 
-// TODO(sdeng): replace it with the SIMD version.
+// TODO(sdeng): replace them with the SIMD versions.
+static AOM_INLINE double highbd_image_sad_c(const uint16_t *src, int src_stride,
+                                            const uint16_t *ref, int ref_stride,
+                                            int w, int h) {
+  double accum = 0.0;
+  int i, j;
+
+  for (i = 0; i < h; ++i) {
+    for (j = 0; j < w; ++j) {
+      double img1px = src[i * src_stride + j];
+      double img2px = ref[i * ref_stride + j];
+
+      accum += fabs(img1px - img2px);
+    }
+  }
+
+  return accum / (double)(h * w);
+}
+
 static AOM_INLINE double image_sad_c(const uint8_t *src, int src_stride,
                                      const uint8_t *ref, int ref_stride, int w,
                                      int h) {
@@ -576,6 +566,7 @@
   const int y_width = cur->y_width;
   const int y_height = cur->y_height;
   YV12_BUFFER_CONFIG blurred_cur, blurred_last, blurred_next;
+  const int bit_depth = cpi->td.mb.e_mbd.bd;
 
   memset(&blurred_cur, 0, sizeof(blurred_cur));
   memset(&blurred_last, 0, sizeof(blurred_last));
@@ -596,13 +587,30 @@
   if (next) gaussian_blur(cpi, next, &blurred_next);
 
   double motion1, motion2 = 65536.0;
-  motion1 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
-                        blurred_last.y_buffer, blurred_last.y_stride, y_width,
-                        y_height);
-  if (next) {
-    motion2 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
-                          blurred_next.y_buffer, blurred_next.y_stride, y_width,
+
+  if (bit_depth > 8) {
+    const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
+    motion1 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
+                                 blurred_cur.y_stride,
+                                 CONVERT_TO_SHORTPTR(blurred_last.y_buffer),
+                                 blurred_last.y_stride, y_width, y_height) *
+              scale_factor;
+    if (next) {
+      motion2 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
+                                   blurred_cur.y_stride,
+                                   CONVERT_TO_SHORTPTR(blurred_next.y_buffer),
+                                   blurred_next.y_stride, y_width, y_height) *
+                scale_factor;
+    }
+  } else {
+    motion1 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
+                          blurred_last.y_buffer, blurred_last.y_stride, y_width,
                           y_height);
+    if (next) {
+      motion2 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
+                            blurred_next.y_buffer, blurred_next.y_stride,
+                            y_width, y_height);
+    }
   }
 
   aom_free_frame_buffer(&blurred_cur);
@@ -616,16 +624,14 @@
 // observation: when the motion score becomes higher, the VMAF score of the
 // same source and distorted frames would become higher.
 int av1_get_vmaf_base_qindex(const AV1_COMP *const cpi, int current_qindex) {
-  const int use_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
-  if (use_hbd) {
-    printf("Tune for VMAF for high bit depth videos is unsupported yet.\n");
-    exit(0);
-  }
   const AV1_COMMON *const cm = &cpi->common;
   if (cm->current_frame.frame_number == 0 || cpi->oxcf.pass == 1) {
     return current_qindex;
   }
-  const double approx_sse = cpi->last_frame_ysse;
+  const int bit_depth = cpi->td.mb.e_mbd.bd;
+  const double approx_sse =
+      cpi->last_frame_ysse /
+      (double)((1 << (bit_depth - 8)) * (1 << (bit_depth - 8)));
   const double approx_dvmaf = cpi->last_frame_bvmaf - cpi->last_frame_vmaf;
   const double sse_threshold =
       0.01 * cpi->source->y_width * cpi->source->y_height;
@@ -643,7 +649,6 @@
         av1_lookahead_peek(cpi->lookahead, src_index, cpi->compressor_stage);
     cur_buf = &cur_entry->img;
   }
-
   assert(cur_buf);
 
   const struct lookahead_entry *last_entry =
@@ -676,11 +681,14 @@
 
 void av1_update_vmaf_curve(AV1_COMP *cpi, YV12_BUFFER_CONFIG *source,
                            YV12_BUFFER_CONFIG *recon) {
-  aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, source,
+  const int bit_depth = cpi->td.mb.e_mbd.bd;
+  aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, source, bit_depth,
                 &cpi->last_frame_bvmaf);
-  aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, recon,
+  aom_calc_vmaf(cpi->oxcf.vmaf_model_path, source, recon, bit_depth,
                 &cpi->last_frame_vmaf);
-  cpi->last_frame_ysse =
-      image_sse_c(source->y_buffer, source->y_stride, recon->y_buffer,
-                  recon->y_stride, source->y_width, source->y_height);
+  if (bit_depth > 8) {
+    cpi->last_frame_ysse = (double)aom_highbd_get_y_sse(source, recon);
+  } else {
+    cpi->last_frame_ysse = (double)aom_get_y_sse(source, recon);
+  }
 }