Add PVQ high bit depth support.
Change-Id: I4d43d33725a5a0e6fdfa1168d1397cb122366b19
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index f938deb..a5bc755 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1010,34 +1010,6 @@
*out_dist_sum = dist_sum;
}
-#if CONFIG_PVQ
-// Without PVQ, av1_block_error_c() return two kind of errors,
-// 1) reconstruction (i.e. decoded) error and
-// 2) Squared sum of transformed residue (i.e. 'coeff')
-// However, if PVQ is enabled, coeff does not keep the transformed residue
-// but instead a transformed original is kept.
-// Hence, new parameter ref vector (i.e. transformed predicted signal)
-// is required to derive the residue signal,
-// i.e. coeff - ref = residue (all transformed).
-
-// TODO(yushin) : Since 4x4 case does not need ssz, better to refactor into
-// a separate function that does not do the extra computations for ssz.
-static int64_t av1_block_error2_c(const tran_low_t *coeff,
- const tran_low_t *dqcoeff,
- const tran_low_t *ref, intptr_t block_size,
- int64_t *ssz) {
- int64_t error;
-
- // Use the existing sse codes for calculating distortion of decoded signal:
- // i.e. (orig - decoded)^2
- error = av1_block_error_fp(coeff, dqcoeff, block_size);
- // prediction residue^2 = (orig - ref)^2
- *ssz = av1_block_error_fp(coeff, ref, block_size);
-
- return error;
-}
-#endif // CONFIG_PVQ
-
int64_t av1_block_error_c(const tran_low_t *coeff, const tran_low_t *dqcoeff,
intptr_t block_size, int64_t *ssz) {
int i;
@@ -1089,6 +1061,57 @@
}
#endif // CONFIG_AOM_HIGHBITDEPTH
+#if CONFIG_PVQ
+// Without PVQ, av1_block_error_c() return two kind of errors,
+// 1) reconstruction (i.e. decoded) error and
+// 2) Squared sum of transformed residue (i.e. 'coeff')
+// However, if PVQ is enabled, coeff does not keep the transformed residue
+// but instead a transformed original is kept.
+// Hence, new parameter ref vector (i.e. transformed predicted signal)
+// is required to derive the residue signal,
+// i.e. coeff - ref = residue (all transformed).
+
+#if CONFIG_AOM_HIGHBITDEPTH
+static int64_t av1_highbd_block_error2_c(const tran_low_t *coeff,
+ const tran_low_t *dqcoeff,
+ const tran_low_t *ref,
+ intptr_t block_size, int64_t *ssz,
+ int bd) {
+ int64_t error;
+ int64_t sqcoeff;
+ int shift = 2 * (bd - 8);
+ int rounding = shift > 0 ? 1 << (shift - 1) : 0;
+ // Use the existing sse codes for calculating distortion of decoded signal:
+ // i.e. (orig - decoded)^2
+ // For high bit depth, throw away ssz until a 32-bit version of
+ // av1_block_error_fp is written.
+ int64_t ssz_trash;
+ error = av1_block_error(coeff, dqcoeff, block_size, &ssz_trash);
+ // prediction residue^2 = (orig - ref)^2
+ sqcoeff = av1_block_error(coeff, ref, block_size, &ssz_trash);
+ error = (error + rounding) >> shift;
+ sqcoeff = (sqcoeff + rounding) >> shift;
+ *ssz = sqcoeff;
+ return error;
+}
+#else
+// TODO(yushin) : Since 4x4 case does not need ssz, better to refactor into
+// a separate function that does not do the extra computations for ssz.
+static int64_t av1_block_error2_c(const tran_low_t *coeff,
+ const tran_low_t *dqcoeff,
+ const tran_low_t *ref, intptr_t block_size,
+ int64_t *ssz) {
+ int64_t error;
+ // Use the existing sse codes for calculating distortion of decoded signal:
+ // i.e. (orig - decoded)^2
+ error = av1_block_error_fp(coeff, dqcoeff, block_size);
+ // prediction residue^2 = (orig - ref)^2
+ *ssz = av1_block_error_fp(coeff, ref, block_size);
+ return error;
+}
+#endif // CONFIG_AOM_HIGHBITDEPTH
+#endif // CONFIG_PVQ
+
#if !CONFIG_PVQ || CONFIG_VAR_TX
/* The trailing '0' is a terminator which is used inside av1_cost_coeffs() to
* decide whether to include cost of a trailing EOB node or not (i.e. we
@@ -1260,20 +1283,26 @@
tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
#if CONFIG_PVQ
tran_low_t *ref_coeff = BLOCK_OFFSET(pd->pvq_ref_coeff, block);
-#endif // CONFIG_PVQ
+
#if CONFIG_AOM_HIGHBITDEPTH
const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8;
+ *out_dist = av1_highbd_block_error2_c(coeff, dqcoeff, ref_coeff,
+ buffer_length, &this_sse, bd) >>
+ shift;
+#else
+ *out_dist = av1_block_error2_c(coeff, dqcoeff, ref_coeff, buffer_length,
+ &this_sse) >>
+ shift;
+#endif // CONFIG_AOM_HIGHBITDEPTH
+#elif CONFIG_AOM_HIGHBITDEPTH
+ const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8;
*out_dist =
av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse, bd) >>
shift;
-#elif CONFIG_PVQ
- *out_dist = av1_block_error2_c(coeff, dqcoeff, ref_coeff, buffer_length,
- &this_sse) >>
- shift;
#else
*out_dist =
av1_block_error(coeff, dqcoeff, buffer_length, &this_sse) >> shift;
-#endif // CONFIG_AOM_HIGHBITDEPTH
+#endif // CONFIG_PVQ
*out_sse = this_sse >> shift;
} else {
const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
@@ -2637,6 +2666,9 @@
#if CONFIG_AOM_HIGHBITDEPTH
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+#if CONFIG_PVQ
+ od_encode_checkpoint(&x->daala_enc, &pre_buf);
+#endif
for (mode = DC_PRED; mode <= TM_PRED; ++mode) {
int64_t this_rd;
int ratey = 0;
@@ -2664,8 +2696,12 @@
av1_raster_order_to_block_index(tx_size, block_raster_idx);
const uint8_t *const src = &src_init[idx * 4 + idy * 4 * src_stride];
uint8_t *const dst = &dst_init[idx * 4 + idy * 4 * dst_stride];
+#if !CONFIG_PVQ
int16_t *const src_diff = av1_raster_block_offset_int16(
BLOCK_8X8, block_raster_idx, p->src_diff);
+#else
+ int i, j;
+#endif
int skip;
assert(block < 4);
assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
@@ -2676,14 +2712,17 @@
av1_predict_intra_block(
xd, pd->width, pd->height, txsize_to_bsize[tx_size], mode, dst,
dst_stride, dst, dst_stride, col + idx, row + idy, 0);
+#if !CONFIG_PVQ
aom_highbd_subtract_block(tx_height, tx_width, src_diff, 8, src,
src_stride, dst, dst_stride, xd->bd);
+#endif
if (is_lossless) {
TX_TYPE tx_type =
get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
const int coeff_ctx =
combine_entropy_contexts(tempa[idx], templ[idy]);
+#if !CONFIG_PVQ
#if CONFIG_NEW_QUANT
av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
@@ -2705,12 +2744,37 @@
templ[idy + 1] = templ[idy];
}
#endif // CONFIG_EXT_TX
+#else
+ (void)scan_order;
+ av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
+ tx_size, coeff_ctx, AV1_XFORM_QUANT_B);
+
+ ratey += x->rate;
+ skip = x->pvq_skip[0];
+ tempa[idx] = !skip;
+ templ[idy] = !skip;
+ can_skip &= skip;
+#endif
if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
goto next_highbd;
- highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
- dst_stride, p->eobs[block], xd->bd, DCT_DCT,
- 1);
+#if CONFIG_PVQ
+ if (!skip) {
+ if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+ for (j = 0; j < tx_height; j++)
+ for (i = 0; i < tx_width; i++)
+ *CONVERT_TO_SHORTPTR(dst + j * dst_stride + i) = 0;
+ } else {
+ for (j = 0; j < tx_height; j++)
+ for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
+ }
+#endif
+ highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+ dst_stride, p->eobs[block], xd->bd, DCT_DCT,
+ 1);
+#if CONFIG_PVQ
+ }
+#endif
} else {
int64_t dist;
unsigned int tmp;
@@ -2719,6 +2783,7 @@
const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
const int coeff_ctx =
combine_entropy_contexts(tempa[idx], templ[idy]);
+#if !CONFIG_PVQ
#if CONFIG_NEW_QUANT
av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
@@ -2741,9 +2806,34 @@
templ[idy + 1] = templ[idy];
}
#endif // CONFIG_EXT_TX
- highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
- dst_stride, p->eobs[block], xd->bd, tx_type,
- 0);
+#else
+ (void)scan_order;
+
+ av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
+ tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
+ ratey += x->rate;
+ skip = x->pvq_skip[0];
+ tempa[idx] = !skip;
+ templ[idy] = !skip;
+ can_skip &= skip;
+#endif
+#if CONFIG_PVQ
+ if (!skip) {
+ if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+ for (j = 0; j < tx_height; j++)
+ for (i = 0; i < tx_width; i++)
+ *CONVERT_TO_SHORTPTR(dst + j * dst_stride + i) = 0;
+ } else {
+ for (j = 0; j < tx_height; j++)
+ for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
+ }
+#endif
+ highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+ dst_stride, p->eobs[block], xd->bd, tx_type,
+ 0);
+#if CONFIG_PVQ
+ }
+#endif
cpi->fn_ptr[sub_bsize].vf(src, src_stride, dst, dst_stride, &tmp);
dist = (int64_t)tmp << 4;
distortion += dist;
@@ -2765,6 +2855,9 @@
*best_mode = mode;
memcpy(a, tempa, pred_width_in_transform_blocks * sizeof(tempa[0]));
memcpy(l, templ, pred_height_in_transform_blocks * sizeof(templ[0]));
+#if CONFIG_PVQ
+ od_encode_checkpoint(&x->daala_enc, &post_buf);
+#endif
for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {
memcpy(best_dst16 + idy * 8,
CONVERT_TO_SHORTPTR(dst_init + idy * dst_stride),
@@ -2772,10 +2865,17 @@
}
}
next_highbd : {}
+#if CONFIG_PVQ
+ od_encode_rollback(&x->daala_enc, &pre_buf);
+#endif
}
if (best_rd >= rd_thresh) return best_rd;
+#if CONFIG_PVQ
+ od_encode_rollback(&x->daala_enc, &post_buf);
+#endif
+
if (y_skip) *y_skip &= best_can_skip;
for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {