blob: 88d98f5bc4337c49ec907a5d821b0eff7f3047d7 [file] [log] [blame]
/*
* Copyright 2020 Google LLC
*
*/
/*
* Copyright (c) 2020, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
cbuffer IntraDataCommon : register(b0) {
int4 cb_planes[3];
int4 cb_flags;
int4 cb_filter[5][8][2];
int4 cb_mode_params_lut[16][7];
int4 cb_sm_weight_arrays[128];
};
cbuffer PSSLIntraSRT : register(b1) {
uint cb_wi_count;
int cb_pass_offset;
int4 cb_counts0;
int4 cb_counts1;
};
ByteAddressBuffer pred_blocks : register(t0);
ByteAddressBuffer residuals : register(t1);
ByteAddressBuffer wedge_mask : register(t2);
RWByteAddressBuffer dst_frame : register(u0);
#define ROWS 8
groupshared int mem[ROWS][8];
groupshared int4 loc_above[8];
groupshared int loc_left[32];
groupshared int loc_corner[18];
[numthreads(8, ROWS, 1)] void main(uint3 thread
: SV_DispatchThreadID) {
if (thread.x >= cb_wi_count) return;
uint3 block = pred_blocks.Load3((cb_pass_offset + (thread.x >> 3)) << 4);
const int wi = thread.x & 7;
int x = 8 * (block.x & 0xffff);
int y = 4 * (block.x >> 16);
// block.y bits:
// 0 3 bw_log
// 3 2 plane
// 5 1 non skip
// 6 4 mode
// all mods except intra_bc:
// 10 6 above_available
// 16 6 left_available
// dir mode params:
// 22 2 upsample
// 24 2 edge_filter above
// 26 2 edge_filter left
// 28 3 mode_angle
// 31 1 inter_intra?
// CFL:
// 22 4 alpha
//
// block.z
// intra_bc - mv
// inter-intra - coef. table indexes;
// filter - mode_info0 = txh | (filter_mode << 4);
// block.w - reserved; (prob. used for sorting);
const int bw = 1 << (block.y & 3);
const int bh = 2 << (block.z & 3);
const int plane = (block.y >> 3) & 3;
const int mode = (block.z >> 4) & 7;
const int stride = cb_planes[plane].x;
const int offset = cb_planes[plane].y + x + y * stride;
const int above_available = (block.y >> 10) & 63;
const int left_available = ((block.y >> 16) & 63) << 2;
if (wi < bw) {
uint2 pixels = uint2(0, 0x01ff0000);
if (above_available)
pixels = dst_frame.Load2(offset - stride + min(above_available - 1, wi) * 8);
else if (left_available)
pixels.y = dst_frame.Load(offset - 4);
if (wi >= above_available) {
pixels.y = (pixels.y & 0x03ff0000) | (pixels.y >> 16);
pixels.x = pixels.y;
}
loc_above[wi].x = (pixels.x >> 0) & 1023;
loc_above[wi].y = (pixels.x >> 16) & 1023;
loc_above[wi].z = (pixels.y >> 0) & 1023;
loc_above[wi].w = (pixels.y >> 16) & 1023;
}
GroupMemoryBarrier();
int row;
for (row = wi; row < bh * 2; row += 8) {
const int l_addr = offset - 4 + min(row, left_available - 1) * stride;
int left = left_available ? (dst_frame.Load(l_addr) >> 16) : above_available ? loc_above[0].x : 513;
loc_left[row] = left;
if (row & 1) loc_corner[((row + 1) >> 1)] = left;
}
if (wi == 0) {
loc_corner[0] = (left_available && above_available) ? (dst_frame.Load(offset - 4 - stride) >> 16)
: (left_available || above_available) ? loc_above[0].x : 512;
}
GroupMemoryBarrier();
int4 filter0 = cb_filter[mode][wi][0];
int3 filter1 = cb_filter[mode][wi][1].xyz;
for (row = thread.y; row < bh; row += ROWS)
for (int col = -(int)thread.y - 1; col < bw; ++col) {
if (col < 0) {
continue;
}
int p0 = loc_corner[row];
int4 p14 = loc_above[col];
int p5 = loc_left[2 * row];
int p6 = loc_left[2 * row + 1];
int pixel = p0 * filter0.x + p14.x * filter0.y + p14.y * filter0.z + p14.z * filter0.w + p14.w * filter1.x +
p5 * filter1.y + p6 * filter1.z;
pixel = clamp((pixel + 8) >> 4, 0, 1023);
mem[thread.y][wi] = pixel;
GroupMemoryBarrier();
if (wi < 2) {
int4 pix;
pix.x = mem[thread.y][wi * 4 + 0];
pix.y = mem[thread.y][wi * 4 + 1];
pix.z = mem[thread.y][wi * 4 + 2];
pix.w = mem[thread.y][wi * 4 + 3];
loc_left[2 * row + wi] = pix.w;
if (wi == 1) {
loc_above[col] = pix;
loc_corner[row] = p14.w;
}
uint2 pixel4;
pixel4.x = pix.x;
pixel4.x |= pix.y << 16;
pixel4.y = pix.z;
pixel4.y |= pix.w << 16;
const int addr = offset + col * 8 + (row * 2 + wi) * stride;
dst_frame.Store2(addr, pixel4);
}
GroupMemoryBarrier();
}
if (block.y & (1 << 5)) {
const int res_stride = cb_planes[plane].z;
const int res_offset = cb_planes[plane].w + x + y * res_stride;
// 4x4: (11-wi)/8 = 1 1 1 1 0 0 0 0
// 8x8: (23-wi)/8= 2 2 2 2 2 2 2 2
// 16x16: (2*4*8+7 - wi) / 8 = 8 8 8 8 8 8 8 8
const int wi2 = wi + thread.y * 8;
for (int i = wi2; i < (4 * bw * bh); i += (8 * ROWS)) {
const int wx = i & ((bw << 1) - 1);
const int wy = i >> ((block.y & 3) + 1);
const int addr = offset + 4 * wx + stride * wy;
uint pix = dst_frame.Load(addr);
int r = residuals.Load(res_offset + 4 * wx + res_stride * wy);
uint result = (clamp((int)((pix >> 0) & 1023) + (int)((r.x << 16) >> 16), 0, 1023) << 0) |
(clamp((int)((pix >> 16) & 1023) + (int)(r.x >> 16), 0, 1023) << 16);
dst_frame.Store(addr, result);
}
}
}