cnn: propagate allocation errors

Bug: aomedia:3276
Change-Id: I02f7327923a57c89d31c2160068454f696a7bc2a
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 355fc0a..c9e87e3 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -485,7 +485,7 @@
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
     add_proto qw/void av1_cnn_activate/, " float **input, int channels, int width, int height, int stride, ACTIVATION layer_activation";
     add_proto qw/void av1_cnn_add/, " float **input, int channels, int width, int height, int stride, const float **add";
-    add_proto qw/void av1_cnn_predict/, " const float **input, int in_width, int in_height, int in_stride, const CNN_CONFIG *cnn_config, const CNN_THREAD_DATA *thread_data, CNN_MULTI_OUT *output_struct";
+    add_proto qw/bool av1_cnn_predict/, " const float **input, int in_width, int in_height, int in_stride, const CNN_CONFIG *cnn_config, const CNN_THREAD_DATA *thread_data, CNN_MULTI_OUT *output_struct";
     add_proto qw/void av1_cnn_convolve_no_maxpool_padding_valid/, " const float **input, int in_width, int in_height, int in_stride, const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride, int start_idx, int cstep, int channel_step";
     if (aom_config("CONFIG_EXCLUDE_SIMD_MISMATCH") ne "yes") {
       specialize qw/av1_cnn_convolve_no_maxpool_padding_valid avx2/;
diff --git a/av1/encoder/cnn.c b/av1/encoder/cnn.c
index 599812f..83e2c45 100644
--- a/av1/encoder/cnn.c
+++ b/av1/encoder/cnn.c
@@ -11,6 +11,7 @@
 
 #include <assert.h>
 #include <math.h>
+#include <stdbool.h>
 
 #include "aom_dsp/aom_dsp_common.h"
 #include "av1/common/av1_common_int.h"
@@ -55,13 +56,14 @@
   }
 }
 
-static void realloc_tensor(TENSOR *tensor, int channels, int width,
+static bool realloc_tensor(TENSOR *tensor, int channels, int width,
                            int height) {
   const int newallocsize = channels * width * height;
   if (tensor->allocsize < newallocsize) {
     free_tensor(tensor);
     tensor->buf[0] =
         (float *)aom_malloc(sizeof(*tensor->buf[0]) * newallocsize);
+    if (!tensor->buf[0]) return false;
     tensor->allocsize = newallocsize;
   }
   tensor->width = width;
@@ -70,6 +72,7 @@
   tensor->channels = channels;
   for (int c = 1; c < channels; ++c)
     tensor->buf[c] = &tensor->buf[0][c * width * height];
+  return true;
 }
 
 static void copy_tensor(const TENSOR *src, int copy_channels, int dst_offset,
@@ -115,7 +118,7 @@
 
 // The concatenated tensor goes into dst with first the channels in
 // original dst followed by the channels in the src
-static void concat_tensor(const TENSOR *src, TENSOR *dst) {
+static bool concat_tensor(const TENSOR *src, TENSOR *dst) {
   assert(src->width == dst->width);
   assert(src->height == dst->height);
 
@@ -126,7 +129,7 @@
     TENSOR t;
     init_tensor(&t);
     // allocate new buffers and copy first the dst channels
-    realloc_tensor(&t, channels, dst->width, dst->height);
+    if (!realloc_tensor(&t, channels, dst->width, dst->height)) return false;
     copy_tensor(dst, dst->channels, 0, &t);
     // Swap the tensors and free the old buffers
     swap_tensor(dst, &t);
@@ -136,6 +139,7 @@
     dst->buf[c] = &dst->buf[0][c * dst->width * dst->height];
   // Copy the channels in src after the first dst_channels channels.
   copy_tensor(src, src->channels, dst_channels, dst);
+  return true;
 }
 
 int check_tensor_equal_dims(TENSOR *t1, TENSOR *t2) {
@@ -326,7 +330,7 @@
   }
 }
 
-static void copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
+static bool copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
                                            const CNN_LAYER_CONFIG *layer_config,
                                            int branch, TENSOR branch_output[]) {
   const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
@@ -338,11 +342,15 @@
       int copy_channels = branch_config->channels_to_copy > 0
                               ? branch_config->channels_to_copy
                               : layer_active_tensor->channels;
-      realloc_tensor(&branch_output[b], copy_channels,
-                     layer_active_tensor->width, layer_active_tensor->height);
+      if (!realloc_tensor(&branch_output[b], copy_channels,
+                          layer_active_tensor->width,
+                          layer_active_tensor->height)) {
+        return false;
+      }
       copy_tensor(layer_active_tensor, copy_channels, 0, &branch_output[b]);
     }
   }
+  return true;
 }
 
 // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
@@ -892,10 +900,11 @@
   }
 }
 
