Move simple_motion_search_based_split to a partition_strategy.h

BUG=aomedia:2343

Change-Id: I98db8765c46b9eb232f706bee0c891097fa56d33
diff --git a/av1/av1.cmake b/av1/av1.cmake
index a590ac3..fb9678a 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -172,6 +172,8 @@
             "${AOM_ROOT}/av1/encoder/ml.h"
             "${AOM_ROOT}/av1/encoder/palette.c"
             "${AOM_ROOT}/av1/encoder/palette.h"
+            "${AOM_ROOT}/av1/encoder/partition_strategy.h"
+            "${AOM_ROOT}/av1/encoder/partition_strategy.c"
             "${AOM_ROOT}/av1/encoder/pass2_strategy.h"
             "${AOM_ROOT}/av1/encoder/pass2_strategy.c"
             "${AOM_ROOT}/av1/encoder/pickcdef.c"
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index c70ede4..c81f9ef 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -55,6 +55,7 @@
 #include "av1/encoder/ethread.h"
 #include "av1/encoder/extend.h"
 #include "av1/encoder/ml.h"
+#include "av1/encoder/partition_strategy.h"
 #include "av1/encoder/partition_model_weights.h"
 #include "av1/encoder/rd.h"
 #include "av1/encoder/rdopt.h"
@@ -71,10 +72,6 @@
                                const MACROBLOCK *const x,
                                const RD_STATS *const rd_stats,
                                unsigned int pb_source_variance);
-static void simple_motion_search(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
-                                 int mi_col, BLOCK_SIZE bsize, int ref,
-                                 MV ref_mv_full, int num_planes,
-                                 int use_subpixel);
 static void firstpass_simple_motion_search_features(
     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
     int mi_col, BLOCK_SIZE bsize, float *features);
@@ -212,25 +209,6 @@
     return BLOCK_8X8;
 }
 
-void av1_simple_motion_sse_var(AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
-                               int mi_col, BLOCK_SIZE bsize,
-                               const MV ref_mv_full, int use_subpixel,
-                               unsigned int *sse, unsigned int *var) {
-  MACROBLOCKD *xd = &x->e_mbd;
-  const MV_REFERENCE_FRAME ref =
-      cpi->rc.is_src_frame_alt_ref ? ALTREF_FRAME : LAST_FRAME;
-
-  simple_motion_search(cpi, x, mi_row, mi_col, bsize, ref, ref_mv_full, 1,
-                       use_subpixel);
-
-  const uint8_t *src = x->plane[0].src.buf;
-  const int src_stride = x->plane[0].src.stride;
-  const uint8_t *dst = xd->plane[0].dst.buf;
-  const int dst_stride = xd->plane[0].dst.stride;
-
-  *var = cpi->fn_ptr[bsize].vf(src, src_stride, dst, dst_stride, sse);
-}
-
 static void set_offsets_without_segment_id(const AV1_COMP *const cpi,
                                            const TileInfo *const tile,
                                            MACROBLOCK *const x, int mi_row,
@@ -304,47 +282,6 @@
   }
 }
 
