Add sse2 four-points sad_avg

Change-Id: Ie5c3e0ccfca813257a19521ee9874012710eca1e
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index b2b4b19..6b31d02 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -822,6 +822,7 @@
   foreach (@block_sizes) {
     ($w, $h) = @$_;
     add_proto qw/void/, "aom_sad${w}x${h}x4d", "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, uint32_t *sad_array";
+    add_proto qw/void/, "aom_sad${w}x${h}x4d_avg", "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, const uint8_t *second_pred, uint32_t *sad_array";
   }
 
   specialize qw/aom_sad128x128x4d avx2          sse2/;
@@ -853,6 +854,34 @@
   specialize qw/aom_sad32x8x4d  sse2/;
   specialize qw/aom_sad64x16x4d sse2/;
 
+  specialize qw/aom_sad128x128x4d_avg sse2/;
+  specialize qw/aom_sad128x64x4d_avg  sse2/;
+  specialize qw/aom_sad64x128x4d_avg  sse2/;
+  specialize qw/aom_sad64x64x4d_avg   sse2/;
+  specialize qw/aom_sad64x32x4d_avg   sse2/;
+  specialize qw/aom_sad64x16x4d_avg   sse2/;
+  specialize qw/aom_sad32x64x4d_avg   sse2/;
+  specialize qw/aom_sad32x32x4d_avg   sse2/;
+  specialize qw/aom_sad32x16x4d_avg   sse2/;
+  specialize qw/aom_sad32x8x4d_avg    sse2/;
+  specialize qw/aom_sad16x64x4d_avg   sse2/;
+  specialize qw/aom_sad16x32x4d_avg   sse2/;
+  specialize qw/aom_sad16x16x4d_avg   sse2/;
+  specialize qw/aom_sad16x8x4d_avg    sse2/;
+
+  specialize qw/aom_sad8x16x4d_avg    sse2/;
+  specialize qw/aom_sad8x8x4d_avg     sse2/;
+  specialize qw/aom_sad8x4x4d_avg     sse2/;
+  specialize qw/aom_sad4x16x4d_avg    sse2/;
+  specialize qw/aom_sad4x8x4d_avg     sse2/;
+  specialize qw/aom_sad4x4x4d_avg     sse2/;
+
+  specialize qw/aom_sad4x32x4d_avg    sse2/;
+  specialize qw/aom_sad4x16x4d_avg    sse2/;
+  specialize qw/aom_sad16x4x4d_avg    sse2/;
+  specialize qw/aom_sad8x32x4d_avg    sse2/;
+  specialize qw/aom_sad32x8x4d_avg    sse2/;
+  specialize qw/aom_sad64x16x4d_avg   sse2/;
   #
   # Multi-block SAD, comparing a reference to N independent blocks
   #
diff --git a/aom_dsp/sad.c b/aom_dsp/sad.c
index 166a17a..8ddc683 100644
--- a/aom_dsp/sad.c
+++ b/aom_dsp/sad.c
@@ -64,15 +64,24 @@
   }
 
 // Calculate sad against 4 reference locations and store each in sad_array
-#define sadMxNx4D(m, n)                                                    \
-  void aom_sad##m##x##n##x4d_c(const uint8_t *src, int src_stride,         \
-                               const uint8_t *const ref_array[],           \
-                               int ref_stride, uint32_t *sad_array) {      \
-    int i;                                                                 \
-    for (i = 0; i < 4; ++i) {                                              \
-      sad_array[i] =                                                       \
-          aom_sad##m##x##n##_c(src, src_stride, ref_array[i], ref_stride); \
-    }                                                                      \
+#define sadMxNx4D(m, n)                                                      \
+  void aom_sad##m##x##n##x4d_c(const uint8_t *src, int src_stride,           \
+                               const uint8_t *const ref_array[],             \
+                               int ref_stride, uint32_t *sad_array) {        \
+    int i;                                                                   \
+    for (i = 0; i < 4; ++i) {                                                \
+      sad_array[i] =                                                         \
+          aom_sad##m##x##n##_c(src, src_stride, ref_array[i], ref_stride);   \
+    }                                                                        \
+  }                                                                          \
+  void aom_sad##m##x##n##x4d_avg_c(                                          \
+      const uint8_t *src, int src_stride, const uint8_t *const ref_array[],  \
+      int ref_stride, const uint8_t *second_pred, uint32_t *sad_array) {     \
+    int i;                                                                   \
+    for (i = 0; i < 4; ++i) {                                                \
+      sad_array[i] = aom_sad##m##x##n##_avg_c(src, src_stride, ref_array[i], \
+                                              ref_stride, second_pred);      \
+    }                                                                        \
   }
 
 // 128x128
