Update adst4 range

Serialize the adst4 operations
Update stage range accordingly
Change the cos_bit precision accordingly.
Correct 4x8/8x4 inv_start_range

BUG=aomedia:1271

Change-Id: I10bc91585a61d790decdc24cb91659102e043620
diff --git a/av1/common/av1_inv_txfm1d.c b/av1/common/av1_inv_txfm1d.c
index 9c27eab..cd2593e 100644
--- a/av1/common/av1_inv_txfm1d.c
+++ b/av1/common/av1_inv_txfm1d.c
@@ -12,8 +12,22 @@
 #include <stdlib.h>
 #include "aom_dsp/inv_txfm.h"
 #include "av1/common/av1_inv_txfm1d.h"
-#if CONFIG_COEFFICIENT_RANGE_CHECKING
 
+int32_t range_check_value(int32_t value, int8_t bit) {
+#if CONFIG_COEFFICIENT_RANGE_CHECKING
+  const int64_t maxValue = (1LL << (bit - 1)) - 1;
+  const int64_t minValue = -(1LL << (bit - 1));
+  if (value < minValue || value > maxValue) {
+    fprintf(stderr, "coeff out of bit range, value: %d bit %d\n", value, bit);
+    assert(0);
+  }
+#else
+  (void)bit;
+#endif
+  return value;
+}
+
+#if CONFIG_COEFFICIENT_RANGE_CHECKING
 void range_check_func(int32_t stage, const int32_t *input, const int32_t *buf,
                       int32_t size, int8_t bit) {
   const int64_t maxValue = (1LL << (bit - 1)) - 1;
@@ -72,6 +86,15 @@
 }
 #endif
 
+int32_t clamp_value(int32_t value, int8_t bit) {
+  if (bit <= 16) {
+    const int32_t maxValue = (1 << 15) - 1;
+    const int32_t minValue = -(1 << 15);
+    return clamp(value, minValue, maxValue);
+  }
+  return value;
+}
+
 // TODO(angiebird): Make 1-d txfm functions static
 void av1_idct4_new(const int32_t *input, int32_t *output, int8_t cos_bit,
                    const int8_t *stage_range) {
@@ -757,30 +780,47 @@
     return;
   }
 
-  s0 = sinpi[1] * x0;
-  s1 = sinpi[2] * x0;
-  s2 = sinpi[3] * x1;
-  s3 = sinpi[4] * x2;
-  s4 = sinpi[1] * x2;
-  s5 = sinpi[2] * x3;
-  s6 = sinpi[4] * x3;
-  s7 = x0 - x2 + x3;
+  // stage 1
+  s0 = range_check_value(sinpi[1] * x0, stage_range[1]);
+  s1 = range_check_value(sinpi[2] * x0, stage_range[1]);
+  s2 = range_check_value(sinpi[3] * x1, stage_range[1]);
+  s3 = range_check_value(sinpi[4] * x2, stage_range[1]);
+  s4 = range_check_value(sinpi[1] * x2, stage_range[1]);
+  s5 = range_check_value(sinpi[2] * x3, stage_range[1]);
+  s6 = range_check_value(sinpi[4] * x3, stage_range[1]);
+  s7 = clamp_value(x0 - x2, stage_range[1]);
 
-  s0 = s0 + s3 + s5;
-  s1 = s1 - s4 - s6;
-  s3 = s2;
-  s2 = sinpi[3] * s7;
+  // stage 2
+  s7 = clamp_value(s7 + x3, stage_range[2]);
+
+  // stage 3
+  s0 = range_check_value(s0 + s3, stage_range[3] + bit);
+  s1 = range_check_value(s1 - s4, stage_range[3] + bit);
+  s3 = range_check_value(s2, stage_range[3] + bit);
+  s2 = range_check_value(sinpi[3] * s7, stage_range[3] + bit);
+
+  // stage 4
+  s0 = range_check_value(s0 + s5, stage_range[4] + bit);
+  s1 = range_check_value(s1 - s6, stage_range[4] + bit);
+
+  // stage 5
+  x0 = range_check_value(s0 + s3, stage_range[5] + bit);
+  x1 = range_check_value(s1 + s3, stage_range[5] + bit);
+  x2 = range_check_value(s2, stage_range[5] + bit);
+  x3 = range_check_value(s0 + s1, stage_range[5] + bit);
+
+  // stage 6
+  x3 = range_check_value(x3 - s3, stage_range[6] + bit);
 
   // 1-D transform scaling factor is sqrt(2).
   // The overall dynamic range is 14b (input) + 14b (multiplication scaling)
   // + 1b (addition) = 29b.
   // Hence the output bit depth is 15b.