-// A simplified version of set_offsets meant to be used for
-// simple_motion_search.
-static void set_offsets_for_motion_search(const AV1_COMP *const cpi,
-                                          MACROBLOCK *const x, int mi_row,
-                                          int mi_col, BLOCK_SIZE bsize) {
-  const AV1_COMMON *const cm = &cpi->common;
-  const int num_planes = av1_num_planes(cm);
-  MACROBLOCKD *const xd = &x->e_mbd;
-  const int mi_width = mi_size_wide[bsize];
-  const int mi_height = mi_size_high[bsize];
-
-  set_mode_info_offsets(cpi, x, xd, mi_row, mi_col);
-
-  // Set up destination pointers.
-  av1_setup_dst_planes(xd->plane, bsize, &cm->cur_frame->buf, mi_row, mi_col, 0,
-                       num_planes);
-
-  // Set up limit values for MV components.
-  // Mv beyond the range do not produce new/different prediction block.
-  x->mv_limits.row_min =
-      -(((mi_row + mi_height) * MI_SIZE) + AOM_INTERP_EXTEND);
-  x->mv_limits.col_min = -(((mi_col + mi_width) * MI_SIZE) + AOM_INTERP_EXTEND);
-  x->mv_limits.row_max = (cm->mi_rows - mi_row) * MI_SIZE + AOM_INTERP_EXTEND;
-  x->mv_limits.col_max = (cm->mi_cols - mi_col) * MI_SIZE + AOM_INTERP_EXTEND;
-
-  set_plane_n4(xd, mi_width, mi_height, num_planes);
-
-  // Set up distance of MB to edge of frame in 1/8th pel units.
-  assert(!(mi_col & (mi_width - 1)) && !(mi_row & (mi_height - 1)));
-  xd->mb_to_top_edge = -((mi_row * MI_SIZE) * 8);
-  xd->mb_to_bottom_edge = ((cm->mi_rows - mi_height - mi_row) * MI_SIZE) * 8;
-  xd->mb_to_left_edge = -((mi_col * MI_SIZE) * 8);
-  xd->mb_to_right_edge = ((cm->mi_cols - mi_width - mi_col) * MI_SIZE) * 8;
-
-  // Set up source buffers.
-  av1_setup_src_planes(x, cpi->source, mi_row, mi_col, num_planes, bsize);
-
-  // R/D setup.
-  x->rdmult = cpi->rd.RDMULT;
-}
-
 static void update_filter_type_count(uint8_t allow_update_cdf,
                                      FRAME_COUNTS *counts,
                                      const MACROBLOCKD *xd,
@@ -2237,104 +2174,6 @@
   return is_active_v_edge;
 }
 
-// Performs a motion search in SIMPLE_TRANSLATION mode using reference frame
-// ref. Note that this sets the offset of mbmi, so we will need to reset it
-// after calling this function.
-static void simple_motion_search(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
-                                 int mi_col, BLOCK_SIZE bsize, int ref,
-                                 MV ref_mv_full, int num_planes,
-                                 int use_subpixel) {
-  assert(num_planes == 1 &&
-         "Currently simple_motion_search only supports luma plane");
-  assert(!frame_is_intra_only(&cpi->common) &&
-         "Simple motion search only enabled for non-key frames");
-  AV1_COMMON *const cm = &cpi->common;
-  MACROBLOCKD *xd = &x->e_mbd;
-
-  set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
-
-  MB_MODE_INFO *mbmi = xd->mi[0];
-  mbmi->sb_type = bsize;
-  mbmi->ref_frame[0] = ref;
-  mbmi->ref_frame[1] = NONE_FRAME;
-  mbmi->motion_mode = SIMPLE_TRANSLATION;
-
-  const YV12_BUFFER_CONFIG *yv12 = get_ref_frame_yv12_buf(cm, ref);
-  const YV12_BUFFER_CONFIG *scaled_ref_frame =
-      av1_get_scaled_ref_frame(cpi, ref);
-  struct buf_2d backup_yv12;
-  // ref_mv is used to code the motion vector. ref_mv_full is the initial point.
-  // ref_mv is in units of 1/8 pel whereas ref_mv_full is in units of pel.
-  MV ref_mv = { 0, 0 };
-  const int step_param = cpi->mv_step_param;
-  const MvLimits tmp_mv_limits = x->mv_limits;
-  const SEARCH_METHODS search_methods = NSTEP;
-  const int do_mesh_search = 0;
-  const int sadpb = x->sadperbit16;
-  int cost_list[5];
-  const int ref_idx = 0;
-  int var;
-
-  if (scaled_ref_frame) {
-    backup_yv12 = xd->plane[AOM_PLANE_Y].pre[ref_idx];
-    av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL,
-                         num_planes);
-  } else {
-    av1_setup_pre_planes(xd, ref_idx, yv12, mi_row, mi_col,
-                         get_ref_scale_factors(cm, ref), num_planes);
-  }
-
-  // This overwrites the mv_limits so we will need to restore it later.
-  av1_set_mv_search_range(&x->mv_limits, &ref_mv);
-  var = av1_full_pixel_search(
-      cpi, x, bsize, &ref_mv_full, step_param, search_methods, do_mesh_search,
-      sadpb, cond_cost_list(cpi, cost_list), &ref_mv, INT_MAX, 1,
-      mi_col * MI_SIZE, mi_row * MI_SIZE, 0, &cpi->ss_cfg[SS_CFG_SRC]);
-  // Restore
-  x->mv_limits = tmp_mv_limits;
-
-  const int use_subpel_search =
-      var < INT_MAX && !cpi->common.cur_frame_force_integer_mv && use_subpixel;
-  if (use_subpel_search) {
-    int not_used = 0;
-    if (cpi->sf.use_accurate_subpel_search) {
-      const int pw = block_size_wide[bsize];
-      const int ph = block_size_high[bsize];
-      cpi->find_fractional_mv_step(
-          x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
-          x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
-          cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
-          x->nmv_vec_cost, x->mv_cost_stack, &not_used, &x->pred_sse[ref], NULL,
-          NULL, 0, 0, pw, ph, cpi->sf.use_accurate_subpel_search, 1);
-    } else {
-      cpi->find_fractional_mv_step(
-          x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
-          x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
-          cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
-          x->nmv_vec_cost, x->mv_cost_stack, &not_used, &x->pred_sse[ref], NULL,
-          NULL, 0, 0, 0, 0, 0, 1);
-    }
-  } else {
-    // Manually convert from units of pixel to 1/8-pixels if we are not doing
-    // subpel search
-    x->best_mv.as_mv.row *= 8;
-    x->best_mv.as_mv.col *= 8;
-  }
-
-  mbmi->mv[0].as_mv = x->best_mv.as_mv;
-
-  // Get a copy of the prediction output
-  set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
-  av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
-                                AOM_PLANE_Y, AOM_PLANE_Y);
-
-  aom_clear_system_state();
-
-  if (scaled_ref_frame) {
-    xd->plane[AOM_PLANE_Y].pre[ref_idx] = backup_yv12;
-  }
-}
-
 static INLINE void store_pred_mv(MACROBLOCK *x, PICK_MODE_CONTEXT *ctx) {
   memcpy(ctx->pred_mv, x->pred_mv, sizeof(x->pred_mv));
 }
@@ -3304,107 +3143,6 @@
 }
 #undef FEATURES
 
-// Performs a simple_motion_search with a single reference frame and extract
-// the variance of residues. Here features is assumed to be a length 6 array.
-// After this function is called, we will store the following in to features:
-// features[0] = log(1 + dc_q**2/256)
-// features[1] = log(1 + variance_of_residue)
-// for i in [2, 3, 4, 5]:
-//  features[i] = log(1 + variance_of_residue_in_block[i]/variance_of_residue)
-static void get_res_var_features(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
-                                 int mi_col, BLOCK_SIZE bsize,
-                                 float *features) {
-  // TODO(chiyotsai@google.com): The data this model trained on did not also use
-  // SIMPLE_TRANSLATION to build the inter_predictor. Retraining and tuning the
-  // model with the correct data should give better performance.
-  assert(mi_size_wide[bsize] == mi_size_high[bsize]);
-
-  MACROBLOCKD *xd = &x->e_mbd;
-
-  // Perform a single motion search in Y_PLANE to make a prediction
-  const int use_subpixel = 0;
-
-  // Start getting the features
-  int f_idx = 0;
-
-  // Q_INDEX
-  const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
-  aom_clear_system_state();
-  features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
-
-  // VARIANCE
-  unsigned int sse = 0;
-  unsigned int var = 0;
-  const MV ref_mv_full = { .row = 0, .col = 0 };
-  av1_simple_motion_sse_var(cpi, x, mi_row, mi_col, bsize, ref_mv_full,
-                            use_subpixel, &sse, &var);
-  aom_clear_system_state();
-  features[f_idx++] = logf(1.0f + (float)var);
-
-  // Regional
-  const uint8_t *src = x->plane[0].src.buf;
-  const int src_stride = x->plane[0].src.stride;
-  const uint8_t *dst = xd->plane[0].dst.buf;
-  const int dst_stride = xd->plane[0].dst.stride;
-  const int bw = block_size_wide[bsize];
-  const int bh = block_size_high[bsize];
-  const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
-  int r_idx = 0;
-  for (r_idx = 0; r_idx < 4; r_idx++) {
-    const int x_idx = (r_idx & 1) * bw / 2;
-    const int y_idx = (r_idx >> 1) * bh / 2;
-    const int src_offset = y_idx * src_stride + x_idx;
-    const int dst_offset = y_idx * dst_stride + x_idx;
-    const unsigned int sub_var = cpi->fn_ptr[subsize].vf(
-        src + src_offset, src_stride, dst + dst_offset, dst_stride, &sse);
-    aom_clear_system_state();
-    const float var_ratio = (1.0f + (float)sub_var) / (4.0f + (float)var);
-    features[f_idx++] = var_ratio;
-  }
-}
-
-static void simple_motion_search_based_split(
-    AV1_COMP *const cpi, MACROBLOCK *x, int mi_row, int mi_col,
-    BLOCK_SIZE bsize, int *partition_none_allowed, int *partition_horz_allowed,
-    int *partition_vert_allowed, int *do_rectangular_split) {
-  const NN_CONFIG *nn_config = NULL;
-  float split_only_thresh = 0.0f;
-  if (bsize == BLOCK_128X128) {
-    nn_config = &av1_simple_motion_search_based_split_nn_config_128;
-    split_only_thresh = av1_simple_motion_search_based_split_thresh_128;
-  } else if (bsize == BLOCK_64X64) {
-    nn_config = &av1_simple_motion_search_based_split_nn_config_64;
-    split_only_thresh = av1_simple_motion_search_based_split_thresh_64;
-  } else if (bsize == BLOCK_32X32) {
-    nn_config = &av1_simple_motion_search_based_split_nn_config_32;
-    split_only_thresh = av1_simple_motion_search_based_split_thresh_32;
-  } else if (bsize == BLOCK_16X16) {
-    nn_config = &av1_simple_motion_search_based_split_nn_config_16;
-    split_only_thresh = av1_simple_motion_search_based_split_thresh_16;
-  } else if (bsize == BLOCK_8X8) {
-    // Disable BLOCK_8X8 for now
-#if !CONFIG_DISABLE_FULL_PIXEL_SPLIT_8X8
-    nn_config = &av1_simple_motion_search_based_split_nn_config_8;
-    split_only_thresh = av1_simple_motion_search_based_split_thresh_8;
-#endif
-  } else {
-    assert(0 && "Unexpected block size in simple_motion_based_split");
-  }
-  if (nn_config) {
-    float features[6] = { 0 };
-    float score = 0;
-    get_res_var_features(cpi, x, mi_row, mi_col, bsize, features);
-    av1_nn_predict(features, nn_config, &score);
-
-    if (score > split_only_thresh) {
-      *partition_none_allowed = 0;
-      *partition_horz_allowed = 0;
-      *partition_vert_allowed = 0;
-      *do_rectangular_split = 0;
-    }
-  }
-}
-
 // Given a list of ref frames in refs, performs simple_motion_search on each of
 // the refs and returns the ref with the smallest sse. Returns -1 if none of the
 // ref in the list is available. Also stores the best sse and var in best_sse,
