Add comparison between cnn and cdef/restoration.
Change-Id: I08899c776c241156f6df07b95829831545f6037e
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 6981382..cf2759e 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -603,6 +603,9 @@
av1_free_context_buffers(cm);
aom_free_frame_buffer(&cpi->last_frame_uf);
+#if CONFIG_CNN_RESTORATION
+ aom_free_frame_buffer(&cpi->cnn_buffer);
+#endif // CONFIG_CNN_RESTORATION
av1_free_restoration_buffers(cm);
aom_free_frame_buffer(&cpi->trial_frame_rst);
aom_free_frame_buffer(&cpi->scaled_source);
@@ -849,6 +852,15 @@
aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
"Failed to allocate last frame buffer");
+#if CONFIG_CNN_RESTORATION
+ if (aom_realloc_frame_buffer(
+ &cpi->cnn_buffer, cm->width, cm->height, seq_params->subsampling_x,
+ seq_params->subsampling_y, seq_params->use_highbitdepth,
+ cpi->oxcf.border_in_pixels, cm->byte_alignment, NULL, NULL, NULL))
+ aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
+ "Failed to allocate CNN frame buffer");
+#endif // CONFIG_CNN_RESTORATION
+
if (aom_realloc_frame_buffer(
&cpi->trial_frame_rst, cm->superres_upscaled_width,
cm->superres_upscaled_height, seq_params->subsampling_x,
@@ -4340,6 +4352,49 @@
}
}
+static void cdef_restoration_frame(AV1_COMP *cpi, AV1_COMMON *cm,
+ MACROBLOCKD *xd, int use_restoration,
+ int use_cdef) {
+ if (use_restoration)
+ av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 0);
+
+ if (use_cdef) {
+ // Find CDEF parameters
+ av1_cdef_search(&cm->cur_frame->buf, cpi->source, cm, xd,
+ cpi->sf.fast_cdef_search);
+
+ // Apply the filter
+ av1_cdef_frame(&cm->cur_frame->buf, cm, xd);
+ } else {
+ cm->cdef_info.cdef_bits = 0;
+ cm->cdef_info.cdef_strengths[0] = 0;
+ cm->cdef_info.nb_cdef_strengths = 1;
+ cm->cdef_info.cdef_uv_strengths[0] = 0;
+ }
+
+ superres_post_encode(cpi);
+
+ if (use_restoration) {
+ av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 1);
+ av1_pick_filter_restoration(cpi->source, cpi);
+ if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
+ cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
+ cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
+ if (cpi->num_workers > 1)
+ av1_loop_restoration_filter_frame_mt(&cm->cur_frame->buf, cm, 0,
+ cpi->workers, cpi->num_workers,
+ &cpi->lr_row_sync, &cpi->lr_ctxt);
+ else
+ av1_loop_restoration_filter_frame(&cm->cur_frame->buf, cm, 0,
+ &cpi->lr_ctxt);
+ }
+ } else {
+ cm->rst_info[0].frame_restoration_type = RESTORE_NONE;
+ cm->rst_info[1].frame_restoration_type = RESTORE_NONE;
+ cm->rst_info[2].frame_restoration_type = RESTORE_NONE;
+ }
+}
+
static void loopfilter_frame(AV1_COMP *cpi, AV1_COMMON *cm) {
const int num_planes = av1_num_planes(cm);
MACROBLOCKD *xd = &cpi->td.mb.e_mbd;
@@ -4388,70 +4443,52 @@
}
#if CONFIG_CNN_RESTORATION
+ cm->use_cnn = 0;
if (av1_use_cnn(cm)) {
- aom_yv12_copy_y(&cm->cur_frame->buf, &cpi->last_frame_uf);
+ int64_t dgd_error = INT64_MAX;
+ int64_t cnn_error = INT64_MAX;
+ int64_t res_error = INT64_MAX;
+
+ aom_yv12_copy_frame(&cm->cur_frame->buf, &cpi->last_frame_uf,
+ av1_num_planes(cm));
const int plane = AOM_PLANE_Y;
- int64_t dgd_error =
- aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
- cm->seq_params.use_highbitdepth);
+ // Find the error of the plane from source.
+ dgd_error = aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
+ cm->seq_params.use_highbitdepth);
av1_encode_restore_cnn(cm);
- int64_t cnn_error =
- aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
- cm->seq_params.use_highbitdepth);
+ // Find the error of the plane from source after applying cnn.
+ cnn_error = aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
+ cm->seq_params.use_highbitdepth);
- if (dgd_error > cnn_error) {
+ if (cnn_error < dgd_error)
+ aom_yv12_copy_y(&cm->cur_frame->buf, &cpi->cnn_buffer);
+ aom_yv12_copy_y(&cpi->last_frame_uf, &cm->cur_frame->buf);
+
+ cdef_restoration_frame(cpi, cm, xd, use_cdef, use_restoration);
+
+ // Find the error of the plane from source after applying cdef-restoration.
+ res_error = aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
+ cm->seq_params.use_highbitdepth);
+ if (cnn_error < res_error && cnn_error < dgd_error) {
+ int num_planes = av1_num_planes(cm);
cm->use_cnn = 1;
- use_cdef = 0;
- use_restoration = 0;
- } else {
- cm->use_cnn = 0;
- aom_yv12_copy_y(&cpi->last_frame_uf, &cm->cur_frame->buf);
+ aom_yv12_copy_y(&cpi->cnn_buffer, &cm->cur_frame->buf);
+ if (num_planes > 1)
+ aom_yv12_copy_u(&cpi->last_frame_uf, &cm->cur_frame->buf);
+ if (num_planes > 2)
+ aom_yv12_copy_v(&cpi->last_frame_uf, &cm->cur_frame->buf);
+ // Since cnn restores better than cdef-restoration, disable the
+ // cdef-restoration instructions.
+ cdef_restoration_frame(cpi, cm, xd, 0, 0);
}
} else {
- cm->use_cnn = 0;
+ cdef_restoration_frame(cpi, cm, xd, use_cdef, use_restoration);
}
+#else
+ cdef_restoration_frame(cpi, cm, xd, use_cdef, use_restoration);
#endif // CONFIG_CNN_RESTORATION
-
- if (use_restoration)
- av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 0);
-
- if (use_cdef) {
- // Find CDEF parameters
- av1_cdef_search(&cm->cur_frame->buf, cpi->source, cm, xd,
- cpi->sf.fast_cdef_search);
-
- // Apply the filter
- av1_cdef_frame(&cm->cur_frame->buf, cm, xd);
- } else {
- cm->cdef_info.cdef_bits = 0;
- cm->cdef_info.cdef_strengths[0] = 0;
- cm->cdef_info.nb_cdef_strengths = 1;
- cm->cdef_info.cdef_uv_strengths[0] = 0;
- }
-
- superres_post_encode(cpi);
-
- if (use_restoration) {
- av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 1);
- av1_pick_filter_restoration(cpi->source, cpi);
- if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
- cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
- cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
- if (cpi->num_workers > 1)
- av1_loop_restoration_filter_frame_mt(&cm->cur_frame->buf, cm, 0,
- cpi->workers, cpi->num_workers,
- &cpi->lr_row_sync, &cpi->lr_ctxt);
- else
- av1_loop_restoration_filter_frame(&cm->cur_frame->buf, cm, 0,
- &cpi->lr_ctxt);
- }
- } else {
- cm->rst_info[0].frame_restoration_type = RESTORE_NONE;
- cm->rst_info[1].frame_restoration_type = RESTORE_NONE;
- cm->rst_info[2].frame_restoration_type = RESTORE_NONE;
- }
}
static int get_refresh_frame_flags(const AV1_COMP *const cpi) {
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index ce83734..ecd1738 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -766,6 +766,9 @@
YV12_BUFFER_CONFIG last_frame_uf;
YV12_BUFFER_CONFIG trial_frame_rst;
+#if CONFIG_CNN_RESTORATION
+ YV12_BUFFER_CONFIG cnn_buffer;
+#endif // CONFIG_CNN_RESTORATION
// Ambient reconstruction err target for force key frames
int64_t ambient_err;