Additional acceleration of the loop restoration filter.
Change-Id: Icc25eceda25b94cc643fcde6572f3a2ae3bcc7d0
diff --git a/libav1/dx/shaders/loop_restoration.hlsl b/libav1/dx/shaders/loop_restoration.hlsl
index 989b417..e912fc4 100644
--- a/libav1/dx/shaders/loop_restoration.hlsl
+++ b/libav1/dx/shaders/loop_restoration.hlsl
@@ -88,45 +88,46 @@
groupshared int B[WG_HEIGHT + 2][WG_WIDTH + 2];
#define get_loaded_source_sample(x, y) input[y + 3][x + 4]
-void box_filter0(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
+int box_filter0(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
uint n = (2 * r + 1) * (2 * r + 1);
- int i;
- for (i = ly - 1; i < h + 1; i += WG_HEIGHT) {
- for (int j = lx - 1; j < w + 1; j += WG_WIDTH) {
- uint a = 0;
- uint b = 0;
- for (int dy = -r; dy <= r; dy++) {
- for (int dx = -r; dx <= r; dx++) {
- uint c = get_loaded_source_sample(j + dx, i + dy);
- a += c * c;
- b += c;
- }
+ int id = ly * WG_WIDTH + lx;
+ for (int idx = id; idx < (h + 2)*(w + 2); idx += WG_HEIGHT * WG_WIDTH) {
+ int i = idx / (w + 2) - 1;
+ int j = idx % (w + 2) - 1;
+ uint a = 0;
+ uint b = 0;
+ for (int dy = -2; dy <= 2; dy++) {
+ for (int dx = -2; dx <= 2; dx++) {
+ uint c = get_loaded_source_sample(j + dx, i + dy);
+ a += c * c;
+ b += c;
}
- a = Round2(a, 2 * (bit_depth - 8));
- uint d = Round2(b, bit_depth - 8);
- uint p = max(0, int(a * n - d * d));
- uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS); // p*s in documentation
- z = min(z, 255);
- // int a2 = x_by_xplus1[z];
- uint a2 = 0;
- if (z >= 255)
- a2 = 256;
- else if (z == 0)
- a2 = 1;
- else
- a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
- uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
- uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
- A[1 + i][1 + j] = a2;
- B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
}
+ a = Round2(a, 2 * (bit_depth - 8));
+ uint d = Round2(b, bit_depth - 8);
+ uint p = max(0, int(a * n - d * d));
+ uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS); // p*s in documentation
+ z = min(z, 255);
+ uint a2 = 0;
+ if (z >= 255)
+ a2 = 256;
+ else if (z == 0)
+ a2 = 1;
+ else
+ a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
+ uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
+ uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
+ A[1 + i][1 + j] = a2;
+ B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
}
- for (i = ly; i < h; i += WG_HEIGHT) {
+ //for (int i = ly; i < h; i += WG_HEIGHT) {
+ { int i = ly;
int shift = 5; // -((1 - stage) * (i & 1));
if (i & 1) {
shift = 4;
}
- for (int j = lx; j < w; j += WG_WIDTH) {
+ //for (int j = lx; j < w; j += WG_WIDTH) {
+ { int j = lx;
int a = 0;
int b = 0;
for (int dy = -1; dy <= 1; dy++) {
@@ -142,46 +143,48 @@
}
}
int v = a * get_loaded_source_sample(j, i) + b;
- flt[0][i][j] = Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
+ return Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
}
}
}
-void box_filter1(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
+int box_filter1(int w, int h, int r, int eps, int lx, int ly, int bit_depth) {
uint n = (2 * r + 1) * (2 * r + 1);
- int i;
- for (i = ly - 1; i < h + 1; i += WG_HEIGHT) {
- for (int j = lx - 1; j < w + 1; j += WG_WIDTH) {
- uint a = 0;
- uint b = 0;
- for (int dy = -r; dy <= r; dy++) {
- for (int dx = -r; dx <= r; dx++) {
- uint c = get_loaded_source_sample(j + dx, i + dy);
- a += c * c;
- b += c;
- }
+ int id = ly * WG_WIDTH + lx;
+ for (int idx = id; idx < (h + 2)*(w + 2); idx += WG_HEIGHT * WG_WIDTH) {
+ int i = idx / (w + 2) - 1;
+ int j = idx % (w + 2) - 1;
+ uint a = 0;
+ uint b = 0;
+ for (int dy = -1; dy <= 1; dy++) {
+ for (int dx = -1; dx <= 1; dx++) {
+ uint c = get_loaded_source_sample(j + dx, i + dy);
+ a += c * c;
+ b += c;
}
- a = Round2(a, 2 * (bit_depth - 8));
- uint d = Round2(b, bit_depth - 8);
- uint p = max(0, int(a * n - d * d));
- uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS); // p*s in documentation
- z = min(z, 255);
- uint a2 = 0;
- if (z >= 255)
- a2 = 256;
- else if (z == 0)
- a2 = 1;
- else
- a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
- uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
- uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
- A[1 + i][1 + j] = a2;
- B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
}
+ a = Round2(a, 2 * (bit_depth - 8));
+ uint d = Round2(b, bit_depth - 8);
+ uint p = max(0, int(a * n - d * d));
+ uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS); // p*s in documentation
+ z = min(z, 255);
+ uint a2 = 0;
+ if (z >= 255)
+ a2 = 256;
+ else if (z == 0)
+ a2 = 1;
+ else
+ a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1);
+ uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n;
+ uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN;
+ A[1 + i][1 + j] = a2;
+ B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS);
}
- for (i = ly; i < h; i += WG_HEIGHT) {
+ //for (i = ly; i < h; i += WG_HEIGHT) {
+ { int i = ly;
int shift = 5; // -((1 - stage) * (i & 1));
- for (int j = lx; j < w; j += WG_WIDTH) {
+ //for (int j = lx; j < w; j += WG_WIDTH) {
+ { int j = lx;
int a = 0;
int b = 0;
for (int dy = -1; dy <= 1; dy++) {
@@ -193,7 +196,7 @@
}
}
int v = a * get_loaded_source_sample(j, i) + b;
- flt[1][i][j] = Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
+ return Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS);
}
}
}
@@ -329,13 +332,19 @@
int eps0 = data.Sgr_Params[rType.w].z;
int eps1 = data.Sgr_Params[rType.w].w;
- box_filter0(WG_WIDTH, WG_HEIGHT, r0, eps0, lx, ly, bit_depth);
- box_filter1(WG_WIDTH, WG_HEIGHT, r1, eps1, lx, ly, bit_depth);
-
int u = input[ly + 3][lx + 4] << SGRPROJ_RST_BITS;
int v = w1 * u;
- v += w0 * (r0 ? flt[0][ly][lx] : u);
- v += w2 * (r1 ? flt[1][ly][lx] : u);
+ if (r0) {
+ v += w0 * box_filter0(WG_WIDTH, WG_HEIGHT, r0, eps0, lx, ly, bit_depth);
+ } else {
+ v += w0 * u;
+ }
+ if (r1) {
+ v += w2 * box_filter1(WG_WIDTH, WG_HEIGHT, r1, eps1, lx, ly, bit_depth);
+ } else {
+ v += w2 * u;
+ }
+
int s = Round2(v, (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS));
output[ly][lx] = clamp(s, 0, (1 << bit_depth) - 1);
if (lx < WG_WIDTH / 4) {