Add high level funct to restore a part of a frame

Adds a high level functionality to restore a part of a frame.

Change-Id: I1ada6a0c153e88374927f3ab1ec6073cfd684616
diff --git a/av1/common/cnn.c b/av1/common/cnn.c
index cb4314c..4986a62 100644
--- a/av1/common/cnn.c
+++ b/av1/common/cnn.c
@@ -741,8 +741,8 @@
   }
 }
 
-void av1_restore_cnn(uint8_t *dgd, int width, int height, int stride,
-                     const CNN_CONFIG *cnn_config) {
+static void av1_restore_cnn(uint8_t *dgd, int width, int height, int stride,
+                            const CNN_CONFIG *cnn_config) {
   const float max_val = 255.0;
   int out_width = 0;
   int out_height = 0;
@@ -806,8 +806,9 @@
   aom_free(output);
 }
 
-void av1_restore_cnn_highbd(uint16_t *dgd, int width, int height, int stride,
-                            const CNN_CONFIG *cnn_config, int bit_depth) {
+static void av1_restore_cnn_highbd(uint16_t *dgd, int width, int height,
+                                   int stride, const CNN_CONFIG *cnn_config,
+                                   int bit_depth) {
   const float max_val = (float)((1 << bit_depth) - 1);
   int out_width = 0;
   int out_height = 0;
@@ -872,6 +873,69 @@
   aom_free(output);
 }
 
+void av1_restore_cnn_plane_part(AV1_COMMON *cm, const CNN_CONFIG *cnn_config,
+                                int plane, int start_x, int start_y, int width,
+                                int height) {
+  YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf;
+
+  assert(start_x >= 0 && start_x + width <= buf->y_crop_width);
+  assert(start_y >= 0 && start_y + height <= buf->y_crop_height);
+
+  int offset = 0, part_width = 0, part_height = 0;
+  switch (plane) {
+    case AOM_PLANE_Y:
+      part_width = width;
+      part_height = height;
+      offset = start_y * buf->y_stride + start_x;
+      break;
+    case AOM_PLANE_U:
+    case AOM_PLANE_V:
+      part_width = width >> buf->subsampling_x;
+      part_height = height >> buf->subsampling_y;
+      offset = (start_y >> buf->subsampling_y) * buf->uv_stride +
+               (start_x >> buf->subsampling_x);
+      break;
+    default: assert(0 && "Invalid plane index");
+  }
+  if (cm->seq_params.use_highbitdepth) {
+    switch (plane) {
+      case AOM_PLANE_Y:
+        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->y_buffer + offset),
+                               part_width, part_height, buf->y_stride,
+                               cnn_config, cm->seq_params.bit_depth);
+        break;
+      case AOM_PLANE_U:
+        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->u_buffer + offset),
+                               part_width, part_height, buf->uv_stride,
+                               cnn_config, cm->seq_params.bit_depth);
+        break;
+      case AOM_PLANE_V:
+        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->v_buffer + offset),
+                               part_width, part_height, buf->uv_stride,
+                               cnn_config, cm->seq_params.bit_depth);
+        break;
+      default: assert(0 && "Invalid plane index");
+    }
+  } else {
+    assert(cm->seq_params.bit_depth == 8);
+    switch (plane) {
+      case AOM_PLANE_Y:
+        av1_restore_cnn(buf->y_buffer + offset, part_width, part_height,
+                        buf->y_stride, cnn_config);
+        break;
+      case AOM_PLANE_U:
+        av1_restore_cnn(buf->u_buffer + offset, part_width, part_height,
+                        buf->uv_stride, cnn_config);
+        break;
+      case AOM_PLANE_V:
+        av1_restore_cnn(buf->v_buffer + offset, part_width, part_height,
+                        buf->uv_stride, cnn_config);
+        break;
+      default: assert(0 && "Invalid plane index");
+    }
+  }
+}
+
 void av1_restore_cnn_plane(AV1_COMMON *cm, const CNN_CONFIG *cnn_config,
                            int plane) {
   YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf;
diff --git a/av1/common/cnn.h b/av1/common/cnn.h
index 6ebd5e4..7304de1 100644
--- a/av1/common/cnn.h
+++ b/av1/common/cnn.h
@@ -113,12 +113,13 @@
 void av1_find_cnn_output_size(int in_width, int in_height,
                               const CNN_CONFIG *cnn_config, int *out_width,
                               int *out_height);
-void av1_restore_cnn(uint8_t *dgd, int width, int height, int stride,
-                     const CNN_CONFIG *cnn_config);
-void av1_restore_cnn_highbd(uint16_t *dgd, int width, int height, int stride,
-                            const CNN_CONFIG *cnn_config, int bit_depth);
+
 void av1_restore_cnn_plane(struct AV1Common *cm, const CNN_CONFIG *cnn_config,
                            int plane);
+void av1_restore_cnn_plane_part(struct AV1Common *cm,
+                                const CNN_CONFIG *cnn_config, int plane,
+                                int start_x, int start_y, int width,
+                                int height);
 
 #ifdef __cplusplus
 }  // extern "C"