Handle non-multiple-of-4 widths in SSE4.1 self-guided filter

Adjust the vectorized filter so that it can handle tile widths
which are not a multiple of 4, so we do not have to fall back
to the C version of the filter.

Negligible speed impact for tiles with widths which are multiples
of 4, and greatly improves speed on tiles with non-multiple-of-4
widths.

Change-Id: Iae9d14f812c52c6f66910d27da1d8e98930df7ba
diff --git a/av1/common/x86/selfguided_sse4.c b/av1/common/x86/selfguided_sse4.c
index 9ec89e2..41dbbaf 100644
--- a/av1/common/x86/selfguided_sse4.c
+++ b/av1/common/x86/selfguided_sse4.c
@@ -41,7 +41,7 @@
                                 x_by_xplus1[_mm_extract_epi32(z, 1)],
                                 x_by_xplus1[_mm_extract_epi32(z, 0)]);
 
-  _mm_store_si128((__m128i *)&A[idx], a_res);
+  _mm_storeu_si128((__m128i *)&A[idx], a_res);
 
   __m128i rounding_res = _mm_set1_epi32((1 << SGRPROJ_RECIP_BITS) >> 1);
   __m128i a_complement = _mm_sub_epi32(_mm_set1_epi32(SGRPROJ_SGR), a_res);
@@ -50,7 +50,7 @@
   __m128i b_res =
       _mm_srli_epi32(_mm_add_epi32(b_int, rounding_res), SGRPROJ_RECIP_BITS);
 
-  _mm_store_si128((__m128i *)&B[idx], b_res);
+  _mm_storeu_si128((__m128i *)&B[idx], b_res);
 }
 
 static void selfguided_restoration_1(uint8_t *src, int width, int height,
@@ -59,8 +59,11 @@
   int i, j;
 
   // Vertical sum
-  assert(!(width & 3));
-  for (j = 0; j < width; j += 4) {
+  // When the width is not a multiple of 4, we know that 'stride' is rounded up
+  // to a multiple of 4. So it is safe for this loop to calculate extra columns
+  // at the right-hand edge of the frame.
+  int width_extend = (width + 3) & ~3;
+  for (j = 0; j < width_extend; j += 4) {
     __m128i a, b, x, y, x2, y2;
     __m128i sum, sum_sq, tmp;
 
@@ -140,12 +143,20 @@
     s = _mm_set1_epi32(sgrproj_mtable[eps - 1][3 * h - 1]);
 
     // Re-align a1 and b1 so that they start at index i * buf_stride + 3
-    a1 = _mm_alignr_epi8(a2, a1, 12);
-    b1 = _mm_alignr_epi8(b2, b1, 12);
-    a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + 7]);
-    b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + 7]);
+    a2 = _mm_alignr_epi8(a2, a1, 12);
+    b2 = _mm_alignr_epi8(b2, b1, 12);
 
