Fix daala-dist for cb4x4

The place where av1_daala_dist() is applied for sub8x8 partition is
moved from sub8x8 mode decision functions to rd_pick_partition().

BD-Rate change by daala-dist with '--disable-var-tx' is:
(AWCY, objective-1-fast, high delay mode)

   PSNR | PSNR Cb | PSNR Cr | PSNR HVS |    SSIM | MS SSIM | CIEDE 2000
15.1558 | 12.9585 | 14.4662 |  -3.8651 | -1.7102 | -9.2956 |    10.8686

In MSE probe mode:

  PSNR | PSNR Cb | PSNR Cr | PSNR HVS |   SSIM | MS SSIM | CIEDE 2000
0.0429 |  0.0435 |  0.1651 |  -0.0415 | 0.0850 |  0.0122 |     0.0546

Change-Id: I3b2ea916d41c48e433eb641adf44552e4725c198
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 39cd27d..b23acce 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -575,10 +575,9 @@
   return sum;
 }
 
-static int64_t av1_daala_dist(const uint8_t *src, int src_stride,
-                              const uint8_t *dst, int dst_stride, int bsw,
-                              int bsh, int qm, int use_activity_masking,
-                              int qindex) {
+int64_t av1_daala_dist(const uint8_t *src, int src_stride, const uint8_t *dst,
+                       int dst_stride, int bsw, int bsh, int qm,
+                       int use_activity_masking, int qindex) {
   int i, j;
   int64_t d;
   DECLARE_ALIGNED(16, od_coeff, orig[MAX_TX_SQUARE]);
@@ -1577,7 +1576,7 @@
   rd = AOMMIN(rd1, rd2);
 
 #if CONFIG_DAALA_DIST
-  if (plane == 0 &&
+  if (plane == 0 && plane_bsize >= BLOCK_8X8 &&
       (tx_size == TX_4X4 || tx_size == TX_4X8 || tx_size == TX_8X4)) {
     this_rd_stats.dist = 0;
     this_rd_stats.sse = 0;
@@ -1615,6 +1614,9 @@
   int use_activity_masking = 0;
 
   (void)tx_size;
+
+  assert(plane == 0);
+  assert(plane_bsize >= BLOCK_8X8);
 #if CONFIG_PVQ
   use_activity_masking = x->daala_enc.use_activity_masking;
 #endif  // CONFIG_PVQ
@@ -1674,10 +1676,15 @@
 
   {
     const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
+    const uint8_t txw_unit = tx_size_wide_unit[tx_size];
+    const uint8_t txh_unit = tx_size_high_unit[tx_size];
+    const int step = txw_unit * txh_unit;
+    int offset_h = tx_size_high_unit[TX_4X4];
     // The rate of the current 8x8 block is the sum of four 4x4 blocks in it.
-    this_rd_stats.rate = x->rate_4x4[block - max_blocks_wide - 1] +
-                         x->rate_4x4[block - max_blocks_wide] +
-                         x->rate_4x4[block - 1] + x->rate_4x4[block];
+    this_rd_stats.rate =
+        x->rate_4x4[block - max_blocks_wide * offset_h - step] +
+        x->rate_4x4[block - max_blocks_wide * offset_h] +
+        x->rate_4x4[block - step] + x->rate_4x4[block];
   }
   rd1 = RDCOST(x->rdmult, x->rddiv, this_rd_stats.rate, this_rd_stats.dist);
   rd2 = RDCOST(x->rdmult, x->rddiv, 0, this_rd_stats.sse);
@@ -1714,10 +1721,10 @@
   av1_get_entropy_contexts(bsize, tx_size, pd, args.t_above, args.t_left);
 
 #if CONFIG_DAALA_DIST
-  if (plane == 0 &&
+  if (plane == 0 && bsize >= BLOCK_8X8 &&
       (tx_size == TX_4X4 || tx_size == TX_4X8 || tx_size == TX_8X4))
-    av1_foreach_8x8_transformed_block_in_plane(
-        xd, bsize, plane, block_rd_txfm, block_8x8_rd_txfm_daala_dist, &args);
+    av1_foreach_8x8_transformed_block_in_yplane(
+        xd, bsize, block_rd_txfm, block_8x8_rd_txfm_daala_dist, &args);
   else
 #endif  // CONFIG_DAALA_DIST
     av1_foreach_transformed_block_in_plane(xd, bsize, plane, block_rd_txfm,
@@ -9498,6 +9505,9 @@
       rd_cost->dist = dist_y + dist_uv;
     }
     rd_cost->rdcost = RDCOST(x->rdmult, x->rddiv, rd_cost->rate, rd_cost->dist);
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+    rd_cost->dist_y = dist_y;
+#endif
   } else {
     rd_cost->rate = INT_MAX;
   }
@@ -10234,6 +10244,10 @@
     int compmode_cost = 0;
     int rate2 = 0, rate_y = 0, rate_uv = 0;
     int64_t distortion2 = 0, distortion_y = 0, distortion_uv = 0;
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+    int64_t distortion2_y = 0;
+    int64_t total_sse_y = INT64_MAX;
+#endif
     int skippable = 0;
     int this_skip2 = 0;
     int64_t total_sse = INT64_MAX;
@@ -10575,6 +10589,9 @@
       if (mbmi->mode != DC_PRED && mbmi->mode != TM_PRED)
         rate2 += intra_cost_penalty;
       distortion2 = distortion_y + distortion_uv;
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+      if (bsize < BLOCK_8X8) distortion2_y = distortion_y;
+#endif
     } else {
       int_mv backup_ref_mv[2];
 
@@ -10668,6 +10685,9 @@
         total_sse = rd_stats.sse;
         rate_y = rd_stats_y.rate;
         rate_uv = rd_stats_uv.rate;
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+        if (bsize < BLOCK_8X8) distortion2_y = rd_stats_y.dist;
+#endif
       }
 
 // TODO(jingning): This needs some refactoring to improve code quality
@@ -10877,6 +10897,12 @@
             tmp_ref_rd = tmp_alt_rd;
             backup_mbmi = *mbmi;
             backup_skip = x->skip;
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+            if (bsize < BLOCK_8X8) {
+              total_sse_y = tmp_rd_stats_y.sse;
+              distortion2_y = tmp_rd_stats_y.dist;
+            }
+#endif
 #if CONFIG_VAR_TX
             for (i = 0; i < MAX_MB_PLANE; ++i)
               memcpy(x->blk_skip_drl[i], x->blk_skip[i],
@@ -10950,6 +10976,9 @@
           this_skip2 = 1;
           rate_y = 0;
           rate_uv = 0;
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+          if (bsize < BLOCK_8X8) distortion2_y = total_sse_y;
+#endif
         }
       } else {
         // Add in the cost of the no skip flag.
@@ -11039,7 +11068,9 @@
         best_rate_y = rate_y + av1_cost_bit(av1_get_skip_prob(cm, xd),
                                             this_skip2 || skippable);
         best_rate_uv = rate_uv;
-
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+        if (bsize < BLOCK_8X8) rd_cost->dist_y = distortion2_y;
+#endif
 #if CONFIG_VAR_TX
         for (i = 0; i < MAX_MB_PLANE; ++i)
           memcpy(ctx->blk_skip[i], x->blk_skip[i],
@@ -11167,6 +11198,9 @@
       rd_cost->rate +=
           (rd_stats_y.rate + rd_stats_uv.rate - best_rate_y - best_rate_uv);
       rd_cost->dist = rd_stats_y.dist + rd_stats_uv.dist;
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+      if (bsize < BLOCK_8X8) rd_cost->dist_y = rd_stats_y.dist;
+#endif
       rd_cost->rdcost =
           RDCOST(x->rdmult, x->rddiv, rd_cost->rate, rd_cost->dist);
       best_skip2 = skip_blk;
@@ -11686,7 +11720,9 @@
   rd_cost->rate = rate2;
   rd_cost->dist = distortion2;
   rd_cost->rdcost = this_rd;
-
+#if CONFIG_DAALA_DIST && CONFIG_CB4X4
+  if (bsize < BLOCK_8X8) rd_cost->dist_y = distortion2;
+#endif
   if (this_rd >= best_rd_so_far) {
     rd_cost->rate = INT_MAX;
     rd_cost->rdcost = INT64_MAX;