Add PVQ high bit depth support.

Change-Id: I4d43d33725a5a0e6fdfa1168d1397cb122366b19
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 6ddb06f..be65c53 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4895,8 +4895,15 @@
     int segment_id = 0;
     int rdmult = set_segment_rdmult(cpi, &td->mb, segment_id);
     int qindex = av1_get_qindex(&cm->seg, segment_id, cm->base_qindex);
-    int64_t q_ac = av1_ac_quant(qindex, 0, cpi->common.bit_depth);
-    int64_t q_dc = av1_dc_quant(qindex, 0, cpi->common.bit_depth);
+#if CONFIG_AOM_HIGHBITDEPTH
+    const int quantizer_shift = td->mb.e_mbd.bd - 8;
+#else
+    const int quantizer_shift = 0;
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+    int64_t q_ac = OD_MAXI(
+        1, av1_ac_quant(qindex, 0, cpi->common.bit_depth) >> quantizer_shift);
+    int64_t q_dc = OD_MAXI(
+        1, av1_dc_quant(qindex, 0, cpi->common.bit_depth) >> quantizer_shift);
     /* td->mb.daala_enc.pvq_norm_lambda = OD_PVQ_LAMBDA; */
     td->mb.daala_enc.pvq_norm_lambda =
         (double)rdmult * (64 / 16) / (q_ac * q_ac * (1 << RDDIV_BITS));
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 32b9986..d165af5 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -453,6 +453,7 @@
   return final_eob;
 }
 
+#if !CONFIG_PVQ
 #if CONFIG_AOM_HIGHBITDEPTH
 typedef enum QUANT_FUNC {
   QUANT_FUNC_LOWBD = 0,
@@ -473,7 +474,7 @@
       { NULL, NULL }
     };
 
-#elif !CONFIG_PVQ
+#else
 
 typedef enum QUANT_FUNC {
   QUANT_FUNC_LOWBD = 0,
@@ -492,7 +493,8 @@
 #endif  // CONFIG_NEW_QUANT
                                          { NULL }
                                        };
-#endif
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+#endif  // CONFIG_PVQ
 
 void av1_xform_quant(const AV1_COMMON *cm, MACROBLOCK *x, int plane, int block,
                      int blk_row, int blk_col, BLOCK_SIZE plane_bsize,
@@ -570,10 +572,20 @@
 
   // transform block size in pixels
   tx_blk_size = tx_size_wide[tx_size];
-
-  for (j = 0; j < tx_blk_size; j++)
-    for (i = 0; i < tx_blk_size; i++)
-      src_int16[diff_stride * j + i] = src[src_stride * j + i];
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    for (j = 0; j < tx_blk_size; j++)
+      for (i = 0; i < tx_blk_size; i++)
+        src_int16[diff_stride * j + i] =
+            CONVERT_TO_SHORTPTR(src)[src_stride * j + i];
+  } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+    for (j = 0; j < tx_blk_size; j++)
+      for (i = 0; i < tx_blk_size; i++)
+        src_int16[diff_stride * j + i] = src[src_stride * j + i];
+#if CONFIG_AOM_HIGHBITDEPTH
+  }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif
 
 #if CONFIG_PVQ || CONFIG_DAALA_DIST
@@ -583,12 +595,22 @@
   // transform block size in pixels
   tx_blk_size = tx_size_wide[tx_size];
 
-  // copy uint8 orig and predicted block to int16 buffer
-  // in order to use existing VP10 transform functions
-  for (j = 0; j < tx_blk_size; j++)
-    for (i = 0; i < tx_blk_size; i++) {
-      pred[diff_stride * j + i] = dst[dst_stride * j + i];
-    }
+// copy uint8 orig and predicted block to int16 buffer
+// in order to use existing VP10 transform functions
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    for (j = 0; j < tx_blk_size; j++)
+      for (i = 0; i < tx_blk_size; i++)
+        pred[diff_stride * j + i] =
+            CONVERT_TO_SHORTPTR(dst)[dst_stride * j + i];
+  } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+    for (j = 0; j < tx_blk_size; j++)
+      for (i = 0; i < tx_blk_size; i++)
+        pred[diff_stride * j + i] = dst[dst_stride * j + i];
+#if CONFIG_AOM_HIGHBITDEPTH
+  }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif
 
   (void)ctx;
