Fix enable-tx64 command line option

Fix an issue where the choose_largest_tx_size path does not take
enable_tx64 into account, which effectively disables bsize >=
BLOCK_64X64.

This issue probably still persists with chroma 444 images even after
this commit.

Change-Id: Id030a3424038f36b704e2f0db2f508f9c4bdfc66
diff --git a/av1/av1_cx_iface.c b/av1/av1_cx_iface.c
index ed85cf5..140ec38 100644
--- a/av1/av1_cx_iface.c
+++ b/av1/av1_cx_iface.c
@@ -519,6 +519,10 @@
 
   RANGE_CHECK(extra_cfg, tx_size_search_method, 0, 2);
 
+  if (!extra_cfg->enable_tx64 && extra_cfg->tx_size_search_method == 2) {
+    ERROR("TX64 cannot be disabled when search_method is USE_LARGESTALL (2).");
+  }
+
   for (int i = 0; i < MAX_NUM_OPERATING_POINTS; ++i) {
     const int level_idx = extra_cfg->target_seq_level_idx[i];
     if (!is_valid_seq_level_idx(level_idx) && level_idx != SEQ_LEVELS) {
@@ -559,6 +563,10 @@
   if (img->d_w != ctx->cfg.g_w || img->d_h != ctx->cfg.g_h)
     ERROR("Image size must match encoder init configuration size");
 
+  if (img->fmt != AOM_IMG_FMT_I420 && !ctx->extra_cfg.enable_tx64) {
+    ERROR("TX64 can only be disabled on I420 images.");
+  }
+
   return AOM_CODEC_OK;
 }
 
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 76127fd..da80a30 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -451,7 +451,11 @@
   // Used to control the tx size search evaluation for mode processing
   // (normal/winner mode)
   int tx_size_search_method;
-  TX_MODE tx_mode;
+  // This tx_mode_search_type is used internally by the encoder, and is not
+  // written to the bitstream. It determines what kind of tx_mode should be
+  // searched. For example, we might set it to TX_MODE_LARGEST to find a good
+  // candidate, then use TX_MODE_SELECT on it
+  TX_MODE tx_mode_search_type;
 
   // Used to control aggressiveness of skip flag prediction for mode processing
   // (normal/winner mode)
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index dad696d..0b7ebcf 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -524,7 +524,7 @@
           seg->update_map ? cpi->segmentation_map : cm->last_frame_seg_map;
       mi_addr->segment_id =
           map ? get_segment_id(cm, map, bsize, mi_row, mi_col) : 0;
-      reset_tx_size(x, mi_addr, x->tx_mode);
+      reset_tx_size(x, mi_addr, x->tx_mode_search_type);
     }
     // Else for cyclic refresh mode update the segment map, set the segment id
     // and then update the quantizer.
@@ -5165,8 +5165,6 @@
   cm->coded_lossless = is_coded_lossless(cm, xd);
   cm->all_lossless = cm->coded_lossless && !av1_superres_scaled(cm);
 
-  cm->tx_mode = get_eval_tx_mode(cpi, DEFAULT_EVAL);
-
   // Fix delta q resolution for the moment
   cm->delta_q_info.delta_q_res = 0;
   if (cpi->oxcf.deltaq_mode == DELTA_Q_OBJECTIVE)
@@ -5348,7 +5346,13 @@
   }
 
   // Set the transform size appropriately before bitstream creation
