| /* |
| * Copyright (c) 2020, Alliance for Open Media. All rights reserved |
| * |
| * This source code is subject to the terms of the BSD 2 Clause License and |
| * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License |
| * was not distributed with this source code in the LICENSE file, you can |
| * obtain it at www.aomedia.org/license/software. If the Alliance for Open |
| * Media Patent License 1.0 was not distributed with this source code in the |
| * PATENTS file, you can obtain it at www.aomedia.org/license/patent. |
| */ |
| |
| #include <vector> |
| |
| #include "aom_dsp/binary_codes_writer.h" |
| #include "av1/common/av1_common_int.h" |
| #include "av1/common/cnn_tflite.h" |
| #include "av1/tflite_models/op_registrations.h" |
| #include "av1/tflite_models/intra_frame_model/uv_qp0_90.h" |
| #include "av1/tflite_models/intra_frame_model/uv_qp91_120.h" |
| #include "av1/tflite_models/intra_frame_model/uv_qp121_145.h" |
| #include "av1/tflite_models/intra_frame_model/uv_qp146_175.h" |
| #include "av1/tflite_models/intra_frame_model/uv_qp176_205.h" |
| #include "av1/tflite_models/intra_frame_model/uv_qp206_255.h" |
| #include "av1/tflite_models/intra_frame_model/qp0_90.h" |
| #include "av1/tflite_models/intra_frame_model/qp91_120.h" |
| #include "av1/tflite_models/intra_frame_model/qp121_145.h" |
| #include "av1/tflite_models/intra_frame_model/qp146_175.h" |
| #include "av1/tflite_models/intra_frame_model/qp176_205.h" |
| #include "av1/tflite_models/intra_frame_model/qp206_255.h" |
| #include "av1/tflite_models/inter_frame_model/uv_qp0_90.h" |
| #include "av1/tflite_models/inter_frame_model/uv_qp91_120.h" |
| #include "av1/tflite_models/inter_frame_model/uv_qp121_145.h" |
| #include "av1/tflite_models/inter_frame_model/uv_qp146_175.h" |
| #include "av1/tflite_models/inter_frame_model/uv_qp176_205.h" |
| #include "av1/tflite_models/inter_frame_model/uv_qp206_255.h" |
| #include "av1/tflite_models/inter_frame_model/qp0_90.h" |
| #include "av1/tflite_models/inter_frame_model/qp91_120.h" |
| #include "av1/tflite_models/inter_frame_model/qp121_145.h" |
| #include "av1/tflite_models/inter_frame_model/qp146_175.h" |
| #include "av1/tflite_models/inter_frame_model/qp176_205.h" |
| #include "av1/tflite_models/inter_frame_model/qp206_255.h" |
| #if CONFIG_EXT_SUPERRES |
| #include "av1/tflite_models/inter_frame_model/sr5by4ra_1_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr5by4ra_2_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr5by4ra_3_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr3by2ra_1_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr3by2ra_2_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr3by2ra_3_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr7by4ra_1_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr7by4ra_2_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr7by4ra_3_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr2by1ra_1_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr2by1ra_2_tflite.h" |
| #include "av1/tflite_models/inter_frame_model/sr2by1ra_3_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr5by4ai_1_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr5by4ai_2_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr5by4ai_3_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr3by2ai_1_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr3by2ai_2_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr3by2ai_3_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr7by4ai_1_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr7by4ai_2_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr7by4ai_3_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr2by1ai_1_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr2by1ai_2_tflite.h" |
| #include "av1/tflite_models/intra_frame_model/sr2by1ai_3_tflite.h" |
| #endif // CONFIG_EXT_SUPERRES |
| |
| #if CONFIG_CNN_GUIDED_QUADTREE |
| #include "av1/tflite_models/inter_frame_model/qp0_90_quadtree.h" |
| #include "av1/tflite_models/inter_frame_model/qp91_120_quadtree.h" |
| #include "av1/tflite_models/inter_frame_model/qp121_145_quadtree.h" |
| #include "av1/tflite_models/inter_frame_model/qp146_175_quadtree.h" |
| #include "av1/tflite_models/inter_frame_model/qp176_205_quadtree.h" |
| #include "av1/tflite_models/inter_frame_model/qp206_255_quadtree.h" |
| #include "av1/tflite_models/intra_frame_model/qp0_90_quadtree.h" |
| #include "av1/tflite_models/intra_frame_model/qp91_120_quadtree.h" |
| #include "av1/tflite_models/intra_frame_model/qp121_145_quadtree.h" |
| #include "av1/tflite_models/intra_frame_model/qp146_175_quadtree.h" |
| #include "av1/tflite_models/intra_frame_model/qp176_205_quadtree.h" |
| #include "av1/tflite_models/intra_frame_model/qp206_255_quadtree.h" |
| #endif |
| |
| #include "common/tf_lite_includes.h" |
| |
| #if CONFIG_CNN_RESTORATION |
| |
| #define USE_XNNPACK 0 |
| |
| // Returns the TF-lite model based on the qindex. |
| static const unsigned char *get_intra_model_from_qindex(int qindex, |
| int superres_denom, |
| int is_luma, |
| int cnn_index) { |
| if (qindex <= MIN_CNN_Q_INDEX) { |
| assert(0); |
| return nullptr; |
| } |
| #if CONFIG_EXT_SUPERRES |
| assert(superres_denom == SCALE_NUMERATOR || superres_denom == 10 || |
| superres_denom == 12 || superres_denom == 14 || superres_denom == 16); |
| #else |
| assert(superres_denom == SCALE_NUMERATOR); |
| #endif // CONFIG_EXT_SUPERRES |
| #if CONFIG_CNN_GUIDED_QUADTREE |
| if (superres_denom == SCALE_NUMERATOR) { // quadtree |
| if (is_luma) { |
| if (qindex <= 90) { |
| return (cnn_index == 0) ? qp0_90_quadtree_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_quadtree_model_tflite_data |
| : qp121_145_quadtree_model_tflite_data; |
| } else if (qindex <= 120) { |
| return (cnn_index == 0) ? qp91_120_quadtree_model_tflite_data |
| : (cnn_index == 1) ? qp0_90_quadtree_model_tflite_data |
| : qp121_145_quadtree_model_tflite_data; |
| } else if (qindex <= 145) { |
| return (cnn_index == 0) ? qp121_145_quadtree_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_quadtree_model_tflite_data |
| : qp146_175_quadtree_model_tflite_data; |
| } else if (qindex <= 175) { |
| return (cnn_index == 0) ? qp146_175_quadtree_model_tflite_data |
| : (cnn_index == 1) ? qp121_145_quadtree_model_tflite_data |
| : qp176_205_quadtree_model_tflite_data; |
| } else if (qindex <= 205) { |
| return (cnn_index == 0) ? qp176_205_quadtree_model_tflite_data |
| : (cnn_index == 1) ? qp146_175_quadtree_model_tflite_data |
| : qp206_255_quadtree_model_tflite_data; |
| } else { |
| return (cnn_index == 0) ? qp206_255_quadtree_model_tflite_data |
| : (cnn_index == 1) ? qp176_205_quadtree_model_tflite_data |
| : qp146_175_quadtree_model_tflite_data; |
| } |
| } |
| } |
| #endif // CONFIG_CNN_GUIDED_QUADTREE |
| |
| if (superres_denom == SCALE_NUMERATOR) { |
| if (is_luma) { |
| if (qindex < 91) { |
| return (cnn_index == 0) ? qp0_90_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_model_tflite_data |
| : qp121_145_model_tflite_data; |
| } else if (qindex < 121) { |
| return (cnn_index == 0) ? qp91_120_model_tflite_data |
| : (cnn_index == 1) ? qp0_90_model_tflite_data |
| : qp121_145_model_tflite_data; |
| } else if (qindex < 146) { |
| return (cnn_index == 0) ? qp121_145_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_model_tflite_data |
| : qp146_175_model_tflite_data; |
| } else if (qindex < 176) { |
| return (cnn_index == 0) ? qp146_175_model_tflite_data |
| : (cnn_index == 1) ? qp121_145_model_tflite_data |
| : qp176_205_model_tflite_data; |
| } else if (qindex < 206) { |
| return (cnn_index == 0) ? qp176_205_model_tflite_data |
| : (cnn_index == 1) ? qp146_175_model_tflite_data |
| : qp206_255_model_tflite_data; |
| } else { |
| return (cnn_index == 0) ? qp206_255_model_tflite_data |
| : (cnn_index == 1) ? qp176_205_model_tflite_data |
| : qp146_175_model_tflite_data; |
| } |
| } else { |
| assert(cnn_index == 0); |
| if (qindex < 91) { |
| return uv_qp0_90_model_tflite_data; |
| } else if (qindex < 121) { |
| return uv_qp91_120_model_tflite_data; |
| } else if (qindex < 146) { |
| return uv_qp121_145_model_tflite_data; |
| } else if (qindex < 176) { |
| return uv_qp146_175_model_tflite_data; |
| } else if (qindex < 206) { |
| return uv_qp176_205_model_tflite_data; |
| } else { |
| return uv_qp206_255_model_tflite_data; |
| } |
| } |
| } |
| #if CONFIG_EXT_SUPERRES |
| assert(is_luma); |
| #if SELECT_CNN_FOR_SUPERRES |
| switch (superres_denom) { |
| case 10: |
| return (cnn_index == 0) ? sr5by4ai_1_tflite |
| : (cnn_index == 1) ? sr5by4ai_2_tflite |
| : sr5by4ai_3_tflite; |
| case 12: |
| return (cnn_index == 0) ? sr3by2ai_1_tflite |
| : (cnn_index == 1) ? sr3by2ai_2_tflite |
| : sr3by2ai_3_tflite; |
| case 14: |
| return (cnn_index == 0) ? sr7by4ai_1_tflite |
| : (cnn_index == 1) ? sr7by4ai_2_tflite |
| : sr7by4ai_3_tflite; |
| case 16: |
| return (cnn_index == 0) ? sr2by1ai_1_tflite |
| : (cnn_index == 1) ? sr2by1ai_2_tflite |
| : sr2by1ai_3_tflite; |
| default: assert(0); return nullptr; |
| } |
| #else // SELECT_CNN_FOR_SUPERRES |
| switch (superres_denom) { |
| case 10: |
| if (qindex < 120) |
| return sr5by4ai_1_tflite; |
| else if (qindex < 180) |
| return sr5by4ai_2_tflite; |
| else |
| return sr5by4ai_3_tflite; |
| case 12: |
| if (qindex < 120) |
| return sr3by2ai_1_tflite; |
| else if (qindex < 180) |
| return sr3by2ai_2_tflite; |
| else |
| return sr3by2ai_3_tflite; |
| case 14: |
| if (qindex < 120) |
| return sr7by4ai_1_tflite; |
| else if (qindex < 180) |
| return sr7by4ai_2_tflite; |
| else |
| return sr7by4ai_3_tflite; |
| case 16: |
| if (qindex < 120) |
| return sr2by1ai_1_tflite; |
| else if (qindex < 180) |
| return sr2by1ai_2_tflite; |
| else |
| return sr2by1ai_3_tflite; |
| default: assert(0); return nullptr; |
| } |
| #endif // SELECT_CNN_FOR_SUPERRES |
| #endif // CONFIG_EXT_SUPERRES |
| return nullptr; |
| } |
| |
| // Returns the TF-lite model based on the qindex. |
| static const unsigned char *get_inter_model_from_qindex(int qindex, |
| int superres_denom, |
| int is_luma, |
| int cnn_index) { |
| if (qindex <= MIN_CNN_Q_INDEX) { |
| assert(0); |
| return nullptr; |
| } |
| #if CONFIG_EXT_SUPERRES |
| assert(superres_denom == SCALE_NUMERATOR || superres_denom == 10 || |
| superres_denom == 12 || superres_denom == 14 || superres_denom == 16); |
| #else |
| assert(superres_denom == SCALE_NUMERATOR); |
| #endif // CONFIG_EXT_SUPERRES |
| #if CONFIG_CNN_GUIDED_QUADTREE |
| if (superres_denom == SCALE_NUMERATOR) { // quadtree |
| if (is_luma) { |
| if (qindex <= 90) { |
| return (cnn_index == 0) ? qp0_90_quadtree_inter_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_quadtree_inter_model_tflite_data |
| : qp121_145_quadtree_inter_model_tflite_data; |
| } else if (qindex <= 120) { |
| return (cnn_index == 0) ? qp91_120_quadtree_inter_model_tflite_data |
| : (cnn_index == 1) ? qp0_90_quadtree_inter_model_tflite_data |
| : qp121_145_quadtree_inter_model_tflite_data; |
| } else if (qindex <= 145) { |
| return (cnn_index == 0) ? qp121_145_quadtree_inter_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_quadtree_inter_model_tflite_data |
| : qp146_175_quadtree_inter_model_tflite_data; |
| } else if (qindex <= 175) { |
| return (cnn_index == 0) ? qp146_175_quadtree_inter_model_tflite_data |
| : (cnn_index == 1) ? qp121_145_quadtree_inter_model_tflite_data |
| : qp176_205_quadtree_inter_model_tflite_data; |
| } else if (qindex <= 205) { |
| return (cnn_index == 0) ? qp176_205_quadtree_inter_model_tflite_data |
| : (cnn_index == 1) ? qp146_175_quadtree_inter_model_tflite_data |
| : qp206_255_quadtree_inter_model_tflite_data; |
| } else { |
| return (cnn_index == 0) ? qp206_255_quadtree_inter_model_tflite_data |
| : (cnn_index == 1) ? qp176_205_quadtree_inter_model_tflite_data |
| : qp146_175_quadtree_inter_model_tflite_data; |
| } |
| } |
| } |
| #endif // CONFIG_CNN_GUIDED_QUADTREE |
| |
| if (superres_denom == SCALE_NUMERATOR) { |
| if (is_luma) { |
| if (qindex < 91) { |
| return (cnn_index == 0) ? qp0_90_inter_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_inter_model_tflite_data |
| : qp121_145_inter_model_tflite_data; |
| } else if (qindex < 121) { |
| return (cnn_index == 0) ? qp91_120_inter_model_tflite_data |
| : (cnn_index == 1) ? qp0_90_inter_model_tflite_data |
| : qp121_145_inter_model_tflite_data; |
| } else if (qindex < 146) { |
| return (cnn_index == 0) ? qp121_145_inter_model_tflite_data |
| : (cnn_index == 1) ? qp91_120_inter_model_tflite_data |
| : qp146_175_inter_model_tflite_data; |
| } else if (qindex < 176) { |
| return (cnn_index == 0) ? qp146_175_inter_model_tflite_data |
| : (cnn_index == 1) ? qp121_145_inter_model_tflite_data |
| : qp176_205_inter_model_tflite_data; |
| } else if (qindex < 206) { |
| return (cnn_index == 0) ? qp176_205_inter_model_tflite_data |
| : (cnn_index == 1) ? qp146_175_inter_model_tflite_data |
| : qp206_255_inter_model_tflite_data; |
| } else { |
| return (cnn_index == 0) ? qp206_255_inter_model_tflite_data |
| : (cnn_index == 1) ? qp176_205_inter_model_tflite_data |
| : qp146_175_inter_model_tflite_data; |
| } |
| } else { |
| assert(cnn_index == 0); |
| if (qindex < 91) { |
| return uv_qp0_90_inter_model_tflite_data; |
| } else if (qindex < 121) { |
| return uv_qp91_120_inter_model_tflite_data; |
| } else if (qindex < 146) { |
| return uv_qp121_145_inter_model_tflite_data; |
| } else if (qindex < 176) { |
| return uv_qp146_175_inter_model_tflite_data; |
| } else if (qindex < 206) { |
| return uv_qp176_205_inter_model_tflite_data; |
| } else { |
| return uv_qp206_255_inter_model_tflite_data; |
| } |
| } |
| } |
| |
| #if CONFIG_EXT_SUPERRES |
| assert(is_luma); |
| switch (superres_denom) { |
| case 10: |
| if (qindex < 120) |
| return sr5by4ra_1_tflite; |
| else if (qindex < 180) |
| return sr5by4ra_2_tflite; |
| else |
| return sr5by4ra_3_tflite; |
| case 12: |
| if (qindex < 120) |
| return sr3by2ra_1_tflite; |
| else if (qindex < 180) |
| return sr3by2ra_2_tflite; |
| else |
| return sr3by2ra_3_tflite; |
| case 14: |
| if (qindex < 120) |
| return sr7by4ra_1_tflite; |
| else if (qindex < 180) |
| return sr7by4ra_2_tflite; |
| else |
| return sr7by4ra_3_tflite; |
| case 16: |
| if (qindex < 120) |
| return sr2by1ra_1_tflite; |
| else if (qindex < 180) |
| return sr2by1ra_2_tflite; |
| else |
| return sr2by1ra_3_tflite; |
| default: assert(0); return nullptr; |
| } |
| #endif // CONFIG_EXT_SUPERRES |
| return nullptr; |
| } |
| |
| #if USE_XNNPACK |
| static TfLiteDelegate *get_tflite_xnnpack_delegate(int num_threads) { |
| TfLiteXNNPackDelegateOptions xnnpack_options = |
| TfLiteXNNPackDelegateOptionsDefault(); |
| xnnpack_options.num_threads = AOMMAX(num_threads, 1); |
| return TfLiteXNNPackDelegateCreate(&xnnpack_options); |
| } |
| #endif // USE_XNNPACK |
| |
| // Builds and returns the TFlite interpreter. |
| static std::unique_ptr<tflite::Interpreter> get_tflite_interpreter( |
| int qindex, int superres_denom, int width, int height, int num_threads, |
| int is_intra_only, int is_luma, int cnn_index |
| #if USE_XNNPACK |
| , |
| TfLiteDelegate *xnnpack_delegate |
| #endif // USE_XNNPACK |
| ) { |
| const unsigned char *const model_tflite_data = |
| is_intra_only ? get_intra_model_from_qindex(qindex, superres_denom, |
| is_luma, cnn_index) |
| : get_inter_model_from_qindex(qindex, superres_denom, |
| is_luma, cnn_index); |
| auto model = tflite::GetModel(model_tflite_data); |
| tflite::MutableOpResolver resolver; |
| RegisterSelectedOpsAllQps(&resolver); |
| tflite::InterpreterBuilder builder(model, resolver); |
| // TODO(urvang): Investigate if caching the interpreter object provides |
| // further speed-up. May still have to re-build the interpreter if qindex |
| // changes. |
| std::unique_ptr<tflite::Interpreter> interpreter; |
| builder(&interpreter); |
| interpreter->SetNumThreads(AOMMAX(num_threads, 1)); |
| tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter(); |
| |
| // Dimension order: batch_size, height, width, num_channels. |
| // Note: height comes before width here! |
| const std::vector<int> in_out_dims = { 1, height, width, 1 }; |
| // We only need to resize the input tensor. All other tensors (including |
| // output tensor) will be resized automatically. |
| if (interpreter->ResizeInputTensor(interpreter->inputs()[0], in_out_dims) != |
| kTfLiteOk) { |
| reporter->Report("Failed at input tensor resize"); |
| return nullptr; |
| } |
| |
| if (interpreter->AllocateTensors() != kTfLiteOk) { |
| reporter->Report("Failed at tensor allocation"); |
| return nullptr; |
| } |
| |
| #if USE_XNNPACK |
| if (interpreter->ModifyGraphWithDelegate(xnnpack_delegate) != kTfLiteOk) { |
| reporter->Report("Failed at modifying graph with XNNPack delegate"); |
| return nullptr; |
| } |
| #endif // USE_XNNPACK |
| |
| return interpreter; |
| } |
| |
| extern "C" int av1_restore_cnn_img_tflite_highbd( |
| int qindex, int superres_denom, const uint16_t *dgd, int width, int height, |
| int dgd_stride, uint16_t *rst, int rst_stride, int num_threads, |
| int bit_depth, int is_intra_only, int is_luma, int cnn_index) { |
| // Ensure image can be downscaled by factor of 8 on each axis |
| int padding_width = int(ceil(float(width) / 8.0) * 8); |
| int padding_height = int(ceil(float(height) / 8.0) * 8); |
| #if USE_XNNPACK |
| TfLiteDelegate *xnnpack_delegate = get_tflite_xnnpack_delegate(num_threads); |
| #endif // USE_XNNPACK |
| std::unique_ptr<tflite::Interpreter> interpreter = get_tflite_interpreter( |
| qindex, superres_denom, padding_width, padding_height, num_threads, |
| is_intra_only, is_luma, cnn_index |
| #if USE_XNNPACK |
| , |
| xnnpack_delegate |
| #endif // USE_XNNPACK |
| ); |
| |
| // Prepare input. |
| const auto max_val = static_cast<float>((1 << bit_depth) - 1); |
| const int in_stride = padding_width; |
| auto input = interpreter->typed_input_tensor<float>(0); |
| for (int r = 0; r < padding_height; ++r) { |
| for (int c = 0; c < padding_width; ++c) { |
| if (r < height && c < width) { |
| input[r * in_stride + c] = |
| static_cast<float>(dgd[r * dgd_stride + c]) / max_val; |
| assert(input[r * in_stride + c] >= 0.0f); |
| assert(input[r * in_stride + c] <= 1.0f); |
| } else { |
| // Padding with either zeros or by copies |
| // input[r * in_stride + c] = 0; // Pad with zeros |
| int w_copy_idx = c; |
| if (c >= width) { |
| w_copy_idx = width + (width - c) - 1; |
| } |
| int h_copy_idx = r; |
| if (r >= height) { |
| h_copy_idx = height + (height - r) - 1; |
| } |
| input[r * in_stride + c] = input[h_copy_idx * in_stride + w_copy_idx]; |
| } |
| } |
| } |
| |
| // Invoke TFlite inference. |
| tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter(); |
| auto status = interpreter->Invoke(); |
| if (status != kTfLiteOk) { |
| reporter->Report("Failed at interpreter invocation"); |
| return 0; |
| } |
| |
| // Use the output to restore 'dgd' and store in 'rst'. |
| const auto output = interpreter->typed_output_tensor<float>(0); |
| const int out_stride = width; |
| for (int r = 0; r < height; ++r) { |
| for (int c = 0; c < width; ++c) { |
| const int residue = |
| static_cast<int>(output[r * out_stride + c] * max_val + 0.5); |
| rst[r * rst_stride + c] = |
| clip_pixel_highbd(dgd[r * dgd_stride + c] + residue, bit_depth); |
| } |
| } |
| |
| interpreter.reset(); |
| |
| #if USE_XNNPACK |
| // IMPORTANT: release the interpreter before destroying the delegate. |
| TfLiteXNNPackDelegateDelete(xnnpack_delegate); |
| #endif // USE_XNNPACK |
| |
| return 1; |
| } |
| |
| extern "C" void av1_restore_cnn_tflite(const AV1_COMMON *cm, int num_threads, |
| const int apply_cnn[MAX_MB_PLANE], |
| const int cnn_indices[MAX_MB_PLANE]) { |
| YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf; |
| const int is_intra_only = frame_is_intra_only(cm); |
| for (int plane = 0; plane < av1_num_planes(cm); ++plane) { |
| if (!apply_cnn[plane]) continue; |
| const int is_luma = (plane == AOM_PLANE_Y); |
| const int cnn_index = cnn_indices[plane]; |
| assert(cnn_index >= 0 && |
| cnn_index < av1_num_cnn_indices_for_plane(cm, plane)); |
| switch (plane) { |
| case AOM_PLANE_Y: |
| av1_restore_cnn_img_tflite_highbd( |
| cm->quant_params.base_qindex, cm->superres_scale_denominator, |
| CONVERT_TO_SHORTPTR(buf->y_buffer), buf->y_crop_width, |
| buf->y_crop_height, buf->y_stride, |
| CONVERT_TO_SHORTPTR(buf->y_buffer), buf->y_stride, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| break; |
| case AOM_PLANE_U: |
| av1_restore_cnn_img_tflite_highbd( |
| cm->quant_params.base_qindex, cm->superres_scale_denominator, |
| CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_crop_width, |
| buf->uv_crop_height, buf->uv_stride, |
| CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_stride, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| break; |
| case AOM_PLANE_V: |
| av1_restore_cnn_img_tflite_highbd( |
| cm->quant_params.base_qindex, cm->superres_scale_denominator, |
| CONVERT_TO_SHORTPTR(buf->v_buffer), buf->uv_crop_width, |
| buf->uv_crop_height, buf->uv_stride, |
| CONVERT_TO_SHORTPTR(buf->v_buffer), buf->uv_stride, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| break; |
| default: assert(0 && "Invalid plane index"); |
| } |
| } |
| } |
| |
| #if CONFIG_CNN_GUIDED_QUADTREE |
| |
| // ------------------- Guided Quadtree: Common -------------------------------// |
| |
| // Given single-channel input in 'dgd', generate intermediate 2-channel CNN |
| // output 'interm'. |
| static int generate_interm_guided_restoration( |
| const uint16_t *dgd, int dgd_stride, int qindex, int superres_denom, |
| int width, int height, int num_threads, int is_intra_only, int is_luma, |
| int cnn_index, int bit_depth, |
| std::vector<std::vector<std::vector<double>>> &interm) { |
| // Make sure we can downscale 4 times. |
| const int padding_width = (int)ceil(width * 1.0 / 16) * 16; |
| const int padding_height = (int)ceil(height * 1.0 / 16) * 16; |
| #if USE_XNNPACK |
| TfLiteDelegate *xnnpack_delegate = get_tflite_xnnpack_delegate(num_threads); |
| #endif // USE_XNNPACK |
| std::unique_ptr<tflite::Interpreter> interpreter = get_tflite_interpreter( |
| qindex, superres_denom, padding_width, padding_height, num_threads, |
| is_intra_only, is_luma, cnn_index |
| #if USE_XNNPACK |
| , |
| xnnpack_delegate |
| #endif // USE_XNNPACK |
| ); |
| |
| // Prepare input. |
| const auto max_val = static_cast<float>((1 << bit_depth) - 1); |
| const int in_stride = padding_width; |
| auto input = interpreter->typed_input_tensor<float>(0); |
| for (int r = 0; r < padding_height; ++r) { |
| for (int c = 0; c < padding_width; ++c) { |
| if (r < height && c < width) { |
| input[r * in_stride + c] = |
| static_cast<float>(dgd[r * dgd_stride + c]) / max_val; |
| assert(input[r * in_stride + c] >= 0.0f); |
| assert(input[r * in_stride + c] <= 1.0f); |
| } else { |
| input[r * in_stride + c] = |
| static_cast<float>(dgd[AOMMIN(r, height - 1) * dgd_stride + |
| AOMMIN(c, width - 1)]) / |
| max_val; |
| } |
| } |
| } |
| |
| // Invoke TFlite inference. |
| tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter(); |
| auto status = interpreter->Invoke(); |
| if (status != kTfLiteOk) { |
| reporter->Report("Failed at interpreter invocation"); |
| return 0; |
| } |
| |
| // Store the output in 'interm'. |
| const auto output = interpreter->typed_output_tensor<float>(0); |
| const int out_stride = padding_width; |
| |
| for (int r = 0; r < height; ++r) { |
| for (int c = 0; c < width; ++c) { |
| interm[r][c][0] = output[r * 2 * out_stride + c * 2] * max_val; |
| interm[r][c][1] = output[r * 2 * out_stride + c * 2 + 1] * max_val; |
| } |
| } |
| |
| // Cleanup. |
| interpreter.reset(); |
| #if USE_XNNPACK |
| // IMPORTANT: release the interpreter before destroying the delegate. |
| TfLiteXNNPackDelegateDelete(xnnpack_delegate); |
| #endif // USE_XNNPACK |
| return 1; |
| } |
| |
| // Get unit width and height based on max size and partition type. |
| static void get_unit_size(int max_unit_width, int max_unit_height, |
| GuidedQuadTreePartitionType partition_type, |
| int *unit_width, int *unit_height) { |
| assert(partition_type >= 0 && partition_type < GUIDED_QT_TYPES); |
| *unit_width = |
| (partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_HORZ) |
| ? max_unit_width |
| : max_unit_width >> 1; |
| *unit_height = |
| (partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_VERT) |
| ? max_unit_height |
| : max_unit_height >> 1; |
| } |
| |
| // ------------------- Guided Quadtree: Encoder ------------------------------// |
| |
| // Given 2-channel intermediate output 'interm', degraded frame 'dgd' and source |
| // frame 'src', generates the single-channel output 'out' and corresponding |
| // linear combination weight pairs 'a'. |
| // Assumes that `width x height` area needs to be combined using unit of size |
| // `unit_width x unit_height`. |
| static void generate_linear_combination( |
| const std::vector<std::vector<std::vector<double>>> &interm, |
| const uint16_t *src, int src_stride, const uint16_t *dgd, int dgd_stride, |
| int start_row, int end_row, int start_col, int end_col, int unit_width, |
| int unit_height, const int *quadtset, int rdmult, const int *norestorecost, |
| int bit_depth, std::vector<std::vector<uint16_t>> &out, |
| std::vector<std::pair<int, int>> &A) { |
| const int scale0 = quadtset[0]; |
| const int scale1 = quadtset[1]; |
| const int A0_min = quadtset[2]; |
| const int A1_min = quadtset[3]; |
| |
| for (int row = start_row; row < end_row; row += unit_height) { |
| const int this_start_row = row; |
| const int this_end_row = AOMMIN(row + unit_height, end_row); |
| for (int col = start_col; col < end_col; col += unit_width) { |
| const int this_start_col = col; |
| const int this_end_col = AOMMIN(col + unit_width, end_col); |
| const int num_pixels = |
| (this_end_row - this_start_row) * (this_end_col - this_start_col); |
| |
| // Extract some flattened arrays. |
| std::vector<int> sub_r_flatten; |
| sub_r_flatten.reserve(num_pixels); |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| sub_r_flatten.push_back(src[i * src_stride + j] - |
| dgd[i * dgd_stride + j]); |
| } |
| } |
| assert((int)sub_r_flatten.size() == num_pixels); |
| |
| std::vector<double> sub_r0; |
| sub_r0.reserve(num_pixels); |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| sub_r0.push_back(interm[i][j][0]); |
| } |
| } |
| assert((int)sub_r0.size() == num_pixels); |
| |
| std::vector<double> sub_r1; |
| sub_r1.reserve(num_pixels); |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| sub_r1.push_back(interm[i][j][1]); |
| } |
| } |
| assert((int)sub_r1.size() == num_pixels); |
| |
| // Get R. |
| std::vector<std::vector<double>> R(num_pixels, std::vector<double>(2)); |
| for (int i = 0; i < num_pixels; i++) { |
| R[i][0] = sub_r0[i]; |
| R[i][1] = sub_r1[i]; |
| } |
| |
| // Get R^T. |
| std::vector<std::vector<double>> R_T(2, std::vector<double>(num_pixels)); |
| for (int i = 0; i < num_pixels; i++) { |
| R_T[0][i] = sub_r0[i]; |
| R_T[1][i] = sub_r1[i]; |
| } |
| |
| // Get R^T * R. |
| double R_TDotR[2][2] = { 0 }; |
| for (int i = 0; i < 2; i++) { |
| for (int j = 0; j < 2; j++) { |
| for (int k = 0; k < num_pixels; k++) { |
| R_TDotR[i][j] += R_T[i][k] * R[k][j]; |
| } |
| } |
| } |
| |
| // Get (R^T * R)^-1. |
| const double value_R_TDotR = |
| R_TDotR[0][0] * R_TDotR[1][1] - R_TDotR[0][1] * R_TDotR[1][0]; |
| |
| double R_TDotR_inver[2][2] = { |
| { R_TDotR[1][1] / value_R_TDotR, -1 * R_TDotR[0][1] / value_R_TDotR }, |
| { -1 * R_TDotR[1][0] / value_R_TDotR, R_TDotR[0][0] / value_R_TDotR } |
| }; |
| |
| // Get (R^T * R)^-1 * R^T. |
| std::vector<std::vector<double>> mid(2, std::vector<double>(num_pixels)); |
| for (int j = 0; j < num_pixels; j++) { |
| mid[0][j] = |
| R_TDotR_inver[0][0] * R_T[0][j] + R_TDotR_inver[0][1] * R_T[1][j]; |
| mid[1][j] = |
| R_TDotR_inver[1][0] * R_T[0][j] + R_TDotR_inver[1][1] * R_T[1][j]; |
| } |
| |
| // Compute A = (R^T * R)^-1 * R^T * residual. |
| double A0 = 0; |
| double A1 = 0; |
| for (int i = 0; i < num_pixels; i++) { |
| A0 += mid[0][i] * sub_r_flatten[i]; |
| A1 += mid[1][i] * sub_r_flatten[i]; |
| } |
| A0 = A0 * scale0; |
| A1 = A1 * scale1; |
| |
| // Do a finer search for best A0, A1 pair amongst four options: |
| // (1) A0_floor = floor(A0), A1_floor = floor(A1) |
| // (2) A0_floor, A1_floor + 1 |
| // (3) A0_floor + 1, A1_floor |
| // (4) A0_floor + 1, A1_floor + 1 |
| const bool do_finer_search = true; |
| if (do_finer_search) { |
| double bestA0 = 0; |
| double bestA1 = 0; |
| double cost; |
| int64_t err = 0; |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| const int diff = src[i * src_stride + j] - dgd[i * dgd_stride + j]; |
| err += diff * diff; |
| } |
| } |
| double bestcost = RDCOST_DBL_WITH_NATIVE_BD_DIST( |
| rdmult, norestorecost[1] >> 4, err, bit_depth); |
| |
| // finer search |
| double flrA0 = (floor(A0)); |
| double flrA1 = (floor(A1)); |
| flrA0 = AOMMIN(AOMMAX(flrA0, A0_min), A0_min + GUIDED_A_RANGE); |
| flrA1 = AOMMIN(AOMMAX(flrA1, A1_min), A1_min + GUIDED_A_RANGE); |
| { |
| A0 = flrA0; |
| A1 = flrA1; |
| err = 0; |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| int rest = int(round(dgd[i * dgd_stride + j] + |
| A0 * interm[i][j][0] / scale0 + |
| A1 * interm[i][j][1] / scale1)); |
| rest = clip_pixel_highbd(rest, bit_depth); |
| const int diff = src[i * src_stride + j] - rest; |
| err += diff * diff; |
| } |
| } |
| // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair |
| cost = RDCOST_DBL_WITH_NATIVE_BD_DIST( |
| rdmult, |
| (norestorecost[0] + |
| (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >> |
| 4, |
| err, bit_depth); |
| if (cost < bestcost) { |
| bestA0 = A0; |
| bestA1 = A1; |
| bestcost = cost; |
| } |
| } |
| if (flrA0 < A0_min + GUIDED_A_RANGE) { |
| A0 = flrA0 + 1; |
| A1 = flrA1; |
| err = 0; |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| int rest = int(round(dgd[i * dgd_stride + j] + |
| A0 * interm[i][j][0] / scale0 + |
| A1 * interm[i][j][1] / scale1)); |
| rest = clip_pixel_highbd(rest, bit_depth); |
| const int diff = src[i * src_stride + j] - rest; |
| err += diff * diff; |
| } |
| } |
| // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair |
| cost = RDCOST_DBL_WITH_NATIVE_BD_DIST( |
| rdmult, |
| (norestorecost[0] + |
| (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >> |
| 4, |
| err, bit_depth); |
| if (cost < bestcost) { |
| bestA0 = A0; |
| bestA1 = A1; |
| bestcost = cost; |
| } |
| } |
| if (flrA1 < A1_min + GUIDED_A_RANGE) { |
| A0 = flrA0; |
| A1 = flrA1 + 1; |
| err = 0; |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| int rest = int(round(dgd[i * dgd_stride + j] + |
| A0 * interm[i][j][0] / scale0 + |
| A1 * interm[i][j][1] / scale1)); |
| rest = clip_pixel_highbd(rest, bit_depth); |
| const int diff = src[i * src_stride + j] - rest; |
| err += diff * diff; |
| } |
| } |
| // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair |
| cost = RDCOST_DBL_WITH_NATIVE_BD_DIST( |
| rdmult, |
| (norestorecost[0] + |
| (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >> |
| 4, |
| err, bit_depth); |
| if (cost < bestcost) { |
| bestA0 = A0; |
| bestA1 = A1; |
| bestcost = cost; |
| } |
| } |
| if (flrA0 < A0_min + GUIDED_A_RANGE && |
| flrA1 < A1_min + GUIDED_A_RANGE) { |
| A0 = flrA0 + 1; |
| A1 = flrA1 + 1; |
| err = 0; |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| int rest = int(round(dgd[i * dgd_stride + j] + |
| A0 * interm[i][j][0] / scale0 + |
| A1 * interm[i][j][1] / scale1)); |
| rest = clip_pixel_highbd(rest, bit_depth); |
| const int diff = src[i * src_stride + j] - rest; |
| err += diff * diff; |
| } |
| } |
| // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair |
| cost = RDCOST_DBL_WITH_NATIVE_BD_DIST( |
| rdmult, |
| (norestorecost[0] + |
| (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >> |
| 4, |
| err, bit_depth); |
| if (cost < bestcost) { |
| bestA0 = A0; |
| bestA1 = A1; |
| bestcost = cost; |
| } |
| } |
| A0 = bestA0; |
| A1 = bestA1; |
| } else { |
| A0 = (round(A0)); |
| A1 = (round(A1)); |
| A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + GUIDED_A_RANGE); |
| A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + GUIDED_A_RANGE); |
| } |
| |
| A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + GUIDED_A_RANGE); |
| A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + GUIDED_A_RANGE); |
| A.emplace_back((int)A0, (int)A1); |
| for (int i = this_start_row; i < this_end_row; i++) { |
| for (int j = this_start_col; j < this_end_col; j++) { |
| const int out_unclipped = int(round(dgd[i * dgd_stride + j] + |
| A0 * interm[i][j][0] / scale0 + |
| A1 * interm[i][j][1] / scale1)); |
| out[i - start_row][j - start_col] = |
| clip_pixel_highbd(out_unclipped, bit_depth); |
| } |
| } |
| } |
| } |
| #ifndef NDEBUG |
| const auto num_units_row = |
| (size_t)ceil((double)(end_row - start_row) / unit_height); |
| const auto num_units_col = |
| (size_t)ceil((double)(end_col - start_col) / unit_width); |
| assert(A.size() == num_units_row * num_units_col); |
| #endif // NDEBUG |
| } |
| |
| // Computes SSE between 'rst' and 'src'. |
| static int64_t compute_sse(const std::vector<std::vector<uint16_t>> &rst, |
| const uint16_t *src, int src_stride, int start_row, |
| int end_row, int start_col, int end_col) { |
| int64_t sse = 0; |
| for (int r = start_row; r < end_row; ++r) { |
| for (int c = start_col; c < end_col; ++c) { |
| const uint16_t this_rst = rst[r - start_row][c - start_col]; |
| const uint16_t this_src = src[r * src_stride + c]; |
| const int64_t diff = (int64_t)(this_rst - this_src); |
| sse += diff * diff; |
| } |
| } |
| return sse; |
| } |
| |
| // Computes bitrate for the given weight parameters. |
| static int compute_rate(const std::vector<std::pair<int, int>> &A, |
| const std::pair<int, int> &prev_A, const int *quadtset, |
| const int *norestorecosts) { |
| const int A0_min = quadtset[2]; |
| const int A1_min = quadtset[3]; |
| int num_bits = 0; |
| int ref0 = AOMMIN(AOMMAX(prev_A.first - A0_min, 0), GUIDED_A_RANGE); |
| int ref1 = AOMMIN(AOMMAX(prev_A.second - A1_min, 0), GUIDED_A_RANGE); |
| for (auto &this_A : A) { |
| if (this_A.first == 0 && this_A.second == 0) { |
| num_bits += norestorecosts[1]; |
| } else { |
| num_bits += norestorecosts[0]; |
| num_bits += (aom_count_primitive_refsubexpfin( |
| GUIDED_A_NUM_VALUES, 1, ref0, this_A.first - A0_min) + |
| aom_count_primitive_refsubexpfin( |
| GUIDED_A_NUM_VALUES, 1, ref1, this_A.second - A1_min)) |
| << AV1_PROB_COST_SHIFT; |
| } |
| ref0 = AOMMIN(AOMMAX(this_A.first - A0_min, 0), GUIDED_A_RANGE); |
| ref1 = AOMMIN(AOMMAX(this_A.second - A1_min, 0), GUIDED_A_RANGE); |
| } |
| return num_bits; |
| } |
| |
| // Given 2-channel intermediate output in 'interm' as well as 'src' and 'dgd' |
| // buffers, tries the given partition type on a single quadtree unit. Outputs |
| // the RDCost in 'this_rdcost' and restored unit in 'out'. |
| static void try_one_partition( |
| const std::vector<std::vector<std::vector<double>>> &interm, |
| GuidedQuadTreePartitionType partition_type, const uint16_t *src, |
| int src_stride, const uint16_t *dgd, int dgd_stride, int start_row, |
| int end_row, int start_col, int end_col, int max_unit_width, |
| int max_unit_height, const int *quadtset, int rdmult, |
| const std::pair<int, int> &prev_A, const int *quad_split_costs, |
| const int *binary_split_costs, const int *norestorecosts, int bit_depth, |
| bool is_horz_partitioning_allowed, int is_vert_partitioning_allowed, |
| double *this_rdcost, std::vector<std::vector<uint16_t>> &out, |
| std::vector<std::pair<int, int>> &A) { |
| assert(IMPLIES( |
| !is_horz_partitioning_allowed, |
| partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_VERT)); |
| assert(IMPLIES( |
| !is_vert_partitioning_allowed, |
| partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_HORZ)); |
| // Get unit width and height based on partition type. |
| int unit_width; |
| int unit_height; |
| get_unit_size(max_unit_width, max_unit_height, partition_type, &unit_width, |
| &unit_height); |
| |
| // Compute restored unit, a0 and a1. |
| generate_linear_combination(interm, src, src_stride, dgd, dgd_stride, |
| start_row, end_row, start_col, end_col, |
| unit_width, unit_height, quadtset, rdmult, |
| norestorecosts, bit_depth, out, A); |
| assert(IMPLIES(partition_type == GUIDED_QT_NONE, A.size() == 1)); |
| assert(IMPLIES(partition_type == GUIDED_QT_HORZ, A.size() == 2)); |
| assert(IMPLIES(partition_type == GUIDED_QT_VERT, A.size() == 2)); |
| assert(IMPLIES(partition_type == GUIDED_QT_SPLIT, A.size() == 4)); |
| |
| // Compute SSE. |
| const int64_t sse = |
| compute_sse(out, src, src_stride, start_row, end_row, start_col, end_col); |
| |
| // Compute Rate. |
| const int a_signaling_cost = |
| compute_rate(A, prev_A, quadtset, norestorecosts); |
| // Partition signaling cost depending on 1, 2 or 4 possible partition types. |
| const int partition_signaling_cost = |
| is_horz_partitioning_allowed && is_vert_partitioning_allowed |
| ? quad_split_costs[partition_type] |
| : (is_horz_partitioning_allowed || is_vert_partitioning_allowed) |
| ? binary_split_costs[partition_type] |
| : 0; |
| const int bitrate = a_signaling_cost + partition_signaling_cost; |
| |
| // Compute RDCost. |
| *this_rdcost = |
| RDCOST_DBL_WITH_NATIVE_BD_DIST(rdmult, bitrate >> 4, sse, bit_depth); |
| } |
| |
| // Given intermediate restoration 'interm', source 'src' and degradade frame |
| // 'dgd', computes the best partitioning out of NONE, SPLIT, HORZ and VERT based |
| // on RD cost for the widthxheight unit starting at 'row' and 'col'. |
| // The split decisions are stored in 'split' and a0,a1 pairs are stored in 'A'. |
| static void select_quadtree_partitioning( |
| const std::vector<std::vector<std::vector<double>>> &interm, |
| const uint16_t *src, int src_stride, int start_row, int start_col, |
| int width, int height, int quadtree_max_size, int max_unit_width, |
| int max_unit_height, const int *quadtset, int rdmult, |
| const std::pair<int, int> &prev_A, const int *quad_split_costs, |
| const int *binary_split_costs, const int norestorecosts[2], int bit_depth, |
| const uint16_t *dgd, int dgd_stride, std::vector<int> &split, |
| std::vector<std::pair<int, int>> &A, double *rdcost) { |
| const int end_row = AOMMIN(start_row + max_unit_height, height); |
| const int end_col = AOMMIN(start_col + max_unit_width, width); |
| // Check for special cases near boundary. |
| const bool is_horz_partitioning_allowed = |
| (max_unit_height >= quadtree_max_size); |
| const bool is_vert_partitioning_allowed = |
| (max_unit_width >= quadtree_max_size); |
| const bool is_split_partitioning_allowed = |
| is_horz_partitioning_allowed && is_vert_partitioning_allowed; |
| |
| auto best_rdcost = DBL_MAX; |
| std::vector<std::pair<int, int>> best_A; |
| std::vector<std::vector<uint16_t>> best_out( |
| max_unit_height, std::vector<uint16_t>(max_unit_width)); |
| GuidedQuadTreePartitionType best_partition_type = GUIDED_QT_INVALID; |
| |
| for (int type = 0; type < GUIDED_QT_TYPES; ++type) { |
| const auto this_partition_type = (GuidedQuadTreePartitionType)type; |
| // Check for special cases near boundary. |
| if (!is_horz_partitioning_allowed && |
| (this_partition_type == GUIDED_QT_HORZ)) { |
| continue; |
| } |
| if (!is_vert_partitioning_allowed && |
| (this_partition_type == GUIDED_QT_VERT)) { |
| continue; |
| } |
| if (!is_split_partitioning_allowed && |
| (this_partition_type == GUIDED_QT_SPLIT)) { |
| continue; |
| } |
| // Try this partition type. |
| double this_rdcost; |
| std::vector<std::pair<int, int>> this_A; |
| std::vector<std::vector<uint16_t>> this_out( |
| max_unit_height, std::vector<uint16_t>(max_unit_width)); |
| try_one_partition( |
| interm, this_partition_type, src, src_stride, dgd, dgd_stride, |
| start_row, end_row, start_col, end_col, max_unit_width, max_unit_height, |
| quadtset, rdmult, prev_A, quad_split_costs, binary_split_costs, |
| norestorecosts, bit_depth, is_horz_partitioning_allowed, |
| is_vert_partitioning_allowed, &this_rdcost, this_out, this_A); |
| if (this_rdcost < best_rdcost) { |
| best_rdcost = this_rdcost; |
| best_A = this_A; |
| best_out = this_out; |
| best_partition_type = this_partition_type; |
| } |
| } |
| |
| // Save RDCost. |
| *rdcost = best_rdcost; |
| |
| // Save a0, a1 pairs. |
| for (auto &a0a1 : best_A) { |
| A.push_back(a0a1); |
| } |
| |
| // Save split decision. |
| if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) { |
| // Nothing should be added to 'split' array. |
| assert(best_partition_type == GUIDED_QT_NONE); |
| return; |
| } |
| assert(best_partition_type >= 0 && best_partition_type < GUIDED_QT_TYPES); |
| split.push_back(best_partition_type); |
| } |
| |
| static void apply_quadtree_partitioning( |
| const std::vector<std::vector<std::vector<double>>> &interm, int start_row, |
| int start_col, int width, int height, int quadtree_max_size, |
| int max_unit_width, int max_unit_height, const int *quadtset, int bit_depth, |
| const std::vector<int> &split, size_t &split_index, |
| const std::vector<std::pair<int, int>> &A, size_t &A_index, uint16_t *dgd, |
| int dgd_stride); |
| |
| // Top-level function to apply guided restoration on encoder side. |
| static int restore_cnn_quadtree_encode_img_tflite_highbd( |
| YV12_BUFFER_CONFIG *source_frame, AV1_COMMON *cm, int superres_denom, |
| int rdmult, const int *quad_split_costs, const int *binary_split_costs, |
| int (*norestorecosts)[2], int num_threads, int bit_depth, int is_intra_only, |
| int is_luma, int cnn_index, QUADInfo *quad_info, double *rdcost) { |
| YV12_BUFFER_CONFIG *dgd_buf = &cm->cur_frame->buf; |
| uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd_buf->y_buffer); |
| const int dgd_stride = dgd_buf->y_stride; |
| const int qindex = cm->quant_params.base_qindex; |
| const int width = cm->superres_upscaled_width; |
| const int height = cm->superres_upscaled_height; |
| |
| // Get 2-channel intermediate restoration. |
| std::vector<std::vector<std::vector<double>>> interm( |
| height, std::vector<std::vector<double>>(width, std::vector<double>(2))); |
| if (!generate_interm_guided_restoration( |
| dgd, dgd_stride, qindex, superres_denom, width, height, num_threads, |
| is_intra_only, is_luma, cnn_index, bit_depth, interm)) { |
| return 0; |
| } |
| |
| // Initialization. |
| const uint16_t *src = CONVERT_TO_SHORTPTR(source_frame->y_buffer); |
| const int src_stride = source_frame->y_stride; |
| const int *quadtset = get_quadparm_from_qindex( |
| qindex, superres_denom, is_intra_only, is_luma, cnn_index); |
| const int A0_min = quadtset[2]; |
| const int A1_min = quadtset[3]; |
| const int norestore_ctx = |
| get_guided_norestore_ctx(qindex, superres_denom, is_intra_only); |
| const int null_norestorecosts[2] = { 0, 0 }; |
| const int *this_norestorecosts = |
| norestore_ctx == -1 ? null_norestorecosts : norestorecosts[norestore_ctx]; |
| |
| // Try all possible quadtree unit sizes. |
| int best_unit_index = -1; |
| std::vector<int> best_split; // selected partitioning options. |
| std::vector<std::pair<int, int>> best_A; // selected a0, a1 weight pairs. |
| double best_rdcost_total = DBL_MAX; |
| for (int this_unit_index = 0; this_unit_index < GUIDED_QT_UNIT_SIZES; |
| ++this_unit_index) { |
| const int quadtree_max_size = |
| quad_tree_get_unit_size(width, height, this_unit_index); |
| // For each quadtree unit, compute the best partitioning out of |
| // NONE, SPLIT, HORZ and VERT based on RD cost. |
| std::vector<int> this_split; // selected partitioning options. |
| std::vector<std::pair<int, int>> this_A; // selected a0, a1 weight pairs. |
| double this_rdcost_total = 0.0; |
| // Previous a0, a1 pair is mid-point of the range by default. |
| std::pair<int, int> prev_A = |
| std::make_pair(GUIDED_A_MID + A0_min, GUIDED_A_MID + A1_min); |
| const int ext_size = quadtree_max_size * 3 / 2; |
| for (int row = 0; row < height;) { |
| const int remaining_height = height - row; |
| const int this_unit_height = |
| (remaining_height < ext_size) ? remaining_height : quadtree_max_size; |
| for (int col = 0; col < width;) { |
| const int remaining_width = width - col; |
| const int this_unit_width = |
| (remaining_width < ext_size) ? remaining_width : quadtree_max_size; |
| double this_rdcost; |
| select_quadtree_partitioning( |
| interm, src, src_stride, row, col, width, height, quadtree_max_size, |
| this_unit_width, this_unit_height, quadtset, rdmult, prev_A, |
| quad_split_costs, binary_split_costs, this_norestorecosts, |
| bit_depth, dgd, dgd_stride, this_split, this_A, &this_rdcost); |
| // updates. |
| this_rdcost_total += this_rdcost; |
| prev_A = this_A.back(); |
| col += this_unit_width; |
| } |
| row += this_unit_height; |
| } |
| // Update best options. |
| if (this_rdcost_total < best_rdcost_total) { |
| best_unit_index = this_unit_index; |
| best_split = this_split; |
| best_A = this_A; |
| best_rdcost_total = this_rdcost_total; |
| } |
| } |
| |
| // Fill in the best options. |
| quad_info->unit_index = best_unit_index; |
| quad_info->split_info_length = (int)best_split.size(); |
| quad_info->unit_info_length = (int)best_A.size(); |
| av1_alloc_quadtree_struct(cm, quad_info); |
| for (unsigned int i = 0; i < best_split.size(); ++i) { |
| quad_info->split_info[i].split = best_split[i]; |
| } |
| for (unsigned int i = 0; i < best_A.size(); ++i) { |
| quad_info->unit_info[i].xqd[0] = best_A[i].first; |
| quad_info->unit_info[i].xqd[1] = best_A[i].second; |
| } |
| *rdcost = best_rdcost_total; |
| |
| // Apply guided restoration to 'dgd' using best options above. |
| size_t split_index = 0; |
| size_t A_index = 0; |
| const int quadtree_max_size = quad_info->unit_size; |
| const int ext_size = quadtree_max_size * 3 / 2; |
| for (int row = 0; row < height;) { |
| const int remaining_height = height - row; |
| const int this_unit_height = |
| (remaining_height < ext_size) ? remaining_height : quadtree_max_size; |
| for (int col = 0; col < width;) { |
| const int remaining_width = width - col; |
| const int this_unit_width = |
| (remaining_width < ext_size) ? remaining_width : quadtree_max_size; |
| apply_quadtree_partitioning( |
| interm, row, col, width, height, quadtree_max_size, this_unit_width, |
| this_unit_height, quadtset, bit_depth, best_split, split_index, |
| best_A, A_index, dgd, dgd_stride); |
| col += this_unit_width; |
| } |
| row += this_unit_height; |
| } |
| |
| return 1; |
| } |
| |
| extern "C" int av1_restore_cnn_quadtree_encode_tflite( |
| struct AV1Common *cm, YV12_BUFFER_CONFIG *source_frame, int RDMULT, |
| int *quad_split_costs, int *binary_split_costs, int (*norestorecosts)[2], |
| int num_threads, const int apply_cnn[MAX_MB_PLANE], |
| const int cnn_indices[MAX_MB_PLANE], QUADInfo *quad_info, double *rdcost) { |
| YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf; |
| const int is_intra_only = frame_is_intra_only(cm); |
| for (int plane = 0; plane < av1_num_planes(cm); ++plane) { |
| if (!apply_cnn[plane]) continue; |
| const int is_luma = (plane == AOM_PLANE_Y); |
| const int cnn_index = cnn_indices[plane]; |
| assert(cnn_index >= 0 && |
| cnn_index < av1_num_cnn_indices_for_plane(cm, plane)); |
| int ret = 1; |
| switch (plane) { |
| case AOM_PLANE_Y: |
| ret = restore_cnn_quadtree_encode_img_tflite_highbd( |
| source_frame, cm, cm->superres_scale_denominator, RDMULT, |
| quad_split_costs, binary_split_costs, norestorecosts, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index, |
| quad_info, rdcost); |
| if (ret == 0) return ret; |
| break; |
| case AOM_PLANE_U: |
| ret = av1_restore_cnn_img_tflite_highbd( |
| cm->quant_params.base_qindex, cm->superres_scale_denominator, |
| CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_crop_width, |
| buf->uv_crop_height, buf->uv_stride, |
| CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_stride, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| if (ret == 0) return ret; |
| break; |
| case AOM_PLANE_V: |
| ret = av1_restore_cnn_img_tflite_highbd( |
| cm->quant_params.base_qindex, cm->superres_scale_denominator, |
| CONVERT_TO_SHORTPTR(buf->v_buffer), buf->uv_crop_width, |
| buf->uv_crop_height, buf->uv_stride, |
| CONVERT_TO_SHORTPTR(buf->v_buffer), buf->uv_stride, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| if (ret == 0) return ret; |
| break; |
| default: assert(0 && "Invalid plane index"); return 0; |
| } |
| } |
| return 1; |
| } |
| |
| // ------------------- Guided Quadtree: Decoder ------------------------------// |
| |
| // Given the 2-channel intermediate output in 'interm' and weight parameters, |
| // restores one quadtree unit in 'dgd'. |
| static void apply_linear_combination( |
| const std::vector<std::vector<std::vector<double>>> &interm, int start_row, |
| int end_row, int start_col, int end_col, int unit_width, int unit_height, |
| const int *quadtset, int bit_depth, |
| const std::vector<std::pair<int, int>> &A, size_t &A_index, uint16_t *dgd, |
| int dgd_stride) { |
| // Get scale parameters. |
| const int scale0 = quadtset[0]; |
| const int scale1 = quadtset[1]; |
| |
| for (int row = start_row; row < end_row; row += unit_height) { |
| const int this_start_row = row; |
| const int this_end_row = AOMMIN(row + unit_height, end_row); |
| for (int col = start_col; col < end_col; col += unit_width) { |
| const int this_start_col = col; |
| const int this_end_col = AOMMIN(col + unit_width, end_col); |
| |
| // Get weight parameters for this unit. |
| const auto this_A = A[A_index++]; |
| const int a0 = this_A.first; |
| const int a1 = this_A.second; |
| |
| // Restore this unit. |
| for (int r = this_start_row; r < this_end_row; ++r) { |
| for (int c = this_start_col; c < this_end_col; ++c) { |
| const int dgd_unclipped = int(round(dgd[r * dgd_stride + c] + |
| a0 * interm[r][c][0] / scale0 + |
| a1 * interm[r][c][1] / scale1)); |
| dgd[r * dgd_stride + c] = clip_pixel_highbd(dgd_unclipped, bit_depth); |
| } |
| } |
| } |
| } |
| } |
| |
| // Given intermediate restoration 'interm', quadtree partitioning info 'split' |
| // and weight parameters 'A', restores the unit starting at 'row' and 'col' |
| // inside 'dgd'. |
| static void apply_quadtree_partitioning( |
| const std::vector<std::vector<std::vector<double>>> &interm, int start_row, |
| int start_col, int width, int height, int quadtree_max_size, |
| int max_unit_width, int max_unit_height, const int *quadtset, int bit_depth, |
| const std::vector<int> &split, size_t &split_index, |
| const std::vector<std::pair<int, int>> &A, size_t &A_index, uint16_t *dgd, |
| int dgd_stride) { |
| const int end_row = AOMMIN(start_row + max_unit_height, height); |
| const int end_col = AOMMIN(start_col + max_unit_width, width); |
| // Check for special cases near boundary. |
| const bool is_horz_partitioning_allowed = |
| (max_unit_height >= quadtree_max_size); |
| const bool is_vert_partitioning_allowed = |
| (max_unit_width >= quadtree_max_size); |
| |
| // Get partition type. |
| GuidedQuadTreePartitionType partition_type = GUIDED_QT_NONE; |
| if (is_horz_partitioning_allowed || is_vert_partitioning_allowed) { |
| partition_type = (GuidedQuadTreePartitionType)split[split_index++]; |
| } |
| assert(partition_type >= 0 && partition_type < GUIDED_QT_TYPES); |
| |
| // Get unit width and height based on partition type. |
| int unit_width; |
| int unit_height; |
| get_unit_size(max_unit_width, max_unit_height, partition_type, &unit_width, |
| &unit_height); |
| |
| // Compute restored unit, a0 and a1 with given A parameters. |
| apply_linear_combination(interm, start_row, end_row, start_col, end_col, |
| unit_width, unit_height, quadtset, bit_depth, A, |
| A_index, dgd, dgd_stride); |
| } |
| |
| // Top-level function to apply guided restoration on decoder side. |
| static int restore_cnn_quadtree_decode_img_tflite_highbd( |
| AV1_COMMON *cm, int superres_denom, int num_threads, int bit_depth, |
| int is_intra_only, int is_luma, int cnn_index) { |
| YV12_BUFFER_CONFIG *dgd_buf = &cm->cur_frame->buf; |
| uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd_buf->y_buffer); |
| const int dgd_stride = dgd_buf->y_stride; |
| const int qindex = cm->quant_params.base_qindex; |
| const int width = cm->superres_upscaled_width; |
| const int height = cm->superres_upscaled_height; |
| |
| // Get 2-channel intermediate restoration. |
| std::vector<std::vector<std::vector<double>>> interm( |
| height, std::vector<std::vector<double>>(width, std::vector<double>(2))); |
| if (!generate_interm_guided_restoration( |
| dgd, dgd_stride, qindex, superres_denom, width, height, num_threads, |
| is_intra_only, is_luma, cnn_index, bit_depth, interm)) { |
| return 0; |
| } |
| |
| // Get quadtree params. |
| const QUADInfo *const quad_info = &cm->cnn_quad_info; |
| const int quadtree_max_size = quad_info->unit_size; |
| const int *quadtset = get_quadparm_from_qindex( |
| qindex, superres_denom, is_intra_only, is_luma, cnn_index); |
| |
| // Get partitioning types. |
| std::vector<int> split; |
| split.reserve(quad_info->split_info_length); |
| for (int i = 0; i < quad_info->split_info_length; ++i) { |
| split.push_back(quad_info->split_info[i].split); |
| } |
| |
| // Get a0,a1 pairs. |
| std::vector<std::pair<int, int>> A; |
| A.reserve(quad_info->unit_info_length); |
| for (int i = 0; i < quad_info->unit_info_length; ++i) { |
| A.emplace_back(quad_info->unit_info[i].xqd[0], |
| quad_info->unit_info[i].xqd[1]); |
| } |
| |
| // For each quadtree unit, apply given quadtree partitioning. |
| size_t split_index = 0; |
| size_t A_index = 0; |
| const int ext_size = quadtree_max_size * 3 / 2; |
| for (int row = 0; row < height;) { |
| const int remaining_height = height - row; |
| const int this_unit_height = |
| (remaining_height < ext_size) ? remaining_height : quadtree_max_size; |
| for (int col = 0; col < width;) { |
| const int remaining_width = width - col; |
| const int this_unit_width = |
| (remaining_width < ext_size) ? remaining_width : quadtree_max_size; |
| apply_quadtree_partitioning(interm, row, col, width, height, |
| quadtree_max_size, this_unit_width, |
| this_unit_height, quadtset, bit_depth, split, |
| split_index, A, A_index, dgd, dgd_stride); |
| col += this_unit_width; |
| } |
| row += this_unit_height; |
| } |
| assert(split_index == split.size()); |
| assert(A_index == A.size()); |
| return 1; |
| } |
| |
| extern "C" int av1_restore_cnn_quadtree_decode_tflite( |
| struct AV1Common *cm, int num_threads, int use_quadtree, |
| const int apply_cnn[MAX_MB_PLANE], const int cnn_indices[MAX_MB_PLANE]) { |
| YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf; |
| const int is_intra_only = frame_is_intra_only(cm); |
| for (int plane = 0; plane < av1_num_planes(cm); ++plane) { |
| if (!apply_cnn[plane]) continue; |
| const int is_luma = (plane == AOM_PLANE_Y); |
| if (is_luma && !use_quadtree) continue; |
| const int cnn_index = cnn_indices[plane]; |
| assert(cnn_index >= 0 && |
| cnn_index < av1_num_cnn_indices_for_plane(cm, plane)); |
| int ret = 1; |
| switch (plane) { |
| case AOM_PLANE_Y: |
| ret = restore_cnn_quadtree_decode_img_tflite_highbd( |
| cm, cm->superres_scale_denominator, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| if (ret == 0) return ret; |
| break; |
| case AOM_PLANE_U: |
| ret = av1_restore_cnn_img_tflite_highbd( |
| cm->quant_params.base_qindex, cm->superres_scale_denominator, |
| CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_crop_width, |
| buf->uv_crop_height, buf->uv_stride, |
| CONVERT_TO_SHORTPTR(buf->u_buffer), buf->uv_stride, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| if (ret == 0) return ret; |
| break; |
| case AOM_PLANE_V: |
| ret = av1_restore_cnn_img_tflite_highbd( |
| cm->quant_params.base_qindex, cm->superres_scale_denominator, |
| CONVERT_TO_SHORTPTR(buf->v_buffer), buf->uv_crop_width, |
| buf->uv_crop_height, buf->uv_stride, |
| CONVERT_TO_SHORTPTR(buf->v_buffer), buf->uv_stride, num_threads, |
| cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index); |
| if (ret == 0) return ret; |
| break; |
| default: assert(0 && "Invalid plane index"); return 0; |
| } |
| } |
| return 1; |
| } |
| #endif // CONFIG_CNN_GUIDED_QUADTREE |
| |
| #endif // CONFIG_CNN_RESTORATION |