@@ -597,6 +619,7 @@
   fwd_txfm_param.tx_size = tx_size;
   fwd_txfm_param.lossless = xd->lossless[mbmi->segment_id];
 
+#if !CONFIG_PVQ
 #if CONFIG_AOM_HIGHBITDEPTH
   fwd_txfm_param.bd = xd->bd;
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
@@ -612,8 +635,6 @@
     return;
   }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-
-#if !CONFIG_PVQ
   fwd_txfm(src_diff, coeff, diff_stride, &fwd_txfm_param);
   if (xform_quant_idx != AV1_XFORM_QUANT_SKIP_QUANT) {
     if (LIKELY(!x->skip_block)) {
@@ -623,16 +644,25 @@
       av1_quantize_skip(tx2d_size, qcoeff, dqcoeff, eob);
     }
   }
-#else   // #if !CONFIG_PVQ
-
+#else  // #if !CONFIG_PVQ
   (void)xform_quant_idx;
-  fwd_txfm(src_int16, coeff, diff_stride, &fwd_txfm_param);
-  fwd_txfm(pred, ref_coeff, diff_stride, &fwd_txfm_param);
+#if CONFIG_AOM_HIGHBITDEPTH
+  fwd_txfm_param.bd = xd->bd;
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    highbd_fwd_txfm(src_int16, coeff, diff_stride, &fwd_txfm_param);
+    highbd_fwd_txfm(pred, ref_coeff, diff_stride, &fwd_txfm_param);
+  } else {
+#endif
+    fwd_txfm(src_int16, coeff, diff_stride, &fwd_txfm_param);
+    fwd_txfm(pred, ref_coeff, diff_stride, &fwd_txfm_param);
+#if CONFIG_AOM_HIGHBITDEPTH
+  }
+#endif
 
   // PVQ for inter mode block
   if (!x->skip_block) {
     PVQ_SKIP_TYPE ac_dc_coded =
-        av1_pvq_encode_helper(&x->daala_enc,
+        av1_pvq_encode_helper(x,
                               coeff,        // target original vector
                               ref_coeff,    // reference vector
                               dqcoeff,      // de-quantized vector
@@ -844,12 +874,22 @@
       // transform block size in pixels
       tx_blk_size = tx_size_wide[tx_size];
 
-      // Since av1 does not have separate function which does inverse transform
-      // but av1_inv_txfm_add_*x*() also does addition of predicted image to
-      // inverse transformed image,
-      // pass blank dummy image to av1_inv_txfm_add_*x*(), i.e. set dst as zeros
-      for (j = 0; j < tx_blk_size; j++)
-        for (i = 0; i < tx_blk_size; i++) dst[j * pd->dst.stride + i] = 0;
+// Since av1 does not have separate function which does inverse transform
+// but av1_inv_txfm_add_*x*() also does addition of predicted image to
+// inverse transformed image,
+// pass blank dummy image to av1_inv_txfm_add_*x*(), i.e. set dst as zeros
+#if CONFIG_AOM_HIGHBITDEPTH
+      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+        for (j = 0; j < tx_blk_size; j++)
+          for (i = 0; i < tx_blk_size; i++)
+            CONVERT_TO_SHORTPTR(dst)[j * pd->dst.stride + i] = 0;
+      } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+        for (j = 0; j < tx_blk_size; j++)
+          for (i = 0; i < tx_blk_size; i++) dst[j * pd->dst.stride + i] = 0;
+#if CONFIG_AOM_HIGHBITDEPTH
+      }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
     }
 #endif  // !CONFIG_PVQ
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -1108,23 +1148,36 @@
   // transform block size in pixels
   tx_blk_size = tx_size_wide[tx_size];
 
