Misc fixes for RD stats collection

Change-Id: I6fe5018075aced0c7dd824438edff598cd50f414
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index a88b0f4..71394e5 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2347,11 +2347,29 @@
                          blk_row, blk_col, plane_bsize, tx_bsize);
 }
 
-static double get_mean(const int16_t *diff, int stride, int w, int h) {
+static double get_diff_mean(const uint8_t *src, int src_stride,
+                            const uint8_t *dst, int dst_stride, int w, int h) {
   double sum = 0.0;
   for (int j = 0; j < h; ++j) {
     for (int i = 0; i < w; ++i) {
-      sum += diff[j * stride + i];
+      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
+      sum += diff;
+    }
+  }
+  assert(w > 0 && h > 0);
+  return sum / (w * h);
+}
+
+static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
+                                   const uint8_t *dst8, int dst_stride, int w,
+                                   int h) {
+  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
+  double sum = 0.0;
+  for (int j = 0; j < h; ++j) {
+    for (int i = 0; i < w; ++i) {
+      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
+      sum += diff;
     }
   }
   assert(w > 0 && h > 0);
@@ -2440,6 +2458,17 @@
 #if CONFIG_COLLECT_RD_STATS
 
 #if CONFIG_COLLECT_RD_STATS == 1
+static double get_mean(const int16_t *diff, int stride, int w, int h) {
+  double sum = 0.0;
+  for (int j = 0; j < h; ++j) {
+    for (int i = 0; i < w; ++i) {
+      sum += diff[j * stride + i];
+    }
+  }
+  assert(w > 0 && h > 0);
+  return sum / (w * h);
+}
+
 static void PrintTransformUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x,
                                     const RD_STATS *const rd_stats, int blk_row,
                                     int blk_col, BLOCK_SIZE plane_bsize,
@@ -2588,7 +2617,14 @@
   const int16_t *const src_diff = p->src_diff;
   const int shift = (xd->bd - 8);
 
-  int64_t sse = aom_sum_squares_2d_i16(src_diff, diff_stride, bw, bh);
+  int64_t sse;
+  if (is_cur_buf_hbd(xd)) {
+    sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
+                         bw, bh);
+  } else {
+    sse =
+        aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw, bh);
+  }
   sse = ROUND_POWER_OF_TWO(sse, shift * 2);
   const double sse_norm = (double)sse / num_samples;
 
@@ -2627,6 +2663,26 @@
   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
           model_rdcost_norm);
 
+  double mean;
+  if (is_cur_buf_hbd(xd)) {
+    mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
+                                pd->dst.stride, bw, bh);
+  } else {
+    mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
+                         bw, bh);
+  }
+  mean /= (1 << shift);
+  float hor_corr, vert_corr;
+  av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
+                                  &vert_corr);
+  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
+
+  double hdist[4] = { 0 }, vdist[4] = { 0 };
+  get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
+                               dst_stride, 1, hdist, vdist);
+  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
+          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
+
 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
   if (cpi->sf.inter_mode_rd_model_estimation == 1) {
     assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
@@ -2644,19 +2700,6 @@
   }
 #endif
 
-  double mean = get_mean(src_diff, diff_stride, bw, bh);
-  mean /= (1 << shift);
-  float hor_corr, vert_corr;
-  av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
-                                  &vert_corr);
-  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
-
-  double hdist[4] = { 0 }, vdist[4] = { 0 };
-  get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
-                               dst_stride, 1, hdist, vdist);
-  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
-          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
-
   fprintf(fout, "\n");
   fclose(fout);
 }
@@ -2708,7 +2751,12 @@
   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
                                    dst_stride, src_diff, diff_stride,
                                    sse_norm_arr, NULL);
-  double mean = get_mean(src_diff, bw, bw, bh);
+  double mean;
+  if (is_cur_buf_hbd(xd)) {
+    mean = get_highbd_diff_mean(src, src_stride, dst, dst_stride, bw, bh);
+  } else {
+    mean = get_diff_mean(src, src_stride, dst, dst_stride, bw, bh);
+  }
   if (shift) {
     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
     mean /= (1 << shift);
@@ -8160,7 +8208,8 @@
                                     AOM_PLANE_Y, AOM_PLANE_Y);
 #if CONFIG_COLLECT_RD_STATS == 3
       RD_STATS rd_stats_y;
-      select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
+      pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col,
+                            INT64_MAX);
       PrintPredictionUnitStats(cpi, tile_data, x, &rd_stats_y, bsize);
 #endif  // CONFIG_COLLECT_RD_STATS == 3
       model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
@@ -8515,7 +8564,7 @@
 
 #if CONFIG_COLLECT_RD_STATS == 3
   RD_STATS rd_stats_y;
-  select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
+  pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
   PrintPredictionUnitStats(cpi, tile_data, x, &rd_stats_y, bsize);
 #endif  // CONFIG_COLLECT_RD_STATS == 3
   model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
@@ -8722,7 +8771,7 @@
    * This function combines y and uv planes' transform search processes
    * together, when the prediction is generated. It first does subtraction to
    * obtain the prediction error. Then it calls
-   * select_tx_type_yrd/super_block_yrd and super_block_uvrd sequentially and
+   * pick_tx_size_type_yrd/super_block_yrd and super_block_uvrd sequentially and
    * handles the early terminations happening in those functions. At the end, it
    * computes the rd_stats/_y/_uv accordingly.
    */