Separate compound_type functions from rdopt.c
Created compound_type.c and compound_type.h to
improve modularity of rdopt.c
compound_type.c : To keep compound type and inter_intra
related functions
compound_type.h : To keep compound type and inter_intra
related data structures, defs and enum.
Change-Id: I174f6f47eedd0fd101275ec56337d7bb213228f2
diff --git a/av1/av1.cmake b/av1/av1.cmake
index f599574..ff1cad2 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -130,6 +130,8 @@
"${AOM_ROOT}/av1/encoder/block.h"
"${AOM_ROOT}/av1/encoder/cnn.c"
"${AOM_ROOT}/av1/encoder/cnn.h"
+ "${AOM_ROOT}/av1/encoder/compound_type.c"
+ "${AOM_ROOT}/av1/encoder/compound_type.h"
"${AOM_ROOT}/av1/encoder/context_tree.c"
"${AOM_ROOT}/av1/encoder/context_tree.h"
"${AOM_ROOT}/av1/encoder/corner_detect.c"
diff --git a/av1/encoder/compound_type.c b/av1/encoder/compound_type.c
new file mode 100644
index 0000000..b5dca37
--- /dev/null
+++ b/av1/encoder/compound_type.c
@@ -0,0 +1,1435 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "av1/common/pred_common.h"
+#include "av1/encoder/compound_type.h"
+#include "av1/encoder/model_rd.h"
+#include "av1/encoder/motion_search.h"
+#include "av1/encoder/rdopt_utils.h"
+#include "av1/encoder/reconinter_enc.h"
+#include "av1/encoder/tx_search.h"
+
+typedef int64_t (*pick_interinter_mask_type)(
+ const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
+ const uint8_t *const p0, const uint8_t *const p1,
+ const int16_t *const residual1, const int16_t *const diff10);
+
+// Checks if characteristics of search match
+static INLINE int is_comp_rd_match(const AV1_COMP *const cpi,
+ const MACROBLOCK *const x,
+ const COMP_RD_STATS *st,
+ const MB_MODE_INFO *const mi,
+ int32_t *comp_rate, int64_t *comp_dist,
+ int32_t *comp_model_rate,
+ int64_t *comp_model_dist, int *comp_rs2) {
+ // TODO(ranjit): Ensure that compound type search use regular filter always
+ // and check if following check can be removed
+ // Check if interp filter matches with previous case
+ if (st->filter.as_int != mi->interp_filters.as_int) return 0;
+
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ // Match MV and reference indices
+ for (int i = 0; i < 2; ++i) {
+ if ((st->ref_frames[i] != mi->ref_frame[i]) ||
+ (st->mv[i].as_int != mi->mv[i].as_int)) {
+ return 0;
+ }
+ const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
+ if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
+ }
+
+ // Store the stats for COMPOUND_AVERAGE and COMPOUND_DISTWTD
+ for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
+ comp_type++) {
+ comp_rate[comp_type] = st->rate[comp_type];
+ comp_dist[comp_type] = st->dist[comp_type];
+ comp_model_rate[comp_type] = st->model_rate[comp_type];
+ comp_model_dist[comp_type] = st->model_dist[comp_type];
+ comp_rs2[comp_type] = st->comp_rs2[comp_type];
+ }
+
+ // For compound wedge/segment, reuse data only if NEWMV is not present in
+ // either of the directions
+ if ((!have_newmv_in_inter_mode(mi->mode) &&
+ !have_newmv_in_inter_mode(st->mode)) ||
+ (cpi->sf.inter_sf.disable_interinter_wedge_newmv_search)) {
+ memcpy(&comp_rate[COMPOUND_WEDGE], &st->rate[COMPOUND_WEDGE],
+ sizeof(comp_rate[COMPOUND_WEDGE]) * 2);
+ memcpy(&comp_dist[COMPOUND_WEDGE], &st->dist[COMPOUND_WEDGE],
+ sizeof(comp_dist[COMPOUND_WEDGE]) * 2);
+ memcpy(&comp_model_rate[COMPOUND_WEDGE], &st->model_rate[COMPOUND_WEDGE],
+ sizeof(comp_model_rate[COMPOUND_WEDGE]) * 2);
+ memcpy(&comp_model_dist[COMPOUND_WEDGE], &st->model_dist[COMPOUND_WEDGE],
+ sizeof(comp_model_dist[COMPOUND_WEDGE]) * 2);
+ memcpy(&comp_rs2[COMPOUND_WEDGE], &st->comp_rs2[COMPOUND_WEDGE],
+ sizeof(comp_rs2[COMPOUND_WEDGE]) * 2);
+ }
+ return 1;
+}
+
+// Checks if similar compound type search case is accounted earlier
+// If found, returns relevant rd data
+static INLINE int find_comp_rd_in_stats(const AV1_COMP *const cpi,
+ const MACROBLOCK *x,
+ const MB_MODE_INFO *const mbmi,
+ int32_t *comp_rate, int64_t *comp_dist,
+ int32_t *comp_model_rate,
+ int64_t *comp_model_dist, int *comp_rs2,
+ int *match_index) {
+ for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
+ if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
+ comp_dist, comp_model_rate, comp_model_dist,
+ comp_rs2)) {
+ *match_index = j;
+ return 1;
+ }
+ }
+ return 0; // no match result found
+}
+
+static INLINE bool enable_wedge_search(MACROBLOCK *const x,
+ const AV1_COMP *const cpi) {
+ // Enable wedge search if source variance and edge strength are above
+ // the thresholds.
+ return x->source_variance >
+ cpi->sf.inter_sf.disable_wedge_search_var_thresh &&
+ x->edge_strength > cpi->sf.inter_sf.disable_wedge_search_edge_thresh;
+}
+
+static INLINE bool enable_wedge_interinter_search(MACROBLOCK *const x,
+ const AV1_COMP *const cpi) {
+ return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interinter_wedge &&
+ !cpi->sf.inter_sf.disable_interinter_wedge;
+}
+
+static INLINE bool enable_wedge_interintra_search(MACROBLOCK *const x,
+ const AV1_COMP *const cpi) {
+ return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interintra_wedge &&
+ !cpi->sf.inter_sf.disable_wedge_interintra_search;
+}
+
+static int8_t estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
+ const BLOCK_SIZE bsize, const uint8_t *pred0,
+ int stride0, const uint8_t *pred1,
+ int stride1) {
+ static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
+ // 4X4
+ BLOCK_INVALID,
+ // 4X8, 8X4, 8X8
+ BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
+ // 8X16, 16X8, 16X16
+ BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
+ // 16X32, 32X16, 32X32
+ BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
+ // 32X64, 64X32, 64X64
+ BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
+ // 64x128, 128x64, 128x128
+ BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
+ // 4X16, 16X4, 8X32
+ BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
+ // 32X8, 16X64, 64X16
+ BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
+ };
+ const struct macroblock_plane *const p = &x->plane[0];
+ const uint8_t *src = p->src.buf;
+ int src_stride = p->src.stride;
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ const int bw_by2 = bw >> 1;
+ const int bh_by2 = bh >> 1;
+ uint32_t esq[2][2];
+ int64_t tl, br;
+
+ const BLOCK_SIZE f_index = split_qtr[bsize];
+ assert(f_index != BLOCK_INVALID);
+
+ if (is_cur_buf_hbd(&x->e_mbd)) {
+ pred0 = CONVERT_TO_BYTEPTR(pred0);
+ pred1 = CONVERT_TO_BYTEPTR(pred1);
+ }
+
+ // Residual variance computation over relevant quandrants in order to
+ // find TL + BR, TL = sum(1st,2nd,3rd) quadrants of (pred0 - pred1),
+ // BR = sum(2nd,3rd,4th) quadrants of (pred1 - pred0)
+ // The 2nd and 3rd quadrants cancel out in TL + BR
+ // Hence TL + BR = 1st quadrant of (pred0-pred1) + 4th of (pred1-pred0)
+ // TODO(nithya): Sign estimation assumes 45 degrees (1st and 4th quadrants)
+ // for all codebooks; experiment with other quadrant combinations for
+ // 0, 90 and 135 degrees also.
+ cpi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
+ cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
+ pred0 + bh_by2 * stride0 + bw_by2, stride0,
+ &esq[0][1]);
+ cpi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
+ cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
+ pred1 + bh_by2 * stride1 + bw_by2, stride0,
+ &esq[1][1]);
+
+ tl = ((int64_t)esq[0][0]) - ((int64_t)esq[1][0]);
+ br = ((int64_t)esq[1][1]) - ((int64_t)esq[0][1]);
+ return (tl + br > 0);
+}
+
+// Choose the best wedge index and sign
+static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
+ const BLOCK_SIZE bsize, const uint8_t *const p0,
+ const int16_t *const residual1,
+ const int16_t *const diff10,
+ int8_t *const best_wedge_sign,
+ int8_t *const best_wedge_index) {
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ const struct buf_2d *const src = &x->plane[0].src;
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ const int N = bw * bh;
+ assert(N >= 64);
+ int rate;
+ int64_t dist;
+ int64_t rd, best_rd = INT64_MAX;
+ int8_t wedge_index;
+ int8_t wedge_sign;
+ const int8_t wedge_types = get_wedge_types_lookup(bsize);
+ const uint8_t *mask;
+ uint64_t sse;
+ const int hbd = is_cur_buf_hbd(xd);
+ const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
+
+ DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]); // src - pred0
+#if CONFIG_AV1_HIGHBITDEPTH
+ if (hbd) {
+ aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
+ CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
+ } else {
+ aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
+ }
+#else
+ (void)hbd;
+ aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
+#endif
+
+ int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
+ (int64_t)aom_sum_squares_i16(residual1, N)) *
+ (1 << WEDGE_WEIGHT_BITS) / 2;
+ int16_t *ds = residual0;
+
+ av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
+
+ for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
+ mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
+
+ wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
+
+ mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
+ sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
+ sse = ROUND_POWER_OF_TWO(sse, bd_round);
+
+ model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
+ &rate, &dist);
+ // int rate2;
+ // int64_t dist2;
+ // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
+ // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
+ // sse, rate, dist, rate2, dist2); dist = dist2;
+ // rate = rate2;
+
+ rate += x->wedge_idx_cost[bsize][wedge_index];
+ rd = RDCOST(x->rdmult, rate, dist);
+
+ if (rd < best_rd) {
+ *best_wedge_index = wedge_index;
+ *best_wedge_sign = wedge_sign;
+ best_rd = rd;
+ }
+ }
+
+ return best_rd -
+ RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
+}
+
+// Choose the best wedge index the specified sign
+static int64_t pick_wedge_fixed_sign(const AV1_COMP *const cpi,
+ const MACROBLOCK *const x,
+ const BLOCK_SIZE bsize,
+ const int16_t *const residual1,
+ const int16_t *const diff10,
+ const int8_t wedge_sign,
+ int8_t *const best_wedge_index) {
+ const MACROBLOCKD *const xd = &x->e_mbd;
+
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ const int N = bw * bh;
+ assert(N >= 64);
+ int rate;
+ int64_t dist;
+ int64_t rd, best_rd = INT64_MAX;
+ int8_t wedge_index;
+ const int8_t wedge_types = get_wedge_types_lookup(bsize);
+ const uint8_t *mask;
+ uint64_t sse;
+ const int hbd = is_cur_buf_hbd(xd);
+ const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
+ for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
+ mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
+ sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
+ sse = ROUND_POWER_OF_TWO(sse, bd_round);
+
+ model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
+ &rate, &dist);
+ rate += x->wedge_idx_cost[bsize][wedge_index];
+ rd = RDCOST(x->rdmult, rate, dist);
+
+ if (rd < best_rd) {
+ *best_wedge_index = wedge_index;
+ best_rd = rd;
+ }
+ }
+ return best_rd -
+ RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
+}
+
+static int64_t pick_interinter_wedge(
+ const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
+ const uint8_t *const p0, const uint8_t *const p1,
+ const int16_t *const residual1, const int16_t *const diff10) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const int bw = block_size_wide[bsize];
+
+ int64_t rd;
+ int8_t wedge_index = -1;
+ int8_t wedge_sign = 0;
+
+ assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+ assert(cpi->common.seq_params.enable_masked_compound);
+
+ if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
+ wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
+ rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
+ &wedge_index);
+ } else {
+ rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
+ &wedge_index);
+ }
+
+ mbmi->interinter_comp.wedge_sign = wedge_sign;
+ mbmi->interinter_comp.wedge_index = wedge_index;
+ return rd;
+}
+
+static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
+ MACROBLOCK *const x, const BLOCK_SIZE bsize,
+ const uint8_t *const p0,
+ const uint8_t *const p1,
+ const int16_t *const residual1,
+ const int16_t *const diff10) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ const int N = 1 << num_pels_log2_lookup[bsize];
+ int rate;
+ int64_t dist;
+ DIFFWTD_MASK_TYPE cur_mask_type;
+ int64_t best_rd = INT64_MAX;
+ DIFFWTD_MASK_TYPE best_mask_type = 0;
+ const int hbd = is_cur_buf_hbd(xd);
+ const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
+ DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
+ uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
+ // try each mask type and its inverse
+ for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
+ // build mask and inverse
+ if (hbd)
+ av1_build_compound_diffwtd_mask_highbd(
+ tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
+ CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
+ else
+ av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
+ p0, bw, p1, bw, bh, bw);
+
+ // compute rd for mask
+ uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
+ tmp_mask[cur_mask_type], N);
+ sse = ROUND_POWER_OF_TWO(sse, bd_round);
+
+ model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
+ &rate, &dist);
+ const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
+
+ if (rd0 < best_rd) {
+ best_mask_type = cur_mask_type;
+ best_rd = rd0;
+ }
+ }
+ mbmi->interinter_comp.mask_type = best_mask_type;
+ if (best_mask_type == DIFFWTD_38_INV) {
+ memcpy(xd->seg_mask, seg_mask, N * 2);
+ }
+ return best_rd;
+}
+
+static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
+ const MACROBLOCK *const x,
+ const BLOCK_SIZE bsize,
+ const uint8_t *const p0,
+ const uint8_t *const p1) {
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ assert(av1_is_wedge_used(bsize));
+ assert(cpi->common.seq_params.enable_interintra_compound);
+
+ const struct buf_2d *const src = &x->plane[0].src;
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]); // src - pred1
+ DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]); // pred1 - pred0
+#if CONFIG_AV1_HIGHBITDEPTH
+ if (is_cur_buf_hbd(xd)) {
+ aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
+ CONVERT_TO_BYTEPTR(p1), bw, xd->bd);
+ aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
+ CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
+ } else {
+ aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
+ aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
+ }
+#else
+ aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
+ aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
+#endif
+ int8_t wedge_index = -1;
+ int64_t rd =
+ pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0, &wedge_index);
+
+ mbmi->interintra_wedge_index = wedge_index;
+ return rd;
+}
+
+static AOM_INLINE void get_inter_predictors_masked_compound(
+ MACROBLOCK *x, const BLOCK_SIZE bsize, uint8_t **preds0, uint8_t **preds1,
+ int16_t *residual1, int16_t *diff10, int *strides) {
+ MACROBLOCKD *xd = &x->e_mbd;
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ // get inter predictors to use for masked compound modes
+ av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0, preds0,
+ strides);
+ av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1, preds1,
+ strides);
+ const struct buf_2d *const src = &x->plane[0].src;
+#if CONFIG_AV1_HIGHBITDEPTH
+ if (is_cur_buf_hbd(xd)) {
+ aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
+ CONVERT_TO_BYTEPTR(*preds1), bw, xd->bd);
+ aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
+ bw, CONVERT_TO_BYTEPTR(*preds0), bw, xd->bd);
+ } else {
+ aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
+ bw);
+ aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
+ }
+#else
+ aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1, bw);
+ aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
+#endif
+}
+
+// Computes the rd cost for the given interintra mode and updates the best
+static INLINE void compute_best_interintra_mode(
+ const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
+ MACROBLOCK *const x, const int *const interintra_mode_cost,
+ const BUFFER_SET *orig_dst, uint8_t *intrapred, const uint8_t *tmp_buf,
+ INTERINTRA_MODE *best_interintra_mode, int64_t *best_interintra_rd,
+ INTERINTRA_MODE interintra_mode, BLOCK_SIZE bsize) {
+ const AV1_COMMON *const cm = &cpi->common;
+ int rate, skip_txfm_sb;
+ int64_t dist, skip_sse_sb;
+ const int bw = block_size_wide[bsize];
+ mbmi->interintra_mode = interintra_mode;
+ int rmode = interintra_mode_cost[interintra_mode];
+ av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
+ intrapred, bw);
+ av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
+ model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](cpi, bsize, x, xd, 0, 0, &rate, &dist,
+ &skip_txfm_sb, &skip_sse_sb, NULL,
+ NULL, NULL);
+ int64_t rd = RDCOST(x->rdmult, rate + rmode, dist);
+ if (rd < *best_interintra_rd) {
+ *best_interintra_rd = rd;
+ *best_interintra_mode = mbmi->interintra_mode;
+ }
+}
+
+static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
+ MACROBLOCK *x, int64_t ref_best_rd,
+ RD_STATS *rd_stats) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ if (ref_best_rd < 0) return INT64_MAX;
+ av1_subtract_plane(x, bs, 0);
+ x->rd_model = LOW_TXFM_RD;
+ int skip_trellis = cpi->optimize_seg_arr[xd->mi[0]->segment_id] ==
+ NO_ESTIMATE_YRD_TRELLIS_OPT;
+ const int64_t rd =
+ txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs, max_txsize_rect_lookup[bs],
+ FTXS_NONE, skip_trellis);
+ x->rd_model = FULL_TXFM_RD;
+ if (rd != INT64_MAX) {
+ const int skip_ctx = av1_get_skip_context(xd);
+ if (rd_stats->skip) {
+ const int s1 = x->skip_cost[skip_ctx][1];
+ rd_stats->rate = s1;
+ } else {
+ const int s0 = x->skip_cost[skip_ctx][0];
+ rd_stats->rate += s0;
+ }
+ }
+ return rd;
+}
+
+// Computes the rd_threshold and total_mode_rate
+static AOM_INLINE int64_t compute_total_rate_and_rd_thresh(
+ MACROBLOCK *const x, int *rate_mv, int *total_mode_rate, BLOCK_SIZE bsize,
+ int64_t ref_best_rd, int rmode) {
+ const int is_wedge_used = av1_is_wedge_used(bsize);
+ const int64_t rd_thresh = get_rd_thresh_from_best_rd(
+ ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
+ INTER_INTRA_RD_THRESH_SCALE);
+ const int rwedge = is_wedge_used ? x->wedge_interintra_cost[bsize][0] : 0;
+ *total_mode_rate = *rate_mv + rmode + rwedge;
+ const int64_t mode_rd = RDCOST(x->rdmult, *total_mode_rate, 0);
+ return (rd_thresh - mode_rd);
+}
+
+// Computes the best wedge interintra mode
+static AOM_INLINE int64_t compute_best_wedge_interintra(
+ const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
+ MACROBLOCK *const x, const int *const interintra_mode_cost,
+ const BUFFER_SET *orig_dst, uint8_t *intrapred_, uint8_t *tmp_buf_,
+ int *best_mode, int *best_wedge_index, BLOCK_SIZE bsize) {
+ const AV1_COMMON *const cm = &cpi->common;
+ const int bw = block_size_wide[bsize];
+ int64_t best_interintra_rd_wedge = INT64_MAX;
+ int64_t best_total_rd = INT64_MAX;
+ uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
+ for (INTERINTRA_MODE mode = 0; mode < INTERINTRA_MODES; ++mode) {
+ mbmi->interintra_mode = mode;
+ av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
+ intrapred, bw);
+ int64_t rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
+ const int rate_overhead =
+ interintra_mode_cost[mode] +
+ x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index];
+ const int64_t total_rd = rd + RDCOST(x->rdmult, rate_overhead, 0);
+ if (total_rd < best_total_rd) {
+ best_total_rd = total_rd;
+ best_interintra_rd_wedge = rd;
+ *best_mode = mbmi->interintra_mode;
+ *best_wedge_index = mbmi->interintra_wedge_index;
+ }
+ }
+ return best_interintra_rd_wedge;
+}
+
+int handle_inter_intra_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+ BLOCK_SIZE bsize, MB_MODE_INFO *mbmi,
+ HandleInterModeArgs *args, int64_t ref_best_rd,
+ int *rate_mv, int *tmp_rate2,
+ const BUFFER_SET *orig_dst) {
+ const int try_smooth_interintra = cpi->oxcf.enable_smooth_interintra &&
+ !cpi->sf.inter_sf.disable_smooth_interintra;
+ const int try_wedge_interintra =
+ av1_is_wedge_used(bsize) && enable_wedge_interintra_search(x, cpi);
+ if (!try_smooth_interintra && !try_wedge_interintra) return -1;
+
+ const AV1_COMMON *const cm = &cpi->common;
+ MACROBLOCKD *xd = &x->e_mbd;
+ int64_t rd = INT64_MAX;
+ const int bw = block_size_wide[bsize];
+ DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
+ DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
+ uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
+ uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
+ const int *const interintra_mode_cost =
+ x->interintra_mode_cost[size_group_lookup[bsize]];
+ const int mi_row = xd->mi_row;
+ const int mi_col = xd->mi_col;
+
+ // Single reference inter prediction
+ mbmi->ref_frame[1] = NONE_FRAME;
+ xd->plane[0].dst.buf = tmp_buf;
+ xd->plane[0].dst.stride = bw;
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ const int num_planes = av1_num_planes(cm);
+
+ // Restore the buffers for intra prediction
+ restore_dst_buf(xd, *orig_dst, num_planes);
+ mbmi->ref_frame[1] = INTRA_FRAME;
+ INTERINTRA_MODE best_interintra_mode =
+ args->inter_intra_mode[mbmi->ref_frame[0]];
+
+ // Compute smooth_interintra
+ int64_t best_interintra_rd_nowedge = INT64_MAX;
+ if (try_smooth_interintra) {
+ mbmi->use_wedge_interintra = 0;
+ int interintra_mode_reuse = 1;
+ if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
+ best_interintra_mode == INTERINTRA_MODES) {
+ interintra_mode_reuse = 0;
+ int64_t best_interintra_rd = INT64_MAX;
+ for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
+ ++cur_mode) {
+ if ((!cpi->oxcf.enable_smooth_intra ||
+ cpi->sf.intra_sf.disable_smooth_intra) &&
+ cur_mode == II_SMOOTH_PRED)
+ continue;
+ compute_best_interintra_mode(cpi, mbmi, xd, x, interintra_mode_cost,
+ orig_dst, intrapred, tmp_buf,
+ &best_interintra_mode, &best_interintra_rd,
+ cur_mode, bsize);
+ }
+ args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
+ }
+ assert(IMPLIES(!cpi->oxcf.enable_smooth_interintra ||
+ cpi->sf.inter_sf.disable_smooth_interintra,
+ best_interintra_mode != II_SMOOTH_PRED));
+ int rmode = interintra_mode_cost[best_interintra_mode];
+ // Recompute prediction if required
+ if (interintra_mode_reuse || best_interintra_mode != INTERINTRA_MODES - 1) {
+ mbmi->interintra_mode = best_interintra_mode;
+ av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
+ intrapred, bw);
+ av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
+ }
+
+ // Compute rd cost for best smooth_interintra
+ RD_STATS rd_stats;
+ int total_mode_rate;
+ const int64_t rd_thresh = compute_total_rate_and_rd_thresh(
+ x, rate_mv, &total_mode_rate, bsize, ref_best_rd, rmode);
+ rd = estimate_yrd_for_sb(cpi, bsize, x, rd_thresh, &rd_stats);
+ if (rd != INT64_MAX) {
+ rd = RDCOST(x->rdmult, total_mode_rate + rd_stats.rate, rd_stats.dist);
+ } else {
+ return -1;
+ }
+ best_interintra_rd_nowedge = rd;
+ // Return early if best_interintra_rd_nowedge not good enough
+ if (ref_best_rd < INT64_MAX &&
+ (best_interintra_rd_nowedge >> INTER_INTRA_RD_THRESH_SHIFT) *
+ INTER_INTRA_RD_THRESH_SCALE >
+ ref_best_rd) {
+ return -1;
+ }
+ }
+
+ // Compute wedge interintra
+ int64_t best_interintra_rd_wedge = INT64_MAX;
+ if (try_wedge_interintra) {
+ mbmi->use_wedge_interintra = 1;
+ if (!cpi->sf.inter_sf.fast_interintra_wedge_search) {
+ // Exhaustive search of all wedge and mode combinations.
+ int best_mode = 0;
+ int best_wedge_index = 0;
+ best_interintra_rd_wedge = compute_best_wedge_interintra(
+ cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred_,
+ tmp_buf_, &best_mode, &best_wedge_index, bsize);
+ mbmi->interintra_mode = best_mode;
+ mbmi->interintra_wedge_index = best_wedge_index;
+ if (best_mode != INTERINTRA_MODES - 1) {
+ av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
+ intrapred, bw);
+ }
+ } else if (!try_smooth_interintra) {
+ if (best_interintra_mode == INTERINTRA_MODES) {
+ mbmi->interintra_mode = INTERINTRA_MODES - 1;
+ best_interintra_mode = INTERINTRA_MODES - 1;
+ av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
+ intrapred, bw);
+ // Pick wedge mask based on INTERINTRA_MODES - 1
+ best_interintra_rd_wedge =
+ pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
+ // Find the best interintra mode for the chosen wedge mask
+ for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
+ ++cur_mode) {
+ compute_best_interintra_mode(
+ cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred,
+ tmp_buf, &best_interintra_mode, &best_interintra_rd_wedge,
+ cur_mode, bsize);
+ }
+ args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
+ mbmi->interintra_mode = best_interintra_mode;
+
+ // Recompute prediction if required
+ if (best_interintra_mode != INTERINTRA_MODES - 1) {
+ av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
+ intrapred, bw);
+ }
+ } else {
+ // Pick wedge mask for the best interintra mode (reused)
+ mbmi->interintra_mode = best_interintra_mode;
+ av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
+ intrapred, bw);
+ best_interintra_rd_wedge =
+ pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
+ }
+ } else {
+ // Pick wedge mask for the best interintra mode from smooth_interintra
+ best_interintra_rd_wedge =
+ pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
+ }
+
+ const int rate_overhead =
+ interintra_mode_cost[mbmi->interintra_mode] +
+ x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index] +
+ x->wedge_interintra_cost[bsize][1];
+ best_interintra_rd_wedge += RDCOST(x->rdmult, rate_overhead + *rate_mv, 0);
+
+ const int_mv mv0 = mbmi->mv[0];
+ int_mv tmp_mv = mv0;
+ rd = INT64_MAX;
+ int tmp_rate_mv = 0;
+ // Refine motion vector for NEWMV case.
+ if (have_newmv_in_inter_mode(mbmi->mode)) {
+ int rate_sum, skip_txfm_sb;
+ int64_t dist_sum, skip_sse_sb;
+ // get negative of mask
+ const uint8_t *mask =
+ av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
+ compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, intrapred,
+ mask, bw, &tmp_rate_mv, 0);
+ if (mbmi->mv[0].as_int != tmp_mv.as_int) {
+ mbmi->mv[0].as_int = tmp_mv.as_int;
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
+ cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
+ &skip_sse_sb, NULL, NULL, NULL);
+ rd =
+ RDCOST(x->rdmult, tmp_rate_mv + rate_overhead + rate_sum, dist_sum);
+ }
+ }
+ if (rd >= best_interintra_rd_wedge) {
+ tmp_mv.as_int = mv0.as_int;
+ tmp_rate_mv = *rate_mv;
+ av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
+ }
+ // Evaluate closer to true rd
+ RD_STATS rd_stats;
+ const int64_t mode_rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv, 0);
+ const int64_t tmp_rd_thresh = best_interintra_rd_nowedge - mode_rd;
+ rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
+ if (rd != INT64_MAX) {
+ rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv + rd_stats.rate,
+ rd_stats.dist);
+ } else {
+ if (best_interintra_rd_nowedge == INT64_MAX) return -1;
+ }
+ best_interintra_rd_wedge = rd;
+ if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
+ mbmi->mv[0].as_int = tmp_mv.as_int;
+ *tmp_rate2 += tmp_rate_mv - *rate_mv;
+ *rate_mv = tmp_rate_mv;
+ } else {
+ mbmi->use_wedge_interintra = 0;
+ mbmi->interintra_mode = best_interintra_mode;
+ mbmi->mv[0].as_int = mv0.as_int;
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ }
+ }
+
+ if (best_interintra_rd_nowedge == INT64_MAX &&
+ best_interintra_rd_wedge == INT64_MAX) {
+ return -1;
+ }
+
+ if (num_planes > 1) {
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+ AOM_PLANE_U, num_planes - 1);
+ }
+ return 0;
+}
+
+static void alloc_compound_type_rd_buffers_no_check(
+ CompoundTypeRdBuffers *const bufs) {
+ bufs->pred0 =
+ (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred0));
+ bufs->pred1 =
+ (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred1));
+ bufs->residual1 =
+ (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->residual1));
+ bufs->diff10 =
+ (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->diff10));
+ bufs->tmp_best_mask_buf = (uint8_t *)aom_malloc(
+ 2 * MAX_SB_SQUARE * sizeof(*bufs->tmp_best_mask_buf));
+}
+
+// Computes the valid compound_types to be evaluated
+static INLINE int compute_valid_comp_types(
+ MACROBLOCK *x, const AV1_COMP *const cpi, int *try_average_and_distwtd_comp,
+ BLOCK_SIZE bsize, int masked_compound_used, int mode_search_mask,
+ COMPOUND_TYPE *valid_comp_types) {
+ const AV1_COMMON *cm = &cpi->common;
+ int valid_type_count = 0;
+ int comp_type, valid_check;
+ int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
+
+ const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
+ const int try_distwtd_comp =
+ ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
+ cm->seq_params.order_hint_info.enable_dist_wtd_comp == 1 &&
+ cpi->sf.inter_sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
+ *try_average_and_distwtd_comp = try_average_comp && try_distwtd_comp;
+
+ // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
+ for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
+ comp_type++) {
+ valid_check =
+ (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
+ if (!*try_average_and_distwtd_comp && valid_check &&
+ is_interinter_compound_used(comp_type, bsize))
+ valid_comp_types[valid_type_count++] = comp_type;
+ }
+ // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
+ if (masked_compound_used) {
+ // enable_masked_type[0] corresponds to COMPOUND_WEDGE
+ // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
+ enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
+ enable_masked_type[1] = cpi->oxcf.enable_diff_wtd_comp;
+ for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
+ comp_type++) {
+ if ((mode_search_mask & (1 << comp_type)) &&
+ is_interinter_compound_used(comp_type, bsize) &&
+ enable_masked_type[comp_type - COMPOUND_WEDGE])
+ valid_comp_types[valid_type_count++] = comp_type;
+ }
+ }
+ return valid_type_count;
+}
+
+// Calculates the cost for compound type mask
+static INLINE void calc_masked_type_cost(MACROBLOCK *x, BLOCK_SIZE bsize,
+ int comp_group_idx_ctx,
+ int comp_index_ctx,
+ int masked_compound_used,
+ int *masked_type_cost) {
+ av1_zero_array(masked_type_cost, COMPOUND_TYPES);
+ // Account for group index cost when wedge and/or diffwtd prediction are
+ // enabled
+ if (masked_compound_used) {
+ // Compound group index of average and distwtd is 0
+ // Compound group index of wedge and diffwtd is 1
+ masked_type_cost[COMPOUND_AVERAGE] +=
+ x->comp_group_idx_cost[comp_group_idx_ctx][0];
+ masked_type_cost[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_AVERAGE];
+ masked_type_cost[COMPOUND_WEDGE] +=
+ x->comp_group_idx_cost[comp_group_idx_ctx][1];
+ masked_type_cost[COMPOUND_DIFFWTD] += masked_type_cost[COMPOUND_WEDGE];
+ }
+
+ // Compute the cost to signal compound index/type
+ masked_type_cost[COMPOUND_AVERAGE] += x->comp_idx_cost[comp_index_ctx][1];
+ masked_type_cost[COMPOUND_DISTWTD] += x->comp_idx_cost[comp_index_ctx][0];
+ masked_type_cost[COMPOUND_WEDGE] += x->compound_type_cost[bsize][0];
+ masked_type_cost[COMPOUND_DIFFWTD] += x->compound_type_cost[bsize][1];
+}
+
+// Updates mbmi structure with the relevant compound type info
+static INLINE void update_mbmi_for_compound_type(MB_MODE_INFO *mbmi,
+ COMPOUND_TYPE cur_type) {
+ mbmi->interinter_comp.type = cur_type;
+ mbmi->comp_group_idx = (cur_type >= COMPOUND_WEDGE);
+ mbmi->compound_idx = (cur_type != COMPOUND_DISTWTD);
+}
+
+// When match is found, populate the compound type data
+// and calculate the rd cost using the stored stats and
+// update the mbmi appropriately.
+static INLINE int populate_reuse_comp_type_data(
+ const MACROBLOCK *x, MB_MODE_INFO *mbmi,
+ BEST_COMP_TYPE_STATS *best_type_stats, int_mv *cur_mv, int32_t *comp_rate,
+ int64_t *comp_dist, int *comp_rs2, int *rate_mv, int64_t *rd,
+ int match_index) {
+ const int winner_comp_type =
+ x->comp_rd_stats[match_index].interinter_comp.type;
+ if (comp_rate[winner_comp_type] == INT_MAX)
+ return best_type_stats->best_compmode_interinter_cost;
+ update_mbmi_for_compound_type(mbmi, winner_comp_type);
+ mbmi->interinter_comp = x->comp_rd_stats[match_index].interinter_comp;
+ *rd = RDCOST(
+ x->rdmult,
+ comp_rs2[winner_comp_type] + *rate_mv + comp_rate[winner_comp_type],
+ comp_dist[winner_comp_type]);
+ mbmi->mv[0].as_int = cur_mv[0].as_int;
+ mbmi->mv[1].as_int = cur_mv[1].as_int;
+ return comp_rs2[winner_comp_type];
+}
+
+// Updates rd cost and relevant compound type data for the best compound type
+static INLINE void update_best_info(const MB_MODE_INFO *const mbmi, int64_t *rd,
+ BEST_COMP_TYPE_STATS *best_type_stats,
+ int64_t best_rd_cur,
+ int64_t comp_model_rd_cur, int rs2) {
+ *rd = best_rd_cur;
+ best_type_stats->comp_best_model_rd = comp_model_rd_cur;
+ best_type_stats->best_compound_data = mbmi->interinter_comp;
+ best_type_stats->best_compmode_interinter_cost = rs2;
+}
+
+// Updates best_mv for masked compound types
+static INLINE void update_mask_best_mv(const MB_MODE_INFO *const mbmi,
+ int_mv *best_mv, int_mv *cur_mv,
+ const COMPOUND_TYPE cur_type,
+ int *best_tmp_rate_mv, int tmp_rate_mv,
+ const SPEED_FEATURES *const sf) {
+ if (cur_type == COMPOUND_WEDGE ||
+ (sf->inter_sf.enable_interinter_diffwtd_newmv_search &&
+ cur_type == COMPOUND_DIFFWTD)) {
+ *best_tmp_rate_mv = tmp_rate_mv;
+ best_mv[0].as_int = mbmi->mv[0].as_int;
+ best_mv[1].as_int = mbmi->mv[1].as_int;
+ } else {
+ best_mv[0].as_int = cur_mv[0].as_int;
+ best_mv[1].as_int = cur_mv[1].as_int;
+ }
+}
+
+// Choose the better of the two COMPOUND_AVERAGE,
+// COMPOUND_DISTWTD based on modeled cost
+static int find_best_avg_distwtd_comp_type(MACROBLOCK *x, int *comp_model_rate,
+ int64_t *comp_model_dist,
+ int rate_mv, int64_t *best_rd) {
+ int64_t est_rd[2];
+ est_rd[COMPOUND_AVERAGE] =
+ RDCOST(x->rdmult, comp_model_rate[COMPOUND_AVERAGE] + rate_mv,
+ comp_model_dist[COMPOUND_AVERAGE]);
+ est_rd[COMPOUND_DISTWTD] =
+ RDCOST(x->rdmult, comp_model_rate[COMPOUND_DISTWTD] + rate_mv,
+ comp_model_dist[COMPOUND_DISTWTD]);
+ int best_type = (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD])
+ ? COMPOUND_AVERAGE
+ : COMPOUND_DISTWTD;
+ *best_rd = est_rd[best_type];
+ return best_type;
+}
+
+static INLINE void save_comp_rd_search_stat(
+ MACROBLOCK *x, const MB_MODE_INFO *const mbmi, const int32_t *comp_rate,
+ const int64_t *comp_dist, const int32_t *comp_model_rate,
+ const int64_t *comp_model_dist, const int_mv *cur_mv, const int *comp_rs2) {
+ const int offset = x->comp_rd_stats_idx;
+ if (offset < MAX_COMP_RD_STATS) {
+ COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
+ memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
+ memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
+ memcpy(rd_stats->model_rate, comp_model_rate, sizeof(rd_stats->model_rate));
+ memcpy(rd_stats->model_dist, comp_model_dist, sizeof(rd_stats->model_dist));
+ memcpy(rd_stats->comp_rs2, comp_rs2, sizeof(rd_stats->comp_rs2));
+ memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
+ memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
+ rd_stats->mode = mbmi->mode;
+ rd_stats->filter = mbmi->interp_filters;
+ rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ for (int i = 0; i < 2; ++i) {
+ const WarpedMotionParams *const wm =
+ &xd->global_motion[mbmi->ref_frame[i]];
+ rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
+ }
+ memcpy(&rd_stats->interinter_comp, &mbmi->interinter_comp,
+ sizeof(rd_stats->interinter_comp));
+ ++x->comp_rd_stats_idx;
+ }
+}
+
+static INLINE int get_interinter_compound_mask_rate(
+ const MACROBLOCK *const x, const MB_MODE_INFO *const mbmi) {
+ const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
+ // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
+ if (compound_type == COMPOUND_WEDGE) {
+ return av1_is_wedge_used(mbmi->sb_type)
+ ? av1_cost_literal(1) +
+ x->wedge_idx_cost[mbmi->sb_type]
+ [mbmi->interinter_comp.wedge_index]
+ : 0;
+ } else {
+ assert(compound_type == COMPOUND_DIFFWTD);
+ return av1_cost_literal(1);
+ }
+}
+
+// Takes a backup of rate, distortion and model_rd for future reuse
+static INLINE void backup_stats(COMPOUND_TYPE cur_type, int32_t *comp_rate,
+ int64_t *comp_dist, int32_t *comp_model_rate,
+ int64_t *comp_model_dist, int rate_sum,
+ int64_t dist_sum, RD_STATS *rd_stats,
+ int *comp_rs2, int rs2) {
+ comp_rate[cur_type] = rd_stats->rate;
+ comp_dist[cur_type] = rd_stats->dist;
+ comp_model_rate[cur_type] = rate_sum;
+ comp_model_dist[cur_type] = dist_sum;
+ comp_rs2[cur_type] = rs2;
+}
+
+static int64_t masked_compound_type_rd(
+ const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
+ const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
+ int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
+ uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
+ int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
+ int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
+ int64_t *comp_model_dist, const int64_t comp_best_model_rd,
+ int64_t *const comp_model_rd_cur, int *comp_rs2) {
+ const AV1_COMMON *const cm = &cpi->common;
+ MACROBLOCKD *xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ int64_t best_rd_cur = INT64_MAX;
+ int64_t rd = INT64_MAX;
+ const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
+ // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
+ assert(compound_type == COMPOUND_WEDGE || compound_type == COMPOUND_DIFFWTD);
+ int rate_sum, tmp_skip_txfm_sb;
+ int64_t dist_sum, tmp_skip_sse_sb;
+ pick_interinter_mask_type pick_interinter_mask[2] = { pick_interinter_wedge,
+ pick_interinter_seg };
+
+ // TODO(any): Save pred and mask calculation as well into records. However
+ // this may increase memory requirements as compound segment mask needs to be
+ // stored in each record.
+ if (*calc_pred_masked_compound) {
+ get_inter_predictors_masked_compound(x, bsize, preds0, preds1, residual1,
+ diff10, strides);
+ *calc_pred_masked_compound = 0;
+ }
+ if (cpi->sf.inter_sf.prune_wedge_pred_diff_based &&
+ compound_type == COMPOUND_WEDGE) {
+ unsigned int sse;
+ if (is_cur_buf_hbd(xd))
+ (void)cpi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
+ CONVERT_TO_BYTEPTR(*preds1), *strides, &sse);
+ else
+ (void)cpi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides, &sse);
+ const unsigned int mse =
+ ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
+ // If two predictors are very similar, skip wedge compound mode search
+ if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
+ *comp_model_rd_cur = INT64_MAX;
+ return INT64_MAX;
+ }
+ }
+ // Function pointer to pick the appropriate mask
+ // compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
+ // compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
+ best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
+ cpi, x, bsize, *preds0, *preds1, residual1, diff10);
+ *rs2 += get_interinter_compound_mask_rate(x, mbmi);
+ best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
+
+ // Although the true rate_mv might be different after motion search, but it
+ // is unlikely to be the best mode considering the transform rd cost and other
+ // mode overhead cost
+ int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
+ if (mode_rd > rd_thresh) {
+ *comp_model_rd_cur = INT64_MAX;
+ return INT64_MAX;
+ }
+
+ // Compute cost if matching record not found, else, reuse data
+ if (comp_rate[compound_type] == INT_MAX) {
+ // Check whether new MV search for wedge is to be done
+ int wedge_newmv_search =
+ have_newmv_in_inter_mode(this_mode) &&
+ (compound_type == COMPOUND_WEDGE) &&
+ (!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search);
+ int diffwtd_newmv_search =
+ cpi->sf.inter_sf.enable_interinter_diffwtd_newmv_search &&
+ compound_type == COMPOUND_DIFFWTD &&
+ have_newmv_in_inter_mode(this_mode);
+
+ // Search for new MV if needed and build predictor
+ if (wedge_newmv_search) {
+ *out_rate_mv =
+ interinter_compound_motion_search(cpi, x, cur_mv, bsize, this_mode);
+ const int mi_row = xd->mi_row;
+ const int mi_col = xd->mi_col;
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ } else if (diffwtd_newmv_search) {
+ *out_rate_mv =
+ interinter_compound_motion_search(cpi, x, cur_mv, bsize, this_mode);
+ // we need to update the mask according to the new motion vector
+ CompoundTypeRdBuffers tmp_buf;
+ int64_t tmp_rd = INT64_MAX;
+ alloc_compound_type_rd_buffers_no_check(&tmp_buf);
+
+ uint8_t *tmp_preds0[1] = { tmp_buf.pred0 };
+ uint8_t *tmp_preds1[1] = { tmp_buf.pred1 };
+
+ get_inter_predictors_masked_compound(x, bsize, tmp_preds0, tmp_preds1,
+ tmp_buf.residual1, tmp_buf.diff10,
+ strides);
+
+ tmp_rd = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
+ cpi, x, bsize, *tmp_preds0, *tmp_preds1, tmp_buf.residual1,
+ tmp_buf.diff10);
+ // we can reuse rs2 here
+ tmp_rd += RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
+
+ if (tmp_rd >= best_rd_cur) {
+ // restore the motion vector
+ mbmi->mv[0].as_int = cur_mv[0].as_int;
+ mbmi->mv[1].as_int = cur_mv[1].as_int;
+ *out_rate_mv = rate_mv;
+ av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
+ strides, preds1, strides);
+ } else {
+ // build the final prediciton using the updated mv
+ av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, tmp_preds0,
+ strides, tmp_preds1, strides);
+ }
+ av1_release_compound_type_rd_buffers(&tmp_buf);
+ } else {
+ *out_rate_mv = rate_mv;
+ av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
+ preds1, strides);
+ }
+ // Get the RD cost from model RD
+ model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
+ cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
+ &tmp_skip_sse_sb, NULL, NULL, NULL);
+ rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
+ *comp_model_rd_cur = rd;
+ // Override with best if current is worse than best for new MV
+ if (wedge_newmv_search) {
+ if (rd >= best_rd_cur) {
+ mbmi->mv[0].as_int = cur_mv[0].as_int;
+ mbmi->mv[1].as_int = cur_mv[1].as_int;
+ *out_rate_mv = rate_mv;
+ av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
+ strides, preds1, strides);
+ *comp_model_rd_cur = best_rd_cur;
+ }
+ }
+ if (cpi->sf.inter_sf.prune_comp_type_by_model_rd &&
+ (*comp_model_rd_cur > comp_best_model_rd) &&
+ comp_best_model_rd != INT64_MAX) {
+ *comp_model_rd_cur = INT64_MAX;
+ return INT64_MAX;
+ }
+ // Compute RD cost for the current type
+ RD_STATS rd_stats;
+ const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
+ const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
+ rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
+ if (rd != INT64_MAX) {
+ rd =
+ RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
+ // Backup rate and distortion for future reuse
+ backup_stats(compound_type, comp_rate, comp_dist, comp_model_rate,
+ comp_model_dist, rate_sum, dist_sum, &rd_stats, comp_rs2,
+ *rs2);
+ }
+ } else {
+ // Reuse data as matching record is found
+ assert(comp_dist[compound_type] != INT64_MAX);
+ // When disable_interinter_wedge_newmv_search is set, motion refinement is
+ // disabled. Hence rate and distortion can be reused in this case as well
+ assert(IMPLIES(have_newmv_in_inter_mode(this_mode),
+ cpi->sf.inter_sf.disable_interinter_wedge_newmv_search));
+ assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
+ assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
+ *out_rate_mv = rate_mv;
+ // Calculate RD cost based on stored stats
+ rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
+ comp_dist[compound_type]);
+ // Recalculate model rdcost with the updated rate
+ *comp_model_rd_cur =
+ RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_model_rate[compound_type],
+ comp_model_dist[compound_type]);
+ }
+ return rd;
+}
+
+int compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
+ int_mv *cur_mv, int mode_search_mask,
+ int masked_compound_used, const BUFFER_SET *orig_dst,
+ const BUFFER_SET *tmp_dst,
+ const CompoundTypeRdBuffers *buffers, int *rate_mv,
+ int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
+ int *is_luma_interp_done, int64_t rd_thresh) {
+ const AV1_COMMON *cm = &cpi->common;
+ MACROBLOCKD *xd = &x->e_mbd;
+ MB_MODE_INFO *mbmi = xd->mi[0];
+ const PREDICTION_MODE this_mode = mbmi->mode;
+ const int bw = block_size_wide[bsize];
+ int rs2;
+ int_mv best_mv[2];
+ int best_tmp_rate_mv = *rate_mv;
+ BEST_COMP_TYPE_STATS best_type_stats;
+ // Initializing BEST_COMP_TYPE_STATS
+ best_type_stats.best_compound_data.type = COMPOUND_AVERAGE;
+ best_type_stats.best_compmode_interinter_cost = 0;
+ best_type_stats.comp_best_model_rd = INT64_MAX;
+
+ uint8_t *preds0[1] = { buffers->pred0 };
+ uint8_t *preds1[1] = { buffers->pred1 };
+ int strides[1] = { bw };
+ int tmp_rate_mv;
+ const int num_pix = 1 << num_pels_log2_lookup[bsize];
+ const int mask_len = 2 * num_pix * sizeof(uint8_t);
+ COMPOUND_TYPE cur_type;
+ // Local array to store the mask cost for different compound types
+ int masked_type_cost[COMPOUND_TYPES];
+
+ int calc_pred_masked_compound = 1;
+ int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
+ INT64_MAX };
+ int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
+ int comp_rs2[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
+ int32_t comp_model_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX,
+ INT_MAX };
+ int64_t comp_model_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
+ INT64_MAX };
+ int match_index = 0;
+ const int match_found =
+ find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rate,
+ comp_model_dist, comp_rs2, &match_index);
+ best_mv[0].as_int = cur_mv[0].as_int;
+ best_mv[1].as_int = cur_mv[1].as_int;
+ *rd = INT64_MAX;
+ int rate_sum, tmp_skip_txfm_sb;
+ int64_t dist_sum, tmp_skip_sse_sb;
+
+ // Local array to store the valid compound types to be evaluated in the core
+ // loop
+ COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
+ COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
+ };
+ int valid_type_count = 0;
+ int try_average_and_distwtd_comp = 0;
+ // compute_valid_comp_types() returns the number of valid compound types to be
+ // evaluated and populates the same in the local array valid_comp_types[].
+ // It also sets the flag 'try_average_and_distwtd_comp'
+ valid_type_count = compute_valid_comp_types(
+ x, cpi, &try_average_and_distwtd_comp, bsize, masked_compound_used,
+ mode_search_mask, valid_comp_types);
+
+ // The following context indices are independent of compound type
+ const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+
+ // Populates masked_type_cost local array for the 4 compound types
+ calc_masked_type_cost(x, bsize, comp_group_idx_ctx, comp_index_ctx,
+ masked_compound_used, masked_type_cost);
+
+ int64_t comp_model_rd_cur = INT64_MAX;
+ int64_t best_rd_cur = INT64_MAX;
+ const int mi_row = xd->mi_row;
+ const int mi_col = xd->mi_col;
+
+ // If the match is found, calculate the rd cost using the
+ // stored stats and update the mbmi appropriately.
+ if (match_found && cpi->sf.inter_sf.reuse_compound_type_decision) {
+ return populate_reuse_comp_type_data(x, mbmi, &best_type_stats, cur_mv,
+ comp_rate, comp_dist, comp_rs2,
+ rate_mv, rd, match_index);
+ }
+ // Special handling if both compound_average and compound_distwtd
+ // are to be searched. In this case, first estimate between the two
+ // modes and then call estimate_yrd_for_sb() only for the better of
+ // the two.
+ if (try_average_and_distwtd_comp) {
+ int est_rate[2];
+ int64_t est_dist[2], est_rd;
+ COMPOUND_TYPE best_type;
+ // Since modelled rate and dist are separately stored,
+ // compute better of COMPOUND_AVERAGE and COMPOUND_DISTWTD
+ // using the stored stats.
+ if ((comp_model_rate[COMPOUND_AVERAGE] != INT_MAX) &&
+ comp_model_rate[COMPOUND_DISTWTD] != INT_MAX) {
+ // Choose the better of the COMPOUND_AVERAGE,
+ // COMPOUND_DISTWTD on modeled cost.
+ best_type = find_best_avg_distwtd_comp_type(
+ x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
+ update_mbmi_for_compound_type(mbmi, best_type);
+ if (comp_rate[best_type] != INT_MAX)
+ best_rd_cur = RDCOST(
+ x->rdmult,
+ masked_type_cost[best_type] + *rate_mv + comp_rate[best_type],
+ comp_dist[best_type]);
+ comp_model_rd_cur = est_rd;
+ // Update stats for best compound type
+ if (best_rd_cur < *rd) {
+ update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
+ comp_model_rd_cur, masked_type_cost[best_type]);
+ }
+ restore_dst_buf(xd, *tmp_dst, 1);
+ } else {
+ // Calculate model_rd for COMPOUND_AVERAGE and COMPOUND_DISTWTD
+ for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
+ comp_type++) {
+ update_mbmi_for_compound_type(mbmi, comp_type);
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ model_rd_sb_fn[MODELRD_CURVFIT](
+ cpi, bsize, x, xd, 0, 0, &est_rate[comp_type], &est_dist[comp_type],
+ NULL, NULL, NULL, NULL, NULL);
+ est_rate[comp_type] += masked_type_cost[comp_type];
+ comp_model_rate[comp_type] = est_rate[comp_type];
+ comp_model_dist[comp_type] = est_dist[comp_type];
+ if (comp_type == COMPOUND_AVERAGE) {
+ *is_luma_interp_done = 1;
+ restore_dst_buf(xd, *tmp_dst, 1);
+ }
+ }
+ // Choose the better of the two based on modeled cost and call
+ // estimate_yrd_for_sb() for that one.
+ best_type = find_best_avg_distwtd_comp_type(
+ x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
+ update_mbmi_for_compound_type(mbmi, best_type);
+ if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *orig_dst, 1);
+ rs2 = masked_type_cost[best_type];
+ RD_STATS est_rd_stats;
+ const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
+ const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+ const int64_t est_rd_ =
+ estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+
+ if (est_rd_ != INT64_MAX) {
+ best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
+ est_rd_stats.dist);
+ // Backup rate and distortion for future reuse
+ backup_stats(best_type, comp_rate, comp_dist, comp_model_rate,
+ comp_model_dist, est_rate[best_type], est_dist[best_type],
+ &est_rd_stats, comp_rs2, rs2);
+ comp_model_rd_cur = est_rd;
+ }
+ if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
+ // Update stats for best compound type
+ if (best_rd_cur < *rd) {
+ update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
+ comp_model_rd_cur, rs2);
+ }
+ }
+ }
+
+ // If COMPOUND_AVERAGE is not valid, use the spare buffer
+ if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
+
+ // Loop over valid compound types
+ for (int i = 0; i < valid_type_count; i++) {
+ cur_type = valid_comp_types[i];
+ comp_model_rd_cur = INT64_MAX;
+ tmp_rate_mv = *rate_mv;
+ best_rd_cur = INT64_MAX;
+
+ // Case COMPOUND_AVERAGE and COMPOUND_DISTWTD
+ if (cur_type < COMPOUND_WEDGE) {
+ update_mbmi_for_compound_type(mbmi, cur_type);
+ rs2 = masked_type_cost[cur_type];
+ const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
+ if (mode_rd < ref_best_rd) {
+ // Reuse data if matching record is found
+ if (comp_rate[cur_type] == INT_MAX) {
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
+
+ // Compute RD cost for the current type
+ RD_STATS est_rd_stats;
+ const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+ const int64_t est_rd =
+ estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+ if (est_rd != INT64_MAX) {
+ best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
+ est_rd_stats.dist);
+ model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
+ cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
+ &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
+ comp_model_rd_cur =
+ RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
+
+ // Backup rate and distortion for future reuse
+ backup_stats(cur_type, comp_rate, comp_dist, comp_model_rate,
+ comp_model_dist, rate_sum, dist_sum, &est_rd_stats,
+ comp_rs2, rs2);
+ }
+ } else {
+ // Calculate RD cost based on stored stats
+ assert(comp_dist[cur_type] != INT64_MAX);
+ best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
+ comp_dist[cur_type]);
+ // Recalculate model rdcost with the updated rate
+ comp_model_rd_cur =
+ RDCOST(x->rdmult, rs2 + *rate_mv + comp_model_rate[cur_type],
+ comp_model_dist[cur_type]);
+ }
+ }
+ // use spare buffer for following compound type try
+ if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
+ } else {
+ // Handle masked compound types
+ update_mbmi_for_compound_type(mbmi, cur_type);
+ rs2 = masked_type_cost[cur_type];
+ // Evaluate COMPOUND_WEDGE / COMPOUND_DIFFWTD if approximated cost is
+ // within threshold
+ int64_t approx_rd = ((*rd / cpi->max_comp_type_rd_threshold_div) *
+ cpi->max_comp_type_rd_threshold_mul);
+
+ if (approx_rd < ref_best_rd) {
+ const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
+ best_rd_cur = masked_compound_type_rd(
+ cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
+ &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
+ strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
+ comp_rate, comp_dist, comp_model_rate, comp_model_dist,
+ best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2);
+ }
+ }
+ // Update stats for best compound type
+ if (best_rd_cur < *rd) {
+ update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
+ comp_model_rd_cur, rs2);
+ if (masked_compound_used && cur_type >= COMPOUND_WEDGE) {
+ memcpy(buffers->tmp_best_mask_buf, xd->seg_mask, mask_len);
+ if (have_newmv_in_inter_mode(this_mode))
+ update_mask_best_mv(mbmi, best_mv, cur_mv, cur_type,
+ &best_tmp_rate_mv, tmp_rate_mv, &cpi->sf);
+ }
+ }
+ // reset to original mvs for next iteration
+ mbmi->mv[0].as_int = cur_mv[0].as_int;
+ mbmi->mv[1].as_int = cur_mv[1].as_int;
+ }
+ if (mbmi->interinter_comp.type != best_type_stats.best_compound_data.type) {
+ mbmi->comp_group_idx =
+ (best_type_stats.best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
+ mbmi->compound_idx =
+ !(best_type_stats.best_compound_data.type == COMPOUND_DISTWTD);
+ mbmi->interinter_comp = best_type_stats.best_compound_data;
+ memcpy(xd->seg_mask, buffers->tmp_best_mask_buf, mask_len);
+ }
+ if (have_newmv_in_inter_mode(this_mode)) {
+ mbmi->mv[0].as_int = best_mv[0].as_int;
+ mbmi->mv[1].as_int = best_mv[1].as_int;
+ if (mbmi->interinter_comp.type == COMPOUND_WEDGE) {
+ rd_stats->rate += best_tmp_rate_mv - *rate_mv;
+ *rate_mv = best_tmp_rate_mv;
+ }
+ }
+ restore_dst_buf(xd, *orig_dst, 1);
+ if (!match_found)
+ save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rate,
+ comp_model_dist, cur_mv, comp_rs2);
+ return best_type_stats.best_compmode_interinter_cost;
+}
diff --git a/av1/encoder/compound_type.h b/av1/encoder/compound_type.h
new file mode 100644
index 0000000..a10e664
--- /dev/null
+++ b/av1/encoder/compound_type.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_AV1_ENCODER_COMPOUND_TYPE_H_
+#define AOM_AV1_ENCODER_COMPOUND_TYPE_H_
+
+#include "av1/encoder/encoder.h"
+#include "av1/encoder/interp_search.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Structure to store the compound type related stats for best compound type
+typedef struct {
+ INTERINTER_COMPOUND_DATA best_compound_data;
+ int64_t comp_best_model_rd;
+ int best_compmode_interinter_cost;
+} BEST_COMP_TYPE_STATS;
+
+int handle_inter_intra_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
+ BLOCK_SIZE bsize, MB_MODE_INFO *mbmi,
+ HandleInterModeArgs *args, int64_t ref_best_rd,
+ int *rate_mv, int *tmp_rate2,
+ const BUFFER_SET *orig_dst);
+
+int compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
+ int_mv *cur_mv, int mode_search_mask,
+ int masked_compound_used, const BUFFER_SET *orig_dst,
+ const BUFFER_SET *tmp_dst,
+ const CompoundTypeRdBuffers *buffers, int *rate_mv,
+ int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
+ int *is_luma_interp_done, int64_t rd_thresh);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // AOM_AV1_ENCODER_COMPOUND_TYPE_H_
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index d41d971..3d2ef4c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -44,6 +44,7 @@
#include "av1/encoder/aq_variance.h"
#include "av1/encoder/av1_quantize.h"
#include "av1/encoder/cost.h"
+#include "av1/encoder/compound_type.h"
#include "av1/encoder/encodemb.h"
#include "av1/encoder/encodemv.h"
#include "av1/encoder/encoder.h"
@@ -72,13 +73,6 @@
MV_REFERENCE_FRAME ref_frame[2];
} MODE_DEFINITION;
-// Structure to store the compound type related stats for best compound type
-typedef struct {
- INTERINTER_COMPOUND_DATA best_compound_data;
- int64_t comp_best_model_rd;
- int best_compmode_interinter_cost;
-} BEST_COMP_TYPE_STATS;
-
#define LAST_NEW_MV_INDEX 6
// This array defines the mapping from the enums in THR_MODES to the actual
// prediction modes and refrence frames
@@ -552,20 +546,6 @@
MV_REFERENCE_FRAME single_rd_order[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
} InterModeSearchState;
-static void alloc_compound_type_rd_buffers_no_check(
- CompoundTypeRdBuffers *const bufs) {
- bufs->pred0 =
- (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred0));
- bufs->pred1 =
- (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred1));
- bufs->residual1 =
- (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->residual1));
- bufs->diff10 =
- (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->diff10));
- bufs->tmp_best_mask_buf = (uint8_t *)aom_malloc(
- 2 * MAX_SB_SQUARE * sizeof(*bufs->tmp_best_mask_buf));
-}
-
void av1_inter_mode_data_init(TileDataEnc *tile_data) {
for (int i = 0; i < BLOCK_SIZES_ALL; ++i) {
InterModeRdModel *md = &tile_data->inter_mode_rd_models[i];
@@ -1353,32 +1333,6 @@
return n;
}
-static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
- MACROBLOCK *x, int64_t ref_best_rd,
- RD_STATS *rd_stats) {
- MACROBLOCKD *const xd = &x->e_mbd;
- if (ref_best_rd < 0) return INT64_MAX;
- av1_subtract_plane(x, bs, 0);
- x->rd_model = LOW_TXFM_RD;
- int skip_trellis = cpi->optimize_seg_arr[xd->mi[0]->segment_id] ==
- NO_ESTIMATE_YRD_TRELLIS_OPT;
- const int64_t rd =
- txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs, max_txsize_rect_lookup[bs],
- FTXS_NONE, skip_trellis);
- x->rd_model = FULL_TXFM_RD;
- if (rd != INT64_MAX) {
- const int skip_ctx = av1_get_skip_context(xd);
- if (rd_stats->skip) {
- const int s1 = x->skip_cost[skip_ctx][1];
- rd_stats->rate = s1;
- } else {
- const int s0 = x->skip_cost[skip_ctx][0];
- rd_stats->rate += s0;
- }
- }
- return rd;
-}
-
// Return the rate cost for luma prediction mode info. of intra blocks.
static int intra_mode_info_cost_y(const AV1_COMP *cpi, const MACROBLOCK *x,
const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
@@ -3205,22 +3159,6 @@
}
}
-static INLINE int get_interinter_compound_mask_rate(
- const MACROBLOCK *const x, const MB_MODE_INFO *const mbmi) {
- const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
- // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
- if (compound_type == COMPOUND_WEDGE) {
- return av1_is_wedge_used(mbmi->sb_type)
- ? av1_cost_literal(1) +
- x->wedge_idx_cost[mbmi->sb_type]
- [mbmi->interinter_comp.wedge_index]
- : 0;
- } else {
- assert(compound_type == COMPOUND_DIFFWTD);
- return av1_cost_literal(1);
- }
-}
-
static INLINE int mv_check_bounds(const MvLimits *mv_limits, const MV *mv) {
return (mv->row >> 3) < mv_limits->row_min ||
(mv->row >> 3) > mv_limits->row_max ||
@@ -3469,525 +3407,6 @@
xd->mb_to_bottom_edge + RIGHT_BOTTOM_MARGIN);
}
-static int8_t estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
- const BLOCK_SIZE bsize, const uint8_t *pred0,
- int stride0, const uint8_t *pred1,
- int stride1) {
- static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
- // 4X4
- BLOCK_INVALID,
- // 4X8, 8X4, 8X8
- BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
- // 8X16, 16X8, 16X16
- BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
- // 16X32, 32X16, 32X32
- BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
- // 32X64, 64X32, 64X64
- BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
- // 64x128, 128x64, 128x128
- BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
- // 4X16, 16X4, 8X32
- BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
- // 32X8, 16X64, 64X16
- BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
- };
- const struct macroblock_plane *const p = &x->plane[0];
- const uint8_t *src = p->src.buf;
- int src_stride = p->src.stride;
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- const int bw_by2 = bw >> 1;
- const int bh_by2 = bh >> 1;
- uint32_t esq[2][2];
- int64_t tl, br;
-
- const BLOCK_SIZE f_index = split_qtr[bsize];
- assert(f_index != BLOCK_INVALID);
-
- if (is_cur_buf_hbd(&x->e_mbd)) {
- pred0 = CONVERT_TO_BYTEPTR(pred0);
- pred1 = CONVERT_TO_BYTEPTR(pred1);
- }
-
- // Residual variance computation over relevant quandrants in order to
- // find TL + BR, TL = sum(1st,2nd,3rd) quadrants of (pred0 - pred1),
- // BR = sum(2nd,3rd,4th) quadrants of (pred1 - pred0)
- // The 2nd and 3rd quadrants cancel out in TL + BR
- // Hence TL + BR = 1st quadrant of (pred0-pred1) + 4th of (pred1-pred0)
- // TODO(nithya): Sign estimation assumes 45 degrees (1st and 4th quadrants)
- // for all codebooks; experiment with other quadrant combinations for
- // 0, 90 and 135 degrees also.
- cpi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
- cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
- pred0 + bh_by2 * stride0 + bw_by2, stride0,
- &esq[0][1]);
- cpi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
- cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
- pred1 + bh_by2 * stride1 + bw_by2, stride0,
- &esq[1][1]);
-
- tl = ((int64_t)esq[0][0]) - ((int64_t)esq[1][0]);
- br = ((int64_t)esq[1][1]) - ((int64_t)esq[0][1]);
- return (tl + br > 0);
-}
-
-// Choose the best wedge index and sign
-static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
- const BLOCK_SIZE bsize, const uint8_t *const p0,
- const int16_t *const residual1,
- const int16_t *const diff10,
- int8_t *const best_wedge_sign,
- int8_t *const best_wedge_index) {
- const MACROBLOCKD *const xd = &x->e_mbd;
- const struct buf_2d *const src = &x->plane[0].src;
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- const int N = bw * bh;
- assert(N >= 64);
- int rate;
- int64_t dist;
- int64_t rd, best_rd = INT64_MAX;
- int8_t wedge_index;
- int8_t wedge_sign;
- const int8_t wedge_types = get_wedge_types_lookup(bsize);
- const uint8_t *mask;
- uint64_t sse;
- const int hbd = is_cur_buf_hbd(xd);
- const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
-
- DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]); // src - pred0
-#if CONFIG_AV1_HIGHBITDEPTH
- if (hbd) {
- aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
- CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
- } else {
- aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
- }
-#else
- (void)hbd;
- aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
-#endif
-
- int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
- (int64_t)aom_sum_squares_i16(residual1, N)) *
- (1 << WEDGE_WEIGHT_BITS) / 2;
- int16_t *ds = residual0;
-
- av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
-
- for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
- mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
-
- wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
-
- mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
- sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
- sse = ROUND_POWER_OF_TWO(sse, bd_round);
-
- model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
- &rate, &dist);
- // int rate2;
- // int64_t dist2;
- // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
- // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
- // sse, rate, dist, rate2, dist2); dist = dist2;
- // rate = rate2;
-
- rate += x->wedge_idx_cost[bsize][wedge_index];
- rd = RDCOST(x->rdmult, rate, dist);
-
- if (rd < best_rd) {
- *best_wedge_index = wedge_index;
- *best_wedge_sign = wedge_sign;
- best_rd = rd;
- }
- }
-
- return best_rd -
- RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
-}
-
-// Choose the best wedge index the specified sign
-static int64_t pick_wedge_fixed_sign(const AV1_COMP *const cpi,
- const MACROBLOCK *const x,
- const BLOCK_SIZE bsize,
- const int16_t *const residual1,
- const int16_t *const diff10,
- const int8_t wedge_sign,
- int8_t *const best_wedge_index) {
- const MACROBLOCKD *const xd = &x->e_mbd;
-
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- const int N = bw * bh;
- assert(N >= 64);
- int rate;
- int64_t dist;
- int64_t rd, best_rd = INT64_MAX;
- int8_t wedge_index;
- const int8_t wedge_types = get_wedge_types_lookup(bsize);
- const uint8_t *mask;
- uint64_t sse;
- const int hbd = is_cur_buf_hbd(xd);
- const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
- for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
- mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
- sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
- sse = ROUND_POWER_OF_TWO(sse, bd_round);
-
- model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
- &rate, &dist);
- rate += x->wedge_idx_cost[bsize][wedge_index];
- rd = RDCOST(x->rdmult, rate, dist);
-
- if (rd < best_rd) {
- *best_wedge_index = wedge_index;
- best_rd = rd;
- }
- }
- return best_rd -
- RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
-}
-
-static int64_t pick_interinter_wedge(
- const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
- const uint8_t *const p0, const uint8_t *const p1,
- const int16_t *const residual1, const int16_t *const diff10) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const int bw = block_size_wide[bsize];
-
- int64_t rd;
- int8_t wedge_index = -1;
- int8_t wedge_sign = 0;
-
- assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
- assert(cpi->common.seq_params.enable_masked_compound);
-
- if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
- wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
- rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
- &wedge_index);
- } else {
- rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
- &wedge_index);
- }
-
- mbmi->interinter_comp.wedge_sign = wedge_sign;
- mbmi->interinter_comp.wedge_index = wedge_index;
- return rd;
-}
-
-static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
- MACROBLOCK *const x, const BLOCK_SIZE bsize,
- const uint8_t *const p0,
- const uint8_t *const p1,
- const int16_t *const residual1,
- const int16_t *const diff10) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- const int N = 1 << num_pels_log2_lookup[bsize];
- int rate;
- int64_t dist;
- DIFFWTD_MASK_TYPE cur_mask_type;
- int64_t best_rd = INT64_MAX;
- DIFFWTD_MASK_TYPE best_mask_type = 0;
- const int hbd = is_cur_buf_hbd(xd);
- const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
- DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
- uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
- // try each mask type and its inverse
- for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
- // build mask and inverse
- if (hbd)
- av1_build_compound_diffwtd_mask_highbd(
- tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
- CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
- else
- av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
- p0, bw, p1, bw, bh, bw);
-
- // compute rd for mask
- uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
- tmp_mask[cur_mask_type], N);
- sse = ROUND_POWER_OF_TWO(sse, bd_round);
-
- model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
- &rate, &dist);
- const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
-
- if (rd0 < best_rd) {
- best_mask_type = cur_mask_type;
- best_rd = rd0;
- }
- }
- mbmi->interinter_comp.mask_type = best_mask_type;
- if (best_mask_type == DIFFWTD_38_INV) {
- memcpy(xd->seg_mask, seg_mask, N * 2);
- }
- return best_rd;
-}
-
-static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
- const MACROBLOCK *const x,
- const BLOCK_SIZE bsize,
- const uint8_t *const p0,
- const uint8_t *const p1) {
- const MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- assert(av1_is_wedge_used(bsize));
- assert(cpi->common.seq_params.enable_interintra_compound);
-
- const struct buf_2d *const src = &x->plane[0].src;
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]); // src - pred1
- DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]); // pred1 - pred0
-#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd)) {
- aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
- CONVERT_TO_BYTEPTR(p1), bw, xd->bd);
- aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
- CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
- } else {
- aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
- aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
- }
-#else
- aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
- aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
-#endif
- int8_t wedge_index = -1;
- int64_t rd =
- pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0, &wedge_index);
-
- mbmi->interintra_wedge_index = wedge_index;
- return rd;
-}
-
-static AOM_INLINE void get_inter_predictors_masked_compound(
- MACROBLOCK *x, const BLOCK_SIZE bsize, uint8_t **preds0, uint8_t **preds1,
- int16_t *residual1, int16_t *diff10, int *strides) {
- MACROBLOCKD *xd = &x->e_mbd;
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- // get inter predictors to use for masked compound modes
- av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0, preds0,
- strides);
- av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1, preds1,
- strides);
- const struct buf_2d *const src = &x->plane[0].src;
-#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd)) {
- aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
- CONVERT_TO_BYTEPTR(*preds1), bw, xd->bd);
- aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
- bw, CONVERT_TO_BYTEPTR(*preds0), bw, xd->bd);
- } else {
- aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
- bw);
- aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
- }
-#else
- aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1, bw);
- aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
-#endif
-}
-
-// Takes a backup of rate, distortion and model_rd for future reuse
-static INLINE void backup_stats(COMPOUND_TYPE cur_type, int32_t *comp_rate,
- int64_t *comp_dist, int32_t *comp_model_rate,
- int64_t *comp_model_dist, int rate_sum,
- int64_t dist_sum, RD_STATS *rd_stats,
- int *comp_rs2, int rs2) {
- comp_rate[cur_type] = rd_stats->rate;
- comp_dist[cur_type] = rd_stats->dist;
- comp_model_rate[cur_type] = rate_sum;
- comp_model_dist[cur_type] = dist_sum;
- comp_rs2[cur_type] = rs2;
-}
-
-static int64_t masked_compound_type_rd(
- const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
- const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
- int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
- uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
- int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
- int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
- int64_t *comp_model_dist, const int64_t comp_best_model_rd,
- int64_t *const comp_model_rd_cur, int *comp_rs2) {
- const AV1_COMMON *const cm = &cpi->common;
- MACROBLOCKD *xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- int64_t best_rd_cur = INT64_MAX;
- int64_t rd = INT64_MAX;
- const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
- // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
- assert(compound_type == COMPOUND_WEDGE || compound_type == COMPOUND_DIFFWTD);
- int rate_sum, tmp_skip_txfm_sb;
- int64_t dist_sum, tmp_skip_sse_sb;
- pick_interinter_mask_type pick_interinter_mask[2] = { pick_interinter_wedge,
- pick_interinter_seg };
-
- // TODO(any): Save pred and mask calculation as well into records. However
- // this may increase memory requirements as compound segment mask needs to be
- // stored in each record.
- if (*calc_pred_masked_compound) {
- get_inter_predictors_masked_compound(x, bsize, preds0, preds1, residual1,
- diff10, strides);
- *calc_pred_masked_compound = 0;
- }
- if (cpi->sf.inter_sf.prune_wedge_pred_diff_based &&
- compound_type == COMPOUND_WEDGE) {
- unsigned int sse;
- if (is_cur_buf_hbd(xd))
- (void)cpi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
- CONVERT_TO_BYTEPTR(*preds1), *strides, &sse);
- else
- (void)cpi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides, &sse);
- const unsigned int mse =
- ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
- // If two predictors are very similar, skip wedge compound mode search
- if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
- *comp_model_rd_cur = INT64_MAX;
- return INT64_MAX;
- }
- }
- // Function pointer to pick the appropriate mask
- // compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
- // compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
- best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
- cpi, x, bsize, *preds0, *preds1, residual1, diff10);
- *rs2 += get_interinter_compound_mask_rate(x, mbmi);
- best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
-
- // Although the true rate_mv might be different after motion search, but it
- // is unlikely to be the best mode considering the transform rd cost and other
- // mode overhead cost
- int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
- if (mode_rd > rd_thresh) {
- *comp_model_rd_cur = INT64_MAX;
- return INT64_MAX;
- }
-
- // Compute cost if matching record not found, else, reuse data
- if (comp_rate[compound_type] == INT_MAX) {
- // Check whether new MV search for wedge is to be done
- int wedge_newmv_search =
- have_newmv_in_inter_mode(this_mode) &&
- (compound_type == COMPOUND_WEDGE) &&
- (!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search);
- int diffwtd_newmv_search =
- cpi->sf.inter_sf.enable_interinter_diffwtd_newmv_search &&
- compound_type == COMPOUND_DIFFWTD &&
- have_newmv_in_inter_mode(this_mode);
-
- // Search for new MV if needed and build predictor
- if (wedge_newmv_search) {
- *out_rate_mv =
- interinter_compound_motion_search(cpi, x, cur_mv, bsize, this_mode);
- const int mi_row = xd->mi_row;
- const int mi_col = xd->mi_col;
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- } else if (diffwtd_newmv_search) {
- *out_rate_mv =
- interinter_compound_motion_search(cpi, x, cur_mv, bsize, this_mode);
- // we need to update the mask according to the new motion vector
- CompoundTypeRdBuffers tmp_buf;
- int64_t tmp_rd = INT64_MAX;
- alloc_compound_type_rd_buffers_no_check(&tmp_buf);
-
- uint8_t *tmp_preds0[1] = { tmp_buf.pred0 };
- uint8_t *tmp_preds1[1] = { tmp_buf.pred1 };
-
- get_inter_predictors_masked_compound(x, bsize, tmp_preds0, tmp_preds1,
- tmp_buf.residual1, tmp_buf.diff10,
- strides);
-
- tmp_rd = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
- cpi, x, bsize, *tmp_preds0, *tmp_preds1, tmp_buf.residual1,
- tmp_buf.diff10);
- // we can reuse rs2 here
- tmp_rd += RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
-
- if (tmp_rd >= best_rd_cur) {
- // restore the motion vector
- mbmi->mv[0].as_int = cur_mv[0].as_int;
- mbmi->mv[1].as_int = cur_mv[1].as_int;
- *out_rate_mv = rate_mv;
- av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
- strides, preds1, strides);
- } else {
- // build the final prediciton using the updated mv
- av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, tmp_preds0,
- strides, tmp_preds1, strides);
- }
- av1_release_compound_type_rd_buffers(&tmp_buf);
- } else {
- *out_rate_mv = rate_mv;
- av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
- preds1, strides);
- }
- // Get the RD cost from model RD
- model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
- cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
- &tmp_skip_sse_sb, NULL, NULL, NULL);
- rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
- *comp_model_rd_cur = rd;
- // Override with best if current is worse than best for new MV
- if (wedge_newmv_search) {
- if (rd >= best_rd_cur) {
- mbmi->mv[0].as_int = cur_mv[0].as_int;
- mbmi->mv[1].as_int = cur_mv[1].as_int;
- *out_rate_mv = rate_mv;
- av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
- strides, preds1, strides);
- *comp_model_rd_cur = best_rd_cur;
- }
- }
- if (cpi->sf.inter_sf.prune_comp_type_by_model_rd &&
- (*comp_model_rd_cur > comp_best_model_rd) &&
- comp_best_model_rd != INT64_MAX) {
- *comp_model_rd_cur = INT64_MAX;
- return INT64_MAX;
- }
- // Compute RD cost for the current type
- RD_STATS rd_stats;
- const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
- const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
- rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
- if (rd != INT64_MAX) {
- rd =
- RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
- // Backup rate and distortion for future reuse
- backup_stats(compound_type, comp_rate, comp_dist, comp_model_rate,
- comp_model_dist, rate_sum, dist_sum, &rd_stats, comp_rs2,
- *rs2);
- }
- } else {
- // Reuse data as matching record is found
- assert(comp_dist[compound_type] != INT64_MAX);
- // When disable_interinter_wedge_newmv_search is set, motion refinement is
- // disabled. Hence rate and distortion can be reused in this case as well
- assert(IMPLIES(have_newmv_in_inter_mode(this_mode),
- cpi->sf.inter_sf.disable_interinter_wedge_newmv_search));
- assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
- assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
- *out_rate_mv = rate_mv;
- // Calculate RD cost based on stored stats
- rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
- comp_dist[compound_type]);
- // Recalculate model rdcost with the updated rate
- *comp_model_rd_cur =
- RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_model_rate[compound_type],
- comp_model_dist[compound_type]);
- }
- return rd;
-}
-
/* If the current mode shares the same mv with other modes with higher cost,
* skip this mode. */
static int skip_repeated_mv(const AV1_COMMON *const cm,
@@ -4157,421 +3576,6 @@
return 0;
}
-// Checks if characteristics of search match
-static INLINE int is_comp_rd_match(const AV1_COMP *const cpi,
- const MACROBLOCK *const x,
- const COMP_RD_STATS *st,
- const MB_MODE_INFO *const mi,
- int32_t *comp_rate, int64_t *comp_dist,
- int32_t *comp_model_rate,
- int64_t *comp_model_dist, int *comp_rs2) {
- // TODO(ranjit): Ensure that compound type search use regular filter always
- // and check if following check can be removed
- // Check if interp filter matches with previous case
- if (st->filter.as_int != mi->interp_filters.as_int) return 0;
-
- const MACROBLOCKD *const xd = &x->e_mbd;
- // Match MV and reference indices
- for (int i = 0; i < 2; ++i) {
- if ((st->ref_frames[i] != mi->ref_frame[i]) ||
- (st->mv[i].as_int != mi->mv[i].as_int)) {
- return 0;
- }
- const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
- if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
- }
-
- // Store the stats for COMPOUND_AVERAGE and COMPOUND_DISTWTD
- for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
- comp_type++) {
- comp_rate[comp_type] = st->rate[comp_type];
- comp_dist[comp_type] = st->dist[comp_type];
- comp_model_rate[comp_type] = st->model_rate[comp_type];
- comp_model_dist[comp_type] = st->model_dist[comp_type];
- comp_rs2[comp_type] = st->comp_rs2[comp_type];
- }
-
- // For compound wedge/segment, reuse data only if NEWMV is not present in
- // either of the directions
- if ((!have_newmv_in_inter_mode(mi->mode) &&
- !have_newmv_in_inter_mode(st->mode)) ||
- (cpi->sf.inter_sf.disable_interinter_wedge_newmv_search)) {
- memcpy(&comp_rate[COMPOUND_WEDGE], &st->rate[COMPOUND_WEDGE],
- sizeof(comp_rate[COMPOUND_WEDGE]) * 2);
- memcpy(&comp_dist[COMPOUND_WEDGE], &st->dist[COMPOUND_WEDGE],
- sizeof(comp_dist[COMPOUND_WEDGE]) * 2);
- memcpy(&comp_model_rate[COMPOUND_WEDGE], &st->model_rate[COMPOUND_WEDGE],
- sizeof(comp_model_rate[COMPOUND_WEDGE]) * 2);
- memcpy(&comp_model_dist[COMPOUND_WEDGE], &st->model_dist[COMPOUND_WEDGE],
- sizeof(comp_model_dist[COMPOUND_WEDGE]) * 2);
- memcpy(&comp_rs2[COMPOUND_WEDGE], &st->comp_rs2[COMPOUND_WEDGE],
- sizeof(comp_rs2[COMPOUND_WEDGE]) * 2);
- }
- return 1;
-}
-
-// Checks if similar compound type search case is accounted earlier
-// If found, returns relevant rd data
-static INLINE int find_comp_rd_in_stats(const AV1_COMP *const cpi,
- const MACROBLOCK *x,
- const MB_MODE_INFO *const mbmi,
- int32_t *comp_rate, int64_t *comp_dist,
- int32_t *comp_model_rate,
- int64_t *comp_model_dist, int *comp_rs2,
- int *match_index) {
- for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
- if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
- comp_dist, comp_model_rate, comp_model_dist,
- comp_rs2)) {
- *match_index = j;
- return 1;
- }
- }
- return 0; // no match result found
-}
-
-static INLINE void save_comp_rd_search_stat(
- MACROBLOCK *x, const MB_MODE_INFO *const mbmi, const int32_t *comp_rate,
- const int64_t *comp_dist, const int32_t *comp_model_rate,
- const int64_t *comp_model_dist, const int_mv *cur_mv, const int *comp_rs2) {
- const int offset = x->comp_rd_stats_idx;
- if (offset < MAX_COMP_RD_STATS) {
- COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
- memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
- memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
- memcpy(rd_stats->model_rate, comp_model_rate, sizeof(rd_stats->model_rate));
- memcpy(rd_stats->model_dist, comp_model_dist, sizeof(rd_stats->model_dist));
- memcpy(rd_stats->comp_rs2, comp_rs2, sizeof(rd_stats->comp_rs2));
- memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
- memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
- rd_stats->mode = mbmi->mode;
- rd_stats->filter = mbmi->interp_filters;
- rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
- const MACROBLOCKD *const xd = &x->e_mbd;
- for (int i = 0; i < 2; ++i) {
- const WarpedMotionParams *const wm =
- &xd->global_motion[mbmi->ref_frame[i]];
- rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
- }
- memcpy(&rd_stats->interinter_comp, &mbmi->interinter_comp,
- sizeof(rd_stats->interinter_comp));
- ++x->comp_rd_stats_idx;
- }
-}
-
-static INLINE bool enable_wedge_search(MACROBLOCK *const x,
- const AV1_COMP *const cpi) {
- // Enable wedge search if source variance and edge strength are above
- // the thresholds.
- return x->source_variance >
- cpi->sf.inter_sf.disable_wedge_search_var_thresh &&
- x->edge_strength > cpi->sf.inter_sf.disable_wedge_search_edge_thresh;
-}
-
-static INLINE bool enable_wedge_interinter_search(MACROBLOCK *const x,
- const AV1_COMP *const cpi) {
- return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interinter_wedge &&
- !cpi->sf.inter_sf.disable_interinter_wedge;
-}
-
-static INLINE bool enable_wedge_interintra_search(MACROBLOCK *const x,
- const AV1_COMP *const cpi) {
- return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interintra_wedge &&
- !cpi->sf.inter_sf.disable_wedge_interintra_search;
-}
-
-// Computes the rd cost for the given interintra mode and updates the best
-static INLINE void compute_best_interintra_mode(
- const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
- MACROBLOCK *const x, const int *const interintra_mode_cost,
- const BUFFER_SET *orig_dst, uint8_t *intrapred, const uint8_t *tmp_buf,
- INTERINTRA_MODE *best_interintra_mode, int64_t *best_interintra_rd,
- INTERINTRA_MODE interintra_mode, BLOCK_SIZE bsize) {
- const AV1_COMMON *const cm = &cpi->common;
- int rate, skip_txfm_sb;
- int64_t dist, skip_sse_sb;
- const int bw = block_size_wide[bsize];
- mbmi->interintra_mode = interintra_mode;
- int rmode = interintra_mode_cost[interintra_mode];
- av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
- intrapred, bw);
- av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
- model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](cpi, bsize, x, xd, 0, 0, &rate, &dist,
- &skip_txfm_sb, &skip_sse_sb, NULL,
- NULL, NULL);
- int64_t rd = RDCOST(x->rdmult, rate + rmode, dist);
- if (rd < *best_interintra_rd) {
- *best_interintra_rd = rd;
- *best_interintra_mode = mbmi->interintra_mode;
- }
-}
-
-// Computes the best wedge interintra mode
-static AOM_INLINE int64_t compute_best_wedge_interintra(
- const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
- MACROBLOCK *const x, const int *const interintra_mode_cost,
- const BUFFER_SET *orig_dst, uint8_t *intrapred_, uint8_t *tmp_buf_,
- int *best_mode, int *best_wedge_index, BLOCK_SIZE bsize) {
- const AV1_COMMON *const cm = &cpi->common;
- const int bw = block_size_wide[bsize];
- int64_t best_interintra_rd_wedge = INT64_MAX;
- int64_t best_total_rd = INT64_MAX;
- uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
- for (INTERINTRA_MODE mode = 0; mode < INTERINTRA_MODES; ++mode) {
- mbmi->interintra_mode = mode;
- av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
- intrapred, bw);
- int64_t rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
- const int rate_overhead =
- interintra_mode_cost[mode] +
- x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index];
- const int64_t total_rd = rd + RDCOST(x->rdmult, rate_overhead, 0);
- if (total_rd < best_total_rd) {
- best_total_rd = total_rd;
- best_interintra_rd_wedge = rd;
- *best_mode = mbmi->interintra_mode;
- *best_wedge_index = mbmi->interintra_wedge_index;
- }
- }
- return best_interintra_rd_wedge;
-}
-
-// Computes the rd_threshold and total_mode_rate
-static AOM_INLINE int64_t compute_total_rate_and_rd_thresh(
- MACROBLOCK *const x, int *rate_mv, int *total_mode_rate, BLOCK_SIZE bsize,
- int64_t ref_best_rd, int rmode) {
- const int is_wedge_used = av1_is_wedge_used(bsize);
- const int64_t rd_thresh = get_rd_thresh_from_best_rd(
- ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
- INTER_INTRA_RD_THRESH_SCALE);
- const int rwedge = is_wedge_used ? x->wedge_interintra_cost[bsize][0] : 0;
- *total_mode_rate = *rate_mv + rmode + rwedge;
- const int64_t mode_rd = RDCOST(x->rdmult, *total_mode_rate, 0);
- return (rd_thresh - mode_rd);
-}
-
-static int handle_inter_intra_mode(const AV1_COMP *const cpi,
- MACROBLOCK *const x, BLOCK_SIZE bsize,
- MB_MODE_INFO *mbmi,
- HandleInterModeArgs *args,
- int64_t ref_best_rd, int *rate_mv,
- int *tmp_rate2, const BUFFER_SET *orig_dst) {
- const int try_smooth_interintra = cpi->oxcf.enable_smooth_interintra &&
- !cpi->sf.inter_sf.disable_smooth_interintra;
- const int try_wedge_interintra =
- av1_is_wedge_used(bsize) && enable_wedge_interintra_search(x, cpi);
- if (!try_smooth_interintra && !try_wedge_interintra) return -1;
-
- const AV1_COMMON *const cm = &cpi->common;
- MACROBLOCKD *xd = &x->e_mbd;
- int64_t rd = INT64_MAX;
- const int bw = block_size_wide[bsize];
- DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
- DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
- uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
- uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
- const int *const interintra_mode_cost =
- x->interintra_mode_cost[size_group_lookup[bsize]];
- const int mi_row = xd->mi_row;
- const int mi_col = xd->mi_col;
-
- // Single reference inter prediction
- mbmi->ref_frame[1] = NONE_FRAME;
- xd->plane[0].dst.buf = tmp_buf;
- xd->plane[0].dst.stride = bw;
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- const int num_planes = av1_num_planes(cm);
-
- // Restore the buffers for intra prediction
- restore_dst_buf(xd, *orig_dst, num_planes);
- mbmi->ref_frame[1] = INTRA_FRAME;
- INTERINTRA_MODE best_interintra_mode =
- args->inter_intra_mode[mbmi->ref_frame[0]];
-
- // Compute smooth_interintra
- int64_t best_interintra_rd_nowedge = INT64_MAX;
- if (try_smooth_interintra) {
- mbmi->use_wedge_interintra = 0;
- int interintra_mode_reuse = 1;
- if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
- best_interintra_mode == INTERINTRA_MODES) {
- interintra_mode_reuse = 0;
- int64_t best_interintra_rd = INT64_MAX;
- for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
- ++cur_mode) {
- if ((!cpi->oxcf.enable_smooth_intra ||
- cpi->sf.intra_sf.disable_smooth_intra) &&
- cur_mode == II_SMOOTH_PRED)
- continue;
- compute_best_interintra_mode(cpi, mbmi, xd, x, interintra_mode_cost,
- orig_dst, intrapred, tmp_buf,
- &best_interintra_mode, &best_interintra_rd,
- cur_mode, bsize);
- }
- args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
- }
- assert(IMPLIES(!cpi->oxcf.enable_smooth_interintra ||
- cpi->sf.inter_sf.disable_smooth_interintra,
- best_interintra_mode != II_SMOOTH_PRED));
- int rmode = interintra_mode_cost[best_interintra_mode];
- // Recompute prediction if required
- if (interintra_mode_reuse || best_interintra_mode != INTERINTRA_MODES - 1) {
- mbmi->interintra_mode = best_interintra_mode;
- av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
- intrapred, bw);
- av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
- }
-
- // Compute rd cost for best smooth_interintra
- RD_STATS rd_stats;
- int total_mode_rate;
- const int64_t rd_thresh = compute_total_rate_and_rd_thresh(
- x, rate_mv, &total_mode_rate, bsize, ref_best_rd, rmode);
- rd = estimate_yrd_for_sb(cpi, bsize, x, rd_thresh, &rd_stats);
- if (rd != INT64_MAX) {
- rd = RDCOST(x->rdmult, total_mode_rate + rd_stats.rate, rd_stats.dist);
- } else {
- return -1;
- }
- best_interintra_rd_nowedge = rd;
- // Return early if best_interintra_rd_nowedge not good enough
- if (ref_best_rd < INT64_MAX &&
- (best_interintra_rd_nowedge >> INTER_INTRA_RD_THRESH_SHIFT) *
- INTER_INTRA_RD_THRESH_SCALE >
- ref_best_rd) {
- return -1;
- }
- }
-
- // Compute wedge interintra
- int64_t best_interintra_rd_wedge = INT64_MAX;
- if (try_wedge_interintra) {
- mbmi->use_wedge_interintra = 1;
- if (!cpi->sf.inter_sf.fast_interintra_wedge_search) {
- // Exhaustive search of all wedge and mode combinations.
- int best_mode = 0;
- int best_wedge_index = 0;
- best_interintra_rd_wedge = compute_best_wedge_interintra(
- cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred_,
- tmp_buf_, &best_mode, &best_wedge_index, bsize);
- mbmi->interintra_mode = best_mode;
- mbmi->interintra_wedge_index = best_wedge_index;
- if (best_mode != INTERINTRA_MODES - 1) {
- av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
- intrapred, bw);
- }
- } else if (!try_smooth_interintra) {
- if (best_interintra_mode == INTERINTRA_MODES) {
- mbmi->interintra_mode = INTERINTRA_MODES - 1;
- best_interintra_mode = INTERINTRA_MODES - 1;
- av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
- intrapred, bw);
- // Pick wedge mask based on INTERINTRA_MODES - 1
- best_interintra_rd_wedge =
- pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
- // Find the best interintra mode for the chosen wedge mask
- for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
- ++cur_mode) {
- compute_best_interintra_mode(
- cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred,
- tmp_buf, &best_interintra_mode, &best_interintra_rd_wedge,
- cur_mode, bsize);
- }
- args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
- mbmi->interintra_mode = best_interintra_mode;
-
- // Recompute prediction if required
- if (best_interintra_mode != INTERINTRA_MODES - 1) {
- av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
- intrapred, bw);
- }
- } else {
- // Pick wedge mask for the best interintra mode (reused)
- mbmi->interintra_mode = best_interintra_mode;
- av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
- intrapred, bw);
- best_interintra_rd_wedge =
- pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
- }
- } else {
- // Pick wedge mask for the best interintra mode from smooth_interintra
- best_interintra_rd_wedge =
- pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
- }
-
- const int rate_overhead =
- interintra_mode_cost[mbmi->interintra_mode] +
- x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index] +
- x->wedge_interintra_cost[bsize][1];
- best_interintra_rd_wedge += RDCOST(x->rdmult, rate_overhead + *rate_mv, 0);
-
- const int_mv mv0 = mbmi->mv[0];
- int_mv tmp_mv = mv0;
- rd = INT64_MAX;
- int tmp_rate_mv = 0;
- // Refine motion vector for NEWMV case.
- if (have_newmv_in_inter_mode(mbmi->mode)) {
- int rate_sum, skip_txfm_sb;
- int64_t dist_sum, skip_sse_sb;
- // get negative of mask
- const uint8_t *mask =
- av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
- compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, intrapred,
- mask, bw, &tmp_rate_mv, 0);
- if (mbmi->mv[0].as_int != tmp_mv.as_int) {
- mbmi->mv[0].as_int = tmp_mv.as_int;
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
- cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
- &skip_sse_sb, NULL, NULL, NULL);
- rd =
- RDCOST(x->rdmult, tmp_rate_mv + rate_overhead + rate_sum, dist_sum);
- }
- }
- if (rd >= best_interintra_rd_wedge) {
- tmp_mv.as_int = mv0.as_int;
- tmp_rate_mv = *rate_mv;
- av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
- }
- // Evaluate closer to true rd
- RD_STATS rd_stats;
- const int64_t mode_rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv, 0);
- const int64_t tmp_rd_thresh = best_interintra_rd_nowedge - mode_rd;
- rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
- if (rd != INT64_MAX) {
- rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv + rd_stats.rate,
- rd_stats.dist);
- } else {
- if (best_interintra_rd_nowedge == INT64_MAX) return -1;
- }
- best_interintra_rd_wedge = rd;
- if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
- mbmi->mv[0].as_int = tmp_mv.as_int;
- *tmp_rate2 += tmp_rate_mv - *rate_mv;
- *rate_mv = tmp_rate_mv;
- } else {
- mbmi->use_wedge_interintra = 0;
- mbmi->interintra_mode = best_interintra_mode;
- mbmi->mv[0].as_int = mv0.as_int;
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- }
- }
-
- if (best_interintra_rd_nowedge == INT64_MAX &&
- best_interintra_rd_wedge == INT64_MAX) {
- return -1;
- }
-
- if (num_planes > 1) {
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
- AOM_PLANE_U, num_planes - 1);
- }
- return 0;
-}
-
// If number of valid neighbours is 1,
// 1) ROTZOOM parameters can be obtained reliably (2 parameters from
// one neighbouring MV)
@@ -5164,428 +4168,6 @@
return cost;
}
-// Calculates the cost for compound type mask
-static INLINE void calc_masked_type_cost(MACROBLOCK *x, BLOCK_SIZE bsize,
- int comp_group_idx_ctx,
- int comp_index_ctx,
- int masked_compound_used,
- int *masked_type_cost) {
- av1_zero_array(masked_type_cost, COMPOUND_TYPES);
- // Account for group index cost when wedge and/or diffwtd prediction are
- // enabled
- if (masked_compound_used) {
- // Compound group index of average and distwtd is 0
- // Compound group index of wedge and diffwtd is 1
- masked_type_cost[COMPOUND_AVERAGE] +=
- x->comp_group_idx_cost[comp_group_idx_ctx][0];
- masked_type_cost[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_AVERAGE];
- masked_type_cost[COMPOUND_WEDGE] +=
- x->comp_group_idx_cost[comp_group_idx_ctx][1];
- masked_type_cost[COMPOUND_DIFFWTD] += masked_type_cost[COMPOUND_WEDGE];
- }
-
- // Compute the cost to signal compound index/type
- masked_type_cost[COMPOUND_AVERAGE] += x->comp_idx_cost[comp_index_ctx][1];
- masked_type_cost[COMPOUND_DISTWTD] += x->comp_idx_cost[comp_index_ctx][0];
- masked_type_cost[COMPOUND_WEDGE] += x->compound_type_cost[bsize][0];
- masked_type_cost[COMPOUND_DIFFWTD] += x->compound_type_cost[bsize][1];
-}
-
-// Updates mbmi structure with the relevant compound type info
-static INLINE void update_mbmi_for_compound_type(MB_MODE_INFO *mbmi,
- COMPOUND_TYPE cur_type) {
- mbmi->interinter_comp.type = cur_type;
- mbmi->comp_group_idx = (cur_type >= COMPOUND_WEDGE);
- mbmi->compound_idx = (cur_type != COMPOUND_DISTWTD);
-}
-
-// When match is found, populate the compound type data
-// and calculate the rd cost using the stored stats and
-// update the mbmi appropriately.
-static INLINE int populate_reuse_comp_type_data(
- const MACROBLOCK *x, MB_MODE_INFO *mbmi,
- BEST_COMP_TYPE_STATS *best_type_stats, int_mv *cur_mv, int32_t *comp_rate,
- int64_t *comp_dist, int *comp_rs2, int *rate_mv, int64_t *rd,
- int match_index) {
- const int winner_comp_type =
- x->comp_rd_stats[match_index].interinter_comp.type;
- if (comp_rate[winner_comp_type] == INT_MAX)
- return best_type_stats->best_compmode_interinter_cost;
- update_mbmi_for_compound_type(mbmi, winner_comp_type);
- mbmi->interinter_comp = x->comp_rd_stats[match_index].interinter_comp;
- *rd = RDCOST(
- x->rdmult,
- comp_rs2[winner_comp_type] + *rate_mv + comp_rate[winner_comp_type],
- comp_dist[winner_comp_type]);
- mbmi->mv[0].as_int = cur_mv[0].as_int;
- mbmi->mv[1].as_int = cur_mv[1].as_int;
- return comp_rs2[winner_comp_type];
-}
-
-// Updates rd cost and relevant compound type data for the best compound type
-static INLINE void update_best_info(const MB_MODE_INFO *const mbmi, int64_t *rd,
- BEST_COMP_TYPE_STATS *best_type_stats,
- int64_t best_rd_cur,
- int64_t comp_model_rd_cur, int rs2) {
- *rd = best_rd_cur;
- best_type_stats->comp_best_model_rd = comp_model_rd_cur;
- best_type_stats->best_compound_data = mbmi->interinter_comp;
- best_type_stats->best_compmode_interinter_cost = rs2;
-}
-
-// Updates best_mv for masked compound types
-static INLINE void update_mask_best_mv(const MB_MODE_INFO *const mbmi,
- int_mv *best_mv, int_mv *cur_mv,
- const COMPOUND_TYPE cur_type,
- int *best_tmp_rate_mv, int tmp_rate_mv,
- const SPEED_FEATURES *const sf) {
- if (cur_type == COMPOUND_WEDGE ||
- (sf->inter_sf.enable_interinter_diffwtd_newmv_search &&
- cur_type == COMPOUND_DIFFWTD)) {
- *best_tmp_rate_mv = tmp_rate_mv;
- best_mv[0].as_int = mbmi->mv[0].as_int;
- best_mv[1].as_int = mbmi->mv[1].as_int;
- } else {
- best_mv[0].as_int = cur_mv[0].as_int;
- best_mv[1].as_int = cur_mv[1].as_int;
- }
-}
-
-// Computes the valid compound_types to be evaluated
-static INLINE int compute_valid_comp_types(
- MACROBLOCK *x, const AV1_COMP *const cpi, int *try_average_and_distwtd_comp,
- BLOCK_SIZE bsize, int masked_compound_used, int mode_search_mask,
- COMPOUND_TYPE *valid_comp_types) {
- const AV1_COMMON *cm = &cpi->common;
- int valid_type_count = 0;
- int comp_type, valid_check;
- int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
-
- const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
- const int try_distwtd_comp =
- ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
- cm->seq_params.order_hint_info.enable_dist_wtd_comp == 1 &&
- cpi->sf.inter_sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
- *try_average_and_distwtd_comp = try_average_comp && try_distwtd_comp;
-
- // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
- for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
- comp_type++) {
- valid_check =
- (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
- if (!*try_average_and_distwtd_comp && valid_check &&
- is_interinter_compound_used(comp_type, bsize))
- valid_comp_types[valid_type_count++] = comp_type;
- }
- // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
- if (masked_compound_used) {
- // enable_masked_type[0] corresponds to COMPOUND_WEDGE
- // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
- enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
- enable_masked_type[1] = cpi->oxcf.enable_diff_wtd_comp;
- for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
- comp_type++) {
- if ((mode_search_mask & (1 << comp_type)) &&
- is_interinter_compound_used(comp_type, bsize) &&
- enable_masked_type[comp_type - COMPOUND_WEDGE])
- valid_comp_types[valid_type_count++] = comp_type;
- }
- }
- return valid_type_count;
-}
-
-// Choose the better of the two COMPOUND_AVERAGE,
-// COMPOUND_DISTWTD based on modeled cost
-static int find_best_avg_distwtd_comp_type(MACROBLOCK *x, int *comp_model_rate,
- int64_t *comp_model_dist,
- int rate_mv, int64_t *best_rd) {
- int64_t est_rd[2];
- est_rd[COMPOUND_AVERAGE] =
- RDCOST(x->rdmult, comp_model_rate[COMPOUND_AVERAGE] + rate_mv,
- comp_model_dist[COMPOUND_AVERAGE]);
- est_rd[COMPOUND_DISTWTD] =
- RDCOST(x->rdmult, comp_model_rate[COMPOUND_DISTWTD] + rate_mv,
- comp_model_dist[COMPOUND_DISTWTD]);
- int best_type = (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD])
- ? COMPOUND_AVERAGE
- : COMPOUND_DISTWTD;
- *best_rd = est_rd[best_type];
- return best_type;
-}
-
-static int compound_type_rd(
- const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int_mv *cur_mv,
- int mode_search_mask, int masked_compound_used, const BUFFER_SET *orig_dst,
- const BUFFER_SET *tmp_dst, const CompoundTypeRdBuffers *buffers,
- int *rate_mv, int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
- int *is_luma_interp_done, int64_t rd_thresh) {
- const AV1_COMMON *cm = &cpi->common;
- MACROBLOCKD *xd = &x->e_mbd;
- MB_MODE_INFO *mbmi = xd->mi[0];
- const PREDICTION_MODE this_mode = mbmi->mode;
- const int bw = block_size_wide[bsize];
- int rs2;
- int_mv best_mv[2];
- int best_tmp_rate_mv = *rate_mv;
- BEST_COMP_TYPE_STATS best_type_stats;
- // Initializing BEST_COMP_TYPE_STATS
- best_type_stats.best_compound_data.type = COMPOUND_AVERAGE;
- best_type_stats.best_compmode_interinter_cost = 0;
- best_type_stats.comp_best_model_rd = INT64_MAX;
-
- uint8_t *preds0[1] = { buffers->pred0 };
- uint8_t *preds1[1] = { buffers->pred1 };
- int strides[1] = { bw };
- int tmp_rate_mv;
- const int num_pix = 1 << num_pels_log2_lookup[bsize];
- const int mask_len = 2 * num_pix * sizeof(uint8_t);
- COMPOUND_TYPE cur_type;
- // Local array to store the mask cost for different compound types
- int masked_type_cost[COMPOUND_TYPES];
-
- int calc_pred_masked_compound = 1;
- int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
- INT64_MAX };
- int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
- int comp_rs2[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
- int32_t comp_model_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX,
- INT_MAX };
- int64_t comp_model_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
- INT64_MAX };
- int match_index = 0;
- const int match_found =
- find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rate,
- comp_model_dist, comp_rs2, &match_index);
- best_mv[0].as_int = cur_mv[0].as_int;
- best_mv[1].as_int = cur_mv[1].as_int;
- *rd = INT64_MAX;
- int rate_sum, tmp_skip_txfm_sb;
- int64_t dist_sum, tmp_skip_sse_sb;
-
- // Local array to store the valid compound types to be evaluated in the core
- // loop
- COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
- COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
- };
- int valid_type_count = 0;
- int try_average_and_distwtd_comp = 0;
- // compute_valid_comp_types() returns the number of valid compound types to be
- // evaluated and populates the same in the local array valid_comp_types[].
- // It also sets the flag 'try_average_and_distwtd_comp'
- valid_type_count = compute_valid_comp_types(
- x, cpi, &try_average_and_distwtd_comp, bsize, masked_compound_used,
- mode_search_mask, valid_comp_types);
-
- // The following context indices are independent of compound type
- const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
- const int comp_index_ctx = get_comp_index_context(cm, xd);
-
- // Populates masked_type_cost local array for the 4 compound types
- calc_masked_type_cost(x, bsize, comp_group_idx_ctx, comp_index_ctx,
- masked_compound_used, masked_type_cost);
-
- int64_t comp_model_rd_cur = INT64_MAX;
- int64_t best_rd_cur = INT64_MAX;
- const int mi_row = xd->mi_row;
- const int mi_col = xd->mi_col;
-
- // If the match is found, calculate the rd cost using the
- // stored stats and update the mbmi appropriately.
- if (match_found && cpi->sf.inter_sf.reuse_compound_type_decision) {
- return populate_reuse_comp_type_data(x, mbmi, &best_type_stats, cur_mv,
- comp_rate, comp_dist, comp_rs2,
- rate_mv, rd, match_index);
- }
- // Special handling if both compound_average and compound_distwtd
- // are to be searched. In this case, first estimate between the two
- // modes and then call estimate_yrd_for_sb() only for the better of
- // the two.
- if (try_average_and_distwtd_comp) {
- int est_rate[2];
- int64_t est_dist[2], est_rd;
- COMPOUND_TYPE best_type;
- // Since modelled rate and dist are separately stored,
- // compute better of COMPOUND_AVERAGE and COMPOUND_DISTWTD
- // using the stored stats.
- if ((comp_model_rate[COMPOUND_AVERAGE] != INT_MAX) &&
- comp_model_rate[COMPOUND_DISTWTD] != INT_MAX) {
- // Choose the better of the COMPOUND_AVERAGE,
- // COMPOUND_DISTWTD on modeled cost.
- best_type = find_best_avg_distwtd_comp_type(
- x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
- update_mbmi_for_compound_type(mbmi, best_type);
- if (comp_rate[best_type] != INT_MAX)
- best_rd_cur = RDCOST(
- x->rdmult,
- masked_type_cost[best_type] + *rate_mv + comp_rate[best_type],
- comp_dist[best_type]);
- comp_model_rd_cur = est_rd;
- // Update stats for best compound type
- if (best_rd_cur < *rd) {
- update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
- comp_model_rd_cur, masked_type_cost[best_type]);
- }
- restore_dst_buf(xd, *tmp_dst, 1);
- } else {
- // Calculate model_rd for COMPOUND_AVERAGE and COMPOUND_DISTWTD
- for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
- comp_type++) {
- update_mbmi_for_compound_type(mbmi, comp_type);
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- model_rd_sb_fn[MODELRD_CURVFIT](
- cpi, bsize, x, xd, 0, 0, &est_rate[comp_type], &est_dist[comp_type],
- NULL, NULL, NULL, NULL, NULL);
- est_rate[comp_type] += masked_type_cost[comp_type];
- comp_model_rate[comp_type] = est_rate[comp_type];
- comp_model_dist[comp_type] = est_dist[comp_type];
- if (comp_type == COMPOUND_AVERAGE) {
- *is_luma_interp_done = 1;
- restore_dst_buf(xd, *tmp_dst, 1);
- }
- }
- // Choose the better of the two based on modeled cost and call
- // estimate_yrd_for_sb() for that one.
- best_type = find_best_avg_distwtd_comp_type(
- x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
- update_mbmi_for_compound_type(mbmi, best_type);
- if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *orig_dst, 1);
- rs2 = masked_type_cost[best_type];
- RD_STATS est_rd_stats;
- const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
- const int64_t est_rd_ =
- estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
-
- if (est_rd_ != INT64_MAX) {
- best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
- est_rd_stats.dist);
- // Backup rate and distortion for future reuse
- backup_stats(best_type, comp_rate, comp_dist, comp_model_rate,
- comp_model_dist, est_rate[best_type], est_dist[best_type],
- &est_rd_stats, comp_rs2, rs2);
- comp_model_rd_cur = est_rd;
- }
- if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
- // Update stats for best compound type
- if (best_rd_cur < *rd) {
- update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
- comp_model_rd_cur, rs2);
- }
- }
- }
-
- // If COMPOUND_AVERAGE is not valid, use the spare buffer
- if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
-
- // Loop over valid compound types
- for (int i = 0; i < valid_type_count; i++) {
- cur_type = valid_comp_types[i];
- comp_model_rd_cur = INT64_MAX;
- tmp_rate_mv = *rate_mv;
- best_rd_cur = INT64_MAX;
-
- // Case COMPOUND_AVERAGE and COMPOUND_DISTWTD
- if (cur_type < COMPOUND_WEDGE) {
- update_mbmi_for_compound_type(mbmi, cur_type);
- rs2 = masked_type_cost[cur_type];
- const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
- if (mode_rd < ref_best_rd) {
- // Reuse data if matching record is found
- if (comp_rate[cur_type] == INT_MAX) {
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
-
- // Compute RD cost for the current type
- RD_STATS est_rd_stats;
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
- const int64_t est_rd =
- estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
- if (est_rd != INT64_MAX) {
- best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
- est_rd_stats.dist);
- model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
- cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
- &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
- comp_model_rd_cur =
- RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
-
- // Backup rate and distortion for future reuse
- backup_stats(cur_type, comp_rate, comp_dist, comp_model_rate,
- comp_model_dist, rate_sum, dist_sum, &est_rd_stats,
- comp_rs2, rs2);
- }
- } else {
- // Calculate RD cost based on stored stats
- assert(comp_dist[cur_type] != INT64_MAX);
- best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
- comp_dist[cur_type]);
- // Recalculate model rdcost with the updated rate
- comp_model_rd_cur =
- RDCOST(x->rdmult, rs2 + *rate_mv + comp_model_rate[cur_type],
- comp_model_dist[cur_type]);
- }
- }
- // use spare buffer for following compound type try
- if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
- } else {
- // Handle masked compound types
- update_mbmi_for_compound_type(mbmi, cur_type);
- rs2 = masked_type_cost[cur_type];
- // Evaluate COMPOUND_WEDGE / COMPOUND_DIFFWTD if approximated cost is
- // within threshold
- int64_t approx_rd = ((*rd / cpi->max_comp_type_rd_threshold_div) *
- cpi->max_comp_type_rd_threshold_mul);
-
- if (approx_rd < ref_best_rd) {
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
- best_rd_cur = masked_compound_type_rd(
- cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
- &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
- strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
- comp_rate, comp_dist, comp_model_rate, comp_model_dist,
- best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2);
- }
- }
- // Update stats for best compound type
- if (best_rd_cur < *rd) {
- update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
- comp_model_rd_cur, rs2);
- if (masked_compound_used && cur_type >= COMPOUND_WEDGE) {
- memcpy(buffers->tmp_best_mask_buf, xd->seg_mask, mask_len);
- if (have_newmv_in_inter_mode(this_mode))
- update_mask_best_mv(mbmi, best_mv, cur_mv, cur_type,
- &best_tmp_rate_mv, tmp_rate_mv, &cpi->sf);
- }
- }
- // reset to original mvs for next iteration
- mbmi->mv[0].as_int = cur_mv[0].as_int;
- mbmi->mv[1].as_int = cur_mv[1].as_int;
- }
- if (mbmi->interinter_comp.type != best_type_stats.best_compound_data.type) {
- mbmi->comp_group_idx =
- (best_type_stats.best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
- mbmi->compound_idx =
- !(best_type_stats.best_compound_data.type == COMPOUND_DISTWTD);
- mbmi->interinter_comp = best_type_stats.best_compound_data;
- memcpy(xd->seg_mask, buffers->tmp_best_mask_buf, mask_len);
- }
- if (have_newmv_in_inter_mode(this_mode)) {
- mbmi->mv[0].as_int = best_mv[0].as_int;
- mbmi->mv[1].as_int = best_mv[1].as_int;
- if (mbmi->interinter_comp.type == COMPOUND_WEDGE) {
- rd_stats->rate += best_tmp_rate_mv - *rate_mv;
- *rate_mv = best_tmp_rate_mv;
- }
- }
- restore_dst_buf(xd, *orig_dst, 1);
- if (!match_found)
- save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rate,
- comp_model_dist, cur_mv, comp_rs2);
- return best_type_stats.best_compmode_interinter_cost;
-}
-
static INLINE int is_single_newmv_valid(const HandleInterModeArgs *const args,
const MB_MODE_INFO *const mbmi,
PREDICTION_MODE this_mode) {
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 020d39f..0a5010b 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -173,11 +173,6 @@
void av1_inter_mode_data_init(struct TileDataEnc *tile_data);
void av1_inter_mode_data_fit(TileDataEnc *tile_data, int rdmult);
-typedef int64_t (*pick_interinter_mask_type)(
- const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
- const uint8_t *const p0, const uint8_t *const p1,
- const int16_t *const residual1, const int16_t *const diff10);
-
static INLINE int av1_encoder_get_relative_dist(const OrderHintInfo *oh, int a,
int b) {
if (!oh->enable_order_hint) return 0;