Rework the high level API for the CNN lib

Change-Id: Icbe081c03afdb721410f5f492a91c3d0465ca178
diff --git a/av1/common/cnn.c b/av1/common/cnn.c
index 4986a62..95c5917 100644
--- a/av1/common/cnn.c
+++ b/av1/common/cnn.c
@@ -741,14 +741,14 @@
   }
 }
 
-static void av1_restore_cnn(uint8_t *dgd, int width, int height, int stride,
-                            const CNN_CONFIG *cnn_config) {
+// Assume output already has proper allocation
+void av1_cnn_predict_img(uint8_t *dgd, int width, int height, int stride,
+                         const CNN_CONFIG *cnn_config, float **output,
+                         int out_stride) {
   const float max_val = 255.0;
   int out_width = 0;
   int out_height = 0;
   av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height);
-  assert(out_width == width);
-  assert(out_height == height);
 
   int in_width = width + 2 * cnn_config->ext_width;
   int in_height = height + 2 * cnn_config->ext_height;
@@ -757,8 +757,6 @@
   float *input =
       input_ + cnn_config->ext_height * in_stride + cnn_config->ext_width;
 
-  float *output = (float *)aom_malloc(width * height * sizeof(*output));
-  const int out_stride = width;
   if (cnn_config->strict_bounds) {
     for (int i = 0; i < height; ++i)
       for (int j = 0; j < width; ++j)
@@ -785,9 +783,71 @@
            ++j)
         input[i * in_stride + j] = (float)dgd[i * stride + j] / max_val;
   }
-
   av1_cnn_predict((const float **)&input_, in_width, in_height, in_stride,
-                  cnn_config, &output, out_stride);
+                  cnn_config, output, out_stride);
+  aom_free(input_);
+}
+
+// Assume output already has proper allocation
+void av1_cnn_predict_img_highbd(uint16_t *dgd, int width, int height,
+                                int stride, const CNN_CONFIG *cnn_config,
+                                int bit_depth, float **output, int out_stride) {
+  const float max_val = (float)((1 << bit_depth) - 1);
+  int out_width = 0;
+  int out_height = 0;
+  av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height);
+
+  int in_width = width + 2 * cnn_config->ext_width;
+  int in_height = height + 2 * cnn_config->ext_height;
+  float *input_ = (float *)aom_malloc(in_width * in_height * sizeof(*input_));
+  const int in_stride = in_width;
+  float *input =
+      input_ + cnn_config->ext_height * in_stride + cnn_config->ext_width;
+
+  if (cnn_config->strict_bounds) {
+    for (int i = 0; i < height; ++i)
+      for (int j = 0; j < width; ++j)
+        input[i * in_stride + j] = (float)dgd[i * stride + j] / max_val;
+    // extend left and right
+    for (int i = 0; i < height; ++i) {
+      for (int j = -cnn_config->ext_width; j < 0; ++j)
+        input[i * in_stride + j] = input[i * in_stride];
+      for (int j = width; j < width + cnn_config->ext_width; ++j)
+        input[i * in_stride + j] = input[i * in_stride + width - 1];
+    }
+    // extend top and bottom
+    for (int i = -cnn_config->ext_height; i < 0; ++i)
+      memcpy(&input[i * in_stride - cnn_config->ext_width],
+             &input[-cnn_config->ext_width], in_width * sizeof(*input));
+    for (int i = height; i < height + cnn_config->ext_height; ++i)
+      memcpy(&input[i * in_stride - cnn_config->ext_width],
+             &input[(height - 1) * in_stride - cnn_config->ext_width],
+             in_width * sizeof(*input));
+  } else {
+    for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
+         ++i)
+      for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
+           ++j)
+        input[i * in_stride + j] = (float)dgd[i * stride + j] / max_val;
+  }
+  av1_cnn_predict((const float **)&input, width, height, in_stride, cnn_config,
+                  output, out_stride);
+  aom_free(input_);
+}
+
+void av1_restore_cnn_img(uint8_t *dgd, int width, int height, int stride,
+                         const CNN_CONFIG *cnn_config) {
+  const float max_val = 255;
+  int out_width = 0;
+  int out_height = 0;
+  av1_find_cnn_output_size(width, height, cnn_config, &out_width, &out_height);
+  assert(out_width == width);
+  assert(out_height == height);
+
+  const int out_stride = width;
+  float *output = (float *)aom_malloc(width * height * sizeof(*output));
+  av1_cnn_predict_img(dgd, width, height, stride, cnn_config, &output,
+                      out_stride);
 
   if (cnn_config->is_residue) {
     for (int i = 0; i < height; ++i)
@@ -801,14 +861,12 @@
         dgd[i * stride + j] =
             clip_pixel((int)(output[i * out_stride + j] * max_val + 0.5));
   }
-
-  aom_free(input_);
   aom_free(output);
 }
 
