Add high bd support

Change-Id: I9174fb53bf6443250a917bda5bd2e8d03286fa56
diff --git a/av1/encoder/allintra_vis.c b/av1/encoder/allintra_vis.c
index 6a5ed70..c66353b 100644
--- a/av1/encoder/allintra_vis.c
+++ b/av1/encoder/allintra_vis.c
@@ -628,8 +628,8 @@
 
 #if CONFIG_TFLITE
 static int model_predict(BLOCK_SIZE block_size, int num_cols, int num_rows,
-                         uint8_t *y_buffer, int y_stride, float *predicts0,
-                         float *predicts1) {
+                         int bit_depth, uint8_t *y_buffer, int y_stride,
+                         float *predicts0, float *predicts1) {
   // Create the model and interpreter options.
   TfLiteModel *model =
       TfLiteModelCreate(av1_deltaq4_model_file, av1_deltaq4_model_fsize);
@@ -674,9 +674,12 @@
 
       uint8_t *buf = y_buffer + row_offset * y_stride + col_offset;
       int r = row_offset, pos = 0;
+      const float base = (float)((1 << bit_depth) - 1);
       while (r < row_offset + (num_mi_h << 2)) {
         for (int c = 0; c < (num_mi_w << 2); ++c) {
-          input_data[pos++] = (float)*(buf + c) / 255.0f;
+          input_data[pos++] = bit_depth > 8
+                                  ? (float)*CONVERT_TO_SHORTPTR(buf + c) / base
+                                  : (float)*(buf + c) / base;
         }
         buf += y_stride;
         ++r;
@@ -724,15 +727,12 @@
   uint8_t *y_buffer = cpi->source->y_buffer;
   const int y_stride = cpi->source->y_stride;
   const int block_size = cpi->common.seq_params->sb_size;
+  const uint32_t bit_depth = cpi->td.mb.e_mbd.bd;
 
   const int num_mi_w = mi_size_wide[block_size];
   const int num_mi_h = mi_size_high[block_size];
   const int num_cols = (mi_params->mi_cols + num_mi_w - 1) / num_mi_w;
   const int num_rows = (mi_params->mi_rows + num_mi_h - 1) / num_mi_h;
-  const int use_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
-
-  // TODO(sdeng): add highbitdepth support.
-  (void)use_hbd;
 
   // TODO(sdeng): fit a better model_1; disable it at this time.
   float *mb_delta_q0, *mb_delta_q1, delta_q_avg0 = 0.0f;
@@ -741,8 +741,8 @@
   CHECK_MEM_ERROR(cm, mb_delta_q1,
                   aom_calloc(num_rows * num_cols, sizeof(float)));
 
-  if (model_predict(block_size, num_cols, num_rows, y_buffer, y_stride,
-                    mb_delta_q0, mb_delta_q1)) {
+  if (model_predict(block_size, num_cols, num_rows, bit_depth, y_buffer,
+                    y_stride, mb_delta_q0, mb_delta_q1)) {
     aom_internal_error(cm->error, AOM_CODEC_ERROR,
                        "Failed to call TFlite functions.");
   }