Avoid using MRC_DCT when the mask produced is invalid

If the mask is invalid, do not allow the encoder to select MRC_DCT.
Currently the mask is invalid if it is all 1 or all 0, but these
criteria will likely expand in a future patch.

Change-Id: I77230ea8357bfdb2bf1e6338903d44bbf1db22d1
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index a7e1dfc..377d75f 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1642,6 +1642,11 @@
         !supertx_enabled &&
 #endif  // CONFIG_SUPERTX
         !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+#if CONFIG_MRC_TX
+      if (tx_type == MRC_DCT)
+        assert(mbmi->valid_mrc_mask && "Invalid MRC mask");
+#endif  // CONFIG_MRC_TX
+
       const int eset =
           get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
       // eset == 0 should correspond to a set with only DCT_DCT and there
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index f4d8494..2ca4f34 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1092,9 +1092,14 @@
 #if CONFIG_MRC_TX
 static void get_masked_residual32(const int16_t **input, int *input_stride,
                                   const uint8_t *pred, int pred_stride,
-                                  int16_t *masked_input) {
+                                  int16_t *masked_input, int *valid_mask) {
   int mrc_mask[32 * 32];
-  get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+  int n_masked_vals = get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+  // Do not use MRC_DCT if mask is invalid. DCT_DCT will be used instead.
+  if (!is_valid_mrc_mask(n_masked_vals, 32, 32)) {
+    *valid_mask = 0;
+    return;
+  }
   int32_t sum = 0;
   int16_t avg;
   // Get the masked average of the prediction
@@ -1103,7 +1108,7 @@
       sum += mrc_mask[i * 32 + j] * (*input)[i * (*input_stride) + j];
     }
   }
-  avg = ROUND_POWER_OF_TWO_SIGNED(sum, 10);
+  avg = sum / n_masked_vals;
   // Replace all of the unmasked pixels in the prediction with the average
   // of the masked pixels
   for (int i = 0; i < 32; ++i) {
@@ -1113,6 +1118,7 @@
   }
   *input = masked_input;
   *input_stride = 32;
+  *valid_mask = 1;
 }
 #endif  // CONFIG_MRC_TX
 
@@ -2464,7 +2470,7 @@
   if (tx_type == MRC_DCT) {
     int16_t masked_input[32 * 32];
     get_masked_residual32(&input, &stride, txfm_param->dst, txfm_param->stride,
-                          masked_input);
+                          masked_input, txfm_param->valid_mask);
   }
 #endif  // CONFIG_MRC_TX
 
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index b225f9f..f2bb213 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -600,6 +600,9 @@
   txfm_param.is_inter = is_inter_block(mbmi);
   txfm_param.dst = dst;
   txfm_param.stride = dst_stride;
+#if CONFIG_MRC_TX
+  txfm_param.valid_mask = &mbmi->valid_mrc_mask;
+#endif  // CONFIG_MRC_TX
 #endif  // CONFIG_MRC_TX || CONFIG_LGT
 #if CONFIG_LGT
   txfm_param.mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 3fccb3c..1b86b61 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2159,6 +2159,13 @@
   }
 #endif  // DISABLE_TRELLISQ_SEARCH
 
+#if CONFIG_MRC_TX
+  if (mbmi->tx_type == MRC_DCT && !mbmi->valid_mrc_mask) {
+    args->exit_early = 1;
+    return;
+  }
+#endif  // CONFIG_MRC_TX
+
   if (!is_inter_block(mbmi)) {
     struct macroblock_plane *const p = &x->plane[plane];
     av1_inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,