Refactor sub8x8 tx size RD for daala-dist

For a tx size RD search with partition size >= 8x8 and tx size < 8x8,
daala-dist function is applied to the whole partition after all tx blocks are encoded
instead of each 8x8 sub block of the partition.

Change-Id: I27d9e2960aa641f550096e32ebcdf8dfb4de79a6
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 8ca465c..7212709 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1562,16 +1562,6 @@
   // TODO(jingning): temporarily enabled only for luma component
   rd = AOMMIN(rd1, rd2);
 
-#if CONFIG_DAALA_DIST
-  if (plane == 0 && plane_bsize >= BLOCK_8X8 &&
-      (tx_size == TX_4X4 || tx_size == TX_4X8 || tx_size == TX_8X4)) {
-    this_rd_stats.dist = 0;
-    this_rd_stats.sse = 0;
-    rd = 0;
-    x->rate_4x4[block] = this_rd_stats.rate;
-  }
-#endif  // CONFIG_DAALA_DIST
-
 #if !CONFIG_PVQ
   this_rd_stats.skip &= !x->plane[plane].eobs[block];
 #else
@@ -1581,111 +1571,74 @@
 
   args->this_rd += rd;
 
-  if (args->this_rd > args->best_rd) {
-    args->exit_early = 1;
-    return;
+#if CONFIG_DAALA_DIST
+  if (!(plane == 0 && plane_bsize >= BLOCK_8X8 &&
+        (tx_size == TX_4X4 || tx_size == TX_4X8 || tx_size == TX_8X4))) {
+#endif
+    if (args->this_rd > args->best_rd) {
+      args->exit_early = 1;
+      return;
+    }
+#if CONFIG_DAALA_DIST
   }
+#endif
 }
 
 #if CONFIG_DAALA_DIST
