Add horn-schunck method to the optical flow API.

Change-Id: Iecfd061169bcfeca1bbeb46778a53c1330b7edd4
diff --git a/av1/encoder/optical_flow.c b/av1/encoder/optical_flow.c
index 00f91af..82ae9c5 100644
--- a/av1/encoder/optical_flow.c
+++ b/av1/encoder/optical_flow.c
@@ -16,6 +16,7 @@
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/mathutils.h"
 #include "av1/encoder/optical_flow.h"
+#include "av1/encoder/sparse_linear_solver.h"
 #include "av1/encoder/reconinter_enc.h"
 #include "aom_mem/aom_mem.h"
 
@@ -23,6 +24,7 @@
 
 void av1_init_opfl_params(OPFL_PARAMS *opfl_params) {
   opfl_params->pyramid_levels = OPFL_PYRAMID_LEVELS;
+  opfl_params->warping_steps = OPFL_WARPING_STEPS;
   opfl_params->lk_params = NULL;
 }
 
@@ -568,6 +570,398 @@
   aom_free(i_y);
 }
 
+// Warp the src_frame to warper_frame according to mvs.
+// mvs point to src_frame
+static void warp_back_frame(YV12_BUFFER_CONFIG *warped_frame,
+                            const YV12_BUFFER_CONFIG *src_frame,
+                            const LOCALMV *mvs, int mv_stride) {
+  int w, h;
+  const int fw = src_frame->y_crop_width;
+  const int fh = src_frame->y_crop_height;
+  const int src_fs = src_frame->y_stride, warped_fs = warped_frame->y_stride;
+  const uint8_t *src_buf = src_frame->y_buffer;
+  uint8_t *warped_buf = warped_frame->y_buffer;
+  double temp;
+  for (h = 0; h < fh; h++) {
+    for (w = 0; w < fw; w++) {
+      double cord_x = (double)w + mvs[h * mv_stride + w].col;
+      double cord_y = (double)h + mvs[h * mv_stride + w].row;
+      cord_x = fclamp(cord_x, 0, (double)(fw - 1));
+      cord_y = fclamp(cord_y, 0, (double)(fh - 1));
+      const int floorx = (int)floor(cord_x);
+      const int floory = (int)floor(cord_y);
+      const double fracx = cord_x - (double)floorx;
+      const double fracy = cord_y - (double)floory;
+
+      temp = 0;
+      for (int hh = 0; hh < 2; hh++) {
+        const double weighth = hh ? (fracy) : (1 - fracy);
+        for (int ww = 0; ww < 2; ww++) {
+          const double weightw = ww ? (fracx) : (1 - fracx);
+          int y = floory + hh;
+          int x = floorx + ww;
+          y = clamp(y, 0, fh - 1);
+          x = clamp(x, 0, fw - 1);
+          temp += (double)src_buf[y * src_fs + x] * weightw * weighth;
+        }
+      }
+      warped_buf[h * warped_fs + w] = (uint8_t)round(temp);
+    }
+  }
+}
+
+// Same as warp_back_frame, but using a better interpolation filter.
+static void warp_back_frame_intp(YV12_BUFFER_CONFIG *warped_frame,
+                                 const YV12_BUFFER_CONFIG *src_frame,
+                                 const LOCALMV *mvs, int mv_stride) {
+  int w, h;
+  const int fw = src_frame->y_crop_width;
+  const int fh = src_frame->y_crop_height;
+  const int warped_fs = warped_frame->y_stride;
+  uint8_t *warped_buf = warped_frame->y_buffer;
+  const int blk = 2;
+  uint8_t temp_blk[4];
+
+  const int is_intrabc = 0;  // Is intra-copied?
+  const int is_high_bitdepth = is_frame_high_bitdepth(src_frame);
+  const int subsampling_x = 0, subsampling_y = 0;  // for y-buffer
+  const int_interpfilters interp_filters =
+      av1_broadcast_interp_filter(MULTITAP_SHARP2);
+  const int plane = 0;  // y-plane
+  const struct buf_2d ref_buf2 = { NULL, src_frame->y_buffer,
+                                   src_frame->y_crop_width,
+                                   src_frame->y_crop_height,
+                                   src_frame->y_stride };
+  const int bit_depth = src_frame->bit_depth;
+  struct scale_factors scale;
+  av1_setup_scale_factors_for_frame(
+      &scale, src_frame->y_crop_width, src_frame->y_crop_height,
+      src_frame->y_crop_width, src_frame->y_crop_height);
+
+  for (h = 0; h < fh; h++) {
+    for (w = 0; w < fw; w++) {
+      InterPredParams inter_pred_params;
+      av1_init_inter_params(&inter_pred_params, blk, blk, h, w, subsampling_x,
+                            subsampling_y, bit_depth, is_high_bitdepth,
+                            is_intrabc, &scale, &ref_buf2, interp_filters);
+      inter_pred_params.interp_filter_params[0] =
+          &av1_interp_filter_params_list[interp_filters.as_filters.x_filter];
+      inter_pred_params.interp_filter_params[1] =
+          &av1_interp_filter_params_list[interp_filters.as_filters.y_filter];
+      inter_pred_params.conv_params = get_conv_params(0, plane, bit_depth);
+      MV newmv = { .row = (int16_t)round((mvs[h * mv_stride + w].row) * 8),
+                   .col = (int16_t)round((mvs[h * mv_stride + w].col) * 8) };
+      av1_enc_build_one_inter_predictor(temp_blk, blk, &newmv,
+                                        &inter_pred_params);
+      warped_buf[h * warped_fs + w] = temp_blk[0];
+    }
+  }
+}
+
+#define DERIVATIVE_FILTER_LENGTH 7
+double filter[DERIVATIVE_FILTER_LENGTH] = { -1.0 / 60, 9.0 / 60,  -45.0 / 60, 0,
+                                            45.0 / 60, -9.0 / 60, 1.0 / 60 };
+
+// Get gradient of the whole frame
+static void get_frame_gradients(const YV12_BUFFER_CONFIG *from_frame,
+                                const YV12_BUFFER_CONFIG *to_frame, double *ix,
+                                double *iy, double *it, int grad_stride) {
+  int w, h, k, idx;
+  const int fw = from_frame->y_crop_width;
+  const int fh = from_frame->y_crop_height;
+  const int from_fs = from_frame->y_stride, to_fs = to_frame->y_stride;
+  const uint8_t *from_buf = from_frame->y_buffer;
+  const uint8_t *to_buf = to_frame->y_buffer;
+
+  const int lh = DERIVATIVE_FILTER_LENGTH;
+  const int hleft = (lh - 1) / 2;
+
+  for (h = 0; h < fh; h++) {
+    for (w = 0; w < fw; w++) {
+      // x
+      ix[h * grad_stride + w] = 0;
+      for (k = 0; k < lh; k++) {
+        // if we want to make this block dependent, need to extend the
+        // boundaries using other initializations.
+        idx = w + k - hleft;
+        idx = clamp(idx, 0, fw - 1);
+        ix[h * grad_stride + w] += filter[k] * 0.5 *
+                                   ((double)from_buf[h * from_fs + idx] +
+                                    (double)to_buf[h * to_fs + idx]);
+      }
+      // y
+      iy[h * grad_stride + w] = 0;
+      for (k = 0; k < lh; k++) {
+        // if we want to make this block dependent, need to extend the
+        // boundaries using other initializations.
+        idx = h + k - hleft;
+        idx = clamp(idx, 0, fh - 1);
+        iy[h * grad_stride + w] += filter[k] * 0.5 *
+                                   ((double)from_buf[idx * from_fs + w] +
+                                    (double)to_buf[idx * to_fs + w]);
+      }
+      // t
+      it[h * grad_stride + w] =
+          (double)to_buf[h * to_fs + w] - (double)from_buf[h * from_fs + w];
+    }
+  }
+}
+
+// Solve for linear equations given by the H-S method
+static void solve_horn_schunck(const double *ix, const double *iy,
+                               const double *it, int grad_stride, int width,
+                               int height, const LOCALMV *init_mvs,
+                               int init_mv_stride, LOCALMV *mvs,
+                               int mv_stride) {
+  // TODO(bohanli): May just need to allocate the buffers once per optical flow
+  // calculation
+  int *row_pos = aom_calloc(width * height * 28, sizeof(*row_pos));
+  int *col_pos = aom_calloc(width * height * 28, sizeof(*col_pos));
+  double *values = aom_calloc(width * height * 28, sizeof(*values));
+  double *mv_vec = aom_calloc(width * height * 2, sizeof(*mv_vec));
+  double *mv_init_vec = aom_calloc(width * height * 2, sizeof(*mv_init_vec));
+  double *temp_b = aom_calloc(width * height * 2, sizeof(*temp_b));
+  double *b = aom_calloc(width * height * 2, sizeof(*b));
+
+  // the location idx for neighboring pixels, k < 4 are the 4 direct neighbors
+  const int check_locs_y[12] = { 0, 0, -1, 1, -1, -1, 1, 1, 0, 0, -2, 2 };
+  const int check_locs_x[12] = { -1, 1, 0, 0, -1, 1, -1, 1, -2, 2, 0, 0 };
+
+  int h, w, checkh, checkw, k;
+  const int offset = height * width;
+  SPARSE_MTX A;
+  int c = 0;
+  const double lambda = 100;
+
+  for (w = 0; w < width; w++) {
+    for (h = 0; h < height; h++) {
+      mv_init_vec[w * height + h] = init_mvs[h * init_mv_stride + w].col;
+      mv_init_vec[w * height + h + offset] =
+          init_mvs[h * init_mv_stride + w].row;
+    }
+  }
+
+  // get matrix A
+  for (w = 0; w < width; w++) {
+    for (h = 0; h < height; h++) {
+      int center_num_direct = 4;
+      const int center_idx = w * height + h;
+      if (w == 0 || w == width - 1) center_num_direct--;
+      if (h == 0 || h == height - 1) center_num_direct--;
+      // diagonal entry for this row from the center pixel
+      double cor_w = center_num_direct * center_num_direct + center_num_direct;
+      row_pos[c] = center_idx;
+      col_pos[c] = center_idx;
+      values[c] = lambda * cor_w;
+      c++;
+      row_pos[c] = center_idx + offset;
+      col_pos[c] = center_idx + offset;
+      values[c] = lambda * cor_w;
+      c++;
+      // other entries from direct neighbors
+      for (k = 0; k < 4; k++) {
+        checkh = h + check_locs_y[k];
+        checkw = w + check_locs_x[k];
+        if (checkh < 0 || checkh >= height || checkw < 0 || checkw >= width) {
+          continue;
+        }
+        int this_idx = checkw * height + checkh;
+        int this_num_direct = 4;
+        if (checkw == 0 || checkw == width - 1) this_num_direct--;
+        if (checkh == 0 || checkh == height - 1) this_num_direct--;
+        cor_w = -center_num_direct - this_num_direct;
+        row_pos[c] = center_idx;
+        col_pos[c] = this_idx;
+        values[c] = lambda * cor_w;
+        c++;
+        row_pos[c] = center_idx + offset;
+        col_pos[c] = this_idx + offset;
+        values[c] = lambda * cor_w;
+        c++;
+      }
+      // entries from neighbors on the diagonal corners
+      for (k = 4; k < 8; k++) {
+        checkh = h + check_locs_y[k];
+        checkw = w + check_locs_x[k];
+        if (checkh < 0 || checkh >= height || checkw < 0 || checkw >= width) {
+          continue;
+        }
+        int this_idx = checkw * height + checkh;
+        cor_w = 2;
+        row_pos[c] = center_idx;
+        col_pos[c] = this_idx;
+        values[c] = lambda * cor_w;
+        c++;
+        row_pos[c] = center_idx + offset;
+        col_pos[c] = this_idx + offset;
+        values[c] = lambda * cor_w;
+        c++;
+      }
+      // entries from neighbors with dist of 2
+      for (k = 8; k < 12; k++) {
+        checkh = h + check_locs_y[k];
+        checkw = w + check_locs_x[k];
+        if (checkh < 0 || checkh >= height || checkw < 0 || checkw >= width) {
+          continue;
+        }
+        int this_idx = checkw * height + checkh;
+        cor_w = 1;
+        row_pos[c] = center_idx;
+        col_pos[c] = this_idx;
+        values[c] = lambda * cor_w;
+        c++;
+        row_pos[c] = center_idx + offset;
+        col_pos[c] = this_idx + offset;
+        values[c] = lambda * cor_w;
+        c++;
+      }
+    }
+  }
+  av1_init_sparse_mtx(row_pos, col_pos, values, c, 2 * width * height,
+                      2 * width * height, &A);
+  // substract init mv part from b
+  av1_mtx_vect_multi_left(&A, mv_init_vec, temp_b, 2 * width * height);
+  for (int i = 0; i < 2 * width * height; i++) {
+    b[i] = -temp_b[i];
+  }
+  av1_free_sparse_mtx_elems(&A);
+
+  // add cross terms to A and modify b with ExEt / EyEt
+  for (w = 0; w < width; w++) {
+    for (h = 0; h < height; h++) {
+      int curidx = w * height + h;
+      // modify b
+      b[curidx] += -ix[h * grad_stride + w] * it[h * grad_stride + w];
+      b[curidx + offset] += -iy[h * grad_stride + w] * it[h * grad_stride + w];
+      // add cross terms to A
+      row_pos[c] = curidx;
+      col_pos[c] = curidx + offset;
+      values[c] = ix[h * grad_stride + w] * iy[h * grad_stride + w];
+      c++;
+      row_pos[c] = curidx + offset;
+      col_pos[c] = curidx;
+      values[c] = ix[h * grad_stride + w] * iy[h * grad_stride + w];
+      c++;
+    }
+  }
+  // Add diagonal terms to A
+  for (int i = 0; i < c; i++) {
+    if (row_pos[i] == col_pos[i]) {
+      if (row_pos[i] < offset) {
+        w = row_pos[i] / height;
+        h = row_pos[i] % height;
+        values[i] += pow(ix[h * grad_stride + w], 2);
+      } else {
+        w = (row_pos[i] - offset) / height;
+        h = (row_pos[i] - offset) % height;
+        values[i] += pow(iy[h * grad_stride + w], 2);
+      }
+    }
+  }
+
+  av1_init_sparse_mtx(row_pos, col_pos, values, c, 2 * width * height,
+                      2 * width * height, &A);
+
+  // solve for the mvs
+  av1_conjugate_gradient_sparse(&A, b, 2 * width * height, mv_vec);
+  // copy mvs
+  for (w = 0; w < width; w++) {
+    for (h = 0; h < height; h++) {
+      mvs[h * mv_stride + w].col = mv_vec[w * height + h];
+      mvs[h * mv_stride + w].row = mv_vec[w * height + h + offset];
+    }
+  }
+  aom_free(row_pos);
+  aom_free(col_pos);
+  aom_free(values);
+  aom_free(mv_vec);
+  aom_free(mv_init_vec);
+  aom_free(b);
+  aom_free(temp_b);
+  av1_free_sparse_mtx_elems(&A);
+}
+
+// Calculate optical flow from from_frame to to_frame using the H-S method.
+void horn_schunck(const YV12_BUFFER_CONFIG *from_frame,
+                  const YV12_BUFFER_CONFIG *to_frame, const int level,
+                  const int mv_stride, const int mv_height, const int mv_width,
+                  const OPFL_PARAMS *opfl_params, LOCALMV *mvs) {
+  // mvs are always on level 0, here we define two new mv arrays that is of size
+  // of this level.
+  const int fw = from_frame->y_crop_width;
+  const int fh = from_frame->y_crop_height;
+  const int factor = (int)pow(2, level);
+  int w, h, k, init_mv_stride;
+  LOCALMV *init_mvs;
+  if (level == 0) {
+    init_mvs = mvs;
+    init_mv_stride = mv_stride;
+  } else {
+    init_mvs = aom_calloc(fw * fh, sizeof(*mvs));
+    init_mv_stride = fw;
+    for (h = 0; h < fh; h++) {
+      for (w = 0; w < fw; w++) {
+        init_mvs[h * init_mv_stride + w].row =
+            mvs[h * factor * mv_stride + w * factor].row / (double)factor;
+        init_mvs[h * init_mv_stride + w].col =
+            mvs[h * factor * mv_stride + w * factor].col / (double)factor;
+      }
+    }
+  }
+  LOCALMV *refine_mvs = aom_calloc(fw * fh, sizeof(*mvs));
+  // temp frame for warping
+  YV12_BUFFER_CONFIG temp_frame;
+  temp_frame.y_buffer =
+      (uint8_t *)aom_calloc(fh * fw, sizeof(*temp_frame.y_buffer));
+  temp_frame.y_crop_height = fh;
+  temp_frame.y_crop_width = fw;
+  temp_frame.y_stride = fw;
+  // gradient buffers
+  double *ix = aom_calloc(fw * fh, sizeof(*ix));
+  double *iy = aom_calloc(fw * fh, sizeof(*iy));
+  double *it = aom_calloc(fw * fh, sizeof(*it));
+  // For each warping step
+  for (k = 0; k < opfl_params->warping_steps; k++) {
+    // warp from_frame with init_mv
+    if (level == 0) {
+      warp_back_frame_intp(&temp_frame, to_frame, init_mvs, init_mv_stride);
+    } else {
+      warp_back_frame(&temp_frame, to_frame, init_mvs, init_mv_stride);
+    }
+    // calculate frame gradients
+    get_frame_gradients(from_frame, &temp_frame, ix, iy, it, fw);
+    // form linear equations and solve mvs
+    solve_horn_schunck(ix, iy, it, fw, fw, fh, init_mvs, init_mv_stride,
+                       refine_mvs, fw);
+    // update init_mvs
+    for (h = 0; h < fh; h++) {
+      for (w = 0; w < fw; w++) {
+        init_mvs[h * init_mv_stride + w].col += refine_mvs[h * fw + w].col;
+        init_mvs[h * init_mv_stride + w].row += refine_mvs[h * fw + w].row;
+      }
+    }
+  }
+  // copy back the mvs if needed
+  if (level != 0) {
+    for (h = 0; h < mv_height; h++) {
+      for (w = 0; w < mv_width; w++) {
+        mvs[h * mv_stride + w].row =
+            init_mvs[h / factor * init_mv_stride + w / factor].row *
+            (double)factor;
+        mvs[h * mv_stride + w].col =
+            init_mvs[h / factor * init_mv_stride + w / factor].col *
+            (double)factor;
+      }
+    }
+  }
+  if (level != 0) aom_free(init_mvs);
+  aom_free(refine_mvs);
+  aom_free(temp_frame.y_buffer);
+  aom_free(ix);
+  aom_free(iy);
+  aom_free(it);
+}
+
 // Apply optical flow iteratively at each pyramid level
 static void pyramid_optical_flow(const YV12_BUFFER_CONFIG *from_frame,
                                  const YV12_BUFFER_CONFIG *to_frame,
@@ -633,6 +1027,11 @@
       lucas_kanade(&buffers1[i], &buffers2[i], i, opfl_params->lk_params,
                    num_ref_corners, ref_corners, buffers1[0].y_crop_width,
                    bit_depth, mvs);
+    } else if (method == HORN_SCHUNCK) {
+      assert(!is_sparse(opfl_params));
+      horn_schunck(&buffers1[i], &buffers2[i], i, buffers1[0].y_crop_width,
+                   buffers1[0].y_crop_height, buffers1[0].y_crop_width,
+                   opfl_params, mvs);
     }
   }
   for (int i = 1; i < levels; i++) {
@@ -654,7 +1053,7 @@
 //   bit_depth:
 //   opfl_params: contains algorithm-specific parameters.
 //   mv_filter: MV_FILTER_NONE, MV_FILTER_SMOOTH, or MV_FILTER_MEDIAN.
-//   method: LUCAS_KANADE,
+//   method: LUCAS_KANADE, HORN_SCHUNCK
 //   mvs: pointer to MVs. Contains initialization, and modified
 //   based on optical flow. Must have
 //   dimensions = from_frame->y_crop_width * from_frame->y_crop_height
diff --git a/av1/encoder/optical_flow.h b/av1/encoder/optical_flow.h
index a54b6a7..2fbe474 100644
--- a/av1/encoder/optical_flow.h
+++ b/av1/encoder/optical_flow.h
@@ -22,7 +22,7 @@
 
 #if CONFIG_OPTICAL_FLOW_API
 
-typedef enum { LUCAS_KANADE } OPTFLOW_METHOD;
+typedef enum { LUCAS_KANADE, HORN_SCHUNCK } OPTFLOW_METHOD;
 
 typedef enum {
   MV_FILTER_NONE,
@@ -39,6 +39,7 @@
 // default options for optical flow
 #define OPFL_WINDOW_SIZE 15
 #define OPFL_PYRAMID_LEVELS 3  // total levels
+#define OPFL_WARPING_STEPS 3
 
 // parameters specific to Lucas-Kanade
 typedef struct lk_params {
@@ -49,6 +50,7 @@
 // optical flow algorithms
 typedef struct opfl_params {
   int pyramid_levels;
+  int warping_steps;
   LK_PARAMS *lk_params;
   int flags;
 } OPFL_PARAMS;