Remove deprecated high-bitdepth functions

This unifies the codepath for high-bitdepth transforms and deletes
all calls to the old deprecated versions. This required reworking
the way 1d configurations are combined in order to support rectangular
transforms.

There is one remaining codepath that calls the deprecated 4x4 hbd
transform from encoder/encodemb.c. I need to take a closer look
at what is happening there and will leave that for a followup
since this change has already gotten so large.

lowres 10 bit: -0.035%
lowres 12 bit: 0.021%

BUG=aomedia:524

Change-Id: I34cdeaed2461ed7942364147cef10d7d21e3779c
diff --git a/aom_dsp/fwd_txfm.h b/aom_dsp/fwd_txfm.h
index 579dbd0..f4dc04a 100644
--- a/aom_dsp/fwd_txfm.h
+++ b/aom_dsp/fwd_txfm.h
@@ -20,10 +20,5 @@
   return result < INT16_MIN ? INT16_MIN : result;
 }
 
-static INLINE tran_high_t fdct_round_shift(tran_high_t input) {
-  tran_high_t rv = ROUND_POWER_OF_TWO(input, DCT_CONST_BITS);
-  return rv;
-}
-
 void aom_fdct32(const tran_high_t *input, tran_high_t *output, int round);
 #endif  // AOM_DSP_FWD_TXFM_H_
diff --git a/aom_dsp/inv_txfm.c b/aom_dsp/inv_txfm.c
index 6e7d8c9..7801f69 100644
--- a/aom_dsp/inv_txfm.c
+++ b/aom_dsp/inv_txfm.c
@@ -1375,6 +1375,8 @@
   }
 }
 
+// TODO(sarahparker) this one still needs to be removed but will be done in
+// a followup because of its use in encoder/encodemb.c
 void aom_highbd_idct4_c(const tran_low_t *input, tran_low_t *output, int bd) {
   tran_low_t step[4];
   tran_high_t temp1, temp2;
@@ -1441,869 +1443,4 @@
     dest += dest_stride;
   }
 }
