Fix Overflow Issue in Optical Flow MV Refinement
This change resolves an overflow problem identified in the optical
flow MV refinement process. The issue originates from the insufficient
precision used in the accumulation of values during the refinement.
Specifically, in the function av1_opfl_mv_refinement_nxn_interp_grad,
the variables `u`, `v`, and `w` are initially set to use 16-bit signed
precision. This precision level is inadequate when dealing with the
accumulated variables `su2`, `sv2`, `suv`, `suw`, and `svw`, as these
can result in values exceeding the limits of a 32-bit signed integer.
In the existing implementation, these variables are declared as
`int64_t` in the C code. However, the SSE4_1 intrinsic, which is
used for processing, maintains the accumulated values in a 32-bit
signed format. This discrepancy between the C code and the intrinsic
function's handling of data types leads to the overflow issue #272 .
diff --git a/av1/common/x86/optflow_refine_sse4.c b/av1/common/x86/optflow_refine_sse4.c
index 8e09075..71cac3b 100644
--- a/av1/common/x86/optflow_refine_sse4.c
+++ b/av1/common/x86/optflow_refine_sse4.c
@@ -751,30 +751,17 @@
}
#if OPFL_COMBINE_INTERP_GRAD_LS
-static AOM_FORCE_INLINE void calc_mv(const __m128i u2, const __m128i v2,
- const __m128i uv, const __m128i uw,
- const __m128i vw, const int d0,
- const int d1, const int bits,
- const uint8_t block_num,
- const int rls_alpha, int *vx0, int *vy0,
- int *vx1, int *vy1) {
- int64_t su2, suv, sv2, suw, svw;
- if (block_num == 0) {
- su2 = (int64_t)_mm_extract_epi32(u2, 0) + _mm_extract_epi32(u2, 1);
- sv2 = (int64_t)_mm_extract_epi32(v2, 0) + _mm_extract_epi32(v2, 1);
- suv = (int64_t)_mm_extract_epi32(uv, 0) + _mm_extract_epi32(uv, 1);
- suw = (int64_t)_mm_extract_epi32(uw, 0) + _mm_extract_epi32(uw, 1);
- svw = (int64_t)_mm_extract_epi32(vw, 0) + _mm_extract_epi32(vw, 1);
- } else {
- su2 = (int64_t)_mm_extract_epi32(u2, 2) + _mm_extract_epi32(u2, 3);
- sv2 = (int64_t)_mm_extract_epi32(v2, 2) + _mm_extract_epi32(v2, 3);
- suv = (int64_t)_mm_extract_epi32(uv, 2) + _mm_extract_epi32(uv, 3);
- suw = (int64_t)_mm_extract_epi32(uw, 2) + _mm_extract_epi32(uw, 3);
- svw = (int64_t)_mm_extract_epi32(vw, 2) + _mm_extract_epi32(vw, 3);
- }
-
- calc_mv_process(su2, sv2, suv, suw, svw, d0, d1, bits, rls_alpha, vx0, vy0,
- vx1, vy1);
+static AOM_FORCE_INLINE void multiply_and_accum(__m128i a_lo_0, __m128i b_lo_0,
+ __m128i a_hi_0, __m128i b_hi_0,
+ __m128i a_lo1, __m128i b_lo1,
+ __m128i a_hi1, __m128i b_hi1,
+ __m128i *t1, __m128i *t2) {
+ const __m128i reg_lo_0 = _mm_mul_epi32(a_lo_0, b_lo_0);
+ const __m128i reg_hi_0 = _mm_mul_epi32(a_hi_0, b_hi_0);
+ const __m128i reg_lo1 = _mm_mul_epi32(a_lo1, b_lo1);
+ const __m128i reg_hi1 = _mm_mul_epi32(a_hi1, b_hi1);
+ *t1 = _mm_add_epi64(reg_lo_0, reg_lo1);
+ *t2 = _mm_add_epi64(reg_hi_0, reg_hi1);
}
static void opfl_mv_refinement_interp_grad_8x4_sse4_1(
@@ -782,11 +769,18 @@
int gstride, int d0, int d1, int grad_prec_bits, int mv_prec_bits, int *vx0,
int *vy0, int *vx1, int *vy1) {
int bHeight = 4;
- __m128i u2 = _mm_setzero_si128();
- __m128i uv = _mm_setzero_si128();
- __m128i v2 = _mm_setzero_si128();
- __m128i uw = _mm_setzero_si128();
- __m128i vw = _mm_setzero_si128();
+ __m128i u2_lo = _mm_setzero_si128();
+ __m128i uv_lo = _mm_setzero_si128();
+ __m128i v2_lo = _mm_setzero_si128();
+ __m128i uw_lo = _mm_setzero_si128();
+ __m128i vw_lo = _mm_setzero_si128();
+ __m128i u2_hi = _mm_setzero_si128();
+ __m128i uv_hi = _mm_setzero_si128();
+ __m128i v2_hi = _mm_setzero_si128();
+ __m128i uw_hi = _mm_setzero_si128();
+ __m128i vw_hi = _mm_setzero_si128();
+ const int bits = mv_prec_bits + grad_prec_bits;
+ const int rls_alpha = OPFL_RLS_PARAM;
#if OPFL_DOWNSAMP_QUINCUNX
const __m128i even_row =
_mm_set_epi16(0, 0xFFFF, 0, 0xFFFF, 0, 0xFFFF, 0, 0xFFFF);
@@ -797,7 +791,6 @@
__m128i gradX = LoadUnaligned16(gx);
__m128i gradY = LoadUnaligned16(gy);
__m128i pred = LoadUnaligned16(pdiff);
- __m128i reg;
#if OPFL_DOWNSAMP_QUINCUNX
const __m128i gradX1 = LoadUnaligned16(gx + gstride);
const __m128i gradY1 = LoadUnaligned16(gy + gstride);
@@ -809,20 +802,50 @@
pred = _mm_or_si128(_mm_and_si128(pred, even_row),
_mm_and_si128(pred1, odd_row));
#endif
- reg = _mm_madd_epi16(gradX, gradX);
- u2 = _mm_add_epi32(reg, u2);
+ // The precision of gx, gy and pred (i.e. d0*p0-d1*p1) buffers is signed
+ // 16bit and there are cases where these buffers can be filled with extreme
+ // values. Hence, the accumulation here needs to be done at 64-bit precision
+ // to avoid overflow issues.
+ const __m128i gradX_lo_0 = _mm_cvtepi16_epi32(gradX);
+ const __m128i gradY_lo_0 = _mm_cvtepi16_epi32(gradY);
+ const __m128i pred_lo_0 = _mm_cvtepi16_epi32(pred);
+ const __m128i gradX_hi_0 = _mm_cvtepi16_epi32(_mm_srli_si128(gradX, 8));
+ const __m128i gradY_hi_0 = _mm_cvtepi16_epi32(_mm_srli_si128(gradY, 8));
+ const __m128i pred_hi_0 = _mm_cvtepi16_epi32(_mm_srli_si128(pred, 8));
- reg = _mm_madd_epi16(gradY, gradY);
- v2 = _mm_add_epi32(reg, v2);
+ const __m128i gradX_lo1 = _mm_srli_si128(gradX_lo_0, 4);
+ const __m128i gradX_hi1 = _mm_srli_si128(gradX_hi_0, 4);
+ const __m128i gradY_lo1 = _mm_srli_si128(gradY_lo_0, 4);
+ const __m128i gradY_hi1 = _mm_srli_si128(gradY_hi_0, 4);
+ const __m128i pred_lo1 = _mm_srli_si128(pred_lo_0, 4);
+ const __m128i pred_hi1 = _mm_srli_si128(pred_hi_0, 4);
+ __m128i t1, t2;
- reg = _mm_madd_epi16(gradX, gradY);
- uv = _mm_add_epi32(reg, uv);
+ multiply_and_accum(gradX_lo_0, gradX_lo_0, gradX_hi_0, gradX_hi_0,
+ gradX_lo1, gradX_lo1, gradX_hi1, gradX_hi1, &t1, &t2);
+ u2_lo = _mm_add_epi64(u2_lo, t1);
+ u2_hi = _mm_add_epi64(u2_hi, t2);
- reg = _mm_madd_epi16(gradX, pred);
- uw = _mm_add_epi32(reg, uw);
+ multiply_and_accum(gradY_lo_0, gradY_lo_0, gradY_hi_0, gradY_hi_0,
+ gradY_lo1, gradY_lo1, gradY_hi1, gradY_hi1, &t1, &t2);
+ v2_lo = _mm_add_epi64(v2_lo, t1);
+ v2_hi = _mm_add_epi64(v2_hi, t2);
- reg = _mm_madd_epi16(gradY, pred);
- vw = _mm_add_epi32(reg, vw);
+ multiply_and_accum(gradX_lo_0, gradY_lo_0, gradX_hi_0, gradY_hi_0,
+ gradX_lo1, gradY_lo1, gradX_hi1, gradY_hi1, &t1, &t2);
+ uv_lo = _mm_add_epi64(uv_lo, t1);
+ uv_hi = _mm_add_epi64(uv_hi, t2);
+
+ multiply_and_accum(gradX_lo_0, pred_lo_0, gradX_hi_0, pred_hi_0, gradX_lo1,
+ pred_lo1, gradX_hi1, pred_hi1, &t1, &t2);
+ uw_lo = _mm_add_epi64(uw_lo, t1);
+ uw_hi = _mm_add_epi64(uw_hi, t2);
+
+ multiply_and_accum(gradY_lo_0, pred_lo_0, gradY_hi_0, pred_hi_0, gradY_lo1,
+ pred_lo1, gradY_hi1, pred_hi1, &t1, &t2);
+ vw_lo = _mm_add_epi64(vw_lo, t1);
+ vw_hi = _mm_add_epi64(vw_hi, t2);
+
#if OPFL_DOWNSAMP_QUINCUNX
gx += gstride << 1;
gy += gstride << 1;
@@ -835,14 +858,35 @@
bHeight -= 1;
#endif
} while (bHeight != 0);
+ u2_lo = _mm_add_epi64(u2_lo, _mm_srli_si128(u2_lo, 8));
+ u2_hi = _mm_add_epi64(u2_hi, _mm_srli_si128(u2_hi, 8));
+ v2_lo = _mm_add_epi64(v2_lo, _mm_srli_si128(v2_lo, 8));
+ v2_hi = _mm_add_epi64(v2_hi, _mm_srli_si128(v2_hi, 8));
+ uv_lo = _mm_add_epi64(uv_lo, _mm_srli_si128(uv_lo, 8));
+ uv_hi = _mm_add_epi64(uv_hi, _mm_srli_si128(uv_hi, 8));
+ uw_lo = _mm_add_epi64(uw_lo, _mm_srli_si128(uw_lo, 8));
+ uw_hi = _mm_add_epi64(uw_hi, _mm_srli_si128(uw_hi, 8));
+ vw_lo = _mm_add_epi64(vw_lo, _mm_srli_si128(vw_lo, 8));
+ vw_hi = _mm_add_epi64(vw_hi, _mm_srli_si128(vw_hi, 8));
- const int bits = mv_prec_bits + grad_prec_bits;
- // As processing block size is 4x4, here '(bw * bh >> 4)' can be replaced
- // by 1.
- const int rls_alpha = OPFL_RLS_PARAM;
- calc_mv(u2, v2, uv, uw, vw, d0, d1, bits, 0, rls_alpha, vx0, vy0, vx1, vy1);
- calc_mv(u2, v2, uv, uw, vw, d0, d1, bits, 1, rls_alpha, vx0 + 1, vy0 + 1,
- vx1 + 1, vy1 + 1);
+ int64_t su2, sv2, suv, suw, svw;
+ xx_storel_64(&su2, u2_lo);
+ xx_storel_64(&sv2, v2_lo);
+ xx_storel_64(&suv, uv_lo);
+ xx_storel_64(&suw, uw_lo);
+ xx_storel_64(&svw, vw_lo);
+
+ calc_mv_process(su2, sv2, suv, suw, svw, d0, d1, bits, rls_alpha, vx0, vy0,
+ vx1, vy1);
+
+ xx_storel_64(&su2, u2_hi);
+ xx_storel_64(&sv2, v2_hi);
+ xx_storel_64(&suv, uv_hi);
+ xx_storel_64(&suw, uw_hi);
+ xx_storel_64(&svw, vw_hi);
+
+ calc_mv_process(su2, sv2, suv, suw, svw, d0, d1, bits, rls_alpha, vx0 + 1,
+ vy0 + 1, vx1 + 1, vy1 + 1);
}
static void opfl_mv_refinement_interp_grad_8x8_sse4_1(
@@ -850,6 +894,8 @@
int gstride, int d0, int d1, int grad_prec_bits, int mv_prec_bits, int *vx0,
int *vy0, int *vx1, int *vy1) {
int bHeight = 8;
+ const int rls_alpha = 4 * OPFL_RLS_PARAM;
+ const int bits = mv_prec_bits + grad_prec_bits;
__m128i u2 = _mm_setzero_si128();
__m128i uv = _mm_setzero_si128();
__m128i v2 = _mm_setzero_si128();
@@ -865,7 +911,6 @@
__m128i gradX = LoadUnaligned16(gx);
__m128i gradY = LoadUnaligned16(gy);
__m128i pred = LoadUnaligned16(pdiff);
- __m128i reg;
#if OPFL_DOWNSAMP_QUINCUNX
const __m128i gradX1 = LoadUnaligned16(gx + gstride);
const __m128i gradY1 = LoadUnaligned16(gy + gstride);
@@ -877,20 +922,49 @@
pred = _mm_or_si128(_mm_and_si128(pred, even_row),
_mm_and_si128(pred1, odd_row));
#endif
- reg = _mm_madd_epi16(gradX, gradX);
- u2 = _mm_add_epi32(u2, reg);
+ // The precision of gx, gy and pred (i.e. d0*p0-d1*p1) buffers is signed
+ // 16bit and there are cases where these buffers can be filled with extreme
+ // values. Hence, the accumulation here needs to be done at 64bit to avoid
+ // overflow issues.
+ const __m128i gradX_lo_0 = _mm_cvtepi16_epi32(gradX);
+ const __m128i gradY_lo_0 = _mm_cvtepi16_epi32(gradY);
+ const __m128i pred_lo_0 = _mm_cvtepi16_epi32(pred);
+ const __m128i gradX_hi_0 = _mm_cvtepi16_epi32(_mm_srli_si128(gradX, 8));
+ const __m128i gradY_hi_0 = _mm_cvtepi16_epi32(_mm_srli_si128(gradY, 8));
+ const __m128i pred_hi_0 = _mm_cvtepi16_epi32(_mm_srli_si128(pred, 8));
- reg = _mm_madd_epi16(gradY, gradY);
- v2 = _mm_add_epi32(v2, reg);
+ const __m128i gradX_lo1 = _mm_srli_si128(gradX_lo_0, 4);
+ const __m128i gradX_hi1 = _mm_srli_si128(gradX_hi_0, 4);
+ const __m128i gradY_lo1 = _mm_srli_si128(gradY_lo_0, 4);
+ const __m128i gradY_hi1 = _mm_srli_si128(gradY_hi_0, 4);
+ const __m128i pred_lo1 = _mm_srli_si128(pred_lo_0, 4);
+ const __m128i pred_hi1 = _mm_srli_si128(pred_hi_0, 4);
+ __m128i t1, t2;
- reg = _mm_madd_epi16(gradX, gradY);
- uv = _mm_add_epi32(uv, reg);
+ multiply_and_accum(gradX_lo_0, gradX_lo_0, gradX_hi_0, gradX_hi_0,
+ gradX_lo1, gradX_lo1, gradX_hi1, gradX_hi1, &t1, &t2);
+ t2 = _mm_add_epi64(t1, t2);
+ u2 = _mm_add_epi64(u2, t2);
- reg = _mm_madd_epi16(gradX, pred);
- uw = _mm_add_epi32(uw, reg);
+ multiply_and_accum(gradY_lo_0, gradY_lo_0, gradY_hi_0, gradY_hi_0,
+ gradY_lo1, gradY_lo1, gradY_hi1, gradY_hi1, &t1, &t2);
+ t2 = _mm_add_epi64(t1, t2);
+ v2 = _mm_add_epi64(v2, t2);
- reg = _mm_madd_epi16(gradY, pred);
- vw = _mm_add_epi32(vw, reg);
+ multiply_and_accum(gradX_lo_0, gradY_lo_0, gradX_hi_0, gradY_hi_0,
+ gradX_lo1, gradY_lo1, gradX_hi1, gradY_hi1, &t1, &t2);
+ t2 = _mm_add_epi64(t1, t2);
+ uv = _mm_add_epi64(uv, t2);
+
+ multiply_and_accum(gradX_lo_0, pred_lo_0, gradX_hi_0, pred_hi_0, gradX_lo1,
+ pred_lo1, gradX_hi1, pred_hi1, &t1, &t2);
+ t2 = _mm_add_epi64(t1, t2);
+ uw = _mm_add_epi64(uw, t2);
+
+ multiply_and_accum(gradY_lo_0, pred_lo_0, gradY_hi_0, pred_hi_0, gradY_lo1,
+ pred_lo1, gradY_hi1, pred_hi1, &t1, &t2);
+ t2 = _mm_add_epi64(t1, t2);
+ vw = _mm_add_epi64(vw, t2);
#if OPFL_DOWNSAMP_QUINCUNX
gx += gstride << 1;
gy += gstride << 1;
@@ -904,17 +978,20 @@
#endif
} while (bHeight != 0);
- u2 = _mm_add_epi32(u2, _mm_srli_si128(u2, 8));
- v2 = _mm_add_epi32(v2, _mm_srli_si128(v2, 8));
- uv = _mm_add_epi32(uv, _mm_srli_si128(uv, 8));
- uw = _mm_add_epi32(uw, _mm_srli_si128(uw, 8));
- vw = _mm_add_epi32(vw, _mm_srli_si128(vw, 8));
+ int64_t su2, sv2, suv, suw, svw;
+ u2 = _mm_add_epi64(u2, _mm_srli_si128(u2, 8));
+ v2 = _mm_add_epi64(v2, _mm_srli_si128(v2, 8));
+ uv = _mm_add_epi64(uv, _mm_srli_si128(uv, 8));
+ uw = _mm_add_epi64(uw, _mm_srli_si128(uw, 8));
+ vw = _mm_add_epi64(vw, _mm_srli_si128(vw, 8));
+ xx_storel_64(&su2, u2);
+ xx_storel_64(&sv2, v2);
+ xx_storel_64(&suv, uv);
+ xx_storel_64(&suw, uw);
+ xx_storel_64(&svw, vw);
- const int bits = mv_prec_bits + grad_prec_bits;
- // As processing block size is 8x8, here '(bw * bh >> 4)' can be replaced
- // by 4.
- const int rls_alpha = 4 * OPFL_RLS_PARAM;
- calc_mv(u2, v2, uv, uw, vw, d0, d1, bits, 0, rls_alpha, vx0, vy0, vx1, vy1);
+ calc_mv_process(su2, sv2, suv, suw, svw, d0, d1, bits, rls_alpha, vx0, vy0,
+ vx1, vy1);
}
static AOM_INLINE void opfl_mv_refinement_interp_grad_sse4_1(
diff --git a/test/opt_flow_test.cc b/test/opt_flow_test.cc
index 70950d8..9060908 100644
--- a/test/opt_flow_test.cc
+++ b/test/opt_flow_test.cc
@@ -213,14 +213,14 @@
}
void Randomize(uint16_t *p, int size, int max_bit_range) {
- assert(max_bit_range < 16);
+ assert(max_bit_range <= 16);
for (int i = 0; i < size; ++i) {
p[i] = rnd_.Rand16() & ((1 << max_bit_range) - 1);
}
}
void Randomize(int16_t *p, int size, int max_bit_range) {
- assert(max_bit_range < 16);
+ assert(max_bit_range <= 16);
for (int i = 0; i < size; ++i) {
p[i] = (rnd_.Rand16() & ((1 << max_bit_range) - 1)) -
(1 << (max_bit_range - 1));
@@ -724,9 +724,14 @@
const int d1 = get_relative_dist(&oh_info, cur_frm_idx, ref1_frm_idx);
if (!d0 || !d1) continue;
- RandomInput16(input_, GetParam(), bd);
- RandomInput16(gx_, GetParam(), bd + 1);
- RandomInput16(gy_, GetParam(), bd + 1);
+ // Here, the input corresponds to 'd0*p0 - d1*p1' (where P0 and P1 can
+ // be 12 bits, d0 and d1 can be >=5 bits) and gx, gy are gradients of
+ // input. Due to the clamping of these value to [INT16_MIN, INT16_MAX],
+ // testing of the same is required. Hence, populating the input_, gx_
+ // and gy_ buffers as per the requirement.
+ RandomInput16(input_, GetParam(), AOMMIN(16, bd + 1));
+ RandomInput16(gx_, GetParam(), AOMMIN(16, bd + 6));
+ RandomInput16(gy_, GetParam(), AOMMIN(16, bd + 6));
TestOptFlowRefine(input_, gx_, gy_, is_speed, d0, d1);
count++;
@@ -742,9 +747,9 @@
const int d1 = RelativeDistExtreme(oh_bits);
if (!d0 || !d1) continue;
- RandomInput16Extreme(input_, GetParam(), bd);
- RandomInput16Extreme(gx_, GetParam(), bd + 1);
- RandomInput16Extreme(gy_, GetParam(), bd + 1);
+ RandomInput16Extreme(input_, GetParam(), AOMMIN(16, bd + 1));
+ RandomInput16Extreme(gx_, GetParam(), AOMMIN(16, bd + 6));
+ RandomInput16Extreme(gy_, GetParam(), AOMMIN(16, bd + 6));
TestOptFlowRefine(input_, gx_, gy_, 0, d0, d1);
count++;