diff --git a/aom_dsp/x86/sad4d_sse2.asm b/aom_dsp/x86/sad4d_sse2.asm
index 55a8569..a904374 100644
--- a/aom_dsp/x86/sad4d_sse2.asm
+++ b/aom_dsp/x86/sad4d_sse2.asm
@@ -15,15 +15,85 @@
 
 SECTION .text
 
-; PROCESS_4x2x4 first, off_{first,second}_{src,ref}, advance_at_end
-%macro PROCESS_4x2x4 5-6 0
-  movd                  m0, [srcq +%2]
+%macro AVG_4x2x4 2
+  movh                  m2, [second_predq]
+  movlhps               m2, m2
+  pavgb                 %1, m2
+  pavgb                 %2, m2
+  lea                   second_predq, [second_predq+8]
+%endmacro
+; 'mflag' affect a lot how the code works.
+;
+; When 'mflag' is false, the 'src_strideq' resides in register,
+; [srcq + src_strideq + offset] is allowed, so we can simply
+; use such form to access src memory and don't bother to update
+; 'srcq' at each line. We only update 'srcq' each two-lines using
+; a compact LEA instruction like [srcq+src_strideq*2].
+;
+; When 'mflag' is true, the 'src_strideq' resides in memory.
+; we cannot use above form to access memory, we have to update
+; 'srcq' at each line break. As we process two parts (first,second)
+; together in each macro function, the second part may also sit
+; in the next line, which means we also need to possibly add
+; one 'src_strideq' to 'srcq' before processing second part.
+
+%macro HANDLE_FIRST_OFFSET 2
+  %define first_offset %2
+  %if mflag == 0 && %1 == 1
+    %define first_offset (src_strideq + %2)
+  %endif
+%endmacro
+
+; first_extraline, second_extraline, in_line_offset
+%macro HANDLE_SECOND_OFFSET 3
+  %define second_offset %3
+  %if mflag && %1 == 0 && %2 == 1
+    add srcq, src_strideq
+  %endif
+  %if mflag == 0 && %2 == 1
+    %define second_offset (src_strideq + %3)
+  %endif
+%endmacro
+
+; Notes for line_ending:
+; 0 -- not a line ending
+; 1 -- line ending of a odd line [line numbers starts from one]
+; 2 -- line ending of a even line
+; This is specically designed to handle when src_strideq is a
+; memory position, under such case, we can not accomplish
+; complex address calculation using LEA, and fall back to
+; using simple ADD instruction at each line ending.
+%macro ADVANCE_END_OF_LINE 1
+  %if mflag
+    add srcq, src_strideq
+  %endif
+  %if mflag == 0 && %1 == 2
+    lea                 srcq, [srcq +src_strideq*2]
+  %endif
+
+  %if %1 == 2
+    lea                ref1q, [ref1q+ref_strideq*2]
+    lea                ref2q, [ref2q+ref_strideq*2]
+    lea                ref3q, [ref3q+ref_strideq*2]
+    lea                ref4q, [ref4q+ref_strideq*2]
+  %endif
+%endmacro
+
+; Please note that the second_offset of src is for in_line_offset,
+; so it is less than src_stride.
+; PROCESS_4x2x4 first, off_{first,second}_{src,ref}, do_avg,
+;               {first, second}_extraline, line_ending
+%macro PROCESS_4x2x4 9
+  HANDLE_FIRST_OFFSET   %7, %2
+  movd                  m0, [srcq + first_offset]
+  HANDLE_SECOND_OFFSET  %7, %8, %4
 %if %1 == 1
   movd                  m6, [ref1q+%3]
   movd                  m4, [ref2q+%3]
   movd                  m7, [ref3q+%3]
   movd                  m5, [ref4q+%3]
-  movd                  m1, [srcq +%4]
+
+  movd                  m1, [srcq + second_offset]
   movd                  m2, [ref1q+%5]
   punpckldq             m0, m1
   punpckldq             m6, m2
