Expand choose_tx_size_fix_type

This CL actually makes the code more complicated but it will allow
us to break the framework of searching tx_size with fixed tx_type

I will find a way to simplify the code later.

Change-Id: Iae933a40d0c7eb9ec65b34ebfd9d543423f304aa
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 404a908..ae58a76 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1929,6 +1929,9 @@
     // transforms should be considered for pruning
     prune = prune_tx_types(cpi, bs, x, xd, -1);
 
+#if CONFIG_REF_MV
+  if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) return 1;
+#endif  // CONFIG_REF_MV
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
   if (!is_rect_tx(tx_size)) return 1;
 #endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -1965,97 +1968,6 @@
   return 0;
 }
 
-static int64_t choose_tx_size_fix_type(const AV1_COMP *const cpi, BLOCK_SIZE bs,
-                                       MACROBLOCK *x, RD_STATS *rd_stats,
-                                       int64_t ref_best_rd, TX_TYPE tx_type
-#if CONFIG_PVQ
-                                       ,
-                                       od_rollback_buffer buf
-#endif  // CONFIG_PVQ
-                                       ) {
-  const AV1_COMMON *const cm = &cpi->common;
-  MACROBLOCKD *const xd = &x->e_mbd;
-  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
-  int64_t rd = INT64_MAX;
-  int n;
-  int start_tx, end_tx;
-  int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
-  const TX_SIZE max_tx_size = max_txsize_lookup[bs];
-  TX_SIZE best_tx_size = max_tx_size;
-  const int tx_select = cm->tx_mode == TX_MODE_SELECT;
-  const int is_inter = is_inter_block(mbmi);
-#if CONFIG_EXT_TX
-#if CONFIG_RECT_TX
-  int evaluate_rect_tx = 0;
-#endif  // CONFIG_RECT_TX
-  int ext_tx_set;
-#endif  // CONFIG_EXT_TX
-
-  if (tx_select) {
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-    evaluate_rect_tx = is_rect_tx_allowed(xd, mbmi);
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
-    start_tx = max_tx_size;
-    end_tx = (max_tx_size >= TX_32X32) ? TX_8X8 : TX_4X4;
-  } else {
-    const TX_SIZE chosen_tx_size =
-        tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-    evaluate_rect_tx = is_rect_tx(chosen_tx_size);
-    assert(IMPLIES(evaluate_rect_tx, is_rect_tx_allowed(xd, mbmi)));
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
-    start_tx = chosen_tx_size;
-    end_tx = chosen_tx_size;
-  }
-
-  av1_invalid_rd_stats(rd_stats);
-
-  mbmi->tx_type = tx_type;
-
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-  if (evaluate_rect_tx) {
-    const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
-    RD_STATS this_rd_stats;
-    ext_tx_set =
-        get_ext_tx_set(rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
-    if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
-        (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
-      rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type,
-                    rect_tx_size);
-      best_tx_size = rect_tx_size;
-      best_rd = rd;
-      *rd_stats = this_rd_stats;
-    }
-  }
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
-
-  last_rd = INT64_MAX;
-  for (n = start_tx; n >= end_tx; --n) {
-    RD_STATS this_rd_stats;
-    if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue;
-    rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n);
-#if CONFIG_PVQ
-    od_encode_rollback(&x->daala_enc, &buf);
-#endif  // CONFIG_PVQ
-    // Early termination in transform size search.
-    if (cpi->sf.tx_size_search_breakout &&
-        (rd == INT64_MAX ||
-         (this_rd_stats.skip == 1 && tx_type != DCT_DCT && n < start_tx) ||
-         (n < (int)max_tx_size && rd > last_rd)))
-      break;
-
-    last_rd = rd;
-    if (rd < best_rd) {
-      best_tx_size = n;
-      best_rd = rd;
-      *rd_stats = this_rd_stats;
-    }
-  }
-  mbmi->tx_size = best_tx_size;
-
-  return best_rd;
-}
-
 #if CONFIG_EXT_INTER
 static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
                                    MACROBLOCK *x, int *r, int64_t *d, int *s,
