ext-inter: Use joint_motion_search for masked compounds

Add functions which take both components of a masked compound and
compute the resulting SAD/SSE. Extend joint_motion_search to understand
masked compounds, and use it to evaluate NEW_NEWMV modes.

Change-Id: I782199a20d119a6c61c6567df157508125ac7ce7
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 86f6beb..aad23bc 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -740,6 +740,7 @@
     ($w, $h) = @$_;
     add_proto qw/unsigned int/, "aom_masked_sad${w}x${h}", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *mask, int mask_stride";
     specialize "aom_masked_sad${w}x${h}", qw/ssse3/;
+    add_proto qw/unsigned int/, "aom_masked_compound_sad${w}x${h}", "const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, const uint8_t *second_pred, const uint8_t *msk, int msk_stride, int invert_mask";
   }
 
   if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
@@ -747,6 +748,8 @@
       ($w, $h) = @$_;
       add_proto qw/unsigned int/, "aom_highbd_masked_sad${w}x${h}", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *mask, int mask_stride";
       specialize "aom_highbd_masked_sad${w}x${h}", qw/ssse3/;
+
+      add_proto qw/unsigned int/, "aom_highbd_masked_compound_sad${w}x${h}", "const uint8_t *src8, int src_stride, const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, int msk_stride, int invert_mask";
     }
   }
 }
@@ -1049,6 +1052,9 @@
     add_proto qw/unsigned int/, "aom_masked_sub_pixel_variance${w}x${h}", "const uint8_t *src_ptr, int source_stride, int xoffset, int  yoffset, const uint8_t *ref_ptr, int ref_stride, const uint8_t *mask, int mask_stride, unsigned int *sse";
     specialize "aom_masked_variance${w}x${h}", qw/ssse3/;
     specialize "aom_masked_sub_pixel_variance${w}x${h}", qw/ssse3/;
+
+    add_proto qw/unsigned int/, "aom_masked_compound_variance${w}x${h}", "const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, const uint8_t *second_pred, const uint8_t *m, int m_stride, int invert_mask, unsigned int *sse";
+    add_proto qw/unsigned int/, "aom_masked_compound_sub_pixel_variance${w}x${h}", "const uint8_t *src, int src_stride, int xoffset, int yoffset, const uint8_t *ref, int ref_stride, const uint8_t *second_pred, const uint8_t *msk, int msk_stride, int invert_mask, unsigned int *sse";
   }
 
   if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
@@ -1059,6 +1065,9 @@
         add_proto qw/unsigned int/, "aom_highbd${bd}masked_sub_pixel_variance${w}x${h}", "const uint8_t *src_ptr, int source_stride, int xoffset, int  yoffset, const uint8_t *ref_ptr, int ref_stride, const uint8_t *m, int m_stride, unsigned int *sse";
         specialize "aom_highbd${bd}masked_variance${w}x${h}", qw/ssse3/;
         specialize "aom_highbd${bd}masked_sub_pixel_variance${w}x${h}", qw/ssse3/;
+
+        add_proto qw/unsigned int/, "aom_highbd${bd}masked_compound_variance${w}x${h}", "const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, const uint8_t *second_pred, const uint8_t *m, int m_stride, int invert_mask, unsigned int *sse";
+        add_proto qw/unsigned int/, "aom_highbd${bd}masked_compound_sub_pixel_variance${w}x${h}", "const uint8_t *src, int src_stride, int xoffset, int yoffset, const uint8_t *ref, int ref_stride, const uint8_t *second_pred, const uint8_t *msk, int msk_stride, int invert_mask, unsigned int *sse";
       }
     }
   }
@@ -1501,6 +1510,15 @@
 
 }  # CONFIG_HIGHBITDEPTH
 
+if (aom_config("CONFIG_EXT_INTER") eq "yes") {
+  add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
+  add_proto qw/void aom_comp_mask_upsampled_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
+  if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
+    add_proto qw/void aom_highbd_comp_mask_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
+    add_proto qw/void aom_highbd_comp_mask_upsampled_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
+  }
+}
+
 }  # CONFIG_ENCODERS
 
 1;
diff --git a/aom_dsp/sad.c b/aom_dsp/sad.c
index 3e10705..e7f31a1 100644
--- a/aom_dsp/sad.c
+++ b/aom_dsp/sad.c
@@ -16,6 +16,7 @@
 
 #include "aom/aom_integer.h"
 #include "aom_ports/mem.h"
+#include "aom_dsp/blend.h"
 
 /* Sum the difference between every corresponding element of the buffers. */
 static INLINE unsigned int sad(const uint8_t *a, int a_stride, const uint8_t *b,
@@ -329,12 +330,48 @@
   return sad;
 }
 
+static INLINE unsigned int masked_compound_sad(const uint8_t *src,
+                                               int src_stride, const uint8_t *a,
+                                               int a_stride, const uint8_t *b,
+                                               int b_stride, const uint8_t *m,
+                                               int m_stride, int width,
+                                               int height) {
+  int y, x;
+  unsigned int sad = 0;
+
+  for (y = 0; y < height; y++) {
+    for (x = 0; x < width; x++) {
+      const uint8_t pred = AOM_BLEND_A64(m[x], a[x], b[x]);
+      sad += abs(pred - src[x]);
+    }
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+  }
+  sad = (sad + 31) >> 6;
+
+  return sad;
+}
+
 #define MASKSADMxN(m, n)                                                      \
   unsigned int aom_masked_sad##m##x##n##_c(                                   \
       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
       const uint8_t *msk, int msk_stride) {                                   \
     return masked_sad(src, src_stride, ref, ref_stride, msk, msk_stride, m,   \
                       n);                                                     \
+  }                                                                           \
+  unsigned int aom_masked_compound_sad##m##x##n##_c(                          \
+      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
+      const uint8_t *second_pred, const uint8_t *msk, int msk_stride,         \
+      int invert_mask) {                                                      \
+    if (!invert_mask)                                                         \
+      return masked_compound_sad(src, src_stride, ref, ref_stride,            \
+                                 second_pred, m, msk, msk_stride, m, n);      \
+    else                                                                      \
+      return masked_compound_sad(src, src_stride, second_pred, m, ref,        \
+                                 ref_stride, msk, msk_stride, m, n);          \
   }
 
 /* clang-format off */
@@ -381,12 +418,51 @@
   return sad;
 }
 
+static INLINE unsigned int highbd_masked_compound_sad(
+    const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
+    const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, int width,
+    int height) {
+  int y, x;
+  unsigned int sad = 0;
+  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+
+  for (y = 0; y < height; y++) {
+    for (x = 0; x < width; x++) {
+      const uint16_t pred = AOM_BLEND_A64(m[x], a[x], b[x]);
+      sad += abs(pred - src[x]);
+    }
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+  }
+  sad = (sad + 31) >> 6;
+
+  return sad;
+}
+
 #define HIGHBD_MASKSADMXN(m, n)                                               \
   unsigned int aom_highbd_masked_sad##m##x##n##_c(                            \
       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
       const uint8_t *msk, int msk_stride) {                                   \
     return highbd_masked_sad(src, src_stride, ref, ref_stride, msk,           \
                              msk_stride, m, n);                               \
+  }                                                                           \
+  unsigned int aom_highbd_masked_compound_sad##m##x##n##_c(                   \
+      const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
+      int ref_stride, const uint8_t *second_pred8, const uint8_t *msk,        \
+      int msk_stride, int invert_mask) {                                      \
+    if (!invert_mask)                                                         \
+      return highbd_masked_compound_sad(src8, src_stride, ref8, ref_stride,   \
+                                        second_pred8, m, msk, msk_stride, m,  \
+                                        n);                                   \
+    else                                                                      \
+      return highbd_masked_compound_sad(src8, src_stride, second_pred8, m,    \
+                                        ref8, ref_stride, msk, msk_stride, m, \
+                                        n);                                   \
   }
 
 #if CONFIG_EXT_PARTITION
diff --git a/aom_dsp/variance.c b/aom_dsp/variance.c
index 9fc0db7..90d0622 100644
--- a/aom_dsp/variance.c
+++ b/aom_dsp/variance.c
@@ -18,6 +18,7 @@
 
 #include "aom_dsp/variance.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/blend.h"
 
 uint32_t aom_get4x4sse_cs_c(const uint8_t *a, int a_stride, const uint8_t *b,
                             int b_stride) {
@@ -672,6 +673,47 @@
 #endif  // CONFIG_HIGHBITDEPTH
 
 #if CONFIG_AV1 && CONFIG_EXT_INTER
+void aom_comp_mask_pred_c(uint8_t *comp_pred, const uint8_t *pred, int width,
+                          int height, const uint8_t *ref, int ref_stride,
+                          const uint8_t *mask, int mask_stride,
+                          int invert_mask) {
+  int i, j;
+
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      if (!invert_mask)
+        comp_pred[j] = AOM_BLEND_A64(mask[j], ref[j], pred[j]);
+      else
+        comp_pred[j] = AOM_BLEND_A64(mask[j], pred[j], ref[j]);
+    }
+    comp_pred += width;
+    pred += width;
+    ref += ref_stride;
+    mask += mask_stride;
+  }
+}
+
+void aom_comp_mask_upsampled_pred_c(uint8_t *comp_pred, const uint8_t *pred,
+                                    int width, int height, const uint8_t *ref,
+                                    int ref_stride, const uint8_t *mask,
+                                    int mask_stride, int invert_mask) {
+  int i, j;
+  int stride = ref_stride << 3;
+
+  for (i = 0; i < height; i++) {
+    for (j = 0; j < width; j++) {
+      if (!invert_mask)
+        comp_pred[j] = AOM_BLEND_A64(mask[j], ref[(j << 3)], pred[j]);
+      else
+        comp_pred[j] = AOM_BLEND_A64(mask[j], pred[j], ref[(j << 3)]);
+    }
+    comp_pred += width;
+    pred += width;
+    ref += stride;
+    mask += mask_stride;
+  }
+}
+
 void masked_variance(const uint8_t *a, int a_stride, const uint8_t *b,
                      int b_stride, const uint8_t *m, int m_stride, int w, int h,
                      unsigned int *sse, int *sum) {
@@ -696,13 +738,54 @@
   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 12);
 }
 
-#define MASK_VAR(W, H)                                                       \
-  unsigned int aom_masked_variance##W##x##H##_c(                             \
-      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,        \
-      const uint8_t *m, int m_stride, unsigned int *sse) {                   \
-    int sum;                                                                 \
-    masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H, sse, &sum); \
-    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));            \
+void masked_compound_variance(const uint8_t *src, int src_stride,
+                              const uint8_t *a, int a_stride, const uint8_t *b,
+                              int b_stride, const uint8_t *m, int m_stride,
+                              int w, int h, unsigned int *sse, int *sum) {
+  int i, j;
+
+  int64_t sum64 = 0;
+  uint64_t sse64 = 0;
+
+  for (i = 0; i < h; i++) {
+    for (j = 0; j < w; j++) {
+      const uint8_t pred = AOM_BLEND_A64(m[j], a[j], b[j]);
+      const int diff = pred - src[j];
+      sum64 += diff;
+      sse64 += diff * diff;
+    }
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+  }
+  sum64 = (sum64 >= 0) ? sum64 : -sum64;
+  *sum = (int)ROUND_POWER_OF_TWO(sum64, 6);
+  *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 12);
+}
+
+#define MASK_VAR(W, H)                                                        \
+  unsigned int aom_masked_variance##W##x##H##_c(                              \
+      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,         \
+      const uint8_t *m, int m_stride, unsigned int *sse) {                    \
+    int sum;                                                                  \
+    masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H, sse, &sum);  \
+    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));             \
+  }                                                                           \
+                                                                              \
+  unsigned int aom_masked_compound_variance##W##x##H##_c(                     \
+      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
+      const uint8_t *second_pred, const uint8_t *m, int m_stride,             \
+      int invert_mask, unsigned int *sse) {                                   \
+    int sum;                                                                  \
+    if (!invert_mask)                                                         \
+      masked_compound_variance(src, src_stride, ref, ref_stride, second_pred, \
+                               W, m, m_stride, W, H, sse, &sum);              \
+    else                                                                      \
+      masked_compound_variance(src, src_stride, second_pred, W, ref,          \
+                               ref_stride, m, m_stride, W, H, sse, &sum);     \
+    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));             \
   }
 
 #define MASK_SUBPIX_VAR(W, H)                                                 \
@@ -720,6 +803,25 @@
                                                                               \
     return aom_masked_variance##W##x##H##_c(temp2, W, dst, dst_stride, msk,   \
                                             msk_stride, sse);                 \