@@ -36,6 +106,9 @@
   movlhps               m0, m0
   movlhps               m6, m4
   movlhps               m7, m5
+%if %6 == 1
+  AVG_4x2x4             m6, m7
+%endif
   psadbw                m6, m0
   psadbw                m7, m0
 %else
@@ -51,38 +124,48 @@
   movd                  m4, [ref4q+%3]
   movd                  m5, [ref4q+%5]
   punpckldq             m4, m5
-  movd                  m5, [srcq +%4]
+  movd                  m5, [srcq + second_offset]
   punpckldq             m0, m5
   movlhps               m0, m0
   movlhps               m1, m2
   movlhps               m3, m4
+%if %6 == 1
+  AVG_4x2x4             m1, m3
+%endif
   psadbw                m1, m0
   psadbw                m3, m0
   paddd                 m6, m1
   paddd                 m7, m3
 %endif
-%if %6 == 1
-  lea                 srcq, [srcq +src_strideq*2]
-  lea                ref1q, [ref1q+ref_strideq*2]
-  lea                ref2q, [ref2q+ref_strideq*2]
-  lea                ref3q, [ref3q+ref_strideq*2]
-  lea                ref4q, [ref4q+ref_strideq*2]
+%if %9 > 0
+  ADVANCE_END_OF_LINE %9
 %endif
 %endmacro
 
-; PROCESS_8x2x4 first, off_{first,second}_{src,ref}, advance_at_end
-%macro PROCESS_8x2x4 5-6 0
-  movh                  m0, [srcq +%2]
+; PROCESS_8x2x4 first, off_{first,second}_{src,ref}, do_avg,
+;               {first,second}_extraline, line_ending
+%macro PROCESS_8x2x4 9
+  HANDLE_FIRST_OFFSET   %7, %2
+  movh                  m0, [srcq + first_offset]
+  HANDLE_SECOND_OFFSET  %7, %8, %4
 %if %1 == 1
   movh                  m4, [ref1q+%3]
   movh                  m5, [ref2q+%3]
   movh                  m6, [ref3q+%3]
   movh                  m7, [ref4q+%3]
-  movhps                m0, [srcq +%4]
+  movhps                m0, [srcq + second_offset]
   movhps                m4, [ref1q+%5]
   movhps                m5, [ref2q+%5]
   movhps                m6, [ref3q+%5]
   movhps                m7, [ref4q+%5]
+%if %6 == 1
+  movu                  m3, [second_predq]
+  pavgb                 m4, m3
+  pavgb                 m5, m3
+  pavgb                 m6, m3
+  pavgb                 m7, m3
+  lea                   second_predq, [second_predq+mmsize]
+%endif
   psadbw                m4, m0
   psadbw                m5, m0
   psadbw                m6, m0
@@ -90,105 +173,148 @@
 %else
   movh                  m1, [ref1q+%3]
   movh                  m2, [ref2q+%3]
-  movh                  m3, [ref3q+%3]
-  movhps                m0, [srcq +%4]
+  movhps                m0, [srcq + second_offset]
   movhps                m1, [ref1q+%5]
   movhps                m2, [ref2q+%5]
-  movhps                m3, [ref3q+%5]
+%if %6 == 1
+  movu                  m3, [second_predq]
+  pavgb                 m1, m3
+  pavgb                 m2, m3
+%endif
   psadbw                m1, m0
   psadbw                m2, m0
-  psadbw                m3, m0
   paddd                 m4, m1
-  movh                  m1, [ref4q+%3]
-  movhps                m1, [ref4q+%5]
   paddd                 m5, m2
-  paddd                 m6, m3
-  psadbw                m1, m0
-  paddd                 m7, m1
-%endif
+
+  movh                  m1, [ref3q+%3]
+  movhps                m1, [ref3q+%5]
+  movh                  m2, [ref4q+%3]
+  movhps                m2, [ref4q+%5]
 %if %6 == 1
