[dist-8x8] Refactor dist_8x8_sub8x8_txfm_rd() function

So that, existing diff pixles can be used to calculate sse distortion
for both intra and inter mode blocks.

Change-Id: Ifa79003dbc08f5a49e3246d350469a32060648cf
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 38818ee..405e8af 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2210,7 +2210,7 @@
 
   int i, j;
   int64_t rd, rd1, rd2;
-  unsigned int tmp1, tmp2;
+  int64_t sse = INT64_MAX, dist = INT64_MAX;
   int qindex = x->qindex;
 
   assert((bw & 0x07) == 0);
@@ -2219,54 +2219,57 @@
   get_txb_dimensions(xd, 0, bsize, 0, 0, bsize, &bw, &bh, &visible_w,
                      &visible_h);
 
-#if CONFIG_HIGHBITDEPTH
-  uint8_t *pred8;
-  DECLARE_ALIGNED(16, uint16_t, pred16[MAX_SB_SQUARE]);
-
-  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-    pred8 = CONVERT_TO_BYTEPTR(pred16);
-  else
-    pred8 = (uint8_t *)pred16;
-#else
-  DECLARE_ALIGNED(16, uint8_t, pred8[MAX_SB_SQUARE]);
-#endif  // CONFIG_HIGHBITDEPTH
-
-#if CONFIG_HIGHBITDEPTH
-  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
-    for (j = 0; j < bh; j++)
-      for (i = 0; i < bw; i++)
-        CONVERT_TO_SHORTPTR(pred8)[j * bw + i] = pred[j * bw + i];
-  } else {
-#endif
-    for (j = 0; j < bh; j++)
-      for (i = 0; i < bw; i++) pred8[j * bw + i] = (uint8_t)pred[j * bw + i];
-#if CONFIG_HIGHBITDEPTH
-  }
-#endif  // CONFIG_HIGHBITDEPTH
-
-  tmp1 = (unsigned)av1_dist_8x8(cpi, x, src, src_stride, pred8, bw, bsize, bw,
-                                bh, visible_w, visible_h, qindex);
-  tmp2 = (unsigned)av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride, bsize,
-                                bw, bh, visible_w, visible_h, qindex);
+  const int diff_stride = block_size_wide[bsize];
+  const int16_t *diff = p->src_diff;
+  sse = av1_dist_8x8_diff(x, src, src_stride, diff, diff_stride, bw, bh,
+                          visible_w, visible_h, qindex);
+  sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2);
+  sse *= 16;
 
   if (!is_inter_block(mbmi)) {
-    if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8) {
-      assert(args->rd_stats.sse == tmp1 * 16);
-      assert(args->rd_stats.dist == tmp2 * 16);
-    }
-    args->rd_stats.sse = (int64_t)tmp1 * 16;
-    args->rd_stats.dist = (int64_t)tmp2 * 16;
+    dist = av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride, bsize, bw, bh,
+                        visible_w, visible_h, qindex);
+    dist *= 16;
   } else {
-    // For inter mode, the decoded pixels are provided in pd->pred,
-    // while the predicted pixels are in dst.
-    if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8) {
-      assert(args->rd_stats.sse == tmp2 * 16);
-      assert(args->rd_stats.dist == tmp1 * 16);
+// For inter mode, the decoded pixels are provided in pd->pred,
+// while the predicted pixels are in dst.
+#if CONFIG_HIGHBITDEPTH
+    uint8_t *pred8;
+    DECLARE_ALIGNED(16, uint16_t, pred16[MAX_SB_SQUARE]);
+
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+      pred8 = CONVERT_TO_BYTEPTR(pred16);
+    else
+      pred8 = (uint8_t *)pred16;
+#else
+    DECLARE_ALIGNED(16, uint8_t, pred8[MAX_SB_SQUARE]);
+#endif  // CONFIG_HIGHBITDEPTH
+
+#if CONFIG_HIGHBITDEPTH
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+      for (j = 0; j < bh; j++)
+        for (i = 0; i < bw; i++)
+          CONVERT_TO_SHORTPTR(pred8)[j * bw + i] = pred[j * bw + i];
+    } else {
+#endif
+      for (j = 0; j < bh; j++)
+        for (i = 0; i < bw; i++) pred8[j * bw + i] = (uint8_t)pred[j * bw + i];
+#if CONFIG_HIGHBITDEPTH
     }
-    args->rd_stats.sse = (int64_t)tmp2 * 16;
-    args->rd_stats.dist = (int64_t)tmp1 * 16;
+#endif  // CONFIG_HIGHBITDEPTH
+
+    dist = av1_dist_8x8(cpi, x, src, src_stride, pred8, bw, bsize, bw, bh,
+                        visible_w, visible_h, qindex);
+    dist *= 16;
   }
 
+  if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8) {
+    assert(args->rd_stats.sse == sse);
+    assert(args->rd_stats.dist == dist);
+  }
+  args->rd_stats.sse = sse;
+  args->rd_stats.dist = dist;
+
   rd1 = RDCOST(x->rdmult, args->rd_stats.rate, args->rd_stats.dist);
   rd2 = RDCOST(x->rdmult, 0, args->rd_stats.sse);
   rd = AOMMIN(rd1, rd2);