Enable high bit-depth support for tune=vmaf

BD-rate gains
basline: tune=psnr
                VMAF
Lowres_bd10    -26.9%
Midres_bd10    -33.1%

Change-Id: Idc3e2d1027135f17e5da3ef25325665bc7391d7d
diff --git a/av1/encoder/tune_vmaf.c b/av1/encoder/tune_vmaf.c
index e7f7f93..c8ebe59 100644
--- a/av1/encoder/tune_vmaf.c
+++ b/av1/encoder/tune_vmaf.c
@@ -225,18 +225,12 @@
 
 void av1_vmaf_blk_preprocessing(const AV1_COMP *const cpi,
                                 YV12_BUFFER_CONFIG *const source) {
-  const int use_hbd = source->flags & YV12_FLAG_HIGHBITDEPTH;
-  // TODO(sdeng): Add high bit depth support.
-  if (use_hbd) {
-    printf(
-        "VMAF preprocessing for high bit depth videos is unsupported yet.\n");
-    exit(0);
-  }
-
   aom_clear_system_state();
   const AV1_COMMON *const cm = &cpi->common;
   const int width = source->y_width;
   const int height = source->y_height;
+  const int bit_depth = cpi->td.mb.e_mbd.bd;
+
   YV12_BUFFER_CONFIG source_extended, blurred;
   memset(&blurred, 0, sizeof(blurred));
   memset(&source_extended, 0, sizeof(source_extended));
@@ -284,28 +278,56 @@
       const int block_height = AOMMIN(height - row_offset_y, block_h);
       const int index = col + row * num_cols;
 
-      uint8_t *frame_src_buf =
-          source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
-      uint8_t *frame_blurred_buf =
-          blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
-      uint8_t *blurred_dst = blurred_block.y_buffer;
-      uint8_t *src_dst = source_block.y_buffer;
+      if (bit_depth > 8) {
+        uint16_t *frame_src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
+                                  row_offset_y * source->y_stride +
+                                  col_offset_y;
+        uint16_t *frame_blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
+                                      row_offset_y * blurred.y_stride +
+                                      col_offset_y;
+        uint16_t *blurred_dst = CONVERT_TO_SHORTPTR(blurred_block.y_buffer);
+        uint16_t *src_dst = CONVERT_TO_SHORTPTR(source_block.y_buffer);
 
-      // Copy block from source frame.
-      for (int i = 0; i < block_h; ++i) {
-        for (int j = 0; j < block_w; ++j) {
-          if (i >= block_height || j >= block_width) {
-            src_dst[j] = 0;
-            blurred_dst[j] = 0;
-          } else {
-            src_dst[j] = frame_src_buf[j];
-            blurred_dst[j] = frame_blurred_buf[j];
+        // Copy block from source frame.
+        for (int i = 0; i < block_h; ++i) {
+          for (int j = 0; j < block_w; ++j) {
+            if (i >= block_height || j >= block_width) {
+              src_dst[j] = 0;
+              blurred_dst[j] = 0;
+            } else {
+              src_dst[j] = frame_src_buf[j];
+              blurred_dst[j] = frame_blurred_buf[j];
+            }
           }
+          frame_src_buf += source->y_stride;
+          frame_blurred_buf += blurred.y_stride;
+          src_dst += source_block.y_stride;
+          blurred_dst += blurred_block.y_stride;
         }
-        frame_src_buf += source->y_stride;
-        frame_blurred_buf += blurred.y_stride;
-        src_dst += source_block.y_stride;
-        blurred_dst += blurred_block.y_stride;
+      } else {
+        uint8_t *frame_src_buf =
+            source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
+        uint8_t *frame_blurred_buf =
+            blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
+        uint8_t *blurred_dst = blurred_block.y_buffer;
+        uint8_t *src_dst = source_block.y_buffer;
+
+        // Copy block from source frame.
+        for (int i = 0; i < block_h; ++i) {
+          for (int j = 0; j < block_w; ++j) {
+            if (i >= block_height || j >= block_width) {
+              src_dst[j] = 0;
+              blurred_dst[j] = 0;
+            } else {
+              src_dst[j] = frame_src_buf[j];
+              blurred_dst[j] = frame_blurred_buf[j];
+            }
+          }
+          frame_src_buf += source->y_stride;
+          frame_blurred_buf += blurred.y_stride;
+          src_dst += source_block.y_stride;
+          blurred_dst += blurred_block.y_stride;
+        }
       }
 
       const double amount_start = AOMMAX(best_frame_unsharp_amount - 0.2, 0.0);
@@ -324,13 +346,25 @@
       const int block_width = AOMMIN(source->y_width - col_offset_y, block_w);
       const int block_height = AOMMIN(source->y_height - row_offset_y, block_h);
       const int index = col + row * num_cols;
-      uint8_t *src_buf =
-          source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
-      uint8_t *blurred_buf =
-          blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
-      unsharp_rect(src_buf, source->y_stride, blurred_buf, blurred.y_stride,
-                   src_buf, source->y_stride, block_width, block_height,
-                   best_unsharp_amounts[index]);
+
+      if (bit_depth > 8) {
+        uint16_t *src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
+                            row_offset_y * source->y_stride + col_offset_y;
+        uint16_t *blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
+                                row_offset_y * blurred.y_stride + col_offset_y;
+        highbd_unsharp_rect(src_buf, source->y_stride, blurred_buf,
+                            blurred.y_stride, src_buf, source->y_stride,
+                            block_width, block_height,
+                            best_unsharp_amounts[index], bit_depth);
+      } else {
+        uint8_t *src_buf =
+            source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
+        uint8_t *blurred_buf =
+            blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
+        unsharp_rect(src_buf, source->y_stride, blurred_buf, blurred.y_stride,
+                     src_buf, source->y_stride, block_width, block_height,
+                     best_unsharp_amounts[index]);
+      }
     }
   }