-  lea                 srcq, [srcq +src_strideq*2]
-  lea                ref1q, [ref1q+ref_strideq*2]
-  lea                ref2q, [ref2q+ref_strideq*2]
-  lea                ref3q, [ref3q+ref_strideq*2]
-  lea                ref4q, [ref4q+ref_strideq*2]
+  pavgb                 m1, m3
+  pavgb                 m2, m3
+  lea                   second_predq, [second_predq+mmsize]
+%endif
+  psadbw                m1, m0
+  psadbw                m2, m0
+  paddd                 m6, m1
+  paddd                 m7, m2
+%endif
+%if %9 > 0
+  ADVANCE_END_OF_LINE %9
 %endif
 %endmacro
 
-; PROCESS_16x2x4 first, off_{first,second}_{src,ref}, advance_at_end
-%macro PROCESS_16x2x4 5-6 0
+; PROCESS_16x2x4 first, off_{first,second}_{src,ref}, do_avg,
+;                {first,second}_extraline, line_ending
+%macro PROCESS_16x2x4 9
   ; 1st 16 px
-  mova                  m0, [srcq +%2]
+  HANDLE_FIRST_OFFSET   %7, %2
+  mova                  m0, [srcq + first_offset]
+  HANDLE_SECOND_OFFSET  %7, %8, %4
 %if %1 == 1
   movu                  m4, [ref1q+%3]
   movu                  m5, [ref2q+%3]
   movu                  m6, [ref3q+%3]
   movu                  m7, [ref4q+%3]
+%if %6 == 1
+  movu                  m3, [second_predq]
+  pavgb                 m4, m3
+  pavgb                 m5, m3
+  pavgb                 m6, m3
+  pavgb                 m7, m3
+  lea                   second_predq, [second_predq+mmsize]
+%endif
   psadbw                m4, m0
   psadbw                m5, m0
   psadbw                m6, m0
   psadbw                m7, m0
-%else
+%else ; %1 == 1
   movu                  m1, [ref1q+%3]
   movu                  m2, [ref2q+%3]
-  movu                  m3, [ref3q+%3]
+%if %6 == 1
+  movu                  m3, [second_predq]
+  pavgb                 m1, m3
+  pavgb                 m2, m3
+%endif
   psadbw                m1, m0
   psadbw                m2, m0
-  psadbw                m3, m0
   paddd                 m4, m1
-  movu                  m1, [ref4q+%3]
   paddd                 m5, m2
-  paddd                 m6, m3
-  psadbw                m1, m0
-  paddd                 m7, m1
+
+  movu                  m1, [ref3q+%3]
+  movu                  m2, [ref4q+%3]
+%if %6 == 1
+  pavgb                 m1, m3
+  pavgb                 m2, m3
+  lea                   second_predq, [second_predq+mmsize]
 %endif
+  psadbw                m1, m0
+  psadbw                m2, m0
+  paddd                 m6, m1
+  paddd                 m7, m2
+%endif ; %1 == 1
 
   ; 2nd 16 px
-  mova                  m0, [srcq +%4]
+  mova                  m0, [srcq + second_offset]
   movu                  m1, [ref1q+%5]
   movu                  m2, [ref2q+%5]
-  movu                  m3, [ref3q+%5]
-  psadbw                m1, m0
-  psadbw                m2, m0
-  psadbw                m3, m0
-  paddd                 m4, m1
-  movu                  m1, [ref4q+%5]
-  paddd                 m5, m2
-  paddd                 m6, m3
+
 %if %6 == 1
-  lea                 srcq, [srcq +src_strideq*2]
-  lea                ref1q, [ref1q+ref_strideq*2]
-  lea                ref2q, [ref2q+ref_strideq*2]
-  lea                ref3q, [ref3q+ref_strideq*2]
-  lea                ref4q, [ref4q+ref_strideq*2]
+  movu                  m3, [second_predq]
+  pavgb                 m1, m3
+  pavgb                 m2, m3
 %endif
   psadbw                m1, m0
-  paddd                 m7, m1
+  psadbw                m2, m0
+  paddd                 m4, m1
+  paddd                 m5, m2
+
+  movu                  m1, [ref3q+%5]
+  movu                  m2, [ref4q+%5]
+
+%if %9 > 0
+  ADVANCE_END_OF_LINE %9
+%endif
+
+%if %6 == 1
+  pavgb                 m1, m3
+  pavgb                 m2, m3
+  lea                   second_predq, [second_predq+mmsize]
+%endif
+  psadbw                m1, m0
+  psadbw                m2, m0
+  paddd                 m6, m1
+  paddd                 m7, m2
 %endmacro
 
