[CFL] SSSE3/AVX2 versions of cfl_build_prediction_hbd

Includes unit tests for conformance and speed.

SSSE3/CFLPredictHBDTest:
4x4: C time = 1436 us, SIMD time = 358 us (~4x)
8x8: C time = 4821 us, SIMD time = 598 us (~8.1x)
16x16: C time = 18528 us, SIMD time = 1793 us (~10x)
32x32: C time = 72998 us, SIMD time = 6400 us (~11x)

AVX2/CFLPredictHBDTest:
4x4: C time = 1436 us, SIMD time = 398 us (~3.6x)
8x8: C time = 4924 us, SIMD time = 644 us (~7.6x)
16x16: C time = 18624 us, SIMD time = 1617 us (~12x)
32x32: C time = 73509 us, SIMD time = 3635 us (~20x)

Change-Id: Icbcfefbf165facdbd77c9b3861af2bbf464254a0
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 9d79d26..0e18faf 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -591,6 +591,9 @@
 
   add_proto qw/cfl_predict_lbd_fn get_predict_lbd_fn/, "TX_SIZE tx_size";
   specialize qw/get_predict_lbd_fn ssse3 avx2/;
+
+  add_proto qw/cfl_predict_hbd_fn get_predict_hbd_fn/, "TX_SIZE tx_size";
+  specialize qw/get_predict_hbd_fn ssse3 avx2/;
 }
 
 1;
