Adjustment to the optical flow API.

Mostly refactors involving inline and static functions and naming of
functions.

Also do a filtering of motion field before the optical flow
calculation.

Change-Id: I180900ca86f880680fa67668462b94d6703e0f77
diff --git a/av1/encoder/optical_flow.c b/av1/encoder/optical_flow.c
index eed1def..00f91af 100644
--- a/av1/encoder/optical_flow.c
+++ b/av1/encoder/optical_flow.c
@@ -21,6 +21,15 @@
 
 #if CONFIG_OPTICAL_FLOW_API
 
+void av1_init_opfl_params(OPFL_PARAMS *opfl_params) {
+  opfl_params->pyramid_levels = OPFL_PYRAMID_LEVELS;
+  opfl_params->lk_params = NULL;
+}
+
+void av1_init_lk_params(LK_PARAMS *lk_params) {
+  lk_params->window_size = OPFL_WINDOW_SIZE;
+}
+
 // Helper function to determine whether a frame is encoded with high bit-depth.
 static INLINE int is_frame_high_bitdepth(const YV12_BUFFER_CONFIG *frame) {
   return (frame->flags & YV12_FLAG_HIGHBITDEPTH) ? 1 : 0;
@@ -31,20 +40,16 @@
   return (opfl_params->flags & OPFL_FLAG_SPARSE) ? 1 : 0;
 }
 
-typedef struct LOCALMV {
-  double row;
-  double col;
-} LOCALMV;
-
-void gradients_over_window(const YV12_BUFFER_CONFIG *frame,
-                           const YV12_BUFFER_CONFIG *ref_frame,
-                           const double x_coord, const double y_coord,
-                           const int window_size, const int bit_depth,
-                           double *ix, double *iy, double *it, LOCALMV *mv);
+static void gradients_over_window(const YV12_BUFFER_CONFIG *frame,
+                                  const YV12_BUFFER_CONFIG *ref_frame,
+                                  const double x_coord, const double y_coord,
+                                  const int window_size, const int bit_depth,
+                                  double *ix, double *iy, double *it,
+                                  LOCALMV *mv);
 
 // coefficients for bilinear interpolation on unit square
