Speed up VMAF calculations for VMAF RDO tuning

Test perf stat -e instructions:u ./aomenc red_kayak_480p.y4m
 --limit=30 -o output --tune=vmaf_without_preprocessing --cpu-used=1

VMAF baseline
 1,302,547,677,300      instructions:u
     233.921052058 seconds time elapsed
     219.757056000 seconds user
      75.185427000 seconds sys

PSNR baseline
   692,598,329,226      instructions:u
      91.892158067 seconds time elapsed
      91.697262000 seconds user
       0.184010000 seconds sys

New VMAF
 1,276,643,191,938      instructions:u
      88.573175585 seconds time elapsed
     218.622893000 seconds user
       2.224477000 seconds sys

Change-Id: I6d2815a9443ef1d1a157a225a3a6644809acd5f8
diff --git a/aom_dsp/vmaf.c b/aom_dsp/vmaf.c
index 82bde1f..71e4222 100644
--- a/aom_dsp/vmaf.c
+++ b/aom_dsp/vmaf.c
@@ -95,7 +95,7 @@
   double vmaf_score;
   const int ret = compute_vmaf(
       &vmaf_score, (char *)"yuv420p", frame_width, frame_height, read_frame,
-      /*user_data=*/&user_data, (char *)model_path,
+      /*user_data=*/user_data, (char *)model_path,
       /*log_path=*/"vmaf_scores.xml", /*log_fmt=*/NULL, /*disable_clip=*/0,
       /*disable_avx=*/0, /*enable_transform=*/0,
       /*phone_model=*/0, /*do_psnr=*/0, /*do_ssim=*/0,
diff --git a/av1/encoder/tune_vmaf.c b/av1/encoder/tune_vmaf.c
index 0b0bd00..b5beb84 100644
--- a/av1/encoder/tune_vmaf.c
+++ b/av1/encoder/tune_vmaf.c
@@ -32,6 +32,24 @@
   }
 }
 
+static AOM_INLINE void unsharp_rect_float(const float *source,
+                                          int source_stride,
+                                          const uint8_t *blurred,
+                                          int blurred_stride, float *dst,
+                                          int dst_stride, int w, int h,
+                                          float amount) {
+  for (int i = 0; i < h; ++i) {
+    for (int j = 0; j < w; ++j) {
+      dst[j] = source[j] + amount * (source[j] - (float)blurred[j]);
+      if (dst[j] < 0.0f) dst[j] = 0.0f;
+      if (dst[j] > 255.0) dst[j] = 255.0f;
+    }
+    source += source_stride;
+    blurred += blurred_stride;
+    dst += dst_stride;
+  }
+}
+
 static AOM_INLINE void unsharp(const YV12_BUFFER_CONFIG *source,
                                const YV12_BUFFER_CONFIG *blurred,
                                const YV12_BUFFER_CONFIG *dst, double amount) {
@@ -282,7 +300,7 @@
 }
 
 // TODO(sdeng): replace it with the SIMD version.
-static AOM_INLINE double image_mse_c(const uint8_t *src, int src_stride,
+static AOM_INLINE double image_sse_c(const uint8_t *src, int src_stride,
                                      const uint8_t *ref, int ref_stride, int w,
                                      int h) {
   double accum = 0.0;
@@ -297,7 +315,71 @@
     }
   }
 
-  return accum / (double)(w * h);
+  return accum;
+}
+
+typedef struct FrameData {
+  const YV12_BUFFER_CONFIG *source, *blurred;
+  int block_w, block_h, num_rows, num_cols, row, col;
+} FrameData;
+
+// A callback function used to pass data to VMAF.
+// Returns 0 after reading a frame.
+// Returns 2 when there is no more frame to read.
+static int update_frame(float *ref_data, float *main_data, float *temp_data,
+                        int stride, void *user_data) {
+  FrameData *frames = (FrameData *)user_data;
+  const int width = frames->source->y_width;
+  const int height = frames->source->y_height;
+  const int row = frames->row;
+  const int col = frames->col;
+  const int num_rows = frames->num_rows;
+  const int num_cols = frames->num_cols;
+  const int block_w = frames->block_w;
+  const int block_h = frames->block_h;
+  const YV12_BUFFER_CONFIG *source = frames->source;
+  const YV12_BUFFER_CONFIG *blurred = frames->blurred;
+  (void)temp_data;
+  stride /= (int)sizeof(*ref_data);
+
+  for (int i = 0; i < height; ++i) {
+    float *ref, *main;
+    uint8_t *src;
+    ref = ref_data + i * stride;
+    main = main_data + i * stride;
+    src = source->y_buffer + i * source->y_stride;
+    for (int j = 0; j < width; ++j) {
+      ref[j] = main[j] = (float)src[j];
+    }
+  }
+  if (row < 0 && col < 0) {
+    frames->row = 0;
+    frames->col = 0;
+    return 0;
+  } else if (row < num_rows && col < num_cols) {
+    // Set current block
+    const int row_offset = row * block_h;
+    const int col_offset = col * block_w;
+    const int block_width = AOMMIN(width - col_offset, block_w);
+    const int block_height = AOMMIN(height - row_offset, block_h);
+
+    float *main_buf = main_data + col_offset + row_offset * stride;
+    float *ref_buf = ref_data + col_offset + row_offset * stride;
+    uint8_t *blurred_buf =
+        blurred->y_buffer + row_offset * blurred->y_stride + col_offset;
+
+    unsharp_rect_float(ref_buf, stride, blurred_buf, blurred->y_stride,
+                       main_buf, stride, block_width, block_height, -1.0f);
+
+    frames->col++;
+    if (frames->col >= num_cols) {
+      frames->col = 0;
+      frames->row++;
+    }
+    return 0;
+  } else {
+    return 2;
+  }
 }
 
 void av1_set_mb_vmaf_rdmult_scaling(AV1_COMP *cpi) {
@@ -325,23 +407,27 @@
   }
 
   aom_clear_system_state();
