select_tx_block: function for split trial

Change-Id: Iffe455375ed83e0a13d5b86d4d66dbcd13c9adaf
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index dba6ca0..14b157e 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4388,6 +4388,174 @@
   no_split->tx_type = mbmi->txk_type[txk_type_idx];
 }
 
+static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
+                            int blk_col, int block, TX_SIZE tx_size, int depth,
+                            BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
+                            ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above,
+                            TXFM_CONTEXT *tx_left, RD_STATS *rd_stats,
+                            int64_t ref_best_rd, int *is_cost_valid,
+                            FAST_TX_SEARCH_MODE ftxs_mode,
+                            TXB_RD_INFO_NODE *rd_info_node);
+
+static void try_tx_block_split(
+    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
+    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
+    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
+    int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
+    FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
+    RD_STATS *split_rd_stats, int64_t *split_rd) {
+  MACROBLOCKD *const xd = &x->e_mbd;
+  const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
+  const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
+  struct macroblock_plane *const p = &x->plane[0];
+  const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
+  const int bsw = tx_size_wide_unit[sub_txs];
+  const int bsh = tx_size_high_unit[sub_txs];
+  const int sub_step = bsw * bsh;
+  RD_STATS this_rd_stats;
+  int this_cost_valid = 1;
+  int64_t tmp_rd = 0;
+#if CONFIG_DIST_8X8
+  int sub8x8_eob[4] = { 0, 0, 0, 0 };
+  struct macroblockd_plane *const pd = &xd->plane[0];
+#endif
+  split_rd_stats->rate = x->txfm_partition_cost[txfm_partition_ctx][1];
+
+  assert(tx_size < TX_SIZES_ALL);
+
+  int blk_idx = 0;
+  for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+    for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) {
+      const int offsetr = blk_row + r;
+      const int offsetc = blk_col + c;
+      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+      assert(blk_idx < 4);
+      select_tx_block(
+          cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
+          tl, tx_above, tx_left, &this_rd_stats, ref_best_rd - tmp_rd,
+          &this_cost_valid, ftxs_mode,
+          (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
+
+#if CONFIG_DIST_8X8
+      if (!x->using_dist_8x8)
+#endif
+        if (!this_cost_valid) goto LOOP_EXIT;
+#if CONFIG_DIST_8X8
+      if (x->using_dist_8x8 && tx_size == TX_8X8) {
+        sub8x8_eob[2 * (r / bsh) + (c / bsw)] = p->eobs[block];
+      }
+#endif  // CONFIG_DIST_8X8
+      av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
+
+      tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
+#if CONFIG_DIST_8X8
+      if (!x->using_dist_8x8)
+#endif
+        if (no_split_rd < tmp_rd) {
+          this_cost_valid = 0;
+          goto LOOP_EXIT;
+        }
+      block += sub_step;
+    }
+  }
+
+LOOP_EXIT : {}
+
+#if CONFIG_DIST_8X8
+  if (x->using_dist_8x8 && this_cost_valid && tx_size == TX_8X8) {
+    const int src_stride = p->src.stride;
+    const int dst_stride = pd->dst.stride;
+
+    const uint8_t *src =
+        &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
+    const uint8_t *dst =
+        &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
+
+    int64_t dist_8x8;
+    const int qindex = x->qindex;
+    const int pred_stride = block_size_wide[plane_bsize];
+    const int pred_idx = (blk_row * pred_stride + blk_col)
+                         << tx_size_wide_log2[0];
+    const int16_t *pred = &x->pred_luma[pred_idx];
+    int i, j;
+    int row, col;
+
+    uint8_t *pred8;
+    DECLARE_ALIGNED(16, uint16_t, pred8_16[8 * 8]);
+
+    dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride, BLOCK_8X8,
+                            8, 8, 8, 8, qindex) *
+               16;
+
+#ifdef DEBUG_DIST_8X8
+    if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8)
+      assert(sum_rd_stats.sse == dist_8x8);
+#endif  // DEBUG_DIST_8X8
+
+    split_rd_stats->sse = dist_8x8;
+
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+      pred8 = CONVERT_TO_BYTEPTR(pred8_16);
+    else
+      pred8 = (uint8_t *)pred8_16;
+
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+      for (row = 0; row < 2; ++row) {
+        for (col = 0; col < 2; ++col) {
+          int idx = row * 2 + col;
+          int eob = sub8x8_eob[idx];
+
+          if (eob > 0) {
+            for (j = 0; j < 4; j++)
+              for (i = 0; i < 4; i++)
+                CONVERT_TO_SHORTPTR(pred8)
+                [(row * 4 + j) * 8 + 4 * col + i] =
+                    pred[(row * 4 + j) * pred_stride + 4 * col + i];
+          } else {
+            for (j = 0; j < 4; j++)
+              for (i = 0; i < 4; i++)
+                CONVERT_TO_SHORTPTR(pred8)
+                [(row * 4 + j) * 8 + 4 * col + i] = CONVERT_TO_SHORTPTR(
+                    dst)[(row * 4 + j) * dst_stride + 4 * col + i];
+          }
+        }
+      }
+    } else {
+      for (row = 0; row < 2; ++row) {
+        for (col = 0; col < 2; ++col) {
+          int idx = row * 2 + col;
+          int eob = sub8x8_eob[idx];
+
+          if (eob > 0) {
+            for (j = 0; j < 4; j++)
+              for (i = 0; i < 4; i++)
+                pred8[(row * 4 + j) * 8 + 4 * col + i] =
+                    (uint8_t)pred[(row * 4 + j) * pred_stride + 4 * col + i];
+          } else {
+            for (j = 0; j < 4; j++)
+              for (i = 0; i < 4; i++)
+                pred8[(row * 4 + j) * 8 + 4 * col + i] =
+                    dst[(row * 4 + j) * dst_stride + 4 * col + i];
+          }
+        }
+      }
+    }
+    dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, pred8, 8, BLOCK_8X8, 8, 8,
+                            8, 8, qindex) *
+               16;
+
+#ifdef DEBUG_DIST_8X8
+    if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8)
+      assert(sum_rd_stats.dist == dist_8x8);
+#endif  // DEBUG_DIST_8X8
+
+    split_rd_stats->dist = dist_8x8;
+    tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
+  }
+#endif  // CONFIG_DIST_8X8
+  if (this_cost_valid) *split_rd = tmp_rd;
+}
+
 // Search for the best tx partition/type for a given luma block.
 static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
                             int blk_col, int block, TX_SIZE tx_size, int depth,