+  }                                                                           \
+                                                                              \
+  unsigned int aom_masked_compound_sub_pixel_variance##W##x##H##_c(           \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
+      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,         \
+      const uint8_t *msk, int msk_stride, int invert_mask,                    \
+      unsigned int *sse) {                                                    \
+    uint16_t fdata3[(H + 1) * W];                                             \
+    uint8_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint8_t, temp3[H * W]);                               \
+                                                                              \
+    var_filter_block2d_bil_first_pass(src, fdata3, src_stride, 1, H + 1, W,   \
+                                      bilinear_filters_2t[xoffset]);          \
+    var_filter_block2d_bil_second_pass(fdata3, temp2, W, W, H, W,             \
+                                       bilinear_filters_2t[yoffset]);         \
+                                                                              \
+    aom_comp_mask_pred(temp3, second_pred, W, H, temp2, W, msk, msk_stride,   \
+                       invert_mask);                                          \
+    return aom_variance##W##x##H##_c(temp3, W, ref, ref_stride, sse);         \
   }
 
 MASK_VAR(4, 4)
@@ -773,6 +875,51 @@
 #endif  // CONFIG_EXT_PARTITION
 
 #if CONFIG_HIGHBITDEPTH
+void aom_highbd_comp_mask_pred_c(uint16_t *comp_pred, const uint8_t *pred8,
+                                 int width, int height, const uint8_t *ref8,
+                                 int ref_stride, const uint8_t *mask,
+                                 int mask_stride, int invert_mask) {
+  int i, j;
+  uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
+  uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      if (!invert_mask)
+        comp_pred[j] = AOM_BLEND_A64(mask[j], ref[j], pred[j]);
+      else
+        comp_pred[j] = AOM_BLEND_A64(mask[j], pred[j], ref[j]);
+    }
+    comp_pred += width;
+    pred += width;
+    ref += ref_stride;
+    mask += mask_stride;
+  }
+}
+
+void aom_highbd_comp_mask_upsampled_pred_c(uint16_t *comp_pred,
+                                           const uint8_t *pred8, int width,
+                                           int height, const uint8_t *ref8,
+                                           int ref_stride, const uint8_t *mask,
+                                           int mask_stride, int invert_mask) {
+  int i, j;
+  int stride = ref_stride << 3;
+
+  uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
+  uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      if (!invert_mask)
+        comp_pred[j] = AOM_BLEND_A64(mask[j], ref[j << 3], pred[j]);
+      else
+        comp_pred[j] = AOM_BLEND_A64(mask[j], pred[j], ref[j << 3]);
+    }
+    comp_pred += width;
+    pred += width;
+    ref += stride;
+    mask += mask_stride;
+  }
+}
+
 void highbd_masked_variance64(const uint8_t *a8, int a_stride,
                               const uint8_t *b8, int b_stride, const uint8_t *m,
                               int m_stride, int w, int h, uint64_t *sse,
@@ -835,85 +982,272 @@
   *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8);
 }
 
-#define HIGHBD_MASK_VAR(W, H)                                                \
-  unsigned int aom_highbd_masked_variance##W##x##H##_c(                      \
-      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,        \
-      const uint8_t *m, int m_stride, unsigned int *sse) {                   \
-    int sum;                                                                 \
-    highbd_masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H, sse, \
-                           &sum);                                            \
-    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));            \
-  }                                                                          \
-                                                                             \
-  unsigned int aom_highbd_10_masked_variance##W##x##H##_c(                   \
-      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,        \
-      const uint8_t *m, int m_stride, unsigned int *sse) {                   \
-    int sum;                                                                 \
-    int64_t var;                                                             \
-    highbd_10_masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H,   \
-                              sse, &sum);                                    \
-    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));                \
-    return (var >= 0) ? (uint32_t)var : 0;                                   \
-  }                                                                          \
-                                                                             \
-  unsigned int aom_highbd_12_masked_variance##W##x##H##_c(                   \
-      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,        \
-      const uint8_t *m, int m_stride, unsigned int *sse) {                   \
-    int sum;                                                                 \
-    int64_t var;                                                             \
-    highbd_12_masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H,   \
-                              sse, &sum);                                    \
-    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));                \
-    return (var >= 0) ? (uint32_t)var : 0;                                   \
+void highbd_masked_compound_variance64(const uint8_t *src8, int src_stride,
+                                       const uint8_t *a8, int a_stride,
+                                       const uint8_t *b8, int b_stride,
+                                       const uint8_t *m, int m_stride, int w,
+                                       int h, uint64_t *sse, int64_t *sum) {
+  int i, j;
+  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+
+  *sum = 0;
+  *sse = 0;
+
+  for (i = 0; i < h; i++) {
+    for (j = 0; j < w; j++) {
+      const uint16_t pred = AOM_BLEND_A64(m[j], a[j], b[j]);
+      const int diff = pred - src[j];
+      *sum += (int64_t)diff;
+      *sse += (int64_t)diff * diff;
+    }
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+  }
+  *sum = (*sum >= 0) ? *sum : -*sum;
+  *sum = ROUND_POWER_OF_TWO(*sum, 6);
+  *sse = ROUND_POWER_OF_TWO(*sse, 12);
+}
+
+void highbd_masked_compound_variance(const uint8_t *src8, int src_stride,
+                                     const uint8_t *a8, int a_stride,
+                                     const uint8_t *b8, int b_stride,
+                                     const uint8_t *m, int m_stride, int w,
+                                     int h, unsigned int *sse, int *sum) {
+  int64_t sum64;
+  uint64_t sse64;
+  highbd_masked_compound_variance64(src8, src_stride, a8, a_stride, b8,
+                                    b_stride, m, m_stride, w, h, &sse64,
+                                    &sum64);
+  *sum = (int)sum64;
+  *sse = (unsigned int)sse64;
+}
+
+void highbd_10_masked_compound_variance(const uint8_t *src8, int src_stride,
+                                        const uint8_t *a8, int a_stride,
+                                        const uint8_t *b8, int b_stride,
+                                        const uint8_t *m, int m_stride, int w,
+                                        int h, unsigned int *sse, int *sum) {
+  int64_t sum64;
+  uint64_t sse64;
+  highbd_masked_compound_variance64(src8, src_stride, a8, a_stride, b8,
+                                    b_stride, m, m_stride, w, h, &sse64,
+                                    &sum64);
+  *sum = (int)ROUND_POWER_OF_TWO(sum64, 2);
+  *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 4);
+}
+
+void highbd_12_masked_compound_variance(const uint8_t *src8, int src_stride,
+                                        const uint8_t *a8, int a_stride,
+                                        const uint8_t *b8, int b_stride,
+                                        const uint8_t *m, int m_stride, int w,
+                                        int h, unsigned int *sse, int *sum) {
+  int64_t sum64;
+  uint64_t sse64;
+  highbd_masked_compound_variance64(src8, src_stride, a8, a_stride, b8,
+                                    b_stride, m, m_stride, w, h, &sse64,
+                                    &sum64);
+  *sum = (int)ROUND_POWER_OF_TWO(sum64, 4);
+  *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8);
+}
+
+#define HIGHBD_MASK_VAR(W, H)                                                  \
+  unsigned int aom_highbd_masked_variance##W##x##H##_c(                        \
+      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,          \
+      const uint8_t *m, int m_stride, unsigned int *sse) {                     \
+    int sum;                                                                   \
+    highbd_masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H, sse,   \
+                           &sum);                                              \
+    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));              \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_10_masked_variance##W##x##H##_c(                     \
+      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,          \
+      const uint8_t *m, int m_stride, unsigned int *sse) {                     \
+    int sum;                                                                   \
+    int64_t var;                                                               \
+    highbd_10_masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H,     \
+                              sse, &sum);                                      \
+    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));                  \
+    return (var >= 0) ? (uint32_t)var : 0;                                     \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_12_masked_variance##W##x##H##_c(                     \
+      const uint8_t *a, int a_stride, const uint8_t *b, int b_stride,          \
+      const uint8_t *m, int m_stride, unsigned int *sse) {                     \
+    int sum;                                                                   \
+    int64_t var;                                                               \
+    highbd_12_masked_variance(a, a_stride, b, b_stride, m, m_stride, W, H,     \
+                              sse, &sum);                                      \
+    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));                  \
+    return (var >= 0) ? (uint32_t)var : 0;                                     \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_masked_compound_variance##W##x##H##_c(               \
+      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,  \
+      const uint8_t *second_pred, const uint8_t *m, int m_stride,              \
+      int invert_mask, unsigned int *sse) {                                    \
+    int sum;                                                                   \
+    if (!invert_mask)                                                          \
+      highbd_masked_compound_variance(src, src_stride, ref, ref_stride,        \
+                                      second_pred, W, m, m_stride, W, H, sse,  \
+                                      &sum);                                   \
+    else                                                                       \
+      highbd_masked_compound_variance(src, src_stride, second_pred, W, ref,    \
+                                      ref_stride, m, m_stride, W, H, sse,      \
+                                      &sum);                                   \
+    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));              \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_10_masked_compound_variance##W##x##H##_c(            \
+      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,  \
+      const uint8_t *second_pred, const uint8_t *m, int m_stride,              \
+      int invert_mask, unsigned int *sse) {                                    \
+    int sum;                                                                   \
+    if (!invert_mask)                                                          \
+      highbd_10_masked_compound_variance(src, src_stride, ref, ref_stride,     \
+                                         second_pred, W, m, m_stride, W, H,    \
+                                         sse, &sum);                           \
+    else                                                                       \
+      highbd_10_masked_compound_variance(src, src_stride, second_pred, W, ref, \
+                                         ref_stride, m, m_stride, W, H, sse,   \
+                                         &sum);                                \
+    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));              \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_12_masked_compound_variance##W##x##H##_c(            \
+      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,  \
+      const uint8_t *second_pred, const uint8_t *m, int m_stride,              \
+      int invert_mask, unsigned int *sse) {                                    \
+    int sum;                                                                   \
+    if (!invert_mask)                                                          \
+      highbd_12_masked_compound_variance(src, src_stride, ref, ref_stride,     \
+                                         second_pred, W, m, m_stride, W, H,    \
+                                         sse, &sum);                           \
+    else                                                                       \
+      highbd_12_masked_compound_variance(src, src_stride, second_pred, W, ref, \
+                                         ref_stride, m, m_stride, W, H, sse,   \
+                                         &sum);                                \
+    return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));              \
   }
 
-#define HIGHBD_MASK_SUBPIX_VAR(W, H)                                          \
-  unsigned int aom_highbd_masked_sub_pixel_variance##W##x##H##_c(             \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, const uint8_t *msk, int msk_stride, \
-      unsigned int *sse) {                                                    \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    return aom_highbd_masked_variance##W##x##H##_c(                           \
-        CONVERT_TO_BYTEPTR(temp2), W, dst, dst_stride, msk, msk_stride, sse); \
-  }                                                                           \
-                                                                              \
-  unsigned int aom_highbd_10_masked_sub_pixel_variance##W##x##H##_c(          \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, const uint8_t *msk, int msk_stride, \
-      unsigned int *sse) {                                                    \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    return aom_highbd_10_masked_variance##W##x##H##_c(                        \
-        CONVERT_TO_BYTEPTR(temp2), W, dst, dst_stride, msk, msk_stride, sse); \
-  }                                                                           \
-                                                                              \
-  unsigned int aom_highbd_12_masked_sub_pixel_variance##W##x##H##_c(          \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, const uint8_t *msk, int msk_stride, \
-      unsigned int *sse) {                                                    \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    return aom_highbd_12_masked_variance##W##x##H##_c(                        \
-        CONVERT_TO_BYTEPTR(temp2), W, dst, dst_stride, msk, msk_stride, sse); \
+#define HIGHBD_MASK_SUBPIX_VAR(W, H)                                           \
+  unsigned int aom_highbd_masked_sub_pixel_variance##W##x##H##_c(              \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, const uint8_t *msk, int msk_stride,  \
+      unsigned int *sse) {                                                     \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    return aom_highbd_masked_variance##W##x##H##_c(                            \
+        CONVERT_TO_BYTEPTR(temp2), W, dst, dst_stride, msk, msk_stride, sse);  \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_10_masked_sub_pixel_variance##W##x##H##_c(           \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, const uint8_t *msk, int msk_stride,  \
+      unsigned int *sse) {                                                     \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    return aom_highbd_10_masked_variance##W##x##H##_c(                         \
+        CONVERT_TO_BYTEPTR(temp2), W, dst, dst_stride, msk, msk_stride, sse);  \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_12_masked_sub_pixel_variance##W##x##H##_c(           \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, const uint8_t *msk, int msk_stride,  \
+      unsigned int *sse) {                                                     \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    return aom_highbd_12_masked_variance##W##x##H##_c(                         \
+        CONVERT_TO_BYTEPTR(temp2), W, dst, dst_stride, msk, msk_stride, sse);  \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_masked_compound_sub_pixel_variance##W##x##H##_c(     \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,          \
+      const uint8_t *msk, int msk_stride, int invert_mask,                     \
+      unsigned int *sse) {                                                     \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_comp_mask_pred_c(temp3, second_pred, W, H,                      \
+                                CONVERT_TO_BYTEPTR(temp2), W, msk, msk_stride, \
+                                invert_mask);                                  \
+                                                                               \
+    return aom_highbd_8_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,    \
+                                              ref, ref_stride, sse);           \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_10_masked_compound_sub_pixel_variance##W##x##H##_c(  \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,          \
+      const uint8_t *msk, int msk_stride, int invert_mask,                     \
+      unsigned int *sse) {                                                     \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_comp_mask_pred_c(temp3, second_pred, W, H,                      \
+                                CONVERT_TO_BYTEPTR(temp2), W, msk, msk_stride, \
+                                invert_mask);                                  \
+                                                                               \
+    return aom_highbd_10_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,   \
+                                               ref, ref_stride, sse);          \
+  }                                                                            \
+                                                                               \
+  unsigned int aom_highbd_12_masked_compound_sub_pixel_variance##W##x##H##_c(  \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,          \
+      const uint8_t *msk, int msk_stride, int invert_mask,                     \
+      unsigned int *sse) {                                                     \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_comp_mask_pred_c(temp3, second_pred, W, H,                      \
+                                CONVERT_TO_BYTEPTR(temp2), W, msk, msk_stride, \
+                                invert_mask);                                  \
+                                                                               \
+    return aom_highbd_12_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,   \
+                                               ref, ref_stride, sse);          \
   }
 
 HIGHBD_MASK_VAR(4, 4)
