Improve temporal filtering
Ported the weight calculation from VP9. Considered U and V pixel values
while calculating Y pixel's filter weights, and vice versa.
Borg test result(speed 1; 150 frames for lowres/midres and 60 frames for
hdres):
avg_psnr: ovr_psnr: ssim:
hdres: -0.140 -0.187 -0.147
midres: -0.088 -0.084 -0.072
lowres: -0.028 -0.062 -0.117
STATS_CHANGED
Change-Id: I0a3ed6297431fa56d0e672ede809af6d55a5786f
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index 8166e0b..ace585e 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -76,6 +76,312 @@
}
}
+static INLINE int64_t mod_index(int64_t sum_dist, int index, int rounding,
+ int strength, int filter_weight) {
+ int64_t mod = (sum_dist * 3) / index;
+ mod += rounding;
+ mod >>= strength;
+
+ mod = AOMMIN(16, mod);
+
+ mod = 16 - mod;
+ mod *= filter_weight;
+
+ return mod;
+}
+
+static INLINE void calculate_squared_errors(const uint8_t *s, int s_stride,
+ const uint8_t *p, int p_stride,
+ uint16_t *diff_sse, unsigned int w,
+ unsigned int h) {
+ int idx = 0;
+ unsigned int i, j;
+
+ for (i = 0; i < h; i++) {
+ for (j = 0; j < w; j++) {
+ const int16_t diff = s[i * s_stride + j] - p[i * p_stride + j];
+ diff_sse[idx] = diff * diff;
+ idx++;
+ }
+ }
+}
+
+static void apply_temporal_filter(
+ const uint8_t *y_frame1, int y_stride, const uint8_t *y_pred,
+ int y_buf_stride, const uint8_t *u_frame1, const uint8_t *v_frame1,
+ int uv_stride, const uint8_t *u_pred, const uint8_t *v_pred,
+ int uv_buf_stride, unsigned int block_width, unsigned int block_height,
+ int ss_x, int ss_y, int strength, int filter_weight,
+ uint32_t *y_accumulator, uint16_t *y_count, uint32_t *u_accumulator,
+ uint16_t *u_count, uint32_t *v_accumulator, uint16_t *v_count) {
+ unsigned int i, j, k, m;
+ int modifier;
+ const int rounding = (1 << strength) >> 1;
+ const unsigned int uv_block_width = block_width >> ss_x;
+ const unsigned int uv_block_height = block_height >> ss_y;
+ DECLARE_ALIGNED(16, uint16_t, y_diff_sse[256]);
+ DECLARE_ALIGNED(16, uint16_t, u_diff_sse[256]);
+ DECLARE_ALIGNED(16, uint16_t, v_diff_sse[256]);
+
+ int idx = 0, idy;
+
+ assert(filter_weight >= 0);
+ assert(filter_weight <= 2);
+
+ memset(y_diff_sse, 0, 256 * sizeof(uint16_t));
+ memset(u_diff_sse, 0, 256 * sizeof(uint16_t));
+ memset(v_diff_sse, 0, 256 * sizeof(uint16_t));
+
+ // Calculate diff^2 for each pixel of the 16x16 block.
+ // TODO(yunqing): the following code needs to be optimized.
+ calculate_squared_errors(y_frame1, y_stride, y_pred, y_buf_stride, y_diff_sse,
+ block_width, block_height);
+ calculate_squared_errors(u_frame1, uv_stride, u_pred, uv_buf_stride,
+ u_diff_sse, uv_block_width, uv_block_height);
+ calculate_squared_errors(v_frame1, uv_stride, v_pred, uv_buf_stride,
+ v_diff_sse, uv_block_width, uv_block_height);
+
+ for (i = 0, k = 0, m = 0; i < block_height; i++) {
+ for (j = 0; j < block_width; j++) {
+ const int pixel_value = y_pred[i * y_buf_stride + j];
+
+ // non-local mean approach
+ int y_index = 0;
+
+ const int uv_r = i >> ss_y;
+ const int uv_c = j >> ss_x;
+ modifier = 0;
+
+ for (idy = -1; idy <= 1; ++idy) {
+ for (idx = -1; idx <= 1; ++idx) {
+ const int row = (int)i + idy;
+ const int col = (int)j + idx;
+
+ if (row >= 0 && row < (int)block_height && col >= 0 &&
+ col < (int)block_width) {
+ modifier += y_diff_sse[row * (int)block_width + col];
+ ++y_index;
+ }
+ }
+ }
+
+ assert(y_index > 0);
+
+ modifier += u_diff_sse[uv_r * uv_block_width + uv_c];
+ modifier += v_diff_sse[uv_r * uv_block_width + uv_c];
+
+ y_index += 2;
+
+ modifier =
+ (int)mod_index(modifier, y_index, rounding, strength, filter_weight);
+
+ y_count[k] += modifier;
+ y_accumulator[k] += modifier * pixel_value;
+
+ ++k;
+
+ // Process chroma component
+ if (!(i & ss_y) && !(j & ss_x)) {
+ const int u_pixel_value = u_pred[uv_r * uv_buf_stride + uv_c];
+ const int v_pixel_value = v_pred[uv_r * uv_buf_stride + uv_c];
+
+ // non-local mean approach
+ int cr_index = 0;
+ int u_mod = 0, v_mod = 0;
+ int y_diff = 0;
+
+ for (idy = -1; idy <= 1; ++idy) {
+ for (idx = -1; idx <= 1; ++idx) {
+ const int row = uv_r + idy;
+ const int col = uv_c + idx;
+
+ if (row >= 0 && row < (int)uv_block_height && col >= 0 &&
+ col < (int)uv_block_width) {
+ u_mod += u_diff_sse[row * uv_block_width + col];
+ v_mod += v_diff_sse[row * uv_block_width + col];
+ ++cr_index;
+ }
+ }
+ }
+
+ assert(cr_index > 0);
+
+ for (idy = 0; idy < 1 + ss_y; ++idy) {
+ for (idx = 0; idx < 1 + ss_x; ++idx) {
+ const int row = (uv_r << ss_y) + idy;
+ const int col = (uv_c << ss_x) + idx;
+ y_diff += y_diff_sse[row * (int)block_width + col];
+ ++cr_index;
+ }
+ }
+
+ u_mod += y_diff;
+ v_mod += y_diff;
+
+ u_mod =
+ (int)mod_index(u_mod, cr_index, rounding, strength, filter_weight);
+ v_mod =
+ (int)mod_index(v_mod, cr_index, rounding, strength, filter_weight);
+
+ u_count[m] += u_mod;
+ u_accumulator[m] += u_mod * u_pixel_value;
+ v_count[m] += v_mod;
+ v_accumulator[m] += v_mod * v_pixel_value;
+
+ ++m;
+ } // Complete YUV pixel
+ }
+ }
+}
+
+static INLINE void highbd_calculate_squared_errors(
+ const uint16_t *s, int s_stride, const uint16_t *p, int p_stride,
+ uint32_t *diff_sse, unsigned int w, unsigned int h) {
+ int idx = 0;
+ unsigned int i, j;
+
+ for (i = 0; i < h; i++) {
+ for (j = 0; j < w; j++) {
+ const int16_t diff = s[i * s_stride + j] - p[i * p_stride + j];
+ diff_sse[idx] = diff * diff;
+ idx++;
+ }
+ }
+}
+
+static void highbd_apply_temporal_filter(
+ const uint8_t *yf, int y_stride, const uint8_t *yp, int y_buf_stride,
+ const uint8_t *uf, const uint8_t *vf, int uv_stride, const uint8_t *up,
+ const uint8_t *vp, int uv_buf_stride, unsigned int block_width,
+ unsigned int block_height, int ss_x, int ss_y, int strength,
+ int filter_weight, uint32_t *y_accumulator, uint16_t *y_count,
+ uint32_t *u_accumulator, uint16_t *u_count, uint32_t *v_accumulator,
+ uint16_t *v_count) {
+ unsigned int i, j, k, m;
+ int64_t modifier;
+ const int rounding = (1 << strength) >> 1;
+ const unsigned int uv_block_width = block_width >> ss_x;
+ const unsigned int uv_block_height = block_height >> ss_y;
+ DECLARE_ALIGNED(16, uint32_t, y_diff_sse[256]);
+ DECLARE_ALIGNED(16, uint32_t, u_diff_sse[256]);
+ DECLARE_ALIGNED(16, uint32_t, v_diff_sse[256]);
+
+ const uint16_t *y_frame1 = CONVERT_TO_SHORTPTR(yf);
+ const uint16_t *u_frame1 = CONVERT_TO_SHORTPTR(uf);
+ const uint16_t *v_frame1 = CONVERT_TO_SHORTPTR(vf);
+ const uint16_t *y_pred = CONVERT_TO_SHORTPTR(yp);
+ const uint16_t *u_pred = CONVERT_TO_SHORTPTR(up);
+ const uint16_t *v_pred = CONVERT_TO_SHORTPTR(vp);
+ int idx = 0, idy;
+
+ assert(filter_weight >= 0);
+ assert(filter_weight <= 2);
+
+ memset(y_diff_sse, 0, 256 * sizeof(uint32_t));
+ memset(u_diff_sse, 0, 256 * sizeof(uint32_t));
+ memset(v_diff_sse, 0, 256 * sizeof(uint32_t));
+
+ // Calculate diff^2 for each pixel of the 16x16 block.
+ // TODO(yunqing): the following code needs to be optimized.
+ highbd_calculate_squared_errors(y_frame1, y_stride, y_pred, y_buf_stride,
+ y_diff_sse, block_width, block_height);
+ highbd_calculate_squared_errors(u_frame1, uv_stride, u_pred, uv_buf_stride,
+ u_diff_sse, uv_block_width, uv_block_height);
+ highbd_calculate_squared_errors(v_frame1, uv_stride, v_pred, uv_buf_stride,
+ v_diff_sse, uv_block_width, uv_block_height);
+
+ for (i = 0, k = 0, m = 0; i < block_height; i++) {
+ for (j = 0; j < block_width; j++) {
+ const int pixel_value = y_pred[i * y_buf_stride + j];
+
+ // non-local mean approach
+ int y_index = 0;
+
+ const int uv_r = i >> ss_y;
+ const int uv_c = j >> ss_x;
+ modifier = 0;
+
+ for (idy = -1; idy <= 1; ++idy) {
+ for (idx = -1; idx <= 1; ++idx) {
+ const int row = (int)i + idy;
+ const int col = (int)j + idx;
+
+ if (row >= 0 && row < (int)block_height && col >= 0 &&
+ col < (int)block_width) {
+ modifier += y_diff_sse[row * (int)block_width + col];
+ ++y_index;
+ }
+ }
+ }
+
+ assert(y_index > 0);
+
+ modifier += u_diff_sse[uv_r * uv_block_width + uv_c];
+ modifier += v_diff_sse[uv_r * uv_block_width + uv_c];
+
+ y_index += 2;
+
+ modifier =
+ mod_index(modifier, y_index, rounding, strength, filter_weight);
+
+ y_count[k] += modifier;
+ y_accumulator[k] += modifier * pixel_value;
+
+ ++k;
+
+ // Process chroma component
+ if (!(i & ss_y) && !(j & ss_x)) {
+ const int u_pixel_value = u_pred[uv_r * uv_buf_stride + uv_c];
+ const int v_pixel_value = v_pred[uv_r * uv_buf_stride + uv_c];
+
+ // non-local mean approach
+ int cr_index = 0;
+ int64_t u_mod = 0, v_mod = 0;
+ int y_diff = 0;
+
+ for (idy = -1; idy <= 1; ++idy) {
+ for (idx = -1; idx <= 1; ++idx) {
+ const int row = uv_r + idy;
+ const int col = uv_c + idx;
+
+ if (row >= 0 && row < (int)uv_block_height && col >= 0 &&
+ col < (int)uv_block_width) {
+ u_mod += u_diff_sse[row * uv_block_width + col];
+ v_mod += v_diff_sse[row * uv_block_width + col];
+ ++cr_index;
+ }
+ }
+ }
+
+ assert(cr_index > 0);
+
+ for (idy = 0; idy < 1 + ss_y; ++idy) {
+ for (idx = 0; idx < 1 + ss_x; ++idx) {
+ const int row = (uv_r << ss_y) + idy;
+ const int col = (uv_c << ss_x) + idx;
+ y_diff += y_diff_sse[row * (int)block_width + col];
+ ++cr_index;
+ }
+ }
+
+ u_mod += y_diff;
+ v_mod += y_diff;
+
+ u_mod = mod_index(u_mod, cr_index, rounding, strength, filter_weight);
+ v_mod = mod_index(v_mod, cr_index, rounding, strength, filter_weight);
+
+ u_count[m] += u_mod;
+ u_accumulator[m] += u_mod * u_pixel_value;
+ v_count[m] += v_mod;
+ v_accumulator[m] += v_mod * v_pixel_value;
+
+ ++m;
+ } // Complete YUV pixel
+ }
+ }
+}
+
+// Only used in single plane case
void av1_temporal_filter_apply_c(uint8_t *frame1, unsigned int stride,
uint8_t *frame2, unsigned int block_width,
unsigned int block_height, int strength,
@@ -137,6 +443,7 @@
}
}
+// Only used in single plane case
void av1_highbd_temporal_filter_apply_c(
uint8_t *frame1_8, unsigned int stride, uint8_t *frame2_8,
unsigned int block_width, unsigned int block_height, int strength,
@@ -378,31 +685,38 @@
// Apply the filter (YUV)
if (mbd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
int adj_strength = strength + 2 * (mbd->bd - 8);
- av1_highbd_temporal_filter_apply(
- f->y_buffer + mb_y_offset, f->y_stride, predictor, 16, 16,
- adj_strength, filter_weight, accumulator, count);
- if (num_planes > 1) {
- av1_highbd_temporal_filter_apply(
- f->u_buffer + mb_uv_offset, f->uv_stride, predictor + 256,
- mb_uv_width, mb_uv_height, adj_strength, filter_weight,
- accumulator + 256, count + 256);
- av1_highbd_temporal_filter_apply(
- f->v_buffer + mb_uv_offset, f->uv_stride, predictor + 512,
- mb_uv_width, mb_uv_height, adj_strength, filter_weight,
+
+ if (num_planes <= 1) {
+ // Single plane case
+ av1_highbd_temporal_filter_apply_c(
+ f->y_buffer + mb_y_offset, f->y_stride, predictor, 16, 16,
+ adj_strength, filter_weight, accumulator, count);
+ } else {
+ // Process 3 planes together.
+ highbd_apply_temporal_filter(
+ f->y_buffer + mb_y_offset, f->y_stride, predictor, 16,
+ f->u_buffer + mb_uv_offset, f->v_buffer + mb_uv_offset,
+ f->uv_stride, predictor + 256, predictor + 512, mb_uv_width,
+ 16, 16, mbd->plane[1].subsampling_x,
+ mbd->plane[1].subsampling_y, adj_strength, filter_weight,
+ accumulator, count, accumulator + 256, count + 256,
accumulator + 512, count + 512);
}
} else {
- av1_temporal_filter_apply_c(f->y_buffer + mb_y_offset, f->y_stride,
- predictor, 16, 16, strength,
- filter_weight, accumulator, count);
- if (num_planes > 1) {
+ if (num_planes <= 1) {
+ // Single plane case
av1_temporal_filter_apply_c(
- f->u_buffer + mb_uv_offset, f->uv_stride, predictor + 256,
- mb_uv_width, mb_uv_height, strength, filter_weight,
- accumulator + 256, count + 256);
- av1_temporal_filter_apply_c(
- f->v_buffer + mb_uv_offset, f->uv_stride, predictor + 512,
- mb_uv_width, mb_uv_height, strength, filter_weight,
+ f->y_buffer + mb_y_offset, f->y_stride, predictor, 16, 16,
+ strength, filter_weight, accumulator, count);
+ } else {
+ // Process 3 planes together.
+ apply_temporal_filter(
+ f->y_buffer + mb_y_offset, f->y_stride, predictor, 16,
+ f->u_buffer + mb_uv_offset, f->v_buffer + mb_uv_offset,
+ f->uv_stride, predictor + 256, predictor + 512, mb_uv_width,
+ 16, 16, mbd->plane[1].subsampling_x,
+ mbd->plane[1].subsampling_y, strength, filter_weight,
+ accumulator, count, accumulator + 256, count + 256,
accumulator + 512, count + 512);
}
}