-static void av1_restore_cnn_highbd(uint16_t *dgd, int width, int height,
-                                   int stride, const CNN_CONFIG *cnn_config,
-                                   int bit_depth) {
+void av1_restore_cnn_img_highbd(uint16_t *dgd, int width, int height,
+                                int stride, const CNN_CONFIG *cnn_config,
+                                int bit_depth) {
   const float max_val = (float)((1 << bit_depth) - 1);
   int out_width = 0;
   int out_height = 0;
@@ -816,44 +874,10 @@
   assert(out_width == width);
   assert(out_height == height);
 
-  int in_width = width + 2 * cnn_config->ext_width;
-  int in_height = height + 2 * cnn_config->ext_height;
-  float *input_ = (float *)aom_malloc(in_width * in_height * sizeof(*input_));
-  const int in_stride = in_width;
-  float *input =
-      input_ + cnn_config->ext_height * in_stride + cnn_config->ext_width;
-
   float *output = (float *)aom_malloc(width * height * sizeof(*output));
   const int out_stride = width;
-  if (cnn_config->strict_bounds) {
-    for (int i = 0; i < height; ++i)
-      for (int j = 0; j < width; ++j)
-        input[i * in_stride + j] = (float)dgd[i * stride + j] / max_val;
-    // extend left and right
-    for (int i = 0; i < height; ++i) {
-      for (int j = -cnn_config->ext_width; j < 0; ++j)
-        input[i * in_stride + j] = input[i * in_stride];
-      for (int j = width; j < width + cnn_config->ext_width; ++j)
-        input[i * in_stride + j] = input[i * in_stride + width - 1];
-    }
-    // extend top and bottom
-    for (int i = -cnn_config->ext_height; i < 0; ++i)
-      memcpy(&input[i * in_stride - cnn_config->ext_width],
-             &input[-cnn_config->ext_width], in_width * sizeof(*input));
-    for (int i = height; i < height + cnn_config->ext_height; ++i)
-      memcpy(&input[i * in_stride - cnn_config->ext_width],
-             &input[(height - 1) * in_stride - cnn_config->ext_width],
-             in_width * sizeof(*input));
-  } else {
-    for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
-         ++i)
-      for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
-           ++j)
-        input[i * in_stride + j] = (float)dgd[i * stride + j] / max_val;
-  }
-
-  av1_cnn_predict((const float **)&input, width, height, in_stride, cnn_config,
-                  &output, out_stride);
+  av1_cnn_predict_img_highbd(dgd, width, height, stride, cnn_config, bit_depth,
+                             &output, out_stride);
 
   if (cnn_config->is_residue) {
     for (int i = 0; i < height; ++i)
@@ -868,8 +892,6 @@
         dgd[i * stride + j] = clip_pixel_highbd(
             (int)(output[i * out_stride + j] * max_val + 0.5), bit_depth);
   }
-
-  aom_free(input_);
   aom_free(output);
 }
 
@@ -900,19 +922,19 @@
   if (cm->seq_params.use_highbitdepth) {
     switch (plane) {
       case AOM_PLANE_Y:
-        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->y_buffer + offset),
-                               part_width, part_height, buf->y_stride,
-                               cnn_config, cm->seq_params.bit_depth);
+        av1_restore_cnn_img_highbd(CONVERT_TO_SHORTPTR(buf->y_buffer + offset),
+                                   part_width, part_height, buf->y_stride,
+                                   cnn_config, cm->seq_params.bit_depth);
         break;
       case AOM_PLANE_U:
-        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->u_buffer + offset),
-                               part_width, part_height, buf->uv_stride,
-                               cnn_config, cm->seq_params.bit_depth);
+        av1_restore_cnn_img_highbd(CONVERT_TO_SHORTPTR(buf->u_buffer + offset),
+                                   part_width, part_height, buf->uv_stride,
+                                   cnn_config, cm->seq_params.bit_depth);
         break;
       case AOM_PLANE_V:
-        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->v_buffer + offset),
-                               part_width, part_height, buf->uv_stride,
-                               cnn_config, cm->seq_params.bit_depth);
+        av1_restore_cnn_img_highbd(CONVERT_TO_SHORTPTR(buf->v_buffer + offset),
+                                   part_width, part_height, buf->uv_stride,
+                                   cnn_config, cm->seq_params.bit_depth);
         break;
       default: assert(0 && "Invalid plane index");
     }
@@ -920,16 +942,16 @@
     assert(cm->seq_params.bit_depth == 8);
     switch (plane) {
       case AOM_PLANE_Y:
-        av1_restore_cnn(buf->y_buffer + offset, part_width, part_height,
-                        buf->y_stride, cnn_config);
+        av1_restore_cnn_img(buf->y_buffer + offset, part_width, part_height,
+                            buf->y_stride, cnn_config);
         break;
       case AOM_PLANE_U:
-        av1_restore_cnn(buf->u_buffer + offset, part_width, part_height,
-                        buf->uv_stride, cnn_config);
+        av1_restore_cnn_img(buf->u_buffer + offset, part_width, part_height,
+                            buf->uv_stride, cnn_config);
         break;
       case AOM_PLANE_V:
-        av1_restore_cnn(buf->v_buffer + offset, part_width, part_height,
-                        buf->uv_stride, cnn_config);
+        av1_restore_cnn_img(buf->v_buffer + offset, part_width, part_height,
+                            buf->uv_stride, cnn_config);
         break;
       default: assert(0 && "Invalid plane index");
     }