-  YV12_BUFFER_CONFIG fake_recon, blurred;
-  memset(&fake_recon, 0, sizeof(fake_recon));
+  YV12_BUFFER_CONFIG blurred;
   memset(&blurred, 0, sizeof(blurred));
-  aom_alloc_frame_buffer(&fake_recon, y_width, y_height, 1, 1,
-                         cm->seq_params.use_highbitdepth,
-                         cpi->oxcf.border_in_pixels, cm->byte_alignment);
   aom_alloc_frame_buffer(&blurred, y_width, y_height, 1, 1,
                          cm->seq_params.use_highbitdepth,
                          cpi->oxcf.border_in_pixels, cm->byte_alignment);
-
   gaussian_blur(cpi, cpi->source, &blurred);
 
-  // baseline vmaf
-  double baseline_mse = 0.0, baseline_vmaf = 0.0;
-  aom_calc_vmaf(cpi->oxcf.vmaf_model_path, cpi->source, cpi->source,
-                &baseline_vmaf);
-  av1_copy_and_extend_frame(cpi->source, &fake_recon);
+  double *scores = aom_malloc(sizeof(*scores) * (num_rows * num_cols + 1));
+  memset(scores, 0, sizeof(*scores) * (num_rows * num_cols + 1));
+  FrameData frame_data;
+  frame_data.source = cpi->source;
+  frame_data.blurred = &blurred;
+  frame_data.block_w = block_w;
+  frame_data.block_h = block_h;
+  frame_data.num_rows = num_rows;
+  frame_data.num_cols = num_cols;
+  frame_data.row = -1;
+  frame_data.col = -1;
+  aom_calc_vmaf_multi_frame(&frame_data, cpi->oxcf.vmaf_model_path,
+                            update_frame, y_width, y_height, scores);
+  const double baseline_mse = 0.0, baseline_vmaf = scores[0];
 
   // Loop through each 'block_size' block.
   for (int row = 0; row < num_rows; ++row) {
@@ -356,26 +442,21 @@
           y_buffer + row_offset_y * y_stride + col_offset_y;
       uint8_t *const blurred_buf =
           blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
-      uint8_t *const fybuf = fake_recon.y_buffer +
-                             row_offset_y * fake_recon.y_stride + col_offset_y;
 
       const int block_width = AOMMIN(y_width - col_offset_y, block_w);
       const int block_height = AOMMIN(y_height - row_offset_y, block_h);
 
-      // Set the blurred block.
-      unsharp_rect(orig_buf, y_stride, blurred_buf, blurred.y_stride, fybuf,
-                   fake_recon.y_stride, block_width, block_height, -1.0);
+      const double vmaf = scores[index + 1];
+      const double dvmaf = baseline_vmaf - vmaf;
 
-      double vmaf, mse;
-      aom_calc_vmaf(cpi->oxcf.vmaf_model_path, cpi->source, &fake_recon, &vmaf);
-      mse = image_mse_c(y_buffer, y_stride, fake_recon.y_buffer,
-                        fake_recon.y_stride, y_width, y_height);
+      const double mse =
+          image_sse_c(orig_buf, y_stride, blurred_buf, blurred.y_stride,
+                      block_width, block_height) /
+          (double)(y_width * y_height);
+      const double dmse = mse - baseline_mse;
 
       double weight = 0.0;
-      const double dvmaf = baseline_vmaf - vmaf;
-      const double dmse = mse - baseline_mse;
       const double eps = 0.01 / (num_rows * num_cols);
-
       if (dvmaf < eps || dmse < eps) {
         weight = 1.0;
       } else {
@@ -385,15 +466,11 @@
       // Normalize it with a data fitted model.
       weight = 6.0 * (1.0 - exp(-0.05 * weight)) + 0.8;
       cpi->vmaf_rdmult_scaling_factors[index] = weight;
-
-      // Reset blurred block.
-      unsharp_rect(orig_buf, y_stride, blurred_buf, blurred.y_stride, fybuf,
-                   fake_recon.y_stride, block_width, block_height, 0.0);
     }
   }
 
-  aom_free_frame_buffer(&fake_recon);
   aom_free_frame_buffer(&blurred);
+  aom_free(scores);
   aom_clear_system_state();
   (void)xd;
 }