SSE4.1 and AVX2 highbd_pixel_proj_error

Add SSE4.1 and AVX2 implementations of the high bit-depth
pixel_proj_error function used in the self-guided filter encoder

SSE4.1 unit test speed-ups
r0=0 r1=0: 5.1x
r0=0 r1=1: 2.7x
r0=1 r1=0: 2.7x
r0=1 r1=1: 2.4x

AVX2 unit test speed-ups
r0=0 r1=0: 6.0x
r0=0 r1=1: 3.5x
r0=1 r1=0: 3.5x
r0=1 r1=1: 2.8x

Encoding 15 frames of ducks_take_off_1080p using AVX2:
cpu=5 bd=10 speed-up=3.6%
cpu=5 bd=12 speed-up=4.0%
cpu=3 bd=10 speed-up=2.4%
cpu=3 bd=12 speed-up=3.2%
cpu=2 bd=10 speed-up=1.0%
cpu=2 bd=12 speed-up=2.1%

Change-Id: I3aac63eeb8d1af33c1d6414dd910708d3dad1793
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 1cd6b4a..481d6b8 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -298,6 +298,9 @@
 
   add_proto qw/int64_t av1_lowbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
   specialize qw/av1_lowbd_pixel_proj_error sse4_1 avx2/;
+
+  add_proto qw/int64_t av1_highbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
+  specialize qw/av1_highbd_pixel_proj_error sse4_1 avx2/;
 }
 # end encoder functions
 
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 4faa849..2858d16 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -255,88 +255,97 @@
   return err;
 }
 
+int64_t av1_highbd_pixel_proj_error_c(const uint8_t *src8, int width,
+                                      int height, int src_stride,
+                                      const uint8_t *dat8, int dat_stride,
+                                      int32_t *flt0, int flt0_stride,
+                                      int32_t *flt1, int flt1_stride, int xq[2],
+                                      const sgr_params_type *params) {
+  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
+  int i, j;
+  int64_t err = 0;
+  const int32_t half = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
+  if (params->r[0] > 0 && params->r[1] > 0) {
+    int xq0 = xq[0];
+    int xq1 = xq[1];
+    for (i = 0; i < height; ++i) {
+      for (j = 0; j < width; ++j) {
+        const int32_t d = dat[j];
+        const int32_t s = src[j];
+        const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
+        int32_t v0 = flt0[j] - u;
+        int32_t v1 = flt1[j] - u;
+        int32_t v = half;
+        v += xq0 * v0;
+        v += xq1 * v1;
+        const int32_t e = (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
+        err += e * e;
+      }
+      dat += dat_stride;
+      flt0 += flt0_stride;
+      flt1 += flt1_stride;
+      src += src_stride;
+    }
+  } else if (params->r[0] > 0 || params->r[1] > 0) {
+    int exq;
+    int32_t *flt;
+    int flt_stride;
+    if (params->r[0] > 0) {
+      exq = xq[0];
+      flt = flt0;
+      flt_stride = flt0_stride;
+    } else {
+      exq = xq[1];
+      flt = flt1;
+      flt_stride = flt1_stride;
+    }
+    for (i = 0; i < height; ++i) {
+      for (j = 0; j < width; ++j) {
+        const int32_t d = dat[j];
+        const int32_t s = src[j];
+        const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
+        int32_t v = half;
+        v += exq * (flt[j] - u);
+        const int32_t e = (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
+        err += e * e;
+      }
+      dat += dat_stride;
+      flt += flt_stride;
+      src += src_stride;
+    }
+  } else {
+    for (i = 0; i < height; ++i) {
+      for (j = 0; j < width; ++j) {
+        const int32_t d = dat[j];
+        const int32_t s = src[j];
+        const int32_t e = d - s;
+        err += e * e;
+      }
+      dat += dat_stride;
+      src += src_stride;
+    }
+  }
+  return err;
+}
+
 static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
                                     int src_stride, const uint8_t *dat8,
                                     int dat_stride, int use_highbitdepth,
                                     int32_t *flt0, int flt0_stride,
                                     int32_t *flt1, int flt1_stride, int *xqd,
                                     const sgr_params_type *params) {
-  int i, j;
-  int64_t err = 0;
   int xq[2];
   decode_xq(xqd, xq, params);
   if (!use_highbitdepth) {
-    err = av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8,
-                                     dat_stride, flt0, flt0_stride, flt1,
-                                     flt1_stride, xq, params);
+    return av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8,
+                                      dat_stride, flt0, flt0_stride, flt1,
+                                      flt1_stride, xq, params);
   } else {
-    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
-    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
-    const int32_t half = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
-    if (params->r[0] > 0 && params->r[1] > 0) {
-      int xq0 = xq[0];
-      int xq1 = xq[1];
-      for (i = 0; i < height; ++i) {
-        for (j = 0; j < width; ++j) {
-          const int32_t d = dat[j];
-          const int32_t s = src[j];
-          const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
-          int32_t v0 = flt0[j] - u;
-          int32_t v1 = flt1[j] - u;
-          int32_t v = half;
-          v += xq0 * v0;
-          v += xq1 * v1;
-          const int32_t e =
-              (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
-          err += e * e;
-        }
-        dat += dat_stride;
-        flt0 += flt0_stride;
-        flt1 += flt1_stride;
-        src += src_stride;
-      }
-    } else if (params->r[0] > 0 || params->r[1] > 0) {
-      int exq;
-      int32_t *flt;
-      int flt_stride;
-      if (params->r[0] > 0) {
-        exq = xq[0];
-        flt = flt0;
-        flt_stride = flt0_stride;
-      } else {
-        exq = xq[1];
-        flt = flt1;
-        flt_stride = flt1_stride;
-      }
-      for (i = 0; i < height; ++i) {
-        for (j = 0; j < width; ++j) {
-          const int32_t d = dat[j];
-          const int32_t s = src[j];
-          const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
-          int32_t v = half;
-          v += exq * (flt[j] - u);
-          const int32_t e =
-              (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
-          err += e * e;
-        }
-        dat += dat_stride;
-        flt += flt_stride;
-        src += src_stride;
-      }
-    } else {
-      for (i = 0; i < height; ++i) {
-        for (j = 0; j < width; ++j) {
-          const int32_t d = dat[j];
-          const int32_t s = src[j];
-          const int32_t e = d - s;
-          err += e * e;
-        }
-        dat += dat_stride;
-        src += src_stride;
-      }
-    }
+    return av1_highbd_pixel_proj_error(src8, width, height, src_stride, dat8,
+                                       dat_stride, flt0, flt0_stride, flt1,
+                                       flt1_stride, xq, params);
   }
-  return err;
 }
 
 #define USE_SGRPROJ_REFINEMENT_SEARCH 1
diff --git a/av1/encoder/x86/pickrst_avx2.c b/av1/encoder/x86/pickrst_avx2.c
index 579e424..7a63c60 100644
--- a/av1/encoder/x86/pickrst_avx2.c
+++ b/av1/encoder/x86/pickrst_avx2.c
@@ -621,3 +621,223 @@
   err += sum[0] + sum[1] + sum[2] + sum[3];
   return err;
 }
