Corner match: Use bidirectional matching

Remember the best match for each feature in both the source and ref
frames. Only generate correspondences for points which are each other's
best matches. This removes a lot of spurious correspondences, which
helps improve the average quality of the generated correspondences.

In addition, rearrange the code to be much more efficient. If there
are n features in each of the source and ref frames, then the matching
loop is run O(n^2) times. Therefore it is beneficial to push as much
work as possible out into pre- and post-processing loops which run
only O(n) times.

Finally, expand patch size from 13x13 to 16x16, as it gains a little
extra quality at minimal cost.

Results @ "good" mode speed 4:

     Compared to      | BDRATE-PSNR | BDRATE-SSIM | Encode time
----------------------+-------------+-------------+-------------
Previous corner match |   -0.044%   |   -0.052%   |   -3.487%
       Disflow        |   +0.003%   |   -0.007%   |   +7.482%

No change to encoder output, as we currently use disflow.

Change-Id: I0abd73717925623eeb3948889eae3ca38c61811d
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 91f5ee9..02081cd 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1789,8 +1789,11 @@
 
   # Flow estimation library
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
-    add_proto qw/double av1_compute_cross_correlation/, "const unsigned char *frame1, int stride1, int x1, int y1, const unsigned char *frame2, int stride2, int x2, int y2";
-    specialize qw/av1_compute_cross_correlation sse4_1 avx2/;
+    add_proto qw/bool aom_compute_mean_stddev/, "const unsigned char *frame, int stride, int x, int y, double *mean, double *one_over_stddev";
+    specialize qw/aom_compute_mean_stddev sse4_1 avx2/;
+
+    add_proto qw/double aom_compute_correlation/, "const unsigned char *frame1, int stride1, int x1, int y1, double mean1, double one_over_stddev1, const unsigned char *frame2, int stride2, int x2, int y2, double mean2, double one_over_stddev2";
+    specialize qw/aom_compute_correlation sse4_1 avx2/;
 
     add_proto qw/void aom_compute_flow_at_point/, "const uint8_t *src, const uint8_t *ref, int x, int y, int width, int height, int stride, double *u, double *v";
     specialize qw/aom_compute_flow_at_point sse4_1 neon/;
diff --git a/aom_dsp/flow_estimation/corner_match.c b/aom_dsp/flow_estimation/corner_match.c
index dd524c1..7b2b9fc 100644
--- a/aom_dsp/flow_estimation/corner_match.c
+++ b/aom_dsp/flow_estimation/corner_match.c
@@ -25,52 +25,76 @@
 
 #define THRESHOLD_NCC 0.75
 
-/* Compute var(frame) * MATCH_SZ_SQ over a MATCH_SZ by MATCH_SZ window of frame,
-   centered at (x, y).
+/* Compute mean and standard deviation of pixels in a window of size
+   MATCH_SZ by MATCH_SZ centered at (x, y).
+   Store results into *mean and *one_over_stddev
+
+   Note: The output of this function is scaled by MATCH_SZ, as in
+   *mean = MATCH_SZ * <true mean> and
+   *one_over_stddev = 1 / (MATCH_SZ * <true stddev>)
+
+   Combined with the fact that we return 1/stddev rather than the standard
+   deviation itself, this allows us to completely avoid divisions in
+   aom_compute_correlation, which is much hotter than this function is.
+
+   Returns true if this feature point is usable, false otherwise.
 */
-static double compute_variance(const unsigned char *frame, int stride, int x,
-                               int y) {
+bool aom_compute_mean_stddev_c(const unsigned char *frame, int stride, int x,
+                               int y, double *mean, double *one_over_stddev) {
   int sum = 0;
   int sumsq = 0;
-  int var;
-  int i, j;
-  for (i = 0; i < MATCH_SZ; ++i)
-    for (j = 0; j < MATCH_SZ; ++j) {
+  for (int i = 0; i < MATCH_SZ; ++i) {
+    for (int j = 0; j < MATCH_SZ; ++j) {
       sum += frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
       sumsq += frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)] *
                frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
     }
-  var = sumsq * MATCH_SZ_SQ - sum * sum;
-  return (double)var;
+  }
+  *mean = (double)sum / MATCH_SZ;
+  const double variance = sumsq - (*mean) * (*mean);
+  if (variance < MIN_FEATURE_VARIANCE) {
+    *one_over_stddev = 0.0;
+    return false;
+  }
+  *one_over_stddev = 1.0 / sqrt(variance);
+  return true;
 }
 
-/* Compute corr(frame1, frame2) * MATCH_SZ * stddev(frame1), where the
-   correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows
-   of each image, centered at (x1, y1) and (x2, y2) respectively.
+/* Compute corr(frame1, frame2) over a window of size MATCH_SZ by MATCH_SZ.
+   To save on computation, the mean and (1 divided by the) standard deviation
+   of the window in each frame are precomputed and passed into this function
+   as arguments.
 */
