Make hbd transforms compatible with 4:1 transforms Change-Id: I7123717b2d11bca826d650c6e6b6ae137476d541
diff --git a/av1/common/av1_fwd_txfm2d.c b/av1/common/av1_fwd_txfm2d.c index 9d5f478..e0910e2 100644 --- a/av1/common/av1_fwd_txfm2d.c +++ b/av1/common/av1_fwd_txfm2d.c
@@ -79,6 +79,15 @@ // for square transforms. const int txfm_size_col = cfg->row_cfg->txfm_size; const int txfm_size_row = cfg->col_cfg->txfm_size; + const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row); + int rect_type2_shift = 0; + if (rect_type == 2 || rect_type == -2) { + const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row); + // For 64x16 / 16x64 shift 3 bits, for 32x8 / 8x32 shift 2 bits, for + // 16x4 / 4x16 shift by 1 bit. + rect_type2_shift = + (txfm_size_max == 64 ? 3 : (txfm_size_max == 32 ? 2 : 1)); + } // Take the shift from the larger dimension in the rectangular case. const int8_t *shift = (txfm_size_col > txfm_size_row) ? cfg->row_cfg->shift : cfg->col_cfg->shift; @@ -108,10 +117,14 @@ } round_shift_array(temp_in, txfm_size_row, -shift[0]); // Multiply everything by Sqrt2 on the larger dimension if the - // transform is rectangular - if (txfm_size_col > txfm_size_row) { + // transform is rectangular and the size difference is a factor of 2. + // If the size difference is a factor of 4, multiply by + // 2^rect_type_2_extra_shift. + if (rect_type == 1) { for (r = 0; r < txfm_size_row; ++r) temp_in[r] = (int32_t)fdct_round_shift(temp_in[r] * Sqrt2); + } else if (rect_type == 2) { + round_shift_array(temp_in, txfm_size_row, -rect_type2_shift); } txfm_func_col(temp_in, temp_out, cos_bit_col, stage_range_col); round_shift_array(temp_out, txfm_size_row, -shift[1]); @@ -128,11 +141,16 @@ // Rows for (r = 0; r < txfm_size_row; ++r) { // Multiply everything by Sqrt2 on the larger dimension if the - // transform is rectangular - if (txfm_size_row > txfm_size_col) { + // transform is rectangular and the size difference is a factor of 2. + // If the size difference is a factor of 4, multiply by 2. + if (rect_type == -1) { for (c = 0; c < txfm_size_col; ++c) buf[r * txfm_size_col + c] = (int32_t)fdct_round_shift(buf[r * txfm_size_col + c] * Sqrt2); + } else if (rect_type == -2) { + for (c = 0; c < txfm_size_col; ++c) + buf[r * txfm_size_col + c] = + buf[r * txfm_size_col + c] * (1 << rect_type2_shift); } txfm_func_row(buf + r * txfm_size_col, output + r * txfm_size_col, cos_bit_row, stage_range_row);
diff --git a/av1/common/av1_inv_txfm2d.c b/av1/common/av1_inv_txfm2d.c index 6450e26..89ec482 100644 --- a/av1/common/av1_inv_txfm2d.c +++ b/av1/common/av1_inv_txfm2d.c
@@ -193,6 +193,14 @@ // for square transforms. const int txfm_size_col = cfg->row_cfg->txfm_size; const int txfm_size_row = cfg->col_cfg->txfm_size; + const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row); + int rect_type2_shift = 0; + if (rect_type == 2 || rect_type == -2) { + const int txfm_size_max = AOMMAX(txfm_size_col, txfm_size_row); + // For 16x4 / 4x16 shift 1 bit, for 32x8 / 8x32 / 64x16 / 16x64 no need + // for any additional shift. + rect_type2_shift = (txfm_size_max == 16 ? 1 : 0); + } // Take the shift from the larger dimension in the rectangular case. const int8_t *shift = (txfm_size_col > txfm_size_row) ? cfg->row_cfg->shift : cfg->col_cfg->shift; @@ -219,10 +227,14 @@ for (r = 0; r < txfm_size_row; ++r) { txfm_func_row(input, buf_ptr, cos_bit_row, stage_range_row); round_shift_array(buf_ptr, txfm_size_col, -shift[0]); - // Multiply everything by Sqrt2 if the transform is rectangular - if (txfm_size_row != txfm_size_col) { + // Multiply everything by Sqrt2 if the transform is rectangular with + // log ratio being 1 or -1, if the log ratio is 2 or -2, multiply by + // 2^rect_type2_shift. + if (abs(rect_type) == 1) { for (c = 0; c < txfm_size_col; ++c) buf_ptr[c] = (int32_t)dct_const_round_shift(buf_ptr[c] * Sqrt2); + } else if (rect_type2_shift) { + round_shift_array(buf_ptr, txfm_size_col, -rect_type2_shift); } input += txfm_size_col; buf_ptr += txfm_size_col;
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h index 44b5bd0..bb0053d 100644 --- a/av1/common/av1_txfm.h +++ b/av1/common/av1_txfm.h
@@ -359,6 +359,22 @@ } #endif // CONFIG_MRC_TX +// Utility function that returns the log of the ratio of the col and row +// sizes. +static INLINE int get_rect_tx_log_ratio(int col, int row) { + if (col == row) return 0; + if (col > row) { + if (col == row * 2) return 1; + if (col == row * 4) return 2; + assert(0 && "Unsupported transform size"); + } else { + if (row == col * 2) return -1; + if (row == col * 4) return -2; + assert(0 && "Unsupported transform size"); + } + return 0; // Invalid +} + void av1_gen_fwd_stage_range(int8_t *stage_range_col, int8_t *stage_range_row, const TXFM_2D_FLIP_CFG *cfg, int bd);