diff --git a/av1/common/cfl.c b/av1/common/cfl.c
index 10a8b9f..d0e69d5 100644
--- a/av1/common/cfl.c
+++ b/av1/common/cfl.c
@@ -186,8 +186,10 @@
 }
 
 static void cfl_build_prediction_hbd(const int16_t *pred_buf_q3, uint16_t *dst,
-                                     int dst_stride, int width, int height,
+                                     int dst_stride, TX_SIZE tx_size,
                                      int alpha_q3, int bit_depth) {
+  const int height = tx_size_high[tx_size];
+  const int width = tx_size_wide[tx_size];
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       dst[i] = clip_pixel_highbd(
@@ -234,6 +236,11 @@
   return cfl_build_prediction_lbd;
 }
 
+cfl_predict_hbd_fn get_predict_hbd_fn_c(TX_SIZE tx_size) {
+  (void)tx_size;
+  return cfl_build_prediction_hbd;
+}
+
 void cfl_predict_block(MACROBLOCKD *const xd, uint8_t *dst, int dst_stride,
                        TX_SIZE tx_size, int plane) {
   CFL_CTX *const cfl = &xd->cfl;
@@ -244,13 +251,12 @@
 
   const int alpha_q3 =
       cfl_idx_to_alpha(mbmi->cfl_alpha_idx, mbmi->cfl_alpha_signs, plane - 1);
-  const int width = tx_size_wide[tx_size];
-  const int height = tx_size_high[tx_size];
-  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
+  assert((tx_size_high[tx_size] - 1) * CFL_BUF_LINE + tx_size_wide[tx_size] <=
+         CFL_BUF_SQUARE);
   if (get_bitdepth_data_path_index(xd)) {
     uint16_t *dst_16 = CONVERT_TO_SHORTPTR(dst);
-    cfl_build_prediction_hbd(cfl->pred_buf_q3, dst_16, dst_stride, width,
-                             height, alpha_q3, xd->bd);
+    get_predict_hbd_fn(tx_size)(cfl->pred_buf_q3, dst_16, dst_stride, tx_size,
+                                alpha_q3, xd->bd);
     return;
   }
   get_predict_lbd_fn(tx_size)(cfl->pred_buf_q3, dst, dst_stride, tx_size,
diff --git a/av1/common/cfl.h b/av1/common/cfl.h
index ccec482..2b6050f 100644
--- a/av1/common/cfl.h
+++ b/av1/common/cfl.h
@@ -21,6 +21,10 @@
                                    int dst_stride, TX_SIZE tx_size,
                                    int alpha_q3);
 
+typedef void (*cfl_predict_hbd_fn)(const int16_t *pred_buf_q3, uint16_t *dst,
+                                   int dst_stride, TX_SIZE tx_size,
+                                   int alpha_q3, int bd);
+
 static INLINE int is_cfl_allowed(const MB_MODE_INFO *mbmi) {
   const BLOCK_SIZE bsize = mbmi->sb_type;
   assert(bsize < BLOCK_SIZES_ALL);
diff --git a/av1/common/cfl_avx2.c b/av1/common/cfl_avx2.c
index e7983f5..1428d71 100644
--- a/av1/common/cfl_avx2.c
+++ b/av1/common/cfl_avx2.c
@@ -100,9 +100,8 @@
   return subsample_lbd[sub_y & 1][sub_x & 1];
 }
 
-static INLINE __m256i predict_lbd_unclipped(const __m256i *input,
-                                            __m256i alpha_q12,
-                                            __m256i alpha_sign, __m256i dc_q0) {
+static INLINE __m256i predict_unclipped(const __m256i *input, __m256i alpha_q12,
+                                        __m256i alpha_sign, __m256i dc_q0) {
   __m256i ac_q3 = _mm256_loadu_si256(input);
   __m256i ac_sign = _mm256_sign_epi16(alpha_sign, ac_q3);
   __m256i scaled_luma_q0 =
@@ -119,12 +118,12 @@
   const __m256i alpha_q12 = _mm256_slli_epi16(_mm256_abs_epi16(alpha_sign), 9);
   const __m256i dc_q0 = _mm256_set1_epi16(*dst);
   do {
-    __m256i res = predict_lbd_unclipped((__m256i *)pred_buf_q3, alpha_q12,
-                                        alpha_sign, dc_q0);
+    __m256i res =
+        predict_unclipped((__m256i *)pred_buf_q3, alpha_q12, alpha_sign, dc_q0);
     __m256i next = res;
     if (width == 32)
-      next = predict_lbd_unclipped((__m256i *)(pred_buf_q3 + 16), alpha_q12,
-                                   alpha_sign, dc_q0);
+      next = predict_unclipped((__m256i *)(pred_buf_q3 + 16), alpha_q12,
+                               alpha_sign, dc_q0);
     res = _mm256_packus_epi16(res, next);
     if (width == 4) {
       *(int32_t *)dst = _mm256_extract_epi32(res, 0);
@@ -146,30 +145,85 @@
   } while (pred_buf_q3 < row_end);
 }
 
-static void cfl_predict_lbd_4(const int16_t *pred_buf_q3, uint8_t *dst,
-                              int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 4);
+static __m256i highbd_max_epi16(int bd) {
+  const __m256i neg_one = _mm256_set1_epi16(-1);
+  // (1 << bd) - 1 => -(-1 << bd) -1 => -1 - (-1 << bd) => -1 ^ (-1 << bd)
+  return _mm256_xor_si256(_mm256_slli_epi16(neg_one, bd), neg_one);
 }
 
-static void cfl_predict_lbd_8(const int16_t *pred_buf_q3, uint8_t *dst,
-                              int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 8);
+static __m256i highbd_clamp_epi16(__m256i u, __m256i zero, __m256i max) {
+  return _mm256_max_epi16(_mm256_min_epi16(u, max), zero);
 }
 
-static void cfl_predict_lbd_16(const int16_t *pred_buf_q3, uint8_t *dst,
-                               int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 16);
+static INLINE void cfl_predict_hbd_x(const int16_t *pred_buf_q3, uint16_t *dst,
+                                     int dst_stride, TX_SIZE tx_size,
+                                     int alpha_q3, int bd, int width) {
+  const int16_t *row_end = pred_buf_q3 + tx_size_high[tx_size] * CFL_BUF_LINE;
+  const __m256i alpha_sign = _mm256_set1_epi16(alpha_q3);
+  const __m256i alpha_q12 = _mm256_slli_epi16(_mm256_abs_epi16(alpha_sign), 9);
+  const __m256i dc_q0 = _mm256_loadu_si256((__m256i *)dst);
+  const __m256i max = highbd_max_epi16(bd);
+  const __m256i zero = _mm256_setzero_si256();
+  do {
+    __m256i res =
+        predict_unclipped((__m256i *)pred_buf_q3, alpha_q12, alpha_sign, dc_q0);
+    res = highbd_clamp_epi16(res, zero, max);
+    if (width == 4)
+#ifdef __x86_64__
+      *(int64_t *)dst = _mm256_extract_epi64(res, 0);
+#else
+      _mm_storel_epi64((__m128i *)dst, _mm256_castsi256_si128(res));
+#endif
+    else if (width == 8)
+      _mm_storeu_si128((__m128i *)dst, _mm256_castsi256_si128(res));
+    else
+      _mm256_storeu_si256((__m256i *)dst, res);
+    if (width == 32) {
+      res = predict_unclipped((__m256i *)(pred_buf_q3 + 16), alpha_q12,
+                              alpha_sign, dc_q0);
+      res = highbd_clamp_epi16(res, zero, max);
+      _mm256_storeu_si256((__m256i *)(dst + 16), res);
+    }
+    dst += dst_stride;
+    pred_buf_q3 += CFL_BUF_LINE;
+  } while (pred_buf_q3 < row_end);
 }
 
-static void cfl_predict_lbd_32(const int16_t *pred_buf_q3, uint8_t *dst,
-                               int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 32);
-}
+#define CFL_PREDICT_LBD_X(width)                                               \
+  static void cfl_predict_lbd_##width(const int16_t *pred_buf_q3,              \
+                                      uint8_t *dst, int dst_stride,            \
+                                      TX_SIZE tx_size, int alpha_q3) {         \
+    cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, width); \
+  }
+
+CFL_PREDICT_LBD_X(4)
+CFL_PREDICT_LBD_X(8)
+CFL_PREDICT_LBD_X(16)
+CFL_PREDICT_LBD_X(32)
+
+#define CFL_PREDICT_HBD_X(width)                                               \
+  static void cfl_predict_hbd_##width(const int16_t *pred_buf_q3,              \
+                                      uint16_t *dst, int dst_stride,           \
+                                      TX_SIZE tx_size, int alpha_q3, int bd) { \
+    cfl_predict_hbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, bd,     \
+                      width);                                                  \
+  }
+
+CFL_PREDICT_HBD_X(4)
+CFL_PREDICT_HBD_X(8)
+CFL_PREDICT_HBD_X(16)
+CFL_PREDICT_HBD_X(32)
 
 cfl_predict_lbd_fn get_predict_lbd_fn_avx2(TX_SIZE tx_size) {
   static const cfl_predict_lbd_fn predict_lbd[4] = {
     cfl_predict_lbd_4, cfl_predict_lbd_8, cfl_predict_lbd_16, cfl_predict_lbd_32
   };
-  const int width_log2 = tx_size_wide_log2[tx_size];
-  return predict_lbd[(width_log2 - 2) & 3];
+  return predict_lbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3];
+}
+
+cfl_predict_hbd_fn get_predict_hbd_fn_avx2(TX_SIZE tx_size) {
+  static const cfl_predict_hbd_fn predict_hbd[4] = {
+    cfl_predict_hbd_4, cfl_predict_hbd_8, cfl_predict_hbd_16, cfl_predict_hbd_32
+  };
+  return predict_hbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3];
 }