-double av1_compute_cross_correlation_c(const unsigned char *frame1, int stride1,
-                                       int x1, int y1,
-                                       const unsigned char *frame2, int stride2,
-                                       int x2, int y2) {
+double aom_compute_correlation_c(const unsigned char *frame1, int stride1,
+                                 int x1, int y1, double mean1,
+                                 double one_over_stddev1,
+                                 const unsigned char *frame2, int stride2,
+                                 int x2, int y2, double mean2,
+                                 double one_over_stddev2) {
   int v1, v2;
-  int sum1 = 0;
-  int sum2 = 0;
-  int sumsq2 = 0;
   int cross = 0;
-  int var2, cov;
-  int i, j;
-  for (i = 0; i < MATCH_SZ; ++i)
-    for (j = 0; j < MATCH_SZ; ++j) {
+  for (int i = 0; i < MATCH_SZ; ++i) {
+    for (int j = 0; j < MATCH_SZ; ++j) {
       v1 = frame1[(i + y1 - MATCH_SZ_BY2) * stride1 + (j + x1 - MATCH_SZ_BY2)];
       v2 = frame2[(i + y2 - MATCH_SZ_BY2) * stride2 + (j + x2 - MATCH_SZ_BY2)];
-      sum1 += v1;
-      sum2 += v2;
-      sumsq2 += v2 * v2;
       cross += v1 * v2;
     }
-  var2 = sumsq2 * MATCH_SZ_SQ - sum2 * sum2;
-  cov = cross * MATCH_SZ_SQ - sum1 * sum2;
-  return cov / sqrt((double)var2);
+  }
+
+  // Note: In theory, the calculations here "should" be
+  //   covariance = cross / N^2 - mean1 * mean2
+  //   correlation = covariance / (stddev1 * stddev2).
+  //
+  // However, because of the scaling in aom_compute_mean_stddev, the
+  // lines below actually calculate
+  //   covariance * N^2 = cross - (mean1 * N) * (mean2 * N)
+  //   correlation = (covariance * N^2) / ((stddev1 * N) * (stddev2 * N))
+  //
+  // ie. we have removed the need for a division, and still end up with the
+  // correct unscaled correlation (ie, in the range [-1, +1])
+  double covariance = cross - mean1 * mean2;
+  double correlation = covariance * (one_over_stddev1 * one_over_stddev2);
+  return correlation;
 }
 
 static int is_eligible_point(int pointx, int pointy, int width, int height) {
@@ -85,6 +109,15 @@
           (point1y - point2y) * (point1y - point2y)) <= thresh * thresh;
 }
 
+typedef struct {
+  int x;
+  int y;
+  double mean;
+  double one_over_stddev;
+  int best_match_idx;
+  double best_match_corr;
+} PointInfo;
+
 static int determine_correspondence(const unsigned char *src,
                                     const int *src_corners, int num_src_corners,
                                     const unsigned char *ref,
@@ -92,44 +125,108 @@
                                     int width, int height, int src_stride,
                                     int ref_stride,
                                     Correspondence *correspondences) {
-  // TODO(sarahparker) Improve this to include 2-way match
-  int i, j;
+  PointInfo *src_point_info = NULL;
+  PointInfo *ref_point_info = NULL;
   int num_correspondences = 0;
-  for (i = 0; i < num_src_corners; ++i) {
-    double best_match_ncc = 0.0;
-    double template_norm;
-    int best_match_j = -1;
-    if (!is_eligible_point(src_corners[2 * i], src_corners[2 * i + 1], width,
-                           height))
+
+  src_point_info =
+      (PointInfo *)aom_calloc(num_src_corners, sizeof(*src_point_info));
+  if (!src_point_info) {
+    goto finished;
+  }
+
+  ref_point_info =
+      (PointInfo *)aom_calloc(num_ref_corners, sizeof(*ref_point_info));
+  if (!ref_point_info) {
+    goto finished;
+  }
+
+  // First pass (linear):
+  // Filter corner lists and compute per-patch means and standard deviations,
+  // for the src and ref frames independently
+  int src_point_count = 0;
+  for (int i = 0; i < num_src_corners; i++) {
+    int src_x = src_corners[2 * i];
+    int src_y = src_corners[2 * i + 1];
+    if (!is_eligible_point(src_x, src_y, width, height)) continue;
+
+    PointInfo *point = &src_point_info[src_point_count];
+    point->x = src_x;
+    point->y = src_y;
+    point->best_match_corr = THRESHOLD_NCC;
+    if (!aom_compute_mean_stddev(src, src_stride, src_x, src_y, &point->mean,
+                                 &point->one_over_stddev))
       continue;
-    for (j = 0; j < num_ref_corners; ++j) {
-      double match_ncc;
-      if (!is_eligible_point(ref_corners[2 * j], ref_corners[2 * j + 1], width,
-                             height))
+    src_point_count++;
+  }
+  if (src_point_count == 0) {
+    goto finished;
+  }
+
+  int ref_point_count = 0;
+  for (int j = 0; j < num_ref_corners; j++) {
+    int ref_x = ref_corners[2 * j];
+    int ref_y = ref_corners[2 * j + 1];
+    if (!is_eligible_point(ref_x, ref_y, width, height)) continue;
+
+    PointInfo *point = &ref_point_info[ref_point_count];
+    point->x = ref_x;
+    point->y = ref_y;
+    point->best_match_corr = THRESHOLD_NCC;
+    if (!aom_compute_mean_stddev(ref, ref_stride, ref_x, ref_y, &point->mean,
+                                 &point->one_over_stddev))
+      continue;
+    ref_point_count++;
+  }
+  if (ref_point_count == 0) {
+    goto finished;
+  }
+
+  // Second pass (quadratic):
+  // For each pair of points, compute correlation, and use this to determine
+  // the best match of each corner, in both directions
+  for (int i = 0; i < src_point_count; ++i) {
+    PointInfo *src_point = &src_point_info[i];
+    for (int j = 0; j < ref_point_count; ++j) {
+      PointInfo *ref_point = &ref_point_info[j];
+      if (!is_eligible_distance(src_point->x, src_point->y, ref_point->x,
+                                ref_point->y, width, height))
         continue;
-      if (!is_eligible_distance(src_corners[2 * i], src_corners[2 * i + 1],
-                                ref_corners[2 * j], ref_corners[2 * j + 1],
-                                width, height))
-        continue;
-      match_ncc = av1_compute_cross_correlation(
-          src, src_stride, src_corners[2 * i], src_corners[2 * i + 1], ref,
-          ref_stride, ref_corners[2 * j], ref_corners[2 * j + 1]);
-      if (match_ncc > best_match_ncc) {
-        best_match_ncc = match_ncc;
-        best_match_j = j;
+
+      double corr = aom_compute_correlation(
+          src, src_stride, src_point->x, src_point->y, src_point->mean,
+          src_point->one_over_stddev, ref, ref_stride, ref_point->x,
+          ref_point->y, ref_point->mean, ref_point->one_over_stddev);
+
+      if (corr > src_point->best_match_corr) {
+        src_point->best_match_idx = j;
+        src_point->best_match_corr = corr;
+      }
+      if (corr > ref_point->best_match_corr) {
+        ref_point->best_match_idx = i;
+        ref_point->best_match_corr = corr;
       }
     }
-    // Note: We want to test if the best correlation is >= THRESHOLD_NCC,
-    // but need to account for the normalization in
-    // av1_compute_cross_correlation.
-    template_norm = compute_variance(src, src_stride, src_corners[2 * i],
-                                     src_corners[2 * i + 1]);
-    if (best_match_ncc > THRESHOLD_NCC * sqrt(template_norm)) {
-      // Apply refinement
-      const int sx = src_corners[2 * i];
-      const int sy = src_corners[2 * i + 1];
-      const int rx = ref_corners[2 * best_match_j];
-      const int ry = ref_corners[2 * best_match_j + 1];
+  }
+
+  // Third pass (linear):
+  // Scan through source corners, generating a correspondence for each corner
+  // iff ref_best_match[src_best_match[i]] == i
+  // Then refine the generated correspondences using optical flow
+  for (int i = 0; i < src_point_count; i++) {
+    PointInfo *point = &src_point_info[i];
+
+    // Skip corners which were not matched, or which didn't find
+    // a good enough match
+    if (point->best_match_corr < THRESHOLD_NCC) continue;
+
+    PointInfo *match_point = &ref_point_info[point->best_match_idx];
+    if (match_point->best_match_idx == i) {
+      // Refine match using optical flow and store
+      const int sx = point->x;
+      const int sy = point->y;
+      const int rx = match_point->x;
+      const int ry = match_point->y;
       double u = (double)(rx - sx);
       double v = (double)(ry - sy);
 
@@ -139,13 +236,18 @@
       aom_compute_flow_at_point(src, ref, patch_tl_x, patch_tl_y, width, height,
                                 src_stride, &u, &v);
 
-      correspondences[num_correspondences].x = (double)sx;
-      correspondences[num_correspondences].y = (double)sy;
-      correspondences[num_correspondences].rx = (double)sx + u;
-      correspondences[num_correspondences].ry = (double)sy + v;
+      Correspondence *correspondence = &correspondences[num_correspondences];
+      correspondence->x = (double)sx;
+      correspondence->y = (double)sy;
+      correspondence->rx = (double)sx + u;
+      correspondence->ry = (double)sy + v;
       num_correspondences++;
     }
   }
+
+finished:
+  aom_free(src_point_info);
+  aom_free(ref_point_info);
   return num_correspondences;
 }
 
diff --git a/aom_dsp/flow_estimation/corner_match.h b/aom_dsp/flow_estimation/corner_match.h
index 4435d2c..99507dc 100644
--- a/aom_dsp/flow_estimation/corner_match.h
+++ b/aom_dsp/flow_estimation/corner_match.h
@@ -25,10 +25,16 @@
 extern "C" {
 #endif
 
-#define MATCH_SZ 13
+#define MATCH_SZ 16
 #define MATCH_SZ_BY2 ((MATCH_SZ - 1) / 2)
 #define MATCH_SZ_SQ (MATCH_SZ * MATCH_SZ)
 
+// Minimum threshold for the variance of a patch, in order for it to be
+// considered useful for matching.
+// This is evaluated against the scaled variance MATCH_SZ_SQ * sigma^2,
+// so a setting of 1 * MATCH_SZ_SQ corresponds to an unscaled variance of 1
+#define MIN_FEATURE_VARIANCE (1 * MATCH_SZ_SQ)
+
 bool av1_compute_global_motion_feature_match(
     TransformationType type, YV12_BUFFER_CONFIG *src, YV12_BUFFER_CONFIG *ref,
     int bit_depth, MotionModel *motion_models, int num_motion_models,
diff --git a/aom_dsp/flow_estimation/x86/corner_match_avx2.c b/aom_dsp/flow_estimation/x86/corner_match_avx2.c
index 87c76fa..ff69ae7 100644
--- a/aom_dsp/flow_estimation/x86/corner_match_avx2.c
+++ b/aom_dsp/flow_estimation/x86/corner_match_avx2.c
@@ -17,64 +17,112 @@
 #include "aom_ports/mem.h"
 #include "aom_dsp/flow_estimation/corner_match.h"
 
-DECLARE_ALIGNED(16, static const uint8_t,
-                byte_mask[16]) = { 255, 255, 255, 255, 255, 255, 255, 255,
-                                   255, 255, 255, 255, 255, 0,   0,   0 };
-#if MATCH_SZ != 13
-#error "Need to change byte_mask in corner_match_sse4.c if MATCH_SZ != 13"
+DECLARE_ALIGNED(32, static const uint16_t, ones_array[16]) = { 1, 1, 1, 1, 1, 1,
+                                                               1, 1, 1, 1, 1, 1,
+                                                               1, 1, 1, 1 };
+
+#if MATCH_SZ != 16
+#error "Need to apply pixel mask in corner_match_avx2.c if MATCH_SZ != 16"
 #endif
 
-/* Compute corr(frame1, frame2) * MATCH_SZ * stddev(frame1), where the
-correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows
-of each image, centered at (x1, y1) and (x2, y2) respectively.
-*/
-double av1_compute_cross_correlation_avx2(const unsigned char *frame1,
-                                          int stride1, int x1, int y1,
-                                          const unsigned char *frame2,
-                                          int stride2, int x2, int y2) {
-  int i, stride1_i = 0, stride2_i = 0;
-  __m256i temp1, sum_vec, sumsq2_vec, cross_vec, v, v1_1, v2_1;
-  const __m128i mask = _mm_load_si128((__m128i *)byte_mask);
-  const __m256i zero = _mm256_setzero_si256();
-  __m128i v1, v2;
+/* Compute mean and standard deviation of pixels in a window of size
+   MATCH_SZ by MATCH_SZ centered at (x, y).
+   Store results into *mean and *one_over_stddev
 
-  sum_vec = zero;
-  sumsq2_vec = zero;
-  cross_vec = zero;
+   Note: The output of this function is scaled by MATCH_SZ, as in
+   *mean = MATCH_SZ * <true mean> and
+   *one_over_stddev = 1 / (MATCH_SZ * <true stddev>)
+
+   Combined with the fact that we return 1/stddev rather than the standard
+   deviation itself, this allows us to completely avoid divisions in
+   aom_compute_correlation, which is much hotter than this function is.
+
+   Returns true if this feature point is usable, false otherwise.
+*/
+bool aom_compute_mean_stddev_avx2(const unsigned char *frame, int stride, int x,
+                                  int y, double *mean,
+                                  double *one_over_stddev) {
+  __m256i sum_vec = _mm256_setzero_si256();
+  __m256i sumsq_vec = _mm256_setzero_si256();
+
+  frame += (y - MATCH_SZ_BY2) * stride + (x - MATCH_SZ_BY2);
+
+  for (int i = 0; i < MATCH_SZ; ++i) {
+    const __m256i v = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)frame));
+
+    sum_vec = _mm256_add_epi16(sum_vec, v);
+    sumsq_vec = _mm256_add_epi32(sumsq_vec, _mm256_madd_epi16(v, v));
+
+    frame += stride;
+  }
+
+  // Reduce sum_vec and sumsq_vec into single values
+  // Start by reducing each vector to 8x32-bit values, hadd() to perform 8
+  // additions, sum vertically to do 4 more, then the last 2 in scalar code.
+  const __m256i ones = _mm256_load_si256((__m256i *)ones_array);
+  const __m256i partial_sum = _mm256_madd_epi16(sum_vec, ones);
+  const __m256i tmp_8x32 = _mm256_hadd_epi32(partial_sum, sumsq_vec);
+  const __m128i tmp_4x32 = _mm_add_epi32(_mm256_extracti128_si256(tmp_8x32, 0),
+                                         _mm256_extracti128_si256(tmp_8x32, 1));
+  const int sum =
+      _mm_extract_epi32(tmp_4x32, 0) + _mm_extract_epi32(tmp_4x32, 1);
+  const int sumsq =
+      _mm_extract_epi32(tmp_4x32, 2) + _mm_extract_epi32(tmp_4x32, 3);
+
+  *mean = (double)sum / MATCH_SZ;
+  const double variance = sumsq - (*mean) * (*mean);
+  if (variance < MIN_FEATURE_VARIANCE) {
+    *one_over_stddev = 0.0;
+    return false;
+  }
+  *one_over_stddev = 1.0 / sqrt(variance);
+  return true;
+}
+
+/* Compute corr(frame1, frame2) over a window of size MATCH_SZ by MATCH_SZ.
+   To save on computation, the mean and (1 divided by the) standard deviation
+   of the window in each frame are precomputed and passed into this function
+   as arguments.
+*/
+double aom_compute_correlation_avx2(const unsigned char *frame1, int stride1,
+                                    int x1, int y1, double mean1,
+                                    double one_over_stddev1,
+                                    const unsigned char *frame2, int stride2,
+                                    int x2, int y2, double mean2,
+                                    double one_over_stddev2) {
+  __m256i cross_vec = _mm256_setzero_si256();
 
   frame1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2);
   frame2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2);
 
-  for (i = 0; i < MATCH_SZ; ++i) {
-    v1 = _mm_and_si128(_mm_loadu_si128((__m128i *)&frame1[stride1_i]), mask);
-    v1_1 = _mm256_cvtepu8_epi16(v1);
-    v2 = _mm_and_si128(_mm_loadu_si128((__m128i *)&frame2[stride2_i]), mask);
-    v2_1 = _mm256_cvtepu8_epi16(v2);
+  for (int i = 0; i < MATCH_SZ; ++i) {
+    const __m256i v1 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)frame1));
+    const __m256i v2 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)frame2));
 
