Vectorize new highpass filter for loop-restoration

Change-Id: Ibe5d4933f599456cb496f636de244694bc786a4c
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index d4e0023..347b1de 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -786,8 +786,8 @@
   add_proto qw/void av1_selfguided_restoration/, "uint8_t *dgd, int width, int height, int stride, int32_t *dst, int dst_stride, int r, int eps, int32_t *tmpbuf";
   specialize qw/av1_selfguided_restoration sse4_1/;
 
-  add_proto qw/void av1_highpass_filter/, "uint8_t *dgd, int width, int height, int stride, int32_t *dst, int dst_stride, int r, int eps, int32_t *tmpbuf";
-  specialize qw/av1_highpass_filter/;
+  add_proto qw/void av1_highpass_filter/, "uint8_t *dgd, int width, int height, int stride, int32_t *dst, int dst_stride, int r, int eps";
+  specialize qw/av1_highpass_filter sse4_1/;
 
   if (aom_config("CONFIG_AOM_HIGHBITDEPTH") eq "yes") {
     add_proto qw/void apply_selfguided_restoration_highbd/, "uint16_t *dat, int width, int height, int stride, int bit_depth, int eps, int *xqd, uint16_t *dst, int dst_stride, int32_t *tmpbuf";
@@ -796,8 +796,8 @@
     add_proto qw/void av1_selfguided_restoration_highbd/, "uint16_t *dgd, int width, int height, int stride, int32_t *dst, int dst_stride, int bit_depth, int r, int eps, int32_t *tmpbuf";
     specialize qw/av1_selfguided_restoration_highbd sse4_1/;
 
-    add_proto qw/void av1_highpass_filter_highbd/, "uint16_t *dgd, int width, int height, int stride, int32_t *dst, int dst_stride, int r, int eps, int32_t *tmpbuf";
-    specialize qw/av1_highpass_filter_highbd/;
+    add_proto qw/void av1_highpass_filter_highbd/, "uint16_t *dgd, int width, int height, int stride, int32_t *dst, int dst_stride, int r, int eps";
+    specialize qw/av1_highpass_filter_highbd sse4_1/;
   }
 }
 
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index f05d5a1..7ecf01d 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -750,108 +750,96 @@
                                       tmpbuf);
 }
 
-#if USE_HIGHPASS_IN_SGRPROJ
-void av1_highpass_filter_internal(int32_t *A, int width, int height, int stride,
-                                  int corner, int edge, int32_t *tmpbuf) {
-  const int center = (1 << SGRPROJ_RST_BITS) - 4 * (corner + edge);
+void av1_highpass_filter_c(uint8_t *dgd, int width, int height, int stride,
+                           int32_t *dst, int dst_stride, int corner, int edge) {
   int i, j;
-  int buf_stride = ((width + 3) & ~3) + 16;
+  const int center = (1 << SGRPROJ_RST_BITS) - 4 * (corner + edge);
 
   i = 0;
   j = 0;
   {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] = center * A[k] + edge * (A[k + 1] + A[k + stride] + A[k] * 2) +
-                corner * (A[k + stride + 1] + A[k + 1] + A[k + stride] + A[k]);
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k + 1] + dgd[k + stride] + dgd[k] * 2) +
+        corner * (dgd[k + stride + 1] + dgd[k + 1] + dgd[k + stride] + dgd[k]);
   }
   i = 0;
   j = width - 1;
   {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] = center * A[k] + edge * (A[k - 1] + A[k + stride] + A[k] * 2) +
-                corner * (A[k + stride - 1] + A[k - 1] + A[k + stride] + A[k]);
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k - 1] + dgd[k + stride] + dgd[k] * 2) +
+        corner * (dgd[k + stride - 1] + dgd[k - 1] + dgd[k + stride] + dgd[k]);
   }
   i = height - 1;
   j = 0;
   {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] = center * A[k] + edge * (A[k + 1] + A[k - stride] + A[k] * 2) +
-                corner * (A[k - stride + 1] + A[k + 1] + A[k - stride] + A[k]);
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k + 1] + dgd[k - stride] + dgd[k] * 2) +
+        corner * (dgd[k - stride + 1] + dgd[k + 1] + dgd[k - stride] + dgd[k]);
   }
   i = height - 1;
   j = width - 1;
   {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] = center * A[k] + edge * (A[k - 1] + A[k - stride] + A[k] * 2) +
-                corner * (A[k - stride - 1] + A[k - 1] + A[k - stride] + A[k]);
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k - 1] + dgd[k - stride] + dgd[k] * 2) +
+        corner * (dgd[k - stride - 1] + dgd[k - 1] + dgd[k - stride] + dgd[k]);
   }
   i = 0;
   for (j = 1; j < width - 1; ++j) {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] =
-        center * A[k] + edge * (A[k - 1] + A[k + stride] + A[k + 1] + A[k]) +
-        corner * (A[k + stride - 1] + A[k + stride + 1] + A[k - 1] + A[k + 1]);
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - 1] + dgd[k + stride] + dgd[k + 1] + dgd[k]) +
+             corner * (dgd[k + stride - 1] + dgd[k + stride + 1] + dgd[k - 1] +
+                       dgd[k + 1]);
   }
   i = height - 1;
   for (j = 1; j < width - 1; ++j) {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] =
-        center * A[k] + edge * (A[k - 1] + A[k - stride] + A[k + 1] + A[k]) +
-        corner * (A[k - stride - 1] + A[k - stride + 1] + A[k - 1] + A[k + 1]);
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - 1] + dgd[k - stride] + dgd[k + 1] + dgd[k]) +
+             corner * (dgd[k - stride - 1] + dgd[k - stride + 1] + dgd[k - 1] +
+                       dgd[k + 1]);
   }
   j = 0;
   for (i = 1; i < height - 1; ++i) {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] = center * A[k] +
-                edge * (A[k - stride] + A[k + 1] + A[k + stride] + A[k]) +
-                corner * (A[k + stride + 1] + A[k - stride + 1] +
-                          A[k - stride] + A[k + stride]);
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - stride] + dgd[k + 1] + dgd[k + stride] + dgd[k]) +
+             corner * (dgd[k + stride + 1] + dgd[k - stride + 1] +
+                       dgd[k - stride] + dgd[k + stride]);
   }
   j = width - 1;
   for (i = 1; i < height - 1; ++i) {
     const int k = i * stride + j;
-    const int l = i * buf_stride + j;
-    tmpbuf[l] = center * A[k] +
-                edge * (A[k - stride] + A[k - 1] + A[k + stride] + A[k]) +
-                corner * (A[k + stride - 1] + A[k - stride - 1] +
-                          A[k - stride] + A[k + stride]);
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k]) +
+             corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                       dgd[k - stride] + dgd[k + stride]);
   }
   for (i = 1; i < height - 1; ++i) {
     for (j = 1; j < width - 1; ++j) {
       const int k = i * stride + j;
-      const int l = i * buf_stride + j;
-      tmpbuf[l] = center * A[k] +
-                  edge * (A[k - stride] + A[k - 1] + A[k + stride] + A[k + 1]) +
-                  corner * (A[k + stride - 1] + A[k - stride - 1] +
-                            A[k - stride + 1] + A[k + stride + 1]);
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k + 1]) +
+          corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                    dgd[k - stride + 1] + dgd[k + stride + 1]);
     }
   }
-  for (i = 0; i < height; ++i) {
-    memcpy(A + stride * i, tmpbuf + buf_stride * i, sizeof(*A) * width);
-  }
 }
 
-void av1_highpass_filter_c(uint8_t *dgd, int width, int height, int stride,
-                           int32_t *dst, int dst_stride, int corner, int edge,
-                           int32_t *tmpbuf) {
-  int i, j;
-  for (i = 0; i < height; ++i) {
-    for (j = 0; j < width; ++j) {
-      dst[i * dst_stride + j] = dgd[i * stride + j];
-    }
-  }
-  av1_highpass_filter_internal(dst, width, height, dst_stride, corner, edge,
-                               tmpbuf);
-}
-#endif  // USE_HIGHPASS_IN_SGRPROJ
-
 void apply_selfguided_restoration_c(uint8_t *dat, int width, int height,
                                     int stride, int eps, int *xqd, uint8_t *dst,
                                     int dst_stride, int32_t *tmpbuf) {
@@ -863,7 +851,7 @@
   assert(width * height <= RESTORATION_TILEPELS_MAX);
 #if USE_HIGHPASS_IN_SGRPROJ
   av1_highpass_filter_c(dat, width, height, stride, flt1, width,
-                        sgr_params[eps].corner, sgr_params[eps].edge, tmpbuf2);
+                        sgr_params[eps].corner, sgr_params[eps].edge);
 #else
   av1_selfguided_restoration_c(dat, width, height, stride, flt1, width,
                                sgr_params[eps].r1, sgr_params[eps].e1, tmpbuf2);
