Make sse used in model based rd use int64_t

Also add a skip mode check in the dnn model.

Change-Id: Ibc1c6fcbd8608a13baf0f9914436a4db85b0ad72
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ae391c9..cf85a33 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1777,9 +1777,11 @@
   for (plane = plane_from; plane <= plane_to; ++plane) {
     struct macroblock_plane *const p = &x->plane[plane];
     struct macroblockd_plane *const pd = &xd->plane[plane];
-    const BLOCK_SIZE bs =
+    const BLOCK_SIZE plane_bsize =
         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
-    unsigned int sse;
+    const int bw = block_size_wide[plane_bsize];
+    const int bh = block_size_high[plane_bsize];
+    int64_t sse;
     int rate;
     int64_t dist;
 
@@ -1787,14 +1789,15 @@
 
     // TODO(geza): Write direct sse functions that do not compute
     // variance as well.
-    cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
-                       &sse);
+    // cpi->fn_ptr[plane_bsize].vf(p->src.buf, p->src.stride, pd->dst.buf,
+    //                             pd->dst.stride, &sse);
+    sse = aom_sum_squares_2d_i16(p->src_diff, bw, bw, bh);
 
-    if (plane == 0) x->pred_sse[ref] = sse;
+    if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
 
     total_sse += sse;
 
-    model_rd_from_sse(cpi, xd, bs, plane, sse, &rate, &dist);
+    model_rd_from_sse(cpi, xd, plane_bsize, plane, sse, &rate, &dist);
 
     rate_sum += rate;
     dist_sum += dist;
@@ -2419,10 +2422,9 @@
 #endif  // CONFIG_COLLECT_RD_STATS == 2
 #endif  // CONFIG_COLLECT_RD_STATS
 
-static void model_rd_with_dnn(const AV1_COMP *const cpi,
-                              const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
-                              int plane, unsigned int *rsse, int *rate,
-                              int64_t *dist) {
+static void model_rd_with_dnn(const AV1_COMP *const cpi, MACROBLOCK *const x,
+                              BLOCK_SIZE plane_bsize, int plane, int64_t *rsse,
+                              int *rate, int64_t *dist) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
   const int log_numpels = num_pels_log2_lookup[plane_bsize];
@@ -2439,8 +2441,8 @@
   const uint8_t *const src = p->src.buf;
   const int dst_stride = pd->dst.stride;
   const uint8_t *const dst = pd->dst.buf;
-  unsigned int sse;
-  cpi->fn_ptr[plane_bsize].vf(src, src_stride, dst, dst_stride, &sse);
+  const int16_t *const src_diff = p->src_diff;
+  const int64_t sse = aom_sum_squares_2d_i16(p->src_diff, bw, bw, bh);
   const double sse_norm = (double)sse / num_samples;
 
   if (sse == 0) {
@@ -2460,20 +2462,17 @@
     return;
   }
 
-  const int diff_stride = block_size_wide[plane_bsize];
-  const int16_t *const src_diff = p->src_diff;
-
   double sse_norm_arr[4];
   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
-                                   dst_stride, src_diff, diff_stride,
-                                   sse_norm_arr, NULL);
-  const double mean = get_mean(src_diff, diff_stride, bw, bh);
+                                   dst_stride, src_diff, bw, sse_norm_arr,
+                                   NULL);
+  const double mean = get_mean(src_diff, bw, bw, bh);
   const double variance = sse_norm - mean * mean;
   assert(variance >= 0.0);
   const double q_sqr = (double)(q_step * q_step);
   const double q_sqr_by_variance = q_sqr / (variance + 1.0);
   double hor_corr, vert_corr;
-  get_horver_correlation(src_diff, diff_stride, bw, bh, &hor_corr, &vert_corr);
+  get_horver_correlation(src_diff, bw, bw, bh, &hor_corr, &vert_corr);
 
   float features[11];
   features[0] = (float)hor_corr;
@@ -2492,9 +2491,15 @@
   av1_nn_predict(features, &av1_pustats_dist_nnconfig, &dist_by_variance_f);
   av1_nn_predict(features, &av1_pustats_rate_nnconfig, &rate_f);
   const float dist_f = (float)((double)dist_by_variance_f * (1.0 + variance));
-  const int rate_i = (int)(AOMMAX(0.0, rate_f * (1 << log_numpels)) + 0.5);
-  const int64_t dist_i =
-      (int64_t)(AOMMAX(0.0, dist_f * (1 << log_numpels)) + 0.5);
+  int rate_i = (int)(AOMMAX(0.0, rate_f * (1 << log_numpels)) + 0.5);
+  int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * (1 << log_numpels)) + 0.5);
+
+  // Check if skip is better
+  if (RDCOST(x->rdmult, rate_i, dist_i) >= RDCOST(x->rdmult, 0, (sse << 4))) {
+    dist_i = sse << 4;
+    rate_i = 0;
+  }
+
   if (rate) *rate = rate_i;
   if (dist) *dist = dist_i;
   if (rsse) *rsse = sse;
