Optimize memory usage in CDEF

Redundant src and ref_coeff buffers are removed
from av1_cdef_search.
Code refactoring of mse functions used in computation
of cdef distortion.

Resolution    Tile     Memory reduction
                       Single   Multi
                       Thread   Thread
640x360       2x1      ~4%      ~3% (2 threads)
832x480       2x1      ~5%      ~4% (2 threads)
1280x720      2x2      ~4%      ~4% (4 threads)
1920x1080     4x2      ~4%      ~3% (8 threads)

Memory measuring command:
$ command time -v ./aomenc ...

Change-Id: I70aeb28230ae4f654424efbb0d344265aacacea7
diff --git a/av1/encoder/pickcdef.c b/av1/encoder/pickcdef.c
index c8b90e1..a1092fd 100644
--- a/av1/encoder/pickcdef.c
+++ b/av1/encoder/pickcdef.c
@@ -190,21 +190,42 @@
   return best_tot_mse;
 }
 
-static void copy_sb16_16(uint16_t *dst, int dstride, const uint16_t *src,
-                         int src_voffset, int src_hoffset, int sstride,
-                         int vsize, int hsize) {
+typedef void (*copy_fn_t)(uint16_t *dst, int dstride, const void *src,
+                          int src_voffset, int src_hoffset, int sstride,
+                          int vsize, int hsize);
+typedef uint64_t (*compute_cdef_dist_t)(void *dst, int dstride, uint16_t *src,
+                                        cdef_list *dlist, int cdef_count,
+                                        BLOCK_SIZE bsize, int coeff_shift,
+                                        int row, int col);
+
+static void copy_sb16_16_highbd(uint16_t *dst, int dstride, const void *src,
+                                int src_voffset, int src_hoffset, int sstride,
+                                int vsize, int hsize) {
   int r;
-  const uint16_t *base = &src[src_voffset * sstride + src_hoffset];
+  const uint16_t *src16 = CONVERT_TO_SHORTPTR((uint8_t *)src);
+  const uint16_t *base = &src16[src_voffset * sstride + src_hoffset];
   for (r = 0; r < vsize; r++)
     memcpy(dst + r * dstride, base + r * sstride, hsize * sizeof(*base));
 }
 
-static INLINE uint64_t mse_8x8_16bit(uint16_t *dst, int dstride, uint16_t *src,
-                                     int sstride) {
+static void copy_sb16_16(uint16_t *dst, int dstride, const void *src,
+                         int src_voffset, int src_hoffset, int sstride,
+                         int vsize, int hsize) {
+  int r, c;
+  const uint8_t *src8 = (uint8_t *)src;
+  const uint8_t *base = &src8[src_voffset * sstride + src_hoffset];
+  for (r = 0; r < vsize; r++)
+    for (c = 0; c < hsize; c++)
+      dst[r * dstride + c] = (uint16_t)base[r * sstride + c];
+}
+
+static INLINE uint64_t mse_wxh_16bit_highbd(uint16_t *dst, int dstride,
+                                            uint16_t *src, int sstride, int w,
+                                            int h) {
   uint64_t sum = 0;
   int i, j;
-  for (i = 0; i < 8; i++) {
-    for (j = 0; j < 8; j++) {
+  for (i = 0; i < h; i++) {
+    for (j = 0; j < w; j++) {
       int e = dst[i * dstride + j] - src[i * sstride + j];
       sum += e * e;
     }
@@ -212,64 +233,72 @@
   return sum;
 }
 
-static INLINE uint64_t mse_4x4_16bit(uint16_t *dst, int dstride, uint16_t *src,
-                                     int sstride) {
+static INLINE uint64_t mse_wxh_16bit(uint8_t *dst, int dstride, uint16_t *src,
+                                     int sstride, int w, int h) {
   uint64_t sum = 0;
   int i, j;
-  for (i = 0; i < 4; i++) {
-    for (j = 0; j < 4; j++) {
-      int e = dst[i * dstride + j] - src[i * sstride + j];
+  for (i = 0; i < h; i++) {
+    for (j = 0; j < w; j++) {
+      int e = (uint16_t)dst[i * dstride + j] - src[i * sstride + j];
       sum += e * e;
     }
   }
   return sum;
 }
 
+static INLINE void init_src_params(int *src_stride, int *width, int *height,
+                                   int *width_log2, int *height_log2,
+                                   BLOCK_SIZE bsize) {
+  *src_stride = block_size_wide[bsize];
+  *width = block_size_wide[bsize];
+  *height = block_size_high[bsize];
+  *width_log2 = MI_SIZE_LOG2 + mi_size_wide_log2[bsize];
+  *height_log2 = MI_SIZE_LOG2 + mi_size_wide_log2[bsize];
+}
+
 /* Compute MSE only on the blocks we filtered. */
-static uint64_t compute_cdef_dist(uint16_t *dst, int dstride, uint16_t *src,
-                                  cdef_list *dlist, int cdef_count,
-                                  BLOCK_SIZE bsize, int coeff_shift, int pli) {
+static uint64_t compute_cdef_dist_highbd(void *dst, int dstride, uint16_t *src,
+                                         cdef_list *dlist, int cdef_count,
+                                         BLOCK_SIZE bsize, int coeff_shift,
+                                         int row, int col) {
+  assert(bsize == BLOCK_4X4 || bsize == BLOCK_4X8 || bsize == BLOCK_8X4 ||
+         bsize == BLOCK_8X8);
   uint64_t sum = 0;
   int bi, bx, by;
-  if (bsize == BLOCK_8X8) {
-    for (bi = 0; bi < cdef_count; bi++) {
-      by = dlist[bi].by;
-      bx = dlist[bi].bx;
-      if (pli == 0) {
-        sum += mse_8x8_16bit(&dst[(by << 3) * dstride + (bx << 3)], dstride,
-                             &src[bi << (3 + 3)], 8);
+  uint16_t *dst16 = CONVERT_TO_SHORTPTR((uint8_t *)dst);
+  uint16_t *dst_buff = &dst16[row * dstride + col];
+  int src_stride, width, height, width_log2, height_log2;
+  init_src_params(&src_stride, &width, &height, &width_log2, &height_log2,
+                  bsize);
+  for (bi = 0; bi < cdef_count; bi++) {
+    by = dlist[bi].by;
+    bx = dlist[bi].bx;
+    sum += mse_wxh_16bit_highbd(
+        &dst_buff[(by << height_log2) * dstride + (bx << width_log2)], dstride,
+        &src[bi << (height_log2 + width_log2)], src_stride, width, height);
+  }
+  return sum >> 2 * coeff_shift;
+}
 
-      } else {
-        sum += mse_8x8_16bit(&dst[(by << 3) * dstride + (bx << 3)], dstride,
-                             &src[bi << (3 + 3)], 8);
-      }
-    }
-  } else if (bsize == BLOCK_4X8) {
-    for (bi = 0; bi < cdef_count; bi++) {
-      by = dlist[bi].by;
-      bx = dlist[bi].bx;
-      sum += mse_4x4_16bit(&dst[(by << 3) * dstride + (bx << 2)], dstride,
-                           &src[bi << (3 + 2)], 4);
-      sum += mse_4x4_16bit(&dst[((by << 3) + 4) * dstride + (bx << 2)], dstride,
-                           &src[(bi << (3 + 2)) + 4 * 4], 4);
-    }
-  } else if (bsize == BLOCK_8X4) {
-    for (bi = 0; bi < cdef_count; bi++) {
-      by = dlist[bi].by;
-      bx = dlist[bi].bx;
-      sum += mse_4x4_16bit(&dst[(by << 2) * dstride + (bx << 3)], dstride,
-                           &src[bi << (2 + 3)], 8);
-      sum += mse_4x4_16bit(&dst[(by << 2) * dstride + (bx << 3) + 4], dstride,
-                           &src[(bi << (2 + 3)) + 4], 8);
-    }
-  } else {
-    assert(bsize == BLOCK_4X4);
-    for (bi = 0; bi < cdef_count; bi++) {
-      by = dlist[bi].by;
-      bx = dlist[bi].bx;
-      sum += mse_4x4_16bit(&dst[(by << 2) * dstride + (bx << 2)], dstride,
-                           &src[bi << (2 + 2)], 4);
-    }
+static uint64_t compute_cdef_dist(void *dst, int dstride, uint16_t *src,
+                                  cdef_list *dlist, int cdef_count,
+                                  BLOCK_SIZE bsize, int coeff_shift, int row,
+                                  int col) {
+  assert(bsize == BLOCK_4X4 || bsize == BLOCK_4X8 || bsize == BLOCK_8X4 ||
+         bsize == BLOCK_8X8);
+  uint64_t sum = 0;
+  int bi, bx, by;
+  uint8_t *dst8 = (uint8_t *)dst;
+  uint8_t *dst_buff = &dst8[row * dstride + col];
+  int src_stride, width, height, width_log2, height_log2;
+  init_src_params(&src_stride, &width, &height, &width_log2, &height_log2,
+                  bsize);
+  for (bi = 0; bi < cdef_count; bi++) {
+    by = dlist[bi].by;
+    bx = dlist[bi].bx;
+    sum += mse_wxh_16bit(
+        &dst_buff[(by << height_log2) * dstride + (bx << width_log2)], dstride,
+        &src[bi << (height_log2 + width_log2)], src_stride, width, height);
   }
   return sum >> 2 * coeff_shift;
 }
@@ -354,8 +383,6 @@
     return;
   }
 
-  uint16_t *src[3];
-  uint16_t *ref_coeff[3];
   cdef_list dlist[MI_SIZE_128X128 * MI_SIZE_128X128];
   int dir[CDEF_NBLOCKS][CDEF_NBLOCKS] = { { 0 } };
   int var[CDEF_NBLOCKS][CDEF_NBLOCKS] = { { 0 } };
@@ -375,62 +402,32 @@
   mse[0] = aom_malloc(sizeof(**mse) * nvfb * nhfb);
   mse[1] = aom_malloc(sizeof(**mse) * nvfb * nhfb);
 
-  int stride[3];
   int bsize[3];
   int mi_wide_l2[3];
   int mi_high_l2[3];
   int xdec[3];
   int ydec[3];
+  uint8_t *ref_buffer[3] = { ref->y_buffer, ref->u_buffer, ref->v_buffer };
+  int ref_stride[3] = { ref->y_stride, ref->uv_stride, ref->uv_stride };
+
   for (int pli = 0; pli < num_planes; pli++) {
-    uint8_t *ref_buffer;
-    int ref_stride;
-    switch (pli) {
-      case 0:
-        ref_buffer = ref->y_buffer;
-        ref_stride = ref->y_stride;
-        break;
-      case 1:
-        ref_buffer = ref->u_buffer;
-        ref_stride = ref->uv_stride;
-        break;
-      case 2:
-        ref_buffer = ref->v_buffer;
-        ref_stride = ref->uv_stride;
-        break;
-    }
-    src[pli] = aom_memalign(32, sizeof(*src) * mi_params->mi_rows *
-                                    mi_params->mi_cols * MI_SIZE * MI_SIZE);
-    ref_coeff[pli] =
-        aom_memalign(32, sizeof(*ref_coeff) * mi_params->mi_rows *
-                             mi_params->mi_cols * MI_SIZE * MI_SIZE);
     xdec[pli] = xd->plane[pli].subsampling_x;
     ydec[pli] = xd->plane[pli].subsampling_y;
     bsize[pli] = ydec[pli] ? (xdec[pli] ? BLOCK_4X4 : BLOCK_8X4)
                            : (xdec[pli] ? BLOCK_4X8 : BLOCK_8X8);
-    stride[pli] = mi_params->mi_cols << MI_SIZE_LOG2;
     mi_wide_l2[pli] = MI_SIZE_LOG2 - xd->plane[pli].subsampling_x;
     mi_high_l2[pli] = MI_SIZE_LOG2 - xd->plane[pli].subsampling_y;
+  }
 
-    const int frame_height =
-        (mi_params->mi_rows * MI_SIZE) >> xd->plane[pli].subsampling_y;
-    const int frame_width =
-        (mi_params->mi_cols * MI_SIZE) >> xd->plane[pli].subsampling_x;
-    const int plane_sride = stride[pli];
-    const int dst_stride = xd->plane[pli].dst.stride;
-    for (int r = 0; r < frame_height; ++r) {
-      for (int c = 0; c < frame_width; ++c) {
-        if (cm->seq_params.use_highbitdepth) {
-          src[pli][r * plane_sride + c] =
-              CONVERT_TO_SHORTPTR(xd->plane[pli].dst.buf)[r * dst_stride + c];
-          ref_coeff[pli][r * plane_sride + c] =
-              CONVERT_TO_SHORTPTR(ref_buffer)[r * ref_stride + c];
-        } else {
-          src[pli][r * plane_sride + c] =
-              xd->plane[pli].dst.buf[r * dst_stride + c];
-          ref_coeff[pli][r * plane_sride + c] = ref_buffer[r * ref_stride + c];
-        }
-      }
-    }
+  copy_fn_t copy_fn;
+  compute_cdef_dist_t compute_cdef_dist_fn;
+
+  if (cm->seq_params.use_highbitdepth) {
+    copy_fn = copy_sb16_16_highbd;
+    compute_cdef_dist_fn = compute_cdef_dist_highbd;
+  } else {
+    copy_fn = copy_sb16_16;
+    compute_cdef_dist_fn = compute_cdef_dist;
   }
 
   DECLARE_ALIGNED(32, uint16_t, inbuf[CDEF_INBUF_SIZE]);
@@ -482,9 +479,9 @@
       int dirinit = 0;
       for (int pli = 0; pli < num_planes; pli++) {
         for (int i = 0; i < CDEF_INBUF_SIZE; i++) inbuf[i] = CDEF_VERY_LARGE;
-        /* We avoid filtering the pixels for which some of the pixels to average
-           are outside the frame. We could change the filter instead, but it
-           would add special cases for any future vectorization. */
+        /* We avoid filtering the pixels for which some of the pixels to
+           average are outside the frame. We could change the filter instead,
+           but it would add special cases for any future vectorization. */
         const int ysize = (nvb << mi_high_l2[pli]) +
                           CDEF_VBORDER * (fbr + vb_step < nvfb) + yoff;
         const int xsize = (nhb << mi_wide_l2[pli]) +
@@ -495,16 +492,16 @@
           int pri_strength = gi / CDEF_SEC_STRENGTHS;
           if (fast) pri_strength = get_pri_strength(pick_method, pri_strength);
           const int sec_strength = gi % CDEF_SEC_STRENGTHS;
-          copy_sb16_16(&in[(-yoff * CDEF_BSTRIDE - xoff)], CDEF_BSTRIDE,
-                       src[pli], row - yoff, col - xoff, stride[pli], ysize,
-                       xsize);
+          copy_fn(&in[(-yoff * CDEF_BSTRIDE - xoff)], CDEF_BSTRIDE,
+                  xd->plane[pli].dst.buf, row - yoff, col - xoff,
+                  xd->plane[pli].dst.stride, ysize, xsize);
           av1_cdef_filter_fb(
               NULL, tmp_dst, CDEF_BSTRIDE, in, xdec[pli], ydec[pli], dir,
               &dirinit, var, pli, dlist, cdef_count, pri_strength,
               sec_strength + (sec_strength == 3), damping, coeff_shift);
-          const uint64_t curr_mse = compute_cdef_dist(
-              ref_coeff[pli] + row * stride[pli] + col, stride[pli], tmp_dst,
-              dlist, cdef_count, bsize[pli], coeff_shift, pli);
+          const uint64_t curr_mse = compute_cdef_dist_fn(
+              ref_buffer[pli], ref_stride[pli], tmp_dst, dlist, cdef_count,
+              bsize[pli], coeff_shift, row, col);
           if (pli < 2)
             mse[pli][sb_count][gi] = curr_mse;
           else
@@ -586,9 +583,5 @@
 
   aom_free(mse[0]);
   aom_free(mse[1]);
-  for (int pli = 0; pli < num_planes; pli++) {
-    aom_free(src[pli]);
-    aom_free(ref_coeff[pli]);
-  }
   aom_free(sb_index);
 }