-    v = _mm256_insertf128_si256(_mm256_castsi128_si256(v1), v2, 1);
-    sumsq2_vec = _mm256_add_epi32(sumsq2_vec, _mm256_madd_epi16(v2_1, v2_1));
+    cross_vec = _mm256_add_epi32(cross_vec, _mm256_madd_epi16(v1, v2));
 
-    sum_vec = _mm256_add_epi16(sum_vec, _mm256_sad_epu8(v, zero));
-    cross_vec = _mm256_add_epi32(cross_vec, _mm256_madd_epi16(v1_1, v2_1));
-    stride1_i += stride1;
-    stride2_i += stride2;
+    frame1 += stride1;
+    frame2 += stride2;
   }
-  __m256i sum_vec1 = _mm256_srli_si256(sum_vec, 8);
-  sum_vec = _mm256_add_epi32(sum_vec, sum_vec1);
-  int sum1_acc = _mm_cvtsi128_si32(_mm256_castsi256_si128(sum_vec));
-  int sum2_acc = _mm256_extract_epi32(sum_vec, 4);
 
-  __m256i unp_low = _mm256_unpacklo_epi64(sumsq2_vec, cross_vec);
-  __m256i unp_hig = _mm256_unpackhi_epi64(sumsq2_vec, cross_vec);
-  temp1 = _mm256_add_epi32(unp_low, unp_hig);
+  // Sum cross_vec into a single value
+  const __m128i tmp = _mm_add_epi32(_mm256_extracti128_si256(cross_vec, 0),
+                                    _mm256_extracti128_si256(cross_vec, 1));
+  const int cross = _mm_extract_epi32(tmp, 0) + _mm_extract_epi32(tmp, 1) +
+                    _mm_extract_epi32(tmp, 2) + _mm_extract_epi32(tmp, 3);
 
