Move tpl-based MV generation outside of av1_single_motion_search
Change-Id: I3b6e23c766b494f23e18cc34850983c22fab85ee
diff --git a/av1/encoder/motion_search_facade.c b/av1/encoder/motion_search_facade.c
index 059c8f2..9834b19 100644
--- a/av1/encoder/motion_search_facade.c
+++ b/av1/encoder/motion_search_facade.c
@@ -44,6 +44,77 @@
cpi->oxcf.speed <= 2;
}
+// Iterate through the tpl and collect the mvs to be used as candidates
+static INLINE void get_mv_candidate_from_tpl(const AV1_COMP *const cpi,
+ const MACROBLOCK *x,
+ BLOCK_SIZE bsize, int ref,
+ cand_mv_t *cand, int *cand_count,
+ int *total_cand_weight) {
+ const SuperBlockEnc *sb_enc = &x->sb_enc;
+ if (!sb_enc->tpl_data_count) {
+ return;
+ }
+
+ const AV1_COMMON *cm = &cpi->common;
+ const MACROBLOCKD *xd = &x->e_mbd;
+ const int mi_row = xd->mi_row;
+ const int mi_col = xd->mi_col;
+
+ const BLOCK_SIZE tpl_bsize =
+ convert_length_to_bsize(cpi->tpl_data.tpl_bsize_1d);
+ const int tplw = mi_size_wide[tpl_bsize];
+ const int tplh = mi_size_high[tpl_bsize];
+ const int nw = mi_size_wide[bsize] / tplw;
+ const int nh = mi_size_high[bsize] / tplh;
+
+ if (nw >= 1 && nh >= 1) {
+ const int of_h = mi_row % mi_size_high[cm->seq_params.sb_size];
+ const int of_w = mi_col % mi_size_wide[cm->seq_params.sb_size];
+ const int start = of_h / tplh * sb_enc->tpl_stride + of_w / tplw;
+ int valid = 1;
+
+ // Assign large weight to start_mv, so it is always tested.
+ cand[0].weight = nw * nh;
+
+ for (int k = 0; k < nh; k++) {
+ for (int l = 0; l < nw; l++) {
+ const int_mv mv =
+ sb_enc
+ ->tpl_mv[start + k * sb_enc->tpl_stride + l][ref - LAST_FRAME];
+ if (mv.as_int == INVALID_MV) {
+ valid = 0;
+ break;
+ }
+
+ const FULLPEL_MV fmv = { GET_MV_RAWPEL(mv.as_mv.row),
+ GET_MV_RAWPEL(mv.as_mv.col) };
+ int unique = 1;
+ for (int m = 0; m < *cand_count; m++) {
+ if (RIGHT_SHIFT_MV(fmv.row) == RIGHT_SHIFT_MV(cand[m].fmv.row) &&
+ RIGHT_SHIFT_MV(fmv.col) == RIGHT_SHIFT_MV(cand[m].fmv.col)) {
+ unique = 0;
+ cand[m].weight++;
+ break;
+ }
+ }
+
+ if (unique) {
+ cand[*cand_count].fmv = fmv;
+ cand[*cand_count].weight = 1;
+ (*cand_count)++;
+ }
+ }
+ if (!valid) break;
+ }
+
+ if (valid) {
+ *total_cand_weight = 2 * nh * nw;
+ if (*cand_count > 2)
+ qsort(cand, *cand_count, sizeof(cand[0]), &compare_weight);
+ }
+ }
+}
+
void av1_single_motion_search(const AV1_COMP *const cpi, MACROBLOCK *x,
BLOCK_SIZE bsize, int ref_idx, int *rate_mv,
int search_range, inter_mode_info *mode_info,
@@ -104,60 +175,7 @@
if (!cpi->sf.mv_sf.full_pixel_search_level &&
mbmi->motion_mode == SIMPLE_TRANSLATION) {
- SuperBlockEnc *sb_enc = &x->sb_enc;
- if (sb_enc->tpl_data_count) {
- const BLOCK_SIZE tpl_bsize =
- convert_length_to_bsize(cpi->tpl_data.tpl_bsize_1d);
- const int tplw = mi_size_wide[tpl_bsize];
- const int tplh = mi_size_high[tpl_bsize];
- const int nw = mi_size_wide[bsize] / tplw;
- const int nh = mi_size_high[bsize] / tplh;
-
- if (nw >= 1 && nh >= 1) {
- const int of_h = mi_row % mi_size_high[cm->seq_params.sb_size];
- const int of_w = mi_col % mi_size_wide[cm->seq_params.sb_size];
- const int start = of_h / tplh * sb_enc->tpl_stride + of_w / tplw;
- int valid = 1;
-
- // Assign large weight to start_mv, so it is always tested.
- cand[0].weight = nw * nh;
-
- for (int k = 0; k < nh; k++) {
- for (int l = 0; l < nw; l++) {
- const int_mv mv = sb_enc->tpl_mv[start + k * sb_enc->tpl_stride + l]
- [ref - LAST_FRAME];
- if (mv.as_int == INVALID_MV) {
- valid = 0;
- break;
- }
-
- const FULLPEL_MV fmv = { GET_MV_RAWPEL(mv.as_mv.row),
- GET_MV_RAWPEL(mv.as_mv.col) };
- int unique = 1;
- for (int m = 0; m < cnt; m++) {
- if (RIGHT_SHIFT_MV(fmv.row) == RIGHT_SHIFT_MV(cand[m].fmv.row) &&
- RIGHT_SHIFT_MV(fmv.col) == RIGHT_SHIFT_MV(cand[m].fmv.col)) {
- unique = 0;
- cand[m].weight++;
- break;
- }
- }
-
- if (unique) {
- cand[cnt].fmv = fmv;
- cand[cnt].weight = 1;
- cnt++;
- }
- }
- if (!valid) break;
- }
-
- if (valid) {
- total_weight = 2 * nh * nw;
- if (cnt > 2) qsort(cand, cnt, sizeof(cand[0]), &compare_weight);
- }
- }
- }
+ get_mv_candidate_from_tpl(cpi, x, bsize, ref, cand, &cnt, &total_weight);
}
// Further reduce the search range.
@@ -190,9 +208,9 @@
switch (mbmi->motion_mode) {
case SIMPLE_TRANSLATION: {
+ // Perform a search with the top 2 candidates
int sum_weight = 0;
-
- for (int m = 0; m < cnt; m++) {
+ for (int m = 0; m < AOMMIN(2, cnt); m++) {
FULLPEL_MV smv = cand[m].fmv;
FULLPEL_MV this_best_mv, this_second_best_mv;
@@ -207,7 +225,7 @@
}
sum_weight += cand[m].weight;
- if (m >= 2 || 4 * sum_weight > 3 * total_weight) break;
+ if (4 * sum_weight > 3 * total_weight) break;
}
} break;
case OBMC_CAUSAL: