Add comparison between cnn and cdef/restoration.

Change-Id: I08899c776c241156f6df07b95829831545f6037e
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 6981382..cf2759e 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -603,6 +603,9 @@
   av1_free_context_buffers(cm);
 
   aom_free_frame_buffer(&cpi->last_frame_uf);
+#if CONFIG_CNN_RESTORATION
+  aom_free_frame_buffer(&cpi->cnn_buffer);
+#endif  // CONFIG_CNN_RESTORATION
   av1_free_restoration_buffers(cm);
   aom_free_frame_buffer(&cpi->trial_frame_rst);
   aom_free_frame_buffer(&cpi->scaled_source);
@@ -849,6 +852,15 @@
     aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
                        "Failed to allocate last frame buffer");
 
+#if CONFIG_CNN_RESTORATION
+  if (aom_realloc_frame_buffer(
+          &cpi->cnn_buffer, cm->width, cm->height, seq_params->subsampling_x,
+          seq_params->subsampling_y, seq_params->use_highbitdepth,
+          cpi->oxcf.border_in_pixels, cm->byte_alignment, NULL, NULL, NULL))
+    aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
+                       "Failed to allocate CNN frame buffer");
+#endif  // CONFIG_CNN_RESTORATION
+
   if (aom_realloc_frame_buffer(
           &cpi->trial_frame_rst, cm->superres_upscaled_width,
           cm->superres_upscaled_height, seq_params->subsampling_x,
@@ -4340,6 +4352,49 @@
   }
 }
 
+static void cdef_restoration_frame(AV1_COMP *cpi, AV1_COMMON *cm,
+                                   MACROBLOCKD *xd, int use_restoration,
+                                   int use_cdef) {
+  if (use_restoration)
+    av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 0);
+
+  if (use_cdef) {
+    // Find CDEF parameters
+    av1_cdef_search(&cm->cur_frame->buf, cpi->source, cm, xd,
+                    cpi->sf.fast_cdef_search);
+
+    // Apply the filter
+    av1_cdef_frame(&cm->cur_frame->buf, cm, xd);
+  } else {
+    cm->cdef_info.cdef_bits = 0;
+    cm->cdef_info.cdef_strengths[0] = 0;
+    cm->cdef_info.nb_cdef_strengths = 1;
+    cm->cdef_info.cdef_uv_strengths[0] = 0;
+  }
+
+  superres_post_encode(cpi);
+
+  if (use_restoration) {
+    av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 1);
+    av1_pick_filter_restoration(cpi->source, cpi);
+    if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
+        cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
+        cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
+      if (cpi->num_workers > 1)
+        av1_loop_restoration_filter_frame_mt(&cm->cur_frame->buf, cm, 0,
+                                             cpi->workers, cpi->num_workers,
+                                             &cpi->lr_row_sync, &cpi->lr_ctxt);
+      else
+        av1_loop_restoration_filter_frame(&cm->cur_frame->buf, cm, 0,
+                                          &cpi->lr_ctxt);
+    }
+  } else {
+    cm->rst_info[0].frame_restoration_type = RESTORE_NONE;
+    cm->rst_info[1].frame_restoration_type = RESTORE_NONE;
+    cm->rst_info[2].frame_restoration_type = RESTORE_NONE;
+  }
+}
+
 static void loopfilter_frame(AV1_COMP *cpi, AV1_COMMON *cm) {
   const int num_planes = av1_num_planes(cm);
   MACROBLOCKD *xd = &cpi->td.mb.e_mbd;
@@ -4388,70 +4443,52 @@
   }
 
 #if CONFIG_CNN_RESTORATION
+  cm->use_cnn = 0;
   if (av1_use_cnn(cm)) {
-    aom_yv12_copy_y(&cm->cur_frame->buf, &cpi->last_frame_uf);
+    int64_t dgd_error = INT64_MAX;
+    int64_t cnn_error = INT64_MAX;
+    int64_t res_error = INT64_MAX;
+
+    aom_yv12_copy_frame(&cm->cur_frame->buf, &cpi->last_frame_uf,
+                        av1_num_planes(cm));
     const int plane = AOM_PLANE_Y;
-    int64_t dgd_error =
-        aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
-                          cm->seq_params.use_highbitdepth);
+    // Find the error of the plane from source.
+    dgd_error = aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
+                                  cm->seq_params.use_highbitdepth);
 
     av1_encode_restore_cnn(cm);
 
-    int64_t cnn_error =
-        aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
-                          cm->seq_params.use_highbitdepth);
+    // Find the error of the plane from source after applying cnn.
+    cnn_error = aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
+                                  cm->seq_params.use_highbitdepth);
 
-    if (dgd_error > cnn_error) {
+    if (cnn_error < dgd_error)
+      aom_yv12_copy_y(&cm->cur_frame->buf, &cpi->cnn_buffer);
+    aom_yv12_copy_y(&cpi->last_frame_uf, &cm->cur_frame->buf);
+
+    cdef_restoration_frame(cpi, cm, xd, use_cdef, use_restoration);
+
+    // Find the error of the plane from source after applying cdef-restoration.
+    res_error = aom_get_sse_plane(cpi->source, &cm->cur_frame->buf, plane,
+                                  cm->seq_params.use_highbitdepth);
+    if (cnn_error < res_error && cnn_error < dgd_error) {
+      int num_planes = av1_num_planes(cm);
       cm->use_cnn = 1;
-      use_cdef = 0;
-      use_restoration = 0;
-    } else {
-      cm->use_cnn = 0;
-      aom_yv12_copy_y(&cpi->last_frame_uf, &cm->cur_frame->buf);
+      aom_yv12_copy_y(&cpi->cnn_buffer, &cm->cur_frame->buf);
+      if (num_planes > 1)
+        aom_yv12_copy_u(&cpi->last_frame_uf, &cm->cur_frame->buf);
+      if (num_planes > 2)
+        aom_yv12_copy_v(&cpi->last_frame_uf, &cm->cur_frame->buf);
+      // Since cnn restores better than cdef-restoration, disable the
+      // cdef-restoration instructions.
+      cdef_restoration_frame(cpi, cm, xd, 0, 0);
     }
   } else {
-    cm->use_cnn = 0;
+    cdef_restoration_frame(cpi, cm, xd, use_cdef, use_restoration);
   }
+#else
+  cdef_restoration_frame(cpi, cm, xd, use_cdef, use_restoration);
 #endif  // CONFIG_CNN_RESTORATION
-
-  if (use_restoration)
-    av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 0);
-
-  if (use_cdef) {
-    // Find CDEF parameters
-    av1_cdef_search(&cm->cur_frame->buf, cpi->source, cm, xd,
-                    cpi->sf.fast_cdef_search);
-
-    // Apply the filter
-    av1_cdef_frame(&cm->cur_frame->buf, cm, xd);
-  } else {
-    cm->cdef_info.cdef_bits = 0;
-    cm->cdef_info.cdef_strengths[0] = 0;
-    cm->cdef_info.nb_cdef_strengths = 1;
-    cm->cdef_info.cdef_uv_strengths[0] = 0;
-  }
-
-  superres_post_encode(cpi);
-
-  if (use_restoration) {
-    av1_loop_restoration_save_boundary_lines(&cm->cur_frame->buf, cm, 1);
-    av1_pick_filter_restoration(cpi->source, cpi);
-    if (cm->rst_info[0].frame_restoration_type != RESTORE_NONE ||
-        cm->rst_info[1].frame_restoration_type != RESTORE_NONE ||
-        cm->rst_info[2].frame_restoration_type != RESTORE_NONE) {
-      if (cpi->num_workers > 1)
-        av1_loop_restoration_filter_frame_mt(&cm->cur_frame->buf, cm, 0,
-                                             cpi->workers, cpi->num_workers,
-                                             &cpi->lr_row_sync, &cpi->lr_ctxt);
-      else
-        av1_loop_restoration_filter_frame(&cm->cur_frame->buf, cm, 0,
-                                          &cpi->lr_ctxt);
-    }
-  } else {
-    cm->rst_info[0].frame_restoration_type = RESTORE_NONE;
-    cm->rst_info[1].frame_restoration_type = RESTORE_NONE;
-    cm->rst_info[2].frame_restoration_type = RESTORE_NONE;
-  }
 }
 
 static int get_refresh_frame_flags(const AV1_COMP *const cpi) {
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index ce83734..ecd1738 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -766,6 +766,9 @@
 
   YV12_BUFFER_CONFIG last_frame_uf;
   YV12_BUFFER_CONFIG trial_frame_rst;
+#if CONFIG_CNN_RESTORATION
+  YV12_BUFFER_CONFIG cnn_buffer;
+#endif  // CONFIG_CNN_RESTORATION
 
   // Ambient reconstruction err target for force key frames
   int64_t ambient_err;