Account for sqrt(2) in range computation of 1:2 tx

The cos bits and 2d tx algorithms are also adjusted accordingly to
meet the 32-bit limit.

Change-Id: I9048f3d3689ff1ef1bb84888ed9f43cdc4371411
diff --git a/av1/common/av1_inv_txfm2d.c b/av1/common/av1_inv_txfm2d.c
index 9714572..9cfe1ea 100644
--- a/av1/common/av1_inv_txfm2d.c
+++ b/av1/common/av1_inv_txfm2d.c
@@ -261,21 +261,25 @@
   const int txfm_size_row = cfg->col_cfg->txfm_size;
   const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
   if (txfm_size_col == txfm_size_row) assert(rect_type == 0);
-  int rect_type2_shift = 0;
-  if (rect_type == 2 || rect_type == -2) {
-    const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row);
-    // For 16x4 / 4x16 shift 1 bit, for 32x8 / 8x32 / 64x16 / 16x64 no need
-    // for any additional shift.
-    rect_type2_shift = (txfm_size_max == 16 ? 1 : 0);
-  }
+  int rect_type_shift = 0;
   // Take the shift from the larger dimension in the rectangular case.
   const int8_t *shift = (txfm_size_col > txfm_size_row) ? cfg->row_cfg->shift
                                                         : cfg->col_cfg->shift;
+
   int shift1 = shift[1];
-  while (rect_type2_shift > 0 && shift1 < 0) {
-    shift1++;
-    rect_type2_shift--;
+  if (rect_type == 1 || rect_type == -1) {
+    rect_type_shift = 1;
+  } else if (rect_type == 2 || rect_type == -2) {
+    const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row);
+    // For 16x4 / 4x16 shift 1 bit, for 32x8 / 8x32 / 64x16 / 16x64 no need
+    // for any additional shift.
+    rect_type_shift = (txfm_size_max == 16 ? 1 : 0);
   }
+  while (rect_type_shift > 0 && shift1 < 0) {
+    shift1++;
+    rect_type_shift--;
+  }
+
   // i < MAX_TXFM_STAGE_NUM will mute above array bounds warning
   for (int i = 0; i < cfg->row_cfg->stage_num && i < MAX_TXFM_STAGE_NUM; ++i) {
     stage_range_row[i] = cfg->row_cfg->stage_range[i] + fwd_shift + bd + 1;
@@ -283,7 +287,7 @@
   // i < MAX_TXFM_STAGE_NUM will mute above array bounds warning
   for (int i = 0; i < cfg->col_cfg->stage_num && i < MAX_TXFM_STAGE_NUM; ++i) {
     stage_range_col[i] = cfg->col_cfg->stage_range[i] + fwd_shift + shift[0] +
-                         bd + 1 + rect_type2_shift;
+                         bd + 1 + rect_type_shift;
   }
 }
 
@@ -304,8 +308,15 @@
                                                         : cfg->col_cfg->shift;
   const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
   int rect_type2_shift = 0;
+  int rect_type1_shift = 0;
   int shift1 = shift[1];