-
-void aom_highbd_idct8_c(const tran_low_t *input, tran_low_t *output, int bd) {
-  tran_low_t step1[8], step2[8];
-  tran_high_t temp1, temp2;
-  // stage 1
-  step1[0] = input[0];
-  step1[2] = input[4];
-  step1[1] = input[2];
-  step1[3] = input[6];
-  temp1 = input[1] * cospi_28_64 - input[7] * cospi_4_64;
-  temp2 = input[1] * cospi_4_64 + input[7] * cospi_28_64;
-  step1[4] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[7] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = input[5] * cospi_12_64 - input[3] * cospi_20_64;
-  temp2 = input[5] * cospi_20_64 + input[3] * cospi_12_64;
-  step1[5] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[6] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  // stage 2 & stage 3 - even half
-  aom_highbd_idct4_c(step1, step1, bd);
-
-  // stage 2 - odd half
-  step2[4] = HIGHBD_WRAPLOW(step1[4] + step1[5], bd);
-  step2[5] = HIGHBD_WRAPLOW(step1[4] - step1[5], bd);
-  step2[6] = HIGHBD_WRAPLOW(-step1[6] + step1[7], bd);
-  step2[7] = HIGHBD_WRAPLOW(step1[6] + step1[7], bd);
-
-  // stage 3 - odd half
-  step1[4] = step2[4];
-  temp1 = (step2[6] - step2[5]) * cospi_16_64;
-  temp2 = (step2[5] + step2[6]) * cospi_16_64;
-  step1[5] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[6] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step1[7] = step2[7];
-
-  // stage 4
-  output[0] = HIGHBD_WRAPLOW(step1[0] + step1[7], bd);
-  output[1] = HIGHBD_WRAPLOW(step1[1] + step1[6], bd);
-  output[2] = HIGHBD_WRAPLOW(step1[2] + step1[5], bd);
-  output[3] = HIGHBD_WRAPLOW(step1[3] + step1[4], bd);
-  output[4] = HIGHBD_WRAPLOW(step1[3] - step1[4], bd);
-  output[5] = HIGHBD_WRAPLOW(step1[2] - step1[5], bd);
-  output[6] = HIGHBD_WRAPLOW(step1[1] - step1[6], bd);
-  output[7] = HIGHBD_WRAPLOW(step1[0] - step1[7], bd);
-}
-
-void aom_highbd_iadst4_c(const tran_low_t *input, tran_low_t *output, int bd) {
-  tran_high_t s0, s1, s2, s3, s4, s5, s6, s7;
-
-  tran_low_t x0 = input[0];
-  tran_low_t x1 = input[1];
-  tran_low_t x2 = input[2];
-  tran_low_t x3 = input[3];
-  (void)bd;
-
-  if (!(x0 | x1 | x2 | x3)) {
-    memset(output, 0, 4 * sizeof(*output));
-    return;
-  }
-
-  s0 = sinpi_1_9 * x0;
-  s1 = sinpi_2_9 * x0;
-  s2 = sinpi_3_9 * x1;
-  s3 = sinpi_4_9 * x2;
-  s4 = sinpi_1_9 * x2;
-  s5 = sinpi_2_9 * x3;
-  s6 = sinpi_4_9 * x3;
-  s7 = (tran_high_t)HIGHBD_WRAPLOW(x0 - x2 + x3, bd);
-
-  s0 = s0 + s3 + s5;
-  s1 = s1 - s4 - s6;
-  s3 = s2;
-  s2 = sinpi_3_9 * s7;
-
-  // 1-D transform scaling factor is sqrt(2).
-  // The overall dynamic range is 14b (input) + 14b (multiplication scaling)
-  // + 1b (addition) = 29b.
-  // Hence the output bit depth is 15b.
-  output[0] = HIGHBD_WRAPLOW(dct_const_round_shift(s0 + s3), bd);
-  output[1] = HIGHBD_WRAPLOW(dct_const_round_shift(s1 + s3), bd);
-  output[2] = HIGHBD_WRAPLOW(dct_const_round_shift(s2), bd);
-  output[3] = HIGHBD_WRAPLOW(dct_const_round_shift(s0 + s1 - s3), bd);
-}
-
-void aom_highbd_iadst8_c(const tran_low_t *input, tran_low_t *output, int bd) {
-  tran_high_t s0, s1, s2, s3, s4, s5, s6, s7;
-
-  tran_low_t x0 = input[7];
-  tran_low_t x1 = input[0];
-  tran_low_t x2 = input[5];
-  tran_low_t x3 = input[2];
-  tran_low_t x4 = input[3];
-  tran_low_t x5 = input[4];
-  tran_low_t x6 = input[1];
-  tran_low_t x7 = input[6];
-  (void)bd;
-
-  if (!(x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7)) {
-    memset(output, 0, 8 * sizeof(*output));
-    return;
-  }
-
-  // stage 1
-  s0 = cospi_2_64 * x0 + cospi_30_64 * x1;
-  s1 = cospi_30_64 * x0 - cospi_2_64 * x1;
-  s2 = cospi_10_64 * x2 + cospi_22_64 * x3;
-  s3 = cospi_22_64 * x2 - cospi_10_64 * x3;
-  s4 = cospi_18_64 * x4 + cospi_14_64 * x5;
-  s5 = cospi_14_64 * x4 - cospi_18_64 * x5;
-  s6 = cospi_26_64 * x6 + cospi_6_64 * x7;
-  s7 = cospi_6_64 * x6 - cospi_26_64 * x7;
-
-  x0 = HIGHBD_WRAPLOW(dct_const_round_shift(s0 + s4), bd);
-  x1 = HIGHBD_WRAPLOW(dct_const_round_shift(s1 + s5), bd);
-  x2 = HIGHBD_WRAPLOW(dct_const_round_shift(s2 + s6), bd);
-  x3 = HIGHBD_WRAPLOW(dct_const_round_shift(s3 + s7), bd);
-  x4 = HIGHBD_WRAPLOW(dct_const_round_shift(s0 - s4), bd);
-  x5 = HIGHBD_WRAPLOW(dct_const_round_shift(s1 - s5), bd);
-  x6 = HIGHBD_WRAPLOW(dct_const_round_shift(s2 - s6), bd);
-  x7 = HIGHBD_WRAPLOW(dct_const_round_shift(s3 - s7), bd);
-
-  // stage 2
-  s0 = x0;
-  s1 = x1;
-  s2 = x2;
-  s3 = x3;
-  s4 = cospi_8_64 * x4 + cospi_24_64 * x5;
-  s5 = cospi_24_64 * x4 - cospi_8_64 * x5;
-  s6 = -cospi_24_64 * x6 + cospi_8_64 * x7;
-  s7 = cospi_8_64 * x6 + cospi_24_64 * x7;
-
-  x0 = HIGHBD_WRAPLOW(s0 + s2, bd);
-  x1 = HIGHBD_WRAPLOW(s1 + s3, bd);
-  x2 = HIGHBD_WRAPLOW(s0 - s2, bd);
-  x3 = HIGHBD_WRAPLOW(s1 - s3, bd);
-  x4 = HIGHBD_WRAPLOW(dct_const_round_shift(s4 + s6), bd);
-  x5 = HIGHBD_WRAPLOW(dct_const_round_shift(s5 + s7), bd);
-  x6 = HIGHBD_WRAPLOW(dct_const_round_shift(s4 - s6), bd);
-  x7 = HIGHBD_WRAPLOW(dct_const_round_shift(s5 - s7), bd);
-
-  // stage 3
-  s2 = cospi_16_64 * (x2 + x3);
-  s3 = cospi_16_64 * (x2 - x3);
-  s6 = cospi_16_64 * (x6 + x7);
-  s7 = cospi_16_64 * (x6 - x7);
-
-  x2 = HIGHBD_WRAPLOW(dct_const_round_shift(s2), bd);
-  x3 = HIGHBD_WRAPLOW(dct_const_round_shift(s3), bd);
-  x6 = HIGHBD_WRAPLOW(dct_const_round_shift(s6), bd);
-  x7 = HIGHBD_WRAPLOW(dct_const_round_shift(s7), bd);
-
-  output[0] = HIGHBD_WRAPLOW(x0, bd);
-  output[1] = HIGHBD_WRAPLOW(-x4, bd);
-  output[2] = HIGHBD_WRAPLOW(x6, bd);
-  output[3] = HIGHBD_WRAPLOW(-x2, bd);
-  output[4] = HIGHBD_WRAPLOW(x3, bd);
-  output[5] = HIGHBD_WRAPLOW(-x7, bd);
-  output[6] = HIGHBD_WRAPLOW(x5, bd);
-  output[7] = HIGHBD_WRAPLOW(-x1, bd);
-}
-
-void aom_highbd_idct16_c(const tran_low_t *input, tran_low_t *output, int bd) {
-  tran_low_t step1[16], step2[16];
-  tran_high_t temp1, temp2;
-  (void)bd;
-
-  // stage 1
-  step1[0] = input[0 / 2];
-  step1[1] = input[16 / 2];
-  step1[2] = input[8 / 2];
-  step1[3] = input[24 / 2];
-  step1[4] = input[4 / 2];
-  step1[5] = input[20 / 2];
-  step1[6] = input[12 / 2];
-  step1[7] = input[28 / 2];
-  step1[8] = input[2 / 2];
-  step1[9] = input[18 / 2];
-  step1[10] = input[10 / 2];
-  step1[11] = input[26 / 2];
-  step1[12] = input[6 / 2];
-  step1[13] = input[22 / 2];
-  step1[14] = input[14 / 2];
-  step1[15] = input[30 / 2];
-
-  // stage 2
-  step2[0] = step1[0];
-  step2[1] = step1[1];
-  step2[2] = step1[2];
-  step2[3] = step1[3];
-  step2[4] = step1[4];
-  step2[5] = step1[5];
-  step2[6] = step1[6];
-  step2[7] = step1[7];
-
-  temp1 = step1[8] * cospi_30_64 - step1[15] * cospi_2_64;
-  temp2 = step1[8] * cospi_2_64 + step1[15] * cospi_30_64;
-  step2[8] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[15] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = step1[9] * cospi_14_64 - step1[14] * cospi_18_64;
-  temp2 = step1[9] * cospi_18_64 + step1[14] * cospi_14_64;
-  step2[9] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[14] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = step1[10] * cospi_22_64 - step1[13] * cospi_10_64;
-  temp2 = step1[10] * cospi_10_64 + step1[13] * cospi_22_64;
-  step2[10] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[13] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = step1[11] * cospi_6_64 - step1[12] * cospi_26_64;
-  temp2 = step1[11] * cospi_26_64 + step1[12] * cospi_6_64;
-  step2[11] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[12] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  // stage 3
-  step1[0] = step2[0];
-  step1[1] = step2[1];
-  step1[2] = step2[2];
-  step1[3] = step2[3];
-
-  temp1 = step2[4] * cospi_28_64 - step2[7] * cospi_4_64;
-  temp2 = step2[4] * cospi_4_64 + step2[7] * cospi_28_64;
-  step1[4] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[7] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = step2[5] * cospi_12_64 - step2[6] * cospi_20_64;
-  temp2 = step2[5] * cospi_20_64 + step2[6] * cospi_12_64;
-  step1[5] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[6] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  step1[8] = HIGHBD_WRAPLOW(step2[8] + step2[9], bd);
-  step1[9] = HIGHBD_WRAPLOW(step2[8] - step2[9], bd);
-  step1[10] = HIGHBD_WRAPLOW(-step2[10] + step2[11], bd);
-  step1[11] = HIGHBD_WRAPLOW(step2[10] + step2[11], bd);
-  step1[12] = HIGHBD_WRAPLOW(step2[12] + step2[13], bd);
-  step1[13] = HIGHBD_WRAPLOW(step2[12] - step2[13], bd);
-  step1[14] = HIGHBD_WRAPLOW(-step2[14] + step2[15], bd);
-  step1[15] = HIGHBD_WRAPLOW(step2[14] + step2[15], bd);
-
-  // stage 4
-  temp1 = (step1[0] + step1[1]) * cospi_16_64;
-  temp2 = (step1[0] - step1[1]) * cospi_16_64;
-  step2[0] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[1] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = step1[2] * cospi_24_64 - step1[3] * cospi_8_64;
-  temp2 = step1[2] * cospi_8_64 + step1[3] * cospi_24_64;
-  step2[2] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[3] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step2[4] = HIGHBD_WRAPLOW(step1[4] + step1[5], bd);
-  step2[5] = HIGHBD_WRAPLOW(step1[4] - step1[5], bd);
-  step2[6] = HIGHBD_WRAPLOW(-step1[6] + step1[7], bd);
-  step2[7] = HIGHBD_WRAPLOW(step1[6] + step1[7], bd);
-
-  step2[8] = step1[8];
-  step2[15] = step1[15];
-  temp1 = -step1[9] * cospi_8_64 + step1[14] * cospi_24_64;
-  temp2 = step1[9] * cospi_24_64 + step1[14] * cospi_8_64;
-  step2[9] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[14] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = -step1[10] * cospi_24_64 - step1[13] * cospi_8_64;
-  temp2 = -step1[10] * cospi_8_64 + step1[13] * cospi_24_64;
-  step2[10] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[13] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step2[11] = step1[11];
-  step2[12] = step1[12];
-
-  // stage 5
-  step1[0] = HIGHBD_WRAPLOW(step2[0] + step2[3], bd);
-  step1[1] = HIGHBD_WRAPLOW(step2[1] + step2[2], bd);
-  step1[2] = HIGHBD_WRAPLOW(step2[1] - step2[2], bd);
-  step1[3] = HIGHBD_WRAPLOW(step2[0] - step2[3], bd);
-  step1[4] = step2[4];
-  temp1 = (step2[6] - step2[5]) * cospi_16_64;
-  temp2 = (step2[5] + step2[6]) * cospi_16_64;
-  step1[5] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[6] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step1[7] = step2[7];
-
-  step1[8] = HIGHBD_WRAPLOW(step2[8] + step2[11], bd);
-  step1[9] = HIGHBD_WRAPLOW(step2[9] + step2[10], bd);
-  step1[10] = HIGHBD_WRAPLOW(step2[9] - step2[10], bd);
-  step1[11] = HIGHBD_WRAPLOW(step2[8] - step2[11], bd);
-  step1[12] = HIGHBD_WRAPLOW(-step2[12] + step2[15], bd);
-  step1[13] = HIGHBD_WRAPLOW(-step2[13] + step2[14], bd);
-  step1[14] = HIGHBD_WRAPLOW(step2[13] + step2[14], bd);
-  step1[15] = HIGHBD_WRAPLOW(step2[12] + step2[15], bd);
-
-  // stage 6
-  step2[0] = HIGHBD_WRAPLOW(step1[0] + step1[7], bd);
-  step2[1] = HIGHBD_WRAPLOW(step1[1] + step1[6], bd);
-  step2[2] = HIGHBD_WRAPLOW(step1[2] + step1[5], bd);
-  step2[3] = HIGHBD_WRAPLOW(step1[3] + step1[4], bd);
-  step2[4] = HIGHBD_WRAPLOW(step1[3] - step1[4], bd);
-  step2[5] = HIGHBD_WRAPLOW(step1[2] - step1[5], bd);
-  step2[6] = HIGHBD_WRAPLOW(step1[1] - step1[6], bd);
-  step2[7] = HIGHBD_WRAPLOW(step1[0] - step1[7], bd);
-  step2[8] = step1[8];
-  step2[9] = step1[9];
-  temp1 = (-step1[10] + step1[13]) * cospi_16_64;
-  temp2 = (step1[10] + step1[13]) * cospi_16_64;
-  step2[10] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[13] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = (-step1[11] + step1[12]) * cospi_16_64;
-  temp2 = (step1[11] + step1[12]) * cospi_16_64;
-  step2[11] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[12] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step2[14] = step1[14];
-  step2[15] = step1[15];
-
-  // stage 7
-  output[0] = HIGHBD_WRAPLOW(step2[0] + step2[15], bd);
-  output[1] = HIGHBD_WRAPLOW(step2[1] + step2[14], bd);
-  output[2] = HIGHBD_WRAPLOW(step2[2] + step2[13], bd);
-  output[3] = HIGHBD_WRAPLOW(step2[3] + step2[12], bd);
-  output[4] = HIGHBD_WRAPLOW(step2[4] + step2[11], bd);
-  output[5] = HIGHBD_WRAPLOW(step2[5] + step2[10], bd);
-  output[6] = HIGHBD_WRAPLOW(step2[6] + step2[9], bd);
-  output[7] = HIGHBD_WRAPLOW(step2[7] + step2[8], bd);
-  output[8] = HIGHBD_WRAPLOW(step2[7] - step2[8], bd);
-  output[9] = HIGHBD_WRAPLOW(step2[6] - step2[9], bd);
-  output[10] = HIGHBD_WRAPLOW(step2[5] - step2[10], bd);
-  output[11] = HIGHBD_WRAPLOW(step2[4] - step2[11], bd);
-  output[12] = HIGHBD_WRAPLOW(step2[3] - step2[12], bd);
-  output[13] = HIGHBD_WRAPLOW(step2[2] - step2[13], bd);
-  output[14] = HIGHBD_WRAPLOW(step2[1] - step2[14], bd);
-  output[15] = HIGHBD_WRAPLOW(step2[0] - step2[15], bd);
-}
-
-void aom_highbd_iadst16_c(const tran_low_t *input, tran_low_t *output, int bd) {
-  tran_high_t s0, s1, s2, s3, s4, s5, s6, s7, s8;
-  tran_high_t s9, s10, s11, s12, s13, s14, s15;
-
-  tran_low_t x0 = input[15];
-  tran_low_t x1 = input[0];
-  tran_low_t x2 = input[13];
-  tran_low_t x3 = input[2];
-  tran_low_t x4 = input[11];
-  tran_low_t x5 = input[4];
-  tran_low_t x6 = input[9];
-  tran_low_t x7 = input[6];
-  tran_low_t x8 = input[7];
-  tran_low_t x9 = input[8];
-  tran_low_t x10 = input[5];
-  tran_low_t x11 = input[10];
-  tran_low_t x12 = input[3];
-  tran_low_t x13 = input[12];
-  tran_low_t x14 = input[1];
-  tran_low_t x15 = input[14];
-  (void)bd;
-
-  if (!(x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | x11 | x12 |
-        x13 | x14 | x15)) {
-    memset(output, 0, 16 * sizeof(*output));
-    return;
-  }
-
-  // stage 1
-  s0 = x0 * cospi_1_64 + x1 * cospi_31_64;
-  s1 = x0 * cospi_31_64 - x1 * cospi_1_64;
-  s2 = x2 * cospi_5_64 + x3 * cospi_27_64;
-  s3 = x2 * cospi_27_64 - x3 * cospi_5_64;
-  s4 = x4 * cospi_9_64 + x5 * cospi_23_64;
-  s5 = x4 * cospi_23_64 - x5 * cospi_9_64;
-  s6 = x6 * cospi_13_64 + x7 * cospi_19_64;
-  s7 = x6 * cospi_19_64 - x7 * cospi_13_64;
-  s8 = x8 * cospi_17_64 + x9 * cospi_15_64;
-  s9 = x8 * cospi_15_64 - x9 * cospi_17_64;
-  s10 = x10 * cospi_21_64 + x11 * cospi_11_64;
-  s11 = x10 * cospi_11_64 - x11 * cospi_21_64;
-  s12 = x12 * cospi_25_64 + x13 * cospi_7_64;
-  s13 = x12 * cospi_7_64 - x13 * cospi_25_64;
-  s14 = x14 * cospi_29_64 + x15 * cospi_3_64;
-  s15 = x14 * cospi_3_64 - x15 * cospi_29_64;
-
-  x0 = HIGHBD_WRAPLOW(dct_const_round_shift(s0 + s8), bd);
-  x1 = HIGHBD_WRAPLOW(dct_const_round_shift(s1 + s9), bd);
-  x2 = HIGHBD_WRAPLOW(dct_const_round_shift(s2 + s10), bd);
-  x3 = HIGHBD_WRAPLOW(dct_const_round_shift(s3 + s11), bd);
-  x4 = HIGHBD_WRAPLOW(dct_const_round_shift(s4 + s12), bd);
-  x5 = HIGHBD_WRAPLOW(dct_const_round_shift(s5 + s13), bd);
-  x6 = HIGHBD_WRAPLOW(dct_const_round_shift(s6 + s14), bd);
-  x7 = HIGHBD_WRAPLOW(dct_const_round_shift(s7 + s15), bd);
-  x8 = HIGHBD_WRAPLOW(dct_const_round_shift(s0 - s8), bd);
-  x9 = HIGHBD_WRAPLOW(dct_const_round_shift(s1 - s9), bd);
-  x10 = HIGHBD_WRAPLOW(dct_const_round_shift(s2 - s10), bd);
-  x11 = HIGHBD_WRAPLOW(dct_const_round_shift(s3 - s11), bd);
-  x12 = HIGHBD_WRAPLOW(dct_const_round_shift(s4 - s12), bd);
-  x13 = HIGHBD_WRAPLOW(dct_const_round_shift(s5 - s13), bd);
-  x14 = HIGHBD_WRAPLOW(dct_const_round_shift(s6 - s14), bd);
-  x15 = HIGHBD_WRAPLOW(dct_const_round_shift(s7 - s15), bd);
-
-  // stage 2
-  s0 = x0;
-  s1 = x1;
-  s2 = x2;
-  s3 = x3;
-  s4 = x4;
-  s5 = x5;
-  s6 = x6;
-  s7 = x7;
-  s8 = x8 * cospi_4_64 + x9 * cospi_28_64;
-  s9 = x8 * cospi_28_64 - x9 * cospi_4_64;
-  s10 = x10 * cospi_20_64 + x11 * cospi_12_64;
-  s11 = x10 * cospi_12_64 - x11 * cospi_20_64;
-  s12 = -x12 * cospi_28_64 + x13 * cospi_4_64;
-  s13 = x12 * cospi_4_64 + x13 * cospi_28_64;
-  s14 = -x14 * cospi_12_64 + x15 * cospi_20_64;
-  s15 = x14 * cospi_20_64 + x15 * cospi_12_64;
-
-  x0 = HIGHBD_WRAPLOW(s0 + s4, bd);
-  x1 = HIGHBD_WRAPLOW(s1 + s5, bd);
-  x2 = HIGHBD_WRAPLOW(s2 + s6, bd);
-  x3 = HIGHBD_WRAPLOW(s3 + s7, bd);
-  x4 = HIGHBD_WRAPLOW(s0 - s4, bd);
-  x5 = HIGHBD_WRAPLOW(s1 - s5, bd);
-  x6 = HIGHBD_WRAPLOW(s2 - s6, bd);
-  x7 = HIGHBD_WRAPLOW(s3 - s7, bd);
-  x8 = HIGHBD_WRAPLOW(dct_const_round_shift(s8 + s12), bd);
-  x9 = HIGHBD_WRAPLOW(dct_const_round_shift(s9 + s13), bd);
-  x10 = HIGHBD_WRAPLOW(dct_const_round_shift(s10 + s14), bd);
-  x11 = HIGHBD_WRAPLOW(dct_const_round_shift(s11 + s15), bd);
-  x12 = HIGHBD_WRAPLOW(dct_const_round_shift(s8 - s12), bd);
-  x13 = HIGHBD_WRAPLOW(dct_const_round_shift(s9 - s13), bd);
-  x14 = HIGHBD_WRAPLOW(dct_const_round_shift(s10 - s14), bd);
-  x15 = HIGHBD_WRAPLOW(dct_const_round_shift(s11 - s15), bd);
-
-  // stage 3
-  s0 = x0;
-  s1 = x1;
-  s2 = x2;
-  s3 = x3;
-  s4 = x4 * cospi_8_64 + x5 * cospi_24_64;
-  s5 = x4 * cospi_24_64 - x5 * cospi_8_64;
-  s6 = -x6 * cospi_24_64 + x7 * cospi_8_64;
-  s7 = x6 * cospi_8_64 + x7 * cospi_24_64;
-  s8 = x8;
-  s9 = x9;
-  s10 = x10;
-  s11 = x11;
-  s12 = x12 * cospi_8_64 + x13 * cospi_24_64;
-  s13 = x12 * cospi_24_64 - x13 * cospi_8_64;
-  s14 = -x14 * cospi_24_64 + x15 * cospi_8_64;
-  s15 = x14 * cospi_8_64 + x15 * cospi_24_64;
-
-  x0 = HIGHBD_WRAPLOW(s0 + s2, bd);
-  x1 = HIGHBD_WRAPLOW(s1 + s3, bd);
-  x2 = HIGHBD_WRAPLOW(s0 - s2, bd);
-  x3 = HIGHBD_WRAPLOW(s1 - s3, bd);
-  x4 = HIGHBD_WRAPLOW(dct_const_round_shift(s4 + s6), bd);
-  x5 = HIGHBD_WRAPLOW(dct_const_round_shift(s5 + s7), bd);
-  x6 = HIGHBD_WRAPLOW(dct_const_round_shift(s4 - s6), bd);
-  x7 = HIGHBD_WRAPLOW(dct_const_round_shift(s5 - s7), bd);
-  x8 = HIGHBD_WRAPLOW(s8 + s10, bd);
-  x9 = HIGHBD_WRAPLOW(s9 + s11, bd);
-  x10 = HIGHBD_WRAPLOW(s8 - s10, bd);
-  x11 = HIGHBD_WRAPLOW(s9 - s11, bd);
-  x12 = HIGHBD_WRAPLOW(dct_const_round_shift(s12 + s14), bd);
-  x13 = HIGHBD_WRAPLOW(dct_const_round_shift(s13 + s15), bd);
-  x14 = HIGHBD_WRAPLOW(dct_const_round_shift(s12 - s14), bd);
-  x15 = HIGHBD_WRAPLOW(dct_const_round_shift(s13 - s15), bd);
-
-  // stage 4
-  s2 = (-cospi_16_64) * (x2 + x3);
-  s3 = cospi_16_64 * (x2 - x3);
-  s6 = cospi_16_64 * (x6 + x7);
-  s7 = cospi_16_64 * (-x6 + x7);
-  s10 = cospi_16_64 * (x10 + x11);
-  s11 = cospi_16_64 * (-x10 + x11);
-  s14 = (-cospi_16_64) * (x14 + x15);
-  s15 = cospi_16_64 * (x14 - x15);
-
-  x2 = HIGHBD_WRAPLOW(dct_const_round_shift(s2), bd);
-  x3 = HIGHBD_WRAPLOW(dct_const_round_shift(s3), bd);
-  x6 = HIGHBD_WRAPLOW(dct_const_round_shift(s6), bd);
-  x7 = HIGHBD_WRAPLOW(dct_const_round_shift(s7), bd);
-  x10 = HIGHBD_WRAPLOW(dct_const_round_shift(s10), bd);
-  x11 = HIGHBD_WRAPLOW(dct_const_round_shift(s11), bd);
-  x14 = HIGHBD_WRAPLOW(dct_const_round_shift(s14), bd);
-  x15 = HIGHBD_WRAPLOW(dct_const_round_shift(s15), bd);
-
-  output[0] = HIGHBD_WRAPLOW(x0, bd);
-  output[1] = HIGHBD_WRAPLOW(-x8, bd);
-  output[2] = HIGHBD_WRAPLOW(x12, bd);
-  output[3] = HIGHBD_WRAPLOW(-x4, bd);
-  output[4] = HIGHBD_WRAPLOW(x6, bd);
-  output[5] = HIGHBD_WRAPLOW(x14, bd);
-  output[6] = HIGHBD_WRAPLOW(x10, bd);
-  output[7] = HIGHBD_WRAPLOW(x2, bd);
-  output[8] = HIGHBD_WRAPLOW(x3, bd);
-  output[9] = HIGHBD_WRAPLOW(x11, bd);
-  output[10] = HIGHBD_WRAPLOW(x15, bd);
-  output[11] = HIGHBD_WRAPLOW(x7, bd);
-  output[12] = HIGHBD_WRAPLOW(x5, bd);
-  output[13] = HIGHBD_WRAPLOW(-x13, bd);
-  output[14] = HIGHBD_WRAPLOW(x9, bd);
-  output[15] = HIGHBD_WRAPLOW(-x1, bd);
-}
-
-void aom_highbd_idct32_c(const tran_low_t *input, tran_low_t *output, int bd) {
-  tran_low_t step1[32], step2[32];
-  tran_high_t temp1, temp2;
-  (void)bd;
-
-  // stage 1
-  step1[0] = input[0];
-  step1[1] = input[16];
-  step1[2] = input[8];
-  step1[3] = input[24];
-  step1[4] = input[4];
-  step1[5] = input[20];
-  step1[6] = input[12];
-  step1[7] = input[28];
-  step1[8] = input[2];
-  step1[9] = input[18];
-  step1[10] = input[10];
-  step1[11] = input[26];
-  step1[12] = input[6];
-  step1[13] = input[22];
-  step1[14] = input[14];
-  step1[15] = input[30];
-
-  temp1 = input[1] * cospi_31_64 - input[31] * cospi_1_64;
-  temp2 = input[1] * cospi_1_64 + input[31] * cospi_31_64;
-  step1[16] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[31] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = input[17] * cospi_15_64 - input[15] * cospi_17_64;
-  temp2 = input[17] * cospi_17_64 + input[15] * cospi_15_64;
-  step1[17] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[30] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = input[9] * cospi_23_64 - input[23] * cospi_9_64;
-  temp2 = input[9] * cospi_9_64 + input[23] * cospi_23_64;
-  step1[18] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[29] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = input[25] * cospi_7_64 - input[7] * cospi_25_64;
-  temp2 = input[25] * cospi_25_64 + input[7] * cospi_7_64;
-  step1[19] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[28] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = input[5] * cospi_27_64 - input[27] * cospi_5_64;
-  temp2 = input[5] * cospi_5_64 + input[27] * cospi_27_64;
-  step1[20] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[27] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = input[21] * cospi_11_64 - input[11] * cospi_21_64;
-  temp2 = input[21] * cospi_21_64 + input[11] * cospi_11_64;
-  step1[21] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[26] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = input[13] * cospi_19_64 - input[19] * cospi_13_64;
-  temp2 = input[13] * cospi_13_64 + input[19] * cospi_19_64;
-  step1[22] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[25] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = input[29] * cospi_3_64 - input[3] * cospi_29_64;
-  temp2 = input[29] * cospi_29_64 + input[3] * cospi_3_64;
-  step1[23] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[24] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  // stage 2
-  step2[0] = step1[0];
-  step2[1] = step1[1];
-  step2[2] = step1[2];
-  step2[3] = step1[3];
-  step2[4] = step1[4];
-  step2[5] = step1[5];
-  step2[6] = step1[6];
-  step2[7] = step1[7];
-
-  temp1 = step1[8] * cospi_30_64 - step1[15] * cospi_2_64;
-  temp2 = step1[8] * cospi_2_64 + step1[15] * cospi_30_64;
-  step2[8] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[15] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = step1[9] * cospi_14_64 - step1[14] * cospi_18_64;
-  temp2 = step1[9] * cospi_18_64 + step1[14] * cospi_14_64;
-  step2[9] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[14] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = step1[10] * cospi_22_64 - step1[13] * cospi_10_64;
-  temp2 = step1[10] * cospi_10_64 + step1[13] * cospi_22_64;
-  step2[10] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[13] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  temp1 = step1[11] * cospi_6_64 - step1[12] * cospi_26_64;
-  temp2 = step1[11] * cospi_26_64 + step1[12] * cospi_6_64;
-  step2[11] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[12] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  step2[16] = HIGHBD_WRAPLOW(step1[16] + step1[17], bd);
-  step2[17] = HIGHBD_WRAPLOW(step1[16] - step1[17], bd);
-  step2[18] = HIGHBD_WRAPLOW(-step1[18] + step1[19], bd);
-  step2[19] = HIGHBD_WRAPLOW(step1[18] + step1[19], bd);
-  step2[20] = HIGHBD_WRAPLOW(step1[20] + step1[21], bd);
-  step2[21] = HIGHBD_WRAPLOW(step1[20] - step1[21], bd);
-  step2[22] = HIGHBD_WRAPLOW(-step1[22] + step1[23], bd);
-  step2[23] = HIGHBD_WRAPLOW(step1[22] + step1[23], bd);
-  step2[24] = HIGHBD_WRAPLOW(step1[24] + step1[25], bd);
-  step2[25] = HIGHBD_WRAPLOW(step1[24] - step1[25], bd);
-  step2[26] = HIGHBD_WRAPLOW(-step1[26] + step1[27], bd);
-  step2[27] = HIGHBD_WRAPLOW(step1[26] + step1[27], bd);
-  step2[28] = HIGHBD_WRAPLOW(step1[28] + step1[29], bd);
-  step2[29] = HIGHBD_WRAPLOW(step1[28] - step1[29], bd);
-  step2[30] = HIGHBD_WRAPLOW(-step1[30] + step1[31], bd);
-  step2[31] = HIGHBD_WRAPLOW(step1[30] + step1[31], bd);
-
-  // stage 3
-  step1[0] = step2[0];
-  step1[1] = step2[1];
-  step1[2] = step2[2];
-  step1[3] = step2[3];
-
-  temp1 = step2[4] * cospi_28_64 - step2[7] * cospi_4_64;
-  temp2 = step2[4] * cospi_4_64 + step2[7] * cospi_28_64;
-  step1[4] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[7] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = step2[5] * cospi_12_64 - step2[6] * cospi_20_64;
-  temp2 = step2[5] * cospi_20_64 + step2[6] * cospi_12_64;
-  step1[5] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[6] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-
-  step1[8] = HIGHBD_WRAPLOW(step2[8] + step2[9], bd);
-  step1[9] = HIGHBD_WRAPLOW(step2[8] - step2[9], bd);
-  step1[10] = HIGHBD_WRAPLOW(-step2[10] + step2[11], bd);
-  step1[11] = HIGHBD_WRAPLOW(step2[10] + step2[11], bd);
-  step1[12] = HIGHBD_WRAPLOW(step2[12] + step2[13], bd);
-  step1[13] = HIGHBD_WRAPLOW(step2[12] - step2[13], bd);
-  step1[14] = HIGHBD_WRAPLOW(-step2[14] + step2[15], bd);
-  step1[15] = HIGHBD_WRAPLOW(step2[14] + step2[15], bd);
-
-  step1[16] = step2[16];
-  step1[31] = step2[31];
-  temp1 = -step2[17] * cospi_4_64 + step2[30] * cospi_28_64;
-  temp2 = step2[17] * cospi_28_64 + step2[30] * cospi_4_64;
-  step1[17] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[30] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = -step2[18] * cospi_28_64 - step2[29] * cospi_4_64;
-  temp2 = -step2[18] * cospi_4_64 + step2[29] * cospi_28_64;
-  step1[18] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[29] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step1[19] = step2[19];
-  step1[20] = step2[20];
-  temp1 = -step2[21] * cospi_20_64 + step2[26] * cospi_12_64;
-  temp2 = step2[21] * cospi_12_64 + step2[26] * cospi_20_64;
-  step1[21] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[26] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = -step2[22] * cospi_12_64 - step2[25] * cospi_20_64;
-  temp2 = -step2[22] * cospi_20_64 + step2[25] * cospi_12_64;
-  step1[22] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[25] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step1[23] = step2[23];
-  step1[24] = step2[24];
-  step1[27] = step2[27];
-  step1[28] = step2[28];
-
-  // stage 4
-  temp1 = (step1[0] + step1[1]) * cospi_16_64;
-  temp2 = (step1[0] - step1[1]) * cospi_16_64;
-  step2[0] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[1] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = step1[2] * cospi_24_64 - step1[3] * cospi_8_64;
-  temp2 = step1[2] * cospi_8_64 + step1[3] * cospi_24_64;
-  step2[2] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[3] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step2[4] = HIGHBD_WRAPLOW(step1[4] + step1[5], bd);
-  step2[5] = HIGHBD_WRAPLOW(step1[4] - step1[5], bd);
-  step2[6] = HIGHBD_WRAPLOW(-step1[6] + step1[7], bd);
-  step2[7] = HIGHBD_WRAPLOW(step1[6] + step1[7], bd);
-
-  step2[8] = step1[8];
-  step2[15] = step1[15];
-  temp1 = -step1[9] * cospi_8_64 + step1[14] * cospi_24_64;
-  temp2 = step1[9] * cospi_24_64 + step1[14] * cospi_8_64;
-  step2[9] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[14] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = -step1[10] * cospi_24_64 - step1[13] * cospi_8_64;
-  temp2 = -step1[10] * cospi_8_64 + step1[13] * cospi_24_64;
-  step2[10] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[13] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step2[11] = step1[11];
-  step2[12] = step1[12];
-
-  step2[16] = HIGHBD_WRAPLOW(step1[16] + step1[19], bd);
-  step2[17] = HIGHBD_WRAPLOW(step1[17] + step1[18], bd);
-  step2[18] = HIGHBD_WRAPLOW(step1[17] - step1[18], bd);
-  step2[19] = HIGHBD_WRAPLOW(step1[16] - step1[19], bd);
-  step2[20] = HIGHBD_WRAPLOW(-step1[20] + step1[23], bd);
-  step2[21] = HIGHBD_WRAPLOW(-step1[21] + step1[22], bd);
-  step2[22] = HIGHBD_WRAPLOW(step1[21] + step1[22], bd);
-  step2[23] = HIGHBD_WRAPLOW(step1[20] + step1[23], bd);
-
-  step2[24] = HIGHBD_WRAPLOW(step1[24] + step1[27], bd);
-  step2[25] = HIGHBD_WRAPLOW(step1[25] + step1[26], bd);
-  step2[26] = HIGHBD_WRAPLOW(step1[25] - step1[26], bd);
-  step2[27] = HIGHBD_WRAPLOW(step1[24] - step1[27], bd);
-  step2[28] = HIGHBD_WRAPLOW(-step1[28] + step1[31], bd);
-  step2[29] = HIGHBD_WRAPLOW(-step1[29] + step1[30], bd);
-  step2[30] = HIGHBD_WRAPLOW(step1[29] + step1[30], bd);
-  step2[31] = HIGHBD_WRAPLOW(step1[28] + step1[31], bd);
-
-  // stage 5
-  step1[0] = HIGHBD_WRAPLOW(step2[0] + step2[3], bd);
-  step1[1] = HIGHBD_WRAPLOW(step2[1] + step2[2], bd);
-  step1[2] = HIGHBD_WRAPLOW(step2[1] - step2[2], bd);
-  step1[3] = HIGHBD_WRAPLOW(step2[0] - step2[3], bd);
-  step1[4] = step2[4];
-  temp1 = (step2[6] - step2[5]) * cospi_16_64;
-  temp2 = (step2[5] + step2[6]) * cospi_16_64;
-  step1[5] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[6] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step1[7] = step2[7];
-
-  step1[8] = HIGHBD_WRAPLOW(step2[8] + step2[11], bd);
-  step1[9] = HIGHBD_WRAPLOW(step2[9] + step2[10], bd);
-  step1[10] = HIGHBD_WRAPLOW(step2[9] - step2[10], bd);
-  step1[11] = HIGHBD_WRAPLOW(step2[8] - step2[11], bd);
-  step1[12] = HIGHBD_WRAPLOW(-step2[12] + step2[15], bd);
-  step1[13] = HIGHBD_WRAPLOW(-step2[13] + step2[14], bd);
-  step1[14] = HIGHBD_WRAPLOW(step2[13] + step2[14], bd);
-  step1[15] = HIGHBD_WRAPLOW(step2[12] + step2[15], bd);
-
-  step1[16] = step2[16];
-  step1[17] = step2[17];
-  temp1 = -step2[18] * cospi_8_64 + step2[29] * cospi_24_64;
-  temp2 = step2[18] * cospi_24_64 + step2[29] * cospi_8_64;
-  step1[18] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[29] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = -step2[19] * cospi_8_64 + step2[28] * cospi_24_64;
-  temp2 = step2[19] * cospi_24_64 + step2[28] * cospi_8_64;
-  step1[19] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[28] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = -step2[20] * cospi_24_64 - step2[27] * cospi_8_64;
-  temp2 = -step2[20] * cospi_8_64 + step2[27] * cospi_24_64;
-  step1[20] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[27] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = -step2[21] * cospi_24_64 - step2[26] * cospi_8_64;
-  temp2 = -step2[21] * cospi_8_64 + step2[26] * cospi_24_64;
-  step1[21] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[26] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step1[22] = step2[22];
-  step1[23] = step2[23];
-  step1[24] = step2[24];
-  step1[25] = step2[25];
-  step1[30] = step2[30];
-  step1[31] = step2[31];
-
-  // stage 6
-  step2[0] = HIGHBD_WRAPLOW(step1[0] + step1[7], bd);
-  step2[1] = HIGHBD_WRAPLOW(step1[1] + step1[6], bd);
-  step2[2] = HIGHBD_WRAPLOW(step1[2] + step1[5], bd);
-  step2[3] = HIGHBD_WRAPLOW(step1[3] + step1[4], bd);
-  step2[4] = HIGHBD_WRAPLOW(step1[3] - step1[4], bd);
-  step2[5] = HIGHBD_WRAPLOW(step1[2] - step1[5], bd);
-  step2[6] = HIGHBD_WRAPLOW(step1[1] - step1[6], bd);
-  step2[7] = HIGHBD_WRAPLOW(step1[0] - step1[7], bd);
-  step2[8] = step1[8];
-  step2[9] = step1[9];
-  temp1 = (-step1[10] + step1[13]) * cospi_16_64;
-  temp2 = (step1[10] + step1[13]) * cospi_16_64;
-  step2[10] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[13] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = (-step1[11] + step1[12]) * cospi_16_64;
-  temp2 = (step1[11] + step1[12]) * cospi_16_64;
-  step2[11] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step2[12] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step2[14] = step1[14];
-  step2[15] = step1[15];
-
-  step2[16] = HIGHBD_WRAPLOW(step1[16] + step1[23], bd);
-  step2[17] = HIGHBD_WRAPLOW(step1[17] + step1[22], bd);
-  step2[18] = HIGHBD_WRAPLOW(step1[18] + step1[21], bd);
-  step2[19] = HIGHBD_WRAPLOW(step1[19] + step1[20], bd);
-  step2[20] = HIGHBD_WRAPLOW(step1[19] - step1[20], bd);
-  step2[21] = HIGHBD_WRAPLOW(step1[18] - step1[21], bd);
-  step2[22] = HIGHBD_WRAPLOW(step1[17] - step1[22], bd);
-  step2[23] = HIGHBD_WRAPLOW(step1[16] - step1[23], bd);
-
-  step2[24] = HIGHBD_WRAPLOW(-step1[24] + step1[31], bd);
-  step2[25] = HIGHBD_WRAPLOW(-step1[25] + step1[30], bd);
-  step2[26] = HIGHBD_WRAPLOW(-step1[26] + step1[29], bd);
-  step2[27] = HIGHBD_WRAPLOW(-step1[27] + step1[28], bd);
-  step2[28] = HIGHBD_WRAPLOW(step1[27] + step1[28], bd);
-  step2[29] = HIGHBD_WRAPLOW(step1[26] + step1[29], bd);
-  step2[30] = HIGHBD_WRAPLOW(step1[25] + step1[30], bd);
-  step2[31] = HIGHBD_WRAPLOW(step1[24] + step1[31], bd);
-
-  // stage 7
-  step1[0] = HIGHBD_WRAPLOW(step2[0] + step2[15], bd);
-  step1[1] = HIGHBD_WRAPLOW(step2[1] + step2[14], bd);
-  step1[2] = HIGHBD_WRAPLOW(step2[2] + step2[13], bd);
-  step1[3] = HIGHBD_WRAPLOW(step2[3] + step2[12], bd);
-  step1[4] = HIGHBD_WRAPLOW(step2[4] + step2[11], bd);
-  step1[5] = HIGHBD_WRAPLOW(step2[5] + step2[10], bd);
-  step1[6] = HIGHBD_WRAPLOW(step2[6] + step2[9], bd);
-  step1[7] = HIGHBD_WRAPLOW(step2[7] + step2[8], bd);
-  step1[8] = HIGHBD_WRAPLOW(step2[7] - step2[8], bd);
-  step1[9] = HIGHBD_WRAPLOW(step2[6] - step2[9], bd);
-  step1[10] = HIGHBD_WRAPLOW(step2[5] - step2[10], bd);
-  step1[11] = HIGHBD_WRAPLOW(step2[4] - step2[11], bd);
-  step1[12] = HIGHBD_WRAPLOW(step2[3] - step2[12], bd);
-  step1[13] = HIGHBD_WRAPLOW(step2[2] - step2[13], bd);
-  step1[14] = HIGHBD_WRAPLOW(step2[1] - step2[14], bd);
-  step1[15] = HIGHBD_WRAPLOW(step2[0] - step2[15], bd);
-
-  step1[16] = step2[16];
-  step1[17] = step2[17];
-  step1[18] = step2[18];
-  step1[19] = step2[19];
-  temp1 = (-step2[20] + step2[27]) * cospi_16_64;
-  temp2 = (step2[20] + step2[27]) * cospi_16_64;
-  step1[20] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[27] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = (-step2[21] + step2[26]) * cospi_16_64;
-  temp2 = (step2[21] + step2[26]) * cospi_16_64;
-  step1[21] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[26] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = (-step2[22] + step2[25]) * cospi_16_64;
-  temp2 = (step2[22] + step2[25]) * cospi_16_64;
-  step1[22] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[25] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  temp1 = (-step2[23] + step2[24]) * cospi_16_64;
-  temp2 = (step2[23] + step2[24]) * cospi_16_64;
-  step1[23] = HIGHBD_WRAPLOW(dct_const_round_shift(temp1), bd);
-  step1[24] = HIGHBD_WRAPLOW(dct_const_round_shift(temp2), bd);
-  step1[28] = step2[28];
-  step1[29] = step2[29];
-  step1[30] = step2[30];
-  step1[31] = step2[31];
-
-  // final stage
-  output[0] = HIGHBD_WRAPLOW(step1[0] + step1[31], bd);
-  output[1] = HIGHBD_WRAPLOW(step1[1] + step1[30], bd);
-  output[2] = HIGHBD_WRAPLOW(step1[2] + step1[29], bd);
-  output[3] = HIGHBD_WRAPLOW(step1[3] + step1[28], bd);
-  output[4] = HIGHBD_WRAPLOW(step1[4] + step1[27], bd);
-  output[5] = HIGHBD_WRAPLOW(step1[5] + step1[26], bd);
-  output[6] = HIGHBD_WRAPLOW(step1[6] + step1[25], bd);
-  output[7] = HIGHBD_WRAPLOW(step1[7] + step1[24], bd);
-  output[8] = HIGHBD_WRAPLOW(step1[8] + step1[23], bd);
-  output[9] = HIGHBD_WRAPLOW(step1[9] + step1[22], bd);
-  output[10] = HIGHBD_WRAPLOW(step1[10] + step1[21], bd);
-  output[11] = HIGHBD_WRAPLOW(step1[11] + step1[20], bd);
-  output[12] = HIGHBD_WRAPLOW(step1[12] + step1[19], bd);
-  output[13] = HIGHBD_WRAPLOW(step1[13] + step1[18], bd);
-  output[14] = HIGHBD_WRAPLOW(step1[14] + step1[17], bd);
-  output[15] = HIGHBD_WRAPLOW(step1[15] + step1[16], bd);
-  output[16] = HIGHBD_WRAPLOW(step1[15] - step1[16], bd);
-  output[17] = HIGHBD_WRAPLOW(step1[14] - step1[17], bd);
-  output[18] = HIGHBD_WRAPLOW(step1[13] - step1[18], bd);
-  output[19] = HIGHBD_WRAPLOW(step1[12] - step1[19], bd);
-  output[20] = HIGHBD_WRAPLOW(step1[11] - step1[20], bd);
-  output[21] = HIGHBD_WRAPLOW(step1[10] - step1[21], bd);
-  output[22] = HIGHBD_WRAPLOW(step1[9] - step1[22], bd);
-  output[23] = HIGHBD_WRAPLOW(step1[8] - step1[23], bd);
-  output[24] = HIGHBD_WRAPLOW(step1[7] - step1[24], bd);
-  output[25] = HIGHBD_WRAPLOW(step1[6] - step1[25], bd);
-  output[26] = HIGHBD_WRAPLOW(step1[5] - step1[26], bd);
-  output[27] = HIGHBD_WRAPLOW(step1[4] - step1[27], bd);
-  output[28] = HIGHBD_WRAPLOW(step1[3] - step1[28], bd);
-  output[29] = HIGHBD_WRAPLOW(step1[2] - step1[29], bd);
-  output[30] = HIGHBD_WRAPLOW(step1[1] - step1[30], bd);
-  output[31] = HIGHBD_WRAPLOW(step1[0] - step1[31], bd);
-}
-
 #endif  // CONFIG_HIGHBITDEPTH