-; PROCESS_32x2x4 first, off_{first,second}_{src,ref}, advance_at_end
-%macro PROCESS_32x2x4 5-6 0
-  PROCESS_16x2x4 %1, %2, %3, %2 + 16, %3 + 16
-  PROCESS_16x2x4  0, %4, %5, %4 + 16, %5 + 16, %6
+; PROCESS_32x2x4 first, off_{first,second}_{src,ref}, do_avg,
+;                {first,second}_extraline, line_ending
+%macro PROCESS_32x2x4 9
+  PROCESS_16x2x4 %1, %2, %3, %2 + 16, %3 + 16, %6, %7, %7, %8 - %7
+  PROCESS_16x2x4  0, %4, %5, %4 + 16, %5 + 16, %6, %8, %8, %9
 %endmacro
 
-; PROCESS_64x2x4 first, off_{first,second}_{src,ref}, advance_at_end
-%macro PROCESS_64x2x4 5-6 0
-  PROCESS_32x2x4 %1, %2, %3, %2 + 32, %3 + 32
-  PROCESS_32x2x4  0, %4, %5, %4 + 32, %5 + 32, %6
+; PROCESS_64x2x4 first, off_{first,second}_{src,ref}, do_avg,
+;                {first,second}_extraline, line_ending
+%macro PROCESS_64x2x4 9
+  PROCESS_32x2x4 %1, %2, %3, %2 + 32, %3 + 32, %6, %7, %7, %8 - %7
+  PROCESS_32x2x4  0, %4, %5, %4 + 32, %5 + 32, %6, %8, %8, %9
 %endmacro
 
-; PROCESS_128x2x4 first, off_{first,second}_{src,ref}, advance_at_end
-%macro PROCESS_128x2x4 5-6 0
-  PROCESS_64x2x4 %1, %2, %3, %2 + 64, %3 + 64
-  PROCESS_64x2x4  0, %4, %5, %4 + 64, %5 + 64, %6
+; PROCESS_128x2x4 first, off_{first,second}_{src,ref}, do_avg,
+;                 {first,second}_extraline, line_ending
+%macro PROCESS_128x2x4 9
+  PROCESS_64x2x4 %1, %2, %3, %2 + 64, %3 + 64, %6, %7, %7, %8 - %7
+  PROCESS_64x2x4  0, %4, %5, %4 + 64, %5 + 64, %6, %8, %8, %9
 %endmacro
 
 ; void aom_sadNxNx4d_sse2(uint8_t *src,    int src_stride,
 ;                         uint8_t *ref[4], int ref_stride,
 ;                         uint32_t res[4]);
 ; where NxN = 64x64, 32x32, 16x16, 16x8, 8x16, 8x8, 8x4, 4x8 and 4x4
-%macro SADNXN4D 2
+%macro SADNXN4D 2-3 0
+%if %3 == 0
 %if UNIX64
 cglobal sad%1x%2x4d, 5, 8, 8, src, src_stride, ref1, ref_stride, \
                               res, ref2, ref3, ref4
@@ -196,18 +322,41 @@
 cglobal sad%1x%2x4d, 4, 7, 8, src, src_stride, ref1, ref_stride, \
                               ref2, ref3, ref4
 %endif
+%else ; avg
+
+%if UNIX64
+cglobal sad%1x%2x4d_avg, 6, 10, 8, src, src_stride, ref1, ref_stride, \
+                                  second_pred, res, ref2, ref3, ref4
+%else
+cglobal sad%1x%2x4d_avg, 5, 7, 8, src, ref4, ref1, ref_stride, \
+                                  second_pred, ref2, ref3
+  %define src_strideq r1mp
+  %define src_strided r1mp
+%endif
+%endif
+
+  %define mflag ((1 - UNIX64) & %3)
   movsxdifnidn src_strideq, src_strided
   movsxdifnidn ref_strideq, ref_strided
+
   mov                ref2q, [ref1q+gprsize*1]
   mov                ref3q, [ref1q+gprsize*2]
   mov                ref4q, [ref1q+gprsize*3]
   mov                ref1q, [ref1q+gprsize*0]
 
