Add flag to enable and disable rect-tx

Added an enable-rect-tx flag by changing logic in tx_search. Most
recently, fixed logic for conditionals for checking for transforms

Change-Id: I42243edad79e0b97c6d852dbf2ec3112fef2fb3b
diff --git a/aom/aomcx.h b/aom/aomcx.h
index ac69d46..3ab1dd2 100644
--- a/aom/aomcx.h
+++ b/aom/aomcx.h
@@ -855,7 +855,17 @@
    */
   AV1E_SET_ENABLE_FLIP_IDTX = 81,
 
-  /* Note: enum value 82 unused */
+  /*!\brief Codec control function to turn on / off rectangular transforms, int
+   * parameter
+   *
+   * This will enable or disable usage of rectangular transforms. NOTE:
+   * Rectangular transforms only enabled when corresponding rectangular
+   * partitions are.
+   *
+   * - 0 = disable
+   * - 1 = enable (default)
+   */
+  AV1E_SET_ENABLE_RECT_TX = 82,
 
   /*!\brief Codec control function to turn on / off dist-wtd compound mode
    * at sequence level, int parameter
@@ -1563,6 +1573,9 @@
 AOM_CTRL_USE_TYPE(AV1E_SET_ENABLE_FLIP_IDTX, int)
 #define AOM_CTRL_AV1E_SET_ENABLE_FLIP_IDTX
 
+AOM_CTRL_USE_TYPE(AV1E_SET_ENABLE_RECT_TX, int)
+#define AOM_CTRL_AV1E_SET_ENABLE_RECT_TX
+
 AOM_CTRL_USE_TYPE(AV1E_SET_ENABLE_DIST_WTD_COMP, int)
 #define AOM_CTRL_AV1E_SET_ENABLE_DIST_WTD_COMP
 
diff --git a/apps/aomenc.c b/apps/aomenc.c
index 4de1752..9814190 100644
--- a/apps/aomenc.c
+++ b/apps/aomenc.c
@@ -491,6 +491,9 @@
             "including FLIPADST_DCT, DCT_FLIPADST, FLIPADST_FLIPADST, "
             "ADST_FLIPADST, FLIPADST_ADST, IDTX, V_DCT, H_DCT, V_ADST, "
             "H_ADST, V_FLIPADST, H_FLIPADST");
+static const arg_def_t enable_rect_tx =
+    ARG_DEF(NULL, "enable-rect-tx", 1,
+            "Enable rectangular transform (0: false, 1: true (default))");
 static const arg_def_t enable_dist_wtd_comp =
     ARG_DEF(NULL, "enable-dist-wtd-comp", 1,
             "Enable distance-weighted compound "
@@ -858,6 +861,7 @@
                                        &enable_order_hint,
                                        &enable_tx64,
                                        &enable_flip_idtx,
+                                       &enable_rect_tx,
                                        &enable_dist_wtd_comp,
                                        &enable_masked_comp,
                                        &enable_onesided_comp,
@@ -965,6 +969,7 @@
                                         AV1E_SET_ENABLE_ORDER_HINT,
                                         AV1E_SET_ENABLE_TX64,
                                         AV1E_SET_ENABLE_FLIP_IDTX,
+                                        AV1E_SET_ENABLE_RECT_TX,
                                         AV1E_SET_ENABLE_DIST_WTD_COMP,
                                         AV1E_SET_ENABLE_MASKED_COMP,
                                         AV1E_SET_ENABLE_ONESIDED_COMP,
diff --git a/av1/av1_cx_iface.c b/av1/av1_cx_iface.c
index 253dd71..5f204a7 100644
--- a/av1/av1_cx_iface.c
+++ b/av1/av1_cx_iface.c
@@ -103,8 +103,9 @@
   int enable_order_hint;         // enable order hint for sequence
   int enable_tx64;               // enable 64-pt transform usage for sequence
   int enable_flip_idtx;          // enable flip and identity transform types
-  int enable_dist_wtd_comp;      // enable dist wtd compound for sequence
-  int max_reference_frames;      // maximum number of references per frame
+  int enable_rect_tx;        // enable rectangular transform usage for sequence
+  int enable_dist_wtd_comp;  // enable dist wtd compound for sequence
+  int max_reference_frames;  // maximum number of references per frame
   int enable_reduced_reference_set;  // enable reduced set of references
   int enable_ref_frame_mvs;          // sequence level
   int allow_ref_frame_mvs;           // frame level
@@ -225,6 +226,7 @@
   1,                            // frame order hint
   1,                            // enable 64-pt transform usage
   1,                            // enable flip and identity transform
+  1,                            // enable rectangular transform usage
   1,                            // dist-wtd compound
   7,                            // max_reference_frames
   0,                            // enable_reduced_reference_set
@@ -1013,6 +1015,7 @@
   // Set transform size/type configuration.
   txfm_cfg->enable_tx64 = extra_cfg->enable_tx64;
   txfm_cfg->enable_flip_idtx = extra_cfg->enable_flip_idtx;
+  txfm_cfg->enable_rect_tx = extra_cfg->enable_rect_tx;
   txfm_cfg->reduced_tx_type_set = extra_cfg->reduced_tx_type_set;
   txfm_cfg->use_intra_dct_only = extra_cfg->use_intra_dct_only;
   txfm_cfg->use_inter_dct_only = extra_cfg->use_inter_dct_only;
@@ -1481,6 +1484,13 @@
   return update_extra_cfg(ctx, &extra_cfg);
 }
 
+static aom_codec_err_t ctrl_set_enable_rect_tx(aom_codec_alg_priv_t *ctx,
+                                               va_list args) {
+  struct av1_extracfg extra_cfg = ctx->extra_cfg;
+  extra_cfg.enable_rect_tx = CAST(AV1E_SET_ENABLE_RECT_TX, args);
+  return update_extra_cfg(ctx, &extra_cfg);
+}
+
 static aom_codec_err_t ctrl_set_enable_dist_wtd_comp(aom_codec_alg_priv_t *ctx,
                                                      va_list args) {
   struct av1_extracfg extra_cfg = ctx->extra_cfg;
@@ -2804,6 +2814,7 @@
   { AV1E_SET_ENABLE_ORDER_HINT, ctrl_set_enable_order_hint },
   { AV1E_SET_ENABLE_TX64, ctrl_set_enable_tx64 },
   { AV1E_SET_ENABLE_FLIP_IDTX, ctrl_set_enable_flip_idtx },
+  { AV1E_SET_ENABLE_RECT_TX, ctrl_set_enable_rect_tx },
   { AV1E_SET_ENABLE_DIST_WTD_COMP, ctrl_set_enable_dist_wtd_comp },
   { AV1E_SET_MAX_REFERENCE_FRAMES, ctrl_set_max_reference_frames },
   { AV1E_SET_REDUCED_REFERENCE_SET, ctrl_set_enable_reduced_reference_set },
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 84e0dce..db15ba0 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -266,6 +266,10 @@
    */
   bool enable_flip_idtx;
   /*!
+   * Flag to indicate if rectangular transform should be enabled.
+   */
+  bool enable_rect_tx;
+  /*!
    * Flag to indicate whether or not to use a default reduced set for ext-tx
    * rather than the potential full set of 16 transforms.
    */
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 2bcaa95..21e6f6d 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -2469,7 +2469,6 @@
   return is_part_allowed;
 }
 
