Fix daala-dist for var-tx

The var-tx has its own suite of tx size/type RD search functions,
which recursively split the partition into square tx blocks.

The Daala-dist requires access to 8x8 pixels (both decoded and predicted)
since it measures the distortion for multiple of a 8x8 pixels.
Thus, if tx block is smaller than 8x8, it waits until all of sub8x8 blocks
are RD searched (with MSE) then replaces the MSE of 8x8 pixels with
daala-dist's calculated distortion for 8x8 pixels.

It is also applied to luma pixels only.

Change-Id: Ic4891e89b4ef05cf880aa26781d2d06ccf3142de
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index b1e1292..a3c0844 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1821,9 +1821,6 @@
   const struct macroblockd_plane *const pd = &xd->plane[0];
   const struct macroblock_plane *const p = &x->plane[0];
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
-#if CONFIG_PVQ
-  use_activity_masking = x->daala_enc.use_activity_masking;
-#endif  // CONFIG_PVQ
   const int src_stride = p->src.stride;
   const int dst_stride = pd->dst.stride;
   const uint8_t *src = &p->src.buf[0];
@@ -1838,6 +1835,9 @@
   int use_activity_masking = 0;
   unsigned int tmp1, tmp2;
   int qindex = x->qindex;
+#if CONFIG_PVQ
+  use_activity_masking = x->daala_enc.use_activity_masking;
+#endif
 
   assert((bw & 0x07) == 0);
   assert((bh & 0x07) == 0);
@@ -4081,7 +4081,6 @@
   BLOCK_SIZE txm_bsize = txsize_to_bsize[tx_size];
   int bh = block_size_high[txm_bsize];
   int bw = block_size_wide[txm_bsize];
-
   int src_stride = p->src.stride;
   uint8_t *src =
       &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
@@ -4137,6 +4136,23 @@
   av1_inverse_transform_block(xd, dqcoeff, tx_type, tx_size, rec_buffer,
                               MAX_TX_SIZE, eob);
   if (eob > 0) {
+#if CONFIG_DAALA_DIST
+    if (plane == 0 && (bw < 8 && bh < 8)) {
+      // Save sub8x8 luma decoded pixels
+      // since 8x8 luma decoded pixels are not available for daala-dist
+      // after recursive split of BLOCK_8x8 is done.
+      const int pred_stride = block_size_wide[plane_bsize];
+      const int pred_idx = (blk_row * pred_stride + blk_col)
+                           << tx_size_wide_log2[0];
+      int16_t *decoded = &pd->pred[pred_idx];
+      int i, j;
+
+      // TODO(yushin): HBD support
+      for (j = 0; j < bh; j++)
+        for (i = 0; i < bw; i++)
+          decoded[j * pred_stride + i] = rec_buffer[j * MAX_TX_SIZE + i];
+    }
+#endif  // CONFIG_DAALA_DIST
     tmp = pixel_dist(cpi, x, plane, src, src_stride, rec_buffer, MAX_TX_SIZE,
                      blk_row, blk_col, plane_bsize, txm_bsize);
   }
@@ -4249,7 +4265,9 @@
     RD_STATS this_rd_stats;
     int this_cost_valid = 1;
     int64_t tmp_rd = 0;
-
+#if CONFIG_DAALA_DIST
+    int sub8x8_eob[4];
+#endif
     sum_rd_stats.rate =
         av1_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 1);
 
@@ -4265,13 +4283,76 @@
                       depth + 1, plane_bsize, ta, tl, tx_above, tx_left,
                       &this_rd_stats, ref_best_rd - tmp_rd, &this_cost_valid,
                       rd_stats_stack);
-
+#if CONFIG_DAALA_DIST
+      if (plane == 0 && tx_size == TX_8X8) {
+        sub8x8_eob[i] = p->eobs[block];
+      }
+#endif  // CONFIG_DAALA_DIST
       av1_merge_rd_stats(&sum_rd_stats, &this_rd_stats);
 
       tmp_rd = RDCOST(x->rdmult, sum_rd_stats.rate, sum_rd_stats.dist);
+#if !CONFIG_DAALA_DIST
       if (this_rd < tmp_rd) break;
+#endif
       block += sub_step;
     }
+#if CONFIG_DAALA_DIST
+    if (this_cost_valid && plane == 0 && tx_size == TX_8X8) {
+      const int src_stride = p->src.stride;
+      const int dst_stride = pd->dst.stride;
+
+      const uint8_t *src =
+          &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
+      const uint8_t *dst =
+          &pd->dst
+               .buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
+
+      int64_t daala_dist;
+      int qindex = x->qindex;
+      const int pred_stride = block_size_wide[plane_bsize];
+      const int pred_idx = (blk_row * pred_stride + blk_col)
+                           << tx_size_wide_log2[0];
+      int16_t *pred = &pd->pred[pred_idx];
+      int j;
+      int qm = OD_HVS_QM;
+      int use_activity_masking = 0;
+      int row, col;
+
+      DECLARE_ALIGNED(16, uint8_t, pred8[8 * 8]);
+
+#if CONFIG_PVQ
+      use_activity_masking = x->daala_enc.use_activity_masking;
+#endif
+      daala_dist = av1_daala_dist(src, src_stride, dst, dst_stride, 8, 8, 8, 8,
+                                  qm, use_activity_masking, qindex) *
+                   16;
+      sum_rd_stats.sse = daala_dist;
+
+      for (row = 0; row < 2; ++row) {
+        for (col = 0; col < 2; ++col) {
+          int idx = row * 2 + col;
+          int eob = sub8x8_eob[idx];
+
+          if (eob > 0) {
+            for (j = 0; j < 4; j++)
+              for (i = 0; i < 4; i++)
+                pred8[(row * 4 + j) * 8 + 4 * col + i] =
+                    pred[(row * 4 + j) * pred_stride + 4 * col + i];
+          } else {
+            for (j = 0; j < 4; j++)
+              for (i = 0; i < 4; i++)
+                pred8[(row * 4 + j) * 8 + 4 * col + i] =
+                    dst[(row * 4 + j) * dst_stride + 4 * col + i];
+          }
+        }
+      }
+      daala_dist = av1_daala_dist(src, src_stride, pred8, 8, 8, 8, 8, 8, qm,
+                                  use_activity_masking, qindex) *
+                   16;
+      sum_rd_stats.dist = daala_dist;
+      tmp_rd = RDCOST(x->rdmult, sum_rd_stats.rate, sum_rd_stats.dist);
+    }
+#endif  // CONFIG_DAALA_DIST
     if (this_cost_valid) sum_rd = tmp_rd;
   }