Add high bd support
Change-Id: I9174fb53bf6443250a917bda5bd2e8d03286fa56
diff --git a/av1/encoder/allintra_vis.c b/av1/encoder/allintra_vis.c
index 6a5ed70..c66353b 100644
--- a/av1/encoder/allintra_vis.c
+++ b/av1/encoder/allintra_vis.c
@@ -628,8 +628,8 @@
#if CONFIG_TFLITE
static int model_predict(BLOCK_SIZE block_size, int num_cols, int num_rows,
- uint8_t *y_buffer, int y_stride, float *predicts0,
- float *predicts1) {
+ int bit_depth, uint8_t *y_buffer, int y_stride,
+ float *predicts0, float *predicts1) {
// Create the model and interpreter options.
TfLiteModel *model =
TfLiteModelCreate(av1_deltaq4_model_file, av1_deltaq4_model_fsize);
@@ -674,9 +674,12 @@
uint8_t *buf = y_buffer + row_offset * y_stride + col_offset;
int r = row_offset, pos = 0;
+ const float base = (float)((1 << bit_depth) - 1);
while (r < row_offset + (num_mi_h << 2)) {
for (int c = 0; c < (num_mi_w << 2); ++c) {
- input_data[pos++] = (float)*(buf + c) / 255.0f;
+ input_data[pos++] = bit_depth > 8
+ ? (float)*CONVERT_TO_SHORTPTR(buf + c) / base
+ : (float)*(buf + c) / base;
}
buf += y_stride;
++r;
@@ -724,15 +727,12 @@
uint8_t *y_buffer = cpi->source->y_buffer;
const int y_stride = cpi->source->y_stride;
const int block_size = cpi->common.seq_params->sb_size;
+ const uint32_t bit_depth = cpi->td.mb.e_mbd.bd;
const int num_mi_w = mi_size_wide[block_size];
const int num_mi_h = mi_size_high[block_size];
const int num_cols = (mi_params->mi_cols + num_mi_w - 1) / num_mi_w;
const int num_rows = (mi_params->mi_rows + num_mi_h - 1) / num_mi_h;
- const int use_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
-
- // TODO(sdeng): add highbitdepth support.
- (void)use_hbd;
// TODO(sdeng): fit a better model_1; disable it at this time.
float *mb_delta_q0, *mb_delta_q1, delta_q_avg0 = 0.0f;
@@ -741,8 +741,8 @@
CHECK_MEM_ERROR(cm, mb_delta_q1,
aom_calloc(num_rows * num_cols, sizeof(float)));
- if (model_predict(block_size, num_cols, num_rows, y_buffer, y_stride,
- mb_delta_q0, mb_delta_q1)) {
+ if (model_predict(block_size, num_cols, num_rows, bit_depth, y_buffer,
+ y_stride, mb_delta_q0, mb_delta_q1)) {
aom_internal_error(cm->error, AOM_CODEC_ERROR,
"Failed to call TFlite functions.");
}