-  __m128i low_sumsq = _mm256_castsi256_si128(temp1);
-  low_sumsq = _mm_add_epi32(low_sumsq, _mm256_extractf128_si256(temp1, 1));
-  low_sumsq = _mm_add_epi32(low_sumsq, _mm_srli_epi64(low_sumsq, 32));
-  int sumsq2_acc = _mm_cvtsi128_si32(low_sumsq);
-  int cross_acc = _mm_extract_epi32(low_sumsq, 2);
-
-  int var2 = sumsq2_acc * MATCH_SZ_SQ - sum2_acc * sum2_acc;
-  int cov = cross_acc * MATCH_SZ_SQ - sum1_acc * sum2_acc;
-  return cov / sqrt((double)var2);
+  // Note: In theory, the calculations here "should" be
+  //   covariance = cross / N^2 - mean1 * mean2
+  //   correlation = covariance / (stddev1 * stddev2).
+  //
+  // However, because of the scaling in aom_compute_mean_stddev, the
+  // lines below actually calculate
+  //   covariance * N^2 = cross - (mean1 * N) * (mean2 * N)
+  //   correlation = (covariance * N^2) / ((stddev1 * N) * (stddev2 * N))
+  //
+  // ie. we have removed the need for a division, and still end up with the
+  // correct unscaled correlation (ie, in the range [-1, +1])
+  const double covariance = cross - mean1 * mean2;
+  const double correlation = covariance * (one_over_stddev1 * one_over_stddev2);
+  return correlation;
 }
