Remove some subtract_plane operation

1. remove some subtract_plane
2. add aom_sse for model_rd_for_sb
3. add aom_sse sse4_1 and avx2 code
4. add aom_highbd_sse sse4_1 and avx2 code

speed up about 1.2% without rd change

test sequence: BasketballDrill_832x480_50.y4m

test command line:./aomenc --cpu-used=1 --psnr -D \
 -q --end-usage=vbr --target-bitrate=800 --limit=20 \
 BasketballDrill_832x480_50.y4m -otest.webm

Change-Id: Ibaac08ff21e7f7dcbde58828d8e8c6b9012d7de7
diff --git a/aom_dsp/aom_dsp.cmake b/aom_dsp/aom_dsp.cmake
index 7f22de9..a76f88f 100644
--- a/aom_dsp/aom_dsp.cmake
+++ b/aom_dsp/aom_dsp.cmake
@@ -158,6 +158,7 @@
               "${AOM_ROOT}/aom_dsp/quantize.c"
               "${AOM_ROOT}/aom_dsp/quantize.h"
               "${AOM_ROOT}/aom_dsp/sad.c"
+              "${AOM_ROOT}/aom_dsp/sse.c"
               "${AOM_ROOT}/aom_dsp/sad_av1.c"
               "${AOM_ROOT}/aom_dsp/sum_squares.c"
               "${AOM_ROOT}/aom_dsp/variance.c"
@@ -197,6 +198,7 @@
               "${AOM_ROOT}/aom_dsp/x86/sad_highbd_avx2.c"
               "${AOM_ROOT}/aom_dsp/x86/sad_impl_avx2.c"
               "${AOM_ROOT}/aom_dsp/x86/variance_avx2.c"
+              "${AOM_ROOT}/aom_dsp/x86/sse_avx2.c"
               "${AOM_ROOT}/aom_dsp/x86/variance_impl_avx2.c"
               "${AOM_ROOT}/aom_dsp/x86/obmc_sad_avx2.c"
               "${AOM_ROOT}/aom_dsp/x86/obmc_variance_avx2.c")
@@ -218,6 +220,7 @@
 
   list(APPEND AOM_DSP_ENCODER_INTRIN_SSE4_1
               "${AOM_ROOT}/aom_dsp/x86/highbd_variance_sse4.c"
+              "${AOM_ROOT}/aom_dsp/x86/sse_sse4.c"
               "${AOM_ROOT}/aom_dsp/x86/obmc_sad_sse4.c"
               "${AOM_ROOT}/aom_dsp/x86/obmc_variance_sse4.c")
 
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 85b2d9b..040ac16 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -569,6 +569,12 @@
   add_proto qw/void aom_highbd_subtract_block/, "int rows, int cols, int16_t *diff_ptr, ptrdiff_t diff_stride, const uint8_t *src_ptr, ptrdiff_t src_stride, const uint8_t *pred_ptr, ptrdiff_t pred_stride, int bd";
   specialize qw/aom_highbd_subtract_block sse2/;
 
+  add_proto qw/int64_t/, "aom_sse", "const uint8_t *a, int a_stride, const uint8_t *b,int b_stride, int width, int height";
+  specialize qw/aom_sse  sse4_1 avx2/;
+
+  add_proto qw/int64_t/, "aom_highbd_sse", "const uint8_t *a8, int a_stride, const uint8_t *b8,int b_stride, int width, int height";
+  specialize qw/aom_highbd_sse  sse4_1 avx2/;
+
   if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
     #
     # Sum of Squares
@@ -578,6 +584,7 @@
 
     add_proto qw/uint64_t aom_sum_squares_i16/, "const int16_t *src, uint32_t N";
     specialize qw/aom_sum_squares_i16 sse2/;
+
   }
 
 
@@ -830,7 +837,6 @@
   specialize qw/aom_highbd_sad16x64x4d sse2/;
   specialize qw/aom_highbd_sad64x16x4d sse2/;
 
-
   #
   # Structured Similarity (SSIM)
   #
diff --git a/aom_dsp/sse.c b/aom_dsp/sse.c
new file mode 100644
index 0000000..2493948
--- /dev/null
+++ b/aom_dsp/sse.c
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+/* Sum the difference between every corresponding element of the buffers. */
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+#include "aom/aom_integer.h"
+
+int64_t aom_sse_c(const uint8_t *a, int a_stride, const uint8_t *b,
+                  int b_stride, int width, int height) {
+  int y, x;
+  int64_t sse = 0;
+
+  for (y = 0; y < height; y++) {
+    for (x = 0; x < width; x++) {
+      const int32_t diff = abs(a[x] - b[x]);
+      sse += diff * diff;
+    }
+
+    a += a_stride;
+    b += b_stride;
+  }
+  return sse;
+}
+
+int64_t aom_highbd_sse_c(const uint8_t *a8, int a_stride, const uint8_t *b8,
+                         int b_stride, int width, int height) {
+  int y, x;
+  int64_t sse = 0;
+  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+  for (y = 0; y < height; y++) {
+    for (x = 0; x < width; x++) {
+      const int32_t diff = (int32_t)(a[x]) - (int32_t)(b[x]);
+      sse += diff * diff;
+    }
+
+    a += a_stride;
+    b += b_stride;
+  }
+  return sse;
+}
diff --git a/aom_dsp/x86/sse_avx2.c b/aom_dsp/x86/sse_avx2.c
new file mode 100644
index 0000000..305dde5
--- /dev/null
+++ b/aom_dsp/x86/sse_avx2.c
@@ -0,0 +1,250 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+#include <smmintrin.h>
+#include <immintrin.h>
+
+#include "config/aom_dsp_rtcd.h"
+
+#include "aom_ports/mem.h"
+#include "aom_dsp/x86/synonyms.h"
+#include "aom_dsp/x86/synonyms_avx2.h"
+
+static INLINE void sse_w32_avx2(__m256i *sum, const uint8_t *a,
+                                const uint8_t *b) {
+  const __m256i v_a0 = yy_loadu_256(a);
+  const __m256i v_b0 = yy_loadu_256(b);
+  const __m256i v_a00_w = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_a0));
+  const __m256i v_a01_w =
+      _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_a0, 1));
+  const __m256i v_b00_w = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_b0));
+  const __m256i v_b01_w =
+      _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_b0, 1));
+  const __m256i v_d00_w = _mm256_sub_epi16(v_a00_w, v_b00_w);
+  const __m256i v_d01_w = _mm256_sub_epi16(v_a01_w, v_b01_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d00_w, v_d00_w));
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d01_w, v_d01_w));
+}
+
+static INLINE int64_t summary_all_avx2(const __m256i *sum_all) {
+  int64_t sum;
+  const __m256i sum0_4x64 =
+      _mm256_cvtepu32_epi64(_mm256_castsi256_si128(*sum_all));
+  const __m256i sum1_4x64 =
+      _mm256_cvtepu32_epi64(_mm256_extracti128_si256(*sum_all, 1));
+  const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64);
+  const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64),
+                                         _mm256_extracti128_si256(sum_4x64, 1));
+  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
+
+  xx_storel_64(&sum, sum_1x64);
+  return sum;
+}
+
+int64_t aom_sse_avx2(const uint8_t *a, int a_stride, const uint8_t *b,
+                     int b_stride, int width, int height) {
+  int32_t y = 0;
+  int64_t sse = 0;
+  __m256i sum = _mm256_setzero_si256();
+  switch (width) {
+    case 4:
+      do {
+        const __m128i v_a0 = xx_loadl_32(a);
+        const __m128i v_a1 = xx_loadl_32(a + a_stride);
+        const __m128i v_a2 = xx_loadl_32(a + a_stride * 2);
+        const __m128i v_a3 = xx_loadl_32(a + a_stride * 3);
+        const __m128i v_b0 = xx_loadl_32(b);
+        const __m128i v_b1 = xx_loadl_32(b + b_stride);
+        const __m128i v_b2 = xx_loadl_32(b + b_stride * 2);
+        const __m128i v_b3 = xx_loadl_32(b + b_stride * 3);
+        const __m128i v_a0123 = _mm_unpacklo_epi64(
+            _mm_unpacklo_epi32(v_a0, v_a1), _mm_unpacklo_epi32(v_a2, v_a3));
+        const __m128i v_b0123 = _mm_unpacklo_epi64(
+            _mm_unpacklo_epi32(v_b0, v_b1), _mm_unpacklo_epi32(v_b2, v_b3));
+        const __m256i v_a_w = _mm256_cvtepu8_epi16(v_a0123);
+        const __m256i v_b_w = _mm256_cvtepu8_epi16(v_b0123);
+        const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+        sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w));
+        a += a_stride << 2;
+        b += b_stride << 2;
+        y += 4;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 8:
+      do {
+        const __m128i v_a0 = xx_loadl_64(a);
+        const __m128i v_a1 = xx_loadl_64(a + a_stride);
+        const __m128i v_b0 = xx_loadl_64(b);
+        const __m128i v_b1 = xx_loadl_64(b + b_stride);
+        const __m256i v_a_w =
+            _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_a0, v_a1));
+        const __m256i v_b_w =
+            _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_b0, v_b1));
+        const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+        sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w));
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 16:
+      do {
+        const __m128i v_a0 = xx_loadu_128(a);
+        const __m128i v_b0 = xx_loadu_128(b);
+        const __m256i v_a_w = _mm256_cvtepu8_epi16(v_a0);
+        const __m256i v_b_w = _mm256_cvtepu8_epi16(v_b0);
+        const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+        sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w));
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 32:
+      do {
+        sse_w32_avx2(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 64:
+      do {
+        sse_w32_avx2(&sum, a, b);
+        sse_w32_avx2(&sum, a + 32, b + 32);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 128:
+      do {
+        sse_w32_avx2(&sum, a, b);
+        sse_w32_avx2(&sum, a + 32, b + 32);
+        sse_w32_avx2(&sum, a + 64, b + 64);
+        sse_w32_avx2(&sum, a + 96, b + 96);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    default: break;
+  }
+
+  return sse;
+}
+
+static INLINE void highbd_sse_w16_avx2(__m256i *sum, const uint16_t *a,
+                                       const uint16_t *b) {
+  const __m256i v_a_w = yy_loadu_256(a);
+  const __m256i v_b_w = yy_loadu_256(b);
+  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
+}
+
+int64_t aom_highbd_sse_avx2(const uint8_t *a8, int a_stride, const uint8_t *b8,
+                            int b_stride, int width, int height) {
+  int32_t y = 0;
+  int64_t sse = 0;
+  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+  __m256i sum = _mm256_setzero_si256();
+  switch (width) {
+    case 4:
+      do {
+        const __m128i v_a0 = xx_loadl_64(a);
+        const __m128i v_a1 = xx_loadl_64(a + a_stride);
+        const __m128i v_a2 = xx_loadl_64(a + a_stride * 2);
+        const __m128i v_a3 = xx_loadl_64(a + a_stride * 3);
+        const __m128i v_b0 = xx_loadl_64(b);
+        const __m128i v_b1 = xx_loadl_64(b + b_stride);
+        const __m128i v_b2 = xx_loadl_64(b + b_stride * 2);
+        const __m128i v_b3 = xx_loadl_64(b + b_stride * 3);
+        const __m256i v_a_w = yy_set_m128i(_mm_unpacklo_epi64(v_a0, v_a1),
+                                           _mm_unpacklo_epi64(v_a2, v_a3));
+        const __m256i v_b_w = yy_set_m128i(_mm_unpacklo_epi64(v_b0, v_b1),
+                                           _mm_unpacklo_epi64(v_b2, v_b3));
+        const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+        sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w));
+        a += a_stride << 2;
+        b += b_stride << 2;
+        y += 4;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 8:
+      do {
+        const __m256i v_a_w = yy_loadu2_128(a + a_stride, a);
+        const __m256i v_b_w = yy_loadu2_128(b + b_stride, b);
+        const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+        sum = _mm256_add_epi32(sum, _mm256_madd_epi16(v_d_w, v_d_w));
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 16:
+      do {
+        highbd_sse_w16_avx2(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 32:
+      do {
+        highbd_sse_w16_avx2(&sum, a, b);
+        highbd_sse_w16_avx2(&sum, a + 16, b + 16);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 64:
+      do {
+        highbd_sse_w16_avx2(&sum, a, b);
+        highbd_sse_w16_avx2(&sum, a + 16 * 1, b + 16 * 1);
+        highbd_sse_w16_avx2(&sum, a + 16 * 2, b + 16 * 2);
+        highbd_sse_w16_avx2(&sum, a + 16 * 3, b + 16 * 3);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 128:
+      do {
+        highbd_sse_w16_avx2(&sum, a, b);
+        highbd_sse_w16_avx2(&sum, a + 16 * 1, b + 16 * 1);
+        highbd_sse_w16_avx2(&sum, a + 16 * 2, b + 16 * 2);
+        highbd_sse_w16_avx2(&sum, a + 16 * 3, b + 16 * 3);
+        highbd_sse_w16_avx2(&sum, a + 16 * 4, b + 16 * 4);
+        highbd_sse_w16_avx2(&sum, a + 16 * 5, b + 16 * 5);
+        highbd_sse_w16_avx2(&sum, a + 16 * 6, b + 16 * 6);
+        highbd_sse_w16_avx2(&sum, a + 16 * 7, b + 16 * 7);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    default: break;
+  }
+  return sse;
+}
diff --git a/aom_dsp/x86/sse_sse4.c b/aom_dsp/x86/sse_sse4.c
new file mode 100644
index 0000000..8b5af84
--- /dev/null
+++ b/aom_dsp/x86/sse_sse4.c
@@ -0,0 +1,241 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <assert.h>
+#include <smmintrin.h>
+
+#include "config/aom_config.h"
+
+#include "aom_ports/mem.h"
+#include "aom/aom_integer.h"
+#include "aom_dsp/x86/synonyms.h"
+
+static INLINE int64_t summary_all_sse4(const __m128i *sum_all) {
+  int64_t sum;
+  const __m128i sum0 = _mm_cvtepu32_epi64(*sum_all);
+  const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum_all, 8));
+  const __m128i sum_2x64 = _mm_add_epi64(sum0, sum1);
+  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
+  xx_storel_64(&sum, sum_1x64);
+  return sum;
+}
+
+static INLINE void sse_w16_sse4_1(__m128i *sum, const uint8_t *a,
+                                  const uint8_t *b) {
+  const __m128i v_a0 = xx_loadu_128(a);
+  const __m128i v_b0 = xx_loadu_128(b);
+  const __m128i v_a00_w = _mm_cvtepu8_epi16(v_a0);
+  const __m128i v_a01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_a0, 8));
+  const __m128i v_b00_w = _mm_cvtepu8_epi16(v_b0);
+  const __m128i v_b01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_b0, 8));
+  const __m128i v_d00_w = _mm_sub_epi16(v_a00_w, v_b00_w);
+  const __m128i v_d01_w = _mm_sub_epi16(v_a01_w, v_b01_w);
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d00_w, v_d00_w));
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d01_w, v_d01_w));
+}
+
+int64_t aom_sse_sse4_1(const uint8_t *a, int a_stride, const uint8_t *b,
+                       int b_stride, int width, int height) {
+  int y = 0;
+  int64_t sse = 0;
+  __m128i sum = _mm_setzero_si128();
+  switch (width) {
+    case 4:
+      do {
+        const __m128i v_a0 = xx_loadl_32(a);
+        const __m128i v_a1 = xx_loadl_32(a + a_stride);
+        const __m128i v_b0 = xx_loadl_32(b);
+        const __m128i v_b1 = xx_loadl_32(b + b_stride);
+        const __m128i v_a_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_a0, v_a1));
+        const __m128i v_b_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_b0, v_b1));
+        const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+        sum = _mm_add_epi32(sum, _mm_madd_epi16(v_d_w, v_d_w));
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 8:
+      do {
+        const __m128i v_a0 = xx_loadl_64(a);
+        const __m128i v_b0 = xx_loadl_64(b);
+        const __m128i v_a_w = _mm_cvtepu8_epi16(v_a0);
+        const __m128i v_b_w = _mm_cvtepu8_epi16(v_b0);
+        const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+        sum = _mm_add_epi32(sum, _mm_madd_epi16(v_d_w, v_d_w));
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 16:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 32:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        sse_w16_sse4_1(&sum, a + 16, b + 16);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 64:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
+        sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
+        sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 128:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
+        sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
+        sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
+        sse_w16_sse4_1(&sum, a + 16 * 4, b + 16 * 4);
+        sse_w16_sse4_1(&sum, a + 16 * 5, b + 16 * 5);
+        sse_w16_sse4_1(&sum, a + 16 * 6, b + 16 * 6);
+        sse_w16_sse4_1(&sum, a + 16 * 7, b + 16 * 7);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    default: break;
+  }
+
+  return sse;
+}
+
+static INLINE void highbd_sse_w8_sse4_1(__m128i *sum, const uint16_t *a,
+                                        const uint16_t *b) {
+  const __m128i v_a_w = xx_loadu_128(a);
+  const __m128i v_b_w = xx_loadu_128(b);
+  const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
+}
+
+int64_t aom_highbd_sse_sse4_1(const uint8_t *a8, int a_stride,
+                              const uint8_t *b8, int b_stride, int width,
+                              int height) {
+  int32_t y = 0;
+  int64_t sse = 0;
+  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+  __m128i sum = _mm_setzero_si128();
+  switch (width) {
+    case 4:
+      do {
+        const __m128i v_a0 = xx_loadl_64(a);
+        const __m128i v_a1 = xx_loadl_64(a + a_stride);
+        const __m128i v_b0 = xx_loadl_64(b);
+        const __m128i v_b1 = xx_loadl_64(b + b_stride);
+        const __m128i v_a_w = _mm_unpacklo_epi64(v_a0, v_a1);
+        const __m128i v_b_w = _mm_unpacklo_epi64(v_b0, v_b1);
+        const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+        sum = _mm_add_epi32(sum, _mm_madd_epi16(v_d_w, v_d_w));
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 8:
+      do {
+        highbd_sse_w8_sse4_1(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 16:
+      do {
+        highbd_sse_w8_sse4_1(&sum, a, b);
+        highbd_sse_w8_sse4_1(&sum, a + 8, b + 8);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 32:
+      do {
+        highbd_sse_w8_sse4_1(&sum, a, b);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 1, b + 8 * 1);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 2, b + 8 * 2);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 3, b + 8 * 3);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 64:
+      do {
+        highbd_sse_w8_sse4_1(&sum, a, b);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 1, b + 8 * 1);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 2, b + 8 * 2);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 3, b + 8 * 3);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 4, b + 8 * 4);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 5, b + 8 * 5);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 6, b + 8 * 6);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 7, b + 8 * 7);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 128:
+      do {
+        highbd_sse_w8_sse4_1(&sum, a, b);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 1, b + 8 * 1);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 2, b + 8 * 2);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 3, b + 8 * 3);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 4, b + 8 * 4);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 5, b + 8 * 5);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 6, b + 8 * 6);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 7, b + 8 * 7);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 8, b + 8 * 8);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 9, b + 8 * 9);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 10, b + 8 * 10);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 11, b + 8 * 11);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 12, b + 8 * 12);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 13, b + 8 * 13);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 14, b + 8 * 14);
+        highbd_sse_w8_sse4_1(&sum, a + 8 * 15, b + 8 * 15);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    default: break;
+  }
+  return sse;
+}
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 0a7b40f..d82833b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1874,9 +1874,13 @@
 
     if (x->skip_chroma_rd && plane) continue;
 
