blob: af054f77cd69b269e09b467d0b30746c049f8491 [file] [log] [blame]
/*
* Copyright (c) 2020, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include <vector>
#include "av1/common/av1_common_int.h"
#include "av1/common/cnn_tflite.h"
#include "av1/tflite_models/op_registrations.h"
#include "av1/tflite_models/intra_frame_model/uv_qp12.h"
#include "av1/tflite_models/intra_frame_model/uv_qp22.h"
#include "av1/tflite_models/intra_frame_model/uv_qp32.h"
#include "av1/tflite_models/intra_frame_model/uv_qp43.h"
#include "av1/tflite_models/intra_frame_model/uv_qp53.h"
#include "av1/tflite_models/intra_frame_model/uv_qp63.h"
#include "av1/tflite_models/intra_frame_model/qp0_90.h"
#include "av1/tflite_models/intra_frame_model/qp91_120.h"
#include "av1/tflite_models/intra_frame_model/qp121_145.h"
#include "av1/tflite_models/intra_frame_model/qp146_175.h"
#include "av1/tflite_models/intra_frame_model/qp176_205.h"
#include "av1/tflite_models/intra_frame_model/qp206_255.h"
#include "av1/tflite_models/inter_frame_model/uv_qp0_67.h"
#include "av1/tflite_models/inter_frame_model/uv_qp68_107.h"
#include "av1/tflite_models/inter_frame_model/uv_qp108_147.h"
#include "av1/tflite_models/inter_frame_model/uv_qp148_191.h"
#include "av1/tflite_models/inter_frame_model/uv_qp192_231.h"
#include "av1/tflite_models/inter_frame_model/uv_qp232_255.h"
#include "av1/tflite_models/inter_frame_model/qp0_90.h"
#include "av1/tflite_models/inter_frame_model/qp91_120.h"
#include "av1/tflite_models/inter_frame_model/qp121_145.h"
#include "av1/tflite_models/inter_frame_model/qp146_175.h"
#include "av1/tflite_models/inter_frame_model/qp176_205.h"
#include "av1/tflite_models/inter_frame_model/qp206_255.h"
#include "common/tf_lite_includes.h"
#if CONFIG_CNN_RESTORATION
// Returns the TF-lite model based on the qindex.
static const unsigned char *get_intra_model_from_qindex(int qindex,
int is_luma) {
if (qindex <= MIN_CNN_Q_INDEX) {
assert(0);
return nullptr;
}
if (is_luma) {
if (qindex < 91) {
return qp0_90_model_tflite_data;
} else if (qindex < 121) {
return qp91_120_model_tflite_data;
} else if (qindex < 146) {
return qp121_145_model_tflite_data;
} else if (qindex < 176) {
return qp146_175_model_tflite_data;
} else if (qindex < 206) {
return qp176_205_model_tflite_data;
} else {
return qp206_255_model_tflite_data;
}
} else {
if (qindex < 68) {
return uv_qp12_model_tflite_data;
} else if (qindex < 108) {
return uv_qp22_model_tflite_data;
} else if (qindex < 148) {
return uv_qp32_model_tflite_data;
} else if (qindex < 192) {
return uv_qp43_model_tflite_data;
} else if (qindex < 232) {
return uv_qp53_model_tflite_data;
} else {
return uv_qp63_model_tflite_data;
}
}
}
// Returns the TF-lite model based on the qindex.
static const unsigned char *get_inter_model_from_qindex(int qindex,
int is_luma) {
if (qindex <= MIN_CNN_Q_INDEX) {
assert(0);
return nullptr;
}
if (is_luma) {
if (qindex < 91) {
return qp0_90_inter_model_tflite_data;
} else if (qindex < 121) {
return qp91_120_inter_model_tflite_data;
} else if (qindex < 146) {
return qp121_145_inter_model_tflite_data;
} else if (qindex < 176) {
return qp146_175_inter_model_tflite_data;
} else if (qindex < 206) {
return qp176_205_inter_model_tflite_data;
} else {
return qp206_255_inter_model_tflite_data;
}
} else {
if (qindex < 68) {
return uv_qp0_67_inter_model_tflite_data;
} else if (qindex < 108) {
return uv_qp68_107_inter_model_tflite_data;
} else if (qindex < 148) {
return uv_qp108_147_inter_model_tflite_data;
} else if (qindex < 192) {
return uv_qp148_191_inter_model_tflite_data;
} else if (qindex < 232) {
return uv_qp192_231_inter_model_tflite_data;
} else {
return uv_qp232_255_inter_model_tflite_data;
}
}
}
static TfLiteDelegate *get_tflite_xnnpack_delegate(int num_threads) {
TfLiteXNNPackDelegateOptions xnnpack_options =
TfLiteXNNPackDelegateOptionsDefault();
xnnpack_options.num_threads = AOMMAX(num_threads, 1);
return TfLiteXNNPackDelegateCreate(&xnnpack_options);
}
// Builds and returns the TFlite interpreter.
static std::unique_ptr<tflite::Interpreter> get_tflite_interpreter(
int qindex, int width, int height, int num_threads, int is_intra_only,
int is_luma, TfLiteDelegate *xnnpack_delegate) {
const unsigned char *const model_tflite_data =
is_intra_only ? get_intra_model_from_qindex(qindex, is_luma)
: get_inter_model_from_qindex(qindex, is_luma);
auto model = tflite::GetModel(model_tflite_data);
tflite::MutableOpResolver resolver;
RegisterSelectedOpsAllQps(&resolver);
tflite::InterpreterBuilder builder(model, resolver);
// TODO(urvang): Investigate if caching the interpreter object provides
// further speed-up. May still have to re-build the interpreter if qindex
// changes.
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
interpreter->SetNumThreads(AOMMAX(num_threads, 1));
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
// Dimension order: batch_size, height, width, num_channels.
// Note: height comes before width here!
const std::vector<int> in_out_dims = { 1, height, width, 1 };
// We only need to resize the input tensor. All other tensors (including
// output tensor) will be resized automatically.
if (interpreter->ResizeInputTensor(interpreter->inputs()[0], in_out_dims) !=
kTfLiteOk) {
reporter->Report("Failed at input tensor resize");
return nullptr;
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
reporter->Report("Failed at tensor allocation");
return nullptr;
}
if (interpreter->ModifyGraphWithDelegate(xnnpack_delegate) != kTfLiteOk) {
reporter->Report("Failed at modifying graph with XNNPack delegate");
return nullptr;
}
return interpreter;
}
extern "C" int av1_restore_cnn_img_tflite(int qindex, const uint8_t *dgd,
int width, int height, int dgd_stride,
uint8_t *rst, int rst_stride,
int num_threads, int is_intra_only,
int is_luma) {
TfLiteDelegate *xnnpack_delegate = get_tflite_xnnpack_delegate(num_threads);
std::unique_ptr<tflite::Interpreter> interpreter =
get_tflite_interpreter(qindex, width, height, num_threads, is_intra_only,
is_luma, xnnpack_delegate);
// Prepare input.
const float max_val = 255.0f;
const int in_stride = width;
auto input = interpreter->typed_input_tensor<float>(0);
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
input[r * in_stride + c] =
static_cast<float>(dgd[r * dgd_stride + c]) / max_val;
assert(input[r * in_stride + c] >= 0.0f);
assert(input[r * in_stride + c] <= 1.0f);
}
}
// Invoke TFlite inference.
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
auto status = interpreter->Invoke();
if (status != kTfLiteOk) {
reporter->Report("Failed at interpreter invocation");
return 0;
}
// Use the output to restore 'dgd' and store in 'rst'.
const auto output = interpreter->typed_output_tensor<float>(0);
const int out_stride = width;
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
const int residue =
static_cast<int>(output[r * out_stride + c] * max_val + 0.5);
rst[r * rst_stride + c] = clip_pixel(dgd[r * dgd_stride + c] + residue);
}
}
// IMPORTANT: release the interpreter before destroying the delegate.
interpreter.reset();
TfLiteXNNPackDelegateDelete(xnnpack_delegate);
return 1;
}
extern "C" int av1_restore_cnn_img_tflite_highbd(
int qindex, const uint16_t *dgd, int width, int height, int dgd_stride,
uint16_t *rst, int rst_stride, int num_threads, int bit_depth,
int is_intra_only, int is_luma) {
TfLiteDelegate *xnnpack_delegate = get_tflite_xnnpack_delegate(num_threads);
std::unique_ptr<tflite::Interpreter> interpreter =
get_tflite_interpreter(qindex, width, height, num_threads, is_intra_only,
is_luma, xnnpack_delegate);
// Prepare input.
const auto max_val = static_cast<float>((1 << bit_depth) - 1);
const int in_stride = width;
auto input = interpreter->typed_input_tensor<float>(0);
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
input[r * in_stride + c] =
static_cast<float>(dgd[r * dgd_stride + c]) / max_val;
assert(input[r * in_stride + c] >= 0.0f);
assert(input[r * in_stride + c] <= 1.0f);
}
}
// Invoke TFlite inference.
tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter();
auto status = interpreter->Invoke();
if (status != kTfLiteOk) {
reporter->Report("Failed at interpreter invocation");
return 0;
}
// Use the output to restore 'dgd' and store in 'rst'.
const auto output = interpreter->typed_output_tensor<float>(0);
const int out_stride = width;
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
const int residue =
static_cast<int>(output[r * out_stride + c] * max_val + 0.5);
rst[r * rst_stride + c] =
clip_pixel_highbd(dgd[r * dgd_stride + c] + residue, bit_depth);
}
}
// IMPORTANT: release the interpreter before destroying the delegate.
interpreter.reset();
TfLiteXNNPackDelegateDelete(xnnpack_delegate);
return 1;
}
extern "C" void av1_restore_cnn_tflite(const AV1_COMMON *cm, int num_threads,
const int apply_cnn[MAX_MB_PLANE]) {
YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf;
const int is_intra_only = frame_is_intra_only(cm);
for (int plane = 0; plane < av1_num_planes(cm); ++plane) {
if (!apply_cnn[plane]) continue;
const int is_luma = (plane == AOM_PLANE_Y);
if (cm->seq_params.use_highbitdepth) {
switch (plane) {
case AOM_PLANE_Y:
av1_restore_cnn_img_tflite_highbd(
cm->quant_params.base_qindex, CONVERT_TO_SHORTPTR(buf->y_buffer),
buf->y_crop_width, buf->y_crop_height, buf->y_stride,
CONVERT_TO_SHORTPTR(buf->y_buffer), buf->y_stride, num_threads,
cm->seq_params.bit_depth, is_intra_only, is_luma);
break;
case AOM_PLANE_U:
av1_restore_cnn_img_tflite_highbd(
cm->quant_params.base_qindex, CONVERT_TO_SHORTPTR(buf->u_buffer),
buf->uv_crop_width, buf->uv_crop_height, buf->uv_stride,
CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_stride, num_threads,
cm->seq_params.bit_depth, is_intra_only, is_luma);
break;
case AOM_PLANE_V:
av1_restore_cnn_img_tflite_highbd(
cm->quant_params.base_qindex, CONVERT_TO_SHORTPTR(buf->v_buffer),
buf->uv_crop_width, buf->uv_crop_height, buf->uv_stride,
CONVERT_TO_SHORTPTR(buf->v_buffer), buf->uv_stride, num_threads,
cm->seq_params.bit_depth, is_intra_only, is_luma);
break;
default: assert(0 && "Invalid plane index");
}
} else {
assert(cm->seq_params.bit_depth == 8);
switch (plane) {
case AOM_PLANE_Y:
av1_restore_cnn_img_tflite(
cm->quant_params.base_qindex, buf->y_buffer, buf->y_crop_width,
buf->y_crop_height, buf->y_stride, buf->y_buffer, buf->y_stride,
num_threads, is_intra_only, is_luma);
break;
case AOM_PLANE_U:
av1_restore_cnn_img_tflite(
cm->quant_params.base_qindex, buf->u_buffer, buf->uv_crop_width,
buf->uv_crop_height, buf->uv_stride, buf->u_buffer,
buf->uv_stride, num_threads, is_intra_only, is_luma);
break;
case AOM_PLANE_V:
av1_restore_cnn_img_tflite(
cm->quant_params.base_qindex, buf->v_buffer, buf->uv_crop_width,
buf->uv_crop_height, buf->uv_stride, buf->v_buffer,
buf->uv_stride, num_threads, is_intra_only, is_luma);
break;
default: assert(0 && "Invalid plane index");
}
}
}
}
#endif // CONFIG_CNN_RESTORATION