diff --git a/av1/common/cfl_ssse3.c b/av1/common/cfl_ssse3.c
index 020471e..a05135f 100644
--- a/av1/common/cfl_ssse3.c
+++ b/av1/common/cfl_ssse3.c
@@ -90,9 +90,8 @@
   return subsample_lbd[sub_y & 1][sub_x & 1];
 }
 
-static INLINE __m128i predict_lbd_unclipped(const __m128i *input,
-                                            __m128i alpha_q12,
-                                            __m128i alpha_sign, __m128i dc_q0) {
+static INLINE __m128i predict_unclipped(const __m128i *input, __m128i alpha_q12,
+                                        __m128i alpha_sign, __m128i dc_q0) {
   __m128i ac_q3 = _mm_loadu_si128(input);
   __m128i ac_sign = _mm_sign_epi16(alpha_sign, ac_q3);
   __m128i scaled_luma_q0 = _mm_mulhrs_epi16(_mm_abs_epi16(ac_q3), alpha_q12);
@@ -108,8 +107,8 @@
   const __m128i alpha_q12 = _mm_slli_epi16(_mm_abs_epi16(alpha_sign), 9);
   const __m128i dc_q0 = _mm_set1_epi16(*dst);
   do {
-    __m128i res = predict_lbd_unclipped((__m128i *)(pred_buf_q3), alpha_q12,
-                                        alpha_sign, dc_q0);
+    __m128i res = predict_unclipped((__m128i *)(pred_buf_q3), alpha_q12,
+                                    alpha_sign, dc_q0);
     if (width < 16) {
       res = _mm_packus_epi16(res, res);
       if (width == 4)
@@ -117,15 +116,15 @@
       else
         _mm_storel_epi64((__m128i *)dst, res);
     } else {
-      __m128i next = predict_lbd_unclipped((__m128i *)(pred_buf_q3 + 8),
-                                           alpha_q12, alpha_sign, dc_q0);
+      __m128i next = predict_unclipped((__m128i *)(pred_buf_q3 + 8), alpha_q12,
+                                       alpha_sign, dc_q0);
       res = _mm_packus_epi16(res, next);
       _mm_storeu_si128((__m128i *)dst, res);
       if (width == 32) {
-        res = predict_lbd_unclipped((__m128i *)(pred_buf_q3 + 16), alpha_q12,
-                                    alpha_sign, dc_q0);
-        next = predict_lbd_unclipped((__m128i *)(pred_buf_q3 + 24), alpha_q12,
-                                     alpha_sign, dc_q0);
+        res = predict_unclipped((__m128i *)(pred_buf_q3 + 16), alpha_q12,
+                                alpha_sign, dc_q0);
+        next = predict_unclipped((__m128i *)(pred_buf_q3 + 24), alpha_q12,
+                                 alpha_sign, dc_q0);
         res = _mm_packus_epi16(res, next);
         _mm_storeu_si128((__m128i *)(dst + 16), res);
       }
@@ -135,30 +134,91 @@
   } while (dst < row_end);
 }
 
-static void cfl_predict_lbd_4(const int16_t *pred_buf_q3, uint8_t *dst,
-                              int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 4);
+static INLINE __m128i highbd_max_epi16(int bd) {
+  const __m128i neg_one = _mm_set1_epi16(-1);
+  // (1 << bd) - 1 => -(-1 << bd) -1 => -1 - (-1 << bd) => -1 ^ (-1 << bd)
+  return _mm_xor_si128(_mm_slli_epi16(neg_one, bd), neg_one);
 }
 
-static void cfl_predict_lbd_8(const int16_t *pred_buf_q3, uint8_t *dst,
-                              int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 8);
+static INLINE __m128i highbd_clamp_epi16(__m128i u, __m128i zero, __m128i max) {
+  return _mm_max_epi16(_mm_min_epi16(u, max), zero);
 }
 
-static void cfl_predict_lbd_16(const int16_t *pred_buf_q3, uint8_t *dst,
-                               int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 16);
+static INLINE void cfl_predict_hbd(__m128i *dst, __m128i *src,
+                                   __m128i alpha_q12, __m128i alpha_sign,
+                                   __m128i dc_q0, __m128i zero, __m128i max) {
+  __m128i res = predict_unclipped(src, alpha_q12, alpha_sign, dc_q0);
+  _mm_storeu_si128(dst, highbd_clamp_epi16(res, zero, max));
 }
 
-static void cfl_predict_lbd_32(const int16_t *pred_buf_q3, uint8_t *dst,
-                               int dst_stride, TX_SIZE tx_size, int alpha_q3) {
-  cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, 32);
+static INLINE void cfl_predict_hbd_x(const int16_t *pred_buf_q3, uint16_t *dst,
+                                     int dst_stride, TX_SIZE tx_size,
+                                     int alpha_q3, int bd, int width) {
+  uint16_t *row_end = dst + tx_size_high[tx_size] * dst_stride;
+  const __m128i alpha_sign = _mm_set1_epi16(alpha_q3);
+  const __m128i alpha_q12 = _mm_slli_epi16(_mm_abs_epi16(alpha_sign), 9);
+  const __m128i dc_q0 = width == 4 ? _mm_loadl_epi64((__m128i *)dst)
+                                   : _mm_load_si128((__m128i *)dst);
+  const __m128i max = highbd_max_epi16(bd);
+  const __m128i zero = _mm_setzero_si128();
+  do {
+    if (width == 4) {
+      __m128i res = predict_unclipped((__m128i *)(pred_buf_q3), alpha_q12,
+                                      alpha_sign, dc_q0);
+      _mm_storel_epi64((__m128i *)dst, highbd_clamp_epi16(res, zero, max));
+    } else {
+      cfl_predict_hbd((__m128i *)dst, (__m128i *)pred_buf_q3, alpha_q12,
+                      alpha_sign, dc_q0, zero, max);
+    }
+    if (width >= 16)
+      cfl_predict_hbd((__m128i *)(dst + 8), (__m128i *)(pred_buf_q3 + 8),
+                      alpha_q12, alpha_sign, dc_q0, zero, max);
+    if (width == 32) {
+      cfl_predict_hbd((__m128i *)(dst + 16), (__m128i *)(pred_buf_q3 + 16),
+                      alpha_q12, alpha_sign, dc_q0, zero, max);
+      cfl_predict_hbd((__m128i *)(dst + 24), (__m128i *)(pred_buf_q3 + 24),
+                      alpha_q12, alpha_sign, dc_q0, zero, max);
+    }
+    dst += dst_stride;
+    pred_buf_q3 += CFL_BUF_LINE;
+  } while (dst < row_end);
 }
 
+#define CFL_PREDICT_LBD_X(width)                                               \
+  static void cfl_predict_lbd_##width(const int16_t *pred_buf_q3,              \
+                                      uint8_t *dst, int dst_stride,            \
+                                      TX_SIZE tx_size, int alpha_q3) {         \
+    cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, width); \
+  }
+
+CFL_PREDICT_LBD_X(4)
+CFL_PREDICT_LBD_X(8)
+CFL_PREDICT_LBD_X(16)
+CFL_PREDICT_LBD_X(32)
+
+#define CFL_PREDICT_HBD_X(width)                                               \
+  static void cfl_predict_hbd_##width(const int16_t *pred_buf_q3,              \
+                                      uint16_t *dst, int dst_stride,           \
+                                      TX_SIZE tx_size, int alpha_q3, int bd) { \
+    cfl_predict_hbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, bd,     \
+                      width);                                                  \
+  }
+
+CFL_PREDICT_HBD_X(4)
+CFL_PREDICT_HBD_X(8)
+CFL_PREDICT_HBD_X(16)
+CFL_PREDICT_HBD_X(32)
+
 cfl_predict_lbd_fn get_predict_lbd_fn_ssse3(TX_SIZE tx_size) {
   static const cfl_predict_lbd_fn predict_lbd[4] = {
     cfl_predict_lbd_4, cfl_predict_lbd_8, cfl_predict_lbd_16, cfl_predict_lbd_32
   };
-  const int width_log2 = tx_size_wide_log2[tx_size];
-  return predict_lbd[(width_log2 - 2) & 3];
+  return predict_lbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3];
+}
+
+cfl_predict_hbd_fn get_predict_hbd_fn_ssse3(TX_SIZE tx_size) {
+  static const cfl_predict_hbd_fn predict_hbd[4] = {
+    cfl_predict_hbd_4, cfl_predict_hbd_8, cfl_predict_hbd_16, cfl_predict_hbd_32
+  };
+  return predict_hbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3];
 }