-  // Since av1 does not have separate function which does inverse transform
-  // but av1_inv_txfm_add_*x*() also does addition of predicted image to
-  // inverse transformed image,
-  // pass blank dummy image to av1_inv_txfm_add_*x*(), i.e. set dst as zeros
-
-  for (j = 0; j < tx_blk_size; j++)
-    for (i = 0; i < tx_blk_size; i++) dst[j * dst_stride + i] = 0;
+// Since av1 does not have separate function which does inverse transform
+// but av1_inv_txfm_add_*x*() also does addition of predicted image to
+// inverse transformed image,
+// pass blank dummy image to av1_inv_txfm_add_*x*(), i.e. set dst as zeros
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    for (j = 0; j < tx_blk_size; j++)
+      for (i = 0; i < tx_blk_size; i++)
+        CONVERT_TO_SHORTPTR(dst)[j * dst_stride + i] = 0;
+  } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+    for (j = 0; j < tx_blk_size; j++)
+      for (i = 0; i < tx_blk_size; i++) dst[j * dst_stride + i] = 0;
+#if CONFIG_AOM_HIGHBITDEPTH
+  }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
 
   inv_txfm_param.tx_type = tx_type;
   inv_txfm_param.tx_size = tx_size;
   inv_txfm_param.eob = *eob;
   inv_txfm_param.lossless = xd->lossless[mbmi->segment_id];
 #if CONFIG_AOM_HIGHBITDEPTH
-#error
-
-#else
-  av1_inv_txfm_add(dqcoeff, dst, dst_stride, &inv_txfm_param);
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    inv_txfm_param.bd = xd->bd;
+    av1_highbd_inv_txfm_add(dqcoeff, dst, dst_stride, &inv_txfm_param);
+  } else {
+#endif
+    av1_inv_txfm_add(dqcoeff, dst, dst_stride, &inv_txfm_param);
+#if CONFIG_AOM_HIGHBITDEPTH
+  }
 #endif
 #endif  // #if !CONFIG_PVQ
 
@@ -1165,14 +1218,17 @@
 }
 
 #if CONFIG_PVQ