-int pixel_interp(const double x, const double y, const double b00,
-                 const double b01, const double b10, const double b11) {
+static int pixel_interp(const double x, const double y, const double b00,
+                        const double b01, const double b10, const double b11) {
   const int xint = (int)x;
   const int yint = (int)y;
   const double xdec = x - xint;
@@ -57,10 +62,12 @@
   int interp = (int)round(a * b00 + b * b01 + c * b10 + d * b11);
   return interp;
 }
+
 // bilinear interpolation to find subpixel values
-int get_subpixels(const YV12_BUFFER_CONFIG *frame, int *pred, const int w,
-                  const int h, LOCALMV mv, const double x_coord,
-                  const double y_coord) {
+static AOM_INLINE int get_subpixels(const YV12_BUFFER_CONFIG *frame, int *pred,
+                                    const int w, const int h, LOCALMV mv,
+                                    const double x_coord,
+                                    const double y_coord) {
   double left = x_coord + mv.row;
   double top = y_coord + mv.col;
   const int fromedge = 2;
@@ -92,10 +99,11 @@
   }
   return 0;
 }
+
 // Scharr filter to compute spatial gradient
-void spatial_gradient(const YV12_BUFFER_CONFIG *frame, const int x_coord,
-                      const int y_coord, const int direction,
-                      double *derivative) {
+static void spatial_gradient(const YV12_BUFFER_CONFIG *frame, const int x_coord,
+                             const int y_coord, const int direction,
+                             double *derivative) {
   double *filter;
   // Scharr filters
   double gx[9] = { -3, 0, 3, -10, 0, 10, -3, 0, 3 };
@@ -117,12 +125,13 @@
   // normalization scaling factor for scharr
   *derivative = d / 32.0;
 }
+
 // Determine the spatial gradient at subpixel locations
 // For example, when reducing images for pyramidal LK,
 // corners found in original image may be at subpixel locations.
-void gradient_interp(double *fullpel_deriv, const double x_coord,
-                     const double y_coord, const int w, const int h,
-                     double *derivative) {
+static void gradient_interp(double *fullpel_deriv, const double x_coord,
+                            const double y_coord, const int w, const int h,
+                            double *derivative) {
   const int xint = (int)x_coord;
   const int yint = (int)y_coord;
   double interp;
@@ -138,17 +147,15 @@
   *derivative = interp;
 }
 
-void temporal_gradient(const YV12_BUFFER_CONFIG *frame,
-                       const YV12_BUFFER_CONFIG *frame2, const double x_coord,
-                       const double y_coord, const int bit_depth,
-                       double *derivative, LOCALMV *mv) {
-  // TODO(any): this is a roundabout way of enforcing build_one_inter_pred
-  // to use the 8-tap filter (instead of lower). it would be more
-  // efficient to apply the filter only at 1 pixel instead of 25 pixels.
-  const int w = 5;
-  const int h = 5;
-  uint8_t pred1[25];
-  uint8_t pred2[25];
+static void temporal_gradient(const YV12_BUFFER_CONFIG *frame,
+                              const YV12_BUFFER_CONFIG *frame2,
+                              const double x_coord, const double y_coord,
+                              const int bit_depth, double *derivative,
+                              LOCALMV *mv) {
+  const int w = 2;
+  const int h = 2;
+  uint8_t pred1[4];
+  uint8_t pred2[4];
 
   const int y = (int)y_coord;
   const int x = (int)x_coord;
@@ -170,6 +177,10 @@
   av1_init_inter_params(&inter_pred_params, w, h, y, x, subsampling_x,
                         subsampling_y, bit_depth, is_high_bitdepth, is_intrabc,
                         &scale, &ref_buf2, interp_filters);
+  inter_pred_params.interp_filter_params[0] =
+      &av1_interp_filter_params_list[interp_filters.as_filters.x_filter];
+  inter_pred_params.interp_filter_params[1] =
+      &av1_interp_filter_params_list[interp_filters.as_filters.y_filter];
   inter_pred_params.conv_params = get_conv_params(0, plane, bit_depth);
   MV newmv = { .row = (int16_t)round((mv->row + xdec) * 8),
                .col = (int16_t)round((mv->col + ydec) * 8) };
@@ -179,6 +190,10 @@
   av1_init_inter_params(&inter_pred_params, w, h, y, x, subsampling_x,
                         subsampling_y, bit_depth, is_high_bitdepth, is_intrabc,
                         &scale, &ref_buf1, interp_filters);
+  inter_pred_params.interp_filter_params[0] =
+      &av1_interp_filter_params_list[interp_filters.as_filters.x_filter];
+  inter_pred_params.interp_filter_params[1] =
+      &av1_interp_filter_params_list[interp_filters.as_filters.y_filter];
   inter_pred_params.conv_params = get_conv_params(0, plane, bit_depth);
   MV zeroMV = { .row = (int16_t)round(xdec * 8),
                 .col = (int16_t)round(ydec * 8) };
@@ -186,15 +201,17 @@
 
   *derivative = pred2[0] - pred1[0];
 }
+
 // Numerical differentiate over window_size x window_size surrounding (x,y)
 // location. Alters ix, iy, it to contain numerical partial derivatives
-void gradients_over_window(const YV12_BUFFER_CONFIG *frame,
-                           const YV12_BUFFER_CONFIG *ref_frame,
-                           const double x_coord, const double y_coord,
-                           const int window_size, const int bit_depth,
-                           double *ix, double *iy, double *it, LOCALMV *mv) {
-  const double left = x_coord - window_size / 2;
-  const double top = y_coord - window_size / 2;
+static void gradients_over_window(const YV12_BUFFER_CONFIG *frame,
+                                  const YV12_BUFFER_CONFIG *ref_frame,
+                                  const double x_coord, const double y_coord,
+                                  const int window_size, const int bit_depth,
+                                  double *ix, double *iy, double *it,
+                                  LOCALMV *mv) {
+  const double left = x_coord - window_size / 2.0;
+  const double top = y_coord - window_size / 2.0;
   // gradient operators need pixel before and after (start at 1)
   const double x_start = AOMMAX(1, left);
   const double y_start = AOMMAX(1, top);
@@ -204,8 +221,8 @@
   double deriv_y;
   double deriv_t;
 
-  const double x_end = AOMMIN(x_coord + window_size / 2, frame_width - 2);
-  const double y_end = AOMMIN(y_coord + window_size / 2, frame_height - 2);
+  const double x_end = AOMMIN(x_coord + window_size / 2.0, frame_width - 2);
+  const double y_end = AOMMIN(y_coord + window_size / 2.0, frame_height - 2);
   const int xs = (int)AOMMAX(1, x_start - 1);
   const int ys = (int)AOMMAX(1, y_start - 1);
   const int xe = (int)AOMMIN(x_end + 2, frame_width - 2);
@@ -251,7 +268,7 @@
 
 // To compute eigenvalues of 2x2 matrix: Solve for lambda where
 // Determinant(matrix - lambda*identity) == 0
-void eigenvalues_2x2(const double *matrix, double *eig) {
+static void eigenvalues_2x2(const double *matrix, double *eig) {
   const double a = 1;
   const double b = -1 * matrix[0] - matrix[3];
   const double c = -1 * matrix[1] * matrix[2] + matrix[0] * matrix[3];
@@ -266,11 +283,12 @@
     eig[1] = tmp;
   }
 }
+
 // Shi-Tomasi corner detection criteria
-double corner_score(const YV12_BUFFER_CONFIG *frame_to_filter,
-                    const YV12_BUFFER_CONFIG *ref_frame, const int x,
-                    const int y, double *i_x, double *i_y, double *i_t,
-                    const int n, const int bit_depth) {
+static double corner_score(const YV12_BUFFER_CONFIG *frame_to_filter,
+                           const YV12_BUFFER_CONFIG *ref_frame, const int x,
+                           const int y, double *i_x, double *i_y, double *i_t,
+                           const int n, const int bit_depth) {
   double eig[2];
   LOCALMV mv = { .row = 0, .col = 0 };
   // TODO(any): technically, ref_frame and i_t are not used by corner score
@@ -286,11 +304,13 @@
   eigenvalues_2x2(M, eig);
   return fabs(eig[0]);
 }
+
 // Finds corners in frame_to_filter
 // For less strict requirements (i.e. more corners), decrease threshold
-int detect_corners(const YV12_BUFFER_CONFIG *frame_to_filter,
-                   const YV12_BUFFER_CONFIG *ref_frame, const int maxcorners,
-                   int *ref_corners, const int bit_depth) {
+static int detect_corners(const YV12_BUFFER_CONFIG *frame_to_filter,
+                          const YV12_BUFFER_CONFIG *ref_frame,
+                          const int maxcorners, int *ref_corners,
+                          const int bit_depth) {
   const int frame_height = frame_to_filter->y_crop_height;
   const int frame_width = frame_to_filter->y_crop_width;
   // TODO(any): currently if maxcorners is decreased, then it only means
@@ -343,10 +363,11 @@
   }
   return countcorners;
 }
+
 // weights is an nxn matrix. weights is filled with a gaussian function,
 // with independent variable: distance from the center point.
-void gaussian(const double sigma, const int n, const int normalize,
-              double *weights) {
+static void gaussian(const double sigma, const int n, const int normalize,
+                     double *weights) {
   double total_weight = 0;
   for (int j = 0; j < n; j++) {
     for (int i = 0; i < n; i++) {
@@ -362,17 +383,19 @@
     }
   }
 }
-double convolve(const double *filter, const int *img, const int size) {
+
+static double convolve(const double *filter, const int *img, const int size) {
   double result = 0;
   for (int i = 0; i < size; i++) {
     result += filter[i] * img[i];
   }
   return result;
 }
+
 // Applies a Gaussian low-pass smoothing filter to produce
 // a corresponding lower resolution image with halved dimensions
-void reduce(uint8_t *img, int height, int width, int stride,
-            uint8_t *reduced_img) {
+static void reduce(uint8_t *img, int height, int width, int stride,
+                   uint8_t *reduced_img) {
   const int new_width = width / 2;
   const int window_size = 5;
   const double gaussian_filter[25] = {
@@ -399,13 +422,16 @@
         }
       }
       reduced_img[(y / 2) * new_width + (x / 2)] = (uint8_t)convolve(
-          gaussian_filter, img_section, (int)pow(window_size, 2));
+          gaussian_filter, img_section, window_size * window_size);
     }
   }
 }
-int cmpfunc(const void *a, const void *b) { return (*(int *)a - *(int *)b); }
-void filter_mvs(const MV_FILTER_TYPE mv_filter, const int frame_height,
-                const int frame_width, LOCALMV *localmvs, MV *mvs) {
+
+static int cmpfunc(const void *a, const void *b) {
+  return (*(int *)a - *(int *)b);
+}
+static void filter_mvs(const MV_FILTER_TYPE mv_filter, const int frame_height,
+                       const int frame_width, LOCALMV *localmvs, MV *mvs) {
   const int n = 5;  // window size
   // for smoothing filter
   const double gaussian_filter[25] = {
@@ -421,56 +447,51 @@
     for (int y = 0; y < frame_height; y++) {
       for (int x = 0; x < frame_width; x++) {
         int center_idx = y * frame_width + x;
-        if (fabs(localmvs[center_idx].row) > 0 ||
-            fabs(localmvs[center_idx].col) > 0) {
-          int i = 0;
-          double filtered_row = 0;
-          double filtered_col = 0;
-          for (int yy = y - n / 2; yy <= y + n / 2; yy++) {
-            for (int xx = x - n / 2; xx <= x + n / 2; xx++) {
-              int yvalue = yy + y;
-              int xvalue = xx + x;
-              // copied pixels outside the boundary
-              if (yvalue < 0) yvalue = 0;
-              if (xvalue < 0) xvalue = 0;
-              if (yvalue >= frame_height) yvalue = frame_height - 1;
-              if (xvalue >= frame_width) xvalue = frame_width - 1;
-              int index = yvalue * frame_width + xvalue;
-              if (mv_filter == MV_FILTER_SMOOTH) {
-                filtered_row += mvs[index].row * gaussian_filter[i];
-                filtered_col += mvs[index].col * gaussian_filter[i];
-              } else if (mv_filter == MV_FILTER_MEDIAN) {
-                mvrows[i] = mvs[index].row;
-                mvcols[i] = mvs[index].col;
-              }
-              i++;
+        int i = 0;
+        double filtered_row = 0;
+        double filtered_col = 0;
+        for (int yy = y - n / 2; yy <= y + n / 2; yy++) {
+          for (int xx = x - n / 2; xx <= x + n / 2; xx++) {
+            int yvalue = yy;
+            int xvalue = xx;
+            // copied pixels outside the boundary
+            if (yvalue < 0) yvalue = 0;
+            if (xvalue < 0) xvalue = 0;
+            if (yvalue >= frame_height) yvalue = frame_height - 1;
+            if (xvalue >= frame_width) xvalue = frame_width - 1;
+            int index = yvalue * frame_width + xvalue;
+            if (mv_filter == MV_FILTER_SMOOTH) {
+              filtered_row += mvs[index].row * gaussian_filter[i];
+              filtered_col += mvs[index].col * gaussian_filter[i];
+            } else if (mv_filter == MV_FILTER_MEDIAN) {
+              mvrows[i] = mvs[index].row;
+              mvcols[i] = mvs[index].col;
             }
+            i++;
           }
-
-          MV mv = mvs[center_idx];
-          if (mv_filter == MV_FILTER_SMOOTH) {
-            mv.row = (int16_t)filtered_row;
-            mv.col = (int16_t)filtered_col;
-          } else if (mv_filter == MV_FILTER_MEDIAN) {
-            qsort(mvrows, 25, sizeof(mv.row), cmpfunc);
-            qsort(mvcols, 25, sizeof(mv.col), cmpfunc);
-            mv.row = mvrows[25 / 2];
-            mv.col = mvcols[25 / 2];
-          }
-          LOCALMV localmv = { .row = ((double)mv.row) / 8,
-                              .col = ((double)mv.row) / 8 };
-          localmvs[y * frame_width + x] = localmv;
-          // if mvs array is immediately updated here, then the result may
-          // propagate to other pixels.
         }
+
+        MV mv = mvs[center_idx];
+        if (mv_filter == MV_FILTER_SMOOTH) {
+          mv.row = (int16_t)filtered_row;
+          mv.col = (int16_t)filtered_col;
+        } else if (mv_filter == MV_FILTER_MEDIAN) {
+          qsort(mvrows, 25, sizeof(mv.row), cmpfunc);
+          qsort(mvcols, 25, sizeof(mv.col), cmpfunc);
+          mv.row = mvrows[25 / 2];
+          mv.col = mvcols[25 / 2];
+        }
+        LOCALMV localmv = { .row = ((double)mv.row) / 8,
+                            .col = ((double)mv.row) / 8 };
+        localmvs[y * frame_width + x] = localmv;
+        // if mvs array is immediately updated here, then the result may
+        // propagate to other pixels.
       }
     }
     for (int i = 0; i < frame_height * frame_width; i++) {
-      if (fabs(localmvs[i].row) > 0 || fabs(localmvs[i].col) > 0) {
-        MV mv = { .row = (int16_t)round(8 * localmvs[i].row),
-                  .col = (int16_t)round(8 * localmvs[i].col) };
-        mvs[i] = mv;
-      }
+      MV mv = { .row = (int16_t)round(8 * localmvs[i].row),
+                .col = (int16_t)round(8 * localmvs[i].col) };
+      mvs[i] = mv;
     }
   }
 }
@@ -478,20 +499,20 @@
 // Computes optical flow at a single pyramid level,
 // using Lucas-Kanade algorithm.
 // Modifies mvs array.
-void lucas_kanade(const YV12_BUFFER_CONFIG *frame_to_filter,
-                  const YV12_BUFFER_CONFIG *ref_frame, const int level,
-                  const LK_PARAMS *lk_params, const int num_ref_corners,
-                  int *ref_corners, const int highres_frame_width,
-                  const int bit_depth, LOCALMV *mvs) {
+static void lucas_kanade(const YV12_BUFFER_CONFIG *from_frame,
+                         const YV12_BUFFER_CONFIG *to_frame, const int level,
+                         const LK_PARAMS *lk_params, const int num_ref_corners,
+                         int *ref_corners, const int mv_stride,
+                         const int bit_depth, LOCALMV *mvs) {
   assert(lk_params->window_size > 0 && lk_params->window_size % 2 == 0);
   const int n = lk_params->window_size;
   // algorithm is sensitive to window size
-  double *i_x = (double *)aom_malloc(n * n * sizeof(double));
-  double *i_y = (double *)aom_malloc(n * n * sizeof(double));
-  double *i_t = (double *)aom_malloc(n * n * sizeof(double));
+  double *i_x = (double *)aom_malloc(n * n * sizeof(*i_x));
+  double *i_y = (double *)aom_malloc(n * n * sizeof(*i_y));
+  double *i_t = (double *)aom_malloc(n * n * sizeof(*i_t));
   const int expand_multiplier = (int)pow(2, level);
   double sigma = 0.2 * n;
-  double *weights = (double *)aom_malloc(n * n * sizeof(double));
+  double *weights = (double *)aom_malloc(n * n * sizeof(*weights));
   // normalizing doesn't really affect anything since it's applied
   // to every component of M and b
   gaussian(sigma, n, 0, weights);
@@ -500,7 +521,7 @@
     const double y_coord = 1.0 * ref_corners[i * 2 + 1] / expand_multiplier;
     int highres_x = ref_corners[i * 2];
     int highres_y = ref_corners[i * 2 + 1];
-    int mv_idx = highres_y * (highres_frame_width) + highres_x;
+    int mv_idx = highres_y * (mv_stride) + highres_x;
     LOCALMV mv_old = mvs[mv_idx];
     mv_old.row = mv_old.row / expand_multiplier;
     mv_old.col = mv_old.col / expand_multiplier;
@@ -511,8 +532,8 @@
       i_y[j] = 0;
       i_t[j] = 0;
     }
-    gradients_over_window(frame_to_filter, ref_frame, x_coord, y_coord, n,
-                          bit_depth, i_x, i_y, i_t, &mv_old);
+    gradients_over_window(from_frame, to_frame, x_coord, y_coord, n, bit_depth,
+                          i_x, i_y, i_t, &mv_old);
     double Mres1[1] = { 0 }, Mres2[1] = { 0 }, Mres3[1] = { 0 };
     double bres1[1] = { 0 }, bres2[1] = { 0 };
     for (int j = 0; j < n * n; j++) {
@@ -548,10 +569,11 @@
 }
 
 // Apply optical flow iteratively at each pyramid level
-void pyramid_optical_flow(const YV12_BUFFER_CONFIG *from_frame,
-                          const YV12_BUFFER_CONFIG *to_frame,
-                          const int bit_depth, const OPFL_PARAMS *opfl_params,
-                          const OPTFLOW_METHOD method, LOCALMV *mvs) {
+static void pyramid_optical_flow(const YV12_BUFFER_CONFIG *from_frame,
+                                 const YV12_BUFFER_CONFIG *to_frame,
+                                 const int bit_depth,
+                                 const OPFL_PARAMS *opfl_params,
+                                 const OPTFLOW_METHOD method, LOCALMV *mvs) {
   assert(opfl_params->pyramid_levels > 0 &&
          opfl_params->pyramid_levels <= MAX_PYRAMID_LEVELS);
   int levels = opfl_params->pyramid_levels;
@@ -565,17 +587,16 @@
   uint8_t *images2[MAX_PYRAMID_LEVELS];
   images1[0] = from_frame->y_buffer;
   images2[0] = to_frame->y_buffer;
-  YV12_BUFFER_CONFIG *buffers1 =
-      aom_malloc(levels * sizeof(YV12_BUFFER_CONFIG));
-  YV12_BUFFER_CONFIG *buffers2 =
-      aom_malloc(levels * sizeof(YV12_BUFFER_CONFIG));
+  YV12_BUFFER_CONFIG *buffers1 = aom_malloc(levels * sizeof(*buffers1));
+  YV12_BUFFER_CONFIG *buffers2 = aom_malloc(levels * sizeof(*buffers2));
   buffers1[0] = *from_frame;
   buffers2[0] = *to_frame;
   int fw = frame_width;
   int fh = frame_height;
   for (int i = 1; i < levels; i++) {
-    images1[i] = (uint8_t *)aom_calloc(fh / 2 * fw / 2, sizeof(uint8_t));
-    images2[i] = (uint8_t *)aom_calloc(fh / 2 * fw / 2, sizeof(uint8_t));
+    // TODO(bohanli): may need to extend buffers for better interpolation SIMD
+    images1[i] = (uint8_t *)aom_calloc(fh / 2 * fw / 2, sizeof(*images1[i]));
+    images2[i] = (uint8_t *)aom_calloc(fh / 2 * fw / 2, sizeof(*images2[i]));
     int stride;
     if (i == 1)
       stride = from_frame->y_stride;
@@ -597,10 +618,11 @@
     buffers2[i] = b;
   }
   // Compute corners for specific frame
-  int maxcorners = from_frame->y_crop_width * from_frame->y_crop_height;
-  int *ref_corners = aom_malloc(maxcorners * 2 * sizeof(int));
+  int *ref_corners = NULL;
   int num_ref_corners = 0;
   if (is_sparse(opfl_params)) {
+    int maxcorners = from_frame->y_crop_width * from_frame->y_crop_height;
+    ref_corners = aom_malloc(maxcorners * 2 * sizeof(*ref_corners));
     num_ref_corners = detect_corners(from_frame, to_frame, maxcorners,
                                      ref_corners, bit_depth);
   }
@@ -618,6 +640,8 @@
     aom_free(images2[i]);
   }
   aom_free(ref_corners);
+  aom_free(buffers1);
+  aom_free(buffers2);
 }
 // Computes optical flow by applying algorithm at
 // multiple pyramid levels of images (lower-resolution, smoothed images)
@@ -634,12 +658,12 @@
 //   mvs: pointer to MVs. Contains initialization, and modified
 //   based on optical flow. Must have
 //   dimensions = from_frame->y_crop_width * from_frame->y_crop_height
-void optical_flow(const YV12_BUFFER_CONFIG *from_frame,
-                  const YV12_BUFFER_CONFIG *to_frame, const int from_frame_idx,
-                  const int to_frame_idx, const int bit_depth,
-                  const OPFL_PARAMS *opfl_params,
-                  const MV_FILTER_TYPE mv_filter, const OPTFLOW_METHOD method,
-                  MV *mvs) {
+void av1_optical_flow(const YV12_BUFFER_CONFIG *from_frame,
+                      const YV12_BUFFER_CONFIG *to_frame,
+                      const int from_frame_idx, const int to_frame_idx,
+                      const int bit_depth, const OPFL_PARAMS *opfl_params,
+                      const MV_FILTER_TYPE mv_filter,
+                      const OPTFLOW_METHOD method, MV *mvs) {
   const int frame_height = from_frame->y_crop_height;
   const int frame_width = from_frame->y_crop_width;
   // TODO(any): deal with the case where frames are not of the same dimensions
@@ -657,7 +681,11 @@
   }
 
   // Initialize double mvs based on input parameter mvs array
-  LOCALMV *localmvs = aom_malloc(frame_height * frame_width * sizeof(LOCALMV));
+  LOCALMV *localmvs =
+      aom_malloc(frame_height * frame_width * sizeof(*localmvs));
+
+  filter_mvs(MV_FILTER_SMOOTH, frame_height, frame_width, localmvs, mvs);
+
   for (int i = 0; i < frame_width * frame_height; i++) {
     MV mv = mvs[i];
     LOCALMV localmv = { .row = ((double)mv.row) / 8,
@@ -672,18 +700,13 @@
   for (int j = 0; j < frame_height; j++) {
     for (int i = 0; i < frame_width; i++) {
       int idx = j * frame_width + i;
-      int new_x = (int)(localmvs[idx].row + i);
-      int new_y = (int)(localmvs[idx].col + j);
-      if ((fabs(localmvs[idx].row) >= 0.125 ||
-           fabs(localmvs[idx].col) >= 0.125)) {
-        // if mv points outside of frame (lost feature), keep old mv.
-        if (new_x < frame_width && new_x >= 0 && new_y < frame_height &&
-            new_y >= 0) {
-          MV mv = { .row = (int16_t)round(8 * localmvs[idx].row),
-                    .col = (int16_t)round(8 * localmvs[idx].col) };
-          mvs[idx] = mv;
-        }
+      if (j + localmvs[idx].row < 0 || j + localmvs[idx].row >= frame_height ||
+          i + localmvs[idx].col < 0 || i + localmvs[idx].col >= frame_width) {
+        continue;
       }
+      MV mv = { .row = (int16_t)round(8 * localmvs[idx].row),
+                .col = (int16_t)round(8 * localmvs[idx].col) };
+      mvs[idx] = mv;
     }
   }
 
diff --git a/av1/encoder/optical_flow.h b/av1/encoder/optical_flow.h
index 9b7cd62..a54b6a7 100644
--- a/av1/encoder/optical_flow.h
+++ b/av1/encoder/optical_flow.h
@@ -12,6 +12,8 @@
 #ifndef AOM_AV1_ENCODER_OPTICAL_FLOW_H_
 #define AOM_AV1_ENCODER_OPTICAL_FLOW_H_
 
+#include "aom_scale/yv12config.h"
+#include "av1/common/mv.h"
 #include "config/aom_config.h"
 
 #ifdef __cplusplus
@@ -28,6 +30,11 @@
   MV_FILTER_MEDIAN
 } MV_FILTER_TYPE;
 
+typedef struct LOCALMV {
+  double row;
+  double col;
+} LOCALMV;
+
 #define MAX_PYRAMID_LEVELS 5
 // default options for optical flow
 #define OPFL_WINDOW_SIZE 15
@@ -48,21 +55,16 @@
 
 #define OPFL_FLAG_SPARSE 1
 
-void init_opfl_params(OPFL_PARAMS *opfl_params) {
-  opfl_params->pyramid_levels = OPFL_PYRAMID_LEVELS;
-  opfl_params->lk_params = NULL;
-}
+void av1_init_opfl_params(OPFL_PARAMS *opfl_params);
 
-void init_lk_params(LK_PARAMS *lk_params) {
-  lk_params->window_size = OPFL_WINDOW_SIZE;
-}
+void av1_init_lk_params(LK_PARAMS *lk_params);
 
-void optical_flow(const YV12_BUFFER_CONFIG *from_frame,
-                  const YV12_BUFFER_CONFIG *to_frame, const int from_frame_idx,
-                  const int to_frame_idx, const int bit_depth,
-                  const OPFL_PARAMS *opfl_params,
-                  const MV_FILTER_TYPE mv_filter, const OPTFLOW_METHOD method,
-                  MV *mvs);
+void av1_optical_flow(const YV12_BUFFER_CONFIG *from_frame,
+                      const YV12_BUFFER_CONFIG *to_frame,
+                      const int from_frame_idx, const int to_frame_idx,
+                      const int bit_depth, const OPFL_PARAMS *opfl_params,
+                      const MV_FILTER_TYPE mv_filter,
+                      const OPTFLOW_METHOD method, MV *mvs);
 #endif
 
 #ifdef __cplusplus