| /* |
| * Copyright (c) 2022, Alliance for Open Media. All rights reserved |
| * |
| * This source code is subject to the terms of the BSD 2 Clause License and |
| * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License |
| * was not distributed with this source code in the LICENSE file, you can |
| * obtain it at www.aomedia.org/license/software. If the Alliance for Open |
| * Media Patent License 1.0 was not distributed with this source code in the |
| * PATENTS file, you can obtain it at www.aomedia.org/license/patent. |
| */ |
| |
| #include "config/av1_rtcd.h" |
| |
| #include "av1/encoder/erp_tflite.h" |
| #include "av1/encoder/erp_models.h" |
| #include "av1/encoder/ml.h" |
| |
| #if CONFIG_ERP_TFLITE |
| #include <vector> |
| #include "av1/tflite_models/op_registrations.h" |
| #include "common/tf_lite_includes.h" |
| #endif // CONFIG_ERP_TFLITE |
| |
| #if CONFIG_EXT_RECUR_PARTITIONS |
| #define MAKE_ERP_MODEL_SWITCH_CASE(bsize) \ |
| case bsize: \ |
| return is_hd ? av1_erp_rect_hd_##bsize##_tflite \ |
| : av1_erp_rect_##bsize##_tflite; |
| |
| #define MAKE_ERP_DNN_MODEL_SWITCH_CASE(bsize) \ |
| case bsize: \ |
| return is_hd ? &av1_erp_rect_hd_nn_config_##bsize \ |
| : &av1_erp_rect_nn_config_##bsize; |
| |
| #define MAKE_ERP_MEAN_SWITCH_CASE(bsize) \ |
| case bsize: \ |
| return is_hd ? av1_erp_rect_hd_feature_mean_##bsize \ |
| : av1_erp_rect_feature_mean_##bsize; |
| |
| #define MAKE_ERP_STD_SWITCH_CASE(bsize) \ |
| case bsize: \ |
| return is_hd ? av1_erp_rect_hd_feature_std_##bsize \ |
| : av1_erp_rect_feature_std_##bsize; |
| |
| #if CONFIG_ERP_TFLITE |
| static const unsigned char *get_model_data(BLOCK_SIZE bsize, bool is_hd) { |
| switch (bsize) { |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_128X128) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_128X64) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_64X128) |
| |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_64X64) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_64X32) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_32X64) |
| |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_32X32) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_32X16) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_16X32) |
| |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_16X16) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_16X8) |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_8X16) |
| |
| MAKE_ERP_MODEL_SWITCH_CASE(BLOCK_8X8) |
| |
| default: assert(0 && "Invalid block size!\n"); return NULL; |
| } |
| } |
| |
| static std::unique_ptr<tflite::Interpreter> get_tflite_interpreter( |
| BLOCK_SIZE bsize, bool is_hd) { |
| const unsigned char *const model_tflite_data = get_model_data(bsize, is_hd); |
| auto model = tflite::GetModel(model_tflite_data); |
| tflite::MutableOpResolver resolver; |
| RegisterSelectedOpsAllQps(&resolver); |
| tflite::InterpreterBuilder builder(model, resolver); |
| std::unique_ptr<tflite::Interpreter> interpreter; |
| builder(&interpreter); |
| interpreter->SetNumThreads(1); |
| tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter(); |
| |
| // Dimension order: batch_size, feature_size |
| const std::vector<int> in_out_dims = { 1, 19 }; |
| |
| if (interpreter->AllocateTensors() != kTfLiteOk) { |
| reporter->Report("Failed at tensor allocation"); |
| return nullptr; |
| } |
| |
| return interpreter; |
| } |
| #else |
| static const NN_CONFIG *get_dnn_model(BLOCK_SIZE bsize, bool is_hd) { |
| switch (bsize) { |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_128X128) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_128X64) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_64X128) |
| |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_64X64) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_64X32) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_32X64) |
| |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_32X32) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_32X16) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_16X32) |
| |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_16X16) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_16X8) |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_8X16) |
| |
| MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_8X8) |
| |
| default: assert(0 && "Invalid block size!\n"); return NULL; |
| } |
| } |
| #endif // CONFIG_ERP_TFLITE |
| |
| static const float *get_mean(BLOCK_SIZE bsize, bool is_hd) { |
| switch (bsize) { |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_128X128) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_128X64) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_64X128) |
| |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_64X64) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_64X32) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_32X64) |
| |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_32X32) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_32X16) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_16X32) |
| |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_16X16) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_16X8) |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_8X16) |
| |
| MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_8X8) |
| |
| default: assert(0 && "Invalid block size!\n"); return NULL; |
| } |
| } |
| |
| static const float *get_std(BLOCK_SIZE bsize, bool is_hd) { |
| switch (bsize) { |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_128X128) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_128X64) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_64X128) |
| |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_64X64) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_64X32) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_32X64) |
| |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_32X32) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_32X16) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_16X32) |
| |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_16X16) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_16X8) |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_8X16) |
| |
| MAKE_ERP_STD_SWITCH_CASE(BLOCK_8X8) |
| |
| default: assert(0 && "Invalid block size!\n"); return NULL; |
| } |
| } |
| #undef MAKE_ERP_MODEL_SWITCH_CASE |
| |
| static inline void normalize(float *features_dst, const float *features_src, |
| const float *mean, const float *std, |
| size_t num_features) { |
| #define EPSILON 0.00001f |
| for (size_t idx = 0; idx < num_features; idx++) { |
| if (std[idx] <= EPSILON) { |
| // Low variance. Assumes a constant |
| features_dst[idx] = 0.0f; |
| } else { |
| features_dst[idx] = (features_src[idx] - mean[idx]) / std[idx]; |
| } |
| } |
| #undef EPSILON |
| } |
| |
| extern "C" int av1_erp_prune_rect(BLOCK_SIZE bsize, bool is_hd, |
| const float *features, bool *prune_horz, |
| bool *prune_vert) { |
| #if CONFIG_ERP_TFLITE |
| std::unique_ptr<tflite::Interpreter> interpreter = |
| get_tflite_interpreter(bsize, is_hd); |
| |
| // Prepare input. |
| float *input = interpreter->typed_input_tensor<float>(0); |
| const float *mean = get_mean(bsize, is_hd); |
| const float *std = get_std(bsize, is_hd); |
| normalize(input, features, mean, std, 19); |
| |
| // Invoke TFlite inference. |
| const float *output; |
| tflite::ErrorReporter *reporter = tflite::DefaultErrorReporter(); |
| auto status = interpreter->Invoke(); |
| if (status != kTfLiteOk) { |
| reporter->Report("Failed at interpreter invocation"); |
| return 0; |
| } |
| output = interpreter->typed_output_tensor<float>(0); |
| interpreter.reset(); |
| #else |
| // Prepare input. |
| float input[19]; |
| const float *mean = get_mean(bsize, is_hd); |
| const float *std = get_std(bsize, is_hd); |
| normalize(input, features, mean, std, 19); |
| |
| // Call nn config |
| float output[3]; |
| const NN_CONFIG *nn_config = get_dnn_model(bsize, is_hd); |
| av1_nn_predict(input, nn_config, 1, output); |
| #endif // CONFIG_ERP_TFLITE |
| |
| float probs[3]; |
| av1_nn_softmax(output, probs, 3); |
| |
| static const float threshes[2][5] = { |
| // Non-hd |
| { |
| // 128, 64, 32, 16, 8 |
| 0.00889f, |
| 0.00268f, |
| 0.01480f, |
| 0.03531f, |
| 0.04103f, |
| }, |
| // HD |
| { |
| // 128, 64, 32, 16, 8 |
| 0.01911f, |
| 0.00327f, |
| 0.00520f, |
| 0.01669f, |
| 0.00176f, |
| }, |
| }; |
| |
| float thresh = 0.0f; |
| switch (bsize) { |
| case BLOCK_128X128: |
| case BLOCK_128X64: |
| case BLOCK_64X128: thresh = threshes[is_hd][0]; break; |
| case BLOCK_64X64: |
| case BLOCK_64X32: |
| case BLOCK_32X64: thresh = threshes[is_hd][1]; break; |
| case BLOCK_32X32: |
| case BLOCK_32X16: |
| case BLOCK_16X32: thresh = threshes[is_hd][2]; break; |
| case BLOCK_16X16: |
| case BLOCK_16X8: |
| case BLOCK_8X16: thresh = threshes[is_hd][3]; break; |
| case BLOCK_8X8: thresh = threshes[is_hd][4]; break; |
| default: |
| assert("Unexpected block size in erp pruning model!\n"); |
| thresh = 0.0f; |
| } |
| |
| if (probs[1] < thresh) { |
| *prune_horz = true; |
| } |
| if (probs[2] < thresh) { |
| *prune_vert = true; |
| } |
| |
| return 1; |
| } |
| #endif // CONFIG_EXT_RECUR_PARTITIONS |