blob: 29998bd0de9b09a167ff4c10988e52fa9d8ac9fe [file] [log] [blame] [edit]
/*
* Copyright (c) 2024, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 3-Clause Clear License
* and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
* License was not distributed with this source code in the LICENSE file, you
* can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
* Alliance for Open Media Patent License 1.0 was not distributed with this
* source code in the PATENTS file, you can obtain it at
* aomedia.org/license/patent-license/.
*/
#include "av1/encoder/part_split_prune_tflite.h"
#include <cstdio>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <iostream>
#include "common/tf_lite_includes.h"
#include "av1/encoder/simple_intrapred_tflite_model_128x128.h"
#include "av1/encoder/simple_intrapred_tflite_model_16x16.h"
#include "av1/encoder/simple_intrapred_tflite_model_32x32.h"
#include "av1/encoder/simple_intrapred_tflite_model_64x64.h"
#include "av1/encoder/sms_part_split_prune_tflite_model.h"
#include "av1/encoder/sms_part_none_prune_tflite_model.h"
#include "av1/encoder/sms_part_none_prune_rect_tflite_model.h"
#if HAVE_FEXCEPT
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include <fenv.h>
#endif
typedef std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
TfLiteDelegateType;
struct Context {
std::unique_ptr<tflite::Interpreter> models[MODEL_COUNT];
bool old_model[MODEL_COUNT];
uint8_t input_order[MODEL_COUNT][8]; // 8 inputs max
std::vector<TfLiteDelegateType> to_delete;
};
std::mutex tfliteMutex;
static std::unique_ptr<tflite::Interpreter> create_interpreter(
unsigned char *model_def, std::vector<TfLiteDelegateType> &to_delete) {
std::lock_guard<std::mutex> lock(tfliteMutex);
tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_ERROR);
tflite::Model *model = (tflite::Model *)tflite::GetModel(model_def);
const int num_threads = 1;
TfLiteXNNPackDelegateOptions xnnpack_options =
TfLiteXNNPackDelegateOptionsDefault();
xnnpack_options.num_threads = AOMMAX(num_threads, 1);
TfLiteDelegateType xnnpack_delegate(
TfLiteXNNPackDelegateCreate(&xnnpack_options),
&TfLiteXNNPackDelegateDelete);
tflite::MutableOpResolver resolver;
RegisterSelectedOps(&resolver);
tflite::InterpreterBuilder builder(model, resolver);
tflite::ErrorReporter *reporter(tflite::DefaultErrorReporter());
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
if (interpreter->ModifyGraphWithDelegate(xnnpack_delegate.get()) !=
kTfLiteOk) {
reporter->Report("Failed at modifying graph with XNNPack delegate");
exit(1);
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
reporter->Report("Failed at allocating tensors");
exit(1);
}
to_delete.push_back(std::move(xnnpack_delegate));
return interpreter;
}
struct ModelDef {
unsigned char *model_def;
size_t model_size;
const struct InputSpec input_spec;
MODEL_TYPE type;
const char *var_name;
const char *enum_name;
int part_type;
int n_features;
int model_version;
};
#define MODELDEF(data, type, part_type, n_features, model_version) \
{ data, sizeof(data), {false, NULL, NULL, NULL}, type, #data, #type, part_type, n_features, model_version }
#define MODELDEF_INS(data, type, part_type, n_features, model_version) \
{ data, sizeof(data), {true, data##_mean, data##_std, data##_std_inv}, type, #data, #type, part_type, n_features, model_version }
enum {
PT_INVAL = -1, PT_NONE = 0, PT_SPLIT, PT_VERT, PT_HORZ
};
const ModelDef models[] = {
MODELDEF(NULL, MODEL_OTHER, PT_INVAL, 0, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_128X128_intra_tflite, MODEL_128X128, PT_SPLIT, 39, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_64X64_intra_tflite, MODEL_64X64, PT_SPLIT, 39, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_32X32_intra_tflite, MODEL_32X32, PT_SPLIT, 39, 0),
MODELDEF(a3_qp96_128_160_luma_BLOCK_16X16_intra_tflite, MODEL_16X16, PT_SPLIT, 39, 0),
MODELDEF(sms_part_split_prune_tflite_model_bs12, MODEL_INTER_SPLIT_64X64, PT_SPLIT, 66, 0),
MODELDEF(sms_part_split_prune_tflite_model_bs9, MODEL_INTER_SPLIT_32X32, PT_SPLIT, 66, 0),
MODELDEF(sms_part_split_prune_tflite_model_bs6, MODEL_INTER_SPLIT_16X16, PT_SPLIT, 66, 0),
MODELDEF(sms_part_split_prune_tflite_model_bs3, MODEL_INTER_SPLIT_8X8, PT_SPLIT, 66, 0),
/*
MODELDEF(NULL, MODEL_INTER_NONE_64X64_160, PT_NONE, 66, 0),
MODELDEF(NULL, MODEL_INTER_NONE_32X32_160, PT_NONE, 66, 0),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs12_110, MODEL_INTER_NONE_64X64_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs12_135, MODEL_INTER_NONE_64X64_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs9_110, MODEL_INTER_NONE_32X32_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs9_135, MODEL_INTER_NONE_32X32_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs6_110, MODEL_INTER_NONE_16X16_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs6_135, MODEL_INTER_NONE_16X16_135, PT_NONE, 66, 6),
MODELDEF(NULL, MODEL_INTER_NONE_16X16_160, PT_NONE, 66, 0),
MODELDEF(NULL, MODEL_INTER_NONE_8X8_110, PT_NONE, 66, 0),
MODELDEF(NULL, MODEL_INTER_NONE_8X8_135, PT_NONE, 66, 0),
MODELDEF(NULL, MODEL_INTER_NONE_8X8_160, PT_NONE, 66, 0),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_110, MODEL_INTER_NONE_64X32_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_135, MODEL_INTER_NONE_64X32_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_160, MODEL_INTER_NONE_64X32_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_185, MODEL_INTER_NONE_64X32_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_210, MODEL_INTER_NONE_64X32_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_235, MODEL_INTER_NONE_64X32_235, PT_NONE, 66, 6),
MODELDEF(NULL, MODEL_INTER_NONE_BS11_160, PT_NONE, 66, 0),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs10_110, MODEL_INTER_NONE_32X64_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs10_135, MODEL_INTER_NONE_32X64_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs10_160, MODEL_INTER_NONE_32X64_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs10_185, MODEL_INTER_NONE_32X64_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs10_210, MODEL_INTER_NONE_32X64_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs10_235, MODEL_INTER_NONE_32X64_235, PT_NONE, 66, 6),
MODELDEF(NULL, MODEL_INTER_NONE_BS10_160, PT_NONE, 66, 0),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs8_110, MODEL_INTER_NONE_BS8_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs8_135, MODEL_INTER_NONE_BS8_135, PT_NONE, 66, 6),
MODELDEF(NULL, MODEL_INTER_NONE_BS8_160, PT_NONE, 66, 0),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs7_110, MODEL_INTER_NONE_BS7_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs7_135, MODEL_INTER_NONE_BS7_135, PT_NONE, 66, 6),
MODELDEF(NULL, MODEL_INTER_NONE_BS7_160, PT_NONE, 66, 0),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs5_110, MODEL_INTER_NONE_BS5_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs5_135, MODEL_INTER_NONE_BS5_135, PT_NONE, 66, 6),
MODELDEF(NULL, MODEL_INTER_NONE_BS5_160, PT_NONE, 66, 0),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs4_110, MODEL_INTER_NONE_BS4_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_bs4_135, MODEL_INTER_NONE_BS4_135, PT_NONE, 66, 6),
MODELDEF(NULL, MODEL_INTER_NONE_BS4_160, PT_NONE, 66, 0),
*/
MODELDEF_INS(sms_part_none_prune_tflite_model_8X8_110, MODEL_INTER_NONE_8X8_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X8_135, MODEL_INTER_NONE_8X8_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X8_160, MODEL_INTER_NONE_8X8_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X8_185, MODEL_INTER_NONE_8X8_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X8_210, MODEL_INTER_NONE_8X8_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X8_235, MODEL_INTER_NONE_8X8_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X16_110, MODEL_INTER_NONE_8X16_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X16_135, MODEL_INTER_NONE_8X16_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X16_160, MODEL_INTER_NONE_8X16_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X16_185, MODEL_INTER_NONE_8X16_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X16_210, MODEL_INTER_NONE_8X16_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X16_235, MODEL_INTER_NONE_8X16_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X8_110, MODEL_INTER_NONE_16X8_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X8_135, MODEL_INTER_NONE_16X8_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X8_160, MODEL_INTER_NONE_16X8_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X8_185, MODEL_INTER_NONE_16X8_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X8_210, MODEL_INTER_NONE_16X8_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X8_235, MODEL_INTER_NONE_16X8_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X16_110, MODEL_INTER_NONE_16X16_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X16_135, MODEL_INTER_NONE_16X16_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X16_160, MODEL_INTER_NONE_16X16_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X16_185, MODEL_INTER_NONE_16X16_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X16_210, MODEL_INTER_NONE_16X16_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X16_235, MODEL_INTER_NONE_16X16_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X32_110, MODEL_INTER_NONE_16X32_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X32_135, MODEL_INTER_NONE_16X32_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X32_160, MODEL_INTER_NONE_16X32_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X32_185, MODEL_INTER_NONE_16X32_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X32_210, MODEL_INTER_NONE_16X32_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X32_235, MODEL_INTER_NONE_16X32_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X16_110, MODEL_INTER_NONE_32X16_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X16_135, MODEL_INTER_NONE_32X16_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X16_160, MODEL_INTER_NONE_32X16_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X16_185, MODEL_INTER_NONE_32X16_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X16_210, MODEL_INTER_NONE_32X16_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X16_235, MODEL_INTER_NONE_32X16_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X32_110, MODEL_INTER_NONE_32X32_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X32_135, MODEL_INTER_NONE_32X32_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X32_160, MODEL_INTER_NONE_32X32_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X32_185, MODEL_INTER_NONE_32X32_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X32_210, MODEL_INTER_NONE_32X32_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X32_235, MODEL_INTER_NONE_32X32_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X64_110, MODEL_INTER_NONE_32X64_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X64_135, MODEL_INTER_NONE_32X64_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X64_160, MODEL_INTER_NONE_32X64_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X64_185, MODEL_INTER_NONE_32X64_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X64_210, MODEL_INTER_NONE_32X64_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X64_235, MODEL_INTER_NONE_32X64_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_110, MODEL_INTER_NONE_64X32_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_135, MODEL_INTER_NONE_64X32_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_160, MODEL_INTER_NONE_64X32_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_185, MODEL_INTER_NONE_64X32_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_210, MODEL_INTER_NONE_64X32_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X32_235, MODEL_INTER_NONE_64X32_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X64_110, MODEL_INTER_NONE_64X64_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X64_135, MODEL_INTER_NONE_64X64_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X64_160, MODEL_INTER_NONE_64X64_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X64_185, MODEL_INTER_NONE_64X64_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X64_210, MODEL_INTER_NONE_64X64_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X64_235, MODEL_INTER_NONE_64X64_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X128_110, MODEL_INTER_NONE_64X128_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X128_135, MODEL_INTER_NONE_64X128_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X128_160, MODEL_INTER_NONE_64X128_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X128_185, MODEL_INTER_NONE_64X128_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X128_210, MODEL_INTER_NONE_64X128_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X128_235, MODEL_INTER_NONE_64X128_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X64_110, MODEL_INTER_NONE_128X64_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X64_135, MODEL_INTER_NONE_128X64_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X64_160, MODEL_INTER_NONE_128X64_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X64_185, MODEL_INTER_NONE_128X64_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X64_210, MODEL_INTER_NONE_128X64_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X64_235, MODEL_INTER_NONE_128X64_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X128_110, MODEL_INTER_NONE_128X128_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X128_135, MODEL_INTER_NONE_128X128_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X128_160, MODEL_INTER_NONE_128X128_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X128_185, MODEL_INTER_NONE_128X128_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X128_210, MODEL_INTER_NONE_128X128_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X128_235, MODEL_INTER_NONE_128X128_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X256_110, MODEL_INTER_NONE_128X256_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X256_135, MODEL_INTER_NONE_128X256_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X256_160, MODEL_INTER_NONE_128X256_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X256_185, MODEL_INTER_NONE_128X256_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X256_210, MODEL_INTER_NONE_128X256_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_128X256_235, MODEL_INTER_NONE_128X256_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X128_110, MODEL_INTER_NONE_256X128_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X128_135, MODEL_INTER_NONE_256X128_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X128_160, MODEL_INTER_NONE_256X128_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X128_185, MODEL_INTER_NONE_256X128_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X128_210, MODEL_INTER_NONE_256X128_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X128_235, MODEL_INTER_NONE_256X128_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X256_110, MODEL_INTER_NONE_256X256_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X256_135, MODEL_INTER_NONE_256X256_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X256_160, MODEL_INTER_NONE_256X256_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X256_185, MODEL_INTER_NONE_256X256_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X256_210, MODEL_INTER_NONE_256X256_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_256X256_235, MODEL_INTER_NONE_256X256_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X32_110, MODEL_INTER_NONE_8X32_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X32_135, MODEL_INTER_NONE_8X32_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X32_160, MODEL_INTER_NONE_8X32_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X32_185, MODEL_INTER_NONE_8X32_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X32_210, MODEL_INTER_NONE_8X32_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X32_235, MODEL_INTER_NONE_8X32_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X8_110, MODEL_INTER_NONE_32X8_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X8_135, MODEL_INTER_NONE_32X8_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X8_160, MODEL_INTER_NONE_32X8_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X8_185, MODEL_INTER_NONE_32X8_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X8_210, MODEL_INTER_NONE_32X8_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_32X8_235, MODEL_INTER_NONE_32X8_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X64_110, MODEL_INTER_NONE_16X64_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X64_135, MODEL_INTER_NONE_16X64_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X64_160, MODEL_INTER_NONE_16X64_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X64_185, MODEL_INTER_NONE_16X64_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X64_210, MODEL_INTER_NONE_16X64_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_16X64_235, MODEL_INTER_NONE_16X64_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X16_110, MODEL_INTER_NONE_64X16_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X16_135, MODEL_INTER_NONE_64X16_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X16_160, MODEL_INTER_NONE_64X16_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X16_185, MODEL_INTER_NONE_64X16_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X16_210, MODEL_INTER_NONE_64X16_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X16_235, MODEL_INTER_NONE_64X16_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X64_110, MODEL_INTER_NONE_8X64_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X64_135, MODEL_INTER_NONE_8X64_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X64_160, MODEL_INTER_NONE_8X64_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X64_185, MODEL_INTER_NONE_8X64_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X64_210, MODEL_INTER_NONE_8X64_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_8X64_235, MODEL_INTER_NONE_8X64_235, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X8_110, MODEL_INTER_NONE_64X8_110, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X8_135, MODEL_INTER_NONE_64X8_135, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X8_160, MODEL_INTER_NONE_64X8_160, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X8_185, MODEL_INTER_NONE_64X8_185, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X8_210, MODEL_INTER_NONE_64X8_210, PT_NONE, 66, 6),
MODELDEF_INS(sms_part_none_prune_tflite_model_64X8_235, MODEL_INTER_NONE_64X8_235, PT_NONE, 66, 6),
};
static void dump_model(ModelDef *def) {
char buf[128];
sprintf(buf, "model_%d_%s.tflite", def->type, def->enum_name);
printf("Storing %s\n", buf);
FILE *fp = fopen(buf, "wb");
fwrite(def->model_def, 1, def->model_size, fp);
fclose(fp);
}
static void dump_input_spec(ModelDef *def) {
if (!def->input_spec.valid)
return;
char buf[128];
sprintf(buf, "model_%d_%s_input_spec.py", def->type, def->enum_name);
printf("Storing %s\n", buf);
FILE *fp = fopen(buf, "wb");
fprintf(fp, "MEAN = [");
for (int i = 0; i < def->n_features; i++)
fprintf(fp, "%e, ", def->input_spec.mean[i]);
fprintf(fp, "]\nSTD = [");
for (int i = 0; i < def->n_features; i++)
fprintf(fp, "%e, ", def->input_spec.std[i]);
fprintf(fp, "]\n");
fclose(fp);
}
static void get_input_order(tflite::Interpreter* interpreter, uint8_t* order) {
static const std::unordered_map<std::string, uint8_t> order_lut = {
{"serving_default_input:0", 0},
{"serving_default_input_1:0", 0},
{"serving_default_input_2:0", 1},
{"serving_default_input_3:0", 2},
{"serving_default_input_4:0", 3},
};
memset(order, 0, 8 * sizeof(order[0]));
for (int i = 0; i < interpreter->inputs().size(); ++i) {
int input_index = interpreter->inputs()[i];
const TfLiteTensor* input_tensor = interpreter->tensor(input_index);
try {
auto value = order_lut.at(input_tensor->name);
order[value] = input_index;
} catch (const std::out_of_range& oor) {
std::cout << "Model with unsupported input name: "
<< input_tensor->name << std::endl;
exit(-1);
}
}
printf("input order: %d,%d,%d,%d\n", order[0], order[1], order[2], order[3]);
}
static void ensure_tflite_init(void **context, MODEL_TYPE model_type) {
assert(model_type != MODEL_OTHER);
if (*context == nullptr) *context = new Context();
Context *ctx = (Context *)*context;
ModelDef def = models[model_type];
if (!ctx->models[model_type]) {
if (def.model_def != NULL) {
ctx->models[model_type] = create_interpreter(def.model_def, ctx->to_delete);
get_input_order(ctx->models[model_type].get(), &ctx->input_order[model_type][0]);
dump_model(&def);
dump_input_spec(&def);
} else {
printf("\x1b[91mUsing undefined model: %s(%d)\x1b[0m\n",
models[model_type].enum_name, model_type);
}
}
}
extern "C" int av2_model_input_spec(MODEL_TYPE model_type,
struct InputSpec *input_spec) {
assert(model_type != MODEL_OTHER);
ModelDef def = models[model_type];
*input_spec = def.input_spec;
return 0;
}
#if HAVE_FEXCEPT && CONFIG_DEBUG
#define FLOATING_POINT_DISABLE_EXCEPTIONS \
const int float_excepts = fedisableexcept(FE_UNDERFLOW | FE_OVERFLOW);
#define FLOATING_POINT_RESTORE_EXCEPTIONS feenableexcept(float_excepts);
#else
#define FLOATING_POINT_DISABLE_EXCEPTIONS
#define FLOATING_POINT_RESTORE_EXCEPTIONS
#endif // HAVE_FEXCEPT && CONFIG_DEBUG
// Simple intra ML TFLite based inference
static inline float norm(const float* input, int feature, struct InputSpec spec) {
return spec.valid ?
(input[feature] - spec.mean[feature]) * spec.invstd[feature] :
input[feature];
}
extern "C" int av2_part_prune_tflite_exec(void **context, const float *ml_input,
float *ml_output, MODEL_TYPE model_type) {
assert(model_type != MODEL_OTHER);
ensure_tflite_init(context, model_type);
Context *ctx = (Context *)*context;
tflite::Interpreter *interpreter = ctx->models[model_type].get();
tflite::ErrorReporter *reporter(tflite::DefaultErrorReporter());
int model_version = models[model_type].model_version;
struct InputSpec input_spec = models[model_type].input_spec;
int input_len = models[model_type].n_features;
int output_len = 1;
if (model_version == 0) {
const TfLiteTensor *input_tensor = interpreter->input_tensor(0);
int num_input_features = input_tensor->dims->data[1];
if (num_input_features > input_len) {
printf("\x1b[91mERROR:\x1b[0m Not enough input features: %d>%d\n",
num_input_features, input_len);
exit(1);
}
if (num_input_features != input_len && !ctx->old_model[model_type]) {
printf("\x1b[95mWARN:\x1b[0m Too many input features for model %s: %d<%d"
" (is it an old model?)\n", get_model_name(model_type),
num_input_features, input_len);
ctx->old_model[model_type] = true;
}
float *input = interpreter->typed_input_tensor<float>(0);
if (input_spec.valid) {
for (int i = 0; i < num_input_features; i++) {
input[i] = (ml_input[i] - input_spec.mean[i]) * input_spec.invstd[i];
}
} else {
for (int i = 0; i < num_input_features; i++) {
input[i] = ml_input[i];
}
}
} else if (model_version == 6) {
uint8_t *input_order = &ctx->input_order[model_type][0];
float *input0 = interpreter->typed_input_tensor<float>(input_order[0]);
float *input1 = interpreter->typed_input_tensor<float>(input_order[1]);
float *input2 = interpreter->typed_input_tensor<float>(input_order[2]);
float *input3 = interpreter->typed_input_tensor<float>(input_order[3]);
input0[0] = norm(ml_input, FEATURE_INTER_RD_MULT, input_spec);
input0[1] = norm(ml_input, FEATURE_INTER_SWITCH, input_spec);
input0[2] = norm(ml_input, FEATURE_INTER_PART_T, input_spec);
input1[0] = norm(ml_input, FEATURE_INTER_FULL_PSNR, input_spec);
input1[1] = norm(ml_input, FEATURE_INTER_SQ_0_PSNR, input_spec);
input1[2] = norm(ml_input, FEATURE_INTER_SQ_1_PSNR, input_spec);
input1[3] = norm(ml_input, FEATURE_INTER_SQ_2_PSNR, input_spec);
input1[4] = norm(ml_input, FEATURE_INTER_SQ_3_PSNR, input_spec);
input1[5] = norm(ml_input, FEATURE_INTER_FULL_Q_COEFF_MAX, input_spec);
input1[6] = norm(ml_input, FEATURE_INTER_SQ_0_Q_COEFF_MAX, input_spec);
input1[7] = norm(ml_input, FEATURE_INTER_SQ_1_Q_COEFF_MAX, input_spec);
input1[8] = norm(ml_input, FEATURE_INTER_SQ_2_Q_COEFF_MAX, input_spec);
input1[9] = norm(ml_input, FEATURE_INTER_SQ_3_Q_COEFF_MAX, input_spec);
input1[10] = norm(ml_input, FEATURE_INTER_FULL_Q_COEFF_NONZ, input_spec);
input1[11] = norm(ml_input, FEATURE_INTER_SQ_0_Q_COEFF_NONZ, input_spec);
input1[12] = norm(ml_input, FEATURE_INTER_SQ_1_Q_COEFF_NONZ, input_spec);
input1[13] = norm(ml_input, FEATURE_INTER_SQ_2_Q_COEFF_NONZ, input_spec);
input1[14] = norm(ml_input, FEATURE_INTER_SQ_3_Q_COEFF_NONZ, input_spec);
input1[15] = norm(ml_input, FEATURE_INTER_FULL_LOG_SATDQ, input_spec);
input1[16] = norm(ml_input, FEATURE_INTER_SQ_0_LOG_SATDQ, input_spec);
input1[17] = norm(ml_input, FEATURE_INTER_SQ_1_LOG_SATDQ, input_spec);
input1[18] = norm(ml_input, FEATURE_INTER_SQ_2_LOG_SATDQ, input_spec);
input1[19] = norm(ml_input, FEATURE_INTER_SQ_3_LOG_SATDQ, input_spec);
input2[0] = norm(ml_input, FEATURE_INTER_FULL_PSNR, input_spec);
input2[1] = norm(ml_input, FEATURE_INTER_HOR_0_PSNR, input_spec);
input2[2] = norm(ml_input, FEATURE_INTER_HOR_1_PSNR, input_spec);
input2[3] = norm(ml_input, FEATURE_INTER_FULL_Q_COEFF_MAX, input_spec);
input2[4] = norm(ml_input, FEATURE_INTER_HOR_0_Q_COEFF_MAX, input_spec);
input2[5] = norm(ml_input, FEATURE_INTER_HOR_1_Q_COEFF_MAX, input_spec);
input2[6] = norm(ml_input, FEATURE_INTER_FULL_Q_COEFF_NONZ, input_spec);
input2[7] = norm(ml_input, FEATURE_INTER_HOR_0_Q_COEFF_NONZ, input_spec);
input2[8] = norm(ml_input, FEATURE_INTER_HOR_1_Q_COEFF_NONZ, input_spec);
input2[9] = norm(ml_input, FEATURE_INTER_FULL_LOG_SATDQ, input_spec);
input2[10] = norm(ml_input, FEATURE_INTER_HOR_0_LOG_SATDQ, input_spec);
input2[11] = norm(ml_input, FEATURE_INTER_HOR_1_LOG_SATDQ, input_spec);
input3[0] = norm(ml_input, FEATURE_INTER_FULL_PSNR, input_spec);
input3[1] = norm(ml_input, FEATURE_INTER_VER_0_PSNR, input_spec);
input3[2] = norm(ml_input, FEATURE_INTER_VER_1_PSNR, input_spec);
input3[3] = norm(ml_input, FEATURE_INTER_FULL_Q_COEFF_MAX, input_spec);
input3[4] = norm(ml_input, FEATURE_INTER_VER_0_Q_COEFF_MAX, input_spec);
input3[5] = norm(ml_input, FEATURE_INTER_VER_1_Q_COEFF_MAX, input_spec);
input3[6] = norm(ml_input, FEATURE_INTER_FULL_Q_COEFF_NONZ, input_spec);
input3[7] = norm(ml_input, FEATURE_INTER_VER_0_Q_COEFF_NONZ, input_spec);
input3[8] = norm(ml_input, FEATURE_INTER_VER_1_Q_COEFF_NONZ, input_spec);
input3[9] = norm(ml_input, FEATURE_INTER_FULL_LOG_SATDQ, input_spec);
input3[10] = norm(ml_input, FEATURE_INTER_VER_0_LOG_SATDQ, input_spec);
input3[11] = norm(ml_input, FEATURE_INTER_VER_1_LOG_SATDQ, input_spec);
}
FLOATING_POINT_DISABLE_EXCEPTIONS
auto status = interpreter->Invoke();
FLOATING_POINT_RESTORE_EXCEPTIONS
if (status != kTfLiteOk) {
reporter->Report("Failed at invoke");
exit(1);
}
float *output = interpreter->typed_output_tensor<float>(0);
for (int i = 0; i < output_len; i++) {
ml_output[i] = output[i];
}
return 0;
}
extern "C" void av2_part_prune_tflite_close(void **context) {
Context *ctx = (Context *)*context;
if (ctx != nullptr) delete ctx;
*context = nullptr;
}
extern "C" const char *get_model_name(MODEL_TYPE type) {
if (type >= MODEL_COUNT) {
return "NA";
}
return models[type].enum_name;
}
extern "C" int get_model_part_type(MODEL_TYPE type) {
if (type >= MODEL_COUNT) {
return PT_INVAL;
}
return models[type].part_type;
}
extern "C" int get_model_n_features(MODEL_TYPE type) {
if (type >= MODEL_COUNT) {
return 0;
}
return models[type].n_features;
}