blob: 29b3d32161a6874cbf53f2f80d19840977afeb59 [file] [log] [blame]
/*
* Copyright (c) 2020, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include <vector>
#include "av1/common/cnn_tflite.h"
#include "av1/common/onyxc_int.h"
#include "av1/tflite_models/op_registrations.h"
#include "av1/tflite_models/intra_frame_model/qp22.h"
#include "av1/tflite_models/intra_frame_model/qp32.h"
#include "av1/tflite_models/intra_frame_model/qp43.h"
#include "av1/tflite_models/intra_frame_model/qp53.h"
#include "av1/tflite_models/intra_frame_model/qp63.h"
#include "av1/tflite_models/inter_frame_model/qp68_107.h"
#include "av1/tflite_models/inter_frame_model/qp108_147.h"
#include "av1/tflite_models/inter_frame_model/qp148_191.h"
#include "av1/tflite_models/inter_frame_model/qp192_231.h"
#include "av1/tflite_models/inter_frame_model/qp232_255.h"
#if CONFIG_NN_RECON
#include "av1/tflite_models/intra_txfm_recon_model/tx16x16.h"
#endif // CONFIG_NN_RECON
#include "common/tf_lite_includes.h"
#if CONFIG_CNN_RESTORATION || CONFIG_LOOP_RESTORE_CNN
// Returns the TF-lite model based on the qindex.
static const unsigned char *get_intra_model_from_qindex(int qindex) {
if (qindex <= MIN_CNN_Q_INDEX) {
assert(0);
return nullptr;
} else if (qindex < 108) {
return qp22_model_tflite_data;
} else if (qindex < 148) {
return qp32_model_tflite_data;
} else if (qindex < 192) {
return qp43_model_tflite_data;
} else if (qindex < 232) {
return qp53_model_tflite_data;
} else {
return qp63_model_tflite_data;
}
}
// Returns the TF-lite model based on the qindex.
static const unsigned char *get_inter_model_from_qindex(int qindex) {
if (qindex <= MIN_CNN_Q_INDEX) {
assert(0);
return nullptr;
} else if (qindex < 108) {
return qp68_107_inter_model_tflite_data;
} else if (qindex < 148) {
return qp108_147_inter_model_tflite_data;
} else if (qindex < 192) {
return qp148_191_inter_model_tflite_data;
} else if (qindex < 232) {
return qp192_231_inter_model_tflite_data;
} else {
return qp232_255_inter_model_tflite_data;
}
}
// Builds and returns the TFlite interpreter.
static std::unique_ptr<tflite::Interpreter> get_tflite_interpreter(
int qindex, int width, int height, int num_threads, int is_intra_only) {
const unsigned char *const model_tflite_data =
is_intra_only ? get_intra_model_from_qindex(qindex)
: get_inter_model_from_qindex(qindex);
auto model = tflite::GetModel(model_tflite_data);
tflite::MutableOpResolver resolver;
RegisterSelectedOpsAllQps(&resolver);
tflite::InterpreterBuilder builder(model, resolver);
// TODO(urvang): Investigate if caching the interpreter object provides
// further speed-up. May still have to re-build the interpreter if qindex
// changes.
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
interpreter->SetNumThreads(AOMMAX(num_threads, 1));
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
// Dimension order: batch_size, height, width, num_channels.
// Note: height comes before width here!
const std::vector<int> in_out_dims = { 1, height, width, 1 };
// We only need to resize the input tensor. All other tensors (including
// output tensor) will be resized automatically.
if (interpreter->ResizeInputTensor(interpreter->inputs()[0], in_out_dims) !=
kTfLiteOk) {
reporter->Report("Failed at input tensor resize");
return nullptr;
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
reporter->Report("Failed at tensor allocation");
return nullptr;
}
return interpreter;
}
extern "C" int av1_restore_cnn_img_tflite(int qindex, const uint8_t *dgd,
int width, int height, int dgd_stride,
uint8_t *rst, int rst_stride,
int num_threads, int is_intra_only) {
std::unique_ptr<tflite::Interpreter> interpreter =
get_tflite_interpreter(qindex, width, height, num_threads, is_intra_only);
// Prepare input.
const float max_val = 255.0f;
const int in_stride = width;
auto input = interpreter->typed_input_tensor<float>(0);
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
input[r * in_stride + c] =
static_cast<float>(dgd[r * dgd_stride + c]) / max_val;
assert(input[r * in_stride + c] >= 0.0f);
assert(input[r * in_stride + c] <= 1.0f);
}
}
// Invoke TFlite inference.
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
auto status = interpreter->Invoke();
if (status != kTfLiteOk) {
reporter->Report("Failed at interpreter invocation");
return 0;
}
// Use the output to restore 'dgd' and store in 'rst'.
const auto output = interpreter->typed_output_tensor<float>(0);
const int out_stride = width;
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
const int residue =
static_cast<int>(output[r * out_stride + c] * max_val + 0.5);
rst[r * rst_stride + c] = clip_pixel(dgd[r * dgd_stride + c] + residue);
}
}
return 1;
}
extern "C" int av1_restore_cnn_img_tflite_highbd(int qindex,
const uint16_t *dgd, int width,
int height, int dgd_stride,
uint16_t *rst, int rst_stride,
int num_threads, int bit_depth,
int is_intra_only) {
std::unique_ptr<tflite::Interpreter> interpreter =
get_tflite_interpreter(qindex, width, height, num_threads, is_intra_only);
// Prepare input.
const auto max_val = static_cast<float>((1 << bit_depth) - 1);
const int in_stride = width;
auto input = interpreter->typed_input_tensor<float>(0);
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
input[r * in_stride + c] =
static_cast<float>(dgd[r * dgd_stride + c]) / max_val;
assert(input[r * in_stride + c] >= 0.0f);
assert(input[r * in_stride + c] <= 1.0f);
}
}
// Invoke TFlite inference.
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
auto status = interpreter->Invoke();
if (status != kTfLiteOk) {
reporter->Report("Failed at interpreter invocation");
return 0;
}
// Use the output to restore 'dgd' and store in 'rst'.
const auto output = interpreter->typed_output_tensor<float>(0);
const int out_stride = width;
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
const int residue =
static_cast<int>(output[r * out_stride + c] * max_val + 0.5);
rst[r * rst_stride + c] =
clip_pixel_highbd(dgd[r * dgd_stride + c] + residue, bit_depth);
}
}
return 1;
}
extern "C" void av1_restore_cnn_tflite(const AV1_COMMON *cm, int num_threads) {
YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf;
const int plane_from = AOM_PLANE_Y;
const int plane_to = AOM_PLANE_Y;
const int is_intra_only = frame_is_intra_only(cm);
for (int plane = plane_from; plane <= plane_to; ++plane) {
if (cm->seq_params.use_highbitdepth) {
switch (plane) {
case AOM_PLANE_Y:
av1_restore_cnn_img_tflite_highbd(
cm->base_qindex, CONVERT_TO_SHORTPTR(buf->y_buffer),
buf->y_crop_width, buf->y_crop_height, buf->y_stride,
CONVERT_TO_SHORTPTR(buf->y_buffer), buf->y_stride, num_threads,
cm->seq_params.bit_depth, is_intra_only);
break;
case AOM_PLANE_U:
av1_restore_cnn_img_tflite_highbd(
cm->base_qindex, CONVERT_TO_SHORTPTR(buf->u_buffer),
buf->uv_crop_width, buf->uv_crop_height, buf->uv_stride,
CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_stride, num_threads,
cm->seq_params.bit_depth, is_intra_only);
break;
case AOM_PLANE_V:
av1_restore_cnn_img_tflite_highbd(
cm->base_qindex, CONVERT_TO_SHORTPTR(buf->v_buffer),
buf->uv_crop_width, buf->uv_crop_height, buf->uv_stride,
CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_stride, num_threads,
cm->seq_params.bit_depth, is_intra_only);
break;
default: assert(0 && "Invalid plane index");
}
} else {
assert(cm->seq_params.bit_depth == 8);
switch (plane) {
case AOM_PLANE_Y:
av1_restore_cnn_img_tflite(cm->base_qindex, buf->y_buffer,
buf->y_crop_width, buf->y_crop_height,
buf->y_stride, buf->y_buffer,
buf->y_stride, num_threads, is_intra_only);
break;
case AOM_PLANE_U:
av1_restore_cnn_img_tflite(
cm->base_qindex, buf->u_buffer, buf->uv_crop_width,
buf->uv_crop_height, buf->uv_stride, buf->u_buffer,
buf->uv_stride, num_threads, is_intra_only);
break;
case AOM_PLANE_V:
av1_restore_cnn_img_tflite(
cm->base_qindex, buf->v_buffer, buf->uv_crop_width,
buf->uv_crop_height, buf->uv_stride, buf->v_buffer,
buf->uv_stride, num_threads, is_intra_only);
break;
default: assert(0 && "Invalid plane index");
}
}
}
}
#endif // CONFIG_CNN_RESTORATION || CONFIG_LOOP_RESTORE_CNN
#if CONFIG_NN_RECON
// Builds and returns the TFlite interpreter.
static std::unique_ptr<tflite::Interpreter> get_nn_recon_tflite_interpreter(
int width, int height, int num_threads) {
auto model = tflite::GetModel(tx16x16_tflite);
tflite::MutableOpResolver resolver;
RegisterSelectedOpsAllQps(&resolver);
tflite::InterpreterBuilder builder(model, resolver);
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
interpreter->SetNumThreads(AOMMAX(num_threads, 1));
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
// Dimension order: batch_size, height, width, num_channels.
// Note: height comes before width here!
const std::vector<int> in_out_dims = { 1, height, width, 1 };
// We only need to resize the input tensor. All other tensors (including
// output tensor) will be resized automatically.
if (interpreter->ResizeInputTensor(interpreter->inputs()[0], in_out_dims) !=
kTfLiteOk) {
reporter->Report("Failed at input tensor resize");
return nullptr;
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
reporter->Report("Failed at tensor allocation");
return nullptr;
}
return interpreter;
}
extern "C" int av1_cnn_recon_tflite(uint8_t *dst, int dst_stride, int height,
int width) {
const int num_threads = 1;
std::unique_ptr<tflite::Interpreter> interpreter =
get_nn_recon_tflite_interpreter(width, height, num_threads);
// Prepare input.
const int in_stride = width;
auto input = interpreter->typed_input_tensor<float>(0);
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
input[r * in_stride + c] = static_cast<float>(dst[r * dst_stride + c]);
}
}
// Invoke TFlite inference.
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
auto status = interpreter->Invoke();
if (status != kTfLiteOk) {
reporter->Report("Failed at interpreter invocation");
return 0;
}
// Use the output to restore 'dgd' and store in 'rst'.
const auto output = interpreter->typed_output_tensor<float>(0);
const int out_stride = width;
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
const int residue = static_cast<int>(output[r * out_stride + c] + 0.5);
dst[r * dst_stride + c] = clip_pixel(dst[r * dst_stride + c] + residue);
}
}
return 1;
}
#endif // CONFIG_NN_RECON