-  if (rect_type == 2 || rect_type == -2) {
+  if (rect_type == 1 || rect_type == -1) {
+    rect_type1_shift = 1;
+    while (rect_type1_shift > 0 && shift1 < 0) {
+      shift1++;
+      rect_type1_shift--;
+    }
+  } else if (rect_type == 2 || rect_type == -2) {
     const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row);
     // For 16x4 / 4x16 shift 1 bit, for 32x8 / 8x32 / 64x16 / 16x64 no need
     // for any additional shift.
@@ -341,7 +352,7 @@
     // Multiply everything by Sqrt2 if the transform is rectangular with
     // log ratio being 1 or -1, if the log ratio is 2 or -2, multiply by
     // 2^rect_type2_shift.
-    if (abs(rect_type) == 1) {
+    if (rect_type1_shift == 1) {
       for (c = 0; c < txfm_size_col; ++c)
         buf_ptr[c] = (int32_t)dct_const_round_shift(buf_ptr[c] * Sqrt2);
     } else if (rect_type2_shift) {
@@ -362,6 +373,11 @@
         temp_in[r] = buf[r * txfm_size_col + (txfm_size_col - c - 1)];
     }
     txfm_func_col(temp_in, temp_out, cos_bit_col, stage_range_col);
+    if (abs(rect_type) == 1 && rect_type1_shift == 0) {
+      for (r = 0; r < txfm_size_row; ++r) {
+        temp_out[r] = (int32_t)dct_const_round_shift(temp_out[r] * InvSqrt2);
+      }
+    }
     av1_round_shift_array(temp_out, txfm_size_row, -shift1);
     if (cfg->ud_flip == 0) {
       for (r = 0; r < txfm_size_row; ++r) {
diff --git a/av1/encoder/av1_fwd_txfm1d_cfg.h b/av1/encoder/av1_fwd_txfm1d_cfg.h
index e5e0bf3..1d3841c 100644
--- a/av1/encoder/av1_fwd_txfm1d_cfg.h
+++ b/av1/encoder/av1_fwd_txfm1d_cfg.h
@@ -134,8 +134,8 @@
 static const int8_t fwd_stage_range_row_adst_4x8[6] =
     ARRAYOFFSET6(3, 0, 0, 1, 2, 2, 2);
 static const int8_t fwd_stage_range_row_idx_4x8[1] = { 4 };
-static const int8_t fwd_cos_bit_row_dct_4x8[6] = { 13, 13, 13, 13 };
-static const int8_t fwd_cos_bit_row_adst_4x8[6] = { 13, 13, 13, 13, 13, 13 };
+static const int8_t fwd_cos_bit_row_dct_4x8[6] = { 13, 12, 12, 12 };
+static const int8_t fwd_cos_bit_row_adst_4x8[6] = { 13, 13, 12, 12, 12, 12 };
 
 //  ---------------- 8x4 1D constants -----------------------
 #define fwd_shift_8x4 fwd_shift_8
@@ -144,9 +144,9 @@
 static const int8_t fwd_stage_range_row_adst_8x4[8] =
     ARRAYOFFSET8(2, 0, 0, 1, 2, 2, 3, 3, 3);
 static const int8_t fwd_stage_range_row_idx_8x4[1] = { 3 };
-static const int8_t fwd_cos_bit_row_dct_8x4[6] = { 13, 13, 13, 13, 13, 13 };
+static const int8_t fwd_cos_bit_row_dct_8x4[6] = { 13, 13, 12, 12, 12, 12 };
 static const int8_t fwd_cos_bit_row_adst_8x4[8] = { 13, 13, 13, 13,
-                                                    13, 13, 13, 13 };
+                                                    12, 12, 12, 12 };
 
 //  ---------------- 8x16 1D constants -----------------------
 #define fwd_shift_8x16 fwd_shift_16
@@ -155,9 +155,9 @@
 static const int8_t fwd_stage_range_row_adst_8x16[8] =
     ARRAYOFFSET8(4, 0, 0, 1, 2, 2, 3, 3, 3);
 static const int8_t fwd_stage_range_row_idx_8x16[1] = { 5 };
-static const int8_t fwd_cos_bit_row_dct_8x16[6] = { 13, 13, 12, 12, 12, 12 };
-static const int8_t fwd_cos_bit_row_adst_8x16[8] = { 13, 13, 13, 13,
-                                                     12, 12, 12, 12 };
+static const int8_t fwd_cos_bit_row_dct_8x16[6] = { 12, 12, 11, 11, 11, 11 };
+static const int8_t fwd_cos_bit_row_adst_8x16[8] = { 12, 12, 12, 12,
+                                                     11, 11, 11, 11 };
 
 //  ---------------- 16x8 1D constants -----------------------
 #define fwd_shift_16x8 fwd_shift_16
@@ -166,10 +166,10 @@
 static const int8_t fwd_stage_range_row_adst_16x8[10] =
     ARRAYOFFSET10(3, 0, 0, 1, 2, 2, 3, 3, 4, 4, 4);
 static const int8_t fwd_stage_range_row_idx_16x8[1] = { 5 };
-static const int8_t fwd_cos_bit_row_dct_16x8[8] = { 13, 13, 13, 12,
-                                                    12, 12, 12, 12 };
-static const int8_t fwd_cos_bit_row_adst_16x8[10] = { 13, 13, 13, 13, 12,
-                                                      12, 12, 12, 12, 12 };
+static const int8_t fwd_cos_bit_row_dct_16x8[8] = { 12, 12, 12, 11,
+                                                    11, 11, 11, 11 };
+static const int8_t fwd_cos_bit_row_adst_16x8[10] = { 12, 12, 12, 12, 12,
+                                                      12, 11, 11, 11, 11 };
 
 //  ---------------- 16x32 1D constants -----------------------
 #define fwd_shift_16x32 fwd_shift_32
@@ -178,10 +178,10 @@
 static const int8_t fwd_stage_range_row_adst_16x32[10] =
     ARRAYOFFSET10(5, 0, 0, 1, 2, 2, 3, 3, 4, 4, 4);
 static const int8_t fwd_stage_range_row_idx_16x32[1] = { 7 };
-static const int8_t fwd_cos_bit_row_dct_16x32[8] = { 13, 13, 13, 12,
-                                                     12, 12, 12, 12 };
+static const int8_t fwd_cos_bit_row_dct_16x32[8] = { 12, 12, 12, 11,
+                                                     11, 11, 11, 11 };
 static const int8_t fwd_cos_bit_row_adst_16x32[10] = { 12, 12, 12, 12, 12,
-                                                       12, 12, 12, 12, 12 };
+                                                       12, 11, 11, 11, 11 };
 
 //  ---------------- 32x16 1D constants -----------------------
 #define fwd_shift_32x16 fwd_shift_32
@@ -190,10 +190,10 @@
 static const int8_t fwd_stage_range_row_adst_32x16[12] =
     ARRAYOFFSET12(4, 0, 0, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5);
 static const int8_t fwd_stage_range_row_idx_32x16[1] = { 6 };
-static const int8_t fwd_cos_bit_row_dct_32x16[10] = { 12, 12, 12, 12, 12,
-                                                      12, 12, 12, 12, 12 };
+static const int8_t fwd_cos_bit_row_dct_32x16[10] = { 12, 12, 12, 12, 11,
+                                                      11, 11, 11, 11, 11 };
 static const int8_t fwd_cos_bit_row_adst_32x16[12] = { 12, 12, 12, 12, 12, 12,
-                                                       12, 12, 12, 12, 12, 12 };
+                                                       12, 12, 11, 11, 11, 11 };
 
 //  ---------------- 32x64 1D constants -----------------------
 #define fwd_shift_32x64 fwd_shift_64
diff --git a/av1/encoder/av1_fwd_txfm2d.c b/av1/encoder/av1_fwd_txfm2d.c
index 149bc99..176dd41 100644
--- a/av1/encoder/av1_fwd_txfm2d.c
+++ b/av1/encoder/av1_fwd_txfm2d.c
@@ -64,22 +64,28 @@
   }
 
   const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