@@ -1041,15 +1029,93 @@
 
 void av1_highpass_filter_highbd_c(uint16_t *dgd, int width, int height,
                                   int stride, int32_t *dst, int dst_stride,
-                                  int corner, int edge, int32_t *tmpbuf) {
+                                  int corner, int edge) {
   int i, j;
-  for (i = 0; i < height; ++i) {
-    for (j = 0; j < width; ++j) {
-      dst[i * dst_stride + j] = dgd[i * stride + j];
+  const int center = (1 << SGRPROJ_RST_BITS) - 4 * (corner + edge);
+
+  i = 0;
+  j = 0;
+  {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k + 1] + dgd[k + stride] + dgd[k] * 2) +
+        corner * (dgd[k + stride + 1] + dgd[k + 1] + dgd[k + stride] + dgd[k]);
+  }
+  i = 0;
+  j = width - 1;
+  {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k - 1] + dgd[k + stride] + dgd[k] * 2) +
+        corner * (dgd[k + stride - 1] + dgd[k - 1] + dgd[k + stride] + dgd[k]);
+  }
+  i = height - 1;
+  j = 0;
+  {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k + 1] + dgd[k - stride] + dgd[k] * 2) +
+        corner * (dgd[k - stride + 1] + dgd[k + 1] + dgd[k - stride] + dgd[k]);
+  }
+  i = height - 1;
+  j = width - 1;
+  {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] =
+        center * dgd[k] + edge * (dgd[k - 1] + dgd[k - stride] + dgd[k] * 2) +
+        corner * (dgd[k - stride - 1] + dgd[k - 1] + dgd[k - stride] + dgd[k]);
+  }
+  i = 0;
+  for (j = 1; j < width - 1; ++j) {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - 1] + dgd[k + stride] + dgd[k + 1] + dgd[k]) +
+             corner * (dgd[k + stride - 1] + dgd[k + stride + 1] + dgd[k - 1] +
+                       dgd[k + 1]);
+  }
+  i = height - 1;
+  for (j = 1; j < width - 1; ++j) {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - 1] + dgd[k - stride] + dgd[k + 1] + dgd[k]) +
+             corner * (dgd[k - stride - 1] + dgd[k - stride + 1] + dgd[k - 1] +
+                       dgd[k + 1]);
+  }
+  j = 0;
+  for (i = 1; i < height - 1; ++i) {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - stride] + dgd[k + 1] + dgd[k + stride] + dgd[k]) +
+             corner * (dgd[k + stride + 1] + dgd[k - stride + 1] +
+                       dgd[k - stride] + dgd[k + stride]);
+  }
+  j = width - 1;
+  for (i = 1; i < height - 1; ++i) {
+    const int k = i * stride + j;
+    const int l = i * dst_stride + j;
+    dst[l] = center * dgd[k] +
+             edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k]) +
+             corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                       dgd[k - stride] + dgd[k + stride]);
+  }
+  for (i = 1; i < height - 1; ++i) {
+    for (j = 1; j < width - 1; ++j) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k + 1]) +
+          corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                    dgd[k - stride + 1] + dgd[k + stride + 1]);
     }
   }
-  av1_highpass_filter_internal(dst, width, height, dst_stride, corner, edge,
-                               tmpbuf);
 }
 
 void apply_selfguided_restoration_highbd_c(uint16_t *dat, int width, int height,
@@ -1064,8 +1130,7 @@
   assert(width * height <= RESTORATION_TILEPELS_MAX);
 #if USE_HIGHPASS_IN_SGRPROJ
   av1_highpass_filter_highbd_c(dat, width, height, stride, flt1, width,
-                               sgr_params[eps].corner, sgr_params[eps].edge,
-                               tmpbuf2);
+                               sgr_params[eps].corner, sgr_params[eps].edge);
 #else
   av1_selfguided_restoration_highbd_c(dat, width, height, stride, flt1, width,
                                       bit_depth, sgr_params[eps].r1,
diff --git a/av1/common/x86/selfguided_sse4.c b/av1/common/x86/selfguided_sse4.c
index 9c89271..b61991f 100644
--- a/av1/common/x86/selfguided_sse4.c
+++ b/av1/common/x86/selfguided_sse4.c
@@ -883,6 +883,167 @@
   }
 }
 
