| /* |
| * Copyright 2020 Google LLC |
| * |
| */ |
| |
| /* |
| * Copyright (c) 2020, Alliance for Open Media. All rights reserved |
| * |
| * This source code is subject to the terms of the BSD 2 Clause License and |
| * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License |
| * was not distributed with this source code in the LICENSE file, you can |
| * obtain it at www.aomedia.org/license/software. If the Alliance for Open |
| * Media Patent License 1.0 was not distributed with this source code in the |
| * PATENTS file, you can obtain it at www.aomedia.org/license/patent. |
| */ |
| |
| #pragma warning(disable : 3557) |
| #include "restoration.h" |
| |
| #define MI_SIZE 4 |
| #define MI_SIZE_LOG2 2 |
| #define Round2(value, n) (((value) + (((1 << (n)) >> 1))) >> (n)) |
| #define STRIPE_SIZE 64 |
| #define RESTORE_NONE 0 |
| #define RESTORE_WIENER 1 |
| #define RESTORE_SGRPROJ 2 |
| #define RESTORE_SWITCHABLE 3 |
| #define RESTORE_SWITCHABLE_TYPES RESTORE_SWITCHABLE |
| #define RESTORE_TYPES 4 |
| #define WIENER_ROUND0_BITS 3 |
| #define FILTER_BITS 7 |
| #define InterRound0 WIENER_ROUND0_BITS |
| #define InterRound1 (2 * FILTER_BITS - WIENER_ROUND0_BITS) |
| |
| ByteAddressBuffer src : register(t0); // SRV |
| ByteAddressBuffer LrTypeSgr : register(t1); // SRV |
| ByteAddressBuffer LrWiener : register(t2); // SRV |
| RWByteAddressBuffer dst : register(u0); // UAV |
| |
| cbuffer PlaneRestorationData : register(b0) { |
| struct { |
| int4 Sgr_Params[16]; |
| } data; |
| }; |
| |
| struct PlaneInfo { |
| int stride; |
| int offset; |
| int width; |
| int height; |
| }; |
| |
| struct UnitsInfo { |
| int Rows; |
| int Cols; |
| int Size; |
| int Stride; |
| }; |
| |
| struct PlaneRestorationData { |
| PlaneInfo plane; |
| UnitsInfo units; |
| |
| int pp_offset; |
| int dst_offset; |
| int Lr_buffer_offset; |
| int subsampling; |
| int hbd; |
| int bit_depth; |
| int2 pad; |
| }; |
| cbuffer PlaneRestorationConstBuffer : register(b1) { PlaneRestorationData pl[3]; }; |
| |
| cbuffer cb_loop_rest_data : register(b2) { |
| int do_restoration; |
| int plane_id; |
| } |
| |
| #define WG_WIDTH 16 |
| #define WG_HEIGHT 4 |
| |
| groupshared int input[WG_HEIGHT + 6][WG_WIDTH + 8]; |
| groupshared int output[WG_HEIGHT][WG_WIDTH]; |
| groupshared int intermediate[WG_HEIGHT + 6][WG_WIDTH]; |
| |
| groupshared int flt[2][WG_HEIGHT][WG_WIDTH]; |
| groupshared int A[WG_HEIGHT + 2][WG_WIDTH + 2]; |
| groupshared int B[WG_HEIGHT + 2][WG_WIDTH + 2]; |
| |
| #define get_loaded_source_sample(x, y) input[y + 3][x + 4] |
| void box_filter0(int w, int h, int r, int eps, int lx, int ly, int bit_depth) { |
| uint n = (2 * r + 1) * (2 * r + 1); |
| int i; |
| for (i = ly - 1; i < h + 1; i += WG_HEIGHT) { |
| for (int j = lx - 1; j < w + 1; j += WG_WIDTH) { |
| uint a = 0; |
| uint b = 0; |
| for (int dy = -r; dy <= r; dy++) { |
| for (int dx = -r; dx <= r; dx++) { |
| uint c = get_loaded_source_sample(j + dx, i + dy); |
| a += c * c; |
| b += c; |
| } |
| } |
| a = Round2(a, 2 * (bit_depth - 8)); |
| uint d = Round2(b, bit_depth - 8); |
| uint p = max(0, int(a * n - d * d)); |
| uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS); // p*s in documentation |
| z = min(z, 255); |
| // int a2 = x_by_xplus1[z]; |
| uint a2 = 0; |
| if (z >= 255) |
| a2 = 256; |
| else if (z == 0) |
| a2 = 1; |
| else |
| a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1); |
| uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n; |
| uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN; |
| A[1 + i][1 + j] = a2; |
| B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS); |
| } |
| } |
| for (i = ly; i < h; i += WG_HEIGHT) { |
| int shift = 5; // -((1 - stage) * (i & 1)); |
| if (i & 1) { |
| shift = 4; |
| } |
| for (int j = lx; j < w; j += WG_WIDTH) { |
| int a = 0; |
| int b = 0; |
| for (int dy = -1; dy <= 1; dy++) { |
| for (int dx = -1; dx <= 1; dx++) { |
| int weight = 0; |
| if ((i + dy) & 1) { |
| weight = (dx == 0) ? 6 : 5; |
| } else { |
| weight = 0; |
| } |
| a += weight * A[1 + i + dy][1 + j + dx]; |
| b += weight * B[1 + i + dy][1 + j + dx]; |
| } |
| } |
| int v = a * get_loaded_source_sample(j, i) + b; |
| flt[0][i][j] = Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS); |
| } |
| } |
| } |
| |
| void box_filter1(int w, int h, int r, int eps, int lx, int ly, int bit_depth) { |
| uint n = (2 * r + 1) * (2 * r + 1); |
| int i; |
| for (i = ly - 1; i < h + 1; i += WG_HEIGHT) { |
| for (int j = lx - 1; j < w + 1; j += WG_WIDTH) { |
| uint a = 0; |
| uint b = 0; |
| for (int dy = -r; dy <= r; dy++) { |
| for (int dx = -r; dx <= r; dx++) { |
| uint c = get_loaded_source_sample(j + dx, i + dy); |
| a += c * c; |
| b += c; |
| } |
| } |
| a = Round2(a, 2 * (bit_depth - 8)); |
| uint d = Round2(b, bit_depth - 8); |
| uint p = max(0, int(a * n - d * d)); |
| uint z = Round2(p * eps, SGRPROJ_MTABLE_BITS); // p*s in documentation |
| z = min(z, 255); |
| uint a2 = 0; |
| if (z >= 255) |
| a2 = 256; |
| else if (z == 0) |
| a2 = 1; |
| else |
| a2 = ((z << SGRPROJ_SGR_BITS) + (z >> 1)) / (z + 1); |
| uint oneOverN = ((1 << SGRPROJ_RECIP_BITS) + (n >> 1)) / n; |
| uint b2 = ((1 << SGRPROJ_SGR_BITS) - a2) * b * oneOverN; |
| A[1 + i][1 + j] = a2; |
| B[1 + i][1 + j] = Round2(b2, SGRPROJ_RECIP_BITS); |
| } |
| } |
| for (i = ly; i < h; i += WG_HEIGHT) { |
| int shift = 5; // -((1 - stage) * (i & 1)); |
| for (int j = lx; j < w; j += WG_WIDTH) { |
| int a = 0; |
| int b = 0; |
| for (int dy = -1; dy <= 1; dy++) { |
| for (int dx = -1; dx <= 1; dx++) { |
| int weight = 0; |
| weight = (dx == 0 || dy == 0) ? 4 : 3; |
| a += weight * A[1 + i + dy][1 + j + dx]; |
| b += weight * B[1 + i + dy][1 + j + dx]; |
| } |
| } |
| int v = a * get_loaded_source_sample(j, i) + b; |
| flt[1][i][j] = Round2(v, SGRPROJ_SGR_BITS + shift - SGRPROJ_RST_BITS); |
| } |
| } |
| } |
| |
| [numthreads(WG_WIDTH, WG_HEIGHT, 1)] void main(uint3 thread |
| : SV_DispatchThreadID) { |
| const int gx = thread.x & (~(WG_WIDTH - 1)); |
| const int gy = thread.y & (~(WG_HEIGHT - 1)); |
| const int lx = thread.x & (WG_WIDTH - 1); |
| const int ly = thread.y & (WG_HEIGHT - 1); |
| |
| const int subsampling = pl[plane_id].subsampling; |
| const int bit_depth = pl[plane_id].bit_depth; |
| |
| // Load block |
| int stripe_id = uint((gy << subsampling) + 8) / STRIPE_SIZE; |
| int StripeStartY = ((stripe_id * STRIPE_SIZE) - 8) >> subsampling; |
| int StripeEndY = StripeStartY + (STRIPE_SIZE >> subsampling) - 1; |
| |
| PlaneInfo plane = pl[plane_id].plane; |
| int block_offset = 0; |
| |
| const int dst_offset = pl[plane_id].dst_offset; |
| |
| UnitsInfo units = pl[plane_id].units; |
| int unitCol = min(units.Cols - 1, uint(gx + 1) / units.Size); |
| int unitRow = min(units.Rows - 1, uint((gy + 1 + (8 >> subsampling))) / units.Size); |
| int unitId = unitRow * units.Stride + unitCol + pl[plane_id].Lr_buffer_offset; |
| int4 rType = int4(RESTORE_NONE, 0, 0, 0); |
| if (do_restoration) rType = LrTypeSgr.Load4(unitId * 16); |
| |
| if (rType.x == RESTORE_NONE || !do_restoration) { |
| if ((thread.x & 3) == 0) { |
| if (pl[plane_id].hbd) { |
| uint2 input_char = src.Load2(plane.offset + plane.stride * thread.y + (thread.x << 1)); |
| dst.Store2(dst_offset + plane.stride * thread.y + (thread.x << 1), input_char); |
| } else { |
| uint input_char = src.Load(plane.offset + plane.stride * thread.y + thread.x); |
| dst.Store(dst_offset + plane.stride * thread.y + thread.x, input_char); |
| } |
| } |
| return; |
| } |
| |
| for (int y = ly; y < WG_HEIGHT + 6; y += WG_HEIGHT) |
| for (int x = lx; x < WG_WIDTH / 4 + 2; x += WG_WIDTH) { |
| int c_y = clamp(y + gy - 3, StripeStartY - 2, StripeEndY + 2); |
| c_y = clamp(c_y, 0, plane.height - 1); |
| int nc_x = (x << 2) + gx - 4; |
| int c_x = clamp(nc_x, 0, plane.width - 1) & (~3); |
| int offset = (c_y < StripeStartY || c_y > StripeEndY) ? pl[plane_id].pp_offset : plane.offset; |
| |
| if (pl[plane_id].hbd) { |
| int shift_max = nc_x < 0 ? 0 : (plane.width - c_x - 1) * 16; |
| int shift_min = (nc_x - c_x) >= 4 ? shift_max : 0; |
| block_offset = c_y * plane.stride + (c_x << 1); |
| |
| uint2 input_char = src.Load2(offset + block_offset); |
| uint4 shift = uint4(clamp(0, shift_min, shift_max), clamp(16, shift_min, shift_max), |
| clamp(32, shift_min, shift_max), clamp(48, shift_min, shift_max)); |
| input[y][x * 4 + 0] = (shift.x > 16 ? (input_char.y >> (shift.x - 32)) : (input_char.x >> shift.x)) & 0xffff; |
| input[y][x * 4 + 1] = (shift.y > 16 ? (input_char.y >> (shift.y - 32)) : (input_char.x >> shift.y)) & 0xffff; |
| input[y][x * 4 + 2] = (shift.z > 16 ? (input_char.y >> (shift.z - 32)) : (input_char.x >> shift.z)) & 0xffff; |
| input[y][x * 4 + 3] = (shift.w > 16 ? (input_char.y >> (shift.w - 32)) : (input_char.x >> shift.w)) & 0xffff; |
| } else { |
| int shift_max = nc_x < 0 ? 0 : (plane.width - c_x - 1) * 8; |
| int shift_min = (nc_x - c_x) >= 4 ? shift_max : 0; |
| block_offset = c_y * plane.stride + c_x; |
| uint input_char = src.Load(offset + block_offset); |
| input[y][x * 4 + 0] = (input_char >> clamp(0, shift_min, shift_max)) & 255; |
| input[y][x * 4 + 1] = (input_char >> clamp(8, shift_min, shift_max)) & 255; |
| input[y][x * 4 + 2] = (input_char >> clamp(16, shift_min, shift_max)) & 255; |
| input[y][x * 4 + 3] = (input_char >> clamp(24, shift_min, shift_max)) & 255; |
| } |
| } |
| |
| GroupMemoryBarrier(); |
| |
| if (rType.x == RESTORE_WIENER) { |
| int limit = (1 << (bit_depth + 1 + FILTER_BITS - InterRound0)) - 1; |
| |
| int Lr_offset = unitId * 64; |
| int4 hfilter0 = LrWiener.Load4(Lr_offset + 0); |
| int4 hfilter1 = LrWiener.Load4(Lr_offset + 16); |
| |
| for (int r = ly; r < WG_HEIGHT + 6; r += WG_HEIGHT) { |
| int s = (input[r][lx + 4] << 7) + (1 << (bit_depth + 7 - 1)); |
| |
| s += hfilter0.x * input[r][lx + 1]; |
| s += hfilter0.y * input[r][lx + 2]; |
| s += hfilter0.z * input[r][lx + 3]; |
| s += hfilter0.w * input[r][lx + 4]; |
| s += hfilter1.x * input[r][lx + 5]; |
| s += hfilter1.y * input[r][lx + 6]; |
| s += hfilter1.z * input[r][lx + 7]; |
| |
| int v = Round2(s, InterRound0); |
| intermediate[r][lx] = clamp(v, 0, limit); |
| } |
| |
| int4 vfilter0 = LrWiener.Load4(Lr_offset + 32); |
| int4 vfilter1 = LrWiener.Load4(Lr_offset + 48); |
| |
| int s = (intermediate[ly + 3][lx] << 7) - (1 << (bit_depth + InterRound1 - 1)); |
| |
| s += vfilter0.x * intermediate[ly + 0][lx]; |
| s += vfilter0.y * intermediate[ly + 1][lx]; |
| s += vfilter0.z * intermediate[ly + 2][lx]; |
| s += vfilter0.w * intermediate[ly + 3][lx]; |
| s += vfilter1.x * intermediate[ly + 4][lx]; |
| s += vfilter1.y * intermediate[ly + 5][lx]; |
| s += vfilter1.z * intermediate[ly + 6][lx]; |
| int v = Round2(s, InterRound1); |
| output[ly][lx] = clamp(v, 0, (1 << bit_depth) - 1); // input[ly][lx + 4];// |
| if (lx < WG_WIDTH / 4) { |
| if (pl[plane_id].hbd) { |
| dst.Store2(dst_offset + (gy + ly) * plane.stride + (gx + lx * 4) * 2, |
| uint2((output[ly][lx * 4 + 0] << 0) | (output[ly][lx * 4 + 1] << 16), |
| (output[ly][lx * 4 + 2] << 0) | (output[ly][lx * 4 + 3] << 16))); |
| } else { |
| dst.Store(dst_offset + (gy + ly) * plane.stride + gx + lx * 4, |
| (output[ly][lx * 4 + 0] << 0) | (output[ly][lx * 4 + 1] << 8) | (output[ly][lx * 4 + 2] << 16) | |
| (output[ly][lx * 4 + 3] << 24)); |
| } |
| } |
| } else if (rType.x == RESTORE_SGRPROJ) { |
| const int w0 = rType.y; |
| const int w1 = rType.z; |
| const int w2 = (1 << SGRPROJ_PRJ_BITS) - w0 - w1; |
| |
| int r0 = data.Sgr_Params[rType.w].x; |
| int r1 = data.Sgr_Params[rType.w].y; |
| int eps0 = data.Sgr_Params[rType.w].z; |
| int eps1 = data.Sgr_Params[rType.w].w; |
| |
| box_filter0(WG_WIDTH, WG_HEIGHT, r0, eps0, lx, ly, bit_depth); |
| box_filter1(WG_WIDTH, WG_HEIGHT, r1, eps1, lx, ly, bit_depth); |
| |
| int u = input[ly + 3][lx + 4] << SGRPROJ_RST_BITS; |
| int v = w1 * u; |
| v += w0 * (r0 ? flt[0][ly][lx] : u); |
| v += w2 * (r1 ? flt[1][ly][lx] : u); |
| int s = Round2(v, (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)); |
| output[ly][lx] = clamp(s, 0, (1 << bit_depth) - 1); |
| if (lx < WG_WIDTH / 4) { |
| if (pl[plane_id].hbd) { |
| dst.Store2(dst_offset + (gy + ly) * plane.stride + (gx + lx * 4) * 2, |
| uint2((output[ly][lx * 4 + 0] << 0) | (output[ly][lx * 4 + 1] << 16), |
| (output[ly][lx * 4 + 2] << 0) | (output[ly][lx * 4 + 3] << 16))); |
| } else { |
| dst.Store(dst_offset + (gy + ly) * plane.stride + gx + lx * 4, |
| (output[ly][lx * 4 + 0] << 0) | (output[ly][lx * 4 + 1] << 8) | (output[ly][lx * 4 + 2] << 16) | |
| (output[ly][lx * 4 + 3] << 24)); |
| } |
| } |
| } |
| } |