@@ -2522,7 +2527,7 @@
     struct macroblockd_plane *const pd = &xd->plane[plane];
     const BLOCK_SIZE plane_bsize =
         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
-    unsigned int sse;
+    int64_t sse;
     int rate;
     int64_t dist;
 
@@ -2530,7 +2535,7 @@
 
     model_rd_with_dnn(cpi, x, plane_bsize, plane, &sse, &rate, &dist);
 
-    if (plane == 0) x->pred_sse[ref] = sse;
+    if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
 
     total_sse += sse;
     rate_sum += rate;
@@ -3415,6 +3420,7 @@
     }
   }
   // RD estimation.
+  av1_subtract_plane(x, bsize, 0);
   model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &this_rd_stats.rate,
                   &this_rd_stats.dist, &this_rd_stats.skip, &temp_sse, NULL,
                   NULL, NULL);
@@ -7439,6 +7445,7 @@
     *out_rate_mv = interinter_compound_motion_search(cpi, x, cur_mv, bsize,
                                                      this_mode, mi_row, mi_col);
     av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, ctx, bsize);
+    av1_subtract_plane(x, bsize, 0);
     model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
                     &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
@@ -7620,17 +7627,15 @@
       get_switchable_rate(x, mbmi->interp_filters, switchable_ctx);
 
   if (!skip_pred) {
-#if DNN_BASED_RD_INTERP_FILTER
-    av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, orig_dst, bsize);
-    for (int plane = 0; plane < num_planes; ++plane)
-      av1_subtract_plane(x, bsize, plane);
-    model_rd_for_sb_with_dnn(cpi, bsize, x, xd, 0, num_planes - 1, &tmp_rate,
-                             &tmp_dist, &tmp_skip_sb, &tmp_skip_sse, NULL, NULL,
-                             NULL);
-#else
     av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize);
+    av1_subtract_plane(x, bsize, 0);
+#if DNN_BASED_RD_INTERP_FILTER
+    model_rd_for_sb_with_dnn(cpi, bsize, x, xd, 0, 0, &tmp_rate, &tmp_dist,
+                             &tmp_skip_sb, &tmp_skip_sse, NULL, NULL, NULL);
+#else
     model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &tmp_rate, &tmp_dist, &tmp_skip_sb,
                     &tmp_skip_sse, NULL, NULL, NULL);
+#endif
     if (num_planes > 1) {
       int64_t tmp_y_rd = RDCOST(x->rdmult, tmp_rs + tmp_rate, tmp_dist);
       if (tmp_y_rd > *rd) {
@@ -7640,15 +7645,22 @@
       int tmp_rate_uv, tmp_skip_sb_uv;
       int64_t tmp_dist_uv, tmp_skip_sse_uv;
       av1_build_inter_predictors_sbuv(cm, xd, mi_row, mi_col, orig_dst, bsize);
+      for (int plane = 1; plane < num_planes; ++plane)
+        av1_subtract_plane(x, bsize, plane);
+#if DNN_BASED_RD_INTERP_FILTER
+      model_rd_for_sb_with_dnn(cpi, bsize, x, xd, 1, num_planes - 1,
+                               &tmp_rate_uv, &tmp_dist_uv, &tmp_skip_sb_uv,
+                               &tmp_skip_sse_uv, NULL, NULL, NULL);
+#else
       model_rd_for_sb(cpi, bsize, x, xd, 1, num_planes - 1, &tmp_rate_uv,
                       &tmp_dist_uv, &tmp_skip_sb_uv, &tmp_skip_sse_uv, NULL,
                       NULL, NULL);
+#endif
       tmp_rate += tmp_rate_uv;
       tmp_skip_sb &= tmp_skip_sb_uv;
       tmp_dist += tmp_dist_uv;
       tmp_skip_sse += tmp_skip_sse_uv;
     }
-#endif  // DNN_BASED_RD_INTERP_FILTER
   } else {
     tmp_rate = *rate;
     tmp_dist = *dist;
@@ -8010,6 +8022,7 @@
         av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
                                                   intrapred, bw);
         av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
+        av1_subtract_plane(x, bsize, 0);
         model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
                         &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
         rd = RDCOST(x->rdmult, tmp_rate_mv + rate_sum + rmode, dist_sum);
@@ -8069,6 +8082,7 @@
             mbmi->mv[0].as_int = tmp_mv.as_int;
             av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst,
                                            bsize);
+            av1_subtract_plane(x, bsize, 0);
             model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
                             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL,
                             NULL);
@@ -8839,6 +8853,8 @@
         int tmp_rate;
         int64_t tmp_dist;
         av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, &orig_dst, bsize);
+        for (int plane = 0; plane < num_planes; ++plane)
+          av1_subtract_plane(x, bsize, plane);
         model_rd_for_sb(cpi, bsize, x, xd, 0, num_planes - 1, &tmp_rate,
                         &tmp_dist, &skip_txfm_sb, &skip_sse_sb, plane_rate,
                         plane_sse, plane_dist);