blob: 233a184b7d305ed35dbd33cb61c2aa32bf9bc459 [file] [log] [blame] [edit]
#include "av1/encoder/intra_dip_mode_prune_tflite.h"
#include <cstdio>
#include <memory>
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include <mutex>
#include "common/tf_lite_includes.h"
#if HAVE_FEXCEPT
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include <fenv.h>
#endif
struct DipContext {
std::unique_ptr<tflite::Interpreter> interpreter;
std::unique_ptr<tflite::FlatBufferModel> model;
dip_pruning_inputs dip_pruning_in;
int model_index = -1;
};
std::mutex dip_prune_mutex;
static void create_interpreter(DipContext *context,
const unsigned char *model_def, int model_len) {
std::lock_guard<std::mutex> lock(dip_prune_mutex);
tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_ERROR);
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer((const char *)model_def,
model_len);
tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
std::cerr << "Failed to build interpreter for DIP model." << std::endl;
exit(1);
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
std::cerr << "Failed to allocate tensors for DIP model." << std::endl;
exit(1);
}
context->interpreter = std::move(interpreter);
context->model = std::move(model);
}
static void ensure_tflite_init(void **context, int model_index) {
if (*context == nullptr) {
*context = new DipContext();
}
DipContext *ctx = (DipContext *)*context;
if (!ctx->interpreter || model_index != ctx->model_index) {
ctx->model_index = model_index;
const uint8_t *model_bytes = NULL;
int model_len = -1;
switch (model_index) {
case 0:
model_bytes = dip_pruning_tflite_qp85;
model_len = sizeof(dip_pruning_tflite_qp85) /
sizeof(dip_pruning_tflite_qp85[0]);
break;
case 1:
model_bytes = dip_pruning_tflite_qp110;
model_len = sizeof(dip_pruning_tflite_qp110) /
sizeof(dip_pruning_tflite_qp110[0]);
break;
case 2:
model_bytes = dip_pruning_tflite_qp135;
model_len = sizeof(dip_pruning_tflite_qp135) /
sizeof(dip_pruning_tflite_qp135[0]);
break;
case 3:
model_bytes = dip_pruning_tflite_qp160;
model_len = sizeof(dip_pruning_tflite_qp160) /
sizeof(dip_pruning_tflite_qp160[0]);
break;
case 4:
model_bytes = dip_pruning_tflite_qp185;
model_len = sizeof(dip_pruning_tflite_qp185) /
sizeof(dip_pruning_tflite_qp185[0]);
break;
case 5:
model_bytes = dip_pruning_tflite_qp210;
model_len = sizeof(dip_pruning_tflite_qp210) /
sizeof(dip_pruning_tflite_qp210[0]);
break;
default:
fprintf(stderr, "Bad DIP pruning model index: %d\n", model_index);
exit(1);
}
create_interpreter(ctx, (const unsigned char *)model_bytes, model_len);
ctx->dip_pruning_in.inputs[0].name = "extra_features";
ctx->dip_pruning_in.inputs[0].size = 19;
ctx->dip_pruning_in.inputs[0].orig_index = 0;
ctx->dip_pruning_in.inputs[0].tflite_index = 3;
ctx->dip_pruning_in.inputs[1].name = "source_pixels";
ctx->dip_pruning_in.inputs[1].size = 64;
ctx->dip_pruning_in.inputs[1].orig_index = 1;
ctx->dip_pruning_in.inputs[1].tflite_index = 0;
ctx->dip_pruning_in.inputs[2].name = "dip_features";
ctx->dip_pruning_in.inputs[2].size = 11;
ctx->dip_pruning_in.inputs[2].orig_index = 2;
ctx->dip_pruning_in.inputs[2].tflite_index = 1;
ctx->dip_pruning_in.inputs[3].name = "block_size";
ctx->dip_pruning_in.inputs[3].size = 2;
ctx->dip_pruning_in.inputs[3].orig_index = 3;
ctx->dip_pruning_in.inputs[3].tflite_index = 4;
ctx->dip_pruning_in.inputs[4].name = "dip_model_rds";
ctx->dip_pruning_in.inputs[4].size = 12;
ctx->dip_pruning_in.inputs[4].orig_index = 4;
ctx->dip_pruning_in.inputs[4].tflite_index = 2;
}
}
#if HAVE_FEXCEPT && CONFIG_DEBUG
#define FLOATING_POINT_DISABLE_EXCEPTIONS \
const int float_excepts = fedisableexcept(FE_UNDERFLOW | FE_OVERFLOW);
#define FLOATING_POINT_RESTORE_EXCEPTIONS feenableexcept(float_excepts);
#else
#define FLOATING_POINT_DISABLE_EXCEPTIONS
#define FLOATING_POINT_RESTORE_EXCEPTIONS
#endif // HAVE_FEXCEPT && CONFIG_DEBUG
static std::vector<float> run_inference(void **context) {
DipContext *ctx = (DipContext *)*context;
tflite::Interpreter *interpreter = ctx->interpreter.get();
for (int i = 0; i < DIP_PRUNING_NUM_INPUTS; i++) {
int tflite_index = ctx->dip_pruning_in.inputs[i].tflite_index;
float *tensor_input = interpreter->typed_input_tensor<float>(tflite_index);
memcpy(tensor_input, ctx->dip_pruning_in.inputs[i].values,
ctx->dip_pruning_in.inputs[i].size * sizeof(float));
}
FLOATING_POINT_DISABLE_EXCEPTIONS
if (interpreter->Invoke() != kTfLiteOk) {
std::cerr << "Failed to run DIP pruning inference." << std::endl;
exit(1);
}
FLOATING_POINT_RESTORE_EXCEPTIONS
float *output = interpreter->typed_output_tensor<float>(0);
size_t output_size =
interpreter->tensor(interpreter->outputs()[0])->bytes / sizeof(float);
std::vector<float> output_data(output, output + output_size);
return output_data;
}
extern "C" int intra_dip_mode_prune_tflite(void **context, float *output,
int qp) {
ensure_tflite_init(context, intra_dip_mode_prune_get_model_index(qp));
auto output_vec = run_inference(context);
memcpy(output, output_vec.data(), 1 * sizeof(float));
return 0;
}
extern "C" int intra_dip_mode_prune_get_model_index(int qp) {
int closest_index = 0;
int closest_delta = 10000;
for (int i = 0; i < 6; i++) {
int delta = abs(qp - DIP_PRUNING_QPS[i]);
if (delta < closest_delta) {
closest_delta = delta;
closest_index = i;
}
}
// Don't run pruning for QP>=210 (index 5).
// TODO(comc): Re-train QP=210 model.
if (closest_index == 5) {
return -1;
}
return closest_index;
}
extern "C" dip_pruning_inputs *intra_dip_mode_prune_get_inputs(void **context,
int qp) {
ensure_tflite_init(context, intra_dip_mode_prune_get_model_index(qp));
DipContext *ctx = (DipContext *)*context;
return &(ctx->dip_pruning_in);
}
extern "C" void intra_dip_mode_prune_normalize_and_resize_8x8(
const uint16_t *input, size_t stride, int bd, size_t width, size_t height,
float *output) {
const size_t x_step = width / 8;
const size_t y_step = height / 8;
const float norm = (float)((1 << bd) - 1);
for (size_t out_y = 0; out_y < 8; out_y++) {
const size_t in_y = out_y * y_step;
for (size_t out_x = 0; out_x < 8; out_x++) {
const size_t in_x = out_x * x_step;
const size_t in_index = in_y * stride + in_x;
const size_t out_index = out_y * 8 + out_x;
output[out_index] = (float)input[in_index] / norm;
}
}
}
extern "C" void intra_dip_mode_prune_close(void **context) {
DipContext *ctx = (DipContext *)*context;
if (ctx != nullptr) delete ctx;
*context = nullptr;
}