diff --git a/aom_dsp/txfm_common.h b/aom_dsp/txfm_common.h
index a5e964a..bb255c4 100644
--- a/aom_dsp/txfm_common.h
+++ b/aom_dsp/txfm_common.h
@@ -67,4 +67,8 @@
 // 16384 * sqrt(2)
 static const tran_high_t Sqrt2 = 23170;
 
+static INLINE tran_high_t fdct_round_shift(tran_high_t input) {
+  tran_high_t rv = ROUND_POWER_OF_TWO(input, DCT_CONST_BITS);
+  return rv;
+}
 #endif  // AOM_DSP_TXFM_COMMON_H_
diff --git a/aom_dsp/x86/inv_txfm_sse2.c b/aom_dsp/x86/inv_txfm_sse2.c
index be200df..fb4968e 100644
--- a/aom_dsp/x86/inv_txfm_sse2.c
+++ b/aom_dsp/x86/inv_txfm_sse2.c
@@ -3627,108 +3627,4 @@
     }
   }
 }
-
-void aom_highbd_idct8x8_10_add_sse2(const tran_low_t *input, uint8_t *dest8,
-                                    int stride, int bd) {
-  tran_low_t out[8 * 8] = { 0 };
-  tran_low_t *outptr = out;
-  int i, j, test;
-  __m128i inptr[8];
-  __m128i min_input, max_input, temp1, temp2, sign_bits;
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-  const __m128i zero = _mm_set1_epi16(0);
-  const __m128i sixteen = _mm_set1_epi16(16);
-  const __m128i max = _mm_set1_epi16(6201);
-  const __m128i min = _mm_set1_epi16(-6201);
-  int optimised_cols = 0;
-
-  // Load input into __m128i & pack to 16 bits
-  for (i = 0; i < 8; i++) {
-    temp1 = _mm_loadu_si128((const __m128i *)(input + 8 * i));
-    temp2 = _mm_loadu_si128((const __m128i *)(input + 8 * i + 4));
-    inptr[i] = _mm_packs_epi32(temp1, temp2);
-  }
-
-  // Find the min & max for the row transform
-  // only first 4 row has non-zero coefs
-  max_input = _mm_max_epi16(inptr[0], inptr[1]);
-  min_input = _mm_min_epi16(inptr[0], inptr[1]);
-  for (i = 2; i < 4; i++) {
-    max_input = _mm_max_epi16(max_input, inptr[i]);
-    min_input = _mm_min_epi16(min_input, inptr[i]);
-  }
-  max_input = _mm_cmpgt_epi16(max_input, max);
-  min_input = _mm_cmplt_epi16(min_input, min);
-  temp1 = _mm_or_si128(max_input, min_input);
-  test = _mm_movemask_epi8(temp1);
-
-  if (!test) {
-    // Do the row transform
-    aom_idct8_sse2(inptr);
-
-    // Find the min & max for the column transform
-    // N.B. Only first 4 cols contain non-zero coeffs
-    max_input = _mm_max_epi16(inptr[0], inptr[1]);
-    min_input = _mm_min_epi16(inptr[0], inptr[1]);
-    for (i = 2; i < 8; i++) {
-      max_input = _mm_max_epi16(max_input, inptr[i]);
-      min_input = _mm_min_epi16(min_input, inptr[i]);
-    }
-    max_input = _mm_cmpgt_epi16(max_input, max);
-    min_input = _mm_cmplt_epi16(min_input, min);
-    temp1 = _mm_or_si128(max_input, min_input);
-    test = _mm_movemask_epi8(temp1);
-
-    if (test) {
-      // Use fact only first 4 rows contain non-zero coeffs
-      array_transpose_4X8(inptr, inptr);
-      for (i = 0; i < 4; i++) {
-        sign_bits = _mm_cmplt_epi16(inptr[i], zero);
-        temp1 = _mm_unpackhi_epi16(inptr[i], sign_bits);
-        temp2 = _mm_unpacklo_epi16(inptr[i], sign_bits);
-        _mm_storeu_si128((__m128i *)(outptr + 4 * (2 * i + 1)), temp1);
-        _mm_storeu_si128((__m128i *)(outptr + 4 * (2 * i)), temp2);
-      }
-    } else {
-      // Set to use the optimised transform for the column
-      optimised_cols = 1;
-    }
-  } else {
-    // Run the un-optimised row transform
-    for (i = 0; i < 4; ++i) {
-      aom_highbd_idct8_c(input, outptr, bd);
-      input += 8;
-      outptr += 8;
-    }
-  }
-
-  if (optimised_cols) {
-    aom_idct8_sse2(inptr);
-
-    // Final round & shift and Reconstruction and Store
-    {
-      __m128i d[8];
-      for (i = 0; i < 8; i++) {
-        inptr[i] = _mm_add_epi16(inptr[i], sixteen);
-        d[i] = _mm_loadu_si128((const __m128i *)(dest + stride * i));
-        inptr[i] = _mm_srai_epi16(inptr[i], 5);
-        d[i] = clamp_high_sse2(_mm_adds_epi16(d[i], inptr[i]), bd);
-        // Store
-        _mm_storeu_si128((__m128i *)(dest + stride * i), d[i]);
-      }
-    }
-  } else {
-    // Run the un-optimised column transform
-    tran_low_t temp_in[8], temp_out[8];
-    for (i = 0; i < 8; ++i) {
-      for (j = 0; j < 8; ++j) temp_in[j] = out[j * 8 + i];
-      aom_highbd_idct8_c(temp_in, temp_out, bd);
-      for (j = 0; j < 8; ++j) {
-        dest[j * stride + i] = highbd_clip_pixel_add(
-            dest[j * stride + i], ROUND_POWER_OF_TWO(temp_out[j], 5), bd);
-      }
-    }
-  }
-}
-
 #endif  // CONFIG_HIGHBITDEPTH
diff --git a/av1/common/av1_fwd_txfm2d.c b/av1/common/av1_fwd_txfm2d.c
index f8d7b23..5d20e54 100644
--- a/av1/common/av1_fwd_txfm2d.c
+++ b/av1/common/av1_fwd_txfm2d.c
@@ -12,6 +12,7 @@
 #include <assert.h>
 
 #include "./av1_rtcd.h"
+#include "aom_dsp/txfm_common.h"
 #include "av1/common/enums.h"
 #include "av1/common/av1_fwd_txfm1d.h"
 #include "av1/common/av1_fwd_txfm1d_cfg.h"
@@ -41,9 +42,17 @@
                                 const int stride, const TXFM_2D_FLIP_CFG *cfg,
                                 int32_t *buf) {
   int c, r;
-  // TODO(sarahparker) must correct for rectangular transforms in follow up
-  const int txfm_size = cfg->row_cfg->txfm_size;
-  const int8_t *shift = cfg->row_cfg->shift;
+  // Note when assigning txfm_size_col, we use the txfm_size from the
+  // row configuration and vice versa. This is intentionally done to
+  // accurately perform rectangular transforms. When the transform is
+  // rectangular, the number of columns will be the same as the
+  // txfm_size stored in the row cfg struct. It will make no difference
+  // for square transforms.
+  const int txfm_size_col = cfg->row_cfg->txfm_size;
+  const int txfm_size_row = cfg->col_cfg->txfm_size;
+  // Take the shift from the larger dimension in the rectangular case.
+  const int8_t *shift =
+      txfm_size_col > txfm_size_row ? cfg->row_cfg->shift : cfg->col_cfg->shift;
   const int8_t *stage_range_col = cfg->col_cfg->stage_range;
   const int8_t *stage_range_row = cfg->row_cfg->stage_range;
   const int8_t *cos_bit_col = cfg->col_cfg->cos_bit;
@@ -53,37 +62,99 @@
 
   // use output buffer as temp buffer
   int32_t *temp_in = output;
-  int32_t *temp_out = output + txfm_size;
+  int32_t *temp_out = output + txfm_size_row;
 
   // Columns
-  for (c = 0; c < txfm_size; ++c) {
+  for (c = 0; c < txfm_size_col; ++c) {
     if (cfg->ud_flip == 0) {
-      for (r = 0; r < txfm_size; ++r) temp_in[r] = input[r * stride + c];
+      for (r = 0; r < txfm_size_row; ++r) temp_in[r] = input[r * stride + c];
     } else {
-      for (r = 0; r < txfm_size; ++r)
+      for (r = 0; r < txfm_size_row; ++r)
         // flip upside down
-        temp_in[r] = input[(txfm_size - r - 1) * stride + c];
+        temp_in[r] = input[(txfm_size_row - r - 1) * stride + c];
     }
-    round_shift_array(temp_in, txfm_size, -shift[0]);
+    round_shift_array(temp_in, txfm_size_row, -shift[0]);
+    // Multiply everything by Sqrt2 on the larger dimension if the
+    // transform is rectangular
+    if (txfm_size_col > txfm_size_row) {
+      for (r = 0; r < txfm_size_row; ++r)
+        temp_in[r] = (int32_t)fdct_round_shift(temp_in[r] * Sqrt2);
+    }
     txfm_func_col(temp_in, temp_out, cos_bit_col, stage_range_col);
-    round_shift_array(temp_out, txfm_size, -shift[1]);
+    round_shift_array(temp_out, txfm_size_row, -shift[1]);
     if (cfg->lr_flip == 0) {
-      for (r = 0; r < txfm_size; ++r) buf[r * txfm_size + c] = temp_out[r];
+      for (r = 0; r < txfm_size_row; ++r)
+        buf[r * txfm_size_col + c] = temp_out[r];
     } else {
-      for (r = 0; r < txfm_size; ++r)
+      for (r = 0; r < txfm_size_row; ++r)
         // flip from left to right
-        buf[r * txfm_size + (txfm_size - c - 1)] = temp_out[r];
+        buf[r * txfm_size_col + (txfm_size_col - c - 1)] = temp_out[r];
     }
   }
 
   // Rows
-  for (r = 0; r < txfm_size; ++r) {
-    txfm_func_row(buf + r * txfm_size, output + r * txfm_size, cos_bit_row,
-                  stage_range_row);
-    round_shift_array(output + r * txfm_size, txfm_size, -shift[2]);
+  for (r = 0; r < txfm_size_row; ++r) {
+    // Multiply everything by Sqrt2 on the larger dimension if the
+    // transform is rectangular
+    if (txfm_size_row > txfm_size_col) {
+      for (c = 0; c < txfm_size_col; ++c)
+        buf[r * txfm_size_col + c] =
+            (int32_t)fdct_round_shift(buf[r * txfm_size_col + c] * Sqrt2);
+    }
+    txfm_func_row(buf + r * txfm_size_col, output + r * txfm_size_col,
+                  cos_bit_row, stage_range_row);
+    round_shift_array(output + r * txfm_size_col, txfm_size_col, -shift[2]);
   }
 }
 
+void av1_fwd_txfm2d_4x8_c(const int16_t *input, int32_t *output, int stride,
+                          int tx_type, int bd) {
+  int32_t txfm_buf[4 * 8];
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_4X8);
+  (void)bd;
+  fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+}
+
+void av1_fwd_txfm2d_8x4_c(const int16_t *input, int32_t *output, int stride,
+                          int tx_type, int bd) {
+  int32_t txfm_buf[8 * 4];
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_8X4);
+  (void)bd;
+  fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+}
+
+void av1_fwd_txfm2d_8x16_c(const int16_t *input, int32_t *output, int stride,
+                           int tx_type, int bd) {
+  int32_t txfm_buf[8 * 16];
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_8X16);
+  (void)bd;
+  fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+}
+
+void av1_fwd_txfm2d_16x8_c(const int16_t *input, int32_t *output, int stride,
+                           int tx_type, int bd) {
+  int32_t txfm_buf[16 * 8];
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_16X8);
+  (void)bd;
+  fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+}
+
+void av1_fwd_txfm2d_16x32_c(const int16_t *input, int32_t *output, int stride,
+                            int tx_type, int bd) {
+  int32_t txfm_buf[16 * 32];
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_16X32);
+  (void)bd;
+  fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+}
+
+void av1_fwd_txfm2d_32x16_c(const int16_t *input, int32_t *output, int stride,
+                            int tx_type, int bd) {
+  int32_t txfm_buf[32 * 16];
+  TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(tx_type, TX_32X16);
+  (void)bd;
+  fwd_txfm2d_c(input, output, stride, &cfg, txfm_buf);
+}
+
 void av1_fwd_txfm2d_4x4_c(const int16_t *input, int32_t *output, int stride,
                           int tx_type, int bd) {
   int32_t txfm_buf[4 * 4];
@@ -195,8 +266,10 @@
   set_flip_cfg(tx_type, &cfg);
   int tx_type_col = vtx_tab[tx_type];
   int tx_type_row = htx_tab[tx_type];
-  cfg.col_cfg = fwd_txfm_col_cfg_ls[tx_type_col][tx_size];
-  cfg.row_cfg = fwd_txfm_row_cfg_ls[tx_type_row][tx_size];
+  int tx_size_col = txsize_vert_map[tx_size];
+  int tx_size_row = txsize_horz_map[tx_size];
+  cfg.col_cfg = fwd_txfm_col_cfg_ls[tx_type_col][tx_size_col];
+  cfg.row_cfg = fwd_txfm_row_cfg_ls[tx_type_row][tx_size_row];
   return cfg;
 }
 