diff --git a/aom_dsp/flow_estimation/x86/corner_match_sse4.c b/aom_dsp/flow_estimation/x86/corner_match_sse4.c
index b3cb5bc..bff7db6 100644
--- a/aom_dsp/flow_estimation/x86/corner_match_sse4.c
+++ b/aom_dsp/flow_estimation/x86/corner_match_sse4.c
@@ -21,84 +21,125 @@
 #include "aom_ports/mem.h"
 #include "aom_dsp/flow_estimation/corner_match.h"
 
-DECLARE_ALIGNED(16, static const uint8_t,
-                byte_mask[16]) = { 255, 255, 255, 255, 255, 255, 255, 255,
-                                   255, 255, 255, 255, 255, 0,   0,   0 };
-#if MATCH_SZ != 13
-#error "Need to change byte_mask in corner_match_sse4.c if MATCH_SZ != 13"
+DECLARE_ALIGNED(16, static const uint16_t, ones_array[8]) = { 1, 1, 1, 1,
+                                                              1, 1, 1, 1 };
+
+#if MATCH_SZ != 16
+#error "Need to apply pixel mask in corner_match_sse4.c if MATCH_SZ != 16"
 #endif
 
-/* Compute corr(frame1, frame2) * MATCH_SZ * stddev(frame1), where the
-   correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows
-   of each image, centered at (x1, y1) and (x2, y2) respectively.
-*/
-double av1_compute_cross_correlation_sse4_1(const unsigned char *frame1,
-                                            int stride1, int x1, int y1,
-                                            const unsigned char *frame2,
-                                            int stride2, int x2, int y2) {
-  int i;
-  // 2 16-bit partial sums in lanes 0, 4 (== 2 32-bit partial sums in lanes 0,
-  // 2)
-  __m128i sum1_vec = _mm_setzero_si128();
-  __m128i sum2_vec = _mm_setzero_si128();
-  // 4 32-bit partial sums of squares
-  __m128i sumsq2_vec = _mm_setzero_si128();
-  __m128i cross_vec = _mm_setzero_si128();
+/* Compute mean and standard deviation of pixels in a window of size
+   MATCH_SZ by MATCH_SZ centered at (x, y).
+   Store results into *mean and *one_over_stddev
 
-  const __m128i mask = _mm_load_si128((__m128i *)byte_mask);
-  const __m128i zero = _mm_setzero_si128();
+   Note: The output of this function is scaled by MATCH_SZ, as in
+   *mean = MATCH_SZ * <true mean> and
+   *one_over_stddev = 1 / (MATCH_SZ * <true stddev>)
+
+   Combined with the fact that we return 1/stddev rather than the standard
+   deviation itself, this allows us to completely avoid divisions in
+   aom_compute_correlation, which is much hotter than this function is.
+
+   Returns true if this feature point is usable, false otherwise.
+*/
+bool aom_compute_mean_stddev_sse4_1(const unsigned char *frame, int stride,
+                                    int x, int y, double *mean,
+                                    double *one_over_stddev) {
+  // 8 16-bit partial sums of pixels
+  // Each lane sums at most 2*MATCH_SZ pixels, which can have values up to 255,
+  // and is therefore at most 2*MATCH_SZ*255, which is > 2^8 but < 2^16.
+  // Thus this value is safe to store in 16 bits.
+  __m128i sum_vec = _mm_setzero_si128();
+
+  // 8 32-bit partial sums of squares
+  __m128i sumsq_vec_l = _mm_setzero_si128();
+  __m128i sumsq_vec_r = _mm_setzero_si128();
+
+  frame += (y - MATCH_SZ_BY2) * stride + (x - MATCH_SZ_BY2);
+
+  for (int i = 0; i < MATCH_SZ; ++i) {
+    const __m128i v = _mm_loadu_si128((__m128i *)frame);
+    const __m128i v_l = _mm_cvtepu8_epi16(v);
+    const __m128i v_r = _mm_cvtepu8_epi16(_mm_srli_si128(v, 8));
+
+    sum_vec = _mm_add_epi16(sum_vec, _mm_add_epi16(v_l, v_r));
+    sumsq_vec_l = _mm_add_epi32(sumsq_vec_l, _mm_madd_epi16(v_l, v_l));
+    sumsq_vec_r = _mm_add_epi32(sumsq_vec_r, _mm_madd_epi16(v_r, v_r));
+
+    frame += stride;
+  }
+
+  // Reduce sum_vec and sumsq_vec into single values
+  // Start by reducing each vector to 4x32-bit values, hadd() to perform four
+  // additions, then perform the last two additions in scalar code.
+  const __m128i ones = _mm_load_si128((__m128i *)ones_array);
+  const __m128i partial_sum = _mm_madd_epi16(sum_vec, ones);
+  const __m128i partial_sumsq = _mm_add_epi32(sumsq_vec_l, sumsq_vec_r);
+  const __m128i tmp = _mm_hadd_epi32(partial_sum, partial_sumsq);
+  const int sum = _mm_extract_epi32(tmp, 0) + _mm_extract_epi32(tmp, 1);
+  const int sumsq = _mm_extract_epi32(tmp, 2) + _mm_extract_epi32(tmp, 3);
+
+  *mean = (double)sum / MATCH_SZ;
+  const double variance = sumsq - (*mean) * (*mean);
+  if (variance < MIN_FEATURE_VARIANCE) {
+    *one_over_stddev = 0.0;
+    return false;
+  }
+  *one_over_stddev = 1.0 / sqrt(variance);
+  return true;
+}
+
+/* Compute corr(frame1, frame2) over a window of size MATCH_SZ by MATCH_SZ.
+   To save on computation, the mean and (1 divided by the) standard deviation
+   of the window in each frame are precomputed and passed into this function
+   as arguments.
+*/
+double aom_compute_correlation_sse4_1(const unsigned char *frame1, int stride1,
+                                      int x1, int y1, double mean1,
+                                      double one_over_stddev1,
+                                      const unsigned char *frame2, int stride2,
+                                      int x2, int y2, double mean2,
+                                      double one_over_stddev2) {
+  // 8 32-bit partial sums of products
+  __m128i cross_vec_l = _mm_setzero_si128();
+  __m128i cross_vec_r = _mm_setzero_si128();
 
   frame1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2);
   frame2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2);
 