@@ -3444,8 +3182,8 @@
 
     if (cpi->ref_frame_flags & av1_ref_frame_flag_list[ref]) {
       unsigned int curr_sse = 0, curr_var = 0;
-      simple_motion_search(cpi, x, mi_row, mi_col, bsize, ref,
-                           mv_ref_fulls[ref], num_planes, use_subpixel);
+      av1_simple_motion_search(cpi, x, mi_row, mi_col, bsize, ref,
+                               mv_ref_fulls[ref], num_planes, use_subpixel);
       curr_var = cpi->fn_ptr[bsize].vf(
           x->plane[0].src.buf, x->plane[0].src.stride, xd->plane[0].dst.buf,
           xd->plane[0].dst.stride, &curr_sse);
@@ -4161,7 +3899,7 @@
       !av1_superres_scaled(cm);
 
   if (try_split_only) {
-    simple_motion_search_based_split(
+    av1_simple_motion_search_based_split(
         cpi, x, mi_row, mi_col, bsize, &partition_none_allowed,
         &partition_horz_allowed, &partition_vert_allowed,
         &do_rectangular_split);
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index d4a4d34..c9573bd 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -19,6 +19,7 @@
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_mem/aom_mem.h"
 #include "aom_ports/mem.h"
+#include "aom_ports/system_state.h"
 
 #include "av1/common/common.h"
 #include "av1/common/mvref_common.h"
@@ -28,6 +29,7 @@
 #include "av1/encoder/encoder.h"
 #include "av1/encoder/encodemv.h"
 #include "av1/encoder/mcomp.h"
+#include "av1/encoder/partition_strategy.h"
 #include "av1/encoder/rdopt.h"
 #include "av1/encoder/reconinter_enc.h"
 
@@ -3064,3 +3066,117 @@
   lower_mv_precision(bestmv, allow_hp, 0);
   return besterr;
 }
+
+void av1_simple_motion_search(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
+                              int mi_col, BLOCK_SIZE bsize, int ref,
+                              MV ref_mv_full, int num_planes,
+                              int use_subpixel) {
+  assert(num_planes == 1 &&
+         "Currently simple_motion_search only supports luma plane");
+  assert(!frame_is_intra_only(&cpi->common) &&
+         "Simple motion search only enabled for non-key frames");
+  AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *xd = &x->e_mbd;
+
+  set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
+
+  MB_MODE_INFO *mbmi = xd->mi[0];
+  mbmi->sb_type = bsize;
+  mbmi->ref_frame[0] = ref;
+  mbmi->ref_frame[1] = NONE_FRAME;
+  mbmi->motion_mode = SIMPLE_TRANSLATION;
+
+  const YV12_BUFFER_CONFIG *yv12 = get_ref_frame_yv12_buf(cm, ref);
+  const YV12_BUFFER_CONFIG *scaled_ref_frame =
+      av1_get_scaled_ref_frame(cpi, ref);
+  struct buf_2d backup_yv12;
+  // ref_mv is used to code the motion vector. ref_mv_full is the initial point.
+  // ref_mv is in units of 1/8 pel whereas ref_mv_full is in units of pel.
+  MV ref_mv = { 0, 0 };
+  const int step_param = cpi->mv_step_param;
+  const MvLimits tmp_mv_limits = x->mv_limits;
+  const SEARCH_METHODS search_methods = NSTEP;
+  const int do_mesh_search = 0;
+  const int sadpb = x->sadperbit16;
+  int cost_list[5];
+  const int ref_idx = 0;
+  int var;
+
+  if (scaled_ref_frame) {
+    backup_yv12 = xd->plane[AOM_PLANE_Y].pre[ref_idx];
+    av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL,
+                         num_planes);
+  } else {
+    av1_setup_pre_planes(xd, ref_idx, yv12, mi_row, mi_col,
+                         get_ref_scale_factors(cm, ref), num_planes);
+  }
+
+  // This overwrites the mv_limits so we will need to restore it later.
+  av1_set_mv_search_range(&x->mv_limits, &ref_mv);
+  var = av1_full_pixel_search(
+      cpi, x, bsize, &ref_mv_full, step_param, search_methods, do_mesh_search,
+      sadpb, cond_cost_list(cpi, cost_list), &ref_mv, INT_MAX, 1,
+      mi_col * MI_SIZE, mi_row * MI_SIZE, 0, &cpi->ss_cfg[SS_CFG_SRC]);
+  // Restore
+  x->mv_limits = tmp_mv_limits;
+
+  const int use_subpel_search =
+      var < INT_MAX && !cpi->common.cur_frame_force_integer_mv && use_subpixel;
+  if (use_subpel_search) {
+    int not_used = 0;
+    if (cpi->sf.use_accurate_subpel_search) {
+      const int pw = block_size_wide[bsize];
+      const int ph = block_size_high[bsize];
+      cpi->find_fractional_mv_step(
+          x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
+          x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
+          cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
+          x->nmv_vec_cost, x->mv_cost_stack, &not_used, &x->pred_sse[ref], NULL,
+          NULL, 0, 0, pw, ph, cpi->sf.use_accurate_subpel_search, 1);
+    } else {
+      cpi->find_fractional_mv_step(
+          x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
+          x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
+          cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
+          x->nmv_vec_cost, x->mv_cost_stack, &not_used, &x->pred_sse[ref], NULL,
+          NULL, 0, 0, 0, 0, 0, 1);
+    }
+  } else {
+    // Manually convert from units of pixel to 1/8-pixels if we are not doing
+    // subpel search
+    x->best_mv.as_mv.row *= 8;
+    x->best_mv.as_mv.col *= 8;
+  }
+
+  mbmi->mv[0].as_mv = x->best_mv.as_mv;
+
+  // Get a copy of the prediction output
+  set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
+  av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
+                                AOM_PLANE_Y, AOM_PLANE_Y);
+
+  aom_clear_system_state();
+
+  if (scaled_ref_frame) {
+    xd->plane[AOM_PLANE_Y].pre[ref_idx] = backup_yv12;
+  }
+}
+
+void av1_simple_motion_sse_var(AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
+                               int mi_col, BLOCK_SIZE bsize,
+                               const MV ref_mv_full, int use_subpixel,
+                               unsigned int *sse, unsigned int *var) {
+  MACROBLOCKD *xd = &x->e_mbd;
+  const MV_REFERENCE_FRAME ref =
+      cpi->rc.is_src_frame_alt_ref ? ALTREF_FRAME : LAST_FRAME;
+
+  av1_simple_motion_search(cpi, x, mi_row, mi_col, bsize, ref, ref_mv_full, 1,
+                           use_subpixel);
+
+  const uint8_t *src = x->plane[0].src.buf;
+  const int src_stride = x->plane[0].src.stride;
+  const uint8_t *dst = xd->plane[0].dst.buf;
+  const int dst_stride = xd->plane[0].dst.stride;
+
+  *var = cpi->fn_ptr[bsize].vf(src, src_stride, dst, dst_stride, sse);
+}
diff --git a/av1/encoder/mcomp.h b/av1/encoder/mcomp.h
index 4131a36..71547da 100644
--- a/av1/encoder/mcomp.h
+++ b/av1/encoder/mcomp.h
@@ -13,6 +13,7 @@
 #define AOM_AV1_ENCODER_MCOMP_H_
 
 #include "av1/encoder/block.h"
