Add txfm functions corresponding to MRC_DCT
MRC_DCT uses a mask based on the prediction signal to modify the
residual before applying DCT_DCT. This adds all necessary functions
to perform this transform and makes the prediction signal available
to the 32x32 txfm functions so the mask can be created. I am still
experimenting with different types of mask generation functions and
so this patch contains a placeholder. This patch has no impact on
performance.
Change-Id: Ie3772f528e82103187a85c91cf00bb291dba328a
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index 1304e4c..269ef57 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -209,6 +209,16 @@
}
}
+#if CONFIG_MRC_TX
+static INLINE void get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
+ int mask_stride, int width, int height) {
+ for (int i = 0; i < height; ++i) {
+ for (int j = 0; j < width; ++j)
+ mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
+ }
+}
+#endif // CONFIG_MRC_TX
+
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 63dfdb0..09f01d3 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -1463,6 +1463,35 @@
aom_idct16x16_256_add(input, dest, stride);
}
+#if CONFIG_MRC_TX
+static void imrc32x32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+ const TxfmParam *txfm_param) {
+#if CONFIG_ADAPT_SCAN
+ const int16_t half = txfm_param->eob_threshold[0];
+ const int16_t quarter = txfm_param->eob_threshold[1];
+#else
+ const int16_t half = 135;
+ const int16_t quarter = 34;
+#endif
+
+ const int eob = txfm_param->eob;
+ if (eob == 1) {
+ aom_idct32x32_1_add_c(input, dest, stride);
+ } else {
+ tran_low_t mask[32 * 32];
+ get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
+ if (eob <= quarter)
+ // non-zero coeff only in upper-left 8x8
+ aom_imrc32x32_34_add_c(input, dest, stride, mask);
+ else if (eob <= half)
+ // non-zero coeff only in upper-left 16x16
+ aom_imrc32x32_135_add_c(input, dest, stride, mask);
+ else
+ aom_imrc32x32_1024_add_c(input, dest, stride, mask);
+ }
+}
+#endif // CONFIG_MRC_TX
+
static void idct32x32_add(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
#if CONFIG_ADAPT_SCAN
@@ -1486,24 +1515,6 @@
aom_idct32x32_1024_add(input, dest, stride);
}
-#if CONFIG_MRC_TX
-static void get_masked_residual32_inv(const tran_low_t *input, uint8_t *dest,
- tran_low_t *output) {
- // placeholder for bitmask creation, in the future it
- // will likely be made based on dest
- (void)dest;
- memcpy(output, input, 32 * 32 * sizeof(*input));
-}
-
-static void imrc32x32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
- const TxfmParam *param) {
- // placeholder mrc tx function
- tran_low_t masked_input[32 * 32];
- get_masked_residual32_inv(input, dest, masked_input);
- idct32x32_add(input, dest, stride, param);
-}
-#endif // CONFIG_MRC_TX
-
#if CONFIG_TX64X64
static void idct64x64_add(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
@@ -2200,10 +2211,12 @@
#endif // CONFIG_PVQ
TxfmParam txfm_param;
init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
-#if CONFIG_LGT
+#if CONFIG_LGT || CONFIG_MRC_TX
txfm_param.dst = dst;
- txfm_param.mode = mode;
txfm_param.stride = stride;
+#endif // CONFIG_LGT || CONFIG_MRC_TX
+#if CONFIG_LGT
+ txfm_param.mode = mode;
#endif
const int is_hbd = get_bitdepth_data_path_index(xd);
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index 2ffc656..850b84c 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1064,17 +1064,29 @@
}
#if CONFIG_MRC_TX
-static void get_masked_residual32_fwd(const tran_low_t *input,
- tran_low_t *output) {
- // placeholder for future bitmask creation
- memcpy(output, input, 32 * 32 * sizeof(*input));
-}
-
-static void fmrc32(const tran_low_t *input, tran_low_t *output) {
- // placeholder for mrc_dct, this just performs regular dct
- tran_low_t masked_input[32 * 32];
- get_masked_residual32_fwd(input, masked_input);
- fdct32(masked_input, output);
+static void get_masked_residual32(const int16_t **input, int *input_stride,
+ const uint8_t *pred, int pred_stride,
+ int16_t *masked_input) {
+ int mrc_mask[32 * 32];
+ get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+ int32_t sum = 0;
+ int16_t avg;
+ // Get the masked average of the prediction
+ for (int i = 0; i < 32; ++i) {
+ for (int j = 0; j < 32; ++j) {
+ sum += mrc_mask[i * 32 + j] * (*input)[i * (*input_stride) + j];
+ }
+ }
+ avg = ROUND_POWER_OF_TWO_SIGNED(sum, 10);
+ // Replace all of the unmasked pixels in the prediction with the average
+ // of the masked pixels
+ for (int i = 0; i < 32; ++i) {
+ for (int j = 0; j < 32; ++j)
+ masked_input[i * 32 + j] =
+ (mrc_mask[i * 32 + j]) ? (*input)[i * (*input_stride) + j] : avg;
+ }
+ *input = masked_input;
+ *input_stride = 32;
}
#endif // CONFIG_MRC_TX
@@ -2387,7 +2399,7 @@
{ fidtx32, fhalfright32 }, // H_FLIPADST
#endif
#if CONFIG_MRC_TX
- { fmrc32, fmrc32 }, // MRC_TX
+ { fdct32, fdct32 }, // MRC_TX
#endif // CONFIG_MRC_TX
};
const transform_2d ht = FHT[tx_type];
@@ -2400,6 +2412,14 @@
maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type);
#endif
+#if CONFIG_MRC_TX
+ if (tx_type == MRC_DCT) {
+ int16_t masked_input[32 * 32];
+ get_masked_residual32(&input, &stride, txfm_param->dst, txfm_param->stride,
+ masked_input);
+ }
+#endif // CONFIG_MRC_TX
+
// Columns
for (i = 0; i < 32; ++i) {
for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4;
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index b532c13..5b91cec 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -538,7 +538,7 @@
TxfmParam txfm_param;
-#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT
+#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
uint8_t *dst;
const int dst_stride = pd->dst.stride;
#if CONFIG_PVQ || CONFIG_DIST_8X8
@@ -601,7 +601,7 @@
#endif // CONFIG_HIGHBITDEPTH
#endif
-#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT
+#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
#if CONFIG_PVQ || CONFIG_DIST_8X8
pred = &pd->pred[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
@@ -623,17 +623,19 @@
}
#endif // CONFIG_HIGHBITDEPTH
#endif // CONFIG_PVQ || CONFIG_DIST_8X8
-#endif // CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT
+#endif // CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
(void)ctx;
txfm_param.tx_type = tx_type;
txfm_param.tx_size = tx_size;
txfm_param.lossless = xd->lossless[mbmi->segment_id];
-#if CONFIG_LGT
- txfm_param.is_inter = is_inter_block(mbmi);
+#if CONFIG_MRC_TX || CONFIG_LGT
txfm_param.dst = dst;
txfm_param.stride = dst_stride;
+#endif // CONFIG_MRC_TX || CONFIG_LGT
+#if CONFIG_LGT
+ txfm_param.is_inter = is_inter_block(mbmi);
txfm_param.mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
#endif