diff --git a/aom_dsp/variance.h b/aom_dsp/variance.h
index 7c925cf..adcf8b4 100644
--- a/aom_dsp/variance.h
+++ b/aom_dsp/variance.h
@@ -66,6 +66,19 @@
     const uint8_t *src, int src_stride, int xoffset, int yoffset,
     const uint8_t *ref, int ref_stride, const uint8_t *msk, int msk_stride,
     unsigned int *sse);
+
+typedef unsigned int (*aom_masked_compound_sad_fn_t)(
+    const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
+    const uint8_t *second_pred, const uint8_t *msk, int msk_stride,
+    int invert_mask);
+typedef unsigned int (*aom_masked_compound_variance_fn_t)(
+    const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
+    const uint8_t *second_pred, const uint8_t *m, int m_stride, int invert_mask,
+    unsigned int *sse);
+typedef unsigned int (*aom_masked_compound_subpixvariance_fn_t)(
+    const uint8_t *src, int src_stride, int xoffset, int yoffset,
+    const uint8_t *ref, int ref_stride, const uint8_t *second_pred,
+    const uint8_t *msk, int msk_stride, int invert_mask, unsigned int *sse);
 #endif  // CONFIG_AV1 && CONFIG_EXT_INTER
 
 #if CONFIG_AV1 && CONFIG_MOTION_VAR
@@ -96,6 +109,10 @@
   aom_masked_sad_fn_t msdf;
   aom_masked_variance_fn_t mvf;
   aom_masked_subpixvariance_fn_t msvf;
+
+  aom_masked_compound_sad_fn_t mcsdf;
+  aom_masked_compound_variance_fn_t mcvf;
+  aom_masked_compound_subpixvariance_fn_t mcsvf;
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR
   aom_obmc_sad_fn_t osdf;
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 84e4e93..bf1c333 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -1175,10 +1175,13 @@
 MAKE_BFP_SAD4D_WRAPPER(aom_highbd_sad4x4x4d)
 
 #if CONFIG_EXT_INTER
-#define HIGHBD_MBFP(BT, MSDF, MVF, MSVF) \
-  cpi->fn_ptr[BT].msdf = MSDF;           \
-  cpi->fn_ptr[BT].mvf = MVF;             \
-  cpi->fn_ptr[BT].msvf = MSVF;
+#define HIGHBD_MBFP(BT, MSDF, MVF, MSVF, MCSDF, MCVF, MCSVF) \
+  cpi->fn_ptr[BT].msdf = MSDF;                               \
+  cpi->fn_ptr[BT].mvf = MVF;                                 \
+  cpi->fn_ptr[BT].msvf = MSVF;                               \
+  cpi->fn_ptr[BT].mcsdf = MCSDF;                             \
+  cpi->fn_ptr[BT].mcvf = MCVF;                               \
+  cpi->fn_ptr[BT].mcsvf = MCSVF;
 
 #define MAKE_MBFP_SAD_WRAPPER(fnname)                                          \
   static unsigned int fnname##_bits8(                                          \
@@ -1199,10 +1202,38 @@
            4;                                                                  \
   }
 
+#define MAKE_MBFP_COMPOUND_SAD_WRAPPER(fnname)                           \
+  static unsigned int fnname##_bits8(                                    \
+      const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, \
+      int ref_stride, const uint8_t *second_pred_ptr, const uint8_t *m,  \
+      int m_stride, int invert_mask) {                                   \
+    return fnname(src_ptr, source_stride, ref_ptr, ref_stride,           \
+                  second_pred_ptr, m, m_stride, invert_mask);            \
+  }                                                                      \
+  static unsigned int fnname##_bits10(                                   \
+      const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, \
+      int ref_stride, const uint8_t *second_pred_ptr, const uint8_t *m,  \
+      int m_stride, int invert_mask) {                                   \
+    return fnname(src_ptr, source_stride, ref_ptr, ref_stride,           \
+                  second_pred_ptr, m, m_stride, invert_mask) >>          \
+           2;                                                            \
+  }                                                                      \
+  static unsigned int fnname##_bits12(                                   \
+      const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, \
+      int ref_stride, const uint8_t *second_pred_ptr, const uint8_t *m,  \
+      int m_stride, int invert_mask) {                                   \
+    return fnname(src_ptr, source_stride, ref_ptr, ref_stride,           \
+                  second_pred_ptr, m, m_stride, invert_mask) >>          \
+           4;                                                            \
+  }
+
 #if CONFIG_EXT_PARTITION
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad128x128)
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad128x64)
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad64x128)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad128x128)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad128x64)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad64x128)
 #endif  // CONFIG_EXT_PARTITION
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad64x64)
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad64x32)
@@ -1217,6 +1248,19 @@
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad8x4)
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad4x8)
 MAKE_MBFP_SAD_WRAPPER(aom_highbd_masked_sad4x4)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad64x64)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad64x32)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad32x64)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad32x32)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad32x16)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad16x32)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad16x16)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad16x8)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad8x16)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad8x8)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad8x4)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad4x8)
+MAKE_MBFP_COMPOUND_SAD_WRAPPER(aom_highbd_masked_compound_sad4x4)
 #endif  // CONFIG_EXT_INTER
 
 #if CONFIG_MOTION_VAR
@@ -1383,53 +1427,101 @@
 #if CONFIG_EXT_PARTITION
         HIGHBD_MBFP(BLOCK_128X128, aom_highbd_masked_sad128x128_bits8,
                     aom_highbd_masked_variance128x128,
-                    aom_highbd_masked_sub_pixel_variance128x128)
+                    aom_highbd_masked_sub_pixel_variance128x128,
+                    aom_highbd_masked_compound_sad128x128_bits8,
+                    aom_highbd_masked_compound_variance128x128,
+                    aom_highbd_masked_compound_sub_pixel_variance128x128)
         HIGHBD_MBFP(BLOCK_128X64, aom_highbd_masked_sad128x64_bits8,
                     aom_highbd_masked_variance128x64,
-                    aom_highbd_masked_sub_pixel_variance128x64)
+                    aom_highbd_masked_sub_pixel_variance128x64,
+                    aom_highbd_masked_compound_sad128x64_bits8,
+                    aom_highbd_masked_compound_variance128x64,
+                    aom_highbd_masked_compound_sub_pixel_variance128x64)
         HIGHBD_MBFP(BLOCK_64X128, aom_highbd_masked_sad64x128_bits8,
                     aom_highbd_masked_variance64x128,
-                    aom_highbd_masked_sub_pixel_variance64x128)
+                    aom_highbd_masked_sub_pixel_variance64x128,
+                    aom_highbd_masked_compound_sad64x128_bits8,
+                    aom_highbd_masked_compound_variance64x128,
+                    aom_highbd_masked_compound_sub_pixel_variance64x128)
 #endif  // CONFIG_EXT_PARTITION
         HIGHBD_MBFP(BLOCK_64X64, aom_highbd_masked_sad64x64_bits8,
                     aom_highbd_masked_variance64x64,
-                    aom_highbd_masked_sub_pixel_variance64x64)
+                    aom_highbd_masked_sub_pixel_variance64x64,
+                    aom_highbd_masked_compound_sad64x64_bits8,
+                    aom_highbd_masked_compound_variance64x64,
+                    aom_highbd_masked_compound_sub_pixel_variance64x64)
         HIGHBD_MBFP(BLOCK_64X32, aom_highbd_masked_sad64x32_bits8,
                     aom_highbd_masked_variance64x32,
-                    aom_highbd_masked_sub_pixel_variance64x32)
+                    aom_highbd_masked_sub_pixel_variance64x32,
+                    aom_highbd_masked_compound_sad64x32_bits8,
+                    aom_highbd_masked_compound_variance64x32,
+                    aom_highbd_masked_compound_sub_pixel_variance64x32)
         HIGHBD_MBFP(BLOCK_32X64, aom_highbd_masked_sad32x64_bits8,
                     aom_highbd_masked_variance32x64,
-                    aom_highbd_masked_sub_pixel_variance32x64)
+                    aom_highbd_masked_sub_pixel_variance32x64,
+                    aom_highbd_masked_compound_sad32x64_bits8,
+                    aom_highbd_masked_compound_variance32x64,
+                    aom_highbd_masked_compound_sub_pixel_variance32x64)
         HIGHBD_MBFP(BLOCK_32X32, aom_highbd_masked_sad32x32_bits8,
                     aom_highbd_masked_variance32x32,
-                    aom_highbd_masked_sub_pixel_variance32x32)
+                    aom_highbd_masked_sub_pixel_variance32x32,
+                    aom_highbd_masked_compound_sad32x32_bits8,
+                    aom_highbd_masked_compound_variance32x32,
+                    aom_highbd_masked_compound_sub_pixel_variance32x32)
         HIGHBD_MBFP(BLOCK_32X16, aom_highbd_masked_sad32x16_bits8,
                     aom_highbd_masked_variance32x16,
-                    aom_highbd_masked_sub_pixel_variance32x16)
+                    aom_highbd_masked_sub_pixel_variance32x16,
+                    aom_highbd_masked_compound_sad32x16_bits8,
+                    aom_highbd_masked_compound_variance32x16,
+                    aom_highbd_masked_compound_sub_pixel_variance32x16)
         HIGHBD_MBFP(BLOCK_16X32, aom_highbd_masked_sad16x32_bits8,
                     aom_highbd_masked_variance16x32,
-                    aom_highbd_masked_sub_pixel_variance16x32)
+                    aom_highbd_masked_sub_pixel_variance16x32,
+                    aom_highbd_masked_compound_sad16x32_bits8,
+                    aom_highbd_masked_compound_variance16x32,
+                    aom_highbd_masked_compound_sub_pixel_variance16x32)
         HIGHBD_MBFP(BLOCK_16X16, aom_highbd_masked_sad16x16_bits8,
                     aom_highbd_masked_variance16x16,
-                    aom_highbd_masked_sub_pixel_variance16x16)
+                    aom_highbd_masked_sub_pixel_variance16x16,
+                    aom_highbd_masked_compound_sad16x16_bits8,
+                    aom_highbd_masked_compound_variance16x16,
+                    aom_highbd_masked_compound_sub_pixel_variance16x16)
         HIGHBD_MBFP(BLOCK_8X16, aom_highbd_masked_sad8x16_bits8,
                     aom_highbd_masked_variance8x16,
-                    aom_highbd_masked_sub_pixel_variance8x16)
+                    aom_highbd_masked_sub_pixel_variance8x16,
+                    aom_highbd_masked_compound_sad8x16_bits8,
+                    aom_highbd_masked_compound_variance8x16,
+                    aom_highbd_masked_compound_sub_pixel_variance8x16)
         HIGHBD_MBFP(BLOCK_16X8, aom_highbd_masked_sad16x8_bits8,
                     aom_highbd_masked_variance16x8,
-                    aom_highbd_masked_sub_pixel_variance16x8)
+                    aom_highbd_masked_sub_pixel_variance16x8,
+                    aom_highbd_masked_compound_sad16x8_bits8,
+                    aom_highbd_masked_compound_variance16x8,
+                    aom_highbd_masked_compound_sub_pixel_variance16x8)
         HIGHBD_MBFP(BLOCK_8X8, aom_highbd_masked_sad8x8_bits8,
                     aom_highbd_masked_variance8x8,
-                    aom_highbd_masked_sub_pixel_variance8x8)
+                    aom_highbd_masked_sub_pixel_variance8x8,
+                    aom_highbd_masked_compound_sad8x8_bits8,
+                    aom_highbd_masked_compound_variance8x8,
+                    aom_highbd_masked_compound_sub_pixel_variance8x8)
         HIGHBD_MBFP(BLOCK_4X8, aom_highbd_masked_sad4x8_bits8,
                     aom_highbd_masked_variance4x8,
-                    aom_highbd_masked_sub_pixel_variance4x8)
+                    aom_highbd_masked_sub_pixel_variance4x8,
+                    aom_highbd_masked_compound_sad4x8_bits8,
+                    aom_highbd_masked_compound_variance4x8,
+                    aom_highbd_masked_compound_sub_pixel_variance4x8)
         HIGHBD_MBFP(BLOCK_8X4, aom_highbd_masked_sad8x4_bits8,
                     aom_highbd_masked_variance8x4,
-                    aom_highbd_masked_sub_pixel_variance8x4)
+                    aom_highbd_masked_sub_pixel_variance8x4,
+                    aom_highbd_masked_compound_sad8x4_bits8,
+                    aom_highbd_masked_compound_variance8x4,
+                    aom_highbd_masked_compound_sub_pixel_variance8x4)
         HIGHBD_MBFP(BLOCK_4X4, aom_highbd_masked_sad4x4_bits8,
                     aom_highbd_masked_variance4x4,
-                    aom_highbd_masked_sub_pixel_variance4x4)
+                    aom_highbd_masked_sub_pixel_variance4x4,
+                    aom_highbd_masked_compound_sad4x4_bits8,
+                    aom_highbd_masked_compound_variance4x4,
+                    aom_highbd_masked_compound_sub_pixel_variance4x4)
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR
 #if CONFIG_EXT_PARTITION