diff --git a/av1/common/av1_inv_txfm2d.c b/av1/common/av1_inv_txfm2d.c
index e07f994..074f062 100644
--- a/av1/common/av1_inv_txfm2d.c
+++ b/av1/common/av1_inv_txfm2d.c
@@ -10,6 +10,7 @@
  */
 
 #include "./av1_rtcd.h"
+#include "aom_dsp/inv_txfm.h"
 #include "av1/common/enums.h"
 #include "av1/common/av1_txfm.h"
 #include "av1/common/av1_inv_txfm1d.h"
@@ -106,10 +107,10 @@
   set_flip_cfg(tx_type, &cfg);
   int tx_type_col = vtx_tab[tx_type];
   int tx_type_row = htx_tab[tx_type];
-  // TODO(sarahparker) this is currently only implemented for
-  // square transforms
-  cfg.col_cfg = inv_txfm_col_cfg_ls[tx_type_col][tx_size];
-  cfg.row_cfg = inv_txfm_row_cfg_ls[tx_type_row][tx_size];
+  int tx_size_col = txsize_vert_map[tx_size];
+  int tx_size_row = txsize_horz_map[tx_size];
+  cfg.col_cfg = inv_txfm_col_cfg_ls[tx_type_col][tx_size_col];
+  cfg.row_cfg = inv_txfm_row_cfg_ls[tx_type_row][tx_size_row];
   return cfg;
 }
 
@@ -129,9 +130,17 @@
 static INLINE void inv_txfm2d_add_c(const int32_t *input, int16_t *output,
                                     int stride, TXFM_2D_FLIP_CFG *cfg,
                                     int32_t *txfm_buf) {
-  // TODO(sarahparker) must correct for rectangular transforms in follow up
-  const int txfm_size = cfg->row_cfg->txfm_size;
-  const int8_t *shift = cfg->row_cfg->shift;
+  // Note when assigning txfm_size_col, we use the txfm_size from the
+  // row configuration and vice versa. This is intentionally done to
+  // accurately perform rectangular transforms. When the transform is
+  // rectangular, the number of columns will be the same as the
+  // txfm_size stored in the row cfg struct. It will make no difference
+  // for square transforms.
+  const int txfm_size_col = cfg->row_cfg->txfm_size;
+  const int txfm_size_row = cfg->col_cfg->txfm_size;
+  // Take the shift from the larger dimension in the rectangular case.
+  const int8_t *shift =
+      txfm_size_col > txfm_size_row ? cfg->row_cfg->shift : cfg->col_cfg->shift;
   const int8_t *stage_range_col = cfg->col_cfg->stage_range;
   const int8_t *stage_range_row = cfg->row_cfg->stage_range;
   const int8_t *cos_bit_col = cfg->col_cfg->cos_bit;
@@ -139,39 +148,45 @@
   const TxfmFunc txfm_func_col = inv_txfm_type_to_func(cfg->col_cfg->txfm_type);
   const TxfmFunc txfm_func_row = inv_txfm_type_to_func(cfg->row_cfg->txfm_type);
 
-  // txfm_buf's length is  txfm_size * txfm_size + 2 * txfm_size
+  // txfm_buf's length is  txfm_size_row * txfm_size_col + 2 * txfm_size_row
   // it is used for intermediate data buffering
   int32_t *temp_in = txfm_buf;
-  int32_t *temp_out = temp_in + txfm_size;
-  int32_t *buf = temp_out + txfm_size;
+  int32_t *temp_out = temp_in + txfm_size_row;
+  int32_t *buf = temp_out + txfm_size_row;
   int32_t *buf_ptr = buf;
   int c, r;
 
   // Rows
-  for (r = 0; r < txfm_size; ++r) {
+  for (r = 0; r < txfm_size_row; ++r) {
     txfm_func_row(input, buf_ptr, cos_bit_row, stage_range_row);
-    round_shift_array(buf_ptr, txfm_size, -shift[0]);
-    input += txfm_size;
-    buf_ptr += txfm_size;
+    round_shift_array(buf_ptr, txfm_size_col, -shift[0]);
+    // Multiply everything by Sqrt2 if the transform is rectangular
+    if (txfm_size_row != txfm_size_col) {
+      for (c = 0; c < txfm_size_col; ++c)
+        buf_ptr[c] = (int32_t)dct_const_round_shift(buf_ptr[c] * Sqrt2);
+    }
+    input += txfm_size_col;
+    buf_ptr += txfm_size_col;
   }
 
   // Columns
-  for (c = 0; c < txfm_size; ++c) {
+  for (c = 0; c < txfm_size_col; ++c) {
     if (cfg->lr_flip == 0) {
-      for (r = 0; r < txfm_size; ++r) temp_in[r] = buf[r * txfm_size + c];
+      for (r = 0; r < txfm_size_row; ++r)
+        temp_in[r] = buf[r * txfm_size_col + c];
     } else {
       // flip left right
-      for (r = 0; r < txfm_size; ++r)
-        temp_in[r] = buf[r * txfm_size + (txfm_size - c - 1)];
+      for (r = 0; r < txfm_size_row; ++r)
+        temp_in[r] = buf[r * txfm_size_col + (txfm_size_col - c - 1)];
     }
     txfm_func_col(temp_in, temp_out, cos_bit_col, stage_range_col);
-    round_shift_array(temp_out, txfm_size, -shift[1]);
+    round_shift_array(temp_out, txfm_size_row, -shift[1]);
     if (cfg->ud_flip == 0) {
-      for (r = 0; r < txfm_size; ++r) output[r * stride + c] += temp_out[r];
+      for (r = 0; r < txfm_size_row; ++r) output[r * stride + c] += temp_out[r];
     } else {
       // flip upside down
-      for (r = 0; r < txfm_size; ++r)
-        output[r * stride + c] += temp_out[txfm_size - r - 1];
+      for (r = 0; r < txfm_size_row; ++r)
+        output[r * stride + c] += temp_out[txfm_size_row - r - 1];
     }
   }
 }
@@ -185,11 +200,44 @@
   // int16_t*
   TXFM_2D_FLIP_CFG cfg = av1_get_inv_txfm_cfg(tx_type, tx_size);
   inv_txfm2d_add_c(input, (int16_t *)output, stride, &cfg, txfm_buf);
-  // TODO(sarahparker) just using the cfg_row->txfm_size for now because
-  // we are assumint this is only used for square transforms. This will
-  // be adjusted in a follow up
-  clamp_block((int16_t *)output, cfg.row_cfg->txfm_size, stride, 0,
-              (1 << bd) - 1);
+  clamp_block((int16_t *)output, cfg.col_cfg->txfm_size, cfg.row_cfg->txfm_size,
+              stride, 0, (1 << bd) - 1);
+}
+
+void av1_inv_txfm2d_add_4x8_c(const int32_t *input, uint16_t *output,
+                              int stride, int tx_type, int bd) {
+  int txfm_buf[4 * 8 + 8 + 8];
+  inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_4X8, bd);
+}
+
+void av1_inv_txfm2d_add_8x4_c(const int32_t *input, uint16_t *output,
+                              int stride, int tx_type, int bd) {
+  int txfm_buf[8 * 4 + 4 + 4];
+  inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_8X4, bd);
+}
+
+void av1_inv_txfm2d_add_8x16_c(const int32_t *input, uint16_t *output,
+                               int stride, int tx_type, int bd) {
+  int txfm_buf[8 * 16 + 16 + 16];
+  inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_8X16, bd);
+}
+
+void av1_inv_txfm2d_add_16x8_c(const int32_t *input, uint16_t *output,
+                               int stride, int tx_type, int bd) {
+  int txfm_buf[16 * 8 + 8 + 8];
+  inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_16X8, bd);
+}
+
+void av1_inv_txfm2d_add_16x32_c(const int32_t *input, uint16_t *output,
+                                int stride, int tx_type, int bd) {
+  int txfm_buf[16 * 32 + 32 + 32];
+  inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_16X32, bd);
+}
+
+void av1_inv_txfm2d_add_32x16_c(const int32_t *input, uint16_t *output,
+                                int stride, int tx_type, int bd) {
+  int txfm_buf[32 * 16 + 16 + 16];
+  inv_txfm2d_add_facade(input, output, stride, txfm_buf, tx_type, TX_32X16, bd);
 }
 
 void av1_inv_txfm2d_add_4x4_c(const int32_t *input, uint16_t *output,
@@ -225,5 +273,5 @@
   // int16_t*
   TXFM_2D_FLIP_CFG cfg = av1_get_inv_txfm_64x64_cfg(tx_type);
   inv_txfm2d_add_c(input, (int16_t *)output, stride, &cfg, txfm_buf);
-  clamp_block((int16_t *)output, 64, stride, 0, (1 << bd) - 1);
+  clamp_block((int16_t *)output, 64, 64, stride, 0, (1 << bd) - 1);
 }
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 6be2be0..fd65fcf 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -255,6 +255,12 @@
 
 if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
   #inv txfm
+  add_proto qw/void av1_inv_txfm2d_add_4x8/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_inv_txfm2d_add_8x4/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_inv_txfm2d_add_8x16/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_inv_txfm2d_add_16x8/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_inv_txfm2d_add_16x32/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_inv_txfm2d_add_32x16/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
   add_proto qw/void av1_inv_txfm2d_add_4x4/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
   specialize qw/av1_inv_txfm2d_add_4x4 sse4_1/;
   add_proto qw/void av1_inv_txfm2d_add_8x8/, "const int32_t *input, uint16_t *output, int stride, int tx_type, int bd";
@@ -405,12 +411,21 @@
 if (aom_config("CONFIG_DPCM_INTRA") eq "yes") {
   @sizes = (4, 8, 16, 32);
   foreach $size (@sizes) {
+    if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
+      add_proto "void", "av1_hbd_dpcm_ft$size", "const int16_t *input, int stride, TX_TYPE_1D tx_type, tran_low_t *output, int dir";
+    }
     add_proto "void", "av1_dpcm_ft$size", "const int16_t *input, int stride, TX_TYPE_1D tx_type, tran_low_t *output";
   }
 }
 
 if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
   #fwd txfm
+  add_proto qw/void av1_fwd_txfm2d_4x8/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_fwd_txfm2d_8x4/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_fwd_txfm2d_8x16/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_fwd_txfm2d_16x8/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_fwd_txfm2d_16x32/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
+  add_proto qw/void av1_fwd_txfm2d_32x16/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
   add_proto qw/void av1_fwd_txfm2d_4x4/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
   specialize qw/av1_fwd_txfm2d_4x4 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_8x8/, "const int16_t *input, int32_t *output, int stride, int tx_type, int bd";
@@ -484,35 +499,6 @@
   }
 
   # fdct functions
-  add_proto qw/void av1_highbd_fht4x4/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-  specialize qw/av1_highbd_fht4x4 sse4_1/;
-
-  add_proto qw/void av1_highbd_fht4x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht8x4/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht8x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht16x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht16x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht32x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht4x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht16x4/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht8x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht32x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht8x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht16x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
-  add_proto qw/void av1_highbd_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
-
   if (aom_config("CONFIG_TX64X64") eq "yes") {
     add_proto qw/void av1_highbd_fht64x64/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
   }
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index b341cb7..1304e4c 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -120,11 +120,12 @@
 }
 
 // TODO(angiebird): implement SSE