+
+int64_t av1_highbd_pixel_proj_error_avx2(
+    const uint8_t *src8, int width, int height, int src_stride,
+    const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+    int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
+  int i, j, k;
+  const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
+  const __m256i rounding = _mm256_set1_epi32(1 << (shift - 1));
+  __m256i sum64 = _mm256_setzero_si256();
+  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
+  int64_t err = 0;
+  if (params->r[0] > 0 && params->r[1] > 0) {  // Both filters are enabled
+    const __m256i xq0 = _mm256_set1_epi32(xq[0]);
+    const __m256i xq1 = _mm256_set1_epi32(xq[1]);
+    for (i = 0; i < height; ++i) {
+      __m256i sum32 = _mm256_setzero_si256();
+      for (j = 0; j <= width - 16; j += 16) {  // Process 16 pixels at a time
+        // Load 16 pixels each from source image and corrupted image
+        const __m256i s0 = yy_loadu_256(src + j);
+        const __m256i d0 = yy_loadu_256(dat + j);
+        // s0 = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16 (indices)
+
+        // Shift-up each pixel to match filtered image scaling
+        const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS);
+
+        // Split u0 into two halves and pad each from u16 to i32
+        const __m256i u0l = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(u0));
+        const __m256i u0h =
+            _mm256_cvtepu16_epi32(_mm256_extracti128_si256(u0, 1));
+        // u0h, u0l = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as u32
+
+        // Load 16 pixels from each filtered image
+        const __m256i flt0l = yy_loadu_256(flt0 + j);
+        const __m256i flt0h = yy_loadu_256(flt0 + j + 8);
+        const __m256i flt1l = yy_loadu_256(flt1 + j);
+        const __m256i flt1h = yy_loadu_256(flt1 + j + 8);
+        // flt?l, flt?h = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as u32
+
+        // Subtract shifted corrupt image from each filtered image
+        const __m256i flt0l_subu = _mm256_sub_epi32(flt0l, u0l);
+        const __m256i flt0h_subu = _mm256_sub_epi32(flt0h, u0h);
+        const __m256i flt1l_subu = _mm256_sub_epi32(flt1l, u0l);
+        const __m256i flt1h_subu = _mm256_sub_epi32(flt1h, u0h);
+
+        // Multiply basis vectors by appropriate coefficients
+        const __m256i v0l = _mm256_mullo_epi32(flt0l_subu, xq0);
+        const __m256i v0h = _mm256_mullo_epi32(flt0h_subu, xq0);
+        const __m256i v1l = _mm256_mullo_epi32(flt1l_subu, xq1);
+        const __m256i v1h = _mm256_mullo_epi32(flt1h_subu, xq1);
+
+        // Add together the contributions from the two basis vectors
+        const __m256i vl = _mm256_add_epi32(v0l, v1l);
+        const __m256i vh = _mm256_add_epi32(v0h, v1h);
+
+        // Right-shift v with appropriate rounding
+        const __m256i vrl =
+            _mm256_srai_epi32(_mm256_add_epi32(vl, rounding), shift);
+        const __m256i vrh =
+            _mm256_srai_epi32(_mm256_add_epi32(vh, rounding), shift);
+        // vrh, vrl = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0]
+
+        // Saturate each i32 to an i16 then combine both halves
+        // The permute (control=[3 1 2 0]) fixes weird ordering from AVX lanes
+        const __m256i vr =
+            _mm256_permute4x64_epi64(_mm256_packs_epi32(vrl, vrh), 0xd8);
+        // intermediate = [15 14 13 12 7 6 5 4] [11 10 9 8 3 2 1 0]
+        // vr = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0]
+
+        // Add twin-subspace-sgr-filter to corrupt image then subtract source
+        const __m256i e0 = _mm256_sub_epi16(_mm256_add_epi16(vr, d0), s0);
+
+        // Calculate squared error and add adjacent values
+        const __m256i err0 = _mm256_madd_epi16(e0, e0);
+
+        sum32 = _mm256_add_epi32(sum32, err0);
+      }
+
+      const __m256i sum32l =
+          _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum32));
+      sum64 = _mm256_add_epi64(sum64, sum32l);
+      const __m256i sum32h =
+          _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64, sum32h);
+
+      // Process remaining pixels in this row (modulo 16)
+      for (k = j; k < width; ++k) {
+        const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+        int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+        const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+        err += e * e;
+      }
+      dat += dat_stride;
+      src += src_stride;
+      flt0 += flt0_stride;
+      flt1 += flt1_stride;
+    }
+  } else if (params->r[0] > 0 || params->r[1] > 0) {  // Only one filter enabled
+    const int32_t xq_on = (params->r[0] > 0) ? xq[0] : xq[1];
+    const __m256i xq_active = _mm256_set1_epi32(xq_on);
+    const __m256i xq_inactive =
+        _mm256_set1_epi32(-xq_on * (1 << SGRPROJ_RST_BITS));
+    const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
+    const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
+    for (i = 0; i < height; ++i) {
+      __m256i sum32 = _mm256_setzero_si256();
+      for (j = 0; j <= width - 16; j += 16) {
+        // Load 16 pixels from source image
+        const __m256i s0 = yy_loadu_256(src + j);
+        // s0 = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16
+
+        // Load 16 pixels from corrupted image and pad each u16 to i32
+        const __m256i d0 = yy_loadu_256(dat + j);
+        const __m256i d0h =
+            _mm256_cvtepu16_epi32(_mm256_extracti128_si256(d0, 1));
+        const __m256i d0l = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(d0));
+        // d0 = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16
+        // d0h, d0l = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as i32
+
+        // Load 16 pixels from the filtered image
+        const __m256i flth = yy_loadu_256(flt + j + 8);
+        const __m256i fltl = yy_loadu_256(flt + j);
+        // flth, fltl = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as i32
+
+        const __m256i flth_xq = _mm256_mullo_epi32(flth, xq_active);
+        const __m256i fltl_xq = _mm256_mullo_epi32(fltl, xq_active);
+        const __m256i d0h_xq = _mm256_mullo_epi32(d0h, xq_inactive);
+        const __m256i d0l_xq = _mm256_mullo_epi32(d0l, xq_inactive);
+
+        const __m256i vh = _mm256_add_epi32(flth_xq, d0h_xq);
+        const __m256i vl = _mm256_add_epi32(fltl_xq, d0l_xq);
+
+        // Shift this down with appropriate rounding
+        const __m256i vrh =
+            _mm256_srai_epi32(_mm256_add_epi32(vh, rounding), shift);
+        const __m256i vrl =
+            _mm256_srai_epi32(_mm256_add_epi32(vl, rounding), shift);
+        // vrh, vrl = [15 14 13 12] [11 10 9 8], [7 6 5 4] [3 2 1 0] as i32
+
+        // Saturate each i32 to an i16 then combine both halves
+        // The permute (control=[3 1 2 0]) fixes weird ordering from AVX lanes
+        const __m256i vr =
+            _mm256_permute4x64_epi64(_mm256_packs_epi32(vrl, vrh), 0xd8);
+        // intermediate = [15 14 13 12 7 6 5 4] [11 10 9 8 3 2 1 0] as u16
+        // vr = [15 14 13 12 11 10 9 8] [7 6 5 4 3 2 1 0] as u16
+
+        // Subtract twin-subspace-sgr filtered from source image to get error
+        const __m256i e0 = _mm256_sub_epi16(_mm256_add_epi16(vr, d0), s0);
+
+        // Calculate squared error and add adjacent values
+        const __m256i err0 = _mm256_madd_epi16(e0, e0);
+
+        sum32 = _mm256_add_epi32(sum32, err0);
+      }
+
+      const __m256i sum32l =
+          _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum32));
+      sum64 = _mm256_add_epi64(sum64, sum32l);
+      const __m256i sum32h =
+          _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64, sum32h);
+
+      // Process remaining pixels in this row (modulo 16)
+      for (k = j; k < width; ++k) {
+        const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+        int32_t v = xq_on * (flt[k] - u);
+        const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+        err += e * e;
+      }
+      dat += dat_stride;
+      src += src_stride;
+      flt += flt_stride;
+    }
+  } else {  // Neither filter is enabled
+    for (i = 0; i < height; ++i) {
+      __m256i sum32 = _mm256_setzero_si256();
+      for (j = 0; j <= width - 32; j += 32) {
+        // Load 2x16 u16 from source image
+        const __m256i s0l = yy_loadu_256(src + j);
+        const __m256i s0h = yy_loadu_256(src + j + 16);
+
+        // Load 2x16 u16 from corrupted image
+        const __m256i d0l = yy_loadu_256(dat + j);
+        const __m256i d0h = yy_loadu_256(dat + j + 16);
+
+        // Subtract corrupted image from source image
+        const __m256i diffl = _mm256_sub_epi16(d0l, s0l);
+        const __m256i diffh = _mm256_sub_epi16(d0h, s0h);
+
+        // Square error and add adjacent values
+        const __m256i err0l = _mm256_madd_epi16(diffl, diffl);
+        const __m256i err0h = _mm256_madd_epi16(diffh, diffh);
+
+        sum32 = _mm256_add_epi32(sum32, err0l);
+        sum32 = _mm256_add_epi32(sum32, err0h);
+      }
+
+      const __m256i sum32l =
+          _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum32));
+      sum64 = _mm256_add_epi64(sum64, sum32l);
+      const __m256i sum32h =
+          _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64, sum32h);
+
+      // Process remaining pixels (modulu 16)
+      for (k = j; k < width; ++k) {
+        const int32_t e = (int32_t)(dat[k]) - src[k];
+        err += e * e;
+      }
+      dat += dat_stride;
+      src += src_stride;
+    }
+  }
+
+  // Sum 4 values from sum64l and sum64h into err
+  int64_t sum[4];
+  yy_storeu_256(sum, sum64);
+  err += sum[0] + sum[1] + sum[2] + sum[3];
+  return err;
+}
diff --git a/av1/encoder/x86/pickrst_sse4.c b/av1/encoder/x86/pickrst_sse4.c
index a067ab6..2326736 100644
--- a/av1/encoder/x86/pickrst_sse4.c
+++ b/av1/encoder/x86/pickrst_sse4.c
@@ -621,3 +621,209 @@
   err += sum[0] + sum[1];
   return err;
 }