-  stage = 3;
-  output[0] = round_shift(s0 + s3, bit);
-  output[1] = round_shift(s1 + s3, bit);
-  output[2] = round_shift(s2, bit);
-  output[3] = round_shift(s0 + s1 - s3, bit);
-  apply_range(stage, input, output, size, stage_range[stage]);
+  output[0] = round_shift(x0, bit);
+  output[1] = round_shift(x1, bit);
+  output[2] = round_shift(x2, bit);
+  output[3] = round_shift(x3, bit);
+  apply_range(6, input, output, size, stage_range[6]);
 }
 
 void av1_iadst8_new(const int32_t *input, int32_t *output, int8_t cos_bit,
diff --git a/av1/common/av1_inv_txfm1d_cfg.h b/av1/common/av1_inv_txfm1d_cfg.h
index 4211b41..4cb4fa6 100644
--- a/av1/common/av1_inv_txfm1d_cfg.h
+++ b/av1/common/av1_inv_txfm1d_cfg.h
@@ -22,8 +22,8 @@
 #if CONFIG_TX64X64
   7,    // 64x64 transform
 #endif  // CONFIG_TX64X64
-  6,    // 4x8 transform
-  6,    // 8x4 transform
+  5,    // 4x8 transform
+  5,    // 8x4 transform
   6,    // 8x16 transform
   6,    // 16x8 transform
   6,    // 16x32 transform
diff --git a/av1/common/av1_inv_txfm2d.c b/av1/common/av1_inv_txfm2d.c
index b97050d..c854fea 100644
--- a/av1/common/av1_inv_txfm2d.c
+++ b/av1/common/av1_inv_txfm2d.c
@@ -58,8 +58,8 @@
 #if CONFIG_TX64X64
 static const int8_t inv_shift_64x64[2] = { -2, -4 };
 #endif
-static const int8_t inv_shift_4x8[2] = { -1, -3 };
-static const int8_t inv_shift_8x4[2] = { -1, -3 };
+static const int8_t inv_shift_4x8[2] = { 0, -4 };
+static const int8_t inv_shift_8x4[2] = { 0, -4 };
 static const int8_t inv_shift_8x16[2] = { -1, -4 };
 static const int8_t inv_shift_16x8[2] = { -1, -4 };
 static const int8_t inv_shift_16x32[2] = { -1, -4 };
@@ -104,13 +104,15 @@
 
 const int8_t inv_cos_bit_row[MAX_TXWH_IDX /*txw_idx*/]
                             [MAX_TXWH_IDX /*txh_idx*/] = {
-                              { 13, 13, 13, 0, 0 },
+                              { 13, 13, 12, 0, 0 },
                               { 13, 13, 12, 12, 0 },
                               { 12, 12, 12, 12, 12 },
                               { 0, 12, 12, 12, 12 },
                               { 0, 0, 12, 12, 12 }
                             };
 
+const int8_t iadst4_range[7] = { 0, 1, 0, 0, 0, 0, 0 };
+
 void av1_get_inv_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size,
                           TXFM_2D_FLIP_CFG *cfg) {
   assert(cfg != NULL);
@@ -127,7 +129,13 @@
   cfg->cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   cfg->cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
   cfg->txfm_type_col = av1_txfm_type_ls[txh_idx][tx_type_1d_col];
+  if (cfg->txfm_type_col == TXFM_TYPE_ADST4) {
+    memcpy(cfg->stage_range_col, iadst4_range, sizeof(iadst4_range));
+  }
   cfg->txfm_type_row = av1_txfm_type_ls[txw_idx][tx_type_1d_row];
+  if (cfg->txfm_type_row == TXFM_TYPE_ADST4) {
+    memcpy(cfg->stage_range_row, iadst4_range, sizeof(iadst4_range));
+  }
   cfg->stage_num_col = av1_txfm_stage_num_list[cfg->txfm_type_col];
   cfg->stage_num_row = av1_txfm_stage_num_list[cfg->txfm_type_row];
 }
