Additional acceleration of the loop restoration filter.

Change-Id: Icc25eceda25b94cc643fcde6572f3a2ae3bcc7d0
diff --git a/libav1/dx/shaders/loop_restoration.hlsl b/libav1/dx/shaders/loop_restoration.hlsl
index 989b417..e912fc4 100644
--- a/libav1/dx/shaders/loop_restoration.hlsl
+++ b/libav1/dx/shaders/loop_restoration.hlsl
@@ -88,45 +88,46 @@
 groupshared int B[WG_HEIGHT + 2][WG_WIDTH + 2];
 
 #define get_loaded_source_sample(x, y) input[y + 3][x + 4]
-void box_filter0(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
+int box_filter0(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
   uint n = (2 * r + 1) * (2 * r + 1);
-  int i;
-  for (i = ly - 1; i < h + 1; i += WG_HEIGHT) {
-    for (int j = lx - 1; j < w + 1; j += WG_WIDTH) {
-      uint a = 0;
-      uint b = 0;
-      for (int dy = -r; dy <= r; dy++) {
-        for (int dx = -r; dx <= r; dx++) {
-          uint c = get_loaded_source_sample(j + dx, i + dy);
-          a += c * c;
-          b += c;
-        }
+  int id = ly * WG_WIDTH + lx;
+  for (int idx = id; idx < (h + 2)*(w + 2); idx += WG_HEIGHT * WG_WIDTH) {
+    int i = idx / (w + 2) - 1;
+    int j = idx % (w + 2) - 1;
+    uint a = 0;
+    uint b = 0;
+    for (int dy = -2; dy <= 2; dy++) {
+      for (int dx = -2; dx <= 2; dx++) {
+        uint c = get_loaded_source_sample(j + dx, i + dy);
+        a += c * c;
+        b += c;
       }
-      a = Round2(a, 2 * (bit_depth - 8));
-      uint d = Round2(b, bit_depth - 8);
-      uint p = max(0, int(a * n - d * d));
-      uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS);  // p*s in documentation
-      z = min(z, 255);
-      // int a2 = x_by_xplus1[z];
-      uint a2 = 0;
-      if (z >= 255)
-        a2 = 256;
-      else if (z == 0)
-        a2 = 1;
-      else
-        a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
-      uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
-      uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
-      A[1 + i][1 + j] = a2;
-      B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
     }
+    a = Round2(a, 2 * (bit_depth - 8));
+    uint d = Round2(b, bit_depth - 8);
+    uint p = max(0, int(a * n - d * d));
+    uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS);  // p*s in documentation
+    z = min(z, 255);
+    uint a2 = 0;
+    if (z >= 255)
+      a2 = 256;
+    else if (z == 0)
+      a2 = 1;
+    else
+      a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
+    uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
+    uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
+    A[1 + i][1 + j] = a2;
+    B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
   }
-  for (i = ly; i < h; i += WG_HEIGHT) {
+  //for (int i = ly; i < h; i += WG_HEIGHT) {
+  { int i = ly;
     int shift = 5;  // -((1 - stage) * (i & 1));
     if (i & 1) {
       shift = 4;
     }
-    for (int j = lx; j < w; j += WG_WIDTH) {
+    //for (int j = lx; j < w; j += WG_WIDTH) {
+    { int j = lx;
       int a = 0;
       int b = 0;
       for (int dy = -1; dy <= 1; dy++) {
@@ -142,46 +143,48 @@
         }
       }
       int v = a * get_loaded_source_sample(j, i) + b;
-      flt[0][i][j] = Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
+      return Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
     }
   }
 }
 
-void box_filter1(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
+int box_filter1(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
   uint n = (2 * r + 1) * (2 * r + 1);
-  int i;
-  for (i = ly - 1; i < h + 1; i += WG_HEIGHT) {
-    for (int j = lx - 1; j < w + 1; j += WG_WIDTH) {
-      uint a = 0;
-      uint b = 0;
-      for (int dy = -r; dy <= r; dy++) {
-        for (int dx = -r; dx <= r; dx++) {
-          uint c = get_loaded_source_sample(j + dx, i + dy);
-          a += c * c;
-          b += c;
-        }
+  int id = ly * WG_WIDTH + lx;
+  for (int idx = id; idx < (h + 2)*(w + 2); idx += WG_HEIGHT * WG_WIDTH) {
+    int i = idx / (w + 2) - 1;
+    int j = idx % (w + 2) - 1;
+    uint a = 0;
+    uint b = 0;
+    for (int dy = -1; dy <= 1; dy++) {
+      for (int dx = -1; dx <= 1; dx++) {
+        uint c = get_loaded_source_sample(j + dx, i + dy);
+        a += c * c;
+        b += c;
       }
-      a = Round2(a, 2 * (bit_depth - 8));
-      uint d = Round2(b, bit_depth - 8);
-      uint p = max(0, int(a * n - d * d));
-      uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS);  // p*s in documentation
-      z = min(z, 255);
-      uint a2 = 0;
-      if (z >= 255)
-        a2 = 256;
-      else if (z == 0)
-        a2 = 1;
-      else
-        a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
-      uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
-      uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
-      A[1 + i][1 + j] = a2;
-      B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
     }
+    a = Round2(a, 2 * (bit_depth - 8));
+    uint d = Round2(b, bit_depth - 8);
+    uint p = max(0, int(a * n - d * d));
+    uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS);  // p*s in documentation
+    z = min(z, 255);
+    uint a2 = 0;
+    if (z >= 255)
+      a2 = 256;
+    else if (z == 0)
+      a2 = 1;
+    else
+      a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
+    uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
+    uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
+    A[1 + i][1 + j] = a2;
+    B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
   }
-  for (i = ly; i < h; i += WG_HEIGHT) {
+  //for (i = ly; i < h; i += WG_HEIGHT) {
+  { int i = ly;
     int shift = 5;  // -((1 - stage) * (i & 1));
-    for (int j = lx; j < w; j += WG_WIDTH) {
+    //for (int j = lx; j < w; j += WG_WIDTH) {
+    { int j = lx;
       int a = 0;
       int b = 0;
       for (int dy = -1; dy <= 1; dy++) {
@@ -193,7 +196,7 @@
         }
       }
       int v = a * get_loaded_source_sample(j, i) + b;
-      flt[1][i][j] = Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
+      return Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
     }
   }
 }
@@ -329,13 +332,19 @@
     int eps0 = data.Sgr_Params[rType.w].z;
     int eps1 = data.Sgr_Params[rType.w].w;
 
-    box_filter0(WG_WIDTH, WG_HEIGHT, r0, eps0, lx, ly, bit_depth);
-    box_filter1(WG_WIDTH, WG_HEIGHT, r1, eps1, lx, ly, bit_depth);
-
     int u = input[ly + 3][lx + 4] << SGRPROJ_RST_BITS;
     int v = w1 * u;
-    v += w0 * (r0 ? flt[0][ly][lx] : u);
-    v += w2 * (r1 ? flt[1][ly][lx] : u);
+    if (r0) {
+        v += w0 * box_filter0(WG_WIDTH, WG_HEIGHT, r0, eps0, lx, ly, bit_depth);
+    } else {
+        v += w0 * u;
+    }
+    if (r1) {
+        v += w2 * box_filter1(WG_WIDTH, WG_HEIGHT, r1, eps1, lx, ly, bit_depth);
+    } else {
+        v += w2 * u;
+    }
+    
     int s = Round2(v, (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS));
     output[ly][lx] = clamp(s, 0, (1 << bit_depth) - 1);
     if (lx < WG_WIDTH / 4) {