-static void block_8x8_rd_txfm_daala_dist(int plane, int block, int blk_row,
-                                         int blk_col, BLOCK_SIZE plane_bsize,
-                                         TX_SIZE tx_size, void *arg) {
-  struct rdcost_block_args *args = arg;
-  MACROBLOCK *const x = args->x;
+static void daala_dist_sub8x8_txfm_rd(MACROBLOCK *x, BLOCK_SIZE bsize,
+                                      struct rdcost_block_args *args) {
   MACROBLOCKD *const xd = &x->e_mbd;
+  const struct macroblockd_plane *const pd = &xd->plane[0];
+  const struct macroblock_plane *const p = &x->plane[0];
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
-  int64_t rd, rd1, rd2;
-  RD_STATS this_rd_stats;
-  int qm = OD_HVS_QM;
-  int use_activity_masking = 0;
-
-  (void)tx_size;
-
-  assert(plane == 0);
-  assert(plane_bsize >= BLOCK_8X8);
 #if CONFIG_PVQ
   use_activity_masking = x->daala_enc.use_activity_masking;
 #endif  // CONFIG_PVQ
-  av1_init_rd_stats(&this_rd_stats);
+  const int src_stride = p->src.stride;
+  const int dst_stride = pd->dst.stride;
+  const int pred_stride = block_size_wide[bsize];
+  const uint8_t *src = &p->src.buf[0];
+  const uint8_t *dst = &pd->dst.buf[0];
+  const int16_t *pred = &pd->pred[0];
+  const int bw = block_size_wide[bsize];
+  const int bh = block_size_high[bsize];
 
-  if (args->exit_early) return;
+  int i, j;
+  int64_t rd, rd1, rd2;
+  int qm = OD_HVS_QM;
+  int use_activity_masking = 0;
+  unsigned int tmp1, tmp2;
+  int qindex = x->qindex;
 
-  {
-    const struct macroblock_plane *const p = &x->plane[plane];
-    struct macroblockd_plane *const pd = &xd->plane[plane];
+  assert((bw & 0x07) == 0);
+  assert((bh & 0x07) == 0);
 
-    const int src_stride = p->src.stride;
-    const int dst_stride = pd->dst.stride;
-    const int diff_stride = block_size_wide[plane_bsize];
+  DECLARE_ALIGNED(16, uint8_t, pred8[MAX_SB_SQUARE]);
 
-    const uint8_t *src =
-        &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
-    const uint8_t *dst =
-        &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
+  for (j = 0; j < bh; j++)
+    for (i = 0; i < bw; i++) pred8[j * bh + i] = pred[j * pred_stride + i];
 
-    unsigned int tmp1, tmp2;
-    int qindex = x->qindex;
-    const int pred_stride = block_size_wide[plane_bsize];
-    const int pred_idx = (blk_row * pred_stride + blk_col)
-                         << tx_size_wide_log2[0];
-    int16_t *pred = &pd->pred[pred_idx];
-    int i, j;
-    const int tx_blk_size = 8;
+  tmp1 = av1_daala_dist(src, src_stride, pred8, bw, bw, bh, qm,
+                        use_activity_masking, qindex);
+  tmp2 = av1_daala_dist(src, src_stride, dst, dst_stride, bw, bh, qm,
+                        use_activity_masking, qindex);
 
-    DECLARE_ALIGNED(16, uint8_t, pred8[8 * 8]);
-
-    for (j = 0; j < tx_blk_size; j++)
-      for (i = 0; i < tx_blk_size; i++)
-        pred8[j * tx_blk_size + i] = pred[j * diff_stride + i];
-
-    tmp1 = av1_daala_dist(src, src_stride, pred8, tx_blk_size, 8, 8, qm,
-                          use_activity_masking, qindex);
-    tmp2 = av1_daala_dist(src, src_stride, dst, dst_stride, 8, 8, qm,
-                          use_activity_masking, qindex);
-
-    if (!is_inter_block(mbmi)) {
-      this_rd_stats.sse = (int64_t)tmp1 * 16;
-      this_rd_stats.dist = (int64_t)tmp2 * 16;
-    } else {
-      // For inter mode, the decoded pixels are provided in pd->pred,
-      // while the predicted pixels are in dst.
-      this_rd_stats.sse = (int64_t)tmp2 * 16;
-      this_rd_stats.dist = (int64_t)tmp1 * 16;
-    }
+  if (!is_inter_block(mbmi)) {
+    args->rd_stats.sse = (int64_t)tmp1 * 16;
+    args->rd_stats.dist = (int64_t)tmp2 * 16;
+  } else {
+    // For inter mode, the decoded pixels are provided in pd->pred,
+    // while the predicted pixels are in dst.
+    args->rd_stats.sse = (int64_t)tmp2 * 16;
+    args->rd_stats.dist = (int64_t)tmp1 * 16;
   }
 
-  rd = RDCOST(x->rdmult, x->rddiv, 0, this_rd_stats.dist);
-  if (args->this_rd + rd > args->best_rd) {
-    args->exit_early = 1;
-    return;
-  }
-
-  {
-    const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
-    const uint8_t txw_unit = tx_size_wide_unit[tx_size];
-    const uint8_t txh_unit = tx_size_high_unit[tx_size];
-    const int step = txw_unit * txh_unit;
-    int offset_h = tx_size_high_unit[TX_4X4];
-    // The rate of the current 8x8 block is the sum of four 4x4 blocks in it.
-    this_rd_stats.rate =
-        x->rate_4x4[block - max_blocks_wide * offset_h - step] +
-        x->rate_4x4[block - max_blocks_wide * offset_h] +
-        x->rate_4x4[block - step] + x->rate_4x4[block];
-  }
-  rd1 = RDCOST(x->rdmult, x->rddiv, this_rd_stats.rate, this_rd_stats.dist);
-  rd2 = RDCOST(x->rdmult, x->rddiv, 0, this_rd_stats.sse);
+  rd1 = RDCOST(x->rdmult, x->rddiv, args->rd_stats.rate, args->rd_stats.dist);
+  rd2 = RDCOST(x->rdmult, x->rddiv, 0, args->rd_stats.sse);
   rd = AOMMIN(rd1, rd2);
 
-  args->rd_stats.dist += this_rd_stats.dist;
-  args->rd_stats.sse += this_rd_stats.sse;
-
-  args->this_rd += rd;
-
-  if (args->this_rd > args->best_rd) {
-    args->exit_early = 1;
-    return;
-  }
+  args->rd_stats.rdcost = rd;
+  args->this_rd = rd;
 }
 #endif  // CONFIG_DAALA_DIST
 
@@ -1707,15 +1660,13 @@
 
   av1_get_entropy_contexts(bsize, tx_size, pd, args.t_above, args.t_left);
 
+  av1_foreach_transformed_block_in_plane(xd, bsize, plane, block_rd_txfm,
+                                         &args);
 #if CONFIG_DAALA_DIST
-  if (plane == 0 && bsize >= BLOCK_8X8 &&
+  if (!args.exit_early && plane == 0 && bsize >= BLOCK_8X8 &&
       (tx_size == TX_4X4 || tx_size == TX_4X8 || tx_size == TX_8X4))
-    av1_foreach_8x8_transformed_block_in_yplane(
-        xd, bsize, block_rd_txfm, block_8x8_rd_txfm_daala_dist, &args);
-  else
-#endif  // CONFIG_DAALA_DIST
-    av1_foreach_transformed_block_in_plane(xd, bsize, plane, block_rd_txfm,
-                                           &args);
+    daala_dist_sub8x8_txfm_rd(x, bsize, &args);
+#endif
 
   if (args.exit_early) {
     av1_invalid_rd_stats(rd_stats);