Improve temporal filtering

Ported the weight calculation from VP9. Considered U and V pixel values
while calculating Y pixel's filter weights, and vice versa.

Borg test result(speed 1; 150 frames for lowres/midres and 60 frames for
hdres):
       avg_psnr:  ovr_psnr:  ssim:
hdres:  -0.140    -0.187    -0.147
midres: -0.088    -0.084    -0.072
lowres: -0.028    -0.062    -0.117

STATS_CHANGED

Change-Id: I0a3ed6297431fa56d0e672ede809af6d55a5786f
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index 8166e0b..ace585e 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -76,6 +76,312 @@
   }
 }
 
+static INLINE int64_t mod_index(int64_t sum_dist, int index, int rounding,
+                                int strength, int filter_weight) {
+  int64_t mod = (sum_dist * 3) / index;
+  mod += rounding;
+  mod >>= strength;
+
+  mod = AOMMIN(16, mod);
+
+  mod = 16 - mod;
+  mod *= filter_weight;
+
+  return mod;
+}
+
+static INLINE void calculate_squared_errors(const uint8_t *s, int s_stride,
+                                            const uint8_t *p, int p_stride,
+                                            uint16_t *diff_sse, unsigned int w,
+                                            unsigned int h) {
+  int idx = 0;
+  unsigned int i, j;
+
+  for (i = 0; i < h; i++) {
+    for (j = 0; j < w; j++) {
+      const int16_t diff = s[i * s_stride + j] - p[i * p_stride + j];
+      diff_sse[idx] = diff * diff;
+      idx++;
+    }
+  }
+}
+
+static void apply_temporal_filter(
+    const uint8_t *y_frame1, int y_stride, const uint8_t *y_pred,
+    int y_buf_stride, const uint8_t *u_frame1, const uint8_t *v_frame1,
+    int uv_stride, const uint8_t *u_pred, const uint8_t *v_pred,
+    int uv_buf_stride, unsigned int block_width, unsigned int block_height,
+    int ss_x, int ss_y, int strength, int filter_weight,
+    uint32_t *y_accumulator, uint16_t *y_count, uint32_t *u_accumulator,
+    uint16_t *u_count, uint32_t *v_accumulator, uint16_t *v_count) {
+  unsigned int i, j, k, m;
+  int modifier;
+  const int rounding = (1 << strength) >> 1;
+  const unsigned int uv_block_width = block_width >> ss_x;
+  const unsigned int uv_block_height = block_height >> ss_y;
+  DECLARE_ALIGNED(16, uint16_t, y_diff_sse[256]);
+  DECLARE_ALIGNED(16, uint16_t, u_diff_sse[256]);
+  DECLARE_ALIGNED(16, uint16_t, v_diff_sse[256]);
+
+  int idx = 0, idy;
+
+  assert(filter_weight >= 0);
+  assert(filter_weight <= 2);
+
+  memset(y_diff_sse, 0, 256 * sizeof(uint16_t));
+  memset(u_diff_sse, 0, 256 * sizeof(uint16_t));
+  memset(v_diff_sse, 0, 256 * sizeof(uint16_t));
+
+  // Calculate diff^2 for each pixel of the 16x16 block.
+  // TODO(yunqing): the following code needs to be optimized.
+  calculate_squared_errors(y_frame1, y_stride, y_pred, y_buf_stride, y_diff_sse,
+                           block_width, block_height);
+  calculate_squared_errors(u_frame1, uv_stride, u_pred, uv_buf_stride,
+                           u_diff_sse, uv_block_width, uv_block_height);
+  calculate_squared_errors(v_frame1, uv_stride, v_pred, uv_buf_stride,
+                           v_diff_sse, uv_block_width, uv_block_height);
+
+  for (i = 0, k = 0, m = 0; i < block_height; i++) {
+    for (j = 0; j < block_width; j++) {
+      const int pixel_value = y_pred[i * y_buf_stride + j];
+
+      // non-local mean approach
+      int y_index = 0;
+
+      const int uv_r = i >> ss_y;
+      const int uv_c = j >> ss_x;
+      modifier = 0;
+
+      for (idy = -1; idy <= 1; ++idy) {
+        for (idx = -1; idx <= 1; ++idx) {
+          const int row = (int)i + idy;
+          const int col = (int)j + idx;
+
+          if (row >= 0 && row < (int)block_height && col >= 0 &&
+              col < (int)block_width) {
+            modifier += y_diff_sse[row * (int)block_width + col];
+            ++y_index;
+          }
+        }
+      }
+
+      assert(y_index > 0);
+
+      modifier += u_diff_sse[uv_r * uv_block_width + uv_c];
+      modifier += v_diff_sse[uv_r * uv_block_width + uv_c];
+
+      y_index += 2;
+
+      modifier =
+          (int)mod_index(modifier, y_index, rounding, strength, filter_weight);
+
+      y_count[k] += modifier;
+      y_accumulator[k] += modifier * pixel_value;
+
+      ++k;
+
+      // Process chroma component
+      if (!(i & ss_y) && !(j & ss_x)) {
+        const int u_pixel_value = u_pred[uv_r * uv_buf_stride + uv_c];
+        const int v_pixel_value = v_pred[uv_r * uv_buf_stride + uv_c];
+
+        // non-local mean approach
+        int cr_index = 0;
+        int u_mod = 0, v_mod = 0;
+        int y_diff = 0;
+
+        for (idy = -1; idy <= 1; ++idy) {
+          for (idx = -1; idx <= 1; ++idx) {
+            const int row = uv_r + idy;
+            const int col = uv_c + idx;
+
+            if (row >= 0 && row < (int)uv_block_height && col >= 0 &&
+                col < (int)uv_block_width) {
+              u_mod += u_diff_sse[row * uv_block_width + col];
+              v_mod += v_diff_sse[row * uv_block_width + col];
+              ++cr_index;
+            }
+          }
+        }
+
+        assert(cr_index > 0);
+
+        for (idy = 0; idy < 1 + ss_y; ++idy) {
+          for (idx = 0; idx < 1 + ss_x; ++idx) {
+            const int row = (uv_r << ss_y) + idy;
+            const int col = (uv_c << ss_x) + idx;
+            y_diff += y_diff_sse[row * (int)block_width + col];
+            ++cr_index;
+          }
+        }
+
+        u_mod += y_diff;
+        v_mod += y_diff;
+
+        u_mod =
+            (int)mod_index(u_mod, cr_index, rounding, strength, filter_weight);
+        v_mod =
+            (int)mod_index(v_mod, cr_index, rounding, strength, filter_weight);
+
+        u_count[m] += u_mod;
+        u_accumulator[m] += u_mod * u_pixel_value;
+        v_count[m] += v_mod;
+        v_accumulator[m] += v_mod * v_pixel_value;
+
+        ++m;
+      }  // Complete YUV pixel
+    }
+  }
+}
+
+static INLINE void highbd_calculate_squared_errors(
+    const uint16_t *s, int s_stride, const uint16_t *p, int p_stride,
+    uint32_t *diff_sse, unsigned int w, unsigned int h) {
+  int idx = 0;
+  unsigned int i, j;
+
+  for (i = 0; i < h; i++) {
+    for (j = 0; j < w; j++) {
+      const int16_t diff = s[i * s_stride + j] - p[i * p_stride + j];
+      diff_sse[idx] = diff * diff;
+      idx++;
+    }
+  }
+}
+
+static void highbd_apply_temporal_filter(
+    const uint8_t *yf, int y_stride, const uint8_t *yp, int y_buf_stride,
+    const uint8_t *uf, const uint8_t *vf, int uv_stride, const uint8_t *up,
+    const uint8_t *vp, int uv_buf_stride, unsigned int block_width,
+    unsigned int block_height, int ss_x, int ss_y, int strength,
+    int filter_weight, uint32_t *y_accumulator, uint16_t *y_count,
+    uint32_t *u_accumulator, uint16_t *u_count, uint32_t *v_accumulator,
+    uint16_t *v_count) {
+  unsigned int i, j, k, m;
+  int64_t modifier;
+  const int rounding = (1 << strength) >> 1;
+  const unsigned int uv_block_width = block_width >> ss_x;
+  const unsigned int uv_block_height = block_height >> ss_y;
+  DECLARE_ALIGNED(16, uint32_t, y_diff_sse[256]);
+  DECLARE_ALIGNED(16, uint32_t, u_diff_sse[256]);
+  DECLARE_ALIGNED(16, uint32_t, v_diff_sse[256]);
+
+  const uint16_t *y_frame1 = CONVERT_TO_SHORTPTR(yf);
+  const uint16_t *u_frame1 = CONVERT_TO_SHORTPTR(uf);
+  const uint16_t *v_frame1 = CONVERT_TO_SHORTPTR(vf);
+  const uint16_t *y_pred = CONVERT_TO_SHORTPTR(yp);
+  const uint16_t *u_pred = CONVERT_TO_SHORTPTR(up);
+  const uint16_t *v_pred = CONVERT_TO_SHORTPTR(vp);
+  int idx = 0, idy;
+
+  assert(filter_weight >= 0);
+  assert(filter_weight <= 2);
+
+  memset(y_diff_sse, 0, 256 * sizeof(uint32_t));
+  memset(u_diff_sse, 0, 256 * sizeof(uint32_t));
+  memset(v_diff_sse, 0, 256 * sizeof(uint32_t));
+
+  // Calculate diff^2 for each pixel of the 16x16 block.
+  // TODO(yunqing): the following code needs to be optimized.
+  highbd_calculate_squared_errors(y_frame1, y_stride, y_pred, y_buf_stride,
+                                  y_diff_sse, block_width, block_height);
+  highbd_calculate_squared_errors(u_frame1, uv_stride, u_pred, uv_buf_stride,
+                                  u_diff_sse, uv_block_width, uv_block_height);
+  highbd_calculate_squared_errors(v_frame1, uv_stride, v_pred, uv_buf_stride,
+                                  v_diff_sse, uv_block_width, uv_block_height);
+
+  for (i = 0, k = 0, m = 0; i < block_height; i++) {
+    for (j = 0; j < block_width; j++) {
+      const int pixel_value = y_pred[i * y_buf_stride + j];
+
+      // non-local mean approach
+      int y_index = 0;
+
+      const int uv_r = i >> ss_y;
+      const int uv_c = j >> ss_x;
+      modifier = 0;
+
+      for (idy = -1; idy <= 1; ++idy) {
+        for (idx = -1; idx <= 1; ++idx) {
+          const int row = (int)i + idy;
+          const int col = (int)j + idx;
+
+          if (row >= 0 && row < (int)block_height && col >= 0 &&
+              col < (int)block_width) {
+            modifier += y_diff_sse[row * (int)block_width + col];
+            ++y_index;
+          }
+        }
+      }
+
+      assert(y_index > 0);
+
+      modifier += u_diff_sse[uv_r * uv_block_width + uv_c];
+      modifier += v_diff_sse[uv_r * uv_block_width + uv_c];
+
+      y_index += 2;
+
+      modifier =
+          mod_index(modifier, y_index, rounding, strength, filter_weight);
+
+      y_count[k] += modifier;
+      y_accumulator[k] += modifier * pixel_value;
+
+      ++k;
+
+      // Process chroma component
+      if (!(i & ss_y) && !(j & ss_x)) {
+        const int u_pixel_value = u_pred[uv_r * uv_buf_stride + uv_c];
+        const int v_pixel_value = v_pred[uv_r * uv_buf_stride + uv_c];
+
+        // non-local mean approach
+        int cr_index = 0;
+        int64_t u_mod = 0, v_mod = 0;
+        int y_diff = 0;
+
+        for (idy = -1; idy <= 1; ++idy) {
+          for (idx = -1; idx <= 1; ++idx) {
+            const int row = uv_r + idy;
+            const int col = uv_c + idx;
+
+            if (row >= 0 && row < (int)uv_block_height && col >= 0 &&
+                col < (int)uv_block_width) {
+              u_mod += u_diff_sse[row * uv_block_width + col];
+              v_mod += v_diff_sse[row * uv_block_width + col];
+              ++cr_index;
+            }
+          }
+        }
+
+        assert(cr_index > 0);
+
+        for (idy = 0; idy < 1 + ss_y; ++idy) {
+          for (idx = 0; idx < 1 + ss_x; ++idx) {
+            const int row = (uv_r << ss_y) + idy;
+            const int col = (uv_c << ss_x) + idx;
+            y_diff += y_diff_sse[row * (int)block_width + col];
+            ++cr_index;
+          }
+        }
+
+        u_mod += y_diff;
+        v_mod += y_diff;
+
+        u_mod = mod_index(u_mod, cr_index, rounding, strength, filter_weight);
+        v_mod = mod_index(v_mod, cr_index, rounding, strength, filter_weight);
+
+        u_count[m] += u_mod;
+        u_accumulator[m] += u_mod * u_pixel_value;
+        v_count[m] += v_mod;
+        v_accumulator[m] += v_mod * v_pixel_value;
+
+        ++m;
+      }  // Complete YUV pixel
+    }
+  }
+}
+
+// Only used in single plane case
 void av1_temporal_filter_apply_c(uint8_t *frame1, unsigned int stride,
                                  uint8_t *frame2, unsigned int block_width,
                                  unsigned int block_height, int strength,
@@ -137,6 +443,7 @@
   }
 }
 