-static INLINE void clamp_block(int16_t *block, int block_size, int stride,
-                               int low, int high) {
+static INLINE void clamp_block(int16_t *block, int block_size_row,
+                               int block_size_col, int stride, int low,
+                               int high) {
   int i, j;
-  for (i = 0; i < block_size; ++i) {
-    for (j = 0; j < block_size; ++j) {
+  for (i = 0; i < block_size_row; ++i) {
+    for (j = 0; j < block_size_col; ++j) {
       block[i * stride + j] = clamp(block[i * stride + j], low, high);
     }
   }
diff --git a/av1/common/idct.c b/av1/common/idct.c
index e94598e..4e15969 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -113,107 +113,6 @@
 }
 #endif  // CONFIG_TX64X64
 
-#if CONFIG_HIGHBITDEPTH
-#if CONFIG_EXT_TX
-// TODO(sarahparker) these functions will be removed once the highbitdepth
-// codepath works properly for rectangular transforms. They have almost
-// identical versions in av1_inv_txfm1d.c, but those are currently only
-// being used for square transforms.
-static void highbd_iidtx4_c(const tran_low_t *input, tran_low_t *output,
-                            int bd) {
-  int i;
-  for (i = 0; i < 4; ++i)
-    output[i] = HIGHBD_WRAPLOW(dct_const_round_shift(input[i] * Sqrt2), bd);
-}
-
-static void highbd_iidtx8_c(const tran_low_t *input, tran_low_t *output,
-                            int bd) {
-  int i;
-  (void)bd;
-  for (i = 0; i < 8; ++i) output[i] = input[i] * 2;
-}
-
-static void highbd_iidtx16_c(const tran_low_t *input, tran_low_t *output,
-                             int bd) {
-  int i;
-  for (i = 0; i < 16; ++i)
-    output[i] = HIGHBD_WRAPLOW(dct_const_round_shift(input[i] * 2 * Sqrt2), bd);
-}
-
-static void highbd_iidtx32_c(const tran_low_t *input, tran_low_t *output,
-                             int bd) {
-  int i;
-  (void)bd;
-  for (i = 0; i < 32; ++i) output[i] = input[i] * 4;
-}
-#endif  // CONFIG_EXT_TX
-
-static void highbd_ihalfright32_c(const tran_low_t *input, tran_low_t *output,
-                                  int bd) {
-  int i;
-  tran_low_t inputhalf[16];
-  // Multiply input by sqrt(2)
-  for (i = 0; i < 16; ++i) {
-    inputhalf[i] = HIGHBD_WRAPLOW(dct_const_round_shift(input[i] * Sqrt2), bd);
-  }
-  for (i = 0; i < 16; ++i) {
-    output[i] = input[16 + i] * 4;
-  }
-  aom_highbd_idct16_c(inputhalf, output + 16, bd);
-  // Note overall scaling factor is 4 times orthogonal
-}
-
-#if CONFIG_EXT_TX
-#if CONFIG_TX64X64
-static void highbd_iidtx64_c(const tran_low_t *input, tran_low_t *output,
-                             int bd) {
-  int i;
-  for (i = 0; i < 64; ++i)
-    output[i] = HIGHBD_WRAPLOW(dct_const_round_shift(input[i] * 4 * Sqrt2), bd);
-}
-#endif  // CONFIG_TX64X64
-#endif  // CONFIG_EXT_TX
-
-#if CONFIG_TX64X64
-// For use in lieu of ADST
-static void highbd_ihalfright64_c(const tran_low_t *input, tran_low_t *output,
-                                  int bd) {
-  int i;
-  tran_low_t inputhalf[32];
-  // Multiply input by sqrt(2)
-  for (i = 0; i < 32; ++i) {
-    inputhalf[i] = HIGHBD_WRAPLOW(dct_const_round_shift(input[i] * Sqrt2), bd);
-  }
-  for (i = 0; i < 32; ++i) {
-    output[i] =
-        HIGHBD_WRAPLOW(dct_const_round_shift(input[32 + i] * 4 * Sqrt2), bd);
-  }
-  aom_highbd_idct32_c(inputhalf, output + 32, bd);
-  // Note overall scaling factor is 4 * sqrt(2)  times orthogonal
-}
-
-static void highbd_idct64_col_c(const tran_low_t *input, tran_low_t *output,
-                                int bd) {
-  int32_t in[64], out[64];
-  int i;
-  (void)bd;
-  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
-  av1_idct64_new(in, out, inv_cos_bit_col_dct_64, inv_stage_range_col_dct_64);
-  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
-}
-
-static void highbd_idct64_row_c(const tran_low_t *input, tran_low_t *output,
-                                int bd) {
-  int32_t in[64], out[64];
-  int i;
-  (void)bd;
-  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
-  av1_idct64_new(in, out, inv_cos_bit_row_dct_64, inv_stage_range_row_dct_64);
-  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
-}
-#endif  // CONFIG_TX64X64
-#endif  // CONFIG_HIGHBITDEPTH
-
 // Inverse identity transform and add.
 #if CONFIG_EXT_TX
 static void inv_idtx_add_c(const tran_low_t *input, uint8_t *dest, int stride,
@@ -278,7 +177,7 @@
 #endif  // CONFIG_EXT_TX
 
 #if CONFIG_HIGHBITDEPTH
-#if CONFIG_EXT_TX
+#if CONFIG_EXT_TX && CONFIG_TX64X64
 static void highbd_inv_idtx_add_c(const tran_low_t *input, uint8_t *dest8,
                                   int stride, int bs, int tx_type, int bd) {
   int r, c;
@@ -294,45 +193,7 @@
     }
   }
 }
-
-static void maybe_flip_strides16(uint16_t **dst, int *dstride, tran_low_t **src,
-                                 int *sstride, int tx_type, int sizey,
-                                 int sizex) {
-  // Note that the transpose of src will be added to dst. In order to LR
-  // flip the addends (in dst coordinates), we UD flip the src. To UD flip
-  // the addends, we UD flip the dst.
-  switch (tx_type) {
-    case DCT_DCT:
-    case ADST_DCT:
-    case DCT_ADST:
-    case ADST_ADST:
-    case IDTX:
-    case V_DCT:
-    case H_DCT:
-    case V_ADST:
-    case H_ADST: break;
-    case FLIPADST_DCT:
-    case FLIPADST_ADST:
-    case V_FLIPADST:
-      // flip UD
-      FLIPUD_PTR(*dst, *dstride, sizey);
-      break;
-    case DCT_FLIPADST:
-    case ADST_FLIPADST:
-    case H_FLIPADST:
-      // flip LR
-      FLIPUD_PTR(*src, *sstride, sizex);
-      break;
-    case FLIPADST_FLIPADST:
-      // flip UD
-      FLIPUD_PTR(*dst, *dstride, sizey);
-      // flip LR
-      FLIPUD_PTR(*src, *sstride, sizex);
-      break;
-    default: assert(0); break;
-  }
-}
-#endif  // CONFIG_EXT_TX
+#endif  // CONFIG_EXT_TX && CONFIG_TX64X64
 #endif  // CONFIG_HIGHBITDEPTH
 
 void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
@@ -1535,919 +1396,6 @@
 #endif  // CONFIG_TX64X64
 
 #if CONFIG_HIGHBITDEPTH
-void av1_highbd_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest8,
-                                int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_4[] = {
-    { aom_highbd_idct4_c, aom_highbd_idct4_c },    // DCT_DCT
-    { aom_highbd_iadst4_c, aom_highbd_idct4_c },   // ADST_DCT
-    { aom_highbd_idct4_c, aom_highbd_iadst4_c },   // DCT_ADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst4_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst4_c, aom_highbd_idct4_c },   // FLIPADST_DCT
-    { aom_highbd_idct4_c, aom_highbd_iadst4_c },   // DCT_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst4_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst4_c },  // ADST_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst4_c },  // FLIPADST_ADST
-    { highbd_iidtx4_c, highbd_iidtx4_c },          // IDTX
-    { aom_highbd_idct4_c, highbd_iidtx4_c },       // V_DCT
-    { highbd_iidtx4_c, aom_highbd_idct4_c },       // H_DCT
-    { aom_highbd_iadst4_c, highbd_iidtx4_c },      // V_ADST
-    { highbd_iidtx4_c, aom_highbd_iadst4_c },      // H_ADST
-    { aom_highbd_iadst4_c, highbd_iidtx4_c },      // V_FLIPADST
-    { highbd_iidtx4_c, aom_highbd_iadst4_c },      // H_FLIPADST
-#endif                                             // CONFIG_EXT_TX
-  };
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t tmp[4][4];
-  tran_low_t out[4][4];
-  tran_low_t *outp = &out[0][0];
-  int outstride = 4;
-
-  // inverse transform row vectors
-  for (i = 0; i < 4; ++i) {
-    HIGH_IHT_4[tx_type].rows(input, out[i], bd);
-    input += 4;
-  }
-
-  // transpose
-  for (i = 0; i < 4; i++) {
-    for (j = 0; j < 4; j++) {
-      tmp[j][i] = out[i][j];
-    }
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < 4; ++i) {
-    HIGH_IHT_4[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, 4, 4);
-#endif
-
-  // Sum with the destination
-  for (i = 0; i < 4; ++i) {
-    for (j = 0; j < 4; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 4), bd);
-    }
-  }
-}
-
-void av1_highbd_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest8,
-                                int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_4x8[] = {
-    { aom_highbd_idct8_c, aom_highbd_idct4_c },    // DCT_DCT
-    { aom_highbd_iadst8_c, aom_highbd_idct4_c },   // ADST_DCT
-    { aom_highbd_idct8_c, aom_highbd_iadst4_c },   // DCT_ADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst4_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst8_c, aom_highbd_idct4_c },   // FLIPADST_DCT
-    { aom_highbd_idct8_c, aom_highbd_iadst4_c },   // DCT_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst4_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst4_c },  // ADST_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst4_c },  // FLIPADST_ADST
-    { highbd_iidtx8_c, highbd_iidtx4_c },          // IDTX
-    { aom_highbd_idct8_c, highbd_iidtx4_c },       // V_DCT
-    { highbd_iidtx8_c, aom_highbd_idct4_c },       // H_DCT
-    { aom_highbd_iadst8_c, highbd_iidtx4_c },      // V_ADST
-    { highbd_iidtx8_c, aom_highbd_iadst4_c },      // H_ADST
-    { aom_highbd_iadst8_c, highbd_iidtx4_c },      // V_FLIPADST
-    { highbd_iidtx8_c, aom_highbd_iadst4_c },      // H_FLIPADST
-#endif                                             // CONFIG_EXT_TX
-  };
-  const int n = 4;
-  const int n2 = 8;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[4][8], tmp[4][8], outtmp[4];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n2;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n2; ++i) {
-    HIGH_IHT_4x8[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n; ++j) {
-      tmp[j][i] = HIGHBD_WRAPLOW(dct_const_round_shift(outtmp[j] * Sqrt2), bd);
-    }
-    input += n;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_4x8[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n2, n);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n2; ++i) {
-    for (j = 0; j < n; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5), bd);
-    }
-  }
-}
-
-void av1_highbd_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest8,
-                                int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_8x4[] = {
-    { aom_highbd_idct4_c, aom_highbd_idct8_c },    // DCT_DCT
-    { aom_highbd_iadst4_c, aom_highbd_idct8_c },   // ADST_DCT
-    { aom_highbd_idct4_c, aom_highbd_iadst8_c },   // DCT_ADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst8_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst4_c, aom_highbd_idct8_c },   // FLIPADST_DCT
-    { aom_highbd_idct4_c, aom_highbd_iadst8_c },   // DCT_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst8_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst8_c },  // ADST_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst8_c },  // FLIPADST_ADST
-    { highbd_iidtx4_c, highbd_iidtx8_c },          // IDTX
-    { aom_highbd_idct4_c, highbd_iidtx8_c },       // V_DCT
-    { highbd_iidtx4_c, aom_highbd_idct8_c },       // H_DCT
-    { aom_highbd_iadst4_c, highbd_iidtx8_c },      // V_ADST
-    { highbd_iidtx4_c, aom_highbd_iadst8_c },      // H_ADST
-    { aom_highbd_iadst4_c, highbd_iidtx8_c },      // V_FLIPADST
-    { highbd_iidtx4_c, aom_highbd_iadst8_c },      // H_FLIPADST
-#endif                                             // CONFIG_EXT_TX
-  };
-  const int n = 4;
-  const int n2 = 8;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[8][4], tmp[8][4], outtmp[8];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_8x4[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n2; ++j) {
-      tmp[j][i] = HIGHBD_WRAPLOW(dct_const_round_shift(outtmp[j] * Sqrt2), bd);
-    }
-    input += n2;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n2; ++i) {
-    HIGH_IHT_8x4[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n, n2);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n; ++i) {
-    for (j = 0; j < n2; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5), bd);
-    }
-  }
-}
-
-void av1_highbd_iht4x16_64_add_c(const tran_low_t *input, uint8_t *dest8,
-                                 int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_4x16[] = {
-    { aom_highbd_idct16_c, aom_highbd_idct4_c },    // DCT_DCT
-    { aom_highbd_iadst16_c, aom_highbd_idct4_c },   // ADST_DCT
-    { aom_highbd_idct16_c, aom_highbd_iadst4_c },   // DCT_ADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst4_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst16_c, aom_highbd_idct4_c },   // FLIPADST_DCT
-    { aom_highbd_idct16_c, aom_highbd_iadst4_c },   // DCT_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst4_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst4_c },  // ADST_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst4_c },  // FLIPADST_ADST
-    { highbd_iidtx16_c, highbd_iidtx4_c },          // IDTX
-    { aom_highbd_idct16_c, highbd_iidtx4_c },       // V_DCT
-    { highbd_iidtx16_c, aom_highbd_idct4_c },       // H_DCT
-    { aom_highbd_iadst16_c, highbd_iidtx4_c },      // V_ADST
-    { highbd_iidtx16_c, aom_highbd_iadst4_c },      // H_ADST
-    { aom_highbd_iadst16_c, highbd_iidtx4_c },      // V_FLIPADST
-    { highbd_iidtx16_c, aom_highbd_iadst4_c },      // H_FLIPADST
-#endif                                              // CONFIG_EXT_TX
-  };
-  const int n = 4;
-  const int n4 = 16;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[4][16], tmp[4][16], outtmp[4];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n4;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n4; ++i) {
-    HIGH_IHT_4x16[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n; ++j) tmp[j][i] = outtmp[j];
-    input += n;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n; ++i) HIGH_IHT_4x16[tx_type].cols(tmp[i], out[i], bd);
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n4, n);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n4; ++i) {
-    for (j = 0; j < n; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5), bd);
-    }
-  }
-}
-
-void av1_highbd_iht16x4_64_add_c(const tran_low_t *input, uint8_t *dest8,
-                                 int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_16x4[] = {
-    { aom_highbd_idct4_c, aom_highbd_idct16_c },    // DCT_DCT
-    { aom_highbd_iadst4_c, aom_highbd_idct16_c },   // ADST_DCT
-    { aom_highbd_idct4_c, aom_highbd_iadst16_c },   // DCT_ADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst16_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst4_c, aom_highbd_idct16_c },   // FLIPADST_DCT
-    { aom_highbd_idct4_c, aom_highbd_iadst16_c },   // DCT_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst16_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst16_c },  // ADST_FLIPADST
-    { aom_highbd_iadst4_c, aom_highbd_iadst16_c },  // FLIPADST_ADST
-    { highbd_iidtx4_c, highbd_iidtx16_c },          // IDTX
-    { aom_highbd_idct4_c, highbd_iidtx16_c },       // V_DCT
-    { highbd_iidtx4_c, aom_highbd_idct16_c },       // H_DCT
-    { aom_highbd_iadst4_c, highbd_iidtx16_c },      // V_ADST
-    { highbd_iidtx4_c, aom_highbd_iadst16_c },      // H_ADST
-    { aom_highbd_iadst4_c, highbd_iidtx16_c },      // V_FLIPADST
-    { highbd_iidtx4_c, aom_highbd_iadst16_c },      // H_FLIPADST
-#endif                                              // CONFIG_EXT_TX
-  };
-  const int n = 4;
-  const int n4 = 16;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[16][4], tmp[16][4], outtmp[16];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_16x4[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n4; ++j) tmp[j][i] = outtmp[j];
-    input += n4;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n4; ++i) {
-    HIGH_IHT_16x4[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n, n4);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n; ++i) {
-    for (j = 0; j < n4; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5), bd);
-    }
-  }
-}
-
-void av1_highbd_iht8x16_128_add_c(const tran_low_t *input, uint8_t *dest8,
-                                  int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_8x16[] = {
-    { aom_highbd_idct16_c, aom_highbd_idct8_c },    // DCT_DCT
-    { aom_highbd_iadst16_c, aom_highbd_idct8_c },   // ADST_DCT
-    { aom_highbd_idct16_c, aom_highbd_iadst8_c },   // DCT_ADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst8_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst16_c, aom_highbd_idct8_c },   // FLIPADST_DCT
-    { aom_highbd_idct16_c, aom_highbd_iadst8_c },   // DCT_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst8_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst8_c },  // ADST_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst8_c },  // FLIPADST_ADST
-    { highbd_iidtx16_c, highbd_iidtx8_c },          // IDTX
-    { aom_highbd_idct16_c, highbd_iidtx8_c },       // V_DCT
-    { highbd_iidtx16_c, aom_highbd_idct8_c },       // H_DCT
-    { aom_highbd_iadst16_c, highbd_iidtx8_c },      // V_ADST
-    { highbd_iidtx16_c, aom_highbd_iadst8_c },      // H_ADST
-    { aom_highbd_iadst16_c, highbd_iidtx8_c },      // V_FLIPADST
-    { highbd_iidtx16_c, aom_highbd_iadst8_c },      // H_FLIPADST
-#endif                                              // CONFIG_EXT_TX
-  };
-  const int n = 8;
-  const int n2 = 16;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[8][16], tmp[8][16], outtmp[8];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n2;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n2; ++i) {
-    HIGH_IHT_8x16[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n; ++j)
-      tmp[j][i] = HIGHBD_WRAPLOW(dct_const_round_shift(outtmp[j] * Sqrt2), bd);
-    input += n;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_8x16[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n2, n);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n2; ++i) {
-    for (j = 0; j < n; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-
-void av1_highbd_iht16x8_128_add_c(const tran_low_t *input, uint8_t *dest8,
-                                  int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_16x8[] = {
-    { aom_highbd_idct8_c, aom_highbd_idct16_c },    // DCT_DCT
-    { aom_highbd_iadst8_c, aom_highbd_idct16_c },   // ADST_DCT
-    { aom_highbd_idct8_c, aom_highbd_iadst16_c },   // DCT_ADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst16_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst8_c, aom_highbd_idct16_c },   // FLIPADST_DCT
-    { aom_highbd_idct8_c, aom_highbd_iadst16_c },   // DCT_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst16_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst16_c },  // ADST_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst16_c },  // FLIPADST_ADST
-    { highbd_iidtx8_c, highbd_iidtx16_c },          // IDTX
-    { aom_highbd_idct8_c, highbd_iidtx16_c },       // V_DCT
-    { highbd_iidtx8_c, aom_highbd_idct16_c },       // H_DCT
-    { aom_highbd_iadst8_c, highbd_iidtx16_c },      // V_ADST
-    { highbd_iidtx8_c, aom_highbd_iadst16_c },      // H_ADST
-    { aom_highbd_iadst8_c, highbd_iidtx16_c },      // V_FLIPADST
-    { highbd_iidtx8_c, aom_highbd_iadst16_c },      // H_FLIPADST
-#endif                                              // CONFIG_EXT_TX
-  };
-  const int n = 8;
-  const int n2 = 16;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[16][8], tmp[16][8], outtmp[16];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_16x8[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n2; ++j)
-      tmp[j][i] = HIGHBD_WRAPLOW(dct_const_round_shift(outtmp[j] * Sqrt2), bd);
-    input += n2;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n2; ++i) {
-    HIGH_IHT_16x8[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n, n2);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n; ++i) {
-    for (j = 0; j < n2; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-
-void av1_highbd_iht8x32_256_add_c(const tran_low_t *input, uint8_t *dest8,
-                                  int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_8x32[] = {
-    { aom_highbd_idct32_c, aom_highbd_idct8_c },     // DCT_DCT
-    { highbd_ihalfright32_c, aom_highbd_idct8_c },   // ADST_DCT
-    { aom_highbd_idct32_c, aom_highbd_iadst8_c },    // DCT_ADST
-    { highbd_ihalfright32_c, aom_highbd_iadst8_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { highbd_ihalfright32_c, aom_highbd_idct8_c },   // FLIPADST_DCT
-    { aom_highbd_idct32_c, aom_highbd_iadst8_c },    // DCT_FLIPADST
-    { highbd_ihalfright32_c, aom_highbd_iadst8_c },  // FLIPADST_FLIPADST
-    { highbd_ihalfright32_c, aom_highbd_iadst8_c },  // ADST_FLIPADST
-    { highbd_ihalfright32_c, aom_highbd_iadst8_c },  // FLIPADST_ADST
-    { highbd_iidtx32_c, highbd_iidtx8_c },           // IDTX
-    { aom_highbd_idct32_c, highbd_iidtx8_c },        // V_DCT
-    { highbd_iidtx32_c, aom_highbd_idct8_c },        // H_DCT
-    { highbd_ihalfright32_c, highbd_iidtx8_c },      // V_ADST
-    { highbd_iidtx32_c, aom_highbd_iadst8_c },       // H_ADST
-    { highbd_ihalfright32_c, highbd_iidtx8_c },      // V_FLIPADST
-    { highbd_iidtx32_c, aom_highbd_iadst8_c },       // H_FLIPADST
-#endif                                               // CONFIG_EXT_TX
-  };
-  const int n = 8;
-  const int n4 = 32;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[8][32], tmp[8][32], outtmp[8];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n4;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n4; ++i) {
-    HIGH_IHT_8x32[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n; ++j) tmp[j][i] = outtmp[j];
-    input += n;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n; ++i) HIGH_IHT_8x32[tx_type].cols(tmp[i], out[i], bd);
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n4, n);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n4; ++i) {
-    for (j = 0; j < n; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-
-void av1_highbd_iht32x8_256_add_c(const tran_low_t *input, uint8_t *dest8,
-                                  int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_32x8[] = {
-    { aom_highbd_idct8_c, aom_highbd_idct32_c },     // DCT_DCT
-    { aom_highbd_iadst8_c, aom_highbd_idct32_c },    // ADST_DCT
-    { aom_highbd_idct8_c, highbd_ihalfright32_c },   // DCT_ADST
-    { aom_highbd_iadst8_c, highbd_ihalfright32_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst8_c, aom_highbd_idct32_c },    // FLIPADST_DCT
-    { aom_highbd_idct8_c, highbd_ihalfright32_c },   // DCT_FLIPADST
-    { aom_highbd_iadst8_c, highbd_ihalfright32_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst8_c, highbd_ihalfright32_c },  // ADST_FLIPADST
-    { aom_highbd_iadst8_c, highbd_ihalfright32_c },  // FLIPADST_ADST
-    { highbd_iidtx8_c, highbd_iidtx32_c },           // IDTX
-    { aom_highbd_idct8_c, highbd_iidtx32_c },        // V_DCT
-    { highbd_iidtx8_c, aom_highbd_idct32_c },        // H_DCT
-    { aom_highbd_iadst8_c, highbd_iidtx32_c },       // V_ADST
-    { highbd_iidtx8_c, highbd_ihalfright32_c },      // H_ADST
-    { aom_highbd_iadst8_c, highbd_iidtx32_c },       // V_FLIPADST
-    { highbd_iidtx8_c, highbd_ihalfright32_c },      // H_FLIPADST
-#endif                                               // CONFIG_EXT_TX
-  };
-  const int n = 8;
-  const int n4 = 32;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[32][8], tmp[32][8], outtmp[32];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_32x8[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n4; ++j) tmp[j][i] = outtmp[j];
-    input += n4;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n4; ++i) HIGH_IHT_32x8[tx_type].cols(tmp[i], out[i], bd);
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n, n4);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n; ++i) {
-    for (j = 0; j < n4; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-
-void av1_highbd_iht16x32_512_add_c(const tran_low_t *input, uint8_t *dest8,
-                                   int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_16x32[] = {
-    { aom_highbd_idct32_c, aom_highbd_idct16_c },     // DCT_DCT
-    { highbd_ihalfright32_c, aom_highbd_idct16_c },   // ADST_DCT
-    { aom_highbd_idct32_c, aom_highbd_iadst16_c },    // DCT_ADST
-    { highbd_ihalfright32_c, aom_highbd_iadst16_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { highbd_ihalfright32_c, aom_highbd_idct16_c },   // FLIPADST_DCT
-    { aom_highbd_idct32_c, aom_highbd_iadst16_c },    // DCT_FLIPADST
-    { highbd_ihalfright32_c, aom_highbd_iadst16_c },  // FLIPADST_FLIPADST
-    { highbd_ihalfright32_c, aom_highbd_iadst16_c },  // ADST_FLIPADST
-    { highbd_ihalfright32_c, aom_highbd_iadst16_c },  // FLIPADST_ADST
-    { highbd_iidtx32_c, highbd_iidtx16_c },           // IDTX
-    { aom_highbd_idct32_c, highbd_iidtx16_c },        // V_DCT
-    { highbd_iidtx32_c, aom_highbd_idct16_c },        // H_DCT
-    { highbd_ihalfright32_c, highbd_iidtx16_c },      // V_ADST
-    { highbd_iidtx32_c, aom_highbd_iadst16_c },       // H_ADST
-    { highbd_ihalfright32_c, highbd_iidtx16_c },      // V_FLIPADST
-    { highbd_iidtx32_c, aom_highbd_iadst16_c },       // H_FLIPADST
-#endif                                                // CONFIG_EXT_TX
-  };
-  const int n = 16;
-  const int n2 = 32;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[16][32], tmp[16][32], outtmp[16];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n2;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n2; ++i) {
-    HIGH_IHT_16x32[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n; ++j)
-      tmp[j][i] = HIGHBD_WRAPLOW(dct_const_round_shift(outtmp[j] * Sqrt2), bd);
-    input += n;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_16x32[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n2, n);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n2; ++i) {
-    for (j = 0; j < n; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-
-void av1_highbd_iht32x16_512_add_c(const tran_low_t *input, uint8_t *dest8,
-                                   int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_32x16[] = {
-    { aom_highbd_idct16_c, aom_highbd_idct32_c },     // DCT_DCT
-    { aom_highbd_iadst16_c, aom_highbd_idct32_c },    // ADST_DCT
-    { aom_highbd_idct16_c, highbd_ihalfright32_c },   // DCT_ADST
-    { aom_highbd_iadst16_c, highbd_ihalfright32_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst16_c, aom_highbd_idct32_c },    // FLIPADST_DCT
-    { aom_highbd_idct16_c, highbd_ihalfright32_c },   // DCT_FLIPADST
-    { aom_highbd_iadst16_c, highbd_ihalfright32_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst16_c, highbd_ihalfright32_c },  // ADST_FLIPADST
-    { aom_highbd_iadst16_c, highbd_ihalfright32_c },  // FLIPADST_ADST
-    { highbd_iidtx16_c, highbd_iidtx32_c },           // IDTX
-    { aom_highbd_idct16_c, highbd_iidtx32_c },        // V_DCT
-    { highbd_iidtx16_c, aom_highbd_idct32_c },        // H_DCT
-    { aom_highbd_iadst16_c, highbd_iidtx32_c },       // V_ADST
-    { highbd_iidtx16_c, highbd_ihalfright32_c },      // H_ADST
-    { aom_highbd_iadst16_c, highbd_iidtx32_c },       // V_FLIPADST
-    { highbd_iidtx16_c, highbd_ihalfright32_c },      // H_FLIPADST
-#endif                                                // CONFIG_EXT_TX
-  };
-  const int n = 16;
-  const int n2 = 32;
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t out[32][16], tmp[32][16], outtmp[32];
-  tran_low_t *outp = &out[0][0];
-  int outstride = n;
-
-  // inverse transform row vectors, and transpose
-  for (i = 0; i < n; ++i) {
-    HIGH_IHT_32x16[tx_type].rows(input, outtmp, bd);
-    for (j = 0; j < n2; ++j)
-      tmp[j][i] = HIGHBD_WRAPLOW(dct_const_round_shift(outtmp[j] * Sqrt2), bd);
-    input += n2;
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < n2; ++i) {
-    HIGH_IHT_32x16[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, n, n2);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < n; ++i) {
-    for (j = 0; j < n2; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-
-void av1_highbd_iht8x8_64_add_c(const tran_low_t *input, uint8_t *dest8,
-                                int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_8[] = {
-    { aom_highbd_idct8_c, aom_highbd_idct8_c },    // DCT_DCT
-    { aom_highbd_iadst8_c, aom_highbd_idct8_c },   // ADST_DCT
-    { aom_highbd_idct8_c, aom_highbd_iadst8_c },   // DCT_ADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst8_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst8_c, aom_highbd_idct8_c },   // FLIPADST_DCT
-    { aom_highbd_idct8_c, aom_highbd_iadst8_c },   // DCT_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst8_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst8_c },  // ADST_FLIPADST
-    { aom_highbd_iadst8_c, aom_highbd_iadst8_c },  // FLIPADST_ADST
-    { highbd_iidtx8_c, highbd_iidtx8_c },          // IDTX
-    { aom_highbd_idct8_c, highbd_iidtx8_c },       // V_DCT
-    { highbd_iidtx8_c, aom_highbd_idct8_c },       // H_DCT
-    { aom_highbd_iadst8_c, highbd_iidtx8_c },      // V_ADST
-    { highbd_iidtx8_c, aom_highbd_iadst8_c },      // H_ADST
-    { aom_highbd_iadst8_c, highbd_iidtx8_c },      // V_FLIPADST
-    { highbd_iidtx8_c, aom_highbd_iadst8_c },      // H_FLIPADST
-#endif                                             // CONFIG_EXT_TX
-  };
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t tmp[8][8];
-  tran_low_t out[8][8];
-  tran_low_t *outp = &out[0][0];
-  int outstride = 8;
-
-  // inverse transform row vectors
-  for (i = 0; i < 8; ++i) {
-    HIGH_IHT_8[tx_type].rows(input, out[i], bd);
-    input += 8;
-  }
-
-  // transpose
-  for (i = 0; i < 8; i++) {
-    for (j = 0; j < 8; j++) {
-      tmp[j][i] = out[i][j];
-    }
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < 8; ++i) {
-    HIGH_IHT_8[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, 8, 8);
-#endif
-
-  // Sum with the destination
-  for (i = 0; i < 8; ++i) {
-    for (j = 0; j < 8; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5), bd);
-    }
-  }
-}
-
-void av1_highbd_iht16x16_256_add_c(const tran_low_t *input, uint8_t *dest8,
-                                   int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_16[] = {
-    { aom_highbd_idct16_c, aom_highbd_idct16_c },    // DCT_DCT
-    { aom_highbd_iadst16_c, aom_highbd_idct16_c },   // ADST_DCT
-    { aom_highbd_idct16_c, aom_highbd_iadst16_c },   // DCT_ADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst16_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { aom_highbd_iadst16_c, aom_highbd_idct16_c },   // FLIPADST_DCT
-    { aom_highbd_idct16_c, aom_highbd_iadst16_c },   // DCT_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst16_c },  // FLIPADST_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst16_c },  // ADST_FLIPADST
-    { aom_highbd_iadst16_c, aom_highbd_iadst16_c },  // FLIPADST_ADST
-    { highbd_iidtx16_c, highbd_iidtx16_c },          // IDTX
-    { aom_highbd_idct16_c, highbd_iidtx16_c },       // V_DCT
-    { highbd_iidtx16_c, aom_highbd_idct16_c },       // H_DCT
-    { aom_highbd_iadst16_c, highbd_iidtx16_c },      // V_ADST
-    { highbd_iidtx16_c, aom_highbd_iadst16_c },      // H_ADST
-    { aom_highbd_iadst16_c, highbd_iidtx16_c },      // V_FLIPADST
-    { highbd_iidtx16_c, aom_highbd_iadst16_c },      // H_FLIPADST
-#endif                                               // CONFIG_EXT_TX
-  };
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t tmp[16][16];
-  tran_low_t out[16][16];
-  tran_low_t *outp = &out[0][0];
-  int outstride = 16;
-
-  // inverse transform row vectors
-  for (i = 0; i < 16; ++i) {
-    HIGH_IHT_16[tx_type].rows(input, out[i], bd);
-    input += 16;
-  }
-
-  // transpose
-  for (i = 0; i < 16; i++) {
-    for (j = 0; j < 16; j++) {
-      tmp[j][i] = out[i][j];
-    }
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < 16; ++i) {
-    HIGH_IHT_16[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, 16, 16);
-#endif
-
-  // Sum with the destination
-  for (i = 0; i < 16; ++i) {
-    for (j = 0; j < 16; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-
-#if CONFIG_EXT_TX
-static void highbd_iht32x32_1024_add_c(const tran_low_t *input, uint8_t *dest8,
-                                       int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_32[] = {
-    { aom_highbd_idct32_c, aom_highbd_idct32_c },      // DCT_DCT
-    { highbd_ihalfright32_c, aom_highbd_idct32_c },    // ADST_DCT
-    { aom_highbd_idct32_c, highbd_ihalfright32_c },    // DCT_ADST
-    { highbd_ihalfright32_c, highbd_ihalfright32_c },  // ADST_ADST
-    { highbd_ihalfright32_c, aom_highbd_idct32_c },    // FLIPADST_DCT
-    { aom_highbd_idct32_c, highbd_ihalfright32_c },    // DCT_FLIPADST
-    { highbd_ihalfright32_c, highbd_ihalfright32_c },  // FLIPADST_FLIPADST
-    { highbd_ihalfright32_c, highbd_ihalfright32_c },  // ADST_FLIPADST
-    { highbd_ihalfright32_c, highbd_ihalfright32_c },  // FLIPADST_ADST
-    { highbd_iidtx32_c, highbd_iidtx32_c },            // IDTX
-    { aom_highbd_idct32_c, highbd_iidtx32_c },         // V_DCT
-    { highbd_iidtx32_c, aom_highbd_idct32_c },         // H_DCT
-    { highbd_ihalfright32_c, highbd_iidtx32_c },       // V_ADST
-    { highbd_iidtx32_c, highbd_ihalfright32_c },       // H_ADST
-    { highbd_ihalfright32_c, highbd_iidtx32_c },       // V_FLIPADST
-    { highbd_iidtx32_c, highbd_ihalfright32_c },       // H_FLIPADST
-  };
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t tmp[32][32];
-  tran_low_t out[32][32];
-  tran_low_t *outp = &out[0][0];
-  int outstride = 32;
-
-  // inverse transform row vectors
-  for (i = 0; i < 32; ++i) {
-    HIGH_IHT_32[tx_type].rows(input, out[i], bd);
-    input += 32;
-  }
-
-  // transpose
-  for (i = 0; i < 32; i++) {
-    for (j = 0; j < 32; j++) {
-      tmp[j][i] = out[i][j];
-    }
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < 32; ++i) {
-    HIGH_IHT_32[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, 32, 32);
-
-  // Sum with the destination
-  for (i = 0; i < 32; ++i) {
-    for (j = 0; j < 32; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 6), bd);
-    }
-  }
-}
-#endif  // CONFIG_EXT_TX
-
-#if CONFIG_TX64X64
-static void highbd_iht64x64_4096_add_c(const tran_low_t *input, uint8_t *dest8,
-                                       int stride, int tx_type, int bd) {
-  static const highbd_transform_2d HIGH_IHT_64[] = {
-    { highbd_idct64_col_c, highbd_idct64_row_c },      // DCT_DCT
-    { highbd_ihalfright64_c, highbd_idct64_row_c },    // ADST_DCT
-    { highbd_idct64_col_c, highbd_ihalfright64_c },    // DCT_ADST
-    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // ADST_ADST
-#if CONFIG_EXT_TX
-    { highbd_ihalfright64_c, highbd_idct64_row_c },    // FLIPADST_DCT
-    { highbd_idct64_col_c, highbd_ihalfright64_c },    // DCT_FLIPADST
-    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // FLIPADST_FLIPADST
-    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // ADST_FLIPADST
-    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // FLIPADST_ADST
-    { highbd_iidtx64_c, highbd_iidtx64_c },            // IDTX
-    { highbd_idct64_col_c, highbd_iidtx64_c },         // V_DCT
-    { highbd_iidtx64_c, highbd_idct64_row_c },         // H_DCT
-    { highbd_ihalfright64_c, highbd_iidtx64_c },       // V_ADST
-    { highbd_iidtx64_c, highbd_ihalfright64_c },       // H_ADST
-    { highbd_ihalfright64_c, highbd_iidtx64_c },       // V_FLIPADST
-    { highbd_iidtx64_c, highbd_ihalfright64_c },       // H_FLIPADST
-#endif                                                 // CONFIG_EXT_TX
-  };
-
-  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
-
-  int i, j;
-  tran_low_t tmp[64][64];
-  tran_low_t out[64][64];
-  tran_low_t *outp = &out[0][0];
-  int outstride = 64;
-
-  // inverse transform row vectors
-  for (i = 0; i < 64; ++i) {
-    HIGH_IHT_64[tx_type].rows(input, out[i], bd);
-    for (j = 0; j < 64; ++j) out[i][j] = ROUND_POWER_OF_TWO(out[i][j], 1);
-    input += 64;
-  }
-
-  // transpose
-  for (i = 0; i < 64; i++) {
-    for (j = 0; j < 64; j++) {
-      tmp[j][i] = out[i][j];
-    }
-  }
-
-  // inverse transform column vectors
-  for (i = 0; i < 64; ++i) {
-    HIGH_IHT_64[tx_type].cols(tmp[i], out[i], bd);
-  }
-
-#if CONFIG_EXT_TX
-  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, 64, 64);
-#endif  // CONFIG_EXT_TX
-
-  // Sum with the destination
-  for (i = 0; i < 64; ++i) {
-    for (j = 0; j < 64; ++j) {
-      int d = i * stride + j;
-      int s = j * outstride + i;
-      dest[d] =
-          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5), bd);
-    }
-  }
-}
-#endif  // CONFIG_TX64X64
-
 // idct
 void av1_highbd_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
                             int eob, int bd) {
@@ -2505,34 +1453,36 @@
     av1_highbd_iwht4x4_add(input, dest, stride, eob, bd);
     return;
   }
-
   switch (tx_type) {
     case DCT_DCT:
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
+      av1_inv_txfm2d_add_4x4(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                             bd);
+      break;
 #if CONFIG_EXT_TX
     case FLIPADST_DCT:
     case DCT_FLIPADST:
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
-#endif  // CONFIG_EXT_TX
+      // fallthrough intended
       av1_inv_txfm2d_add_4x4(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
                              bd);
       break;
-#if CONFIG_EXT_TX
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      // Use C version since DST only exists in C code
-      av1_highbd_iht4x4_16_add_c(input, dest, stride, tx_type, bd);
-      break;
     case IDTX:
-      highbd_inv_idtx_add_c(input, dest, stride, 4, tx_type, bd);
+      // fallthrough intended
+      av1_inv_txfm2d_add_4x4_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+                               tx_type, bd);
       break;
 #endif  // CONFIG_EXT_TX
     default: assert(0); break;
@@ -2542,69 +1492,47 @@
 void av1_highbd_inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest,
                                  int stride, int eob, int bd, TX_TYPE tx_type) {
   (void)eob;
-  av1_highbd_iht4x8_32_add_c(input, dest, stride, tx_type, bd);
+  av1_inv_txfm2d_add_4x8_c(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                           bd);
 }
 
 void av1_highbd_inv_txfm_add_8x4(const tran_low_t *input, uint8_t *dest,
                                  int stride, int eob, int bd, TX_TYPE tx_type) {
   (void)eob;
-  av1_highbd_iht8x4_32_add_c(input, dest, stride, tx_type, bd);
-}
-
-void av1_highbd_inv_txfm_add_4x16(const tran_low_t *input, uint8_t *dest,
-                                  int stride, int eob, int bd,
-                                  TX_TYPE tx_type) {
-  (void)eob;
-  av1_highbd_iht4x16_64_add_c(input, dest, stride, tx_type, bd);
-}
-
-void av1_highbd_inv_txfm_add_16x4(const tran_low_t *input, uint8_t *dest,
-                                  int stride, int eob, int bd,
-                                  TX_TYPE tx_type) {
-  (void)eob;
-  av1_highbd_iht16x4_64_add_c(input, dest, stride, tx_type, bd);
+  av1_inv_txfm2d_add_8x4_c(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                           bd);
 }
 
 static void highbd_inv_txfm_add_8x16(const tran_low_t *input, uint8_t *dest,
                                      int stride, int eob, int bd,
                                      TX_TYPE tx_type) {
   (void)eob;
-  av1_highbd_iht8x16_128_add_c(input, dest, stride, tx_type, bd);
+  av1_inv_txfm2d_add_8x16_c(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                            bd);
 }
 
 static void highbd_inv_txfm_add_16x8(const tran_low_t *input, uint8_t *dest,
                                      int stride, int eob, int bd,
                                      TX_TYPE tx_type) {
   (void)eob;
-  av1_highbd_iht16x8_128_add_c(input, dest, stride, tx_type, bd);
-}
-
-void av1_highbd_inv_txfm_add_8x32(const tran_low_t *input, uint8_t *dest,
-                                  int stride, int eob, int bd,
-                                  TX_TYPE tx_type) {
-  (void)eob;
-  av1_highbd_iht8x32_256_add_c(input, dest, stride, tx_type, bd);
-}
-
-void av1_highbd_inv_txfm_add_32x8(const tran_low_t *input, uint8_t *dest,
-                                  int stride, int eob, int bd,
-                                  TX_TYPE tx_type) {
-  (void)eob;
-  av1_highbd_iht32x8_256_add_c(input, dest, stride, tx_type, bd);
+  av1_inv_txfm2d_add_16x8_c(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                            bd);
 }
 
 static void highbd_inv_txfm_add_16x32(const tran_low_t *input, uint8_t *dest,
                                       int stride, int eob, int bd,
                                       TX_TYPE tx_type) {
   (void)eob;
-  av1_highbd_iht16x32_512_add_c(input, dest, stride, tx_type, bd);
+  av1_inv_txfm2d_add_16x32_c(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                             bd);
 }
 
 static void highbd_inv_txfm_add_32x16(const tran_low_t *input, uint8_t *dest,
                                       int stride, int eob, int bd,
                                       TX_TYPE tx_type) {
   (void)eob;
-  av1_highbd_iht32x16_512_add_c(input, dest, stride, tx_type, bd);
+  av1_inv_txfm2d_add_32x16_c(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                             bd);
 }
 
 static void highbd_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest,
@@ -2616,31 +1544,34 @@
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
+      av1_inv_txfm2d_add_8x8(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                             bd);
+      break;
 #if CONFIG_EXT_TX
     case FLIPADST_DCT:
     case DCT_FLIPADST:
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
-#endif  // CONFIG_EXT_TX
+      // fallthrough intended
       av1_inv_txfm2d_add_8x8(input, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
                              bd);
       break;
-#if CONFIG_EXT_TX
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      // Use C version since DST only exists in C code
-      av1_highbd_iht8x8_64_add_c(input, dest, stride, tx_type, bd);
-      break;
     case IDTX:
-      highbd_inv_idtx_add_c(input, dest, stride, 8, tx_type, bd);
+      // fallthrough intended
+      av1_inv_txfm2d_add_8x8_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+                               tx_type, bd);
       break;
 #endif  // CONFIG_EXT_TX
-    default: assert(0); break;
+    default: assert(0);
   }
 }
 
@@ -2653,31 +1584,34 @@
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
+      av1_inv_txfm2d_add_16x16(input, CONVERT_TO_SHORTPTR(dest), stride,
+                               tx_type, bd);
+      break;
 #if CONFIG_EXT_TX
     case FLIPADST_DCT:
     case DCT_FLIPADST:
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
-#endif  // CONFIG_EXT_TX
+      // fallthrough intended
       av1_inv_txfm2d_add_16x16(input, CONVERT_TO_SHORTPTR(dest), stride,
                                tx_type, bd);
       break;
-#if CONFIG_EXT_TX
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      // Use C version since DST only exists in C code
-      av1_highbd_iht16x16_256_add_c(input, dest, stride, tx_type, bd);
-      break;
     case IDTX:
-      highbd_inv_idtx_add_c(input, dest, stride, 16, tx_type, bd);
+      // fallthrough intended
+      av1_inv_txfm2d_add_16x16_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+                                 tx_type, bd);
       break;
 #endif  // CONFIG_EXT_TX
-    default: assert(0); break;
+    default: assert(0);
   }
 }
 
@@ -2687,31 +1621,37 @@
   (void)eob;
   switch (tx_type) {
     case DCT_DCT:
-      av1_inv_txfm2d_add_32x32(input, CONVERT_TO_SHORTPTR(dest), stride,
-                               DCT_DCT, bd);
-      break;
-#if CONFIG_EXT_TX
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
+      av1_inv_txfm2d_add_32x32(input, CONVERT_TO_SHORTPTR(dest), stride,
+                               tx_type, bd);
+      break;
+#if CONFIG_EXT_TX
     case FLIPADST_DCT:
     case DCT_FLIPADST:
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
+      // fallthrough intended
+      av1_inv_txfm2d_add_32x32(input, CONVERT_TO_SHORTPTR(dest), stride,
+                               tx_type, bd);
+      break;
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      highbd_iht32x32_1024_add_c(input, dest, stride, tx_type, bd);
-      break;
     case IDTX:
-      highbd_inv_idtx_add_c(input, dest, stride, 32, tx_type, bd);
+      // fallthrough intended
+      av1_inv_txfm2d_add_32x32_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+                                 tx_type, bd);
       break;
 #endif  // CONFIG_EXT_TX
-    default: assert(0); break;
+    default: assert(0);
   }
 }
 
