Prune interpolation filter search
Prune interpolation filter search based on previous frames'
filter selection stats. Calculate probabilities of each filter,
and prune the search at prediction block level. Currently, this
feature is turned on at speed 5.
Borg test(150 frames) results:
avg_psnr: ovr_psnr: ssim: avg speedups over whole set:
lowres: 0.153 0.163 0.149 1.6%
midres: 0.145 0.158 0.070 1.8%
STATS_CHANGED
Change-Id: I21edcc76ac30b2fd2eec8f5efd7edeb2c7119cbf
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a685152..fc9b4f6 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5507,6 +5507,33 @@
cpi->warped_probs[update_type] =
(cpi->warped_probs[update_type] + new_prob) >> 1;
}
+
+ if (cm->current_frame.frame_type != KEY_FRAME &&
+ cpi->sf.interp_sf.adaptive_interp_filter_search == 2 &&
+ cm->interp_filter == SWITCHABLE) {
+ const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
+
+ for (i = 0; i < SWITCHABLE_FILTER_CONTEXTS; i++) {
+ int sum = 0;
+ int j;
+ int left = 1536;
+
+ for (j = 0; j < SWITCHABLE_FILTERS; j++) {
+ sum += cpi->td.counts->switchable_interp[i][j];
+ }
+
+ for (j = SWITCHABLE_FILTERS - 1; j >= 0; j--) {
+ const int new_prob =
+ sum ? 1536 * cpi->td.counts->switchable_interp[i][j] / sum
+ : (j ? 0 : 1536);
+ int prob =
+ (cpi->switchable_interp_probs[update_type][i][j] + new_prob) >> 1;
+ left -= prob;
+ if (j == 0) prob += left;
+ cpi->switchable_interp_probs[update_type][i][j] = prob;
+ }
+ }
+ }
}
#define CHECK_PRECOMPUTED_REF_FRAME_MAP 0
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index a5373eb..9e85e2b 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -251,6 +251,125 @@
const int default_warped_probs[FRAME_UPDATE_TYPES] = { 64, 64, 64, 64,
64, 64, 64 };
+// TODO(yunqing): the default probs can be trained later from better
+// performance.
+const int default_switchable_interp_probs[FRAME_UPDATE_TYPES]
+ [SWITCHABLE_FILTER_CONTEXTS]
+ [SWITCHABLE_FILTERS] = {
+ { { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 } },
+ { { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 } },
+ { { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 } },
+ { { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 } },
+ { { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 } },
+ { { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 } },
+ { { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 },
+ { 512, 512, 512 } }
+ };
+
static INLINE void Scale2Ratio(AOM_SCALING mode, int *hr, int *hs) {
switch (mode) {
case NORMAL:
@@ -5030,7 +5149,7 @@
cm->current_frame.frame_type == KEY_FRAME) {
av1_copy(cpi->tx_type_probs, default_tx_type_probs);
- int thr[2][2] = { { 15, 10 }, { 17, 10 } };
+ const int thr[2][2] = { { 15, 10 }, { 17, 10 } };
for (int f = 0; f < FRAME_UPDATE_TYPES; f++) {
int kf_arf_update = (f == KF_UPDATE || f == ARF_UPDATE);
cpi->tx_type_probs_thresh[f] =
@@ -5049,6 +5168,16 @@
av1_copy(cpi->warped_probs, default_warped_probs);
}
+ if (cpi->sf.interp_sf.adaptive_interp_filter_search == 2 &&
+ cm->current_frame.frame_type == KEY_FRAME) {
+ av1_copy(cpi->switchable_interp_probs, default_switchable_interp_probs);
+
+ const int thr[7] = { 0, 8, 8, 8, 8, 0, 8 };
+ for (int f = 0; f < FRAME_UPDATE_TYPES; f++) {
+ cpi->switchable_interp_thresh[f] = thr[f];
+ }
+ }
+
// Loop variables
int loop_count = 0;
int loop_at_this_size = 0;
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index b2b8638..997c673 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -990,6 +990,9 @@
int warped_probs[FRAME_UPDATE_TYPES];
int tx_type_probs[FRAME_UPDATE_TYPES][TX_SIZES_ALL][TX_TYPES];
int tx_type_probs_thresh[FRAME_UPDATE_TYPES];
+ int switchable_interp_probs[FRAME_UPDATE_TYPES][SWITCHABLE_FILTER_CONTEXTS]
+ [SWITCHABLE_FILTERS];
+ int switchable_interp_thresh[FRAME_UPDATE_TYPES];
// Multi-threading
int num_workers;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 170776d..754e860 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8257,6 +8257,29 @@
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = xd->mi[0];
+ uint16_t interp_filter_search_mask = cpi->interp_filter_search_mask;
+
+ if (cpi->sf.interp_sf.adaptive_interp_filter_search == 2) {
+ const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
+ const int ctx0 = av1_get_pred_context_switchable_interp(xd, 0);
+ const int ctx1 = av1_get_pred_context_switchable_interp(xd, 1);
+ const int *switchable_interp_p0 =
+ cpi->switchable_interp_probs[update_type][ctx0];
+ const int *switchable_interp_p1 =
+ cpi->switchable_interp_probs[update_type][ctx1];
+
+ const int thresh = cpi->switchable_interp_thresh[update_type];
+ for (i = 0; i < SWITCHABLE_FILTERS; i++) {
+ // For non-dual case, the 2 dir's prob should be identical.
+ assert(switchable_interp_p0[i] == switchable_interp_p1[i]);
+ if (switchable_interp_p0[i] < thresh &&
+ switchable_interp_p1[i] < thresh) {
+ DUAL_FILTER_TYPE filt_type = i + SWITCHABLE_FILTERS * i;
+ reset_interp_filter_allowed_mask(&interp_filter_search_mask, filt_type);
+ }
+ }
+ }
+
// Regular filter evaluation should have been done and hence the same should
// be the winner
assert(x->e_mbd.mi[0]->interp_filters.as_int == filter_sets[0].as_int);
@@ -8273,7 +8296,7 @@
assert(filter_sets[filter_idx].as_filters.x_filter ==
filter_sets[filter_idx].as_filters.y_filter);
if (cpi->sf.interp_sf.adaptive_interp_filter_search &&
- !(get_interp_filter_allowed_mask(cpi->interp_filter_search_mask,
+ !(get_interp_filter_allowed_mask(interp_filter_search_mask,
filter_idx))) {
return;
}
@@ -8308,7 +8331,7 @@
set_interp_filter_allowed_mask(&allowed_interp_mask, SHARP_SHARP);
set_interp_filter_allowed_mask(&allowed_interp_mask, SMOOTH_SMOOTH);
if (cpi->sf.interp_sf.adaptive_interp_filter_search)
- allowed_interp_mask &= cpi->interp_filter_search_mask;
+ allowed_interp_mask &= interp_filter_search_mask;
find_best_interp_rd_facade(x, cpi, tile_data, bsize, orig_dst, rd,
rd_stats_y, rd_stats, switchable_rate, dst_bufs,
@@ -8322,8 +8345,7 @@
assert(filter_sets[i].as_filters.x_filter ==
filter_sets[i].as_filters.y_filter);
if (cpi->sf.interp_sf.adaptive_interp_filter_search &&
- !(get_interp_filter_allowed_mask(cpi->interp_filter_search_mask,
- i))) {
+ !(get_interp_filter_allowed_mask(interp_filter_search_mask, i))) {
continue;
}
interpolation_filter_rd(x, cpi, tile_data, bsize, orig_dst, rd,
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index e54f42c..741ea4c 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -539,6 +539,8 @@
sf->inter_sf.disable_smooth_interintra = 1;
sf->lpf_sf.disable_lr_filter = 1;
+
+ sf->interp_sf.adaptive_interp_filter_search = 2;
}
}