Avoid using MRC_DCT when the mask produced is invalid

If the mask is invalid, do not allow the encoder to select MRC_DCT.
Currently the mask is invalid if it is all 1 or all 0, but these
criteria will likely expand in a future patch.

Change-Id: I77230ea8357bfdb2bf1e6338903d44bbf1db22d1
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index 269ef57..aa4a76a 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -210,12 +210,20 @@
 }
 
 #if CONFIG_MRC_TX
-static INLINE void get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
-                                int mask_stride, int width, int height) {
+static INLINE int get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
+                               int mask_stride, int width, int height) {
+  int n_masked_vals = 0;
   for (int i = 0; i < height; ++i) {
-    for (int j = 0; j < width; ++j)
+    for (int j = 0; j < width; ++j) {
       mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
+      n_masked_vals += mask[i * mask_stride + j];
+    }
   }
+  return n_masked_vals;
+}
+
+static INLINE int is_valid_mrc_mask(int n_masked_vals, int width, int height) {
+  return !(n_masked_vals == 0 || n_masked_vals == (width * height));
 }
 #endif  // CONFIG_MRC_TX
 
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 2e2474b..7fc7b77 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -377,6 +377,10 @@
 #endif                      // CONFIG_SUPERTX
   int8_t seg_id_predicted;  // valid only when temporal_update is enabled
 
+#if CONFIG_MRC_TX
+  int valid_mrc_mask;
+#endif  // CONFIG_MRC_TX
+
   // Only for INTRA blocks
   UV_PREDICTION_MODE uv_mode;
 #if CONFIG_PALETTE
diff --git a/av1/common/idct.c b/av1/common/idct.c
index a6543a2..2a52cbf 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -1495,8 +1495,12 @@
   if (eob == 1) {
     aom_idct32x32_1_add_c(input, dest, stride);
   } else {
-    tran_low_t mask[32 * 32];
-    get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
+    int mask[32 * 32];
+    int n_masked_vals =
+        get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
+
+    if (!is_valid_mrc_mask(n_masked_vals, 32, 32))
+      assert(0 && "Invalid MRC mask");
     if (eob <= quarter)
       // non-zero coeff only in upper-left 8x8
       aom_imrc32x32_34_add_c(input, dest, stride, mask);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index a7e1dfc..377d75f 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1642,6 +1642,11 @@
         !supertx_enabled &&
 #endif  // CONFIG_SUPERTX
         !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+#if CONFIG_MRC_TX
+      if (tx_type == MRC_DCT)
+        assert(mbmi->valid_mrc_mask && "Invalid MRC mask");
+#endif  // CONFIG_MRC_TX
+
       const int eset =
           get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
       // eset == 0 should correspond to a set with only DCT_DCT and there
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index f4d8494..2ca4f34 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1092,9 +1092,14 @@
 #if CONFIG_MRC_TX
 static void get_masked_residual32(const int16_t **input, int *input_stride,
                                   const uint8_t *pred, int pred_stride,
-                                  int16_t *masked_input) {
+                                  int16_t *masked_input, int *valid_mask) {
   int mrc_mask[32 * 32];
-  get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+  int n_masked_vals = get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+  // Do not use MRC_DCT if mask is invalid. DCT_DCT will be used instead.
+  if (!is_valid_mrc_mask(n_masked_vals, 32, 32)) {
+    *valid_mask = 0;
+    return;
+  }
   int32_t sum = 0;
   int16_t avg;
   // Get the masked average of the prediction
@@ -1103,7 +1108,7 @@
       sum += mrc_mask[i * 32 + j] * (*input)[i * (*input_stride) + j];
     }
   }
-  avg = ROUND_POWER_OF_TWO_SIGNED(sum, 10);
+  avg = sum / n_masked_vals;
   // Replace all of the unmasked pixels in the prediction with the average
   // of the masked pixels
   for (int i = 0; i < 32; ++i) {
@@ -1113,6 +1118,7 @@
   }
   *input = masked_input;
   *input_stride = 32;
+  *valid_mask = 1;
 }
 #endif  // CONFIG_MRC_TX
 
@@ -2464,7 +2470,7 @@
   if (tx_type == MRC_DCT) {
     int16_t masked_input[32 * 32];
     get_masked_residual32(&input, &stride, txfm_param->dst, txfm_param->stride,
-                          masked_input);
+                          masked_input, txfm_param->valid_mask);
   }
 #endif  // CONFIG_MRC_TX
 
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index b225f9f..f2bb213 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -600,6 +600,9 @@
   txfm_param.is_inter = is_inter_block(mbmi);
   txfm_param.dst = dst;
   txfm_param.stride = dst_stride;
+#if CONFIG_MRC_TX
+  txfm_param.valid_mask = &mbmi->valid_mrc_mask;
+#endif  // CONFIG_MRC_TX
 #endif  // CONFIG_MRC_TX || CONFIG_LGT
 #if CONFIG_LGT
   txfm_param.mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 3fccb3c..1b86b61 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2159,6 +2159,13 @@
   }
 #endif  // DISABLE_TRELLISQ_SEARCH
 
+#if CONFIG_MRC_TX
+  if (mbmi->tx_type == MRC_DCT && !mbmi->valid_mrc_mask) {
+    args->exit_early = 1;
+    return;
+  }
+#endif  // CONFIG_MRC_TX
+
   if (!is_inter_block(mbmi)) {
     struct macroblock_plane *const p = &x->plane[plane];
     av1_inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,