+void av1_highpass_filter_sse4_1(uint8_t *dgd, int width, int height, int stride,
+                                int32_t *dst, int dst_stride, int corner,
+                                int edge) {
+  int i, j;
+  const int center = (1 << SGRPROJ_RST_BITS) - 4 * (corner + edge);
+
+  {
+    i = 0;
+    j = 0;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k + 1] + dgd[k + stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k + stride + 1] + dgd[k + 1] + dgd[k + stride] + dgd[k]);
+    }
+    for (j = 1; j < width - 1; ++j) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] = center * dgd[k] +
+               edge * (dgd[k - 1] + dgd[k + stride] + dgd[k + 1] + dgd[k]) +
+               corner * (dgd[k + stride - 1] + dgd[k + stride + 1] +
+                         dgd[k - 1] + dgd[k + 1]);
+    }
+    j = width - 1;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k - 1] + dgd[k + stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k + stride - 1] + dgd[k - 1] + dgd[k + stride] + dgd[k]);
+    }
+  }
+  {
+    i = height - 1;
+    j = 0;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k + 1] + dgd[k - stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k - stride + 1] + dgd[k + 1] + dgd[k - stride] + dgd[k]);
+    }
+    for (j = 1; j < width - 1; ++j) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] = center * dgd[k] +
+               edge * (dgd[k - 1] + dgd[k - stride] + dgd[k + 1] + dgd[k]) +
+               corner * (dgd[k - stride - 1] + dgd[k - stride + 1] +
+                         dgd[k - 1] + dgd[k + 1]);
+    }
+    j = width - 1;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k - 1] + dgd[k - stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k - stride - 1] + dgd[k - 1] + dgd[k - stride] + dgd[k]);
+    }
+  }
+  __m128i center_ = _mm_set1_epi16(center);
+  __m128i edge_ = _mm_set1_epi16(edge);
+  __m128i corner_ = _mm_set1_epi16(corner);
+  for (i = 1; i < height - 1; ++i) {
+    j = 0;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k + 1] + dgd[k + stride] + dgd[k]) +
+          corner * (dgd[k + stride + 1] + dgd[k - stride + 1] +
+                    dgd[k - stride] + dgd[k + stride]);
+    }
+    // Process in units of 8 pixels at a time.
+    for (j = 1; j < width - 8; j += 8) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+
+      __m128i a = _mm_loadu_si128((__m128i *)&dgd[k - stride - 1]);
+      __m128i b = _mm_loadu_si128((__m128i *)&dgd[k - 1]);
+      __m128i c = _mm_loadu_si128((__m128i *)&dgd[k + stride - 1]);
+
+      __m128i tl = _mm_cvtepu8_epi16(a);
+      __m128i tr = _mm_cvtepu8_epi16(_mm_srli_si128(a, 8));
+      __m128i cl = _mm_cvtepu8_epi16(b);
+      __m128i cr = _mm_cvtepu8_epi16(_mm_srli_si128(b, 8));
+      __m128i bl = _mm_cvtepu8_epi16(c);
+      __m128i br = _mm_cvtepu8_epi16(_mm_srli_si128(c, 8));
+
+      __m128i x = _mm_alignr_epi8(cr, cl, 2);
+      __m128i y = _mm_add_epi16(_mm_add_epi16(_mm_alignr_epi8(tr, tl, 2), cl),
+                                _mm_add_epi16(_mm_alignr_epi8(br, bl, 2),
+                                              _mm_alignr_epi8(cr, cl, 4)));
+      __m128i z = _mm_add_epi16(_mm_add_epi16(tl, bl),
+                                _mm_add_epi16(_mm_alignr_epi8(tr, tl, 4),
+                                              _mm_alignr_epi8(br, bl, 4)));
+
+      __m128i res = _mm_add_epi16(_mm_mullo_epi16(x, center_),
+                                  _mm_add_epi16(_mm_mullo_epi16(y, edge_),
+                                                _mm_mullo_epi16(z, corner_)));
+
+      _mm_storeu_si128((__m128i *)&dst[l], _mm_cvtepi16_epi32(res));
+      _mm_storeu_si128((__m128i *)&dst[l + 4],
+                       _mm_cvtepi16_epi32(_mm_srli_si128(res, 8)));
+    }
+    // If there are enough pixels left in this row, do another batch of 4
+    // pixels.
+    for (; j < width - 4; j += 4) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+
+      __m128i a = _mm_loadl_epi64((__m128i *)&dgd[k - stride - 1]);
+      __m128i b = _mm_loadl_epi64((__m128i *)&dgd[k - 1]);
+      __m128i c = _mm_loadl_epi64((__m128i *)&dgd[k + stride - 1]);
+
+      __m128i tl = _mm_cvtepu8_epi16(a);
+      __m128i cl = _mm_cvtepu8_epi16(b);
+      __m128i bl = _mm_cvtepu8_epi16(c);
+
+      __m128i x = _mm_srli_si128(cl, 2);
+      __m128i y = _mm_add_epi16(
+          _mm_add_epi16(_mm_srli_si128(tl, 2), cl),
+          _mm_add_epi16(_mm_srli_si128(bl, 2), _mm_srli_si128(cl, 4)));
+      __m128i z = _mm_add_epi16(
+          _mm_add_epi16(tl, bl),
+          _mm_add_epi16(_mm_srli_si128(tl, 4), _mm_srli_si128(bl, 4)));
+
+      __m128i res = _mm_add_epi16(_mm_mullo_epi16(x, center_),
+                                  _mm_add_epi16(_mm_mullo_epi16(y, edge_),
+                                                _mm_mullo_epi16(z, corner_)));
+
+      _mm_storeu_si128((__m128i *)&dst[l], _mm_cvtepi16_epi32(res));
+    }
+    // Handle any leftover pixels
+    for (; j < width - 1; ++j) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k + 1]) +
+          corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                    dgd[k - stride + 1] + dgd[k + stride + 1]);
+    }
+    j = width - 1;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k]) +
+          corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                    dgd[k - stride] + dgd[k + stride]);
+    }
+  }
+}
+
 void apply_selfguided_restoration_sse4_1(uint8_t *dat, int width, int height,
                                          int stride, int eps, int *xqd,
                                          uint8_t *dst, int dst_stride,
@@ -894,8 +1055,8 @@
   int i, j;
   assert(width * height <= RESTORATION_TILEPELS_MAX);
 #if USE_HIGHPASS_IN_SGRPROJ
-  av1_highpass_filter_c(dat, width, height, stride, flt1, width,
-                        sgr_params[eps].corner, sgr_params[eps].edge, tmpbuf2);
+  av1_highpass_filter_sse4_1(dat, width, height, stride, flt1, width,
+                             sgr_params[eps].corner, sgr_params[eps].edge);
 #else
     av1_selfguided_restoration_sse4_1(dat, width, height, stride, flt1, width,
                                       sgr_params[eps].r1, sgr_params[eps].e1,
@@ -1427,6 +1588,137 @@
   }
 }
 