-  for (i = 0; i < MATCH_SZ; ++i) {
-    const __m128i v1 =
-        _mm_and_si128(_mm_loadu_si128((__m128i *)&frame1[i * stride1]), mask);
-    const __m128i v2 =
-        _mm_and_si128(_mm_loadu_si128((__m128i *)&frame2[i * stride2]), mask);
-
-    // Using the 'sad' intrinsic here is a bit faster than adding
-    // v1_l + v1_r and v2_l + v2_r, plus it avoids the need for a 16->32 bit
-    // conversion step later, for a net speedup of ~10%
-    sum1_vec = _mm_add_epi16(sum1_vec, _mm_sad_epu8(v1, zero));
-    sum2_vec = _mm_add_epi16(sum2_vec, _mm_sad_epu8(v2, zero));
+  for (int i = 0; i < MATCH_SZ; ++i) {
+    const __m128i v1 = _mm_loadu_si128((__m128i *)frame1);
+    const __m128i v2 = _mm_loadu_si128((__m128i *)frame2);
 
     const __m128i v1_l = _mm_cvtepu8_epi16(v1);
     const __m128i v1_r = _mm_cvtepu8_epi16(_mm_srli_si128(v1, 8));
     const __m128i v2_l = _mm_cvtepu8_epi16(v2);
     const __m128i v2_r = _mm_cvtepu8_epi16(_mm_srli_si128(v2, 8));
 
-    sumsq2_vec = _mm_add_epi32(
-        sumsq2_vec,
-        _mm_add_epi32(_mm_madd_epi16(v2_l, v2_l), _mm_madd_epi16(v2_r, v2_r)));
-    cross_vec = _mm_add_epi32(
-        cross_vec,
-        _mm_add_epi32(_mm_madd_epi16(v1_l, v2_l), _mm_madd_epi16(v1_r, v2_r)));
+    cross_vec_l = _mm_add_epi32(cross_vec_l, _mm_madd_epi16(v1_l, v2_l));
+    cross_vec_r = _mm_add_epi32(cross_vec_r, _mm_madd_epi16(v1_r, v2_r));
+
+    frame1 += stride1;
+    frame2 += stride2;
   }
 
-  // Now we can treat the four registers (sum1_vec, sum2_vec, sumsq2_vec,
-  // cross_vec)
-  // as holding 4 32-bit elements each, which we want to sum horizontally.
-  // We do this by transposing and then summing vertically.
-  __m128i tmp_0 = _mm_unpacklo_epi32(sum1_vec, sum2_vec);
-  __m128i tmp_1 = _mm_unpackhi_epi32(sum1_vec, sum2_vec);
-  __m128i tmp_2 = _mm_unpacklo_epi32(sumsq2_vec, cross_vec);
-  __m128i tmp_3 = _mm_unpackhi_epi32(sumsq2_vec, cross_vec);
+  // Sum cross_vec into a single value
+  const __m128i tmp = _mm_add_epi32(cross_vec_l, cross_vec_r);
+  const int cross = _mm_extract_epi32(tmp, 0) + _mm_extract_epi32(tmp, 1) +
+                    _mm_extract_epi32(tmp, 2) + _mm_extract_epi32(tmp, 3);
 
-  __m128i tmp_4 = _mm_unpacklo_epi64(tmp_0, tmp_2);
-  __m128i tmp_5 = _mm_unpackhi_epi64(tmp_0, tmp_2);
-  __m128i tmp_6 = _mm_unpacklo_epi64(tmp_1, tmp_3);
-  __m128i tmp_7 = _mm_unpackhi_epi64(tmp_1, tmp_3);
-
-  __m128i res =
-      _mm_add_epi32(_mm_add_epi32(tmp_4, tmp_5), _mm_add_epi32(tmp_6, tmp_7));
-
-  int sum1 = _mm_extract_epi32(res, 0);
-  int sum2 = _mm_extract_epi32(res, 1);
-  int sumsq2 = _mm_extract_epi32(res, 2);
-  int cross = _mm_extract_epi32(res, 3);
-
-  int var2 = sumsq2 * MATCH_SZ_SQ - sum2 * sum2;
-  int cov = cross * MATCH_SZ_SQ - sum1 * sum2;
-  return cov / sqrt((double)var2);
+  // Note: In theory, the calculations here "should" be
+  //   covariance = cross / N^2 - mean1 * mean2
+  //   correlation = covariance / (stddev1 * stddev2).
+  //
+  // However, because of the scaling in aom_compute_mean_stddev, the
+  // lines below actually calculate
+  //   covariance * N^2 = cross - (mean1 * N) * (mean2 * N)
+  //   correlation = (covariance * N^2) / ((stddev1 * N) * (stddev2 * N))
+  //
+  // ie. we have removed the need for a division, and still end up with the
+  // correct unscaled correlation (ie, in the range [-1, +1])
+  const double covariance = cross - mean1 * mean2;
+  const double correlation = covariance * (one_over_stddev1 * one_over_stddev2);
+  return correlation;
 }
diff --git a/test/corner_match_test.cc b/test/corner_match_test.cc
index 9733732..895c8ad 100644
--- a/test/corner_match_test.cc
+++ b/test/corner_match_test.cc
@@ -27,13 +27,19 @@
 
 using libaom_test::ACMRandom;
 
-typedef double (*ComputeCrossCorrFunc)(const unsigned char *im1, int stride1,
-                                       int x1, int y1, const unsigned char *im2,
-                                       int stride2, int x2, int y2);
+typedef bool (*ComputeMeanStddevFunc)(const unsigned char *frame, int stride,
+                                      int x, int y, double *mean,
+                                      double *one_over_stddev);
+typedef double (*ComputeCorrFunc)(const unsigned char *frame1, int stride1,
+                                  int x1, int y1, double mean1,
+                                  double one_over_stddev1,
+                                  const unsigned char *frame2, int stride2,
+                                  int x2, int y2, double mean2,
+                                  double one_over_stddev2);
 
 using std::make_tuple;
 using std::tuple;