-    // TODO(geza): Write direct sse functions that do not compute
-    // variance as well.
-    sse = aom_sum_squares_2d_i16(p->src_diff, bw, bw, bh);
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+      sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
+                           pd->dst.stride, bw, bh);
+    } else {
+      sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
+                    bh);
+    }
     sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2);
 
     model_rd_from_sse(cpi, x, plane_bsize, plane, sse, &rate, &dist);
@@ -2624,11 +2628,16 @@
 
     int bw, bh;
     const struct macroblock_plane *const p = &x->plane[plane];
-    const int diff_stride = block_size_wide[plane_bsize];
     const int shift = (xd->bd - 8);
     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
                        &bw, &bh);
-    sse = aom_sum_squares_2d_i16(p->src_diff, diff_stride, bw, bh);
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+      sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
+                           pd->dst.stride, bw, bh);
+    } else {
+      sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
+                    bh);
+    }
     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
 
     model_rd_with_dnn(cpi, x, plane_bsize, plane, sse, &rate, &dist);
@@ -2725,11 +2734,16 @@
 
     int bw, bh;
     const struct macroblock_plane *const p = &x->plane[plane];
-    const int diff_stride = block_size_wide[plane_bsize];
     const int shift = (xd->bd - 8);
     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
                        &bw, &bh);
