Documentation for transform type RDO

BUG=aomedia:2617

Change-Id: I66de4213d82b3a03b8674d83f8e06628ed01343f
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 277ad54..be27670 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -1079,7 +1079,7 @@
                                int block, int blk_row, int blk_col,
                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
                                const TXB_CTX *const txb_ctx, int skip_trellis,
-                               TX_TYPE best_tx_type, TX_TYPE last_tx_type,
+                               TX_TYPE best_tx_type, int do_quant,
                                int *rate_cost, uint16_t best_eob) {
   const AV1_COMMON *cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
@@ -1088,11 +1088,9 @@
   if (!is_inter && best_eob &&
       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
-    // intra mode needs decoded result such that the next transform block
-    // can use it for prediction.
-    // if the last search tx_type is the best tx_type, we don't need to
-    // do this again
-    if (best_tx_type != last_tx_type) {
+    // if the quantized coefficients are stored in the dqcoeff buffer, we don't
+    // need to do transform and quantization again.
+    if (do_quant) {
       TxfmParam txfm_param_intra;
       QUANT_PARAM quant_param_intra;
       av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
@@ -1259,32 +1257,30 @@
 static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
 static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
 
-static INLINE int is_intra_hash_match(
-    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int blk_row, int blk_col,
-    BLOCK_SIZE plane_bsize, TX_SIZE tx_size, const TXB_CTX *const txb_ctx,
-    TXB_RD_INFO **intra_txb_rd_info, int within_border,
-    const int tx_type_map_idx, uint16_t *cur_joint_ctx) {
-  const AV1_COMMON *cm = &cpi->common;
+static INLINE int is_intra_hash_match(const AV1_COMP *cpi, MACROBLOCK *x,
+                                      int plane, int blk_row, int blk_col,
+                                      BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+                                      const TXB_CTX *const txb_ctx,
+                                      TXB_RD_INFO **intra_txb_rd_info,
+                                      const int tx_type_map_idx,
+                                      uint16_t *cur_joint_ctx) {
   MACROBLOCKD *xd = &x->e_mbd;
-  MB_MODE_INFO *mbmi = xd->mi[0];
-  const int is_inter = is_inter_block(mbmi);
-  if (within_border && cpi->sf.tx_sf.use_intra_txb_hash &&
-      frame_is_intra_only(cm) && !is_inter && plane == 0 &&
-      tx_size_wide[tx_size] == tx_size_high[tx_size]) {
-    const uint32_t intra_hash =
-        get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
-    const int intra_hash_idx =
-        find_tx_size_rd_info(&x->txb_rd_record_intra, intra_hash);
-    *intra_txb_rd_info = &x->txb_rd_record_intra.tx_rd_info[intra_hash_idx];
-    *cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
-    if ((*intra_txb_rd_info)->entropy_context == *cur_joint_ctx &&
-        x->txb_rd_record_intra.tx_rd_info[intra_hash_idx].valid) {
-      xd->tx_type_map[tx_type_map_idx] = (*intra_txb_rd_info)->tx_type;
-      const TX_TYPE ref_tx_type =
-          av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
-                          cpi->common.features.reduced_tx_set_used);
-      return (ref_tx_type == (*intra_txb_rd_info)->tx_type);
-    }
+  assert(cpi->sf.tx_sf.use_intra_txb_hash &&
+         frame_is_intra_only(&cpi->common) && !is_inter_block(xd->mi[0]) &&
+         plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size]);
+  const uint32_t intra_hash =
+      get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
+  const int intra_hash_idx =
+      find_tx_size_rd_info(&x->txb_rd_record_intra, intra_hash);
+  *intra_txb_rd_info = &x->txb_rd_record_intra.tx_rd_info[intra_hash_idx];
+  *cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
+  if ((*intra_txb_rd_info)->entropy_context == *cur_joint_ctx &&
+      x->txb_rd_record_intra.tx_rd_info[intra_hash_idx].valid) {
+    xd->tx_type_map[tx_type_map_idx] = (*intra_txb_rd_info)->tx_type;
+    const TX_TYPE ref_tx_type =
+        av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
+                        cpi->common.features.reduced_tx_set_used);
+    return (ref_tx_type == (*intra_txb_rd_info)->tx_type);
   }
   return 0;
 }
@@ -2082,13 +2078,15 @@
   return cost;
 }
 
-static void search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
-                            int block, int blk_row, int blk_col,
-                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
-                            const TXB_CTX *const txb_ctx,
-                            FAST_TX_SEARCH_MODE ftxs_mode,
-                            int use_fast_coef_costing, int skip_trellis,
-                            int64_t ref_best_rd, RD_STATS *best_rd_stats) {
+// Search for the best transform type for a given transform block.
+// This function can be used for both inter and intra, both luma and chroma.
+static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                           int block, int blk_row, int blk_col,
+                           BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+                           const TXB_CTX *const txb_ctx,
+                           FAST_TX_SEARCH_MODE ftxs_mode,
+                           int use_fast_coef_costing, int skip_trellis,
+                           int64_t ref_best_rd, RD_STATS *best_rd_stats) {
   const AV1_COMMON *cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
   struct macroblockd_plane *const pd = &xd->plane[plane];
@@ -2096,7 +2094,6 @@
   int64_t best_rd = INT64_MAX;
   uint16_t best_eob = 0;
   TX_TYPE best_tx_type = DCT_DCT;
-  TX_TYPE last_tx_type = TX_TYPES;
   int rate_cost = 0;
   // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
   // of the best tx_type
@@ -2105,78 +2102,86 @@
   tran_low_t *best_dqcoeff = this_dqcoeff;
   const int tx_type_map_idx =
       plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
-  int perform_block_coeff_opt = 0;
   av1_invalid_rd_stats(best_rd_stats);
 
-  TXB_RD_INFO *intra_txb_rd_info = NULL;
-  uint16_t cur_joint_ctx = 0;
-  const int mi_row = xd->mi_row;
-  const int mi_col = xd->mi_col;
-  const int within_border =
-      mi_row >= xd->tile.mi_row_start &&
-      (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
-      mi_col >= xd->tile.mi_col_start &&
-      (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
-
   skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
                                    DRY_RUN_NORMAL);
-  if (is_intra_hash_match(cpi, x, plane, blk_row, blk_col, plane_bsize, tx_size,
-                          txb_ctx, &intra_txb_rd_info, within_border,
-                          tx_type_map_idx, &cur_joint_ctx)) {
-    best_rd_stats->rate = intra_txb_rd_info->rate;
-    best_rd_stats->dist = intra_txb_rd_info->dist;
-    best_rd_stats->sse = intra_txb_rd_info->sse;
-    best_rd_stats->skip = intra_txb_rd_info->eob == 0;
-    x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
-    x->plane[plane].txb_entropy_ctx[block] = intra_txb_rd_info->txb_entropy_ctx;
-    best_eob = intra_txb_rd_info->eob;
-    best_tx_type = intra_txb_rd_info->tx_type;
-    perform_block_coeff_opt = intra_txb_rd_info->perform_block_coeff_opt;
-    skip_trellis |= !perform_block_coeff_opt;
-    update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
-    recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
-                txb_ctx, skip_trellis, best_tx_type, last_tx_type, &rate_cost,
-                best_eob);
-    pd->dqcoeff = orig_dqcoeff;
-    return;
+
+  // Hashing based speed feature for intra block. If the hash of the residue
+  // is found in the hash table, use the previous RD search results stored in
+  // the table and terminate early.
+  TXB_RD_INFO *intra_txb_rd_info = NULL;
+  uint16_t cur_joint_ctx = 0;
+  const int is_inter = is_inter_block(mbmi);
+  const int use_intra_txb_hash =
+      cpi->sf.tx_sf.use_intra_txb_hash && frame_is_intra_only(cm) &&
+      !is_inter && plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size];
+  if (use_intra_txb_hash) {
+    const int mi_row = xd->mi_row;
+    const int mi_col = xd->mi_col;
+    const int within_border =
+        mi_row >= xd->tile.mi_row_start &&
+        (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
+        mi_col >= xd->tile.mi_col_start &&
+        (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
+    if (within_border &&
+        is_intra_hash_match(cpi, x, plane, blk_row, blk_col, plane_bsize,
+                            tx_size, txb_ctx, &intra_txb_rd_info,
+                            tx_type_map_idx, &cur_joint_ctx)) {
+      best_rd_stats->rate = intra_txb_rd_info->rate;
+      best_rd_stats->dist = intra_txb_rd_info->dist;
+      best_rd_stats->sse = intra_txb_rd_info->sse;
+      best_rd_stats->skip = intra_txb_rd_info->eob == 0;
+      x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
+      x->plane[plane].txb_entropy_ctx[block] =
+          intra_txb_rd_info->txb_entropy_ctx;
+      best_eob = intra_txb_rd_info->eob;
+      best_tx_type = intra_txb_rd_info->tx_type;
+      skip_trellis |= !intra_txb_rd_info->perform_block_coeff_opt;
+      update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
+      recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+                  txb_ctx, skip_trellis, best_tx_type, 1, &rate_cost, best_eob);
+      pd->dqcoeff = orig_dqcoeff;
+      return;
+    }
   }
 
   uint8_t best_txb_ctx = 0;
+  // txk_allowed = TX_TYPES: >1 tx types are allowed
+  // txk_allowed < TX_TYPES: only that specific tx type is allowed.
   TX_TYPE txk_allowed = TX_TYPES;
   int txk_map[TX_TYPES] = {
     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
   };
-  uint16_t allowed_tx_mask =
+  // Bit mask to indicate which transform types are allowed in the RD search.
+  const uint16_t allowed_tx_mask =
       get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
                   txb_ctx, ftxs_mode, ref_best_rd, &txk_allowed, txk_map);
 
-  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
-  int64_t block_sse = 0;
-  unsigned int block_mse_q8 = UINT_MAX;
-  block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, tx_bsize,
-                              &block_mse_q8);
+  unsigned int block_mse_q8;
+  int64_t block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
+                                      txsize_to_bsize[tx_size], &block_mse_q8);
   assert(block_mse_q8 != UINT_MAX);
   if (is_cur_buf_hbd(xd)) {
     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
     block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
   }
   block_sse *= 16;