@@ -1606,53 +1698,101 @@
 #if CONFIG_EXT_PARTITION
         HIGHBD_MBFP(BLOCK_128X128, aom_highbd_masked_sad128x128_bits10,
                     aom_highbd_10_masked_variance128x128,
-                    aom_highbd_10_masked_sub_pixel_variance128x128)
+                    aom_highbd_10_masked_sub_pixel_variance128x128,
+                    aom_highbd_masked_compound_sad128x128_bits10,
+                    aom_highbd_10_masked_compound_variance128x128,
+                    aom_highbd_10_masked_compound_sub_pixel_variance128x128)
         HIGHBD_MBFP(BLOCK_128X64, aom_highbd_masked_sad128x64_bits10,
                     aom_highbd_10_masked_variance128x64,
-                    aom_highbd_10_masked_sub_pixel_variance128x64)
+                    aom_highbd_10_masked_sub_pixel_variance128x64,
+                    aom_highbd_masked_compound_sad128x64_bits10,
+                    aom_highbd_10_masked_compound_variance128x64,
+                    aom_highbd_10_masked_compound_sub_pixel_variance128x64)
         HIGHBD_MBFP(BLOCK_64X128, aom_highbd_masked_sad64x128_bits10,
                     aom_highbd_10_masked_variance64x128,
-                    aom_highbd_10_masked_sub_pixel_variance64x128)
+                    aom_highbd_10_masked_sub_pixel_variance64x128,
+                    aom_highbd_masked_compound_sad64x128_bits10,
+                    aom_highbd_10_masked_compound_variance64x128,
+                    aom_highbd_10_masked_compound_sub_pixel_variance64x128)
 #endif  // CONFIG_EXT_PARTITION
         HIGHBD_MBFP(BLOCK_64X64, aom_highbd_masked_sad64x64_bits10,
                     aom_highbd_10_masked_variance64x64,
-                    aom_highbd_10_masked_sub_pixel_variance64x64)
+                    aom_highbd_10_masked_sub_pixel_variance64x64,
+                    aom_highbd_masked_compound_sad64x64_bits10,
+                    aom_highbd_10_masked_compound_variance64x64,
+                    aom_highbd_10_masked_compound_sub_pixel_variance64x64)
         HIGHBD_MBFP(BLOCK_64X32, aom_highbd_masked_sad64x32_bits10,
                     aom_highbd_10_masked_variance64x32,
-                    aom_highbd_10_masked_sub_pixel_variance64x32)
+                    aom_highbd_10_masked_sub_pixel_variance64x32,
+                    aom_highbd_masked_compound_sad64x32_bits10,
+                    aom_highbd_10_masked_compound_variance64x32,
+                    aom_highbd_10_masked_compound_sub_pixel_variance64x32)
         HIGHBD_MBFP(BLOCK_32X64, aom_highbd_masked_sad32x64_bits10,
                     aom_highbd_10_masked_variance32x64,
-                    aom_highbd_10_masked_sub_pixel_variance32x64)
+                    aom_highbd_10_masked_sub_pixel_variance32x64,
+                    aom_highbd_masked_compound_sad32x64_bits10,
+                    aom_highbd_10_masked_compound_variance32x64,
+                    aom_highbd_10_masked_compound_sub_pixel_variance32x64)
         HIGHBD_MBFP(BLOCK_32X32, aom_highbd_masked_sad32x32_bits10,
                     aom_highbd_10_masked_variance32x32,
-                    aom_highbd_10_masked_sub_pixel_variance32x32)
+                    aom_highbd_10_masked_sub_pixel_variance32x32,
+                    aom_highbd_masked_compound_sad32x32_bits10,
+                    aom_highbd_10_masked_compound_variance32x32,
+                    aom_highbd_10_masked_compound_sub_pixel_variance32x32)
         HIGHBD_MBFP(BLOCK_32X16, aom_highbd_masked_sad32x16_bits10,
                     aom_highbd_10_masked_variance32x16,
-                    aom_highbd_10_masked_sub_pixel_variance32x16)
+                    aom_highbd_10_masked_sub_pixel_variance32x16,
+                    aom_highbd_masked_compound_sad32x16_bits10,
+                    aom_highbd_10_masked_compound_variance32x16,
+                    aom_highbd_10_masked_compound_sub_pixel_variance32x16)
         HIGHBD_MBFP(BLOCK_16X32, aom_highbd_masked_sad16x32_bits10,
                     aom_highbd_10_masked_variance16x32,
-                    aom_highbd_10_masked_sub_pixel_variance16x32)
+                    aom_highbd_10_masked_sub_pixel_variance16x32,
+                    aom_highbd_masked_compound_sad16x32_bits10,
+                    aom_highbd_10_masked_compound_variance16x32,
+                    aom_highbd_10_masked_compound_sub_pixel_variance16x32)
         HIGHBD_MBFP(BLOCK_16X16, aom_highbd_masked_sad16x16_bits10,
                     aom_highbd_10_masked_variance16x16,
-                    aom_highbd_10_masked_sub_pixel_variance16x16)
+                    aom_highbd_10_masked_sub_pixel_variance16x16,
+                    aom_highbd_masked_compound_sad16x16_bits10,
+                    aom_highbd_10_masked_compound_variance16x16,
+                    aom_highbd_10_masked_compound_sub_pixel_variance16x16)
         HIGHBD_MBFP(BLOCK_8X16, aom_highbd_masked_sad8x16_bits10,
                     aom_highbd_10_masked_variance8x16,
-                    aom_highbd_10_masked_sub_pixel_variance8x16)
+                    aom_highbd_10_masked_sub_pixel_variance8x16,
+                    aom_highbd_masked_compound_sad8x16_bits10,
+                    aom_highbd_10_masked_compound_variance8x16,
+                    aom_highbd_10_masked_compound_sub_pixel_variance8x16)
         HIGHBD_MBFP(BLOCK_16X8, aom_highbd_masked_sad16x8_bits10,
                     aom_highbd_10_masked_variance16x8,
-                    aom_highbd_10_masked_sub_pixel_variance16x8)
+                    aom_highbd_10_masked_sub_pixel_variance16x8,
+                    aom_highbd_masked_compound_sad16x8_bits10,
+                    aom_highbd_10_masked_compound_variance16x8,
+                    aom_highbd_10_masked_compound_sub_pixel_variance16x8)
         HIGHBD_MBFP(BLOCK_8X8, aom_highbd_masked_sad8x8_bits10,
                     aom_highbd_10_masked_variance8x8,
-                    aom_highbd_10_masked_sub_pixel_variance8x8)
+                    aom_highbd_10_masked_sub_pixel_variance8x8,
+                    aom_highbd_masked_compound_sad8x8_bits10,
+                    aom_highbd_10_masked_compound_variance8x8,
+                    aom_highbd_10_masked_compound_sub_pixel_variance8x8)
         HIGHBD_MBFP(BLOCK_4X8, aom_highbd_masked_sad4x8_bits10,
                     aom_highbd_10_masked_variance4x8,
-                    aom_highbd_10_masked_sub_pixel_variance4x8)
+                    aom_highbd_10_masked_sub_pixel_variance4x8,
+                    aom_highbd_masked_compound_sad4x8_bits10,
+                    aom_highbd_10_masked_compound_variance4x8,
+                    aom_highbd_10_masked_compound_sub_pixel_variance4x8)
         HIGHBD_MBFP(BLOCK_8X4, aom_highbd_masked_sad8x4_bits10,
                     aom_highbd_10_masked_variance8x4,
-                    aom_highbd_10_masked_sub_pixel_variance8x4)
+                    aom_highbd_10_masked_sub_pixel_variance8x4,
+                    aom_highbd_masked_compound_sad8x4_bits10,
+                    aom_highbd_10_masked_compound_variance8x4,
+                    aom_highbd_10_masked_compound_sub_pixel_variance8x4)
         HIGHBD_MBFP(BLOCK_4X4, aom_highbd_masked_sad4x4_bits10,
                     aom_highbd_10_masked_variance4x4,
-                    aom_highbd_10_masked_sub_pixel_variance4x4)
+                    aom_highbd_10_masked_sub_pixel_variance4x4,
+                    aom_highbd_masked_compound_sad4x4_bits10,
+                    aom_highbd_10_masked_compound_variance4x4,
+                    aom_highbd_10_masked_compound_sub_pixel_variance4x4)
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR
 #if CONFIG_EXT_PARTITION