-  PROCESS_%1x2x4 1, 0, 0, src_strideq, ref_strideq, 1
+  PROCESS_%1x2x4 1, 0, 0, 0, ref_strideq, %3, 0, 1, 2
 %rep (%2-4)/2
-  PROCESS_%1x2x4 0, 0, 0, src_strideq, ref_strideq, 1
+  PROCESS_%1x2x4 0, 0, 0, 0, ref_strideq, %3, 0, 1, 2
 %endrep
-  PROCESS_%1x2x4 0, 0, 0, src_strideq, ref_strideq, 0
+  PROCESS_%1x2x4 0, 0, 0, 0, ref_strideq, %3, 0, 1, 2
+
+%if %3 == 0
+  %define resultq r4
+  %define resultmp r4mp
+%else
+  %define resultq r5
+  %define resultmp r5mp
+%endif
 
 %if %1 > 4
   pslldq                m5, 4
@@ -218,16 +367,16 @@
   mova                  m7, m6
   punpcklqdq            m4, m6
   punpckhqdq            m5, m7
-  movifnidn             r4, r4mp
   paddd                 m4, m5
-  movu                [r4], m4
+  movifnidn             resultq, resultmp
+  movu                [resultq], m4
   RET
 %else
-  movifnidn             r4, r4mp
   pshufd            m6, m6, 0x08
   pshufd            m7, m7, 0x08
-  movq              [r4+0], m6
-  movq              [r4+8], m7
+  movifnidn             resultq, resultmp
+  movq              [resultq+0], m6
+  movq              [resultq+8], m7
   RET
 %endif
 %endmacro
@@ -255,3 +404,25 @@
 SADNXN4D 32,  8
 SADNXN4D 16, 64
 SADNXN4D 64, 16
+SADNXN4D 128, 128, 1
+SADNXN4D 128, 64, 1
+SADNXN4D 64,  128, 1
+SADNXN4D 64, 64, 1
+SADNXN4D 64, 32, 1
+SADNXN4D 32, 64, 1
+SADNXN4D 32, 32, 1
+SADNXN4D 32, 16, 1
+SADNXN4D 16, 32, 1
+SADNXN4D 16, 16, 1
+SADNXN4D 16,  8, 1
+SADNXN4D  8, 16, 1
+SADNXN4D  8,  8, 1
+SADNXN4D  8,  4, 1
+SADNXN4D  4,  8, 1
+SADNXN4D  4,  4, 1
+SADNXN4D  4, 16, 1
+SADNXN4D 16,  4, 1
+SADNXN4D  8, 32, 1
+SADNXN4D 32,  8, 1
+SADNXN4D 16, 64, 1
+SADNXN4D 64, 16, 1
diff --git a/test/sad_test.cc b/test/sad_test.cc
index 70dc969..0bdbf37 100644
--- a/test/sad_test.cc
+++ b/test/sad_test.cc
@@ -60,6 +60,12 @@
                              uint32_t *sad_array);
 typedef std::tuple<int, int, SadMxNx4Func, int> SadMxNx4Param;
 
+typedef void (*SadMxNx4AvgFunc)(const uint8_t *src_ptr, int src_stride,
+                                const uint8_t *const ref_ptr[], int ref_stride,
+                                const uint8_t *second_pred,
+                                uint32_t *sad_array);
+typedef std::tuple<int, int, SadMxNx4AvgFunc, int> SadMxNx4AvgParam;
+
 using libaom_test::ACMRandom;
 
 namespace {
@@ -339,6 +345,42 @@
   }
 };
 