-    sse = aom_sum_squares_2d_i16(p->src_diff, diff_stride, bw, bh);
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+      sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
+                           pd->dst.stride, bw, bh);
+    } else {
+      sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
+                    bh);
+    }
     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
 
     model_rd_with_surffit(cpi, x, plane_bsize, plane, sse, &rate, &dist);
@@ -2826,11 +2840,18 @@
 
     int bw, bh;
     const struct macroblock_plane *const p = &x->plane[plane];
-    const int diff_stride = block_size_wide[plane_bsize];
     const int shift = (xd->bd - 8);
     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
                        &bw, &bh);
-    sse = aom_sum_squares_2d_i16(p->src_diff, diff_stride, bw, bh);
+
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+      sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
+                           pd->dst.stride, bw, bh);
+    } else {
+      sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
+                    bh);
+    }
+
     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
     model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, &rate, &dist);
 
@@ -2876,7 +2897,13 @@
 
     if (x->skip_chroma_rd && plane) continue;
 
-    sse = aom_sum_squares_2d_i16(p->src_diff, bw, bw, bh);
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+      sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
+                           pd->dst.stride, bw, bh);
+    } else {
+      sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
+                    bh);
+    }
     sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2);
 
     RD_STATS rd_stats;
@@ -3667,7 +3694,6 @@
     }
   }
   // RD estimation.