@@ -4458,155 +4626,10 @@
   RD_STATS split_rd_stats;
   av1_init_rd_stats(&split_rd_stats);
   if (try_split) {
-    const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsw = tx_size_wide_unit[sub_txs];
-    const int bsh = tx_size_high_unit[sub_txs];
-    const int sub_step = bsw * bsh;
-    RD_STATS this_rd_stats;
-    int this_cost_valid = 1;
-    int64_t tmp_rd = 0;
-#if CONFIG_DIST_8X8
-    int sub8x8_eob[4] = { 0, 0, 0, 0 };
-    struct macroblockd_plane *const pd = &xd->plane[0];
-#endif
-    split_rd_stats.rate = x->txfm_partition_cost[ctx][1];
-
-    assert(tx_size < TX_SIZES_ALL);
-
-    ref_best_rd = AOMMIN(no_split.rd, ref_best_rd);
-
-    int blk_idx = 0;
-    for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
-      for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) {
-        const int offsetr = blk_row + r;
-        const int offsetc = blk_col + c;
-        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
-        assert(blk_idx < 4);
-        select_tx_block(
-            cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize,
-            ta, tl, tx_above, tx_left, &this_rd_stats, ref_best_rd - tmp_rd,
-            &this_cost_valid, ftxs_mode,
-            (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
-
-#if CONFIG_DIST_8X8
-        if (!x->using_dist_8x8)
-#endif
-          if (!this_cost_valid) goto LOOP_EXIT;
-#if CONFIG_DIST_8X8
-        if (x->using_dist_8x8 && tx_size == TX_8X8) {
-          sub8x8_eob[2 * (r / bsh) + (c / bsw)] = p->eobs[block];
-        }
-#endif  // CONFIG_DIST_8X8
-        av1_merge_rd_stats(&split_rd_stats, &this_rd_stats);
-
-        tmp_rd = RDCOST(x->rdmult, split_rd_stats.rate, split_rd_stats.dist);
-#if CONFIG_DIST_8X8
-        if (!x->using_dist_8x8)
-#endif
-          if (no_split.rd < tmp_rd) {
-            this_cost_valid = 0;
-            goto LOOP_EXIT;
-          }
-        block += sub_step;
-      }
-    }
-
-  LOOP_EXIT : {}
-
-#if CONFIG_DIST_8X8
-    if (x->using_dist_8x8 && this_cost_valid && tx_size == TX_8X8) {
-      const int src_stride = p->src.stride;
-      const int dst_stride = pd->dst.stride;
-
-      const uint8_t *src =
-          &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
-      const uint8_t *dst =
-          &pd->dst
-               .buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
-
-      int64_t dist_8x8;
-      const int qindex = x->qindex;
-      const int pred_stride = block_size_wide[plane_bsize];
-      const int pred_idx = (blk_row * pred_stride + blk_col)
-                           << tx_size_wide_log2[0];
-      const int16_t *pred = &x->pred_luma[pred_idx];
-      int i, j;
-      int row, col;
-
-      uint8_t *pred8;
-      DECLARE_ALIGNED(16, uint16_t, pred8_16[8 * 8]);
-
-      dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride,
-                              BLOCK_8X8, 8, 8, 8, 8, qindex) *
-                 16;
-
-#ifdef DEBUG_DIST_8X8
-      if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8)
-        assert(sum_rd_stats.sse == dist_8x8);
-#endif  // DEBUG_DIST_8X8
-
-      split_rd_stats.sse = dist_8x8;
-
-      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-        pred8 = CONVERT_TO_BYTEPTR(pred8_16);
-      else
-        pred8 = (uint8_t *)pred8_16;
-
-      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
-        for (row = 0; row < 2; ++row) {
-          for (col = 0; col < 2; ++col) {
-            int idx = row * 2 + col;
-            int eob = sub8x8_eob[idx];
-
-            if (eob > 0) {
-              for (j = 0; j < 4; j++)
-                for (i = 0; i < 4; i++)
-                  CONVERT_TO_SHORTPTR(pred8)
-                  [(row * 4 + j) * 8 + 4 * col + i] =
-                      pred[(row * 4 + j) * pred_stride + 4 * col + i];
-            } else {
-              for (j = 0; j < 4; j++)
-                for (i = 0; i < 4; i++)
-                  CONVERT_TO_SHORTPTR(pred8)
-                  [(row * 4 + j) * 8 + 4 * col + i] = CONVERT_TO_SHORTPTR(
-                      dst)[(row * 4 + j) * dst_stride + 4 * col + i];
-            }
-          }
-        }
-      } else {
-        for (row = 0; row < 2; ++row) {
-          for (col = 0; col < 2; ++col) {
-            int idx = row * 2 + col;
-            int eob = sub8x8_eob[idx];
-
-            if (eob > 0) {
-              for (j = 0; j < 4; j++)
-                for (i = 0; i < 4; i++)
-                  pred8[(row * 4 + j) * 8 + 4 * col + i] =
-                      (uint8_t)pred[(row * 4 + j) * pred_stride + 4 * col + i];
-            } else {
-              for (j = 0; j < 4; j++)
-                for (i = 0; i < 4; i++)
-                  pred8[(row * 4 + j) * 8 + 4 * col + i] =
-                      dst[(row * 4 + j) * dst_stride + 4 * col + i];
-            }
-          }
-        }
-      }
-      dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, pred8, 8, BLOCK_8X8, 8,
-                              8, 8, 8, qindex) *
-                 16;
-
-#ifdef DEBUG_DIST_8X8
-      if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8)
-        assert(sum_rd_stats.dist == dist_8x8);
-#endif  // DEBUG_DIST_8X8
-
-      split_rd_stats.dist = dist_8x8;
-      tmp_rd = RDCOST(x->rdmult, split_rd_stats.rate, split_rd_stats.dist);
-    }
-#endif  // CONFIG_DIST_8X8
-    if (this_cost_valid) split_rd = tmp_rd;
+    try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
+                       plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
+                       AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
+                       rd_info_node, &split_rd_stats, &split_rd);
   }
 
 #if COLLECT_TX_SIZE_DATA