-  cm->tx_mode = get_eval_tx_mode(cpi, WINNER_MODE_EVAL);
+  const MODE_EVAL_TYPE eval_type = cpi->sf.enable_winner_mode_for_tx_size_srch
+                                       ? WINNER_MODE_EVAL
+                                       : DEFAULT_EVAL;
+  const TX_SIZE_SEARCH_METHOD tx_search_type =
+      cpi->tx_size_search_methods[eval_type];
+  assert(cpi->oxcf.enable_tx64 || tx_search_type != USE_LARGESTALL);
+  cm->tx_mode = select_tx_mode(cpi, tx_search_type);
 
   if (cpi->sf.tx_type_search.prune_tx_type_using_stats) {
     const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
@@ -5759,8 +5763,9 @@
 
   if (!dry_run) {
     if (av1_allow_intrabc(cm) && is_intrabc_block(mbmi)) td->intrabc_used = 1;
-    if (x->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id] &&
-        mbmi->sb_type > BLOCK_4X4 && !(is_inter && (mbmi->skip || seg_skip))) {
+    if (x->tx_mode_search_type == TX_MODE_SELECT &&
+        !xd->lossless[mbmi->segment_id] && mbmi->sb_type > BLOCK_4X4 &&
+        !(is_inter && (mbmi->skip || seg_skip))) {
       if (is_inter) {
         tx_partition_count_update(cm, x, bsize, mi_row, mi_col, td->counts,
                                   tile_data->allow_update_cdf);
@@ -5790,7 +5795,7 @@
         if (xd->lossless[mbmi->segment_id]) {
           intra_tx_size = TX_4X4;
         } else {
-          intra_tx_size = tx_size_from_tx_mode(bsize, x->tx_mode);
+          intra_tx_size = tx_size_from_tx_mode(bsize, x->tx_mode_search_type);
         }
       } else {
         intra_tx_size = mbmi->tx_size;
@@ -5805,9 +5810,9 @@
     }
   }
 
-  if (x->tx_mode == TX_MODE_SELECT && block_signals_txsize(mbmi->sb_type) &&
-      is_inter && !(mbmi->skip || seg_skip) &&
-      !xd->lossless[mbmi->segment_id]) {
+  if (x->tx_mode_search_type == TX_MODE_SELECT &&
+      block_signals_txsize(mbmi->sb_type) && is_inter &&
+      !(mbmi->skip || seg_skip) && !xd->lossless[mbmi->segment_id]) {
     if (dry_run) tx_partition_set_contexts(cm, xd, bsize, mi_row, mi_col);
   } else {
     TX_SIZE tx_size = mbmi->tx_size;
@@ -5816,7 +5821,7 @@
       if (xd->lossless[mbmi->segment_id]) {
         tx_size = TX_4X4;
       } else {
-        tx_size = tx_size_from_tx_mode(bsize, x->tx_mode);
+        tx_size = tx_size_from_tx_mode(bsize, x->tx_mode_search_type);
       }
     } else {
       tx_size = (bsize > BLOCK_4X4) ? tx_size : TX_4X4;
diff --git a/av1/encoder/nonrd_pickmode.c b/av1/encoder/nonrd_pickmode.c
index c2e9a4e..12b1cf2 100644
--- a/av1/encoder/nonrd_pickmode.c
+++ b/av1/encoder/nonrd_pickmode.c
@@ -502,10 +502,10 @@
                                  unsigned int sse) {
   MACROBLOCKD *const xd = &x->e_mbd;
   TX_SIZE tx_size;
-  if (x->tx_mode == TX_MODE_SELECT) {
+  if (x->tx_mode_search_type == TX_MODE_SELECT) {
     if (sse > (var << 2))
       tx_size = AOMMIN(max_txsize_lookup[bsize],
-                       tx_mode_to_biggest_tx_size[x->tx_mode]);
+                       tx_mode_to_biggest_tx_size[x->tx_mode_search_type]);
     else
       tx_size = TX_8X8;
 
@@ -516,7 +516,7 @@
       tx_size = TX_16X16;
   } else {
     tx_size = AOMMIN(max_txsize_lookup[bsize],
-                     tx_mode_to_biggest_tx_size[x->tx_mode]);
+                     tx_mode_to_biggest_tx_size[x->tx_mode_search_type]);
   }
   if (bsize > BLOCK_32X32) tx_size = TX_16X16;
   return AOMMIN(tx_size, TX_16X16);
@@ -1557,9 +1557,10 @@
 #endif
     init_mbmi(mi, this_mode, ref_frame, NONE_FRAME, cm);
 
-    mi->tx_size = AOMMIN(AOMMIN(max_txsize_lookup[bsize],
-                                tx_mode_to_biggest_tx_size[x->tx_mode]),
-                         TX_16X16);
+    mi->tx_size =
+        AOMMIN(AOMMIN(max_txsize_lookup[bsize],
+                      tx_mode_to_biggest_tx_size[x->tx_mode_search_type]),
+               TX_16X16);
     memset(mi->inter_tx_size, mi->tx_size, sizeof(mi->inter_tx_size));
     memset(xd->tx_type_map, DCT_DCT,
            sizeof(xd->tx_type_map[0]) * ctx->num_4x4_blk);
@@ -1901,7 +1902,7 @@
     PRED_BUFFER *const best_pred = best_pickmode.best_pred;
     TX_SIZE intra_tx_size =
         AOMMIN(AOMMIN(max_txsize_lookup[bsize],
-                      tx_mode_to_biggest_tx_size[x->tx_mode]),
+                      tx_mode_to_biggest_tx_size[x->tx_mode_search_type]),
                TX_16X16);
 
     if (reuse_inter_pred && best_pred != NULL) {
diff --git a/av1/encoder/rd.h b/av1/encoder/rd.h
index 4bb7aa9..cc1484f 100644
--- a/av1/encoder/rd.h
+++ b/av1/encoder/rd.h
@@ -51,9 +51,10 @@
 #define SWITCHABLE_INTERP_RATE_FACTOR 1
 
 enum {
-  // Default initialization
+  // Default initialization when we are not using winner mode framework. e.g.
+  // intrabc
   DEFAULT_EVAL = 0,
-  // Initialization for default mode evaluation
+  // Initialization for selecting winner mode
   MODE_EVAL,
   // Initialization for winner mode evaluation
   WINNER_MODE_EVAL,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 10d955f..6120387 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3821,7 +3821,8 @@
 static int tx_size_cost(const MACROBLOCK *const x, BLOCK_SIZE bsize,
                         TX_SIZE tx_size) {
   assert(bsize == x->e_mbd.mi[0]->sb_type);
-  if (x->tx_mode != TX_MODE_SELECT || !block_signals_txsize(bsize)) return 0;
+  if (x->tx_mode_search_type != TX_MODE_SELECT || !block_signals_txsize(bsize))
+    return 0;
 
   const int32_t tx_size_cat = bsize_to_tx_size_cat(bsize);
   const int depth = tx_size_to_depth(tx_size, bsize);
@@ -3840,8 +3841,8 @@
   const int skip_ctx = av1_get_skip_context(xd);
   int s0, s1;
   const int is_inter = is_inter_block(mbmi);
-  const int tx_select =
-      x->tx_mode == TX_MODE_SELECT && block_signals_txsize(mbmi->sb_type);
+  const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT &&
+                        block_signals_txsize(mbmi->sb_type);
   int ctx = txfm_partition_context(
       xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
   const int r_tx_size =
@@ -3920,7 +3921,35 @@
                                               BLOCK_SIZE bs) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
-  mbmi->tx_size = tx_size_from_tx_mode(bs, x->tx_mode);
+  mbmi->tx_size = tx_size_from_tx_mode(bs, x->tx_mode_search_type);
+
+  // If tx64 is not enabled, we need to go down to the next available size
+  if (!cpi->oxcf.enable_tx64) {
+    static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
+      TX_4X4,    // 4x4 transform
+      TX_8X8,    // 8x8 transform
+      TX_16X16,  // 16x16 transform
+      TX_32X32,  // 32x32 transform
+      TX_32X32,  // 64x64 transform
+      TX_4X8,    // 4x8 transform
+      TX_8X4,    // 8x4 transform
+      TX_8X16,   // 8x16 transform
+      TX_16X8,   // 16x8 transform
+      TX_16X32,  // 16x32 transform
+      TX_32X16,  // 32x16 transform
+      TX_32X32,  // 32x64 transform
+      TX_32X32,  // 64x32 transform
+      TX_4X16,   // 4x16 transform
+      TX_16X4,   // 16x4 transform
+      TX_8X32,   // 8x32 transform
+      TX_32X8,   // 32x8 transform
+      TX_16X32,  // 16x64 transform
+      TX_32X16,  // 64x16 transform
+    };
+
+    mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
+  }
+
   const int skip_ctx = av1_get_skip_context(xd);
   int s0, s1;
 
@@ -3987,7 +4016,7 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
-  const int tx_select = x->tx_mode == TX_MODE_SELECT;
+  const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT;
   int start_tx;
   int depth, init_depth;
 
@@ -3997,7 +4026,8 @@
                                        is_inter_block(mbmi), &cpi->sf,
                                        x->tx_size_search_method);
   } else {
-    const TX_SIZE chosen_tx_size = tx_size_from_tx_mode(bs, x->tx_mode);
+    const TX_SIZE chosen_tx_size =
+        tx_size_from_tx_mode(bs, x->tx_mode_search_type);
     start_tx = chosen_tx_size;
     init_depth = MAX_TX_DEPTH;
   }
@@ -4431,7 +4461,7 @@
   RD_STATS this_rd_stats;
   int row, col;
   int64_t temp_sse, this_rd;
-  TX_SIZE tx_size = tx_size_from_tx_mode(bsize, x->tx_mode);
+  TX_SIZE tx_size = tx_size_from_tx_mode(bsize, x->tx_mode_search_type);
   const int stepr = tx_size_high_unit[tx_size];
   const int stepc = tx_size_wide_unit[tx_size];
   const int max_blocks_wide = max_block_wide(xd, bsize, 0);
@@ -9469,7 +9499,8 @@
 
   // cost and distortion
   av1_subtract_plane(x, bsize, 0);
-  if (x->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+  if (x->tx_mode_search_type == TX_MODE_SELECT &&
+      !xd->lossless[mbmi->segment_id]) {
     pick_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
 #if CONFIG_COLLECT_RD_STATS == 2
     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
@@ -11705,8 +11736,9 @@
 
     // Set up tx_size related variables for skip-specific loop filtering.
     search_state->best_mbmode.tx_size =
-        block_signals_txsize(bsize) ? tx_size_from_tx_mode(bsize, x->tx_mode)
-                                    : max_txsize_rect_lookup[bsize];
+        block_signals_txsize(bsize)
+            ? tx_size_from_tx_mode(bsize, x->tx_mode_search_type)
+            : max_txsize_rect_lookup[bsize];
     memset(search_state->best_mbmode.inter_tx_size,
            search_state->best_mbmode.tx_size,
            sizeof(search_state->best_mbmode.inter_tx_size));
@@ -11788,7 +11820,8 @@
           av1_build_obmc_inter_predictors_sb(cm, xd);
 
         av1_subtract_plane(x, bsize, 0);
-        if (x->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+        if (x->tx_mode_search_type == TX_MODE_SELECT &&
+            !xd->lossless[mbmi->segment_id]) {
           pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
           assert(rd_stats_y.rate != INT_MAX);
         } else {
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 381432f..e953c8c 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -211,24 +211,13 @@
 static TX_MODE select_tx_mode(
     const AV1_COMP *cpi, const TX_SIZE_SEARCH_METHOD tx_size_search_method) {
   if (cpi->common.coded_lossless) return ONLY_4X4;
-  if (tx_size_search_method == USE_LARGESTALL)
+  if (tx_size_search_method == USE_LARGESTALL) {
     return TX_MODE_LARGEST;
-  else if (tx_size_search_method == USE_FULL_RD ||
-           tx_size_search_method == USE_FAST_RD)
+  } else {
+    assert(tx_size_search_method == USE_FULL_RD ||
+           tx_size_search_method == USE_FAST_RD);
     return TX_MODE_SELECT;
-  else
-    return cpi->common.tx_mode;
-}
-
-static INLINE TX_MODE get_eval_tx_mode(const AV1_COMP *cpi,
-                                       MODE_EVAL_TYPE eval_type) {
-  TX_MODE tx_mode;
-  if (cpi->sf.enable_winner_mode_for_tx_size_srch)
-    tx_mode = select_tx_mode(cpi, cpi->tx_size_search_methods[eval_type]);
-  else
-    tx_mode = select_tx_mode(cpi, cpi->tx_size_search_methods[DEFAULT_EVAL]);
-
-  return tx_mode;
+  }
 }
 
 static INLINE void set_tx_size_search_method(
@@ -242,7 +231,7 @@
     else
       x->tx_size_search_method = cpi->tx_size_search_methods[MODE_EVAL];
   }
-  x->tx_mode = select_tx_mode(cpi, x->tx_size_search_method);
+  x->tx_mode_search_type = select_tx_mode(cpi, x->tx_size_search_method);
 }
 
 static INLINE void set_tx_type_prune(const SPEED_FEATURES *sf, MACROBLOCK *x,