+
 #include "aom_dsp/variance.h"
 
 #ifdef __cplusplus
@@ -161,6 +162,19 @@
                                   int mi_row, int mi_col, int *pts0,
                                   int *pts_inref0, int total_samples);
 
+// Performs a motion search in SIMPLE_TRANSLATION mode using reference frame
+// ref. Note that this sets the offset of mbmi, so we will need to reset it
+// after calling this function.
+void av1_simple_motion_search(struct AV1_COMP *const cpi, MACROBLOCK *x,
+                              int mi_row, int mi_col, BLOCK_SIZE bsize, int ref,
+                              MV ref_mv_full, int num_planes, int use_subpixel);
+
+// Performs a simple motion search to calculate the sse and var of the residue
+void av1_simple_motion_sse_var(struct AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
+                               int mi_col, BLOCK_SIZE bsize,
+                               const MV ref_mv_full, int use_subpixel,
+                               unsigned int *sse, unsigned int *var);
+
 static INLINE void av1_set_fractional_mv(int_mv *fractional_best_mv) {
   for (int z = 0; z < 3; z++) {
     fractional_best_mv[z].as_int = INVALID_MV;
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
new file mode 100644
index 0000000..1f587cc
--- /dev/null
+++ b/av1/encoder/partition_strategy.c
@@ -0,0 +1,120 @@
+/*
+ * Copyright (c) 2019, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "aom_ports/system_state.h"
+
+#include "av1/common/enums.h"
+#include "av1/common/reconinter.h"
+
+#include "av1/encoder/encoder.h"
+#include "av1/encoder/partition_model_weights.h"
+#include "av1/encoder/partition_strategy.h"
+
+// Performs a simple_motion_search with a single reference frame and extract
+// the variance of residues. Here features is assumed to be a length 6 array.
+// After this function is called, we will store the following in to features:
+// features[0] = log(1 + dc_q**2/256)
+// features[1] = log(1 + variance_of_residue)
+// for i in [2, 3, 4, 5]:
+//  features[i] = log(1 + variance_of_residue_in_block[i]/variance_of_residue)
+static void get_res_var_features(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
+                                 int mi_col, BLOCK_SIZE bsize,
+                                 float *features) {
+  // TODO(chiyotsai@google.com): The data this model trained on did not also use
+  // SIMPLE_TRANSLATION to build the inter_predictor. Retraining and tuning the
+  // model with the correct data should give better performance.
+  assert(mi_size_wide[bsize] == mi_size_high[bsize]);
+
+  MACROBLOCKD *xd = &x->e_mbd;
+
+  // Perform a single motion search in Y_PLANE to make a prediction
+  const int use_subpixel = 0;
+
+  // Start getting the features
+  int f_idx = 0;
+
+  // Q_INDEX
+  const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
+  aom_clear_system_state();
+  features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
+
+  // VARIANCE
+  unsigned int sse = 0;
+  unsigned int var = 0;
+  const MV ref_mv_full = { .row = 0, .col = 0 };
+  av1_simple_motion_sse_var(cpi, x, mi_row, mi_col, bsize, ref_mv_full,
+                            use_subpixel, &sse, &var);
+  aom_clear_system_state();
+  features[f_idx++] = logf(1.0f + (float)var);
+
+  // Regional
+  const uint8_t *src = x->plane[0].src.buf;
+  const int src_stride = x->plane[0].src.stride;
+  const uint8_t *dst = xd->plane[0].dst.buf;
+  const int dst_stride = xd->plane[0].dst.stride;
+  const int bw = block_size_wide[bsize];
+  const int bh = block_size_high[bsize];
+  const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+  int r_idx = 0;
+  for (r_idx = 0; r_idx < 4; r_idx++) {
+    const int x_idx = (r_idx & 1) * bw / 2;
+    const int y_idx = (r_idx >> 1) * bh / 2;
+    const int src_offset = y_idx * src_stride + x_idx;
+    const int dst_offset = y_idx * dst_stride + x_idx;
+    const unsigned int sub_var = cpi->fn_ptr[subsize].vf(
+        src + src_offset, src_stride, dst + dst_offset, dst_stride, &sse);
+    aom_clear_system_state();
+    const float var_ratio = (1.0f + (float)sub_var) / (4.0f + (float)var);
+    features[f_idx++] = var_ratio;
+  }
+}
+
+void av1_simple_motion_search_based_split(
+    AV1_COMP *const cpi, MACROBLOCK *x, int mi_row, int mi_col,
+    BLOCK_SIZE bsize, int *partition_none_allowed, int *partition_horz_allowed,
+    int *partition_vert_allowed, int *do_rectangular_split) {
+  const NN_CONFIG *nn_config = NULL;
+  float split_only_thresh = 0.0f;
+  if (bsize == BLOCK_128X128) {
+    nn_config = &av1_simple_motion_search_based_split_nn_config_128;
+    split_only_thresh = av1_simple_motion_search_based_split_thresh_128;
+  } else if (bsize == BLOCK_64X64) {
+    nn_config = &av1_simple_motion_search_based_split_nn_config_64;
+    split_only_thresh = av1_simple_motion_search_based_split_thresh_64;
+  } else if (bsize == BLOCK_32X32) {
+    nn_config = &av1_simple_motion_search_based_split_nn_config_32;
+    split_only_thresh = av1_simple_motion_search_based_split_thresh_32;
+  } else if (bsize == BLOCK_16X16) {
+    nn_config = &av1_simple_motion_search_based_split_nn_config_16;
+    split_only_thresh = av1_simple_motion_search_based_split_thresh_16;
+  } else if (bsize == BLOCK_8X8) {
+    // Disable BLOCK_8X8 for now
+#if !CONFIG_DISABLE_FULL_PIXEL_SPLIT_8X8
+    nn_config = &av1_simple_motion_search_based_split_nn_config_8;
+    split_only_thresh = av1_simple_motion_search_based_split_thresh_8;
+#endif
+  } else {
+    assert(0 && "Unexpected block size in simple_motion_based_split");
+  }
+  if (nn_config) {
+    float features[6] = { 0 };
+    float score = 0;
+    get_res_var_features(cpi, x, mi_row, mi_col, bsize, features);
+    av1_nn_predict(features, nn_config, &score);
+
+    if (score > split_only_thresh) {
+      *partition_none_allowed = 0;
+      *partition_horz_allowed = 0;
+      *partition_vert_allowed = 0;
+      *do_rectangular_split = 0;
+    }
+  }
+}
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
new file mode 100644
index 0000000..8ec32c5
--- /dev/null
+++ b/av1/encoder/partition_strategy.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2019, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_AV1_ENCODER_PARTITION_STRATEGY_H_
+#define AOM_AV1_ENCODER_PARTITION_STRATEGY_H_
+
+#include "av1/encoder/encodeframe.h"
+#include "av1/encoder/encodemb.h"
+#include "av1/encoder/encoder.h"
+
+// Performs a simple_motion_search with a single reference frame and extract
+// the variance of residues. Then use the features to determine whether we want
+// to go straight to splitting without trying PARTITION_NONE
+void av1_simple_motion_search_based_split(
+    AV1_COMP *const cpi, MACROBLOCK *x, int mi_row, int mi_col,
+    BLOCK_SIZE bsize, int *partition_none_allowed, int *partition_horz_allowed,
+    int *partition_vert_allowed, int *do_rectangular_split);
+
+// A simplified version of set_offsets meant to be used for
+// simple_motion_search.
+static INLINE void set_offsets_for_motion_search(const AV1_COMP *const cpi,
+                                                 MACROBLOCK *const x,
+                                                 int mi_row, int mi_col,
+                                                 BLOCK_SIZE bsize) {
+  const AV1_COMMON *const cm = &cpi->common;
+  const int num_planes = av1_num_planes(cm);
+  MACROBLOCKD *const xd = &x->e_mbd;
+  const int mi_width = mi_size_wide[bsize];
+  const int mi_height = mi_size_high[bsize];
+
+  set_mode_info_offsets(cpi, x, xd, mi_row, mi_col);
+
+  // Set up destination pointers.
+  av1_setup_dst_planes(xd->plane, bsize, &cm->cur_frame->buf, mi_row, mi_col, 0,
+                       num_planes);
+
+  // Set up limit values for MV components.
+  // Mv beyond the range do not produce new/different prediction block.
+  x->mv_limits.row_min =
+      -(((mi_row + mi_height) * MI_SIZE) + AOM_INTERP_EXTEND);
+  x->mv_limits.col_min = -(((mi_col + mi_width) * MI_SIZE) + AOM_INTERP_EXTEND);
+  x->mv_limits.row_max = (cm->mi_rows - mi_row) * MI_SIZE + AOM_INTERP_EXTEND;
+  x->mv_limits.col_max = (cm->mi_cols - mi_col) * MI_SIZE + AOM_INTERP_EXTEND;
+
+  set_plane_n4(xd, mi_width, mi_height, num_planes);
+
+  // Set up distance of MB to edge of frame in 1/8th pel units.
+  assert(!(mi_col & (mi_width - 1)) && !(mi_row & (mi_height - 1)));
+  xd->mb_to_top_edge = -((mi_row * MI_SIZE) * 8);
+  xd->mb_to_bottom_edge = ((cm->mi_rows - mi_height - mi_row) * MI_SIZE) * 8;
+  xd->mb_to_left_edge = -((mi_col * MI_SIZE) * 8);
+  xd->mb_to_right_edge = ((cm->mi_cols - mi_width - mi_col) * MI_SIZE) * 8;
+
+  // Set up source buffers.
+  av1_setup_src_planes(x, cpi->source, mi_row, mi_col, num_planes, bsize);
+
+  // R/D setup.
+  x->rdmult = cpi->rd.RDMULT;
+}
+#endif  // AOM_AV1_ENCODER_PARTITION_STRATEGY_H_
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index a2c8a57..7ba1b18 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -116,10 +116,6 @@
 unsigned int av1_high_get_sby_perpixel_variance(const struct AV1_COMP *cpi,
                                                 const struct buf_2d *ref,
                                                 BLOCK_SIZE bs, int bd);
-void av1_simple_motion_sse_var(AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
-                               int mi_col, BLOCK_SIZE bsize, MV ref_mv_full,
-                               int use_subpixel, unsigned int *sse,
-                               unsigned int *var);
 
 void av1_rd_pick_inter_mode_sb(struct AV1_COMP *cpi,
                                struct TileDataEnc *tile_data,