-
   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
-
   // Use mse / qstep^2 based threshold logic to take decision of R-D
   // optimization of coeffs. For smaller residuals, coeff optimization
   // would be helpful. For larger residuals, R-D optimization may not be
   // effective.
   // TODO(any): Experiment with variance and mean based thresholds
-  perform_block_coeff_opt =
+  const int perform_block_coeff_opt =
       ((uint64_t)block_mse_q8 <=
        (uint64_t)x->coeff_opt_dist_threshold * qstep * qstep);
-
   skip_trellis |= !perform_block_coeff_opt;
 
-  // Tranform domain distortion is accurate for higher residuals.
+  // Flag to indicate if distortion should be calculated in transform domain or
+  // not during iterating through transform type candidates.
+  // Transform domain distortion is accurate for higher residuals.
   // TODO(any): Experiment with variance and mean based thresholds
   int use_transform_domain_distortion =
       (x->use_transform_domain_distortion > 0) &&
@@ -2185,6 +2190,9 @@
       // Therefore transform domain distortion is not valid for these
       // transform sizes.
       txsize_sqr_up_map[tx_size] != TX_64X64;
+  // Flag to indicate if an extra calculation of distortion in the pixel domain
+  // should be performed at the end, after the best transform type has been
+  // decided.
   int calc_pixel_domain_distortion_final =
       x->use_transform_domain_distortion == 1 &&
       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