-  int rect_shift = 0;
+  int rect_type_shift = 0;
+
   int shift2 = shift[2];
-  if (rect_type == 2 || rect_type == -2) {
+  if (rect_type == 1 || rect_type == -1) {
+    rect_type_shift = 1;
+  } else if (rect_type == 2 || rect_type == -2) {
     const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row);
+
     // For 64x16 / 16x64 / 32x8 / 8x32 shift 2 bits, and
     // For 16x4 / 4x16 shift by 1 bit.
-    rect_shift = (txfm_size_max >= 32) ? 2 : 1;
+    rect_type_shift = (txfm_size_max >= 32) ? 2 : 1;
   }
-  while (rect_shift > 0 && shift2 < 0) {
+
+  while (rect_type_shift > 0 && shift2 < 0) {
     shift2++;
-    rect_shift--;
+    rect_type_shift--;
   }
+
   // i < MAX_TXFM_STAGE_NUM will mute above array bounds warning
   for (int i = 0; i < cfg->row_cfg->stage_num && i < MAX_TXFM_STAGE_NUM; ++i) {
     stage_range_row[i] = cfg->row_cfg->stage_range[i] + shift[0] + shift[1] +
-                         bd + 1 + rect_shift;
+                         bd + 1 + rect_type_shift;
   }
 }
 
@@ -100,8 +106,15 @@
                                                         : cfg->col_cfg->shift;
   int shift2 = shift[2];
   const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+  int rect_type1_shift = 0;
   int rect_type2_shift = 0;
