blob: 5103df11fa09c3396fbd8b1838b9b239402c7c74 [file] [log] [blame]
/*
* Copyright 2020 Google LLC
*
*/
/*
* Copyright (c) 2020, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#define SubblockW 4
#define SubblockH 4
#define PixelMax 1023
cbuffer IntraDataCommon : register(b0) {
int4 cb_planes[3];
int4 cb_flags;
int4 cb_filter[5][8][2];
int4 cb_mode_params_lut[16][7];
int4 cb_sm_weight_arrays[128];
};
cbuffer PSSLIntraSRT : register(b1) {
uint cb_wi_count;
int cb_pass_offset;
uint cb_width_log2;
uint cb_height_log2;
uint cb_fb_base_offset;
int r[5];
};
ByteAddressBuffer pred_blocks : register(t0);
ByteAddressBuffer residuals : register(t1);
ByteAddressBuffer palette_buf : register(t2);
RWByteAddressBuffer dst_frame : register(u0);
[numthreads(64, 1, 1)] void main(uint3 thread
: SV_DispatchThreadID) {
if (thread.x >= cb_wi_count) return;
const int w_log = cb_width_log2;
const int h_log = cb_height_log2;
const int subblock = thread.x & ((1 << (w_log + h_log)) - 1);
const int block_index = cb_pass_offset + (thread.x >> (w_log + h_log));
uint2 block = pred_blocks.Load2(block_index << 4);
int x = SubblockW * ((block.x & 0xffff) + (subblock & ((1 << w_log) - 1)));
int y = SubblockH * ((block.x >> 16) + (subblock >> w_log));
const int plane = block.y & 3;
const int do_recon = block.y & 4;
const int do_palette = block.y & 8;
x <<= 1;
const int output_stride = cb_planes[plane].x;
const int output_offset = cb_planes[plane].y + x + y * output_stride;
const int res_stride = cb_planes[plane].z;
const int res_addr = cb_planes[plane].w + x + y * res_stride;
for (int i = 0; i < 4; ++i) {
uint addr = output_offset + i * output_stride;
uint2 pixels;
if (do_palette)
pixels = palette_buf.Load2(addr - cb_fb_base_offset);
else
pixels = dst_frame.Load2(addr);
if (do_recon) {
int2 res = residuals.Load2(res_addr + i * res_stride);
pixels.x = clamp((int)((pixels.x >> 0) & PixelMax) + (int)((res.x << 16) >> 16), 0, PixelMax) |
(clamp((int)((pixels.x >> 16) & PixelMax) + (int)(res.x >> 16), 0, PixelMax) << 16);
pixels.y = clamp((int)((pixels.y >> 0) & PixelMax) + (int)((res.y << 16) >> 16), 0, PixelMax) |
(clamp((int)((pixels.y >> 16) & PixelMax) + (int)(res.y >> 16), 0, PixelMax) << 16);
}
dst_frame.Store2(addr, pixels);
}
}