sadx4d highbd avx2 code refactoring

the change facilitates further sad4d code improvements
no performance changes expected with it

Change-Id: Ifeb0b3ce9fbdd5ac7a8ef252fe23bb3bb5905f5f
diff --git a/aom_dsp/x86/sad_highbd_avx2.c b/aom_dsp/x86/sad_highbd_avx2.c
index eba442c..2cff2e6 100644
--- a/aom_dsp/x86/sad_highbd_avx2.c
+++ b/aom_dsp/x86/sad_highbd_avx2.c
@@ -37,532 +37,247 @@
   return (unsigned int)_mm_cvtsi128_si32(lo128);
 }
 
-unsigned int aom_highbd_sad16x8_avx2(const uint8_t *src, int src_stride,
-                                     const uint8_t *ref, int ref_stride) {
-  const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
-  const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
-
-  // first 4 rows
-  __m256i s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
-  __m256i s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
-  __m256i s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
-  __m256i s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
-
-  __m256i r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
-  __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
-  __m256i r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
-  __m256i r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
-
-  __m256i u0 = _mm256_sub_epi16(s0, r0);
-  __m256i u1 = _mm256_sub_epi16(s1, r1);
-  __m256i u2 = _mm256_sub_epi16(s2, r2);
-  __m256i u3 = _mm256_sub_epi16(s3, r3);
-  __m256i zero = _mm256_setzero_si256();
-  __m256i sum0, sum1;
-
-  u0 = _mm256_abs_epi16(u0);
-  u1 = _mm256_abs_epi16(u1);
-  u2 = _mm256_abs_epi16(u2);
-  u3 = _mm256_abs_epi16(u3);
-
-  sum0 = _mm256_add_epi16(u0, u1);
-  sum0 = _mm256_add_epi16(sum0, u2);
-  sum0 = _mm256_add_epi16(sum0, u3);
-
-  // second 4 rows
-  src_ptr += src_stride << 2;
-  ref_ptr += ref_stride << 2;
-  s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
-  s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
-  s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
-  s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
-
-  r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
-  r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
-  r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
-  r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
-
-  u0 = _mm256_sub_epi16(s0, r0);
-  u1 = _mm256_sub_epi16(s1, r1);
-  u2 = _mm256_sub_epi16(s2, r2);
-  u3 = _mm256_sub_epi16(s3, r3);
-
-  u0 = _mm256_abs_epi16(u0);
-  u1 = _mm256_abs_epi16(u1);
-  u2 = _mm256_abs_epi16(u2);
-  u3 = _mm256_abs_epi16(u3);
-
-  sum1 = _mm256_add_epi16(u0, u1);
-  sum1 = _mm256_add_epi16(sum1, u2);
-  sum1 = _mm256_add_epi16(sum1, u3);
-
-  // find out the SAD
-  s0 = _mm256_unpacklo_epi16(sum0, zero);
-  s1 = _mm256_unpackhi_epi16(sum0, zero);
-  r0 = _mm256_unpacklo_epi16(sum1, zero);
-  r1 = _mm256_unpackhi_epi16(sum1, zero);
-  s0 = _mm256_add_epi32(s0, s1);
-  r0 = _mm256_add_epi32(r0, r1);
-  sum0 = _mm256_add_epi32(s0, r0);
-  // 8 32-bit summation
-
-  return (unsigned int)get_sad_from_mm256_epi32(&sum0);
-}
-
-unsigned int aom_highbd_sad16x16_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
-  const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
-  __m256i s0, s1, s2, s3, r0, r1, r2, r3, u0, u1, u2, u3;
-  __m256i sum0;
-  __m256i sum = _mm256_setzero_si256();
+static INLINE void highbd_sad16x4_core_avx2(__m256i *s, __m256i *r,
+                                            __m256i *sad_acc) {
   const __m256i zero = _mm256_setzero_si256();
-  int row = 0;
-
-  // Loop for every 4 rows
-  while (row < 16) {
-    s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
-    s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
-    s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
-    s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
-
-    r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
-    r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
-    r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
-    r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
-
-    u0 = _mm256_sub_epi16(s0, r0);
-    u1 = _mm256_sub_epi16(s1, r1);
-    u2 = _mm256_sub_epi16(s2, r2);
-    u3 = _mm256_sub_epi16(s3, r3);
-
-    u0 = _mm256_abs_epi16(u0);
-    u1 = _mm256_abs_epi16(u1);
-    u2 = _mm256_abs_epi16(u2);
-    u3 = _mm256_abs_epi16(u3);
-
-    sum0 = _mm256_add_epi16(u0, u1);
-    sum0 = _mm256_add_epi16(sum0, u2);
-    sum0 = _mm256_add_epi16(sum0, u3);
-
-    s0 = _mm256_unpacklo_epi16(sum0, zero);
-    s1 = _mm256_unpackhi_epi16(sum0, zero);
-    sum = _mm256_add_epi32(sum, s0);
-    sum = _mm256_add_epi32(sum, s1);
-    // 8 32-bit summation
-
-    row += 4;
-    src_ptr += src_stride << 2;
-    ref_ptr += ref_stride << 2;
+  int i;
+  for (i = 0; i < 4; i++) {
+    s[i] = _mm256_sub_epi16(s[i], r[i]);
+    s[i] = _mm256_abs_epi16(s[i]);
   }
-  return get_sad_from_mm256_epi32(&sum);
-}
-
-static void sad32x4(const uint16_t *src_ptr, int src_stride,
-                    const uint16_t *ref_ptr, int ref_stride,
-                    const uint16_t *sec_ptr, __m256i *sad_acc) {
-  __m256i s0, s1, s2, s3, r0, r1, r2, r3;
-  const __m256i zero = _mm256_setzero_si256();
-  int row_sections = 0;
-
-  while (row_sections < 2) {
-    s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
-    s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
-    s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
-    s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
-
-    r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
-    r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
-    r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
-    r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
-
-    if (sec_ptr) {
-      r0 = _mm256_avg_epu16(r0, _mm256_loadu_si256((const __m256i *)sec_ptr));
-      r1 = _mm256_avg_epu16(
-          r1, _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
-      r2 = _mm256_avg_epu16(
-          r2, _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
-      r3 = _mm256_avg_epu16(
-          r3, _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
-    }
-    s0 = _mm256_sub_epi16(s0, r0);
-    s1 = _mm256_sub_epi16(s1, r1);
-    s2 = _mm256_sub_epi16(s2, r2);
-    s3 = _mm256_sub_epi16(s3, r3);
-
-    s0 = _mm256_abs_epi16(s0);
-    s1 = _mm256_abs_epi16(s1);
-    s2 = _mm256_abs_epi16(s2);
-    s3 = _mm256_abs_epi16(s3);
-
-    s0 = _mm256_add_epi16(s0, s1);
-    s0 = _mm256_add_epi16(s0, s2);
-    s0 = _mm256_add_epi16(s0, s3);
-
-    r0 = _mm256_unpacklo_epi16(s0, zero);
-    r1 = _mm256_unpackhi_epi16(s0, zero);
-
-    r0 = _mm256_add_epi32(r0, r1);
-    *sad_acc = _mm256_add_epi32(*sad_acc, r0);
-
-    row_sections += 1;
-    src_ptr += src_stride << 1;
-    ref_ptr += ref_stride << 1;
-    if (sec_ptr) sec_ptr += 32 << 1;
-  }
-}
-
-unsigned int aom_highbd_sad32x16_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  __m256i sad = _mm256_setzero_si256();
-  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
-  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
-  const int left_shift = 2;
-  int row_section = 0;
-
-  while (row_section < 4) {
-    sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
-    srcp += src_stride << left_shift;
-    refp += ref_stride << left_shift;
-    row_section += 1;
-  }
-  return get_sad_from_mm256_epi32(&sad);
-}
-
-unsigned int aom_highbd_sad32x8_avx2(const uint8_t *src, int src_stride,
-                                     const uint8_t *ref, int ref_stride) {
-  __m256i sad = _mm256_setzero_si256();
-  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
-  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
-  const int left_shift = 2;
-  int row_section = 0;
-
-  while (row_section < 2) {
-    sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
-    srcp += src_stride << left_shift;
-    refp += ref_stride << left_shift;
-    row_section += 1;
-  }
-  return get_sad_from_mm256_epi32(&sad);
-}
-
-unsigned int aom_highbd_sad16x32_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  uint32_t sum = aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
-  src += src_stride << 4;
-  ref += ref_stride << 4;
-  sum += aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
-  return sum;
-}
-
-unsigned int aom_highbd_sad32x32_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  uint32_t sum = aom_highbd_sad32x16_avx2(src, src_stride, ref, ref_stride);
-  src += src_stride << 4;
-  ref += ref_stride << 4;
-  sum += aom_highbd_sad32x16_avx2(src, src_stride, ref, ref_stride);
-  return sum;
-}
-
-unsigned int aom_highbd_sad32x64_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  uint32_t sum = aom_highbd_sad32x32_avx2(src, src_stride, ref, ref_stride);
-  src += src_stride << 5;
-  ref += ref_stride << 5;
-  sum += aom_highbd_sad32x32_avx2(src, src_stride, ref, ref_stride);
-  return sum;
-}
-
-static void sad64x2(const uint16_t *src_ptr, int src_stride,
-                    const uint16_t *ref_ptr, int ref_stride,
-                    const uint16_t *sec_ptr, __m256i *sad_acc) {
-  __m256i s[8], r[8];
-  const __m256i zero = _mm256_setzero_si256();
-
-  s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
-  s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
-  s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
-  s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
-  s[4] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
-  s[5] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
-  s[6] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 32));
-  s[7] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 48));
-
-  r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
-  r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
-  r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
-  r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
-  r[4] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
-  r[5] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
-  r[6] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 32));
-  r[7] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 48));
-
-  if (sec_ptr) {
-    r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
-    r[1] = _mm256_avg_epu16(
-        r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
-    r[2] = _mm256_avg_epu16(
-        r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
-    r[3] = _mm256_avg_epu16(
-        r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
-    r[4] = _mm256_avg_epu16(
-        r[4], _mm256_loadu_si256((const __m256i *)(sec_ptr + 64)));
-    r[5] = _mm256_avg_epu16(
-        r[5], _mm256_loadu_si256((const __m256i *)(sec_ptr + 80)));
-    r[6] = _mm256_avg_epu16(
-        r[6], _mm256_loadu_si256((const __m256i *)(sec_ptr + 96)));
-    r[7] = _mm256_avg_epu16(
-        r[7], _mm256_loadu_si256((const __m256i *)(sec_ptr + 112)));
-  }
-
-  s[0] = _mm256_sub_epi16(s[0], r[0]);
-  s[1] = _mm256_sub_epi16(s[1], r[1]);
-  s[2] = _mm256_sub_epi16(s[2], r[2]);
-  s[3] = _mm256_sub_epi16(s[3], r[3]);
-  s[4] = _mm256_sub_epi16(s[4], r[4]);
-  s[5] = _mm256_sub_epi16(s[5], r[5]);
-  s[6] = _mm256_sub_epi16(s[6], r[6]);
-  s[7] = _mm256_sub_epi16(s[7], r[7]);
-
-  s[0] = _mm256_abs_epi16(s[0]);
-  s[1] = _mm256_abs_epi16(s[1]);
-  s[2] = _mm256_abs_epi16(s[2]);
-  s[3] = _mm256_abs_epi16(s[3]);
-  s[4] = _mm256_abs_epi16(s[4]);
-  s[5] = _mm256_abs_epi16(s[5]);
-  s[6] = _mm256_abs_epi16(s[6]);
-  s[7] = _mm256_abs_epi16(s[7]);
 
   s[0] = _mm256_add_epi16(s[0], s[1]);
   s[0] = _mm256_add_epi16(s[0], s[2]);
   s[0] = _mm256_add_epi16(s[0], s[3]);
 
-  s[4] = _mm256_add_epi16(s[4], s[5]);
-  s[4] = _mm256_add_epi16(s[4], s[6]);
-  s[4] = _mm256_add_epi16(s[4], s[7]);
-
   r[0] = _mm256_unpacklo_epi16(s[0], zero);
   r[1] = _mm256_unpackhi_epi16(s[0], zero);
-  r[2] = _mm256_unpacklo_epi16(s[4], zero);
-  r[3] = _mm256_unpackhi_epi16(s[4], zero);
 
   r[0] = _mm256_add_epi32(r[0], r[1]);
-  r[0] = _mm256_add_epi32(r[0], r[2]);
-  r[0] = _mm256_add_epi32(r[0], r[3]);
   *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
 }
 
-unsigned int aom_highbd_sad64x32_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  __m256i sad = _mm256_setzero_si256();
-  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
-  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
-  const int left_shift = 1;
-  int row_section = 0;
-
-  while (row_section < 16) {
-    sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
-    srcp += src_stride << left_shift;
-    refp += ref_stride << left_shift;
-    row_section += 1;
-  }
-  return get_sad_from_mm256_epi32(&sad);
-}
-
-unsigned int aom_highbd_sad64x16_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  __m256i sad = _mm256_setzero_si256();
-  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
-  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
-  const int left_shift = 1;
-  int row_section = 0;
-
-  while (row_section < 8) {
-    sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
-    srcp += src_stride << left_shift;
-    refp += ref_stride << left_shift;
-    row_section += 1;
-  }
-  return get_sad_from_mm256_epi32(&sad);
-}
-
-unsigned int aom_highbd_sad64x64_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  uint32_t sum = aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
-  src += src_stride << 5;
-  ref += ref_stride << 5;
-  sum += aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
-  return sum;
-}
-
-static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
-                     const uint16_t *sec_ptr, __m256i *sad_acc) {
-  __m256i s[8], r[8];
-  const __m256i zero = _mm256_setzero_si256();
-
-  s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
-  s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
-  s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
-  s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
-  s[4] = _mm256_loadu_si256((const __m256i *)(src_ptr + 64));
-  s[5] = _mm256_loadu_si256((const __m256i *)(src_ptr + 80));
-  s[6] = _mm256_loadu_si256((const __m256i *)(src_ptr + 96));
-  s[7] = _mm256_loadu_si256((const __m256i *)(src_ptr + 112));
-
-  r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
-  r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
-  r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
-  r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
-  r[4] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 64));
-  r[5] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 80));
-  r[6] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 96));
-  r[7] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 112));
-
-  if (sec_ptr) {
-    r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
-    r[1] = _mm256_avg_epu16(
-        r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
-    r[2] = _mm256_avg_epu16(
-        r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
-    r[3] = _mm256_avg_epu16(
-        r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
-    r[4] = _mm256_avg_epu16(
-        r[4], _mm256_loadu_si256((const __m256i *)(sec_ptr + 64)));
-    r[5] = _mm256_avg_epu16(
-        r[5], _mm256_loadu_si256((const __m256i *)(sec_ptr + 80)));
-    r[6] = _mm256_avg_epu16(
-        r[6], _mm256_loadu_si256((const __m256i *)(sec_ptr + 96)));
-    r[7] = _mm256_avg_epu16(
-        r[7], _mm256_loadu_si256((const __m256i *)(sec_ptr + 112)));
-  }
-
-  s[0] = _mm256_sub_epi16(s[0], r[0]);
-  s[1] = _mm256_sub_epi16(s[1], r[1]);
-  s[2] = _mm256_sub_epi16(s[2], r[2]);
-  s[3] = _mm256_sub_epi16(s[3], r[3]);
-  s[4] = _mm256_sub_epi16(s[4], r[4]);
-  s[5] = _mm256_sub_epi16(s[5], r[5]);
-  s[6] = _mm256_sub_epi16(s[6], r[6]);
-  s[7] = _mm256_sub_epi16(s[7], r[7]);
-
-  s[0] = _mm256_abs_epi16(s[0]);
-  s[1] = _mm256_abs_epi16(s[1]);
-  s[2] = _mm256_abs_epi16(s[2]);
-  s[3] = _mm256_abs_epi16(s[3]);
-  s[4] = _mm256_abs_epi16(s[4]);
-  s[5] = _mm256_abs_epi16(s[5]);
-  s[6] = _mm256_abs_epi16(s[6]);
-  s[7] = _mm256_abs_epi16(s[7]);
-
-  s[0] = _mm256_add_epi16(s[0], s[1]);
-  s[0] = _mm256_add_epi16(s[0], s[2]);
-  s[0] = _mm256_add_epi16(s[0], s[3]);
-
-  s[4] = _mm256_add_epi16(s[4], s[5]);
-  s[4] = _mm256_add_epi16(s[4], s[6]);
-  s[4] = _mm256_add_epi16(s[4], s[7]);
-
-  r[0] = _mm256_unpacklo_epi16(s[0], zero);
-  r[1] = _mm256_unpackhi_epi16(s[0], zero);
-  r[2] = _mm256_unpacklo_epi16(s[4], zero);
-  r[3] = _mm256_unpackhi_epi16(s[4], zero);
-
-  r[0] = _mm256_add_epi32(r[0], r[1]);
-  r[0] = _mm256_add_epi32(r[0], r[2]);
-  r[0] = _mm256_add_epi32(r[0], r[3]);
-  *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
-}
-
-unsigned int aom_highbd_sad128x64_avx2(const uint8_t *src, int src_stride,
-                                       const uint8_t *ref, int ref_stride) {
-  __m256i sad = _mm256_setzero_si256();
-  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
-  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
-  int row = 0;
-  while (row < 64) {
-    sad128x1(srcp, refp, NULL, &sad);
-    srcp += src_stride;
-    refp += ref_stride;
-    row += 1;
-  }
-  return get_sad_from_mm256_epi32(&sad);
-}
-
-unsigned int aom_highbd_sad64x128_avx2(const uint8_t *src, int src_stride,
-                                       const uint8_t *ref, int ref_stride) {
-  uint32_t sum = aom_highbd_sad64x64_avx2(src, src_stride, ref, ref_stride);
-  src += src_stride << 6;
-  ref += ref_stride << 6;
-  sum += aom_highbd_sad64x64_avx2(src, src_stride, ref, ref_stride);
-  return sum;
-}
-
-unsigned int aom_highbd_sad128x128_avx2(const uint8_t *src, int src_stride,
-                                        const uint8_t *ref, int ref_stride) {
-  uint32_t sum = aom_highbd_sad128x64_avx2(src, src_stride, ref, ref_stride);
-  src += src_stride << 6;
-  ref += ref_stride << 6;
-  sum += aom_highbd_sad128x64_avx2(src, src_stride, ref, ref_stride);
-  return sum;
-}
-
 // If sec_ptr = 0, calculate regular SAD. Otherwise, calculate average SAD.
 static INLINE void sad16x4(const uint16_t *src_ptr, int src_stride,
                            const uint16_t *ref_ptr, int ref_stride,
                            const uint16_t *sec_ptr, __m256i *sad_acc) {
-  __m256i s0, s1, s2, s3, r0, r1, r2, r3;
-  const __m256i zero = _mm256_setzero_si256();
+  __m256i s[4], r[4];
+  s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
+  s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
+  s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
+  s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
 
-  s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
-  s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
-  s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
-  s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
-
-  r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
-  r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
-  r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
-  r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
+  r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
+  r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
+  r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
+  r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
 
   if (sec_ptr) {
-    r0 = _mm256_avg_epu16(r0, _mm256_loadu_si256((const __m256i *)sec_ptr));
-    r1 = _mm256_avg_epu16(r1,
-                          _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
-    r2 = _mm256_avg_epu16(r2,
-                          _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
-    r3 = _mm256_avg_epu16(r3,
-                          _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
+    r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
+    r[1] = _mm256_avg_epu16(
+        r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
+    r[2] = _mm256_avg_epu16(
+        r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
+    r[3] = _mm256_avg_epu16(
+        r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
   }
-
-  s0 = _mm256_sub_epi16(s0, r0);
-  s1 = _mm256_sub_epi16(s1, r1);
-  s2 = _mm256_sub_epi16(s2, r2);
-  s3 = _mm256_sub_epi16(s3, r3);
-
-  s0 = _mm256_abs_epi16(s0);
-  s1 = _mm256_abs_epi16(s1);
-  s2 = _mm256_abs_epi16(s2);
-  s3 = _mm256_abs_epi16(s3);
-
-  s0 = _mm256_add_epi16(s0, s1);
-  s0 = _mm256_add_epi16(s0, s2);
-  s0 = _mm256_add_epi16(s0, s3);
-
-  r0 = _mm256_unpacklo_epi16(s0, zero);
-  r1 = _mm256_unpackhi_epi16(s0, zero);
-
-  r0 = _mm256_add_epi32(r0, r1);
-  *sad_acc = _mm256_add_epi32(*sad_acc, r0);
+  highbd_sad16x4_core_avx2(s, r, sad_acc);
 }
 
-unsigned int aom_highbd_sad16x4_avx2(const uint8_t *src, int src_stride,
-                                     const uint8_t *ref, int ref_stride) {
+static AOM_FORCE_INLINE unsigned int aom_highbd_sad16xN_avx2(int N,
+                                                             const uint8_t *src,
+                                                             int src_stride,
+                                                             const uint8_t *ref,
+                                                             int ref_stride) {
+  const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
+  const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
+  int i;
+  __m256i sad = _mm256_setzero_si256();
+  for (i = 0; i < N; i += 4) {
+    sad16x4(src_ptr, src_stride, ref_ptr, ref_stride, NULL, &sad);
+    src_ptr += src_stride << 2;
+    ref_ptr += ref_stride << 2;
+  }
+  return (unsigned int)get_sad_from_mm256_epi32(&sad);
+}
+
+static void sad32x4(const uint16_t *src_ptr, int src_stride,
+                    const uint16_t *ref_ptr, int ref_stride,
+                    const uint16_t *sec_ptr, __m256i *sad_acc) {
+  __m256i s[4], r[4];
+  int row_sections = 0;
+
+  while (row_sections < 2) {
+    s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
+    s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
+    s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
+    s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
+
+    r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
+    r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
+    r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
+    r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
+
+    if (sec_ptr) {
+      r[0] =
+          _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
+      r[1] = _mm256_avg_epu16(
+          r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
+      r[2] = _mm256_avg_epu16(
+          r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
+      r[3] = _mm256_avg_epu16(
+          r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
+      sec_ptr += 32 << 1;
+    }
+    highbd_sad16x4_core_avx2(s, r, sad_acc);
+
+    row_sections += 1;
+    src_ptr += src_stride << 1;
+    ref_ptr += ref_stride << 1;
+  }
+}
+
+static AOM_FORCE_INLINE unsigned int aom_highbd_sad32xN_avx2(int N,
+                                                             const uint8_t *src,
+                                                             int src_stride,
+                                                             const uint8_t *ref,
+                                                             int ref_stride) {
   __m256i sad = _mm256_setzero_si256();
   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
-  sad16x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
+  const int left_shift = 2;
+  int i;
+
+  for (i = 0; i < N; i += 4) {
+    sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
+    srcp += src_stride << left_shift;
+    refp += ref_stride << left_shift;
+  }
   return get_sad_from_mm256_epi32(&sad);
 }
 
+static void sad64x2(const uint16_t *src_ptr, int src_stride,
+                    const uint16_t *ref_ptr, int ref_stride,
+                    const uint16_t *sec_ptr, __m256i *sad_acc) {
+  __m256i s[4], r[4];
+  int i;
+  for (i = 0; i < 2; i++) {
+    s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
+    s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
+    s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
+    s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
+
+    r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
+    r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
+    r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
+    r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
+    if (sec_ptr) {
+      r[0] =
+          _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
+      r[1] = _mm256_avg_epu16(
+          r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
+      r[2] = _mm256_avg_epu16(
+          r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
+      r[3] = _mm256_avg_epu16(
+          r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
+      sec_ptr += 64;
+    }
+    highbd_sad16x4_core_avx2(s, r, sad_acc);
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  }
+}
+
+static AOM_FORCE_INLINE unsigned int aom_highbd_sad64xN_avx2(int N,
+                                                             const uint8_t *src,
+                                                             int src_stride,
+                                                             const uint8_t *ref,
+                                                             int ref_stride) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  const int left_shift = 1;
+  int i;
+  for (i = 0; i < N; i += 2) {
+    sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
+    srcp += src_stride << left_shift;
+    refp += ref_stride << left_shift;
+  }
+  return get_sad_from_mm256_epi32(&sad);
+}
+
+static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
+                     const uint16_t *sec_ptr, __m256i *sad_acc) {
+  __m256i s[4], r[4];
+  int i;
+  for (i = 0; i < 2; i++) {
+    s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
+    s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
+    s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
+    s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
+    r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
+    r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
+    r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
+    r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
+    if (sec_ptr) {
+      r[0] =
+          _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
+      r[1] = _mm256_avg_epu16(
+          r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
+      r[2] = _mm256_avg_epu16(
+          r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
+      r[3] = _mm256_avg_epu16(
+          r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
+      sec_ptr += 64;
+    }
+    highbd_sad16x4_core_avx2(s, r, sad_acc);
+    src_ptr += 64;
+    ref_ptr += 64;
+  }
+}
+
+static AOM_FORCE_INLINE unsigned int aom_highbd_sad128xN_avx2(
+    int N, const uint8_t *src, int src_stride, const uint8_t *ref,
+    int ref_stride) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  int row = 0;
+  while (row < N) {
+    sad128x1(srcp, refp, NULL, &sad);
+    srcp += src_stride;
+    refp += ref_stride;
+    row++;
+  }
+  return get_sad_from_mm256_epi32(&sad);
+}
+
+#define highbd_sadMxN_avx2(m, n)                                            \
+  unsigned int aom_highbd_sad##m##x##n##_avx2(                              \
+      const uint8_t *src, int src_stride, const uint8_t *ref,               \
+      int ref_stride) {                                                     \
+    return aom_highbd_sad##m##xN_avx2(n, src, src_stride, ref, ref_stride); \
+  }
+
+highbd_sadMxN_avx2(16, 4);
+highbd_sadMxN_avx2(16, 8);
+highbd_sadMxN_avx2(16, 16);
+highbd_sadMxN_avx2(16, 32);
+highbd_sadMxN_avx2(16, 64);
+
+highbd_sadMxN_avx2(32, 8);
+highbd_sadMxN_avx2(32, 16);
+highbd_sadMxN_avx2(32, 32);
+highbd_sadMxN_avx2(32, 64);
+
+highbd_sadMxN_avx2(64, 16);
+highbd_sadMxN_avx2(64, 32);
+highbd_sadMxN_avx2(64, 64);
+highbd_sadMxN_avx2(64, 128);
+
+highbd_sadMxN_avx2(128, 64);
+highbd_sadMxN_avx2(128, 128);
+
 unsigned int aom_highbd_sad16x4_avg_avx2(const uint8_t *src, int src_stride,
                                          const uint8_t *ref, int ref_stride,
                                          const uint8_t *second_pred) {
@@ -621,16 +336,6 @@
   return sum;
 }
 
-unsigned int aom_highbd_sad16x64_avx2(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride) {
-  const int left_shift = 5;
-  uint32_t sum = aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride);
-  src += src_stride << left_shift;
-  ref += ref_stride << left_shift;
-  sum += aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride);
-  return sum;
-}
-
 unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride,
                                           const uint8_t *ref, int ref_stride,
                                           const uint8_t *second_pred) {
@@ -816,7 +521,7 @@
 }
 
 // SAD 4D
-// Combine 4 __m256i vectors to uint32_t result[4]
+// Combine 4 __m256i input vectors  v to uint32_t result[4]
 static INLINE void get_4d_sad_from_mm256_epi32(const __m256i *v,
                                                uint32_t *res) {
   __m256i u0, u1, u2, u3;
@@ -871,386 +576,124 @@
   s[3] = _mm256_setzero_si256();
 }
 
-void aom_highbd_sad16x4x4d_avx2(const uint8_t *src, int src_stride,
-                                const uint8_t *const ref_array[],
-                                int ref_stride, uint32_t *sad_array) {
-  __m256i sad_vec[4];
-  const uint16_t *refp[4];
-  const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
-  const uint16_t *srcp;
-  int i;
-
-  init_sad(sad_vec);
-  convert_pointers(ref_array, refp);
-
-  for (i = 0; i < 4; ++i) {
-    srcp = keep;
-    sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
-  }
-  get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
-}
-
-void aom_highbd_sad16x8x4d_avx2(const uint8_t *src, int src_stride,
-                                const uint8_t *const ref_array[],
-                                int ref_stride, uint32_t *sad_array) {
+static AOM_FORCE_INLINE void aom_highbd_sad16xNx4d_avx2(
+    int N, const uint8_t *src, int src_stride, const uint8_t *const ref_array[],
+    int ref_stride, uint32_t *sad_array) {
   __m256i sad_vec[4];
   const uint16_t *refp[4];
   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
   const uint16_t *srcp;
   const int shift_for_4_rows = 2;
-  int i;
+  int i, j;
 
   init_sad(sad_vec);
   convert_pointers(ref_array, refp);
 
   for (i = 0; i < 4; ++i) {
     srcp = keep;
-    sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
-    srcp += src_stride << shift_for_4_rows;
-    refp[i] += ref_stride << shift_for_4_rows;
-    sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
-  }
-  get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
-}
-
-void aom_highbd_sad16x16x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
-  uint32_t first8rows[4];
-  uint32_t second8rows[4];
-  const uint8_t *ref[4];
-  const int shift_for_8_rows = 3;
-
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
-
-  aom_highbd_sad16x8x4d_avx2(src, src_stride, ref, ref_stride, first8rows);
-  src += src_stride << shift_for_8_rows;
-  ref[0] += ref_stride << shift_for_8_rows;
-  ref[1] += ref_stride << shift_for_8_rows;
-  ref[2] += ref_stride << shift_for_8_rows;
-  ref[3] += ref_stride << shift_for_8_rows;
-  aom_highbd_sad16x8x4d_avx2(src, src_stride, ref, ref_stride, second8rows);
-  sad_array[0] = first8rows[0] + second8rows[0];
-  sad_array[1] = first8rows[1] + second8rows[1];
-  sad_array[2] = first8rows[2] + second8rows[2];
-  sad_array[3] = first8rows[3] + second8rows[3];
-}
-
-void aom_highbd_sad16x32x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
-  uint32_t first_half[4];
-  uint32_t second_half[4];
-  const uint8_t *ref[4];
-  const int shift_for_rows = 4;
-
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
-
-  aom_highbd_sad16x16x4d_avx2(src, src_stride, ref, ref_stride, first_half);
-  src += src_stride << shift_for_rows;
-  ref[0] += ref_stride << shift_for_rows;
-  ref[1] += ref_stride << shift_for_rows;
-  ref[2] += ref_stride << shift_for_rows;
-  ref[3] += ref_stride << shift_for_rows;
-  aom_highbd_sad16x16x4d_avx2(src, src_stride, ref, ref_stride, second_half);
-  sad_array[0] = first_half[0] + second_half[0];
-  sad_array[1] = first_half[1] + second_half[1];
-  sad_array[2] = first_half[2] + second_half[2];
-  sad_array[3] = first_half[3] + second_half[3];
-}
-
-void aom_highbd_sad16x64x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
-  uint32_t first_half[4];
-  uint32_t second_half[4];
-  const uint8_t *ref[4];
-  const int shift_for_rows = 5;
-
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
-
-  aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
-  src += src_stride << shift_for_rows;
-  ref[0] += ref_stride << shift_for_rows;
-  ref[1] += ref_stride << shift_for_rows;
-  ref[2] += ref_stride << shift_for_rows;
-  ref[3] += ref_stride << shift_for_rows;
-  aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
-  sad_array[0] = first_half[0] + second_half[0];
-  sad_array[1] = first_half[1] + second_half[1];
-  sad_array[2] = first_half[2] + second_half[2];
-  sad_array[3] = first_half[3] + second_half[3];
-}
-
-void aom_highbd_sad32x8x4d_avx2(const uint8_t *src, int src_stride,
-                                const uint8_t *const ref_array[],
-                                int ref_stride, uint32_t *sad_array) {
-  __m256i sad_vec[4];
-  const uint16_t *refp[4];
-  const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
-  const uint16_t *srcp;
-  const int shift_for_4_rows = 2;
-  int i;
-  int rows_section;
-
-  init_sad(sad_vec);
-  convert_pointers(ref_array, refp);
-
-  for (i = 0; i < 4; ++i) {
-    srcp = keep;
-    rows_section = 0;
-    while (rows_section < 2) {
-      sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
+    for (j = 0; j < N; j += 4) {
+      sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
       srcp += src_stride << shift_for_4_rows;
       refp[i] += ref_stride << shift_for_4_rows;
-      rows_section++;
     }
   }
   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
 }
 
-void aom_highbd_sad32x16x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
+static AOM_FORCE_INLINE void aom_highbd_sad32xNx4d_avx2(
+    int N, const uint8_t *src, int src_stride, const uint8_t *const ref_array[],
+    int ref_stride, uint32_t *sad_array) {
   __m256i sad_vec[4];
   const uint16_t *refp[4];
   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
   const uint16_t *srcp;
   const int shift_for_4_rows = 2;
-  int i;
-  int rows_section;
+  int i, r;
 
   init_sad(sad_vec);
   convert_pointers(ref_array, refp);
 
   for (i = 0; i < 4; ++i) {
     srcp = keep;
-    rows_section = 0;
-    while (rows_section < 4) {
+    for (r = 0; r < N; r += 4) {
       sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
       srcp += src_stride << shift_for_4_rows;
       refp[i] += ref_stride << shift_for_4_rows;
-      rows_section++;
     }
   }
   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
 }
 
-void aom_highbd_sad32x32x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
-  uint32_t first_half[4];
-  uint32_t second_half[4];
-  const uint8_t *ref[4];
-  const int shift_for_rows = 4;
-
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
-
-  aom_highbd_sad32x16x4d_avx2(src, src_stride, ref, ref_stride, first_half);
-  src += src_stride << shift_for_rows;
-  ref[0] += ref_stride << shift_for_rows;
-  ref[1] += ref_stride << shift_for_rows;
-  ref[2] += ref_stride << shift_for_rows;
-  ref[3] += ref_stride << shift_for_rows;
-  aom_highbd_sad32x16x4d_avx2(src, src_stride, ref, ref_stride, second_half);
-  sad_array[0] = first_half[0] + second_half[0];
-  sad_array[1] = first_half[1] + second_half[1];
-  sad_array[2] = first_half[2] + second_half[2];
-  sad_array[3] = first_half[3] + second_half[3];
-}
-
-void aom_highbd_sad32x64x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
-  uint32_t first_half[4];
-  uint32_t second_half[4];
-  const uint8_t *ref[4];
-  const int shift_for_rows = 5;
-
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
-
-  aom_highbd_sad32x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
-  src += src_stride << shift_for_rows;
-  ref[0] += ref_stride << shift_for_rows;
-  ref[1] += ref_stride << shift_for_rows;
-  ref[2] += ref_stride << shift_for_rows;
-  ref[3] += ref_stride << shift_for_rows;
-  aom_highbd_sad32x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
-  sad_array[0] = first_half[0] + second_half[0];
-  sad_array[1] = first_half[1] + second_half[1];
-  sad_array[2] = first_half[2] + second_half[2];
-  sad_array[3] = first_half[3] + second_half[3];
-}
-
-void aom_highbd_sad64x16x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
+static AOM_FORCE_INLINE void aom_highbd_sad64xNx4d_avx2(
+    int N, const uint8_t *src, int src_stride, const uint8_t *const ref_array[],
+    int ref_stride, uint32_t *sad_array) {
   __m256i sad_vec[4];
   const uint16_t *refp[4];
   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
   const uint16_t *srcp;
   const int shift_for_rows = 1;
-  int i;
-  int rows_section;
+  int i, r;
 
   init_sad(sad_vec);
   convert_pointers(ref_array, refp);
 
   for (i = 0; i < 4; ++i) {
     srcp = keep;
-    rows_section = 0;
-    while (rows_section < 8) {
+    for (r = 0; r < N; r += 2) {
       sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
       srcp += src_stride << shift_for_rows;
       refp[i] += ref_stride << shift_for_rows;
-      rows_section++;
     }
   }
   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
 }
 
-void aom_highbd_sad64x32x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
+static AOM_FORCE_INLINE void aom_highbd_sad128xNx4d_avx2(
+    int N, const uint8_t *src, int src_stride, const uint8_t *const ref_array[],
+    int ref_stride, uint32_t *sad_array) {
   __m256i sad_vec[4];
   const uint16_t *refp[4];
   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
   const uint16_t *srcp;
-  const int shift_for_rows = 1;
-  int i;
-  int rows_section;
+  int i, r;
 
   init_sad(sad_vec);
   convert_pointers(ref_array, refp);
 
   for (i = 0; i < 4; ++i) {
     srcp = keep;
-    rows_section = 0;
-    while (rows_section < 16) {
-      sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
-      srcp += src_stride << shift_for_rows;
-      refp[i] += ref_stride << shift_for_rows;
-      rows_section++;
-    }
-  }
-  get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
-}
-
-void aom_highbd_sad64x64x4d_avx2(const uint8_t *src, int src_stride,
-                                 const uint8_t *const ref_array[],
-                                 int ref_stride, uint32_t *sad_array) {
-  uint32_t first_half[4];
-  uint32_t second_half[4];
-  const uint8_t *ref[4];
-  const int shift_for_rows = 5;
-
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
-
-  aom_highbd_sad64x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
-  src += src_stride << shift_for_rows;
-  ref[0] += ref_stride << shift_for_rows;
-  ref[1] += ref_stride << shift_for_rows;
-  ref[2] += ref_stride << shift_for_rows;
-  ref[3] += ref_stride << shift_for_rows;
-  aom_highbd_sad64x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
-  sad_array[0] = first_half[0] + second_half[0];
-  sad_array[1] = first_half[1] + second_half[1];
-  sad_array[2] = first_half[2] + second_half[2];
-  sad_array[3] = first_half[3] + second_half[3];
-}
-
-void aom_highbd_sad64x128x4d_avx2(const uint8_t *src, int src_stride,
-                                  const uint8_t *const ref_array[],
-                                  int ref_stride, uint32_t *sad_array) {
-  uint32_t first_half[4];
-  uint32_t second_half[4];
-  const uint8_t *ref[4];
-  const int shift_for_rows = 6;
-
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
-
-  aom_highbd_sad64x64x4d_avx2(src, src_stride, ref, ref_stride, first_half);
-  src += src_stride << shift_for_rows;
-  ref[0] += ref_stride << shift_for_rows;
-  ref[1] += ref_stride << shift_for_rows;
-  ref[2] += ref_stride << shift_for_rows;
-  ref[3] += ref_stride << shift_for_rows;
-  aom_highbd_sad64x64x4d_avx2(src, src_stride, ref, ref_stride, second_half);
-  sad_array[0] = first_half[0] + second_half[0];
-  sad_array[1] = first_half[1] + second_half[1];
-  sad_array[2] = first_half[2] + second_half[2];
-  sad_array[3] = first_half[3] + second_half[3];
-}
-
-void aom_highbd_sad128x64x4d_avx2(const uint8_t *src, int src_stride,
-                                  const uint8_t *const ref_array[],
-                                  int ref_stride, uint32_t *sad_array) {
-  __m256i sad_vec[4];
-  const uint16_t *refp[4];
-  const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
-  const uint16_t *srcp;
-  int i;
-  int rows_section;
-
-  init_sad(sad_vec);
-  convert_pointers(ref_array, refp);
-
-  for (i = 0; i < 4; ++i) {
-    srcp = keep;
-    rows_section = 0;
-    while (rows_section < 64) {
+    for (r = 0; r < N; r++) {
       sad128x1(srcp, refp[i], NULL, &sad_vec[i]);
       srcp += src_stride;
       refp[i] += ref_stride;
-      rows_section++;
     }
   }
   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
 }
 
-void aom_highbd_sad128x128x4d_avx2(const uint8_t *src, int src_stride,
-                                   const uint8_t *const ref_array[],
-                                   int ref_stride, uint32_t *sad_array) {
-  uint32_t first_half[4];
-  uint32_t second_half[4];
-  const uint8_t *ref[4];
-  const int shift_for_rows = 6;
+#define highbd_sadMxNx4d_avx2(m, n)                                          \
+  void aom_highbd_sad##m##x##n##x4d_avx2(                                    \
+      const uint8_t *src, int src_stride, const uint8_t *const ref_array[],  \
+      int ref_stride, uint32_t *sad_array) {                                 \
+    aom_highbd_sad##m##xNx4d_avx2(n, src, src_stride, ref_array, ref_stride, \
+                                  sad_array);                                \
+  }
 
-  ref[0] = ref_array[0];
-  ref[1] = ref_array[1];
-  ref[2] = ref_array[2];
-  ref[3] = ref_array[3];
+highbd_sadMxNx4d_avx2(16, 4);
+highbd_sadMxNx4d_avx2(16, 8);
+highbd_sadMxNx4d_avx2(16, 16);
+highbd_sadMxNx4d_avx2(16, 32);
+highbd_sadMxNx4d_avx2(16, 64);
 
-  aom_highbd_sad128x64x4d_avx2(src, src_stride, ref, ref_stride, first_half);
-  src += src_stride << shift_for_rows;
-  ref[0] += ref_stride << shift_for_rows;
-  ref[1] += ref_stride << shift_for_rows;
-  ref[2] += ref_stride << shift_for_rows;
-  ref[3] += ref_stride << shift_for_rows;
-  aom_highbd_sad128x64x4d_avx2(src, src_stride, ref, ref_stride, second_half);
-  sad_array[0] = first_half[0] + second_half[0];
-  sad_array[1] = first_half[1] + second_half[1];
-  sad_array[2] = first_half[2] + second_half[2];
-  sad_array[3] = first_half[3] + second_half[3];
-}
+highbd_sadMxNx4d_avx2(32, 8);
+highbd_sadMxNx4d_avx2(32, 16);
+highbd_sadMxNx4d_avx2(32, 32);
+highbd_sadMxNx4d_avx2(32, 64);
+
+highbd_sadMxNx4d_avx2(64, 16);
+highbd_sadMxNx4d_avx2(64, 32);
+highbd_sadMxNx4d_avx2(64, 64);
+highbd_sadMxNx4d_avx2(64, 128);
+
+highbd_sadMxNx4d_avx2(128, 64);
+highbd_sadMxNx4d_avx2(128, 128);