blob: 2ffdb90531c7548418cd222ac5aac008269f2059 [file] [log] [blame] [edit]
/*
* Copyright (c) 2024, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 3-Clause Clear License
* and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
* License was not distributed with this source code in the LICENSE file, you
* can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
* Alliance for Open Media Patent License 1.0 was not distributed with this
* source code in the PATENTS file, you can obtain it at
* aomedia.org/license/patent-license/.
*/
#include "av1/encoder/part_split_prune_tflite.h"
#include <cstdio>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <iostream>
#include "common/tf_lite_includes.h"
#include "av1/encoder/simple_intrapred_tflite_model_128x128.h"
#include "av1/encoder/simple_intrapred_tflite_model_16x16.h"
#include "av1/encoder/simple_intrapred_tflite_model_32x32.h"
#include "av1/encoder/simple_intrapred_tflite_model_64x64.h"
#include "av1/encoder/sms_part_split_prune_tflite_model.h"
#include "av1/encoder/sms_part_none_prune_tflite_model.h"
#include "av1/encoder/sms_part_none_prune_rect_tflite_model.h"
#if HAVE_FEXCEPT
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include <fenv.h>
#endif
typedef std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
TfLiteDelegateType;
struct PartSplitContext {
std::unique_ptr<tflite::Interpreter> models[MODEL_COUNT];
bool old_model[MODEL_COUNT];
uint8_t input_order[MODEL_COUNT][8]; // 8 inputs max
std::vector<TfLiteDelegateType> to_delete;
};
std::mutex tfliteMutex;
static std::unique_ptr<tflite::Interpreter> create_interpreter(
unsigned char *model_def, std::vector<TfLiteDelegateType> &to_delete) {
std::lock_guard<std::mutex> lock(tfliteMutex);
tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_ERROR);
tflite::Model *model = (tflite::Model *)tflite::GetModel(model_def);
const int num_threads = 1;
TfLiteXNNPackDelegateOptions xnnpack_options =
TfLiteXNNPackDelegateOptionsDefault();
xnnpack_options.num_threads = AOMMAX(num_threads, 1);
TfLiteDelegateType xnnpack_delegate(
TfLiteXNNPackDelegateCreate(&xnnpack_options),
&TfLiteXNNPackDelegateDelete);
tflite::MutableOpResolver resolver;
RegisterSelectedOps(&resolver);
tflite::InterpreterBuilder builder(model, resolver);
tflite::ErrorReporter *reporter(tflite::DefaultErrorReporter());
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
if (interpreter->ModifyGraphWithDelegate(xnnpack_delegate.get()) !=
kTfLiteOk) {
reporter->Report("Failed at modifying graph with XNNPack delegate");
exit(1);
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
reporter->Report("Failed at allocating tensors");
exit(1);
}
to_delete.push_back(std::move(xnnpack_delegate));
return interpreter;
}
struct ModelDef {
unsigned char *model_def;
size_t model_size;
const struct InputNorm input_norm;
MODEL_TYPE type;
const char *var_name;
const char *enum_name;
int part_type;
int n_features;
int model_version;
};
// clang-format off
#define MODELDEF(data, type, part_type, n_features, model_version) \
{ data, sizeof(data), { false, NULL, NULL, NULL }, \
type, #data, #type, \
part_type, n_features, model_version }
#define MODELDEF_NORM(data, type, part_type, n_features, model_version) \
{ data, sizeof(data), { true, data##_mean, data##_std, data##_std_inv }, \
type, #data, #type, \
part_type, n_features, model_version }
// clang-format on
const ModelDef models[] = {
MODELDEF(NULL, MODEL_OTHER, PT_INVAL, 0, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_128X128_intra_tflite, MODEL_128X128,
PT_SPLIT, 37, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_64X64_intra_tflite, MODEL_64X64, PT_SPLIT,
37, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_32X32_intra_tflite, MODEL_32X32, PT_SPLIT,
37, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_16X16_intra_tflite, MODEL_16X16, PT_SPLIT,
37, 0),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs12_110,
MODEL_INTER_NONE_64X64_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs12_135,
MODEL_INTER_NONE_64X64_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs9_110,
MODEL_INTER_NONE_32X32_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs9_135,
MODEL_INTER_NONE_32X32_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs6_110,
MODEL_INTER_NONE_16X16_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs6_135,
MODEL_INTER_NONE_16X16_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs11_110,
MODEL_INTER_NONE_BS11_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs11_135,
MODEL_INTER_NONE_BS11_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs10_110,
MODEL_INTER_NONE_BS10_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs10_135,
MODEL_INTER_NONE_BS10_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs8_110,
MODEL_INTER_NONE_BS8_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs8_135,
MODEL_INTER_NONE_BS8_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs7_110,
MODEL_INTER_NONE_BS7_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs7_135,
MODEL_INTER_NONE_BS7_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs5_110,
MODEL_INTER_NONE_BS5_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs5_135,
MODEL_INTER_NONE_BS5_135, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs4_110,
MODEL_INTER_NONE_BS4_110, PT_NONE, 66, 6),
MODELDEF_NORM(sms_part_none_prune_tflite_model_bs4_135,
MODEL_INTER_NONE_BS4_135, PT_NONE, 66, 6),
MODELDEF(sms_part_split_prune_tflite_model_bs12, MODEL_INTER_SPLIT_64X64,
PT_SPLIT, 31, 0),
MODELDEF(sms_part_split_prune_tflite_model_bs9, MODEL_INTER_SPLIT_32X32,
PT_SPLIT, 31, 0),
MODELDEF(sms_part_split_prune_tflite_model_bs6, MODEL_INTER_SPLIT_16X16,
PT_SPLIT, 31, 0),
MODELDEF(sms_part_split_prune_tflite_model_bs3, MODEL_INTER_SPLIT_8X8,
PT_SPLIT, 31, 0),
};
static void get_input_order(tflite::Interpreter *interpreter, uint8_t *order) {
static const std::unordered_map<std::string, uint8_t> order_lut = {
{ "serving_default_input:0", 0 }, { "serving_default_input_1:0", 0 },
{ "serving_default_input_2:0", 1 }, { "serving_default_input_3:0", 2 },
{ "serving_default_input_4:0", 3 },
};
memset(order, 0, 8 * sizeof(order[0]));
for (size_t i = 0; i < interpreter->inputs().size(); ++i) {
int input_index = interpreter->inputs()[i];
const TfLiteTensor *input_tensor = interpreter->tensor(input_index);
try {
auto value = order_lut.at(input_tensor->name);
order[value] = input_index;
} catch (const std::out_of_range &oor) {
std::cout << "Model with unsupported input name: " << input_tensor->name
<< std::endl;
exit(-1);
}
}
}
static void ensure_tflite_init(void **context, MODEL_TYPE model_type) {
assert(model_type != MODEL_OTHER);
if (*context == nullptr) *context = new PartSplitContext();
PartSplitContext *ctx = (PartSplitContext *)*context;
ModelDef def = models[model_type];
if (!ctx->models[model_type]) {
if (def.model_def != NULL) {
ctx->models[model_type] =
create_interpreter(def.model_def, ctx->to_delete);
get_input_order(ctx->models[model_type].get(),
&ctx->input_order[model_type][0]);
} else {
printf("\x1b[91mUsing undefined model: %s(%d)\x1b[0m\n",
models[model_type].enum_name, model_type);
}
}
}
extern "C" int av2_model_input_norm(MODEL_TYPE model_type,
struct InputNorm *norm) {
assert(model_type != MODEL_OTHER);
ModelDef def = models[model_type];
*norm = def.input_norm;
return 0;
}
#if HAVE_FEXCEPT && CONFIG_DEBUG
#define FLOATING_POINT_DISABLE_EXCEPTIONS \
const int float_excepts = fedisableexcept(FE_UNDERFLOW | FE_OVERFLOW);
#define FLOATING_POINT_RESTORE_EXCEPTIONS feenableexcept(float_excepts);
#else
#define FLOATING_POINT_DISABLE_EXCEPTIONS
#define FLOATING_POINT_RESTORE_EXCEPTIONS
#endif // HAVE_FEXCEPT && CONFIG_DEBUG
// Simple intra ML TFLite based inference
static inline float norm(const float *input, int feature,
struct InputNorm norm) {
return norm.valid ? (float)((input[feature] - norm.mean[feature]) *
norm.invstd[feature])
: input[feature];
}
extern "C" int av2_part_prune_tflite_exec(void **context, const float *ml_input,
float *ml_output,
MODEL_TYPE model_type) {
assert(model_type != MODEL_OTHER);
ensure_tflite_init(context, model_type);
PartSplitContext *ctx = (PartSplitContext *)*context;
tflite::Interpreter *interpreter = ctx->models[model_type].get();
tflite::ErrorReporter *reporter(tflite::DefaultErrorReporter());
int model_version = models[model_type].model_version;
struct InputNorm input_norm = models[model_type].input_norm;
int input_len = models[model_type].n_features;
int output_len = 1;
if (model_version == 0) {
const TfLiteTensor *input_tensor = interpreter->input_tensor(0);
int num_input_features = input_tensor->dims->data[1];
if (num_input_features > input_len) {
printf("\x1b[91mERROR:\x1b[0m Not enough input features: %d>%d\n",
num_input_features, input_len);
exit(1);
}
if (num_input_features != input_len && !ctx->old_model[model_type]) {
printf(
"\x1b[95mWARN:\x1b[0m Too many input features for model %s: %d<%d"
" (is it an old model?)\n",
models[model_type].var_name, num_input_features, input_len);
ctx->old_model[model_type] = true;
}
float *input = interpreter->typed_input_tensor<float>(0);
if (input_norm.valid) {
for (int i = 0; i < num_input_features; i++) {
input[i] =
(float)((ml_input[i] - input_norm.mean[i]) * input_norm.invstd[i]);
}
} else {
for (int i = 0; i < num_input_features; i++) {
input[i] = ml_input[i];
}
}
} else if (model_version == 6) {
uint8_t *input_order = &ctx->input_order[model_type][0];
float *input0 = interpreter->typed_input_tensor<float>(input_order[0]);
float *input1 = interpreter->typed_input_tensor<float>(input_order[1]);
float *input2 = interpreter->typed_input_tensor<float>(input_order[2]);
float *input3 = interpreter->typed_input_tensor<float>(input_order[3]);
int kV6Inp0[] = { FEATURE_INTER_RD_MULT, FEATURE_INTER_SWITCH,
FEATURE_INTER_PART_T };
int kV6Inp1[] = {
FEATURE_INTER_FULL_PSNR, FEATURE_INTER_SQ_0_PSNR,
FEATURE_INTER_SQ_1_PSNR, FEATURE_INTER_SQ_2_PSNR,
FEATURE_INTER_SQ_3_PSNR, FEATURE_INTER_FULL_Q_COEFF_MAX,
FEATURE_INTER_SQ_0_Q_COEFF_MAX, FEATURE_INTER_SQ_1_Q_COEFF_MAX,
FEATURE_INTER_SQ_2_Q_COEFF_MAX, FEATURE_INTER_SQ_3_Q_COEFF_MAX,
FEATURE_INTER_FULL_Q_COEFF_NONZ, FEATURE_INTER_SQ_0_Q_COEFF_NONZ,
FEATURE_INTER_SQ_1_Q_COEFF_NONZ, FEATURE_INTER_SQ_2_Q_COEFF_NONZ,
FEATURE_INTER_SQ_3_Q_COEFF_NONZ, FEATURE_INTER_FULL_LOG_SATDQ,
FEATURE_INTER_SQ_0_LOG_SATDQ, FEATURE_INTER_SQ_1_LOG_SATDQ,
FEATURE_INTER_SQ_2_LOG_SATDQ, FEATURE_INTER_SQ_3_LOG_SATDQ
};
int kV6Inp2[] = {
FEATURE_INTER_FULL_PSNR, FEATURE_INTER_HOR_0_PSNR,
FEATURE_INTER_HOR_1_PSNR, FEATURE_INTER_FULL_Q_COEFF_MAX,
FEATURE_INTER_HOR_0_Q_COEFF_MAX, FEATURE_INTER_HOR_1_Q_COEFF_MAX,
FEATURE_INTER_FULL_Q_COEFF_NONZ, FEATURE_INTER_HOR_0_Q_COEFF_NONZ,
FEATURE_INTER_HOR_1_Q_COEFF_NONZ, FEATURE_INTER_FULL_LOG_SATDQ,
FEATURE_INTER_HOR_0_LOG_SATDQ, FEATURE_INTER_HOR_1_LOG_SATDQ
};
int kV6Inp3[] = {
FEATURE_INTER_FULL_PSNR, FEATURE_INTER_VER_0_PSNR,
FEATURE_INTER_VER_1_PSNR, FEATURE_INTER_FULL_Q_COEFF_MAX,
FEATURE_INTER_VER_0_Q_COEFF_MAX, FEATURE_INTER_VER_1_Q_COEFF_MAX,
FEATURE_INTER_FULL_Q_COEFF_NONZ, FEATURE_INTER_VER_0_Q_COEFF_NONZ,
FEATURE_INTER_VER_1_Q_COEFF_NONZ, FEATURE_INTER_FULL_LOG_SATDQ,
FEATURE_INTER_VER_0_LOG_SATDQ, FEATURE_INTER_VER_1_LOG_SATDQ
};
for (size_t i = 0; i < sizeof(kV6Inp0) / sizeof(kV6Inp0[0]); i++)
input0[i] = norm(ml_input, kV6Inp0[i], input_norm);
for (size_t i = 0; i < sizeof(kV6Inp1) / sizeof(kV6Inp1[0]); i++)
input1[i] = norm(ml_input, kV6Inp1[i], input_norm);
for (size_t i = 0; i < sizeof(kV6Inp2) / sizeof(kV6Inp2[0]); i++)
input2[i] = norm(ml_input, kV6Inp2[i], input_norm);
for (size_t i = 0; i < sizeof(kV6Inp3) / sizeof(kV6Inp3[0]); i++)
input3[i] = norm(ml_input, kV6Inp3[i], input_norm);
}
FLOATING_POINT_DISABLE_EXCEPTIONS
auto status = interpreter->Invoke();
FLOATING_POINT_RESTORE_EXCEPTIONS
if (status != kTfLiteOk) {
reporter->Report("Failed at invoke");
exit(1);
}
float *output = interpreter->typed_output_tensor<float>(0);
for (int i = 0; i < output_len; i++) {
ml_output[i] = output[i];
}
return 0;
}
extern "C" void av2_part_prune_tflite_close(void **context) {
PartSplitContext *ctx = (PartSplitContext *)*context;
if (ctx != nullptr) delete ctx;
*context = nullptr;
}
extern "C" int get_model_part_type(MODEL_TYPE type) {
if (type >= MODEL_COUNT) {
return PT_INVAL;
}
return models[type].part_type;
}