diff --git a/av1/common/av1_txfm.c b/av1/common/av1_txfm.c
index 50b3787..d3ffe4e 100644
--- a/av1/common/av1_txfm.c
+++ b/av1/common/av1_txfm.c
@@ -44,7 +44,7 @@
   8,   // TXFM_TYPE_DCT16
   10,  // TXFM_TYPE_DCT32
   12,  // TXFM_TYPE_DCT64
-  6,   // TXFM_TYPE_ADST4
+  7,   // TXFM_TYPE_ADST4
   8,   // TXFM_TYPE_ADST8
   10,  // TXFM_TYPE_ADST16
   12,  // TXFM_TYPE_ADST32
diff --git a/av1/encoder/av1_fwd_txfm1d.c b/av1/encoder/av1_fwd_txfm1d.c
index c385097..11beeb3 100644
--- a/av1/encoder/av1_fwd_txfm1d.c
+++ b/av1/encoder/av1_fwd_txfm1d.c
@@ -12,14 +12,16 @@
 #include <stdlib.h>
 #include "aom_dsp/inv_txfm.h"
 #include "av1/encoder/av1_fwd_txfm1d.h"
-#if CONFIG_COEFFICIENT_RANGE_CHECKING
+int32_t range_check_value(int32_t value, int8_t bit);
 
+#if CONFIG_COEFFICIENT_RANGE_CHECKING
 void range_check_func(int32_t stage, const int32_t *input, const int32_t *buf,
                       int32_t size, int8_t bit);
 
 #define range_check(stage, input, buf, size, bit) \
   range_check_func(stage, input, buf, size, bit)
-#else
+#else  // CONFIG_COEFFICIENT_RANGE_CHECKING
+
 #define range_check(stage, input, buf, size, bit) \
   {                                               \
     (void)stage;                                  \
@@ -28,7 +30,7 @@
     (void)size;                                   \
     (void)bit;                                    \
   }
