Avoid using MRC_DCT when the mask produced is invalid
If the mask is invalid, do not allow the encoder to select MRC_DCT.
Currently the mask is invalid if it is all 1 or all 0, but these
criteria will likely expand in a future patch.
Change-Id: I77230ea8357bfdb2bf1e6338903d44bbf1db22d1
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index 269ef57..aa4a76a 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -210,12 +210,20 @@
}
#if CONFIG_MRC_TX
-static INLINE void get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
- int mask_stride, int width, int height) {
+static INLINE int get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
+ int mask_stride, int width, int height) {
+ int n_masked_vals = 0;
for (int i = 0; i < height; ++i) {
- for (int j = 0; j < width; ++j)
+ for (int j = 0; j < width; ++j) {
mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
+ n_masked_vals += mask[i * mask_stride + j];
+ }
}
+ return n_masked_vals;
+}
+
+static INLINE int is_valid_mrc_mask(int n_masked_vals, int width, int height) {
+ return !(n_masked_vals == 0 || n_masked_vals == (width * height));
}
#endif // CONFIG_MRC_TX
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 2e2474b..7fc7b77 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -377,6 +377,10 @@
#endif // CONFIG_SUPERTX
int8_t seg_id_predicted; // valid only when temporal_update is enabled
+#if CONFIG_MRC_TX
+ int valid_mrc_mask;
+#endif // CONFIG_MRC_TX
+
// Only for INTRA blocks
UV_PREDICTION_MODE uv_mode;
#if CONFIG_PALETTE
diff --git a/av1/common/idct.c b/av1/common/idct.c
index a6543a2..2a52cbf 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -1495,8 +1495,12 @@
if (eob == 1) {
aom_idct32x32_1_add_c(input, dest, stride);
} else {
- tran_low_t mask[32 * 32];
- get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
+ int mask[32 * 32];
+ int n_masked_vals =
+ get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
+
+ if (!is_valid_mrc_mask(n_masked_vals, 32, 32))
+ assert(0 && "Invalid MRC mask");
if (eob <= quarter)
// non-zero coeff only in upper-left 8x8
aom_imrc32x32_34_add_c(input, dest, stride, mask);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index a7e1dfc..377d75f 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1642,6 +1642,11 @@
!supertx_enabled &&
#endif // CONFIG_SUPERTX
!segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+#if CONFIG_MRC_TX
+ if (tx_type == MRC_DCT)
+ assert(mbmi->valid_mrc_mask && "Invalid MRC mask");
+#endif // CONFIG_MRC_TX
+
const int eset =
get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
// eset == 0 should correspond to a set with only DCT_DCT and there
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index f4d8494..2ca4f34 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1092,9 +1092,14 @@
#if CONFIG_MRC_TX
static void get_masked_residual32(const int16_t **input, int *input_stride,
const uint8_t *pred, int pred_stride,
- int16_t *masked_input) {
+ int16_t *masked_input, int *valid_mask) {
int mrc_mask[32 * 32];
- get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+ int n_masked_vals = get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+ // Do not use MRC_DCT if mask is invalid. DCT_DCT will be used instead.
+ if (!is_valid_mrc_mask(n_masked_vals, 32, 32)) {
+ *valid_mask = 0;
+ return;
+ }
int32_t sum = 0;
int16_t avg;
// Get the masked average of the prediction
@@ -1103,7 +1108,7 @@
sum += mrc_mask[i * 32 + j] * (*input)[i * (*input_stride) + j];
}
}
- avg = ROUND_POWER_OF_TWO_SIGNED(sum, 10);
+ avg = sum / n_masked_vals;
// Replace all of the unmasked pixels in the prediction with the average
// of the masked pixels
for (int i = 0; i < 32; ++i) {
@@ -1113,6 +1118,7 @@
}
*input = masked_input;
*input_stride = 32;
+ *valid_mask = 1;
}
#endif // CONFIG_MRC_TX
@@ -2464,7 +2470,7 @@
if (tx_type == MRC_DCT) {
int16_t masked_input[32 * 32];
get_masked_residual32(&input, &stride, txfm_param->dst, txfm_param->stride,
- masked_input);
+ masked_input, txfm_param->valid_mask);
}
#endif // CONFIG_MRC_TX
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index b225f9f..f2bb213 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -600,6 +600,9 @@
txfm_param.is_inter = is_inter_block(mbmi);
txfm_param.dst = dst;
txfm_param.stride = dst_stride;
+#if CONFIG_MRC_TX
+ txfm_param.valid_mask = &mbmi->valid_mrc_mask;
+#endif // CONFIG_MRC_TX
#endif // CONFIG_MRC_TX || CONFIG_LGT
#if CONFIG_LGT
txfm_param.mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 3fccb3c..1b86b61 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2159,6 +2159,13 @@
}
#endif // DISABLE_TRELLISQ_SEARCH
+#if CONFIG_MRC_TX
+ if (mbmi->tx_type == MRC_DCT && !mbmi->valid_mrc_mask) {
+ args->exit_early = 1;
+ return;
+ }
+#endif // CONFIG_MRC_TX
+
if (!is_inter_block(mbmi)) {
struct macroblock_plane *const p = &x->plane[plane];
av1_inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,