@@ -2740,7 +1680,14 @@
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      highbd_iht64x64_4096_add_c(input, dest, stride, tx_type, bd);
+      // TODO(sarahparker)
+      // I've deleted the 64x64 implementations that existed in lieu
+      // of adst, flipadst and identity for simplicity but will bring back
+      // in a later change. This shouldn't impact performance since
+      // DCT_DCT is the only extended type currently allowed for 64x64,
+      // as dictated by get_ext_tx_set_type in blockd.h.
+      av1_inv_txfm2d_add_64x64_c(input, CONVERT_TO_SHORTPTR(dest), stride,
+                                 DCT_DCT, bd);
       break;
     case IDTX:
       highbd_inv_idtx_add_c(input, dest, stride, 64, tx_type, bd);
@@ -2990,16 +1937,27 @@
 }
 
 #if CONFIG_HIGHBITDEPTH
+// TODO(sarahparker) I am adding a quick workaround for these functions
+// to remove the old hbd transforms. This will be cleaned up in a followup.
 void av1_hbd_dpcm_inv_txfm_add_4_c(const tran_low_t *input, int stride,
-                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest) {
+                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                   int dir) {
   assert(tx_type < TX_TYPES_1D);
-  static const highbd_transform_1d IHT[] = { aom_highbd_idct4_c,
-                                             aom_highbd_iadst4_c,
-                                             aom_highbd_iadst4_c,
-                                             highbd_iidtx4_c };
-  const highbd_transform_1d inv_tx = IHT[tx_type];
+  static const TxfmFunc IHT[] = { av1_idct4_new, av1_iadst4_new, av1_iadst4_new,
+                                  av1_iidentity4_c };
+  // In order { horizontal, vertical }
+  static const TXFM_1D_CFG *inv_txfm_cfg_ls[TX_TYPES_1D][2] = {
+    { &inv_txfm_1d_row_cfg_dct_4, &inv_txfm_1d_col_cfg_dct_4 },
+    { &inv_txfm_1d_row_cfg_adst_4, &inv_txfm_1d_col_cfg_adst_4 },
+    { &inv_txfm_1d_row_cfg_adst_4, &inv_txfm_1d_col_cfg_adst_4 },
+    { &inv_txfm_1d_cfg_identity_4, &inv_txfm_1d_cfg_identity_4 }
+  };
+
+  const TXFM_1D_CFG *inv_txfm_cfg = inv_txfm_cfg_ls[tx_type][dir];
+  const TxfmFunc inv_tx = IHT[tx_type];
+
   tran_low_t out[4];
-  inv_tx(input, out, bd);
+  inv_tx(input, out, inv_txfm_cfg->cos_bit, inv_txfm_cfg->stage_range);
   for (int i = 0; i < 4; ++i) {
     out[i] = (tran_low_t)dct_const_round_shift(out[i] * Sqrt2);
     dest[i * stride] = highbd_clip_pixel_add(dest[i * stride],
@@ -3008,15 +1966,24 @@
 }
 
 void av1_hbd_dpcm_inv_txfm_add_8_c(const tran_low_t *input, int stride,
-                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest) {
-  static const highbd_transform_1d IHT[] = { aom_highbd_idct8_c,
-                                             aom_highbd_iadst8_c,
-                                             aom_highbd_iadst8_c,
-                                             highbd_iidtx8_c };
+                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                   int dir) {
   assert(tx_type < TX_TYPES_1D);
-  const highbd_transform_1d inv_tx = IHT[tx_type];
+  static const TxfmFunc IHT[] = { av1_idct4_new, av1_iadst4_new, av1_iadst4_new,
+                                  av1_iidentity4_c };
+  // In order { horizontal, vertical }
+  static const TXFM_1D_CFG *inv_txfm_cfg_ls[TX_TYPES_1D][2] = {
+    { &inv_txfm_1d_row_cfg_dct_8, &inv_txfm_1d_col_cfg_dct_8 },
+    { &inv_txfm_1d_row_cfg_adst_8, &inv_txfm_1d_col_cfg_adst_8 },
+    { &inv_txfm_1d_row_cfg_adst_8, &inv_txfm_1d_col_cfg_adst_8 },
+    { &inv_txfm_1d_cfg_identity_8, &inv_txfm_1d_cfg_identity_8 }
+  };
+
+  const TXFM_1D_CFG *inv_txfm_cfg = inv_txfm_cfg_ls[tx_type][dir];
+  const TxfmFunc inv_tx = IHT[tx_type];
+
   tran_low_t out[8];
-  inv_tx(input, out, bd);
+  inv_tx(input, out, inv_txfm_cfg->cos_bit, inv_txfm_cfg->stage_range);
   for (int i = 0; i < 8; ++i) {
     dest[i * stride] = highbd_clip_pixel_add(dest[i * stride],
                                              ROUND_POWER_OF_TWO(out[i], 4), bd);
@@ -3024,16 +1991,24 @@
 }
 
 void av1_hbd_dpcm_inv_txfm_add_16_c(const tran_low_t *input, int stride,
-                                    TX_TYPE_1D tx_type, int bd,
-                                    uint16_t *dest) {
+                                    TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                    int dir) {
   assert(tx_type < TX_TYPES_1D);
-  static const highbd_transform_1d IHT[] = { aom_highbd_idct16_c,
-                                             aom_highbd_iadst16_c,
-                                             aom_highbd_iadst16_c,
-                                             highbd_iidtx16_c };
-  const highbd_transform_1d inv_tx = IHT[tx_type];
+  static const TxfmFunc IHT[] = { av1_idct4_new, av1_iadst4_new, av1_iadst4_new,
+                                  av1_iidentity4_c };
+  // In order { horizontal, vertical }
+  static const TXFM_1D_CFG *inv_txfm_cfg_ls[TX_TYPES_1D][2] = {
+    { &inv_txfm_1d_row_cfg_dct_16, &inv_txfm_1d_col_cfg_dct_16 },
+    { &inv_txfm_1d_row_cfg_adst_16, &inv_txfm_1d_col_cfg_adst_16 },
+    { &inv_txfm_1d_row_cfg_adst_16, &inv_txfm_1d_col_cfg_adst_16 },
+    { &inv_txfm_1d_cfg_identity_16, &inv_txfm_1d_cfg_identity_16 }
+  };
+
+  const TXFM_1D_CFG *inv_txfm_cfg = inv_txfm_cfg_ls[tx_type][dir];
+  const TxfmFunc inv_tx = IHT[tx_type];
+
   tran_low_t out[16];
-  inv_tx(input, out, bd);
+  inv_tx(input, out, inv_txfm_cfg->cos_bit, inv_txfm_cfg->stage_range);
   for (int i = 0; i < 16; ++i) {
     out[i] = (tran_low_t)dct_const_round_shift(out[i] * Sqrt2);
     dest[i * stride] = highbd_clip_pixel_add(dest[i * stride],
@@ -3042,16 +2017,24 @@
 }
 
 void av1_hbd_dpcm_inv_txfm_add_32_c(const tran_low_t *input, int stride,
-                                    TX_TYPE_1D tx_type, int bd,
-                                    uint16_t *dest) {
+                                    TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                    int dir) {
   assert(tx_type < TX_TYPES_1D);
-  static const highbd_transform_1d IHT[] = { aom_highbd_idct32_c,
-                                             highbd_ihalfright32_c,
-                                             highbd_ihalfright32_c,
-                                             highbd_iidtx32_c };
-  const highbd_transform_1d inv_tx = IHT[tx_type];
+  static const TxfmFunc IHT[] = { av1_idct4_new, av1_iadst4_new, av1_iadst4_new,
+                                  av1_iidentity4_c };
+  // In order { horizontal, vertical }
+  static const TXFM_1D_CFG *inv_txfm_cfg_ls[TX_TYPES_1D][2] = {
+    { &inv_txfm_1d_row_cfg_dct_32, &inv_txfm_1d_col_cfg_dct_32 },
+    { &inv_txfm_1d_row_cfg_adst_32, &inv_txfm_1d_col_cfg_adst_32 },
+    { &inv_txfm_1d_row_cfg_adst_32, &inv_txfm_1d_col_cfg_adst_32 },
+    { &inv_txfm_1d_cfg_identity_32, &inv_txfm_1d_cfg_identity_32 }
+  };
+
+  const TXFM_1D_CFG *inv_txfm_cfg = inv_txfm_cfg_ls[tx_type][dir];
+  const TxfmFunc inv_tx = IHT[tx_type];
+
   tran_low_t out[32];
-  inv_tx(input, out, bd);
+  inv_tx(input, out, inv_txfm_cfg->cos_bit, inv_txfm_cfg->stage_range);
   for (int i = 0; i < 32; ++i) {
     dest[i * stride] = highbd_clip_pixel_add(dest[i * stride],
                                              ROUND_POWER_OF_TWO(out[i], 4), bd);
diff --git a/av1/common/idct.h b/av1/common/idct.h
index cf656dc..55f0c65 100644
--- a/av1/common/idct.h
+++ b/av1/common/idct.h
@@ -98,16 +98,20 @@
 dpcm_inv_txfm_add_func av1_get_dpcm_inv_txfm_add_func(int tx_length);
 #if CONFIG_HIGHBITDEPTH
 void av1_hbd_dpcm_inv_txfm_add_4_c(const tran_low_t *input, int stride,
-                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest);
+                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                   int dir);
 void av1_hbd_dpcm_inv_txfm_add_8_c(const tran_low_t *input, int stride,
-                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest);
+                                   TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                   int dir);
 void av1_hbd_dpcm_inv_txfm_add_16_c(const tran_low_t *input, int stride,
-                                    TX_TYPE_1D tx_type, int bd, uint16_t *dest);
+                                    TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                    int dir);
 void av1_hbd_dpcm_inv_txfm_add_32_c(const tran_low_t *input, int stride,
-                                    TX_TYPE_1D tx_type, int bd, uint16_t *dest);
+                                    TX_TYPE_1D tx_type, int bd, uint16_t *dest,
+                                    int dir);
 typedef void (*hbd_dpcm_inv_txfm_add_func)(const tran_low_t *input, int stride,
                                            TX_TYPE_1D tx_type, int bd,
-                                           uint16_t *dest);
+                                           uint16_t *dest, int dir);
 hbd_dpcm_inv_txfm_add_func av1_get_hbd_dpcm_inv_txfm_add_func(int tx_length);
 #endif  // CONFIG_HIGHBITDEPTH
 #endif  // CONFIG_DPCM_INTRA
diff --git a/av1/common/x86/av1_fwd_txfm2d_sse4.c b/av1/common/x86/av1_fwd_txfm2d_sse4.c
index 1d7c553..1785509 100644
--- a/av1/common/x86/av1_fwd_txfm2d_sse4.c
+++ b/av1/common/x86/av1_fwd_txfm2d_sse4.c
@@ -40,7 +40,11 @@
                                      const int stride,
                                      const TXFM_2D_FLIP_CFG *cfg,
                                      int32_t *txfm_buf) {
-  // TODO(sarahparker) must correct for rectangular transforms in follow up
+  // TODO(sarahparker) This does not currently support rectangular transforms
+  // and will break without splitting txfm_size out into row and col size.
+  // Rectangular transforms use c code only, so it should be ok for now.
+  // It will be corrected when there are sse implementations for rectangular
+  // transforms.
   const int txfm_size = cfg->row_cfg->txfm_size;
   const int8_t *shift = cfg->row_cfg->shift;
   const int8_t *stage_range_col = cfg->col_cfg->stage_range;
diff --git a/av1/common/x86/av1_txfm1d_sse4.h b/av1/common/x86/av1_txfm1d_sse4.h
index af7afb7..fd0a6ed 100644
--- a/av1/common/x86/av1_txfm1d_sse4.h
+++ b/av1/common/x86/av1_txfm1d_sse4.h
@@ -64,7 +64,7 @@
 // the entire input block can be represent by a grid of 4x4 blocks
 // each 4x4 blocks can be represent by 4 vertical __m128i
 // we first transpose each 4x4 block internally
-// than transpose the grid
+// then transpose the grid
 static INLINE void transpose_32(int txfm_size, const __m128i *input,
                                 __m128i *output) {
   const int num_per_128 = 4;
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 8585c6c..4e02df4 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -562,7 +562,7 @@
       av1_get_hbd_dpcm_inv_txfm_add_func(tx1d_width);
   for (int r = 0; r < tx1d_height; ++r) {
     if (r > 0) memcpy(dst, dst - dst_stride, tx1d_width * sizeof(dst[0]));
-    inverse_tx(dqcoeff, 1, tx_type_1d, bd, dst);
+    inverse_tx(dqcoeff, 1, tx_type_1d, bd, dst, 1);
     dqcoeff += tx1d_width;
     dst += dst_stride;
   }
@@ -590,7 +590,7 @@
       if (c > 0) dst[r * dst_stride] = dst[r * dst_stride - 1];
       tx_buff[r] = dqcoeff[r * tx1d_width];
     }
-    inverse_tx(tx_buff, dst_stride, tx_type_1d, bd, dst);
+    inverse_tx(tx_buff, dst_stride, tx_type_1d, bd, dst, 0);
   }
 }
 #endif  // CONFIG_HIGHBITDEPTH
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index f6b64f0..fcaea59 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1990,75 +1990,10 @@
 }
 
 #if CONFIG_HIGHBITDEPTH
-void av1_highbd_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
-                         int tx_type) {
-  av1_fht4x4_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
-                         int tx_type) {
-  av1_fht4x8_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
-                         int tx_type) {
-  av1_fht8x4_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht8x16_c(const int16_t *input, tran_low_t *output, int stride,
-                          int tx_type) {
-  av1_fht8x16_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht16x8_c(const int16_t *input, tran_low_t *output, int stride,
-                          int tx_type) {
-  av1_fht16x8_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht16x32_c(const int16_t *input, tran_low_t *output, int stride,
-                           int tx_type) {
-  av1_fht16x32_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht32x16_c(const int16_t *input, tran_low_t *output, int stride,
-                           int tx_type) {
-  av1_fht32x16_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht4x16_c(const int16_t *input, tran_low_t *output, int stride,
-                          int tx_type) {
-  av1_fht4x16_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht16x4_c(const int16_t *input, tran_low_t *output, int stride,
-                          int tx_type) {
-  av1_fht16x4_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht8x32_c(const int16_t *input, tran_low_t *output, int stride,
-                          int tx_type) {
-  av1_fht8x32_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht32x8_c(const int16_t *input, tran_low_t *output, int stride,
-                          int tx_type) {
-  av1_fht32x8_c(input, output, stride, tx_type);
-}
-
-void av1_highbd_fht8x8_c(const int16_t *input, tran_low_t *output, int stride,
-                         int tx_type) {
-  av1_fht8x8_c(input, output, stride, tx_type);
-}
-
 void av1_highbd_fwht4x4_c(const int16_t *input, tran_low_t *output,
                           int stride) {
   av1_fwht4x4_c(input, output, stride);
 }
-
-void av1_highbd_fht16x16_c(const int16_t *input, tran_low_t *output, int stride,
-                           int tx_type) {
-  av1_fht16x16_c(input, output, stride, tx_type);
-}
 #endif  // CONFIG_HIGHBITDEPTH
 
 void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride,
@@ -2271,5 +2206,54 @@
   for (int i = 0; i < 32; ++i) temp_in[i] = input[i * stride];
   ft(temp_in, output);
 }
+
+#if CONFIG_HIGHBITDEPTH
+void av1_hbd_dpcm_ft4_c(const int16_t *input, int stride, TX_TYPE_1D tx_type,
+                        tran_low_t *output, int dir) {
+  (void)dir;
+  assert(tx_type < TX_TYPES_1D);
+  static const transform_1d FHT[] = { fdct4, fadst4, fadst4, fidtx4 };
+  const transform_1d ft = FHT[tx_type];
+  tran_low_t temp_in[4];
+  for (int i = 0; i < 4; ++i)
+    temp_in[i] = (tran_low_t)fdct_round_shift(input[i * stride] * 4 * Sqrt2);
+  ft(temp_in, output);
+}
+
+void av1_hbd_dpcm_ft8_c(const int16_t *input, int stride, TX_TYPE_1D tx_type,
+                        tran_low_t *output, int dir) {
+  (void)dir;
+  assert(tx_type < TX_TYPES_1D);
+  static const transform_1d FHT[] = { fdct8, fadst8, fadst8, fidtx8 };
+  const transform_1d ft = FHT[tx_type];
+  tran_low_t temp_in[8];
+  for (int i = 0; i < 8; ++i) temp_in[i] = input[i * stride] * 4;
+  ft(temp_in, output);
+}
+
+void av1_hbd_dpcm_ft16_c(const int16_t *input, int stride, TX_TYPE_1D tx_type,
+                         tran_low_t *output, int dir) {
+  (void)dir;
+  assert(tx_type < TX_TYPES_1D);
+  static const transform_1d FHT[] = { fdct16, fadst16, fadst16, fidtx16 };
+  const transform_1d ft = FHT[tx_type];
+  tran_low_t temp_in[16];
+  for (int i = 0; i < 16; ++i)
+    temp_in[i] = (tran_low_t)fdct_round_shift(input[i * stride] * 2 * Sqrt2);
+  ft(temp_in, output);
+}
+
+void av1_hbd_dpcm_ft32_c(const int16_t *input, int stride, TX_TYPE_1D tx_type,
+                         tran_low_t *output, int dir) {
+  (void)dir;
+  assert(tx_type < TX_TYPES_1D);
+  static const transform_1d FHT[] = { fdct32, fhalfright32, fhalfright32,
+                                      fidtx32 };
+  const transform_1d ft = FHT[tx_type];
+  tran_low_t temp_in[32];
+  for (int i = 0; i < 32; ++i) temp_in[i] = input[i * stride];
+  ft(temp_in, output);
+}
+#endif  // CONFIG_HIGHBITDEPTH
 #endif  // CONFIG_DPCM_INTRA
 #endif  // !AV1_DCT_GTEST
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 7c97815..b0945ed 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -1435,6 +1435,24 @@
   }
 }
 
+#if CONFIG_HIGHBITDEPTH
+typedef void (*hbd_dpcm_fwd_tx_func)(const int16_t *input, int stride,
+                                     TX_TYPE_1D tx_type, tran_low_t *output,
+                                     int dir);
+
+static hbd_dpcm_fwd_tx_func get_hbd_dpcm_fwd_tx_func(int tx_length) {
+  switch (tx_length) {
+    case 4: return av1_hbd_dpcm_ft4_c;
+    case 8: return av1_hbd_dpcm_ft8_c;
+    case 16: return av1_hbd_dpcm_ft16_c;
+    case 32:
+      return av1_hbd_dpcm_ft32_c;
+    // TODO(huisu): add support for TX_64X64.
+    default: assert(0); return NULL;
+  }
+}
+#endif  // CONFIG_HIGHBITDEPTH
+
 typedef void (*dpcm_fwd_tx_func)(const int16_t *input, int stride,
                                  TX_TYPE_1D tx_type, tran_low_t *output);
 
@@ -1539,7 +1557,7 @@
     int16_t *src_diff, int diff_stride, tran_low_t *coeff, tran_low_t *qcoeff,
     tran_low_t *dqcoeff) {
   const int tx1d_width = tx_size_wide[tx_size];
-  dpcm_fwd_tx_func forward_tx = get_dpcm_fwd_tx_func(tx1d_width);
+  hbd_dpcm_fwd_tx_func forward_tx = get_hbd_dpcm_fwd_tx_func(tx1d_width);
   hbd_dpcm_inv_txfm_add_func inverse_tx =
       av1_get_hbd_dpcm_inv_txfm_add_func(tx1d_width);
   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
@@ -1553,7 +1571,7 @@
     // Subtraction.
     for (int c = 0; c < tx1d_width; ++c) src_diff[c] = src[c] - dst[c];
     // Forward transform.
-    forward_tx(src_diff, 1, tx_type_1d, coeff);
+    forward_tx(src_diff, 1, tx_type_1d, coeff, 1);
     // Quantization.
     for (int c = 0; c < tx1d_width; ++c) {
       quantize_scaler(coeff[c], p->zbin[q_idx], p->round[q_idx],
@@ -1562,7 +1580,7 @@
       q_idx = 1;
     }
     // Inverse transform.
-    inverse_tx(dqcoeff, 1, tx_type_1d, bd, dst);
+    inverse_tx(dqcoeff, 1, tx_type_1d, bd, dst, 1);
     // Move to the next row.
     coeff += tx1d_width;
     qcoeff += tx1d_width;
@@ -1580,7 +1598,7 @@
     int16_t *src_diff, int diff_stride, tran_low_t *coeff, tran_low_t *qcoeff,
     tran_low_t *dqcoeff) {
   const int tx1d_height = tx_size_high[tx_size];
-  dpcm_fwd_tx_func forward_tx = get_dpcm_fwd_tx_func(tx1d_height);
+  hbd_dpcm_fwd_tx_func forward_tx = get_hbd_dpcm_fwd_tx_func(tx1d_height);
   hbd_dpcm_inv_txfm_add_func inverse_tx =
       av1_get_hbd_dpcm_inv_txfm_add_func(tx1d_height);
   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
@@ -1597,7 +1615,7 @@
     }
     // Forward transform.
     tran_low_t tx_buff[64];
-    forward_tx(src_diff, diff_stride, tx_type_1d, tx_buff);
+    forward_tx(src_diff, diff_stride, tx_type_1d, tx_buff, 0);
     for (int r = 0; r < tx1d_height; ++r) coeff[r * tx1d_width] = tx_buff[r];
     // Quantization.
     for (int r = 0; r < tx1d_height; ++r) {
@@ -1609,7 +1627,7 @@
     }
     // Inverse transform.
     for (int r = 0; r < tx1d_height; ++r) tx_buff[r] = dqcoeff[r * tx1d_width];
-    inverse_tx(tx_buff, dst_stride, tx_type_1d, bd, dst);
+    inverse_tx(tx_buff, dst_stride, tx_type_1d, bd, dst, 0);
     // Move to the next column.
     ++coeff, ++qcoeff, ++dqcoeff, ++src_diff, ++dst, ++src;
   }
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index c57deed..c71cae9 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -201,12 +201,12 @@
     av1_highbd_fwht4x4(src_diff, coeff, diff_stride);
     return;
   }
-
   switch (tx_type) {
     case DCT_DCT:
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
       av1_fwd_txfm2d_4x4(src_diff, coeff, diff_stride, tx_type, bd);
       break;
 #if CONFIG_EXT_TX
@@ -215,17 +215,20 @@
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
+      // fallthrough intended
       av1_fwd_txfm2d_4x4(src_diff, coeff, diff_stride, tx_type, bd);
       break;
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      av1_highbd_fht4x4_c(src_diff, coeff, diff_stride, tx_type);
+    case IDTX:
+      // fallthrough intended
+      av1_fwd_txfm2d_4x4_c(src_diff, coeff, diff_stride, tx_type, bd);
       break;
-    case IDTX: av1_fwd_idtx_c(src_diff, coeff, diff_stride, 4, tx_type); break;
 #endif  // CONFIG_EXT_TX
     default: assert(0);
   }
@@ -235,48 +238,42 @@
                                 int diff_stride, TX_TYPE tx_type,
                                 FWD_TXFM_OPT fwd_txfm_opt, const int bd) {
   (void)fwd_txfm_opt;
-  (void)bd;
-  av1_highbd_fht4x8(src_diff, coeff, diff_stride, tx_type);
+  av1_fwd_txfm2d_4x8_c(src_diff, coeff, diff_stride, tx_type, bd);
 }
 
 static void highbd_fwd_txfm_8x4(const int16_t *src_diff, tran_low_t *coeff,
                                 int diff_stride, TX_TYPE tx_type,
                                 FWD_TXFM_OPT fwd_txfm_opt, const int bd) {
   (void)fwd_txfm_opt;
-  (void)bd;
-  av1_highbd_fht8x4(src_diff, coeff, diff_stride, tx_type);
+  av1_fwd_txfm2d_8x4_c(src_diff, coeff, diff_stride, tx_type, bd);
 }
 
 static void highbd_fwd_txfm_8x16(const int16_t *src_diff, tran_low_t *coeff,
                                  int diff_stride, TX_TYPE tx_type,
                                  FWD_TXFM_OPT fwd_txfm_opt, const int bd) {
   (void)fwd_txfm_opt;
-  (void)bd;
-  av1_highbd_fht8x16(src_diff, coeff, diff_stride, tx_type);
+  av1_fwd_txfm2d_8x16_c(src_diff, coeff, diff_stride, tx_type, bd);
 }
 
 static void highbd_fwd_txfm_16x8(const int16_t *src_diff, tran_low_t *coeff,
                                  int diff_stride, TX_TYPE tx_type,
                                  FWD_TXFM_OPT fwd_txfm_opt, const int bd) {
   (void)fwd_txfm_opt;
-  (void)bd;
-  av1_highbd_fht16x8(src_diff, coeff, diff_stride, tx_type);
+  av1_fwd_txfm2d_16x8_c(src_diff, coeff, diff_stride, tx_type, bd);
 }
 
 static void highbd_fwd_txfm_16x32(const int16_t *src_diff, tran_low_t *coeff,
                                   int diff_stride, TX_TYPE tx_type,
                                   FWD_TXFM_OPT fwd_txfm_opt, const int bd) {
   (void)fwd_txfm_opt;
-  (void)bd;
-  av1_highbd_fht16x32(src_diff, coeff, diff_stride, tx_type);
+  av1_fwd_txfm2d_16x32_c(src_diff, coeff, diff_stride, tx_type, bd);
 }
 
 static void highbd_fwd_txfm_32x16(const int16_t *src_diff, tran_low_t *coeff,
                                   int diff_stride, TX_TYPE tx_type,
                                   FWD_TXFM_OPT fwd_txfm_opt, const int bd) {
   (void)fwd_txfm_opt;
-  (void)bd;
-  av1_highbd_fht32x16(src_diff, coeff, diff_stride, tx_type);
+  av1_fwd_txfm2d_32x16_c(src_diff, coeff, diff_stride, tx_type, bd);
 }
 
 static void highbd_fwd_txfm_8x8(const int16_t *src_diff, tran_low_t *coeff,
@@ -288,6 +285,7 @@
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
       av1_fwd_txfm2d_8x8(src_diff, coeff, diff_stride, tx_type, bd);
       break;
 #if CONFIG_EXT_TX
@@ -296,18 +294,20 @@
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
+      // fallthrough intended
       av1_fwd_txfm2d_8x8(src_diff, coeff, diff_stride, tx_type, bd);
       break;
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      // Use C version since DST exists only in C
-      av1_highbd_fht8x8_c(src_diff, coeff, diff_stride, tx_type);
+    case IDTX:
+      // fallthrough intended
+      av1_fwd_txfm2d_8x8_c(src_diff, coeff, diff_stride, tx_type, bd);
       break;
-    case IDTX: av1_fwd_idtx_c(src_diff, coeff, diff_stride, 8, tx_type); break;
 #endif  // CONFIG_EXT_TX
     default: assert(0);
   }
@@ -322,6 +322,7 @@
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
       av1_fwd_txfm2d_16x16(src_diff, coeff, diff_stride, tx_type, bd);
       break;
 #if CONFIG_EXT_TX
@@ -330,18 +331,20 @@
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
+      // fallthrough intended
       av1_fwd_txfm2d_16x16(src_diff, coeff, diff_stride, tx_type, bd);
       break;
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      // Use C version since DST exists only in C
-      av1_highbd_fht16x16_c(src_diff, coeff, diff_stride, tx_type);
+    case IDTX:
+      // fallthrough intended
+      av1_fwd_txfm2d_16x16_c(src_diff, coeff, diff_stride, tx_type, bd);
       break;
-    case IDTX: av1_fwd_idtx_c(src_diff, coeff, diff_stride, 16, tx_type); break;
 #endif  // CONFIG_EXT_TX
     default: assert(0);
   }
@@ -353,28 +356,34 @@
   (void)fwd_txfm_opt;
   switch (tx_type) {
     case DCT_DCT:
-      av1_fwd_txfm2d_32x32(src_diff, coeff, diff_stride, tx_type, bd);
-      break;
-#if CONFIG_EXT_TX
     case ADST_DCT:
     case DCT_ADST:
     case ADST_ADST:
+      // fallthrough intended
+      av1_fwd_txfm2d_32x32(src_diff, coeff, diff_stride, tx_type, bd);
+      break;
+#if CONFIG_EXT_TX
     case FLIPADST_DCT:
     case DCT_FLIPADST:
     case FLIPADST_FLIPADST:
     case ADST_FLIPADST:
     case FLIPADST_ADST:
+      // fallthrough intended
+      av1_fwd_txfm2d_32x32(src_diff, coeff, diff_stride, tx_type, bd);
+      break;
+    // use the c version for anything including identity for now
     case V_DCT:
     case H_DCT:
     case V_ADST:
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      av1_highbd_fht32x32_c(src_diff, coeff, diff_stride, tx_type);
+    case IDTX:
+      // fallthrough intended
+      av1_fwd_txfm2d_32x32_c(src_diff, coeff, diff_stride, tx_type, bd);
       break;
-    case IDTX: av1_fwd_idtx_c(src_diff, coeff, diff_stride, 32, tx_type); break;
 #endif  // CONFIG_EXT_TX
-    default: assert(0); break;
+    default: assert(0);
   }
 }
 