@@ -2203,6 +2211,7 @@
                                : AV1_XFORM_QUANT_FP,
                   cpi->use_quant_b_adapt, &quant_param);
 
+  // Iterate through all transform type candidates.
   for (int idx = 0; idx < TX_TYPES; ++idx) {
     const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
     if (!(allowed_tx_mask & (1 << tx_type))) continue;
@@ -2218,6 +2227,7 @@
     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
                     &quant_param);
 
+    // Calculate rate cost of quantized coefficients.
     if (quant_param.use_optimize_b) {
       if (cpi->sf.rd_sf.optimize_b_precheck && best_rd < INT64_MAX &&
           eobs_ptr[block] >= 4) {
@@ -2238,10 +2248,11 @@
                                   cm->features.reduced_tx_set_used);
     }
 
-    // If rd cost based on coeff rate is more than best_rd, skip the calculation
-    // of distortion
-    int64_t tmp_rd = RDCOST(x->rdmult, rate_cost, 0);
-    if (tmp_rd > best_rd) continue;
+    // If rd cost based on coeff rate alone is already more than best_rd,
+    // terminate early.
+    if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
+
+    // Calculate distortion.
     if (eobs_ptr[block] == 0) {
       // When eob is 0, pixel domain distortion is more efficient and accurate.
       this_rd_stats.dist = this_rd_stats.sse = block_sse;
@@ -2279,6 +2290,7 @@
         if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
           this_rd_stats.dist = tx_domain_dist;
       } else {
+        assert(sse_diff < INT64_MAX);
         this_rd_stats.dist += sse_diff;
       }
       this_rd_stats.sse = block_sse;
@@ -2295,9 +2307,7 @@
       best_tx_type = tx_type;
       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
       best_eob = x->plane[plane].eobs[block];
-      last_tx_type = best_tx_type;
-
-      // Swap qcoeff and dqcoeff buffers
+      // Swap dqcoeff buffers
       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
       best_dqcoeff = pd->dqcoeff;
       pd->dqcoeff = tmp_dqcoeff;
@@ -2345,6 +2355,8 @@
     }
 #endif  // COLLECT_TX_SIZE_DATA
 
+    // If the current best RD cost is much worse than the reference RD cost,
+    // terminate early.
     if (cpi->sf.tx_sf.adaptive_txb_search_level) {
       if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
           ref_best_rd) {
@@ -2352,8 +2364,8 @@
       }
     }
 
-    // Skip transform type search when we found the block has been quantized to
-    // all zero and at the same time, it has better rdcost than doing transform.
+    // Terminate transform type search if the block has been quantized to
+    // all zero.
     if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
   }
 
@@ -2364,6 +2376,9 @@
   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
   x->plane[plane].eobs[block] = best_eob;
 
+  // Point dqcoeff to the quantized coefficients corresponding to the best
+  // transform type, then we can skip transform and quantization, e.g. in the
+  // final pixel domain distortion calculation and recon_intra().
   pd->dqcoeff = best_dqcoeff;
 
   if (calc_pixel_domain_distortion_final && best_eob) {
@@ -2384,9 +2399,10 @@
     if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
   }
 
+  // Intra mode needs decoded pixels such that the next transform block
+  // can use them for prediction.
   recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
-              txb_ctx, skip_trellis, best_tx_type, last_tx_type, &rate_cost,
-              best_eob);
+              txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
   pd->dqcoeff = orig_dqcoeff;
 }
 
@@ -2426,9 +2442,9 @@
 
   RD_STATS this_rd_stats;
   const int skip_trellis = 0;
-  search_txk_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
-                  txb_ctx, ftxs_mode, 0, skip_trellis, ref_rdcost,
-                  &this_rd_stats);
+  search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
+                 txb_ctx, ftxs_mode, 0, skip_trellis, ref_rdcost,
+                 &this_rd_stats);
 
   av1_merge_rd_stats(rd_stats, &this_rd_stats);
 
@@ -2806,10 +2822,10 @@
   }
   TXB_CTX txb_ctx;
   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
-  search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
-                  &txb_ctx, args->ftxs_mode, args->use_fast_coef_costing,
-                  args->skip_trellis, args->best_rd - args->this_rd,
-                  &this_rd_stats);
+  search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+                 &txb_ctx, args->ftxs_mode, args->use_fast_coef_costing,
+                 args->skip_trellis, args->best_rd - args->this_rd,
+                 &this_rd_stats);
 
   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
     assert(!is_inter || plane_bsize < BLOCK_8X8);