Make hbd transforms compatible with 4:1 transforms

Change-Id: I7123717b2d11bca826d650c6e6b6ae137476d541
diff --git a/av1/common/av1_fwd_txfm2d.c b/av1/common/av1_fwd_txfm2d.c
index 9d5f478..e0910e2 100644
--- a/av1/common/av1_fwd_txfm2d.c
+++ b/av1/common/av1_fwd_txfm2d.c
@@ -79,6 +79,15 @@
   // for square transforms.
   const int txfm_size_col = cfg->row_cfg->txfm_size;
   const int txfm_size_row = cfg->col_cfg->txfm_size;
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+  int rect_type2_shift = 0;
+  if (rect_type == 2 || rect_type == -2) {
+    const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row);
+    // For 64x16 / 16x64 shift 3 bits, for 32x8 / 8x32 shift 2 bits, for
+    // 16x4 / 4x16 shift by 1 bit.
+    rect_type2_shift =
+        (txfm_size_max == 64 ? 3 : (txfm_size_max == 32 ? 2 : 1));
+  }
   // Take the shift from the larger dimension in the rectangular case.
   const int8_t *shift = (txfm_size_col > txfm_size_row) ? cfg->row_cfg->shift
                                                         : cfg->col_cfg->shift;
@@ -108,10 +117,14 @@
     }
     round_shift_array(temp_in, txfm_size_row, -shift[0]);
     // Multiply everything by Sqrt2 on the larger dimension if the
-    // transform is rectangular
-    if (txfm_size_col > txfm_size_row) {
+    // transform is rectangular and the size difference is a factor of 2.
+    // If the size difference is a factor of 4, multiply by
+    // 2^rect_type_2_extra_shift.
+    if (rect_type == 1) {
       for (r = 0; r < txfm_size_row; ++r)
         temp_in[r] = (int32_t)fdct_round_shift(temp_in[r] * Sqrt2);
+    } else if (rect_type == 2) {
+      round_shift_array(temp_in, txfm_size_row, -rect_type2_shift);
     }
     txfm_func_col(temp_in, temp_out, cos_bit_col, stage_range_col);
     round_shift_array(temp_out, txfm_size_row, -shift[1]);
@@ -128,11 +141,16 @@
   // Rows
   for (r = 0; r < txfm_size_row; ++r) {
     // Multiply everything by Sqrt2 on the larger dimension if the
-    // transform is rectangular
-    if (txfm_size_row > txfm_size_col) {
+    // transform is rectangular and the size difference is a factor of 2.
+    // If the size difference is a factor of 4, multiply by 2.
+    if (rect_type == -1) {
       for (c = 0; c < txfm_size_col; ++c)
         buf[r * txfm_size_col + c] =
             (int32_t)fdct_round_shift(buf[r * txfm_size_col + c] * Sqrt2);
+    } else if (rect_type == -2) {
+      for (c = 0; c < txfm_size_col; ++c)
+        buf[r * txfm_size_col + c] =
+            buf[r * txfm_size_col + c] * (1 << rect_type2_shift);
     }
     txfm_func_row(buf + r * txfm_size_col, output + r * txfm_size_col,
                   cos_bit_row, stage_range_row);
diff --git a/av1/common/av1_inv_txfm2d.c b/av1/common/av1_inv_txfm2d.c
index 6450e26..89ec482 100644
--- a/av1/common/av1_inv_txfm2d.c
+++ b/av1/common/av1_inv_txfm2d.c
@@ -193,6 +193,14 @@
   // for square transforms.
   const int txfm_size_col = cfg->row_cfg->txfm_size;
   const int txfm_size_row = cfg->col_cfg->txfm_size;
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+  int rect_type2_shift = 0;
+  if (rect_type == 2 || rect_type == -2) {
+    const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row);
+    // For 16x4 / 4x16 shift 1 bit, for 32x8 / 8x32 / 64x16 / 16x64 no need
+    // for any additional shift.
+    rect_type2_shift = (txfm_size_max == 16 ? 1 : 0);
+  }
   // Take the shift from the larger dimension in the rectangular case.
   const int8_t *shift = (txfm_size_col > txfm_size_row) ? cfg->row_cfg->shift
                                                         : cfg->col_cfg->shift;
@@ -219,10 +227,14 @@
   for (r = 0; r < txfm_size_row; ++r) {
     txfm_func_row(input, buf_ptr, cos_bit_row, stage_range_row);
     round_shift_array(buf_ptr, txfm_size_col, -shift[0]);
-    // Multiply everything by Sqrt2 if the transform is rectangular
-    if (txfm_size_row != txfm_size_col) {
+    // Multiply everything by Sqrt2 if the transform is rectangular with
+    // log ratio being 1 or -1, if the log ratio is 2 or -2, multiply by
+    // 2^rect_type2_shift.
+    if (abs(rect_type) == 1) {
       for (c = 0; c < txfm_size_col; ++c)
         buf_ptr[c] = (int32_t)dct_const_round_shift(buf_ptr[c] * Sqrt2);
+    } else if (rect_type2_shift) {
+      round_shift_array(buf_ptr, txfm_size_col, -rect_type2_shift);
     }
     input += txfm_size_col;
     buf_ptr += txfm_size_col;
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index 44b5bd0..bb0053d 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -359,6 +359,22 @@
 }
 #endif  // CONFIG_MRC_TX
 
+// Utility function that returns the log of the ratio of the col and row
+// sizes.
+static INLINE int get_rect_tx_log_ratio(int col, int row) {
+  if (col == row) return 0;
+  if (col > row) {
+    if (col == row * 2) return 1;
+    if (col == row * 4) return 2;
+    assert(0 && "Unsupported transform size");
+  } else {
+    if (row == col * 2) return -1;
+    if (row == col * 4) return -2;
+    assert(0 && "Unsupported transform size");
+  }
+  return 0;  // Invalid
+}
+
 void av1_gen_fwd_stage_range(int8_t *stage_range_col, int8_t *stage_range_row,
                              const TXFM_2D_FLIP_CFG *cfg, int bd);