diff --git a/test/cfl_test.cc b/test/cfl_test.cc
index 7a04952..7ad7a67 100644
--- a/test/cfl_test.cc
+++ b/test/cfl_test.cc
@@ -50,12 +50,16 @@
 
 typedef cfl_predict_lbd_fn (*get_predict_fn)(TX_SIZE tx_size);
 
+typedef cfl_predict_hbd_fn (*get_predict_fn_hbd)(TX_SIZE tx_size);
+
 typedef std::tr1::tuple<int, int, subtract_fn> subtract_param;
 
 typedef std::tr1::tuple<int, int, get_subsample_fn> subsample_param;
 
 typedef std::tr1::tuple<TX_SIZE, get_predict_fn> predict_param;
 
+typedef std::tr1::tuple<TX_SIZE, get_predict_fn_hbd> predict_param_hbd;
+
 static void assertFaster(int ref_elapsed_time, int elapsed_time) {
   EXPECT_GT(ref_elapsed_time, elapsed_time)
       << "Error: CFLSubtractSpeedTest, SIMD slower than C." << std::endl
@@ -151,6 +155,40 @@
   }
 };
 
+class CFLPredictHBDTest : public ::testing::TestWithParam<predict_param_hbd> {
+ public:
+  virtual ~CFLPredictHBDTest() {}
+  virtual void SetUp() { predict = GET_PARAM(1); }
+
+ protected:
+  int Width() const { return tx_size_wide[GET_PARAM(0)]; }
+  int Height() const { return tx_size_high[GET_PARAM(0)]; }
+  TX_SIZE Tx_size() const { return GET_PARAM(0); }
+  DECLARE_ALIGNED(32, uint16_t, chroma_pels_ref[CFL_BUF_SQUARE]);
+  DECLARE_ALIGNED(32, int16_t, sub_luma_pels_ref[CFL_BUF_SQUARE]);
+  DECLARE_ALIGNED(32, uint16_t, chroma_pels[CFL_BUF_SQUARE]);
+  DECLARE_ALIGNED(32, int16_t, sub_luma_pels[CFL_BUF_SQUARE]);
+  get_predict_fn_hbd predict;
+  int bd;
+  int alpha_q3;
+  uint8_t dc;
+  void init(int width, int height) {
+    ACMRandom rnd(ACMRandom::DeterministicSeed());
+    bd = 12;
+    alpha_q3 = rnd(33) - 16;
+    dc = rnd(1 << bd);
+    for (int j = 0; j < height; j++) {
+      for (int i = 0; i < width; i++) {
+        chroma_pels[j * CFL_BUF_LINE + i] = dc;
+        chroma_pels_ref[j * CFL_BUF_LINE + i] = dc;
+        sub_luma_pels_ref[j * CFL_BUF_LINE + i] =
+            sub_luma_pels[j * CFL_BUF_LINE + i] =
+                rnd(1 << bd) - (1 << (bd - 1));
+      }
+    }
+  }
+};
+
 TEST_P(CFLSubtractTest, SubtractTest) {
   const int width = Width();
   const int height = Height();
@@ -296,6 +334,58 @@
   assertFaster(ref_elapsed_time, elapsed_time);
 }
 
+TEST_P(CFLPredictHBDTest, PredictHBDTest) {
+  const int width = Width();
+  const int height = Height();
+  const TX_SIZE tx_size = Tx_size();
+
+  for (int it = 0; it < NUM_ITERATIONS; it++) {
+    init(width, height);
+    predict(tx_size)(sub_luma_pels, chroma_pels, CFL_BUF_LINE, tx_size,
+                     alpha_q3, bd);
+    get_predict_hbd_fn_c(tx_size)(sub_luma_pels_ref, chroma_pels_ref,
+                                  CFL_BUF_LINE, tx_size, alpha_q3, bd);
+    for (int j = 0; j < height; j++) {
+      for (int i = 0; i < width; i++) {
+        ASSERT_EQ(chroma_pels_ref[j * CFL_BUF_LINE + i],
+                  chroma_pels[j * CFL_BUF_LINE + i]);
+      }
+    }
+  }
+}
+
+TEST_P(CFLPredictHBDTest, DISABLED_PredictHBDSpeedTest) {
+  const int width = Width();
+  const int height = Height();
+  const TX_SIZE tx_size = Tx_size();
+
+  aom_usec_timer ref_timer;
+  aom_usec_timer timer;
+
+  init(width, height);
+  cfl_predict_hbd_fn predict_impl = get_predict_hbd_fn_c(tx_size);
+  aom_usec_timer_start(&ref_timer);
+
+  for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) {
+    predict_impl(sub_luma_pels_ref, chroma_pels_ref, CFL_BUF_LINE, tx_size,
+                 alpha_q3, bd);
+  }
+  aom_usec_timer_mark(&ref_timer);
+  int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer);
+
+  predict_impl = predict(tx_size);
+  aom_usec_timer_start(&timer);
+  for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) {
+    predict_impl(sub_luma_pels, chroma_pels, CFL_BUF_LINE, tx_size, alpha_q3,
+                 bd);
+  }
+  aom_usec_timer_mark(&timer);
+  int elapsed_time = (int)aom_usec_timer_elapsed(&timer);
+
+  printSpeed(ref_elapsed_time, elapsed_time, width, height);
+  assertFaster(ref_elapsed_time, elapsed_time);
+}
+
 #if HAVE_SSE2
 const subtract_param subtract_sizes_sse2[] = { ALL_CFL_SIZES(
     av1_cfl_subtract_sse2) };