-typedef tuple<int, ComputeCrossCorrFunc> CornerMatchParam;
+typedef tuple<int, ComputeMeanStddevFunc, ComputeCorrFunc> CornerMatchParam;
 
 class AV1CornerMatchTest : public ::testing::TestWithParam<CornerMatchParam> {
  public:
@@ -41,8 +47,11 @@
   void SetUp() override;
 
  protected:
-  void RunCheckOutput(int run_times);
-  ComputeCrossCorrFunc target_func;
+  void GenerateInput(uint8_t *input1, uint8_t *input2, int w, int h, int mode);
+  void RunCheckOutput();
+  void RunSpeedTest();
+  ComputeMeanStddevFunc target_compute_mean_stddev_func;
+  ComputeCorrFunc target_compute_corr_func;
 
   libaom_test::ACMRandom rnd_;
 };
@@ -51,13 +60,87 @@
 AV1CornerMatchTest::~AV1CornerMatchTest() = default;
 void AV1CornerMatchTest::SetUp() {
   rnd_.Reset(ACMRandom::DeterministicSeed());
-  target_func = GET_PARAM(1);
+  target_compute_mean_stddev_func = GET_PARAM(1);
+  target_compute_corr_func = GET_PARAM(2);
 }
 
-void AV1CornerMatchTest::RunCheckOutput(int run_times) {
+void AV1CornerMatchTest::GenerateInput(uint8_t *input1, uint8_t *input2, int w,
+                                       int h, int mode) {
+  if (mode == 0) {
+    for (int i = 0; i < h; ++i)
+      for (int j = 0; j < w; ++j) {
+        input1[i * w + j] = rnd_.Rand8();
+        input2[i * w + j] = rnd_.Rand8();
+      }
+  } else if (mode == 1) {
+    for (int i = 0; i < h; ++i)
+      for (int j = 0; j < w; ++j) {
+        int v = rnd_.Rand8();
+        input1[i * w + j] = v;
+        input2[i * w + j] = (v / 2) + (rnd_.Rand8() & 15);
+      }
+  }
+}
+
+void AV1CornerMatchTest::RunCheckOutput() {
   const int w = 128, h = 128;
-  const int num_iters = 10000;
-  int i, j;
+  const int num_iters = 1000;
+
+  std::unique_ptr<uint8_t[]> input1(new (std::nothrow) uint8_t[w * h]);
+  std::unique_ptr<uint8_t[]> input2(new (std::nothrow) uint8_t[w * h]);
+  ASSERT_NE(input1, nullptr);
+  ASSERT_NE(input2, nullptr);
+
+  // Test the two extreme cases:
+  // i) Random data, should have correlation close to 0
+  // ii) Linearly related data + noise, should have correlation close to 1
+  int mode = GET_PARAM(0);
+  GenerateInput(&input1[0], &input2[0], w, h, mode);
+
+  for (int i = 0; i < num_iters; ++i) {
+    int x1 = MATCH_SZ_BY2 + rnd_.PseudoUniform(w + 1 - MATCH_SZ);
+    int y1 = MATCH_SZ_BY2 + rnd_.PseudoUniform(h + 1 - MATCH_SZ);
+    int x2 = MATCH_SZ_BY2 + rnd_.PseudoUniform(w + 1 - MATCH_SZ);
+    int y2 = MATCH_SZ_BY2 + rnd_.PseudoUniform(h + 1 - MATCH_SZ);
+
+    double c_mean1, c_one_over_stddev1, c_mean2, c_one_over_stddev2;
+    bool c_valid1 = aom_compute_mean_stddev_c(input1.get(), w, x1, y1, &c_mean1,
+                                              &c_one_over_stddev1);
+    bool c_valid2 = aom_compute_mean_stddev_c(input2.get(), w, x2, y2, &c_mean2,
+                                              &c_one_over_stddev2);
+
+    double simd_mean1, simd_one_over_stddev1, simd_mean2, simd_one_over_stddev2;
+    bool simd_valid1 = target_compute_mean_stddev_func(
+        input1.get(), w, x1, y1, &simd_mean1, &simd_one_over_stddev1);
+    bool simd_valid2 = target_compute_mean_stddev_func(
+        input2.get(), w, x2, y2, &simd_mean2, &simd_one_over_stddev2);
+
+    // Run the correlation calculation even if one of the "valid" flags is
+    // false, i.e. if one of the patches doesn't have enough variance. This is
+    // safe because any potential division by 0 is caught in
+    // aom_compute_mean_stddev(), and one_over_stddev is set to 0 instead.
+    // This causes aom_compute_correlation() to return 0, without causing a
+    // division by 0.
+    const double c_corr = aom_compute_correlation_c(
+        input1.get(), w, x1, y1, c_mean1, c_one_over_stddev1, input2.get(), w,
+        x2, y2, c_mean2, c_one_over_stddev2);
+    const double simd_corr = target_compute_corr_func(
+        input1.get(), w, x1, y1, c_mean1, c_one_over_stddev1, input2.get(), w,
+        x2, y2, c_mean2, c_one_over_stddev2);
+
+    ASSERT_EQ(simd_valid1, c_valid1);
+    ASSERT_EQ(simd_valid2, c_valid2);
+    ASSERT_EQ(simd_mean1, c_mean1);
+    ASSERT_EQ(simd_one_over_stddev1, c_one_over_stddev1);
+    ASSERT_EQ(simd_mean2, c_mean2);
+    ASSERT_EQ(simd_one_over_stddev2, c_one_over_stddev2);
+    ASSERT_EQ(simd_corr, c_corr);
+  }
+}
+
+void AV1CornerMatchTest::RunSpeedTest() {
+  const int w = 16, h = 16;
+  const int num_iters = 1000000;
   aom_usec_timer ref_timer, test_timer;
 
   std::unique_ptr<uint8_t[]> input1(new (std::nothrow) uint8_t[w * h]);
@@ -69,76 +152,82 @@
   // i) Random data, should have correlation close to 0
   // ii) Linearly related data + noise, should have correlation close to 1
   int mode = GET_PARAM(0);
-  if (mode == 0) {
-    for (i = 0; i < h; ++i)
-      for (j = 0; j < w; ++j) {
-        input1[i * w + j] = rnd_.Rand8();
-        input2[i * w + j] = rnd_.Rand8();
-      }
-  } else if (mode == 1) {
-    for (i = 0; i < h; ++i)
-      for (j = 0; j < w; ++j) {
-        int v = rnd_.Rand8();
-        input1[i * w + j] = v;
-        input2[i * w + j] = (v / 2) + (rnd_.Rand8() & 15);
-      }
+  GenerateInput(&input1[0], &input2[0], w, h, mode);
+
+  // Time aom_compute_mean_stddev()
+  double c_mean1, c_one_over_stddev1, c_mean2, c_one_over_stddev2;
+  aom_usec_timer_start(&ref_timer);
+  for (int i = 0; i < num_iters; i++) {
+    aom_compute_mean_stddev_c(input1.get(), w, 0, 0, &c_mean1,
+                              &c_one_over_stddev1);
+    aom_compute_mean_stddev_c(input2.get(), w, 0, 0, &c_mean2,
+                              &c_one_over_stddev2);
   }
+  aom_usec_timer_mark(&ref_timer);
+  int elapsed_time_c = static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
 
-  for (i = 0; i < num_iters; ++i) {
-    int x1 = MATCH_SZ_BY2 + rnd_.PseudoUniform(w - 2 * MATCH_SZ_BY2);
-    int y1 = MATCH_SZ_BY2 + rnd_.PseudoUniform(h - 2 * MATCH_SZ_BY2);
-    int x2 = MATCH_SZ_BY2 + rnd_.PseudoUniform(w - 2 * MATCH_SZ_BY2);
-    int y2 = MATCH_SZ_BY2 + rnd_.PseudoUniform(h - 2 * MATCH_SZ_BY2);
-
-    double res_c = av1_compute_cross_correlation_c(input1.get(), w, x1, y1,
-                                                   input2.get(), w, x2, y2);
-    double res_simd =
-        target_func(input1.get(), w, x1, y1, input2.get(), w, x2, y2);
-
-    if (run_times > 1) {
-      aom_usec_timer_start(&ref_timer);
-      for (j = 0; j < run_times; j++) {
-        av1_compute_cross_correlation_c(input1.get(), w, x1, y1, input2.get(),
-                                        w, x2, y2);
-      }
-      aom_usec_timer_mark(&ref_timer);
-      const int elapsed_time_c =
-          static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
-
-      aom_usec_timer_start(&test_timer);
-      for (j = 0; j < run_times; j++) {
-        target_func(input1.get(), w, x1, y1, input2.get(), w, x2, y2);
-      }
-      aom_usec_timer_mark(&test_timer);
-      const int elapsed_time_simd =
-          static_cast<int>(aom_usec_timer_elapsed(&test_timer));
-
-      printf(
-          "c_time=%d \t simd_time=%d \t "
-          "gain=%d\n",
-          elapsed_time_c, elapsed_time_simd,
-          (elapsed_time_c / elapsed_time_simd));
-    } else {
-      ASSERT_EQ(res_simd, res_c);
-    }
+  double simd_mean1, simd_one_over_stddev1, simd_mean2, simd_one_over_stddev2;
+  aom_usec_timer_start(&test_timer);
+  for (int i = 0; i < num_iters; i++) {
+    target_compute_mean_stddev_func(input1.get(), w, 0, 0, &simd_mean1,
+                                    &simd_one_over_stddev1);
+    target_compute_mean_stddev_func(input2.get(), w, 0, 0, &simd_mean2,
+                                    &simd_one_over_stddev2);
   }
+  aom_usec_timer_mark(&test_timer);
+  int elapsed_time_simd = static_cast<int>(aom_usec_timer_elapsed(&test_timer));
+
+  printf(
+      "aom_compute_mean_stddev(): c_time=%6d   simd_time=%6d   "
+      "gain=%.3f\n",
+      elapsed_time_c, elapsed_time_simd,
+      (elapsed_time_c / (double)elapsed_time_simd));
+
+  // Time aom_compute_correlation
+  aom_usec_timer_start(&ref_timer);
+  for (int i = 0; i < num_iters; i++) {
+    aom_compute_correlation_c(input1.get(), w, 0, 0, c_mean1,
+                              c_one_over_stddev1, input2.get(), w, 0, 0,
+                              c_mean2, c_one_over_stddev2);
+  }
+  aom_usec_timer_mark(&ref_timer);
+  elapsed_time_c = static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
+
+  aom_usec_timer_start(&test_timer);
+  for (int i = 0; i < num_iters; i++) {
+    target_compute_corr_func(input1.get(), w, 0, 0, c_mean1, c_one_over_stddev1,
+                             input2.get(), w, 0, 0, c_mean2,
+                             c_one_over_stddev2);
+  }
+  aom_usec_timer_mark(&test_timer);
+  elapsed_time_simd = static_cast<int>(aom_usec_timer_elapsed(&test_timer));
+
+  printf(
+      "aom_compute_correlation(): c_time=%6d   simd_time=%6d   "
+      "gain=%.3f\n",
+      elapsed_time_c, elapsed_time_simd,
+      (elapsed_time_c / (double)elapsed_time_simd));
 }
 
-TEST_P(AV1CornerMatchTest, CheckOutput) { RunCheckOutput(1); }
-TEST_P(AV1CornerMatchTest, DISABLED_Speed) { RunCheckOutput(100000); }
+TEST_P(AV1CornerMatchTest, CheckOutput) { RunCheckOutput(); }
+TEST_P(AV1CornerMatchTest, DISABLED_Speed) { RunSpeedTest(); }
 
 #if HAVE_SSE4_1
 INSTANTIATE_TEST_SUITE_P(
     SSE4_1, AV1CornerMatchTest,
-    ::testing::Values(make_tuple(0, &av1_compute_cross_correlation_sse4_1),
-                      make_tuple(1, &av1_compute_cross_correlation_sse4_1)));
+    ::testing::Values(make_tuple(0, &aom_compute_mean_stddev_sse4_1,
+                                 &aom_compute_correlation_sse4_1),
+                      make_tuple(1, &aom_compute_mean_stddev_sse4_1,
+                                 &aom_compute_correlation_sse4_1)));
 #endif
 
 #if HAVE_AVX2
 INSTANTIATE_TEST_SUITE_P(
     AVX2, AV1CornerMatchTest,
-    ::testing::Values(make_tuple(0, &av1_compute_cross_correlation_avx2),
-                      make_tuple(1, &av1_compute_cross_correlation_avx2)));
+    ::testing::Values(make_tuple(0, &aom_compute_mean_stddev_avx2,
+                                 &aom_compute_correlation_avx2),
+                      make_tuple(1, &aom_compute_mean_stddev_avx2,
+                                 &aom_compute_correlation_avx2)));
 #endif
 }  // namespace AV1CornerMatch