blob: 583b6d2fec688b3b2e11138f166d7425a5990f86 [file] [log] [blame]
Hui Su8e154702018-03-23 16:10:57 -07001/*
Lester Lu6bc30d62021-12-16 19:13:21 +00002 * Copyright (c) 2021, Alliance for Open Media. All rights reserved
Hui Su8e154702018-03-23 16:10:57 -07003 *
Lester Lu6bc30d62021-12-16 19:13:21 +00004 * This source code is subject to the terms of the BSD 3-Clause Clear License
5 * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
6 * License was not distributed with this source code in the LICENSE file, you
7 * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
8 * Alliance for Open Media Patent License 1.0 was not distributed with this
9 * source code in the PATENTS file, you can obtain it at
10 * aomedia.org/license/patent-license/.
Hui Su8e154702018-03-23 16:10:57 -070011 */
12
James Zerne1cbb132018-08-22 14:10:36 -070013#ifndef AOM_AV1_ENCODER_ML_H_
14#define AOM_AV1_ENCODER_ML_H_
Hui Su8e154702018-03-23 16:10:57 -070015
16#ifdef __cplusplus
17extern "C" {
18#endif
19
David Turner486cc982018-11-09 15:48:58 +000020#include "config/av1_rtcd.h"
21
Hui Su8e154702018-03-23 16:10:57 -070022#define NN_MAX_HIDDEN_LAYERS 10
23#define NN_MAX_NODES_PER_LAYER 128
24
David Turner486cc982018-11-09 15:48:58 +000025struct NN_CONFIG {
Hui Su8e154702018-03-23 16:10:57 -070026 int num_inputs; // Number of input nodes, i.e. features.
27 int num_outputs; // Number of output nodes.
28 int num_hidden_layers; // Number of hidden layers, maximum 10.
29 // Number of nodes for each hidden layer.
30 int num_hidden_nodes[NN_MAX_HIDDEN_LAYERS];
31 // Weight parameters, indexed by layer.
32 const float *weights[NN_MAX_HIDDEN_LAYERS + 1];
33 // Bias parameters, indexed by layer.
34 const float *bias[NN_MAX_HIDDEN_LAYERS + 1];
David Turner486cc982018-11-09 15:48:58 +000035};
36// Typedef from struct NN_CONFIG to NN_CONFIG is in rtcd_defs
Hui Su8e154702018-03-23 16:10:57 -070037
mlchen759cdac2019-06-13 15:52:02 -070038#if CONFIG_NN_V2
39// Fully-connectedly layer configuration
40struct FC_LAYER {
41 const int num_inputs; // Number of input nodes, i.e. features.
42 const int num_outputs; // Number of output nodes.
43
44 float *weights; // Weight parameters.
45 float *bias; // Bias parameters.
46 const ACTIVATION activation; // Activation function.
47
48 float *output; // The output array.
49 float *dY; // Gradient of outputs
50 float *dW; // Gradient of weights.
51 float *db; // Gradient of bias
52};
53
54// NN configure structure V2
55struct NN_CONFIG_V2 {
56 const int num_hidden_layers; // Number of hidden layers, max = 10.
57 FC_LAYER layer[NN_MAX_HIDDEN_LAYERS + 1]; // The layer array
58 const int num_logits; // Number of output nodes.
59 float *logits; // Raw prediction (same as output of final layer)
60 const LOSS loss; // Loss function
61};
62
63// Calculate prediction based on the given input features and neural net config.
64// Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
65// layer.
66void av1_nn_predict_v2(const float *features, NN_CONFIG_V2 *nn_config,
Debargha Mukherjeed44f5d12019-06-27 14:56:05 -070067 int reduce_prec, float *output);
mlchen759cdac2019-06-13 15:52:02 -070068#endif // CONFIG_NN_V2
69
Alexander Bokov9b5fb2c2018-08-27 14:37:21 -070070// Applies the softmax normalization function to the input
71// to get a valid probability distribution in the output:
72// output[i] = exp(input[i]) / sum_{k \in [0,n)}(exp(input[k]))
73void av1_nn_softmax(const float *input, float *output, int n);
74
Debargha Mukherjeed44f5d12019-06-27 14:56:05 -070075// Applies a precision reduction to output of av1_nn_predict to prevent
76// mismatches between C and SIMD implementations.
77void av1_nn_output_prec_reduce(float *const output, int num_output);
78
Hui Su8e154702018-03-23 16:10:57 -070079#ifdef __cplusplus
80} // extern "C"
81#endif
82
James Zerne1cbb132018-08-22 14:10:36 -070083#endif // AOM_AV1_ENCODER_ML_H_