@@ -312,11 +402,17 @@
 const predict_param predict_sizes_ssse3[] = { ALL_CFL_TX_SIZES(
     get_predict_lbd_fn_ssse3) };
 
+const predict_param_hbd predict_sizes_hbd_ssse3[] = { ALL_CFL_TX_SIZES(
+    get_predict_hbd_fn_ssse3) };
+
 INSTANTIATE_TEST_CASE_P(SSSE3, CFLSubsampleTest,
                         ::testing::ValuesIn(subsample_sizes_ssse3));
 
 INSTANTIATE_TEST_CASE_P(SSSE3, CFLPredictTest,
                         ::testing::ValuesIn(predict_sizes_ssse3));
+
+INSTANTIATE_TEST_CASE_P(SSSE3, CFLPredictHBDTest,
+                        ::testing::ValuesIn(predict_sizes_hbd_ssse3));
 #endif
 
 #if HAVE_AVX2
@@ -329,6 +425,9 @@
 const predict_param predict_sizes_avx2[] = { ALL_CFL_TX_SIZES(
     get_predict_lbd_fn_avx2) };
 
+const predict_param_hbd predict_sizes_hbd_avx2[] = { ALL_CFL_TX_SIZES(
+    get_predict_hbd_fn_avx2) };
+
 INSTANTIATE_TEST_CASE_P(AVX2, CFLSubtractTest,
                         ::testing::ValuesIn(subtract_sizes_avx2));
 
@@ -337,5 +436,8 @@
 
 INSTANTIATE_TEST_CASE_P(AVX2, CFLPredictTest,
                         ::testing::ValuesIn(predict_sizes_avx2));
+
+INSTANTIATE_TEST_CASE_P(AVX2, CFLPredictHBDTest,
+                        ::testing::ValuesIn(predict_sizes_hbd_avx2));
 #endif
 }  // namespace