+
+int64_t av1_highbd_pixel_proj_error_sse4_1(
+    const uint8_t *src8, int width, int height, int src_stride,
+    const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+    int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
+  int i, j, k;
+  const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
+  const __m128i rounding = _mm_set1_epi32(1 << (shift - 1));
+  __m128i sum64 = _mm_setzero_si128();
+  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
+  int64_t err = 0;
+  if (params->r[0] > 0 && params->r[1] > 0) {  // Both filters are enabled
+    const __m128i xq0 = _mm_set1_epi32(xq[0]);
+    const __m128i xq1 = _mm_set1_epi32(xq[1]);
+
+    for (i = 0; i < height; ++i) {
+      __m128i sum32 = _mm_setzero_si128();
+      for (j = 0; j <= width - 8; j += 8) {
+        // Load 8x pixels from source image
+        const __m128i s0 = xx_loadu_128(src + j);
+        // s0 = [7 6 5 4 3 2 1 0] as i16 (indices of src[])
+
+        // Load 8x pixels from corrupted image
+        const __m128i d0 = xx_loadu_128(dat + j);
+        // d0 = [7 6 5 4 3 2 1 0] as i16 (indices of dat[])
+
+        // Shift each pixel value up by SGRPROJ_RST_BITS
+        const __m128i u0 = _mm_slli_epi16(d0, SGRPROJ_RST_BITS);
+
+        // Split u0 into two halves and pad each from u16 to i32
+        const __m128i u0l = _mm_cvtepu16_epi32(u0);
+        const __m128i u0h = _mm_cvtepu16_epi32(_mm_srli_si128(u0, 8));
+        // u0h = [7 6 5 4] as i32, u0l = [3 2 1 0] as i32, all dat[] indices
+
+        // Load 8 pixels from first and second filtered images
+        const __m128i flt0l = xx_loadu_128(flt0 + j);
+        const __m128i flt0h = xx_loadu_128(flt0 + j + 4);
+        const __m128i flt1l = xx_loadu_128(flt1 + j);
+        const __m128i flt1h = xx_loadu_128(flt1 + j + 4);
+        // flt0 = [7 6 5 4] [3 2 1 0] as i32 (indices of flt0+j)
+        // flt1 = [7 6 5 4] [3 2 1 0] as i32 (indices of flt1+j)
+
+        // Subtract shifted corrupt image from each filtered image
+        // This gives our two basis vectors for the projection
+        const __m128i flt0l_subu = _mm_sub_epi32(flt0l, u0l);
+        const __m128i flt0h_subu = _mm_sub_epi32(flt0h, u0h);
+        const __m128i flt1l_subu = _mm_sub_epi32(flt1l, u0l);
+        const __m128i flt1h_subu = _mm_sub_epi32(flt1h, u0h);
+        // flt?h_subu = [ f[7]-u[7] f[6]-u[6] f[5]-u[5] f[4]-u[4] ] as i32
+        // flt?l_subu = [ f[3]-u[3] f[2]-u[2] f[1]-u[1] f[0]-u[0] ] as i32
+
+        // Multiply each basis vector by the corresponding coefficient
+        const __m128i v0l = _mm_mullo_epi32(flt0l_subu, xq0);
+        const __m128i v0h = _mm_mullo_epi32(flt0h_subu, xq0);
+        const __m128i v1l = _mm_mullo_epi32(flt1l_subu, xq1);
+        const __m128i v1h = _mm_mullo_epi32(flt1h_subu, xq1);
+
+        // Add together the contribution from each scaled basis vector
+        const __m128i vl = _mm_add_epi32(v0l, v1l);
+        const __m128i vh = _mm_add_epi32(v0h, v1h);
+
+        // Right-shift v with appropriate rounding
+        const __m128i vrl = _mm_srai_epi32(_mm_add_epi32(vl, rounding), shift);
+        const __m128i vrh = _mm_srai_epi32(_mm_add_epi32(vh, rounding), shift);
+
+        // Saturate each i32 value to i16 and combine lower and upper halves
+        const __m128i vr = _mm_packs_epi32(vrl, vrh);
+
+        // Add twin-subspace-sgr-filter to corrupt image then subtract source
+        const __m128i e0 = _mm_sub_epi16(_mm_add_epi16(vr, d0), s0);
+
+        // Calculate squared error and add adjacent values
+        const __m128i err0 = _mm_madd_epi16(e0, e0);
+
+        sum32 = _mm_add_epi32(sum32, err0);
+      }
+
+      const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
+      sum64 = _mm_add_epi64(sum64, sum32l);
+      const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
+      sum64 = _mm_add_epi64(sum64, sum32h);
+
+      // Process remaining pixels in this row (modulo 8)
+      for (k = j; k < width; ++k) {
+        const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+        int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+        const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+        err += e * e;
+      }
+      dat += dat_stride;
+      src += src_stride;
+      flt0 += flt0_stride;
+      flt1 += flt1_stride;
+    }
+  } else if (params->r[0] > 0 || params->r[1] > 0) {  // Only one filter enabled
+    const int32_t xq_on = (params->r[0] > 0) ? xq[0] : xq[1];
+    const __m128i xq_active = _mm_set1_epi32(xq_on);
+    const __m128i xq_inactive =
+        _mm_set1_epi32(-xq_on * (1 << SGRPROJ_RST_BITS));
+    const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
+    const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
+    for (i = 0; i < height; ++i) {
+      __m128i sum32 = _mm_setzero_si128();
+      for (j = 0; j <= width - 8; j += 8) {
+        // Load 8x pixels from source image
+        const __m128i s0 = xx_loadu_128(src + j);
+        // s0 = [7 6 5 4 3 2 1 0] as u16 (indices of src[])
+
+        // Load 8x pixels from corrupted image and pad each u16 to i32
+        const __m128i d0 = xx_loadu_128(dat + j);
+        const __m128i d0h = _mm_cvtepu16_epi32(_mm_srli_si128(d0, 8));
+        const __m128i d0l = _mm_cvtepu16_epi32(d0);
+        // d0h, d0l = [7 6 5 4], [3 2 1 0] as u32 (indices of dat[])
+
+        // Load 8 pixels from the filtered image
+        const __m128i flth = xx_loadu_128(flt + j + 4);
+        const __m128i fltl = xx_loadu_128(flt + j);
+        // flth, fltl = [7 6 5 4], [3 2 1 0] as i32 (indices of flt+j)
+
+        const __m128i flth_xq = _mm_mullo_epi32(flth, xq_active);
+        const __m128i fltl_xq = _mm_mullo_epi32(fltl, xq_active);
+        const __m128i d0h_xq = _mm_mullo_epi32(d0h, xq_inactive);
+        const __m128i d0l_xq = _mm_mullo_epi32(d0l, xq_inactive);
+
+        const __m128i vh = _mm_add_epi32(flth_xq, d0h_xq);
+        const __m128i vl = _mm_add_epi32(fltl_xq, d0l_xq);
+        // vh = [ xq0(f[7]-d[7]) xq0(f[6]-d[6]) xq0(f[5]-d[5]) xq0(f[4]-d[4]) ]
+        // vl = [ xq0(f[3]-d[3]) xq0(f[2]-d[2]) xq0(f[1]-d[1]) xq0(f[0]-d[0]) ]
+
+        // Shift this down with appropriate rounding
+        const __m128i vrh = _mm_srai_epi32(_mm_add_epi32(vh, rounding), shift);
+        const __m128i vrl = _mm_srai_epi32(_mm_add_epi32(vl, rounding), shift);
+
+        // Saturate vr0 and vr1 from i32 to i16 then pack together
+        const __m128i vr = _mm_packs_epi32(vrl, vrh);
+
+        // Subtract twin-subspace-sgr filtered from source image to get error
+        const __m128i e0 = _mm_sub_epi16(_mm_add_epi16(vr, d0), s0);
+
+        // Calculate squared error and add adjacent values
+        const __m128i err0 = _mm_madd_epi16(e0, e0);
+
+        sum32 = _mm_add_epi32(sum32, err0);
+      }
+
+      const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
+      sum64 = _mm_add_epi64(sum64, sum32l);
+      const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
+      sum64 = _mm_add_epi64(sum64, sum32h);
+
+      // Process remaining pixels in this row (modulo 8)
+      for (k = j; k < width; ++k) {
+        const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+        int32_t v = xq_on * (flt[k] - u);
+        const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+        err += e * e;
+      }
+      dat += dat_stride;
+      src += src_stride;
+      flt += flt_stride;
+    }
+  } else {  // Neither filter is enabled
+    for (i = 0; i < height; ++i) {
+      __m128i sum32 = _mm_setzero_si128();
+      for (j = 0; j <= width - 16; j += 16) {
+        // Load 2x8 u16 from source image
+        const __m128i s0 = xx_loadu_128(src + j);
+        const __m128i s1 = xx_loadu_128(src + j + 8);
+        // Load 2x8 u16 from corrupted image
+        const __m128i d0 = xx_loadu_128(dat + j);
+        const __m128i d1 = xx_loadu_128(dat + j + 8);
+
+        // Subtract corrupted image from source image
+        const __m128i diff0 = _mm_sub_epi16(d0, s0);
+        const __m128i diff1 = _mm_sub_epi16(d1, s1);
+
+        // Square error and add adjacent values
+        const __m128i err0 = _mm_madd_epi16(diff0, diff0);
+        const __m128i err1 = _mm_madd_epi16(diff1, diff1);
+
+        sum32 = _mm_add_epi32(sum32, err0);
+        sum32 = _mm_add_epi32(sum32, err1);
+      }
+
+      const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
+      sum64 = _mm_add_epi64(sum64, sum32l);
+      const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
+      sum64 = _mm_add_epi64(sum64, sum32h);
+
+      // Process remaining pixels (modulu 8)
+      for (k = j; k < width; ++k) {
+        const int32_t e = (int32_t)(dat[k]) - src[k];
+        err += e * e;
+      }
+      dat += dat_stride;
+      src += src_stride;
+    }
+  }
+
+  // Sum 4 values from sum64l and sum64h into err
+  int64_t sum[2];
+  xx_storeu_128(sum, sum64);
+  err += sum[0] + sum[1];
+  return err;
+}
diff --git a/test/pickrst_test.cc b/test/pickrst_test.cc
index 040e8e8..68b6621 100644
--- a/test/pickrst_test.cc
+++ b/test/pickrst_test.cc
@@ -23,7 +23,7 @@
 
 #define MAX_DATA_BLOCK 384
 
-namespace {
+namespace pickrst_test_lowbd {
 static const int kIterations = 100;
 
 typedef int64_t (*lowbd_pixel_proj_error_func)(
@@ -184,4 +184,169 @@
                         ::testing::Values(av1_lowbd_pixel_proj_error_avx2));
 #endif  // HAVE_AVX2
 
-}  // namespace
+}  // namespace pickrst_test_lowbd
+
+namespace pickrst_test_highbd {
+static const int kIterations = 100;
+
+typedef int64_t (*highbd_pixel_proj_error_func)(
+    const uint8_t *src8, int width, int height, int src_stride,
+    const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+    int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params);
+
+typedef libaom_test::FuncParam<highbd_pixel_proj_error_func> TestFuncs;
+
+////////////////////////////////////////////////////////////////////////////////
+// High bit-depth
+////////////////////////////////////////////////////////////////////////////////
+
+typedef ::testing::tuple<const highbd_pixel_proj_error_func>
+    PixelProjErrorTestParam;
+
+class PixelProjHighbdErrorTest
+    : public ::testing::TestWithParam<PixelProjErrorTestParam> {
+ public:
+  virtual void SetUp() {
+    target_func_ = GET_PARAM(0);
+    src_ =
+        (uint16_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*src_));
+    dgd_ =
+        (uint16_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*dgd_));
+    flt0_ =
+        (int32_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*flt0_));
+    flt1_ =
+        (int32_t *)aom_malloc(MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*flt1_));
+  }
+  virtual void TearDown() {
+    aom_free(src_);
+    aom_free(dgd_);
+    aom_free(flt0_);
+    aom_free(flt1_);
+  }
+  void runPixelProjErrorTest(int32_t run_times);
+  void runPixelProjErrorTest_ExtremeValues();
+
+ private:
+  highbd_pixel_proj_error_func target_func_;
+  ACMRandom rng_;
+  uint16_t *src_;
+  uint16_t *dgd_;
+  int32_t *flt0_;
+  int32_t *flt1_;
+};
+
+void PixelProjHighbdErrorTest::runPixelProjErrorTest(int32_t run_times) {
+  int h_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1;
+  int v_end = run_times != 1 ? 128 : (rng_.Rand16() % MAX_DATA_BLOCK) + 1;
+  const int dgd_stride = MAX_DATA_BLOCK;
+  const int src_stride = MAX_DATA_BLOCK;
+  const int flt0_stride = MAX_DATA_BLOCK;
+  const int flt1_stride = MAX_DATA_BLOCK;
+  sgr_params_type params;
+  int xq[2];
+  const int iters = run_times == 1 ? kIterations : 4;
+  for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
+    int64_t err_ref = 0, err_test = 1;
+    for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) {
+      dgd_[i] = rng_.Rand16() % (1 << 12);
+      src_[i] = rng_.Rand16() % (1 << 12);
+      flt0_[i] = rng_.Rand15Signed();
+      flt1_[i] = rng_.Rand15Signed();
+    }
+    xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+    xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+    params.r[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2);
+    params.r[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2);
+    params.s[0] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter % 2);
+    params.s[1] = run_times == 1 ? (rng_.Rand8() % MAX_RADIUS) : (iter / 2);
+    uint8_t *dgd8 = CONVERT_TO_BYTEPTR(dgd_);
+    uint8_t *src8 = CONVERT_TO_BYTEPTR(src_);
+
+    aom_usec_timer timer;
+    aom_usec_timer_start(&timer);
+    for (int i = 0; i < run_times; ++i) {
+      err_ref = av1_highbd_pixel_proj_error_c(
+          src8, h_end, v_end, src_stride, dgd8, dgd_stride, flt0_, flt0_stride,
+          flt1_, flt1_stride, xq, &params);
+    }
+    aom_usec_timer_mark(&timer);
+    const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
+    aom_usec_timer_start(&timer);
+    for (int i = 0; i < run_times; ++i) {
+      err_test =
+          target_func_(src8, h_end, v_end, src_stride, dgd8, dgd_stride, flt0_,
+                       flt0_stride, flt1_, flt1_stride, xq, &params);
+    }
+    aom_usec_timer_mark(&timer);
+    const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
+    if (run_times > 10) {
+      printf("r0 %d r1 %d %3dx%-3d:%7.2f/%7.2fns (%3.2f)\n", params.r[0],
+             params.r[1], h_end, v_end, time1, time2, time1 / time2);
+    }
+    ASSERT_EQ(err_ref, err_test);
+  }
+}
+
+void PixelProjHighbdErrorTest::runPixelProjErrorTest_ExtremeValues() {
+  const int h_start = 0;
+  int h_end = 192;
+  const int v_start = 0;
+  int v_end = 192;
+  const int dgd_stride = MAX_DATA_BLOCK;
+  const int src_stride = MAX_DATA_BLOCK;
+  const int flt0_stride = MAX_DATA_BLOCK;
+  const int flt1_stride = MAX_DATA_BLOCK;
+  sgr_params_type params;
+  int xq[2];
+  const int iters = kIterations;
+  for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
+    int64_t err_ref = 0, err_test = 1;
+    for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) {
+      dgd_[i] = 0;
+      src_[i] = (1 << 12) - 1;
+      flt0_[i] = rng_.Rand15Signed();
+      flt1_[i] = rng_.Rand15Signed();
+    }
+    xq[0] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+    xq[1] = rng_.Rand8() % (1 << SGRPROJ_PRJ_BITS);
+    params.r[0] = rng_.Rand8() % MAX_RADIUS;
+    params.r[1] = rng_.Rand8() % MAX_RADIUS;
+    params.s[0] = rng_.Rand8() % MAX_RADIUS;
+    params.s[1] = rng_.Rand8() % MAX_RADIUS;
+    uint8_t *dgd8 = CONVERT_TO_BYTEPTR(dgd_);
+    uint8_t *src8 = CONVERT_TO_BYTEPTR(src_);
+
+    err_ref = av1_highbd_pixel_proj_error_c(
+        src8, h_end - h_start, v_end - v_start, src_stride, dgd8, dgd_stride,
+        flt0_, flt0_stride, flt1_, flt1_stride, xq, &params);
+
+    err_test = target_func_(src8, h_end - h_start, v_end - v_start, src_stride,
+                            dgd8, dgd_stride, flt0_, flt0_stride, flt1_,
+                            flt1_stride, xq, &params);
+
+    ASSERT_EQ(err_ref, err_test);
+  }
+}
+
+TEST_P(PixelProjHighbdErrorTest, RandomValues) { runPixelProjErrorTest(1); }
+
+TEST_P(PixelProjHighbdErrorTest, ExtremeValues) {
+  runPixelProjErrorTest_ExtremeValues();
+}
+
+TEST_P(PixelProjHighbdErrorTest, DISABLED_Speed) {
+  runPixelProjErrorTest(200000);
+}
+
+#if HAVE_SSE4_1
+INSTANTIATE_TEST_CASE_P(SSE4_1, PixelProjHighbdErrorTest,
+                        ::testing::Values(av1_highbd_pixel_proj_error_sse4_1));
+#endif  // HAVE_SSE4_1
+
+#if HAVE_AVX2
+
+INSTANTIATE_TEST_CASE_P(AVX2, PixelProjHighbdErrorTest,
+                        ::testing::Values(av1_highbd_pixel_proj_error_avx2));
+#endif  // HAVE_AVX2
+
+}  // namespace pickrst_test_highbd