-void av1_cnn_predict_c(const float **input, int in_width, int in_height,
+bool av1_cnn_predict_c(const float **input, int in_width, int in_height,
                        int in_stride, const CNN_CONFIG *cnn_config,
                        const CNN_THREAD_DATA *thread_data,
                        CNN_MULTI_OUT *output_struct) {
+  bool success = false;
   TENSOR tensor1[CNN_MAX_BRANCHES] = { { 0 } };
   TENSOR tensor2[CNN_MAX_BRANCHES] = { { 0 } };
 
@@ -938,8 +947,10 @@
                                    &o_height);
     const int output_num = layer_config->output_num;
     if (output_num == -1) {  // Non-output layer
-      realloc_tensor(&tensor2[branch], layer_config->out_channels, o_width,
-                     o_height);
+      if (!realloc_tensor(&tensor2[branch], layer_config->out_channels, o_width,
+                          o_height)) {
+        goto Error;
+      }
     } else {  // Output layer
       free_tensor(&tensor2[branch]);
       assign_tensor(&tensor2[branch], output[output_num],
@@ -953,8 +964,10 @@
                    !(branch_config->branches_to_combine & (1 << branch))));
 
     if (layer_config->branch_copy_type == BRANCH_INPUT) {
-      copy_active_tensor_to_branches(&tensor1[branch], layer_config, branch,
-                                     tensor2);
+      if (!copy_active_tensor_to_branches(&tensor1[branch], layer_config,
+                                          branch, tensor2)) {
+        goto Error;
+      }
     }
     // Check consistency of input and output channels
     assert(tensor1[branch].channels == layer_config->in_channels);
@@ -981,8 +994,10 @@
     }
 
     if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
-      copy_active_tensor_to_branches(&tensor2[branch], layer_config, branch,
-                                     tensor2);
+      if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
+                                          branch, tensor2)) {
+        goto Error;
+      }
     }
 
     // Add tensors from other branches if needed
@@ -1018,7 +1033,7 @@
           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
             assert(tensor2[b].channels > 0);
-            concat_tensor(&tensor2[b], &tensor2[branch]);
+            if (!concat_tensor(&tensor2[b], &tensor2[branch])) goto Error;
           }
         }
       } else {  // Output layer
@@ -1048,20 +1063,25 @@
     }
 
     if (layer_config->branch_copy_type == BRANCH_COMBINED) {
-      copy_active_tensor_to_branches(&tensor2[branch], layer_config, branch,
-                                     tensor2);
+      if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
+                                          branch, tensor2)) {
+        goto Error;
+      }
     }
   }
 
+  success = true;
+Error:
   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
     free_tensor(&tensor1[b]);
     free_tensor(&tensor2[b]);
   }
+  return success;
 }
 
 // Assume output already has proper allocation
 // Assume input image buffers all have same resolution and strides
-void av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
+bool av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
                                    int stride, const CNN_CONFIG *cnn_config,
                                    const CNN_THREAD_DATA *thread_data,
                                    CNN_MULTI_OUT *output) {
@@ -1073,6 +1093,7 @@
   float *inputs[CNN_MAX_CHANNELS];
   float *input_ =
       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
+  if (!input_) return false;
   const int in_stride = in_width;
 
   for (int c = 0; c < in_channels; ++c) {
@@ -1107,15 +1128,16 @@
           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
     }
   }
-  av1_cnn_predict((const float **)inputs, in_width, in_height, in_stride,
-                  cnn_config, thread_data, output);
+  bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
+                                 in_stride, cnn_config, thread_data, output);
 
   aom_free(input_);
+  return success;
 }
 
 // Assume output already has proper allocation
 // Assume input image buffers all have same resolution and strides
-void av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
+bool av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
                                           int stride,
                                           const CNN_CONFIG *cnn_config,
                                           const CNN_THREAD_DATA *thread_data,
@@ -1129,6 +1151,7 @@
   float *inputs[CNN_MAX_CHANNELS];
   float *input_ =
       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
+  if (!input_) return false;
   const int in_stride = in_width;
 
   for (int c = 0; c < in_channels; ++c) {
@@ -1164,15 +1187,16 @@
     }
   }
 
-  av1_cnn_predict((const float **)inputs, in_width, in_height, in_stride,
-                  cnn_config, thread_data, output);
+  bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
+                                 in_stride, cnn_config, thread_data, output);
 
   aom_free(input_);
+  return success;
 }
 
 // Assume output already has proper allocation
 // Assume input image buffers all have same resolution and strides
