Add high bd support Change-Id: I9174fb53bf6443250a917bda5bd2e8d03286fa56
diff --git a/av1/encoder/allintra_vis.c b/av1/encoder/allintra_vis.c index 6a5ed70..c66353b 100644 --- a/av1/encoder/allintra_vis.c +++ b/av1/encoder/allintra_vis.c
@@ -628,8 +628,8 @@ #if CONFIG_TFLITE static int model_predict(BLOCK_SIZE block_size, int num_cols, int num_rows, - uint8_t *y_buffer, int y_stride, float *predicts0, - float *predicts1) { + int bit_depth, uint8_t *y_buffer, int y_stride, + float *predicts0, float *predicts1) { // Create the model and interpreter options. TfLiteModel *model = TfLiteModelCreate(av1_deltaq4_model_file, av1_deltaq4_model_fsize); @@ -674,9 +674,12 @@ uint8_t *buf = y_buffer + row_offset * y_stride + col_offset; int r = row_offset, pos = 0; + const float base = (float)((1 << bit_depth) - 1); while (r < row_offset + (num_mi_h << 2)) { for (int c = 0; c < (num_mi_w << 2); ++c) { - input_data[pos++] = (float)*(buf + c) / 255.0f; + input_data[pos++] = bit_depth > 8 + ? (float)*CONVERT_TO_SHORTPTR(buf + c) / base + : (float)*(buf + c) / base; } buf += y_stride; ++r; @@ -724,15 +727,12 @@ uint8_t *y_buffer = cpi->source->y_buffer; const int y_stride = cpi->source->y_stride; const int block_size = cpi->common.seq_params->sb_size; + const uint32_t bit_depth = cpi->td.mb.e_mbd.bd; const int num_mi_w = mi_size_wide[block_size]; const int num_mi_h = mi_size_high[block_size]; const int num_cols = (mi_params->mi_cols + num_mi_w - 1) / num_mi_w; const int num_rows = (mi_params->mi_rows + num_mi_h - 1) / num_mi_h; - const int use_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH; - - // TODO(sdeng): add highbitdepth support. - (void)use_hbd; // TODO(sdeng): fit a better model_1; disable it at this time. float *mb_delta_q0, *mb_delta_q1, delta_q_avg0 = 0.0f; @@ -741,8 +741,8 @@ CHECK_MEM_ERROR(cm, mb_delta_q1, aom_calloc(num_rows * num_cols, sizeof(float))); - if (model_predict(block_size, num_cols, num_rows, y_buffer, y_stride, - mb_delta_q0, mb_delta_q1)) { + if (model_predict(block_size, num_cols, num_rows, bit_depth, y_buffer, + y_stride, mb_delta_q0, mb_delta_q1)) { aom_internal_error(cm->error, AOM_CODEC_ERROR, "Failed to call TFlite functions."); }