-#endif
+#endif  // CONFIG_COEFFICIENT_RANGE_CHECKING
 
 void av1_fdct4_new(const int32_t *input, int32_t *output, int8_t cos_bit,
                    const int8_t *stage_range) {
@@ -692,13 +694,13 @@
 
 void av1_fadst4_new(const int32_t *input, int32_t *output, int8_t cos_bit,
                     const int8_t *stage_range) {
-  (void)cos_bit;
-  (void)stage_range;
   int bit = cos_bit;
   const int32_t *sinpi = sinpi_arr(bit);
   int32_t x0, x1, x2, x3;
   int32_t s0, s1, s2, s3, s4, s5, s6, s7;
 
+  // stage 0
+  range_check(0, input, input, 4, stage_range[0]);
   x0 = input[0];
   x1 = input[1];
   x2 = input[2];
@@ -709,30 +711,44 @@
     return;
   }
 
-  s0 = sinpi[1] * x0;
-  s1 = sinpi[4] * x0;
-  s2 = sinpi[2] * x1;
-  s3 = sinpi[1] * x1;
-  s4 = sinpi[3] * x2;
-  s5 = sinpi[4] * x3;
-  s6 = sinpi[2] * x3;
-  s7 = x0 + x1 - x3;
+  // stage 1
+  s0 = range_check_value(sinpi[1] * x0, bit + stage_range[1]);
+  s1 = range_check_value(sinpi[4] * x0, bit + stage_range[1]);
+  s2 = range_check_value(sinpi[2] * x1, bit + stage_range[1]);
+  s3 = range_check_value(sinpi[1] * x1, bit + stage_range[1]);
+  s4 = range_check_value(sinpi[3] * x2, bit + stage_range[1]);
+  s5 = range_check_value(sinpi[4] * x3, bit + stage_range[1]);
+  s6 = range_check_value(sinpi[2] * x3, bit + stage_range[1]);
+  s7 = range_check_value(x0 + x1, stage_range[1]);
 
-  x0 = s0 + s2 + s5;
-  x1 = sinpi[3] * s7;
-  x2 = s1 - s3 + s6;
-  x3 = s4;
+  // stage 2
+  s7 = range_check_value(s7 - x3, stage_range[2]);
 
-  s0 = x0 + x3;
-  s1 = x1;
-  s2 = x2 - x3;
-  s3 = x2 - x0 + x3;
+  // stage 3
+  x0 = range_check_value(s0 + s2, bit + stage_range[3]);
+  x1 = range_check_value(sinpi[3] * s7, bit + stage_range[3]);
+  x2 = range_check_value(s1 - s3, bit + stage_range[3]);
+  x3 = range_check_value(s4, bit + stage_range[3]);
+
+  // stage 4
+  x0 = range_check_value(x0 + s5, bit + stage_range[4]);
+  x2 = range_check_value(x2 + s6, bit + stage_range[4]);
+
+  // stage 5
+  s0 = range_check_value(x0 + x3, bit + stage_range[5]);
+  s1 = range_check_value(x1, bit + stage_range[5]);
+  s2 = range_check_value(x2 - x3, bit + stage_range[5]);
+  s3 = range_check_value(x2 - x0, bit + stage_range[5]);
+
+  // stage 6
+  s3 = range_check_value(s3 + x3, bit + stage_range[6]);
 
   // 1-D transform scaling factor is sqrt(2).
   output[0] = round_shift(s0, bit);
   output[1] = round_shift(s1, bit);
   output[2] = round_shift(s2, bit);
   output[3] = round_shift(s3, bit);
+  range_check(6, input, output, 4, stage_range[6]);
 }
 
 void av1_fadst8_new(const int32_t *input, int32_t *output, int8_t cos_bit,
diff --git a/av1/encoder/av1_fwd_txfm2d.c b/av1/encoder/av1_fwd_txfm2d.c
index c325c64..eb110fa 100644
--- a/av1/encoder/av1_fwd_txfm2d.c
+++ b/av1/encoder/av1_fwd_txfm2d.c
@@ -474,7 +474,7 @@
 
 const int8_t fwd_cos_bit_row[MAX_TXWH_IDX /*txw_idx*/]
                             [MAX_TXWH_IDX /*txh_idx*/] = {
-                              { 13, 13, 13, 0, 0 },
+                              { 13, 13, 12, 0, 0 },
                               { 13, 13, 13, 12, 0 },
                               { 13, 13, 12, 13, 12 },
                               { 0, 12, 13, 12, 11 },
@@ -488,7 +488,7 @@
 const int8_t fdct64_range_mult2[12] = { 0,  2,  4,  6,  8,  10,
                                         11, 11, 11, 11, 11, 11 };
 
-const int8_t fadst4_range_mult2[6] = { 0, 0, 1, 3, 3, 3 };
+const int8_t fadst4_range_mult2[7] = { 0, 2, 4, 3, 3, 3, 3 };
 const int8_t fadst8_range_mult2[8] = { 0, 0, 1, 3, 3, 5, 5, 5 };
 const int8_t fadst16_range_mult2[10] = { 0, 0, 1, 3, 3, 5, 5, 7, 7, 7 };
 const int8_t fadst32_range_mult2[12] = { 0, 0, 1, 3, 3, 5, 5, 7, 7, 9, 9, 9 };