+void av1_highpass_filter_highbd_sse4_1(uint16_t *dgd, int width, int height,
+                                       int stride, int32_t *dst, int dst_stride,
+                                       int corner, int edge) {
+  int i, j;
+  const int center = (1 << SGRPROJ_RST_BITS) - 4 * (corner + edge);
+
+  {
+    i = 0;
+    j = 0;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k + 1] + dgd[k + stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k + stride + 1] + dgd[k + 1] + dgd[k + stride] + dgd[k]);
+    }
+    for (j = 1; j < width - 1; ++j) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] = center * dgd[k] +
+               edge * (dgd[k - 1] + dgd[k + stride] + dgd[k + 1] + dgd[k]) +
+               corner * (dgd[k + stride - 1] + dgd[k + stride + 1] +
+                         dgd[k - 1] + dgd[k + 1]);
+    }
+    j = width - 1;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k - 1] + dgd[k + stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k + stride - 1] + dgd[k - 1] + dgd[k + stride] + dgd[k]);
+    }
+  }
+  __m128i center_ = _mm_set1_epi32(center);
+  __m128i edge_ = _mm_set1_epi32(edge);
+  __m128i corner_ = _mm_set1_epi32(corner);
+  for (i = 1; i < height - 1; ++i) {
+    j = 0;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k + 1] + dgd[k + stride] + dgd[k]) +
+          corner * (dgd[k + stride + 1] + dgd[k - stride + 1] +
+                    dgd[k - stride] + dgd[k + stride]);
+    }
+    // Process 4 pixels at a time
+    for (j = 1; j < width - 4; j += 4) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+
+      __m128i a = _mm_loadu_si128((__m128i *)&dgd[k - stride - 1]);
+      __m128i b = _mm_loadu_si128((__m128i *)&dgd[k - 1]);
+      __m128i c = _mm_loadu_si128((__m128i *)&dgd[k + stride - 1]);
+
+      __m128i tl = _mm_cvtepu16_epi32(a);
+      __m128i tr = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8));
+      __m128i cl = _mm_cvtepu16_epi32(b);
+      __m128i cr = _mm_cvtepu16_epi32(_mm_srli_si128(b, 8));
+      __m128i bl = _mm_cvtepu16_epi32(c);
+      __m128i br = _mm_cvtepu16_epi32(_mm_srli_si128(c, 8));
+
+      __m128i x = _mm_alignr_epi8(cr, cl, 4);
+      __m128i y = _mm_add_epi32(_mm_add_epi32(_mm_alignr_epi8(tr, tl, 4), cl),
+                                _mm_add_epi32(_mm_alignr_epi8(br, bl, 4),
+                                              _mm_alignr_epi8(cr, cl, 8)));
+      __m128i z = _mm_add_epi32(_mm_add_epi32(tl, bl),
+                                _mm_add_epi32(_mm_alignr_epi8(tr, tl, 8),
+                                              _mm_alignr_epi8(br, bl, 8)));
+
+      __m128i res = _mm_add_epi32(_mm_mullo_epi32(x, center_),
+                                  _mm_add_epi32(_mm_mullo_epi32(y, edge_),
+                                                _mm_mullo_epi32(z, corner_)));
+
+      _mm_storeu_si128((__m128i *)&dst[l], res);
+    }
+    // Handle any leftover pixels
+    for (; j < width - 1; ++j) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k + 1]) +
+          corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                    dgd[k - stride + 1] + dgd[k + stride + 1]);
+    }
+    j = width - 1;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] +
+          edge * (dgd[k - stride] + dgd[k - 1] + dgd[k + stride] + dgd[k]) +
+          corner * (dgd[k + stride - 1] + dgd[k - stride - 1] +
+                    dgd[k - stride] + dgd[k + stride]);
+    }
+  }
+  {
+    i = height - 1;
+    j = 0;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k + 1] + dgd[k - stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k - stride + 1] + dgd[k + 1] + dgd[k - stride] + dgd[k]);
+    }
+    for (j = 1; j < width - 1; ++j) {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] = center * dgd[k] +
+               edge * (dgd[k - 1] + dgd[k - stride] + dgd[k + 1] + dgd[k]) +
+               corner * (dgd[k - stride - 1] + dgd[k - stride + 1] +
+                         dgd[k - 1] + dgd[k + 1]);
+    }
+    j = width - 1;
+    {
+      const int k = i * stride + j;
+      const int l = i * dst_stride + j;
+      dst[l] =
+          center * dgd[k] + edge * (dgd[k - 1] + dgd[k - stride] + dgd[k] * 2) +
+          corner *
+              (dgd[k - stride - 1] + dgd[k - 1] + dgd[k - stride] + dgd[k]);
+    }
+  }
+}
+
 void apply_selfguided_restoration_highbd_sse4_1(
     uint16_t *dat, int width, int height, int stride, int bit_depth, int eps,
     int *xqd, uint16_t *dst, int dst_stride, int32_t *tmpbuf) {
@@ -1437,9 +1729,9 @@
   int i, j;
   assert(width * height <= RESTORATION_TILEPELS_MAX);
 #if USE_HIGHPASS_IN_SGRPROJ
-  av1_highpass_filter_highbd_c(dat, width, height, stride, flt1, width,
-                               sgr_params[eps].corner, sgr_params[eps].edge,
-                               tmpbuf2);
+  av1_highpass_filter_highbd_sse4_1(dat, width, height, stride, flt1, width,
+                                    sgr_params[eps].corner,
+                                    sgr_params[eps].edge);
 #else
   av1_selfguided_restoration_highbd_sse4_1(dat, width, height, stride, flt1,
                                            width, bit_depth, sgr_params[eps].r1,
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 55950e1..a9abbe0 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -295,8 +295,7 @@
       uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
 #if USE_HIGHPASS_IN_SGRPROJ
       av1_highpass_filter_highbd(dat, width, height, dat_stride, flt1, width,
-                                 sgr_params[ep].corner, sgr_params[ep].edge,
-                                 tmpbuf2);
+                                 sgr_params[ep].corner, sgr_params[ep].edge);
 #else
       av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                         width, bit_depth, sgr_params[ep].r1,
@@ -309,7 +308,7 @@
 #endif
 #if USE_HIGHPASS_IN_SGRPROJ
       av1_highpass_filter(dat8, width, height, dat_stride, flt1, width,
-                          sgr_params[ep].corner, sgr_params[ep].edge, tmpbuf2);
+                          sgr_params[ep].corner, sgr_params[ep].edge);
 #else
     av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                                sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);