-  av1_subtract_plane(x, bsize, 0);
   model_rd_fn[MODELRD_LEGACY](cpi, bsize, x, xd, 0, 0, mi_row, mi_col,
                               &this_rd_stats.rate, &this_rd_stats.dist,
                               &this_rd_stats.skip, &temp_sse, NULL, NULL, NULL);
@@ -7646,7 +7672,6 @@
     *out_rate_mv = interinter_compound_motion_search(cpi, x, cur_mv, bsize,
                                                      this_mode, mi_row, mi_col);
     av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, ctx, bsize);
-    av1_subtract_plane(x, bsize, 0);
     model_rd_fn[MODELRD_LEGACY](cpi, bsize, x, xd, 0, 0, mi_row, mi_col,
                                 &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
                                 &tmp_skip_sse_sb, NULL, NULL, NULL);
@@ -7844,7 +7869,6 @@
   if (skip_pred != cpi->default_interp_skip_flags) {
     if (skip_pred != DEFAULT_LUMA_INTERP_SKIP_FLAG) {
       av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize);
-      av1_subtract_plane(x, bsize, 0);
 #if CONFIG_COLLECT_RD_STATS == 3
       RD_STATS rd_stats_y;
       select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
@@ -7871,7 +7895,6 @@
         }
         av1_build_inter_predictors_sbp(cm, xd, mi_row, mi_col, orig_dst, bsize,
                                        plane);