+// Only used in single plane case
 void av1_highbd_temporal_filter_apply_c(
     uint8_t *frame1_8, unsigned int stride, uint8_t *frame2_8,
     unsigned int block_width, unsigned int block_height, int strength,
@@ -378,31 +685,38 @@
           // Apply the filter (YUV)
           if (mbd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
             int adj_strength = strength + 2 * (mbd->bd - 8);
-            av1_highbd_temporal_filter_apply(
-                f->y_buffer + mb_y_offset, f->y_stride, predictor, 16, 16,
-                adj_strength, filter_weight, accumulator, count);
-            if (num_planes > 1) {
-              av1_highbd_temporal_filter_apply(
-                  f->u_buffer + mb_uv_offset, f->uv_stride, predictor + 256,
-                  mb_uv_width, mb_uv_height, adj_strength, filter_weight,
-                  accumulator + 256, count + 256);
-              av1_highbd_temporal_filter_apply(
-                  f->v_buffer + mb_uv_offset, f->uv_stride, predictor + 512,
-                  mb_uv_width, mb_uv_height, adj_strength, filter_weight,
+
+            if (num_planes <= 1) {
+              // Single plane case
+              av1_highbd_temporal_filter_apply_c(
+                  f->y_buffer + mb_y_offset, f->y_stride, predictor, 16, 16,
+                  adj_strength, filter_weight, accumulator, count);
+            } else {
+              // Process 3 planes together.
+              highbd_apply_temporal_filter(
+                  f->y_buffer + mb_y_offset, f->y_stride, predictor, 16,
+                  f->u_buffer + mb_uv_offset, f->v_buffer + mb_uv_offset,
+                  f->uv_stride, predictor + 256, predictor + 512, mb_uv_width,
+                  16, 16, mbd->plane[1].subsampling_x,
+                  mbd->plane[1].subsampling_y, adj_strength, filter_weight,
+                  accumulator, count, accumulator + 256, count + 256,
                   accumulator + 512, count + 512);
             }
           } else {
-            av1_temporal_filter_apply_c(f->y_buffer + mb_y_offset, f->y_stride,
-                                        predictor, 16, 16, strength,
-                                        filter_weight, accumulator, count);
-            if (num_planes > 1) {
+            if (num_planes <= 1) {
+              // Single plane case
               av1_temporal_filter_apply_c(
-                  f->u_buffer + mb_uv_offset, f->uv_stride, predictor + 256,
-                  mb_uv_width, mb_uv_height, strength, filter_weight,
-                  accumulator + 256, count + 256);
-              av1_temporal_filter_apply_c(
-                  f->v_buffer + mb_uv_offset, f->uv_stride, predictor + 512,
-                  mb_uv_width, mb_uv_height, strength, filter_weight,
+                  f->y_buffer + mb_y_offset, f->y_stride, predictor, 16, 16,
+                  strength, filter_weight, accumulator, count);
+            } else {
+              // Process 3 planes together.
+              apply_temporal_filter(
+                  f->y_buffer + mb_y_offset, f->y_stride, predictor, 16,
+                  f->u_buffer + mb_uv_offset, f->v_buffer + mb_uv_offset,
+                  f->uv_stride, predictor + 256, predictor + 512, mb_uv_width,
+                  16, 16, mbd->plane[1].subsampling_x,
+                  mbd->plane[1].subsampling_y, strength, filter_weight,
+                  accumulator, count, accumulator + 256, count + 256,
                   accumulator + 512, count + 512);
             }
           }