+class SADx4AvgTest : public ::testing::WithParamInterface<SadMxNx4AvgParam>,
+                     public SADTestBase {
+ public:
+  SADx4AvgTest() : SADTestBase(GET_PARAM(0), GET_PARAM(1), GET_PARAM(3)) {}
+
+ protected:
+  void SADs(unsigned int *results) {
+    const uint8_t *references[] = { GetReference(0), GetReference(1),
+                                    GetReference(2), GetReference(3) };
+
+    ASM_REGISTER_STATE_CHECK(GET_PARAM(2)(source_data_, source_stride_,
+                                          references, reference_stride_,
+                                          second_pred_, results));
+  }
+
+  void CheckSADs() {
+    unsigned int reference_sad, exp_sad[4];
+
+    SADs(exp_sad);
+    for (int block = 0; block < 4; ++block) {
+      reference_sad = ReferenceSADavg(block);
+
+      EXPECT_EQ(reference_sad, exp_sad[block]) << "block " << block;
+    }
+  }
+
+  void SpeedSAD() {
+    int test_count = 200000;
+    unsigned int exp_sad[4];
+    while (test_count > 0) {
+      SADs(exp_sad);
+      test_count -= 1;
+    }
+  }
+};
+
 class SADTest : public ::testing::WithParamInterface<SadMxNParam>,
                 public SADTestBase {
  public:
@@ -814,6 +856,69 @@
 
 using std::make_tuple;
 
+#if SPEED_TEST
+TEST_P(SADx4AvgTest, Speed) {
+  int tmp_stride = reference_stride_;
+  reference_stride_ >>= 1;
+  FillRandom(source_data_, source_stride_);
+  FillRandom(GetReference(0), reference_stride_);
+  FillRandom(GetReference(1), reference_stride_);
+  FillRandom(GetReference(2), reference_stride_);
+  FillRandom(GetReference(3), reference_stride_);
+  FillRandom(second_pred_, width_);
+  SpeedSAD();
+  reference_stride_ = tmp_stride;
+}
+#endif
+
+TEST_P(SADx4AvgTest, MaxRef) {
+  FillConstant(source_data_, source_stride_, 0);
+  FillConstant(GetReference(0), reference_stride_, mask_);
+  FillConstant(GetReference(1), reference_stride_, mask_);
+  FillConstant(GetReference(2), reference_stride_, mask_);
+  FillConstant(GetReference(3), reference_stride_, mask_);
+  FillConstant(second_pred_, width_, 0);
+  CheckSADs();
+}
+
+TEST_P(SADx4AvgTest, MaxSrc) {
+  FillConstant(source_data_, source_stride_, mask_);
+  FillConstant(GetReference(0), reference_stride_, 0);
+  FillConstant(GetReference(1), reference_stride_, 0);
+  FillConstant(GetReference(2), reference_stride_, 0);
+  FillConstant(GetReference(3), reference_stride_, 0);
+  FillConstant(second_pred_, width_, 0);
+  CheckSADs();
+}
+
+TEST_P(SADx4AvgTest, ShortRef) {
+  int tmp_stride = reference_stride_;
+  reference_stride_ >>= 1;
+  FillRandom(source_data_, source_stride_);
+  FillRandom(GetReference(0), reference_stride_);
+  FillRandom(GetReference(1), reference_stride_);
+  FillRandom(GetReference(2), reference_stride_);
+  FillRandom(GetReference(3), reference_stride_);
+  FillRandom(second_pred_, width_);
+  CheckSADs();
+  reference_stride_ = tmp_stride;
+}
+
+TEST_P(SADx4AvgTest, UnalignedRef) {
+  // The reference frame, but not the source frame, may be unaligned for
+  // certain types of searches.
+  int tmp_stride = reference_stride_;
+  reference_stride_ -= 1;
+  FillRandom(source_data_, source_stride_);
+  FillRandom(GetReference(0), reference_stride_);
+  FillRandom(GetReference(1), reference_stride_);
+  FillRandom(GetReference(2), reference_stride_);
+  FillRandom(GetReference(3), reference_stride_);
+  FillRandom(second_pred_, width_);
+  CheckSADs();
+  reference_stride_ = tmp_stride;
+}
+
 //------------------------------------------------------------------------------
 // C functions
 const SadMxNParam c_tests[] = {
@@ -1175,6 +1280,32 @@
 };
 INSTANTIATE_TEST_SUITE_P(C, SADx4Test, ::testing::ValuesIn(x4d_c_tests));
 
+const SadMxNx4AvgParam x4d_avg_c_tests[] = {
+  make_tuple(128, 128, &aom_sad128x128x4d_avg_c, -1),
+  make_tuple(128, 64, &aom_sad128x64x4d_avg_c, -1),
+  make_tuple(64, 128, &aom_sad64x128x4d_avg_c, -1),
+  make_tuple(64, 64, &aom_sad64x64x4d_avg_c, -1),
+  make_tuple(64, 32, &aom_sad64x32x4d_avg_c, -1),
+  make_tuple(32, 64, &aom_sad32x64x4d_avg_c, -1),
+  make_tuple(32, 32, &aom_sad32x32x4d_avg_c, -1),
+  make_tuple(32, 16, &aom_sad32x16x4d_avg_c, -1),
+  make_tuple(16, 32, &aom_sad16x32x4d_avg_c, -1),
+  make_tuple(16, 16, &aom_sad16x16x4d_avg_c, -1),
+  make_tuple(16, 8, &aom_sad16x8x4d_avg_c, -1),
+  make_tuple(8, 16, &aom_sad8x16x4d_avg_c, -1),
+  make_tuple(8, 8, &aom_sad8x8x4d_avg_c, -1),
+  make_tuple(8, 4, &aom_sad8x4x4d_avg_c, -1),
+  make_tuple(4, 8, &aom_sad4x8x4d_avg_c, -1),
+  make_tuple(4, 4, &aom_sad4x4x4d_avg_c, -1),
+  make_tuple(64, 16, &aom_sad64x16x4d_avg_c, -1),
+  make_tuple(16, 64, &aom_sad16x64x4d_avg_c, -1),
+  make_tuple(32, 8, &aom_sad32x8x4d_avg_c, -1),
+  make_tuple(8, 32, &aom_sad8x32x4d_avg_c, -1),
+  make_tuple(16, 4, &aom_sad16x4x4d_avg_c, -1),
+  make_tuple(4, 16, &aom_sad4x16x4d_avg_c, -1),
+};
+INSTANTIATE_TEST_SUITE_P(C, SADx4AvgTest, ::testing::ValuesIn(x4d_avg_c_tests));
+
 //------------------------------------------------------------------------------
 // ARM functions
 #if HAVE_NEON
@@ -1472,6 +1603,33 @@
 #endif
 };
 INSTANTIATE_TEST_SUITE_P(SSE2, SADx4Test, ::testing::ValuesIn(x4d_sse2_tests));
+
+const SadMxNx4AvgParam x4d_avg_sse2_tests[] = {
+  make_tuple(128, 128, &aom_sad128x128x4d_avg_sse2, -1),
+  make_tuple(128, 64, &aom_sad128x64x4d_avg_sse2, -1),
+  make_tuple(64, 128, &aom_sad64x128x4d_avg_sse2, -1),
+  make_tuple(64, 64, &aom_sad64x64x4d_avg_sse2, -1),
+  make_tuple(64, 32, &aom_sad64x32x4d_avg_sse2, -1),
+  make_tuple(32, 64, &aom_sad32x64x4d_avg_sse2, -1),
+  make_tuple(32, 32, &aom_sad32x32x4d_avg_sse2, -1),
+  make_tuple(32, 16, &aom_sad32x16x4d_avg_sse2, -1),
+  make_tuple(16, 32, &aom_sad16x32x4d_avg_sse2, -1),
+  make_tuple(16, 16, &aom_sad16x16x4d_avg_sse2, -1),
+  make_tuple(16, 8, &aom_sad16x8x4d_avg_sse2, -1),
+  make_tuple(8, 16, &aom_sad8x16x4d_avg_sse2, -1),
+  make_tuple(8, 8, &aom_sad8x8x4d_avg_sse2, -1),
+  make_tuple(8, 4, &aom_sad8x4x4d_avg_sse2, -1),
+  make_tuple(4, 8, &aom_sad4x8x4d_avg_sse2, -1),
+  make_tuple(4, 4, &aom_sad4x4x4d_avg_sse2, -1),
+  make_tuple(64, 16, &aom_sad64x16x4d_avg_sse2, -1),
+  make_tuple(16, 64, &aom_sad16x64x4d_avg_sse2, -1),
+  make_tuple(32, 8, &aom_sad32x8x4d_avg_sse2, -1),
+  make_tuple(8, 32, &aom_sad8x32x4d_avg_sse2, -1),
+  make_tuple(16, 4, &aom_sad16x4x4d_avg_sse2, -1),
+  make_tuple(4, 16, &aom_sad4x16x4d_avg_sse2, -1),
+};
+INSTANTIATE_TEST_SUITE_P(SSE2, SADx4AvgTest,
+                         ::testing::ValuesIn(x4d_avg_sse2_tests));
 #endif  // HAVE_SSE2
 
 #if HAVE_SSSE3