Add horn-schunck method to the optical flow API.
Change-Id: Iecfd061169bcfeca1bbeb46778a53c1330b7edd4
diff --git a/av1/encoder/optical_flow.c b/av1/encoder/optical_flow.c
index 00f91af..82ae9c5 100644
--- a/av1/encoder/optical_flow.c
+++ b/av1/encoder/optical_flow.c
@@ -16,6 +16,7 @@
#include "av1/encoder/encoder.h"
#include "av1/encoder/mathutils.h"
#include "av1/encoder/optical_flow.h"
+#include "av1/encoder/sparse_linear_solver.h"
#include "av1/encoder/reconinter_enc.h"
#include "aom_mem/aom_mem.h"
@@ -23,6 +24,7 @@
void av1_init_opfl_params(OPFL_PARAMS *opfl_params) {
opfl_params->pyramid_levels = OPFL_PYRAMID_LEVELS;
+ opfl_params->warping_steps = OPFL_WARPING_STEPS;
opfl_params->lk_params = NULL;
}
@@ -568,6 +570,398 @@
aom_free(i_y);
}
+// Warp the src_frame to warper_frame according to mvs.
+// mvs point to src_frame
+static void warp_back_frame(YV12_BUFFER_CONFIG *warped_frame,
+ const YV12_BUFFER_CONFIG *src_frame,
+ const LOCALMV *mvs, int mv_stride) {
+ int w, h;
+ const int fw = src_frame->y_crop_width;
+ const int fh = src_frame->y_crop_height;
+ const int src_fs = src_frame->y_stride, warped_fs = warped_frame->y_stride;
+ const uint8_t *src_buf = src_frame->y_buffer;
+ uint8_t *warped_buf = warped_frame->y_buffer;
+ double temp;
+ for (h = 0; h < fh; h++) {
+ for (w = 0; w < fw; w++) {
+ double cord_x = (double)w + mvs[h * mv_stride + w].col;
+ double cord_y = (double)h + mvs[h * mv_stride + w].row;
+ cord_x = fclamp(cord_x, 0, (double)(fw - 1));
+ cord_y = fclamp(cord_y, 0, (double)(fh - 1));
+ const int floorx = (int)floor(cord_x);
+ const int floory = (int)floor(cord_y);
+ const double fracx = cord_x - (double)floorx;
+ const double fracy = cord_y - (double)floory;
+
+ temp = 0;
+ for (int hh = 0; hh < 2; hh++) {
+ const double weighth = hh ? (fracy) : (1 - fracy);
+ for (int ww = 0; ww < 2; ww++) {
+ const double weightw = ww ? (fracx) : (1 - fracx);
+ int y = floory + hh;
+ int x = floorx + ww;
+ y = clamp(y, 0, fh - 1);
+ x = clamp(x, 0, fw - 1);
+ temp += (double)src_buf[y * src_fs + x] * weightw * weighth;
+ }
+ }
+ warped_buf[h * warped_fs + w] = (uint8_t)round(temp);
+ }
+ }
+}
+
+// Same as warp_back_frame, but using a better interpolation filter.
+static void warp_back_frame_intp(YV12_BUFFER_CONFIG *warped_frame,
+ const YV12_BUFFER_CONFIG *src_frame,
+ const LOCALMV *mvs, int mv_stride) {
+ int w, h;
+ const int fw = src_frame->y_crop_width;
+ const int fh = src_frame->y_crop_height;
+ const int warped_fs = warped_frame->y_stride;
+ uint8_t *warped_buf = warped_frame->y_buffer;
+ const int blk = 2;
+ uint8_t temp_blk[4];
+
+ const int is_intrabc = 0; // Is intra-copied?
+ const int is_high_bitdepth = is_frame_high_bitdepth(src_frame);
+ const int subsampling_x = 0, subsampling_y = 0; // for y-buffer
+ const int_interpfilters interp_filters =
+ av1_broadcast_interp_filter(MULTITAP_SHARP2);
+ const int plane = 0; // y-plane
+ const struct buf_2d ref_buf2 = { NULL, src_frame->y_buffer,
+ src_frame->y_crop_width,
+ src_frame->y_crop_height,
+ src_frame->y_stride };
+ const int bit_depth = src_frame->bit_depth;
+ struct scale_factors scale;
+ av1_setup_scale_factors_for_frame(
+ &scale, src_frame->y_crop_width, src_frame->y_crop_height,
+ src_frame->y_crop_width, src_frame->y_crop_height);
+
+ for (h = 0; h < fh; h++) {
+ for (w = 0; w < fw; w++) {
+ InterPredParams inter_pred_params;
+ av1_init_inter_params(&inter_pred_params, blk, blk, h, w, subsampling_x,
+ subsampling_y, bit_depth, is_high_bitdepth,
+ is_intrabc, &scale, &ref_buf2, interp_filters);
+ inter_pred_params.interp_filter_params[0] =
+ &av1_interp_filter_params_list[interp_filters.as_filters.x_filter];
+ inter_pred_params.interp_filter_params[1] =
+ &av1_interp_filter_params_list[interp_filters.as_filters.y_filter];
+ inter_pred_params.conv_params = get_conv_params(0, plane, bit_depth);
+ MV newmv = { .row = (int16_t)round((mvs[h * mv_stride + w].row) * 8),
+ .col = (int16_t)round((mvs[h * mv_stride + w].col) * 8) };
+ av1_enc_build_one_inter_predictor(temp_blk, blk, &newmv,
+ &inter_pred_params);
+ warped_buf[h * warped_fs + w] = temp_blk[0];
+ }
+ }
+}
+
+#define DERIVATIVE_FILTER_LENGTH 7
+double filter[DERIVATIVE_FILTER_LENGTH] = { -1.0 / 60, 9.0 / 60, -45.0 / 60, 0,
+ 45.0 / 60, -9.0 / 60, 1.0 / 60 };
+
+// Get gradient of the whole frame
+static void get_frame_gradients(const YV12_BUFFER_CONFIG *from_frame,
+ const YV12_BUFFER_CONFIG *to_frame, double *ix,
+ double *iy, double *it, int grad_stride) {
+ int w, h, k, idx;
+ const int fw = from_frame->y_crop_width;
+ const int fh = from_frame->y_crop_height;
+ const int from_fs = from_frame->y_stride, to_fs = to_frame->y_stride;
+ const uint8_t *from_buf = from_frame->y_buffer;
+ const uint8_t *to_buf = to_frame->y_buffer;
+
+ const int lh = DERIVATIVE_FILTER_LENGTH;
+ const int hleft = (lh - 1) / 2;
+
+ for (h = 0; h < fh; h++) {
+ for (w = 0; w < fw; w++) {
+ // x
+ ix[h * grad_stride + w] = 0;
+ for (k = 0; k < lh; k++) {
+ // if we want to make this block dependent, need to extend the
+ // boundaries using other initializations.
+ idx = w + k - hleft;
+ idx = clamp(idx, 0, fw - 1);
+ ix[h * grad_stride + w] += filter[k] * 0.5 *
+ ((double)from_buf[h * from_fs + idx] +
+ (double)to_buf[h * to_fs + idx]);
+ }
+ // y
+ iy[h * grad_stride + w] = 0;
+ for (k = 0; k < lh; k++) {
+ // if we want to make this block dependent, need to extend the
+ // boundaries using other initializations.
+ idx = h + k - hleft;
+ idx = clamp(idx, 0, fh - 1);
+ iy[h * grad_stride + w] += filter[k] * 0.5 *
+ ((double)from_buf[idx * from_fs + w] +
+ (double)to_buf[idx * to_fs + w]);
+ }
+ // t
+ it[h * grad_stride + w] =
+ (double)to_buf[h * to_fs + w] - (double)from_buf[h * from_fs + w];
+ }
+ }
+}
+
+// Solve for linear equations given by the H-S method
+static void solve_horn_schunck(const double *ix, const double *iy,
+ const double *it, int grad_stride, int width,
+ int height, const LOCALMV *init_mvs,
+ int init_mv_stride, LOCALMV *mvs,
+ int mv_stride) {
+ // TODO(bohanli): May just need to allocate the buffers once per optical flow
+ // calculation
+ int *row_pos = aom_calloc(width * height * 28, sizeof(*row_pos));
+ int *col_pos = aom_calloc(width * height * 28, sizeof(*col_pos));
+ double *values = aom_calloc(width * height * 28, sizeof(*values));
+ double *mv_vec = aom_calloc(width * height * 2, sizeof(*mv_vec));
+ double *mv_init_vec = aom_calloc(width * height * 2, sizeof(*mv_init_vec));
+ double *temp_b = aom_calloc(width * height * 2, sizeof(*temp_b));
+ double *b = aom_calloc(width * height * 2, sizeof(*b));
+
+ // the location idx for neighboring pixels, k < 4 are the 4 direct neighbors
+ const int check_locs_y[12] = { 0, 0, -1, 1, -1, -1, 1, 1, 0, 0, -2, 2 };
+ const int check_locs_x[12] = { -1, 1, 0, 0, -1, 1, -1, 1, -2, 2, 0, 0 };
+
+ int h, w, checkh, checkw, k;
+ const int offset = height * width;
+ SPARSE_MTX A;
+ int c = 0;
+ const double lambda = 100;
+
+ for (w = 0; w < width; w++) {
+ for (h = 0; h < height; h++) {
+ mv_init_vec[w * height + h] = init_mvs[h * init_mv_stride + w].col;
+ mv_init_vec[w * height + h + offset] =
+ init_mvs[h * init_mv_stride + w].row;
+ }
+ }
+
+ // get matrix A
+ for (w = 0; w < width; w++) {
+ for (h = 0; h < height; h++) {
+ int center_num_direct = 4;
+ const int center_idx = w * height + h;
+ if (w == 0 || w == width - 1) center_num_direct--;
+ if (h == 0 || h == height - 1) center_num_direct--;
+ // diagonal entry for this row from the center pixel
+ double cor_w = center_num_direct * center_num_direct + center_num_direct;
+ row_pos[c] = center_idx;
+ col_pos[c] = center_idx;
+ values[c] = lambda * cor_w;
+ c++;
+ row_pos[c] = center_idx + offset;
+ col_pos[c] = center_idx + offset;
+ values[c] = lambda * cor_w;
+ c++;
+ // other entries from direct neighbors
+ for (k = 0; k < 4; k++) {
+ checkh = h + check_locs_y[k];
+ checkw = w + check_locs_x[k];
+ if (checkh < 0 || checkh >= height || checkw < 0 || checkw >= width) {
+ continue;
+ }
+ int this_idx = checkw * height + checkh;
+ int this_num_direct = 4;
+ if (checkw == 0 || checkw == width - 1) this_num_direct--;
+ if (checkh == 0 || checkh == height - 1) this_num_direct--;
+ cor_w = -center_num_direct - this_num_direct;
+ row_pos[c] = center_idx;
+ col_pos[c] = this_idx;
+ values[c] = lambda * cor_w;
+ c++;
+ row_pos[c] = center_idx + offset;
+ col_pos[c] = this_idx + offset;
+ values[c] = lambda * cor_w;
+ c++;
+ }
+ // entries from neighbors on the diagonal corners
+ for (k = 4; k < 8; k++) {
+ checkh = h + check_locs_y[k];
+ checkw = w + check_locs_x[k];
+ if (checkh < 0 || checkh >= height || checkw < 0 || checkw >= width) {
+ continue;
+ }
+ int this_idx = checkw * height + checkh;
+ cor_w = 2;
+ row_pos[c] = center_idx;
+ col_pos[c] = this_idx;
+ values[c] = lambda * cor_w;
+ c++;
+ row_pos[c] = center_idx + offset;
+ col_pos[c] = this_idx + offset;
+ values[c] = lambda * cor_w;
+ c++;
+ }
+ // entries from neighbors with dist of 2
+ for (k = 8; k < 12; k++) {
+ checkh = h + check_locs_y[k];
+ checkw = w + check_locs_x[k];
+ if (checkh < 0 || checkh >= height || checkw < 0 || checkw >= width) {
+ continue;
+ }
+ int this_idx = checkw * height + checkh;
+ cor_w = 1;
+ row_pos[c] = center_idx;
+ col_pos[c] = this_idx;
+ values[c] = lambda * cor_w;
+ c++;
+ row_pos[c] = center_idx + offset;
+ col_pos[c] = this_idx + offset;
+ values[c] = lambda * cor_w;
+ c++;
+ }
+ }
+ }
+ av1_init_sparse_mtx(row_pos, col_pos, values, c, 2 * width * height,
+ 2 * width * height, &A);
+ // substract init mv part from b
+ av1_mtx_vect_multi_left(&A, mv_init_vec, temp_b, 2 * width * height);
+ for (int i = 0; i < 2 * width * height; i++) {
+ b[i] = -temp_b[i];
+ }
+ av1_free_sparse_mtx_elems(&A);
+
+ // add cross terms to A and modify b with ExEt / EyEt
+ for (w = 0; w < width; w++) {
+ for (h = 0; h < height; h++) {
+ int curidx = w * height + h;
+ // modify b
+ b[curidx] += -ix[h * grad_stride + w] * it[h * grad_stride + w];
+ b[curidx + offset] += -iy[h * grad_stride + w] * it[h * grad_stride + w];
+ // add cross terms to A
+ row_pos[c] = curidx;
+ col_pos[c] = curidx + offset;
+ values[c] = ix[h * grad_stride + w] * iy[h * grad_stride + w];
+ c++;
+ row_pos[c] = curidx + offset;
+ col_pos[c] = curidx;
+ values[c] = ix[h * grad_stride + w] * iy[h * grad_stride + w];
+ c++;
+ }
+ }
+ // Add diagonal terms to A
+ for (int i = 0; i < c; i++) {
+ if (row_pos[i] == col_pos[i]) {
+ if (row_pos[i] < offset) {
+ w = row_pos[i] / height;
+ h = row_pos[i] % height;
+ values[i] += pow(ix[h * grad_stride + w], 2);
+ } else {
+ w = (row_pos[i] - offset) / height;
+ h = (row_pos[i] - offset) % height;
+ values[i] += pow(iy[h * grad_stride + w], 2);
+ }
+ }
+ }
+
+ av1_init_sparse_mtx(row_pos, col_pos, values, c, 2 * width * height,
+ 2 * width * height, &A);
+
+ // solve for the mvs
+ av1_conjugate_gradient_sparse(&A, b, 2 * width * height, mv_vec);
+ // copy mvs
+ for (w = 0; w < width; w++) {
+ for (h = 0; h < height; h++) {
+ mvs[h * mv_stride + w].col = mv_vec[w * height + h];
+ mvs[h * mv_stride + w].row = mv_vec[w * height + h + offset];
+ }
+ }
+ aom_free(row_pos);
+ aom_free(col_pos);
+ aom_free(values);
+ aom_free(mv_vec);
+ aom_free(mv_init_vec);
+ aom_free(b);
+ aom_free(temp_b);
+ av1_free_sparse_mtx_elems(&A);
+}
+
+// Calculate optical flow from from_frame to to_frame using the H-S method.
+void horn_schunck(const YV12_BUFFER_CONFIG *from_frame,
+ const YV12_BUFFER_CONFIG *to_frame, const int level,
+ const int mv_stride, const int mv_height, const int mv_width,
+ const OPFL_PARAMS *opfl_params, LOCALMV *mvs) {
+ // mvs are always on level 0, here we define two new mv arrays that is of size
+ // of this level.
+ const int fw = from_frame->y_crop_width;
+ const int fh = from_frame->y_crop_height;
+ const int factor = (int)pow(2, level);
+ int w, h, k, init_mv_stride;
+ LOCALMV *init_mvs;
+ if (level == 0) {
+ init_mvs = mvs;
+ init_mv_stride = mv_stride;
+ } else {
+ init_mvs = aom_calloc(fw * fh, sizeof(*mvs));
+ init_mv_stride = fw;
+ for (h = 0; h < fh; h++) {
+ for (w = 0; w < fw; w++) {
+ init_mvs[h * init_mv_stride + w].row =
+ mvs[h * factor * mv_stride + w * factor].row / (double)factor;
+ init_mvs[h * init_mv_stride + w].col =
+ mvs[h * factor * mv_stride + w * factor].col / (double)factor;
+ }
+ }
+ }
+ LOCALMV *refine_mvs = aom_calloc(fw * fh, sizeof(*mvs));
+ // temp frame for warping
+ YV12_BUFFER_CONFIG temp_frame;
+ temp_frame.y_buffer =
+ (uint8_t *)aom_calloc(fh * fw, sizeof(*temp_frame.y_buffer));
+ temp_frame.y_crop_height = fh;
+ temp_frame.y_crop_width = fw;
+ temp_frame.y_stride = fw;
+ // gradient buffers
+ double *ix = aom_calloc(fw * fh, sizeof(*ix));
+ double *iy = aom_calloc(fw * fh, sizeof(*iy));
+ double *it = aom_calloc(fw * fh, sizeof(*it));
+ // For each warping step
+ for (k = 0; k < opfl_params->warping_steps; k++) {
+ // warp from_frame with init_mv
+ if (level == 0) {
+ warp_back_frame_intp(&temp_frame, to_frame, init_mvs, init_mv_stride);
+ } else {
+ warp_back_frame(&temp_frame, to_frame, init_mvs, init_mv_stride);
+ }
+ // calculate frame gradients
+ get_frame_gradients(from_frame, &temp_frame, ix, iy, it, fw);
+ // form linear equations and solve mvs
+ solve_horn_schunck(ix, iy, it, fw, fw, fh, init_mvs, init_mv_stride,
+ refine_mvs, fw);
+ // update init_mvs
+ for (h = 0; h < fh; h++) {
+ for (w = 0; w < fw; w++) {
+ init_mvs[h * init_mv_stride + w].col += refine_mvs[h * fw + w].col;
+ init_mvs[h * init_mv_stride + w].row += refine_mvs[h * fw + w].row;
+ }
+ }
+ }
+ // copy back the mvs if needed
+ if (level != 0) {
+ for (h = 0; h < mv_height; h++) {
+ for (w = 0; w < mv_width; w++) {
+ mvs[h * mv_stride + w].row =
+ init_mvs[h / factor * init_mv_stride + w / factor].row *
+ (double)factor;
+ mvs[h * mv_stride + w].col =
+ init_mvs[h / factor * init_mv_stride + w / factor].col *
+ (double)factor;
+ }
+ }
+ }
+ if (level != 0) aom_free(init_mvs);
+ aom_free(refine_mvs);
+ aom_free(temp_frame.y_buffer);
+ aom_free(ix);
+ aom_free(iy);
+ aom_free(it);
+}
+
// Apply optical flow iteratively at each pyramid level
static void pyramid_optical_flow(const YV12_BUFFER_CONFIG *from_frame,
const YV12_BUFFER_CONFIG *to_frame,
@@ -633,6 +1027,11 @@
lucas_kanade(&buffers1[i], &buffers2[i], i, opfl_params->lk_params,
num_ref_corners, ref_corners, buffers1[0].y_crop_width,
bit_depth, mvs);
+ } else if (method == HORN_SCHUNCK) {
+ assert(!is_sparse(opfl_params));
+ horn_schunck(&buffers1[i], &buffers2[i], i, buffers1[0].y_crop_width,
+ buffers1[0].y_crop_height, buffers1[0].y_crop_width,
+ opfl_params, mvs);
}
}
for (int i = 1; i < levels; i++) {
@@ -654,7 +1053,7 @@
// bit_depth:
// opfl_params: contains algorithm-specific parameters.
// mv_filter: MV_FILTER_NONE, MV_FILTER_SMOOTH, or MV_FILTER_MEDIAN.
-// method: LUCAS_KANADE,
+// method: LUCAS_KANADE, HORN_SCHUNCK
// mvs: pointer to MVs. Contains initialization, and modified
// based on optical flow. Must have
// dimensions = from_frame->y_crop_width * from_frame->y_crop_height
diff --git a/av1/encoder/optical_flow.h b/av1/encoder/optical_flow.h
index a54b6a7..2fbe474 100644
--- a/av1/encoder/optical_flow.h
+++ b/av1/encoder/optical_flow.h
@@ -22,7 +22,7 @@
#if CONFIG_OPTICAL_FLOW_API
-typedef enum { LUCAS_KANADE } OPTFLOW_METHOD;
+typedef enum { LUCAS_KANADE, HORN_SCHUNCK } OPTFLOW_METHOD;
typedef enum {
MV_FILTER_NONE,
@@ -39,6 +39,7 @@
// default options for optical flow
#define OPFL_WINDOW_SIZE 15
#define OPFL_PYRAMID_LEVELS 3 // total levels
+#define OPFL_WARPING_STEPS 3
// parameters specific to Lucas-Kanade
typedef struct lk_params {
@@ -49,6 +50,7 @@
// optical flow algorithms
typedef struct opfl_params {
int pyramid_levels;
+ int warping_steps;
LK_PARAMS *lk_params;
int flags;
} OPFL_PARAMS;