-// Rectangular partition types search function.
 static void rectangular_partition_search(
     AV1_COMP *const cpi, ThreadData *td, TileDataEnc *tile_data,
     TokenExtra **tp, MACROBLOCK *x, PC_TREE *pc_tree,
@@ -2892,7 +2891,7 @@
           part4_search_allowed[cur_part[i]]))
       continue;
     // Loop over split partitions.
-    // Get reactnagular partitions winner info of split partitions.
+    // Get rectangular partitions winner info of split partitions.
     for (int idx = 0; idx < SUB_PARTITIONS_SPLIT; idx++)
       num_child_rect_win[i] +=
           (part_search_state->split_part_rect_win[idx].rect_part_win[i]) ? 1
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index df577c3..7a6e321 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -2640,8 +2640,10 @@
                                          mbmi->sb_type, tx_size);
   struct macroblock_plane *const p = &x->plane[0];
 
-  const int try_no_split =
-      cpi->oxcf.txfm_cfg.enable_tx64 || txsize_sqr_up_map[tx_size] != TX_64X64;
+  const int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
+                            txsize_sqr_up_map[tx_size] != TX_64X64) &&
+                           (cpi->oxcf.txfm_cfg.enable_rect_tx ||
+                            tx_size_wide[tx_size] == tx_size_high[tx_size]);
   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
 
@@ -2723,7 +2725,7 @@
   mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
 
   // If tx64 is not enabled, we need to go down to the next available size
-  if (!cpi->oxcf.txfm_cfg.enable_tx64) {
+  if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
     static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
       TX_4X4,    // 4x4 transform
       TX_8X8,    // 8x8 transform
@@ -2745,8 +2747,56 @@
       TX_16X32,  // 16x64 transform
       TX_32X16,  // 64x16 transform
     };
-
     mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
+  } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
+             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
+    static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
+      TX_4X4,    // 4x4 transform
+      TX_8X8,    // 8x8 transform
+      TX_16X16,  // 16x16 transform
+      TX_32X32,  // 32x32 transform
+      TX_64X64,  // 64x64 transform
+      TX_4X4,    // 4x8 transform
+      TX_4X4,    // 8x4 transform
+      TX_8X8,    // 8x16 transform
+      TX_8X8,    // 16x8 transform
+      TX_16X16,  // 16x32 transform
+      TX_16X16,  // 32x16 transform
+      TX_32X32,  // 32x64 transform
+      TX_32X32,  // 64x32 transform
+      TX_4X4,    // 4x16 transform
+      TX_4X4,    // 16x4 transform
+      TX_8X8,    // 8x32 transform
+      TX_8X8,    // 32x8 transform
+      TX_16X16,  // 16x64 transform
+      TX_16X16,  // 64x16 transform
+    };
+    mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
+  } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
+             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
+    static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
+      TX_4X4,    // 4x4 transform
+      TX_8X8,    // 8x8 transform
+      TX_16X16,  // 16x16 transform
+      TX_32X32,  // 32x32 transform
+      TX_32X32,  // 64x64 transform
+      TX_4X4,    // 4x8 transform
+      TX_4X4,    // 8x4 transform
+      TX_8X8,    // 8x16 transform
+      TX_8X8,    // 16x8 transform
+      TX_16X16,  // 16x32 transform
+      TX_16X16,  // 32x16 transform
+      TX_32X32,  // 32x64 transform
+      TX_32X32,  // 64x32 transform
+      TX_4X4,    // 4x16 transform
+      TX_4X4,    // 16x4 transform
+      TX_8X8,    // 8x32 transform
+      TX_8X8,    // 32x8 transform
+      TX_16X16,  // 16x64 transform
+      TX_16X16,  // 64x16 transform
+    };
+
+    mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
   }
 
   const int skip_ctx = av1_get_skip_txfm_context(xd);
@@ -2818,8 +2868,10 @@
   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
   for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
        depth++, tx_size = sub_tx_size_map[tx_size]) {
-    if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
-        txsize_sqr_up_map[tx_size] == TX_64X64) {
+    if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
+         txsize_sqr_up_map[tx_size] == TX_64X64) ||
+        (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
+         tx_size_wide[tx_size] != tx_size_high[tx_size])) {
       continue;
     }