@@ -1829,53 +1969,101 @@
 #if CONFIG_EXT_PARTITION
         HIGHBD_MBFP(BLOCK_128X128, aom_highbd_masked_sad128x128_bits12,
                     aom_highbd_12_masked_variance128x128,
-                    aom_highbd_12_masked_sub_pixel_variance128x128)
+                    aom_highbd_12_masked_sub_pixel_variance128x128,
+                    aom_highbd_masked_compound_sad128x128_bits12,
+                    aom_highbd_12_masked_compound_variance128x128,
+                    aom_highbd_12_masked_compound_sub_pixel_variance128x128)
         HIGHBD_MBFP(BLOCK_128X64, aom_highbd_masked_sad128x64_bits12,
                     aom_highbd_12_masked_variance128x64,
-                    aom_highbd_12_masked_sub_pixel_variance128x64)
+                    aom_highbd_12_masked_sub_pixel_variance128x64,
+                    aom_highbd_masked_compound_sad128x64_bits12,
+                    aom_highbd_12_masked_compound_variance128x64,
+                    aom_highbd_12_masked_compound_sub_pixel_variance128x64)
         HIGHBD_MBFP(BLOCK_64X128, aom_highbd_masked_sad64x128_bits12,
                     aom_highbd_12_masked_variance64x128,
-                    aom_highbd_12_masked_sub_pixel_variance64x128)
+                    aom_highbd_12_masked_sub_pixel_variance64x128,
+                    aom_highbd_masked_compound_sad64x128_bits12,
+                    aom_highbd_12_masked_compound_variance64x128,
+                    aom_highbd_12_masked_compound_sub_pixel_variance64x128)
 #endif  // CONFIG_EXT_PARTITION
         HIGHBD_MBFP(BLOCK_64X64, aom_highbd_masked_sad64x64_bits12,
                     aom_highbd_12_masked_variance64x64,
-                    aom_highbd_12_masked_sub_pixel_variance64x64)
+                    aom_highbd_12_masked_sub_pixel_variance64x64,
+                    aom_highbd_masked_compound_sad64x64_bits12,
+                    aom_highbd_12_masked_compound_variance64x64,
+                    aom_highbd_12_masked_compound_sub_pixel_variance64x64)
         HIGHBD_MBFP(BLOCK_64X32, aom_highbd_masked_sad64x32_bits12,
                     aom_highbd_12_masked_variance64x32,
-                    aom_highbd_12_masked_sub_pixel_variance64x32)
+                    aom_highbd_12_masked_sub_pixel_variance64x32,
+                    aom_highbd_masked_compound_sad64x32_bits12,
+                    aom_highbd_12_masked_compound_variance64x32,
+                    aom_highbd_12_masked_compound_sub_pixel_variance64x32)
         HIGHBD_MBFP(BLOCK_32X64, aom_highbd_masked_sad32x64_bits12,
                     aom_highbd_12_masked_variance32x64,
-                    aom_highbd_12_masked_sub_pixel_variance32x64)
+                    aom_highbd_12_masked_sub_pixel_variance32x64,
+                    aom_highbd_masked_compound_sad32x64_bits12,
+                    aom_highbd_12_masked_compound_variance32x64,
+                    aom_highbd_12_masked_compound_sub_pixel_variance32x64)
         HIGHBD_MBFP(BLOCK_32X32, aom_highbd_masked_sad32x32_bits12,
                     aom_highbd_12_masked_variance32x32,
-                    aom_highbd_12_masked_sub_pixel_variance32x32)
+                    aom_highbd_12_masked_sub_pixel_variance32x32,
+                    aom_highbd_masked_compound_sad32x32_bits12,
+                    aom_highbd_12_masked_compound_variance32x32,
+                    aom_highbd_12_masked_compound_sub_pixel_variance32x32)
         HIGHBD_MBFP(BLOCK_32X16, aom_highbd_masked_sad32x16_bits12,
                     aom_highbd_12_masked_variance32x16,
-                    aom_highbd_12_masked_sub_pixel_variance32x16)
+                    aom_highbd_12_masked_sub_pixel_variance32x16,
+                    aom_highbd_masked_compound_sad32x16_bits12,
+                    aom_highbd_12_masked_compound_variance32x16,
+                    aom_highbd_12_masked_compound_sub_pixel_variance32x16)
         HIGHBD_MBFP(BLOCK_16X32, aom_highbd_masked_sad16x32_bits12,
                     aom_highbd_12_masked_variance16x32,
-                    aom_highbd_12_masked_sub_pixel_variance16x32)
+                    aom_highbd_12_masked_sub_pixel_variance16x32,
+                    aom_highbd_masked_compound_sad16x32_bits12,
+                    aom_highbd_12_masked_compound_variance16x32,
+                    aom_highbd_12_masked_compound_sub_pixel_variance16x32)
         HIGHBD_MBFP(BLOCK_16X16, aom_highbd_masked_sad16x16_bits12,
                     aom_highbd_12_masked_variance16x16,
-                    aom_highbd_12_masked_sub_pixel_variance16x16)
+                    aom_highbd_12_masked_sub_pixel_variance16x16,
+                    aom_highbd_masked_compound_sad16x16_bits12,
+                    aom_highbd_12_masked_compound_variance16x16,
+                    aom_highbd_12_masked_compound_sub_pixel_variance16x16)
         HIGHBD_MBFP(BLOCK_8X16, aom_highbd_masked_sad8x16_bits12,
                     aom_highbd_12_masked_variance8x16,
-                    aom_highbd_12_masked_sub_pixel_variance8x16)
+                    aom_highbd_12_masked_sub_pixel_variance8x16,
+                    aom_highbd_masked_compound_sad8x16_bits12,
+                    aom_highbd_12_masked_compound_variance8x16,
+                    aom_highbd_12_masked_compound_sub_pixel_variance8x16)
         HIGHBD_MBFP(BLOCK_16X8, aom_highbd_masked_sad16x8_bits12,
                     aom_highbd_12_masked_variance16x8,
-                    aom_highbd_12_masked_sub_pixel_variance16x8)
+                    aom_highbd_12_masked_sub_pixel_variance16x8,
+                    aom_highbd_masked_compound_sad16x8_bits12,
+                    aom_highbd_12_masked_compound_variance16x8,
+                    aom_highbd_12_masked_compound_sub_pixel_variance16x8)
         HIGHBD_MBFP(BLOCK_8X8, aom_highbd_masked_sad8x8_bits12,
                     aom_highbd_12_masked_variance8x8,
-                    aom_highbd_12_masked_sub_pixel_variance8x8)
+                    aom_highbd_12_masked_sub_pixel_variance8x8,
+                    aom_highbd_masked_compound_sad8x8_bits12,
+                    aom_highbd_12_masked_compound_variance8x8,
+                    aom_highbd_12_masked_compound_sub_pixel_variance8x8)
         HIGHBD_MBFP(BLOCK_4X8, aom_highbd_masked_sad4x8_bits12,
                     aom_highbd_12_masked_variance4x8,
-                    aom_highbd_12_masked_sub_pixel_variance4x8)
+                    aom_highbd_12_masked_sub_pixel_variance4x8,
+                    aom_highbd_masked_compound_sad4x8_bits12,
+                    aom_highbd_12_masked_compound_variance4x8,
+                    aom_highbd_12_masked_compound_sub_pixel_variance4x8)
         HIGHBD_MBFP(BLOCK_8X4, aom_highbd_masked_sad8x4_bits12,
                     aom_highbd_12_masked_variance8x4,
-                    aom_highbd_12_masked_sub_pixel_variance8x4)
+                    aom_highbd_12_masked_sub_pixel_variance8x4,
+                    aom_highbd_masked_compound_sad8x4_bits12,
+                    aom_highbd_12_masked_compound_variance8x4,
+                    aom_highbd_12_masked_compound_sub_pixel_variance8x4)
         HIGHBD_MBFP(BLOCK_4X4, aom_highbd_masked_sad4x4_bits12,
                     aom_highbd_12_masked_variance4x4,
-                    aom_highbd_12_masked_sub_pixel_variance4x4)
+                    aom_highbd_12_masked_sub_pixel_variance4x4,
+                    aom_highbd_masked_compound_sad4x4_bits12,
+                    aom_highbd_12_masked_compound_variance4x4,
+                    aom_highbd_12_masked_compound_sub_pixel_variance4x4)
 #endif  // CONFIG_EXT_INTER
 
 #if CONFIG_MOTION_VAR
@@ -2464,45 +2652,80 @@
 #endif  // CONFIG_MOTION_VAR
 
 #if CONFIG_EXT_INTER
-#define MBFP(BT, MSDF, MVF, MSVF) \
-  cpi->fn_ptr[BT].msdf = MSDF;    \
-  cpi->fn_ptr[BT].mvf = MVF;      \
-  cpi->fn_ptr[BT].msvf = MSVF;
+#define MBFP(BT, MSDF, MVF, MSVF, MCSDF, MCVF, MCSVF) \
+  cpi->fn_ptr[BT].msdf = MSDF;                        \
+  cpi->fn_ptr[BT].mvf = MVF;                          \
+  cpi->fn_ptr[BT].msvf = MSVF;                        \
+  cpi->fn_ptr[BT].mcsdf = MCSDF;                      \
+  cpi->fn_ptr[BT].mcvf = MCVF;                        \
+  cpi->fn_ptr[BT].mcsvf = MCSVF;
 
 #if CONFIG_EXT_PARTITION
   MBFP(BLOCK_128X128, aom_masked_sad128x128, aom_masked_variance128x128,
-       aom_masked_sub_pixel_variance128x128)
+       aom_masked_sub_pixel_variance128x128, aom_masked_compound_sad128x128,
+       aom_masked_compound_variance128x128,
+       aom_masked_compound_sub_pixel_variance128x128)
   MBFP(BLOCK_128X64, aom_masked_sad128x64, aom_masked_variance128x64,
-       aom_masked_sub_pixel_variance128x64)
+       aom_masked_sub_pixel_variance128x64, aom_masked_compound_sad128x64,
+       aom_masked_compound_variance128x64,
+       aom_masked_compound_sub_pixel_variance128x64)
   MBFP(BLOCK_64X128, aom_masked_sad64x128, aom_masked_variance64x128,
-       aom_masked_sub_pixel_variance64x128)
+       aom_masked_sub_pixel_variance64x128, aom_masked_compound_sad64x128,
+       aom_masked_compound_variance64x128,
+       aom_masked_compound_sub_pixel_variance64x128)
 #endif  // CONFIG_EXT_PARTITION
   MBFP(BLOCK_64X64, aom_masked_sad64x64, aom_masked_variance64x64,
-       aom_masked_sub_pixel_variance64x64)
+       aom_masked_sub_pixel_variance64x64, aom_masked_compound_sad64x64,
+       aom_masked_compound_variance64x64,
+       aom_masked_compound_sub_pixel_variance64x64)
   MBFP(BLOCK_64X32, aom_masked_sad64x32, aom_masked_variance64x32,
-       aom_masked_sub_pixel_variance64x32)
+       aom_masked_sub_pixel_variance64x32, aom_masked_compound_sad64x32,
+       aom_masked_compound_variance64x32,
+       aom_masked_compound_sub_pixel_variance64x32)
   MBFP(BLOCK_32X64, aom_masked_sad32x64, aom_masked_variance32x64,
-       aom_masked_sub_pixel_variance32x64)
+       aom_masked_sub_pixel_variance32x64, aom_masked_compound_sad32x64,
+       aom_masked_compound_variance32x64,
+       aom_masked_compound_sub_pixel_variance32x64)
   MBFP(BLOCK_32X32, aom_masked_sad32x32, aom_masked_variance32x32,
-       aom_masked_sub_pixel_variance32x32)
+       aom_masked_sub_pixel_variance32x32, aom_masked_compound_sad32x32,
+       aom_masked_compound_variance32x32,
+       aom_masked_compound_sub_pixel_variance32x32)
   MBFP(BLOCK_32X16, aom_masked_sad32x16, aom_masked_variance32x16,
-       aom_masked_sub_pixel_variance32x16)
+       aom_masked_sub_pixel_variance32x16, aom_masked_compound_sad32x16,
+       aom_masked_compound_variance32x16,
+       aom_masked_compound_sub_pixel_variance32x16)
   MBFP(BLOCK_16X32, aom_masked_sad16x32, aom_masked_variance16x32,
-       aom_masked_sub_pixel_variance16x32)
+       aom_masked_sub_pixel_variance16x32, aom_masked_compound_sad16x32,
+       aom_masked_compound_variance16x32,
+       aom_masked_compound_sub_pixel_variance16x32)
   MBFP(BLOCK_16X16, aom_masked_sad16x16, aom_masked_variance16x16,
-       aom_masked_sub_pixel_variance16x16)
+       aom_masked_sub_pixel_variance16x16, aom_masked_compound_sad16x16,
+       aom_masked_compound_variance16x16,
+       aom_masked_compound_sub_pixel_variance16x16)
   MBFP(BLOCK_16X8, aom_masked_sad16x8, aom_masked_variance16x8,
-       aom_masked_sub_pixel_variance16x8)
+       aom_masked_sub_pixel_variance16x8, aom_masked_compound_sad16x8,
+       aom_masked_compound_variance16x8,
+       aom_masked_compound_sub_pixel_variance16x8)
   MBFP(BLOCK_8X16, aom_masked_sad8x16, aom_masked_variance8x16,
-       aom_masked_sub_pixel_variance8x16)
+       aom_masked_sub_pixel_variance8x16, aom_masked_compound_sad8x16,
+       aom_masked_compound_variance8x16,
+       aom_masked_compound_sub_pixel_variance8x16)
   MBFP(BLOCK_8X8, aom_masked_sad8x8, aom_masked_variance8x8,
-       aom_masked_sub_pixel_variance8x8)
+       aom_masked_sub_pixel_variance8x8, aom_masked_compound_sad8x8,
+       aom_masked_compound_variance8x8,
+       aom_masked_compound_sub_pixel_variance8x8)
   MBFP(BLOCK_4X8, aom_masked_sad4x8, aom_masked_variance4x8,
-       aom_masked_sub_pixel_variance4x8)
+       aom_masked_sub_pixel_variance4x8, aom_masked_compound_sad4x8,
+       aom_masked_compound_variance4x8,
+       aom_masked_compound_sub_pixel_variance4x8)
   MBFP(BLOCK_8X4, aom_masked_sad8x4, aom_masked_variance8x4,
-       aom_masked_sub_pixel_variance8x4)
+       aom_masked_sub_pixel_variance8x4, aom_masked_compound_sad8x4,
+       aom_masked_compound_variance8x4,
+       aom_masked_compound_sub_pixel_variance8x4)
   MBFP(BLOCK_4X4, aom_masked_sad4x4, aom_masked_variance4x4,
-       aom_masked_sub_pixel_variance4x4)
+       aom_masked_sub_pixel_variance4x4, aom_masked_compound_sad4x4,
+       aom_masked_compound_variance4x4,
+       aom_masked_compound_sub_pixel_variance4x4)
 #endif  // CONFIG_EXT_INTER
 
 #if CONFIG_HIGHBITDEPTH