-PVQ_SKIP_TYPE av1_pvq_encode_helper(
-    daala_enc_ctx *daala_enc, tran_low_t *const coeff, tran_low_t *ref_coeff,
-    tran_low_t *const dqcoeff, uint16_t *eob, const int16_t *quant, int plane,
-    int tx_size, TX_TYPE tx_type, int *rate, int speed, PVQ_INFO *pvq_info) {
+PVQ_SKIP_TYPE av1_pvq_encode_helper(MACROBLOCK *x, tran_low_t *const coeff,
+                                    tran_low_t *ref_coeff,
+                                    tran_low_t *const dqcoeff, uint16_t *eob,
+                                    const int16_t *quant, int plane,
+                                    int tx_size, TX_TYPE tx_type, int *rate,
+                                    int speed, PVQ_INFO *pvq_info) {
   const int tx_blk_size = tx_size_wide[tx_size];
+  daala_enc_ctx *daala_enc = &x->daala_enc;
   PVQ_SKIP_TYPE ac_dc_coded;
-  /*TODO(tterribe): Handle CONFIG_AOM_HIGHBITDEPTH.*/
   int coeff_shift = 3 - get_tx_scale(tx_size);
+  int hbd_downshift = 0;
   int rounding_mask;
   int pvq_dc_quant;
   int use_activity_masking = daala_enc->use_activity_masking;
@@ -1189,16 +1245,21 @@
   DECLARE_ALIGNED(16, int32_t, ref_int32[OD_TXSIZE_MAX * OD_TXSIZE_MAX]);
   DECLARE_ALIGNED(16, int32_t, out_int32[OD_TXSIZE_MAX * OD_TXSIZE_MAX]);
 
-  assert(OD_COEFF_SHIFT >= 3);
+#if CONFIG_AOM_HIGHBITDEPTH
+  hbd_downshift = x->e_mbd.bd - 8;
+#endif
+
+  assert(OD_COEFF_SHIFT >= 4);
   // DC quantizer for PVQ
   if (use_activity_masking)
     pvq_dc_quant =
-        OD_MAXI(1, (quant[0] << (OD_COEFF_SHIFT - 3)) *
+        OD_MAXI(1, (quant[0] << (OD_COEFF_SHIFT - 3) >> hbd_downshift) *
                            daala_enc->state
                                .pvq_qm_q4[plane][od_qm_get_index(tx_size, 0)] >>
                        4);
   else
-    pvq_dc_quant = OD_MAXI(1, quant[0] << (OD_COEFF_SHIFT - 3));
+    pvq_dc_quant =
+        OD_MAXI(1, quant[0] << (OD_COEFF_SHIFT - 3) >> hbd_downshift);
 
   *eob = 0;
 
@@ -1217,8 +1278,10 @@
   // copy int16 inputs to int32
   for (i = 0; i < tx_blk_size * tx_blk_size; i++) {
     ref_int32[i] =
-        AOM_SIGNED_SHL(ref_coeff_pvq[i], OD_COEFF_SHIFT - coeff_shift);
-    in_int32[i] = AOM_SIGNED_SHL(coeff_pvq[i], OD_COEFF_SHIFT - coeff_shift);
+        AOM_SIGNED_SHL(ref_coeff_pvq[i], OD_COEFF_SHIFT - coeff_shift) >>
+        hbd_downshift;
+    in_int32[i] = AOM_SIGNED_SHL(coeff_pvq[i], OD_COEFF_SHIFT - coeff_shift) >>
+                  hbd_downshift;
   }
 
   if (abs(in_int32[0] - ref_int32[0]) < pvq_dc_quant * 141 / 256) { /* 0.55 */
@@ -1227,17 +1290,20 @@
     out_int32[0] = OD_DIV_R0(in_int32[0] - ref_int32[0], pvq_dc_quant);
   }
 
-  ac_dc_coded = od_pvq_encode(
-      daala_enc, ref_int32, in_int32, out_int32,
-      quant[0] << (OD_COEFF_SHIFT - 3),  // scale/quantizer
-      quant[1] << (OD_COEFF_SHIFT - 3),  // scale/quantizer
-      plane, tx_size, OD_PVQ_BETA[use_activity_masking][plane][tx_size],
-      OD_ROBUST_STREAM,
-      0,        // is_keyframe,
-      0, 0, 0,  // q_scaling, bx, by,
-      daala_enc->state.qm + off, daala_enc->state.qm_inv + off,
-      speed,  // speed
-      pvq_info);
+  ac_dc_coded =
+      od_pvq_encode(daala_enc, ref_int32, in_int32, out_int32,
+                    OD_MAXI(1, quant[0] << (OD_COEFF_SHIFT - 3) >>
+                                   hbd_downshift),  // scale/quantizer
+                    OD_MAXI(1, quant[1] << (OD_COEFF_SHIFT - 3) >>
+                                   hbd_downshift),  // scale/quantizer
+                    plane,
+                    tx_size, OD_PVQ_BETA[use_activity_masking][plane][tx_size],
+                    OD_ROBUST_STREAM,
+                    0,        // is_keyframe,
+                    0, 0, 0,  // q_scaling, bx, by,
+                    daala_enc->state.qm + off, daala_enc->state.qm_inv + off,
+                    speed,  // speed
+                    pvq_info);
 
   // Encode residue of DC coeff, if required.
   if (!has_dc_skip || out_int32[0]) {
@@ -1260,6 +1326,7 @@
   assert(OD_COEFF_SHIFT > coeff_shift);
   rounding_mask = (1 << (OD_COEFF_SHIFT - coeff_shift - 1)) - 1;
   for (i = 0; i < tx_blk_size * tx_blk_size; i++) {
+    out_int32[i] = AOM_SIGNED_SHL(out_int32[i], hbd_downshift);
     dqcoeff_pvq[i] = (out_int32[i] + (out_int32[i] < 0) + rounding_mask) >>
                      (OD_COEFF_SHIFT - coeff_shift);
   }
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index f093b3a..96aa5b2 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -72,10 +72,12 @@
                                   const int mi_col);
 
 #if CONFIG_PVQ
-PVQ_SKIP_TYPE av1_pvq_encode_helper(
-    daala_enc_ctx *daala_enc, tran_low_t *const coeff, tran_low_t *ref_coeff,
-    tran_low_t *const dqcoeff, uint16_t *eob, const int16_t *quant, int plane,
-    int tx_size, TX_TYPE tx_type, int *rate, int speed, PVQ_INFO *pvq_info);
+PVQ_SKIP_TYPE av1_pvq_encode_helper(MACROBLOCK *x, tran_low_t *const coeff,
+                                    tran_low_t *ref_coeff,
+                                    tran_low_t *const dqcoeff, uint16_t *eob,
+                                    const int16_t *quant, int plane,
+                                    int tx_size, TX_TYPE tx_type, int *rate,
+                                    int speed, PVQ_INFO *pvq_info);
 
 void av1_store_pvq_enc_info(PVQ_INFO *pvq_info, int *qg, int *theta,
                             int *max_theta, int *k, od_coeff *y, int nb_bands,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index f938deb..a5bc755 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1010,34 +1010,6 @@
   *out_dist_sum = dist_sum;
 }
 
-#if CONFIG_PVQ
-// Without PVQ, av1_block_error_c() return two kind of errors,
-// 1) reconstruction (i.e. decoded) error and
-// 2) Squared sum of transformed residue (i.e. 'coeff')
-// However, if PVQ is enabled, coeff does not keep the transformed residue
-// but instead a transformed original is kept.
-// Hence, new parameter ref vector (i.e. transformed predicted signal)
-// is required to derive the residue signal,
-// i.e. coeff - ref = residue (all transformed).
-
-// TODO(yushin) : Since 4x4 case does not need ssz, better to refactor into
-// a separate function that does not do the extra computations for ssz.
-static int64_t av1_block_error2_c(const tran_low_t *coeff,
-                                  const tran_low_t *dqcoeff,
-                                  const tran_low_t *ref, intptr_t block_size,
-                                  int64_t *ssz) {
-  int64_t error;
-
-  // Use the existing sse codes for calculating distortion of decoded signal:
-  // i.e. (orig - decoded)^2
-  error = av1_block_error_fp(coeff, dqcoeff, block_size);
-  // prediction residue^2 = (orig - ref)^2
-  *ssz = av1_block_error_fp(coeff, ref, block_size);
-
-  return error;
-}
-#endif  // CONFIG_PVQ
-
 int64_t av1_block_error_c(const tran_low_t *coeff, const tran_low_t *dqcoeff,
                           intptr_t block_size, int64_t *ssz) {
   int i;
@@ -1089,6 +1061,57 @@
 }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
+#if CONFIG_PVQ
+// Without PVQ, av1_block_error_c() return two kind of errors,
+// 1) reconstruction (i.e. decoded) error and
+// 2) Squared sum of transformed residue (i.e. 'coeff')
+// However, if PVQ is enabled, coeff does not keep the transformed residue
+// but instead a transformed original is kept.
+// Hence, new parameter ref vector (i.e. transformed predicted signal)
+// is required to derive the residue signal,
+// i.e. coeff - ref = residue (all transformed).
+
+#if CONFIG_AOM_HIGHBITDEPTH
+static int64_t av1_highbd_block_error2_c(const tran_low_t *coeff,
+                                         const tran_low_t *dqcoeff,
+                                         const tran_low_t *ref,
+                                         intptr_t block_size, int64_t *ssz,
+                                         int bd) {
+  int64_t error;
+  int64_t sqcoeff;
+  int shift = 2 * (bd - 8);
+  int rounding = shift > 0 ? 1 << (shift - 1) : 0;
+  // Use the existing sse codes for calculating distortion of decoded signal:
+  // i.e. (orig - decoded)^2
+  // For high bit depth, throw away ssz until a 32-bit version of
+  // av1_block_error_fp is written.
+  int64_t ssz_trash;
+  error = av1_block_error(coeff, dqcoeff, block_size, &ssz_trash);
+  // prediction residue^2 = (orig - ref)^2
+  sqcoeff = av1_block_error(coeff, ref, block_size, &ssz_trash);
+  error = (error + rounding) >> shift;
+  sqcoeff = (sqcoeff + rounding) >> shift;
+  *ssz = sqcoeff;
+  return error;
+}
+#else
+// TODO(yushin) : Since 4x4 case does not need ssz, better to refactor into
+// a separate function that does not do the extra computations for ssz.
+static int64_t av1_block_error2_c(const tran_low_t *coeff,
+                                  const tran_low_t *dqcoeff,
+                                  const tran_low_t *ref, intptr_t block_size,
+                                  int64_t *ssz) {
+  int64_t error;
+  // Use the existing sse codes for calculating distortion of decoded signal:
+  // i.e. (orig - decoded)^2
+  error = av1_block_error_fp(coeff, dqcoeff, block_size);
+  // prediction residue^2 = (orig - ref)^2
+  *ssz = av1_block_error_fp(coeff, ref, block_size);
+  return error;
+}
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+#endif  // CONFIG_PVQ
+
 #if !CONFIG_PVQ || CONFIG_VAR_TX
 /* The trailing '0' is a terminator which is used inside av1_cost_coeffs() to
  * decide whether to include cost of a trailing EOB node or not (i.e. we
@@ -1260,20 +1283,26 @@
     tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
 #if CONFIG_PVQ
     tran_low_t *ref_coeff = BLOCK_OFFSET(pd->pvq_ref_coeff, block);
-#endif  // CONFIG_PVQ
+
 #if CONFIG_AOM_HIGHBITDEPTH
     const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8;
+    *out_dist = av1_highbd_block_error2_c(coeff, dqcoeff, ref_coeff,
+                                          buffer_length, &this_sse, bd) >>
+                shift;
+#else
+    *out_dist = av1_block_error2_c(coeff, dqcoeff, ref_coeff, buffer_length,
+                                   &this_sse) >>
+                shift;
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+#elif CONFIG_AOM_HIGHBITDEPTH
+    const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8;
     *out_dist =
         av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse, bd) >>
         shift;
-#elif CONFIG_PVQ
-    *out_dist = av1_block_error2_c(coeff, dqcoeff, ref_coeff, buffer_length,
-                                   &this_sse) >>
-                shift;
 #else
     *out_dist =
         av1_block_error(coeff, dqcoeff, buffer_length, &this_sse) >> shift;
-#endif  // CONFIG_AOM_HIGHBITDEPTH
+#endif  // CONFIG_PVQ
     *out_sse = this_sse >> shift;
   } else {
     const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
@@ -2637,6 +2666,9 @@
 
 #if CONFIG_AOM_HIGHBITDEPTH
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+#if CONFIG_PVQ
+    od_encode_checkpoint(&x->daala_enc, &pre_buf);
+#endif
     for (mode = DC_PRED; mode <= TM_PRED; ++mode) {
       int64_t this_rd;
       int ratey = 0;
@@ -2664,8 +2696,12 @@
               av1_raster_order_to_block_index(tx_size, block_raster_idx);
           const uint8_t *const src = &src_init[idx * 4 + idy * 4 * src_stride];
           uint8_t *const dst = &dst_init[idx * 4 + idy * 4 * dst_stride];
+#if !CONFIG_PVQ
           int16_t *const src_diff = av1_raster_block_offset_int16(
               BLOCK_8X8, block_raster_idx, p->src_diff);
+#else
+          int i, j;
+#endif
           int skip;
           assert(block < 4);
           assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
@@ -2676,14 +2712,17 @@
           av1_predict_intra_block(
               xd, pd->width, pd->height, txsize_to_bsize[tx_size], mode, dst,
               dst_stride, dst, dst_stride, col + idx, row + idy, 0);
+#if !CONFIG_PVQ
           aom_highbd_subtract_block(tx_height, tx_width, src_diff, 8, src,
                                     src_stride, dst, dst_stride, xd->bd);
+#endif
           if (is_lossless) {
             TX_TYPE tx_type =
                 get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
             const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
             const int coeff_ctx =
                 combine_entropy_contexts(tempa[idx], templ[idy]);
+#if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
                             tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
@@ -2705,12 +2744,37 @@
               templ[idy + 1] = templ[idy];
             }
 #endif  // CONFIG_EXT_TX
+#else
+            (void)scan_order;
 
+            av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_B);
+
+            ratey += x->rate;
+            skip = x->pvq_skip[0];
+            tempa[idx] = !skip;
+            templ[idy] = !skip;
+            can_skip &= skip;
+#endif
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
               goto next_highbd;
-            highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                 dst_stride, p->eobs[block], xd->bd, DCT_DCT,
-                                 1);
+#if CONFIG_PVQ
+            if (!skip) {
+              if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+                for (j = 0; j < tx_height; j++)
+                  for (i = 0; i < tx_width; i++)
+                    *CONVERT_TO_SHORTPTR(dst + j * dst_stride + i) = 0;
+              } else {
+                for (j = 0; j < tx_height; j++)
+                  for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
+              }
+#endif
+              highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+                                   dst_stride, p->eobs[block], xd->bd, DCT_DCT,
+                                   1);
+#if CONFIG_PVQ
+            }
+#endif
           } else {
             int64_t dist;
             unsigned int tmp;
@@ -2719,6 +2783,7 @@
             const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
             const int coeff_ctx =
                 combine_entropy_contexts(tempa[idx], templ[idy]);
+#if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
                             tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
@@ -2741,9 +2806,34 @@
               templ[idy + 1] = templ[idy];
             }
 #endif  // CONFIG_EXT_TX
-            highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                 dst_stride, p->eobs[block], xd->bd, tx_type,
-                                 0);
+#else
+            (void)scan_order;
+
+            av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
+            ratey += x->rate;
+            skip = x->pvq_skip[0];
+            tempa[idx] = !skip;
+            templ[idy] = !skip;
+            can_skip &= skip;
+#endif
+#if CONFIG_PVQ
+            if (!skip) {
+              if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+                for (j = 0; j < tx_height; j++)
+                  for (i = 0; i < tx_width; i++)
+                    *CONVERT_TO_SHORTPTR(dst + j * dst_stride + i) = 0;
+              } else {
+                for (j = 0; j < tx_height; j++)
+                  for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
+              }
+#endif
+              highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+                                   dst_stride, p->eobs[block], xd->bd, tx_type,
+                                   0);
+#if CONFIG_PVQ
+            }
+#endif
             cpi->fn_ptr[sub_bsize].vf(src, src_stride, dst, dst_stride, &tmp);
             dist = (int64_t)tmp << 4;
             distortion += dist;
@@ -2765,6 +2855,9 @@
         *best_mode = mode;
         memcpy(a, tempa, pred_width_in_transform_blocks * sizeof(tempa[0]));
         memcpy(l, templ, pred_height_in_transform_blocks * sizeof(templ[0]));
+#if CONFIG_PVQ
+        od_encode_checkpoint(&x->daala_enc, &post_buf);
+#endif
         for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {
           memcpy(best_dst16 + idy * 8,
                  CONVERT_TO_SHORTPTR(dst_init + idy * dst_stride),
@@ -2772,10 +2865,17 @@
         }
       }
     next_highbd : {}
+#if CONFIG_PVQ
+      od_encode_rollback(&x->daala_enc, &pre_buf);
+#endif
     }
 
     if (best_rd >= rd_thresh) return best_rd;
 
+#if CONFIG_PVQ
+    od_encode_rollback(&x->daala_enc, &post_buf);
+#endif
+
     if (y_skip) *y_skip &= best_can_skip;
 
     for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {