/*
 * Copyright (c) 2021, Alliance for Open Media. All rights reserved
 *
 * This source code is subject to the terms of the BSD 3-Clause Clear License
 * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
 * License was not distributed with this source code in the LICENSE file, you
 * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/.  If the
 * Alliance for Open Media Patent License 1.0 was not distributed with this
 * source code in the PATENTS file, you can obtain it at
 * aomedia.org/license/patent-license/.
 */

#ifndef AOM_AV1_ENCODER_MODEL_RD_H_
#define AOM_AV1_ENCODER_MODEL_RD_H_

#include "aom/aom_integer.h"
#include "av1/encoder/block.h"
#include "av1/encoder/encoder.h"
#include "av1/encoder/pustats.h"
#include "av1/encoder/rdopt_utils.h"
#include "aom_ports/system_state.h"
#include "config/aom_dsp_rtcd.h"

#ifdef __cplusplus
extern "C" {
#endif

// Each one uses one of the model RD types from ModelRdType enum.
#define MODELRD_TYPE_INTERP_FILTER 1
#define MODELRD_TYPE_TX_SEARCH_PRUNE 1
#define MODELRD_TYPE_MASKED_COMPOUND 1
#define MODELRD_TYPE_INTERINTRA 1
#define MODELRD_TYPE_INTRA 1
#define MODELRD_TYPE_MOTION_MODE_RD 1

// note : if cpi(AV1_COMP)->mrsse is true, the below flags will be ignored.
// (apply MRSSE) 0: Legacy SSE 1: Mean removed SSE
#if CONFIG_MRSSE
#define SSE_TYPE_INTERP_FILTER 0
#define SSE_TYPE_TX_SEARCH_PRUNE 0
#define SSE_TYPE_MASKED_COMPOUND 0
#define SSE_TYPE_INTERINTRA 0
#define SSE_TYPE_INTRA 0
#define SSE_TYPE_MOTION_MODE_RD 0
#endif  // CONFIG_MRSSE

typedef void (*model_rd_for_sb_type)(const AV1_COMP *const cpi,
                                     BLOCK_SIZE bsize, MACROBLOCK *x,
                                     MACROBLOCKD *xd, int plane_from,
                                     int plane_to, int *out_rate_sum,
                                     int64_t *out_dist_sum, int *skip_txfm_sb,
                                     int64_t *skip_sse_sb, int *plane_rate,
                                     int64_t *plane_sse, int64_t *plane_dist
#if CONFIG_MRSSE
                                     ,
                                     int use_mrsse
#endif  // CONFIG_MRSSE
);
typedef void (*model_rd_from_sse_type)(const AV1_COMP *const cpi,
                                       const MACROBLOCK *const x,
                                       BLOCK_SIZE plane_bsize, int plane,
                                       int64_t sse, int num_samples, int *rate,
                                       int64_t *dist);

static int64_t calculate_sse(MACROBLOCKD *const xd,
                             const struct macroblock_plane *p,
                             struct macroblockd_plane *pd, const int bw,
                             const int bh) {
  int64_t sse = 0;
  const int shift = xd->bd - 8;
  sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
                       bw, bh);
  sse = ROUND_POWER_OF_TWO(sse, shift * 2);
  return sse;
}

#if CONFIG_MRSSE
// Caclulate Mean-removed SSE.
static int64_t calculate_mrsse(MACROBLOCKD *const xd,
                               const struct macroblock_plane *p,
                               struct macroblockd_plane *pd, const int bw,
                               const int bh) {
  int64_t mrsse = 0;
  const int shift = xd->bd - 8;
  mrsse = aom_highbd_mrsse(p->src.buf, p->src.stride, pd->dst.buf,
                           pd->dst.stride, bw, bh);
  mrsse = ROUND_POWER_OF_TWO(mrsse, shift * 2);
  return mrsse;
}

typedef int64_t (*sse_type)(MACROBLOCKD *const xd,
                            const struct macroblock_plane *p,
                            struct macroblockd_plane *pd, const int bw,
                            const int bh);

enum { SSE, MR_SSE, SSE_TYPES } UENUM1BYTE(SSEType);
static sse_type sse_fn[SSE_TYPES] = { calculate_sse, calculate_mrsse };
#endif  // CONFIG_MRSSE

static AOM_INLINE int64_t compute_sse_plane(MACROBLOCK *x, MACROBLOCKD *xd,
                                            int plane, const BLOCK_SIZE bsize
#if CONFIG_MRSSE
                                            ,
                                            bool use_mrsse
#endif  // CONFIG_MRSSE
) {
  struct macroblockd_plane *const pd = &xd->plane[plane];
  const BLOCK_SIZE plane_bsize =
      get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
  int bw, bh;
  const struct macroblock_plane *const p = &x->plane[plane];
  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
                     &bh);
#if CONFIG_MRSSE
  const int sse_fn_idx = use_mrsse;
  int64_t sse = sse_fn[sse_fn_idx](xd, p, pd, bw, bh);
#else
  int64_t sse = calculate_sse(xd, p, pd, bw, bh);
#endif  // CONFIG_MRSSE
  return sse;
}

static AOM_INLINE void model_rd_from_sse(const AV1_COMP *const cpi,
                                         const MACROBLOCK *const x,
                                         BLOCK_SIZE plane_bsize, int plane,
                                         int64_t sse, int num_samples,
                                         int *rate, int64_t *dist) {
  (void)num_samples;
  const MACROBLOCKD *const xd = &x->e_mbd;
  const struct macroblock_plane *const p = &x->plane[plane];
  const int dequant_shift = xd->bd - 5;

  // Fast approximate the modelling function.
  if (cpi->sf.rd_sf.simple_model_rd_from_var) {
    const int64_t square_error = sse;
    int quantizer = ROUND_POWER_OF_TWO(p->dequant_QTX[1], QUANT_TABLE_BITS) >>
                    dequant_shift;
    if (quantizer < 120)
      *rate = (int)AOMMIN(
          (square_error * (280 - quantizer)) >> (16 - AV1_PROB_COST_SHIFT),
          INT_MAX);
    else
      *rate = 0;
    assert(*rate >= 0);
    *dist = (square_error * quantizer) >> 8;
  } else {
    av1_model_rd_from_var_lapndz(sse, num_pels_log2_lookup[plane_bsize],
                                 p->dequant_QTX[1] >> dequant_shift, rate,
                                 dist);
  }
  *dist <<= 4;
}

// Fits a curve for rate and distortion using as feature:
// log2(sse_norm/qstep^2)
static AOM_INLINE void model_rd_with_curvfit(const AV1_COMP *const cpi,
                                             const MACROBLOCK *const x,
                                             BLOCK_SIZE plane_bsize, int plane,
                                             int64_t sse, int num_samples,
                                             int *rate, int64_t *dist) {
  (void)cpi;
  (void)plane_bsize;
  const MACROBLOCKD *const xd = &x->e_mbd;
  const struct macroblock_plane *const p = &x->plane[plane];
  const int dequant_shift = xd->bd - 5;
  const int qstep = AOMMAX(
      ROUND_POWER_OF_TWO(p->dequant_QTX[1], QUANT_TABLE_BITS) >> dequant_shift,
      1);

  if (sse == 0) {
    if (rate) *rate = 0;
    if (dist) *dist = 0;
    return;
  }
  aom_clear_system_state();
  const double sse_norm = (double)sse / num_samples;
  const double qstepsqr = (double)qstep * qstep;
  const double xqr = log2(sse_norm / qstepsqr);
  double rate_f, dist_by_sse_norm_f;
  av1_model_rd_curvfit(plane_bsize, sse_norm, xqr, &rate_f,
                       &dist_by_sse_norm_f);

  const double dist_f = dist_by_sse_norm_f * sse_norm;
  int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
  int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5);
  aom_clear_system_state();

  // Check if skip is better
  if (rate_i == 0) {
    dist_i = sse << 4;
  } else if (RDCOST(x->rdmult, rate_i, dist_i) >=
             RDCOST(x->rdmult, 0, sse << 4)) {
    rate_i = 0;
    dist_i = sse << 4;
  }

  if (rate) *rate = rate_i;
  if (dist) *dist = dist_i;
}

static AOM_INLINE void model_rd_for_sb(const AV1_COMP *const cpi,
                                       BLOCK_SIZE bsize, MACROBLOCK *x,
                                       MACROBLOCKD *xd, int plane_from,
                                       int plane_to, int *out_rate_sum,
                                       int64_t *out_dist_sum, int *skip_txfm_sb,
                                       int64_t *skip_sse_sb, int *plane_rate,
                                       int64_t *plane_sse, int64_t *plane_dist
#if CONFIG_MRSSE
                                       ,
                                       int use_mrsse
#endif  // CONFIG_MRSSE
) {
#if CONFIG_EXT_RECUR_PARTITIONS
  (void)bsize;
#endif  // CONFIG_EXT_RECUR_PARTITIONS

  // Note our transform coeffs are 8 times an orthogonal transform.
  // Hence quantizer step is also 8 times. To get effective quantizer
  // we need to divide by 8 before sending to modeling function.
  int plane;
  const int ref = COMPACT_INDEX0_NRS(xd->mi[0]->ref_frame[0]);

  int64_t rate_sum = 0;
  int64_t dist_sum = 0;
  int64_t total_sse = 0;
#if CONFIG_MRSSE
  const int sse_fn_idx = cpi->oxcf.tool_cfg.enable_mrsse || use_mrsse;
#endif  // CONFIG_MRSSE
  assert(bsize < BLOCK_SIZES_ALL);

  for (plane = plane_from; plane <= plane_to; ++plane) {
    if (plane && !xd->is_chroma_ref) break;
    struct macroblock_plane *const p = &x->plane[plane];
    struct macroblockd_plane *const pd = &xd->plane[plane];
#if CONFIG_EXT_RECUR_PARTITIONS
    const BLOCK_SIZE plane_bsize = get_mb_plane_block_size(
        xd, xd->mi[0], plane, pd->subsampling_x, pd->subsampling_y);
#else
    const BLOCK_SIZE plane_bsize =
        get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
#endif  // CONFIG_EXT_RECUR_PARTITIONS
    assert(plane_bsize < BLOCK_SIZES_ALL);
    const int bw = block_size_wide[plane_bsize];
    const int bh = block_size_high[plane_bsize];
    int64_t sse;
    int rate;
    int64_t dist;
#if CONFIG_MRSSE
    sse = sse_fn[sse_fn_idx](xd, p, pd, bw, bh);
#else
    sse = calculate_sse(xd, p, pd, bw, bh);
#endif  // CONFIG_MRSSE
    model_rd_from_sse(cpi, x, plane_bsize, plane, sse, bw * bh, &rate, &dist);

    if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);

    total_sse += sse;
    rate_sum += rate;
    dist_sum += dist;
    if (plane_rate) plane_rate[plane] = rate;
    if (plane_sse) plane_sse[plane] = sse;
    if (plane_dist) plane_dist[plane] = dist;
    assert(rate_sum >= 0);
  }

  if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
  if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
  rate_sum = AOMMIN(rate_sum, INT_MAX);
  *out_rate_sum = (int)rate_sum;
  *out_dist_sum = dist_sum;
}

static AOM_INLINE void model_rd_for_sb_with_curvfit(
    const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
    int plane_from, int plane_to, int *out_rate_sum, int64_t *out_dist_sum,
    int *skip_txfm_sb, int64_t *skip_sse_sb, int *plane_rate,
    int64_t *plane_sse, int64_t *plane_dist
#if CONFIG_MRSSE
    ,
    int use_mrsse
#endif  // CONFIG_MRSSE
) {

#if CONFIG_EXT_RECUR_PARTITIONS
  (void)bsize;
#endif  // CONFIG_EXT_RECUR_PARTITIONS

  // Note our transform coeffs are 8 times an orthogonal transform.
  // Hence quantizer step is also 8 times. To get effective quantizer
  // we need to divide by 8 before sending to modeling function.
  const int ref = COMPACT_INDEX0_NRS(xd->mi[0]->ref_frame[0]);

  int64_t rate_sum = 0;
  int64_t dist_sum = 0;
  int64_t total_sse = 0;
#if CONFIG_MRSSE
  const int sse_fn_idx = cpi->oxcf.tool_cfg.enable_mrsse || use_mrsse;
#endif  // CONFIG_MRSSE

  for (int plane = plane_from; plane <= plane_to; ++plane) {
    if (plane && !xd->is_chroma_ref) break;
    struct macroblockd_plane *const pd = &xd->plane[plane];
#if CONFIG_EXT_RECUR_PARTITIONS
    const BLOCK_SIZE plane_bsize = get_mb_plane_block_size(
        xd, xd->mi[0], plane, pd->subsampling_x, pd->subsampling_y);
#else
    const BLOCK_SIZE plane_bsize =
        get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
#endif  // CONFIG_EXT_RECUR_PARTITIONS
    assert(plane_bsize < BLOCK_SIZES_ALL);
    int64_t dist, sse;
    int rate;
    int bw, bh;
    const struct macroblock_plane *const p = &x->plane[plane];
#if CONFIG_E191_OFS_PRED_RES_HANDLE
    const AV1_COMMON *const cm = &cpi->common;
    const int block_width = block_size_wide[plane_bsize];
    const int block_height = block_size_high[plane_bsize];
    const int is_border_block =
        get_visible_dimensions(xd, plane, 0, 0, block_width, block_height,
                               cm->width, cm->height, &bw, &bh);
#else
    get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
                       &bw, &bh);
#endif  // CONFIG_E191_OFS_PRED_RES_HANDLE
#if CONFIG_MRSSE
    sse = sse_fn[sse_fn_idx](xd, p, pd, bw, bh);
#else
#if CONFIG_E191_OFS_PRED_RES_HANDLE
    const int shift = xd->bd - 8;
    if (!is_border_block)
      sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
                           pd->dst.stride, bw, bh);
    else
      sse = aom_highbd_sse_c(p->src.buf, p->src.stride, pd->dst.buf,
                             pd->dst.stride, bw, bh);

    sse = ROUND_POWER_OF_TWO(sse, shift * 2);
#else
    sse = calculate_sse(xd, p, pd, bw, bh);
#endif  // CONFIG_E191_OFS_PRED_RES_HANDLE
#endif  // CONFIG_MRSSE
    model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
                          &dist);

    if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);

    total_sse += sse;
    rate_sum += rate;
    dist_sum += dist;

    if (plane_rate) plane_rate[plane] = rate;
    if (plane_sse) plane_sse[plane] = sse;
    if (plane_dist) plane_dist[plane] = dist;
  }

  if (skip_txfm_sb) *skip_txfm_sb = rate_sum == 0;
  if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
  *out_rate_sum = (int)rate_sum;
  *out_dist_sum = dist_sum;
}

enum { MODELRD_LEGACY, MODELRD_CURVFIT, MODELRD_TYPES } UENUM1BYTE(ModelRdType);

static const model_rd_for_sb_type model_rd_sb_fn[MODELRD_TYPES] = {
  model_rd_for_sb, model_rd_for_sb_with_curvfit
};

static const model_rd_from_sse_type model_rd_sse_fn[MODELRD_TYPES] = {
  model_rd_from_sse, model_rd_with_curvfit
};

#ifdef __cplusplus
}  // extern "C"
#endif
#endif  // AOM_AV1_ENCODER_MODEL_RD_H_