@@ -942,22 +964,22 @@
   if (cm->seq_params.use_highbitdepth) {
     switch (plane) {
       case AOM_PLANE_Y:
-        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->y_buffer),
-                               buf->y_crop_width, buf->y_crop_height,
-                               buf->y_stride, cnn_config,
-                               cm->seq_params.bit_depth);
+        av1_restore_cnn_img_highbd(CONVERT_TO_SHORTPTR(buf->y_buffer),
+                                   buf->y_crop_width, buf->y_crop_height,
+                                   buf->y_stride, cnn_config,
+                                   cm->seq_params.bit_depth);
         break;
       case AOM_PLANE_U:
-        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->u_buffer),
-                               buf->uv_crop_width, buf->uv_crop_height,
-                               buf->uv_stride, cnn_config,
-                               cm->seq_params.bit_depth);
+        av1_restore_cnn_img_highbd(CONVERT_TO_SHORTPTR(buf->u_buffer),
+                                   buf->uv_crop_width, buf->uv_crop_height,
+                                   buf->uv_stride, cnn_config,
+                                   cm->seq_params.bit_depth);
         break;
       case AOM_PLANE_V:
-        av1_restore_cnn_highbd(CONVERT_TO_SHORTPTR(buf->v_buffer),
-                               buf->uv_crop_width, buf->uv_crop_height,
-                               buf->uv_stride, cnn_config,
-                               cm->seq_params.bit_depth);
+        av1_restore_cnn_img_highbd(CONVERT_TO_SHORTPTR(buf->v_buffer),
+                                   buf->uv_crop_width, buf->uv_crop_height,
+                                   buf->uv_stride, cnn_config,
+                                   cm->seq_params.bit_depth);
         break;
       default: assert(0 && "Invalid plane index");
     }
@@ -965,16 +987,16 @@
     assert(cm->seq_params.bit_depth == 8);
     switch (plane) {
       case AOM_PLANE_Y:
-        av1_restore_cnn(buf->y_buffer, buf->y_crop_width, buf->y_crop_height,
-                        buf->y_stride, cnn_config);
+        av1_restore_cnn_img(buf->y_buffer, buf->y_crop_width,
+                            buf->y_crop_height, buf->y_stride, cnn_config);
         break;
       case AOM_PLANE_U:
-        av1_restore_cnn(buf->u_buffer, buf->uv_crop_width, buf->uv_crop_height,
-                        buf->uv_stride, cnn_config);
+        av1_restore_cnn_img(buf->u_buffer, buf->uv_crop_width,
+                            buf->uv_crop_height, buf->uv_stride, cnn_config);
         break;
       case AOM_PLANE_V:
-        av1_restore_cnn(buf->v_buffer, buf->uv_crop_width, buf->uv_crop_height,
-                        buf->uv_stride, cnn_config);
+        av1_restore_cnn_img(buf->v_buffer, buf->uv_crop_width,
+                            buf->uv_crop_height, buf->uv_stride, cnn_config);
         break;
       default: assert(0 && "Invalid plane index");
     }
diff --git a/av1/common/cnn.h b/av1/common/cnn.h
index 7304de1..51495fd 100644
--- a/av1/common/cnn.h
+++ b/av1/common/cnn.h
@@ -110,10 +110,29 @@
   CNN_LAYER_CONFIG layer_config[CNN_MAX_LAYERS];
 };
 
+// Function to return size of output
 void av1_find_cnn_output_size(int in_width, int in_height,
                               const CNN_CONFIG *cnn_config, int *out_width,
                               int *out_height);
 
+// Prediction functions from input image buffer
+void av1_cnn_predict_img(uint8_t *dgd, int width, int height, int stride,
+                         const CNN_CONFIG *cnn_config, float **output,
+                         int out_stride);
+void av1_cnn_predict_img_highbd(uint16_t *dgd, int width, int height,
+                                int stride, const CNN_CONFIG *cnn_config,
+                                int bit_depth, float **output, int out_stride);
+
+// Restoration functions from input image buffer
+// These internally call av1_cnn_predict_img() / av1_cnn_predict_img_highbd().
+void av1_restore_cnn_img(uint8_t *dgd, int width, int height, int stride,
+                         const CNN_CONFIG *cnn_config);
+void av1_restore_cnn_img_highbd(uint16_t *dgd, int width, int height,
+                                int stride, const CNN_CONFIG *cnn_config,
+                                int bit_depth);
+
+// Restoration functions that work on current frame buffer in AV1_COMMON
+// directly for convenience.
 void av1_restore_cnn_plane(struct AV1Common *cm, const CNN_CONFIG *cnn_config,
                            int plane);
 void av1_restore_cnn_plane_part(struct AV1Common *cm,