diff --git a/av1/encoder/mbgraph.c b/av1/encoder/mbgraph.c
index 8578611..3f5daeb 100644
--- a/av1/encoder/mbgraph.c
+++ b/av1/encoder/mbgraph.c
@@ -52,11 +52,14 @@
   {
     int distortion;
     unsigned int sse;
-    cpi->find_fractional_mv_step(x, ref_mv, cpi->common.allow_high_precision_mv,
-                                 x->errorperbit, &v_fn_ptr, 0,
-                                 mv_sf->subpel_iters_per_step,
-                                 cond_cost_list(cpi, cost_list), NULL, NULL,
-                                 &distortion, &sse, NULL, 0, 0, 0);
+    cpi->find_fractional_mv_step(
+        x, ref_mv, cpi->common.allow_high_precision_mv, x->errorperbit,
+        &v_fn_ptr, 0, mv_sf->subpel_iters_per_step,
+        cond_cost_list(cpi, cost_list), NULL, NULL, &distortion, &sse, NULL,
+#if CONFIG_EXT_INTER
+        NULL, 0, 0,
+#endif
+        0, 0, 0);
   }
 
 #if CONFIG_EXT_INTER
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index 27a7509..cbdfc8f 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -176,6 +176,33 @@
 }
 
 /* checks if (r, c) has better score than previous best */
+#if CONFIG_EXT_INTER
+#define CHECK_BETTER(v, r, c)                                              \
+  if (c >= minc && c <= maxc && r >= minr && r <= maxr) {                  \
+    MV this_mv = { r, c };                                                 \
+    v = mv_err_cost(&this_mv, ref_mv, mvjcost, mvcost, error_per_bit);     \
+    if (second_pred == NULL)                                               \
+      thismse = vfp->svf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r),   \
+                         src_address, src_stride, &sse);                   \
+    else if (mask)                                                         \
+      thismse = vfp->mcsvf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r), \
+                           src_address, src_stride, second_pred, mask,     \
+                           mask_stride, invert_mask, &sse);                \
+    else                                                                   \
+      thismse = vfp->svaf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r),  \
+                          src_address, src_stride, &sse, second_pred);     \
+    v += thismse;                                                          \
+    if (v < besterr) {                                                     \
+      besterr = v;                                                         \
+      br = r;                                                              \
+      bc = c;                                                              \
+      *distortion = thismse;                                               \
+      *sse1 = sse;                                                         \
+    }                                                                      \
+  } else {                                                                 \
+    v = INT_MAX;                                                           \
+  }
+#else
 #define CHECK_BETTER(v, r, c)                                             \
   if (c >= minc && c <= maxc && r >= minr && r <= maxr) {                 \
     MV this_mv = { r, c };                                                \
@@ -197,6 +224,7 @@
   } else {                                                                \
     v = INT_MAX;                                                          \
   }
+#endif
 
 #define CHECK_BETTER0(v, r, c) CHECK_BETTER(v, r, c)
 
@@ -206,6 +234,26 @@
 }
 
 /* checks if (r, c) has better score than previous best */
+#if CONFIG_EXT_INTER
+#define CHECK_BETTER1(v, r, c)                                               \
+  if (c >= minc && c <= maxc && r >= minr && r <= maxr) {                    \
+    MV this_mv = { r, c };                                                   \
+    thismse = upsampled_pref_error(                                          \
+        xd, vfp, src_address, src_stride, upre(y, y_stride, r, c), y_stride, \
+        second_pred, mask, mask_stride, invert_mask, w, h, &sse);            \
+    v = mv_err_cost(&this_mv, ref_mv, mvjcost, mvcost, error_per_bit);       \
+    v += thismse;                                                            \
+    if (v < besterr) {                                                       \
+      besterr = v;                                                           \
+      br = r;                                                                \
+      bc = c;                                                                \
+      *distortion = thismse;                                                 \
+      *sse1 = sse;                                                           \
+    }                                                                        \
+  } else {                                                                   \
+    v = INT_MAX;                                                             \
+  }
+#else
 #define CHECK_BETTER1(v, r, c)                                         \
   if (c >= minc && c <= maxc && r >= minr && r <= maxr) {              \
     MV this_mv = { r, c };                                             \
@@ -224,6 +272,7 @@
   } else {                                                             \
     v = INT_MAX;                                                       \
   }