-    for (j = 4; j < width - 4; j += 4) {
+    // Note: When the width is not a multiple of 4, this loop may end up
+    // writing to the last 4 columns of the frame, potentially with incorrect
+    // values (especially for r=2 and r=3).
+    // This is fine, since we fix up those values in the block after this
+    // loop, and in exchange we never have more than four values to
+    // write / fix up after this loop finishes.
+    for (j = 4; j < width_extend - 4; j += 4) {
+      a1 = a2;
+      b1 = b2;
+      a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 3]);
+      b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 3]);
       /* Loop invariant: At this point,
          a1 = original A[i * buf_stride + j - 1 : i * buf_stride + j + 3]
          a2 = original A[i * buf_stride + j + 3 : i * buf_stride + j + 7]
@@ -157,12 +168,38 @@
                                                 _mm_alignr_epi8(a2, a1, 8)));
       calc_block(sum_, sum_sq_, n, one_over_n, s, bit_depth, i * buf_stride + j,
                  A, B);
-
-      a1 = a2;
-      a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 7]);
-      b1 = b2;
-      b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 7]);
     }
+    __m128i a3 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 3]);
+    __m128i b3 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 3]);
+
+    j = width - 4;
+    switch (width % 4) {
+      case 0:
+        a1 = a2;
+        b1 = b2;
+        a2 = a3;
+        b2 = b3;
+        break;
+      case 1:
+        a1 = _mm_alignr_epi8(a2, a1, 4);
+        b1 = _mm_alignr_epi8(b2, b1, 4);
+        a2 = _mm_alignr_epi8(a3, a2, 4);
+        b2 = _mm_alignr_epi8(b3, b2, 4);
+        break;
+      case 2:
+        a1 = _mm_alignr_epi8(a2, a1, 8);
+        b1 = _mm_alignr_epi8(b2, b1, 8);
+        a2 = _mm_alignr_epi8(a3, a2, 8);
+        b2 = _mm_alignr_epi8(b3, b2, 8);
+        break;
+      case 3:
+        a1 = _mm_alignr_epi8(a2, a1, 12);
+        b1 = _mm_alignr_epi8(b2, b1, 12);
+        a2 = _mm_alignr_epi8(a3, a2, 12);
+        b2 = _mm_alignr_epi8(b3, b2, 12);
+        break;
+    }
+
     // Zero out the data loaded from "off the edge" of the array
     __m128i zero = _mm_setzero_si128();
     a2 = _mm_blend_epi16(a2, zero, 0xfc);
@@ -189,8 +226,8 @@
   int i, j;
 
   // Vertical sum
-  assert(!(width & 3));
-  for (j = 0; j < width; j += 4) {
+  int width_extend = (width + 3) & ~3;
+  for (j = 0; j < width_extend; j += 4) {
     __m128i a, b, c, c2, x, y, x2, y2;
     __m128i sum, sum_sq, tmp;
 
@@ -290,16 +327,18 @@
                B);
 
     // Re-align a1 and b1 so that they start at index i * buf_stride + 2
-    a1 = _mm_alignr_epi8(a2, a1, 8);
-    b1 = _mm_alignr_epi8(b2, b1, 8);
-    a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + 6]);
-    b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + 6]);
+    a2 = _mm_alignr_epi8(a2, a1, 8);
+    b2 = _mm_alignr_epi8(b2, b1, 8);
 
     n = _mm_set1_epi32(5 * h);
     one_over_n = _mm_set1_epi32(one_by_x[5 * h - 1]);
     s = _mm_set1_epi32(sgrproj_mtable[eps - 1][5 * h - 1]);
 
-    for (j = 4; j < width - 4; j += 4) {
+    for (j = 4; j < width_extend - 4; j += 4) {
+      a1 = a2;
+      a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 2]);
+      b1 = b2;
+      b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 2]);
       /* Loop invariant: At this point,
          a1 = original A[i * buf_stride + j - 2 : i * buf_stride + j + 2]
          a2 = original A[i * buf_stride + j + 2 : i * buf_stride + j + 6]
@@ -316,12 +355,40 @@
 
       calc_block(sum_, sum_sq_, n, one_over_n, s, bit_depth, i * buf_stride + j,
                  A, B);
-
-      a1 = a2;
-      a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 6]);
-      b1 = b2;
-      b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 6]);
     }
+    // If the width is not a multiple of 4, we need to reset j to width - 4
+    // and adjust a1, a2, b1, b2 so that the loop invariant above is maintained
+    __m128i a3 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 2]);
+    __m128i b3 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 2]);
+
+    j = width - 4;
+    switch (width % 4) {
+      case 0:
+        a1 = a2;
+        b1 = b2;
+        a2 = a3;
+        b2 = b3;
+        break;
+      case 1:
+        a1 = _mm_alignr_epi8(a2, a1, 4);
+        b1 = _mm_alignr_epi8(b2, b1, 4);
+        a2 = _mm_alignr_epi8(a3, a2, 4);
+        b2 = _mm_alignr_epi8(b3, b2, 4);
+        break;
+      case 2:
+        a1 = _mm_alignr_epi8(a2, a1, 8);
+        b1 = _mm_alignr_epi8(b2, b1, 8);
+        a2 = _mm_alignr_epi8(a3, a2, 8);
+        b2 = _mm_alignr_epi8(b3, b2, 8);
+        break;
+      case 3:
+        a1 = _mm_alignr_epi8(a2, a1, 12);
+        b1 = _mm_alignr_epi8(b2, b1, 12);
+        a2 = _mm_alignr_epi8(a3, a2, 12);
+        b2 = _mm_alignr_epi8(b3, b2, 12);
+        break;
+    }
+
     // Zero out the data loaded from "off the edge" of the array
     __m128i zero = _mm_setzero_si128();
     a2 = _mm_blend_epi16(a2, zero, 0xf0);
@@ -353,8 +420,8 @@
   int i, j;
 
   // Vertical sum over 7-pixel regions, 4 columns at a time
-  assert(!(width & 3));
-  for (j = 0; j < width; j += 4) {
+  int width_extend = (width + 3) & ~3;
+  for (j = 0; j < width_extend; j += 4) {
     __m128i a, b, c, d, x, y, x2, y2;
     __m128i sum, sum_sq, tmp, tmp2;
 
@@ -476,16 +543,18 @@
                B);
 
     // Re-align a1 and b1 so that they start at index i * buf_stride + 1
-    a1 = _mm_alignr_epi8(a2, a1, 4);
-    b1 = _mm_alignr_epi8(b2, b1, 4);
-    a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + 5]);
-    b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + 5]);
+    a2 = _mm_alignr_epi8(a2, a1, 4);
+    b2 = _mm_alignr_epi8(b2, b1, 4);
 
     n = _mm_set1_epi32(7 * h);
     one_over_n = _mm_set1_epi32(one_by_x[7 * h - 1]);
     s = _mm_set1_epi32(sgrproj_mtable[eps - 1][7 * h - 1]);
 
-    for (j = 4; j < width - 4; j += 4) {
+    for (j = 4; j < width_extend - 4; j += 4) {
+      a1 = a2;
+      a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 1]);
+      b1 = b2;
+      b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 1]);
       __m128i a3 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 5]);
       __m128i b3 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 5]);
       /* Loop invariant: At this point,
@@ -509,12 +578,38 @@
 
       calc_block(sum_, sum_sq_, n, one_over_n, s, bit_depth, i * buf_stride + j,
                  A, B);
-
-      a1 = a2;
-      a2 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 5]);
-      b1 = b2;
-      b2 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 5]);
     }
+    __m128i a3 = _mm_loadu_si128((__m128i *)&A[i * buf_stride + j + 1]);
+    __m128i b3 = _mm_loadu_si128((__m128i *)&B[i * buf_stride + j + 1]);
+
+    j = width - 4;
+    switch (width % 4) {
+      case 0:
+        a1 = a2;
+        b1 = b2;
+        a2 = a3;
+        b2 = b3;
+        break;
+      case 1:
+        a1 = _mm_alignr_epi8(a2, a1, 4);
+        b1 = _mm_alignr_epi8(b2, b1, 4);
+        a2 = _mm_alignr_epi8(a3, a2, 4);
+        b2 = _mm_alignr_epi8(b3, b2, 4);
+        break;
+      case 2:
+        a1 = _mm_alignr_epi8(a2, a1, 8);
+        b1 = _mm_alignr_epi8(b2, b1, 8);
+        a2 = _mm_alignr_epi8(a3, a2, 8);
+        b2 = _mm_alignr_epi8(b3, b2, 8);
+        break;
+      case 3:
+        a1 = _mm_alignr_epi8(a2, a1, 12);
+        b1 = _mm_alignr_epi8(b2, b1, 12);
+        a2 = _mm_alignr_epi8(a3, a2, 12);
+        b2 = _mm_alignr_epi8(b3, b2, 12);
+        break;
+    }
+
     // Zero out the data loaded from "off the edge" of the array
     __m128i zero = _mm_setzero_si128();
     a2 = _mm_blend_epi16(a2, zero, 0xc0);
@@ -775,14 +870,10 @@
   int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
   int i, j;
   assert(width * height <= RESTORATION_TILEPELS_MAX);
-  // The SSE4.1 code currently only supports tiles which are a multiple of 4
-  // pixels wide (but has no height restriction). If this is not the case,
-  // we fall back to the C version.
-  // Similarly, highbitdepth mode is not fully supported yet, so drop back
-  // to the C code in that case.
-  // TODO(david.barker): Allow non-multiple-of-4 widths and bit_depth > 8
-  // in the SSE4.1 code.
-  if ((width & 3) || bit_depth != 8) {
+  // The SSE4.1 code does not currently support highbitdepth, so drop back
+  // to the C filter in that case.
+  // TODO(david.barker): Allow bit_depth > 8 in the SSE4.1 code.
+  if (bit_depth != 8) {
     apply_selfguided_restoration_c(dat, width, height, stride, bit_depth, eps,
                                    xqd, dst, dst_stride, tmpbuf);
     return;
diff --git a/test/selfguided_filter_test.cc b/test/selfguided_filter_test.cc
index 5577254..be02701 100644
--- a/test/selfguided_filter_test.cc
+++ b/test/selfguided_filter_test.cc
@@ -62,7 +62,7 @@
     };
     // Fix a parameter set, since the speed depends slightly on r.
     // Change this to test different combinations of values of r.
-    int eps = 4;
+    int eps = 15;
 
     av1_loop_restoration_precal();
 
@@ -84,7 +84,7 @@
 
   void RunCorrectnessTest() {
     const int w = 256, h = 256, stride = 672, out_stride = 672;
-    const int NUM_ITERS = 250;
+    const int NUM_ITERS = 81;
     int i, j, k;
 
     uint8_t *input = new uint8_t[stride * h];
@@ -98,8 +98,8 @@
     av1_loop_restoration_precal();
 
     for (i = 0; i < NUM_ITERS; ++i) {
-      for (j = 0; i < h; ++i)
-        for (k = 0; j < w; ++j) input[j * stride + k] = rnd.Rand16() & 0xFF;
+      for (j = 0; j < h; ++j)
+        for (k = 0; k < w; ++k) input[j * stride + k] = rnd.Rand16() & 0xFF;
 
       int xqd[2] = {
         SGRPROJ_PRJ_MIN0 +
@@ -109,12 +109,16 @@
       };
       int eps = rnd.PseudoUniform(1 << SGRPROJ_PARAMS_BITS);
 
-      apply_selfguided_restoration(input, w, h, stride, 8, eps, xqd, output,
-                                   out_stride, tmpbuf);
-      apply_selfguided_restoration_c(input, w, h, stride, 8, eps, xqd, output2,
-                                     out_stride, tmpbuf);
-      for (j = 0; j < h; ++j)
-        for (k = 0; k < w; ++k)
+      // Test various tile sizes around 256x256
+      int test_w = w + 4 - (i / 9);
+      int test_h = h + 4 - (i % 9);
+
+      apply_selfguided_restoration(input, test_w, test_h, stride, 8, eps, xqd,
+                                   output, out_stride, tmpbuf);
+      apply_selfguided_restoration_c(input, test_w, test_h, stride, 8, eps, xqd,
+                                     output2, out_stride, tmpbuf);
+      for (j = 0; j < test_h; ++j)
+        for (k = 0; k < test_w; ++k)
           ASSERT_EQ(output[j * out_stride + k], output2[j * out_stride + k]);
     }