-        av1_subtract_plane(x, bsize, plane);
         model_rd_fn[MODELRD_TYPE_INTERP_FILTER](
             cpi, bsize, x, xd, plane, plane, mi_row, mi_col, &tmp_rate_uv,
             &tmp_dist_uv, &tmp_skip_sb_uv, &tmp_skip_sse_uv, NULL, NULL, NULL);
@@ -8074,8 +8097,7 @@
       get_switchable_rate(x, mbmi->interp_filters, switchable_ctx);
   if (!skip_build_pred)
     av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, orig_dst, bsize);
-  for (int plane = 0; plane < num_planes; ++plane)
-    av1_subtract_plane(x, bsize, plane);
+
 #if CONFIG_COLLECT_RD_STATS == 3
   RD_STATS rd_stats_y;
   select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
@@ -8552,7 +8574,6 @@
         av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
                                                   intrapred, bw);
         av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
-        av1_subtract_plane(x, bsize, 0);
         model_rd_fn[MODELRD_LEGACY](cpi, bsize, x, xd, 0, 0, mi_row, mi_col,
                                     &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
                                     &tmp_skip_sse_sb, NULL, NULL, NULL);
@@ -8610,7 +8631,6 @@
             mbmi->mv[0].as_int = tmp_mv.as_int;
             av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst,
                                            bsize);
-            av1_subtract_plane(x, bsize, 0);
             model_rd_fn[MODELRD_LEGACY](cpi, bsize, x, xd, 0, 0, mi_row, mi_col,
                                         &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
                                         &tmp_skip_sse_sb, NULL, NULL, NULL);