-  if (rect_type == 2 || rect_type == -2) {
+  if (rect_type == 1 || rect_type == -1) {
+    rect_type1_shift = 1;
+    while (rect_type1_shift > 0 && shift2 < 0) {
+      shift2++;
+      rect_type1_shift--;
+    }
+  } else if (rect_type == 2 || rect_type == -2) {
     const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row);
     // For 64x16 / 16x64 / 32x8 / 8x32 shift 2 bits, and
     // For 16x4 / 4x16 shift by 1 bit.
@@ -141,10 +154,10 @@
     // transform is rectangular and the size difference is a factor of 2.
     // If the size difference is a factor of 4, multiply by
     // 2^rect_type_2_extra_shift.
-    if (rect_type == 1) {
+    if (abs(rect_type) == 1 && rect_type1_shift == 1) {
       for (r = 0; r < txfm_size_row; ++r)
         temp_out[r] = (int32_t)fdct_round_shift(temp_out[r] * Sqrt2);
-    } else if (rect_type == 2) {
+    } else if (abs(rect_type) == 2) {
       av1_round_shift_array(temp_out, txfm_size_row, -rect_type2_shift);
     }
     av1_round_shift_array(temp_out, txfm_size_row, -shift[1]);
@@ -160,20 +173,13 @@
 
   // Rows
   for (r = 0; r < txfm_size_row; ++r) {
-    // Multiply everything by Sqrt2 on the larger dimension if the
-    // transform is rectangular and the size difference is a factor of 2.
-    // If the size difference is a factor of 4, multiply by 2.
-    if (rect_type == -1) {
-      for (c = 0; c < txfm_size_col; ++c)
-        buf[r * txfm_size_col + c] =
-            (int32_t)fdct_round_shift(buf[r * txfm_size_col + c] * Sqrt2);
-    } else if (rect_type == -2) {
-      for (c = 0; c < txfm_size_col; ++c)
-        buf[r * txfm_size_col + c] =
-            buf[r * txfm_size_col + c] * (1 << rect_type2_shift);
-    }
     txfm_func_row(buf + r * txfm_size_col, output + r * txfm_size_col,
                   cos_bit_row, stage_range_row);
+    if (abs(rect_type) == 1 && rect_type1_shift == 0) {
+      for (c = 0; c < txfm_size_col; ++c)
+        output[r * txfm_size_col + c] =
+            (int32_t)fdct_round_shift(output[r * txfm_size_col + c] * InvSqrt2);
+    }
     av1_round_shift_array(output + r * txfm_size_col, txfm_size_col, -shift2);
   }
 }
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index 2a6a308..feeb45b 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -150,9 +150,9 @@
 
     param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_4X8, 3.2, 0.58));
     param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_8X4, 3.2, 0.58));
-    param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_8X16, 6.5, 1));
-    param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X8, 6.5, 1));
-    param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X32, 50, 7));
+    param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_8X16, 15, 1.5));
+    param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X8, 15, 1.5));
+    param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X32, 55, 7));
     param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_32X16, 30, 7));
 
     param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_4X16, 5, 0.7));
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index 5f85148..d8ddd2e 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -156,8 +156,8 @@
 
     param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_4X8, 2, 0.016));
     param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_8X4, 2, 0.016));
-    param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_8X16, 2, 0.033));
-    param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_16X8, 2, 0.033));
+    param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_8X16, 2, 0.2));
+    param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_16X8, 2, 0.2));
     param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_16X32, 3, 0.4));
     param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_32X16, 3, 0.4));
 
@@ -168,8 +168,8 @@
 
 #if CONFIG_TX64X64
     if (tx_type == DCT_DCT) {  // Other types not supported by these tx sizes.
-      param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_32X64, 4, 0.38));
-      param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_64X32, 4, 0.38));
+      param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_32X64, 5, 0.38));
+      param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_64X32, 5, 0.38));
       param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_16X64, 3, 0.38));
       param_list.push_back(AV1InvTxfm2dParam(tx_type, TX_64X16, 3, 0.38));
     }
diff --git a/test/av1_txfm_test.cc b/test/av1_txfm_test.cc
index b7710d2..dc11009 100644
--- a/test/av1_txfm_test.cc
+++ b/test/av1_txfm_test.cc
@@ -330,6 +330,14 @@
     // make sure there is no overflow while doing half_btf()
     EXPECT_LE(stage_range[i] + cos_bit[i], high_range);
     EXPECT_LE(stage_range[i + 1] + cos_bit[i], high_range);
+    if (stage_range[i] + cos_bit[i] > high_range) {
+      std::cout << i;
+      assert(0);
+    }
+    if (stage_range[i + 1] + cos_bit[i] > high_range) {
+      std::cout << i;
+      assert(0);
+    }
   }
 }
 }  // namespace libaom_test