+#endif
 
 #define FIRST_LEVEL_CHECKS                                       \
   {                                                              \
@@ -327,20 +376,36 @@
     const MACROBLOCKD *xd, const MV *bestmv, const MV *ref_mv,
     int error_per_bit, const aom_variance_fn_ptr_t *vfp,
     const uint8_t *const src, const int src_stride, const uint8_t *const y,
-    int y_stride, const uint8_t *second_pred, int w, int h, int offset,
-    int *mvjcost, int *mvcost[2], unsigned int *sse1, int *distortion) {
+    int y_stride, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int offset, int *mvjcost, int *mvcost[2], unsigned int *sse1,
+    int *distortion) {
   unsigned int besterr;
 #if CONFIG_HIGHBITDEPTH
   if (second_pred != NULL) {
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
       DECLARE_ALIGNED(16, uint16_t, comp_pred16[MAX_SB_SQUARE]);
-      aom_highbd_comp_avg_pred(comp_pred16, second_pred, w, h, y + offset,
-                               y_stride);
+#if CONFIG_EXT_INTER
+      if (mask)
+        aom_highbd_comp_mask_pred(comp_pred16, second_pred, w, h, y + offset,
+                                  y_stride, mask, mask_stride, invert_mask);
+      else
+#endif
+        aom_highbd_comp_avg_pred(comp_pred16, second_pred, w, h, y + offset,
+                                 y_stride);
       besterr =
           vfp->vf(CONVERT_TO_BYTEPTR(comp_pred16), w, src, src_stride, sse1);
     } else {
       DECLARE_ALIGNED(16, uint8_t, comp_pred[MAX_SB_SQUARE]);
-      aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
+#if CONFIG_EXT_INTER
+      if (mask)
+        aom_comp_mask_pred(comp_pred, second_pred, w, h, y + offset, y_stride,
+                           mask, mask_stride, invert_mask);
+      else
+#endif
+        aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
       besterr = vfp->vf(comp_pred, w, src, src_stride, sse1);
     }
   } else {
@@ -352,7 +417,13 @@
   (void)xd;
   if (second_pred != NULL) {
     DECLARE_ALIGNED(16, uint8_t, comp_pred[MAX_SB_SQUARE]);
-    aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
+#if CONFIG_EXT_INTER
+    if (mask)
+      aom_comp_mask_pred(comp_pred, second_pred, w, h, y + offset, y_stride,
+                         mask, mask_stride, invert_mask);
+    else
+#endif
+      aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
     besterr = vfp->vf(comp_pred, w, src, src_stride, sse1);
   } else {
     besterr = vfp->vf(y + offset, y_stride, src, src_stride, sse1);
@@ -391,12 +462,19 @@
     MACROBLOCK *x, const MV *ref_mv, int allow_hp, int error_per_bit,
     const aom_variance_fn_ptr_t *vfp, int forced_stop, int iters_per_step,
     int *cost_list, int *mvjcost, int *mvcost[2], int *distortion,
-    unsigned int *sse1, const uint8_t *second_pred, int w, int h,
-    int use_upsampled_ref) {
+    unsigned int *sse1, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int use_upsampled_ref) {
   SETUP_SUBPEL_SEARCH;
-  besterr = setup_center_error(
-      xd, bestmv, ref_mv, error_per_bit, vfp, src_address, src_stride, y,
-      y_stride, second_pred, w, h, offset, mvjcost, mvcost, sse1, distortion);
+  besterr =
+      setup_center_error(xd, bestmv, ref_mv, error_per_bit, vfp, src_address,
+                         src_stride, y, y_stride, second_pred,
+#if CONFIG_EXT_INTER
+                         mask, mask_stride, invert_mask,
+#endif
+                         w, h, offset, mvjcost, mvcost, sse1, distortion);
   (void)halfiters;
   (void)quarteriters;
   (void)eighthiters;
@@ -457,14 +535,21 @@
     MACROBLOCK *x, const MV *ref_mv, int allow_hp, int error_per_bit,
     const aom_variance_fn_ptr_t *vfp, int forced_stop, int iters_per_step,
     int *cost_list, int *mvjcost, int *mvcost[2], int *distortion,
-    unsigned int *sse1, const uint8_t *second_pred, int w, int h,
-    int use_upsampled_ref) {
+    unsigned int *sse1, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int use_upsampled_ref) {
   SETUP_SUBPEL_SEARCH;
   (void)use_upsampled_ref;
 
-  besterr = setup_center_error(
-      xd, bestmv, ref_mv, error_per_bit, vfp, src_address, src_stride, y,
-      y_stride, second_pred, w, h, offset, mvjcost, mvcost, sse1, distortion);
+  besterr =
+      setup_center_error(xd, bestmv, ref_mv, error_per_bit, vfp, src_address,
+                         src_stride, y, y_stride, second_pred,
+#if CONFIG_EXT_INTER
+                         mask, mask_stride, invert_mask,
+#endif
+                         w, h, offset, mvjcost, mvcost, sse1, distortion);
   if (cost_list && cost_list[0] != INT_MAX && cost_list[1] != INT_MAX &&
       cost_list[2] != INT_MAX && cost_list[3] != INT_MAX &&
       cost_list[4] != INT_MAX && is_cost_list_wellbehaved(cost_list)) {
@@ -519,14 +604,21 @@
     MACROBLOCK *x, const MV *ref_mv, int allow_hp, int error_per_bit,
     const aom_variance_fn_ptr_t *vfp, int forced_stop, int iters_per_step,
     int *cost_list, int *mvjcost, int *mvcost[2], int *distortion,
-    unsigned int *sse1, const uint8_t *second_pred, int w, int h,
-    int use_upsampled_ref) {
+    unsigned int *sse1, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int use_upsampled_ref) {
   SETUP_SUBPEL_SEARCH;
   (void)use_upsampled_ref;
 
-  besterr = setup_center_error(
-      xd, bestmv, ref_mv, error_per_bit, vfp, src_address, src_stride, y,
-      y_stride, second_pred, w, h, offset, mvjcost, mvcost, sse1, distortion);
+  besterr =
+      setup_center_error(xd, bestmv, ref_mv, error_per_bit, vfp, src_address,
+                         src_stride, y, y_stride, second_pred,
+#if CONFIG_EXT_INTER
+                         mask, mask_stride, invert_mask,
+#endif
+                         w, h, offset, mvjcost, mvcost, sse1, distortion);
   if (cost_list && cost_list[0] != INT_MAX && cost_list[1] != INT_MAX &&
       cost_list[2] != INT_MAX && cost_list[3] != INT_MAX &&
       cost_list[4] != INT_MAX) {
@@ -612,17 +704,29 @@
                                 const aom_variance_fn_ptr_t *vfp,
                                 const uint8_t *const src, const int src_stride,
                                 const uint8_t *const y, int y_stride,
-                                const uint8_t *second_pred, int w, int h,
-                                unsigned int *sse) {
+                                const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+                                const uint8_t *mask, int mask_stride,
+                                int invert_mask,
+#endif
+                                int w, int h, unsigned int *sse) {
   unsigned int besterr;
 #if CONFIG_HIGHBITDEPTH
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
     DECLARE_ALIGNED(16, uint16_t, pred16[MAX_SB_SQUARE]);
-    if (second_pred != NULL)
-      aom_highbd_comp_avg_upsampled_pred(pred16, second_pred, w, h, y,
-                                         y_stride);
-    else
+    if (second_pred != NULL) {
+#if CONFIG_EXT_INTER
+      if (mask)
+        aom_highbd_comp_mask_upsampled_pred(pred16, second_pred, w, h, y,
+                                            y_stride, mask, mask_stride,
+                                            invert_mask);
+      else
+#endif
+        aom_highbd_comp_avg_upsampled_pred(pred16, second_pred, w, h, y,
+                                           y_stride);
+    } else {
       aom_highbd_upsampled_pred(pred16, w, h, y, y_stride);
+    }
 
     besterr = vfp->vf(CONVERT_TO_BYTEPTR(pred16), w, src, src_stride, sse);
   } else {
@@ -631,10 +735,17 @@
   DECLARE_ALIGNED(16, uint8_t, pred[MAX_SB_SQUARE]);
   (void)xd;
 #endif  // CONFIG_HIGHBITDEPTH
-    if (second_pred != NULL)
-      aom_comp_avg_upsampled_pred(pred, second_pred, w, h, y, y_stride);
-    else
+    if (second_pred != NULL) {
+#if CONFIG_EXT_INTER
+      if (mask)
+        aom_comp_mask_upsampled_pred(pred, second_pred, w, h, y, y_stride, mask,
+                                     mask_stride, invert_mask);
+      else
+#endif
+        aom_comp_avg_upsampled_pred(pred, second_pred, w, h, y, y_stride);
+    } else {
       aom_upsampled_pred(pred, w, h, y, y_stride);
+    }
 
     besterr = vfp->vf(pred, w, src, src_stride, sse);
 #if CONFIG_HIGHBITDEPTH
@@ -647,23 +758,32 @@
     const MACROBLOCKD *xd, const MV *bestmv, const MV *ref_mv,
     int error_per_bit, const aom_variance_fn_ptr_t *vfp,
     const uint8_t *const src, const int src_stride, const uint8_t *const y,
-    int y_stride, const uint8_t *second_pred, int w, int h, int offset,
-    int *mvjcost, int *mvcost[2], unsigned int *sse1, int *distortion) {
-  unsigned int besterr = upsampled_pref_error(
-      xd, vfp, src, src_stride, y + offset, y_stride, second_pred, w, h, sse1);
+    int y_stride, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int offset, int *mvjcost, int *mvcost[2], unsigned int *sse1,
+    int *distortion) {
+  unsigned int besterr = upsampled_pref_error(xd, vfp, src, src_stride,
+                                              y + offset, y_stride, second_pred,
+#if CONFIG_EXT_INTER
+                                              mask, mask_stride, invert_mask,
+#endif
+                                              w, h, sse1);
   *distortion = besterr;
   besterr += mv_err_cost(bestmv, ref_mv, mvjcost, mvcost, error_per_bit);
   return besterr;
 }
 
-int av1_find_best_sub_pixel_tree(MACROBLOCK *x, const MV *ref_mv, int allow_hp,
-                                 int error_per_bit,
-                                 const aom_variance_fn_ptr_t *vfp,
-                                 int forced_stop, int iters_per_step,
-                                 int *cost_list, int *mvjcost, int *mvcost[2],
-                                 int *distortion, unsigned int *sse1,
-                                 const uint8_t *second_pred, int w, int h,
-                                 int use_upsampled_ref) {
+int av1_find_best_sub_pixel_tree(
+    MACROBLOCK *x, const MV *ref_mv, int allow_hp, int error_per_bit,
+    const aom_variance_fn_ptr_t *vfp, int forced_stop, int iters_per_step,
+    int *cost_list, int *mvjcost, int *mvcost[2], int *distortion,
+    unsigned int *sse1, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int use_upsampled_ref) {
   const uint8_t *const src_address = x->plane[0].src.buf;
   const int src_stride = x->plane[0].src.stride;
   const MACROBLOCKD *xd = &x->e_mbd;
@@ -700,12 +820,19 @@
   if (use_upsampled_ref)
     besterr = upsampled_setup_center_error(
         xd, bestmv, ref_mv, error_per_bit, vfp, src_address, src_stride, y,
-        y_stride, second_pred, w, h, (offset * 8), mvjcost, mvcost, sse1,
-        distortion);
+        y_stride, second_pred,
+#if CONFIG_EXT_INTER
+        mask, mask_stride, invert_mask,
+#endif
+        w, h, (offset * 8), mvjcost, mvcost, sse1, distortion);
   else
-    besterr = setup_center_error(
-        xd, bestmv, ref_mv, error_per_bit, vfp, src_address, src_stride, y,
-        y_stride, second_pred, w, h, offset, mvjcost, mvcost, sse1, distortion);
+    besterr =
+        setup_center_error(xd, bestmv, ref_mv, error_per_bit, vfp, src_address,
+                           src_stride, y, y_stride, second_pred,
+#if CONFIG_EXT_INTER
+                           mask, mask_stride, invert_mask,
+#endif
+                           w, h, offset, mvjcost, mvcost, sse1, distortion);
 
   (void)cost_list;  // to silence compiler warning
 
@@ -721,14 +848,23 @@
           const uint8_t *const pre_address = y + tr * y_stride + tc;
 
           thismse = upsampled_pref_error(xd, vfp, src_address, src_stride,
-                                         pre_address, y_stride, second_pred, w,
-                                         h, &sse);
+                                         pre_address, y_stride, second_pred,
+#if CONFIG_EXT_INTER
+                                         mask, mask_stride, invert_mask,
+#endif
+                                         w, h, &sse);
         } else {
           const uint8_t *const pre_address =
               y + (tr >> 3) * y_stride + (tc >> 3);
           if (second_pred == NULL)
             thismse = vfp->svf(pre_address, y_stride, sp(tc), sp(tr),
                                src_address, src_stride, &sse);
+#if CONFIG_EXT_INTER
+          else if (mask)
+            thismse = vfp->mcsvf(pre_address, y_stride, sp(tc), sp(tr),
+                                 src_address, src_stride, second_pred, mask,
+                                 mask_stride, invert_mask, &sse);
+#endif
           else
             thismse = vfp->svaf(pre_address, y_stride, sp(tc), sp(tr),
                                 src_address, src_stride, &sse, second_pred);
@@ -760,15 +896,24 @@
       if (use_upsampled_ref) {
         const uint8_t *const pre_address = y + tr * y_stride + tc;
 
-        thismse =
-            upsampled_pref_error(xd, vfp, src_address, src_stride, pre_address,
-                                 y_stride, second_pred, w, h, &sse);
+        thismse = upsampled_pref_error(xd, vfp, src_address, src_stride,
+                                       pre_address, y_stride, second_pred,
+#if CONFIG_EXT_INTER
+                                       mask, mask_stride, invert_mask,
+#endif
+                                       w, h, &sse);
       } else {
         const uint8_t *const pre_address = y + (tr >> 3) * y_stride + (tc >> 3);
 
         if (second_pred == NULL)
           thismse = vfp->svf(pre_address, y_stride, sp(tc), sp(tr), src_address,
                              src_stride, &sse);
+#if CONFIG_EXT_INTER
+        else if (mask)
+          thismse = vfp->mcsvf(pre_address, y_stride, sp(tc), sp(tr),
+                               src_address, src_stride, second_pred, mask,
+                               mask_stride, invert_mask, &sse);
+#endif
         else
           thismse = vfp->svaf(pre_address, y_stride, sp(tc), sp(tr),
                               src_address, src_stride, &sse, second_pred);
@@ -1232,6 +1377,27 @@
                      : 0);
 }
 
+#if CONFIG_EXT_INTER
+int av1_get_mvpred_mask_var(const MACROBLOCK *x, const MV *best_mv,
+                            const MV *center_mv, const uint8_t *second_pred,
+                            const uint8_t *mask, int mask_stride,
+                            int invert_mask, const aom_variance_fn_ptr_t *vfp,
+                            int use_mvcost) {
+  const MACROBLOCKD *const xd = &x->e_mbd;
+  const struct buf_2d *const what = &x->plane[0].src;
+  const struct buf_2d *const in_what = &xd->plane[0].pre[0];
+  const MV mv = { best_mv->row * 8, best_mv->col * 8 };
+  unsigned int unused;
+
+  return vfp->mcsvf(what->buf, what->stride, 0, 0,
+                    get_buf_from_mv(in_what, best_mv), in_what->stride,
+                    second_pred, mask, mask_stride, invert_mask, &unused) +
+         (use_mvcost ? mv_err_cost(&mv, center_mv, x->nmvjointcost, x->mvcost,
+                                   x->errorperbit)
+                     : 0);
+}
+#endif
+
 int av1_hex_search(MACROBLOCK *x, MV *start_mv, int search_param,
                    int sad_per_bit, int do_init_search, int *cost_list,
                    const aom_variance_fn_ptr_t *vfp, int use_mvcost,
@@ -2199,6 +2365,10 @@
 // mode.
 int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range,
                              const aom_variance_fn_ptr_t *fn_ptr,
+#if CONFIG_EXT_INTER
+                             const uint8_t *mask, int mask_stride,
+                             int invert_mask,
+#endif
                              const MV *center_mv, const uint8_t *second_pred) {
   const MV neighbors[8] = { { -1, 0 },  { 0, -1 }, { 0, 1 },  { 1, 0 },
                             { -1, -1 }, { 1, -1 }, { -1, 1 }, { 1, 1 } };
@@ -2212,10 +2382,18 @@
 
   clamp_mv(best_mv, x->mv_limits.col_min, x->mv_limits.col_max,
            x->mv_limits.row_min, x->mv_limits.row_max);
-  best_sad =
-      fn_ptr->sdaf(what->buf, what->stride, get_buf_from_mv(in_what, best_mv),
-                   in_what->stride, second_pred) +
-      mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
+#if CONFIG_EXT_INTER
+  if (mask)
+    best_sad = fn_ptr->mcsdf(what->buf, what->stride,
+                             get_buf_from_mv(in_what, best_mv), in_what->stride,
+                             second_pred, mask, mask_stride, invert_mask) +
+               mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
+  else
+#endif
+    best_sad =
+        fn_ptr->sdaf(what->buf, what->stride, get_buf_from_mv(in_what, best_mv),
+                     in_what->stride, second_pred) +
+        mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
 
   for (i = 0; i < search_range; ++i) {
     int best_site = -1;
@@ -2225,9 +2403,17 @@
                       best_mv->col + neighbors[j].col };
 
       if (is_mv_in(&x->mv_limits, &mv)) {
-        unsigned int sad =
-            fn_ptr->sdaf(what->buf, what->stride, get_buf_from_mv(in_what, &mv),
-                         in_what->stride, second_pred);
+        unsigned int sad;
+#if CONFIG_EXT_INTER
+        if (mask)
+          sad = fn_ptr->mcsdf(what->buf, what->stride,
+                              get_buf_from_mv(in_what, &mv), in_what->stride,
+                              second_pred, mask, mask_stride, invert_mask);
+        else
+#endif
+          sad = fn_ptr->sdaf(what->buf, what->stride,
+                             get_buf_from_mv(in_what, &mv), in_what->stride,
+                             second_pred);
         if (sad < best_sad) {
           sad += mvsad_err_cost(x, &mv, &fcenter_mv, error_per_bit);
           if (sad < best_sad) {
@@ -3453,15 +3639,21 @@
   (void)thismse;           \
   (void)cost_list;
 // Return the maximum MV.
-int av1_return_max_sub_pixel_mv(MACROBLOCK *x, const MV *ref_mv, int allow_hp,
-                                int error_per_bit,
-                                const aom_variance_fn_ptr_t *vfp,
-                                int forced_stop, int iters_per_step,
-                                int *cost_list, int *mvjcost, int *mvcost[2],
-                                int *distortion, unsigned int *sse1,
-                                const uint8_t *second_pred, int w, int h,
-                                int use_upsampled_ref) {
+int av1_return_max_sub_pixel_mv(
+    MACROBLOCK *x, const MV *ref_mv, int allow_hp, int error_per_bit,
+    const aom_variance_fn_ptr_t *vfp, int forced_stop, int iters_per_step,
+    int *cost_list, int *mvjcost, int *mvcost[2], int *distortion,
+    unsigned int *sse1, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int use_upsampled_ref) {
   COMMON_MV_TEST;
+#if CONFIG_EXT_INTER
+  (void)mask;
+  (void)mask_stride;
+  (void)invert_mask;
+#endif
   (void)minr;
   (void)minc;
   bestmv->row = maxr;
@@ -3473,17 +3665,23 @@
   return besterr;
 }
 // Return the minimum MV.
-int av1_return_min_sub_pixel_mv(MACROBLOCK *x, const MV *ref_mv, int allow_hp,
-                                int error_per_bit,
-                                const aom_variance_fn_ptr_t *vfp,
-                                int forced_stop, int iters_per_step,
-                                int *cost_list, int *mvjcost, int *mvcost[2],
-                                int *distortion, unsigned int *sse1,
-                                const uint8_t *second_pred, int w, int h,
-                                int use_upsampled_ref) {
+int av1_return_min_sub_pixel_mv(
+    MACROBLOCK *x, const MV *ref_mv, int allow_hp, int error_per_bit,
+    const aom_variance_fn_ptr_t *vfp, int forced_stop, int iters_per_step,
+    int *cost_list, int *mvjcost, int *mvcost[2], int *distortion,
+    unsigned int *sse1, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int use_upsampled_ref) {
   COMMON_MV_TEST;
   (void)maxr;
   (void)maxc;
+#if CONFIG_EXT_INTER
+  (void)mask;
+  (void)mask_stride;
+  (void)invert_mask;
+#endif
   bestmv->row = minr;
   bestmv->col = minc;
   besterr = 0;
diff --git a/av1/encoder/mcomp.h b/av1/encoder/mcomp.h
index 8465860..eb989e8 100644
--- a/av1/encoder/mcomp.h
+++ b/av1/encoder/mcomp.h
@@ -58,6 +58,13 @@
 int av1_get_mvpred_av_var(const MACROBLOCK *x, const MV *best_mv,
                           const MV *center_mv, const uint8_t *second_pred,
                           const aom_variance_fn_ptr_t *vfp, int use_mvcost);
+#if CONFIG_EXT_INTER
+int av1_get_mvpred_mask_var(const MACROBLOCK *x, const MV *best_mv,
+                            const MV *center_mv, const uint8_t *second_pred,
+                            const uint8_t *mask, int mask_stride,
+                            int invert_mask, const aom_variance_fn_ptr_t *vfp,
+                            int use_mvcost);
+#endif
 
 struct AV1_COMP;
 struct SPEED_FEATURES;
@@ -91,8 +98,11 @@
     const aom_variance_fn_ptr_t *vfp,
     int forced_stop,  // 0 - full, 1 - qtr only, 2 - half only
     int iters_per_step, int *cost_list, int *mvjcost, int *mvcost[2],
-    int *distortion, unsigned int *sse1, const uint8_t *second_pred, int w,
-    int h, int use_upsampled_ref);
+    int *distortion, unsigned int *sse1, const uint8_t *second_pred,
+#if CONFIG_EXT_INTER
+    const uint8_t *mask, int mask_stride, int invert_mask,
+#endif
+    int w, int h, int use_upsampled_ref);
 
 extern fractional_mv_step_fp av1_find_best_sub_pixel_tree;
 extern fractional_mv_step_fp av1_find_best_sub_pixel_tree_pruned;
@@ -113,6 +123,10 @@
 
 int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range,
                              const aom_variance_fn_ptr_t *fn_ptr,
+#if CONFIG_EXT_INTER
+                             const uint8_t *mask, int mask_stride,
+                             int invert_mask,
+#endif
                              const MV *center_mv, const uint8_t *second_pred);
 
 struct AV1_COMP;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index a4aff3c..63e594a 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5452,7 +5452,8 @@
                                 BLOCK_SIZE bsize, int_mv *frame_mv, int mi_row,
                                 int mi_col,
 #if CONFIG_EXT_INTER
-                                int_mv *ref_mv_sub8x8[2],
+                                int_mv *ref_mv_sub8x8[2], const uint8_t *mask,
+                                int mask_stride,
 #endif  // CONFIG_EXT_INTER
                                 int *rate_mv, const int block) {
   const AV1_COMMON *const cm = &cpi->common;
@@ -5618,10 +5619,21 @@
     // Small-range full-pixel motion search.
     bestsme =
         av1_refining_search_8p_c(x, sadpb, search_range, &cpi->fn_ptr[bsize],
+#if CONFIG_EXT_INTER
+                                 mask, mask_stride, id,
+#endif
                                  &ref_mv[id].as_mv, second_pred);
-    if (bestsme < INT_MAX)
-      bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv[id].as_mv,
-                                      second_pred, &cpi->fn_ptr[bsize], 1);
+    if (bestsme < INT_MAX) {
+#if CONFIG_EXT_INTER
+      if (mask)
+        bestsme = av1_get_mvpred_mask_var(x, best_mv, &ref_mv[id].as_mv,
+                                          second_pred, mask, mask_stride, id,
+                                          &cpi->fn_ptr[bsize], 1);
+      else
+#endif
+        bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv[id].as_mv,
+                                        second_pred, &cpi->fn_ptr[bsize], 1);
+    }
 
     x->mv_limits = tmp_mv_limits;
 
@@ -5654,7 +5666,11 @@
             x, &ref_mv[id].as_mv, cpi->common.allow_high_precision_mv,
             x->errorperbit, &cpi->fn_ptr[bsize], 0,
             cpi->sf.mv.subpel_iters_per_step, NULL, x->nmvjointcost, x->mvcost,
-            &dis, &sse, second_pred, pw, ph, 1);
+            &dis, &sse, second_pred,
+#if CONFIG_EXT_INTER
+            mask, mask_stride, id,
+#endif
+            pw, ph, 1);
 
         // Restore the reference frames.
         pd->pre[0] = backup_pred;
@@ -5664,7 +5680,11 @@
             x, &ref_mv[id].as_mv, cpi->common.allow_high_precision_mv,
             x->errorperbit, &cpi->fn_ptr[bsize], 0,
             cpi->sf.mv.subpel_iters_per_step, NULL, x->nmvjointcost, x->mvcost,
-            &dis, &sse, second_pred, pw, ph, 0);
+            &dis, &sse, second_pred,
+#if CONFIG_EXT_INTER
+            mask, mask_stride, id,
+#endif
+            pw, ph, 0);
       }
     }
 