@@ -386,7 +395,7 @@
   (void)bd;
   switch (tx_type) {
     case DCT_DCT:
-      av1_highbd_fht64x64(src_diff, coeff, diff_stride, tx_type);
+      av1_fwd_txfm2d_64x64(src_diff, coeff, diff_stride, tx_type, bd);
       break;
 #if CONFIG_EXT_TX
     case ADST_DCT:
@@ -403,7 +412,13 @@
     case H_ADST:
     case V_FLIPADST:
     case H_FLIPADST:
-      av1_highbd_fht64x64(src_diff, coeff, diff_stride, tx_type);
+      // TODO(sarahparker)
+      // I've deleted the 64x64 implementations that existed in lieu
+      // of adst, flipadst and identity for simplicity but will bring back
+      // in a later change. This shouldn't impact performance since
+      // DCT_DCT is the only extended type currently allowed for 64x64,
+      // as dictated by get_ext_tx_set_type in blockd.h.
+      av1_fwd_txfm2d_64x64_c(src_diff, coeff, diff_stride, DCT_DCT, bd);
       break;
     case IDTX: av1_fwd_idtx_c(src_diff, coeff, diff_stride, 64, tx_type); break;
 #endif  // CONFIG_EXT_TX
diff --git a/av1/encoder/x86/highbd_fwd_txfm_sse4.c b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
index b56eed5..fb74068 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_sse4.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
@@ -120,18 +120,6 @@
   _mm_store_si128((__m128i *)(output + 3 * 4), res[3]);
 }
 
-// Note:
-//  We implement av1_fwd_txfm2d_4x4(). This function is kept here since
-//  av1_highbd_fht4x4_c() is not removed yet
-void av1_highbd_fht4x4_sse4_1(const int16_t *input, tran_low_t *output,
-                              int stride, int tx_type) {
-  (void)input;
-  (void)output;
-  (void)stride;
-  (void)tx_type;
-  assert(0);
-}
-
 static void fadst4x4_sse4_1(__m128i *in, int bit) {
   const int32_t *cospi = cospi_arr(bit);
   const __m128i cospi8 = _mm_set1_epi32(cospi[8]);
diff --git a/test/av1_fwd_txfm1d_test.cc b/test/av1_fwd_txfm1d_test.cc
index 511e057..b10e84d 100644
--- a/test/av1_fwd_txfm1d_test.cc
+++ b/test/av1_fwd_txfm1d_test.cc
@@ -82,7 +82,7 @@
   int col = 1;
   int block_size = 3;
   int stride = 5;
-  clamp_block(block[row] + col, block_size, stride, -4, 2);
+  clamp_block(block[row] + col, block_size, block_size, stride, -4, 2);
   for (int r = 0; r < stride; r++) {
     for (int c = 0; c < stride; c++) {
       EXPECT_EQ(block[r][c], ref_block[r][c]);
diff --git a/test/dct16x16_test.cc b/test/dct16x16_test.cc
index 89263ce..50e61f6 100644
--- a/test/dct16x16_test.cc
+++ b/test/dct16x16_test.cc
@@ -255,12 +255,20 @@
 }
 
 #if CONFIG_HIGHBITDEPTH
+void fht16x16_10(const int16_t *in, tran_low_t *out, int stride, int tx_type) {
+  av1_fwd_txfm2d_16x16_c(in, out, stride, tx_type, 10);
+}
+
+void fht16x16_12(const int16_t *in, tran_low_t *out, int stride, int tx_type) {
+  av1_fwd_txfm2d_16x16_c(in, out, stride, tx_type, 12);
+}
+
 void iht16x16_10(const tran_low_t *in, uint8_t *out, int stride, int tx_type) {
-  av1_highbd_iht16x16_256_add_c(in, out, stride, tx_type, 10);
+  av1_inv_txfm2d_add_16x16_c(in, CONVERT_TO_SHORTPTR(out), stride, tx_type, 10);
 }
 
 void iht16x16_12(const tran_low_t *in, uint8_t *out, int stride, int tx_type) {
-  av1_highbd_iht16x16_256_add_c(in, out, stride, tx_type, 12);
+  av1_inv_txfm2d_add_16x16_c(in, CONVERT_TO_SHORTPTR(out), stride, tx_type, 12);
 }
 #endif  // CONFIG_HIGHBITDEPTH
 
@@ -625,9 +633,18 @@
     mask_ = (1 << bit_depth_) - 1;
 #if CONFIG_HIGHBITDEPTH
     switch (bit_depth_) {
-      case AOM_BITS_10: inv_txfm_ref = iht16x16_10; break;
-      case AOM_BITS_12: inv_txfm_ref = iht16x16_12; break;
-      default: inv_txfm_ref = iht16x16_ref; break;
+      case AOM_BITS_10:
+        fwd_txfm_ref = fht16x16_10;
+        inv_txfm_ref = iht16x16_10;
+        break;
+      case AOM_BITS_12:
+        fwd_txfm_ref = fht16x16_12;
+        inv_txfm_ref = iht16x16_12;
+        break;
+      default:
+        fwd_txfm_ref = fht16x16_ref;
+        inv_txfm_ref = iht16x16_ref;
+        break;
     }
 #else
     inv_txfm_ref = iht16x16_ref;
@@ -767,14 +784,14 @@
 INSTANTIATE_TEST_CASE_P(
     C, Trans16x16HT,
     ::testing::Values(
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_10, 0, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_10, 1, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_10, 2, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_10, 3, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_12, 0, AOM_BITS_12),
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_12, 1, AOM_BITS_12),
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_12, 2, AOM_BITS_12),
-        make_tuple(&av1_highbd_fht16x16_c, &iht16x16_12, 3, AOM_BITS_12),
+        make_tuple(&fht16x16_10, &iht16x16_10, 0, AOM_BITS_10),
+        make_tuple(&fht16x16_10, &iht16x16_10, 1, AOM_BITS_10),
+        make_tuple(&fht16x16_10, &iht16x16_10, 2, AOM_BITS_10),
+        make_tuple(&fht16x16_10, &iht16x16_10, 3, AOM_BITS_10),
+        make_tuple(&fht16x16_12, &iht16x16_12, 0, AOM_BITS_12),
+        make_tuple(&fht16x16_12, &iht16x16_12, 1, AOM_BITS_12),
+        make_tuple(&fht16x16_12, &iht16x16_12, 2, AOM_BITS_12),
+        make_tuple(&fht16x16_12, &iht16x16_12, 3, AOM_BITS_12),
         make_tuple(&av1_fht16x16_c, &av1_iht16x16_256_add_c, 0, AOM_BITS_8),
         make_tuple(&av1_fht16x16_c, &av1_iht16x16_256_add_c, 1, AOM_BITS_8),
         make_tuple(&av1_fht16x16_c, &av1_iht16x16_256_add_c, 2, AOM_BITS_8),
diff --git a/test/fdct4x4_test.cc b/test/fdct4x4_test.cc
index ed265e8..a2b180a 100644
--- a/test/fdct4x4_test.cc
+++ b/test/fdct4x4_test.cc
@@ -55,6 +55,14 @@
 }
 
 #if CONFIG_HIGHBITDEPTH
+void fht4x4_10(const int16_t *in, tran_low_t *out, int stride, int tx_type) {
+  av1_fwd_txfm2d_4x4_c(in, out, stride, tx_type, 10);
+}
+
+void fht4x4_12(const int16_t *in, tran_low_t *out, int stride, int tx_type) {
+  av1_fwd_txfm2d_4x4_c(in, out, stride, tx_type, 12);
+}
+
 void idct4x4_10(const tran_low_t *in, uint8_t *out, int stride) {
   aom_highbd_idct4x4_16_add_c(in, out, stride, 10);
 }
@@ -64,11 +72,11 @@
 }
 
 void iht4x4_10(const tran_low_t *in, uint8_t *out, int stride, int tx_type) {
-  av1_highbd_iht4x4_16_add_c(in, out, stride, tx_type, 10);
+  av1_inv_txfm2d_add_4x4_c(in, CONVERT_TO_SHORTPTR(out), stride, tx_type, 10);
 }
 
 void iht4x4_12(const tran_low_t *in, uint8_t *out, int stride, int tx_type) {
-  av1_highbd_iht4x4_16_add_c(in, out, stride, tx_type, 12);
+  av1_inv_txfm2d_add_4x4_c(in, CONVERT_TO_SHORTPTR(out), stride, tx_type, 12);
 }
 
 void iwht4x4_10(const tran_low_t *in, uint8_t *out, int stride) {
@@ -143,6 +151,13 @@
     bit_depth_ = GET_PARAM(3);
     mask_ = (1 << bit_depth_) - 1;
     num_coeffs_ = GET_PARAM(4);
+#if CONFIG_HIGHBITDEPTH
+    switch (bit_depth_) {
+      case AOM_BITS_10: fwd_txfm_ref = fht4x4_10; break;
+      case AOM_BITS_12: fwd_txfm_ref = fht4x4_12; break;
+      default: fwd_txfm_ref = fht4x4_ref; break;
+    }
+#endif
   }
   virtual void TearDown() { libaom_test::ClearSystemState(); }
 
@@ -222,16 +237,19 @@
 
 #if CONFIG_HIGHBITDEPTH
 INSTANTIATE_TEST_CASE_P(
+    DISABLED_C, Trans4x4HT,
+    ::testing::Values(make_tuple(&fht4x4_12, &iht4x4_12, 0, AOM_BITS_12, 16),
+                      make_tuple(&fht4x4_12, &iht4x4_12, 1, AOM_BITS_12, 16),
+                      make_tuple(&fht4x4_12, &iht4x4_12, 2, AOM_BITS_12, 16),
+                      make_tuple(&fht4x4_12, &iht4x4_12, 3, AOM_BITS_12, 16)));
+
+INSTANTIATE_TEST_CASE_P(
     C, Trans4x4HT,
     ::testing::Values(
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_10, 0, AOM_BITS_10, 16),
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_10, 1, AOM_BITS_10, 16),
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_10, 2, AOM_BITS_10, 16),
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_10, 3, AOM_BITS_10, 16),
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_12, 0, AOM_BITS_12, 16),
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_12, 1, AOM_BITS_12, 16),
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_12, 2, AOM_BITS_12, 16),
-        make_tuple(&av1_highbd_fht4x4_c, &iht4x4_12, 3, AOM_BITS_12, 16),
+        make_tuple(&fht4x4_10, &iht4x4_10, 0, AOM_BITS_10, 16),
+        make_tuple(&fht4x4_10, &iht4x4_10, 1, AOM_BITS_10, 16),
+        make_tuple(&fht4x4_10, &iht4x4_10, 2, AOM_BITS_10, 16),
+        make_tuple(&fht4x4_10, &iht4x4_10, 3, AOM_BITS_10, 16),
         make_tuple(&av1_fht4x4_c, &av1_iht4x4_16_add_c, 0, AOM_BITS_8, 16),
         make_tuple(&av1_fht4x4_c, &av1_iht4x4_16_add_c, 1, AOM_BITS_8, 16),
         make_tuple(&av1_fht4x4_c, &av1_iht4x4_16_add_c, 2, AOM_BITS_8, 16),
diff --git a/test/fdct8x8_test.cc b/test/fdct8x8_test.cc
index 0e86c70..ff3ce4b 100644
--- a/test/fdct8x8_test.cc
+++ b/test/fdct8x8_test.cc
@@ -87,12 +87,20 @@
 }
 
 #if CONFIG_HIGHBITDEPTH
+void fht8x8_10(const int16_t *in, tran_low_t *out, int stride, int tx_type) {
+  av1_fwd_txfm2d_8x8_c(in, out, stride, tx_type, 10);
+}
+
+void fht8x8_12(const int16_t *in, tran_low_t *out, int stride, int tx_type) {
+  av1_fwd_txfm2d_8x8_c(in, out, stride, tx_type, 12);
+}
+
 void iht8x8_10(const tran_low_t *in, uint8_t *out, int stride, int tx_type) {
-  av1_highbd_iht8x8_64_add_c(in, out, stride, tx_type, 10);
+  av1_inv_txfm2d_add_8x8_c(in, CONVERT_TO_SHORTPTR(out), stride, tx_type, 10);
 }
 
 void iht8x8_12(const tran_low_t *in, uint8_t *out, int stride, int tx_type) {
-  av1_highbd_iht8x8_64_add_c(in, out, stride, tx_type, 12);
+  av1_inv_txfm2d_add_8x8_c(in, CONVERT_TO_SHORTPTR(out), stride, tx_type, 12);
 }
 
 #endif  // CONFIG_HIGHBITDEPTH
@@ -534,6 +542,13 @@
     fwd_txfm_ref = fht8x8_ref;
     bit_depth_ = GET_PARAM(3);
     mask_ = (1 << bit_depth_) - 1;
+#if CONFIG_HIGHBITDEPTH
+    switch (bit_depth_) {
+      case AOM_BITS_10: fwd_txfm_ref = fht8x8_10; break;
+      case AOM_BITS_12: fwd_txfm_ref = fht8x8_12; break;
+      default: fwd_txfm_ref = fht8x8_ref; break;
+    }
+#endif
   }
 
   virtual void TearDown() { libaom_test::ClearSystemState(); }
@@ -606,14 +621,14 @@
     C, FwdTrans8x8HT,
     ::testing::Values(
         make_tuple(&av1_fht8x8_c, &av1_iht8x8_64_add_c, 0, AOM_BITS_8),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_10, 0, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_10, 1, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_10, 2, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_10, 3, AOM_BITS_10),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_12, 0, AOM_BITS_12),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_12, 1, AOM_BITS_12),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_12, 2, AOM_BITS_12),
-        make_tuple(&av1_highbd_fht8x8_c, &iht8x8_12, 3, AOM_BITS_12),
+        make_tuple(&fht8x8_10, &iht8x8_10, 0, AOM_BITS_10),
+        make_tuple(&fht8x8_10, &iht8x8_10, 1, AOM_BITS_10),
+        make_tuple(&fht8x8_10, &iht8x8_10, 2, AOM_BITS_10),
+        make_tuple(&fht8x8_10, &iht8x8_10, 3, AOM_BITS_10),
+        make_tuple(&fht8x8_12, &iht8x8_12, 0, AOM_BITS_12),
+        make_tuple(&fht8x8_12, &iht8x8_12, 1, AOM_BITS_12),
+        make_tuple(&fht8x8_12, &iht8x8_12, 2, AOM_BITS_12),
+        make_tuple(&fht8x8_12, &iht8x8_12, 3, AOM_BITS_12),
         make_tuple(&av1_fht8x8_c, &av1_iht8x8_64_add_c, 1, AOM_BITS_8),
         make_tuple(&av1_fht8x8_c, &av1_iht8x8_64_add_c, 2, AOM_BITS_8),
         make_tuple(&av1_fht8x8_c, &av1_iht8x8_64_add_c, 3, AOM_BITS_8)));