blob: 840547c1862cde08b3ed48d70e5ecf178f03ef1e [file] [log] [blame] [edit]
/*
* Copyright (c) 2022, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include "config/av1_rtcd.h"
#include "av1/encoder/erp_ml.h"
#include "av1/encoder/erp_models.h"
#include "av1/encoder/ml.h"
#define MAKE_ERP_MODEL_SWITCH_CASE(bsize) \
case bsize: \
return is_hd ? av1_erp_rect_hd_##bsize##_tflite \
: av1_erp_rect_##bsize##_tflite;
#define MAKE_ERP_DNN_MODEL_SWITCH_CASE(bsize) \
case bsize: \
return is_hd ? &av1_erp_rect_hd_nn_config_##bsize \
: &av1_erp_rect_nn_config_##bsize;
#define MAKE_ERP_MEAN_SWITCH_CASE(bsize) \
case bsize: \
return is_hd ? av1_erp_rect_hd_feature_mean_##bsize \
: av1_erp_rect_feature_mean_##bsize;
#define MAKE_ERP_STD_SWITCH_CASE(bsize) \
case bsize: \
return is_hd ? av1_erp_rect_hd_feature_std_##bsize \
: av1_erp_rect_feature_std_##bsize;
static const NN_CONFIG *get_dnn_model(BLOCK_SIZE bsize, bool is_hd) {
switch (bsize) {
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_128X128)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_128X64)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_64X128)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_64X64)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_64X32)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_32X64)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_32X32)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_32X16)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_16X32)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_16X16)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_16X8)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_8X16)
MAKE_ERP_DNN_MODEL_SWITCH_CASE(BLOCK_8X8)
default: assert(0 && "Invalid block size!\n"); return NULL;
}
}
static const float *get_mean(BLOCK_SIZE bsize, bool is_hd) {
switch (bsize) {
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_128X128)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_128X64)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_64X128)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_64X64)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_64X32)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_32X64)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_32X32)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_32X16)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_16X32)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_16X16)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_16X8)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_8X16)
MAKE_ERP_MEAN_SWITCH_CASE(BLOCK_8X8)
default: assert(0 && "Invalid block size!\n"); return NULL;
}
}
static const float *get_std(BLOCK_SIZE bsize, bool is_hd) {
switch (bsize) {
MAKE_ERP_STD_SWITCH_CASE(BLOCK_128X128)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_128X64)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_64X128)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_64X64)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_64X32)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_32X64)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_32X32)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_32X16)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_16X32)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_16X16)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_16X8)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_8X16)
MAKE_ERP_STD_SWITCH_CASE(BLOCK_8X8)
default: assert(0 && "Invalid block size!\n"); return NULL;
}
}
#undef MAKE_ERP_MODEL_SWITCH_CASE
static inline void normalize(float *features_dst, const float *features_src,
const float *mean, const float *std,
size_t num_features) {
#define EPSILON 0.00001f
for (size_t idx = 0; idx < num_features; idx++) {
if (std[idx] <= EPSILON) {
// Low variance. Assumes a constant
features_dst[idx] = 0.0f;
} else {
features_dst[idx] = (features_src[idx] - mean[idx]) / std[idx];
}
}
#undef EPSILON
}
int av1_erp_prune_rect(BLOCK_SIZE bsize, bool is_hd, const float *features,
bool *prune_horz, bool *prune_vert) {
// Prepare input.
float input[19];
const float *mean = get_mean(bsize, is_hd);
const float *std = get_std(bsize, is_hd);
normalize(input, features, mean, std, 19);
// Call nn config
float output[3];
const NN_CONFIG *nn_config = get_dnn_model(bsize, is_hd);
av1_nn_predict(input, nn_config, 1, output);
float probs[3];
av1_nn_softmax(output, probs, 3);
static const float threshes[2][5] = {
// Non-hd
{
// 128, 64, 32, 16, 8
0.00889f,
0.00268f,
0.01480f,
0.03531f,
0.04103f,
},
// HD
{
// 128, 64, 32, 16, 8
0.01911f,
0.00327f,
0.00520f,
0.01669f,
0.00176f,
},
};
float thresh = 0.0f;
switch (bsize) {
case BLOCK_128X128:
case BLOCK_128X64:
case BLOCK_64X128: thresh = threshes[is_hd][0]; break;
case BLOCK_64X64:
case BLOCK_64X32:
case BLOCK_32X64: thresh = threshes[is_hd][1]; break;
case BLOCK_32X32:
case BLOCK_32X16:
case BLOCK_16X32: thresh = threshes[is_hd][2]; break;
case BLOCK_16X16:
case BLOCK_16X8:
case BLOCK_8X16: thresh = threshes[is_hd][3]; break;
case BLOCK_8X8: thresh = threshes[is_hd][4]; break;
default:
assert(0 && "Unexpected block size in erp pruning model!\n");
thresh = 0.0f;
}
if (probs[1] < thresh) {
*prune_horz = true;
}
if (probs[2] < thresh) {
*prune_vert = true;
}
return 1;
}