-void av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
+bool av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
                          const CNN_CONFIG *cnn_config,
                          const CNN_THREAD_DATA *thread_data, float **output,
                          int out_stride) {
@@ -1184,13 +1208,13 @@
   CNN_MULTI_OUT output_struct = { .output_channels = output_chs,
                                   .output_strides = output_strides,
                                   .output_buffer = output };
-  av1_cnn_predict_img_multi_out(dgd, width, height, stride, cnn_config,
-                                thread_data, &output_struct);
+  return av1_cnn_predict_img_multi_out(dgd, width, height, stride, cnn_config,
+                                       thread_data, &output_struct);
 }
 
 // Assume output already has proper allocation
 // Assume input image buffers all have same resolution and strides
-void av1_cnn_predict_img_highbd(uint16_t **dgd, int width, int height,
+bool av1_cnn_predict_img_highbd(uint16_t **dgd, int width, int height,
                                 int stride, const CNN_CONFIG *cnn_config,
                                 const CNN_THREAD_DATA *thread_data,
                                 int bit_depth, float **output, int out_stride) {
@@ -1202,6 +1226,7 @@
   CNN_MULTI_OUT output_struct = { .output_channels = output_chs,
                                   .output_strides = output_strides,
                                   .output_buffer = output };
-  av1_cnn_predict_img_multi_out_highbd(dgd, width, height, stride, cnn_config,
-                                       thread_data, bit_depth, &output_struct);
+  return av1_cnn_predict_img_multi_out_highbd(dgd, width, height, stride,
+                                              cnn_config, thread_data,
+                                              bit_depth, &output_struct);
 }
diff --git a/av1/encoder/cnn.h b/av1/encoder/cnn.h
index 3b55aa0..1a6c03a4c 100644
--- a/av1/encoder/cnn.h
+++ b/av1/encoder/cnn.h
@@ -17,6 +17,7 @@
 #endif
 
 #include <math.h>
+#include <stdbool.h>
 
 #include "aom_util/aom_thread.h"
 #include "config/av1_rtcd.h"
@@ -174,11 +175,11 @@
 
 // Prediction functions from set of input image buffers. This function supports
 // CNN with multiple outputs.
-void av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
+bool av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
                                    int stride, const CNN_CONFIG *cnn_config,
                                    const CNN_THREAD_DATA *thread_data,
                                    struct CNN_MULTI_OUT *output);
-void av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
+bool av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
                                           int stride,
                                           const CNN_CONFIG *cnn_config,
                                           const CNN_THREAD_DATA *thread_data,
@@ -186,11 +187,11 @@
 
 // Prediction functions from set of input image buffers. This function only
 // supports a single output.
-void av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
+bool av1_cnn_predict_img(uint8_t **dgd, int width, int height, int stride,
                          const CNN_CONFIG *cnn_config,
                          const CNN_THREAD_DATA *thread_data, float **output,
                          int out_stride);
-void av1_cnn_predict_img_highbd(uint16_t **dgd, int width, int height,
+bool av1_cnn_predict_img_highbd(uint16_t **dgd, int width, int height,
                                 int stride, const CNN_CONFIG *cnn_config,
                                 const CNN_THREAD_DATA *thread_data,
                                 int bit_depth, float **output, int out_stride);
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index 86d4b9b..c4024b4 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -200,14 +200,22 @@
         CONVERT_TO_SHORTPTR(x->plane[AOM_PLANE_Y].src.buf) - stride - 1
       };
 
-      av1_cnn_predict_img_multi_out_highbd(image, width, height, stride,
-                                           cnn_config, &thread_data, bit_depth,
-                                           &output);
+      if (!av1_cnn_predict_img_multi_out_highbd(image, width, height, stride,
+                                                cnn_config, &thread_data,
+                                                bit_depth, &output)) {
+        aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
+                           "Error allocating CNN data");
+        return;
+      }
     } else {
       uint8_t *image[1] = { x->plane[AOM_PLANE_Y].src.buf - stride - 1 };
 
-      av1_cnn_predict_img_multi_out(image, width, height, stride, cnn_config,
-                                    &thread_data, &output);
+      if (!av1_cnn_predict_img_multi_out(image, width, height, stride,
+                                         cnn_config, &thread_data, &output)) {
+        aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
+                           "Error allocating CNN data");
+        return;
+      }
     }
 
     part_info->cnn_output_valid = 1;
diff --git a/test/cnn_test.cc b/test/cnn_test.cc
index 2468700..0b92197 100644
--- a/test/cnn_test.cc
+++ b/test/cnn_test.cc
@@ -88,8 +88,8 @@
 
     av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
                              out_heights, not_used);
-    av1_cnn_predict(input, image_width, image_height, in_stride, cnn_config,
-                    thread_data, output);
+    ASSERT_TRUE(av1_cnn_predict(input, image_width, image_height, in_stride,
+                                cnn_config, thread_data, output));
 
     int channel_offset = 0;
     for (int output_idx = 0; output_idx < num_outputs; output_idx++) {