@@ -6060,8 +6080,11 @@
                   cpi->sf.mv.subpel_force_stop,
                   cpi->sf.mv.subpel_iters_per_step,
                   cond_cost_list(cpi, cost_list), x->nmvjointcost, x->mvcost,
-                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL, pw, ph,
-                  1);
+                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL,
+#if CONFIG_EXT_INTER
+                  NULL, 0, 0,
+#endif
+                  pw, ph, 1);
 
               if (try_second) {
                 int this_var;
@@ -6088,7 +6111,11 @@
                       cpi->sf.mv.subpel_iters_per_step,
                       cond_cost_list(cpi, cost_list), x->nmvjointcost,
                       x->mvcost, &distortion, &x->pred_sse[mbmi->ref_frame[0]],
-                      NULL, pw, ph, 1);
+                      NULL,
+#if CONFIG_EXT_INTER
+                      NULL, 0, 0,
+#endif
+                      pw, ph, 1);
                   if (this_var < best_mv_var) best_mv = x->best_mv.as_mv;
                   x->best_mv.as_mv = best_mv;
                 }
@@ -6103,7 +6130,11 @@
                   cpi->sf.mv.subpel_force_stop,
                   cpi->sf.mv.subpel_iters_per_step,
                   cond_cost_list(cpi, cost_list), x->nmvjointcost, x->mvcost,
-                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL, 0, 0, 0);
+                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL,
+#if CONFIG_EXT_INTER
+                  NULL, 0, 0,
+#endif
+                  0, 0, 0);
             }
 
 // save motion search result for use in compound prediction
@@ -6165,7 +6196,7 @@
             joint_motion_search(cpi, x, bsize, frame_mv[this_mode], mi_row,
                                 mi_col,
 #if CONFIG_EXT_INTER
-                                bsi->ref_mv,
+                                bsi->ref_mv, NULL, 0,
 #endif  // CONFIG_EXT_INTER
                                 &rate_mv, index);
 #if CONFIG_EXT_INTER
@@ -6958,8 +6989,11 @@
               x, &ref_mv, cm->allow_high_precision_mv, x->errorperbit,
               &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
-              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL, pw, ph,
-              1);
+              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL,
+#if CONFIG_EXT_INTER
+              NULL, 0, 0,
+#endif
+              pw, ph, 1);
 
           if (try_second) {
             const int minc =
@@ -6983,7 +7017,11 @@
                   &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
                   cpi->sf.mv.subpel_iters_per_step,
                   cond_cost_list(cpi, cost_list), x->nmvjointcost, x->mvcost,
-                  &dis, &x->pred_sse[ref], NULL, pw, ph, 1);
+                  &dis, &x->pred_sse[ref], NULL,
+#if CONFIG_EXT_INTER
+                  NULL, 0, 0,
+#endif
+                  pw, ph, 1);
               if (this_var < best_mv_var) best_mv = x->best_mv.as_mv;
               x->best_mv.as_mv = best_mv;
             }
@@ -6996,8 +7034,11 @@
               x, &ref_mv, cm->allow_high_precision_mv, x->errorperbit,
               &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
-              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL, 0, 0,
-              0);
+              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL,
+#if CONFIG_EXT_INTER
+              NULL, 0, 0,
+#endif
+              0, 0, 0);
         }
 #if CONFIG_MOTION_VAR
         break;
@@ -7161,7 +7202,7 @@
 }
 
 static void do_masked_motion_search_indexed(
-    const AV1_COMP *const cpi, MACROBLOCK *x,
+    const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
     const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
     int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int which) {
   // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
@@ -7173,11 +7214,21 @@
 
   mask = av1_get_compound_type_mask(comp_data, sb_type);
 
-  if (which == 0 || which == 2)
-    do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
-                            &tmp_mv[0], &rate_mv[0], 0);
+  if (which == 2) {
+    int_mv frame_mv[TOTAL_REFS_PER_FRAME];
+    MV_REFERENCE_FRAME rf[2] = { mbmi->ref_frame[0], mbmi->ref_frame[1] };
+    assert(bsize >= BLOCK_8X8 || CONFIG_CB4X4);
 
-  if (which == 1 || which == 2) {
+    frame_mv[rf[0]].as_int = cur_mv[0].as_int;
+    frame_mv[rf[1]].as_int = cur_mv[1].as_int;
+    joint_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col, NULL, mask,
+                        mask_stride, rate_mv, 0);
+    tmp_mv[0].as_int = frame_mv[rf[0]].as_int;
+    tmp_mv[1].as_int = frame_mv[rf[1]].as_int;
+  } else if (which == 0) {
+    do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
+                            &tmp_mv[0], rate_mv, 0);
+  } else if (which == 1) {
 // get the negative mask
 #if CONFIG_COMPOUND_SEGMENT
     uint8_t inv_mask_buf[2 * MAX_SB_SQUARE];
@@ -7188,7 +7239,7 @@
     mask = av1_get_compound_type_mask_inverse(comp_data, sb_type);
 #endif  // CONFIG_COMPOUND_SEGMENT
     do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
-                            &tmp_mv[1], &rate_mv[1], 1);
+                            &tmp_mv[1], rate_mv, 1);
   }
 }
 #endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
@@ -7665,15 +7716,13 @@
   }
 }
 
-static int interinter_compound_motion_search(const AV1_COMP *const cpi,
-                                             MACROBLOCK *x,
-                                             const BLOCK_SIZE bsize,
-                                             const int this_mode, int mi_row,
-                                             int mi_col) {
+static int interinter_compound_motion_search(
+    const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
+    const BLOCK_SIZE bsize, const int this_mode, int mi_row, int mi_col) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   int_mv tmp_mv[2];
-  int rate_mvs[2], tmp_rate_mv = 0;
+  int tmp_rate_mv = 0;
   const INTERINTER_COMPOUND_DATA compound_data = {
 #if CONFIG_WEDGE
     mbmi->wedge_index,
@@ -7686,20 +7735,17 @@
     mbmi->interinter_compound_type
   };
   if (this_mode == NEW_NEWMV) {
-    do_masked_motion_search_indexed(cpi, x, &compound_data, bsize, mi_row,
-                                    mi_col, tmp_mv, rate_mvs, 2);
-    tmp_rate_mv = rate_mvs[0] + rate_mvs[1];
+    do_masked_motion_search_indexed(cpi, x, cur_mv, &compound_data, bsize,
+                                    mi_row, mi_col, tmp_mv, &tmp_rate_mv, 2);
     mbmi->mv[0].as_int = tmp_mv[0].as_int;
     mbmi->mv[1].as_int = tmp_mv[1].as_int;
   } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) {
-    do_masked_motion_search_indexed(cpi, x, &compound_data, bsize, mi_row,
-                                    mi_col, tmp_mv, rate_mvs, 0);
-    tmp_rate_mv = rate_mvs[0];
+    do_masked_motion_search_indexed(cpi, x, cur_mv, &compound_data, bsize,
+                                    mi_row, mi_col, tmp_mv, &tmp_rate_mv, 0);
     mbmi->mv[0].as_int = tmp_mv[0].as_int;
   } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
-    do_masked_motion_search_indexed(cpi, x, &compound_data, bsize, mi_row,
-                                    mi_col, tmp_mv, rate_mvs, 1);
-    tmp_rate_mv = rate_mvs[1];
+    do_masked_motion_search_indexed(cpi, x, cur_mv, &compound_data, bsize,
+                                    mi_row, mi_col, tmp_mv, &tmp_rate_mv, 1);
     mbmi->mv[1].as_int = tmp_mv[1].as_int;
   }
   return tmp_rate_mv;
@@ -7726,8 +7772,8 @@
 
   if (have_newmv_in_inter_mode(this_mode) &&
       use_masked_motion_search(compound_type)) {
-    *out_rate_mv = interinter_compound_motion_search(cpi, x, bsize, this_mode,
-                                                     mi_row, mi_col);
+    *out_rate_mv = interinter_compound_motion_search(cpi, x, cur_mv, bsize,
+                                                     this_mode, mi_row, mi_col);
     av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, ctx, bsize);
     model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
                     &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
@@ -7823,8 +7869,8 @@
       frame_mv[refs[1]].as_int = single_newmv[refs[1]].as_int;
 
       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
-        joint_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col, NULL,
-                            rate_mv, 0);
+        joint_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col, NULL, NULL,
+                            0, rate_mv, 0);
       } else {
         *rate_mv = 0;
         for (i = 0; i < 2; ++i) {
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index 98feb87..7508de8 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -297,8 +297,11 @@
   bestsme = cpi->find_fractional_mv_step(
       x, &best_ref_mv1, cpi->common.allow_high_precision_mv, x->errorperbit,
       &cpi->fn_ptr[BLOCK_16X16], 0, mv_sf->subpel_iters_per_step,
-      cond_cost_list(cpi, cost_list), NULL, NULL, &distortion, &sse, NULL, 0, 0,
-      0);
+      cond_cost_list(cpi, cost_list), NULL, NULL, &distortion, &sse, NULL,
+#if CONFIG_EXT_INTER
+      NULL, 0, 0,
+#endif
+      0, 0, 0);
 
   x->e_mbd.mi[0]->bmi[0].as_mv[0] = x->best_mv;