@@ -2256,46 +2168,105 @@
 static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
                                         MACROBLOCK *x, RD_STATS *rd_stats,
                                         int64_t ref_best_rd, BLOCK_SIZE bs) {
+  const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   int64_t rd = INT64_MAX;
-  int64_t best_rd = INT64_MAX;
-  TX_SIZE best_tx = max_txsize_lookup[bs];
-  TX_TYPE tx_type, best_tx_type = DCT_DCT;
-
+  int n;
+  int start_tx, end_tx;
+  int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
+  const TX_SIZE max_tx_size = max_txsize_lookup[bs];
+  TX_SIZE best_tx_size = max_tx_size;
+  TX_TYPE best_tx_type = DCT_DCT;
+  const int tx_select = cm->tx_mode == TX_MODE_SELECT;
+  const int is_inter = is_inter_block(mbmi);
 #if CONFIG_PVQ
   od_rollback_buffer buf;
-#endif  // CONFIG_PVQ
-  av1_invalid_rd_stats(rd_stats);
-
-#if CONFIG_PVQ
   od_encode_checkpoint(&x->daala_enc, &buf);
 #endif  // CONFIG_PVQ
-  for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
-    RD_STATS this_rd_stats;
+
+  av1_invalid_rd_stats(rd_stats);
+
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+  int evaluate_rect_tx = 0;
+  if (tx_select) {
+    evaluate_rect_tx = is_rect_tx_allowed(xd, mbmi);
+  } else {
+    const TX_SIZE chosen_tx_size =
+        tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
+    evaluate_rect_tx = is_rect_tx(chosen_tx_size);
+    assert(IMPLIES(evaluate_rect_tx, is_rect_tx_allowed(xd, mbmi)));
+  }
+  if (evaluate_rect_tx) {
+    TX_TYPE tx_type;
+    for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
 #if CONFIG_REF_MV
-    if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
+      if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
 #endif  // CONFIG_REF_MV
-    rd =
-        choose_tx_size_fix_type(cpi, bs, x, &this_rd_stats, ref_best_rd, tx_type
-#if CONFIG_PVQ
-                                ,
-                                buf
-#endif  // CONFIG_PVQ
-                                );
-    if (rd < best_rd) {
-      best_rd = rd;
-      *rd_stats = this_rd_stats;
-      best_tx_type = tx_type;
-      best_tx = mbmi->tx_size;
-    }
+      const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
+      RD_STATS this_rd_stats;
+      int ext_tx_set =
+          get_ext_tx_set(rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
+      if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
+          (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
+        rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type,
+                      rect_tx_size);
+        if (rd < best_rd) {
+          best_tx_type = tx_type;
+          best_tx_size = rect_tx_size;
+          best_rd = rd;
+          *rd_stats = this_rd_stats;
+        }
+      }
 #if CONFIG_CB4X4 && !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
-    const int is_inter = is_inter_block(mbmi);
-    if (mbmi->sb_type < BLOCK_8X8 && is_inter) break;
+      const int is_inter = is_inter_block(mbmi);
+      if (mbmi->sb_type < BLOCK_8X8 && is_inter) break;
 #endif  // CONFIG_CB4X4 && !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
+    }
+  }
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
+
+  if (tx_select) {
+    start_tx = max_tx_size;
+    end_tx = (max_tx_size >= TX_32X32) ? TX_8X8 : TX_4X4;
+  } else {
+    const TX_SIZE chosen_tx_size =
+        tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
+    start_tx = chosen_tx_size;
+    end_tx = chosen_tx_size;
   }
 
-  mbmi->tx_size = best_tx;
+  last_rd = INT64_MAX;
+  for (n = start_tx; n >= end_tx; --n) {
+    TX_TYPE tx_type;
+    for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
+      RD_STATS this_rd_stats;
+      if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue;
+      rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n);
+#if CONFIG_PVQ
+      od_encode_rollback(&x->daala_enc, &buf);
+#endif  // CONFIG_PVQ
+      // Early termination in transform size search.
+      if (cpi->sf.tx_size_search_breakout &&
+          (rd == INT64_MAX ||
+           (this_rd_stats.skip == 1 && tx_type != DCT_DCT && n < start_tx) ||
+           (n < (int)max_tx_size && rd > last_rd)))
+        break;
+
+      last_rd = rd;
+      if (rd < best_rd) {
+        best_tx_type = tx_type;
+        best_tx_size = n;
+        best_rd = rd;
+        *rd_stats = this_rd_stats;
+      }
+#if CONFIG_CB4X4 && !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
+      const int is_inter = is_inter_block(mbmi);
+      if (mbmi->sb_type < BLOCK_8X8 && is_inter) break;
+#endif  // CONFIG_CB4X4 && !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
+    }
+  }
+  mbmi->tx_size = best_tx_size;
   mbmi->tx_type = best_tx_type;
 
 #if CONFIG_VAR_TX
@@ -2307,7 +2278,7 @@
 #endif  // !CONFIG_EXT_TX
 #if CONFIG_PVQ
   if (best_rd != INT64_MAX) {
-    txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs, best_tx_type, best_tx);
+    txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs, best_tx_type, best_tx_size);
   }
 #endif  // CONFIG_PVQ
 }