Merge tag 'v3.6.1' into HEAD

libaom v3.6.1 release

2023-05-08 v3.6.1
  This release includes several bug fixes. This release is ABI
  compatible with the last release. See
  https://aomedia.googlesource.com/aom/+log/v3.6.0..v3.6.1 for all the
  commits in this release.

  - Bug Fixes
    * aomedia:2871: Guard the support of the 7.x and 8.x levels for AV1
      under the CONFIG_CWG_C013 config flag, and only output the 7.x and
      8.x levels when explicitly requested.
    * aomedia:3382: Choose sb_size by ppi instead of svc.
    * aomedia:3384: Fix fullpel search limits.
    * aomedia:3388: Replace left shift of xq_active by multiplication.
    * aomedia:3389: Fix MV clamping in av1_mv_pred.
    * aomedia:3390: set_ld_layer_depth: cap max_layer_depth to
      MAX_ARF_LAYERS.
    * aomedia:3418: Fix MV clamping in av1_int_pro_motion_estimation.
    * aomedia:3429: Move lpf thread data init to lpf_pipeline_mt_init().
    * b:266719111: Fix undefined behavior in Arm Neon code.
    * b:269840681: nonrd_opt: align scan tables.
    * rtc: Fix is_key_frame setting in variance partition.
    * Build: Fix build with clang-cl and Visual Studio.
    * Build: Fix module definition file for MinGW/MSYS.

Change-Id: I98f8033273b1ba51ad59cd532b9ce64bf419ebc4
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 87d88fa..6fb7362 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -11,7 +11,7 @@
 if(CONFIG_TFLITE)
   cmake_minimum_required(VERSION 3.11)
 else()
-  cmake_minimum_required(VERSION 3.7)
+  cmake_minimum_required(VERSION 3.9)
 endif()
 
 set(AOM_ROOT "${CMAKE_CURRENT_SOURCE_DIR}")
@@ -41,6 +41,13 @@
   endif()
 endif()
 
+if(MSVC AND MSVC_VERSION LESS 1920)
+  message(
+    WARNING
+      "MSVC versions prior to 2019 (v16) are not supported and may generate"
+      " incorrect code!")
+endif()
+
 # Library version info. Update LT_CURRENT, LT_REVISION and LT_AGE when making a
 # public release by following the guidelines in the libtool document:
 # https://www.gnu.org/software/libtool/manual/libtool.html#Updating-version-info
@@ -210,13 +217,9 @@
 include_directories(${AOM_ROOT} ${AOM_CONFIG_DIR} ${AOM_ROOT}/apps
                     ${AOM_ROOT}/common ${AOM_ROOT}/examples ${AOM_ROOT}/stats)
 
-if(CONFIG_RUNTIME_CPU_DETECT AND ANDROID_NDK)
-  include_directories(${ANDROID_NDK}/sources/android/cpufeatures)
-endif()
-
 # Targets
 add_library(aom_version ${AOM_VERSION_SOURCES})
-add_dummy_source_file_to_target(aom_version c)
+add_no_op_source_file_to_target(aom_version c)
 add_custom_command(OUTPUT "${AOM_CONFIG_DIR}/config/aom_version.h"
                    COMMAND ${CMAKE_COMMAND} ARGS
                            -DAOM_CONFIG_DIR=${AOM_CONFIG_DIR}
@@ -270,11 +273,34 @@
   set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_encoder_stats)
 endif()
 
-add_library(aom ${AOM_SOURCES} $<TARGET_OBJECTS:aom_rtcd>)
-if(BUILD_SHARED_LIBS)
-  add_library(aom_static STATIC ${AOM_SOURCES} $<TARGET_OBJECTS:aom_rtcd>)
-  set_target_properties(aom_static PROPERTIES OUTPUT_NAME aom)
+# Xcode generator cannot take a library composed solely of objects. See
+# https://gitlab.kitware.com/cmake/cmake/-/issues/17500
+if(XCODE)
+  set(target_objs_aom ${AOM_SOURCES})
+else()
+  add_library(aom_obj OBJECT ${AOM_SOURCES})
+  set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_obj)
+  set(target_objs_aom $<TARGET_OBJECTS:aom_obj>)
+endif()
+add_library(aom ${target_objs_aom} $<TARGET_OBJECTS:aom_rtcd>)
 
+if(BUILD_SHARED_LIBS)
+  add_library(aom_static STATIC ${target_objs_aom} $<TARGET_OBJECTS:aom_rtcd>)
+  if(MSVC OR (WIN32 AND NOT MINGW))
+    # Fix race condition on the export library file between the two versions.
+    # Affects MSVC in all three flavors (stock, Clang/CL, LLVM-- the latter sets
+    # MSVC and MINGW both to FALSE).
+    set_target_properties(aom PROPERTIES ARCHIVE_OUTPUT_NAME "aom_shared")
+  endif()
+  # Tests and examples rely on having unexported AV1 symbols visible. This is
+  # not the case on Windows platforms due to enforcement of the module
+  # definition file.
+  set(AOM_LIB aom_static)
+else()
+  set(AOM_LIB aom)
+endif()
+
+if(BUILD_SHARED_LIBS)
   if(NOT MSVC)
     # Extract version string and set VERSION/SOVERSION for the aom target.
     extract_version_string("${AOM_CONFIG_DIR}/config/aom_version.h"
@@ -304,7 +330,7 @@
   endif()
 endif()
 
-if(CONFIG_AV1_RC_RTC AND CONFIG_AV1_ENCODER AND NOT BUILD_SHARED_LIBS)
+if(CONFIG_AV1_ENCODER AND NOT CONFIG_REALTIME_ONLY AND NOT BUILD_SHARED_LIBS)
   list(APPEND AOM_AV1_RC_SOURCES "${AOM_ROOT}/av1/ratectrl_rtc.h"
               "${AOM_ROOT}/av1/ratectrl_rtc.cc")
   add_library(aom_av1_rc ${AOM_AV1_RC_SOURCES})
@@ -312,33 +338,13 @@
   if(NOT WIN32 AND NOT APPLE)
     target_link_libraries(aom_av1_rc ${AOM_LIB_LINK_TYPE} m)
   endif()
-endif()
-
-if(CONFIG_AV1_ENCODER AND NOT CONFIG_REALTIME_ONLY AND NOT BUILD_SHARED_LIBS)
-  list(APPEND AOM_AV1_RC_QMODE_SOURCES
-              "${AOM_ROOT}/av1/qmode_rc/ratectrl_qmode_interface.h"
-              "${AOM_ROOT}/av1/qmode_rc/ratectrl_qmode_interface.cc"
-              "${AOM_ROOT}/av1/qmode_rc/reference_manager.h"
-              "${AOM_ROOT}/av1/qmode_rc/reference_manager.cc"
-              "${AOM_ROOT}/av1/qmode_rc/ratectrl_qmode.h"
-              "${AOM_ROOT}/av1/qmode_rc/ratectrl_qmode.cc"
-              "${AOM_ROOT}/av1/qmode_rc/ducky_encode.h"
-              "${AOM_ROOT}/av1/qmode_rc/ducky_encode.cc")
-  add_library(av1_rc_qmode ${AOM_AV1_RC_QMODE_SOURCES})
-  target_link_libraries(av1_rc_qmode ${AOM_LIB_LINK_TYPE} aom)
-  if(NOT MSVC AND NOT APPLE)
-    target_link_libraries(av1_rc_qmode ${AOM_LIB_LINK_TYPE} m)
-  endif()
-  set_target_properties(av1_rc_qmode PROPERTIES LINKER_LANGUAGE CXX)
+  set_target_properties(aom_av1_rc PROPERTIES LINKER_LANGUAGE CXX)
 endif()
 
 # List of object and static library targets.
 set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_rtcd aom_mem aom_scale aom)
-if(CONFIG_AV1_RC_RTC AND CONFIG_AV1_ENCODER AND NOT BUILD_SHARED_LIBS)
-  set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_av1_rc)
-endif()
 if(CONFIG_AV1_ENCODER AND NOT CONFIG_REALTIME_ONLY AND NOT BUILD_SHARED_LIBS)
-  set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} av1_rc_qmode)
+  set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_av1_rc)
 endif()
 if(BUILD_SHARED_LIBS)
   set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_static)
@@ -356,19 +362,22 @@
 setup_av1_targets()
 
 # Make all library targets depend on aom_rtcd to make sure it builds first.
-foreach(aom_lib ${AOM_LIB_TARGETS})
-  if(NOT "${aom_lib}" STREQUAL "aom_rtcd")
-    add_dependencies(${aom_lib} aom_rtcd)
+foreach(aom_tgt ${AOM_LIB_TARGETS})
+  if(NOT "${aom_tgt}" STREQUAL "aom_rtcd")
+    add_dependencies(${aom_tgt} aom_rtcd)
   endif()
 endforeach()
 
-# Generate C/C++ stub files containing the function usage_exit(). Users of the
+# Generate C/C++ files containing the function usage_exit(). Users of the
 # aom_common_app_util library must define this function. This is a convenience
 # to allow omission of the function from applications that might want to use
 # other pieces of the util support without defining usage_exit().
-file(WRITE "${AOM_GEN_SRC_DIR}/usage_exit.c" "void usage_exit(void) {}")
+file(WRITE "${AOM_GEN_SRC_DIR}/usage_exit.c"
+     "#include <stdlib.h>\n\n#include \"common/tools_common.h\"\n\n"
+     "void usage_exit(void) { exit(EXIT_FAILURE); }\n")
 file(WRITE "${AOM_GEN_SRC_DIR}/usage_exit.cc"
-     "extern \"C\" void usage_exit(void) {}")
+     "#include <stdlib.h>\n\n#include \"common/tools_common.h\"\n\n"
+     "extern \"C\" void usage_exit(void) { exit(EXIT_FAILURE); }\n")
 
 #
 # Application and application support targets.
@@ -727,7 +736,7 @@
 endif()
 
 foreach(aom_app ${AOM_APP_TARGETS})
-  target_link_libraries(${aom_app} ${AOM_LIB_LINK_TYPE} aom)
+  target_link_libraries(${aom_app} ${AOM_LIB_LINK_TYPE} ${AOM_LIB})
 endforeach()
 
 if(ENABLE_EXAMPLES OR ENABLE_TESTS OR ENABLE_TOOLS)
@@ -795,7 +804,7 @@
     # here, it really is the Xcode generator's fault, or just a deficiency in
     # Xcode itself.
     foreach(aom_app ${AOM_APP_TARGETS})
-      add_dummy_source_file_to_target("${aom_app}" "cc")
+      add_no_op_source_file_to_target("${aom_app}" "cc")
     endforeach()
   endif()
 endif()
@@ -935,7 +944,7 @@
 get_cmake_property(all_cmake_vars VARIABLES)
 foreach(var ${all_cmake_vars})
   if("${var}" MATCHES "SOURCES$\|_INTRIN_\|_ASM_"
-     AND NOT "${var}" MATCHES "_APP_\|DOXYGEN\|LIBWEBM\|LIBYUV\|_PKG_\|TEST")
+     AND NOT "${var}" MATCHES "DOXYGEN\|LIBYUV\|_PKG_\|TEST")
     list(APPEND aom_source_vars ${var})
   endif()
 endforeach()
diff --git a/README.md b/README.md
index 0d51080d..d7b66e0 100644
--- a/README.md
+++ b/README.md
@@ -217,27 +217,26 @@
 ### Microsoft Visual Studio builds {#microsoft-visual-studio-builds}
 
 Building the AV1 codec library in Microsoft Visual Studio is supported. Visual
-Studio 2017 (15.0) or later is required. The following example demonstrates
+Studio 2019 (16.0) or later is required. The following example demonstrates
 generating projects and a solution for the Microsoft IDE:
 
 ~~~
     # This does not require a bash shell; Command Prompt (cmd.exe) is fine.
     # This assumes the build host is a Windows x64 computer.
 
-    # To build with Visual Studio 2019 for the x64 target:
+    # To create a Visual Studio 2022 solution for the x64 target:
+    $ cmake path/to/aom -G "Visual Studio 17 2022"
+
+    # To create a Visual Studio 2022 solution for the 32-bit x86 target:
+    $ cmake path/to/aom -G "Visual Studio 17 2022" -A Win32
+
+    # To create a Visual Studio 2019 solution for the x64 target:
     $ cmake path/to/aom -G "Visual Studio 16 2019"
-    $ cmake --build .
 
-    # To build with Visual Studio 2019 for the 32-bit x86 target:
+    # To create a Visual Studio 2019 solution for the 32-bit x86 target:
     $ cmake path/to/aom -G "Visual Studio 16 2019" -A Win32
-    $ cmake --build .
 
-    # To build with Visual Studio 2017 for the x64 target:
-    $ cmake path/to/aom -G "Visual Studio 15 2017" -T host=x64 -A x64
-    $ cmake --build .
-
-    # To build with Visual Studio 2017 for the 32-bit x86 target:
-    $ cmake path/to/aom -G "Visual Studio 15 2017" -T host=x64
+    # To build the solution:
     $ cmake --build .
 ~~~
 
@@ -575,12 +574,19 @@
 `Generate Password` Password link at the top of the page. You’ll be given
 instructions for creating a cookie to use with our Git repos.
 
+You must also have a Gerrit account associated with your Google account. To do
+this visit the [Gerrit review server](https://aomedia-review.googlesource.com)
+and click "Sign in" (top right).
+
 ### Contributor agreement {#contributor-agreement}
 
 You will be required to execute a
 [contributor agreement](http://aomedia.org/license) to ensure that the AOMedia
 Project has the right to distribute your changes.
 
+Note: If you are pushing changes on behalf of an Alliance for Open Media member
+organization this step is not necessary.
+
 ### Testing your code {#testing-your-code}
 
 The testing basics are covered in the [testing section](#testing-the-av1-codec)
diff --git a/aom/aom_codec.h b/aom/aom_codec.h
index 6a9fb7b..d5b8790 100644
--- a/aom/aom_codec.h
+++ b/aom/aom_codec.h
@@ -417,19 +417,21 @@
  * \param[in]    ctx     Pointer to this instance's context.
  *
  */
-const char *aom_codec_error(aom_codec_ctx_t *ctx);
+const char *aom_codec_error(const aom_codec_ctx_t *ctx);
 
 /*!\brief Retrieve detailed error information for codec context
  *
  * Returns a human readable string providing detailed information about
- * the last error.
+ * the last error. The returned string is only valid until the next
+ * aom_codec_* function call (except aom_codec_error and
+ * aom_codec_error_detail) on the codec context.
  *
  * \param[in]    ctx     Pointer to this instance's context.
  *
  * \retval NULL
  *     No detailed information is available.
  */
-const char *aom_codec_error_detail(aom_codec_ctx_t *ctx);
+const char *aom_codec_error_detail(const aom_codec_ctx_t *ctx);
 
 /* REQUIRED FUNCTIONS
  *
@@ -444,9 +446,11 @@
  * \param[in] ctx   Pointer to this instance's context
  *
  * \retval #AOM_CODEC_OK
- *     The codec algorithm initialized.
- * \retval #AOM_CODEC_MEM_ERROR
- *     Memory allocation failed.
+ *     The codec instance has been destroyed.
+ * \retval #AOM_CODEC_INVALID_PARAM
+ *     ctx is a null pointer.
+ * \retval #AOM_CODEC_ERROR
+ *     Codec context not initialized.
  */
 aom_codec_err_t aom_codec_destroy(aom_codec_ctx_t *ctx);
 
diff --git a/aom/aom_decoder.h b/aom/aom_decoder.h
index 5ce7c7b..f3f11d8 100644
--- a/aom/aom_decoder.h
+++ b/aom/aom_decoder.h
@@ -113,7 +113,7 @@
  * \param[in]    ver     ABI version number. Must be set to
  *                       AOM_DECODER_ABI_VERSION
  * \retval #AOM_CODEC_OK
- *     The decoder algorithm initialized.
+ *     The decoder algorithm has been initialized.
  * \retval #AOM_CODEC_MEM_ERROR
  *     Memory allocation failed.
  */
diff --git a/aom/aom_encoder.h b/aom/aom_encoder.h
index c0efe79..e3d8d29 100644
--- a/aom/aom_encoder.h
+++ b/aom/aom_encoder.h
@@ -903,7 +903,7 @@
 
 /*!\brief Initialize an encoder instance
  *
- * Initializes a encoder context using the given interface. Applications
+ * Initializes an encoder context using the given interface. Applications
  * should call the aom_codec_enc_init convenience macro instead of this
  * function directly, to ensure that the ABI version number parameter
  * is properly initialized.
@@ -912,6 +912,9 @@
  * is not thread safe and should be guarded with a lock if being used
  * in a multithreaded context.
  *
+ * If aom_codec_enc_init_ver() fails, it is not necessary to call
+ * aom_codec_destroy() on the encoder context.
+ *
  * \param[in]    ctx     Pointer to this instance's context.
  * \param[in]    iface   Pointer to the algorithm interface to use.
  * \param[in]    cfg     Configuration to use, if known.
@@ -919,7 +922,7 @@
  * \param[in]    ver     ABI version number. Must be set to
  *                       AOM_ENCODER_ABI_VERSION
  * \retval #AOM_CODEC_OK
- *     The decoder algorithm initialized.
+ *     The encoder algorithm has been initialized.
  * \retval #AOM_CODEC_MEM_ERROR
  *     Memory allocation failed.
  */
@@ -1024,6 +1027,10 @@
  * \param[in]    img       Image data to encode, NULL to flush.
  *                         Encoding sample values outside the range
  *                         [0..(1<<img->bit_depth)-1] is undefined behavior.
+ *                         Note: Although img is declared as a const pointer,
+ *                         if AV1E_SET_DENOISE_NOISE_LEVEL is set to a nonzero
+ *                         value aom_codec_encode() modifies (denoises) the
+ *                         samples in img->planes[i] .
  * \param[in]    pts       Presentation time stamp, in timebase units. If img
  *                         is NULL, pts is ignored.
  * \param[in]    duration  Duration to show frame, in timebase units. If img
diff --git a/aom/aomcx.h b/aom/aomcx.h
index 906cf2a..8887e9a 100644
--- a/aom/aomcx.h
+++ b/aom/aomcx.h
@@ -1481,6 +1481,38 @@
    */
   AV1E_ENABLE_SB_QP_SWEEP = 158,
 
+  /*!\brief Codec control to set quantizer for the next frame, int parameter.
+   *
+   * - Valid range [0, 63]
+   *
+   * This will turn off cyclic refresh. Only applicable to 1-pass.
+   */
+  AV1E_SET_QUANTIZER_ONE_PASS = 159,
+
+  /*!\brief Codec control to enable the rate distribution guided delta
+   * quantization in all intra mode, unsigned int parameter
+   *
+   * - 0 = disable (default)
+   * - 1 = enable
+   *
+   * \attention This feature requires --deltaq-mode=3, also an input file
+   *            which contains rate distribution for each 16x16 block,
+   *            passed in by --rate-distribution-info=rate_distribution.txt.
+   */
+  AV1E_ENABLE_RATE_GUIDE_DELTAQ = 160,
+
+  /*!\brief Codec control to set the input file for rate distribution used
+   * in all intra mode, const char * parameter
+   * The input should be the name of a text file, which
+   * contains (rows x cols) float values separated by space.
+   * Each float value represent the number of bits for each 16x16 block.
+   * rows = (frame_height + 15) / 16
+   * cols = (frame_width + 15) / 16
+   *
+   * \attention This feature requires --enable-rate-guide-deltaq=1.
+   */
+  AV1E_SET_RATE_DISTRIBUTION_INFO = 161,
+
   // Any new encoder control IDs should be added above.
   // Maximum allowed encoder control ID is 229.
   // No encoder control ID should be added below.
@@ -1497,7 +1529,9 @@
   AOME_THREEFOUR = 3,
   AOME_ONEFOUR = 4,
   AOME_ONEEIGHT = 5,
-  AOME_ONETWO = 6
+  AOME_ONETWO = 6,
+  AOME_TWOTHREE = 7,
+  AOME_ONETHREE = 8
 } AOM_SCALING_MODE;
 
 /*!\brief Max number of segments
@@ -1579,6 +1613,7 @@
   AOM_TUNE_VMAF_MAX_GAIN = 6,
   AOM_TUNE_VMAF_NEG_MAX_GAIN = 7,
   AOM_TUNE_BUTTERAUGLI = 8,
+  AOM_TUNE_VMAF_SALIENCY_MAP = 9,
 } aom_tune_metric;
 
 /*!\brief Distortion metric to use for RD optimization.
@@ -1608,7 +1643,12 @@
   int temporal_layer_id; /**< Temporal layer ID */
 } aom_svc_layer_id_t;
 
-/*!brief Parameter type for SVC */
+/*!brief Parameter type for SVC
+ *
+ * In the arrays of size AOM_MAX_LAYERS, the index for spatial layer `sl` and
+ * temporal layer `tl` is sl * number_temporal_layers + tl.
+ *
+ */
 typedef struct aom_svc_params {
   int number_spatial_layers;                 /**< Number of spatial layers */
   int number_temporal_layers;                /**< Number of temporal layers */
@@ -1616,7 +1656,7 @@
   int min_quantizers[AOM_MAX_LAYERS];        /**< Min Q for each layer */
   int scaling_factor_num[AOM_MAX_SS_LAYERS]; /**< Scaling factor-numerator */
   int scaling_factor_den[AOM_MAX_SS_LAYERS]; /**< Scaling factor-denominator */
-  /*! Target bitrate for each layer */
+  /*! Target bitrate for each layer, in kilobits per second */
   int layer_target_bitrate[AOM_MAX_LAYERS];
   /*! Frame rate factor for each temporal layer */
   int framerate_factor[AOM_MAX_TS_LAYERS];
@@ -2103,6 +2143,15 @@
 AOM_CTRL_USE_TYPE(AV1E_ENABLE_SB_QP_SWEEP, unsigned int)
 #define AOM_CTRL_AV1E_ENABLE_SB_QP_SWEEP
 
+AOM_CTRL_USE_TYPE(AV1E_SET_QUANTIZER_ONE_PASS, int)
+#define AOM_CTRL_AV1E_SET_QUANTIZER_ONE_PASS
+
+AOM_CTRL_USE_TYPE(AV1E_ENABLE_RATE_GUIDE_DELTAQ, unsigned int)
+#define AOM_CTRL_AV1E_ENABLE_RATE_GUIDE_DELTAQ
+
+AOM_CTRL_USE_TYPE(AV1E_SET_RATE_DISTRIBUTION_INFO, const char *)
+#define AOM_CTRL_AV1E_SET_RATE_DISTRIBUTION_INFO
+
 /*!\endcond */
 /*! @} - end defgroup aom_encoder */
 #ifdef __cplusplus
diff --git a/aom/src/aom_codec.c b/aom/src/aom_codec.c
index bc2039a..4e75fcb 100644
--- a/aom/src/aom_codec.c
+++ b/aom/src/aom_codec.c
@@ -52,12 +52,12 @@
   return "Unrecognized error code";
 }
 
-const char *aom_codec_error(aom_codec_ctx_t *ctx) {
+const char *aom_codec_error(const aom_codec_ctx_t *ctx) {
   return (ctx) ? aom_codec_err_to_string(ctx->err)
                : aom_codec_err_to_string(AOM_CODEC_INVALID_PARAM);
 }
 
-const char *aom_codec_error_detail(aom_codec_ctx_t *ctx) {
+const char *aom_codec_error_detail(const aom_codec_ctx_t *ctx) {
   if (ctx && ctx->err)
     return ctx->priv ? ctx->priv->err_detail : ctx->err_detail;
 
@@ -81,7 +81,7 @@
 }
 
 aom_codec_caps_t aom_codec_get_caps(aom_codec_iface_t *iface) {
-  return (iface) ? iface->caps : 0;
+  return iface ? iface->caps : 0;
 }
 
 aom_codec_err_t aom_codec_control(aom_codec_ctx_t *ctx, int ctrl_id, ...) {
diff --git a/aom/src/aom_encoder.c b/aom/src/aom_encoder.c
index 6ec2f34..f9fe2fe 100644
--- a/aom/src/aom_encoder.c
+++ b/aom/src/aom_encoder.c
@@ -80,6 +80,10 @@
     res = ctx->iface->init(ctx);
 
     if (res) {
+      // IMPORTANT: ctx->priv->err_detail must be null or point to a string
+      // that remains valid after ctx->priv is destroyed, such as a C string
+      // literal. This makes it safe to call aom_codec_error_detail() after
+      // aom_codec_enc_init_ver() failed.
       ctx->err_detail = ctx->priv ? ctx->priv->err_detail : NULL;
       aom_codec_destroy(ctx);
     }
@@ -92,7 +96,6 @@
                                              aom_codec_enc_cfg_t *cfg,
                                              unsigned int usage) {
   aom_codec_err_t res;
-  int i;
 
   if (!iface || !cfg)
     res = AOM_CODEC_INVALID_PARAM;
@@ -101,22 +104,20 @@
   else {
     res = AOM_CODEC_INVALID_PARAM;
 
-    for (i = 0; i < iface->enc.cfg_count; ++i) {
+    for (int i = 0; i < iface->enc.cfg_count; ++i) {
       if (iface->enc.cfgs[i].g_usage == usage) {
         *cfg = iface->enc.cfgs[i];
         res = AOM_CODEC_OK;
+        /* default values */
+        memset(&cfg->encoder_cfg, 0, sizeof(cfg->encoder_cfg));
+        cfg->encoder_cfg.super_block_size = 0;  // Dynamic
+        cfg->encoder_cfg.max_partition_size = 128;
+        cfg->encoder_cfg.min_partition_size = 4;
+        cfg->encoder_cfg.disable_trellis_quant = 3;
         break;
       }
     }
   }
-  /* default values */
-  if (cfg) {
-    memset(&cfg->encoder_cfg, 0, sizeof(cfg->encoder_cfg));
-    cfg->encoder_cfg.super_block_size = 0;  // Dynamic
-    cfg->encoder_cfg.max_partition_size = 128;
-    cfg->encoder_cfg.min_partition_size = 4;
-    cfg->encoder_cfg.disable_trellis_quant = 3;
-  }
   return res;
 }
 
@@ -138,8 +139,10 @@
   const int float_excepts =           \
       feenableexcept(FE_DIVBYZERO | FE_UNDERFLOW | FE_OVERFLOW);
 #define FLOATING_POINT_RESTORE_EXCEPTIONS \
-  fedisableexcept(FE_ALL_EXCEPT);         \
-  feenableexcept(float_excepts);
+  if (float_excepts != -1) {              \
+    fedisableexcept(FE_ALL_EXCEPT);       \
+    feenableexcept(float_excepts);        \
+  }
 #else
 #define FLOATING_POINT_SET_EXCEPTIONS
 #define FLOATING_POINT_RESTORE_EXCEPTIONS
diff --git a/aom_dsp/aom_dsp.cmake b/aom_dsp/aom_dsp.cmake
index c5c2db7..f9dbc55f 100644
--- a/aom_dsp/aom_dsp.cmake
+++ b/aom_dsp/aom_dsp.cmake
@@ -117,7 +117,8 @@
             "${AOM_ROOT}/aom_dsp/arm/highbd_intrapred_neon.c"
             "${AOM_ROOT}/aom_dsp/arm/intrapred_neon.c"
             "${AOM_ROOT}/aom_dsp/arm/subtract_neon.c"
-            "${AOM_ROOT}/aom_dsp/arm/blend_a64_mask_neon.c")
+            "${AOM_ROOT}/aom_dsp/arm/blend_a64_mask_neon.c"
+            "${AOM_ROOT}/aom_dsp/arm/avg_pred_neon.c")
 
 if(CONFIG_AV1_HIGHBITDEPTH)
   list(APPEND AOM_DSP_COMMON_INTRIN_SSE2
@@ -176,7 +177,7 @@
 
   # Flow estimation library
   if(NOT CONFIG_REALTIME_ONLY)
-    list(APPEND AOM_DSP_ENCODER_SOURCES
+    list(APPEND AOM_DSP_ENCODER_SOURCES "${AOM_ROOT}/aom_dsp/pyramid.c"
                 "${AOM_ROOT}/aom_dsp/flow_estimation/corner_detect.c"
                 "${AOM_ROOT}/aom_dsp/flow_estimation/corner_match.c"
                 "${AOM_ROOT}/aom_dsp/flow_estimation/disflow.c"
@@ -184,7 +185,8 @@
                 "${AOM_ROOT}/aom_dsp/flow_estimation/ransac.c")
 
     list(APPEND AOM_DSP_ENCODER_INTRIN_SSE4_1
-                "${AOM_ROOT}/aom_dsp/flow_estimation/x86/corner_match_sse4.c")
+                "${AOM_ROOT}/aom_dsp/flow_estimation/x86/corner_match_sse4.c"
+                "${AOM_ROOT}/aom_dsp/flow_estimation/x86/disflow_sse4.c")
 
     list(APPEND AOM_DSP_ENCODER_INTRIN_AVX2
                 "${AOM_ROOT}/aom_dsp/flow_estimation/x86/corner_match_avx2.c")
@@ -208,7 +210,8 @@
               "${AOM_ROOT}/aom_dsp/x86/quantize_x86.h"
               "${AOM_ROOT}/aom_dsp/x86/blk_sse_sum_sse2.c"
               "${AOM_ROOT}/aom_dsp/x86/sum_squares_sse2.c"
-              "${AOM_ROOT}/aom_dsp/x86/variance_sse2.c")
+              "${AOM_ROOT}/aom_dsp/x86/variance_sse2.c"
+              "${AOM_ROOT}/aom_dsp/x86/jnt_sad_sse2.c")
 
   list(APPEND AOM_DSP_ENCODER_ASM_SSSE3_X86_64
               "${AOM_ROOT}/aom_dsp/x86/fwd_txfm_ssse3_x86_64.asm"
@@ -245,8 +248,7 @@
               "${AOM_ROOT}/aom_dsp/x86/masked_variance_intrin_ssse3.c"
               "${AOM_ROOT}/aom_dsp/x86/quantize_ssse3.c"
               "${AOM_ROOT}/aom_dsp/x86/variance_impl_ssse3.c"
-              "${AOM_ROOT}/aom_dsp/x86/jnt_variance_ssse3.c"
-              "${AOM_ROOT}/aom_dsp/x86/jnt_sad_ssse3.c")
+              "${AOM_ROOT}/aom_dsp/x86/jnt_variance_ssse3.c")
 
   list(APPEND AOM_DSP_ENCODER_INTRIN_SSE4_1
               "${AOM_ROOT}/aom_dsp/x86/avg_intrin_sse4.c"
@@ -254,12 +256,15 @@
               "${AOM_ROOT}/aom_dsp/x86/obmc_sad_sse4.c"
               "${AOM_ROOT}/aom_dsp/x86/obmc_variance_sse4.c")
 
-  list(APPEND AOM_DSP_ENCODER_INTRIN_NEON "${AOM_ROOT}/aom_dsp/arm/sad4d_neon.c"
+  list(APPEND AOM_DSP_ENCODER_INTRIN_NEON
+              "${AOM_ROOT}/aom_dsp/arm/sad4d_neon.c"
               "${AOM_ROOT}/aom_dsp/arm/sad_neon.c"
+              "${AOM_ROOT}/aom_dsp/arm/masked_sad_neon.c"
               "${AOM_ROOT}/aom_dsp/arm/subpel_variance_neon.c"
               "${AOM_ROOT}/aom_dsp/arm/variance_neon.c"
               "${AOM_ROOT}/aom_dsp/arm/hadamard_neon.c"
               "${AOM_ROOT}/aom_dsp/arm/avg_neon.c"
+              "${AOM_ROOT}/aom_dsp/arm/obmc_variance_neon.c"
               "${AOM_ROOT}/aom_dsp/arm/sse_neon.c"
               "${AOM_ROOT}/aom_dsp/arm/sum_squares_neon.c")
 
@@ -322,8 +327,8 @@
 function(setup_aom_dsp_targets)
   add_library(aom_dsp_common OBJECT ${AOM_DSP_COMMON_SOURCES})
   list(APPEND AOM_LIB_TARGETS aom_dsp_common)
-  create_dummy_source_file("aom_av1" "c" "dummy_source_file")
-  add_library(aom_dsp OBJECT "${dummy_source_file}")
+  create_no_op_source_file("aom_av1" "c" "no_op_source_file")
+  add_library(aom_dsp OBJECT "${no_op_source_file}")
   target_sources(aom PRIVATE $<TARGET_OBJECTS:aom_dsp_common>)
   if(BUILD_SHARED_LIBS)
     target_sources(aom_static PRIVATE $<TARGET_OBJECTS:aom_dsp_common>)
@@ -331,8 +336,8 @@
   list(APPEND AOM_LIB_TARGETS aom_dsp)
 
   # Not all generators support libraries consisting only of object files. Add a
-  # dummy source file to the aom_dsp target.
-  add_dummy_source_file_to_target("aom_dsp" "c")
+  # source file to the aom_dsp target.
+  add_no_op_source_file_to_target("aom_dsp" "c")
 
   if(CONFIG_AV1_DECODER)
     add_library(aom_dsp_decoder OBJECT ${AOM_DSP_DECODER_SOURCES})
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index b3f8ec7..76827cd 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -16,8 +16,8 @@
 
 #include "aom/aom_integer.h"
 #include "aom_dsp/aom_dsp_common.h"
-#include "av1/common/enums.h"
 #include "av1/common/blockd.h"
+#include "av1/common/enums.h"
 
 EOF
 }
@@ -86,104 +86,104 @@
 }
 
 specialize qw/aom_dc_top_predictor_4x4 neon sse2/;
-specialize qw/aom_dc_top_predictor_4x8 sse2/;
-specialize qw/aom_dc_top_predictor_4x16 sse2/;
-specialize qw/aom_dc_top_predictor_8x4 sse2/;
+specialize qw/aom_dc_top_predictor_4x8 neon sse2/;
+specialize qw/aom_dc_top_predictor_4x16 neon sse2/;
+specialize qw/aom_dc_top_predictor_8x4 neon sse2/;
 specialize qw/aom_dc_top_predictor_8x8 neon sse2/;
-specialize qw/aom_dc_top_predictor_8x16 sse2/;
-specialize qw/aom_dc_top_predictor_8x32 sse2/;
-specialize qw/aom_dc_top_predictor_16x4 sse2/;
-specialize qw/aom_dc_top_predictor_16x8 sse2/;
+specialize qw/aom_dc_top_predictor_8x16 neon sse2/;
+specialize qw/aom_dc_top_predictor_8x32 neon sse2/;
+specialize qw/aom_dc_top_predictor_16x4 neon sse2/;
+specialize qw/aom_dc_top_predictor_16x8 neon sse2/;
 specialize qw/aom_dc_top_predictor_16x16 neon sse2/;
-specialize qw/aom_dc_top_predictor_16x32 sse2/;
-specialize qw/aom_dc_top_predictor_16x64 sse2/;
-specialize qw/aom_dc_top_predictor_32x8 sse2/;
-specialize qw/aom_dc_top_predictor_32x16 sse2 avx2/;
+specialize qw/aom_dc_top_predictor_16x32 neon sse2/;
+specialize qw/aom_dc_top_predictor_16x64 neon sse2/;
+specialize qw/aom_dc_top_predictor_32x8 neon sse2/;
+specialize qw/aom_dc_top_predictor_32x16 neon sse2 avx2/;
 specialize qw/aom_dc_top_predictor_32x32 neon sse2 avx2/;
-specialize qw/aom_dc_top_predictor_32x64 sse2 avx2/;
-specialize qw/aom_dc_top_predictor_64x16 sse2 avx2/;
-specialize qw/aom_dc_top_predictor_64x32 sse2 avx2/;
-specialize qw/aom_dc_top_predictor_64x64 sse2 avx2/;
+specialize qw/aom_dc_top_predictor_32x64 neon sse2 avx2/;
+specialize qw/aom_dc_top_predictor_64x16 neon sse2 avx2/;
+specialize qw/aom_dc_top_predictor_64x32 neon sse2 avx2/;
+specialize qw/aom_dc_top_predictor_64x64 neon sse2 avx2/;
 
 specialize qw/aom_dc_left_predictor_4x4 neon sse2/;
-specialize qw/aom_dc_left_predictor_4x8 sse2/;
-specialize qw/aom_dc_left_predictor_4x16 sse2/;
-specialize qw/aom_dc_left_predictor_8x4 sse2/;
+specialize qw/aom_dc_left_predictor_4x8 neon sse2/;
+specialize qw/aom_dc_left_predictor_4x16 neon sse2/;
+specialize qw/aom_dc_left_predictor_8x4 neon sse2/;
 specialize qw/aom_dc_left_predictor_8x8 neon sse2/;
-specialize qw/aom_dc_left_predictor_8x16 sse2/;
-specialize qw/aom_dc_left_predictor_8x32 sse2/;
-specialize qw/aom_dc_left_predictor_16x4 sse2/;
-specialize qw/aom_dc_left_predictor_16x8 sse2/;
+specialize qw/aom_dc_left_predictor_8x16 neon sse2/;
+specialize qw/aom_dc_left_predictor_8x32 neon sse2/;
+specialize qw/aom_dc_left_predictor_16x4 neon sse2/;
+specialize qw/aom_dc_left_predictor_16x8 neon sse2/;
 specialize qw/aom_dc_left_predictor_16x16 neon sse2/;
-specialize qw/aom_dc_left_predictor_16x32 sse2/;
-specialize qw/aom_dc_left_predictor_16x64 sse2/;
-specialize qw/aom_dc_left_predictor_32x8 sse2/;
-specialize qw/aom_dc_left_predictor_32x16 sse2 avx2/;
+specialize qw/aom_dc_left_predictor_16x32 neon sse2/;
+specialize qw/aom_dc_left_predictor_16x64 neon sse2/;
+specialize qw/aom_dc_left_predictor_32x8 neon sse2/;
+specialize qw/aom_dc_left_predictor_32x16 neon sse2 avx2/;
 specialize qw/aom_dc_left_predictor_32x32 neon sse2 avx2/;
-specialize qw/aom_dc_left_predictor_32x64 sse2 avx2/;
-specialize qw/aom_dc_left_predictor_64x16 sse2 avx2/;
-specialize qw/aom_dc_left_predictor_64x32 sse2 avx2/;
-specialize qw/aom_dc_left_predictor_64x64 sse2 avx2/;
+specialize qw/aom_dc_left_predictor_32x64 neon sse2 avx2/;
+specialize qw/aom_dc_left_predictor_64x16 neon sse2 avx2/;
+specialize qw/aom_dc_left_predictor_64x32 neon sse2 avx2/;
+specialize qw/aom_dc_left_predictor_64x64 neon sse2 avx2/;
 
 specialize qw/aom_dc_128_predictor_4x4 neon sse2/;
-specialize qw/aom_dc_128_predictor_4x8 sse2/;
-specialize qw/aom_dc_128_predictor_4x16 sse2/;
-specialize qw/aom_dc_128_predictor_8x4 sse2/;
+specialize qw/aom_dc_128_predictor_4x8 neon sse2/;
+specialize qw/aom_dc_128_predictor_4x16 neon sse2/;
+specialize qw/aom_dc_128_predictor_8x4 neon sse2/;
 specialize qw/aom_dc_128_predictor_8x8 neon sse2/;
-specialize qw/aom_dc_128_predictor_8x16 sse2/;
-specialize qw/aom_dc_128_predictor_8x32 sse2/;
-specialize qw/aom_dc_128_predictor_16x4 sse2/;
-specialize qw/aom_dc_128_predictor_16x8 sse2/;
+specialize qw/aom_dc_128_predictor_8x16 neon sse2/;
+specialize qw/aom_dc_128_predictor_8x32 neon sse2/;
+specialize qw/aom_dc_128_predictor_16x4 neon sse2/;
+specialize qw/aom_dc_128_predictor_16x8 neon sse2/;
 specialize qw/aom_dc_128_predictor_16x16 neon sse2/;
-specialize qw/aom_dc_128_predictor_16x32 sse2/;
-specialize qw/aom_dc_128_predictor_16x64 sse2/;
-specialize qw/aom_dc_128_predictor_32x8 sse2/;
-specialize qw/aom_dc_128_predictor_32x16 sse2 avx2/;
+specialize qw/aom_dc_128_predictor_16x32 neon sse2/;
+specialize qw/aom_dc_128_predictor_16x64 neon sse2/;
+specialize qw/aom_dc_128_predictor_32x8 neon sse2/;
+specialize qw/aom_dc_128_predictor_32x16 neon sse2 avx2/;
 specialize qw/aom_dc_128_predictor_32x32 neon sse2 avx2/;
-specialize qw/aom_dc_128_predictor_32x64 sse2 avx2/;
-specialize qw/aom_dc_128_predictor_64x16 sse2 avx2/;
-specialize qw/aom_dc_128_predictor_64x32 sse2 avx2/;
-specialize qw/aom_dc_128_predictor_64x64 sse2 avx2/;
+specialize qw/aom_dc_128_predictor_32x64 neon sse2 avx2/;
+specialize qw/aom_dc_128_predictor_64x16 neon sse2 avx2/;
+specialize qw/aom_dc_128_predictor_64x32 neon sse2 avx2/;
+specialize qw/aom_dc_128_predictor_64x64 neon sse2 avx2/;
 
 specialize qw/aom_v_predictor_4x4 neon sse2/;
-specialize qw/aom_v_predictor_4x8 sse2/;
-specialize qw/aom_v_predictor_4x16 sse2/;
-specialize qw/aom_v_predictor_8x4 sse2/;
+specialize qw/aom_v_predictor_4x8 neon sse2/;
+specialize qw/aom_v_predictor_4x16 neon sse2/;
+specialize qw/aom_v_predictor_8x4 neon sse2/;
 specialize qw/aom_v_predictor_8x8 neon sse2/;
-specialize qw/aom_v_predictor_8x16 sse2/;
-specialize qw/aom_v_predictor_8x32 sse2/;
-specialize qw/aom_v_predictor_16x4 sse2/;
-specialize qw/aom_v_predictor_16x8 sse2/;
+specialize qw/aom_v_predictor_8x16 neon sse2/;
+specialize qw/aom_v_predictor_8x32 neon sse2/;
+specialize qw/aom_v_predictor_16x4 neon sse2/;
+specialize qw/aom_v_predictor_16x8 neon sse2/;
 specialize qw/aom_v_predictor_16x16 neon sse2/;
-specialize qw/aom_v_predictor_16x32 sse2/;
-specialize qw/aom_v_predictor_16x64 sse2/;
-specialize qw/aom_v_predictor_32x8 sse2/;
-specialize qw/aom_v_predictor_32x16 sse2 avx2/;
+specialize qw/aom_v_predictor_16x32 neon sse2/;
+specialize qw/aom_v_predictor_16x64 neon sse2/;
+specialize qw/aom_v_predictor_32x8 neon sse2/;
+specialize qw/aom_v_predictor_32x16 neon sse2 avx2/;
 specialize qw/aom_v_predictor_32x32 neon sse2 avx2/;
-specialize qw/aom_v_predictor_32x64 sse2 avx2/;
-specialize qw/aom_v_predictor_64x16 sse2 avx2/;
-specialize qw/aom_v_predictor_64x32 sse2 avx2/;
-specialize qw/aom_v_predictor_64x64 sse2 avx2/;
+specialize qw/aom_v_predictor_32x64 neon sse2 avx2/;
+specialize qw/aom_v_predictor_64x16 neon sse2 avx2/;
+specialize qw/aom_v_predictor_64x32 neon sse2 avx2/;
+specialize qw/aom_v_predictor_64x64 neon sse2 avx2/;
 
 specialize qw/aom_h_predictor_4x4 neon sse2/;
-specialize qw/aom_h_predictor_4x8 sse2/;
-specialize qw/aom_h_predictor_4x16 sse2/;
-specialize qw/aom_h_predictor_8x4 sse2/;
+specialize qw/aom_h_predictor_4x8 neon sse2/;
+specialize qw/aom_h_predictor_4x16 neon sse2/;
+specialize qw/aom_h_predictor_8x4 neon sse2/;
 specialize qw/aom_h_predictor_8x8 neon sse2/;
-specialize qw/aom_h_predictor_8x16 sse2/;
-specialize qw/aom_h_predictor_8x32 sse2/;
-specialize qw/aom_h_predictor_16x4 sse2/;
-specialize qw/aom_h_predictor_16x8 sse2/;
+specialize qw/aom_h_predictor_8x16 neon sse2/;
+specialize qw/aom_h_predictor_8x32 neon sse2/;
+specialize qw/aom_h_predictor_16x4 neon sse2/;
+specialize qw/aom_h_predictor_16x8 neon sse2/;
 specialize qw/aom_h_predictor_16x16 neon sse2/;
-specialize qw/aom_h_predictor_16x32 sse2/;
-specialize qw/aom_h_predictor_16x64 sse2/;
-specialize qw/aom_h_predictor_32x8 sse2/;
-specialize qw/aom_h_predictor_32x16 sse2/;
+specialize qw/aom_h_predictor_16x32 neon sse2/;
+specialize qw/aom_h_predictor_16x64 neon sse2/;
+specialize qw/aom_h_predictor_32x8 neon sse2/;
+specialize qw/aom_h_predictor_32x16 neon sse2/;
 specialize qw/aom_h_predictor_32x32 neon sse2 avx2/;
-specialize qw/aom_h_predictor_32x64 sse2/;
-specialize qw/aom_h_predictor_64x16 sse2/;
-specialize qw/aom_h_predictor_64x32 sse2/;
-specialize qw/aom_h_predictor_64x64 sse2/;
+specialize qw/aom_h_predictor_32x64 neon sse2/;
+specialize qw/aom_h_predictor_64x16 neon sse2/;
+specialize qw/aom_h_predictor_64x32 neon sse2/;
+specialize qw/aom_h_predictor_64x64 neon sse2/;
 
 specialize qw/aom_paeth_predictor_4x4 ssse3 neon/;
 specialize qw/aom_paeth_predictor_4x8 ssse3 neon/;
@@ -268,24 +268,24 @@
 # TODO(yunqingwang): optimize rectangular DC_PRED to replace division
 # by multiply and shift.
 specialize qw/aom_dc_predictor_4x4 neon sse2/;
-specialize qw/aom_dc_predictor_4x8 sse2/;
-specialize qw/aom_dc_predictor_4x16 sse2/;
-specialize qw/aom_dc_predictor_8x4 sse2/;
+specialize qw/aom_dc_predictor_4x8 neon sse2/;
+specialize qw/aom_dc_predictor_4x16 neon sse2/;
+specialize qw/aom_dc_predictor_8x4 neon sse2/;
 specialize qw/aom_dc_predictor_8x8 neon sse2/;
-specialize qw/aom_dc_predictor_8x16 sse2/;
-specialize qw/aom_dc_predictor_8x32 sse2/;
-specialize qw/aom_dc_predictor_16x4 sse2/;
-specialize qw/aom_dc_predictor_16x8 sse2/;
+specialize qw/aom_dc_predictor_8x16 neon sse2/;
+specialize qw/aom_dc_predictor_8x32 neon sse2/;
+specialize qw/aom_dc_predictor_16x4 neon sse2/;
+specialize qw/aom_dc_predictor_16x8 neon sse2/;
 specialize qw/aom_dc_predictor_16x16 neon sse2/;
-specialize qw/aom_dc_predictor_16x32 sse2/;
-specialize qw/aom_dc_predictor_16x64 sse2/;
-specialize qw/aom_dc_predictor_32x8 sse2/;
-specialize qw/aom_dc_predictor_32x16 sse2 avx2/;
+specialize qw/aom_dc_predictor_16x32 neon sse2/;
+specialize qw/aom_dc_predictor_16x64 neon sse2/;
+specialize qw/aom_dc_predictor_32x8 neon sse2/;
+specialize qw/aom_dc_predictor_32x16 neon sse2 avx2/;
 specialize qw/aom_dc_predictor_32x32 neon sse2 avx2/;
-specialize qw/aom_dc_predictor_32x64 sse2 avx2/;
-specialize qw/aom_dc_predictor_64x64 sse2 avx2/;
-specialize qw/aom_dc_predictor_64x32 sse2 avx2/;
-specialize qw/aom_dc_predictor_64x16 sse2 avx2/;
+specialize qw/aom_dc_predictor_32x64 neon sse2 avx2/;
+specialize qw/aom_dc_predictor_64x64 neon sse2 avx2/;
+specialize qw/aom_dc_predictor_64x32 neon sse2 avx2/;
+specialize qw/aom_dc_predictor_64x16 neon sse2 avx2/;
 if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
   specialize qw/aom_highbd_v_predictor_4x4 sse2 neon/;
   specialize qw/aom_highbd_v_predictor_4x8 sse2 neon/;
@@ -607,12 +607,16 @@
     add_proto qw/void aom_fdct4x4_lp/, "const int16_t *input, int16_t *output, int stride";
     specialize qw/aom_fdct4x4_lp neon sse2/;
 
-    add_proto qw/void aom_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride";
-    specialize qw/aom_fdct8x8 neon sse2/, "$ssse3_x86_64";
-    # High bit depth
-    if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
-      add_proto qw/void aom_highbd_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride";
-      specialize qw/aom_highbd_fdct8x8 sse2/;
+    if (aom_config("CONFIG_INTERNAL_STATS") eq "yes"){
+      # 8x8 DCT transform for psnr-hvs. Unlike other transforms isn't compatible
+      # with av1 scan orders, because it does two transposes.
+      add_proto qw/void aom_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride";
+      specialize qw/aom_fdct8x8 neon sse2/, "$ssse3_x86_64";
+      # High bit depth
+      if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
+        add_proto qw/void aom_highbd_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride";
+        specialize qw/aom_highbd_fdct8x8 sse2/;
+      }
     }
     # FFT/IFFT (float) only used for denoising (and noise power spectral density estimation)
     add_proto qw/void aom_fft2x2_float/, "const float *input, float *temp, float *output";
@@ -834,43 +838,31 @@
   specialize qw/aom_sad16x64_avg        sse2 neon/;
   specialize qw/aom_sad64x16_avg        sse2 neon/;
 
-  specialize qw/aom_dist_wtd_sad128x128_avg ssse3/;
-  specialize qw/aom_dist_wtd_sad128x64_avg  ssse3/;
-  specialize qw/aom_dist_wtd_sad64x128_avg  ssse3/;
-  specialize qw/aom_dist_wtd_sad64x64_avg   ssse3/;
-  specialize qw/aom_dist_wtd_sad64x32_avg   ssse3/;
-  specialize qw/aom_dist_wtd_sad32x64_avg   ssse3/;
-  specialize qw/aom_dist_wtd_sad32x32_avg   ssse3/;
-  specialize qw/aom_dist_wtd_sad32x16_avg   ssse3/;
-  specialize qw/aom_dist_wtd_sad16x32_avg   ssse3/;
-  specialize qw/aom_dist_wtd_sad16x16_avg   ssse3/;
-  specialize qw/aom_dist_wtd_sad16x8_avg    ssse3/;
-  specialize qw/aom_dist_wtd_sad8x16_avg    ssse3/;
-  specialize qw/aom_dist_wtd_sad8x8_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad8x4_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad4x8_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad4x4_avg     ssse3/;
+  specialize qw/aom_dist_wtd_sad128x128_avg sse2/;
+  specialize qw/aom_dist_wtd_sad128x64_avg  sse2/;
+  specialize qw/aom_dist_wtd_sad64x128_avg  sse2/;
+  specialize qw/aom_dist_wtd_sad64x64_avg   sse2/;
+  specialize qw/aom_dist_wtd_sad64x32_avg   sse2/;
+  specialize qw/aom_dist_wtd_sad32x64_avg   sse2/;
+  specialize qw/aom_dist_wtd_sad32x32_avg   sse2/;
+  specialize qw/aom_dist_wtd_sad32x16_avg   sse2/;
+  specialize qw/aom_dist_wtd_sad16x32_avg   sse2/;
+  specialize qw/aom_dist_wtd_sad16x16_avg   sse2/;
+  specialize qw/aom_dist_wtd_sad16x8_avg    sse2/;
+  specialize qw/aom_dist_wtd_sad8x16_avg    sse2/;
+  specialize qw/aom_dist_wtd_sad8x8_avg     sse2/;
+  specialize qw/aom_dist_wtd_sad8x4_avg     sse2/;
+  specialize qw/aom_dist_wtd_sad4x8_avg     sse2/;
+  specialize qw/aom_dist_wtd_sad4x4_avg     sse2/;
 
-  specialize qw/aom_dist_wtd_sad4x16_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad16x4_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad8x32_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad32x8_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad16x64_avg    ssse3/;
-  specialize qw/aom_dist_wtd_sad64x16_avg    ssse3/;
-
-  add_proto qw/unsigned int/, "aom_sad4xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
-  add_proto qw/unsigned int/, "aom_sad8xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
-  add_proto qw/unsigned int/, "aom_sad16xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
-  add_proto qw/unsigned int/, "aom_sad32xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
-  add_proto qw/unsigned int/, "aom_sad64xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
-  add_proto qw/unsigned int/, "aom_sad128xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
-
-  specialize qw/aom_sad4xh   sse2/;
-  specialize qw/aom_sad8xh   sse2/;
-  specialize qw/aom_sad16xh  sse2/;
-  specialize qw/aom_sad32xh  sse2/;
-  specialize qw/aom_sad64xh  sse2/;
-  specialize qw/aom_sad128xh sse2/;
+  if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
+    specialize qw/aom_dist_wtd_sad4x16_avg     sse2/;
+    specialize qw/aom_dist_wtd_sad16x4_avg     sse2/;
+    specialize qw/aom_dist_wtd_sad8x32_avg     sse2/;
+    specialize qw/aom_dist_wtd_sad32x8_avg     sse2/;
+    specialize qw/aom_dist_wtd_sad16x64_avg    sse2/;
+    specialize qw/aom_dist_wtd_sad64x16_avg    sse2/;
+  }
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
     foreach (@encoder_block_sizes) {
@@ -957,7 +949,7 @@
   foreach (@encoder_block_sizes) {
     ($w, $h) = @$_;
     add_proto qw/unsigned int/, "aom_masked_sad${w}x${h}", "const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, const uint8_t *second_pred, const uint8_t *msk, int msk_stride, int invert_mask";
-    specialize "aom_masked_sad${w}x${h}", qw/ssse3 avx2/;
+    specialize "aom_masked_sad${w}x${h}", qw/ssse3 avx2 neon/;
   }
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
@@ -998,7 +990,6 @@
     ($w, $h) = @$_;
     add_proto qw/void/, "aom_sad${w}x${h}x4d", "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[4], int ref_stride, uint32_t sad_array[4]";
     add_proto qw/void/, "aom_sad${w}x${h}x3d", "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[4], int ref_stride, uint32_t sad_array[4]";
-    add_proto qw/void/, "aom_sad${w}x${h}x4d_avg", "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[4], int ref_stride, const uint8_t *second_pred, uint32_t sad_array[4]";
     add_proto qw/void/, "aom_sad_skip_${w}x${h}x4d", "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[4], int ref_stride, uint32_t sad_array[4]";
     add_proto qw/void/, "aom_masked_sad${w}x${h}x4d", "const uint8_t *src, int src_stride, const uint8_t *ref[4], int ref_stride, const uint8_t *second_pred, const uint8_t *msk, int msk_stride, int invert_mask, unsigned sads[4]";
   }
@@ -1067,37 +1058,6 @@
   specialize qw/aom_sad32x8x3d    avx2/;
   specialize qw/aom_sad16x64x3d   avx2/;
 
-  if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
-    specialize qw/aom_sad128x128x4d_avg sse2/;
-    specialize qw/aom_sad128x64x4d_avg  sse2/;
-    specialize qw/aom_sad64x128x4d_avg  sse2/;
-    specialize qw/aom_sad64x64x4d_avg   sse2/;
-    specialize qw/aom_sad64x32x4d_avg   sse2/;
-    specialize qw/aom_sad64x16x4d_avg   sse2/;
-    specialize qw/aom_sad32x64x4d_avg   sse2/;
-    specialize qw/aom_sad32x32x4d_avg   sse2/;
-    specialize qw/aom_sad32x16x4d_avg   sse2/;
-    specialize qw/aom_sad32x8x4d_avg    sse2/;
-    specialize qw/aom_sad16x64x4d_avg   sse2/;
-    specialize qw/aom_sad16x32x4d_avg   sse2/;
-    specialize qw/aom_sad16x16x4d_avg   sse2/;
-    specialize qw/aom_sad16x8x4d_avg    sse2/;
-
-    specialize qw/aom_sad8x16x4d_avg    sse2/;
-    specialize qw/aom_sad8x8x4d_avg     sse2/;
-    specialize qw/aom_sad8x4x4d_avg     sse2/;
-    specialize qw/aom_sad4x16x4d_avg    sse2/;
-    specialize qw/aom_sad4x8x4d_avg     sse2/;
-    specialize qw/aom_sad4x4x4d_avg     sse2/;
-
-    specialize qw/aom_sad4x32x4d_avg    sse2/;
-    specialize qw/aom_sad4x16x4d_avg    sse2/;
-    specialize qw/aom_sad16x4x4d_avg    sse2/;
-    specialize qw/aom_sad8x32x4d_avg    sse2/;
-    specialize qw/aom_sad32x8x4d_avg    sse2/;
-    specialize qw/aom_sad64x16x4d_avg   sse2/;
-  }
-
   specialize qw/aom_masked_sad128x128x4d  ssse3/;
   specialize qw/aom_masked_sad128x64x4d   ssse3/;
   specialize qw/aom_masked_sad64x128x4d   ssse3/;
@@ -1214,7 +1174,7 @@
   specialize qw/aom_avg_8x8_quad avx2 sse2 neon/;
 
   add_proto qw/void aom_minmax_8x8/, "const uint8_t *s, int p, const uint8_t *d, int dp, int *min, int *max";
-  specialize qw/aom_minmax_8x8 sse2/;
+  specialize qw/aom_minmax_8x8 sse2 neon/;
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
     add_proto qw/unsigned int aom_highbd_avg_8x8/, "const uint8_t *, int p";
@@ -1238,7 +1198,7 @@
   # hamadard transform and satd for implmenting temporal dependency model
   #
   add_proto qw/void aom_hadamard_4x4/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
-  specialize qw/aom_hadamard_4x4 sse2/;
+  specialize qw/aom_hadamard_4x4 sse2 neon/;
 
   add_proto qw/void aom_hadamard_8x8/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
   specialize qw/aom_hadamard_8x8 sse2 neon/;
@@ -1247,7 +1207,7 @@
   specialize qw/aom_hadamard_16x16 avx2 sse2 neon/;
 
   add_proto qw/void aom_hadamard_32x32/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
-  specialize qw/aom_hadamard_32x32 avx2 sse2/;
+  specialize qw/aom_hadamard_32x32 avx2 sse2 neon/;
 
   add_proto qw/void aom_hadamard_lp_8x8/, "const int16_t *src_diff, ptrdiff_t src_stride, int16_t *coeff";
   specialize qw/aom_hadamard_lp_8x8 sse2 neon/;
@@ -1258,9 +1218,6 @@
   add_proto qw/void aom_hadamard_lp_8x8_dual/, "const int16_t *src_diff, ptrdiff_t src_stride, int16_t *coeff";
   specialize qw/aom_hadamard_lp_8x8_dual sse2 avx2 neon/;
 
-  add_proto qw/void aom_pixel_scale/, "const int16_t *src_diff, ptrdiff_t src_stride, int16_t *coeff, int log_scale, int h8, int w8";
-  specialize qw/aom_pixel_scale sse2/;
-
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
     add_proto qw/void aom_highbd_hadamard_8x8/, "const int16_t *src_diff, ptrdiff_t src_stride, tran_low_t *coeff";
     specialize qw/aom_highbd_hadamard_8x8 avx2/;
@@ -1299,17 +1256,11 @@
   #
   # Specialty Variance
   #
-  add_proto qw/void aom_get16x16var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-  add_proto qw/void aom_get8x8var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-
-  specialize qw/aom_get16x16var                neon/;
-  specialize qw/aom_get8x8var             sse2 neon/;
-
   add_proto qw/void aom_get_var_sse_sum_8x8_quad/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, uint32_t *sse8x8, int *sum8x8, unsigned int *tot_sse, int *tot_sum, uint32_t *var8x8";
   specialize qw/aom_get_var_sse_sum_8x8_quad        avx2 sse2 neon/;
 
   add_proto qw/void aom_get_var_sse_sum_16x16_dual/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, uint32_t *sse16x16, unsigned int *tot_sse, int *tot_sum, uint32_t *var16x16";
-  specialize qw/aom_get_var_sse_sum_16x16_dual        avx2/;
+  specialize qw/aom_get_var_sse_sum_16x16_dual        avx2 sse2 neon/;
 
   add_proto qw/unsigned int aom_mse16x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
   add_proto qw/unsigned int aom_mse16x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
@@ -1323,9 +1274,6 @@
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
     foreach $bd (8, 10, 12) {
-      add_proto qw/void/, "aom_highbd_${bd}_get16x16var", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-      add_proto qw/void/, "aom_highbd_${bd}_get8x8var", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-
       add_proto qw/unsigned int/, "aom_highbd_${bd}_mse16x16", "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
       add_proto qw/unsigned int/, "aom_highbd_${bd}_mse16x8", "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
       add_proto qw/unsigned int/, "aom_highbd_${bd}_mse8x16", "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
@@ -1340,10 +1288,7 @@
   #
   #
   add_proto qw/unsigned int aom_get_mb_ss/, "const int16_t *";
-  add_proto qw/unsigned int aom_get4x4sse_cs/, "const unsigned char *src_ptr, int source_stride, const unsigned char *ref_ptr, int ref_stride";
-
   specialize qw/aom_get_mb_ss sse2/;
-  specialize qw/aom_get4x4sse_cs neon/;
 
   #
   # Variance / Subpixel Variance / Subpixel Avg Variance
@@ -1522,7 +1467,7 @@
   foreach (@encoder_block_sizes) {
     ($w, $h) = @$_;
     add_proto qw/unsigned int/, "aom_masked_sub_pixel_variance${w}x${h}", "const uint8_t *src, int src_stride, int xoffset, int yoffset, const uint8_t *ref, int ref_stride, const uint8_t *second_pred, const uint8_t *msk, int msk_stride, int invert_mask, unsigned int *sse";
-    specialize "aom_masked_sub_pixel_variance${w}x${h}", qw/ssse3/;
+    specialize "aom_masked_sub_pixel_variance${w}x${h}", qw/ssse3 neon/;
   }
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
@@ -1543,8 +1488,8 @@
       ($w, $h) = @$_;
       add_proto qw/unsigned int/, "aom_obmc_variance${w}x${h}", "const uint8_t *pre, int pre_stride, const int32_t *wsrc, const int32_t *mask, unsigned int *sse";
       add_proto qw/unsigned int/, "aom_obmc_sub_pixel_variance${w}x${h}", "const uint8_t *pre, int pre_stride, int xoffset, int yoffset, const int32_t *wsrc, const int32_t *mask, unsigned int *sse";
-      specialize "aom_obmc_variance${w}x${h}", qw/sse4_1 avx2/;
-      specialize "aom_obmc_sub_pixel_variance${w}x${h}", q/sse4_1/;
+      specialize "aom_obmc_variance${w}x${h}", qw/sse4_1 avx2 neon/;
+      specialize "aom_obmc_sub_pixel_variance${w}x${h}", qw/sse4_1 neon/;
     }
 
     if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
@@ -1602,6 +1547,7 @@
   # Comp Avg
   #
   add_proto qw/void aom_comp_avg_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride";
+  specialize qw/aom_comp_avg_pred avx2 neon/;
 
   add_proto qw/void aom_dist_wtd_comp_avg_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const DIST_WTD_COMP_PARAMS *jcp_param";
   specialize qw/aom_dist_wtd_comp_avg_pred ssse3/;
@@ -1737,15 +1683,6 @@
     add_proto qw/unsigned int aom_highbd_8_variance4x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
     add_proto qw/unsigned int aom_highbd_8_variance4x4/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
 
-    add_proto qw/void aom_highbd_8_get16x16var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-    add_proto qw/void aom_highbd_8_get8x8var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-
-    add_proto qw/void aom_highbd_10_get16x16var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-    add_proto qw/void aom_highbd_10_get8x8var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-
-    add_proto qw/void aom_highbd_12_get16x16var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-    add_proto qw/void aom_highbd_12_get8x8var/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse, int *sum";
-
     add_proto qw/unsigned int aom_highbd_8_mse16x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
     specialize qw/aom_highbd_8_mse16x16 sse2/;
 
@@ -2028,7 +1965,7 @@
 
 
   add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
-  specialize qw/aom_comp_mask_pred ssse3 avx2/;
+  specialize qw/aom_comp_mask_pred ssse3 avx2 neon/;
 
   if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
     add_proto qw/void aom_highbd_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
@@ -2037,8 +1974,11 @@
 
   # Flow estimation library
   if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") {
-    add_proto qw/double av1_compute_cross_correlation/, "unsigned char *im1, int stride1, int x1, int y1, unsigned char *im2, int stride2, int x2, int y2";
+    add_proto qw/double av1_compute_cross_correlation/, "const unsigned char *frame1, int stride1, int x1, int y1, const unsigned char *frame2, int stride2, int x2, int y2";
     specialize qw/av1_compute_cross_correlation sse4_1 avx2/;
+
+    add_proto qw/void aom_compute_flow_at_point/, "const uint8_t *src, const uint8_t *ref, int x, int y, int width, int height, int stride, double *u, double *v";
+    specialize qw/aom_compute_flow_at_point sse4_1/;
   }
 
 }  # CONFIG_AV1_ENCODER
diff --git a/aom_dsp/arm/avg_neon.c b/aom_dsp/arm/avg_neon.c
index 991fd3f..bac50ca 100644
--- a/aom_dsp/arm/avg_neon.c
+++ b/aom_dsp/arm/avg_neon.c
@@ -216,3 +216,57 @@
 #endif
 }
 #endif  // CONFIG_AV1_HIGHBITDEPTH
+
+void aom_minmax_8x8_neon(const uint8_t *a, int a_stride, const uint8_t *b,
+                         int b_stride, int *min, int *max) {
+  // Load and concatenate.
+  const uint8x16_t a01 = load_u8_8x2(a + 0 * a_stride, a_stride);
+  const uint8x16_t a23 = load_u8_8x2(a + 2 * a_stride, a_stride);
+  const uint8x16_t a45 = load_u8_8x2(a + 4 * a_stride, a_stride);
+  const uint8x16_t a67 = load_u8_8x2(a + 6 * a_stride, a_stride);
+
+  const uint8x16_t b01 = load_u8_8x2(b + 0 * b_stride, b_stride);
+  const uint8x16_t b23 = load_u8_8x2(b + 2 * b_stride, b_stride);
+  const uint8x16_t b45 = load_u8_8x2(b + 4 * b_stride, b_stride);
+  const uint8x16_t b67 = load_u8_8x2(b + 6 * b_stride, b_stride);
+
+  // Absolute difference.
+  const uint8x16_t ab01_diff = vabdq_u8(a01, b01);
+  const uint8x16_t ab23_diff = vabdq_u8(a23, b23);
+  const uint8x16_t ab45_diff = vabdq_u8(a45, b45);
+  const uint8x16_t ab67_diff = vabdq_u8(a67, b67);
+
+  // Max values between the Q vectors.
+  const uint8x16_t ab0123_max = vmaxq_u8(ab01_diff, ab23_diff);
+  const uint8x16_t ab4567_max = vmaxq_u8(ab45_diff, ab67_diff);
+  const uint8x16_t ab0123_min = vminq_u8(ab01_diff, ab23_diff);
+  const uint8x16_t ab4567_min = vminq_u8(ab45_diff, ab67_diff);
+
+  const uint8x16_t ab07_max = vmaxq_u8(ab0123_max, ab4567_max);
+  const uint8x16_t ab07_min = vminq_u8(ab0123_min, ab4567_min);
+
+#if defined(__aarch64__)
+  *min = *max = 0;  // Clear high bits
+  *((uint8_t *)max) = vmaxvq_u8(ab07_max);
+  *((uint8_t *)min) = vminvq_u8(ab07_min);
+#else
+  // Split into 64-bit vectors and execute pairwise min/max.
+  uint8x8_t ab_max = vmax_u8(vget_high_u8(ab07_max), vget_low_u8(ab07_max));
+  uint8x8_t ab_min = vmin_u8(vget_high_u8(ab07_min), vget_low_u8(ab07_min));
+
+  // Enough runs of vpmax/min propagate the max/min values to every position.
+  ab_max = vpmax_u8(ab_max, ab_max);
+  ab_min = vpmin_u8(ab_min, ab_min);
+
+  ab_max = vpmax_u8(ab_max, ab_max);
+  ab_min = vpmin_u8(ab_min, ab_min);
+
+  ab_max = vpmax_u8(ab_max, ab_max);
+  ab_min = vpmin_u8(ab_min, ab_min);
+
+  *min = *max = 0;  // Clear high bits
+  // Store directly to avoid costly neon->gpr transfer.
+  vst1_lane_u8((uint8_t *)max, ab_max, 0);
+  vst1_lane_u8((uint8_t *)min, ab_min, 0);
+#endif
+}
diff --git a/aom_dsp/arm/avg_pred_neon.c b/aom_dsp/arm/avg_pred_neon.c
new file mode 100644
index 0000000..04e0904
--- /dev/null
+++ b/aom_dsp/arm/avg_pred_neon.c
@@ -0,0 +1,171 @@
+/*
+ * Copyright (c) 2023, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+
+#include "config/aom_dsp_rtcd.h"
+#include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/blend.h"
+
+void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
+                            int height, const uint8_t *ref, int ref_stride) {
+  if (width > 8) {
+    do {
+      const uint8_t *pred_ptr = pred;
+      const uint8_t *ref_ptr = ref;
+      uint8_t *comp_pred_ptr = comp_pred;
+      int w = width;
+
+      do {
+        const uint8x16_t p = vld1q_u8(pred_ptr);
+        const uint8x16_t r = vld1q_u8(ref_ptr);
+        const uint8x16_t avg = vrhaddq_u8(p, r);
+
+        vst1q_u8(comp_pred_ptr, avg);
+
+        ref_ptr += 16;
+        pred_ptr += 16;
+        comp_pred_ptr += 16;
+        w -= 16;
+      } while (w != 0);
+
+      ref += ref_stride;
+      pred += width;
+      comp_pred += width;
+    } while (--height != 0);
+  } else if (width == 8) {
+    int h = height / 2;
+
+    do {
+      const uint8x16_t p = vld1q_u8(pred);
+      const uint8x16_t r = load_u8_8x2(ref, ref_stride);
+      const uint8x16_t avg = vrhaddq_u8(p, r);
+
+      vst1q_u8(comp_pred, avg);
+
+      ref += 2 * ref_stride;
+      pred += 16;
+      comp_pred += 16;
+    } while (--h != 0);
+  } else {
+    int h = height / 4;
+    assert(width == 4);
+
+    do {
+      const uint8x16_t p = vld1q_u8(pred);
+      const uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
+      const uint8x16_t avg = vrhaddq_u8(p, r);
+
+      vst1q_u8(comp_pred, avg);
+
+      ref += 4 * ref_stride;
+      pred += 16;
+      comp_pred += 16;
+    } while (--h != 0);
+  }
+}
+
+void aom_comp_mask_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
+                             int height, const uint8_t *ref, int ref_stride,
+                             const uint8_t *mask, int mask_stride,
+                             int invert_mask) {
+  const uint8_t *src0 = invert_mask ? pred : ref;
+  const uint8_t *src1 = invert_mask ? ref : pred;
+  const int src_stride0 = invert_mask ? width : ref_stride;
+  const int src_stride1 = invert_mask ? ref_stride : width;
+
+  if (width > 8) {
+    const uint8x16_t max_alpha = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+    do {
+      const uint8_t *src0_ptr = src0;
+      const uint8_t *src1_ptr = src1;
+      const uint8_t *mask_ptr = mask;
+      uint8_t *comp_pred_ptr = comp_pred;
+      int w = width;
+
+      do {
+        const uint8x16_t s0 = vld1q_u8(src0_ptr);
+        const uint8x16_t s1 = vld1q_u8(src1_ptr);
+        const uint8x16_t m0 = vld1q_u8(mask_ptr);
+
+        uint8x16_t m0_inv = vsubq_u8(max_alpha, m0);
+        uint16x8_t blend_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(m0));
+        uint16x8_t blend_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(m0));
+        blend_u16_lo =
+            vmlal_u8(blend_u16_lo, vget_low_u8(s1), vget_low_u8(m0_inv));
+        blend_u16_hi =
+            vmlal_u8(blend_u16_hi, vget_high_u8(s1), vget_high_u8(m0_inv));
+
+        uint8x8_t blend_u8_lo =
+            vrshrn_n_u16(blend_u16_lo, AOM_BLEND_A64_ROUND_BITS);
+        uint8x8_t blend_u8_hi =
+            vrshrn_n_u16(blend_u16_hi, AOM_BLEND_A64_ROUND_BITS);
+        uint8x16_t blend_u8 = vcombine_u8(blend_u8_lo, blend_u8_hi);
+
+        vst1q_u8(comp_pred_ptr, blend_u8);
+
+        src0_ptr += 16;
+        src1_ptr += 16;
+        mask_ptr += 16;
+        comp_pred_ptr += 16;
+        w -= 16;
+      } while (w != 0);
+
+      src0 += src_stride0;
+      src1 += src_stride1;
+      mask += mask_stride;
+      comp_pred += width;
+    } while (--height != 0);
+  } else if (width == 8) {
+    const uint8x8_t max_alpha = vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+
+    do {
+      const uint8x8_t s0 = vld1_u8(src0);
+      const uint8x8_t s1 = vld1_u8(src1);
+      const uint8x8_t m0 = vld1_u8(mask);
+
+      uint8x8_t m0_inv = vsub_u8(max_alpha, m0);
+      uint16x8_t blend_u16 = vmull_u8(s0, m0);
+      blend_u16 = vmlal_u8(blend_u16, s1, m0_inv);
+      uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+      vst1_u8(comp_pred, blend_u8);
+
+      src0 += src_stride0;
+      src1 += src_stride1;
+      mask += mask_stride;
+      comp_pred += 8;
+    } while (--height != 0);
+  } else {
+    const uint8x8_t max_alpha = vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA);
+    int h = height / 2;
+    assert(width == 4);
+
+    do {
+      const uint8x8_t s0 = load_unaligned_u8(src0, src_stride0);
+      const uint8x8_t s1 = load_unaligned_u8(src1, src_stride1);
+      const uint8x8_t m0 = load_unaligned_u8(mask, mask_stride);
+
+      uint8x8_t m0_inv = vsub_u8(max_alpha, m0);
+      uint16x8_t blend_u16 = vmull_u8(s0, m0);
+      blend_u16 = vmlal_u8(blend_u16, s1, m0_inv);
+      uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+      vst1_u8(comp_pred, blend_u8);
+
+      src0 += 2 * src_stride0;
+      src1 += 2 * src_stride1;
+      mask += 2 * mask_stride;
+      comp_pred += 8;
+    } while (--h != 0);
+  }
+}
diff --git a/aom_dsp/arm/blend_a64_mask_neon.c b/aom_dsp/arm/blend_a64_mask_neon.c
index f11d57e..c3ee0b7 100644
--- a/aom_dsp/arm/blend_a64_mask_neon.c
+++ b/aom_dsp/arm/blend_a64_mask_neon.c
@@ -86,19 +86,21 @@
                              const int16x8_t vec_round_bits) {
   int16x8_t src0_0, src0_1;
   int16x8_t src1_0, src1_1;
-  uint64x2_t tu0 = vdupq_n_u64(0), tu1 = vdupq_n_u64(0), tu2 = vdupq_n_u64(0),
-             tu3 = vdupq_n_u64(0);
+  uint16x8_t tu0 = vdupq_n_u16(0);
+  uint16x8_t tu1 = vdupq_n_u16(0);
+  uint16x8_t tu2 = vdupq_n_u16(0);
+  uint16x8_t tu3 = vdupq_n_u16(0);
   int16x8_t mask0_1, mask2_3;
   int16x8_t res0, res1;
 
   load_unaligned_u16_4x4(src0, src0_stride, &tu0, &tu1);
   load_unaligned_u16_4x4(src1, src1_stride, &tu2, &tu3);
 
-  src0_0 = vreinterpretq_s16_u64(tu0);
-  src0_1 = vreinterpretq_s16_u64(tu1);
+  src0_0 = vreinterpretq_s16_u16(tu0);
+  src0_1 = vreinterpretq_s16_u16(tu1);
 
-  src1_0 = vreinterpretq_s16_u64(tu2);
-  src1_1 = vreinterpretq_s16_u64(tu3);
+  src1_0 = vreinterpretq_s16_u16(tu2);
+  src1_1 = vreinterpretq_s16_u16(tu3);
 
   mask0_1 = vcombine_s16(mask0, mask1);
   mask2_3 = vcombine_s16(mask2, mask3);
@@ -150,9 +152,10 @@
   assert(IS_POWER_OF_TWO(h));
   assert(IS_POWER_OF_TWO(w));
 
-  uint8x8_t s0, s1, s2, s3;
-  uint32x2_t tu0 = vdup_n_u32(0), tu1 = vdup_n_u32(0), tu2 = vdup_n_u32(0),
-             tu3 = vdup_n_u32(0);
+  uint8x8_t s0 = vdup_n_u8(0);
+  uint8x8_t s1 = vdup_n_u8(0);
+  uint8x8_t s2 = vdup_n_u8(0);
+  uint8x8_t s3 = vdup_n_u8(0);
   uint8x16_t t0, t1, t2, t3, t4, t5, t6, t7;
   int16x8_t mask0, mask1, mask2, mask3;
   int16x8_t mask4, mask5, mask6, mask7;
@@ -197,10 +200,10 @@
       } while (i < h);
     } else {
       do {
-        load_unaligned_u8_4x4(mask_tmp, mask_stride, &tu0, &tu1);
+        load_unaligned_u8_4x4(mask_tmp, mask_stride, &s0, &s1);
 
-        mask0 = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(tu0)));
-        mask1 = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(tu1)));
+        mask0 = vreinterpretq_s16_u16(vmovl_u8(s0));
+        mask1 = vreinterpretq_s16_u16(vmovl_u8(s1));
 
         mask0_low = vget_low_s16(mask0);
         mask1_low = vget_high_s16(mask0);
@@ -412,14 +415,9 @@
       } while (i < h);
     } else {
       do {
-        load_unaligned_u8_4x4(mask_tmp, 2 * mask_stride, &tu0, &tu1);
-        load_unaligned_u8_4x4(mask_tmp + mask_stride, 2 * mask_stride, &tu2,
-                              &tu3);
-
-        s0 = vreinterpret_u8_u32(tu0);
-        s1 = vreinterpret_u8_u32(tu1);
-        s2 = vreinterpret_u8_u32(tu2);
-        s3 = vreinterpret_u8_u32(tu3);
+        load_unaligned_u8_4x4(mask_tmp, 2 * mask_stride, &s0, &s1);
+        load_unaligned_u8_4x4(mask_tmp + mask_stride, 2 * mask_stride, &s2,
+                              &s3);
 
         mask0 = vreinterpretq_s16_u16(vaddl_u8(s0, s2));
         mask1 = vreinterpretq_s16_u16(vaddl_u8(s1, s3));
diff --git a/aom_dsp/arm/fwd_txfm_neon.c b/aom_dsp/arm/fwd_txfm_neon.c
index 7fccdab..a7d66b3 100644
--- a/aom_dsp/arm/fwd_txfm_neon.c
+++ b/aom_dsp/arm/fwd_txfm_neon.c
@@ -67,7 +67,10 @@
     int16x4_t out_1 = vrshrn_n_s32(temp3, DCT_CONST_BITS);
     int16x4_t out_3 = vrshrn_n_s32(temp4, DCT_CONST_BITS);
 
-    transpose_s16_4x4d(&out_0, &out_1, &out_2, &out_3);
+    // Only transpose the first pass
+    if (i == 0) {
+      transpose_s16_4x4d(&out_0, &out_1, &out_2, &out_3);
+    }
 
     *input_0 = out_0;
     *input_1 = out_1;
diff --git a/aom_dsp/arm/hadamard_neon.c b/aom_dsp/arm/hadamard_neon.c
index 75dd7d6..eda5db0 100644
--- a/aom_dsp/arm/hadamard_neon.c
+++ b/aom_dsp/arm/hadamard_neon.c
@@ -15,6 +15,38 @@
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
 
+static INLINE void hadamard_4x4_one_pass(int16x4_t *a0, int16x4_t *a1,
+                                         int16x4_t *a2, int16x4_t *a3) {
+  const int16x4_t b0 = vhadd_s16(*a0, *a1);
+  const int16x4_t b1 = vhsub_s16(*a0, *a1);
+  const int16x4_t b2 = vhadd_s16(*a2, *a3);
+  const int16x4_t b3 = vhsub_s16(*a2, *a3);
+
+  *a0 = vadd_s16(b0, b2);
+  *a1 = vadd_s16(b1, b3);
+  *a2 = vsub_s16(b0, b2);
+  *a3 = vsub_s16(b1, b3);
+}
+
+void aom_hadamard_4x4_neon(const int16_t *src_diff, ptrdiff_t src_stride,
+                           tran_low_t *coeff) {
+  int16x4_t a0 = vld1_s16(src_diff);
+  int16x4_t a1 = vld1_s16(src_diff + src_stride);
+  int16x4_t a2 = vld1_s16(src_diff + 2 * src_stride);
+  int16x4_t a3 = vld1_s16(src_diff + 3 * src_stride);
+
+  hadamard_4x4_one_pass(&a0, &a1, &a2, &a3);
+
+  transpose_s16_4x4d(&a0, &a1, &a2, &a3);
+
+  hadamard_4x4_one_pass(&a0, &a1, &a2, &a3);
+
+  store_s16_to_tran_low(coeff, a0);
+  store_s16_to_tran_low(coeff + 4, a1);
+  store_s16_to_tran_low(coeff + 8, a2);
+  store_s16_to_tran_low(coeff + 12, a3);
+}
+
 static void hadamard8x8_one_pass(int16x8_t *a0, int16x8_t *a1, int16x8_t *a2,
                                  int16x8_t *a3, int16x8_t *a4, int16x8_t *a5,
                                  int16x8_t *a6, int16x8_t *a7) {
@@ -154,44 +186,106 @@
 
 void aom_hadamard_16x16_neon(const int16_t *src_diff, ptrdiff_t src_stride,
                              tran_low_t *coeff) {
-  DECLARE_ALIGNED(32, tran_low_t, temp_coeff[16 * 16]);
   /* Rearrange 16x16 to 8x32 and remove stride.
    * Top left first. */
-  aom_hadamard_8x8_neon(src_diff + 0 + 0 * src_stride, src_stride,
-                        temp_coeff + 0);
+  aom_hadamard_8x8_neon(src_diff + 0 + 0 * src_stride, src_stride, coeff + 0);
   /* Top right. */
-  aom_hadamard_8x8_neon(src_diff + 8 + 0 * src_stride, src_stride,
-                        temp_coeff + 64);
+  aom_hadamard_8x8_neon(src_diff + 8 + 0 * src_stride, src_stride, coeff + 64);
   /* Bottom left. */
-  aom_hadamard_8x8_neon(src_diff + 0 + 8 * src_stride, src_stride,
-                        temp_coeff + 128);
+  aom_hadamard_8x8_neon(src_diff + 0 + 8 * src_stride, src_stride, coeff + 128);
   /* Bottom right. */
-  aom_hadamard_8x8_neon(src_diff + 8 + 8 * src_stride, src_stride,
-                        temp_coeff + 192);
+  aom_hadamard_8x8_neon(src_diff + 8 + 8 * src_stride, src_stride, coeff + 192);
 
-  tran_low_t *t_coeff = temp_coeff;
-  for (int i = 0; i < 64; i += 8) {
-    const int16x8_t a0 = load_tran_low_to_s16q(t_coeff + 0);
-    const int16x8_t a1 = load_tran_low_to_s16q(t_coeff + 64);
-    const int16x8_t a2 = load_tran_low_to_s16q(t_coeff + 128);
-    const int16x8_t a3 = load_tran_low_to_s16q(t_coeff + 192);
+  for (int i = 0; i < 64; i += 16) {
+    const int16x8_t a00 = load_tran_low_to_s16q(coeff + 0);
+    const int16x8_t a01 = load_tran_low_to_s16q(coeff + 64);
+    const int16x8_t a02 = load_tran_low_to_s16q(coeff + 128);
+    const int16x8_t a03 = load_tran_low_to_s16q(coeff + 192);
 
-    const int16x8_t b0 = vhaddq_s16(a0, a1);
-    const int16x8_t b1 = vhsubq_s16(a0, a1);
-    const int16x8_t b2 = vhaddq_s16(a2, a3);
-    const int16x8_t b3 = vhsubq_s16(a2, a3);
+    const int16x8_t b00 = vhaddq_s16(a00, a01);
+    const int16x8_t b01 = vhsubq_s16(a00, a01);
+    const int16x8_t b02 = vhaddq_s16(a02, a03);
+    const int16x8_t b03 = vhsubq_s16(a02, a03);
+
+    const int16x8_t c00 = vaddq_s16(b00, b02);
+    const int16x8_t c01 = vaddq_s16(b01, b03);
+    const int16x8_t c02 = vsubq_s16(b00, b02);
+    const int16x8_t c03 = vsubq_s16(b01, b03);
+
+    const int16x8_t a10 = load_tran_low_to_s16q(coeff + 8 + 0);
+    const int16x8_t a11 = load_tran_low_to_s16q(coeff + 8 + 64);
+    const int16x8_t a12 = load_tran_low_to_s16q(coeff + 8 + 128);
+    const int16x8_t a13 = load_tran_low_to_s16q(coeff + 8 + 192);
+
+    const int16x8_t b10 = vhaddq_s16(a10, a11);
+    const int16x8_t b11 = vhsubq_s16(a10, a11);
+    const int16x8_t b12 = vhaddq_s16(a12, a13);
+    const int16x8_t b13 = vhsubq_s16(a12, a13);
+
+    const int16x8_t c10 = vaddq_s16(b10, b12);
+    const int16x8_t c11 = vaddq_s16(b11, b13);
+    const int16x8_t c12 = vsubq_s16(b10, b12);
+    const int16x8_t c13 = vsubq_s16(b11, b13);
+
+    store_s16_to_tran_low(coeff + 0 + 0, vget_low_s16(c00));
+    store_s16_to_tran_low(coeff + 0 + 4, vget_low_s16(c10));
+    store_s16_to_tran_low(coeff + 0 + 8, vget_high_s16(c00));
+    store_s16_to_tran_low(coeff + 0 + 12, vget_high_s16(c10));
+
+    store_s16_to_tran_low(coeff + 64 + 0, vget_low_s16(c01));
+    store_s16_to_tran_low(coeff + 64 + 4, vget_low_s16(c11));
+    store_s16_to_tran_low(coeff + 64 + 8, vget_high_s16(c01));
+    store_s16_to_tran_low(coeff + 64 + 12, vget_high_s16(c11));
+
+    store_s16_to_tran_low(coeff + 128 + 0, vget_low_s16(c02));
+    store_s16_to_tran_low(coeff + 128 + 4, vget_low_s16(c12));
+    store_s16_to_tran_low(coeff + 128 + 8, vget_high_s16(c02));
+    store_s16_to_tran_low(coeff + 128 + 12, vget_high_s16(c12));
+
+    store_s16_to_tran_low(coeff + 192 + 0, vget_low_s16(c03));
+    store_s16_to_tran_low(coeff + 192 + 4, vget_low_s16(c13));
+    store_s16_to_tran_low(coeff + 192 + 8, vget_high_s16(c03));
+    store_s16_to_tran_low(coeff + 192 + 12, vget_high_s16(c13));
+
+    coeff += 16;
+  }
+}
+
+void aom_hadamard_32x32_neon(const int16_t *src_diff, ptrdiff_t src_stride,
+                             tran_low_t *coeff) {
+  /* Top left first. */
+  aom_hadamard_16x16_neon(src_diff + 0 + 0 * src_stride, src_stride, coeff + 0);
+  /* Top right. */
+  aom_hadamard_16x16_neon(src_diff + 16 + 0 * src_stride, src_stride,
+                          coeff + 256);
+  /* Bottom left. */
+  aom_hadamard_16x16_neon(src_diff + 0 + 16 * src_stride, src_stride,
+                          coeff + 512);
+  /* Bottom right. */
+  aom_hadamard_16x16_neon(src_diff + 16 + 16 * src_stride, src_stride,
+                          coeff + 768);
+
+  for (int i = 0; i < 256; i += 8) {
+    const int16x8_t a0 = load_tran_low_to_s16q(coeff);
+    const int16x8_t a1 = load_tran_low_to_s16q(coeff + 256);
+    const int16x8_t a2 = load_tran_low_to_s16q(coeff + 512);
+    const int16x8_t a3 = load_tran_low_to_s16q(coeff + 768);
+
+    const int16x8_t b0 = vshrq_n_s16(vaddq_s16(a0, a1), 2);
+    const int16x8_t b1 = vshrq_n_s16(vsubq_s16(a0, a1), 2);
+    const int16x8_t b2 = vshrq_n_s16(vaddq_s16(a2, a3), 2);
+    const int16x8_t b3 = vshrq_n_s16(vsubq_s16(a2, a3), 2);
 
     const int16x8_t c0 = vaddq_s16(b0, b2);
     const int16x8_t c1 = vaddq_s16(b1, b3);
     const int16x8_t c2 = vsubq_s16(b0, b2);
     const int16x8_t c3 = vsubq_s16(b1, b3);
 
-    store_s16q_to_tran_low_offset_4(coeff + 0, c0);
-    store_s16q_to_tran_low_offset_4(coeff + 64, c1);
-    store_s16q_to_tran_low_offset_4(coeff + 128, c2);
-    store_s16q_to_tran_low_offset_4(coeff + 192, c3);
+    store_s16q_to_tran_low(coeff + 0, c0);
+    store_s16q_to_tran_low(coeff + 256, c1);
+    store_s16q_to_tran_low(coeff + 512, c2);
+    store_s16q_to_tran_low(coeff + 768, c3);
 
-    t_coeff += 8;
-    coeff += (4 + (((i >> 3) & 1) << 3));
+    coeff += 8;
   }
 }
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c
index 8e6dc12..ba17f8a 100644
--- a/aom_dsp/arm/intrapred_neon.c
+++ b/aom_dsp/arm/intrapred_neon.c
@@ -17,518 +17,1029 @@
 
 #include "aom/aom_integer.h"
 #include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/arm/sum_neon.h"
 #include "aom_dsp/intrapred_common.h"
 
 //------------------------------------------------------------------------------
 // DC 4x4
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_4x4(uint8_t *dst, ptrdiff_t stride, const uint8_t *above,
-                          const uint8_t *left, int do_above, int do_left) {
-  uint16x8_t sum_top;
-  uint16x8_t sum_left;
-  uint8x8_t dc0;
+static INLINE uint16x8_t dc_load_sum_4(const uint8_t *in) {
+  const uint8x8_t a = load_u8_4x1_lane0(in);
+  const uint16x4_t p0 = vpaddl_u8(a);
+  const uint16x4_t p1 = vpadd_u16(p0, p0);
+  return vcombine_u16(p1, vdup_n_u16(0));
+}
 
-  if (do_above) {
-    const uint8x8_t A = vld1_u8(above);  // top row
-    const uint16x4_t p0 = vpaddl_u8(A);  // cascading summation of the top
-    const uint16x4_t p1 = vpadd_u16(p0, p0);
-    sum_top = vcombine_u16(p1, p1);
-  }
-
-  if (do_left) {
-    const uint8x8_t L = vld1_u8(left);   // left border
-    const uint16x4_t p0 = vpaddl_u8(L);  // cascading summation of the left
-    const uint16x4_t p1 = vpadd_u16(p0, p0);
-    sum_left = vcombine_u16(p1, p1);
-  }
-
-  if (do_above && do_left) {
-    const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
-    dc0 = vrshrn_n_u16(sum, 3);
-  } else if (do_above) {
-    dc0 = vrshrn_n_u16(sum_top, 2);
-  } else if (do_left) {
-    dc0 = vrshrn_n_u16(sum_left, 2);
-  } else {
-    dc0 = vdup_n_u8(0x80);
-  }
-
-  {
-    const uint8x8_t dc = vdup_lane_u8(dc0, 0);
-    int i;
-    for (i = 0; i < 4; ++i) {
-      vst1_lane_u32((uint32_t *)(dst + i * stride), vreinterpret_u32_u8(dc), 0);
-    }
+static INLINE void dc_store_4xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                uint8x8_t dc) {
+  for (int i = 0; i < h; ++i) {
+    store_u8_4x1(dst + i * stride, dc, 0);
   }
 }
 
 void aom_dc_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                const uint8_t *above, const uint8_t *left) {
-  dc_4x4(dst, stride, above, left, 1, 1);
+  const uint16x8_t sum_top = dc_load_sum_4(above);
+  const uint16x8_t sum_left = dc_load_sum_4(left);
+  const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum, 3);
+  dc_store_4xh(dst, stride, 4, vdup_lane_u8(dc0, 0));
 }
 
 void aom_dc_left_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                     const uint8_t *above, const uint8_t *left) {
+  const uint16x8_t sum_left = dc_load_sum_4(left);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 2);
   (void)above;
-  dc_4x4(dst, stride, NULL, left, 0, 1);
+  dc_store_4xh(dst, stride, 4, vdup_lane_u8(dc0, 0));
 }
 
 void aom_dc_top_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint16x8_t sum_top = dc_load_sum_4(above);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 2);
   (void)left;
-  dc_4x4(dst, stride, above, NULL, 1, 0);
+  dc_store_4xh(dst, stride, 4, vdup_lane_u8(dc0, 0));
 }
 
 void aom_dc_128_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t dc0 = vdup_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_4x4(dst, stride, NULL, NULL, 0, 0);
+  dc_store_4xh(dst, stride, 4, dc0);
 }
 
 //------------------------------------------------------------------------------
 // DC 8x8
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_8x8(uint8_t *dst, ptrdiff_t stride, const uint8_t *above,
-                          const uint8_t *left, int do_above, int do_left) {
-  uint16x8_t sum_top;
-  uint16x8_t sum_left;
-  uint8x8_t dc0;
+static INLINE uint16x8_t dc_load_sum_8(const uint8_t *in) {
+  // This isn't used in the case where we want to load both above and left
+  // vectors, since we want to avoid performing the reduction twice.
+  const uint8x8_t a = vld1_u8(in);
+  const uint16x4_t p0 = vpaddl_u8(a);
+  const uint16x4_t p1 = vpadd_u16(p0, p0);
+  const uint16x4_t p2 = vpadd_u16(p1, p1);
+  return vcombine_u16(p2, vdup_n_u16(0));
+}
 
-  if (do_above) {
-    const uint8x8_t A = vld1_u8(above);  // top row
-    const uint16x4_t p0 = vpaddl_u8(A);  // cascading summation of the top
-    const uint16x4_t p1 = vpadd_u16(p0, p0);
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    sum_top = vcombine_u16(p2, p2);
-  }
+static INLINE uint16x8_t horizontal_add_and_broadcast_u16x8(uint16x8_t a) {
+#ifdef __aarch64__
+  // On AArch64 we could also use vdupq_n_u16(vaddvq_u16(a)) here to save an
+  // instruction, however the addv instruction is usually slightly more
+  // expensive than a pairwise addition, so the need for immediately
+  // broadcasting the result again seems to negate any benefit.
+  const uint16x8_t b = vpaddq_u16(a, a);
+  const uint16x8_t c = vpaddq_u16(b, b);
+  return vpaddq_u16(c, c);
+#else
+  const uint16x4_t b = vadd_u16(vget_low_u16(a), vget_high_u16(a));
+  const uint16x4_t c = vpadd_u16(b, b);
+  const uint16x4_t d = vpadd_u16(c, c);
+  return vcombine_u16(d, d);
+#endif
+}
 
-  if (do_left) {
-    const uint8x8_t L = vld1_u8(left);   // left border
-    const uint16x4_t p0 = vpaddl_u8(L);  // cascading summation of the left
-    const uint16x4_t p1 = vpadd_u16(p0, p0);
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    sum_left = vcombine_u16(p2, p2);
-  }
-
-  if (do_above && do_left) {
-    const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
-    dc0 = vrshrn_n_u16(sum, 4);
-  } else if (do_above) {
-    dc0 = vrshrn_n_u16(sum_top, 3);
-  } else if (do_left) {
-    dc0 = vrshrn_n_u16(sum_left, 3);
-  } else {
-    dc0 = vdup_n_u8(0x80);
-  }
-
-  {
-    const uint8x8_t dc = vdup_lane_u8(dc0, 0);
-    int i;
-    for (i = 0; i < 8; ++i) {
-      vst1_u32((uint32_t *)(dst + i * stride), vreinterpret_u32_u8(dc));
-    }
+static INLINE void dc_store_8xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                uint8x8_t dc) {
+  for (int i = 0; i < h; ++i) {
+    vst1_u8(dst + i * stride, dc);
   }
 }
 
 void aom_dc_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                const uint8_t *above, const uint8_t *left) {
-  dc_8x8(dst, stride, above, left, 1, 1);
+  const uint8x8_t sum_top = vld1_u8(above);
+  const uint8x8_t sum_left = vld1_u8(left);
+  uint16x8_t sum = vaddl_u8(sum_left, sum_top);
+  sum = horizontal_add_and_broadcast_u16x8(sum);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum, 4);
+  dc_store_8xh(dst, stride, 8, vdup_lane_u8(dc0, 0));
 }
 
 void aom_dc_left_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                     const uint8_t *above, const uint8_t *left) {
+  const uint16x8_t sum_left = dc_load_sum_8(left);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 3);
   (void)above;
-  dc_8x8(dst, stride, NULL, left, 0, 1);
+  dc_store_8xh(dst, stride, 8, vdup_lane_u8(dc0, 0));
 }
 
 void aom_dc_top_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint16x8_t sum_top = dc_load_sum_8(above);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 3);
   (void)left;
-  dc_8x8(dst, stride, above, NULL, 1, 0);
+  dc_store_8xh(dst, stride, 8, vdup_lane_u8(dc0, 0));
 }
 
 void aom_dc_128_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t dc0 = vdup_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_8x8(dst, stride, NULL, NULL, 0, 0);
+  dc_store_8xh(dst, stride, 8, dc0);
 }
 
 //------------------------------------------------------------------------------
 // DC 16x16
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_16x16(uint8_t *dst, ptrdiff_t stride,
-                            const uint8_t *above, const uint8_t *left,
-                            int do_above, int do_left) {
-  uint16x8_t sum_top;
-  uint16x8_t sum_left;
-  uint8x8_t dc0;
+static INLINE uint16x8_t dc_load_partial_sum_16(const uint8_t *in) {
+  const uint8x16_t a = vld1q_u8(in);
+  // delay the remainder of the reduction until
+  // horizontal_add_and_broadcast_u16x8, since we want to do it once rather
+  // than twice in the case we are loading both above and left.
+  return vpaddlq_u8(a);
+}
 
-  if (do_above) {
-    const uint8x16_t A = vld1q_u8(above);  // top row
-    const uint16x8_t p0 = vpaddlq_u8(A);   // cascading summation of the top
-    const uint16x4_t p1 = vadd_u16(vget_low_u16(p0), vget_high_u16(p0));
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    const uint16x4_t p3 = vpadd_u16(p2, p2);
-    sum_top = vcombine_u16(p3, p3);
-  }
+static INLINE uint16x8_t dc_load_sum_16(const uint8_t *in) {
+  return horizontal_add_and_broadcast_u16x8(dc_load_partial_sum_16(in));
+}
 
-  if (do_left) {
-    const uint8x16_t L = vld1q_u8(left);  // left row
-    const uint16x8_t p0 = vpaddlq_u8(L);  // cascading summation of the left
-    const uint16x4_t p1 = vadd_u16(vget_low_u16(p0), vget_high_u16(p0));
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    const uint16x4_t p3 = vpadd_u16(p2, p2);
-    sum_left = vcombine_u16(p3, p3);
-  }
-
-  if (do_above && do_left) {
-    const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
-    dc0 = vrshrn_n_u16(sum, 5);
-  } else if (do_above) {
-    dc0 = vrshrn_n_u16(sum_top, 4);
-  } else if (do_left) {
-    dc0 = vrshrn_n_u16(sum_left, 4);
-  } else {
-    dc0 = vdup_n_u8(0x80);
-  }
-
-  {
-    const uint8x16_t dc = vdupq_lane_u8(dc0, 0);
-    int i;
-    for (i = 0; i < 16; ++i) {
-      vst1q_u8(dst + i * stride, dc);
-    }
+static INLINE void dc_store_16xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                 uint8x16_t dc) {
+  for (int i = 0; i < h; ++i) {
+    vst1q_u8(dst + i * stride, dc);
   }
 }
 
 void aom_dc_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                  const uint8_t *above, const uint8_t *left) {
-  dc_16x16(dst, stride, above, left, 1, 1);
+  const uint16x8_t sum_top = dc_load_partial_sum_16(above);
+  const uint16x8_t sum_left = dc_load_partial_sum_16(left);
+  uint16x8_t sum = vaddq_u16(sum_left, sum_top);
+  sum = horizontal_add_and_broadcast_u16x8(sum);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum, 5);
+  dc_store_16xh(dst, stride, 16, vdupq_lane_u8(dc0, 0));
 }
 
 void aom_dc_left_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                       const uint8_t *above,
                                       const uint8_t *left) {
+  const uint16x8_t sum_left = dc_load_sum_16(left);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 4);
   (void)above;
-  dc_16x16(dst, stride, NULL, left, 0, 1);
+  dc_store_16xh(dst, stride, 16, vdupq_lane_u8(dc0, 0));
 }
 
 void aom_dc_top_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint16x8_t sum_top = dc_load_sum_16(above);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 4);
   (void)left;
-  dc_16x16(dst, stride, above, NULL, 1, 0);
+  dc_store_16xh(dst, stride, 16, vdupq_lane_u8(dc0, 0));
 }
 
 void aom_dc_128_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint8x16_t dc0 = vdupq_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_16x16(dst, stride, NULL, NULL, 0, 0);
+  dc_store_16xh(dst, stride, 16, dc0);
 }
 
 //------------------------------------------------------------------------------
 // DC 32x32
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_32x32(uint8_t *dst, ptrdiff_t stride,
-                            const uint8_t *above, const uint8_t *left,
-                            int do_above, int do_left) {
-  uint16x8_t sum_top;
-  uint16x8_t sum_left;
-  uint8x8_t dc0;
+static INLINE uint16x8_t dc_load_partial_sum_32(const uint8_t *in) {
+  const uint8x16_t a0 = vld1q_u8(in);
+  const uint8x16_t a1 = vld1q_u8(in + 16);
+  // delay the remainder of the reduction until
+  // horizontal_add_and_broadcast_u16x8, since we want to do it once rather
+  // than twice in the case we are loading both above and left.
+  return vpadalq_u8(vpaddlq_u8(a0), a1);
+}
 
-  if (do_above) {
-    const uint8x16_t A0 = vld1q_u8(above);  // top row
-    const uint8x16_t A1 = vld1q_u8(above + 16);
-    const uint16x8_t p0 = vpaddlq_u8(A0);  // cascading summation of the top
-    const uint16x8_t p1 = vpaddlq_u8(A1);
-    const uint16x8_t p2 = vaddq_u16(p0, p1);
-    const uint16x4_t p3 = vadd_u16(vget_low_u16(p2), vget_high_u16(p2));
-    const uint16x4_t p4 = vpadd_u16(p3, p3);
-    const uint16x4_t p5 = vpadd_u16(p4, p4);
-    sum_top = vcombine_u16(p5, p5);
-  }
+static INLINE uint16x8_t dc_load_sum_32(const uint8_t *in) {
+  return horizontal_add_and_broadcast_u16x8(dc_load_partial_sum_32(in));
+}
 
-  if (do_left) {
-    const uint8x16_t L0 = vld1q_u8(left);  // left row
-    const uint8x16_t L1 = vld1q_u8(left + 16);
-    const uint16x8_t p0 = vpaddlq_u8(L0);  // cascading summation of the left
-    const uint16x8_t p1 = vpaddlq_u8(L1);
-    const uint16x8_t p2 = vaddq_u16(p0, p1);
-    const uint16x4_t p3 = vadd_u16(vget_low_u16(p2), vget_high_u16(p2));
-    const uint16x4_t p4 = vpadd_u16(p3, p3);
-    const uint16x4_t p5 = vpadd_u16(p4, p4);
-    sum_left = vcombine_u16(p5, p5);
-  }
-
-  if (do_above && do_left) {
-    const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
-    dc0 = vrshrn_n_u16(sum, 6);
-  } else if (do_above) {
-    dc0 = vrshrn_n_u16(sum_top, 5);
-  } else if (do_left) {
-    dc0 = vrshrn_n_u16(sum_left, 5);
-  } else {
-    dc0 = vdup_n_u8(0x80);
-  }
-
-  {
-    const uint8x16_t dc = vdupq_lane_u8(dc0, 0);
-    int i;
-    for (i = 0; i < 32; ++i) {
-      vst1q_u8(dst + i * stride, dc);
-      vst1q_u8(dst + i * stride + 16, dc);
-    }
+static INLINE void dc_store_32xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                 uint8x16_t dc) {
+  for (int i = 0; i < h; ++i) {
+    vst1q_u8(dst + i * stride, dc);
+    vst1q_u8(dst + i * stride + 16, dc);
   }
 }
 
 void aom_dc_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                  const uint8_t *above, const uint8_t *left) {
-  dc_32x32(dst, stride, above, left, 1, 1);
+  const uint16x8_t sum_top = dc_load_partial_sum_32(above);
+  const uint16x8_t sum_left = dc_load_partial_sum_32(left);
+  uint16x8_t sum = vaddq_u16(sum_left, sum_top);
+  sum = horizontal_add_and_broadcast_u16x8(sum);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum, 6);
+  dc_store_32xh(dst, stride, 32, vdupq_lane_u8(dc0, 0));
 }
 
 void aom_dc_left_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                       const uint8_t *above,
                                       const uint8_t *left) {
+  const uint16x8_t sum_left = dc_load_sum_32(left);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 5);
   (void)above;
-  dc_32x32(dst, stride, NULL, left, 0, 1);
+  dc_store_32xh(dst, stride, 32, vdupq_lane_u8(dc0, 0));
 }
 
 void aom_dc_top_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint16x8_t sum_top = dc_load_sum_32(above);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 5);
   (void)left;
-  dc_32x32(dst, stride, above, NULL, 1, 0);
+  dc_store_32xh(dst, stride, 32, vdupq_lane_u8(dc0, 0));
 }
 
 void aom_dc_128_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint8x16_t dc0 = vdupq_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_32x32(dst, stride, NULL, NULL, 0, 0);
+  dc_store_32xh(dst, stride, 32, dc0);
 }
 
+//------------------------------------------------------------------------------
+// DC 64x64
+
+static INLINE uint16x8_t dc_load_partial_sum_64(const uint8_t *in) {
+  const uint8x16_t a0 = vld1q_u8(in);
+  const uint8x16_t a1 = vld1q_u8(in + 16);
+  const uint8x16_t a2 = vld1q_u8(in + 32);
+  const uint8x16_t a3 = vld1q_u8(in + 48);
+  const uint16x8_t p01 = vpadalq_u8(vpaddlq_u8(a0), a1);
+  const uint16x8_t p23 = vpadalq_u8(vpaddlq_u8(a2), a3);
+  // delay the remainder of the reduction until
+  // horizontal_add_and_broadcast_u16x8, since we want to do it once rather
+  // than twice in the case we are loading both above and left.
+  return vaddq_u16(p01, p23);
+}
+
+static INLINE uint16x8_t dc_load_sum_64(const uint8_t *in) {
+  return horizontal_add_and_broadcast_u16x8(dc_load_partial_sum_64(in));
+}
+
+static INLINE void dc_store_64xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                 uint8x16_t dc) {
+  for (int i = 0; i < h; ++i) {
+    vst1q_u8(dst + i * stride, dc);
+    vst1q_u8(dst + i * stride + 16, dc);
+    vst1q_u8(dst + i * stride + 32, dc);
+    vst1q_u8(dst + i * stride + 48, dc);
+  }
+}
+
+void aom_dc_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                 const uint8_t *above, const uint8_t *left) {
+  const uint16x8_t sum_top = dc_load_partial_sum_64(above);
+  const uint16x8_t sum_left = dc_load_partial_sum_64(left);
+  uint16x8_t sum = vaddq_u16(sum_left, sum_top);
+  sum = horizontal_add_and_broadcast_u16x8(sum);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum, 7);
+  dc_store_64xh(dst, stride, 64, vdupq_lane_u8(dc0, 0));
+}
+
+void aom_dc_left_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                      const uint8_t *above,
+                                      const uint8_t *left) {
+  const uint16x8_t sum_left = dc_load_sum_64(left);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 6);
+  (void)above;
+  dc_store_64xh(dst, stride, 64, vdupq_lane_u8(dc0, 0));
+}
+
+void aom_dc_top_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                     const uint8_t *above,
+                                     const uint8_t *left) {
+  const uint16x8_t sum_top = dc_load_sum_64(above);
+  const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 6);
+  (void)left;
+  dc_store_64xh(dst, stride, 64, vdupq_lane_u8(dc0, 0));
+}
+
+void aom_dc_128_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                     const uint8_t *above,
+                                     const uint8_t *left) {
+  const uint8x16_t dc0 = vdupq_n_u8(0x80);
+  (void)above;
+  (void)left;
+  dc_store_64xh(dst, stride, 64, dc0);
+}
+
+//------------------------------------------------------------------------------
+// DC rectangular cases
+
+#define DC_MULTIPLIER_1X2 0x5556
+#define DC_MULTIPLIER_1X4 0x3334
+
+#define DC_SHIFT2 16
+
+static INLINE int divide_using_multiply_shift(int num, int shift1,
+                                              int multiplier, int shift2) {
+  const int interm = num >> shift1;
+  return interm * multiplier >> shift2;
+}
+
+static INLINE int calculate_dc_from_sum(int bw, int bh, uint32_t sum,
+                                        int shift1, int multiplier) {
+  const int expected_dc = divide_using_multiply_shift(
+      sum + ((bw + bh) >> 1), shift1, multiplier, DC_SHIFT2);
+  assert(expected_dc < (1 << 8));
+  return expected_dc;
+}
+
+#undef DC_SHIFT2
+
+void aom_dc_predictor_4x8_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  uint8x8_t a = load_u8_4x1_lane0(above);
+  uint8x8_t l = vld1_u8(left);
+  uint32_t sum = horizontal_add_u16x8(vaddl_u8(a, l));
+  uint32_t dc = calculate_dc_from_sum(4, 8, sum, 2, DC_MULTIPLIER_1X2);
+  dc_store_4xh(dst, stride, 8, vdup_n_u8(dc));
+}
+
+void aom_dc_predictor_8x4_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  uint8x8_t a = vld1_u8(above);
+  uint8x8_t l = load_u8_4x1_lane0(left);
+  uint32_t sum = horizontal_add_u16x8(vaddl_u8(a, l));
+  uint32_t dc = calculate_dc_from_sum(8, 4, sum, 2, DC_MULTIPLIER_1X2);
+  dc_store_8xh(dst, stride, 4, vdup_n_u8(dc));
+}
+
+void aom_dc_predictor_4x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  uint8x8_t a = load_u8_4x1_lane0(above);
+  uint8x16_t l = vld1q_u8(left);
+  uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(l), a);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(4, 16, sum, 2, DC_MULTIPLIER_1X4);
+  dc_store_4xh(dst, stride, 16, vdup_n_u8(dc));
+}
+
+void aom_dc_predictor_16x4_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  uint8x16_t a = vld1q_u8(above);
+  uint8x8_t l = load_u8_4x1_lane0(left);
+  uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(a), l);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(16, 4, sum, 2, DC_MULTIPLIER_1X4);
+  dc_store_16xh(dst, stride, 4, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_8x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  uint8x8_t a = vld1_u8(above);
+  uint8x16_t l = vld1q_u8(left);
+  uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(l), a);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(8, 16, sum, 3, DC_MULTIPLIER_1X2);
+  dc_store_8xh(dst, stride, 16, vdup_n_u8(dc));
+}
+
+void aom_dc_predictor_16x8_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  uint8x16_t a = vld1q_u8(above);
+  uint8x8_t l = vld1_u8(left);
+  uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(a), l);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(16, 8, sum, 3, DC_MULTIPLIER_1X2);
+  dc_store_16xh(dst, stride, 8, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_8x32_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  uint8x8_t a = vld1_u8(above);
+  uint16x8_t sum_left = dc_load_partial_sum_32(left);
+  uint16x8_t sum_al = vaddw_u8(sum_left, a);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(8, 32, sum, 3, DC_MULTIPLIER_1X4);
+  dc_store_8xh(dst, stride, 32, vdup_n_u8(dc));
+}
+
+void aom_dc_predictor_32x8_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  uint16x8_t sum_top = dc_load_partial_sum_32(above);
+  uint8x8_t l = vld1_u8(left);
+  uint16x8_t sum_al = vaddw_u8(sum_top, l);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(32, 8, sum, 3, DC_MULTIPLIER_1X4);
+  dc_store_32xh(dst, stride, 8, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_16x32_neon(uint8_t *dst, ptrdiff_t stride,
+                                 const uint8_t *above, const uint8_t *left) {
+  uint16x8_t sum_above = dc_load_partial_sum_16(above);
+  uint16x8_t sum_left = dc_load_partial_sum_32(left);
+  uint16x8_t sum_al = vaddq_u16(sum_left, sum_above);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(16, 32, sum, 4, DC_MULTIPLIER_1X2);
+  dc_store_16xh(dst, stride, 32, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_32x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                 const uint8_t *above, const uint8_t *left) {
+  uint16x8_t sum_above = dc_load_partial_sum_32(above);
+  uint16x8_t sum_left = dc_load_partial_sum_16(left);
+  uint16x8_t sum_al = vaddq_u16(sum_left, sum_above);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(32, 16, sum, 4, DC_MULTIPLIER_1X2);
+  dc_store_32xh(dst, stride, 16, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_16x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                 const uint8_t *above, const uint8_t *left) {
+  uint16x8_t sum_above = dc_load_partial_sum_16(above);
+  uint16x8_t sum_left = dc_load_partial_sum_64(left);
+  uint16x8_t sum_al = vaddq_u16(sum_left, sum_above);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(16, 64, sum, 4, DC_MULTIPLIER_1X4);
+  dc_store_16xh(dst, stride, 64, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_64x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                 const uint8_t *above, const uint8_t *left) {
+  uint16x8_t sum_above = dc_load_partial_sum_64(above);
+  uint16x8_t sum_left = dc_load_partial_sum_16(left);
+  uint16x8_t sum_al = vaddq_u16(sum_above, sum_left);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(64, 16, sum, 4, DC_MULTIPLIER_1X4);
+  dc_store_64xh(dst, stride, 16, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_32x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                 const uint8_t *above, const uint8_t *left) {
+  uint16x8_t sum_above = dc_load_partial_sum_32(above);
+  uint16x8_t sum_left = dc_load_partial_sum_64(left);
+  uint16x8_t sum_al = vaddq_u16(sum_above, sum_left);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(32, 64, sum, 5, DC_MULTIPLIER_1X2);
+  dc_store_32xh(dst, stride, 64, vdupq_n_u8(dc));
+}
+
+void aom_dc_predictor_64x32_neon(uint8_t *dst, ptrdiff_t stride,
+                                 const uint8_t *above, const uint8_t *left) {
+  uint16x8_t sum_above = dc_load_partial_sum_64(above);
+  uint16x8_t sum_left = dc_load_partial_sum_32(left);
+  uint16x8_t sum_al = vaddq_u16(sum_above, sum_left);
+  uint32_t sum = horizontal_add_u16x8(sum_al);
+  uint32_t dc = calculate_dc_from_sum(64, 32, sum, 5, DC_MULTIPLIER_1X2);
+  dc_store_64xh(dst, stride, 32, vdupq_n_u8(dc));
+}
+
+#undef DC_MULTIPLIER_1X2
+#undef DC_MULTIPLIER_1X4
+
+#define DC_PREDICTOR_128(w, h, q)                                            \
+  void aom_dc_128_predictor_##w##x##h##_neon(uint8_t *dst, ptrdiff_t stride, \
+                                             const uint8_t *above,           \
+                                             const uint8_t *left) {          \
+    (void)above;                                                             \
+    (void)left;                                                              \
+    dc_store_##w##xh(dst, stride, (h), vdup##q##_n_u8(0x80));                \
+  }
+
+DC_PREDICTOR_128(4, 8, )
+DC_PREDICTOR_128(4, 16, )
+DC_PREDICTOR_128(8, 4, )
+DC_PREDICTOR_128(8, 16, )
+DC_PREDICTOR_128(8, 32, )
+DC_PREDICTOR_128(16, 4, q)
+DC_PREDICTOR_128(16, 8, q)
+DC_PREDICTOR_128(16, 32, q)
+DC_PREDICTOR_128(16, 64, q)
+DC_PREDICTOR_128(32, 8, q)
+DC_PREDICTOR_128(32, 16, q)
+DC_PREDICTOR_128(32, 64, q)
+DC_PREDICTOR_128(64, 32, q)
+DC_PREDICTOR_128(64, 16, q)
+
+#undef DC_PREDICTOR_128
+
+#define DC_PREDICTOR_LEFT(w, h, shift, q)                                     \
+  void aom_dc_left_predictor_##w##x##h##_neon(uint8_t *dst, ptrdiff_t stride, \
+                                              const uint8_t *above,           \
+                                              const uint8_t *left) {          \
+    (void)above;                                                              \
+    const uint16x8_t sum = dc_load_sum_##h(left);                             \
+    const uint8x8_t dc0 = vrshrn_n_u16(sum, (shift));                         \
+    dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u8(dc0, 0));            \
+  }
+
+DC_PREDICTOR_LEFT(4, 8, 3, )
+DC_PREDICTOR_LEFT(8, 4, 2, )
+DC_PREDICTOR_LEFT(8, 16, 4, )
+DC_PREDICTOR_LEFT(16, 8, 3, q)
+DC_PREDICTOR_LEFT(16, 32, 5, q)
+DC_PREDICTOR_LEFT(32, 16, 4, q)
+DC_PREDICTOR_LEFT(32, 64, 6, q)
+DC_PREDICTOR_LEFT(64, 32, 5, q)
+DC_PREDICTOR_LEFT(4, 16, 4, )
+DC_PREDICTOR_LEFT(16, 4, 2, q)
+DC_PREDICTOR_LEFT(8, 32, 5, )
+DC_PREDICTOR_LEFT(32, 8, 3, q)
+DC_PREDICTOR_LEFT(16, 64, 6, q)
+DC_PREDICTOR_LEFT(64, 16, 4, q)
+
+#undef DC_PREDICTOR_LEFT
+
+#define DC_PREDICTOR_TOP(w, h, shift, q)                                     \
+  void aom_dc_top_predictor_##w##x##h##_neon(uint8_t *dst, ptrdiff_t stride, \
+                                             const uint8_t *above,           \
+                                             const uint8_t *left) {          \
+    (void)left;                                                              \
+    const uint16x8_t sum = dc_load_sum_##w(above);                           \
+    const uint8x8_t dc0 = vrshrn_n_u16(sum, (shift));                        \
+    dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u8(dc0, 0));           \
+  }
+
+DC_PREDICTOR_TOP(4, 8, 2, )
+DC_PREDICTOR_TOP(4, 16, 2, )
+DC_PREDICTOR_TOP(8, 4, 3, )
+DC_PREDICTOR_TOP(8, 16, 3, )
+DC_PREDICTOR_TOP(8, 32, 3, )
+DC_PREDICTOR_TOP(16, 4, 4, q)
+DC_PREDICTOR_TOP(16, 8, 4, q)
+DC_PREDICTOR_TOP(16, 32, 4, q)
+DC_PREDICTOR_TOP(16, 64, 4, q)
+DC_PREDICTOR_TOP(32, 8, 5, q)
+DC_PREDICTOR_TOP(32, 16, 5, q)
+DC_PREDICTOR_TOP(32, 64, 5, q)
+DC_PREDICTOR_TOP(64, 16, 6, q)
+DC_PREDICTOR_TOP(64, 32, 6, q)
+
+#undef DC_PREDICTOR_TOP
+
 // -----------------------------------------------------------------------------
 
-void aom_d135_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
-                                 const uint8_t *above, const uint8_t *left) {
-  const uint8x8_t XABCD_u8 = vld1_u8(above - 1);
-  const uint64x1_t XABCD = vreinterpret_u64_u8(XABCD_u8);
-  const uint64x1_t ____XABC = vshl_n_u64(XABCD, 32);
-  const uint32x2_t zero = vdup_n_u32(0);
-  const uint32x2_t IJKL = vld1_lane_u32((const uint32_t *)left, zero, 0);
-  const uint8x8_t IJKL_u8 = vreinterpret_u8_u32(IJKL);
-  const uint64x1_t LKJI____ = vreinterpret_u64_u8(vrev32_u8(IJKL_u8));
-  const uint64x1_t LKJIXABC = vorr_u64(LKJI____, ____XABC);
-  const uint8x8_t KJIXABC_ = vreinterpret_u8_u64(vshr_n_u64(LKJIXABC, 8));
-  const uint8x8_t JIXABC__ = vreinterpret_u8_u64(vshr_n_u64(LKJIXABC, 16));
-  const uint8_t D = vget_lane_u8(XABCD_u8, 4);
-  const uint8x8_t JIXABCD_ = vset_lane_u8(D, JIXABC__, 6);
-  const uint8x8_t LKJIXABC_u8 = vreinterpret_u8_u64(LKJIXABC);
-  const uint8x8_t avg1 = vhadd_u8(JIXABCD_, LKJIXABC_u8);
-  const uint8x8_t avg2 = vrhadd_u8(avg1, KJIXABC_);
-  const uint64x1_t avg2_u64 = vreinterpret_u64_u8(avg2);
-  const uint32x2_t r3 = vreinterpret_u32_u8(avg2);
-  const uint32x2_t r2 = vreinterpret_u32_u64(vshr_n_u64(avg2_u64, 8));
-  const uint32x2_t r1 = vreinterpret_u32_u64(vshr_n_u64(avg2_u64, 16));
-  const uint32x2_t r0 = vreinterpret_u32_u64(vshr_n_u64(avg2_u64, 24));
-  vst1_lane_u32((uint32_t *)(dst + 0 * stride), r0, 0);
-  vst1_lane_u32((uint32_t *)(dst + 1 * stride), r1, 0);
-  vst1_lane_u32((uint32_t *)(dst + 2 * stride), r2, 0);
-  vst1_lane_u32((uint32_t *)(dst + 3 * stride), r3, 0);
+static INLINE void v_store_4xh(uint8_t *dst, ptrdiff_t stride, int h,
+                               uint8x8_t d0) {
+  for (int i = 0; i < h; ++i) {
+    store_u8_4x1(dst + i * stride, d0, 0);
+  }
+}
+
+static INLINE void v_store_8xh(uint8_t *dst, ptrdiff_t stride, int h,
+                               uint8x8_t d0) {
+  for (int i = 0; i < h; ++i) {
+    vst1_u8(dst + i * stride, d0);
+  }
+}
+
+static INLINE void v_store_16xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                uint8x16_t d0) {
+  for (int i = 0; i < h; ++i) {
+    vst1q_u8(dst + i * stride, d0);
+  }
+}
+
+static INLINE void v_store_32xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                uint8x16_t d0, uint8x16_t d1) {
+  for (int i = 0; i < h; ++i) {
+    vst1q_u8(dst + 0, d0);
+    vst1q_u8(dst + 16, d1);
+    dst += stride;
+  }
+}
+
+static INLINE void v_store_64xh(uint8_t *dst, ptrdiff_t stride, int h,
+                                uint8x16_t d0, uint8x16_t d1, uint8x16_t d2,
+                                uint8x16_t d3) {
+  for (int i = 0; i < h; ++i) {
+    vst1q_u8(dst + 0, d0);
+    vst1q_u8(dst + 16, d1);
+    vst1q_u8(dst + 32, d2);
+    vst1q_u8(dst + 48, d3);
+    dst += stride;
+  }
 }
 
 void aom_v_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                               const uint8_t *above, const uint8_t *left) {
-  int i;
-  uint32x2_t d0u32 = vdup_n_u32(0);
   (void)left;
-
-  d0u32 = vld1_lane_u32((const uint32_t *)above, d0u32, 0);
-  for (i = 0; i < 4; i++, dst += stride)
-    vst1_lane_u32((uint32_t *)dst, d0u32, 0);
+  v_store_4xh(dst, stride, 4, load_u8_4x1_lane0(above));
 }
 
 void aom_v_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                               const uint8_t *above, const uint8_t *left) {
-  int i;
-  uint8x8_t d0u8 = vdup_n_u8(0);
   (void)left;
-
-  d0u8 = vld1_u8(above);
-  for (i = 0; i < 8; i++, dst += stride) vst1_u8(dst, d0u8);
+  v_store_8xh(dst, stride, 8, vld1_u8(above));
 }
 
 void aom_v_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                 const uint8_t *above, const uint8_t *left) {
-  int i;
-  uint8x16_t q0u8 = vdupq_n_u8(0);
   (void)left;
-
-  q0u8 = vld1q_u8(above);
-  for (i = 0; i < 16; i++, dst += stride) vst1q_u8(dst, q0u8);
+  v_store_16xh(dst, stride, 16, vld1q_u8(above));
 }
 
 void aom_v_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                 const uint8_t *above, const uint8_t *left) {
-  int i;
-  uint8x16_t q0u8 = vdupq_n_u8(0);
-  uint8x16_t q1u8 = vdupq_n_u8(0);
+  const uint8x16_t d0 = vld1q_u8(above);
+  const uint8x16_t d1 = vld1q_u8(above + 16);
   (void)left;
+  v_store_32xh(dst, stride, 32, d0, d1);
+}
 
-  q0u8 = vld1q_u8(above);
-  q1u8 = vld1q_u8(above + 16);
-  for (i = 0; i < 32; i++, dst += stride) {
-    vst1q_u8(dst, q0u8);
-    vst1q_u8(dst + 16, q1u8);
-  }
+void aom_v_predictor_4x8_neon(uint8_t *dst, ptrdiff_t stride,
+                              const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_4xh(dst, stride, 8, load_u8_4x1_lane0(above));
+}
+
+void aom_v_predictor_4x16_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_4xh(dst, stride, 16, load_u8_4x1_lane0(above));
+}
+
+void aom_v_predictor_8x4_neon(uint8_t *dst, ptrdiff_t stride,
+                              const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_8xh(dst, stride, 4, vld1_u8(above));
+}
+
+void aom_v_predictor_8x16_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_8xh(dst, stride, 16, vld1_u8(above));
+}
+
+void aom_v_predictor_8x32_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_8xh(dst, stride, 32, vld1_u8(above));
+}
+
+void aom_v_predictor_16x4_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_16xh(dst, stride, 4, vld1q_u8(above));
+}
+
+void aom_v_predictor_16x8_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_16xh(dst, stride, 8, vld1q_u8(above));
+}
+
+void aom_v_predictor_16x32_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_16xh(dst, stride, 32, vld1q_u8(above));
+}
+
+void aom_v_predictor_16x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  (void)left;
+  v_store_16xh(dst, stride, 64, vld1q_u8(above));
+}
+
+void aom_v_predictor_32x8_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(above);
+  const uint8x16_t d1 = vld1q_u8(above + 16);
+  (void)left;
+  v_store_32xh(dst, stride, 8, d0, d1);
+}
+
+void aom_v_predictor_32x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(above);
+  const uint8x16_t d1 = vld1q_u8(above + 16);
+  (void)left;
+  v_store_32xh(dst, stride, 16, d0, d1);
+}
+
+void aom_v_predictor_32x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(above);
+  const uint8x16_t d1 = vld1q_u8(above + 16);
+  (void)left;
+  v_store_32xh(dst, stride, 64, d0, d1);
+}
+
+void aom_v_predictor_64x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(above);
+  const uint8x16_t d1 = vld1q_u8(above + 16);
+  const uint8x16_t d2 = vld1q_u8(above + 32);
+  const uint8x16_t d3 = vld1q_u8(above + 48);
+  (void)left;
+  v_store_64xh(dst, stride, 16, d0, d1, d2, d3);
+}
+
+void aom_v_predictor_64x32_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(above);
+  const uint8x16_t d1 = vld1q_u8(above + 16);
+  const uint8x16_t d2 = vld1q_u8(above + 32);
+  const uint8x16_t d3 = vld1q_u8(above + 48);
+  (void)left;
+  v_store_64xh(dst, stride, 32, d0, d1, d2, d3);
+}
+
+void aom_v_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(above);
+  const uint8x16_t d1 = vld1q_u8(above + 16);
+  const uint8x16_t d2 = vld1q_u8(above + 32);
+  const uint8x16_t d3 = vld1q_u8(above + 48);
+  (void)left;
+  v_store_64xh(dst, stride, 64, d0, d1, d2, d3);
+}
+
+// -----------------------------------------------------------------------------
+
+static INLINE void h_store_4x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) {
+  store_u8_4x1(dst + 0 * stride, vdup_lane_u8(d0, 0), 0);
+  store_u8_4x1(dst + 1 * stride, vdup_lane_u8(d0, 1), 0);
+  store_u8_4x1(dst + 2 * stride, vdup_lane_u8(d0, 2), 0);
+  store_u8_4x1(dst + 3 * stride, vdup_lane_u8(d0, 3), 0);
+  store_u8_4x1(dst + 4 * stride, vdup_lane_u8(d0, 4), 0);
+  store_u8_4x1(dst + 5 * stride, vdup_lane_u8(d0, 5), 0);
+  store_u8_4x1(dst + 6 * stride, vdup_lane_u8(d0, 6), 0);
+  store_u8_4x1(dst + 7 * stride, vdup_lane_u8(d0, 7), 0);
+}
+
+static INLINE void h_store_8x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) {
+  vst1_u8(dst + 0 * stride, vdup_lane_u8(d0, 0));
+  vst1_u8(dst + 1 * stride, vdup_lane_u8(d0, 1));
+  vst1_u8(dst + 2 * stride, vdup_lane_u8(d0, 2));
+  vst1_u8(dst + 3 * stride, vdup_lane_u8(d0, 3));
+  vst1_u8(dst + 4 * stride, vdup_lane_u8(d0, 4));
+  vst1_u8(dst + 5 * stride, vdup_lane_u8(d0, 5));
+  vst1_u8(dst + 6 * stride, vdup_lane_u8(d0, 6));
+  vst1_u8(dst + 7 * stride, vdup_lane_u8(d0, 7));
+}
+
+static INLINE void h_store_16x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) {
+  vst1q_u8(dst + 0 * stride, vdupq_lane_u8(d0, 0));
+  vst1q_u8(dst + 1 * stride, vdupq_lane_u8(d0, 1));
+  vst1q_u8(dst + 2 * stride, vdupq_lane_u8(d0, 2));
+  vst1q_u8(dst + 3 * stride, vdupq_lane_u8(d0, 3));
+  vst1q_u8(dst + 4 * stride, vdupq_lane_u8(d0, 4));
+  vst1q_u8(dst + 5 * stride, vdupq_lane_u8(d0, 5));
+  vst1q_u8(dst + 6 * stride, vdupq_lane_u8(d0, 6));
+  vst1q_u8(dst + 7 * stride, vdupq_lane_u8(d0, 7));
+}
+
+static INLINE void h_store_32x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) {
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 0));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 0));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 1));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 1));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 2));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 2));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 3));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 3));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 4));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 4));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 5));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 5));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 6));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 6));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 7));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 7));
+}
+
+static INLINE void h_store_64x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) {
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 0));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 0));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 0));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 0));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 1));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 1));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 1));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 1));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 2));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 2));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 2));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 2));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 3));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 3));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 3));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 3));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 4));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 4));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 4));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 4));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 5));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 5));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 5));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 5));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 6));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 6));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 6));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 6));
+  dst += stride;
+  vst1q_u8(dst + 0, vdupq_lane_u8(d0, 7));
+  vst1q_u8(dst + 16, vdupq_lane_u8(d0, 7));
+  vst1q_u8(dst + 32, vdupq_lane_u8(d0, 7));
+  vst1q_u8(dst + 48, vdupq_lane_u8(d0, 7));
 }
 
 void aom_h_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                               const uint8_t *above, const uint8_t *left) {
-  uint8x8_t d0u8 = vdup_n_u8(0);
-  uint32x2_t d1u32 = vdup_n_u32(0);
+  const uint8x8_t d0 = load_u8_4x1_lane0(left);
   (void)above;
-
-  d1u32 = vld1_lane_u32((const uint32_t *)left, d1u32, 0);
-
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u32(d1u32), 0);
-  vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d0u8), 0);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u32(d1u32), 1);
-  vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d0u8), 0);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u32(d1u32), 2);
-  vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d0u8), 0);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u32(d1u32), 3);
-  vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(d0u8), 0);
+  store_u8_4x1(dst + 0 * stride, vdup_lane_u8(d0, 0), 0);
+  store_u8_4x1(dst + 1 * stride, vdup_lane_u8(d0, 1), 0);
+  store_u8_4x1(dst + 2 * stride, vdup_lane_u8(d0, 2), 0);
+  store_u8_4x1(dst + 3 * stride, vdup_lane_u8(d0, 3), 0);
 }
 
 void aom_h_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                               const uint8_t *above, const uint8_t *left) {
-  uint8x8_t d0u8 = vdup_n_u8(0);
-  uint64x1_t d1u64 = vdup_n_u64(0);
+  const uint8x8_t d0 = vld1_u8(left);
   (void)above;
-
-  d1u64 = vld1_u64((const uint64_t *)left);
-
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 0);
-  vst1_u8(dst, d0u8);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 1);
-  vst1_u8(dst, d0u8);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 2);
-  vst1_u8(dst, d0u8);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 3);
-  vst1_u8(dst, d0u8);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 4);
-  vst1_u8(dst, d0u8);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 5);
-  vst1_u8(dst, d0u8);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 6);
-  vst1_u8(dst, d0u8);
-  dst += stride;
-  d0u8 = vdup_lane_u8(vreinterpret_u8_u64(d1u64), 7);
-  vst1_u8(dst, d0u8);
+  h_store_8x8(dst, stride, d0);
 }
 
 void aom_h_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                 const uint8_t *above, const uint8_t *left) {
-  int j;
-  uint8x8_t d2u8 = vdup_n_u8(0);
-  uint8x16_t q0u8 = vdupq_n_u8(0);
-  uint8x16_t q1u8 = vdupq_n_u8(0);
+  const uint8x16_t d0 = vld1q_u8(left);
   (void)above;
-
-  q1u8 = vld1q_u8(left);
-  d2u8 = vget_low_u8(q1u8);
-  for (j = 0; j < 2; j++, d2u8 = vget_high_u8(q1u8)) {
-    q0u8 = vdupq_lane_u8(d2u8, 0);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-    q0u8 = vdupq_lane_u8(d2u8, 1);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-    q0u8 = vdupq_lane_u8(d2u8, 2);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-    q0u8 = vdupq_lane_u8(d2u8, 3);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-    q0u8 = vdupq_lane_u8(d2u8, 4);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-    q0u8 = vdupq_lane_u8(d2u8, 5);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-    q0u8 = vdupq_lane_u8(d2u8, 6);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-    q0u8 = vdupq_lane_u8(d2u8, 7);
-    vst1q_u8(dst, q0u8);
-    dst += stride;
-  }
+  h_store_16x8(dst, stride, vget_low_u8(d0));
+  h_store_16x8(dst + 8 * stride, stride, vget_high_u8(d0));
 }
 
 void aom_h_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                 const uint8_t *above, const uint8_t *left) {
-  int j, k;
-  uint8x8_t d2u8 = vdup_n_u8(0);
-  uint8x16_t q0u8 = vdupq_n_u8(0);
-  uint8x16_t q1u8 = vdupq_n_u8(0);
+  const uint8x16_t d0 = vld1q_u8(left);
+  const uint8x16_t d1 = vld1q_u8(left + 16);
   (void)above;
+  h_store_32x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_32x8(dst + 8 * stride, stride, vget_high_u8(d0));
+  h_store_32x8(dst + 16 * stride, stride, vget_low_u8(d1));
+  h_store_32x8(dst + 24 * stride, stride, vget_high_u8(d1));
+}
 
-  for (k = 0; k < 2; k++, left += 16) {
-    q1u8 = vld1q_u8(left);
-    d2u8 = vget_low_u8(q1u8);
-    for (j = 0; j < 2; j++, d2u8 = vget_high_u8(q1u8)) {
-      q0u8 = vdupq_lane_u8(d2u8, 0);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-      q0u8 = vdupq_lane_u8(d2u8, 1);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-      q0u8 = vdupq_lane_u8(d2u8, 2);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-      q0u8 = vdupq_lane_u8(d2u8, 3);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-      q0u8 = vdupq_lane_u8(d2u8, 4);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-      q0u8 = vdupq_lane_u8(d2u8, 5);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-      q0u8 = vdupq_lane_u8(d2u8, 6);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-      q0u8 = vdupq_lane_u8(d2u8, 7);
-      vst1q_u8(dst, q0u8);
-      vst1q_u8(dst + 16, q0u8);
-      dst += stride;
-    }
+void aom_h_predictor_4x8_neon(uint8_t *dst, ptrdiff_t stride,
+                              const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t d0 = vld1_u8(left);
+  (void)above;
+  h_store_4x8(dst, stride, d0);
+}
+
+void aom_h_predictor_4x16_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left);
+  (void)above;
+  h_store_4x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_4x8(dst + 8 * stride, stride, vget_high_u8(d0));
+}
+
+void aom_h_predictor_8x4_neon(uint8_t *dst, ptrdiff_t stride,
+                              const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t d0 = load_u8_4x1_lane0(left);
+  (void)above;
+  vst1_u8(dst + 0 * stride, vdup_lane_u8(d0, 0));
+  vst1_u8(dst + 1 * stride, vdup_lane_u8(d0, 1));
+  vst1_u8(dst + 2 * stride, vdup_lane_u8(d0, 2));
+  vst1_u8(dst + 3 * stride, vdup_lane_u8(d0, 3));
+}
+
+void aom_h_predictor_8x16_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left);
+  (void)above;
+  h_store_8x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_8x8(dst + 8 * stride, stride, vget_high_u8(d0));
+}
+
+void aom_h_predictor_8x32_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left);
+  const uint8x16_t d1 = vld1q_u8(left + 16);
+  (void)above;
+  h_store_8x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_8x8(dst + 8 * stride, stride, vget_high_u8(d0));
+  h_store_8x8(dst + 16 * stride, stride, vget_low_u8(d1));
+  h_store_8x8(dst + 24 * stride, stride, vget_high_u8(d1));
+}
+
+void aom_h_predictor_16x4_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t d0 = load_u8_4x1_lane0(left);
+  (void)above;
+  vst1q_u8(dst + 0 * stride, vdupq_lane_u8(d0, 0));
+  vst1q_u8(dst + 1 * stride, vdupq_lane_u8(d0, 1));
+  vst1q_u8(dst + 2 * stride, vdupq_lane_u8(d0, 2));
+  vst1q_u8(dst + 3 * stride, vdupq_lane_u8(d0, 3));
+}
+
+void aom_h_predictor_16x8_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t d0 = vld1_u8(left);
+  (void)above;
+  h_store_16x8(dst, stride, d0);
+}
+
+void aom_h_predictor_16x32_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left);
+  const uint8x16_t d1 = vld1q_u8(left + 16);
+  (void)above;
+  h_store_16x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_16x8(dst + 8 * stride, stride, vget_high_u8(d0));
+  h_store_16x8(dst + 16 * stride, stride, vget_low_u8(d1));
+  h_store_16x8(dst + 24 * stride, stride, vget_high_u8(d1));
+}
+
+void aom_h_predictor_16x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left);
+  const uint8x16_t d1 = vld1q_u8(left + 16);
+  const uint8x16_t d2 = vld1q_u8(left + 32);
+  const uint8x16_t d3 = vld1q_u8(left + 48);
+  (void)above;
+  h_store_16x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_16x8(dst + 8 * stride, stride, vget_high_u8(d0));
+  h_store_16x8(dst + 16 * stride, stride, vget_low_u8(d1));
+  h_store_16x8(dst + 24 * stride, stride, vget_high_u8(d1));
+  h_store_16x8(dst + 32 * stride, stride, vget_low_u8(d2));
+  h_store_16x8(dst + 40 * stride, stride, vget_high_u8(d2));
+  h_store_16x8(dst + 48 * stride, stride, vget_low_u8(d3));
+  h_store_16x8(dst + 56 * stride, stride, vget_high_u8(d3));
+}
+
+void aom_h_predictor_32x8_neon(uint8_t *dst, ptrdiff_t stride,
+                               const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t d0 = vld1_u8(left);
+  (void)above;
+  h_store_32x8(dst, stride, d0);
+}
+
+void aom_h_predictor_32x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left);
+  (void)above;
+  h_store_32x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_32x8(dst + 8 * stride, stride, vget_high_u8(d0));
+}
+
+void aom_h_predictor_32x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left + 0);
+  const uint8x16_t d1 = vld1q_u8(left + 16);
+  const uint8x16_t d2 = vld1q_u8(left + 32);
+  const uint8x16_t d3 = vld1q_u8(left + 48);
+  (void)above;
+  h_store_32x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_32x8(dst + 8 * stride, stride, vget_high_u8(d0));
+  h_store_32x8(dst + 16 * stride, stride, vget_low_u8(d1));
+  h_store_32x8(dst + 24 * stride, stride, vget_high_u8(d1));
+  h_store_32x8(dst + 32 * stride, stride, vget_low_u8(d2));
+  h_store_32x8(dst + 40 * stride, stride, vget_high_u8(d2));
+  h_store_32x8(dst + 48 * stride, stride, vget_low_u8(d3));
+  h_store_32x8(dst + 56 * stride, stride, vget_high_u8(d3));
+}
+
+void aom_h_predictor_64x16_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  const uint8x16_t d0 = vld1q_u8(left);
+  (void)above;
+  h_store_64x8(dst + 0 * stride, stride, vget_low_u8(d0));
+  h_store_64x8(dst + 8 * stride, stride, vget_high_u8(d0));
+}
+
+void aom_h_predictor_64x32_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  (void)above;
+  for (int i = 0; i < 2; ++i) {
+    const uint8x16_t d0 = vld1q_u8(left);
+    h_store_64x8(dst + 0 * stride, stride, vget_low_u8(d0));
+    h_store_64x8(dst + 8 * stride, stride, vget_high_u8(d0));
+    left += 16;
+    dst += 16 * stride;
+  }
+}
+
+void aom_h_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride,
+                                const uint8_t *above, const uint8_t *left) {
+  (void)above;
+  for (int i = 0; i < 4; ++i) {
+    const uint8x16_t d0 = vld1q_u8(left);
+    h_store_64x8(dst + 0 * stride, stride, vget_low_u8(d0));
+    h_store_64x8(dst + 8 * stride, stride, vget_high_u8(d0));
+    left += 16;
+    dst += 16 * stride;
   }
 }
 
@@ -638,7 +1149,6 @@
     0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
 };
 
-/* clang-format on */
 static AOM_FORCE_INLINE void dr_prediction_z1_HxW_internal_neon_64(
     int H, int W, uint8x8_t *dst, const uint8_t *above, int upsample_above,
     int dx) {
@@ -653,23 +1163,12 @@
   // final pixels will be calculated as:
   //   (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
 
-  uint16x8_t a0, a1;
-  uint16x8_t diff, a32;
-  uint16x8_t a16;
-  uint8x8_t a_mbase_x;
-
-  a16 = vdupq_n_u16(16);
-  a_mbase_x = vdup_n_u8(above[max_base_x]);
-  uint16x8_t v_32 = vdupq_n_u16(32);
-  int16x8_t v_upsample_above = vdupq_n_s16(upsample_above);
-  uint16x8_t c3f = vdupq_n_u16(0x3f);
+  const uint16x8_t a16 = vdupq_n_u16(16);
+  const uint8x8_t a_mbase_x = vdup_n_u8(above[max_base_x]);
+  const uint8x8_t v_32 = vdup_n_u8(32);
 
   int x = dx;
   for (int r = 0; r < W; r++) {
-    uint16x8_t res;
-    uint16x8_t shift;
-    uint8x8x2_t v_tmp_a0_128;
-
     int base = x >> frac_bits;
     int base_max_diff = (max_base_x - base) >> upsample_above;
     if (base_max_diff <= 0) {
@@ -681,24 +1180,22 @@
 
     if (base_max_diff > H) base_max_diff = H;
 
+    uint8x8x2_t a01_128;
+    uint16x8_t shift;
     if (upsample_above) {
-      v_tmp_a0_128 = vld2_u8(above + base);
-      shift = vshrq_n_u16(
-          vandq_u16(vshlq_u16(vdupq_n_u16(x), v_upsample_above), c3f), 1);
+      a01_128 = vld2_u8(above + base);
+      shift = vdupq_n_u16(((x << upsample_above) & 0x3f) >> 1);
     } else {
-      v_tmp_a0_128.val[0] = vld1_u8(above + base);
-      v_tmp_a0_128.val[1] = vld1_u8(above + base + 1);
-      shift = vshrq_n_u16(vandq_u16(vdupq_n_u16(x), c3f), 1);
+      a01_128.val[0] = vld1_u8(above + base);
+      a01_128.val[1] = vld1_u8(above + base + 1);
+      shift = vdupq_n_u16((x & 0x3f) >> 1);
     }
-    a0 = vmovl_u8(v_tmp_a0_128.val[0]);
-    a1 = vmovl_u8(v_tmp_a0_128.val[1]);
-    diff = vsubq_u16(a1, a0);        // a[x+1] - a[x]
-    a32 = vmlaq_u16(a16, a0, v_32);  // a[x] * 32 + 16
-    res = vmlaq_u16(a32, diff, shift);
+    uint16x8_t diff = vsubl_u8(a01_128.val[1], a01_128.val[0]);
+    uint16x8_t a32 = vmlal_u8(a16, a01_128.val[0], v_32);
+    uint16x8_t res = vmlaq_u16(a32, diff, shift);
 
     uint8x8_t mask = vld1_u8(BaseMask[base_max_diff]);
-    dst[r] =
-        vorr_u8(vand_u8(mask, vshrn_n_u16(res, 5)), vbic_u8(a_mbase_x, mask));
+    dst[r] = vbsl_u8(mask, vshrn_n_u16(res, 5), a_mbase_x);
 
     x += dx;
   }
@@ -743,17 +1240,10 @@
   // final pixels will be calculated as:
   //   (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
 
-  uint8x16x2_t a0, a1;
-  uint16x8x2_t diff, a32;
-  uint16x8_t a16, c3f;
-  uint8x16_t a_mbase_x;
-
-  a16 = vdupq_n_u16(16);
-  a_mbase_x = vdupq_n_u8(above[max_base_x]);
-  c3f = vdupq_n_u16(0x3f);
-  uint16x8_t v_32 = vdupq_n_u16(32);
-  uint8x16_t v_zero = vdupq_n_u8(0);
-  int16x8_t v_upsample_above = vdupq_n_s16(upsample_above);
+  const uint16x8_t a16 = vdupq_n_u16(16);
+  const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]);
+  const uint8x8_t v_32 = vdup_n_u8(32);
+  const uint8x16_t v_zero = vdupq_n_u8(0);
 
   int x = dx;
   for (int r = 0; r < W; r++) {
@@ -776,30 +1266,24 @@
       uint8x8x2_t v_tmp_a0_128 = vld2_u8(above + base);
       a0_128 = vcombine_u8(v_tmp_a0_128.val[0], v_tmp_a0_128.val[1]);
       a1_128 = vextq_u8(a0_128, v_zero, 8);
-      shift = vshrq_n_u16(
-          vandq_u16(vshlq_u16(vdupq_n_u16(x), v_upsample_above), c3f), 1);
+      shift = vdupq_n_u16(((x << upsample_above) & 0x3f) >> 1);
     } else {
       a0_128 = vld1q_u8(above + base);
       a1_128 = vld1q_u8(above + base + 1);
-      shift = vshrq_n_u16(vandq_u16(vdupq_n_u16(x), c3f), 1);
+      shift = vdupq_n_u16((x & 0x3f) >> 1);
     }
-    a0 = vzipq_u8(a0_128, v_zero);
-    a1 = vzipq_u8(a1_128, v_zero);
-    diff.val[0] = vsubq_u16(vreinterpretq_u16_u8(a1.val[0]),
-                            vreinterpretq_u16_u8(a0.val[0]));  // a[x+1] - a[x]
-    diff.val[1] = vsubq_u16(vreinterpretq_u16_u8(a1.val[1]),
-                            vreinterpretq_u16_u8(a0.val[1]));  // a[x+1] - a[x]
-    a32.val[0] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0.val[0]),
-                           v_32);  // a[x] * 32 + 16
-    a32.val[1] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0.val[1]),
-                           v_32);  // a[x] * 32 + 16
+    uint16x8x2_t diff, a32;
+    diff.val[0] = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128));
+    diff.val[1] = vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128));
+    a32.val[0] = vmlal_u8(a16, vget_low_u8(a0_128), v_32);
+    a32.val[1] = vmlal_u8(a16, vget_high_u8(a0_128), v_32);
     res.val[0] = vmlaq_u16(a32.val[0], diff.val[0], shift);
     res.val[1] = vmlaq_u16(a32.val[1], diff.val[1], shift);
     uint8x16_t v_temp =
         vcombine_u8(vshrn_n_u16(res.val[0], 5), vshrn_n_u16(res.val[1], 5));
 
     uint8x16_t mask = vld1q_u8(BaseMask[base_max_diff]);
-    dst[r] = vorrq_u8(vandq_u8(mask, v_temp), vbicq_u8(a_mbase_x, mask));
+    dst[r] = vbslq_u8(mask, v_temp, a_mbase_x);
 
     x += dx;
   }
@@ -831,22 +1315,13 @@
   // final pixels will be calculated as:
   //   (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
 
-  uint8x16_t a_mbase_x;
-  uint8x16x2_t a0, a1;
-  uint16x8x2_t diff, a32;
-  uint16x8_t a16, c3f;
-
-  a_mbase_x = vdupq_n_u8(above[max_base_x]);
-  a16 = vdupq_n_u16(16);
-  c3f = vdupq_n_u16(0x3f);
-  uint16x8_t v_32 = vdupq_n_u16(32);
-  uint8x16_t v_zero = vdupq_n_u8(0);
+  const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]);
+  const uint16x8_t a16 = vdupq_n_u16(16);
+  const uint8x8_t v_32 = vdup_n_u8(32);
 
   int x = dx;
   for (int r = 0; r < N; r++) {
-    uint16x8x2_t res;
     uint8x16_t res16[2];
-    uint8x16_t a0_128, a1_128;
 
     int base = x >> frac_bits;
     int base_max_diff = (max_base_x - base);
@@ -859,27 +1334,21 @@
     }
     if (base_max_diff > 32) base_max_diff = 32;
 
-    uint16x8_t shift = vshrq_n_u16(vandq_u16(vdupq_n_u16(x), c3f), 1);
+    uint16x8_t shift = vdupq_n_u16((x & 0x3f) >> 1);
 
     for (int j = 0, jj = 0; j < 32; j += 16, jj++) {
       int mdiff = base_max_diff - j;
       if (mdiff <= 0) {
         res16[jj] = a_mbase_x;
       } else {
+        uint16x8x2_t a32, diff, res;
+        uint8x16_t a0_128, a1_128;
         a0_128 = vld1q_u8(above + base + j);
         a1_128 = vld1q_u8(above + base + j + 1);
-        a0 = vzipq_u8(a0_128, v_zero);
-        a1 = vzipq_u8(a1_128, v_zero);
-        diff.val[0] =
-            vsubq_u16(vreinterpretq_u16_u8(a1.val[0]),
-                      vreinterpretq_u16_u8(a0.val[0]));  // a[x+1] - a[x]
-        diff.val[1] =
-            vsubq_u16(vreinterpretq_u16_u8(a1.val[1]),
-                      vreinterpretq_u16_u8(a0.val[1]));  // a[x+1] - a[x]
-        a32.val[0] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0.val[0]),
-                               v_32);  // a[x] * 32 + 16
-        a32.val[1] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0.val[1]),
-                               v_32);  // a[x] * 32 + 16
+        diff.val[0] = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128));
+        diff.val[1] = vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128));
+        a32.val[0] = vmlal_u8(a16, vget_low_u8(a0_128), v_32);
+        a32.val[1] = vmlal_u8(a16, vget_high_u8(a0_128), v_32);
         res.val[0] = vmlaq_u16(a32.val[0], diff.val[0], shift);
         res.val[1] = vmlaq_u16(a32.val[1], diff.val[1], shift);
 
@@ -892,10 +1361,8 @@
 
     mask.val[0] = vld1q_u8(BaseMask[base_max_diff]);
     mask.val[1] = vld1q_u8(BaseMask[base_max_diff] + 16);
-    dstvec[r].val[0] = vorrq_u8(vandq_u8(mask.val[0], res16[0]),
-                                vbicq_u8(a_mbase_x, mask.val[0]));
-    dstvec[r].val[1] = vorrq_u8(vandq_u8(mask.val[1], res16[1]),
-                                vbicq_u8(a_mbase_x, mask.val[1]));
+    dstvec[r].val[0] = vbslq_u8(mask.val[0], res16[0], a_mbase_x);
+    dstvec[r].val[1] = vbslq_u8(mask.val[1], res16[1], a_mbase_x);
     x += dx;
   }
 }
@@ -927,23 +1394,15 @@
   // final pixels will be calculated as:
   //   (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
 
-  uint8x16x2_t a0, a1;
-  uint16x8x2_t a32, diff;
-  uint16x8_t a16, c3f;
-  uint8x16_t a_mbase_x, max_base_x128, mask128;
-
-  a16 = vdupq_n_u16(16);
-  a_mbase_x = vdupq_n_u8(above[max_base_x]);
-  max_base_x128 = vdupq_n_u8(max_base_x);
-  c3f = vdupq_n_u16(0x3f);
-  uint16x8_t v_32 = vdupq_n_u16(32);
-  uint8x16_t v_zero = vdupq_n_u8(0);
-  uint8x16_t step = vdupq_n_u8(16);
+  const uint16x8_t a16 = vdupq_n_u16(16);
+  const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]);
+  const uint8x16_t max_base_x128 = vdupq_n_u8(max_base_x);
+  const uint8x8_t v_32 = vdup_n_u8(32);
+  const uint8x16_t v_zero = vdupq_n_u8(0);
+  const uint8x16_t step = vdupq_n_u8(16);
 
   int x = dx;
   for (int r = 0; r < N; r++, dst += stride) {
-    uint16x8x2_t res;
-
     int base = x >> frac_bits;
     if (base >= max_base_x) {
       for (int i = r; i < N; ++i) {
@@ -956,8 +1415,7 @@
       return;
     }
 
-    uint16x8_t shift = vshrq_n_u16(vandq_u16(vdupq_n_u16(x), c3f), 1);
-    uint8x16_t a0_128, a1_128, res128;
+    uint16x8_t shift = vdupq_n_u16((x & 0x3f) >> 1);
     uint8x16_t base_inc128 =
         vaddq_u8(vdupq_n_u8(base), vcombine_u8(vcreate_u8(0x0706050403020100),
                                                vcreate_u8(0x0F0E0D0C0B0A0908)));
@@ -967,28 +1425,21 @@
       if (mdif <= 0) {
         vst1q_u8(dst + j, a_mbase_x);
       } else {
+        uint16x8x2_t a32, diff, res;
+        uint8x16_t a0_128, a1_128, mask128, res128;
         a0_128 = vld1q_u8(above + base + j);
         a1_128 = vld1q_u8(above + base + 1 + j);
-        a0 = vzipq_u8(a0_128, v_zero);
-        a1 = vzipq_u8(a1_128, v_zero);
-        diff.val[0] =
-            vsubq_u16(vreinterpretq_u16_u8(a1.val[0]),
-                      vreinterpretq_u16_u8(a0.val[0]));  // a[x+1] - a[x]
-        diff.val[1] =
-            vsubq_u16(vreinterpretq_u16_u8(a1.val[1]),
-                      vreinterpretq_u16_u8(a0.val[1]));  // a[x+1] - a[x]
-        a32.val[0] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0.val[0]),
-                               v_32);  // a[x] * 32 + 16
-        a32.val[1] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0.val[1]),
-                               v_32);  // a[x] * 32 + 16
+        diff.val[0] = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128));
+        diff.val[1] = vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128));
+        a32.val[0] = vmlal_u8(a16, vget_low_u8(a0_128), v_32);
+        a32.val[1] = vmlal_u8(a16, vget_high_u8(a0_128), v_32);
         res.val[0] = vmlaq_u16(a32.val[0], diff.val[0], shift);
         res.val[1] = vmlaq_u16(a32.val[1], diff.val[1], shift);
         uint8x16_t v_temp =
             vcombine_u8(vshrn_n_u16(res.val[0], 5), vshrn_n_u16(res.val[1], 5));
 
         mask128 = vcgtq_u8(vqsubq_u8(max_base_x128, base_inc128), v_zero);
-        res128 =
-            vorrq_u8(vandq_u8(mask128, v_temp), vbicq_u8(a_mbase_x, mask128));
+        res128 = vbslq_u8(mask128, v_temp, a_mbase_x);
         vst1q_u8(dst + j, res128);
 
         base_inc128 = vaddq_u8(base_inc128, step);
@@ -1023,7 +1474,6 @@
       break;
     default: break;
   }
-  return;
 }
 
 /* ---------------------P R E D I C T I O N   Z 2--------------------------- */
@@ -1289,6 +1739,14 @@
       }
     }
 
+    diff.val[0] =
+        vsubq_u16(vreinterpretq_u16_u8(a1_x.val[0]),
+                  vreinterpretq_u16_u8(a0_x.val[0]));  // a[x+1] - a[x]
+    a32.val[0] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0_x.val[0]),
+                           v_32);  // a[x] * 32 + 16
+    res.val[0] = vmlaq_u16(a32.val[0], diff.val[0], shift.val[0]);
+    resx = vshrn_n_u16(res.val[0], 5);
+
     // y calc
     if (base_x < min_base_x) {
       DECLARE_ALIGNED(32, int16_t, base_y_c[16]);
@@ -1334,26 +1792,20 @@
         shift.val[1] =
             vshrq_n_u16(vandq_u16(vreinterpretq_u16_s16(y_c128), c3f), 1);
       }
+      diff.val[1] =
+          vsubq_u16(vreinterpretq_u16_u8(a1_x.val[1]),
+                    vreinterpretq_u16_u8(a0_x.val[1]));  // a[x+1] - a[x]
+      a32.val[1] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0_x.val[1]),
+                             v_32);  // a[x] * 32 + 16
+      res.val[1] = vmlaq_u16(a32.val[1], diff.val[1], shift.val[1]);
+      resy = vshrn_n_u16(res.val[1], 5);
+      uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]);
+      resxy = vorr_u8(vand_u8(mask, resy), vbic_u8(resx, mask));
+      vst1_u8(dst, resxy);
+    } else {
+      vst1_u8(dst, resx);
     }
-    diff.val[0] =
-        vsubq_u16(vreinterpretq_u16_u8(a1_x.val[0]),
-                  vreinterpretq_u16_u8(a0_x.val[0]));  // a[x+1] - a[x]
-    diff.val[1] =
-        vsubq_u16(vreinterpretq_u16_u8(a1_x.val[1]),
-                  vreinterpretq_u16_u8(a0_x.val[1]));  // a[x+1] - a[x]
-    a32.val[0] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0_x.val[0]),
-                           v_32);  // a[x] * 32 + 16
-    a32.val[1] = vmlaq_u16(a16, vreinterpretq_u16_u8(a0_x.val[1]),
-                           v_32);  // a[x] * 32 + 16
-    res.val[0] = vmlaq_u16(a32.val[0], diff.val[0], shift.val[0]);
-    res.val[1] = vmlaq_u16(a32.val[1], diff.val[1], shift.val[1]);
-    resx = vshrn_n_u16(res.val[0], 5);
-    resy = vshrn_n_u16(res.val[1], 5);
 
-    uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]);
-
-    resxy = vorr_u8(vand_u8(mask, resy), vbic_u8(resx, mask));
-    vst1_u8(dst, resxy);
     dst += stride;
   }
 }
@@ -1629,7 +2081,6 @@
                                 upsample_above, upsample_left, dx, dy);
       break;
   }
-  return;
 }
 
 /* ---------------------P R E D I C T I O N   Z 3--------------------------- */
@@ -3212,7 +3663,7 @@
                                        int width, int height) {
   const uint8x8_t top_left = vdup_n_u8(top_row[-1]);
   const uint16x8_t top_left_x2 = vdupq_n_u16(top_row[-1] + top_row[-1]);
-  uint8x8_t top;
+  uint8x8_t UNINITIALIZED_IS_SAFE(top);
   if (width == 4) {
     load_u8_4x1(top_row, &top, 0);
   } else {  // width == 8
diff --git a/aom_dsp/arm/loopfilter_neon.c b/aom_dsp/arm/loopfilter_neon.c
index f3f86a2..8fc7ccb 100644
--- a/aom_dsp/arm/loopfilter_neon.c
+++ b/aom_dsp/arm/loopfilter_neon.c
@@ -628,7 +628,7 @@
   // row1: x p6 p5 p4 p3 p2 p1 p0 | q0 q1 q2 q3 q4 q5 q6 y
   // row2: x p6 p5 p4 p3 p2 p1 p0 | q0 q1 q2 q3 q4 q5 q6 y
   // row3: x p6 p5 p4 p3 p2 p1 p0 | q0 q1 q2 q3 q4 q5 q6 y
-  load_u8_8x16(src - 8, stride, &row0, &row1, &row2, &row3);
+  load_u8_16x4(src - 8, stride, &row0, &row1, &row2, &row3);
 
   pxp3 = vget_low_u8(row0);
   p6p2 = vget_low_u8(row1);
@@ -841,8 +841,7 @@
   // row1: p1 p0 | q0 q1
   // row2: p1 p0 | q0 q1
   // row3: p1 p0 | q0 q1
-  load_unaligned_u8_4x4(src - 2, stride, (uint32x2_t *)&p1p0,
-                        (uint32x2_t *)&q0q1);
+  load_unaligned_u8_4x4(src - 2, stride, &p1p0, &q0q1);
 
   transpose_u8_4x4(&p1p0, &q0q1);
 
@@ -1037,7 +1036,7 @@
 
 void aom_lpf_horizontal_4_neon(uint8_t *src, int stride, const uint8_t *blimit,
                                const uint8_t *limit, const uint8_t *thresh) {
-  uint8x8_t p0q0, UNINITIALIZED_IS_SAFE(p1q1);
+  uint8x8_t UNINITIALIZED_IS_SAFE(p0q0), UNINITIALIZED_IS_SAFE(p1q1);
 
   load_u8_4x1(src - 2 * stride, &p1q1, 0);
   load_u8_4x1(src - 1 * stride, &p0q0, 0);
diff --git a/aom_dsp/arm/masked_sad_neon.c b/aom_dsp/arm/masked_sad_neon.c
new file mode 100644
index 0000000..340df05
--- /dev/null
+++ b/aom_dsp/arm/masked_sad_neon.c
@@ -0,0 +1,257 @@
+/*
+ * Copyright (c) 2023, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <arm_neon.h>
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+#include "aom/aom_integer.h"
+#include "aom_dsp/blend.h"
+#include "mem_neon.h"
+#include "sum_neon.h"
+
+static INLINE uint16x8_t masked_sad_16x1_neon(uint16x8_t sad,
+                                              const uint8_t *src,
+                                              const uint8_t *a,
+                                              const uint8_t *b,
+                                              const uint8_t *m) {
+  uint8x16_t m0 = vld1q_u8(m);
+  uint8x16_t a0 = vld1q_u8(a);
+  uint8x16_t b0 = vld1q_u8(b);
+  uint8x16_t s0 = vld1q_u8(src);
+
+  uint8x16_t m0_inv = vsubq_u8(vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA), m0);
+  uint16x8_t blend_u16_lo = vmull_u8(vget_low_u8(m0), vget_low_u8(a0));
+  uint16x8_t blend_u16_hi = vmull_u8(vget_high_u8(m0), vget_high_u8(a0));
+  blend_u16_lo = vmlal_u8(blend_u16_lo, vget_low_u8(m0_inv), vget_low_u8(b0));
+  blend_u16_hi = vmlal_u8(blend_u16_hi, vget_high_u8(m0_inv), vget_high_u8(b0));
+
+  uint8x8_t blend_u8_lo = vrshrn_n_u16(blend_u16_lo, AOM_BLEND_A64_ROUND_BITS);
+  uint8x8_t blend_u8_hi = vrshrn_n_u16(blend_u16_hi, AOM_BLEND_A64_ROUND_BITS);
+  uint8x16_t blend_u8 = vcombine_u8(blend_u8_lo, blend_u8_hi);
+
+  return vpadalq_u8(sad, vabdq_u8(blend_u8, s0));
+}
+
+static INLINE unsigned masked_sad_128xh_neon(const uint8_t *src, int src_stride,
+                                             const uint8_t *a, int a_stride,
+                                             const uint8_t *b, int b_stride,
+                                             const uint8_t *m, int m_stride,
+                                             int height) {
+  // Eight accumulator vectors are required to avoid overflow in the 128x128
+  // case.
+  assert(height <= 128);
+  uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                       vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                       vdupq_n_u16(0), vdupq_n_u16(0) };
+
+  do {
+    sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]);
+    sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]);
+    sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]);
+    sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]);
+    sad[4] = masked_sad_16x1_neon(sad[4], &src[64], &a[64], &b[64], &m[64]);
+    sad[5] = masked_sad_16x1_neon(sad[5], &src[80], &a[80], &b[80], &m[80]);
+    sad[6] = masked_sad_16x1_neon(sad[6], &src[96], &a[96], &b[96], &m[96]);
+    sad[7] = masked_sad_16x1_neon(sad[7], &src[112], &a[112], &b[112], &m[112]);
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+    height--;
+  } while (height != 0);
+
+  return horizontal_long_add_u16x8(sad[0], sad[1]) +
+         horizontal_long_add_u16x8(sad[2], sad[3]) +
+         horizontal_long_add_u16x8(sad[4], sad[5]) +
+         horizontal_long_add_u16x8(sad[6], sad[7]);
+}
+
+static INLINE unsigned masked_sad_64xh_neon(const uint8_t *src, int src_stride,
+                                            const uint8_t *a, int a_stride,
+                                            const uint8_t *b, int b_stride,
+                                            const uint8_t *m, int m_stride,
+                                            int height) {
+  // Four accumulator vectors are required to avoid overflow in the 64x128 case.
+  assert(height <= 128);
+  uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                       vdupq_n_u16(0) };
+
+  do {
+    sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]);
+    sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]);
+    sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]);
+    sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]);
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+    height--;
+  } while (height != 0);
+
+  return horizontal_long_add_u16x8(sad[0], sad[1]) +
+         horizontal_long_add_u16x8(sad[2], sad[3]);
+}
+
+static INLINE unsigned masked_sad_32xh_neon(const uint8_t *src, int src_stride,
+                                            const uint8_t *a, int a_stride,
+                                            const uint8_t *b, int b_stride,
+                                            const uint8_t *m, int m_stride,
+                                            int height) {
+  // We could use a single accumulator up to height=64 without overflow.
+  assert(height <= 64);
+  uint16x8_t sad = vdupq_n_u16(0);
+
+  do {
+    sad = masked_sad_16x1_neon(sad, &src[0], &a[0], &b[0], &m[0]);
+    sad = masked_sad_16x1_neon(sad, &src[16], &a[16], &b[16], &m[16]);
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+    height--;
+  } while (height != 0);
+
+  return horizontal_add_u16x8(sad);
+}
+
+static INLINE unsigned masked_sad_16xh_neon(const uint8_t *src, int src_stride,
+                                            const uint8_t *a, int a_stride,
+                                            const uint8_t *b, int b_stride,
+                                            const uint8_t *m, int m_stride,
+                                            int height) {
+  // We could use a single accumulator up to height=128 without overflow.
+  assert(height <= 128);
+  uint16x8_t sad = vdupq_n_u16(0);
+
+  do {
+    sad = masked_sad_16x1_neon(sad, src, a, b, m);
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+    height--;
+  } while (height != 0);
+
+  return horizontal_add_u16x8(sad);
+}
+
+static INLINE unsigned masked_sad_8xh_neon(const uint8_t *src, int src_stride,
+                                           const uint8_t *a, int a_stride,
+                                           const uint8_t *b, int b_stride,
+                                           const uint8_t *m, int m_stride,
+                                           int height) {
+  // We could use a single accumulator up to height=128 without overflow.
+  assert(height <= 128);
+  uint16x4_t sad = vdup_n_u16(0);
+
+  do {
+    uint8x8_t m0 = vld1_u8(m);
+    uint8x8_t a0 = vld1_u8(a);
+    uint8x8_t b0 = vld1_u8(b);
+    uint8x8_t s0 = vld1_u8(src);
+
+    uint8x8_t m0_inv = vsub_u8(vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA), m0);
+    uint16x8_t blend_u16 = vmull_u8(m0, a0);
+    blend_u16 = vmlal_u8(blend_u16, m0_inv, b0);
+    uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+    sad = vpadal_u8(sad, vabd_u8(blend_u8, s0));
+
+    src += src_stride;
+    a += a_stride;
+    b += b_stride;
+    m += m_stride;
+    height--;
+  } while (height != 0);
+
+  return horizontal_add_u16x4(sad);
+}
+
+static INLINE unsigned masked_sad_4xh_neon(const uint8_t *src, int src_stride,
+                                           const uint8_t *a, int a_stride,
+                                           const uint8_t *b, int b_stride,
+                                           const uint8_t *m, int m_stride,
+                                           int height) {
+  // Process two rows per loop iteration.
+  assert(height % 2 == 0);
+
+  // We could use a single accumulator up to height=256 without overflow.
+  assert(height <= 256);
+  uint16x4_t sad = vdup_n_u16(0);
+
+  do {
+    uint8x8_t m0 = load_unaligned_u8(m, m_stride);
+    uint8x8_t a0 = load_unaligned_u8(a, a_stride);
+    uint8x8_t b0 = load_unaligned_u8(b, b_stride);
+    uint8x8_t s0 = load_unaligned_u8(src, src_stride);
+
+    uint8x8_t m0_inv = vsub_u8(vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA), m0);
+    uint16x8_t blend_u16 = vmull_u8(m0, a0);
+    blend_u16 = vmlal_u8(blend_u16, m0_inv, b0);
+    uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
+
+    sad = vpadal_u8(sad, vabd_u8(blend_u8, s0));
+
+    src += 2 * src_stride;
+    a += 2 * a_stride;
+    b += 2 * b_stride;
+    m += 2 * m_stride;
+    height -= 2;
+  } while (height != 0);
+
+  return horizontal_add_u16x4(sad);
+}
+
+#define MASKED_SAD_WXH_NEON(width, height)                                    \
+  unsigned aom_masked_sad##width##x##height##_neon(                           \
+      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
+      const uint8_t *second_pred, const uint8_t *msk, int msk_stride,         \
+      int invert_mask) {                                                      \
+    if (!invert_mask)                                                         \
+      return masked_sad_##width##xh_neon(src, src_stride, ref, ref_stride,    \
+                                         second_pred, width, msk, msk_stride, \
+                                         height);                             \
+    else                                                                      \
+      return masked_sad_##width##xh_neon(src, src_stride, second_pred, width, \
+                                         ref, ref_stride, msk, msk_stride,    \
+                                         height);                             \
+  }
+
+MASKED_SAD_WXH_NEON(4, 4)
+MASKED_SAD_WXH_NEON(4, 8)
+MASKED_SAD_WXH_NEON(8, 4)
+MASKED_SAD_WXH_NEON(8, 8)
+MASKED_SAD_WXH_NEON(8, 16)
+MASKED_SAD_WXH_NEON(16, 8)
+MASKED_SAD_WXH_NEON(16, 16)
+MASKED_SAD_WXH_NEON(16, 32)
+MASKED_SAD_WXH_NEON(32, 16)
+MASKED_SAD_WXH_NEON(32, 32)
+MASKED_SAD_WXH_NEON(32, 64)
+MASKED_SAD_WXH_NEON(64, 32)
+MASKED_SAD_WXH_NEON(64, 64)
+MASKED_SAD_WXH_NEON(64, 128)
+MASKED_SAD_WXH_NEON(128, 64)
+MASKED_SAD_WXH_NEON(128, 128)
+#if !CONFIG_REALTIME_ONLY
+MASKED_SAD_WXH_NEON(4, 16)
+MASKED_SAD_WXH_NEON(16, 4)
+MASKED_SAD_WXH_NEON(8, 32)
+MASKED_SAD_WXH_NEON(32, 8)
+MASKED_SAD_WXH_NEON(16, 64)
+MASKED_SAD_WXH_NEON(64, 16)
+#endif
diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index 73a5127..30328f7 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -12,6 +12,7 @@
 #define AOM_AOM_DSP_ARM_MEM_NEON_H_
 
 #include <arm_neon.h>
+#include <assert.h>
 #include <string.h>
 #include "aom_dsp/aom_dsp_common.h"
 
@@ -73,14 +74,18 @@
 #endif  // __GNUC__ < 9
 #endif  // defined(__GNUC__) && !defined(__clang__)
 
-static INLINE void store_row2_u8_8x8(uint8_t *s, int p, const uint8x8_t s0,
-                                     const uint8x8_t s1) {
+static INLINE void store_u8_8x2(uint8_t *s, ptrdiff_t p, const uint8x8_t s0,
+                                const uint8x8_t s1) {
   vst1_u8(s, s0);
   s += p;
   vst1_u8(s, s1);
   s += p;
 }
 
+static INLINE uint8x16_t load_u8_8x2(const uint8_t *s, ptrdiff_t p) {
+  return vcombine_u8(vld1_u8(s), vld1_u8(s + p));
+}
+
 /* These intrinsics require immediate values, so we must use #defines
    to enforce that. */
 #define load_u8_4x1(s, s0, lane)                                           \
@@ -89,6 +94,26 @@
         vld1_lane_u32((uint32_t *)(s), vreinterpret_u32_u8(*(s0)), lane)); \
   } while (0)
 
+// Load four bytes into the low half of a uint8x8_t, zero the upper half.
+static INLINE uint8x8_t load_u8_4x1_lane0(const uint8_t *p) {
+  uint8x8_t ret = vdup_n_u8(0);
+  load_u8_4x1(p, &ret, 0);
+  return ret;
+}
+
+// Load 2 sets of 4 bytes when alignment is guaranteed.
+static INLINE uint8x8_t load_u8(const uint8_t *buf, ptrdiff_t stride) {
+  uint32x2_t a = vdup_n_u32(0);
+
+  assert(!((intptr_t)buf % sizeof(uint32_t)));
+  assert(!(stride % sizeof(uint32_t)));
+
+  a = vld1_lane_u32((const uint32_t *)buf, a, 0);
+  buf += stride;
+  a = vld1_lane_u32((const uint32_t *)buf, a, 1);
+  return vreinterpret_u8_u32(a);
+}
+
 static INLINE void load_u8_8x8(const uint8_t *s, ptrdiff_t p,
                                uint8x8_t *const s0, uint8x8_t *const s1,
                                uint8x8_t *const s2, uint8x8_t *const s3,
@@ -111,16 +136,24 @@
   *s7 = vld1_u8(s);
 }
 
-static INLINE void load_u8_8x16(const uint8_t *s, ptrdiff_t p,
-                                uint8x16_t *const s0, uint8x16_t *const s1,
-                                uint8x16_t *const s2, uint8x16_t *const s3) {
-  *s0 = vld1q_u8(s);
+static INLINE void load_u8_8x7(const uint8_t *s, ptrdiff_t p,
+                               uint8x8_t *const s0, uint8x8_t *const s1,
+                               uint8x8_t *const s2, uint8x8_t *const s3,
+                               uint8x8_t *const s4, uint8x8_t *const s5,
+                               uint8x8_t *const s6) {
+  *s0 = vld1_u8(s);
   s += p;
-  *s1 = vld1q_u8(s);
+  *s1 = vld1_u8(s);
   s += p;
-  *s2 = vld1q_u8(s);
+  *s2 = vld1_u8(s);
   s += p;
-  *s3 = vld1q_u8(s);
+  *s3 = vld1_u8(s);
+  s += p;
+  *s4 = vld1_u8(s);
+  s += p;
+  *s5 = vld1_u8(s);
+  s += p;
+  *s6 = vld1_u8(s);
 }
 
 static INLINE void load_u8_8x4(const uint8_t *s, const ptrdiff_t p,
@@ -148,6 +181,40 @@
   s += p;
 }
 
+static INLINE void load_u16_4x7(const uint16_t *s, ptrdiff_t p,
+                                uint16x4_t *const s0, uint16x4_t *const s1,
+                                uint16x4_t *const s2, uint16x4_t *const s3,
+                                uint16x4_t *const s4, uint16x4_t *const s5,
+                                uint16x4_t *const s6) {
+  *s0 = vld1_u16(s);
+  s += p;
+  *s1 = vld1_u16(s);
+  s += p;
+  *s2 = vld1_u16(s);
+  s += p;
+  *s3 = vld1_u16(s);
+  s += p;
+  *s4 = vld1_u16(s);
+  s += p;
+  *s5 = vld1_u16(s);
+  s += p;
+  *s6 = vld1_u16(s);
+}
+
+static INLINE void load_s16_8x2(const int16_t *s, const ptrdiff_t p,
+                                int16x8_t *const s0, int16x8_t *const s1) {
+  *s0 = vld1q_s16(s);
+  s += p;
+  *s1 = vld1q_s16(s);
+}
+
+static INLINE void load_u16_8x2(const uint16_t *s, const ptrdiff_t p,
+                                uint16x8_t *const s0, uint16x8_t *const s1) {
+  *s0 = vld1q_u16(s);
+  s += p;
+  *s1 = vld1q_u16(s);
+}
+
 static INLINE void load_u16_8x4(const uint16_t *s, const ptrdiff_t p,
                                 uint16x8_t *const s0, uint16x8_t *const s1,
                                 uint16x8_t *const s2, uint16x8_t *const s3) {
@@ -161,6 +228,66 @@
   s += p;
 }
 
+static INLINE void load_s16_4x11(const int16_t *s, ptrdiff_t p,
+                                 int16x4_t *const s0, int16x4_t *const s1,
+                                 int16x4_t *const s2, int16x4_t *const s3,
+                                 int16x4_t *const s4, int16x4_t *const s5,
+                                 int16x4_t *const s6, int16x4_t *const s7,
+                                 int16x4_t *const s8, int16x4_t *const s9,
+                                 int16x4_t *const s10) {
+  *s0 = vld1_s16(s);
+  s += p;
+  *s1 = vld1_s16(s);
+  s += p;
+  *s2 = vld1_s16(s);
+  s += p;
+  *s3 = vld1_s16(s);
+  s += p;
+  *s4 = vld1_s16(s);
+  s += p;
+  *s5 = vld1_s16(s);
+  s += p;
+  *s6 = vld1_s16(s);
+  s += p;
+  *s7 = vld1_s16(s);
+  s += p;
+  *s8 = vld1_s16(s);
+  s += p;
+  *s9 = vld1_s16(s);
+  s += p;
+  *s10 = vld1_s16(s);
+}
+
+static INLINE void load_u16_4x11(const uint16_t *s, ptrdiff_t p,
+                                 uint16x4_t *const s0, uint16x4_t *const s1,
+                                 uint16x4_t *const s2, uint16x4_t *const s3,
+                                 uint16x4_t *const s4, uint16x4_t *const s5,
+                                 uint16x4_t *const s6, uint16x4_t *const s7,
+                                 uint16x4_t *const s8, uint16x4_t *const s9,
+                                 uint16x4_t *const s10) {
+  *s0 = vld1_u16(s);
+  s += p;
+  *s1 = vld1_u16(s);
+  s += p;
+  *s2 = vld1_u16(s);
+  s += p;
+  *s3 = vld1_u16(s);
+  s += p;
+  *s4 = vld1_u16(s);
+  s += p;
+  *s5 = vld1_u16(s);
+  s += p;
+  *s6 = vld1_u16(s);
+  s += p;
+  *s7 = vld1_u16(s);
+  s += p;
+  *s8 = vld1_u16(s);
+  s += p;
+  *s9 = vld1_u16(s);
+  s += p;
+  *s10 = vld1_u16(s);
+}
+
 static INLINE void load_s16_4x8(const int16_t *s, ptrdiff_t p,
                                 int16x4_t *const s0, int16x4_t *const s1,
                                 int16x4_t *const s2, int16x4_t *const s3,
@@ -183,6 +310,88 @@
   *s7 = vld1_s16(s);
 }
 
+static INLINE void load_s16_4x7(const int16_t *s, ptrdiff_t p,
+                                int16x4_t *const s0, int16x4_t *const s1,
+                                int16x4_t *const s2, int16x4_t *const s3,
+                                int16x4_t *const s4, int16x4_t *const s5,
+                                int16x4_t *const s6) {
+  *s0 = vld1_s16(s);
+  s += p;
+  *s1 = vld1_s16(s);
+  s += p;
+  *s2 = vld1_s16(s);
+  s += p;
+  *s3 = vld1_s16(s);
+  s += p;
+  *s4 = vld1_s16(s);
+  s += p;
+  *s5 = vld1_s16(s);
+  s += p;
+  *s6 = vld1_s16(s);
+}
+
+static INLINE void load_s16_4x5(const int16_t *s, ptrdiff_t p,
+                                int16x4_t *const s0, int16x4_t *const s1,
+                                int16x4_t *const s2, int16x4_t *const s3,
+                                int16x4_t *const s4) {
+  *s0 = vld1_s16(s);
+  s += p;
+  *s1 = vld1_s16(s);
+  s += p;
+  *s2 = vld1_s16(s);
+  s += p;
+  *s3 = vld1_s16(s);
+  s += p;
+  *s4 = vld1_s16(s);
+}
+
+static INLINE void load_u16_4x5(const uint16_t *s, const ptrdiff_t p,
+                                uint16x4_t *const s0, uint16x4_t *const s1,
+                                uint16x4_t *const s2, uint16x4_t *const s3,
+                                uint16x4_t *const s4) {
+  *s0 = vld1_u16(s);
+  s += p;
+  *s1 = vld1_u16(s);
+  s += p;
+  *s2 = vld1_u16(s);
+  s += p;
+  *s3 = vld1_u16(s);
+  s += p;
+  *s4 = vld1_u16(s);
+  s += p;
+}
+
+static INLINE void load_u8_8x5(const uint8_t *s, ptrdiff_t p,
+                               uint8x8_t *const s0, uint8x8_t *const s1,
+                               uint8x8_t *const s2, uint8x8_t *const s3,
+                               uint8x8_t *const s4) {
+  *s0 = vld1_u8(s);
+  s += p;
+  *s1 = vld1_u8(s);
+  s += p;
+  *s2 = vld1_u8(s);
+  s += p;
+  *s3 = vld1_u8(s);
+  s += p;
+  *s4 = vld1_u8(s);
+}
+
+static INLINE void load_u16_8x5(const uint16_t *s, const ptrdiff_t p,
+                                uint16x8_t *const s0, uint16x8_t *const s1,
+                                uint16x8_t *const s2, uint16x8_t *const s3,
+                                uint16x8_t *const s4) {
+  *s0 = vld1q_u16(s);
+  s += p;
+  *s1 = vld1q_u16(s);
+  s += p;
+  *s2 = vld1q_u16(s);
+  s += p;
+  *s3 = vld1q_u16(s);
+  s += p;
+  *s4 = vld1q_u16(s);
+  s += p;
+}
+
 static INLINE void load_s16_4x4(const int16_t *s, ptrdiff_t p,
                                 int16x4_t *const s0, int16x4_t *const s1,
                                 int16x4_t *const s2, int16x4_t *const s3) {
@@ -197,6 +406,11 @@
 
 /* These intrinsics require immediate values, so we must use #defines
    to enforce that. */
+#define store_u8_2x1(s, s0, lane)                                  \
+  do {                                                             \
+    vst1_lane_u16((uint16_t *)(s), vreinterpret_u16_u8(s0), lane); \
+  } while (0)
+
 #define store_u8_4x1(s, s0, lane)                                  \
   do {                                                             \
     vst1_lane_u32((uint32_t *)(s), vreinterpret_u32_u8(s0), lane); \
@@ -282,6 +496,13 @@
   vst1_u16(s, s3);
 }
 
+static INLINE void store_u16_8x2(uint16_t *s, ptrdiff_t dst_stride,
+                                 const uint16x8_t s0, const uint16x8_t s1) {
+  vst1q_u16(s, s0);
+  s += dst_stride;
+  vst1q_u16(s, s1);
+}
+
 static INLINE void store_u16_8x4(uint16_t *s, ptrdiff_t dst_stride,
                                  const uint16x8_t s0, const uint16x8_t s1,
                                  const uint16x8_t s2, const uint16x8_t s3) {
@@ -328,6 +549,21 @@
   vst1_s16(s, s3);
 }
 
+/* These intrinsics require immediate values, so we must use #defines
+   to enforce that. */
+#define store_s16_2x1(s, s0, lane)                                 \
+  do {                                                             \
+    vst1_lane_s32((int32_t *)(s), vreinterpret_s32_s16(s0), lane); \
+  } while (0)
+#define store_u16_2x1(s, s0, lane)                                  \
+  do {                                                              \
+    vst1_lane_u32((uint32_t *)(s), vreinterpret_u32_u16(s0), lane); \
+  } while (0)
+#define store_u16q_2x1(s, s0, lane)                                   \
+  do {                                                                \
+    vst1q_lane_u32((uint32_t *)(s), vreinterpretq_u32_u16(s0), lane); \
+  } while (0)
+
 static INLINE void store_s16_8x4(int16_t *s, ptrdiff_t dst_stride,
                                  const int16x8_t s0, const int16x8_t s1,
                                  const int16x8_t s2, const int16x8_t s3) {
@@ -340,6 +576,96 @@
   vst1q_s16(s, s3);
 }
 
+static INLINE void load_u8_8x11(const uint8_t *s, ptrdiff_t p,
+                                uint8x8_t *const s0, uint8x8_t *const s1,
+                                uint8x8_t *const s2, uint8x8_t *const s3,
+                                uint8x8_t *const s4, uint8x8_t *const s5,
+                                uint8x8_t *const s6, uint8x8_t *const s7,
+                                uint8x8_t *const s8, uint8x8_t *const s9,
+                                uint8x8_t *const s10) {
+  *s0 = vld1_u8(s);
+  s += p;
+  *s1 = vld1_u8(s);
+  s += p;
+  *s2 = vld1_u8(s);
+  s += p;
+  *s3 = vld1_u8(s);
+  s += p;
+  *s4 = vld1_u8(s);
+  s += p;
+  *s5 = vld1_u8(s);
+  s += p;
+  *s6 = vld1_u8(s);
+  s += p;
+  *s7 = vld1_u8(s);
+  s += p;
+  *s8 = vld1_u8(s);
+  s += p;
+  *s9 = vld1_u8(s);
+  s += p;
+  *s10 = vld1_u8(s);
+}
+
+static INLINE void load_s16_8x11(const int16_t *s, ptrdiff_t p,
+                                 int16x8_t *const s0, int16x8_t *const s1,
+                                 int16x8_t *const s2, int16x8_t *const s3,
+                                 int16x8_t *const s4, int16x8_t *const s5,
+                                 int16x8_t *const s6, int16x8_t *const s7,
+                                 int16x8_t *const s8, int16x8_t *const s9,
+                                 int16x8_t *const s10) {
+  *s0 = vld1q_s16(s);
+  s += p;
+  *s1 = vld1q_s16(s);
+  s += p;
+  *s2 = vld1q_s16(s);
+  s += p;
+  *s3 = vld1q_s16(s);
+  s += p;
+  *s4 = vld1q_s16(s);
+  s += p;
+  *s5 = vld1q_s16(s);
+  s += p;
+  *s6 = vld1q_s16(s);
+  s += p;
+  *s7 = vld1q_s16(s);
+  s += p;
+  *s8 = vld1q_s16(s);
+  s += p;
+  *s9 = vld1q_s16(s);
+  s += p;
+  *s10 = vld1q_s16(s);
+}
+
+static INLINE void load_u16_8x11(const uint16_t *s, ptrdiff_t p,
+                                 uint16x8_t *const s0, uint16x8_t *const s1,
+                                 uint16x8_t *const s2, uint16x8_t *const s3,
+                                 uint16x8_t *const s4, uint16x8_t *const s5,
+                                 uint16x8_t *const s6, uint16x8_t *const s7,
+                                 uint16x8_t *const s8, uint16x8_t *const s9,
+                                 uint16x8_t *const s10) {
+  *s0 = vld1q_u16(s);
+  s += p;
+  *s1 = vld1q_u16(s);
+  s += p;
+  *s2 = vld1q_u16(s);
+  s += p;
+  *s3 = vld1q_u16(s);
+  s += p;
+  *s4 = vld1q_u16(s);
+  s += p;
+  *s5 = vld1q_u16(s);
+  s += p;
+  *s6 = vld1q_u16(s);
+  s += p;
+  *s7 = vld1q_u16(s);
+  s += p;
+  *s8 = vld1q_u16(s);
+  s += p;
+  *s9 = vld1q_u16(s);
+  s += p;
+  *s10 = vld1q_u16(s);
+}
+
 static INLINE void load_s16_8x8(const int16_t *s, ptrdiff_t p,
                                 int16x8_t *const s0, int16x8_t *const s1,
                                 int16x8_t *const s2, int16x8_t *const s3,
@@ -362,6 +688,61 @@
   *s7 = vld1q_s16(s);
 }
 
+static INLINE void load_u16_8x7(const uint16_t *s, ptrdiff_t p,
+                                uint16x8_t *const s0, uint16x8_t *const s1,
+                                uint16x8_t *const s2, uint16x8_t *const s3,
+                                uint16x8_t *const s4, uint16x8_t *const s5,
+                                uint16x8_t *const s6) {
+  *s0 = vld1q_u16(s);
+  s += p;
+  *s1 = vld1q_u16(s);
+  s += p;
+  *s2 = vld1q_u16(s);
+  s += p;
+  *s3 = vld1q_u16(s);
+  s += p;
+  *s4 = vld1q_u16(s);
+  s += p;
+  *s5 = vld1q_u16(s);
+  s += p;
+  *s6 = vld1q_u16(s);
+}
+
+static INLINE void load_s16_8x7(const int16_t *s, ptrdiff_t p,
+                                int16x8_t *const s0, int16x8_t *const s1,
+                                int16x8_t *const s2, int16x8_t *const s3,
+                                int16x8_t *const s4, int16x8_t *const s5,
+                                int16x8_t *const s6) {
+  *s0 = vld1q_s16(s);
+  s += p;
+  *s1 = vld1q_s16(s);
+  s += p;
+  *s2 = vld1q_s16(s);
+  s += p;
+  *s3 = vld1q_s16(s);
+  s += p;
+  *s4 = vld1q_s16(s);
+  s += p;
+  *s5 = vld1q_s16(s);
+  s += p;
+  *s6 = vld1q_s16(s);
+}
+
+static INLINE void load_s16_8x5(const int16_t *s, ptrdiff_t p,
+                                int16x8_t *const s0, int16x8_t *const s1,
+                                int16x8_t *const s2, int16x8_t *const s3,
+                                int16x8_t *const s4) {
+  *s0 = vld1q_s16(s);
+  s += p;
+  *s1 = vld1q_s16(s);
+  s += p;
+  *s2 = vld1q_s16(s);
+  s += p;
+  *s3 = vld1q_s16(s);
+  s += p;
+  *s4 = vld1q_s16(s);
+}
+
 static INLINE void load_s16_8x4(const int16_t *s, ptrdiff_t p,
                                 int16x8_t *const s0, int16x8_t *const s1,
                                 int16x8_t *const s2, int16x8_t *const s3) {
@@ -404,71 +785,61 @@
   return vreinterpretq_u8_u32(a_u32);
 }
 
-static INLINE void load_unaligned_u8_4x8(const uint8_t *buf, int stride,
-                                         uint32x2_t *tu0, uint32x2_t *tu1,
-                                         uint32x2_t *tu2, uint32x2_t *tu3) {
+static INLINE uint8x8_t load_unaligned_u8_2x2(const uint8_t *buf, int stride) {
+  uint16_t a;
+  uint16x4_t a_u16;
+
+  memcpy(&a, buf, 2);
+  buf += stride;
+  a_u16 = vdup_n_u16(a);
+  memcpy(&a, buf, 2);
+  a_u16 = vset_lane_u16(a, a_u16, 1);
+  return vreinterpret_u8_u16(a_u16);
+}
+
+static INLINE uint8x8_t load_unaligned_u8_4x1(const uint8_t *buf) {
   uint32_t a;
+  uint32x2_t a_u32;
+
+  memcpy(&a, buf, 4);
+  a_u32 = vdup_n_u32(0);
+  a_u32 = vset_lane_u32(a, a_u32, 0);
+  return vreinterpret_u8_u32(a_u32);
+}
+
+static INLINE uint8x8_t load_unaligned_u8_4x2(const uint8_t *buf, int stride) {
+  uint32_t a;
+  uint32x2_t a_u32;
 
   memcpy(&a, buf, 4);
   buf += stride;
-  *tu0 = vdup_n_u32(a);
+  a_u32 = vdup_n_u32(a);
   memcpy(&a, buf, 4);
-  buf += stride;
-  *tu0 = vset_lane_u32(a, *tu0, 1);
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu1 = vdup_n_u32(a);
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu1 = vset_lane_u32(a, *tu1, 1);
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu2 = vdup_n_u32(a);
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu2 = vset_lane_u32(a, *tu2, 1);
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu3 = vdup_n_u32(a);
-  memcpy(&a, buf, 4);
-  *tu3 = vset_lane_u32(a, *tu3, 1);
+  a_u32 = vset_lane_u32(a, a_u32, 1);
+  return vreinterpret_u8_u32(a_u32);
 }
 
 static INLINE void load_unaligned_u8_4x4(const uint8_t *buf, int stride,
-                                         uint32x2_t *tu0, uint32x2_t *tu1) {
-  uint32_t a;
-
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu0 = vdup_n_u32(a);
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu0 = vset_lane_u32(a, *tu0, 1);
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu1 = vdup_n_u32(a);
-  memcpy(&a, buf, 4);
-  *tu1 = vset_lane_u32(a, *tu1, 1);
+                                         uint8x8_t *tu0, uint8x8_t *tu1) {
+  *tu0 = load_unaligned_u8_4x2(buf, stride);
+  buf += 2 * stride;
+  *tu1 = load_unaligned_u8_4x2(buf, stride);
 }
 
-static INLINE void load_unaligned_u8_4x1(const uint8_t *buf, int stride,
-                                         uint32x2_t *tu0) {
-  uint32_t a;
-
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu0 = vset_lane_u32(a, *tu0, 0);
+static INLINE void load_unaligned_u8_3x8(const uint8_t *buf, int stride,
+                                         uint8x8_t *tu0, uint8x8_t *tu1,
+                                         uint8x8_t *tu2) {
+  load_unaligned_u8_4x4(buf, stride, tu0, tu1);
+  buf += 4 * stride;
+  *tu2 = load_unaligned_u8_4x2(buf, stride);
 }
 
-static INLINE void load_unaligned_u8_4x2(const uint8_t *buf, int stride,
-                                         uint32x2_t *tu0) {
-  uint32_t a;
-
-  memcpy(&a, buf, 4);
-  buf += stride;
-  *tu0 = vdup_n_u32(a);
-  memcpy(&a, buf, 4);
-  *tu0 = vset_lane_u32(a, *tu0, 1);
+static INLINE void load_unaligned_u8_4x8(const uint8_t *buf, int stride,
+                                         uint8x8_t *tu0, uint8x8_t *tu1,
+                                         uint8x8_t *tu2, uint8x8_t *tu3) {
+  load_unaligned_u8_4x4(buf, stride, tu0, tu1);
+  buf += 4 * stride;
+  load_unaligned_u8_4x4(buf, stride, tu2, tu3);
 }
 
 /* These intrinsics require immediate values, so we must use #defines
@@ -487,17 +858,6 @@
     memcpy(dst, &a, 2);                                \
   } while (0)
 
-static INLINE void load_unaligned_u8_2x2(const uint8_t *buf, int stride,
-                                         uint16x4_t *tu0) {
-  uint16_t a;
-
-  memcpy(&a, buf, 2);
-  buf += stride;
-  *tu0 = vdup_n_u16(a);
-  memcpy(&a, buf, 2);
-  *tu0 = vset_lane_u16(a, *tu0, 1);
-}
-
 static INLINE void load_u8_16x8(const uint8_t *s, ptrdiff_t p,
                                 uint8x16_t *const s0, uint8x16_t *const s1,
                                 uint8x16_t *const s2, uint8x16_t *const s3,
@@ -532,21 +892,43 @@
   *s3 = vld1q_u8(s);
 }
 
+static INLINE void load_u16_16x4(const uint16_t *s, ptrdiff_t p,
+                                 uint16x8_t *const s0, uint16x8_t *const s1,
+                                 uint16x8_t *const s2, uint16x8_t *const s3,
+                                 uint16x8_t *const s4, uint16x8_t *const s5,
+                                 uint16x8_t *const s6, uint16x8_t *const s7) {
+  *s0 = vld1q_u16(s);
+  *s1 = vld1q_u16(s + 8);
+  s += p;
+  *s2 = vld1q_u16(s);
+  *s3 = vld1q_u16(s + 8);
+  s += p;
+  *s4 = vld1q_u16(s);
+  *s5 = vld1q_u16(s + 8);
+  s += p;
+  *s6 = vld1q_u16(s);
+  *s7 = vld1q_u16(s + 8);
+}
+
 static INLINE void load_unaligned_u16_4x4(const uint16_t *buf, uint32_t stride,
-                                          uint64x2_t *tu0, uint64x2_t *tu1) {
+                                          uint16x8_t *tu0, uint16x8_t *tu1) {
   uint64_t a;
+  uint64x2_t a_u64;
 
   memcpy(&a, buf, 8);
   buf += stride;
-  *tu0 = vdupq_n_u64(a);
+  a_u64 = vdupq_n_u64(0);
+  a_u64 = vsetq_lane_u64(a, a_u64, 0);
   memcpy(&a, buf, 8);
   buf += stride;
-  *tu0 = vsetq_lane_u64(a, *tu0, 1);
+  a_u64 = vsetq_lane_u64(a, a_u64, 1);
+  *tu0 = vreinterpretq_u16_u64(a_u64);
   memcpy(&a, buf, 8);
   buf += stride;
-  *tu1 = vdupq_n_u64(a);
+  a_u64 = vdupq_n_u64(a);
   memcpy(&a, buf, 8);
-  *tu1 = vsetq_lane_u64(a, *tu1, 1);
+  a_u64 = vsetq_lane_u64(a, a_u64, 1);
+  *tu1 = vreinterpretq_u16_u64(a_u64);
 }
 
 static INLINE void load_s32_4x4(int32_t *s, int32_t p, int32x4_t *s1,
@@ -609,17 +991,9 @@
   vst1q_s32(buf + 4, v1);
 }
 
-// Stores the second result at an offset of 8 (instead of 4) to match the output
-// with that of C implementation and the function is similar to
-// store_s16q_to_tran_low(). The offset in the function name signifies that
-// pointer should be incremented by at least 4 in the calling function after
-// store_s16q_to_tran_low_offset_4() call.
-static INLINE void store_s16q_to_tran_low_offset_4(tran_low_t *buf,
-                                                   const int16x8_t a) {
-  const int32x4_t v0 = vmovl_s16(vget_low_s16(a));
-  const int32x4_t v1 = vmovl_s16(vget_high_s16(a));
+static INLINE void store_s16_to_tran_low(tran_low_t *buf, const int16x4_t a) {
+  const int32x4_t v0 = vmovl_s16(a);
   vst1q_s32(buf, v0);
-  vst1q_s32(buf + 8, v1);
 }
 
 #endif  // AOM_AOM_DSP_ARM_MEM_NEON_H_
diff --git a/aom_dsp/arm/obmc_variance_neon.c b/aom_dsp/arm/obmc_variance_neon.c
new file mode 100644
index 0000000..8702ba6
--- /dev/null
+++ b/aom_dsp/arm/obmc_variance_neon.c
@@ -0,0 +1,290 @@
+/*
+ * Copyright (c) 2023, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <arm_neon.h>
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+#include "mem_neon.h"
+#include "sum_neon.h"
+
+static INLINE void obmc_variance_8x1_s16_neon(int16x8_t pre_s16,
+                                              const int32_t *wsrc,
+                                              const int32_t *mask,
+                                              int32x4_t *ssev,
+                                              int32x4_t *sumv) {
+  // For 4xh and 8xh we observe it is faster to avoid the double-widening of
+  // pre. Instead we do a single widening step and narrow the mask to 16-bits
+  // to allow us to perform a widening multiply. Widening multiply
+  // instructions have better throughput on some micro-architectures but for
+  // the larger block sizes this benefit is outweighed by the additional
+  // instruction needed to first narrow the mask vectors.
+
+  int32x4_t wsrc_s32_lo = vld1q_s32(&wsrc[0]);
+  int32x4_t wsrc_s32_hi = vld1q_s32(&wsrc[4]);
+  int16x8_t mask_s16 = vuzpq_s16(vreinterpretq_s16_s32(vld1q_s32(&mask[0])),
+                                 vreinterpretq_s16_s32(vld1q_s32(&mask[4])))
+                           .val[0];
+
+  int32x4_t diff_s32_lo =
+      vmlsl_s16(wsrc_s32_lo, vget_low_s16(pre_s16), vget_low_s16(mask_s16));
+  int32x4_t diff_s32_hi =
+      vmlsl_s16(wsrc_s32_hi, vget_high_s16(pre_s16), vget_high_s16(mask_s16));
+
+  // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away
+  // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up.
+  // This difference only affects the bit patterns at the rounding breakpoints
+  // exactly, so we can add -1 to all negative numbers to move the breakpoint
+  // one value across and into the correct rounding region.
+  diff_s32_lo = vsraq_n_s32(diff_s32_lo, diff_s32_lo, 31);
+  diff_s32_hi = vsraq_n_s32(diff_s32_hi, diff_s32_hi, 31);
+  int32x4_t round_s32_lo = vrshrq_n_s32(diff_s32_lo, 12);
+  int32x4_t round_s32_hi = vrshrq_n_s32(diff_s32_hi, 12);
+
+  *sumv = vrsraq_n_s32(*sumv, diff_s32_lo, 12);
+  *sumv = vrsraq_n_s32(*sumv, diff_s32_hi, 12);
+  *ssev = vmlaq_s32(*ssev, round_s32_lo, round_s32_lo);
+  *ssev = vmlaq_s32(*ssev, round_s32_hi, round_s32_hi);
+}
+
+#if defined(__aarch64__)
+
+// Use tbl for doing a double-width zero extension from 8->32 bits since we can
+// do this in one instruction rather than two (indices out of range (255 here)
+// are set to zero by tbl).
+DECLARE_ALIGNED(16, static const uint8_t, obmc_variance_permute_idx[]) = {
+  0,  255, 255, 255, 1,  255, 255, 255, 2,  255, 255, 255, 3,  255, 255, 255,
+  4,  255, 255, 255, 5,  255, 255, 255, 6,  255, 255, 255, 7,  255, 255, 255,
+  8,  255, 255, 255, 9,  255, 255, 255, 10, 255, 255, 255, 11, 255, 255, 255,
+  12, 255, 255, 255, 13, 255, 255, 255, 14, 255, 255, 255, 15, 255, 255, 255
+};
+
+static INLINE void obmc_variance_8x1_s32_neon(
+    int32x4_t pre_lo, int32x4_t pre_hi, const int32_t *wsrc,
+    const int32_t *mask, int32x4_t *ssev, int32x4_t *sumv) {
+  int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]);
+  int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]);
+  int32x4_t mask_lo = vld1q_s32(&mask[0]);
+  int32x4_t mask_hi = vld1q_s32(&mask[4]);
+
+  int32x4_t diff_lo = vmlsq_s32(wsrc_lo, pre_lo, mask_lo);
+  int32x4_t diff_hi = vmlsq_s32(wsrc_hi, pre_hi, mask_hi);
+
+  // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away from
+  // zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. This
+  // difference only affects the bit patterns at the rounding breakpoints
+  // exactly, so we can add -1 to all negative numbers to move the breakpoint
+  // one value across and into the correct rounding region.
+  diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31);
+  diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31);
+  int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12);
+  int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12);
+
+  *sumv = vrsraq_n_s32(*sumv, diff_lo, 12);
+  *sumv = vrsraq_n_s32(*sumv, diff_hi, 12);
+  *ssev = vmlaq_s32(*ssev, round_lo, round_lo);
+  *ssev = vmlaq_s32(*ssev, round_hi, round_hi);
+}
+
+static INLINE void obmc_variance_large_neon(const uint8_t *pre, int pre_stride,
+                                            const int32_t *wsrc,
+                                            const int32_t *mask, int width,
+                                            int height, unsigned *sse,
+                                            int *sum) {
+  assert(width % 16 == 0);
+
+  // Use tbl for doing a double-width zero extension from 8->32 bits since we
+  // can do this in one instruction rather than two.
+  uint8x16_t pre_idx0 = vld1q_u8(&obmc_variance_permute_idx[0]);
+  uint8x16_t pre_idx1 = vld1q_u8(&obmc_variance_permute_idx[16]);
+  uint8x16_t pre_idx2 = vld1q_u8(&obmc_variance_permute_idx[32]);
+  uint8x16_t pre_idx3 = vld1q_u8(&obmc_variance_permute_idx[48]);
+
+  int32x4_t ssev = vdupq_n_s32(0);
+  int32x4_t sumv = vdupq_n_s32(0);
+
+  int h = height;
+  do {
+    int w = width;
+    do {
+      uint8x16_t pre_u8 = vld1q_u8(pre);
+
+      int32x4_t pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx0));
+      int32x4_t pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx1));
+      obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[0], &mask[0],
+                                 &ssev, &sumv);
+
+      pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx2));
+      pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx3));
+      obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[8], &mask[8],
+                                 &ssev, &sumv);
+
+      wsrc += 16;
+      mask += 16;
+      pre += 16;
+      w -= 16;
+    } while (w != 0);
+
+    pre += pre_stride - width;
+  } while (--h != 0);
+
+  *sse = horizontal_add_s32x4(ssev);
+  *sum = horizontal_add_s32x4(sumv);
+}
+
+#else  // !defined(__aarch64__)
+
+static INLINE void obmc_variance_large_neon(const uint8_t *pre, int pre_stride,
+                                            const int32_t *wsrc,
+                                            const int32_t *mask, int width,
+                                            int height, unsigned *sse,
+                                            int *sum) {
+  // Non-aarch64 targets do not have a 128-bit tbl instruction, so use the
+  // widening version of the core kernel instead.
+
+  assert(width % 16 == 0);
+
+  int32x4_t ssev = vdupq_n_s32(0);
+  int32x4_t sumv = vdupq_n_s32(0);
+
+  int h = height;
+  do {
+    int w = width;
+    do {
+      uint8x16_t pre_u8 = vld1q_u8(pre);
+
+      int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(pre_u8)));
+      obmc_variance_8x1_s16_neon(pre_s16, &wsrc[0], &mask[0], &ssev, &sumv);
+
+      pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(pre_u8)));
+      obmc_variance_8x1_s16_neon(pre_s16, &wsrc[8], &mask[8], &ssev, &sumv);
+
+      wsrc += 16;
+      mask += 16;
+      pre += 16;
+      w -= 16;
+    } while (w != 0);
+
+    pre += pre_stride - width;
+  } while (--h != 0);
+
+  *sse = horizontal_add_s32x4(ssev);
+  *sum = horizontal_add_s32x4(sumv);
+}
+
+#endif  // defined(__aarch64__)
+
+static INLINE void obmc_variance_neon_128xh(const uint8_t *pre, int pre_stride,
+                                            const int32_t *wsrc,
+                                            const int32_t *mask, int h,
+                                            unsigned *sse, int *sum) {
+  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse, sum);
+}
+
+static INLINE void obmc_variance_neon_64xh(const uint8_t *pre, int pre_stride,
+                                           const int32_t *wsrc,
+                                           const int32_t *mask, int h,
+                                           unsigned *sse, int *sum) {
+  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum);
+}
+
+static INLINE void obmc_variance_neon_32xh(const uint8_t *pre, int pre_stride,
+                                           const int32_t *wsrc,
+                                           const int32_t *mask, int h,
+                                           unsigned *sse, int *sum) {
+  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum);
+}
+
+static INLINE void obmc_variance_neon_16xh(const uint8_t *pre, int pre_stride,
+                                           const int32_t *wsrc,
+                                           const int32_t *mask, int h,
+                                           unsigned *sse, int *sum) {
+  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum);
+}
+
+static INLINE void obmc_variance_neon_8xh(const uint8_t *pre, int pre_stride,
+                                          const int32_t *wsrc,
+                                          const int32_t *mask, int h,
+                                          unsigned *sse, int *sum) {
+  int32x4_t ssev = vdupq_n_s32(0);
+  int32x4_t sumv = vdupq_n_s32(0);
+
+  do {
+    uint8x8_t pre_u8 = vld1_u8(pre);
+    int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8));
+
+    obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv);
+
+    pre += pre_stride;
+    wsrc += 8;
+    mask += 8;
+  } while (--h != 0);
+
+  *sse = horizontal_add_s32x4(ssev);
+  *sum = horizontal_add_s32x4(sumv);
+}
+
+static INLINE void obmc_variance_neon_4xh(const uint8_t *pre, int pre_stride,
+                                          const int32_t *wsrc,
+                                          const int32_t *mask, int h,
+                                          unsigned *sse, int *sum) {
+  assert(h % 2 == 0);
+
+  int32x4_t ssev = vdupq_n_s32(0);
+  int32x4_t sumv = vdupq_n_s32(0);
+
+  do {
+    uint8x8_t pre_u8 = load_unaligned_u8(pre, pre_stride);
+    int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8));
+
+    obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv);
+
+    pre += 2 * pre_stride;
+    wsrc += 8;
+    mask += 8;
+    h -= 2;
+  } while (h != 0);
+
+  *sse = horizontal_add_s32x4(ssev);
+  *sum = horizontal_add_s32x4(sumv);
+}
+
+#define OBMC_VARIANCE_WXH_NEON(W, H)                                       \
+  unsigned aom_obmc_variance##W##x##H##_neon(                              \
+      const uint8_t *pre, int pre_stride, const int32_t *wsrc,             \
+      const int32_t *mask, unsigned *sse) {                                \
+    int sum;                                                               \
+    obmc_variance_neon_##W##xh(pre, pre_stride, wsrc, mask, H, sse, &sum); \
+    return *sse - (unsigned)(((int64_t)sum * sum) / (W * H));              \
+  }
+
+OBMC_VARIANCE_WXH_NEON(4, 4)
+OBMC_VARIANCE_WXH_NEON(4, 8)
+OBMC_VARIANCE_WXH_NEON(8, 4)
+OBMC_VARIANCE_WXH_NEON(8, 8)
+OBMC_VARIANCE_WXH_NEON(8, 16)
+OBMC_VARIANCE_WXH_NEON(16, 8)
+OBMC_VARIANCE_WXH_NEON(16, 16)
+OBMC_VARIANCE_WXH_NEON(16, 32)
+OBMC_VARIANCE_WXH_NEON(32, 16)
+OBMC_VARIANCE_WXH_NEON(32, 32)
+OBMC_VARIANCE_WXH_NEON(32, 64)
+OBMC_VARIANCE_WXH_NEON(64, 32)
+OBMC_VARIANCE_WXH_NEON(64, 64)
+OBMC_VARIANCE_WXH_NEON(64, 128)
+OBMC_VARIANCE_WXH_NEON(128, 64)
+OBMC_VARIANCE_WXH_NEON(128, 128)
+OBMC_VARIANCE_WXH_NEON(4, 16)
+OBMC_VARIANCE_WXH_NEON(16, 4)
+OBMC_VARIANCE_WXH_NEON(8, 32)
+OBMC_VARIANCE_WXH_NEON(32, 8)
+OBMC_VARIANCE_WXH_NEON(16, 64)
+OBMC_VARIANCE_WXH_NEON(64, 16)
diff --git a/aom_dsp/arm/sad4d_neon.c b/aom_dsp/arm/sad4d_neon.c
index e1eccc3..bc73fb8 100644
--- a/aom_dsp/arm/sad4d_neon.c
+++ b/aom_dsp/arm/sad4d_neon.c
@@ -15,9 +15,10 @@
 #include "config/aom_dsp_rtcd.h"
 
 #include "aom/aom_integer.h"
+#include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/sum_neon.h"
 
-#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+#if defined(__ARM_FEATURE_DOTPROD)
 
 static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
                               uint32x4_t *const sad_sum) {
@@ -25,148 +26,64 @@
   *sad_sum = vdotq_u32(*sad_sum, abs_diff, vdupq_n_u8(1));
 }
 
-static INLINE void sad128xhx4d_neon(const uint8_t *src, int src_stride,
-                                    const uint8_t *const ref[4], int ref_stride,
-                                    uint32_t res[4], int h) {
+static INLINE void sadwxhx4d_large_neon(const uint8_t *src, int src_stride,
+                                        const uint8_t *const ref[4],
+                                        int ref_stride, uint32_t res[4], int w,
+                                        int h) {
   uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
                            vdupq_n_u32(0) };
   uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
                            vdupq_n_u32(0) };
+  uint32x4_t sum[4];
 
-  int i = 0;
+  int ref_offset = 0;
+  int i = h;
   do {
-    const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
-    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
-    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
-    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
-    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+    int j = 0;
+    do {
+      const uint8x16_t s0 = vld1q_u8(src + j);
+      sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
+      sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
+      sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
+      sad16_neon(s0, vld1q_u8(ref[3] + ref_offset + j), &sum_lo[3]);
 
-    const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
-    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
-    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
-    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
-    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+      const uint8x16_t s1 = vld1q_u8(src + j + 16);
+      sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
+      sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
+      sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
+      sad16_neon(s1, vld1q_u8(ref[3] + ref_offset + j + 16), &sum_hi[3]);
 
-    const uint8x16_t s2 = vld1q_u8(src + i * src_stride + 32);
-    sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
-    sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
-    sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
-    sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
+      j += 32;
+    } while (j < w);
 
-    const uint8x16_t s3 = vld1q_u8(src + i * src_stride + 48);
-    sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
-    sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
-    sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
-    sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
 
-    const uint8x16_t s4 = vld1q_u8(src + i * src_stride + 64);
-    sad16_neon(s4, vld1q_u8(ref[0] + i * ref_stride + 64), &sum_lo[0]);
-    sad16_neon(s4, vld1q_u8(ref[1] + i * ref_stride + 64), &sum_lo[1]);
-    sad16_neon(s4, vld1q_u8(ref[2] + i * ref_stride + 64), &sum_lo[2]);
-    sad16_neon(s4, vld1q_u8(ref[3] + i * ref_stride + 64), &sum_lo[3]);
+  sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]);
+  sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]);
+  sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]);
+  sum[3] = vaddq_u32(sum_lo[3], sum_hi[3]);
 
-    const uint8x16_t s5 = vld1q_u8(src + i * src_stride + 80);
-    sad16_neon(s5, vld1q_u8(ref[0] + i * ref_stride + 80), &sum_hi[0]);
-    sad16_neon(s5, vld1q_u8(ref[1] + i * ref_stride + 80), &sum_hi[1]);
-    sad16_neon(s5, vld1q_u8(ref[2] + i * ref_stride + 80), &sum_hi[2]);
-    sad16_neon(s5, vld1q_u8(ref[3] + i * ref_stride + 80), &sum_hi[3]);
+  vst1q_u32(res, horizontal_add_4d_u32x4(sum));
+}
 
-    const uint8x16_t s6 = vld1q_u8(src + i * src_stride + 96);
-    sad16_neon(s6, vld1q_u8(ref[0] + i * ref_stride + 96), &sum_lo[0]);
-    sad16_neon(s6, vld1q_u8(ref[1] + i * ref_stride + 96), &sum_lo[1]);
-    sad16_neon(s6, vld1q_u8(ref[2] + i * ref_stride + 96), &sum_lo[2]);
-    sad16_neon(s6, vld1q_u8(ref[3] + i * ref_stride + 96), &sum_lo[3]);
-
-    const uint8x16_t s7 = vld1q_u8(src + i * src_stride + 112);
-    sad16_neon(s7, vld1q_u8(ref[0] + i * ref_stride + 112), &sum_hi[0]);
-    sad16_neon(s7, vld1q_u8(ref[1] + i * ref_stride + 112), &sum_hi[1]);
-    sad16_neon(s7, vld1q_u8(ref[2] + i * ref_stride + 112), &sum_hi[2]);
-    sad16_neon(s7, vld1q_u8(ref[3] + i * ref_stride + 112), &sum_hi[3]);
-
-    i++;
-  } while (i < h);
-
-  uint32x4_t res0 = vpaddq_u32(vaddq_u32(sum_lo[0], sum_hi[0]),
-                               vaddq_u32(sum_lo[1], sum_hi[1]));
-  uint32x4_t res1 = vpaddq_u32(vaddq_u32(sum_lo[2], sum_hi[2]),
-                               vaddq_u32(sum_lo[3], sum_hi[3]));
-  vst1q_u32(res, vpaddq_u32(res0, res1));
+static INLINE void sad128xhx4d_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *const ref[4], int ref_stride,
+                                    uint32_t res[4], int h) {
+  sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, res, 128, h);
 }
 
 static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride,
                                    const uint8_t *const ref[4], int ref_stride,
                                    uint32_t res[4], int h) {
-  uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                           vdupq_n_u32(0) };
-  uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                           vdupq_n_u32(0) };
-
-  int i = 0;
-  do {
-    const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
-    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
-    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
-    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
-    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
-
-    const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
-    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
-    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
-    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
-    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
-
-    const uint8x16_t s2 = vld1q_u8(src + i * src_stride + 32);
-    sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
-    sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
-    sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
-    sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
-
-    const uint8x16_t s3 = vld1q_u8(src + i * src_stride + 48);
-    sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
-    sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
-    sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
-    sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
-
-    i++;
-  } while (i < h);
-
-  uint32x4_t res0 = vpaddq_u32(vaddq_u32(sum_lo[0], sum_hi[0]),
-                               vaddq_u32(sum_lo[1], sum_hi[1]));
-  uint32x4_t res1 = vpaddq_u32(vaddq_u32(sum_lo[2], sum_hi[2]),
-                               vaddq_u32(sum_lo[3], sum_hi[3]));
-  vst1q_u32(res, vpaddq_u32(res0, res1));
+  sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, res, 64, h);
 }
 
 static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
                                    const uint8_t *const ref[4], int ref_stride,
                                    uint32_t res[4], int h) {
-  uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                           vdupq_n_u32(0) };
-  uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                           vdupq_n_u32(0) };
-
-  int i = 0;
-  do {
-    const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
-    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
-    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
-    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
-    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
-
-    const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
-    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
-    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
-    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
-    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
-
-    i++;
-  } while (i < h);
-
-  uint32x4_t res0 = vpaddq_u32(vaddq_u32(sum_lo[0], sum_hi[0]),
-                               vaddq_u32(sum_lo[1], sum_hi[1]));
-  uint32x4_t res1 = vpaddq_u32(vaddq_u32(sum_lo[2], sum_hi[2]),
-                               vaddq_u32(sum_lo[3], sum_hi[3]));
-  vst1q_u32(res, vpaddq_u32(res0, res1));
+  sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, res, 32, h);
 }
 
 static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
@@ -175,23 +92,23 @@
   uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
                         vdupq_n_u32(0) };
 
-  int i = 0;
+  int ref_offset = 0;
+  int i = h;
   do {
-    const uint8x16_t s = vld1q_u8(src + i * src_stride);
-    sad16_neon(s, vld1q_u8(ref[0] + i * ref_stride), &sum[0]);
-    sad16_neon(s, vld1q_u8(ref[1] + i * ref_stride), &sum[1]);
-    sad16_neon(s, vld1q_u8(ref[2] + i * ref_stride), &sum[2]);
-    sad16_neon(s, vld1q_u8(ref[3] + i * ref_stride), &sum[3]);
+    const uint8x16_t s = vld1q_u8(src);
+    sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]);
+    sad16_neon(s, vld1q_u8(ref[3] + ref_offset), &sum[3]);
 
-    i++;
-  } while (i < h);
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
 
-  uint32x4_t res0 = vpaddq_u32(sum[0], sum[1]);
-  uint32x4_t res1 = vpaddq_u32(sum[2], sum[3]);
-  vst1q_u32(res, vpaddq_u32(res0, res1));
+  vst1q_u32(res, horizontal_add_4d_u32x4(sum));
 }
 
-#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+#else  // !(defined(__ARM_FEATURE_DOTPROD))
 
 static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
                               uint16x8_t *const sad_sum) {
@@ -199,12 +116,15 @@
   *sad_sum = vpadalq_u8(*sad_sum, abs_diff);
 }
 
-static INLINE void sad128xhx4d_neon(const uint8_t *src, int src_stride,
-                                    const uint8_t *const ref[4], int ref_stride,
-                                    uint32_t res[4], int h) {
-  vst1q_u32(res, vdupq_n_u32(0));
-  int h_tmp = h > 32 ? 32 : h;
+static INLINE void sadwxhx4d_large_neon(const uint8_t *src, int src_stride,
+                                        const uint8_t *const ref[4],
+                                        int ref_stride, uint32_t res[4], int w,
+                                        int h, int h_overflow) {
+  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0) };
+  int h_limit = h > h_overflow ? h_overflow : h;
 
+  int ref_offset = 0;
   int i = 0;
   do {
     uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
@@ -213,114 +133,52 @@
                              vdupq_n_u16(0) };
 
     do {
-      const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
-      sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
-      sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
-      sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
-      sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+      int j = 0;
+      do {
+        const uint8x16_t s0 = vld1q_u8(src + j);
+        sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
+        sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
+        sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
+        sad16_neon(s0, vld1q_u8(ref[3] + ref_offset + j), &sum_lo[3]);
 
-      const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
-      sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
-      sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
-      sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
-      sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+        const uint8x16_t s1 = vld1q_u8(src + j + 16);
+        sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
+        sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
+        sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
+        sad16_neon(s1, vld1q_u8(ref[3] + ref_offset + j + 16), &sum_hi[3]);
 
-      const uint8x16_t s2 = vld1q_u8(src + i * src_stride + 32);
-      sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
-      sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
-      sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
-      sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
+        j += 32;
+      } while (j < w);
 
-      const uint8x16_t s3 = vld1q_u8(src + i * src_stride + 48);
-      sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
-      sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
-      sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
-      sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
+      src += src_stride;
+      ref_offset += ref_stride;
+    } while (++i < h_limit);
 
-      const uint8x16_t s4 = vld1q_u8(src + i * src_stride + 64);
-      sad16_neon(s4, vld1q_u8(ref[0] + i * ref_stride + 64), &sum_lo[0]);
-      sad16_neon(s4, vld1q_u8(ref[1] + i * ref_stride + 64), &sum_lo[1]);
-      sad16_neon(s4, vld1q_u8(ref[2] + i * ref_stride + 64), &sum_lo[2]);
-      sad16_neon(s4, vld1q_u8(ref[3] + i * ref_stride + 64), &sum_lo[3]);
+    sum[0] = vpadalq_u16(sum[0], sum_lo[0]);
+    sum[0] = vpadalq_u16(sum[0], sum_hi[0]);
+    sum[1] = vpadalq_u16(sum[1], sum_lo[1]);
+    sum[1] = vpadalq_u16(sum[1], sum_hi[1]);
+    sum[2] = vpadalq_u16(sum[2], sum_lo[2]);
+    sum[2] = vpadalq_u16(sum[2], sum_hi[2]);
+    sum[3] = vpadalq_u16(sum[3], sum_lo[3]);
+    sum[3] = vpadalq_u16(sum[3], sum_hi[3]);
 
-      const uint8x16_t s5 = vld1q_u8(src + i * src_stride + 80);
-      sad16_neon(s5, vld1q_u8(ref[0] + i * ref_stride + 80), &sum_hi[0]);
-      sad16_neon(s5, vld1q_u8(ref[1] + i * ref_stride + 80), &sum_hi[1]);
-      sad16_neon(s5, vld1q_u8(ref[2] + i * ref_stride + 80), &sum_hi[2]);
-      sad16_neon(s5, vld1q_u8(ref[3] + i * ref_stride + 80), &sum_hi[3]);
-
-      const uint8x16_t s6 = vld1q_u8(src + i * src_stride + 96);
-      sad16_neon(s6, vld1q_u8(ref[0] + i * ref_stride + 96), &sum_lo[0]);
-      sad16_neon(s6, vld1q_u8(ref[1] + i * ref_stride + 96), &sum_lo[1]);
-      sad16_neon(s6, vld1q_u8(ref[2] + i * ref_stride + 96), &sum_lo[2]);
-      sad16_neon(s6, vld1q_u8(ref[3] + i * ref_stride + 96), &sum_lo[3]);
-
-      const uint8x16_t s7 = vld1q_u8(src + i * src_stride + 112);
-      sad16_neon(s7, vld1q_u8(ref[0] + i * ref_stride + 112), &sum_hi[0]);
-      sad16_neon(s7, vld1q_u8(ref[1] + i * ref_stride + 112), &sum_hi[1]);
-      sad16_neon(s7, vld1q_u8(ref[2] + i * ref_stride + 112), &sum_hi[2]);
-      sad16_neon(s7, vld1q_u8(ref[3] + i * ref_stride + 112), &sum_hi[3]);
-
-      i++;
-    } while (i < h_tmp);
-
-    res[0] += horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
-    res[1] += horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
-    res[2] += horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
-    res[3] += horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
-
-    h_tmp += 32;
+    h_limit += h_overflow;
   } while (i < h);
+
+  vst1q_u32(res, horizontal_add_4d_u32x4(sum));
+}
+
+static INLINE void sad128xhx4d_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *const ref[4], int ref_stride,
+                                    uint32_t res[4], int h) {
+  sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, res, 128, h, 32);
 }
 
 static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride,
                                    const uint8_t *const ref[4], int ref_stride,
                                    uint32_t res[4], int h) {
-  vst1q_u32(res, vdupq_n_u32(0));
-  int h_tmp = h > 64 ? 64 : h;
-
-  int i = 0;
-  do {
-    uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                             vdupq_n_u16(0) };
-    uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                             vdupq_n_u16(0) };
-
-    do {
-      const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
-      sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
-      sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
-      sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
-      sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
-
-      const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
-      sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
-      sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
-      sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
-      sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
-
-      const uint8x16_t s2 = vld1q_u8(src + i * src_stride + 32);
-      sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
-      sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
-      sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
-      sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
-
-      const uint8x16_t s3 = vld1q_u8(src + i * src_stride + 48);
-      sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
-      sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
-      sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
-      sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
-
-      i++;
-    } while (i < h_tmp);
-
-    res[0] += horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
-    res[1] += horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
-    res[2] += horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
-    res[3] += horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
-
-    h_tmp += 64;
-  } while (i < h);
+  sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, res, 64, h, 64);
 }
 
 static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
@@ -331,128 +189,122 @@
   uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
                            vdupq_n_u16(0) };
 
-  int i = 0;
+  int ref_offset = 0;
+  int i = h;
   do {
-    const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
-    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
-    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
-    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
-    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+    const uint8x16_t s0 = vld1q_u8(src);
+    sad16_neon(s0, vld1q_u8(ref[0] + ref_offset), &sum_lo[0]);
+    sad16_neon(s0, vld1q_u8(ref[1] + ref_offset), &sum_lo[1]);
+    sad16_neon(s0, vld1q_u8(ref[2] + ref_offset), &sum_lo[2]);
+    sad16_neon(s0, vld1q_u8(ref[3] + ref_offset), &sum_lo[3]);
 
-    const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
-    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
-    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
-    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
-    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+    const uint8x16_t s1 = vld1q_u8(src + 16);
+    sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + 16), &sum_hi[0]);
+    sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + 16), &sum_hi[1]);
+    sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + 16), &sum_hi[2]);
+    sad16_neon(s1, vld1q_u8(ref[3] + ref_offset + 16), &sum_hi[3]);
 
-    i++;
-  } while (i < h);
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
 
-  res[0] = horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
-  res[1] = horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
-  res[2] = horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
-  res[3] = horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
+  vst1q_u32(res, horizontal_long_add_4d_u16x8(sum_lo, sum_hi));
 }
 
 static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
                                    const uint8_t *const ref[4], int ref_stride,
                                    uint32_t res[4], int h) {
-  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0) };
+  uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                            vdupq_n_u16(0) };
+  uint32x4_t sum_u32[4];
 
-  int i = 0;
+  int ref_offset = 0;
+  int i = h;
   do {
-    const uint8x16_t s = vld1q_u8(src + i * src_stride);
-    sad16_neon(s, vld1q_u8(ref[0] + i * ref_stride), &sum[0]);
-    sad16_neon(s, vld1q_u8(ref[1] + i * ref_stride), &sum[1]);
-    sad16_neon(s, vld1q_u8(ref[2] + i * ref_stride), &sum[2]);
-    sad16_neon(s, vld1q_u8(ref[3] + i * ref_stride), &sum[3]);
+    const uint8x16_t s = vld1q_u8(src);
+    sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum_u16[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum_u16[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum_u16[2]);
+    sad16_neon(s, vld1q_u8(ref[3] + ref_offset), &sum_u16[3]);
 
-    i++;
-  } while (i < h);
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
 
-  res[0] = horizontal_add_u16x8(sum[0]);
-  res[1] = horizontal_add_u16x8(sum[1]);
-  res[2] = horizontal_add_u16x8(sum[2]);
-  res[3] = horizontal_add_u16x8(sum[3]);
+  sum_u32[0] = vpaddlq_u16(sum_u16[0]);
+  sum_u32[1] = vpaddlq_u16(sum_u16[1]);
+  sum_u32[2] = vpaddlq_u16(sum_u16[2]);
+  sum_u32[3] = vpaddlq_u16(sum_u16[3]);
+
+  vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32));
 }
 
-#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
-
-static INLINE void sad8_neon(uint8x8_t src, uint8x8_t ref,
-                             uint16x8_t *const sad_sum) {
-  uint8x8_t abs_diff = vabd_u8(src, ref);
-  *sad_sum = vaddw_u8(*sad_sum, abs_diff);
-}
+#endif  // defined(__ARM_FEATURE_DOTPROD)
 
 static INLINE void sad8xhx4d_neon(const uint8_t *src, int src_stride,
                                   const uint8_t *const ref[4], int ref_stride,
                                   uint32_t res[4], int h) {
-  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0) };
+  uint16x8_t sum[4];
 
-  int i = 0;
+  uint8x8_t s = vld1_u8(src);
+  sum[0] = vabdl_u8(s, vld1_u8(ref[0]));
+  sum[1] = vabdl_u8(s, vld1_u8(ref[1]));
+  sum[2] = vabdl_u8(s, vld1_u8(ref[2]));
+  sum[3] = vabdl_u8(s, vld1_u8(ref[3]));
+
+  src += src_stride;
+  int ref_offset = ref_stride;
+  int i = h - 1;
   do {
-    const uint8x8_t s = vld1_u8(src + i * src_stride);
-    sad8_neon(s, vld1_u8(ref[0] + i * ref_stride), &sum[0]);
-    sad8_neon(s, vld1_u8(ref[1] + i * ref_stride), &sum[1]);
-    sad8_neon(s, vld1_u8(ref[2] + i * ref_stride), &sum[2]);
-    sad8_neon(s, vld1_u8(ref[3] + i * ref_stride), &sum[3]);
+    s = vld1_u8(src);
+    sum[0] = vabal_u8(sum[0], s, vld1_u8(ref[0] + ref_offset));
+    sum[1] = vabal_u8(sum[1], s, vld1_u8(ref[1] + ref_offset));
+    sum[2] = vabal_u8(sum[2], s, vld1_u8(ref[2] + ref_offset));
+    sum[3] = vabal_u8(sum[3], s, vld1_u8(ref[3] + ref_offset));
 
-    i++;
-  } while (i < h);
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
 
-  res[0] = horizontal_add_u16x8(sum[0]);
-  res[1] = horizontal_add_u16x8(sum[1]);
-  res[2] = horizontal_add_u16x8(sum[2]);
-  res[3] = horizontal_add_u16x8(sum[3]);
+  vst1q_u32(res, horizontal_add_4d_u16x8(sum));
 }
 
 static INLINE void sad4xhx4d_neon(const uint8_t *src, int src_stride,
                                   const uint8_t *const ref[4], int ref_stride,
                                   uint32_t res[4], int h) {
-  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0) };
+  uint16x8_t sum[4];
 
-  int i = 0;
+  uint8x8_t s = load_unaligned_u8(src, src_stride);
+  uint8x8_t r0 = load_unaligned_u8(ref[0], ref_stride);
+  uint8x8_t r1 = load_unaligned_u8(ref[1], ref_stride);
+  uint8x8_t r2 = load_unaligned_u8(ref[2], ref_stride);
+  uint8x8_t r3 = load_unaligned_u8(ref[3], ref_stride);
+
+  sum[0] = vabdl_u8(s, r0);
+  sum[1] = vabdl_u8(s, r1);
+  sum[2] = vabdl_u8(s, r2);
+  sum[3] = vabdl_u8(s, r3);
+
+  src += 2 * src_stride;
+  int ref_offset = 2 * ref_stride;
+  int i = (h - 1) / 2;
   do {
-    uint32x2_t s, r0, r1, r2, r3;
-    uint32_t s_lo, s_hi, r0_lo, r0_hi, r1_lo, r1_hi, r2_lo, r2_hi, r3_lo, r3_hi;
+    s = load_unaligned_u8(src, src_stride);
+    r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride);
+    r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride);
+    r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride);
+    r3 = load_unaligned_u8(ref[3] + ref_offset, ref_stride);
 
-    memcpy(&s_lo, src + i * src_stride, 4);
-    memcpy(&r0_lo, ref[0] + i * ref_stride, 4);
-    memcpy(&r1_lo, ref[1] + i * ref_stride, 4);
-    memcpy(&r2_lo, ref[2] + i * ref_stride, 4);
-    memcpy(&r3_lo, ref[3] + i * ref_stride, 4);
-    s = vdup_n_u32(s_lo);
-    r0 = vdup_n_u32(r0_lo);
-    r1 = vdup_n_u32(r1_lo);
-    r2 = vdup_n_u32(r2_lo);
-    r3 = vdup_n_u32(r3_lo);
+    sum[0] = vabal_u8(sum[0], s, r0);
+    sum[1] = vabal_u8(sum[1], s, r1);
+    sum[2] = vabal_u8(sum[2], s, r2);
+    sum[3] = vabal_u8(sum[3], s, r3);
 
-    memcpy(&s_hi, src + (i + 1) * src_stride, 4);
-    memcpy(&r0_hi, ref[0] + (i + 1) * ref_stride, 4);
-    memcpy(&r1_hi, ref[1] + (i + 1) * ref_stride, 4);
-    memcpy(&r2_hi, ref[2] + (i + 1) * ref_stride, 4);
-    memcpy(&r3_hi, ref[3] + (i + 1) * ref_stride, 4);
-    s = vset_lane_u32(s_hi, s, 1);
-    r0 = vset_lane_u32(r0_hi, r0, 1);
-    r1 = vset_lane_u32(r1_hi, r1, 1);
-    r2 = vset_lane_u32(r2_hi, r2, 1);
-    r3 = vset_lane_u32(r3_hi, r3, 1);
+    src += 2 * src_stride;
+    ref_offset += 2 * ref_stride;
+  } while (--i != 0);
 
-    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r0), &sum[0]);
-    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r1), &sum[1]);
-    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r2), &sum[2]);
-    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r3), &sum[3]);
-
-    i += 2;
-  } while (i < h);
-
-  res[0] = horizontal_add_u16x8(sum[0]);
-  res[1] = horizontal_add_u16x8(sum[1]);
-  res[2] = horizontal_add_u16x8(sum[2]);
-  res[3] = horizontal_add_u16x8(sum[3]);
+  vst1q_u32(res, horizontal_add_4d_u16x8(sum));
 }
 
 #define SAD_WXH_4D_NEON(w, h)                                                  \
diff --git a/aom_dsp/arm/sad_neon.c b/aom_dsp/arm/sad_neon.c
index 5ba7f10..6a22289 100644
--- a/aom_dsp/arm/sad_neon.c
+++ b/aom_dsp/arm/sad_neon.c
@@ -10,9 +10,12 @@
  */
 
 #include <arm_neon.h>
+
 #include "config/aom_config.h"
 #include "config/aom_dsp_rtcd.h"
+
 #include "aom/aom_integer.h"
+#include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/sum_neon.h"
 
 #if defined(__ARM_FEATURE_DOTPROD)
@@ -289,24 +292,13 @@
 
   int i = h / 2;
   do {
-    uint32x2_t s, r;
-    uint32_t s0, s1, r0, r1;
+    uint8x8_t s = load_unaligned_u8(src_ptr, src_stride);
+    uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride);
 
-    memcpy(&s0, src_ptr, 4);
-    memcpy(&r0, ref_ptr, 4);
-    s = vdup_n_u32(s0);
-    r = vdup_n_u32(r0);
-    src_ptr += src_stride;
-    ref_ptr += ref_stride;
+    sum = vabal_u8(sum, s, r);
 
-    memcpy(&s1, src_ptr, 4);
-    memcpy(&r1, ref_ptr, 4);
-    s = vset_lane_u32(s1, s, 1);
-    r = vset_lane_u32(r1, r, 1);
-    src_ptr += src_stride;
-    ref_ptr += ref_stride;
-
-    sum = vabal_u8(sum, vreinterpret_u8_u32(s), vreinterpret_u8_u32(r));
+    src_ptr += 2 * src_stride;
+    ref_ptr += 2 * ref_stride;
   } while (--i != 0);
 
   return horizontal_add_u16x8(sum);
@@ -732,28 +724,15 @@
 
   int i = h / 2;
   do {
-    uint32x2_t s, r;
-    uint32_t s0, s1, r0, r1;
-    uint8x8_t p, avg;
+    uint8x8_t s = load_unaligned_u8(src_ptr, src_stride);
+    uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride);
+    uint8x8_t p = vld1_u8(second_pred);
 
-    memcpy(&s0, src_ptr, 4);
-    memcpy(&r0, ref_ptr, 4);
-    s = vdup_n_u32(s0);
-    r = vdup_n_u32(r0);
-    src_ptr += src_stride;
-    ref_ptr += ref_stride;
+    uint8x8_t avg = vrhadd_u8(r, p);
+    sum = vabal_u8(sum, s, avg);
 
-    memcpy(&s1, src_ptr, 4);
-    memcpy(&r1, ref_ptr, 4);
-    s = vset_lane_u32(s1, s, 1);
-    r = vset_lane_u32(r1, r, 1);
-    src_ptr += src_stride;
-    ref_ptr += ref_stride;
-
-    p = vld1_u8(second_pred);
-    avg = vrhadd_u8(vreinterpret_u8_u32(r), p);
-
-    sum = vabal_u8(sum, vreinterpret_u8_u32(s), avg);
+    src_ptr += 2 * src_stride;
+    ref_ptr += 2 * ref_stride;
     second_pred += 8;
   } while (--i != 0);
 
diff --git a/aom_dsp/arm/sse_neon.c b/aom_dsp/arm/sse_neon.c
index 2c988dc..d1d3d93 100644
--- a/aom_dsp/arm/sse_neon.c
+++ b/aom_dsp/arm/sse_neon.c
@@ -348,7 +348,8 @@
 
 int64_t aom_highbd_sse_neon(const uint8_t *a8, int a_stride, const uint8_t *b8,
                             int b_stride, int width, int height) {
-  const uint16x8_t q0 = { 0, 1, 2, 3, 4, 5, 6, 7 };
+  static const uint16_t k01234567[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
+  const uint16x8_t q0 = vld1q_u16(k01234567);
   int64_t sse = 0;
   uint16_t *a = CONVERT_TO_SHORTPTR(a8);
   uint16_t *b = CONVERT_TO_SHORTPTR(b8);
diff --git a/aom_dsp/arm/subpel_variance_neon.c b/aom_dsp/arm/subpel_variance_neon.c
index a058860..9599ae0 100644
--- a/aom_dsp/arm/subpel_variance_neon.c
+++ b/aom_dsp/arm/subpel_variance_neon.c
@@ -549,3 +549,239 @@
 
 #undef SUBPEL_AVG_VARIANCE_WXH_NEON
 #undef SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON
+
+#if !CONFIG_REALTIME_ONLY
+
+#define OBMC_SUBPEL_VARIANCE_WXH_NEON(w, h, padding)                   \
+  unsigned int aom_obmc_sub_pixel_variance##w##x##h##_neon(            \
+      const uint8_t *pre, int pre_stride, int xoffset, int yoffset,    \
+      const int32_t *wsrc, const int32_t *mask, unsigned int *sse) {   \
+    uint8_t tmp0[w * (h + padding)];                                   \
+    uint8_t tmp1[w * h];                                               \
+    var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h + padding, \
+                                xoffset);                              \
+    var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset);         \
+    return aom_obmc_variance##w##x##h(tmp1, w, wsrc, mask, sse);       \
+  }
+
+#define SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(w, h, padding)              \
+  unsigned int aom_obmc_sub_pixel_variance##w##x##h##_neon(                   \
+      const uint8_t *pre, int pre_stride, int xoffset, int yoffset,           \
+      const int32_t *wsrc, const int32_t *mask, unsigned int *sse) {          \
+    if (xoffset == 0) {                                                       \
+      if (yoffset == 0) {                                                     \
+        return aom_obmc_variance##w##x##h##_neon(pre, pre_stride, wsrc, mask, \
+                                                 sse);                        \
+      } else if (yoffset == 4) {                                              \
+        uint8_t tmp[w * h];                                                   \
+        var_filter_block2d_avg(pre, tmp, pre_stride, pre_stride, w, h);       \
+        return aom_obmc_variance##w##x##h##_neon(tmp, w, wsrc, mask, sse);    \
+      } else {                                                                \
+        uint8_t tmp[w * h];                                                   \
+        var_filter_block2d_bil_w##w(pre, tmp, pre_stride, pre_stride, h,      \
+                                    yoffset);                                 \
+        return aom_obmc_variance##w##x##h##_neon(tmp, w, wsrc, mask, sse);    \
+      }                                                                       \
+    } else if (xoffset == 4) {                                                \
+      uint8_t tmp0[w * (h + padding)];                                        \
+      if (yoffset == 0) {                                                     \
+        var_filter_block2d_avg(pre, tmp0, pre_stride, 1, w, h);               \
+        return aom_obmc_variance##w##x##h##_neon(tmp0, w, wsrc, mask, sse);   \
+      } else if (yoffset == 4) {                                              \
+        uint8_t tmp1[w * (h + padding)];                                      \
+        var_filter_block2d_avg(pre, tmp0, pre_stride, 1, w, h + padding);     \
+        var_filter_block2d_avg(tmp0, tmp1, w, w, w, h);                       \
+        return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse);   \
+      } else {                                                                \
+        uint8_t tmp1[w * (h + padding)];                                      \
+        var_filter_block2d_avg(pre, tmp0, pre_stride, 1, w, h + padding);     \
+        var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset);            \
+        return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse);   \
+      }                                                                       \
+    } else {                                                                  \
+      uint8_t tmp0[w * (h + padding)];                                        \
+      if (yoffset == 0) {                                                     \
+        var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h, xoffset);    \
+        return aom_obmc_variance##w##x##h##_neon(tmp0, w, wsrc, mask, sse);   \
+      } else if (yoffset == 4) {                                              \
+        uint8_t tmp1[w * h];                                                  \
+        var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h + padding,    \
+                                    xoffset);                                 \
+        var_filter_block2d_avg(tmp0, tmp1, w, w, w, h);                       \
+        return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse);   \
+      } else {                                                                \
+        uint8_t tmp1[w * h];                                                  \
+        var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h + padding,    \
+                                    xoffset);                                 \
+        var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset);            \
+        return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse);   \
+      }                                                                       \
+    }                                                                         \
+  }
+
+OBMC_SUBPEL_VARIANCE_WXH_NEON(4, 4, 2)
+OBMC_SUBPEL_VARIANCE_WXH_NEON(4, 8, 2)
+OBMC_SUBPEL_VARIANCE_WXH_NEON(4, 16, 2)
+
+OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 4, 1)
+OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 8, 1)
+OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 16, 1)
+OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 32, 1)
+
+OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 4, 1)
+OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 8, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 16, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 32, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 64, 1)
+
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 8, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 16, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 32, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 64, 1)
+
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 16, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 32, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 64, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 128, 1)
+
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(128, 64, 1)
+SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(128, 128, 1)
+
+#undef OBMC_SUBPEL_VARIANCE_WXH_NEON
+#undef SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON
+#endif  // !CONFIG_REALTIME_ONLY
+
+#define MASKED_SUBPEL_VARIANCE_WXH_NEON(w, h, padding)                         \
+  unsigned int aom_masked_sub_pixel_variance##w##x##h##_neon(                  \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,          \
+      const uint8_t *msk, int msk_stride, int invert_mask,                     \
+      unsigned int *sse) {                                                     \
+    uint8_t tmp0[w * (h + padding)];                                           \
+    uint8_t tmp1[w * h];                                                       \
+    uint8_t tmp2[w * h];                                                       \
+    var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding),       \
+                                xoffset);                                      \
+    var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset);                 \
+    aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk, msk_stride, \
+                            invert_mask);                                      \
+    return aom_variance##w##x##h##_neon(tmp2, w, ref, ref_stride, sse);        \
+  }
+
+#define SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(w, h, padding)             \
+  unsigned int aom_masked_sub_pixel_variance##w##x##h##_neon(                  \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,          \
+      const uint8_t *msk, int msk_stride, int invert_mask,                     \
+      unsigned int *sse) {                                                     \
+    if (xoffset == 0) {                                                        \
+      uint8_t tmp0[w * h];                                                     \
+      if (yoffset == 0) {                                                      \
+        aom_comp_mask_pred_neon(tmp0, second_pred, w, h, src, src_stride, msk, \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp0, w, ref, ref_stride, sse);    \
+      } else if (yoffset == 4) {                                               \
+        uint8_t tmp1[w * h];                                                   \
+        var_filter_block2d_avg(src, tmp0, src_stride, src_stride, w, h);       \
+        aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp1, w, ref, ref_stride, sse);    \
+      } else {                                                                 \
+        uint8_t tmp1[w * h];                                                   \
+        var_filter_block2d_bil_w##w(src, tmp0, src_stride, src_stride, h,      \
+                                    yoffset);                                  \
+        aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp1, w, ref, ref_stride, sse);    \
+      }                                                                        \
+    } else if (xoffset == 4) {                                                 \
+      uint8_t tmp0[w * (h + padding)];                                         \
+      if (yoffset == 0) {                                                      \
+        uint8_t tmp1[w * h];                                                   \
+        var_filter_block2d_avg(src, tmp0, src_stride, 1, w, h);                \
+        aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp1, w, ref, ref_stride, sse);    \
+      } else if (yoffset == 4) {                                               \
+        uint8_t tmp1[w * h];                                                   \
+        uint8_t tmp2[w * h];                                                   \
+        var_filter_block2d_avg(src, tmp0, src_stride, 1, w, (h + padding));    \
+        var_filter_block2d_avg(tmp0, tmp1, w, w, w, h);                        \
+        aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp2, w, ref, ref_stride, sse);    \
+      } else {                                                                 \
+        uint8_t tmp1[w * h];                                                   \
+        uint8_t tmp2[w * h];                                                   \
+        var_filter_block2d_avg(src, tmp0, src_stride, 1, w, (h + padding));    \
+        var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset);             \
+        aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp2, w, ref, ref_stride, sse);    \
+      }                                                                        \
+    } else {                                                                   \
+      if (yoffset == 0) {                                                      \
+        uint8_t tmp0[w * h];                                                   \
+        uint8_t tmp1[w * h];                                                   \
+        var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, h, xoffset);     \
+        aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp1, w, ref, ref_stride, sse);    \
+      } else if (yoffset == 4) {                                               \
+        uint8_t tmp0[w * (h + padding)];                                       \
+        uint8_t tmp1[w * h];                                                   \
+        uint8_t tmp2[w * h];                                                   \
+        var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding),   \
+                                    xoffset);                                  \
+        var_filter_block2d_avg(tmp0, tmp1, w, w, w, h);                        \
+        aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp2, w, ref, ref_stride, sse);    \
+      } else {                                                                 \
+        uint8_t tmp0[w * (h + padding)];                                       \
+        uint8_t tmp1[w * (h + padding)];                                       \
+        uint8_t tmp2[w * h];                                                   \
+        var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding),   \
+                                    xoffset);                                  \
+        var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset);             \
+        aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk,         \
+                                msk_stride, invert_mask);                      \
+        return aom_variance##w##x##h##_neon(tmp2, w, ref, ref_stride, sse);    \
+      }                                                                        \
+    }                                                                          \
+  }
+
+MASKED_SUBPEL_VARIANCE_WXH_NEON(4, 4, 2)
+MASKED_SUBPEL_VARIANCE_WXH_NEON(4, 8, 2)
+
+MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 4, 1)
+MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 8, 1)
+MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 1)
+
+MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 8, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 16, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 32, 1)
+
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 16, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 32, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 64, 1)
+
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 32, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 64, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 128, 1)
+
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(128, 64, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(128, 128, 1)
+
+// Realtime mode doesn't use 4x rectangular blocks.
+#if !CONFIG_REALTIME_ONLY
+MASKED_SUBPEL_VARIANCE_WXH_NEON(4, 16, 2)
+MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 1)
+MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 4, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 64, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 8, 1)
+SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 16, 1)
+#endif  // !CONFIG_REALTIME_ONLY
+
+#undef MASKED_SUBPEL_VARIANCE_WXH_NEON
+#undef SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 855edf6..9101979 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -37,6 +37,14 @@
 #endif
 }
 
+static INLINE int64_t horizontal_add_s64x2(const int64x2_t a) {
+#if defined(__aarch64__)
+  return vaddvq_s64(a);
+#else
+  return vgetq_lane_s64(a, 0) + vgetq_lane_s64(a, 1);
+#endif
+}
+
 static INLINE uint64_t horizontal_add_u64x2(const uint64x2_t a) {
 #if defined(__aarch64__)
   return vaddvq_u64(a);
@@ -65,6 +73,21 @@
 #endif
 }
 
+static INLINE uint32x4_t horizontal_add_4d_u32x4(const uint32x4_t sum[4]) {
+#if defined(__aarch64__)
+  uint32x4_t res01 = vpaddq_u32(sum[0], sum[1]);
+  uint32x4_t res23 = vpaddq_u32(sum[2], sum[3]);
+  return vpaddq_u32(res01, res23);
+#else
+  uint32x4_t res = vdupq_n_u32(0);
+  res = vsetq_lane_u32(horizontal_add_u32x4(sum[0]), res, 0);
+  res = vsetq_lane_u32(horizontal_add_u32x4(sum[1]), res, 1);
+  res = vsetq_lane_u32(horizontal_add_u32x4(sum[2]), res, 2);
+  res = vsetq_lane_u32(horizontal_add_u32x4(sum[3]), res, 3);
+  return res;
+#endif
+}
+
 static INLINE uint32_t horizontal_long_add_u16x8(const uint16x8_t vec_lo,
                                                  const uint16x8_t vec_hi) {
 #if defined(__aarch64__)
@@ -82,6 +105,31 @@
 #endif
 }
 
+static INLINE uint32x4_t horizontal_long_add_4d_u16x8(
+    const uint16x8_t sum_lo[4], const uint16x8_t sum_hi[4]) {
+  const uint32x4_t a0 = vpaddlq_u16(sum_lo[0]);
+  const uint32x4_t a1 = vpaddlq_u16(sum_lo[1]);
+  const uint32x4_t a2 = vpaddlq_u16(sum_lo[2]);
+  const uint32x4_t a3 = vpaddlq_u16(sum_lo[3]);
+  const uint32x4_t b0 = vpadalq_u16(a0, sum_hi[0]);
+  const uint32x4_t b1 = vpadalq_u16(a1, sum_hi[1]);
+  const uint32x4_t b2 = vpadalq_u16(a2, sum_hi[2]);
+  const uint32x4_t b3 = vpadalq_u16(a3, sum_hi[3]);
+#if defined(__aarch64__)
+  const uint32x4_t c0 = vpaddq_u32(b0, b1);
+  const uint32x4_t c1 = vpaddq_u32(b2, b3);
+  return vpaddq_u32(c0, c1);
+#else
+  const uint32x2_t c0 = vadd_u32(vget_low_u32(b0), vget_high_u32(b0));
+  const uint32x2_t c1 = vadd_u32(vget_low_u32(b1), vget_high_u32(b1));
+  const uint32x2_t c2 = vadd_u32(vget_low_u32(b2), vget_high_u32(b2));
+  const uint32x2_t c3 = vadd_u32(vget_low_u32(b3), vget_high_u32(b3));
+  const uint32x2_t d0 = vpadd_u32(c0, c1);
+  const uint32x2_t d1 = vpadd_u32(c2, c3);
+  return vcombine_u32(d0, d1);
+#endif
+}
+
 static INLINE uint32_t horizontal_add_u16x8(const uint16x8_t a) {
 #if defined(__aarch64__)
   return vaddlvq_u16(a);
@@ -94,6 +142,23 @@
 #endif
 }
 
+static INLINE uint32x4_t horizontal_add_4d_u16x8(const uint16x8_t sum[4]) {
+#if defined(__aarch64__)
+  const uint16x8_t a0 = vpaddq_u16(sum[0], sum[1]);
+  const uint16x8_t a1 = vpaddq_u16(sum[2], sum[3]);
+  const uint16x8_t b0 = vpaddq_u16(a0, a1);
+  return vpaddlq_u16(b0);
+#else
+  const uint16x4_t a0 = vadd_u16(vget_low_u16(sum[0]), vget_high_u16(sum[0]));
+  const uint16x4_t a1 = vadd_u16(vget_low_u16(sum[1]), vget_high_u16(sum[1]));
+  const uint16x4_t a2 = vadd_u16(vget_low_u16(sum[2]), vget_high_u16(sum[2]));
+  const uint16x4_t a3 = vadd_u16(vget_low_u16(sum[3]), vget_high_u16(sum[3]));
+  const uint16x4_t b0 = vpadd_u16(a0, a1);
+  const uint16x4_t b1 = vpadd_u16(a2, a3);
+  return vpaddlq_u16(vcombine_u16(b0, b1));
+#endif
+}
+
 static INLINE uint32_t horizontal_add_u32x2(const uint32x2_t a) {
 #if defined(__aarch64__)
   return vaddv_u32(a);
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index bf212a9..095a2c6 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -35,7 +35,7 @@
                                                        int stride, int height) {
   int32x4_t sum_squares[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
 
-  int h = 0;
+  int h = height;
   do {
     int16x4_t s0 = vld1_s16(src + 0 * stride);
     int16x4_t s1 = vld1_s16(src + 1 * stride);
@@ -48,8 +48,8 @@
     sum_squares[1] = vmlal_s16(sum_squares[1], s3, s3);
 
     src += 4 * stride;
-    h += 4;
-  } while (h < height);
+    h -= 4;
+  } while (h != 0);
 
   return horizontal_long_add_u32x4(
       vreinterpretq_u32_s32(vaddq_s32(sum_squares[0], sum_squares[1])));
@@ -60,7 +60,7 @@
                                                        int height) {
   uint64x2_t sum_squares = vdupq_n_u64(0);
 
-  int h = 0;
+  int h = height;
   do {
     int32x4_t ss_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
     int w = 0;
@@ -86,8 +86,8 @@
         sum_squares, vreinterpretq_u32_s32(vaddq_s32(ss_row[0], ss_row[1])));
 
     src += 4 * stride;
-    h += 4;
-  } while (h < height);
+    h -= 4;
+  } while (h != 0);
 
   return horizontal_add_u64x2(sum_squares);
 }
@@ -134,7 +134,7 @@
   int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
   int32x2_t sum_acc[2] = { vdup_n_s32(0), vdup_n_s32(0) };
 
-  int h = 0;
+  int h = height;
   do {
     int16x4_t s0 = vld1_s16(src + 0 * stride);
     int16x4_t s1 = vld1_s16(src + 1 * stride);
@@ -152,8 +152,8 @@
     sum_acc[1] = vpadal_s16(sum_acc[1], s3);
 
     src += 4 * stride;
-    h += 4;
-  } while (h < height);
+    h -= 4;
+  } while (h != 0);
 
   *sum += horizontal_add_s32x4(vcombine_s32(sum_acc[0], sum_acc[1]));
   return horizontal_long_add_u32x4(
@@ -166,7 +166,7 @@
   uint64x2_t sse = vdupq_n_u64(0);
   int32x4_t sum_acc = vdupq_n_s32(0);
 
-  int h = 0;
+  int h = height;
   do {
     int32x4_t sse_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
     int w = 0;
@@ -198,8 +198,8 @@
                       vreinterpretq_u32_s32(vaddq_s32(sse_row[0], sse_row[1])));
 
     src += 4 * stride;
-    h += 4;
-  } while (h < height);
+    h -= 4;
+  } while (h != 0);
 
   *sum += horizontal_add_s32x4(sum_acc);
   return horizontal_add_u64x2(sse);
diff --git a/aom_dsp/arm/transpose_neon.h b/aom_dsp/arm/transpose_neon.h
index 26fc1fd..d151c58 100644
--- a/aom_dsp/arm/transpose_neon.h
+++ b/aom_dsp/arm/transpose_neon.h
@@ -258,13 +258,19 @@
   a[3] = vreinterpretq_u16_u32(c1.val[1]);
 }
 
-static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(const uint32x4_t a0,
-                                                const uint32x4_t a1) {
+static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(uint32x4_t a0, uint32x4_t a1) {
   uint16x8x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vreinterpretq_u16_u64(
+      vtrn1q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+  b0.val[1] = vreinterpretq_u16_u64(
+      vtrn2q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+#else
   b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)),
                            vreinterpret_u16_u32(vget_low_u32(a1)));
   b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)),
                            vreinterpret_u16_u32(vget_high_u32(a1)));
+#endif
   return b0;
 }
 
@@ -343,7 +349,7 @@
                                      uint16x4_t *a6, uint16x4_t *a7,
                                      uint16x8_t *o0, uint16x8_t *o1,
                                      uint16x8_t *o2, uint16x8_t *o3) {
-  // Swap 16 bit elements. Goes from:
+  // Combine rows. Goes from:
   // a0: 00 01 02 03
   // a1: 10 11 12 13
   // a2: 20 21 22 23
@@ -353,53 +359,40 @@
   // a6: 60 61 62 63
   // a7: 70 71 72 73
   // to:
-  // b0.val[0]: 00 10 02 12
-  // b0.val[1]: 01 11 03 13
-  // b1.val[0]: 20 30 22 32
-  // b1.val[1]: 21 31 23 33
-  // b2.val[0]: 40 50 42 52
-  // b2.val[1]: 41 51 43 53
-  // b3.val[0]: 60 70 62 72
-  // b3.val[1]: 61 71 63 73
+  // b0: 00 01 02 03 40 41 42 43
+  // b1: 10 11 12 13 50 51 52 53
+  // b2: 20 21 22 23 60 61 62 63
+  // b3: 30 31 32 33 70 71 72 73
 
-  uint16x4x2_t b0 = vtrn_u16(*a0, *a1);
-  uint16x4x2_t b1 = vtrn_u16(*a2, *a3);
-  uint16x4x2_t b2 = vtrn_u16(*a4, *a5);
-  uint16x4x2_t b3 = vtrn_u16(*a6, *a7);
+  const uint16x8_t b0 = vcombine_u16(*a0, *a4);
+  const uint16x8_t b1 = vcombine_u16(*a1, *a5);
+  const uint16x8_t b2 = vcombine_u16(*a2, *a6);
+  const uint16x8_t b3 = vcombine_u16(*a3, *a7);
+
+  // Swap 16 bit elements resulting in:
+  // c0.val[0]: 00 10 02 12 40 50 42 52
+  // c0.val[1]: 01 11 03 13 41 51 43 53
+  // c1.val[0]: 20 30 22 32 60 70 62 72
+  // c1.val[1]: 21 31 23 33 61 71 63 73
+
+  const uint16x8x2_t c0 = vtrnq_u16(b0, b1);
+  const uint16x8x2_t c1 = vtrnq_u16(b2, b3);
 
   // Swap 32 bit elements resulting in:
-  // c0.val[0]: 00 10 20 30
-  // c0.val[1]: 02 12 22 32
-  // c1.val[0]: 01 11 21 31
-  // c1.val[1]: 03 13 23 33
-  // c2.val[0]: 40 50 60 70
-  // c2.val[1]: 42 52 62 72
-  // c3.val[0]: 41 51 61 71
-  // c3.val[1]: 43 53 63 73
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 02 12 22 32 42 52 62 72
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 03 13 23 33 43 53 63 73
 
-  uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]),
-                             vreinterpret_u32_u16(b1.val[0]));
-  uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]),
-                             vreinterpret_u32_u16(b1.val[1]));
-  uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b2.val[0]),
-                             vreinterpret_u32_u16(b3.val[0]));
-  uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b2.val[1]),
-                             vreinterpret_u32_u16(b3.val[1]));
+  const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]),
+                                    vreinterpretq_u32_u16(c1.val[0]));
+  const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]),
+                                    vreinterpretq_u32_u16(c1.val[1]));
 
-  // Swap 64 bit elements resulting in:
-  // o0: 00 10 20 30 40 50 60 70
-  // o1: 01 11 21 31 41 51 61 71
-  // o2: 02 12 22 32 42 52 62 72
-  // o3: 03 13 23 33 43 53 63 73
-
-  *o0 = vcombine_u16(vreinterpret_u16_u32(c0.val[0]),
-                     vreinterpret_u16_u32(c2.val[0]));
-  *o1 = vcombine_u16(vreinterpret_u16_u32(c1.val[0]),
-                     vreinterpret_u16_u32(c3.val[0]));
-  *o2 = vcombine_u16(vreinterpret_u16_u32(c0.val[1]),
-                     vreinterpret_u16_u32(c2.val[1]));
-  *o3 = vcombine_u16(vreinterpret_u16_u32(c1.val[1]),
-                     vreinterpret_u16_u32(c3.val[1]));
+  *o0 = vreinterpretq_u16_u32(d0.val[0]);
+  *o1 = vreinterpretq_u16_u32(d1.val[0]);
+  *o2 = vreinterpretq_u16_u32(d0.val[1]);
+  *o3 = vreinterpretq_u16_u32(d1.val[1]);
 }
 
 static INLINE void transpose_s16_4x8(int16x4_t *a0, int16x4_t *a1,
@@ -408,7 +401,7 @@
                                      int16x4_t *a6, int16x4_t *a7,
                                      int16x8_t *o0, int16x8_t *o1,
                                      int16x8_t *o2, int16x8_t *o3) {
-  // Swap 16 bit elements. Goes from:
+  // Combine rows. Goes from:
   // a0: 00 01 02 03
   // a1: 10 11 12 13
   // a2: 20 21 22 23
@@ -418,53 +411,40 @@
   // a6: 60 61 62 63
   // a7: 70 71 72 73
   // to:
-  // b0.val[0]: 00 10 02 12
-  // b0.val[1]: 01 11 03 13
-  // b1.val[0]: 20 30 22 32
-  // b1.val[1]: 21 31 23 33
-  // b2.val[0]: 40 50 42 52
-  // b2.val[1]: 41 51 43 53
-  // b3.val[0]: 60 70 62 72
-  // b3.val[1]: 61 71 63 73
+  // b0: 00 01 02 03 40 41 42 43
+  // b1: 10 11 12 13 50 51 52 53
+  // b2: 20 21 22 23 60 61 62 63
+  // b3: 30 31 32 33 70 71 72 73
 
-  int16x4x2_t b0 = vtrn_s16(*a0, *a1);
-  int16x4x2_t b1 = vtrn_s16(*a2, *a3);
-  int16x4x2_t b2 = vtrn_s16(*a4, *a5);
-  int16x4x2_t b3 = vtrn_s16(*a6, *a7);
+  const int16x8_t b0 = vcombine_s16(*a0, *a4);
+  const int16x8_t b1 = vcombine_s16(*a1, *a5);
+  const int16x8_t b2 = vcombine_s16(*a2, *a6);
+  const int16x8_t b3 = vcombine_s16(*a3, *a7);
+
+  // Swap 16 bit elements resulting in:
+  // c0.val[0]: 00 10 02 12 40 50 42 52
+  // c0.val[1]: 01 11 03 13 41 51 43 53
+  // c1.val[0]: 20 30 22 32 60 70 62 72
+  // c1.val[1]: 21 31 23 33 61 71 63 73
+
+  const int16x8x2_t c0 = vtrnq_s16(b0, b1);
+  const int16x8x2_t c1 = vtrnq_s16(b2, b3);
 
   // Swap 32 bit elements resulting in:
-  // c0.val[0]: 00 10 20 30
-  // c0.val[1]: 02 12 22 32
-  // c1.val[0]: 01 11 21 31
-  // c1.val[1]: 03 13 23 33
-  // c2.val[0]: 40 50 60 70
-  // c2.val[1]: 42 52 62 72
-  // c3.val[0]: 41 51 61 71
-  // c3.val[1]: 43 53 63 73
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 02 12 22 32 42 52 62 72
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 03 13 23 33 43 53 63 73
 
-  int32x2x2_t c0 = vtrn_s32(vreinterpret_s32_s16(b0.val[0]),
-                            vreinterpret_s32_s16(b1.val[0]));
-  int32x2x2_t c1 = vtrn_s32(vreinterpret_s32_s16(b0.val[1]),
-                            vreinterpret_s32_s16(b1.val[1]));
-  int32x2x2_t c2 = vtrn_s32(vreinterpret_s32_s16(b2.val[0]),
-                            vreinterpret_s32_s16(b3.val[0]));
-  int32x2x2_t c3 = vtrn_s32(vreinterpret_s32_s16(b2.val[1]),
-                            vreinterpret_s32_s16(b3.val[1]));
+  const int32x4x2_t d0 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[0]),
+                                   vreinterpretq_s32_s16(c1.val[0]));
+  const int32x4x2_t d1 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[1]),
+                                   vreinterpretq_s32_s16(c1.val[1]));
 
-  // Swap 64 bit elements resulting in:
-  // o0: 00 10 20 30 40 50 60 70
-  // o1: 01 11 21 31 41 51 61 71
-  // o2: 02 12 22 32 42 52 62 72
-  // o3: 03 13 23 33 43 53 63 73
-
-  *o0 = vcombine_s16(vreinterpret_s16_s32(c0.val[0]),
-                     vreinterpret_s16_s32(c2.val[0]));
-  *o1 = vcombine_s16(vreinterpret_s16_s32(c1.val[0]),
-                     vreinterpret_s16_s32(c3.val[0]));
-  *o2 = vcombine_s16(vreinterpret_s16_s32(c0.val[1]),
-                     vreinterpret_s16_s32(c2.val[1]));
-  *o3 = vcombine_s16(vreinterpret_s16_s32(c1.val[1]),
-                     vreinterpret_s16_s32(c3.val[1]));
+  *o0 = vreinterpretq_s16_s32(d0.val[0]);
+  *o1 = vreinterpretq_s16_s32(d1.val[0]);
+  *o2 = vreinterpretq_s16_s32(d0.val[1]);
+  *o3 = vreinterpretq_s16_s32(d1.val[1]);
 }
 
 static INLINE void transpose_u16_8x8(uint16x8_t *a0, uint16x8_t *a1,
@@ -514,25 +494,45 @@
   const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]),
                                     vreinterpretq_u32_u16(b3.val[1]));
 
-  *a0 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[0])),
-                     vget_low_u16(vreinterpretq_u16_u32(c2.val[0])));
-  *a4 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[0])),
-                     vget_high_u16(vreinterpretq_u16_u32(c2.val[0])));
+  // Swap 64 bit elements resulting in:
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 04 14 24 34 44 54 64 74
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 05 15 25 35 45 55 65 75
+  // d2.val[0]: 02 12 22 32 42 52 62 72
+  // d2.val[1]: 06 16 26 36 46 56 66 76
+  // d3.val[0]: 03 13 23 33 43 53 63 73
+  // d3.val[1]: 07 17 27 37 47 57 67 77
 
-  *a2 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c0.val[1])),
-                     vget_low_u16(vreinterpretq_u16_u32(c2.val[1])));
-  *a6 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c0.val[1])),
-                     vget_high_u16(vreinterpretq_u16_u32(c2.val[1])));
+  const uint16x8x2_t d0 = aom_vtrnq_u64_to_u16(c0.val[0], c2.val[0]);
+  const uint16x8x2_t d1 = aom_vtrnq_u64_to_u16(c1.val[0], c3.val[0]);
+  const uint16x8x2_t d2 = aom_vtrnq_u64_to_u16(c0.val[1], c2.val[1]);
+  const uint16x8x2_t d3 = aom_vtrnq_u64_to_u16(c1.val[1], c3.val[1]);
 
-  *a1 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[0])),
-                     vget_low_u16(vreinterpretq_u16_u32(c3.val[0])));
-  *a5 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[0])),
-                     vget_high_u16(vreinterpretq_u16_u32(c3.val[0])));
+  *a0 = d0.val[0];
+  *a1 = d1.val[0];
+  *a2 = d2.val[0];
+  *a3 = d3.val[0];
+  *a4 = d0.val[1];
+  *a5 = d1.val[1];
+  *a6 = d2.val[1];
+  *a7 = d3.val[1];
+}
 
-  *a3 = vcombine_u16(vget_low_u16(vreinterpretq_u16_u32(c1.val[1])),
-                     vget_low_u16(vreinterpretq_u16_u32(c3.val[1])));
-  *a7 = vcombine_u16(vget_high_u16(vreinterpretq_u16_u32(c1.val[1])),
-                     vget_high_u16(vreinterpretq_u16_u32(c3.val[1])));
+static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) {
+  int16x8x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vreinterpretq_s16_s64(
+      vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+  b0.val[1] = vreinterpretq_s16_s64(
+      vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+#else
+  b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
+                           vreinterpret_s16_s32(vget_low_s32(a1)));
+  b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
+                           vreinterpret_s16_s32(vget_high_s32(a1)));
+#endif
+  return b0;
 }
 
 static INLINE void transpose_s16_8x8(int16x8_t *a0, int16x8_t *a1,
@@ -582,37 +582,32 @@
   const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
                                    vreinterpretq_s32_s16(b3.val[1]));
 
-  *a0 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[0])),
-                     vget_low_s16(vreinterpretq_s16_s32(c2.val[0])));
-  *a4 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[0])),
-                     vget_high_s16(vreinterpretq_s16_s32(c2.val[0])));
+  // Swap 64 bit elements resulting in:
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 04 14 24 34 44 54 64 74
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 05 15 25 35 45 55 65 75
+  // d2.val[0]: 02 12 22 32 42 52 62 72
+  // d2.val[1]: 06 16 26 36 46 56 66 76
+  // d3.val[0]: 03 13 23 33 43 53 63 73
+  // d3.val[1]: 07 17 27 37 47 57 67 77
 
-  *a2 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c0.val[1])),
-                     vget_low_s16(vreinterpretq_s16_s32(c2.val[1])));
-  *a6 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c0.val[1])),
-                     vget_high_s16(vreinterpretq_s16_s32(c2.val[1])));
+  const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]);
+  const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]);
+  const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]);
+  const int16x8x2_t d3 = aom_vtrnq_s64_to_s16(c1.val[1], c3.val[1]);
 
-  *a1 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[0])),
-                     vget_low_s16(vreinterpretq_s16_s32(c3.val[0])));
-  *a5 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[0])),
-                     vget_high_s16(vreinterpretq_s16_s32(c3.val[0])));
-
-  *a3 = vcombine_s16(vget_low_s16(vreinterpretq_s16_s32(c1.val[1])),
-                     vget_low_s16(vreinterpretq_s16_s32(c3.val[1])));
-  *a7 = vcombine_s16(vget_high_s16(vreinterpretq_s16_s32(c1.val[1])),
-                     vget_high_s16(vreinterpretq_s16_s32(c3.val[1])));
+  *a0 = d0.val[0];
+  *a1 = d1.val[0];
+  *a2 = d2.val[0];
+  *a3 = d3.val[0];
+  *a4 = d0.val[1];
+  *a5 = d1.val[1];
+  *a6 = d2.val[1];
+  *a7 = d3.val[1];
 }
 
-static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) {
-  int16x8x2_t b0;
-  b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
-                           vreinterpret_s16_s32(vget_low_s32(a1)));
-  b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
-                           vreinterpret_s16_s32(vget_high_s32(a1)));
-  return b0;
-}
-
-static INLINE void transpose_s16_8x8q(int16x8_t *a0, int16x8_t *out) {
+static INLINE void transpose_s16_8x8q(int16x8_t *a, int16x8_t *out) {
   // Swap 16 bit elements. Goes from:
   // a0: 00 01 02 03 04 05 06 07
   // a1: 10 11 12 13 14 15 16 17
@@ -632,10 +627,10 @@
   // b3.val[0]: 60 70 62 72 64 74 66 76
   // b3.val[1]: 61 71 63 73 65 75 67 77
 
-  const int16x8x2_t b0 = vtrnq_s16(*a0, *(a0 + 1));
-  const int16x8x2_t b1 = vtrnq_s16(*(a0 + 2), *(a0 + 3));
-  const int16x8x2_t b2 = vtrnq_s16(*(a0 + 4), *(a0 + 5));
-  const int16x8x2_t b3 = vtrnq_s16(*(a0 + 6), *(a0 + 7));
+  const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]);
+  const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]);
+  const int16x8x2_t b2 = vtrnq_s16(a[4], a[5]);
+  const int16x8x2_t b3 = vtrnq_s16(a[6], a[7]);
 
   // Swap 32 bit elements resulting in:
   // c0.val[0]: 00 10 20 30 04 14 24 34
@@ -665,19 +660,53 @@
   // d2.val[1]: 06 16 26 36 46 56 66 76
   // d3.val[0]: 03 13 23 33 43 53 63 73
   // d3.val[1]: 07 17 27 37 47 57 67 77
+
   const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]);
   const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]);
   const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]);
   const int16x8x2_t d3 = aom_vtrnq_s64_to_s16(c1.val[1], c3.val[1]);
 
-  *out = d0.val[0];
-  *(out + 1) = d1.val[0];
-  *(out + 2) = d2.val[0];
-  *(out + 3) = d3.val[0];
-  *(out + 4) = d0.val[1];
-  *(out + 5) = d1.val[1];
-  *(out + 6) = d2.val[1];
-  *(out + 7) = d3.val[1];
+  out[0] = d0.val[0];
+  out[1] = d1.val[0];
+  out[2] = d2.val[0];
+  out[3] = d3.val[0];
+  out[4] = d0.val[1];
+  out[5] = d1.val[1];
+  out[6] = d2.val[1];
+  out[7] = d3.val[1];
+}
+
+static INLINE void transpose_u16_4x4d(uint16x4_t *a0, uint16x4_t *a1,
+                                      uint16x4_t *a2, uint16x4_t *a3) {
+  // Swap 16 bit elements. Goes from:
+  // a0: 00 01 02 03
+  // a1: 10 11 12 13
+  // a2: 20 21 22 23
+  // a3: 30 31 32 33
+  // to:
+  // b0.val[0]: 00 10 02 12
+  // b0.val[1]: 01 11 03 13
+  // b1.val[0]: 20 30 22 32
+  // b1.val[1]: 21 31 23 33
+
+  const uint16x4x2_t b0 = vtrn_u16(*a0, *a1);
+  const uint16x4x2_t b1 = vtrn_u16(*a2, *a3);
+
+  // Swap 32 bit elements resulting in:
+  // c0.val[0]: 00 10 20 30
+  // c0.val[1]: 02 12 22 32
+  // c1.val[0]: 01 11 21 31
+  // c1.val[1]: 03 13 23 33
+
+  const uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]),
+                                   vreinterpret_u32_u16(b1.val[0]));
+  const uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]),
+                                   vreinterpret_u32_u16(b1.val[1]));
+
+  *a0 = vreinterpret_u16_u32(c0.val[0]);
+  *a1 = vreinterpret_u16_u32(c1.val[0]);
+  *a2 = vreinterpret_u16_u32(c0.val[1]);
+  *a3 = vreinterpret_u16_u32(c1.val[1]);
 }
 
 static INLINE void transpose_s16_4x4d(int16x4_t *a0, int16x4_t *a1,
@@ -715,8 +744,15 @@
 
 static INLINE int32x4x2_t aom_vtrnq_s64_to_s32(int32x4_t a0, int32x4_t a1) {
   int32x4x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vreinterpretq_s32_s64(
+      vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+  b0.val[1] = vreinterpretq_s32_s64(
+      vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+#else
   b0.val[0] = vcombine_s32(vget_low_s32(a0), vget_low_s32(a1));
   b0.val[1] = vcombine_s32(vget_high_s32(a0), vget_high_s32(a1));
+#endif
   return b0;
 }
 
diff --git a/aom_dsp/arm/variance_neon.c b/aom_dsp/arm/variance_neon.c
index 40e40f0..5e33996 100644
--- a/aom_dsp/arm/variance_neon.c
+++ b/aom_dsp/arm/variance_neon.c
@@ -27,7 +27,7 @@
   uint32x4_t ref_sum = vdupq_n_u32(0);
   uint32x4_t sse_u32 = vdupq_n_u32(0);
 
-  int i = 0;
+  int i = h;
   do {
     uint8x16_t s = load_unaligned_u8q(src, src_stride);
     uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
@@ -40,8 +40,8 @@
 
     src += 4 * src_stride;
     ref += 4 * ref_stride;
-    i += 4;
-  } while (i < h);
+    i -= 4;
+  } while (i != 0);
 
   int32x4_t sum_diff =
       vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -56,7 +56,7 @@
   uint32x4_t ref_sum = vdupq_n_u32(0);
   uint32x4_t sse_u32 = vdupq_n_u32(0);
 
-  int i = 0;
+  int i = h;
   do {
     uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
     uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
@@ -69,8 +69,8 @@
 
     src += 2 * src_stride;
     ref += 2 * ref_stride;
-    i += 2;
-  } while (i < h);
+    i -= 2;
+  } while (i != 0);
 
   int32x4_t sum_diff =
       vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -85,7 +85,7 @@
   uint32x4_t ref_sum = vdupq_n_u32(0);
   uint32x4_t sse_u32 = vdupq_n_u32(0);
 
-  int i = 0;
+  int i = h;
   do {
     uint8x16_t s = vld1q_u8(src);
     uint8x16_t r = vld1q_u8(ref);
@@ -98,8 +98,7 @@
 
     src += src_stride;
     ref += ref_stride;
-    i++;
-  } while (i < h);
+  } while (--i != 0);
 
   int32x4_t sum_diff =
       vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -114,7 +113,7 @@
   uint32x4_t ref_sum = vdupq_n_u32(0);
   uint32x4_t sse_u32 = vdupq_n_u32(0);
 
-  int i = 0;
+  int i = h;
   do {
     int j = 0;
     do {
@@ -132,8 +131,7 @@
 
     src += src_stride;
     ref += ref_stride;
-    i++;
-  } while (i < h);
+  } while (--i != 0);
 
   int32x4_t sum_diff =
       vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
@@ -171,7 +169,7 @@
   // 32767 / 255 ~= 128, but we use an 8-wide accumulator; so 256 4-wide rows.
   assert(h <= 256);
 
-  int i = 0;
+  int i = h;
   do {
     uint8x8_t s = load_unaligned_u8(src, src_stride);
     uint8x8_t r = load_unaligned_u8(ref, ref_stride);
@@ -184,8 +182,8 @@
 
     src += 2 * src_stride;
     ref += 2 * ref_stride;
-    i += 2;
-  } while (i < h);
+    i -= 2;
+  } while (i != 0);
 
   *sum = horizontal_add_s16x8(sum_s16);
   *sse = (uint32_t)horizontal_add_s32x4(sse_s32);
@@ -201,7 +199,7 @@
   // 32767 / 255 ~= 128
   assert(h <= 128);
 
-  int i = 0;
+  int i = h;
   do {
     uint8x8_t s = vld1_u8(src);
     uint8x8_t r = vld1_u8(ref);
@@ -215,8 +213,7 @@
 
     src += src_stride;
     ref += ref_stride;
-    i++;
-  } while (i < h);
+  } while (--i != 0);
 
   *sum = horizontal_add_s16x8(sum_s16);
   *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1]));
@@ -232,7 +229,7 @@
   // 32767 / 255 ~= 128, so 128 16-wide rows.
   assert(h <= 128);
 
-  int i = 0;
+  int i = h;
   do {
     uint8x16_t s = vld1q_u8(src);
     uint8x16_t r = vld1q_u8(ref);
@@ -256,8 +253,7 @@
 
     src += src_stride;
     ref += ref_stride;
-    i++;
-  } while (i < h);
+  } while (--i != 0);
 
   *sum = horizontal_add_s16x8(vaddq_s16(sum_s16[0], sum_s16[1]));
   *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1]));
@@ -378,17 +374,6 @@
 
 #undef VARIANCE_WXH_NEON
 
-void aom_get8x8var_neon(const uint8_t *src, int src_stride, const uint8_t *ref,
-                        int ref_stride, unsigned int *sse, int *sum) {
-  variance_8xh_neon(src, src_stride, ref, ref_stride, 8, sse, sum);
-}
-
-void aom_get16x16var_neon(const uint8_t *src, int src_stride,
-                          const uint8_t *ref, int ref_stride, unsigned int *sse,
-                          int *sum) {
-  variance_16xh_neon(src, src_stride, ref, ref_stride, 16, sse, sum);
-}
-
 // TODO(yunqingwang): Perform variance of two/four 8x8 blocks similar to that of
 // AVX2. Also, implement the NEON for variance computation present in this
 // function.
@@ -409,6 +394,25 @@
     var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6);
 }
 
+void aom_get_var_sse_sum_16x16_dual_neon(const uint8_t *src, int src_stride,
+                                         const uint8_t *ref, int ref_stride,
+                                         uint32_t *sse16x16,
+                                         unsigned int *tot_sse, int *tot_sum,
+                                         uint32_t *var16x16) {
+  int sum16x16[2] = { 0 };
+  // Loop over 2 16x16 blocks. Process one 16x32 block.
+  for (int k = 0; k < 2; k++) {
+    variance_16xh_neon(src + (k * 16), src_stride, ref + (k * 16), ref_stride,
+                       16, &sse16x16[k], &sum16x16[k]);
+  }
+
+  *tot_sse += sse16x16[0] + sse16x16[1];
+  *tot_sum += sum16x16[0] + sum16x16[1];
+  for (int i = 0; i < 2; i++)
+    var16x16[i] =
+        sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8);
+}
+
 #if defined(__ARM_FEATURE_DOTPROD)
 
 static INLINE unsigned int mse8xh_neon(const uint8_t *src, int src_stride,
@@ -416,7 +420,7 @@
                                        unsigned int *sse, int h) {
   uint32x4_t sse_u32 = vdupq_n_u32(0);
 
-  int i = 0;
+  int i = h;
   do {
     uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
     uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
@@ -427,8 +431,8 @@
 
     src += 2 * src_stride;
     ref += 2 * ref_stride;
-    i += 2;
-  } while (i < h);
+    i -= 2;
+  } while (i != 0);
 
   *sse = horizontal_add_u32x4(sse_u32);
   return horizontal_add_u32x4(sse_u32);
@@ -439,7 +443,7 @@
                                         unsigned int *sse, int h) {
   uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
 
-  int i = 0;
+  int i = h;
   do {
     uint8x16_t s0 = vld1q_u8(src);
     uint8x16_t s1 = vld1q_u8(src + src_stride);
@@ -454,25 +458,13 @@
 
     src += 2 * src_stride;
     ref += 2 * ref_stride;
-    i += 2;
-  } while (i < h);
+    i -= 2;
+  } while (i != 0);
 
   *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
   return horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
 }
 
-unsigned int aom_get4x4sse_cs_neon(const uint8_t *src, int src_stride,
-                                   const uint8_t *ref, int ref_stride) {
-  uint8x16_t s = load_unaligned_u8q(src, src_stride);
-  uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
-
-  uint8x16_t abs_diff = vabdq_u8(s, r);
-
-  uint32x4_t sse = vdotq_u32(vdupq_n_u32(0), abs_diff, abs_diff);
-
-  return horizontal_add_u32x4(sse);
-}
-
 #else  // !defined(__ARM_FEATURE_DOTPROD)
 
 static INLINE unsigned int mse8xh_neon(const uint8_t *src, int src_stride,
@@ -483,7 +475,7 @@
   uint16x8_t diff[2];
   int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
 
-  int i = 0;
+  int i = h;
   do {
     s[0] = vld1_u8(src);
     src += src_stride;
@@ -507,8 +499,8 @@
     sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]);
     sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]);
 
-    i += 2;
-  } while (i < h);
+    i -= 2;
+  } while (i != 0);
 
   sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]);
 
@@ -525,7 +517,7 @@
   int32x4_t sse_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
                            vdupq_n_s32(0) };
 
-  int i = 0;
+  int i = h;
   do {
     s[0] = vld1q_u8(src);
     src += src_stride;
@@ -561,8 +553,8 @@
     sse_s32[2] = vmlal_s16(sse_s32[2], diff_hi[2], diff_hi[2]);
     sse_s32[3] = vmlal_s16(sse_s32[3], diff_hi[3], diff_hi[3]);
 
-    i += 2;
-  } while (i < h);
+    i -= 2;
+  } while (i != 0);
 
   sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]);
   sse_s32[2] = vaddq_s32(sse_s32[2], sse_s32[3]);
@@ -572,40 +564,6 @@
   return horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0]));
 }
 
-unsigned int aom_get4x4sse_cs_neon(const uint8_t *src, int src_stride,
-                                   const uint8_t *ref, int ref_stride) {
-  uint8x8_t s[4], r[4];
-  int16x4_t diff[4];
-  int32x4_t sse;
-
-  s[0] = vld1_u8(src);
-  src += src_stride;
-  r[0] = vld1_u8(ref);
-  ref += ref_stride;
-  s[1] = vld1_u8(src);
-  src += src_stride;
-  r[1] = vld1_u8(ref);
-  ref += ref_stride;
-  s[2] = vld1_u8(src);
-  src += src_stride;
-  r[2] = vld1_u8(ref);
-  ref += ref_stride;
-  s[3] = vld1_u8(src);
-  r[3] = vld1_u8(ref);
-
-  diff[0] = vget_low_s16(vreinterpretq_s16_u16(vsubl_u8(s[0], r[0])));
-  diff[1] = vget_low_s16(vreinterpretq_s16_u16(vsubl_u8(s[1], r[1])));
-  diff[2] = vget_low_s16(vreinterpretq_s16_u16(vsubl_u8(s[2], r[2])));
-  diff[3] = vget_low_s16(vreinterpretq_s16_u16(vsubl_u8(s[3], r[3])));
-
-  sse = vmull_s16(diff[0], diff[0]);
-  sse = vmlal_s16(sse, diff[1], diff[1]);
-  sse = vmlal_s16(sse, diff[2], diff[2]);
-  sse = vmlal_s16(sse, diff[3], diff[3]);
-
-  return horizontal_add_u32x4(vreinterpretq_u32_s32(sse));
-}
-
 #endif  // defined(__ARM_FEATURE_DOTPROD)
 
 #define MSE_WXH_NEON(w, h)                                                 \
@@ -647,7 +605,7 @@
                                               int h) {
   uint64x2_t square_result = vdupq_n_u64(0);
   uint32_t d0, d1;
-  int i = 0;
+  int i = h;
   uint8_t *dst_ptr = dst;
   uint16_t *src_ptr = src;
   do {
@@ -678,8 +636,8 @@
     const uint16x8_t src_16x8 = vcombine_u16(src0_16x4, src1_16x4);
 
     COMPUTE_MSE_16BIT(src_16x8, dst_16x8)
-    i += 2;
-  } while (i < h);
+    i -= 2;
+  } while (i != 0);
   uint64x1_t sum =
       vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result));
   return vget_lane_u64(sum, 0);
@@ -689,16 +647,18 @@
                                               uint16_t *src, int sstride,
                                               int h) {
   uint64x2_t square_result = vdupq_n_u64(0);
-  int i = 0;
+  int i = h;
   do {
     // d7 d6 d5 d4 d3 d2 d1 d0 - 8 bit
-    const uint16x8_t dst_16x8 = vmovl_u8(vld1_u8(&dst[i * dstride]));
+    const uint16x8_t dst_16x8 = vmovl_u8(vld1_u8(dst));
     // s7 s6 s5 s4 s3 s2 s1 s0 - 16 bit
-    const uint16x8_t src_16x8 = vld1q_u16(&src[i * sstride]);
+    const uint16x8_t src_16x8 = vld1q_u16(src);
 
     COMPUTE_MSE_16BIT(src_16x8, dst_16x8)
-    i++;
-  } while (i < h);
+
+    dst += dstride;
+    src += sstride;
+  } while (--i != 0);
   uint64x1_t sum =
       vadd_u64(vget_high_u64(square_result), vget_low_u64(square_result));
   return vget_lane_u64(sum, 0);
diff --git a/aom_dsp/avg.c b/aom_dsp/avg.c
index ceb1026..e6eab47 100644
--- a/aom_dsp/avg.c
+++ b/aom_dsp/avg.c
@@ -87,7 +87,7 @@
   int i, j;
   const uint16_t *s = CONVERT_TO_SHORTPTR(s8);
   const uint16_t *d = CONVERT_TO_SHORTPTR(d8);
-  *min = 255;
+  *min = 65535;
   *max = 0;
   for (i = 0; i < 8; ++i, s += p, d += dp) {
     for (j = 0; j < 8; ++j) {
@@ -99,14 +99,6 @@
 }
 #endif  // CONFIG_AV1_HIGHBITDEPTH
 
-void aom_pixel_scale_c(const int16_t *src_diff, ptrdiff_t src_stride,
-                       int16_t *coeff, int log_scale, int h8, int w8) {
-  for (int idy = 0; idy < h8 * 8; ++idy)
-    for (int idx = 0; idx < w8 * 8; ++idx)
-      coeff[idy * (h8 * 8) + idx] = src_diff[idy * src_stride + idx]
-                                    << log_scale;
-}
-
 static void hadamard_col4(const int16_t *src_diff, ptrdiff_t src_stride,
                           int16_t *coeff) {
   int16_t b0 = (src_diff[0 * src_stride] + src_diff[1 * src_stride]) >> 1;
diff --git a/aom_dsp/entdec.c b/aom_dsp/entdec.c
index da43e8a..5bbcdda 100644
--- a/aom_dsp/entdec.c
+++ b/aom_dsp/entdec.c
@@ -205,14 +205,14 @@
   assert(dif >> (OD_EC_WINDOW_SIZE - 16) < r);
   assert(icdf[nsyms - 1] == OD_ICDF(CDF_PROB_TOP));
   assert(32768U <= r);
-  assert(7 - EC_PROB_SHIFT - CDF_SHIFT >= 0);
+  assert(7 - EC_PROB_SHIFT >= 0);
   c = (unsigned)(dif >> (OD_EC_WINDOW_SIZE - 16));
   v = r;
   ret = -1;
   do {
     u = v;
     v = ((r >> 8) * (uint32_t)(icdf[++ret] >> EC_PROB_SHIFT) >>
-         (7 - EC_PROB_SHIFT - CDF_SHIFT));
+         (7 - EC_PROB_SHIFT));
     v += EC_MIN_PROB * (N - ret);
   } while (c < v);
   assert(v < u);
diff --git a/aom_dsp/entenc.c b/aom_dsp/entenc.c
index 2fd4493..dfc1624 100644
--- a/aom_dsp/entenc.c
+++ b/aom_dsp/entenc.c
@@ -49,11 +49,11 @@
   }*/
 
 /*Takes updated low and range values, renormalizes them so that
-   32768 <= rng < 65536 (flushing bytes from low to the pre-carry buffer if
+   32768 <= rng < 65536 (flushing bytes from low to the output buffer if
    necessary), and stores them back in the encoder context.
   low: The new value of low.
   rng: The new value of the range.*/
-static void od_ec_enc_normalize(od_ec_enc *enc, od_ec_window low,
+static void od_ec_enc_normalize(od_ec_enc *enc, od_ec_enc_window low,
                                 unsigned rng) {
   int d;
   int c;
@@ -63,44 +63,59 @@
   /*The number of leading zeros in the 16-bit binary representation of rng.*/
   d = 16 - OD_ILOG_NZ(rng);
   s = c + d;
-  /*TODO: Right now we flush every time we have at least one byte available.
-    Instead we should use an od_ec_window and flush right before we're about to
-     shift bits off the end of the window.
-    For a 32-bit window this is about the same amount of work, but for a 64-bit
-     window it should be a fair win.*/
-  if (s >= 0) {
-    uint16_t *buf;
-    uint32_t storage;
-    uint32_t offs;
-    unsigned m;
-    buf = enc->precarry_buf;
-    storage = enc->precarry_storage;
-    offs = enc->offs;
-    if (offs + 2 > storage) {
-      storage = 2 * storage + 2;
-      buf = (uint16_t *)realloc(buf, sizeof(*buf) * storage);
-      if (buf == NULL) {
+
+  /* We flush every time "low" cannot safely and efficiently accommodate any
+     more data. Overall, c must not exceed 63 at the time of byte flush out. To
+     facilitate this, "s" cannot exceed 56-bits because we have to keep 1 byte
+     for carry. Also, we need to subtract 16 because we want to keep room for
+     the next symbol worth "d"-bits (max 15). An alternate condition would be if
+     (e < d), where e = number of leading zeros in "low", indicating there is
+     not enough rooom to accommodate "rng" worth of "d"-bits in "low". However,
+     this approach needs additional computations: (i) compute "e", (ii) push
+     the leading 0x00's as a special case.
+  */
+  if (s >= 40) {  // 56 - 16
+    unsigned char *out = enc->buf;
+    uint32_t storage = enc->storage;
+    uint32_t offs = enc->offs;
+    if (offs + 8 > storage) {
+      storage = 2 * storage + 8;
+      out = (unsigned char *)realloc(out, sizeof(*out) * storage);
+      if (out == NULL) {
         enc->error = -1;
         enc->offs = 0;
         return;
       }
-      enc->precarry_buf = buf;
-      enc->precarry_storage = storage;
+      enc->buf = out;
+      enc->storage = storage;
     }
-    c += 16;
-    m = (1 << c) - 1;
-    if (s >= 8) {
-      assert(offs < storage);
-      buf[offs++] = (uint16_t)(low >> c);
-      low &= m;
-      c -= 8;
-      m >>= 8;
-    }
-    assert(offs < storage);
-    buf[offs++] = (uint16_t)(low >> c);
+    // Need to add 1 byte here since enc->cnt always counts 1 byte less
+    // (enc->cnt = -9) to ensure correct operation
+    uint8_t num_bytes_ready = (s >> 3) + 1;
+
+    // Update "c" to contain the number of non-ready bits in "low". Since "low"
+    // has 64-bit capacity, we need to add the (64 - 40) cushion bits and take
+    // off the number of ready bits.
+    c += 24 - (num_bytes_ready << 3);
+
+    // Prepare "output" and update "low"
+    uint64_t output = low >> c;
+    low = low & (((uint64_t)1 << c) - 1);
+
+    // Prepare data and carry mask
+    uint64_t mask = (uint64_t)1 << (num_bytes_ready << 3);
+    uint64_t carry = output & mask;
+
+    mask = mask - 0x01;
+    output = output & mask;
+
+    // Write data in a single operation
+    write_enc_data_to_out_buf(out, offs, output, carry, &enc->offs,
+                              num_bytes_ready);
+
+    // Update state of the encoder: enc->cnt to contain the number of residual
+    // bits
     s = c + d - 24;
-    low &= m;
-    enc->offs = offs;
   }
   enc->low = low << d;
   enc->rng = rng << d;
@@ -117,12 +132,6 @@
     enc->storage = 0;
     enc->error = -1;
   }
-  enc->precarry_buf = (uint16_t *)malloc(sizeof(*enc->precarry_buf) * size);
-  enc->precarry_storage = size;
-  if (size > 0 && enc->precarry_buf == NULL) {
-    enc->precarry_storage = 0;
-    enc->error = -1;
-  }
 }
 
 /*Reinitializes the encoder.*/
@@ -141,21 +150,16 @@
 }
 
 /*Frees the buffers used by the encoder.*/
-void od_ec_enc_clear(od_ec_enc *enc) {
-  free(enc->precarry_buf);
-  free(enc->buf);
-}
+void od_ec_enc_clear(od_ec_enc *enc) { free(enc->buf); }
 
 /*Encodes a symbol given its frequency in Q15.
   fl: CDF_PROB_TOP minus the cumulative frequency of all symbols that come
-  before the
-       one to be encoded.
+  before the one to be encoded.
   fh: CDF_PROB_TOP minus the cumulative frequency of all symbols up to and
-  including
-       the one to be encoded.*/
+  including the one to be encoded.*/
 static void od_ec_encode_q15(od_ec_enc *enc, unsigned fl, unsigned fh, int s,
                              int nsyms) {
-  od_ec_window l;
+  od_ec_enc_window l;
   unsigned r;
   unsigned u;
   unsigned v;
@@ -164,20 +168,17 @@
   assert(32768U <= r);
   assert(fh <= fl);
   assert(fl <= 32768U);
-  assert(7 - EC_PROB_SHIFT - CDF_SHIFT >= 0);
+  assert(7 - EC_PROB_SHIFT >= 0);
   const int N = nsyms - 1;
   if (fl < CDF_PROB_TOP) {
-    u = ((r >> 8) * (uint32_t)(fl >> EC_PROB_SHIFT) >>
-         (7 - EC_PROB_SHIFT - CDF_SHIFT)) +
+    u = ((r >> 8) * (uint32_t)(fl >> EC_PROB_SHIFT) >> (7 - EC_PROB_SHIFT)) +
         EC_MIN_PROB * (N - (s - 1));
-    v = ((r >> 8) * (uint32_t)(fh >> EC_PROB_SHIFT) >>
-         (7 - EC_PROB_SHIFT - CDF_SHIFT)) +
+    v = ((r >> 8) * (uint32_t)(fh >> EC_PROB_SHIFT) >> (7 - EC_PROB_SHIFT)) +
         EC_MIN_PROB * (N - (s + 0));
     l += r - u;
     r = u - v;
   } else {
-    r -= ((r >> 8) * (uint32_t)(fh >> EC_PROB_SHIFT) >>
-          (7 - EC_PROB_SHIFT - CDF_SHIFT)) +
+    r -= ((r >> 8) * (uint32_t)(fh >> EC_PROB_SHIFT) >> (7 - EC_PROB_SHIFT)) +
          EC_MIN_PROB * (N - (s + 0));
   }
   od_ec_enc_normalize(enc, l, r);
@@ -191,7 +192,7 @@
   val: The value to encode (0 or 1).
   f: The probability that the val is one, scaled by 32768.*/
 void od_ec_encode_bool_q15(od_ec_enc *enc, int val, unsigned f) {
-  od_ec_window l;
+  od_ec_enc_window l;
   unsigned r;
   unsigned v;
   assert(0 < f);
@@ -251,12 +252,11 @@
   mask = ((1U << nbits) - 1) << shift;
   if (enc->offs > 0) {
     /*The first byte has been finalized.*/
-    enc->precarry_buf[0] =
-        (uint16_t)((enc->precarry_buf[0] & ~mask) | val << shift);
+    enc->buf[0] = (unsigned char)((enc->buf[0] & ~mask) | val << shift);
   } else if (9 + enc->cnt + (enc->rng == 0x8000) > nbits) {
     /*The first byte has yet to be output.*/
-    enc->low = (enc->low & ~((od_ec_window)mask << (16 + enc->cnt))) |
-               (od_ec_window)val << (16 + enc->cnt + shift);
+    enc->low = (enc->low & ~((od_ec_enc_window)mask << (16 + enc->cnt))) |
+               (od_ec_enc_window)val << (16 + enc->cnt + shift);
   } else {
     /*The encoder hasn't even encoded _nbits of data yet.*/
     enc->error = -1;
@@ -276,11 +276,10 @@
 unsigned char *od_ec_enc_done(od_ec_enc *enc, uint32_t *nbytes) {
   unsigned char *out;
   uint32_t storage;
-  uint16_t *buf;
   uint32_t offs;
-  od_ec_window m;
-  od_ec_window e;
-  od_ec_window l;
+  od_ec_enc_window m;
+  od_ec_enc_window e;
+  od_ec_enc_window l;
   int c;
   int s;
   if (enc->error) return NULL;
@@ -295,8 +294,7 @@
             (double)tell / enc->nb_symbols);
   }
 #endif
-  /*We output the minimum number of bits that ensures that the symbols encoded
-     thus far will be decoded correctly regardless of the bits that follow.*/
+
   l = enc->low;
   c = enc->cnt;
   s = 10;
@@ -304,36 +302,14 @@
   e = ((l + m) & ~m) | (m + 1);
   s += c;
   offs = enc->offs;
-  buf = enc->precarry_buf;
-  if (s > 0) {
-    unsigned n;
-    storage = enc->precarry_storage;
-    if (offs + ((s + 7) >> 3) > storage) {
-      storage = storage * 2 + ((s + 7) >> 3);
-      buf = (uint16_t *)realloc(buf, sizeof(*buf) * storage);
-      if (buf == NULL) {
-        enc->error = -1;
-        return NULL;
-      }
-      enc->precarry_buf = buf;
-      enc->precarry_storage = storage;
-    }
-    n = (1 << (c + 16)) - 1;
-    do {
-      assert(offs < storage);
-      buf[offs++] = (uint16_t)(e >> (c + 16));
-      e &= n;
-      s -= 8;
-      c -= 8;
-      n >>= 8;
-    } while (s > 0);
-  }
+
   /*Make sure there's enough room for the entropy-coded bits.*/
   out = enc->buf;
   storage = enc->storage;
-  c = OD_MAXI((s + 7) >> 3, 0);
-  if (offs + c > storage) {
-    storage = offs + c;
+  const int s_bits = (s + 7) >> 3;
+  int b = OD_MAXI(s_bits, 0);
+  if (offs + b > storage) {
+    storage = offs + b;
     out = (unsigned char *)realloc(out, sizeof(*out) * storage);
     if (out == NULL) {
       enc->error = -1;
@@ -342,23 +318,30 @@
     enc->buf = out;
     enc->storage = storage;
   }
-  *nbytes = offs;
-  /*Perform carry propagation.*/
-  assert(offs <= storage);
-  out = out + storage - offs;
-  c = 0;
-  while (offs > 0) {
-    offs--;
-    c = buf[offs] + c;
-    out[offs] = (unsigned char)c;
-    c >>= 8;
+
+  /*We output the minimum number of bits that ensures that the symbols encoded
+     thus far will be decoded correctly regardless of the bits that follow.*/
+  if (s > 0) {
+    uint64_t n;
+    n = ((uint64_t)1 << (c + 16)) - 1;
+    do {
+      assert(offs < storage);
+      uint16_t val = (uint16_t)(e >> (c + 16));
+      out[offs] = (unsigned char)(val & 0x00FF);
+      if (val & 0x0100) {
+        assert(offs > 0);
+        propagate_carry_bwd(out, offs - 1);
+      }
+      offs++;
+
+      e &= n;
+      s -= 8;
+      c -= 8;
+      n >>= 8;
+    } while (s > 0);
   }
-  /*Note: Unless there's an allocation error, if you keep encoding into the
-     current buffer and call this function again later, everything will work
-     just fine (you won't get a new packet out, but you will get a single
-     buffer with the new data appended to the old).
-    However, this function is O(N) where N is the amount of data coded so far,
-     so calling it more than once for a given packet is a bad idea.*/
+  *nbytes = offs;
+
   return out;
 }
 
@@ -407,17 +390,10 @@
 void od_ec_enc_rollback(od_ec_enc *dst, const od_ec_enc *src) {
   unsigned char *buf;
   uint32_t storage;
-  uint16_t *precarry_buf;
-  uint32_t precarry_storage;
   assert(dst->storage >= src->storage);
-  assert(dst->precarry_storage >= src->precarry_storage);
   buf = dst->buf;
   storage = dst->storage;
-  precarry_buf = dst->precarry_buf;
-  precarry_storage = dst->precarry_storage;
   OD_COPY(dst, src, 1);
   dst->buf = buf;
   dst->storage = storage;
-  dst->precarry_buf = precarry_buf;
-  dst->precarry_storage = precarry_storage;
 }
diff --git a/aom_dsp/entenc.h b/aom_dsp/entenc.h
index 3551d42..467e47b 100644
--- a/aom_dsp/entenc.h
+++ b/aom_dsp/entenc.h
@@ -13,11 +13,14 @@
 #define AOM_AOM_DSP_ENTENC_H_
 #include <stddef.h>
 #include "aom_dsp/entcode.h"
+#include "aom_ports/bitops.h"
 
 #ifdef __cplusplus
 extern "C" {
 #endif
 
+typedef uint64_t od_ec_enc_window;
+
 typedef struct od_ec_enc od_ec_enc;
 
 #define OD_MEASURE_EC_OVERHEAD (0)
@@ -30,14 +33,10 @@
   unsigned char *buf;
   /*The size of the buffer.*/
   uint32_t storage;
-  /*A buffer for output bytes with their associated carry flags.*/
-  uint16_t *precarry_buf;
-  /*The size of the pre-carry buffer.*/
-  uint32_t precarry_storage;
   /*The offset at which the next entropy-coded byte will be written.*/
   uint32_t offs;
   /*The low end of the current range.*/
-  od_ec_window low;
+  od_ec_enc_window low;
   /*The number of values in the current range.*/
   uint16_t rng;
   /*The number of bits of data in the current value.*/
@@ -78,6 +77,32 @@
 void od_ec_enc_checkpoint(od_ec_enc *dst, const od_ec_enc *src);
 void od_ec_enc_rollback(od_ec_enc *dst, const od_ec_enc *src);
 
+// buf is the frame bitbuffer, offs is where carry to be added
+static AOM_INLINE void propagate_carry_bwd(unsigned char *buf, uint32_t offs) {
+  uint16_t sum, carry = 1;
+  do {
+    sum = (uint16_t)buf[offs] + 1;
+    buf[offs--] = (unsigned char)sum;
+    carry = sum >> 8;
+  } while (carry);
+}
+
+// Reverse byte order and write data to buffer adding the carry-bit
+static AOM_INLINE void write_enc_data_to_out_buf(unsigned char *out,
+                                                 uint32_t offs, uint64_t output,
+                                                 uint64_t carry,
+                                                 uint32_t *enc_offs,
+                                                 uint8_t num_bytes_ready) {
+  const uint64_t reg = get_byteswap64(output) >> ((8 - num_bytes_ready) << 3);
+  memcpy(&out[offs], &reg, 8);
+  // Propagate carry backwards if exists
+  if (carry) {
+    assert(offs > 0);
+    propagate_carry_bwd(out, offs - 1);
+  }
+  *enc_offs = offs + num_bytes_ready;
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/aom_dsp/flow_estimation/corner_detect.c b/aom_dsp/flow_estimation/corner_detect.c
index c49e3fa..7848295 100644
--- a/aom_dsp/flow_estimation/corner_detect.c
+++ b/aom_dsp/flow_estimation/corner_detect.c
@@ -17,21 +17,149 @@
 
 #include "third_party/fastfeat/fast.h"
 
+#include "aom_dsp/aom_dsp_common.h"
 #include "aom_dsp/flow_estimation/corner_detect.h"
+#include "aom_mem/aom_mem.h"
+#include "av1/common/common.h"
 
-// Fast_9 wrapper
 #define FAST_BARRIER 18
-int av1_fast_corner_detect(unsigned char *buf, int width, int height,
-                           int stride, int *points, int max_points) {
-  int num_points;
-  xy *const frm_corners_xy = aom_fast9_detect_nonmax(buf, width, height, stride,
-                                                     FAST_BARRIER, &num_points);
-  num_points = (num_points <= max_points ? num_points : max_points);
-  if (num_points > 0 && frm_corners_xy) {
-    memcpy(points, frm_corners_xy, sizeof(*frm_corners_xy) * num_points);
-    free(frm_corners_xy);
-    return num_points;
+
+size_t av1_get_corner_list_size() { return sizeof(CornerList); }
+
+CornerList *av1_alloc_corner_list() {
+  CornerList *corners = (CornerList *)aom_calloc(1, sizeof(CornerList));
+  if (!corners) {
+    return NULL;
   }
-  free(frm_corners_xy);
-  return 0;
+
+  corners->valid = false;
+#if CONFIG_MULTITHREAD
+  pthread_mutex_init(&corners->mutex, NULL);
+#endif  // CONFIG_MULTITHREAD
+  return corners;
+}
+
+void compute_corner_list(const ImagePyramid *pyr, CornerList *corners) {
+  const uint8_t *buf = pyr->layers[0].buffer;
+  int width = pyr->layers[0].width;
+  int height = pyr->layers[0].height;
+  int stride = pyr->layers[0].stride;
+
+  int *scores = NULL;
+  int num_corners;
+  xy *const frame_corners_xy = aom_fast9_detect_nonmax(
+      buf, width, height, stride, FAST_BARRIER, &scores, &num_corners);
+
+  if (num_corners <= 0) {
+    // Some error occured, so no corners are available
+    corners->num_corners = 0;
+  } else if (num_corners <= MAX_CORNERS) {
+    // Use all detected corners
+    memcpy(corners->corners, frame_corners_xy,
+           sizeof(*frame_corners_xy) * num_corners);
+    corners->num_corners = num_corners;
+  } else {
+    // There are more than MAX_CORNERS corners avilable, so pick out a subset
+    // of the sharpest corners, as these will be the most useful for flow
+    // estimation
+    int histogram[256];
+    av1_zero(histogram);
+    for (int i = 0; i < num_corners; i++) {
+      assert(FAST_BARRIER <= scores[i] && scores[i] <= 255);
+      histogram[scores[i]] += 1;
+    }
+
+    int threshold = -1;
+    int found_corners = 0;
+    for (int bucket = 255; bucket >= 0; bucket--) {
+      if (found_corners + histogram[bucket] > MAX_CORNERS) {
+        // Set threshold here
+        threshold = bucket;
+        break;
+      }
+      found_corners += histogram[bucket];
+    }
+    assert(threshold != -1 && "Failed to select a valid threshold");
+
+    int copied_corners = 0;
+    for (int i = 0; i < num_corners; i++) {
+      if (scores[i] > threshold) {
+        assert(copied_corners < MAX_CORNERS);
+        corners->corners[2 * copied_corners + 0] = frame_corners_xy[i].x;
+        corners->corners[2 * copied_corners + 1] = frame_corners_xy[i].y;
+        copied_corners += 1;
+      }
+    }
+    assert(copied_corners == found_corners);
+    corners->num_corners = copied_corners;
+  }
+
+  free(scores);
+  free(frame_corners_xy);
+}
+
+void av1_compute_corner_list(const ImagePyramid *pyr, CornerList *corners) {
+  assert(corners);
+
+#if CONFIG_MULTITHREAD
+  pthread_mutex_lock(&corners->mutex);
+#endif  // CONFIG_MULTITHREAD
+
+  if (!corners->valid) {
+    compute_corner_list(pyr, corners);
+    corners->valid = true;
+  }
+
+#if CONFIG_MULTITHREAD
+  pthread_mutex_unlock(&corners->mutex);
+#endif  // CONFIG_MULTITHREAD
+}
+
+#ifndef NDEBUG
+// Check if a corner list has already been computed.
+// This is mostly a debug helper - as it is necessary to hold corners->mutex
+// while reading the valid flag, we cannot just write:
+//   assert(corners->valid);
+// This function allows the check to be correctly written as:
+//   assert(aom_is_corner_list_valid(corners));
+bool aom_is_corner_list_valid(CornerList *corners) {
+  assert(corners);
+
+  // Per the comments in the CornerList struct, we must take this mutex
+  // before reading or writing the "valid" flag, and hold it while computing
+  // the pyramid, to ensure proper behaviour if multiple threads call this
+  // function simultaneously
+#if CONFIG_MULTITHREAD
+  pthread_mutex_lock(&corners->mutex);
+#endif  // CONFIG_MULTITHREAD
+
+  bool valid = corners->valid;
+
+#if CONFIG_MULTITHREAD
+  pthread_mutex_unlock(&corners->mutex);
+#endif  // CONFIG_MULTITHREAD
+
+  return valid;
+}
+#endif
+
+void av1_invalidate_corner_list(CornerList *corners) {
+  if (corners) {
+#if CONFIG_MULTITHREAD
+    pthread_mutex_lock(&corners->mutex);
+#endif  // CONFIG_MULTITHREAD
+    corners->valid = false;
+#if CONFIG_MULTITHREAD
+    pthread_mutex_unlock(&corners->mutex);
+#endif  // CONFIG_MULTITHREAD
+  }
+}
+
+void av1_free_corner_list(CornerList *corners) {
+  if (corners) {
+#if CONFIG_MULTITHREAD
+    pthread_mutex_destroy(&corners->mutex);
+#endif  // CONFIG_MULTITHREAD
+    aom_free(corners);
+  }
 }
diff --git a/aom_dsp/flow_estimation/corner_detect.h b/aom_dsp/flow_estimation/corner_detect.h
index 4481c4e..c77813e 100644
--- a/aom_dsp/flow_estimation/corner_detect.h
+++ b/aom_dsp/flow_estimation/corner_detect.h
@@ -14,14 +14,64 @@
 
 #include <stdio.h>
 #include <stdlib.h>
+#include <stdbool.h>
 #include <memory.h>
 
+#include "aom_dsp/pyramid.h"
+#include "aom_util/aom_thread.h"
+
 #ifdef __cplusplus
 extern "C" {
 #endif
 
-int av1_fast_corner_detect(unsigned char *buf, int width, int height,
-                           int stride, int *points, int max_points);
+#define MAX_CORNERS 4096
+
+typedef struct corner_list {
+#if CONFIG_MULTITHREAD
+  // Mutex which is used to prevent the corner list from being computed twice
+  // at the same time
+  //
+  // Semantics:
+  // * This mutex must be held whenever reading or writing the `valid` flag
+  //
+  // * This mutex must also be held while computing the image pyramid,
+  //   to ensure that only one thread may do so at a time.
+  //
+  // * However, once you have read the valid flag and seen a true value,
+  //   it is safe to drop the mutex and read from the remaining fields.
+  //   This is because, once the image pyramid is computed, its contents
+  //   will not be changed until the parent frame buffer is recycled,
+  //   which will not happen until there are no more outstanding references
+  //   to the frame buffer.
+  pthread_mutex_t mutex;
+#endif  // CONFIG_MULTITHREAD
+  // Flag indicating whether the corner list contains valid data
+  bool valid;
+  // Number of corners found
+  int num_corners;
+  // (x, y) coordinates of each corner
+  int corners[2 * MAX_CORNERS];
+} CornerList;
+
+size_t av1_get_corner_list_size();
+
+CornerList *av1_alloc_corner_list();
+
+void av1_compute_corner_list(const ImagePyramid *pyr, CornerList *corners);
+
+#ifndef NDEBUG
+// Check if a corner list has already been computed.
+// This is mostly a debug helper - as it is necessary to hold corners->mutex
+// while reading the valid flag, we cannot just write:
+//   assert(corners->valid);
+// This function allows the check to be correctly written as:
+//   assert(aom_is_corner_list_valid(corners));
+bool aom_is_corner_list_valid(CornerList *corners);
+#endif
+
+void av1_invalidate_corner_list(CornerList *corners);
+
+void av1_free_corner_list(CornerList *corners);
 
 #ifdef __cplusplus
 }
diff --git a/aom_dsp/flow_estimation/corner_match.c b/aom_dsp/flow_estimation/corner_match.c
index f675604..f34178e 100644
--- a/aom_dsp/flow_estimation/corner_match.c
+++ b/aom_dsp/flow_estimation/corner_match.c
@@ -19,6 +19,7 @@
 #include "aom_dsp/flow_estimation/corner_match.h"
 #include "aom_dsp/flow_estimation/flow_estimation.h"
 #include "aom_dsp/flow_estimation/ransac.h"
+#include "aom_dsp/pyramid.h"
 #include "aom_scale/yv12config.h"
 
 #define SEARCH_SZ 9
@@ -26,30 +27,32 @@
 
 #define THRESHOLD_NCC 0.75
 
-/* Compute var(im) * MATCH_SZ_SQ over a MATCH_SZ by MATCH_SZ window of im,
+/* Compute var(frame) * MATCH_SZ_SQ over a MATCH_SZ by MATCH_SZ window of frame,
    centered at (x, y).
 */
-static double compute_variance(unsigned char *im, int stride, int x, int y) {
+static double compute_variance(const unsigned char *frame, int stride, int x,
+                               int y) {
   int sum = 0;
   int sumsq = 0;
   int var;
   int i, j;
   for (i = 0; i < MATCH_SZ; ++i)
     for (j = 0; j < MATCH_SZ; ++j) {
-      sum += im[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
-      sumsq += im[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)] *
-               im[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
+      sum += frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
+      sumsq += frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)] *
+               frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
     }
   var = sumsq * MATCH_SZ_SQ - sum * sum;
   return (double)var;
 }
 
-/* Compute corr(im1, im2) * MATCH_SZ * stddev(im1), where the
+/* Compute corr(frame1, frame2) * MATCH_SZ * stddev(frame1), where the
    correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows
    of each image, centered at (x1, y1) and (x2, y2) respectively.
 */
-double av1_compute_cross_correlation_c(unsigned char *im1, int stride1, int x1,
-                                       int y1, unsigned char *im2, int stride2,
+double av1_compute_cross_correlation_c(const unsigned char *frame1, int stride1,
+                                       int x1, int y1,
+                                       const unsigned char *frame2, int stride2,
                                        int x2, int y2) {
   int v1, v2;
   int sum1 = 0;
@@ -60,8 +63,8 @@
   int i, j;
   for (i = 0; i < MATCH_SZ; ++i)
     for (j = 0; j < MATCH_SZ; ++j) {
-      v1 = im1[(i + y1 - MATCH_SZ_BY2) * stride1 + (j + x1 - MATCH_SZ_BY2)];
-      v2 = im2[(i + y2 - MATCH_SZ_BY2) * stride2 + (j + x2 - MATCH_SZ_BY2)];
+      v1 = frame1[(i + y1 - MATCH_SZ_BY2) * stride1 + (j + x1 - MATCH_SZ_BY2)];
+      v2 = frame2[(i + y2 - MATCH_SZ_BY2) * stride2 + (j + x2 - MATCH_SZ_BY2)];
       sum1 += v1;
       sum2 += v2;
       sumsq2 += v2 * v2;
@@ -84,28 +87,30 @@
           (point1y - point2y) * (point1y - point2y)) <= thresh * thresh;
 }
 
-static void improve_correspondence(unsigned char *frm, unsigned char *ref,
-                                   int width, int height, int frm_stride,
-                                   int ref_stride,
+static void improve_correspondence(const unsigned char *src,
+                                   const unsigned char *ref, int width,
+                                   int height, int src_stride, int ref_stride,
                                    Correspondence *correspondences,
                                    int num_correspondences) {
   int i;
   for (i = 0; i < num_correspondences; ++i) {
     int x, y, best_x = 0, best_y = 0;
     double best_match_ncc = 0.0;
+    // For this algorithm, all points have integer coordinates.
+    // It's a little more efficient to convert them to ints once,
+    // before the inner loops
+    int x0 = (int)correspondences[i].x;
+    int y0 = (int)correspondences[i].y;
+    int rx0 = (int)correspondences[i].rx;
+    int ry0 = (int)correspondences[i].ry;
     for (y = -SEARCH_SZ_BY2; y <= SEARCH_SZ_BY2; ++y) {
       for (x = -SEARCH_SZ_BY2; x <= SEARCH_SZ_BY2; ++x) {
         double match_ncc;
-        if (!is_eligible_point(correspondences[i].rx + x,
-                               correspondences[i].ry + y, width, height))
+        if (!is_eligible_point(rx0 + x, ry0 + y, width, height)) continue;
+        if (!is_eligible_distance(x0, y0, rx0 + x, ry0 + y, width, height))
           continue;
-        if (!is_eligible_distance(correspondences[i].x, correspondences[i].y,
-                                  correspondences[i].rx + x,
-                                  correspondences[i].ry + y, width, height))
-          continue;
-        match_ncc = av1_compute_cross_correlation(
-            frm, frm_stride, correspondences[i].x, correspondences[i].y, ref,
-            ref_stride, correspondences[i].rx + x, correspondences[i].ry + y);
+        match_ncc = av1_compute_cross_correlation(src, src_stride, x0, y0, ref,
+                                                  ref_stride, rx0 + x, ry0 + y);
         if (match_ncc > best_match_ncc) {
           best_match_ncc = match_ncc;
           best_y = y;
@@ -119,19 +124,18 @@
   for (i = 0; i < num_correspondences; ++i) {
     int x, y, best_x = 0, best_y = 0;
     double best_match_ncc = 0.0;
+    int x0 = (int)correspondences[i].x;
+    int y0 = (int)correspondences[i].y;
+    int rx0 = (int)correspondences[i].rx;
+    int ry0 = (int)correspondences[i].ry;
     for (y = -SEARCH_SZ_BY2; y <= SEARCH_SZ_BY2; ++y)
       for (x = -SEARCH_SZ_BY2; x <= SEARCH_SZ_BY2; ++x) {
         double match_ncc;
-        if (!is_eligible_point(correspondences[i].x + x,
-                               correspondences[i].y + y, width, height))
-          continue;
-        if (!is_eligible_distance(
-                correspondences[i].x + x, correspondences[i].y + y,
-                correspondences[i].rx, correspondences[i].ry, width, height))
+        if (!is_eligible_point(x0 + x, y0 + y, width, height)) continue;
+        if (!is_eligible_distance(x0 + x, y0 + y, rx0, ry0, width, height))
           continue;
         match_ncc = av1_compute_cross_correlation(
-            ref, ref_stride, correspondences[i].rx, correspondences[i].ry, frm,
-            frm_stride, correspondences[i].x + x, correspondences[i].y + y);
+            ref, ref_stride, rx0, ry0, src, src_stride, x0 + x, y0 + y);
         if (match_ncc > best_match_ncc) {
           best_match_ncc = match_ncc;
           best_y = y;
@@ -143,14 +147,15 @@
   }
 }
 
-int aom_determine_correspondence(unsigned char *src, int *src_corners,
-                                 int num_src_corners, unsigned char *ref,
-                                 int *ref_corners, int num_ref_corners,
+int aom_determine_correspondence(const unsigned char *src,
+                                 const int *src_corners, int num_src_corners,
+                                 const unsigned char *ref,
+                                 const int *ref_corners, int num_ref_corners,
                                  int width, int height, int src_stride,
-                                 int ref_stride, int *correspondence_pts) {
+                                 int ref_stride,
+                                 Correspondence *correspondences) {
   // TODO(sarahparker) Improve this to include 2-way match
   int i, j;
-  Correspondence *correspondences = (Correspondence *)correspondence_pts;
   int num_correspondences = 0;
   for (i = 0; i < num_src_corners; ++i) {
     double best_match_ncc = 0.0;
@@ -195,71 +200,44 @@
   return num_correspondences;
 }
 
-static bool get_inliers_from_indices(MotionModel *params,
-                                     int *correspondences) {
-  int *inliers_tmp = (int *)aom_calloc(2 * MAX_CORNERS, sizeof(*inliers_tmp));
-  if (!inliers_tmp) return false;
-
-  for (int i = 0; i < params->num_inliers; i++) {
-    int index = params->inliers[i];
-    inliers_tmp[2 * i] = correspondences[4 * index];
-    inliers_tmp[2 * i + 1] = correspondences[4 * index + 1];
-  }
-  memcpy(params->inliers, inliers_tmp, sizeof(*inliers_tmp) * 2 * MAX_CORNERS);
-  aom_free(inliers_tmp);
-  return true;
-}
-
-int av1_compute_global_motion_feature_based(
-    TransformationType type, unsigned char *src_buffer, int src_width,
-    int src_height, int src_stride, int *src_corners, int num_src_corners,
-    YV12_BUFFER_CONFIG *ref, int bit_depth, int *num_inliers_by_motion,
-    MotionModel *params_by_motion, int num_motions) {
-  int i;
-  int num_ref_corners;
+bool av1_compute_global_motion_feature_match(
+    TransformationType type, YV12_BUFFER_CONFIG *src, YV12_BUFFER_CONFIG *ref,
+    int bit_depth, MotionModel *motion_models, int num_motion_models) {
   int num_correspondences;
-  int *correspondences;
-  int ref_corners[2 * MAX_CORNERS];
-  unsigned char *ref_buffer = ref->y_buffer;
-  RansacFunc ransac = av1_get_ransac_type(type);
+  Correspondence *correspondences;
+  ImagePyramid *src_pyramid = src->y_pyramid;
+  CornerList *src_corners = src->corners;
+  ImagePyramid *ref_pyramid = ref->y_pyramid;
+  CornerList *ref_corners = ref->corners;
 
-  if (ref->flags & YV12_FLAG_HIGHBITDEPTH) {
-    ref_buffer = av1_downconvert_frame(ref, bit_depth);
-  }
+  // Precompute information we will need about each frame
+  aom_compute_pyramid(src, bit_depth, src_pyramid);
+  av1_compute_corner_list(src_pyramid, src_corners);
+  aom_compute_pyramid(ref, bit_depth, ref_pyramid);
+  av1_compute_corner_list(ref_pyramid, ref_corners);
 
-  num_ref_corners =
-      av1_fast_corner_detect(ref_buffer, ref->y_width, ref->y_height,
-                             ref->y_stride, ref_corners, MAX_CORNERS);
+  const uint8_t *src_buffer = src_pyramid->layers[0].buffer;
+  const int src_width = src_pyramid->layers[0].width;
+  const int src_height = src_pyramid->layers[0].height;
+  const int src_stride = src_pyramid->layers[0].stride;
+
+  const uint8_t *ref_buffer = ref_pyramid->layers[0].buffer;
+  assert(ref_pyramid->layers[0].width == src_width);
+  assert(ref_pyramid->layers[0].height == src_height);
+  const int ref_stride = ref_pyramid->layers[0].stride;
 
   // find correspondences between the two images
-  correspondences =
-      (int *)aom_malloc(num_src_corners * 4 * sizeof(*correspondences));
-  if (!correspondences) return 0;
+  correspondences = (Correspondence *)aom_malloc(src_corners->num_corners *
+                                                 sizeof(*correspondences));
+  if (!correspondences) return false;
   num_correspondences = aom_determine_correspondence(
-      src_buffer, (int *)src_corners, num_src_corners, ref_buffer,
-      (int *)ref_corners, num_ref_corners, src_width, src_height, src_stride,
-      ref->y_stride, correspondences);
+      src_buffer, src_corners->corners, src_corners->num_corners, ref_buffer,
+      ref_corners->corners, ref_corners->num_corners, src_width, src_height,
+      src_stride, ref_stride, correspondences);
 
-  ransac(correspondences, num_correspondences, num_inliers_by_motion,
-         params_by_motion, num_motions);
-
-  // Set num_inliers = 0 for motions with too few inliers so they are ignored.
-  for (i = 0; i < num_motions; ++i) {
-    if (num_inliers_by_motion[i] < MIN_INLIER_PROB * num_correspondences ||
-        num_correspondences == 0) {
-      num_inliers_by_motion[i] = 0;
-    } else if (!get_inliers_from_indices(&params_by_motion[i],
-                                         correspondences)) {
-      aom_free(correspondences);
-      return 0;
-    }
-  }
+  bool result = ransac(correspondences, num_correspondences, type,
+                       motion_models, num_motion_models);
 
   aom_free(correspondences);
-
-  // Return true if any one of the motions has inliers.
-  for (i = 0; i < num_motions; ++i) {
-    if (num_inliers_by_motion[i] > 0) return 1;
-  }
-  return 0;
+  return result;
 }
diff --git a/aom_dsp/flow_estimation/corner_match.h b/aom_dsp/flow_estimation/corner_match.h
index 71afadf..bb69944 100644
--- a/aom_dsp/flow_estimation/corner_match.h
+++ b/aom_dsp/flow_estimation/corner_match.h
@@ -12,10 +12,12 @@
 #ifndef AOM_AOM_DSP_FLOW_ESTIMATION_CORNER_MATCH_H_
 #define AOM_AOM_DSP_FLOW_ESTIMATION_CORNER_MATCH_H_
 
+#include <stdbool.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <memory.h>
 
+#include "aom_dsp/flow_estimation/corner_detect.h"
 #include "aom_dsp/flow_estimation/flow_estimation.h"
 #include "aom_scale/yv12config.h"
 
@@ -27,22 +29,17 @@
 #define MATCH_SZ_BY2 ((MATCH_SZ - 1) / 2)
 #define MATCH_SZ_SQ (MATCH_SZ * MATCH_SZ)
 
-typedef struct {
-  int x, y;
-  int rx, ry;
-} Correspondence;
-
-int aom_determine_correspondence(unsigned char *src, int *src_corners,
-                                 int num_src_corners, unsigned char *ref,
-                                 int *ref_corners, int num_ref_corners,
+int aom_determine_correspondence(const unsigned char *src,
+                                 const int *src_corners, int num_src_corners,
+                                 const unsigned char *ref,
+                                 const int *ref_corners, int num_ref_corners,
                                  int width, int height, int src_stride,
-                                 int ref_stride, int *correspondence_pts);
+                                 int ref_stride,
+                                 Correspondence *correspondences);
 
-int av1_compute_global_motion_feature_based(
-    TransformationType type, unsigned char *src_buffer, int src_width,
-    int src_height, int src_stride, int *src_corners, int num_src_corners,
-    YV12_BUFFER_CONFIG *ref, int bit_depth, int *num_inliers_by_motion,
-    MotionModel *params_by_motion, int num_motions);
+bool av1_compute_global_motion_feature_match(
+    TransformationType type, YV12_BUFFER_CONFIG *src, YV12_BUFFER_CONFIG *ref,
+    int bit_depth, MotionModel *motion_models, int num_motion_models);
 
 #ifdef __cplusplus
 }
diff --git a/aom_dsp/flow_estimation/disflow.c b/aom_dsp/flow_estimation/disflow.c
index 2a6ad4b..a8e7b06 100644
--- a/aom_dsp/flow_estimation/disflow.c
+++ b/aom_dsp/flow_estimation/disflow.c
@@ -9,626 +9,643 @@
  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
  */
 
-#include <stdbool.h>
-#include <stddef.h>
-#include <stdint.h>
+// Dense Inverse Search flow algorithm
+// Paper: https://arxiv.org/abs/1603.03590
 
+#include <assert.h>
+#include <math.h>
+
+#include "aom_dsp/aom_dsp_common.h"
+#include "aom_dsp/flow_estimation/corner_detect.h"
 #include "aom_dsp/flow_estimation/disflow.h"
-#include "aom_dsp/flow_estimation/flow_estimation.h"
 #include "aom_dsp/flow_estimation/ransac.h"
+#include "aom_dsp/pyramid.h"
+#include "aom_mem/aom_mem.h"
 
-#include "aom_scale/yv12config.h"
+#include "config/aom_dsp_rtcd.h"
 
+// TODO(rachelbarker):
+// Implement specialized functions for upscaling flow fields,
+// replacing av1_upscale_plane_double_prec().
+// Then we can avoid needing to include code from av1/
 #include "av1/common/resize.h"
 
-// Number of pyramid levels in disflow computation
-#define N_LEVELS 2
-// Size of square patches in the disflow dense grid
-#define PATCH_SIZE 8
-// Center point of square patch
-#define PATCH_CENTER ((PATCH_SIZE + 1) >> 1)
-// Step size between patches, lower value means greater patch overlap
-#define PATCH_STEP 1
-// Minimum size of border padding for disflow
-#define MIN_PAD 7
-// Warp error convergence threshold for disflow
-#define DISFLOW_ERROR_TR 0.01
-// Max number of iterations if warp convergence is not found
-#define DISFLOW_MAX_ITR 10
+// Amount to downsample the flow field by.
+// eg. DOWNSAMPLE_SHIFT = 2 (DOWNSAMPLE_FACTOR == 4) means we calculate
+// one flow point for each 4x4 pixel region of the frame
+// Must be a power of 2
+#define DOWNSAMPLE_SHIFT 3
+#define DOWNSAMPLE_FACTOR (1 << DOWNSAMPLE_SHIFT)
+// Number of outermost flow field entries (on each edge) which can't be
+// computed, because the patch they correspond to extends outside of the
+// frame
+// The border is (DISFLOW_PATCH_SIZE >> 1) pixels, which is
+// (DISFLOW_PATCH_SIZE >> 1) >> DOWNSAMPLE_SHIFT many flow field entries
+#define FLOW_BORDER ((DISFLOW_PATCH_SIZE >> 1) >> DOWNSAMPLE_SHIFT)
+// When downsampling the flow field, each flow field entry covers a square
+// region of pixels in the image pyramid. This value is equal to the position
+// of the center of that region, as an offset from the top/left edge.
+//
+// Note: Using ((DOWNSAMPLE_FACTOR - 1) / 2) is equivalent to the more
+// natural expression ((DOWNSAMPLE_FACTOR / 2) - 1),
+// unless DOWNSAMPLE_FACTOR == 1 (ie, no downsampling), in which case
+// this gives the correct offset of 0 instead of -1.
+#define UPSAMPLE_CENTER_OFFSET ((DOWNSAMPLE_FACTOR - 1) / 2)
 
-// Struct for an image pyramid
-typedef struct {
-  int n_levels;
-  int pad_size;
-  int has_gradient;
-  int widths[N_LEVELS];
-  int heights[N_LEVELS];
-  int strides[N_LEVELS];
-  int level_loc[N_LEVELS];
-  unsigned char *level_buffer;
-  double *level_dx_buffer;
-  double *level_dy_buffer;
-} ImagePyramid;
-
-// Don't use points around the frame border since they are less reliable
-static INLINE int valid_point(int x, int y, int width, int height) {
-  return (x > (PATCH_SIZE + PATCH_CENTER)) &&
-         (x < (width - PATCH_SIZE - PATCH_CENTER)) &&
-         (y > (PATCH_SIZE + PATCH_CENTER)) &&
-         (y < (height - PATCH_SIZE - PATCH_CENTER));
+static INLINE void get_cubic_kernel_dbl(double x, double *kernel) {
+  assert(0 <= x && x < 1);
+  double x2 = x * x;
+  double x3 = x2 * x;
+  kernel[0] = -0.5 * x + x2 - 0.5 * x3;
+  kernel[1] = 1.0 - 2.5 * x2 + 1.5 * x3;
+  kernel[2] = 0.5 * x + 2.0 * x2 - 1.5 * x3;
+  kernel[3] = -0.5 * x2 + 0.5 * x3;
 }
 
-static int determine_disflow_correspondence(int *frm_corners,
-                                            int num_frm_corners, double *flow_u,
-                                            double *flow_v, int width,
-                                            int height, int stride,
-                                            double *correspondences) {
+static INLINE void get_cubic_kernel_int(double x, int *kernel) {
+  double kernel_dbl[4];
+  get_cubic_kernel_dbl(x, kernel_dbl);
+
+  kernel[0] = (int)rint(kernel_dbl[0] * (1 << DISFLOW_INTERP_BITS));
+  kernel[1] = (int)rint(kernel_dbl[1] * (1 << DISFLOW_INTERP_BITS));
+  kernel[2] = (int)rint(kernel_dbl[2] * (1 << DISFLOW_INTERP_BITS));
+  kernel[3] = (int)rint(kernel_dbl[3] * (1 << DISFLOW_INTERP_BITS));
+}
+
+static INLINE double get_cubic_value_dbl(const double *p,
+                                         const double *kernel) {
+  return kernel[0] * p[0] + kernel[1] * p[1] + kernel[2] * p[2] +
+         kernel[3] * p[3];
+}
+
+static INLINE int get_cubic_value_int(const int *p, const int *kernel) {
+  return kernel[0] * p[0] + kernel[1] * p[1] + kernel[2] * p[2] +
+         kernel[3] * p[3];
+}
+
+static INLINE double bicubic_interp_one(const double *arr, int stride,
+                                        double *h_kernel, double *v_kernel) {
+  double tmp[1 * 4];
+
+  // Horizontal convolution
+  for (int i = -1; i < 3; ++i) {
+    tmp[i + 1] = get_cubic_value_dbl(&arr[i * stride - 1], h_kernel);
+  }
+
+  // Vertical convolution
+  return get_cubic_value_dbl(tmp, v_kernel);
+}
+
+static int determine_disflow_correspondence(CornerList *corners,
+                                            const FlowField *flow,
+                                            Correspondence *correspondences) {
+  const int width = flow->width;
+  const int height = flow->height;
+  const int stride = flow->stride;
+
   int num_correspondences = 0;
-  int x, y;
-  for (int i = 0; i < num_frm_corners; ++i) {
-    x = frm_corners[2 * i];
-    y = frm_corners[2 * i + 1];
-    if (valid_point(x, y, width, height)) {
-      correspondences[4 * num_correspondences] = x;
-      correspondences[4 * num_correspondences + 1] = y;
-      correspondences[4 * num_correspondences + 2] = x + flow_u[y * stride + x];
-      correspondences[4 * num_correspondences + 3] = y + flow_v[y * stride + x];
-      num_correspondences++;
-    }
+  for (int i = 0; i < corners->num_corners; ++i) {
+    const int x0 = corners->corners[2 * i];
+    const int y0 = corners->corners[2 * i + 1];
+
+    // Offset points, to compensate for the fact that (say) a flow field entry
+    // at horizontal index i, is nominally associated with the pixel at
+    // horizontal coordinate (i << DOWNSAMPLE_FACTOR) + UPSAMPLE_CENTER_OFFSET
+    // This offset must be applied before we split the coordinate into integer
+    // and fractional parts, in order for the interpolation to be correct.
+    const int x = x0 - UPSAMPLE_CENTER_OFFSET;
+    const int y = y0 - UPSAMPLE_CENTER_OFFSET;
+
+    // Split the pixel coordinates into integer flow field coordinates and
+    // an offset for interpolation
+    const int flow_x = x >> DOWNSAMPLE_SHIFT;
+    const double flow_sub_x =
+        (x & (DOWNSAMPLE_FACTOR - 1)) / (double)DOWNSAMPLE_FACTOR;
+    const int flow_y = y >> DOWNSAMPLE_SHIFT;
+    const double flow_sub_y =
+        (y & (DOWNSAMPLE_FACTOR - 1)) / (double)DOWNSAMPLE_FACTOR;
+
+    // Make sure that bicubic interpolation won't read outside of the flow field
+    if (flow_x < 1 || (flow_x + 2) >= width) continue;
+    if (flow_y < 1 || (flow_y + 2) >= height) continue;
+
+    double h_kernel[4];
+    double v_kernel[4];
+    get_cubic_kernel_dbl(flow_sub_x, h_kernel);
+    get_cubic_kernel_dbl(flow_sub_y, v_kernel);
+
+    const double flow_u = bicubic_interp_one(&flow->u[flow_y * stride + flow_x],
+                                             stride, h_kernel, v_kernel);
+    const double flow_v = bicubic_interp_one(&flow->v[flow_y * stride + flow_x],
+                                             stride, h_kernel, v_kernel);
+
+    // Use original points (without offsets) when filling in correspondence
+    // array
+    correspondences[num_correspondences].x = x0;
+    correspondences[num_correspondences].y = y0;
+    correspondences[num_correspondences].rx = x0 + flow_u;
+    correspondences[num_correspondences].ry = y0 + flow_v;
+    num_correspondences++;
   }
   return num_correspondences;
 }
 
-static double getCubicValue(double p[4], double x) {
-  return p[1] + 0.5 * x *
-                    (p[2] - p[0] +
-                     x * (2.0 * p[0] - 5.0 * p[1] + 4.0 * p[2] - p[3] +
-                          x * (3.0 * (p[1] - p[2]) + p[3] - p[0])));
-}
+// Compare two regions of width x height pixels, one rooted at position
+// (x, y) in src and the other at (x + u, y + v) in ref.
+// This function returns the sum of squared pixel differences between
+// the two regions.
+static INLINE void compute_flow_error(const uint8_t *src, const uint8_t *ref,
+                                      int width, int height, int stride, int x,
+                                      int y, double u, double v, int16_t *dt) {
+  // Split offset into integer and fractional parts, and compute cubic
+  // interpolation kernels
+  const int u_int = (int)floor(u);
+  const int v_int = (int)floor(v);
+  const double u_frac = u - floor(u);
+  const double v_frac = v - floor(v);
 
-static void get_subcolumn(unsigned char *ref, double col[4], int stride, int x,
-                          int y_start) {
-  int i;
-  for (i = 0; i < 4; ++i) {
-    col[i] = ref[(i + y_start) * stride + x];
-  }
-}
+  int h_kernel[4];
+  int v_kernel[4];
+  get_cubic_kernel_int(u_frac, h_kernel);
+  get_cubic_kernel_int(v_frac, v_kernel);
 
-static double bicubic(unsigned char *ref, double x, double y, int stride) {
-  double arr[4];
-  int k;
-  int i = (int)x;
-  int j = (int)y;
-  for (k = 0; k < 4; ++k) {
-    double arr_temp[4];
-    get_subcolumn(ref, arr_temp, stride, i + k - 1, j - 1);
-    arr[k] = getCubicValue(arr_temp, y - j);
-  }
-  return getCubicValue(arr, x - i);
-}
+  // Storage for intermediate values between the two convolution directions
+  int tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 3)];
+  int *tmp = tmp_ + DISFLOW_PATCH_SIZE;  // Offset by one row
 
-// Interpolate a warped block using bicubic interpolation when possible
-static unsigned char interpolate(unsigned char *ref, double x, double y,
-                                 int width, int height, int stride) {
-  if (x < 0 && y < 0)
-    return ref[0];
-  else if (x < 0 && y > height - 1)
-    return ref[(height - 1) * stride];
-  else if (x > width - 1 && y < 0)
-    return ref[width - 1];
-  else if (x > width - 1 && y > height - 1)
-    return ref[(height - 1) * stride + (width - 1)];
-  else if (x < 0) {
-    int v;
-    int i = (int)y;
-    double a = y - i;
-    if (y > 1 && y < height - 2) {
-      double arr[4];
-      get_subcolumn(ref, arr, stride, 0, i - 1);
-      return clamp((int)(getCubicValue(arr, a) + 0.5), 0, 255);
-    }
-    v = (int)(ref[i * stride] * (1 - a) + ref[(i + 1) * stride] * a + 0.5);
-    return clamp(v, 0, 255);
-  } else if (y < 0) {
-    int v;
-    int j = (int)x;
-    double b = x - j;
-    if (x > 1 && x < width - 2) {
-      double arr[4] = { ref[j - 1], ref[j], ref[j + 1], ref[j + 2] };
-      return clamp((int)(getCubicValue(arr, b) + 0.5), 0, 255);
-    }
-    v = (int)(ref[j] * (1 - b) + ref[j + 1] * b + 0.5);
-    return clamp(v, 0, 255);
-  } else if (x > width - 1) {
-    int v;
-    int i = (int)y;
-    double a = y - i;
-    if (y > 1 && y < height - 2) {
-      double arr[4];
-      get_subcolumn(ref, arr, stride, width - 1, i - 1);
-      return clamp((int)(getCubicValue(arr, a) + 0.5), 0, 255);
-    }
-    v = (int)(ref[i * stride + width - 1] * (1 - a) +
-              ref[(i + 1) * stride + width - 1] * a + 0.5);
-    return clamp(v, 0, 255);
-  } else if (y > height - 1) {
-    int v;
-    int j = (int)x;
-    double b = x - j;
-    if (x > 1 && x < width - 2) {
-      int row = (height - 1) * stride;
-      double arr[4] = { ref[row + j - 1], ref[row + j], ref[row + j + 1],
-                        ref[row + j + 2] };
-      return clamp((int)(getCubicValue(arr, b) + 0.5), 0, 255);
-    }
-    v = (int)(ref[(height - 1) * stride + j] * (1 - b) +
-              ref[(height - 1) * stride + j + 1] * b + 0.5);
-    return clamp(v, 0, 255);
-  } else if (x > 1 && y > 1 && x < width - 2 && y < height - 2) {
-    return clamp((int)(bicubic(ref, x, y, stride) + 0.5), 0, 255);
-  } else {
-    int i = (int)y;
-    int j = (int)x;
-    double a = y - i;
-    double b = x - j;
-    int v = (int)(ref[i * stride + j] * (1 - a) * (1 - b) +
-                  ref[i * stride + j + 1] * (1 - a) * b +
-                  ref[(i + 1) * stride + j] * a * (1 - b) +
-                  ref[(i + 1) * stride + j + 1] * a * b);
-    return clamp(v, 0, 255);
-  }
-}
+  // Clamp coordinates so that all pixels we fetch will remain within the
+  // allocated border region, but allow them to go far enough out that
+  // the border pixels' values do not change.
+  // Since we are calculating an 8x8 block, the bottom-right pixel
+  // in the block has coordinates (x0 + 7, y0 + 7). Then, the cubic
+  // interpolation has 4 taps, meaning that the output of pixel
+  // (x_w, y_w) depends on the pixels in the range
+  // ([x_w - 1, x_w + 2], [y_w - 1, y_w + 2]).
+  //
+  // Thus the most extreme coordinates which will be fetched are
+  // (x0 - 1, y0 - 1) and (x0 + 9, y0 + 9).
+  const int x0 = clamp(x + u_int, -9, width);
+  const int y0 = clamp(y + v_int, -9, height);
 
-// Warps a block using flow vector [u, v] and computes the mse
-static double compute_warp_and_error(unsigned char *ref, unsigned char *frm,
-                                     int width, int height, int stride, int x,
-                                     int y, double u, double v, int16_t *dt) {
-  int i, j;
-  unsigned char warped;
-  double x_w, y_w;
-  double mse = 0;
-  int16_t err = 0;
-  for (i = y; i < y + PATCH_SIZE; ++i)
-    for (j = x; j < x + PATCH_SIZE; ++j) {
-      x_w = (double)j + u;
-      y_w = (double)i + v;
-      warped = interpolate(ref, x_w, y_w, width, height, stride);
-      err = warped - frm[j + i * stride];
-      mse += err * err;
-      dt[(i - y) * PATCH_SIZE + (j - x)] = err;
-    }
+  // Horizontal convolution
+  for (int i = -1; i < DISFLOW_PATCH_SIZE + 2; ++i) {
+    const int y_w = y0 + i;
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; ++j) {
+      const int x_w = x0 + j;
+      int arr[4];
 
-  mse /= (PATCH_SIZE * PATCH_SIZE);
-  return mse;
-}
+      arr[0] = (int)ref[y_w * stride + (x_w - 1)];
+      arr[1] = (int)ref[y_w * stride + (x_w + 0)];
+      arr[2] = (int)ref[y_w * stride + (x_w + 1)];
+      arr[3] = (int)ref[y_w * stride + (x_w + 2)];
 
-// Computes the components of the system of equations used to solve for
-// a flow vector. This includes:
-// 1.) The hessian matrix for optical flow. This matrix is in the
-// form of:
-//
-//       M = |sum(dx * dx)  sum(dx * dy)|
-//           |sum(dx * dy)  sum(dy * dy)|
-//
-// 2.)   b = |sum(dx * dt)|
-//           |sum(dy * dt)|
-// Where the sums are computed over a square window of PATCH_SIZE.
-static INLINE void compute_flow_system(const double *dx, int dx_stride,
-                                       const double *dy, int dy_stride,
-                                       const int16_t *dt, int dt_stride,
-                                       double *M, double *b) {
-  for (int i = 0; i < PATCH_SIZE; i++) {
-    for (int j = 0; j < PATCH_SIZE; j++) {
-      M[0] += dx[i * dx_stride + j] * dx[i * dx_stride + j];
-      M[1] += dx[i * dx_stride + j] * dy[i * dy_stride + j];
-      M[3] += dy[i * dy_stride + j] * dy[i * dy_stride + j];
-
-      b[0] += dx[i * dx_stride + j] * dt[i * dt_stride + j];
-      b[1] += dy[i * dy_stride + j] * dt[i * dt_stride + j];
+      // Apply kernel and round, keeping 6 extra bits of precision.
+      //
+      // 6 is the maximum allowable number of extra bits which will avoid
+      // the intermediate values overflowing an int16_t. The most extreme
+      // intermediate value occurs when:
+      // * The input pixels are [0, 255, 255, 0]
+      // * u_frac = 0.5
+      // In this case, the un-scaled output is 255 * 1.125 = 286.875.
+      // As an integer with 6 fractional bits, that is 18360, which fits
+      // in an int16_t. But with 7 fractional bits it would be 36720,
+      // which is too large.
+      tmp[i * DISFLOW_PATCH_SIZE + j] = ROUND_POWER_OF_TWO(
+          get_cubic_value_int(arr, h_kernel), DISFLOW_INTERP_BITS - 6);
     }
   }
 
-  M[2] = M[1];
-}
+  // Vertical convolution
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; ++i) {
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; ++j) {
+      const int *p = &tmp[i * DISFLOW_PATCH_SIZE + j];
+      const int arr[4] = { p[-DISFLOW_PATCH_SIZE], p[0], p[DISFLOW_PATCH_SIZE],
+                           p[2 * DISFLOW_PATCH_SIZE] };
+      const int result = get_cubic_value_int(arr, v_kernel);
 
-// Solves a general Mx = b where M is a 2x2 matrix and b is a 2x1 matrix
-static INLINE void solve_2x2_system(const double *M, const double *b,
-                                    double *output_vec) {
-  double M_0 = M[0];
-  double M_3 = M[3];
-  double det = (M_0 * M_3) - (M[1] * M[2]);
-  if (det < 1e-5) {
-    // Handle singular matrix
-    // TODO(sarahparker) compare results using pseudo inverse instead
-    M_0 += 1e-10;
-    M_3 += 1e-10;
-    det = (M_0 * M_3) - (M[1] * M[2]);
-  }
-  const double det_inv = 1 / det;
-  const double mult_b0 = det_inv * b[0];
-  const double mult_b1 = det_inv * b[1];
-  output_vec[0] = M_3 * mult_b0 - M[1] * mult_b1;
-  output_vec[1] = -M[2] * mult_b0 + M_0 * mult_b1;
-}
-
-/*
-static INLINE void image_difference(const uint8_t *src, int src_stride,
-                                    const uint8_t *ref, int ref_stride,
-                                    int16_t *dst, int dst_stride, int height,
-                                    int width) {
-  const int block_unit = 8;
-  // Take difference in 8x8 blocks to make use of optimized diff function
-  for (int i = 0; i < height; i += block_unit) {
-    for (int j = 0; j < width; j += block_unit) {
-      aom_subtract_block(block_unit, block_unit, dst + i * dst_stride + j,
-                         dst_stride, src + i * src_stride + j, src_stride,
-                         ref + i * ref_stride + j, ref_stride);
+      // Apply kernel and round.
+      // This time, we have to round off the 6 extra bits which were kept
+      // earlier, but we also want to keep DISFLOW_DERIV_SCALE_LOG2 extra bits
+      // of precision to match the scale of the dx and dy arrays.
+      const int round_bits = DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2;
+      const int warped = ROUND_POWER_OF_TWO(result, round_bits);
+      const int src_px = src[(x + j) + (y + i) * stride] << 3;
+      const int err = warped - src_px;
+      dt[i * DISFLOW_PATCH_SIZE + j] = err;
     }
   }
 }
-*/
 
-static INLINE void convolve_2d_sobel_y(const uint8_t *src, int src_stride,
-                                       double *dst, int dst_stride, int w,
-                                       int h, int dir, double norm) {
-  int16_t im_block[(MAX_SB_SIZE + MAX_FILTER_TAP - 1) * MAX_SB_SIZE];
-  DECLARE_ALIGNED(256, static const int16_t, sobel_a[3]) = { 1, 0, -1 };
-  DECLARE_ALIGNED(256, static const int16_t, sobel_b[3]) = { 1, 2, 1 };
+static INLINE void sobel_filter(const uint8_t *src, int src_stride,
+                                int16_t *dst, int dst_stride, int dir) {
+  int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
+  int16_t *tmp = tmp_ + DISFLOW_PATCH_SIZE;
+
+  // Sobel filter kernel
+  // This must have an overall scale factor equal to DISFLOW_DERIV_SCALE,
+  // in order to produce correctly scaled outputs.
+  // To work out the scale factor, we multiply two factors:
+  //
+  // * For the derivative filter (sobel_a), comparing our filter
+  //    image[x - 1] - image[x + 1]
+  //   to the standard form
+  //    d/dx image[x] = image[x+1] - image[x]
+  //   tells us that we're actually calculating -2 * d/dx image[2]
+  //
+  // * For the smoothing filter (sobel_b), all coefficients are positive
+  //   so the scale factor is just the sum of the coefficients
+  //
+  // Thus we need to make sure that DISFLOW_DERIV_SCALE = 2 * sum(sobel_b)
+  // (and take care of the - sign from sobel_a elsewhere)
+  static const int16_t sobel_a[3] = { 1, 0, -1 };
+  static const int16_t sobel_b[3] = { 1, 2, 1 };
   const int taps = 3;
-  int im_h = h + taps - 1;
-  int im_stride = w;
-  const int fo_vert = 1;
-  const int fo_horiz = 1;
 
   // horizontal filter
-  const uint8_t *src_horiz = src - fo_vert * src_stride;
-  const int16_t *x_filter = dir ? sobel_a : sobel_b;
-  for (int y = 0; y < im_h; ++y) {
-    for (int x = 0; x < w; ++x) {
-      int16_t sum = 0;
+  const int16_t *h_kernel = dir ? sobel_a : sobel_b;
+
+  for (int y = -1; y < DISFLOW_PATCH_SIZE + 1; ++y) {
+    for (int x = 0; x < DISFLOW_PATCH_SIZE; ++x) {
+      int sum = 0;
       for (int k = 0; k < taps; ++k) {
-        sum += x_filter[k] * src_horiz[y * src_stride + x - fo_horiz + k];
+        sum += h_kernel[k] * src[y * src_stride + (x + k - 1)];
       }
-      im_block[y * im_stride + x] = sum;
+      tmp[y * DISFLOW_PATCH_SIZE + x] = sum;
     }
   }
 
   // vertical filter
-  int16_t *src_vert = im_block + fo_vert * im_stride;
-  const int16_t *y_filter = dir ? sobel_b : sobel_a;
-  for (int y = 0; y < h; ++y) {
-    for (int x = 0; x < w; ++x) {
-      int16_t sum = 0;
+  const int16_t *v_kernel = dir ? sobel_b : sobel_a;
+
+  for (int y = 0; y < DISFLOW_PATCH_SIZE; ++y) {
+    for (int x = 0; x < DISFLOW_PATCH_SIZE; ++x) {
+      int sum = 0;
       for (int k = 0; k < taps; ++k) {
-        sum += y_filter[k] * src_vert[(y - fo_vert + k) * im_stride + x];
+        sum += v_kernel[k] * tmp[(y + k - 1) * DISFLOW_PATCH_SIZE + x];
       }
-      dst[y * dst_stride + x] = sum * norm;
+      dst[y * dst_stride + x] = sum;
     }
   }
 }
 
-// Compute an image gradient using a sobel filter.
-// If dir == 1, compute the x gradient. If dir == 0, compute y. This function
-// assumes the images have been padded so that they can be processed in units
-// of 8.
-static INLINE void sobel_xy_image_gradient(const uint8_t *src, int src_stride,
-                                           double *dst, int dst_stride,
-                                           int height, int width, int dir) {
-  double norm = 1.0;
-  // TODO(sarahparker) experiment with doing this over larger block sizes
-  const int block_unit = 8;
-  // Filter in 8x8 blocks to eventually make use of optimized convolve function
-  for (int i = 0; i < height; i += block_unit) {
-    for (int j = 0; j < width; j += block_unit) {
-      convolve_2d_sobel_y(src + i * src_stride + j, src_stride,
-                          dst + i * dst_stride + j, dst_stride, block_unit,
-                          block_unit, dir, norm);
+// Computes the components of the system of equations used to solve for
+// a flow vector.
+//
+// The flow equations are a least-squares system, derived as follows:
+//
+// For each pixel in the patch, we calculate the current error `dt`,
+// and the x and y gradients `dx` and `dy` of the source patch.
+// This means that, to first order, the squared error for this pixel is
+//
+//    (dt + u * dx + v * dy)^2
+//
+// where (u, v) are the incremental changes to the flow vector.
+//
+// We then want to find the values of u and v which minimize the sum
+// of the squared error across all pixels. Conveniently, this fits exactly
+// into the form of a least squares problem, with one equation
+//
+//   u * dx + v * dy = -dt
+//
+// for each pixel.
+//
+// Summing across all pixels in a square window of size DISFLOW_PATCH_SIZE,
+// and absorbing the - sign elsewhere, this results in the least squares system
+//
+//   M = |sum(dx * dx)  sum(dx * dy)|
+//       |sum(dx * dy)  sum(dy * dy)|
+//
+//   b = |sum(dx * dt)|
+//       |sum(dy * dt)|
+static INLINE void compute_flow_matrix(const int16_t *dx, int dx_stride,
+                                       const int16_t *dy, int dy_stride,
+                                       double *M) {
+  int tmp[4] = { 0 };
+
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; j++) {
+      tmp[0] += dx[i * dx_stride + j] * dx[i * dx_stride + j];
+      tmp[1] += dx[i * dx_stride + j] * dy[i * dy_stride + j];
+      // Don't compute tmp[2], as it should be equal to tmp[1]
+      tmp[3] += dy[i * dy_stride + j] * dy[i * dy_stride + j];
+    }
+  }
+
+  // Apply regularization
+  // We follow the standard regularization method of adding `k * I` before
+  // inverting. This ensures that the matrix will be invertible.
+  //
+  // Setting the regularization strength k to 1 seems to work well here, as
+  // typical values coming from the other equations are very large (1e5 to
+  // 1e6, with an upper limit of around 6e7, at the time of writing).
+  // It also preserves the property that all matrix values are whole numbers,
+  // which is convenient for integerized SIMD implementation.
+  tmp[0] += 1;
+  tmp[3] += 1;
+
+  tmp[2] = tmp[1];
+
+  M[0] = (double)tmp[0];
+  M[1] = (double)tmp[1];
+  M[2] = (double)tmp[2];
+  M[3] = (double)tmp[3];
+}
+
+static INLINE void compute_flow_vector(const int16_t *dx, int dx_stride,
+                                       const int16_t *dy, int dy_stride,
+                                       const int16_t *dt, int dt_stride,
+                                       int *b) {
+  memset(b, 0, 2 * sizeof(*b));
+
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; j++) {
+      b[0] += dx[i * dx_stride + j] * dt[i * dt_stride + j];
+      b[1] += dy[i * dy_stride + j] * dt[i * dt_stride + j];
     }
   }
 }
 
-static void free_pyramid(ImagePyramid *pyr) {
-  aom_free(pyr->level_buffer);
-  if (pyr->has_gradient) {
-    aom_free(pyr->level_dx_buffer);
-    aom_free(pyr->level_dy_buffer);
-  }
-  aom_free(pyr);
+// Try to invert the matrix M
+// Note: Due to the nature of how a least-squares matrix is constructed, all of
+// the eigenvalues will be >= 0, and therefore det M >= 0 as well.
+// The regularization term `+ k * I` further ensures that det M >= k^2.
+// As mentioned in compute_flow_matrix(), here we use k = 1, so det M >= 1.
+// So we don't have to worry about non-invertible matrices here.
+static INLINE void invert_2x2(const double *M, double *M_inv) {
+  double det = (M[0] * M[3]) - (M[1] * M[2]);
+  assert(det >= 1);
+  const double det_inv = 1 / det;
+
+  M_inv[0] = M[3] * det_inv;
+  M_inv[1] = -M[1] * det_inv;
+  M_inv[2] = -M[2] * det_inv;
+  M_inv[3] = M[0] * det_inv;
 }
 
-static ImagePyramid *alloc_pyramid(int width, int height, int pad_size,
-                                   int compute_gradient) {
-  ImagePyramid *pyr = aom_calloc(1, sizeof(*pyr));
-  if (!pyr) return NULL;
-  pyr->has_gradient = compute_gradient;
-  // 2 * width * height is the upper bound for a buffer that fits
-  // all pyramid levels + padding for each level
-  const int buffer_size = sizeof(*pyr->level_buffer) * 2 * width * height +
-                          (width + 2 * pad_size) * 2 * pad_size * N_LEVELS;
-  pyr->level_buffer = aom_malloc(buffer_size);
-  if (!pyr->level_buffer) {
-    free_pyramid(pyr);
-    return NULL;
-  }
-  memset(pyr->level_buffer, 0, buffer_size);
+void aom_compute_flow_at_point_c(const uint8_t *src, const uint8_t *ref, int x,
+                                 int y, int width, int height, int stride,
+                                 double *u, double *v) {
+  double M[4];
+  double M_inv[4];
+  int b[2];
+  int16_t dt[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+  int16_t dx[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+  int16_t dy[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
 
-  if (compute_gradient) {
-    const int gradient_size =
-        sizeof(*pyr->level_dx_buffer) * 2 * width * height +
-        (width + 2 * pad_size) * 2 * pad_size * N_LEVELS;
-    pyr->level_dx_buffer = aom_calloc(1, gradient_size);
-    pyr->level_dy_buffer = aom_calloc(1, gradient_size);
-    if (!(pyr->level_dx_buffer && pyr->level_dy_buffer)) {
-      free_pyramid(pyr);
-      return NULL;
-    }
-  }
-  return pyr;
-}
+  // Compute gradients within this patch
+  const uint8_t *src_patch = &src[y * stride + x];
+  sobel_filter(src_patch, stride, dx, DISFLOW_PATCH_SIZE, 1);
+  sobel_filter(src_patch, stride, dy, DISFLOW_PATCH_SIZE, 0);
 
-static INLINE void update_level_dims(ImagePyramid *frm_pyr, int level) {
-  frm_pyr->widths[level] = frm_pyr->widths[level - 1] >> 1;
-  frm_pyr->heights[level] = frm_pyr->heights[level - 1] >> 1;
-  frm_pyr->strides[level] = frm_pyr->widths[level] + 2 * frm_pyr->pad_size;
-  // Point the beginning of the next level buffer to the correct location inside
-  // the padded border
-  frm_pyr->level_loc[level] =
-      frm_pyr->level_loc[level - 1] +
-      frm_pyr->strides[level - 1] *
-          (2 * frm_pyr->pad_size + frm_pyr->heights[level - 1]);
-}
-
-// Compute coarse to fine pyramids for a frame
-static void compute_flow_pyramids(unsigned char *frm, const int frm_width,
-                                  const int frm_height, const int frm_stride,
-                                  int n_levels, int pad_size, int compute_grad,
-                                  ImagePyramid *frm_pyr) {
-  int cur_width, cur_height, cur_stride, cur_loc;
-  assert((frm_width >> n_levels) > 0);
-  assert((frm_height >> n_levels) > 0);
-
-  // Initialize first level
-  frm_pyr->n_levels = n_levels;
-  frm_pyr->pad_size = pad_size;
-  frm_pyr->widths[0] = frm_width;
-  frm_pyr->heights[0] = frm_height;
-  frm_pyr->strides[0] = frm_width + 2 * frm_pyr->pad_size;
-  // Point the beginning of the level buffer to the location inside
-  // the padded border
-  frm_pyr->level_loc[0] =
-      frm_pyr->strides[0] * frm_pyr->pad_size + frm_pyr->pad_size;
-  // This essentially copies the original buffer into the pyramid buffer
-  // without the original padding
-  av1_resize_plane(frm, frm_height, frm_width, frm_stride,
-                   frm_pyr->level_buffer + frm_pyr->level_loc[0],
-                   frm_pyr->heights[0], frm_pyr->widths[0],
-                   frm_pyr->strides[0]);
-
-  if (compute_grad) {
-    cur_width = frm_pyr->widths[0];
-    cur_height = frm_pyr->heights[0];
-    cur_stride = frm_pyr->strides[0];
-    cur_loc = frm_pyr->level_loc[0];
-    assert(frm_pyr->has_gradient && frm_pyr->level_dx_buffer != NULL &&
-           frm_pyr->level_dy_buffer != NULL);
-    // Computation x gradient
-    sobel_xy_image_gradient(frm_pyr->level_buffer + cur_loc, cur_stride,
-                            frm_pyr->level_dx_buffer + cur_loc, cur_stride,
-                            cur_height, cur_width, 1);
-
-    // Computation y gradient
-    sobel_xy_image_gradient(frm_pyr->level_buffer + cur_loc, cur_stride,
-                            frm_pyr->level_dy_buffer + cur_loc, cur_stride,
-                            cur_height, cur_width, 0);
-  }
-
-  // Start at the finest level and resize down to the coarsest level
-  for (int level = 1; level < n_levels; ++level) {
-    update_level_dims(frm_pyr, level);
-    cur_width = frm_pyr->widths[level];
-    cur_height = frm_pyr->heights[level];
-    cur_stride = frm_pyr->strides[level];
-    cur_loc = frm_pyr->level_loc[level];
-
-    av1_resize_plane(frm_pyr->level_buffer + frm_pyr->level_loc[level - 1],
-                     frm_pyr->heights[level - 1], frm_pyr->widths[level - 1],
-                     frm_pyr->strides[level - 1],
-                     frm_pyr->level_buffer + cur_loc, cur_height, cur_width,
-                     cur_stride);
-
-    if (compute_grad) {
-      assert(frm_pyr->has_gradient && frm_pyr->level_dx_buffer != NULL &&
-             frm_pyr->level_dy_buffer != NULL);
-      // Computation x gradient
-      sobel_xy_image_gradient(frm_pyr->level_buffer + cur_loc, cur_stride,
-                              frm_pyr->level_dx_buffer + cur_loc, cur_stride,
-                              cur_height, cur_width, 1);
-
-      // Computation y gradient
-      sobel_xy_image_gradient(frm_pyr->level_buffer + cur_loc, cur_stride,
-                              frm_pyr->level_dy_buffer + cur_loc, cur_stride,
-                              cur_height, cur_width, 0);
-    }
-  }
-}
-
-static INLINE void compute_flow_at_point(unsigned char *frm, unsigned char *ref,
-                                         double *dx, double *dy, int x, int y,
-                                         int width, int height, int stride,
-                                         double *u, double *v) {
-  double M[4] = { 0 };
-  double b[2] = { 0 };
-  double tmp_output_vec[2] = { 0 };
-  double error = 0;
-  int16_t dt[PATCH_SIZE * PATCH_SIZE];
-  double o_u = *u;
-  double o_v = *v;
+  compute_flow_matrix(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, M);
+  invert_2x2(M, M_inv);
 
   for (int itr = 0; itr < DISFLOW_MAX_ITR; itr++) {
-    error = compute_warp_and_error(ref, frm, width, height, stride, x, y, *u,
-                                   *v, dt);
-    if (error <= DISFLOW_ERROR_TR) break;
-    compute_flow_system(dx, stride, dy, stride, dt, PATCH_SIZE, M, b);
-    solve_2x2_system(M, b, tmp_output_vec);
-    *u += tmp_output_vec[0];
-    *v += tmp_output_vec[1];
+    compute_flow_error(src, ref, width, height, stride, x, y, *u, *v, dt);
+    compute_flow_vector(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, dt,
+                        DISFLOW_PATCH_SIZE, b);
+
+    // Solve flow equations to find a better estimate for the flow vector
+    // at this point
+    const double step_u = M_inv[0] * b[0] + M_inv[1] * b[1];
+    const double step_v = M_inv[2] * b[0] + M_inv[3] * b[1];
+    *u += fclamp(step_u * DISFLOW_STEP_SIZE, -2, 2);
+    *v += fclamp(step_v * DISFLOW_STEP_SIZE, -2, 2);
+
+    if (fabs(step_u) + fabs(step_v) < DISFLOW_STEP_SIZE_THRESOLD) {
+      // Stop iteration when we're close to convergence
+      break;
+    }
   }
-  if (fabs(*u - o_u) > PATCH_SIZE || fabs(*v - o_u) > PATCH_SIZE) {
-    *u = o_u;
-    *v = o_v;
+}
+
+static void fill_flow_field_borders(double *flow, int width, int height,
+                                    int stride) {
+  // Calculate the bounds of the rectangle which was filled in by
+  // compute_flow_field() before calling this function.
+  // These indices are inclusive on both ends.
+  const int left_index = FLOW_BORDER;
+  const int right_index = (width - FLOW_BORDER - 1);
+  const int top_index = FLOW_BORDER;
+  const int bottom_index = (height - FLOW_BORDER - 1);
+
+  // Left area
+  for (int i = top_index; i <= bottom_index; i += 1) {
+    double *row = flow + i * stride;
+    const double left = row[left_index];
+    for (int j = 0; j < left_index; j++) {
+      row[j] = left;
+    }
+  }
+
+  // Right area
+  for (int i = top_index; i <= bottom_index; i += 1) {
+    double *row = flow + i * stride;
+    const double right = row[right_index];
+    for (int j = right_index + 1; j < width; j++) {
+      row[j] = right;
+    }
+  }
+
+  // Top area
+  const double *top_row = flow + top_index * stride;
+  for (int i = 0; i < top_index; i++) {
+    double *row = flow + i * stride;
+    memcpy(row, top_row, width * sizeof(*row));
+  }
+
+  // Bottom area
+  const double *bottom_row = flow + bottom_index * stride;
+  for (int i = bottom_index + 1; i < height; i++) {
+    double *row = flow + i * stride;
+    memcpy(row, bottom_row, width * sizeof(*row));
   }
 }
 
 // make sure flow_u and flow_v start at 0
-static bool compute_flow_field(ImagePyramid *frm_pyr, ImagePyramid *ref_pyr,
-                               double *flow_u, double *flow_v) {
-  int cur_width, cur_height, cur_stride, cur_loc, patch_loc, patch_center;
-  double *u_upscale =
-      aom_malloc(frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_u));
-  double *v_upscale =
-      aom_malloc(frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_v));
-  if (!(u_upscale && v_upscale)) {
-    aom_free(u_upscale);
-    aom_free(v_upscale);
-    return false;
-  }
+static void compute_flow_field(const ImagePyramid *src_pyr,
+                               const ImagePyramid *ref_pyr, FlowField *flow) {
+  assert(src_pyr->n_levels == ref_pyr->n_levels);
 
-  assert(frm_pyr->n_levels == ref_pyr->n_levels);
+  double *flow_u = flow->u;
+  double *flow_v = flow->v;
+
+  const size_t flow_size = flow->stride * (size_t)flow->height;
+  double *u_upscale = aom_malloc(flow_size * sizeof(*u_upscale));
+  double *v_upscale = aom_malloc(flow_size * sizeof(*v_upscale));
 
   // Compute flow field from coarsest to finest level of the pyramid
-  for (int level = frm_pyr->n_levels - 1; level >= 0; --level) {
-    cur_width = frm_pyr->widths[level];
-    cur_height = frm_pyr->heights[level];
-    cur_stride = frm_pyr->strides[level];
-    cur_loc = frm_pyr->level_loc[level];
+  for (int level = src_pyr->n_levels - 1; level >= 0; --level) {
+    const PyramidLayer *cur_layer = &src_pyr->layers[level];
+    const int cur_width = cur_layer->width;
+    const int cur_height = cur_layer->height;
+    const int cur_stride = cur_layer->stride;
 
-    for (int i = PATCH_SIZE; i < cur_height - PATCH_SIZE; i += PATCH_STEP) {
-      for (int j = PATCH_SIZE; j < cur_width - PATCH_SIZE; j += PATCH_STEP) {
-        patch_loc = i * cur_stride + j;
-        patch_center = patch_loc + PATCH_CENTER * cur_stride + PATCH_CENTER;
-        compute_flow_at_point(frm_pyr->level_buffer + cur_loc,
-                              ref_pyr->level_buffer + cur_loc,
-                              frm_pyr->level_dx_buffer + cur_loc + patch_loc,
-                              frm_pyr->level_dy_buffer + cur_loc + patch_loc, j,
-                              i, cur_width, cur_height, cur_stride,
-                              flow_u + patch_center, flow_v + patch_center);
+    const uint8_t *src_buffer = cur_layer->buffer;
+    const uint8_t *ref_buffer = ref_pyr->layers[level].buffer;
+
+    const int cur_flow_width = cur_width >> DOWNSAMPLE_SHIFT;
+    const int cur_flow_height = cur_height >> DOWNSAMPLE_SHIFT;
+    const int cur_flow_stride = flow->stride;
+
+    for (int i = FLOW_BORDER; i < cur_flow_height - FLOW_BORDER; i += 1) {
+      for (int j = FLOW_BORDER; j < cur_flow_width - FLOW_BORDER; j += 1) {
+        const int flow_field_idx = i * cur_flow_stride + j;
+
+        // Calculate the position of a patch of size DISFLOW_PATCH_SIZE pixels,
+        // which is centered on the region covered by this flow field entry
+        const int patch_center_x =
+            (j << DOWNSAMPLE_SHIFT) + UPSAMPLE_CENTER_OFFSET;  // In pixels
+        const int patch_center_y =
+            (i << DOWNSAMPLE_SHIFT) + UPSAMPLE_CENTER_OFFSET;  // In pixels
+        const int patch_tl_x = patch_center_x - DISFLOW_PATCH_CENTER;
+        const int patch_tl_y = patch_center_y - DISFLOW_PATCH_CENTER;
+        assert(patch_tl_x >= 0);
+        assert(patch_tl_y >= 0);
+
+        aom_compute_flow_at_point(src_buffer, ref_buffer, patch_tl_x,
+                                  patch_tl_y, cur_width, cur_height, cur_stride,
+                                  &flow_u[flow_field_idx],
+                                  &flow_v[flow_field_idx]);
       }
     }
-    // TODO(sarahparker) Replace this with upscale function in resize.c
+
+    // Fill in the areas which we haven't explicitly computed, with copies
+    // of the outermost values which we did compute
+    fill_flow_field_borders(flow_u, cur_flow_width, cur_flow_height,
+                            cur_flow_stride);
+    fill_flow_field_borders(flow_v, cur_flow_width, cur_flow_height,
+                            cur_flow_stride);
+
     if (level > 0) {
-      int h_upscale = frm_pyr->heights[level - 1];
-      int w_upscale = frm_pyr->widths[level - 1];
-      int s_upscale = frm_pyr->strides[level - 1];
-      for (int i = 0; i < h_upscale; ++i) {
-        for (int j = 0; j < w_upscale; ++j) {
-          u_upscale[j + i * s_upscale] =
-              flow_u[(int)(j >> 1) + (int)(i >> 1) * cur_stride];
-          v_upscale[j + i * s_upscale] =
-              flow_v[(int)(j >> 1) + (int)(i >> 1) * cur_stride];
+      const int upscale_flow_width = cur_flow_width << 1;
+      const int upscale_flow_height = cur_flow_height << 1;
+      const int upscale_stride = flow->stride;
+
+      av1_upscale_plane_double_prec(
+          flow_u, cur_flow_height, cur_flow_width, cur_flow_stride, u_upscale,
+          upscale_flow_height, upscale_flow_width, upscale_stride);
+      av1_upscale_plane_double_prec(
+          flow_v, cur_flow_height, cur_flow_width, cur_flow_stride, v_upscale,
+          upscale_flow_height, upscale_flow_width, upscale_stride);
+
+      // Multiply all flow vectors by 2.
+      // When we move down a pyramid level, the image resolution doubles.
+      // Thus we need to double all vectors in order for them to represent
+      // the same translation at the next level down
+      for (int i = 0; i < upscale_flow_height; i++) {
+        for (int j = 0; j < upscale_flow_width; j++) {
+          const int index = i * upscale_stride + j;
+          flow_u[index] = u_upscale[index] * 2.0;
+          flow_v[index] = v_upscale[index] * 2.0;
         }
       }
-      memcpy(flow_u, u_upscale,
-             frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_u));
-      memcpy(flow_v, v_upscale,
-             frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_v));
+
+      // If we didn't fill in the rightmost column or bottommost row during
+      // upsampling (in order to keep the ratio to exactly 2), fill them
+      // in here by copying the next closest column/row
+      const PyramidLayer *next_layer = &src_pyr->layers[level - 1];
+      const int next_flow_width = next_layer->width >> DOWNSAMPLE_SHIFT;
+      const int next_flow_height = next_layer->height >> DOWNSAMPLE_SHIFT;
+
+      // Rightmost column
+      if (next_flow_width > upscale_flow_width) {
+        assert(next_flow_width == upscale_flow_width + 1);
+        for (int i = 0; i < upscale_flow_height; i++) {
+          const int index = i * upscale_stride + upscale_flow_width;
+          flow_u[index] = flow_u[index - 1];
+          flow_v[index] = flow_v[index - 1];
+        }
+      }
+
+      // Bottommost row
+      if (next_flow_height > upscale_flow_height) {
+        assert(next_flow_height == upscale_flow_height + 1);
+        for (int j = 0; j < next_flow_width; j++) {
+          const int index = upscale_flow_height * upscale_stride + j;
+          flow_u[index] = flow_u[index - upscale_stride];
+          flow_v[index] = flow_v[index - upscale_stride];
+        }
+      }
     }
   }
   aom_free(u_upscale);
   aom_free(v_upscale);
-  return true;
 }
 
-int av1_compute_global_motion_disflow_based(
-    TransformationType type, unsigned char *frm_buffer, int frm_width,
-    int frm_height, int frm_stride, int *frm_corners, int num_frm_corners,
-    YV12_BUFFER_CONFIG *ref, int bit_depth, int *num_inliers_by_motion,
-    MotionModel *params_by_motion, int num_motions) {
-  unsigned char *ref_buffer = ref->y_buffer;
-  const int ref_width = ref->y_width;
-  const int ref_height = ref->y_height;
-  const int pad_size = AOMMAX(PATCH_SIZE, MIN_PAD);
-  int num_correspondences;
-  double *correspondences;
-  RansacFuncDouble ransac = av1_get_ransac_double_prec_type(type);
-  assert(frm_width == ref_width);
-  assert(frm_height == ref_height);
+static FlowField *alloc_flow_field(int frame_width, int frame_height) {
+  FlowField *flow = (FlowField *)aom_malloc(sizeof(FlowField));
+  if (flow == NULL) return NULL;
 
-  // Ensure the number of pyramid levels will work with the frame resolution
-  const int msb =
-      frm_width < frm_height ? get_msb(frm_width) : get_msb(frm_height);
-  const int n_levels = AOMMIN(msb, N_LEVELS);
+  // Calculate the size of the bottom (largest) layer of the flow pyramid
+  flow->width = frame_width >> DOWNSAMPLE_SHIFT;
+  flow->height = frame_height >> DOWNSAMPLE_SHIFT;
+  flow->stride = flow->width;
 
-  if (ref->flags & YV12_FLAG_HIGHBITDEPTH) {
-    ref_buffer = av1_downconvert_frame(ref, bit_depth);
+  const size_t flow_size = flow->stride * (size_t)flow->height;
+  flow->u = aom_calloc(flow_size, sizeof(*flow->u));
+  flow->v = aom_calloc(flow_size, sizeof(*flow->v));
+
+  if (flow->u == NULL || flow->v == NULL) {
+    aom_free(flow->u);
+    aom_free(flow->v);
+    aom_free(flow);
+    return NULL;
   }
 
-  // TODO(sarahparker) We will want to do the source pyramid computation
-  // outside of this function so it doesn't get recomputed for every
-  // reference. We also don't need to compute every pyramid level for the
-  // reference in advance, since lower levels can be overwritten once their
-  // flow field is computed and upscaled. I'll add these optimizations
-  // once the full implementation is working.
-  // Allocate frm image pyramids
-  int compute_gradient = 1;
-  ImagePyramid *frm_pyr =
-      alloc_pyramid(frm_width, frm_height, pad_size, compute_gradient);
-  if (!frm_pyr) return 0;
-  compute_flow_pyramids(frm_buffer, frm_width, frm_height, frm_stride, n_levels,
-                        pad_size, compute_gradient, frm_pyr);
-  // Allocate ref image pyramids
-  compute_gradient = 0;
-  ImagePyramid *ref_pyr =
-      alloc_pyramid(ref_width, ref_height, pad_size, compute_gradient);
-  if (!ref_pyr) {
-    free_pyramid(frm_pyr);
-    return 0;
-  }
-  compute_flow_pyramids(ref_buffer, ref_width, ref_height, ref->y_stride,
-                        n_levels, pad_size, compute_gradient, ref_pyr);
+  return flow;
+}
 
-  int ret = 0;
-  double *flow_u =
-      aom_malloc(frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_u));
-  double *flow_v =
-      aom_malloc(frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_v));
-  if (!(flow_u && flow_v)) goto Error;
+static void free_flow_field(FlowField *flow) {
+  aom_free(flow->u);
+  aom_free(flow->v);
+  aom_free(flow);
+}
 
-  memset(flow_u, 0,
-         frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_u));
-  memset(flow_v, 0,
-         frm_pyr->strides[0] * frm_pyr->heights[0] * sizeof(*flow_v));
+// Compute flow field between `src` and `ref`, and then use that flow to
+// compute a global motion model relating the two frames.
+//
+// Following the convention in flow_estimation.h, the flow vectors are computed
+// at fixed points in `src` and point to the corresponding locations in `ref`,
+// regardless of the temporal ordering of the frames.
+bool av1_compute_global_motion_disflow(TransformationType type,
+                                       YV12_BUFFER_CONFIG *src,
+                                       YV12_BUFFER_CONFIG *ref, int bit_depth,
+                                       MotionModel *motion_models,
+                                       int num_motion_models) {
+  // Precompute information we will need about each frame
+  ImagePyramid *src_pyramid = src->y_pyramid;
+  CornerList *src_corners = src->corners;
+  ImagePyramid *ref_pyramid = ref->y_pyramid;
+  aom_compute_pyramid(src, bit_depth, src_pyramid);
+  av1_compute_corner_list(src_pyramid, src_corners);
+  aom_compute_pyramid(ref, bit_depth, ref_pyramid);
 
-  if (!compute_flow_field(frm_pyr, ref_pyr, flow_u, flow_v)) goto Error;
+  const int src_width = src_pyramid->layers[0].width;
+  const int src_height = src_pyramid->layers[0].height;
+  assert(ref_pyramid->layers[0].width == src_width);
+  assert(ref_pyramid->layers[0].height == src_height);
+
+  FlowField *flow = alloc_flow_field(src_width, src_height);
+  if (!flow) return false;
+
+  compute_flow_field(src_pyramid, ref_pyramid, flow);
 
   // find correspondences between the two images using the flow field
-  correspondences = aom_malloc(num_frm_corners * 4 * sizeof(*correspondences));
-  if (!correspondences) goto Error;
-  num_correspondences = determine_disflow_correspondence(
-      frm_corners, num_frm_corners, flow_u, flow_v, frm_width, frm_height,
-      frm_pyr->strides[0], correspondences);
-  ransac(correspondences, num_correspondences, num_inliers_by_motion,
-         params_by_motion, num_motions);
-
-  // Set num_inliers = 0 for motions with too few inliers so they are ignored.
-  for (int i = 0; i < num_motions; ++i) {
-    if (num_inliers_by_motion[i] < MIN_INLIER_PROB * num_correspondences) {
-      num_inliers_by_motion[i] = 0;
-    }
+  Correspondence *correspondences =
+      aom_malloc(src_corners->num_corners * sizeof(*correspondences));
+  if (!correspondences) {
+    free_flow_field(flow);
+    return false;
   }
 
-  // Return true if any one of the motions has inliers.
-  for (int i = 0; i < num_motions; ++i) {
-    if (num_inliers_by_motion[i] > 0) {
-      ret = 1;
-      break;
-    }
-  }
+  const int num_correspondences =
+      determine_disflow_correspondence(src_corners, flow, correspondences);
+
+  bool result = ransac(correspondences, num_correspondences, type,
+                       motion_models, num_motion_models);
 
   aom_free(correspondences);
-Error:
-  free_pyramid(frm_pyr);
-  free_pyramid(ref_pyr);
-  aom_free(flow_u);
-  aom_free(flow_v);
-  return ret;
+  free_flow_field(flow);
+  return result;
 }
diff --git a/aom_dsp/flow_estimation/disflow.h b/aom_dsp/flow_estimation/disflow.h
index 52fb261..2e97ba2 100644
--- a/aom_dsp/flow_estimation/disflow.h
+++ b/aom_dsp/flow_estimation/disflow.h
@@ -12,18 +12,88 @@
 #ifndef AOM_AOM_DSP_FLOW_ESTIMATION_DISFLOW_H_
 #define AOM_AOM_DSP_FLOW_ESTIMATION_DISFLOW_H_
 
+#include <stdbool.h>
+
 #include "aom_dsp/flow_estimation/flow_estimation.h"
+#include "aom_dsp/rect.h"
 #include "aom_scale/yv12config.h"
 
 #ifdef __cplusplus
 extern "C" {
 #endif
 
-int av1_compute_global_motion_disflow_based(
-    TransformationType type, unsigned char *frm_buffer, int frm_width,
-    int frm_height, int frm_stride, int *frm_corners, int num_frm_corners,
-    YV12_BUFFER_CONFIG *ref, int bit_depth, int *num_inliers_by_motion,
-    MotionModel *params_by_motion, int num_motions);
+// Number of pyramid levels in disflow computation
+#define DISFLOW_PYRAMID_LEVELS 12
+
+// Size of square patches in the disflow dense grid
+// Must be a power of 2
+#define DISFLOW_PATCH_SIZE_LOG2 3
+#define DISFLOW_PATCH_SIZE (1 << DISFLOW_PATCH_SIZE_LOG2)
+// Center point of square patch
+#define DISFLOW_PATCH_CENTER ((DISFLOW_PATCH_SIZE / 2) - 1)
+
+// Overall scale of the `dx`, `dy` and `dt` arrays in the disflow code
+// In other words, the various derivatives are calculated with an internal
+// precision of (8 + DISFLOW_DERIV_SCALE_LOG2) bits, from an 8-bit input.
+//
+// This must be carefully synchronized with the code in sobel_filter()
+// (which fills the dx and dy arrays) and compute_flow_error() (which
+// fills dt); see the comments in those functions for more details
+#define DISFLOW_DERIV_SCALE_LOG2 3
+#define DISFLOW_DERIV_SCALE (1 << DISFLOW_DERIV_SCALE_LOG2)
+
+// Scale factor applied to each step in the main refinement loop
+//
+// This should be <= 1.0 to avoid overshoot. Values below 1.0
+// may help in some cases, but slow convergence overall, so
+// will require careful tuning.
+// TODO(rachelbarker): Tune this value
+#define DISFLOW_STEP_SIZE 1.0
+
+// Step size at which we should terminate iteration
+// The idea here is that, if we take a step which is much smaller than 1px in
+// size, then the values won't change much from iteration to iteration, so
+// many future steps will also be small, and that won't have much effect
+// on the ultimate result. So we can terminate early.
+//
+// To look at it another way, when we take a small step, that means that
+// either we're near to convergence (so can stop), or we're stuck in a
+// shallow valley and will take many iterations to get unstuck.
+//
+// Solving the latter properly requires fancier methods, such as "gradient
+// descent with momentum". For now, we terminate to avoid wasting a ton of
+// time on points which are either nearly-converged or stuck.
+//
+// Terminating at 1/8 px seems to give good results for global motion estimation
+#define DISFLOW_STEP_SIZE_THRESOLD (1. / 8.)
+
+// Max number of iterations if warp convergence is not found
+#define DISFLOW_MAX_ITR 4
+
+// Internal precision of cubic interpolation filters
+// The limiting factor here is that:
+// * Before integerizing, the maximum value of any kernel tap is 1.0
+// * After integerizing, each tap must fit into an int16_t.
+// Thus the largest multiplier we can get away with is 2^14 = 16384,
+// as 2^15 = 32768 is too large to fit in an int16_t.
+#define DISFLOW_INTERP_BITS 14
+
+typedef struct {
+  // x and y directions of flow, per patch
+  double *u;
+  double *v;
+
+  // Sizes of the above arrays
+  int width;
+  int height;
+  int stride;
+} FlowField;
+
+bool av1_compute_global_motion_disflow(TransformationType type,
+                                       YV12_BUFFER_CONFIG *src,
+                                       YV12_BUFFER_CONFIG *ref, int bit_depth,
+                                       MotionModel *motion_models,
+                                       int num_motion_models);
 
 #ifdef __cplusplus
 }
diff --git a/aom_dsp/flow_estimation/flow_estimation.c b/aom_dsp/flow_estimation/flow_estimation.c
index d8cf8bd..a6bf942 100644
--- a/aom_dsp/flow_estimation/flow_estimation.c
+++ b/aom_dsp/flow_estimation/flow_estimation.c
@@ -11,49 +11,48 @@
 
 #include <assert.h>
 
+#include "aom_dsp/flow_estimation/corner_detect.h"
 #include "aom_dsp/flow_estimation/corner_match.h"
 #include "aom_dsp/flow_estimation/disflow.h"
 #include "aom_dsp/flow_estimation/flow_estimation.h"
 #include "aom_ports/mem.h"
 #include "aom_scale/yv12config.h"
 
-int aom_compute_global_motion(TransformationType type,
-                              unsigned char *src_buffer, int src_width,
-                              int src_height, int src_stride, int *src_corners,
-                              int num_src_corners, YV12_BUFFER_CONFIG *ref,
-                              int bit_depth,
-                              GlobalMotionEstimationType gm_estimation_type,
-                              int *num_inliers_by_motion,
-                              MotionModel *params_by_motion, int num_motions) {
-  switch (gm_estimation_type) {
-    case GLOBAL_MOTION_FEATURE_BASED:
-      return av1_compute_global_motion_feature_based(
-          type, src_buffer, src_width, src_height, src_stride, src_corners,
-          num_src_corners, ref, bit_depth, num_inliers_by_motion,
-          params_by_motion, num_motions);
-    case GLOBAL_MOTION_DISFLOW_BASED:
-      return av1_compute_global_motion_disflow_based(
-          type, src_buffer, src_width, src_height, src_stride, src_corners,
-          num_src_corners, ref, bit_depth, num_inliers_by_motion,
-          params_by_motion, num_motions);
+// For each global motion method, how many pyramid levels should we allocate?
+// Note that this is a maximum, and fewer levels will be allocated if the frame
+// is not large enough to need all of the specified levels
+const int global_motion_pyr_levels[GLOBAL_MOTION_METHODS] = {
+  1,   // GLOBAL_MOTION_METHOD_FEATURE_MATCH
+  16,  // GLOBAL_MOTION_METHOD_DISFLOW
+};
+
+// clang-format off
+const double kIdentityParams[MAX_PARAMDIM] = {
+  0.0, 0.0, 1.0, 0.0, 0.0, 1.0
+};
+// clang-format on
+
+// Compute a global motion model between the given source and ref frames.
+//
+// As is standard for video codecs, the resulting model maps from (x, y)
+// coordinates in `src` to the corresponding points in `ref`, regardless
+// of the temporal order of the two frames.
+//
+// Returns true if global motion estimation succeeded, false if not.
+// The output models should only be used if this function succeeds.
+bool aom_compute_global_motion(TransformationType type, YV12_BUFFER_CONFIG *src,
+                               YV12_BUFFER_CONFIG *ref, int bit_depth,
+                               GlobalMotionMethod gm_method,
+                               MotionModel *motion_models,
+                               int num_motion_models) {
+  switch (gm_method) {
+    case GLOBAL_MOTION_METHOD_FEATURE_MATCH:
+      return av1_compute_global_motion_feature_match(
+          type, src, ref, bit_depth, motion_models, num_motion_models);
+    case GLOBAL_MOTION_METHOD_DISFLOW:
+      return av1_compute_global_motion_disflow(
+          type, src, ref, bit_depth, motion_models, num_motion_models);
     default: assert(0 && "Unknown global motion estimation type");
   }
   return 0;
 }
-
-unsigned char *av1_downconvert_frame(YV12_BUFFER_CONFIG *frm, int bit_depth) {
-  int i, j;
-  uint16_t *orig_buf = CONVERT_TO_SHORTPTR(frm->y_buffer);
-  uint8_t *buf_8bit = frm->y_buffer_8bit;
-  assert(buf_8bit);
-  if (!frm->buf_8bit_valid) {
-    for (i = 0; i < frm->y_height; ++i) {
-      for (j = 0; j < frm->y_width; ++j) {
-        buf_8bit[i * frm->y_stride + j] =
-            orig_buf[i * frm->y_stride + j] >> (bit_depth - 8);
-      }
-    }
-    frm->buf_8bit_valid = 1;
-  }
-  return buf_8bit;
-}
diff --git a/aom_dsp/flow_estimation/flow_estimation.h b/aom_dsp/flow_estimation/flow_estimation.h
index ab9d328..ea38b27 100644
--- a/aom_dsp/flow_estimation/flow_estimation.h
+++ b/aom_dsp/flow_estimation/flow_estimation.h
@@ -12,6 +12,8 @@
 #ifndef AOM_AOM_DSP_FLOW_ESTIMATION_H_
 #define AOM_AOM_DSP_FLOW_ESTIMATION_H_
 
+#include "aom_dsp/pyramid.h"
+#include "aom_dsp/flow_estimation/corner_detect.h"
 #include "aom_ports/mem.h"
 #include "aom_scale/yv12config.h"
 
@@ -19,8 +21,7 @@
 extern "C" {
 #endif
 
-#define MAX_PARAMDIM 9
-#define MAX_CORNERS 4096
+#define MAX_PARAMDIM 6
 #define MIN_INLIER_PROB 0.1
 
 /* clang-format off */
@@ -37,26 +38,48 @@
 static const int trans_model_params[TRANS_TYPES] = { 0, 2, 4, 6 };
 
 typedef enum {
-  GLOBAL_MOTION_FEATURE_BASED,
-  GLOBAL_MOTION_DISFLOW_BASED,
-} GlobalMotionEstimationType;
+  GLOBAL_MOTION_METHOD_FEATURE_MATCH,
+  GLOBAL_MOTION_METHOD_DISFLOW,
+  GLOBAL_MOTION_METHOD_LAST = GLOBAL_MOTION_METHOD_DISFLOW,
+  GLOBAL_MOTION_METHODS
+} GlobalMotionMethod;
 
 typedef struct {
-  double params[MAX_PARAMDIM - 1];
+  double params[MAX_PARAMDIM];
   int *inliers;
   int num_inliers;
 } MotionModel;
 
-int aom_compute_global_motion(TransformationType type,
-                              unsigned char *src_buffer, int src_width,
-                              int src_height, int src_stride, int *src_corners,
-                              int num_src_corners, YV12_BUFFER_CONFIG *ref,
-                              int bit_depth,
-                              GlobalMotionEstimationType gm_estimation_type,
-                              int *num_inliers_by_motion,
-                              MotionModel *params_by_motion, int num_motions);
+// Data structure to store a single correspondence point during global
+// motion search.
+//
+// A correspondence (x, y) -> (rx, ry) means that point (x, y) in the
+// source frame corresponds to point (rx, ry) in the ref frame.
+typedef struct {
+  double x, y;
+  double rx, ry;
+} Correspondence;
 
-unsigned char *av1_downconvert_frame(YV12_BUFFER_CONFIG *frm, int bit_depth);
+// For each global motion method, how many pyramid levels should we allocate?
+// Note that this is a maximum, and fewer levels will be allocated if the frame
+// is not large enough to need all of the specified levels
+extern const int global_motion_pyr_levels[GLOBAL_MOTION_METHODS];
+
+extern const double kIdentityParams[MAX_PARAMDIM];
+
+// Compute a global motion model between the given source and ref frames.
+//
+// As is standard for video codecs, the resulting model maps from (x, y)
+// coordinates in `src` to the corresponding points in `ref`, regardless
+// of the temporal order of the two frames.
+//
+// Returns true if global motion estimation succeeded, false if not.
+// The output models should only be used if this function succeeds.
+bool aom_compute_global_motion(TransformationType type, YV12_BUFFER_CONFIG *src,
+                               YV12_BUFFER_CONFIG *ref, int bit_depth,
+                               GlobalMotionMethod gm_method,
+                               MotionModel *motion_models,
+                               int num_motion_models);
 
 #ifdef __cplusplus
 }
diff --git a/aom_dsp/flow_estimation/ransac.c b/aom_dsp/flow_estimation/ransac.c
index 8ffc30d..81c5f2c 100644
--- a/aom_dsp/flow_estimation/ransac.c
+++ b/aom_dsp/flow_estimation/ransac.c
@@ -13,37 +13,54 @@
 #include <math.h>
 #include <time.h>
 #include <stdio.h>
-#include <stdlib.h>
+#include <stdbool.h>
+#include <string.h>
 #include <assert.h>
 
 #include "aom_dsp/flow_estimation/ransac.h"
 #include "aom_dsp/mathutils.h"
+#include "aom_mem/aom_mem.h"
 
 // TODO(rachelbarker): Remove dependence on code in av1/encoder/
 #include "av1/encoder/random.h"
 
 #define MAX_MINPTS 4
-#define MAX_DEGENERATE_ITER 10
 #define MINPTS_MULTIPLIER 5
 
 #define INLIER_THRESHOLD 1.25
-#define MIN_TRIALS 20
+#define INLIER_THRESHOLD_SQUARED (INLIER_THRESHOLD * INLIER_THRESHOLD)
+#define NUM_TRIALS 20
+
+// Flag to enable functions for finding TRANSLATION type models.
+//
+// These modes are not considered currently due to a spec bug (see comments
+// in gm_get_motion_vector() in av1/common/mv.h). Thus we don't need to compile
+// the corresponding search functions, but it is nice to keep the source around
+// but disabled, for completeness.
+#define ALLOW_TRANSLATION_MODELS 0
 
 ////////////////////////////////////////////////////////////////////////////////
 // ransac
-typedef int (*IsDegenerateFunc)(double *p);
-typedef void (*NormalizeFunc)(double *p, int np, double *T);
-typedef void (*DenormalizeFunc)(double *params, double *T1, double *T2);
-typedef int (*FindTransformationFunc)(int points, double *points1,
-                                      double *points2, double *params);
-typedef void (*ProjectPointsDoubleFunc)(double *mat, double *points,
-                                        double *proj, int n, int stride_points,
-                                        int stride_proj);
+typedef bool (*IsDegenerateFunc)(double *p);
+typedef bool (*FindTransformationFunc)(int points, const double *points1,
+                                       const double *points2, double *params);
+typedef void (*ProjectPointsFunc)(const double *mat, const double *points,
+                                  double *proj, int n, int stride_points,
+                                  int stride_proj);
 
-static void project_points_double_translation(double *mat, double *points,
-                                              double *proj, int n,
-                                              int stride_points,
-                                              int stride_proj) {
+// vtable-like structure which stores all of the information needed by RANSAC
+// for a particular model type
+typedef struct {
+  IsDegenerateFunc is_degenerate;
+  FindTransformationFunc find_transformation;
+  ProjectPointsFunc project_points;
+  int minpts;
+} RansacModelInfo;
+
+#if ALLOW_TRANSLATION_MODELS
+static void project_points_translation(const double *mat, const double *points,
+                                       double *proj, int n, int stride_points,
+                                       int stride_proj) {
   int i;
   for (i = 0; i < n; ++i) {
     const double x = *(points++), y = *(points++);
@@ -53,23 +70,11 @@
     proj += stride_proj - 2;
   }
 }
+#endif  // ALLOW_TRANSLATION_MODELS
 
-static void project_points_double_rotzoom(double *mat, double *points,
-                                          double *proj, int n,
-                                          int stride_points, int stride_proj) {
-  int i;
-  for (i = 0; i < n; ++i) {
-    const double x = *(points++), y = *(points++);
-    *(proj++) = mat[2] * x + mat[3] * y + mat[0];
-    *(proj++) = -mat[3] * x + mat[2] * y + mat[1];
-    points += stride_points - 2;
-    proj += stride_proj - 2;
-  }
-}
-
-static void project_points_double_affine(double *mat, double *points,
-                                         double *proj, int n, int stride_points,
-                                         int stride_proj) {
+static void project_points_affine(const double *mat, const double *points,
+                                  double *proj, int n, int stride_points,
+                                  int stride_proj) {
   int i;
   for (i = 0; i < n; ++i) {
     const double x = *(points++), y = *(points++);
@@ -80,261 +85,135 @@
   }
 }
 
-static void normalize_homography(double *pts, int n, double *T) {
-  double *p = pts;
-  double mean[2] = { 0, 0 };
-  double msqe = 0;
-  double scale;
-  int i;
+#if ALLOW_TRANSLATION_MODELS
+static bool find_translation(int np, const double *pts1, const double *pts2,
+                             double *params) {
+  double sumx = 0;
+  double sumy = 0;
 
-  assert(n > 0);
-  for (i = 0; i < n; ++i, p += 2) {
-    mean[0] += p[0];
-    mean[1] += p[1];
-  }
-  mean[0] /= n;
-  mean[1] /= n;
-  for (p = pts, i = 0; i < n; ++i, p += 2) {
-    p[0] -= mean[0];
-    p[1] -= mean[1];
-    msqe += sqrt(p[0] * p[0] + p[1] * p[1]);
-  }
-  msqe /= n;
-  scale = (msqe == 0 ? 1.0 : sqrt(2) / msqe);
-  T[0] = scale;
-  T[1] = 0;
-  T[2] = -scale * mean[0];
-  T[3] = 0;
-  T[4] = scale;
-  T[5] = -scale * mean[1];
-  T[6] = 0;
-  T[7] = 0;
-  T[8] = 1;
-  for (p = pts, i = 0; i < n; ++i, p += 2) {
-    p[0] *= scale;
-    p[1] *= scale;
-  }
-}
-
-static void invnormalize_mat(double *T, double *iT) {
-  double is = 1.0 / T[0];
-  double m0 = -T[2] * is;
-  double m1 = -T[5] * is;
-  iT[0] = is;
-  iT[1] = 0;
-  iT[2] = m0;
-  iT[3] = 0;
-  iT[4] = is;
-  iT[5] = m1;
-  iT[6] = 0;
-  iT[7] = 0;
-  iT[8] = 1;
-}
-
-static void denormalize_homography(double *params, double *T1, double *T2) {
-  double iT2[9];
-  double params2[9];
-  invnormalize_mat(T2, iT2);
-  multiply_mat(params, T1, params2, 3, 3, 3);
-  multiply_mat(iT2, params2, params, 3, 3, 3);
-}
-
-static void denormalize_affine_reorder(double *params, double *T1, double *T2) {
-  double params_denorm[MAX_PARAMDIM];
-  params_denorm[0] = params[0];
-  params_denorm[1] = params[1];
-  params_denorm[2] = params[4];
-  params_denorm[3] = params[2];
-  params_denorm[4] = params[3];
-  params_denorm[5] = params[5];
-  params_denorm[6] = params_denorm[7] = 0;
-  params_denorm[8] = 1;
-  denormalize_homography(params_denorm, T1, T2);
-  params[0] = params_denorm[2];
-  params[1] = params_denorm[5];
-  params[2] = params_denorm[0];
-  params[3] = params_denorm[1];
-  params[4] = params_denorm[3];
-  params[5] = params_denorm[4];
-  params[6] = params[7] = 0;
-}
-
-static void denormalize_rotzoom_reorder(double *params, double *T1,
-                                        double *T2) {
-  double params_denorm[MAX_PARAMDIM];
-  params_denorm[0] = params[0];
-  params_denorm[1] = params[1];
-  params_denorm[2] = params[2];
-  params_denorm[3] = -params[1];
-  params_denorm[4] = params[0];
-  params_denorm[5] = params[3];
-  params_denorm[6] = params_denorm[7] = 0;
-  params_denorm[8] = 1;
-  denormalize_homography(params_denorm, T1, T2);
-  params[0] = params_denorm[2];
-  params[1] = params_denorm[5];
-  params[2] = params_denorm[0];
-  params[3] = params_denorm[1];
-  params[4] = -params[3];
-  params[5] = params[2];
-  params[6] = params[7] = 0;
-}
-
-static void denormalize_translation_reorder(double *params, double *T1,
-                                            double *T2) {
-  double params_denorm[MAX_PARAMDIM];
-  params_denorm[0] = 1;
-  params_denorm[1] = 0;
-  params_denorm[2] = params[0];
-  params_denorm[3] = 0;
-  params_denorm[4] = 1;
-  params_denorm[5] = params[1];
-  params_denorm[6] = params_denorm[7] = 0;
-  params_denorm[8] = 1;
-  denormalize_homography(params_denorm, T1, T2);
-  params[0] = params_denorm[2];
-  params[1] = params_denorm[5];
-  params[2] = params[5] = 1;
-  params[3] = params[4] = 0;
-  params[6] = params[7] = 0;
-}
-
-static int find_translation(int np, double *pts1, double *pts2, double *mat) {
-  int i;
-  double sx, sy, dx, dy;
-  double sumx, sumy;
-
-  double T1[9], T2[9];
-  normalize_homography(pts1, np, T1);
-  normalize_homography(pts2, np, T2);
-
-  sumx = 0;
-  sumy = 0;
-  for (i = 0; i < np; ++i) {
-    dx = *(pts2++);
-    dy = *(pts2++);
-    sx = *(pts1++);
-    sy = *(pts1++);
+  for (int i = 0; i < np; ++i) {
+    double dx = *(pts2++);
+    double dy = *(pts2++);
+    double sx = *(pts1++);
+    double sy = *(pts1++);
 
     sumx += dx - sx;
     sumy += dy - sy;
   }
-  mat[0] = sumx / np;
-  mat[1] = sumy / np;
-  denormalize_translation_reorder(mat, T1, T2);
-  return 0;
+
+  params[0] = sumx / np;
+  params[1] = sumy / np;
+  params[2] = 1;
+  params[3] = 0;
+  params[4] = 0;
+  params[5] = 1;
+  return true;
+}
+#endif  // ALLOW_TRANSLATION_MODELS
+
+static bool find_rotzoom(int np, const double *pts1, const double *pts2,
+                         double *params) {
+  const int n = 4;    // Size of least-squares problem
+  double mat[4 * 4];  // Accumulator for A'A
+  double y[4];        // Accumulator for A'b
+  double a[4];        // Single row of A
+  double b;           // Single element of b
+
+  least_squares_init(mat, y, n);
+  for (int i = 0; i < np; ++i) {
+    double dx = *(pts2++);
+    double dy = *(pts2++);
+    double sx = *(pts1++);
+    double sy = *(pts1++);
+
+    a[0] = 1;
+    a[1] = 0;
+    a[2] = sx;
+    a[3] = sy;
+    b = dx;
+    least_squares_accumulate(mat, y, a, b, n);
+
+    a[0] = 0;
+    a[1] = 1;
+    a[2] = sy;
+    a[3] = -sx;
+    b = dy;
+    least_squares_accumulate(mat, y, a, b, n);
+  }
+
+  // Fill in params[0] .. params[3] with output model
+  if (!least_squares_solve(mat, y, params, n)) {
+    return false;
+  }
+
+  // Fill in remaining parameters
+  params[4] = -params[3];
+  params[5] = params[2];
+
+  return true;
 }
 
-static int find_rotzoom(int np, double *pts1, double *pts2, double *mat) {
-  const int np2 = np * 2;
-  double *a = (double *)aom_malloc(sizeof(*a) * (np2 * 5 + 20));
-  if (a == NULL) return 1;
-  double *b = a + np2 * 4;
-  double *temp = b + np2;
-  int i;
-  double sx, sy, dx, dy;
+static bool find_affine(int np, const double *pts1, const double *pts2,
+                        double *params) {
+  // Note: The least squares problem for affine models is 6-dimensional,
+  // but it splits into two independent 3-dimensional subproblems.
+  // Solving these two subproblems separately and recombining at the end
+  // results in less total computation than solving the 6-dimensional
+  // problem directly.
+  //
+  // The two subproblems correspond to all the parameters which contribute
+  // to the x output of the model, and all the parameters which contribute
+  // to the y output, respectively.
 
-  double T1[9], T2[9];
-  normalize_homography(pts1, np, T1);
-  normalize_homography(pts2, np, T2);
+  const int n = 3;       // Size of each least-squares problem
+  double mat[2][3 * 3];  // Accumulator for A'A
+  double y[2][3];        // Accumulator for A'b
+  double x[2][3];        // Output vector
+  double a[2][3];        // Single row of A
+  double b[2];           // Single element of b
 
-  for (i = 0; i < np; ++i) {
-    dx = *(pts2++);
-    dy = *(pts2++);
-    sx = *(pts1++);
-    sy = *(pts1++);
+  least_squares_init(mat[0], y[0], n);
+  least_squares_init(mat[1], y[1], n);
+  for (int i = 0; i < np; ++i) {
+    double dx = *(pts2++);
+    double dy = *(pts2++);
+    double sx = *(pts1++);
+    double sy = *(pts1++);
 
-    a[i * 2 * 4 + 0] = sx;
-    a[i * 2 * 4 + 1] = sy;
-    a[i * 2 * 4 + 2] = 1;
-    a[i * 2 * 4 + 3] = 0;
-    a[(i * 2 + 1) * 4 + 0] = sy;
-    a[(i * 2 + 1) * 4 + 1] = -sx;
-    a[(i * 2 + 1) * 4 + 2] = 0;
-    a[(i * 2 + 1) * 4 + 3] = 1;
+    a[0][0] = 1;
+    a[0][1] = sx;
+    a[0][2] = sy;
+    b[0] = dx;
+    least_squares_accumulate(mat[0], y[0], a[0], b[0], n);
 
-    b[2 * i] = dx;
-    b[2 * i + 1] = dy;
+    a[1][0] = 1;
+    a[1][1] = sx;
+    a[1][2] = sy;
+    b[1] = dy;
+    least_squares_accumulate(mat[1], y[1], a[1], b[1], n);
   }
-  if (!least_squares(4, a, np2, 4, b, temp, mat)) {
-    aom_free(a);
-    return 1;
+
+  if (!least_squares_solve(mat[0], y[0], x[0], n)) {
+    return false;
   }
-  denormalize_rotzoom_reorder(mat, T1, T2);
-  aom_free(a);
-  return 0;
-}
-
-static int find_affine(int np, double *pts1, double *pts2, double *mat) {
-  assert(np > 0);
-  const int np2 = np * 2;
-  double *a = (double *)aom_malloc(sizeof(*a) * (np2 * 7 + 42));
-  if (a == NULL) return 1;
-  double *b = a + np2 * 6;
-  double *temp = b + np2;
-  int i;
-  double sx, sy, dx, dy;
-
-  double T1[9], T2[9];
-  normalize_homography(pts1, np, T1);
-  normalize_homography(pts2, np, T2);
-
-  for (i = 0; i < np; ++i) {
-    dx = *(pts2++);
-    dy = *(pts2++);
-    sx = *(pts1++);
-    sy = *(pts1++);
-
-    a[i * 2 * 6 + 0] = sx;
-    a[i * 2 * 6 + 1] = sy;
-    a[i * 2 * 6 + 2] = 0;
-    a[i * 2 * 6 + 3] = 0;
-    a[i * 2 * 6 + 4] = 1;
-    a[i * 2 * 6 + 5] = 0;
-    a[(i * 2 + 1) * 6 + 0] = 0;
-    a[(i * 2 + 1) * 6 + 1] = 0;
-    a[(i * 2 + 1) * 6 + 2] = sx;
-    a[(i * 2 + 1) * 6 + 3] = sy;
-    a[(i * 2 + 1) * 6 + 4] = 0;
-    a[(i * 2 + 1) * 6 + 5] = 1;
-
-    b[2 * i] = dx;
-    b[2 * i + 1] = dy;
+  if (!least_squares_solve(mat[1], y[1], x[1], n)) {
+    return false;
   }
-  if (!least_squares(6, a, np2, 6, b, temp, mat)) {
-    aom_free(a);
-    return 1;
-  }
-  denormalize_affine_reorder(mat, T1, T2);
-  aom_free(a);
-  return 0;
-}
 
-static int get_rand_indices(int npoints, int minpts, int *indices,
-                            unsigned int *seed) {
-  int i, j;
-  int ptr = lcg_rand16(seed) % npoints;
-  if (minpts > npoints) return 0;
-  indices[0] = ptr;
-  ptr = (ptr == npoints - 1 ? 0 : ptr + 1);
-  i = 1;
-  while (i < minpts) {
-    int index = lcg_rand16(seed) % npoints;
-    while (index) {
-      ptr = (ptr == npoints - 1 ? 0 : ptr + 1);
-      for (j = 0; j < i; ++j) {
-        if (indices[j] == ptr) break;
-      }
-      if (j == i) index--;
-    }
-    indices[i++] = ptr;
-  }
-  return 1;
+  // Rearrange least squares result to form output model
+  params[0] = x[0][0];
+  params[1] = x[1][0];
+  params[2] = x[0][1];
+  params[3] = x[0][2];
+  params[4] = x[1][1];
+  params[5] = x[1][2];
+
+  return true;
 }
 
 typedef struct {
   int num_inliers;
-  double variance;
+  double sse;  // Sum of squared errors of inliers
   int *inlier_indices;
 } RANSAC_MOTION;
 
@@ -345,13 +224,13 @@
 
   if (motion_a->num_inliers > motion_b->num_inliers) return -1;
   if (motion_a->num_inliers < motion_b->num_inliers) return 1;
-  if (motion_a->variance < motion_b->variance) return -1;
-  if (motion_a->variance > motion_b->variance) return 1;
+  if (motion_a->sse < motion_b->sse) return -1;
+  if (motion_a->sse > motion_b->sse) return 1;
   return 0;
 }
 
-static int is_better_motion(const RANSAC_MOTION *motion_a,
-                            const RANSAC_MOTION *motion_b) {
+static bool is_better_motion(const RANSAC_MOTION *motion_a,
+                             const RANSAC_MOTION *motion_b) {
   return compare_motions(motion_a, motion_b) < 0;
 }
 
@@ -364,24 +243,14 @@
   }
 }
 
-static const double kInfiniteVariance = 1e12;
-
-static void clear_motion(RANSAC_MOTION *motion, int num_points) {
-  motion->num_inliers = 0;
-  motion->variance = kInfiniteVariance;
-  memset(motion->inlier_indices, 0,
-         sizeof(*motion->inlier_indices) * num_points);
-}
-
-static int ransac(const int *matched_points, int npoints,
-                  int *num_inliers_by_motion, MotionModel *params_by_motion,
-                  int num_desired_motions, int minpts,
-                  IsDegenerateFunc is_degenerate,
-                  FindTransformationFunc find_transformation,
-                  ProjectPointsDoubleFunc projectpoints) {
-  int trial_count = 0;
+// Returns true on success, false on error
+static bool ransac_internal(const Correspondence *matched_points, int npoints,
+                            MotionModel *motion_models, int num_desired_motions,
+                            const RansacModelInfo *model_info) {
+  assert(npoints >= 0);
   int i = 0;
-  int ret_val = 0;
+  int minpts = model_info->minpts;
+  bool ret_val = true;
 
   unsigned int seed = (unsigned int)npoints;
 
@@ -389,7 +258,7 @@
 
   double *points1, *points2;
   double *corners1, *corners2;
-  double *image1_coord;
+  double *projected_corners;
 
   // Store information for the num_desired_motions best transformations found
   // and the worst motion among them, as well as the motion currently under
@@ -401,123 +270,115 @@
   // currently under consideration.
   double params_this_motion[MAX_PARAMDIM];
 
-  double *cnp1, *cnp2;
-
-  for (i = 0; i < num_desired_motions; ++i) {
-    num_inliers_by_motion[i] = 0;
-  }
   if (npoints < minpts * MINPTS_MULTIPLIER || npoints == 0) {
-    return 1;
+    return false;
   }
 
+  int min_inliers = AOMMAX((int)(MIN_INLIER_PROB * npoints), minpts);
+
   points1 = (double *)aom_malloc(sizeof(*points1) * npoints * 2);
   points2 = (double *)aom_malloc(sizeof(*points2) * npoints * 2);
   corners1 = (double *)aom_malloc(sizeof(*corners1) * npoints * 2);
   corners2 = (double *)aom_malloc(sizeof(*corners2) * npoints * 2);
-  image1_coord = (double *)aom_malloc(sizeof(*image1_coord) * npoints * 2);
+  projected_corners =
+      (double *)aom_malloc(sizeof(*projected_corners) * npoints * 2);
   motions =
       (RANSAC_MOTION *)aom_calloc(num_desired_motions, sizeof(RANSAC_MOTION));
-  current_motion.inlier_indices =
-      (int *)aom_malloc(sizeof(*current_motion.inlier_indices) * npoints);
-  if (!(points1 && points2 && corners1 && corners2 && image1_coord && motions &&
-        current_motion.inlier_indices)) {
-    ret_val = 1;
+
+  // Allocate one large buffer which will be carved up to store the inlier
+  // indices for the current motion plus the num_desired_motions many
+  // output models
+  // This allows us to keep the allocation/deallocation logic simple, without
+  // having to (for example) check that `motions` is non-null before allocating
+  // the inlier arrays
+  int *inlier_buffer = (int *)aom_malloc(sizeof(*inlier_buffer) * npoints *
+                                         (num_desired_motions + 1));
+
+  if (!(points1 && points2 && corners1 && corners2 && projected_corners &&
+        motions && inlier_buffer)) {
+    ret_val = false;
     goto finish_ransac;
   }
 
-  for (i = 0; i < num_desired_motions; ++i) {
-    motions[i].inlier_indices =
-        (int *)aom_malloc(sizeof(*motions->inlier_indices) * npoints);
-    if (!motions[i].inlier_indices) {
-      ret_val = 1;
-      goto finish_ransac;
-    }
-    clear_motion(motions + i, npoints);
-  }
-  clear_motion(&current_motion, npoints);
-
+  // Once all our allocations are known-good, we can fill in our structures
   worst_kept_motion = motions;
 
-  cnp1 = corners1;
-  cnp2 = corners2;
+  for (i = 0; i < num_desired_motions; ++i) {
+    motions[i].inlier_indices = inlier_buffer + i * npoints;
+  }
+  memset(&current_motion, 0, sizeof(current_motion));
+  current_motion.inlier_indices = inlier_buffer + num_desired_motions * npoints;
+
   for (i = 0; i < npoints; ++i) {
-    *(cnp1++) = *(matched_points++);
-    *(cnp1++) = *(matched_points++);
-    *(cnp2++) = *(matched_points++);
-    *(cnp2++) = *(matched_points++);
+    corners1[2 * i + 0] = matched_points[i].x;
+    corners1[2 * i + 1] = matched_points[i].y;
+    corners2[2 * i + 0] = matched_points[i].rx;
+    corners2[2 * i + 1] = matched_points[i].ry;
   }
 
-  while (MIN_TRIALS > trial_count) {
-    double sum_distance = 0.0;
-    double sum_distance_squared = 0.0;
+  for (int trial_count = 0; trial_count < NUM_TRIALS; trial_count++) {
+    lcg_pick(npoints, minpts, indices, &seed);
 
-    clear_motion(&current_motion, npoints);
+    copy_points_at_indices(points1, corners1, indices, minpts);
+    copy_points_at_indices(points2, corners2, indices, minpts);
 
-    int degenerate = 1;
-    int num_degenerate_iter = 0;
-
-    while (degenerate) {
-      num_degenerate_iter++;
-      if (!get_rand_indices(npoints, minpts, indices, &seed)) {
-        ret_val = 1;
-        goto finish_ransac;
-      }
-
-      copy_points_at_indices(points1, corners1, indices, minpts);
-      copy_points_at_indices(points2, corners2, indices, minpts);
-
-      degenerate = is_degenerate(points1);
-      if (num_degenerate_iter > MAX_DEGENERATE_ITER) {
-        ret_val = 1;
-        goto finish_ransac;
-      }
-    }
-
-    if (find_transformation(minpts, points1, points2, params_this_motion)) {
-      trial_count++;
+    if (model_info->is_degenerate(points1)) {
       continue;
     }
 
-    projectpoints(params_this_motion, corners1, image1_coord, npoints, 2, 2);
+    if (!model_info->find_transformation(minpts, points1, points2,
+                                         params_this_motion)) {
+      continue;
+    }
 
+    model_info->project_points(params_this_motion, corners1, projected_corners,
+                               npoints, 2, 2);
+
+    current_motion.num_inliers = 0;
+    double sse = 0.0;
     for (i = 0; i < npoints; ++i) {
-      double dx = image1_coord[i * 2] - corners2[i * 2];
-      double dy = image1_coord[i * 2 + 1] - corners2[i * 2 + 1];
-      double distance = sqrt(dx * dx + dy * dy);
+      double dx = projected_corners[i * 2] - corners2[i * 2];
+      double dy = projected_corners[i * 2 + 1] - corners2[i * 2 + 1];
+      double squared_error = dx * dx + dy * dy;
 
-      if (distance < INLIER_THRESHOLD) {
+      if (squared_error < INLIER_THRESHOLD_SQUARED) {
         current_motion.inlier_indices[current_motion.num_inliers++] = i;
-        sum_distance += distance;
-        sum_distance_squared += distance * distance;
+        sse += squared_error;
       }
     }
 
-    if (current_motion.num_inliers >= worst_kept_motion->num_inliers &&
-        current_motion.num_inliers > 1) {
-      double mean_distance;
-      mean_distance = sum_distance / ((double)current_motion.num_inliers);
-      current_motion.variance =
-          sum_distance_squared / ((double)current_motion.num_inliers - 1.0) -
-          mean_distance * mean_distance * ((double)current_motion.num_inliers) /
-              ((double)current_motion.num_inliers - 1.0);
-      if (is_better_motion(&current_motion, worst_kept_motion)) {
-        // This motion is better than the worst currently kept motion. Remember
-        // the inlier points and variance. The parameters for each kept motion
-        // will be recomputed later using only the inliers.
-        worst_kept_motion->num_inliers = current_motion.num_inliers;
-        worst_kept_motion->variance = current_motion.variance;
-        memcpy(worst_kept_motion->inlier_indices, current_motion.inlier_indices,
-               sizeof(*current_motion.inlier_indices) * npoints);
-        assert(npoints > 0);
-        // Determine the new worst kept motion and its num_inliers and variance.
-        for (i = 0; i < num_desired_motions; ++i) {
-          if (is_better_motion(worst_kept_motion, &motions[i])) {
-            worst_kept_motion = &motions[i];
-          }
+    if (current_motion.num_inliers < min_inliers) {
+      // Reject models with too few inliers
+      continue;
+    }
+
+    current_motion.sse = sse;
+    if (is_better_motion(&current_motion, worst_kept_motion)) {
+      // This motion is better than the worst currently kept motion. Remember
+      // the inlier points and sse. The parameters for each kept motion
+      // will be recomputed later using only the inliers.
+      worst_kept_motion->num_inliers = current_motion.num_inliers;
+      worst_kept_motion->sse = current_motion.sse;
+
+      // Rather than copying the (potentially many) inlier indices from
+      // current_motion.inlier_indices to worst_kept_motion->inlier_indices,
+      // we can swap the underlying pointers.
+      //
+      // This is okay because the next time current_motion.inlier_indices
+      // is used will be in the next trial, where we ignore its previous
+      // contents anyway. And both arrays will be deallocated together at the
+      // end of this function, so there are no lifetime issues.
+      int *tmp = worst_kept_motion->inlier_indices;
+      worst_kept_motion->inlier_indices = current_motion.inlier_indices;
+      current_motion.inlier_indices = tmp;
+
+      // Determine the new worst kept motion and its num_inliers and sse.
+      for (i = 0; i < num_desired_motions; ++i) {
+        if (is_better_motion(worst_kept_motion, &motions[i])) {
+          worst_kept_motion = &motions[i];
         }
       }
     }
-    trial_count++;
   }
 
   // Sort the motions, best first.
@@ -525,310 +386,96 @@
 
   // Recompute the motions using only the inliers.
   for (i = 0; i < num_desired_motions; ++i) {
-    if (motions[i].num_inliers >= minpts) {
+    int num_inliers = motions[i].num_inliers;
+    if (num_inliers > 0) {
+      assert(num_inliers >= minpts);
+
       copy_points_at_indices(points1, corners1, motions[i].inlier_indices,
-                             motions[i].num_inliers);
+                             num_inliers);
       copy_points_at_indices(points2, corners2, motions[i].inlier_indices,
-                             motions[i].num_inliers);
+                             num_inliers);
 
-      find_transformation(motions[i].num_inliers, points1, points2,
-                          params_by_motion[i].params);
+      if (!model_info->find_transformation(num_inliers, points1, points2,
+                                           motion_models[i].params)) {
+        // In the unlikely event that this model fitting fails,
+        // we don't have a good fallback. So just clear the output
+        // model and move on
+        memcpy(motion_models[i].params, kIdentityParams,
+               MAX_PARAMDIM * sizeof(*(motion_models[i].params)));
+        motion_models[i].num_inliers = 0;
+        continue;
+      }
 
-      params_by_motion[i].num_inliers = motions[i].num_inliers;
-      memcpy(params_by_motion[i].inliers, motions[i].inlier_indices,
-             sizeof(*motions[i].inlier_indices) * npoints);
-      num_inliers_by_motion[i] = motions[i].num_inliers;
+      // Populate inliers array
+      for (int j = 0; j < num_inliers; j++) {
+        int index = motions[i].inlier_indices[j];
+        const Correspondence *corr = &matched_points[index];
+        motion_models[i].inliers[2 * j + 0] = (int)rint(corr->x);
+        motion_models[i].inliers[2 * j + 1] = (int)rint(corr->y);
+      }
+      motion_models[i].num_inliers = num_inliers;
+    } else {
+      memcpy(motion_models[i].params, kIdentityParams,
+             MAX_PARAMDIM * sizeof(*(motion_models[i].params)));
+      motion_models[i].num_inliers = 0;
     }
   }
 
 finish_ransac:
-  aom_free(points1);
-  aom_free(points2);
-  aom_free(corners1);
+  aom_free(inlier_buffer);
+  aom_free(motions);
+  aom_free(projected_corners);
   aom_free(corners2);
-  aom_free(image1_coord);
-  aom_free(current_motion.inlier_indices);
-  if (motions) {
-    for (i = 0; i < num_desired_motions; ++i) {
-      aom_free(motions[i].inlier_indices);
-    }
-    aom_free(motions);
-  }
+  aom_free(corners1);
+  aom_free(points2);
+  aom_free(points1);
 
   return ret_val;
 }
 
-static int ransac_double_prec(const double *matched_points, int npoints,
-                              int *num_inliers_by_motion,
-                              MotionModel *params_by_motion,
-                              int num_desired_motions, int minpts,
-                              IsDegenerateFunc is_degenerate,
-                              FindTransformationFunc find_transformation,
-                              ProjectPointsDoubleFunc projectpoints) {
-  int trial_count = 0;
-  int i = 0;
-  int ret_val = 0;
-
-  unsigned int seed = (unsigned int)npoints;
-
-  int indices[MAX_MINPTS] = { 0 };
-
-  double *points1, *points2;
-  double *corners1, *corners2;
-  double *image1_coord;
-
-  // Store information for the num_desired_motions best transformations found
-  // and the worst motion among them, as well as the motion currently under
-  // consideration.
-  RANSAC_MOTION *motions, *worst_kept_motion = NULL;
-  RANSAC_MOTION current_motion;
-
-  // Store the parameters and the indices of the inlier points for the motion
-  // currently under consideration.
-  double params_this_motion[MAX_PARAMDIM];
-
-  double *cnp1, *cnp2;
-
-  for (i = 0; i < num_desired_motions; ++i) {
-    num_inliers_by_motion[i] = 0;
-  }
-  if (npoints < minpts * MINPTS_MULTIPLIER || npoints == 0) {
-    return 1;
-  }
-
-  points1 = (double *)aom_malloc(sizeof(*points1) * npoints * 2);
-  points2 = (double *)aom_malloc(sizeof(*points2) * npoints * 2);
-  corners1 = (double *)aom_malloc(sizeof(*corners1) * npoints * 2);
-  corners2 = (double *)aom_malloc(sizeof(*corners2) * npoints * 2);
-  image1_coord = (double *)aom_malloc(sizeof(*image1_coord) * npoints * 2);
-  motions =
-      (RANSAC_MOTION *)aom_calloc(num_desired_motions, sizeof(RANSAC_MOTION));
-  current_motion.inlier_indices =
-      (int *)aom_malloc(sizeof(*current_motion.inlier_indices) * npoints);
-  if (!(points1 && points2 && corners1 && corners2 && image1_coord && motions &&
-        current_motion.inlier_indices)) {
-    ret_val = 1;
-    goto finish_ransac;
-  }
-
-  for (i = 0; i < num_desired_motions; ++i) {
-    motions[i].inlier_indices =
-        (int *)aom_malloc(sizeof(*motions->inlier_indices) * npoints);
-    if (!motions[i].inlier_indices) {
-      ret_val = 1;
-      goto finish_ransac;
-    }
-    clear_motion(motions + i, npoints);
-  }
-  clear_motion(&current_motion, npoints);
-
-  worst_kept_motion = motions;
-
-  cnp1 = corners1;
-  cnp2 = corners2;
-  for (i = 0; i < npoints; ++i) {
-    *(cnp1++) = *(matched_points++);
-    *(cnp1++) = *(matched_points++);
-    *(cnp2++) = *(matched_points++);
-    *(cnp2++) = *(matched_points++);
-  }
-
-  while (MIN_TRIALS > trial_count) {
-    double sum_distance = 0.0;
-    double sum_distance_squared = 0.0;
-
-    clear_motion(&current_motion, npoints);
-
-    int degenerate = 1;
-    int num_degenerate_iter = 0;
-
-    while (degenerate) {
-      num_degenerate_iter++;
-      if (!get_rand_indices(npoints, minpts, indices, &seed)) {
-        ret_val = 1;
-        goto finish_ransac;
-      }
-
-      copy_points_at_indices(points1, corners1, indices, minpts);
-      copy_points_at_indices(points2, corners2, indices, minpts);
-
-      degenerate = is_degenerate(points1);
-      if (num_degenerate_iter > MAX_DEGENERATE_ITER) {
-        ret_val = 1;
-        goto finish_ransac;
-      }
-    }
-
-    if (find_transformation(minpts, points1, points2, params_this_motion)) {
-      trial_count++;
-      continue;
-    }
-
-    projectpoints(params_this_motion, corners1, image1_coord, npoints, 2, 2);
-
-    for (i = 0; i < npoints; ++i) {
-      double dx = image1_coord[i * 2] - corners2[i * 2];
-      double dy = image1_coord[i * 2 + 1] - corners2[i * 2 + 1];
-      double distance = sqrt(dx * dx + dy * dy);
-
-      if (distance < INLIER_THRESHOLD) {
-        current_motion.inlier_indices[current_motion.num_inliers++] = i;
-        sum_distance += distance;
-        sum_distance_squared += distance * distance;
-      }
-    }
-
-    if (current_motion.num_inliers >= worst_kept_motion->num_inliers &&
-        current_motion.num_inliers > 1) {
-      double mean_distance;
-      mean_distance = sum_distance / ((double)current_motion.num_inliers);
-      current_motion.variance =
-          sum_distance_squared / ((double)current_motion.num_inliers - 1.0) -
-          mean_distance * mean_distance * ((double)current_motion.num_inliers) /
-              ((double)current_motion.num_inliers - 1.0);
-      if (is_better_motion(&current_motion, worst_kept_motion)) {
-        // This motion is better than the worst currently kept motion. Remember
-        // the inlier points and variance. The parameters for each kept motion
-        // will be recomputed later using only the inliers.
-        worst_kept_motion->num_inliers = current_motion.num_inliers;
-        worst_kept_motion->variance = current_motion.variance;
-        memcpy(worst_kept_motion->inlier_indices, current_motion.inlier_indices,
-               sizeof(*current_motion.inlier_indices) * npoints);
-        assert(npoints > 0);
-        // Determine the new worst kept motion and its num_inliers and variance.
-        for (i = 0; i < num_desired_motions; ++i) {
-          if (is_better_motion(worst_kept_motion, &motions[i])) {
-            worst_kept_motion = &motions[i];
-          }
-        }
-      }
-    }
-    trial_count++;
-  }
-
-  // Sort the motions, best first.
-  qsort(motions, num_desired_motions, sizeof(RANSAC_MOTION), compare_motions);
-
-  // Recompute the motions using only the inliers.
-  for (i = 0; i < num_desired_motions; ++i) {
-    if (motions[i].num_inliers >= minpts) {
-      copy_points_at_indices(points1, corners1, motions[i].inlier_indices,
-                             motions[i].num_inliers);
-      copy_points_at_indices(points2, corners2, motions[i].inlier_indices,
-                             motions[i].num_inliers);
-
-      find_transformation(motions[i].num_inliers, points1, points2,
-                          params_by_motion[i].params);
-      memcpy(params_by_motion[i].inliers, motions[i].inlier_indices,
-             sizeof(*motions[i].inlier_indices) * npoints);
-    }
-    num_inliers_by_motion[i] = motions[i].num_inliers;
-  }
-
-finish_ransac:
-  aom_free(points1);
-  aom_free(points2);
-  aom_free(corners1);
-  aom_free(corners2);
-  aom_free(image1_coord);
-  aom_free(current_motion.inlier_indices);
-  if (motions) {
-    for (i = 0; i < num_desired_motions; ++i) {
-      aom_free(motions[i].inlier_indices);
-    }
-    aom_free(motions);
-  }
-
-  return ret_val;
-}
-
-static int is_collinear3(double *p1, double *p2, double *p3) {
+static bool is_collinear3(double *p1, double *p2, double *p3) {
   static const double collinear_eps = 1e-3;
   const double v =
       (p2[0] - p1[0]) * (p3[1] - p1[1]) - (p2[1] - p1[1]) * (p3[0] - p1[0]);
   return fabs(v) < collinear_eps;
 }
 
-static int is_degenerate_translation(double *p) {
+#if ALLOW_TRANSLATION_MODELS
+static bool is_degenerate_translation(double *p) {
   return (p[0] - p[2]) * (p[0] - p[2]) + (p[1] - p[3]) * (p[1] - p[3]) <= 2;
 }
+#endif  // ALLOW_TRANSLATION_MODELS
 
-static int is_degenerate_affine(double *p) {
+static bool is_degenerate_affine(double *p) {
   return is_collinear3(p, p + 2, p + 4);
 }
 
-static int ransac_translation(int *matched_points, int npoints,
-                              int *num_inliers_by_motion,
-                              MotionModel *params_by_motion,
-                              int num_desired_motions) {
-  return ransac(matched_points, npoints, num_inliers_by_motion,
-                params_by_motion, num_desired_motions, 3,
-                is_degenerate_translation, find_translation,
-                project_points_double_translation);
-}
+static const RansacModelInfo ransac_model_info[TRANS_TYPES] = {
+  // IDENTITY
+  { NULL, NULL, NULL, 0 },
+// TRANSLATION
+#if ALLOW_TRANSLATION_MODELS
+  { is_degenerate_translation, find_translation, project_points_translation,
+    3 },
+#else
+  { NULL, NULL, NULL, 0 },
+#endif
+  // ROTZOOM
+  { is_degenerate_affine, find_rotzoom, project_points_affine, 3 },
+  // AFFINE
+  { is_degenerate_affine, find_affine, project_points_affine, 3 },
+};
 
-static int ransac_rotzoom(int *matched_points, int npoints,
-                          int *num_inliers_by_motion,
-                          MotionModel *params_by_motion,
-                          int num_desired_motions) {
-  return ransac(matched_points, npoints, num_inliers_by_motion,
-                params_by_motion, num_desired_motions, 3, is_degenerate_affine,
-                find_rotzoom, project_points_double_rotzoom);
-}
+// Returns true on success, false on error
+bool ransac(const Correspondence *matched_points, int npoints,
+            TransformationType type, MotionModel *motion_models,
+            int num_desired_motions) {
+#if ALLOW_TRANSLATION_MODELS
+  assert(type > IDENTITY && type < TRANS_TYPES);
+#else
+  assert(type > TRANSLATION && type < TRANS_TYPES);
+#endif  // ALLOW_TRANSLATION_MODELS
 
-static int ransac_affine(int *matched_points, int npoints,
-                         int *num_inliers_by_motion,
-                         MotionModel *params_by_motion,
-                         int num_desired_motions) {
-  return ransac(matched_points, npoints, num_inliers_by_motion,
-                params_by_motion, num_desired_motions, 3, is_degenerate_affine,
-                find_affine, project_points_double_affine);
-}
-
-RansacFunc av1_get_ransac_type(TransformationType type) {
-  switch (type) {
-    case AFFINE: return ransac_affine;
-    case ROTZOOM: return ransac_rotzoom;
-    case TRANSLATION: return ransac_translation;
-    default: assert(0); return NULL;
-  }
-}
-
-static int ransac_translation_double_prec(double *matched_points, int npoints,
-                                          int *num_inliers_by_motion,
-                                          MotionModel *params_by_motion,
-                                          int num_desired_motions) {
-  return ransac_double_prec(matched_points, npoints, num_inliers_by_motion,
-                            params_by_motion, num_desired_motions, 3,
-                            is_degenerate_translation, find_translation,
-                            project_points_double_translation);
-}
-
-static int ransac_rotzoom_double_prec(double *matched_points, int npoints,
-                                      int *num_inliers_by_motion,
-                                      MotionModel *params_by_motion,
-                                      int num_desired_motions) {
-  return ransac_double_prec(matched_points, npoints, num_inliers_by_motion,
-                            params_by_motion, num_desired_motions, 3,
-                            is_degenerate_affine, find_rotzoom,
-                            project_points_double_rotzoom);
-}
-
-static int ransac_affine_double_prec(double *matched_points, int npoints,
-                                     int *num_inliers_by_motion,
-                                     MotionModel *params_by_motion,
-                                     int num_desired_motions) {
-  return ransac_double_prec(matched_points, npoints, num_inliers_by_motion,
-                            params_by_motion, num_desired_motions, 3,
-                            is_degenerate_affine, find_affine,
-                            project_points_double_affine);
-}
-
-RansacFuncDouble av1_get_ransac_double_prec_type(TransformationType type) {
-  switch (type) {
-    case AFFINE: return ransac_affine_double_prec;
-    case ROTZOOM: return ransac_rotzoom_double_prec;
-    case TRANSLATION: return ransac_translation_double_prec;
-    default: assert(0); return NULL;
-  }
+  return ransac_internal(matched_points, npoints, motion_models,
+                         num_desired_motions, &ransac_model_info[type]);
 }
diff --git a/aom_dsp/flow_estimation/ransac.h b/aom_dsp/flow_estimation/ransac.h
index aa3a243..6047580 100644
--- a/aom_dsp/flow_estimation/ransac.h
+++ b/aom_dsp/flow_estimation/ransac.h
@@ -16,6 +16,7 @@
 #include <stdlib.h>
 #include <math.h>
 #include <memory.h>
+#include <stdbool.h>
 
 #include "aom_dsp/flow_estimation/flow_estimation.h"
 
@@ -23,14 +24,9 @@
 extern "C" {
 #endif
 
-typedef int (*RansacFunc)(int *matched_points, int npoints,
-                          int *num_inliers_by_motion,
-                          MotionModel *params_by_motion, int num_motions);
-typedef int (*RansacFuncDouble)(double *matched_points, int npoints,
-                                int *num_inliers_by_motion,
-                                MotionModel *params_by_motion, int num_motions);
-RansacFunc av1_get_ransac_type(TransformationType type);
-RansacFuncDouble av1_get_ransac_double_prec_type(TransformationType type);
+bool ransac(const Correspondence *matched_points, int npoints,
+            TransformationType type, MotionModel *motion_models,
+            int num_desired_motions);
 
 #ifdef __cplusplus
 }
diff --git a/aom_dsp/flow_estimation/x86/corner_match_avx2.c b/aom_dsp/flow_estimation/x86/corner_match_avx2.c
index 9830ad8..87c76fa 100644
--- a/aom_dsp/flow_estimation/x86/corner_match_avx2.c
+++ b/aom_dsp/flow_estimation/x86/corner_match_avx2.c
@@ -24,12 +24,13 @@
 #error "Need to change byte_mask in corner_match_sse4.c if MATCH_SZ != 13"
 #endif
 
-/* Compute corr(im1, im2) * MATCH_SZ * stddev(im1), where the
+/* Compute corr(frame1, frame2) * MATCH_SZ * stddev(frame1), where the
 correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows
 of each image, centered at (x1, y1) and (x2, y2) respectively.
 */
-double av1_compute_cross_correlation_avx2(unsigned char *im1, int stride1,
-                                          int x1, int y1, unsigned char *im2,
+double av1_compute_cross_correlation_avx2(const unsigned char *frame1,
+                                          int stride1, int x1, int y1,
+                                          const unsigned char *frame2,
                                           int stride2, int x2, int y2) {
   int i, stride1_i = 0, stride2_i = 0;
   __m256i temp1, sum_vec, sumsq2_vec, cross_vec, v, v1_1, v2_1;
@@ -41,13 +42,13 @@
   sumsq2_vec = zero;
   cross_vec = zero;
 
-  im1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2);
-  im2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2);
+  frame1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2);
+  frame2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2);
 
   for (i = 0; i < MATCH_SZ; ++i) {
-    v1 = _mm_and_si128(_mm_loadu_si128((__m128i *)&im1[stride1_i]), mask);
+    v1 = _mm_and_si128(_mm_loadu_si128((__m128i *)&frame1[stride1_i]), mask);
     v1_1 = _mm256_cvtepu8_epi16(v1);
-    v2 = _mm_and_si128(_mm_loadu_si128((__m128i *)&im2[stride2_i]), mask);
+    v2 = _mm_and_si128(_mm_loadu_si128((__m128i *)&frame2[stride2_i]), mask);
     v2_1 = _mm256_cvtepu8_epi16(v2);
 
     v = _mm256_insertf128_si256(_mm256_castsi128_si256(v1), v2, 1);
diff --git a/aom_dsp/flow_estimation/x86/corner_match_sse4.c b/aom_dsp/flow_estimation/x86/corner_match_sse4.c
index 40eec6c..b3cb5bc 100644
--- a/aom_dsp/flow_estimation/x86/corner_match_sse4.c
+++ b/aom_dsp/flow_estimation/x86/corner_match_sse4.c
@@ -28,12 +28,13 @@
 #error "Need to change byte_mask in corner_match_sse4.c if MATCH_SZ != 13"
 #endif
 
-/* Compute corr(im1, im2) * MATCH_SZ * stddev(im1), where the
+/* Compute corr(frame1, frame2) * MATCH_SZ * stddev(frame1), where the
    correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows
    of each image, centered at (x1, y1) and (x2, y2) respectively.
 */
-double av1_compute_cross_correlation_sse4_1(unsigned char *im1, int stride1,
-                                            int x1, int y1, unsigned char *im2,
+double av1_compute_cross_correlation_sse4_1(const unsigned char *frame1,
+                                            int stride1, int x1, int y1,
+                                            const unsigned char *frame2,
                                             int stride2, int x2, int y2) {
   int i;
   // 2 16-bit partial sums in lanes 0, 4 (== 2 32-bit partial sums in lanes 0,
@@ -47,14 +48,14 @@
   const __m128i mask = _mm_load_si128((__m128i *)byte_mask);
   const __m128i zero = _mm_setzero_si128();
 
-  im1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2);
-  im2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2);
+  frame1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2);
+  frame2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2);
 
   for (i = 0; i < MATCH_SZ; ++i) {
     const __m128i v1 =
-        _mm_and_si128(_mm_loadu_si128((__m128i *)&im1[i * stride1]), mask);
+        _mm_and_si128(_mm_loadu_si128((__m128i *)&frame1[i * stride1]), mask);
     const __m128i v2 =
-        _mm_and_si128(_mm_loadu_si128((__m128i *)&im2[i * stride2]), mask);
+        _mm_and_si128(_mm_loadu_si128((__m128i *)&frame2[i * stride2]), mask);
 
     // Using the 'sad' intrinsic here is a bit faster than adding
     // v1_l + v1_r and v2_l + v2_r, plus it avoids the need for a 16->32 bit
diff --git a/aom_dsp/flow_estimation/x86/disflow_sse4.c b/aom_dsp/flow_estimation/x86/disflow_sse4.c
new file mode 100644
index 0000000..6ef40da
--- /dev/null
+++ b/aom_dsp/flow_estimation/x86/disflow_sse4.c
@@ -0,0 +1,556 @@
+/*
+ * Copyright (c) 2022, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 3-Clause Clear License
+ * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+ * License was not distributed with this source code in the LICENSE file, you
+ * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/.  If the
+ * Alliance for Open Media Patent License 1.0 was not distributed with this
+ * source code in the PATENTS file, you can obtain it at
+ * aomedia.org/license/patent-license/.
+ */
+
+#include <assert.h>
+#include <math.h>
+#include <smmintrin.h>
+
+#include "aom_dsp/aom_dsp_common.h"
+#include "aom_dsp/flow_estimation/disflow.h"
+#include "aom_dsp/x86/synonyms.h"
+
+#include "config/aom_dsp_rtcd.h"
+
+// Internal cross-check against C code
+// If you set this to 1 and compile in debug mode, then the outputs of the two
+// convolution stages will be checked against the plain C version of the code,
+// and an assertion will be fired if the results differ.
+#define CHECK_RESULTS 1
+
+// Note: Max sum(+ve coefficients) = 1.125 * scale
+static INLINE void get_cubic_kernel_dbl(double x, double *kernel) {
+  assert(0 <= x && x < 1);
+  double x2 = x * x;
+  double x3 = x2 * x;
+  kernel[0] = -0.5 * x + x2 - 0.5 * x3;
+  kernel[1] = 1.0 - 2.5 * x2 + 1.5 * x3;
+  kernel[2] = 0.5 * x + 2.0 * x2 - 1.5 * x3;
+  kernel[3] = -0.5 * x2 + 0.5 * x3;
+}
+
+static INLINE void get_cubic_kernel_int(double x, int16_t *kernel) {
+  double kernel_dbl[4];
+  get_cubic_kernel_dbl(x, kernel_dbl);
+
+  kernel[0] = (int16_t)rint(kernel_dbl[0] * (1 << DISFLOW_INTERP_BITS));
+  kernel[1] = (int16_t)rint(kernel_dbl[1] * (1 << DISFLOW_INTERP_BITS));
+  kernel[2] = (int16_t)rint(kernel_dbl[2] * (1 << DISFLOW_INTERP_BITS));
+  kernel[3] = (int16_t)rint(kernel_dbl[3] * (1 << DISFLOW_INTERP_BITS));
+}
+
+#if CHECK_RESULTS
+static INLINE int get_cubic_value_int(const int *p, const int16_t *kernel) {
+  return kernel[0] * p[0] + kernel[1] * p[1] + kernel[2] * p[2] +
+         kernel[3] * p[3];
+}
+#endif  // CHECK_RESULTS
+
+// Compare two regions of width x height pixels, one rooted at position
+// (x, y) in src and the other at (x + u, y + v) in ref.
+// This function returns the sum of squared pixel differences between
+// the two regions.
+//
+// TODO(rachelbarker): Test speed/quality impact of using bilinear interpolation
+// instad of bicubic interpolation
+static INLINE void compute_flow_error(const uint8_t *src, const uint8_t *ref,
+                                      int width, int height, int stride, int x,
+                                      int y, double u, double v, int16_t *dt) {
+  // This function is written to do 8x8 convolutions only
+  assert(DISFLOW_PATCH_SIZE == 8);
+
+  // Split offset into integer and fractional parts, and compute cubic
+  // interpolation kernels
+  const int u_int = (int)floor(u);
+  const int v_int = (int)floor(v);
+  const double u_frac = u - floor(u);
+  const double v_frac = v - floor(v);
+
+  int16_t h_kernel[4];
+  int16_t v_kernel[4];
+  get_cubic_kernel_int(u_frac, h_kernel);
+  get_cubic_kernel_int(v_frac, v_kernel);
+
+  // Storage for intermediate values between the two convolution directions
+  int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 3)];
+  int16_t *tmp = tmp_ + DISFLOW_PATCH_SIZE;  // Offset by one row
+
+  // Clamp coordinates so that all pixels we fetch will remain within the
+  // allocated border region, but allow them to go far enough out that
+  // the border pixels' values do not change.
+  // Since we are calculating an 8x8 block, the bottom-right pixel
+  // in the block has coordinates (x0 + 7, y0 + 7). Then, the cubic
+  // interpolation has 4 taps, meaning that the output of pixel
+  // (x_w, y_w) depends on the pixels in the range
+  // ([x_w - 1, x_w + 2], [y_w - 1, y_w + 2]).
+  //
+  // Thus the most extreme coordinates which will be fetched are
+  // (x0 - 1, y0 - 1) and (x0 + 9, y0 + 9).
+  const int x0 = clamp(x + u_int, -9, width);
+  const int y0 = clamp(y + v_int, -9, height);
+
+  // Horizontal convolution
+
+  // Prepare the kernel vectors
+  // We split the kernel into two vectors with kernel indices:
+  // 0, 1, 0, 1, 0, 1, 0, 1, and
+  // 2, 3, 2, 3, 2, 3, 2, 3
+  __m128i h_kernel_01 = xx_set2_epi16(h_kernel[0], h_kernel[1]);
+  __m128i h_kernel_23 = xx_set2_epi16(h_kernel[2], h_kernel[3]);
+
+  __m128i round_const_h = _mm_set1_epi32(1 << (DISFLOW_INTERP_BITS - 6 - 1));
+
+  for (int i = -1; i < DISFLOW_PATCH_SIZE + 2; ++i) {
+    const int y_w = y0 + i;
+    const uint8_t *ref_row = &ref[y_w * stride + (x0 - 1)];
+    int16_t *tmp_row = &tmp[i * DISFLOW_PATCH_SIZE];
+
+    // Load this row of pixels.
+    // For an 8x8 patch, we need to load the 8 image pixels + 3 extras,
+    // for a total of 11 pixels. Here we load 16 pixels, but only use
+    // the first 11.
+    __m128i row = _mm_loadu_si128((__m128i *)ref_row);
+
+    // Expand pixels to int16s
+    __m128i px_0to7_i16 = _mm_cvtepu8_epi16(row);
+    __m128i px_4to10_i16 = _mm_cvtepu8_epi16(_mm_srli_si128(row, 4));
+
+    // Relevant multiply instruction
+    // This multiplies pointwise, then sums in pairs.
+    //_mm_madd_epi16();
+
+    // Compute first four outputs
+    // input pixels 0, 1, 1, 2, 2, 3, 3, 4
+    // * kernel     0, 1, 0, 1, 0, 1, 0, 1
+    __m128i px0 =
+        _mm_unpacklo_epi16(px_0to7_i16, _mm_srli_si128(px_0to7_i16, 2));
+    // input pixels 2, 3, 3, 4, 4, 5, 5, 6
+    // * kernel     2, 3, 2, 3, 2, 3, 2, 3
+    __m128i px1 = _mm_unpacklo_epi16(_mm_srli_si128(px_0to7_i16, 4),
+                                     _mm_srli_si128(px_0to7_i16, 6));
+    // Convolve with kernel and sum 2x2 boxes to form first 4 outputs
+    __m128i sum0 = _mm_add_epi32(_mm_madd_epi16(px0, h_kernel_01),
+                                 _mm_madd_epi16(px1, h_kernel_23));
+
+    __m128i out0 = _mm_srai_epi32(_mm_add_epi32(sum0, round_const_h),
+                                  DISFLOW_INTERP_BITS - 6);
+
+    // Compute second four outputs
+    __m128i px2 =
+        _mm_unpacklo_epi16(px_4to10_i16, _mm_srli_si128(px_4to10_i16, 2));
+    __m128i px3 = _mm_unpacklo_epi16(_mm_srli_si128(px_4to10_i16, 4),
+                                     _mm_srli_si128(px_4to10_i16, 6));
+    __m128i sum1 = _mm_add_epi32(_mm_madd_epi16(px2, h_kernel_01),
+                                 _mm_madd_epi16(px3, h_kernel_23));
+
+    // Round by just enough bits that the result is
+    // guaranteed to fit into an i16. Then the next stage can use 16 x 16 -> 32
+    // bit multiplies, which should be a fair bit faster than 32 x 32 -> 32
+    // as it does now
+    // This means shifting down so we have 6 extra bits, for a maximum value
+    // of +18360, which can occur if u_frac == 0.5 and the input pixels are
+    // {0, 255, 255, 0}.
+    __m128i out1 = _mm_srai_epi32(_mm_add_epi32(sum1, round_const_h),
+                                  DISFLOW_INTERP_BITS - 6);
+
+    _mm_storeu_si128((__m128i *)tmp_row, _mm_packs_epi32(out0, out1));
+
+#if CHECK_RESULTS && !defined(NDEBUG)
+    // Cross-check
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; ++j) {
+      const int x_w = x0 + j;
+      int arr[4];
+
+      arr[0] = (int)ref[y_w * stride + (x_w - 1)];
+      arr[1] = (int)ref[y_w * stride + (x_w + 0)];
+      arr[2] = (int)ref[y_w * stride + (x_w + 1)];
+      arr[3] = (int)ref[y_w * stride + (x_w + 2)];
+
+      // Apply kernel and round, keeping 6 extra bits of precision.
+      //
+      // 6 is the maximum allowable number of extra bits which will avoid
+      // the intermediate values overflowing an int16_t. The most extreme
+      // intermediate value occurs when:
+      // * The input pixels are [0, 255, 255, 0]
+      // * u_frac = 0.5
+      // In this case, the un-scaled output is 255 * 1.125 = 286.875.
+      // As an integer with 6 fractional bits, that is 18360, which fits
+      // in an int16_t. But with 7 fractional bits it would be 36720,
+      // which is too large.
+      const int c_value = ROUND_POWER_OF_TWO(get_cubic_value_int(arr, h_kernel),
+                                             DISFLOW_INTERP_BITS - 6);
+      (void)c_value;  // Suppress warnings
+      assert(tmp_row[j] == c_value);
+    }
+#endif  // CHECK_RESULTS
+  }
+
+  // Vertical convolution
+  const int round_bits = DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2;
+  __m128i round_const_v = _mm_set1_epi32(1 << (round_bits - 1));
+
+  __m128i v_kernel_01 = xx_set2_epi16(v_kernel[0], v_kernel[1]);
+  __m128i v_kernel_23 = xx_set2_epi16(v_kernel[2], v_kernel[3]);
+
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; ++i) {
+    int16_t *tmp_row = &tmp[i * DISFLOW_PATCH_SIZE];
+
+    // Load 4 rows of 8 x 16-bit values
+    __m128i px0 = _mm_loadu_si128((__m128i *)(tmp_row - DISFLOW_PATCH_SIZE));
+    __m128i px1 = _mm_loadu_si128((__m128i *)tmp_row);
+    __m128i px2 = _mm_loadu_si128((__m128i *)(tmp_row + DISFLOW_PATCH_SIZE));
+    __m128i px3 =
+        _mm_loadu_si128((__m128i *)(tmp_row + 2 * DISFLOW_PATCH_SIZE));
+
+    // We want to calculate px0 * v_kernel[0] + px1 * v_kernel[1] + ... ,
+    // but each multiply expands its output to 32 bits. So we need to be
+    // a little clever about how we do this
+    __m128i sum0 = _mm_add_epi32(
+        _mm_madd_epi16(_mm_unpacklo_epi16(px0, px1), v_kernel_01),
+        _mm_madd_epi16(_mm_unpacklo_epi16(px2, px3), v_kernel_23));
+    __m128i sum1 = _mm_add_epi32(
+        _mm_madd_epi16(_mm_unpackhi_epi16(px0, px1), v_kernel_01),
+        _mm_madd_epi16(_mm_unpackhi_epi16(px2, px3), v_kernel_23));
+
+    __m128i sum0_rounded =
+        _mm_srai_epi32(_mm_add_epi32(sum0, round_const_v), round_bits);
+    __m128i sum1_rounded =
+        _mm_srai_epi32(_mm_add_epi32(sum1, round_const_v), round_bits);
+
+    __m128i warped = _mm_packs_epi32(sum0_rounded, sum1_rounded);
+    __m128i src_pixels_u8 =
+        _mm_loadl_epi64((__m128i *)&src[(y + i) * stride + x]);
+    __m128i src_pixels = _mm_slli_epi16(_mm_cvtepu8_epi16(src_pixels_u8), 3);
+
+    // Calculate delta from the target patch
+    __m128i err = _mm_sub_epi16(warped, src_pixels);
+    _mm_storeu_si128((__m128i *)&dt[i * DISFLOW_PATCH_SIZE], err);
+
+#if CHECK_RESULTS
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; ++j) {
+      int16_t *p = &tmp[i * DISFLOW_PATCH_SIZE + j];
+      int arr[4] = { p[-DISFLOW_PATCH_SIZE], p[0], p[DISFLOW_PATCH_SIZE],
+                     p[2 * DISFLOW_PATCH_SIZE] };
+      const int result = get_cubic_value_int(arr, v_kernel);
+
+      // Apply kernel and round.
+      // This time, we have to round off the 6 extra bits which were kept
+      // earlier, but we also want to keep DISFLOW_DERIV_SCALE_LOG2 extra bits
+      // of precision to match the scale of the dx and dy arrays.
+      const int c_warped = ROUND_POWER_OF_TWO(result, round_bits);
+      const int c_src_px = src[(x + j) + (y + i) * stride] << 3;
+      const int c_err = c_warped - c_src_px;
+      (void)c_err;
+      assert(dt[i * DISFLOW_PATCH_SIZE + j] == c_err);
+    }
+#endif  // CHECK_RESULTS
+  }
+}
+
+static INLINE void sobel_filter_x(const uint8_t *src, int src_stride,
+                                  int16_t *dst, int dst_stride) {
+  int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
+  int16_t *tmp = tmp_ + DISFLOW_PATCH_SIZE;
+  const int taps = 3;
+
+  // Horizontal filter
+  // As the kernel is simply {1, 0, -1}, we implement this as simply
+  //  out[x] = image[x-1] - image[x+1]
+  // rather than doing a "proper" convolution operation
+  for (int y = -1; y < DISFLOW_PATCH_SIZE + 1; ++y) {
+    const uint8_t *src_row = src + y * src_stride;
+    int16_t *tmp_row = tmp + y * DISFLOW_PATCH_SIZE;
+
+    // Load pixels and expand to 16 bits
+    __m128i row = _mm_loadu_si128((__m128i *)(src_row - 1));
+    __m128i px0 = _mm_cvtepu8_epi16(row);
+    __m128i px2 = _mm_cvtepu8_epi16(_mm_srli_si128(row, 2));
+
+    __m128i out = _mm_sub_epi16(px0, px2);
+
+    // Store to intermediate array
+    _mm_storeu_si128((__m128i *)tmp_row, out);
+
+#if CHECK_RESULTS
+    // Cross-check
+    static const int16_t h_kernel[3] = { 1, 0, -1 };
+    for (int x = 0; x < DISFLOW_PATCH_SIZE; ++x) {
+      int sum = 0;
+      for (int k = 0; k < taps; ++k) {
+        sum += h_kernel[k] * src_row[x + k - 1];
+      }
+      (void)sum;
+      assert(tmp_row[x] == sum);
+    }
+#endif  // CHECK_RESULTS
+  }
+
+  // Vertical filter
+  // Here the kernel is {1, 2, 1}, which can be implemented
+  // with simple sums rather than multiplies and adds.
+  // In order to minimize dependency chains, we evaluate in the order
+  // (image[y - 1] + image[y + 1]) + (image[y] << 1)
+  // This way, the first addition and the shift can happen in parallel
+  for (int y = 0; y < DISFLOW_PATCH_SIZE; ++y) {
+    const int16_t *tmp_row = tmp + y * DISFLOW_PATCH_SIZE;
+    int16_t *dst_row = dst + y * dst_stride;
+
+    __m128i px0 = _mm_loadu_si128((__m128i *)(tmp_row - DISFLOW_PATCH_SIZE));
+    __m128i px1 = _mm_loadu_si128((__m128i *)tmp_row);
+    __m128i px2 = _mm_loadu_si128((__m128i *)(tmp_row + DISFLOW_PATCH_SIZE));
+
+    __m128i out =
+        _mm_add_epi16(_mm_add_epi16(px0, px2), _mm_slli_epi16(px1, 1));
+
+    _mm_storeu_si128((__m128i *)dst_row, out);
+
+#if CHECK_RESULTS
+    static const int16_t v_kernel[3] = { 1, 2, 1 };
+    for (int x = 0; x < DISFLOW_PATCH_SIZE; ++x) {
+      int sum = 0;
+      for (int k = 0; k < taps; ++k) {
+        sum += v_kernel[k] * tmp[(y + k - 1) * DISFLOW_PATCH_SIZE + x];
+      }
+      (void)sum;
+      assert(dst_row[x] == sum);
+    }
+#endif  // CHECK_RESULTS
+  }
+}
+
+static INLINE void sobel_filter_y(const uint8_t *src, int src_stride,
+                                  int16_t *dst, int dst_stride) {
+  int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
+  int16_t *tmp = tmp_ + DISFLOW_PATCH_SIZE;
+  const int taps = 3;
+
+  // Horizontal filter
+  // Here the kernel is {1, 2, 1}, which can be implemented
+  // with simple sums rather than multiplies and adds.
+  // In order to minimize dependency chains, we evaluate in the order
+  // (image[y - 1] + image[y + 1]) + (image[y] << 1)
+  // This way, the first addition and the shift can happen in parallel
+  for (int y = -1; y < DISFLOW_PATCH_SIZE + 1; ++y) {
+    const uint8_t *src_row = src + y * src_stride;
+    int16_t *tmp_row = tmp + y * DISFLOW_PATCH_SIZE;
+
+    // Load pixels and expand to 16 bits
+    __m128i row = _mm_loadu_si128((__m128i *)(src_row - 1));
+    __m128i px0 = _mm_cvtepu8_epi16(row);
+    __m128i px1 = _mm_cvtepu8_epi16(_mm_srli_si128(row, 1));
+    __m128i px2 = _mm_cvtepu8_epi16(_mm_srli_si128(row, 2));
+
+    __m128i out =
+        _mm_add_epi16(_mm_add_epi16(px0, px2), _mm_slli_epi16(px1, 1));
+
+    // Store to intermediate array
+    _mm_storeu_si128((__m128i *)tmp_row, out);
+
+#if CHECK_RESULTS
+    // Cross-check
+    static const int16_t h_kernel[3] = { 1, 2, 1 };
+    for (int x = 0; x < DISFLOW_PATCH_SIZE; ++x) {
+      int sum = 0;
+      for (int k = 0; k < taps; ++k) {
+        sum += h_kernel[k] * src_row[x + k - 1];
+      }
+      (void)sum;
+      assert(tmp_row[x] == sum);
+    }
+#endif  // CHECK_RESULTS
+  }
+
+  // Vertical filter
+  // As the kernel is simply {1, 0, -1}, we implement this as simply
+  //  out[x] = image[x-1] - image[x+1]
+  // rather than doing a "proper" convolution operation
+  for (int y = 0; y < DISFLOW_PATCH_SIZE; ++y) {
+    const int16_t *tmp_row = tmp + y * DISFLOW_PATCH_SIZE;
+    int16_t *dst_row = dst + y * dst_stride;
+
+    __m128i px0 = _mm_loadu_si128((__m128i *)(tmp_row - DISFLOW_PATCH_SIZE));
+    __m128i px2 = _mm_loadu_si128((__m128i *)(tmp_row + DISFLOW_PATCH_SIZE));
+
+    __m128i out = _mm_sub_epi16(px0, px2);
+
+    _mm_storeu_si128((__m128i *)dst_row, out);
+
+#if CHECK_RESULTS
+    static const int16_t v_kernel[3] = { 1, 0, -1 };
+    for (int x = 0; x < DISFLOW_PATCH_SIZE; ++x) {
+      int sum = 0;
+      for (int k = 0; k < taps; ++k) {
+        sum += v_kernel[k] * tmp[(y + k - 1) * DISFLOW_PATCH_SIZE + x];
+      }
+      (void)sum;
+      assert(dst_row[x] == sum);
+    }
+#endif  // CHECK_RESULTS
+  }
+}
+
+static INLINE void compute_flow_vector(const int16_t *dx, int dx_stride,
+                                       const int16_t *dy, int dy_stride,
+                                       const int16_t *dt, int dt_stride,
+                                       int *b) {
+  __m128i b0_acc = _mm_setzero_si128();
+  __m128i b1_acc = _mm_setzero_si128();
+
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+    // Need to load 8 values of dx, 8 of dy, 8 of dt, which conveniently
+    // works out to one register each. Then just calculate dx * dt, dy * dt,
+    // and (implicitly) sum horizontally in pairs.
+    // This gives four 32-bit partial sums for each of b[0] and b[1],
+    // which can be accumulated and summed at the end.
+    __m128i dx_row = _mm_loadu_si128((__m128i *)&dx[i * dx_stride]);
+    __m128i dy_row = _mm_loadu_si128((__m128i *)&dy[i * dy_stride]);
+    __m128i dt_row = _mm_loadu_si128((__m128i *)&dt[i * dt_stride]);
+
+    b0_acc = _mm_add_epi32(b0_acc, _mm_madd_epi16(dx_row, dt_row));
+    b1_acc = _mm_add_epi32(b1_acc, _mm_madd_epi16(dy_row, dt_row));
+  }
+
+  // We need to set b[0] = sum(b0_acc), b[1] = sum(b1_acc).
+  // We might as well use a `hadd` instruction to do 4 of the additions
+  // needed here. Then that just leaves two more additions, which can be
+  // done in scalar code
+  __m128i partial_sum = _mm_hadd_epi32(b0_acc, b1_acc);
+  b[0] = _mm_extract_epi32(partial_sum, 0) + _mm_extract_epi32(partial_sum, 1);
+  b[1] = _mm_extract_epi32(partial_sum, 2) + _mm_extract_epi32(partial_sum, 3);
+
+#if CHECK_RESULTS
+  int c_result[2] = { 0 };
+
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; j++) {
+      c_result[0] += dx[i * dx_stride + j] * dt[i * dt_stride + j];
+      c_result[1] += dy[i * dy_stride + j] * dt[i * dt_stride + j];
+    }
+  }
+
+  assert(b[0] == c_result[0]);
+  assert(b[1] == c_result[1]);
+#endif  // CHECK_RESULTS
+}
+
+static INLINE void compute_flow_matrix(const int16_t *dx, int dx_stride,
+                                       const int16_t *dy, int dy_stride,
+                                       double *M) {
+  __m128i acc[4] = { 0 };
+
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+    __m128i dx_row = _mm_loadu_si128((__m128i *)&dx[i * dx_stride]);
+    __m128i dy_row = _mm_loadu_si128((__m128i *)&dy[i * dy_stride]);
+
+    acc[0] = _mm_add_epi32(acc[0], _mm_madd_epi16(dx_row, dx_row));
+    acc[1] = _mm_add_epi32(acc[1], _mm_madd_epi16(dx_row, dy_row));
+    // Don't compute acc[2], as it should be equal to acc[1]
+    acc[3] = _mm_add_epi32(acc[3], _mm_madd_epi16(dy_row, dy_row));
+  }
+
+  // Condense sums
+  __m128i partial_sum_0 = _mm_hadd_epi32(acc[0], acc[1]);
+  __m128i partial_sum_1 = _mm_hadd_epi32(acc[1], acc[3]);
+  __m128i result = _mm_hadd_epi32(partial_sum_0, partial_sum_1);
+
+  // Apply regularization
+  // We follow the standard regularization method of adding `k * I` before
+  // inverting. This ensures that the matrix will be invertible.
+  //
+  // Setting the regularization strength k to 1 seems to work well here, as
+  // typical values coming from the other equations are very large (1e5 to
+  // 1e6, with an upper limit of around 6e7, at the time of writing).
+  // It also preserves the property that all matrix values are whole numbers,
+  // which is convenient for integerized SIMD implementation.
+  result = _mm_add_epi32(result, _mm_set_epi32(1, 0, 0, 1));
+
+#if CHECK_RESULTS
+  int tmp[4] = { 0 };
+
+  for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+    for (int j = 0; j < DISFLOW_PATCH_SIZE; j++) {
+      tmp[0] += dx[i * dx_stride + j] * dx[i * dx_stride + j];
+      tmp[1] += dx[i * dx_stride + j] * dy[i * dy_stride + j];
+      // Don't compute tmp[2], as it should be equal to tmp[1]
+      tmp[3] += dy[i * dy_stride + j] * dy[i * dy_stride + j];
+    }
+  }
+
+  // Apply regularization
+  tmp[0] += 1;
+  tmp[3] += 1;
+
+  tmp[2] = tmp[1];
+
+  assert(tmp[0] == _mm_extract_epi32(result, 0));
+  assert(tmp[1] == _mm_extract_epi32(result, 1));
+  assert(tmp[2] == _mm_extract_epi32(result, 2));
+  assert(tmp[3] == _mm_extract_epi32(result, 3));
+#endif  // CHECK_RESULTS
+
+  // Convert results to doubles and store
+  _mm_storeu_pd(M, _mm_cvtepi32_pd(result));
+  _mm_storeu_pd(M + 2, _mm_cvtepi32_pd(_mm_srli_si128(result, 8)));
+}
+
+// Try to invert the matrix M
+// Note: Due to the nature of how a least-squares matrix is constructed, all of
+// the eigenvalues will be >= 0, and therefore det M >= 0 as well.
+// The regularization term `+ k * I` further ensures that det M >= k^2.
+// As mentioned in compute_flow_matrix(), here we use k = 1, so det M >= 1.
+// So we don't have to worry about non-invertible matrices here.
+static INLINE void invert_2x2(const double *M, double *M_inv) {
+  double det = (M[0] * M[3]) - (M[1] * M[2]);
+  assert(det >= 1);
+  const double det_inv = 1 / det;
+
+  M_inv[0] = M[3] * det_inv;
+  M_inv[1] = -M[1] * det_inv;
+  M_inv[2] = -M[2] * det_inv;
+  M_inv[3] = M[0] * det_inv;
+}
+
+void aom_compute_flow_at_point_sse4_1(const uint8_t *src, const uint8_t *ref,
+                                      int x, int y, int width, int height,
+                                      int stride, double *u, double *v) {
+  double M[4];
+  double M_inv[4];
+  int b[2];
+  int16_t dt[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+  int16_t dx[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+  int16_t dy[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+
+  // Compute gradients within this patch
+  const uint8_t *src_patch = &src[y * stride + x];
+  sobel_filter_x(src_patch, stride, dx, DISFLOW_PATCH_SIZE);
+  sobel_filter_y(src_patch, stride, dy, DISFLOW_PATCH_SIZE);
+
+  compute_flow_matrix(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, M);
+  invert_2x2(M, M_inv);
+
+  for (int itr = 0; itr < DISFLOW_MAX_ITR; itr++) {
+    compute_flow_error(src, ref, width, height, stride, x, y, *u, *v, dt);
+    compute_flow_vector(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, dt,
+                        DISFLOW_PATCH_SIZE, b);
+
+    // Solve flow equations to find a better estimate for the flow vector
+    // at this point
+    const double step_u = M_inv[0] * b[0] + M_inv[1] * b[1];
+    const double step_v = M_inv[2] * b[0] + M_inv[3] * b[1];
+    *u += fclamp(step_u * DISFLOW_STEP_SIZE, -2, 2);
+    *v += fclamp(step_v * DISFLOW_STEP_SIZE, -2, 2);
+
+    if (fabs(step_u) + fabs(step_v) < DISFLOW_STEP_SIZE_THRESOLD) {
+      // Stop iteration when we're close to convergence
+      break;
+    }
+  }
+}
diff --git a/aom_dsp/fwd_txfm.c b/aom_dsp/fwd_txfm.c
index 3d30444..5503501 100644
--- a/aom_dsp/fwd_txfm.c
+++ b/aom_dsp/fwd_txfm.c
@@ -16,19 +16,16 @@
 void aom_fdct4x4_c(const int16_t *input, tran_low_t *output, int stride) {
   // The 2D transform is done with two passes which are actually pretty
   // similar. In the first one, we transform the columns and transpose
-  // the results. In the second one, we transform the rows. To achieve that,
-  // as the first pass results are transposed, we transpose the columns (that
-  // is the transposed rows) and transpose the results (so that it goes back
-  // in normal/row positions).
+  // the results. In the second one, we transform the rows.
   // We need an intermediate buffer between passes.
   tran_low_t intermediate[4 * 4];
   const tran_low_t *in_low = NULL;
   tran_low_t *out = intermediate;
-  // Do the two transform/transpose passes
+  // Do the two transform passes
   for (int pass = 0; pass < 2; ++pass) {
-    tran_high_t in_high[4];    // canbe16
-    tran_high_t step[4];       // canbe16
-    tran_high_t temp1, temp2;  // needs32
+    tran_high_t in_high[4];  // canbe16
+    tran_high_t step[4];     // canbe16
+    tran_low_t temp[4];
     for (int i = 0; i < 4; ++i) {
       // Load inputs.
       if (pass == 0) {
@@ -39,30 +36,40 @@
         if (i == 0 && in_high[0]) {
           ++in_high[0];
         }
+        ++input;  // Next column
       } else {
         assert(in_low != NULL);
         in_high[0] = in_low[0 * 4];
         in_high[1] = in_low[1 * 4];
         in_high[2] = in_low[2 * 4];
         in_high[3] = in_low[3 * 4];
-        ++in_low;
+        ++in_low;  // Next column (which is a transposed row)
       }
       // Transform.
       step[0] = in_high[0] + in_high[3];
       step[1] = in_high[1] + in_high[2];
       step[2] = in_high[1] - in_high[2];
       step[3] = in_high[0] - in_high[3];
-      temp1 = (step[0] + step[1]) * cospi_16_64;
-      temp2 = (step[0] - step[1]) * cospi_16_64;
-      out[0] = (tran_low_t)fdct_round_shift(temp1);
-      out[2] = (tran_low_t)fdct_round_shift(temp2);
-      temp1 = step[2] * cospi_24_64 + step[3] * cospi_8_64;
-      temp2 = -step[2] * cospi_8_64 + step[3] * cospi_24_64;
-      out[1] = (tran_low_t)fdct_round_shift(temp1);
-      out[3] = (tran_low_t)fdct_round_shift(temp2);
-      // Do next column (which is a transposed row in second/horizontal pass)
-      ++input;
-      out += 4;
+      temp[0] = (tran_low_t)fdct_round_shift((step[0] + step[1]) * cospi_16_64);
+      temp[2] = (tran_low_t)fdct_round_shift((step[0] - step[1]) * cospi_16_64);
+      temp[1] = (tran_low_t)fdct_round_shift(step[2] * cospi_24_64 +
+                                             step[3] * cospi_8_64);
+      temp[3] = (tran_low_t)fdct_round_shift(-step[2] * cospi_8_64 +
+                                             step[3] * cospi_24_64);
+      // Only transpose the first pass.
+      if (pass == 0) {
+        out[0] = temp[0];
+        out[1] = temp[1];
+        out[2] = temp[2];
+        out[3] = temp[3];
+        out += 4;
+      } else {
+        out[0 * 4] = temp[0];
+        out[1 * 4] = temp[1];
+        out[2 * 4] = temp[2];
+        out[3 * 4] = temp[3];
+        ++out;
+      }
     }
     // Setup in/out for next pass.
     in_low = intermediate;
@@ -78,19 +85,16 @@
 void aom_fdct4x4_lp_c(const int16_t *input, int16_t *output, int stride) {
   // The 2D transform is done with two passes which are actually pretty
   // similar. In the first one, we transform the columns and transpose
-  // the results. In the second one, we transform the rows. To achieve that,
-  // as the first pass results are transposed, we transpose the columns (that
-  // is the transposed rows) and transpose the results (so that it goes back
-  // in normal/row positions).
+  // the results. In the second one, we transform the rows.
   // We need an intermediate buffer between passes.
   int16_t intermediate[4 * 4];
   const int16_t *in_low = NULL;
   int16_t *out = intermediate;
-  // Do the two transform/transpose passes
+  // Do the two transform passes
   for (int pass = 0; pass < 2; ++pass) {
-    int32_t in_high[4];    // canbe16
-    int32_t step[4];       // canbe16
-    int32_t temp1, temp2;  // needs32
+    int32_t in_high[4];  // canbe16
+    int32_t step[4];     // canbe16
+    int16_t temp[4];
     for (int i = 0; i < 4; ++i) {
       // Load inputs.
       if (pass == 0) {
@@ -98,6 +102,7 @@
         in_high[1] = input[1 * stride] * 16;
         in_high[2] = input[2 * stride] * 16;
         in_high[3] = input[3 * stride] * 16;
+        ++input;
         if (i == 0 && in_high[0]) {
           ++in_high[0];
         }
@@ -114,17 +119,26 @@
       step[1] = in_high[1] + in_high[2];
       step[2] = in_high[1] - in_high[2];
       step[3] = in_high[0] - in_high[3];
-      temp1 = (step[0] + step[1]) * (int32_t)cospi_16_64;
-      temp2 = (step[0] - step[1]) * (int32_t)cospi_16_64;
-      out[0] = (int16_t)fdct_round_shift(temp1);
-      out[2] = (int16_t)fdct_round_shift(temp2);
-      temp1 = step[2] * (int32_t)cospi_24_64 + step[3] * (int32_t)cospi_8_64;
-      temp2 = -step[2] * (int32_t)cospi_8_64 + step[3] * (int32_t)cospi_24_64;
-      out[1] = (int16_t)fdct_round_shift(temp1);
-      out[3] = (int16_t)fdct_round_shift(temp2);
-      // Do next column (which is a transposed row in second/horizontal pass)
-      ++input;
-      out += 4;
+      temp[0] = (int16_t)fdct_round_shift((step[0] + step[1]) * cospi_16_64);
+      temp[2] = (int16_t)fdct_round_shift((step[0] - step[1]) * cospi_16_64);
+      temp[1] = (int16_t)fdct_round_shift(step[2] * cospi_24_64 +
+                                          step[3] * cospi_8_64);
+      temp[3] = (int16_t)fdct_round_shift(-step[2] * cospi_8_64 +
+                                          step[3] * cospi_24_64);
+      // Only transpose the first pass.
+      if (pass == 0) {
+        out[0] = temp[0];
+        out[1] = temp[1];
+        out[2] = temp[2];
+        out[3] = temp[3];
+        out += 4;
+      } else {
+        out[0 * 4] = temp[0];
+        out[1 * 4] = temp[1];
+        out[2 * 4] = temp[2];
+        out[3 * 4] = temp[3];
+        ++out;
+      }
     }
     // Setup in/out for next pass.
     in_low = intermediate;
@@ -137,6 +151,7 @@
   }
 }
 
+#if CONFIG_INTERNAL_STATS
 void aom_fdct8x8_c(const int16_t *input, tran_low_t *final_output, int stride) {
   int i, j;
   tran_low_t intermediate[64];
@@ -220,8 +235,9 @@
     for (j = 0; j < 8; ++j) final_output[j + i * 8] /= 2;
   }
 }
+#endif  // CONFIG_INTERNAL_STATS
 
-#if CONFIG_AV1_HIGHBITDEPTH
+#if CONFIG_AV1_HIGHBITDEPTH && CONFIG_INTERNAL_STATS
 void aom_highbd_fdct8x8_c(const int16_t *input, tran_low_t *final_output,
                           int stride) {
   aom_fdct8x8_c(input, final_output, stride);
diff --git a/aom_dsp/grain_table.h b/aom_dsp/grain_table.h
index 3f75101..49e8498 100644
--- a/aom_dsp/grain_table.h
+++ b/aom_dsp/grain_table.h
@@ -52,7 +52,7 @@
 /*!\brief Add a mapping from [time_stamp, end_time) to the given grain
  * parameters
  *
- * \param[in/out] table      The grain table
+ * \param[in,out] table      The grain table
  * \param[in]     time_stamp The start time stamp
  * \param[in]     end_stamp  The end time_stamp
  * \param[in]     grain      The grain parameters
diff --git a/aom_dsp/mathutils.h b/aom_dsp/mathutils.h
index 22b0202..cbb6cf4 100644
--- a/aom_dsp/mathutils.h
+++ b/aom_dsp/mathutils.h
@@ -63,32 +63,51 @@
 // Solves for n-dim x in a least squares sense to minimize |Ax - b|^2
 // The solution is simply x = (A'A)^-1 A'b or simply the solution for
 // the system: A'A x = A'b
-static INLINE int least_squares(int n, double *A, int rows, int stride,
-                                double *b, double *scratch, double *x) {
-  int i, j, k;
-  double *scratch_ = NULL;
-  double *AtA, *Atb;
-  if (!scratch) {
-    scratch_ = (double *)aom_malloc(sizeof(*scratch) * n * (n + 1));
-    if (!scratch_) return 0;
-    scratch = scratch_;
-  }
-  AtA = scratch;
-  Atb = scratch + n * n;
+//
+// This process is split into three steps in order to avoid needing to
+// explicitly allocate the A matrix, which may be very large if there
+// are many equations to solve.
+//
+// The process for using this is (in pseudocode):
+//
+// Allocate mat (size n*n), y (size n), a (size n), x (size n)
+// least_squares_init(mat, y, n)
+// for each equation a . x = b {
+//    least_squares_accumulate(mat, y, a, b, n)
+// }
+// least_squares_solve(mat, y, x, n)
+//
+// where:
+// * mat, y are accumulators for the values A'A and A'b respectively,
+// * a, b are the coefficients of each individual equation,
+// * x is the result vector
+// * and n is the problem size
+static INLINE void least_squares_init(double *mat, double *y, int n) {
+  memset(mat, 0, n * n * sizeof(double));
+  memset(y, 0, n * sizeof(double));
+}
 
-  for (i = 0; i < n; ++i) {
-    for (j = i; j < n; ++j) {
-      AtA[i * n + j] = 0.0;
-      for (k = 0; k < rows; ++k)
-        AtA[i * n + j] += A[k * stride + i] * A[k * stride + j];
-      AtA[j * n + i] = AtA[i * n + j];
+// Round the given positive value to nearest integer
+static AOM_FORCE_INLINE int iroundpf(float x) {
+  assert(x >= 0.0);
+  return (int)(x + 0.5f);
+}
+
+static INLINE void least_squares_accumulate(double *mat, double *y,
+                                            const double *a, double b, int n) {
+  for (int i = 0; i < n; i++) {
+    for (int j = 0; j < n; j++) {
+      mat[i * n + j] += a[i] * a[j];
     }
-    Atb[i] = 0;
-    for (k = 0; k < rows; ++k) Atb[i] += A[k * stride + i] * b[k];
   }
-  int ret = linsolve(n, AtA, n, Atb, x);
-  aom_free(scratch_);
-  return ret;
+  for (int i = 0; i < n; i++) {
+    y[i] += a[i] * b;
+  }
+}
+
+static INLINE int least_squares_solve(double *mat, double *y, double *x,
+                                      int n) {
+  return linsolve(n, mat, n, y, x);
 }
 
 // Matrix multiply
@@ -108,4 +127,19 @@
   }
 }
 
+static AOM_INLINE float approx_exp(float y) {
+#define A ((1 << 23) / 0.69314718056f)  // (1 << 23) / ln(2)
+#define B \
+  127  // Offset for the exponent according to IEEE floating point standard.
+#define C 60801  // Magic number controls the accuracy of approximation
+  union {
+    float as_float;
+    int32_t as_int32;
+  } container;
+  container.as_int32 = ((int32_t)(y * A)) + ((B << 23) - C);
+  return container.as_float;
+#undef A
+#undef B
+#undef C
+}
 #endif  // AOM_AOM_DSP_MATHUTILS_H_
diff --git a/aom_dsp/noise_model.c b/aom_dsp/noise_model.c
index 8521232..13eaf1c 100644
--- a/aom_dsp/noise_model.c
+++ b/aom_dsp/noise_model.c
@@ -571,7 +571,6 @@
   const int num_blocks_w = (w + block_size - 1) / block_size;
   const int num_blocks_h = (h + block_size - 1) / block_size;
   int num_flat = 0;
-  int bx = 0, by = 0;
   double *plane = (double *)aom_malloc(n * sizeof(*plane));
   double *block = (double *)aom_malloc(n * sizeof(*block));
   index_and_score_t *scores = (index_and_score_t *)aom_malloc(
@@ -587,19 +586,18 @@
 #ifdef NOISE_MODEL_LOG_SCORE
   fprintf(stderr, "score = [");
 #endif
-  for (by = 0; by < num_blocks_h; ++by) {
-    for (bx = 0; bx < num_blocks_w; ++bx) {
+  for (int by = 0; by < num_blocks_h; ++by) {
+    for (int bx = 0; bx < num_blocks_w; ++bx) {
       // Compute gradient covariance matrix.
-      double Gxx = 0, Gxy = 0, Gyy = 0;
-      double var = 0;
-      double mean = 0;
-      int xi, yi;
       aom_flat_block_finder_extract_block(block_finder, data, w, h, stride,
                                           bx * block_size, by * block_size,
                                           plane, block);
+      double Gxx = 0, Gxy = 0, Gyy = 0;
+      double mean = 0;
+      double var = 0;
 
-      for (yi = 1; yi < block_size - 1; ++yi) {
-        for (xi = 1; xi < block_size - 1; ++xi) {
+      for (int yi = 1; yi < block_size - 1; ++yi) {
+        for (int xi = 1; xi < block_size - 1; ++xi) {
           const double gx = (block[yi * block_size + xi + 1] -
                              block[yi * block_size + xi - 1]) /
                             2;
@@ -1623,6 +1621,8 @@
   return 1;
 }
 
+// TODO(aomedia:3151): Handle a monochrome image (sd->u_buffer and sd->v_buffer
+// are null pointers) correctly.
 int aom_denoise_and_model_run(struct aom_denoise_and_model_t *ctx,
                               YV12_BUFFER_CONFIG *sd,
                               aom_film_grain_t *film_grain, int apply_denoise) {
diff --git a/aom_dsp/noise_model.h b/aom_dsp/noise_model.h
index f385251..8228aea 100644
--- a/aom_dsp/noise_model.h
+++ b/aom_dsp/noise_model.h
@@ -293,13 +293,13 @@
  * parameter will be true when the input buffer was successfully denoised and
  * grain was modelled. Returns false on error.
  *
- * \param[in]      ctx          Struct allocated with
+ * \param[in]     ctx           Struct allocated with
  *                              aom_denoise_and_model_alloc that holds some
  *                              buffers for denoising and the current noise
  *                              estimate.
- * \param[in/out]   buf         The raw input buffer to be denoised.
+ * \param[in,out] buf           The raw input buffer to be denoised.
  * \param[out]    grain         Output film grain parameters
- * \param[out]    apply_denoise Whether or not to apply the denoising to the
+ * \param[in]     apply_denoise Whether or not to apply the denoising to the
  *                              frame that will be encoded
  */
 int aom_denoise_and_model_run(struct aom_denoise_and_model_t *ctx,
diff --git a/aom_dsp/prob.h b/aom_dsp/prob.h
index 5e25b9c..5711a40 100644
--- a/aom_dsp/prob.h
+++ b/aom_dsp/prob.h
@@ -31,16 +31,12 @@
 #define CDF_SIZE(x) ((x) + 1)
 #define CDF_PROB_BITS 15
 #define CDF_PROB_TOP (1 << CDF_PROB_BITS)
-#define CDF_INIT_TOP 32768
-#define CDF_SHIFT (15 - CDF_PROB_BITS)
 /*The value stored in an iCDF is CDF_PROB_TOP minus the actual cumulative
   probability (an "inverse" CDF).
   This function converts from one representation to the other (and is its own
   inverse).*/
 #define AOM_ICDF(x) (CDF_PROB_TOP - (x))
 
-#if CDF_SHIFT == 0
-
 #define AOM_CDF2(a0) AOM_ICDF(a0), AOM_ICDF(CDF_PROB_TOP), 0
 #define AOM_CDF3(a0, a1) AOM_ICDF(a0), AOM_ICDF(a1), AOM_ICDF(CDF_PROB_TOP), 0
 #define AOM_CDF4(a0, a1, a2) \
@@ -101,535 +97,6 @@
       AOM_ICDF(a11), AOM_ICDF(a12), AOM_ICDF(a13), AOM_ICDF(a14),             \
       AOM_ICDF(CDF_PROB_TOP), 0
 
-#else
-#define AOM_CDF2(a0)                                       \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 2) + \
-            ((CDF_INIT_TOP - 2) >> 1)) /                   \
-               ((CDF_INIT_TOP - 2)) +                      \
-           1)                                              \
-  , AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF3(a0, a1)                                       \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 3) +     \
-            ((CDF_INIT_TOP - 3) >> 1)) /                       \
-               ((CDF_INIT_TOP - 3)) +                          \
-           1)                                                  \
-  ,                                                            \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 3) + \
-                ((CDF_INIT_TOP - 3) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 3)) +                      \
-               2),                                             \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF4(a0, a1, a2)                                   \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 4) +     \
-            ((CDF_INIT_TOP - 4) >> 1)) /                       \
-               ((CDF_INIT_TOP - 4)) +                          \
-           1)                                                  \
-  ,                                                            \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 4) + \
-                ((CDF_INIT_TOP - 4) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 4)) +                      \
-               2),                                             \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 4) + \
-                ((CDF_INIT_TOP - 4) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 4)) +                      \
-               3),                                             \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF5(a0, a1, a2, a3)                               \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 5) +     \
-            ((CDF_INIT_TOP - 5) >> 1)) /                       \
-               ((CDF_INIT_TOP - 5)) +                          \
-           1)                                                  \
-  ,                                                            \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 5) + \
-                ((CDF_INIT_TOP - 5) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 5)) +                      \
-               2),                                             \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 5) + \
-                ((CDF_INIT_TOP - 5) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 5)) +                      \
-               3),                                             \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 5) + \
-                ((CDF_INIT_TOP - 5) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 5)) +                      \
-               4),                                             \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF6(a0, a1, a2, a3, a4)                           \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 6) +     \
-            ((CDF_INIT_TOP - 6) >> 1)) /                       \
-               ((CDF_INIT_TOP - 6)) +                          \
-           1)                                                  \
-  ,                                                            \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 6) + \
-                ((CDF_INIT_TOP - 6) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 6)) +                      \
-               2),                                             \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 6) + \
-                ((CDF_INIT_TOP - 6) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 6)) +                      \
-               3),                                             \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 6) + \
-                ((CDF_INIT_TOP - 6) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 6)) +                      \
-               4),                                             \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 6) + \
-                ((CDF_INIT_TOP - 6) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 6)) +                      \
-               5),                                             \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF7(a0, a1, a2, a3, a4, a5)                       \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 7) +     \
-            ((CDF_INIT_TOP - 7) >> 1)) /                       \
-               ((CDF_INIT_TOP - 7)) +                          \
-           1)                                                  \
-  ,                                                            \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 7) + \
-                ((CDF_INIT_TOP - 7) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 7)) +                      \
-               2),                                             \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 7) + \
-                ((CDF_INIT_TOP - 7) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 7)) +                      \
-               3),                                             \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 7) + \
-                ((CDF_INIT_TOP - 7) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 7)) +                      \
-               4),                                             \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 7) + \
-                ((CDF_INIT_TOP - 7) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 7)) +                      \
-               5),                                             \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 7) + \
-                ((CDF_INIT_TOP - 7) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 7)) +                      \
-               6),                                             \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF8(a0, a1, a2, a3, a4, a5, a6)                   \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 8) +     \
-            ((CDF_INIT_TOP - 8) >> 1)) /                       \
-               ((CDF_INIT_TOP - 8)) +                          \
-           1)                                                  \
-  ,                                                            \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 8) + \
-                ((CDF_INIT_TOP - 8) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 8)) +                      \
-               2),                                             \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 8) + \
-                ((CDF_INIT_TOP - 8) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 8)) +                      \
-               3),                                             \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 8) + \
-                ((CDF_INIT_TOP - 8) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 8)) +                      \
-               4),                                             \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 8) + \
-                ((CDF_INIT_TOP - 8) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 8)) +                      \
-               5),                                             \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 8) + \
-                ((CDF_INIT_TOP - 8) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 8)) +                      \
-               6),                                             \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 8) + \
-                ((CDF_INIT_TOP - 8) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 8)) +                      \
-               7),                                             \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF9(a0, a1, a2, a3, a4, a5, a6, a7)               \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) +     \
-            ((CDF_INIT_TOP - 9) >> 1)) /                       \
-               ((CDF_INIT_TOP - 9)) +                          \
-           1)                                                  \
-  ,                                                            \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) + \
-                ((CDF_INIT_TOP - 9) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 9)) +                      \
-               2),                                             \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) + \
-                ((CDF_INIT_TOP - 9) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 9)) +                      \
-               3),                                             \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) + \
-                ((CDF_INIT_TOP - 9) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 9)) +                      \
-               4),                                             \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) + \
-                ((CDF_INIT_TOP - 9) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 9)) +                      \
-               5),                                             \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) + \
-                ((CDF_INIT_TOP - 9) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 9)) +                      \
-               6),                                             \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) + \
-                ((CDF_INIT_TOP - 9) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 9)) +                      \
-               7),                                             \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 9) + \
-                ((CDF_INIT_TOP - 9) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 9)) +                      \
-               8),                                             \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF10(a0, a1, a2, a3, a4, a5, a6, a7, a8)           \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) +     \
-            ((CDF_INIT_TOP - 10) >> 1)) /                       \
-               ((CDF_INIT_TOP - 10)) +                          \
-           1)                                                   \
-  ,                                                             \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               2),                                              \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               3),                                              \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               4),                                              \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               5),                                              \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               6),                                              \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               7),                                              \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               8),                                              \
-      AOM_ICDF((((a8)-9) * ((CDF_INIT_TOP >> CDF_SHIFT) - 10) + \
-                ((CDF_INIT_TOP - 10) >> 1)) /                   \
-                   ((CDF_INIT_TOP - 10)) +                      \
-               9),                                              \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF11(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9)        \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +      \
-            ((CDF_INIT_TOP - 11) >> 1)) /                        \
-               ((CDF_INIT_TOP - 11)) +                           \
-           1)                                                    \
-  ,                                                              \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               2),                                               \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               3),                                               \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               4),                                               \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               5),                                               \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               6),                                               \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               7),                                               \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               8),                                               \
-      AOM_ICDF((((a8)-9) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) +  \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               9),                                               \
-      AOM_ICDF((((a9)-10) * ((CDF_INIT_TOP >> CDF_SHIFT) - 11) + \
-                ((CDF_INIT_TOP - 11) >> 1)) /                    \
-                   ((CDF_INIT_TOP - 11)) +                       \
-               10),                                              \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF12(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10)    \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +       \
-            ((CDF_INIT_TOP - 12) >> 1)) /                         \
-               ((CDF_INIT_TOP - 12)) +                            \
-           1)                                                     \
-  ,                                                               \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               2),                                                \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               3),                                                \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               4),                                                \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               5),                                                \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               6),                                                \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               7),                                                \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               8),                                                \
-      AOM_ICDF((((a8)-9) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +   \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               9),                                                \
-      AOM_ICDF((((a9)-10) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) +  \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               10),                                               \
-      AOM_ICDF((((a10)-11) * ((CDF_INIT_TOP >> CDF_SHIFT) - 12) + \
-                ((CDF_INIT_TOP - 12) >> 1)) /                     \
-                   ((CDF_INIT_TOP - 12)) +                        \
-               11),                                               \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF13(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +         \
-            ((CDF_INIT_TOP - 13) >> 1)) /                           \
-               ((CDF_INIT_TOP - 13)) +                              \
-           1)                                                       \
-  ,                                                                 \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               2),                                                  \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               3),                                                  \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               4),                                                  \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               5),                                                  \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               6),                                                  \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               7),                                                  \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               8),                                                  \
-      AOM_ICDF((((a8)-9) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +     \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               9),                                                  \
-      AOM_ICDF((((a9)-10) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +    \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               10),                                                 \
-      AOM_ICDF((((a10)-11) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +   \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               11),                                                 \
-      AOM_ICDF((((a11)-12) * ((CDF_INIT_TOP >> CDF_SHIFT) - 13) +   \
-                ((CDF_INIT_TOP - 13) >> 1)) /                       \
-                   ((CDF_INIT_TOP - 13)) +                          \
-               12),                                                 \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF14(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +              \
-            ((CDF_INIT_TOP - 14) >> 1)) /                                \
-               ((CDF_INIT_TOP - 14)) +                                   \
-           1)                                                            \
-  ,                                                                      \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               2),                                                       \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               3),                                                       \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               4),                                                       \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               5),                                                       \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               6),                                                       \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               7),                                                       \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               8),                                                       \
-      AOM_ICDF((((a8)-9) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +          \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               9),                                                       \
-      AOM_ICDF((((a9)-10) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +         \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               10),                                                      \
-      AOM_ICDF((((a10)-11) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +        \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               11),                                                      \
-      AOM_ICDF((((a11)-12) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +        \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               12),                                                      \
-      AOM_ICDF((((a12)-13) * ((CDF_INIT_TOP >> CDF_SHIFT) - 14) +        \
-                ((CDF_INIT_TOP - 14) >> 1)) /                            \
-                   ((CDF_INIT_TOP - 14)) +                               \
-               13),                                                      \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF15(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +                   \
-            ((CDF_INIT_TOP - 15) >> 1)) /                                     \
-               ((CDF_INIT_TOP - 15)) +                                        \
-           1)                                                                 \
-  ,                                                                           \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               2),                                                            \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               3),                                                            \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               4),                                                            \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               5),                                                            \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               6),                                                            \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               7),                                                            \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               8),                                                            \
-      AOM_ICDF((((a8)-9) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +               \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               9),                                                            \
-      AOM_ICDF((((a9)-10) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +              \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               10),                                                           \
-      AOM_ICDF((((a10)-11) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +             \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               11),                                                           \
-      AOM_ICDF((((a11)-12) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +             \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               12),                                                           \
-      AOM_ICDF((((a12)-13) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +             \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               13),                                                           \
-      AOM_ICDF((((a13)-14) * ((CDF_INIT_TOP >> CDF_SHIFT) - 15) +             \
-                ((CDF_INIT_TOP - 15) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 15)) +                                    \
-               14),                                                           \
-      AOM_ICDF(CDF_PROB_TOP), 0
-#define AOM_CDF16(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, \
-                  a14)                                                        \
-  AOM_ICDF((((a0)-1) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +                   \
-            ((CDF_INIT_TOP - 16) >> 1)) /                                     \
-               ((CDF_INIT_TOP - 16)) +                                        \
-           1)                                                                 \
-  ,                                                                           \
-      AOM_ICDF((((a1)-2) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               2),                                                            \
-      AOM_ICDF((((a2)-3) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               3),                                                            \
-      AOM_ICDF((((a3)-4) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               4),                                                            \
-      AOM_ICDF((((a4)-5) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               5),                                                            \
-      AOM_ICDF((((a5)-6) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               6),                                                            \
-      AOM_ICDF((((a6)-7) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               7),                                                            \
-      AOM_ICDF((((a7)-8) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               8),                                                            \
-      AOM_ICDF((((a8)-9) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +               \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               9),                                                            \
-      AOM_ICDF((((a9)-10) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +              \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               10),                                                           \
-      AOM_ICDF((((a10)-11) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +             \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               11),                                                           \
-      AOM_ICDF((((a11)-12) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +             \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               12),                                                           \
-      AOM_ICDF((((a12)-13) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +             \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               13),                                                           \
-      AOM_ICDF((((a13)-14) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +             \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               14),                                                           \
-      AOM_ICDF((((a14)-15) * ((CDF_INIT_TOP >> CDF_SHIFT) - 16) +             \
-                ((CDF_INIT_TOP - 16) >> 1)) /                                 \
-                   ((CDF_INIT_TOP - 16)) +                                    \
-               15),                                                           \
-      AOM_ICDF(CDF_PROB_TOP), 0
-
-#endif
-
 static INLINE uint8_t get_prob(unsigned int num, unsigned int den) {
   assert(den != 0);
   {
diff --git a/aom_dsp/psnr.c b/aom_dsp/psnr.c
index 08fb69c..f71590c 100644
--- a/aom_dsp/psnr.c
+++ b/aom_dsp/psnr.c
@@ -44,9 +44,9 @@
 }
 
 #if CONFIG_AV1_HIGHBITDEPTH
-static int64_t encoder_highbd_8_sse(const uint8_t *a8, int a_stride,
-                                    const uint8_t *b8, int b_stride, int w,
-                                    int h) {
+static int64_t encoder_highbd_sse(const uint8_t *a8, int a_stride,
+                                  const uint8_t *b8, int b_stride, int w,
+                                  int h) {
   const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
   const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
   int64_t sse = 0;
@@ -84,10 +84,8 @@
   for (y = 0; y < height / 16; ++y) {
     const uint8_t *pa = a;
     const uint8_t *pb = b;
-    unsigned int sse;
     for (x = 0; x < width / 16; ++x) {
-      aom_mse16x16(pa, a_stride, pb, b_stride, &sse);
-      total_sse += sse;
+      total_sse += aom_sse(pa, a_stride, pb, b_stride, 16, 16);
 
       pa += 16;
       pb += 16;
@@ -128,22 +126,20 @@
   const int dh = height % 16;
 
   if (dw > 0) {
-    total_sse += encoder_highbd_8_sse(&a[width - dw], a_stride, &b[width - dw],
-                                      b_stride, dw, height);
+    total_sse += encoder_highbd_sse(&a[width - dw], a_stride, &b[width - dw],
+                                    b_stride, dw, height);
   }
   if (dh > 0) {
-    total_sse += encoder_highbd_8_sse(&a[(height - dh) * a_stride], a_stride,
-                                      &b[(height - dh) * b_stride], b_stride,
-                                      width - dw, dh);
+    total_sse += encoder_highbd_sse(&a[(height - dh) * a_stride], a_stride,
+                                    &b[(height - dh) * b_stride], b_stride,
+                                    width - dw, dh);
   }
 
   for (y = 0; y < height / 16; ++y) {
     const uint8_t *pa = a;
     const uint8_t *pb = b;
-    unsigned int sse;
     for (x = 0; x < width / 16; ++x) {
-      aom_highbd_8_mse16x16(pa, a_stride, pb, b_stride, &sse);
-      total_sse += sse;
+      total_sse += aom_highbd_sse(pa, a_stride, pb, b_stride, 16, 16);
       pa += 16;
       pb += 16;
     }
diff --git a/aom_dsp/pyramid.c b/aom_dsp/pyramid.c
new file mode 100644
index 0000000..a26d302
--- /dev/null
+++ b/aom_dsp/pyramid.c
@@ -0,0 +1,411 @@
+/*
+ * Copyright (c) 2022, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "aom_dsp/pyramid.h"
+#include "aom_mem/aom_mem.h"
+#include "aom_ports/bitops.h"
+#include "aom_util/aom_thread.h"
+
+// TODO(rachelbarker): Move needed code from av1/ to aom_dsp/
+#include "av1/common/resize.h"
+
+#include <assert.h>
+#include <string.h>
+
+// Lifecycle:
+// * Frame buffer alloc code calls aom_get_pyramid_alloc_size()
+//   to work out how much space is needed for a given number of pyramid
+//   levels. This is counted in the size checked against the max allocation
+//   limit
+// * Then calls aom_alloc_pyramid() to actually create the pyramid
+// * Pyramid is initially marked as invalid (no data)
+// * Whenever pyramid is needed, we check the valid flag. If set, use existing
+//   data. If not set, compute full pyramid
+// * Whenever frame buffer is reused, clear the valid flag
+// * Whenever frame buffer is resized, reallocate pyramid
+
+size_t aom_get_pyramid_alloc_size(int width, int height, int n_levels,
+                                  bool image_is_16bit) {
+  // Limit number of levels on small frames
+  const int msb = get_msb(AOMMIN(width, height));
+  const int max_levels = AOMMAX(msb - MIN_PYRAMID_SIZE_LOG2, 1);
+  n_levels = AOMMIN(n_levels, max_levels);
+
+  size_t alloc_size = 0;
+  alloc_size += sizeof(ImagePyramid);
+  alloc_size += n_levels * sizeof(PyramidLayer);
+
+  // Calculate how much memory is needed for downscaled frame buffers
+  size_t buffer_size = 0;
+
+  // Work out if we need to allocate a few extra bytes for alignment.
+  // aom_memalign() will ensure that the start of the allocation is aligned
+  // to a multiple of PYRAMID_ALIGNMENT. But we want the first image pixel
+  // to be aligned, not the first byte of the allocation.
+  //
+  // In the loop below, we ensure that the stride of every image is a multiple
+  // of PYRAMID_ALIGNMENT. Thus the allocated size of each pyramid level will
+  // also be a multiple of PYRAMID_ALIGNMENT. Thus, as long as we can get the
+  // first pixel in the first pyramid layer aligned properly, that will
+  // automatically mean that the first pixel of every row of every layer is
+  // properly aligned too.
+  //
+  // Thus all we need to consider is the first pixel in the first layer.
+  // This is located at offset
+  //   extra_bytes + level_stride * PYRAMID_PADDING + PYRAMID_PADDING
+  // bytes into the buffer. Since level_stride is a multiple of
+  // PYRAMID_ALIGNMENT, we can ignore that. So we need
+  //   extra_bytes + PYRAMID_PADDING = multiple of PYRAMID_ALIGNMENT
+  //
+  // To solve this, we can round PYRAMID_PADDING up to the next multiple
+  // of PYRAMID_ALIGNMENT, then subtract the orginal value to calculate
+  // how many extra bytes are needed.
+  size_t first_px_offset =
+      (PYRAMID_PADDING + PYRAMID_ALIGNMENT - 1) & ~(PYRAMID_ALIGNMENT - 1);
+  size_t extra_bytes = first_px_offset - PYRAMID_PADDING;
+  buffer_size += extra_bytes;
+
+  // If the original image is stored in an 8-bit buffer, then we can point the
+  // lowest pyramid level at that buffer rather than allocating a new one.
+  int first_allocated_level = image_is_16bit ? 0 : 1;
+
+  for (int level = first_allocated_level; level < n_levels; level++) {
+    int level_width = width >> level;
+    int level_height = height >> level;
+
+    // Allocate padding for each layer
+    int padded_width = level_width + 2 * PYRAMID_PADDING;
+    int padded_height = level_height + 2 * PYRAMID_PADDING;
+
+    // Align the layer stride to be a multiple of PYRAMID_ALIGNMENT
+    // This ensures that, as long as the top-left pixel in this pyramid level is
+    // properly aligned, then so will the leftmost pixel in every row of the
+    // pyramid level.
+    int level_stride =
+        (padded_width + PYRAMID_ALIGNMENT - 1) & ~(PYRAMID_ALIGNMENT - 1);
+
+    buffer_size += level_stride * padded_height;
+  }
+
+  alloc_size += buffer_size;
+
+  return alloc_size;
+}
+
+ImagePyramid *aom_alloc_pyramid(int width, int height, int n_levels,
+                                bool image_is_16bit) {
+  // Limit number of levels on small frames
+  const int msb = get_msb(AOMMIN(width, height));
+  const int max_levels = AOMMAX(msb - MIN_PYRAMID_SIZE_LOG2, 1);
+  n_levels = AOMMIN(n_levels, max_levels);
+
+  ImagePyramid *pyr = aom_calloc(1, sizeof(*pyr));
+  if (!pyr) {
+    return NULL;
+  }
+
+  pyr->layers = aom_calloc(n_levels, sizeof(PyramidLayer));
+  if (!pyr->layers) {
+    aom_free(pyr);
+    return NULL;
+  }
+
+  pyr->valid = false;
+  pyr->n_levels = n_levels;
+
+  // Compute sizes and offsets for each pyramid level
+  // These are gathered up first, so that we can allocate all pyramid levels
+  // in a single buffer
+  size_t buffer_size = 0;
+  size_t *layer_offsets = aom_calloc(n_levels, sizeof(size_t));
+  if (!layer_offsets) {
+    aom_free(pyr);
+    aom_free(pyr->layers);
+    return NULL;
+  }
+
+  // Work out if we need to allocate a few extra bytes for alignment.
+  // aom_memalign() will ensure that the start of the allocation is aligned
+  // to a multiple of PYRAMID_ALIGNMENT. But we want the first image pixel
+  // to be aligned, not the first byte of the allocation.
+  //
+  // In the loop below, we ensure that the stride of every image is a multiple
+  // of PYRAMID_ALIGNMENT. Thus the allocated size of each pyramid level will
+  // also be a multiple of PYRAMID_ALIGNMENT. Thus, as long as we can get the
+  // first pixel in the first pyramid layer aligned properly, that will
+  // automatically mean that the first pixel of every row of every layer is
+  // properly aligned too.
+  //
+  // Thus all we need to consider is the first pixel in the first layer.
+  // This is located at offset
+  //   extra_bytes + level_stride * PYRAMID_PADDING + PYRAMID_PADDING
+  // bytes into the buffer. Since level_stride is a multiple of
+  // PYRAMID_ALIGNMENT, we can ignore that. So we need
+  //   extra_bytes + PYRAMID_PADDING = multiple of PYRAMID_ALIGNMENT
+  //
+  // To solve this, we can round PYRAMID_PADDING up to the next multiple
+  // of PYRAMID_ALIGNMENT, then subtract the orginal value to calculate
+  // how many extra bytes are needed.
+  size_t first_px_offset =
+      (PYRAMID_PADDING + PYRAMID_ALIGNMENT - 1) & ~(PYRAMID_ALIGNMENT - 1);
+  size_t extra_bytes = first_px_offset - PYRAMID_PADDING;
+  buffer_size += extra_bytes;
+
+  // If the original image is stored in an 8-bit buffer, then we can point the
+  // lowest pyramid level at that buffer rather than allocating a new one.
+  int first_allocated_level = image_is_16bit ? 0 : 1;
+
+  for (int level = first_allocated_level; level < n_levels; level++) {
+    PyramidLayer *layer = &pyr->layers[level];
+
+    int level_width = width >> level;
+    int level_height = height >> level;
+
+    // Allocate padding for each layer
+    int padded_width = level_width + 2 * PYRAMID_PADDING;
+    int padded_height = level_height + 2 * PYRAMID_PADDING;
+
+    // Align the layer stride to be a multiple of PYRAMID_ALIGNMENT
+    // This ensures that, as long as the top-left pixel in this pyramid level is
+    // properly aligned, then so will the leftmost pixel in every row of the
+    // pyramid level.
+    int level_stride =
+        (padded_width + PYRAMID_ALIGNMENT - 1) & ~(PYRAMID_ALIGNMENT - 1);
+
+    size_t level_alloc_start = buffer_size;
+    size_t level_start =
+        level_alloc_start + PYRAMID_PADDING * level_stride + PYRAMID_PADDING;
+
+    buffer_size += level_stride * padded_height;
+
+    layer_offsets[level] = level_start;
+    layer->width = level_width;
+    layer->height = level_height;
+    layer->stride = level_stride;
+  }
+
+  pyr->buffer_alloc =
+      aom_memalign(PYRAMID_ALIGNMENT, buffer_size * sizeof(*pyr->buffer_alloc));
+  if (!pyr->buffer_alloc) {
+    aom_free(pyr);
+    aom_free(pyr->layers);
+    aom_free(layer_offsets);
+    return NULL;
+  }
+
+  // Fill in pointers for each level
+  // If image is 8-bit, then the lowest level is left unconfigured for now,
+  // and will be set up properly when the pyramid is filled in
+  for (int level = first_allocated_level; level < n_levels; level++) {
+    PyramidLayer *layer = &pyr->layers[level];
+    layer->buffer = pyr->buffer_alloc + layer_offsets[level];
+  }
+
+#if CONFIG_MULTITHREAD
+  pthread_mutex_init(&pyr->mutex, NULL);
+#endif  // CONFIG_MULTITHREAD
+
+  aom_free(layer_offsets);
+  return pyr;
+}
+
+// Fill the border region of a pyramid frame.
+// This must be called after the main image area is filled out.
+// `img_buf` should point to the first pixel in the image area,
+// ie. it should be pyr->level_buffer + pyr->level_loc[level].
+static INLINE void fill_border(uint8_t *img_buf, const int width,
+                               const int height, const int stride) {
+  // Fill left and right areas
+  for (int row = 0; row < height; row++) {
+    uint8_t *row_start = &img_buf[row * stride];
+    uint8_t left_pixel = row_start[0];
+    memset(row_start - PYRAMID_PADDING, left_pixel, PYRAMID_PADDING);
+    uint8_t right_pixel = row_start[width - 1];
+    memset(row_start + width, right_pixel, PYRAMID_PADDING);
+  }
+
+  // Fill top area
+  for (int row = -PYRAMID_PADDING; row < 0; row++) {
+    uint8_t *row_start = &img_buf[row * stride];
+    memcpy(row_start - PYRAMID_PADDING, img_buf - PYRAMID_PADDING,
+           width + 2 * PYRAMID_PADDING);
+  }
+
+  // Fill bottom area
+  uint8_t *last_row_start = &img_buf[(height - 1) * stride];
+  for (int row = height; row < height + PYRAMID_PADDING; row++) {
+    uint8_t *row_start = &img_buf[row * stride];
+    memcpy(row_start - PYRAMID_PADDING, last_row_start - PYRAMID_PADDING,
+           width + 2 * PYRAMID_PADDING);
+  }
+}
+
+// Compute coarse to fine pyramids for a frame
+// This must only be called while holding frame_pyr->mutex
+static INLINE void fill_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth,
+                                ImagePyramid *frame_pyr) {
+  int n_levels = frame_pyr->n_levels;
+  const int frame_width = frame->y_crop_width;
+  const int frame_height = frame->y_crop_height;
+  const int frame_stride = frame->y_stride;
+  assert((frame_width >> n_levels) >= 0);
+  assert((frame_height >> n_levels) >= 0);
+
+  PyramidLayer *first_layer = &frame_pyr->layers[0];
+  if (frame->flags & YV12_FLAG_HIGHBITDEPTH) {
+    // For frames stored in a 16-bit buffer, we need to downconvert to 8 bits
+    assert(first_layer->width == frame_width);
+    assert(first_layer->height == frame_height);
+
+    uint16_t *frame_buffer = CONVERT_TO_SHORTPTR(frame->y_buffer);
+    uint8_t *pyr_buffer = first_layer->buffer;
+    int pyr_stride = first_layer->stride;
+    for (int y = 0; y < frame_height; y++) {
+      uint16_t *frame_row = frame_buffer + y * frame_stride;
+      uint8_t *pyr_row = pyr_buffer + y * pyr_stride;
+      for (int x = 0; x < frame_width; x++) {
+        pyr_row[x] = frame_row[x] >> (bit_depth - 8);
+      }
+    }
+
+    fill_border(pyr_buffer, frame_width, frame_height, pyr_stride);
+  } else {
+    // For frames stored in an 8-bit buffer, we need to configure the first
+    // pyramid layer to point at the original image buffer
+    first_layer->buffer = frame->y_buffer;
+    first_layer->width = frame_width;
+    first_layer->height = frame_height;
+    first_layer->stride = frame_stride;
+  }
+
+  // Fill in the remaining levels through progressive downsampling
+  for (int level = 1; level < n_levels; ++level) {
+    PyramidLayer *prev_layer = &frame_pyr->layers[level - 1];
+    uint8_t *prev_buffer = prev_layer->buffer;
+    int prev_stride = prev_layer->stride;
+
+    PyramidLayer *this_layer = &frame_pyr->layers[level];
+    uint8_t *this_buffer = this_layer->buffer;
+    int this_width = this_layer->width;
+    int this_height = this_layer->height;
+    int this_stride = this_layer->stride;
+
+    // Compute the this pyramid level by downsampling the current level.
+    //
+    // We downsample by a factor of exactly 2, clipping the rightmost and
+    // bottommost pixel off of the current level if needed. We do this for
+    // two main reasons:
+    //
+    // 1) In the disflow code, when stepping from a higher pyramid level to a
+    //    lower pyramid level, we need to not just interpolate the flow field
+    //    but also to scale each flow vector by the upsampling ratio.
+    //    So it is much more convenient if this ratio is simply 2.
+    //
+    // 2) Up/downsampling by a factor of 2 can be implemented much more
+    //    efficiently than up/downsampling by a generic ratio.
+    //    TODO(rachelbarker): Use optimized downsample-by-2 function
+    av1_resize_plane(prev_buffer, this_height << 1, this_width << 1,
+                     prev_stride, this_buffer, this_height, this_width,
+                     this_stride);
+    fill_border(this_buffer, this_width, this_height, this_stride);
+  }
+}
+
+// Fill out a downsampling pyramid for a given frame.
+//
+// The top level (index 0) will always be an 8-bit copy of the input frame,
+// regardless of the input bit depth. Additional levels are then downscaled
+// by powers of 2.
+//
+// For small input frames, the number of levels actually constructed
+// will be limited so that the smallest image is at least MIN_PYRAMID_SIZE
+// pixels along each side.
+//
+// However, if the input frame has a side of length < MIN_PYRAMID_SIZE,
+// we will still construct the top level.
+void aom_compute_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth,
+                         ImagePyramid *pyr) {
+  assert(pyr);
+
+  // Per the comments in the ImagePyramid struct, we must take this mutex
+  // before reading or writing the "valid" flag, and hold it while computing
+  // the pyramid, to ensure proper behaviour if multiple threads call this
+  // function simultaneously
+#if CONFIG_MULTITHREAD
+  pthread_mutex_lock(&pyr->mutex);
+#endif  // CONFIG_MULTITHREAD
+
+  if (!pyr->valid) {
+    fill_pyramid(frame, bit_depth, pyr);
+    pyr->valid = true;
+  }
+
+  // At this point, the pyramid is guaranteed to be valid, and can be safely
+  // read from without holding the mutex any more
+
+#if CONFIG_MULTITHREAD
+  pthread_mutex_unlock(&pyr->mutex);
+#endif  // CONFIG_MULTITHREAD
+}
+
+#ifndef NDEBUG
+// Check if a pyramid has already been computed.
+// This is mostly a debug helper - as it is necessary to hold pyr->mutex
+// while reading the valid flag, we cannot just write:
+//   assert(pyr->valid);
+// This function allows the check to be correctly written as:
+//   assert(aom_is_pyramid_valid(pyr));
+bool aom_is_pyramid_valid(ImagePyramid *pyr) {
+  assert(pyr);
+
+  // Per the comments in the ImagePyramid struct, we must take this mutex
+  // before reading or writing the "valid" flag, and hold it while computing
+  // the pyramid, to ensure proper behaviour if multiple threads call this
+  // function simultaneously
+#if CONFIG_MULTITHREAD
+  pthread_mutex_lock(&pyr->mutex);
+#endif  // CONFIG_MULTITHREAD
+
+  bool valid = pyr->valid;
+
+#if CONFIG_MULTITHREAD
+  pthread_mutex_unlock(&pyr->mutex);
+#endif  // CONFIG_MULTITHREAD
+
+  return valid;
+}
+#endif
+
+// Mark a pyramid as no longer containing valid data.
+// This must be done whenever the corresponding frame buffer is reused
+void aom_invalidate_pyramid(ImagePyramid *pyr) {
+  if (pyr) {
+#if CONFIG_MULTITHREAD
+    pthread_mutex_lock(&pyr->mutex);
+#endif  // CONFIG_MULTITHREAD
+    pyr->valid = false;
+#if CONFIG_MULTITHREAD
+    pthread_mutex_unlock(&pyr->mutex);
+#endif  // CONFIG_MULTITHREAD
+  }
+}
+
+// Release the memory associated with a pyramid
+void aom_free_pyramid(ImagePyramid *pyr) {
+  if (pyr) {
+#if CONFIG_MULTITHREAD
+    pthread_mutex_destroy(&pyr->mutex);
+#endif  // CONFIG_MULTITHREAD
+    aom_free(pyr->buffer_alloc);
+    aom_free(pyr->layers);
+    aom_free(pyr);
+  }
+}
diff --git a/aom_dsp/pyramid.h b/aom_dsp/pyramid.h
new file mode 100644
index 0000000..812aae1
--- /dev/null
+++ b/aom_dsp/pyramid.h
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2022, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_AOM_DSP_PYRAMID_H_
+#define AOM_AOM_DSP_PYRAMID_H_
+
+#include <stddef.h>
+#include <stdint.h>
+#include <stdbool.h>
+
+#include "config/aom_config.h"
+
+#include "aom_scale/yv12config.h"
+#include "aom_util/aom_thread.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Minimum dimensions of a downsampled image
+#define MIN_PYRAMID_SIZE_LOG2 3
+#define MIN_PYRAMID_SIZE (1 << MIN_PYRAMID_SIZE_LOG2)
+
+// Size of border around each pyramid image, in pixels
+// Similarly to the border around regular image buffers, this border is filled
+// with copies of the outermost pixels of the frame, to allow for more efficient
+// convolution code
+// TODO(rachelbarker): How many pixels do we actually need here?
+// I think we only need 9 for disflow, but how many for corner matching?
+#define PYRAMID_PADDING 16
+
+// Byte alignment of each line within the image pyramids.
+// That is, the first pixel inside the image (ie, not in the border region),
+// on each row of each pyramid level, is aligned to this byte alignment.
+// This value must be a power of 2.
+#define PYRAMID_ALIGNMENT 32
+
+typedef struct {
+  uint8_t *buffer;
+  int width;
+  int height;
+  int stride;
+} PyramidLayer;
+
+// Struct for an image pyramid
+typedef struct image_pyramid {
+#if CONFIG_MULTITHREAD
+  // Mutex which is used to prevent the pyramid being computed twice at the
+  // same time
+  //
+  // Semantics:
+  // * This mutex must be held whenever reading or writing the `valid` flag
+  //
+  // * This mutex must also be held while computing the image pyramid,
+  //   to ensure that only one thread may do so at a time.
+  //
+  // * However, once you have read the valid flag and seen a true value,
+  //   it is safe to drop the mutex and read from the remaining fields.
+  //   This is because, once the image pyramid is computed, its contents
+  //   will not be changed until the parent frame buffer is recycled,
+  //   which will not happen until there are no more outstanding references
+  //   to the frame buffer.
+  pthread_mutex_t mutex;
+#endif
+  // Flag indicating whether the pyramid contains valid data
+  bool valid;
+  // Number of allocated/filled levels in this pyramid
+  int n_levels;
+  // Pointer to allocated buffer
+  uint8_t *buffer_alloc;
+  // Data for each level
+  // The `buffer` pointers inside this array point into the region which
+  // is stored in the `buffer_alloc` field here
+  PyramidLayer *layers;
+} ImagePyramid;
+
+size_t aom_get_pyramid_alloc_size(int width, int height, int n_levels,
+                                  bool image_is_16bit);
+
+ImagePyramid *aom_alloc_pyramid(int width, int height, int n_levels,
+                                bool image_is_16bit);
+
+// Fill out a downsampling pyramid for a given frame.
+//
+// The top level (index 0) will always be an 8-bit copy of the input frame,
+// regardless of the input bit depth. Additional levels are then downscaled
+// by powers of 2.
+//
+// For small input frames, the number of levels actually constructed
+// will be limited so that the smallest image is at least MIN_PYRAMID_SIZE
+// pixels along each side.
+//
+// However, if the input frame has a side of length < MIN_PYRAMID_SIZE,
+// we will still construct the top level.
+void aom_compute_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth,
+                         ImagePyramid *pyr);
+
+#ifndef NDEBUG
+// Check if a pyramid has already been computed.
+// This is mostly a debug helper - as it is necessary to hold pyr->mutex
+// while reading the valid flag, we cannot just write:
+//   assert(pyr->valid);
+// This function allows the check to be correctly written as:
+//   assert(aom_is_pyramid_valid(pyr));
+bool aom_is_pyramid_valid(ImagePyramid *pyr);
+#endif
+
+// Mark a pyramid as no longer containing valid data.
+// This must be done whenever the corresponding frame buffer is reused
+void aom_invalidate_pyramid(ImagePyramid *pyr);
+
+// Release the memory associated with a pyramid
+void aom_free_pyramid(ImagePyramid *pyr);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  // AOM_AOM_DSP_PYRAMID_H_
diff --git a/aom_dsp/sad.c b/aom_dsp/sad.c
index 5b7b0e4..341a5ff 100644
--- a/aom_dsp/sad.c
+++ b/aom_dsp/sad.c
@@ -35,13 +35,6 @@
   return sad;
 }
 
-#define SAD_MXH(m)                                                         \
-  unsigned int aom_sad##m##xh_c(const uint8_t *a, int a_stride,            \
-                                const uint8_t *b, int b_stride, int width, \
-                                int height) {                              \
-    return sad(a, a_stride, b, b_stride, width, height);                   \
-  }
-
 #define SADMXN(m, n)                                                          \
   unsigned int aom_sad##m##x##n##_c(const uint8_t *src, int src_stride,       \
                                     const uint8_t *ref, int ref_stride) {     \
@@ -68,7 +61,6 @@
     return 2 * sad(src, 2 * src_stride, ref, 2 * ref_stride, (m), (n / 2));   \
   }
 
-#if CONFIG_REALTIME_ONLY
 // Calculate sad against 4 reference locations and store each in sad_array
 #define SAD_MXNX4D(m, n)                                                      \
   void aom_sad##m##x##n##x4d_c(const uint8_t *src, int src_stride,            \
@@ -89,37 +81,6 @@
                              2 * ref_stride, (m), (n / 2));                   \
     }                                                                         \
   }
-#else  // !CONFIG_REALTIME_ONLY
-// Calculate sad against 4 reference locations and store each in sad_array
-#define SAD_MXNX4D(m, n)                                                      \
-  void aom_sad##m##x##n##x4d_c(const uint8_t *src, int src_stride,            \
-                               const uint8_t *const ref_array[4],             \
-                               int ref_stride, uint32_t sad_array[4]) {       \
-    int i;                                                                    \
-    for (i = 0; i < 4; ++i) {                                                 \
-      sad_array[i] =                                                          \
-          aom_sad##m##x##n##_c(src, src_stride, ref_array[i], ref_stride);    \
-    }                                                                         \
-  }                                                                           \
-  void aom_sad##m##x##n##x4d_avg_c(                                           \
-      const uint8_t *src, int src_stride, const uint8_t *const ref_array[4],  \
-      int ref_stride, const uint8_t *second_pred, uint32_t sad_array[4]) {    \
-    int i;                                                                    \
-    for (i = 0; i < 4; ++i) {                                                 \
-      sad_array[i] = aom_sad##m##x##n##_avg_c(src, src_stride, ref_array[i],  \
-                                              ref_stride, second_pred);       \
-    }                                                                         \
-  }                                                                           \
-  void aom_sad_skip_##m##x##n##x4d_c(const uint8_t *src, int src_stride,      \
-                                     const uint8_t *const ref_array[4],       \
-                                     int ref_stride, uint32_t sad_array[4]) { \
-    int i;                                                                    \
-    for (i = 0; i < 4; ++i) {                                                 \
-      sad_array[i] = 2 * sad(src, 2 * src_stride, ref_array[i],               \
-                             2 * ref_stride, (m), (n / 2));                   \
-    }                                                                         \
-  }
-#endif  // CONFIG_REALTIME_ONLY
 // Call SIMD version of aom_sad_mxnx4d if the 3d version is unavailable.
 #define SAD_MXNX3D(m, n)                                                      \
   void aom_sad##m##x##n##x3d_c(const uint8_t *src, int src_stride,            \
@@ -208,13 +169,7 @@
 SAD_MXNX4D(4, 4)
 SAD_MXNX3D(4, 4)
 
-SAD_MXH(128)
-SAD_MXH(64)
-SAD_MXH(32)
-SAD_MXH(16)
-SAD_MXH(8)
-SAD_MXH(4)
-
+#if !CONFIG_REALTIME_ONLY
 SADMXN(4, 16)
 SAD_MXNX4D(4, 16)
 SADMXN(16, 4)
@@ -227,7 +182,6 @@
 SAD_MXNX4D(16, 64)
 SADMXN(64, 16)
 SAD_MXNX4D(64, 16)
-#if !CONFIG_REALTIME_ONLY
 SAD_MXNX3D(4, 16)
 SAD_MXNX3D(16, 4)
 SAD_MXNX3D(8, 32)
diff --git a/aom_dsp/simd/v128_intrinsics_arm.h b/aom_dsp/simd/v128_intrinsics_arm.h
index 2d497f4..fb89d60 100644
--- a/aom_dsp/simd/v128_intrinsics_arm.h
+++ b/aom_dsp/simd/v128_intrinsics_arm.h
@@ -29,7 +29,7 @@
 SIMD_INLINE v128 v128_from_v64(v64 a, v64 b) { return vcombine_s64(b, a); }
 
 SIMD_INLINE v128 v128_from_64(uint64_t a, uint64_t b) {
-  return vcombine_s64((int64x1_t)b, (int64x1_t)a);
+  return vcombine_s64(vcreate_s64(b), vcreate_s64(a));
 }
 
 SIMD_INLINE v128 v128_from_32(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
@@ -101,7 +101,7 @@
   return vaddlvq_s16(t1) + vaddlvq_s16(t2);
 #else
   int64x2_t t = vpaddlq_s32(vaddq_s32(vpaddlq_s16(t1), vpaddlq_s16(t2)));
-  return (int64_t)vget_high_s64(t) + (int64_t)vget_low_s64(t);
+  return vget_lane_s64(vadd_s64(vget_high_s64(t), vget_low_s64(t)), 0);
 #endif
 }
 
@@ -113,7 +113,7 @@
 SIMD_INLINE int64_t v128_dotp_s32(v128 a, v128 b) {
   int64x2_t t = vpaddlq_s32(
       vmulq_s32(vreinterpretq_s32_s64(a), vreinterpretq_s32_s64(b)));
-  return (int64_t)vget_high_s64(t) + (int64_t)vget_low_s64(t);
+  return vget_lane_s64(vadd_s64(vget_high_s64(t), vget_low_s64(t)), 0);
 }
 
 SIMD_INLINE uint64_t v128_hadd_u8(v128 x) {
@@ -159,7 +159,8 @@
   return vaddlvq_u16(s.hi) + vaddlvq_u16(s.lo);
 #else
   uint64x2_t t = vpaddlq_u32(vpaddlq_u16(vaddq_u16(s.hi, s.lo)));
-  return (uint32_t)(uint64_t)(vget_high_u64(t) + vget_low_u64(t));
+  return (uint32_t)vget_lane_u64(vadd_u64(vget_high_u64(t), vget_low_u64(t)),
+                                 0);
 #endif
 }
 
@@ -377,8 +378,8 @@
   uint64x2_t m = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(
       vandq_u8(vreinterpretq_u8_s64(a),
                vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201ULL))))));
-  return v64_low_u32(
-      v64_ziplo_8(v128_high_v64((v128)m), v128_low_v64((v128)m)));
+  int64x2_t s = vreinterpretq_s64_u64(m);
+  return v64_low_u32(v64_ziplo_8(vget_high_s64(s), vget_low_s64(s)));
 #endif
 }
 
@@ -488,12 +489,11 @@
 }
 
 SIMD_INLINE v128 v128_ziplo_64(v128 a, v128 b) {
-  return v128_from_v64(vget_low_s64((int64x2_t)a), vget_low_s64((int64x2_t)b));
+  return v128_from_v64(vget_low_s64(a), vget_low_s64(b));
 }
 
 SIMD_INLINE v128 v128_ziphi_64(v128 a, v128 b) {
-  return v128_from_v64(vget_high_s64((int64x2_t)a),
-                       vget_high_s64((int64x2_t)b));
+  return v128_from_v64(vget_high_s64(a), vget_high_s64(b));
 }
 
 SIMD_INLINE v128 v128_unziplo_8(v128 x, v128 y) {
@@ -643,10 +643,12 @@
 #else
   uint8x8x2_t p = { { vget_low_u8(vreinterpretq_u8_s64(x)),
                       vget_high_u8(vreinterpretq_u8_s64(x)) } };
-  return v128_from_64((uint64_t)vreinterpret_s64_u8(vtbl2_u8(
-                          p, vreinterpret_u8_s64(vget_high_s64(pattern)))),
-                      (uint64_t)vreinterpret_s64_u8(vtbl2_u8(
-                          p, vreinterpret_u8_s64(vget_low_s64(pattern)))));
+  uint8x8_t shuffle_hi =
+      vtbl2_u8(p, vreinterpret_u8_s64(vget_high_s64(pattern)));
+  uint8x8_t shuffle_lo =
+      vtbl2_u8(p, vreinterpret_u8_s64(vget_low_s64(pattern)));
+  return v128_from_64(vget_lane_u64(vreinterpret_u64_u8(shuffle_hi), 0),
+                      vget_lane_u64(vreinterpret_u64_u8(shuffle_lo), 0));
 #endif
 }
 
@@ -697,72 +699,72 @@
 
 SIMD_INLINE v128 v128_shl_8(v128 a, unsigned int c) {
   return (c > 7) ? v128_zero()
-                 : vreinterpretq_s64_u8(
-                       vshlq_u8(vreinterpretq_u8_s64(a), vdupq_n_s8(c)));
+                 : vreinterpretq_s64_u8(vshlq_u8(vreinterpretq_u8_s64(a),
+                                                 vdupq_n_s8((int8_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_u8(v128 a, unsigned int c) {
   return (c > 7) ? v128_zero()
-                 : vreinterpretq_s64_u8(
-                       vshlq_u8(vreinterpretq_u8_s64(a), vdupq_n_s8(-c)));
+                 : vreinterpretq_s64_u8(vshlq_u8(vreinterpretq_u8_s64(a),
+                                                 vdupq_n_s8(-(int8_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_s8(v128 a, unsigned int c) {
   return (c > 7) ? v128_ones()
-                 : vreinterpretq_s64_s8(
-                       vshlq_s8(vreinterpretq_s8_s64(a), vdupq_n_s8(-c)));
+                 : vreinterpretq_s64_s8(vshlq_s8(vreinterpretq_s8_s64(a),
+                                                 vdupq_n_s8(-(int8_t)c)));
 }
 
 SIMD_INLINE v128 v128_shl_16(v128 a, unsigned int c) {
   return (c > 15) ? v128_zero()
-                  : vreinterpretq_s64_u16(
-                        vshlq_u16(vreinterpretq_u16_s64(a), vdupq_n_s16(c)));
+                  : vreinterpretq_s64_u16(vshlq_u16(vreinterpretq_u16_s64(a),
+                                                    vdupq_n_s16((int16_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_u16(v128 a, unsigned int c) {
   return (c > 15) ? v128_zero()
-                  : vreinterpretq_s64_u16(
-                        vshlq_u16(vreinterpretq_u16_s64(a), vdupq_n_s16(-c)));
+                  : vreinterpretq_s64_u16(vshlq_u16(vreinterpretq_u16_s64(a),
+                                                    vdupq_n_s16(-(int16_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_s16(v128 a, unsigned int c) {
   return (c > 15) ? v128_ones()
-                  : vreinterpretq_s64_s16(
-                        vshlq_s16(vreinterpretq_s16_s64(a), vdupq_n_s16(-c)));
+                  : vreinterpretq_s64_s16(vshlq_s16(vreinterpretq_s16_s64(a),
+                                                    vdupq_n_s16(-(int16_t)c)));
 }
 
 SIMD_INLINE v128 v128_shl_32(v128 a, unsigned int c) {
   return (c > 31) ? v128_zero()
-                  : vreinterpretq_s64_u32(
-                        vshlq_u32(vreinterpretq_u32_s64(a), vdupq_n_s32(c)));
+                  : vreinterpretq_s64_u32(vshlq_u32(vreinterpretq_u32_s64(a),
+                                                    vdupq_n_s32((int32_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_u32(v128 a, unsigned int c) {
   return (c > 31) ? v128_zero()
-                  : vreinterpretq_s64_u32(
-                        vshlq_u32(vreinterpretq_u32_s64(a), vdupq_n_s32(-c)));
+                  : vreinterpretq_s64_u32(vshlq_u32(vreinterpretq_u32_s64(a),
+                                                    vdupq_n_s32(-(int32_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_s32(v128 a, unsigned int c) {
   return (c > 31) ? v128_ones()
-                  : vreinterpretq_s64_s32(
-                        vshlq_s32(vreinterpretq_s32_s64(a), vdupq_n_s32(-c)));
+                  : vreinterpretq_s64_s32(vshlq_s32(vreinterpretq_s32_s64(a),
+                                                    vdupq_n_s32(-(int32_t)c)));
 }
 
 SIMD_INLINE v128 v128_shl_64(v128 a, unsigned int c) {
   return (c > 63) ? v128_zero()
-                  : vreinterpretq_s64_u64(
-                        vshlq_u64(vreinterpretq_u64_s64(a), vdupq_n_s64(c)));
+                  : vreinterpretq_s64_u64(vshlq_u64(vreinterpretq_u64_s64(a),
+                                                    vdupq_n_s64((int64_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_u64(v128 a, unsigned int c) {
   return (c > 63) ? v128_zero()
-                  : vreinterpretq_s64_u64(
-                        vshlq_u64(vreinterpretq_u64_s64(a), vdupq_n_s64(-c)));
+                  : vreinterpretq_s64_u64(vshlq_u64(vreinterpretq_u64_s64(a),
+                                                    vdupq_n_s64(-(int64_t)c)));
 }
 
 SIMD_INLINE v128 v128_shr_s64(v128 a, unsigned int c) {
-  return (c > 63) ? v128_ones() : vshlq_s64(a, vdupq_n_s64(-c));
+  return (c > 63) ? v128_ones() : vshlq_s64(a, vdupq_n_s64(-(int64_t)c));
 }
 
 #if defined(__OPTIMIZE__) && __OPTIMIZE__ && !defined(__clang__)
@@ -949,8 +951,8 @@
 
 SIMD_INLINE uint32_t v128_sad_u16_sum(sad128_internal_u16 s) {
   uint64x2_t t = vpaddlq_u32(s);
-  return (uint32_t)(uint64_t)vget_high_u64(t) +
-         (uint32_t)(uint64_t)vget_low_u64(t);
+  return (uint32_t)vget_lane_u64(vadd_u64(vget_high_u64(t), vget_low_u64(t)),
+                                 0);
 }
 
 typedef v128 ssd128_internal_s16;
diff --git a/aom_dsp/simd/v256_intrinsics_v128.h b/aom_dsp/simd/v256_intrinsics_v128.h
index 0d22667..cf44965 100644
--- a/aom_dsp/simd/v256_intrinsics_v128.h
+++ b/aom_dsp/simd/v256_intrinsics_v128.h
@@ -626,15 +626,18 @@
                       vget_high_u8(vreinterpretq_u8_s64(x.val[0])),
                       vget_low_u8(vreinterpretq_u8_s64(x.val[1])),
                       vget_high_u8(vreinterpretq_u8_s64(x.val[1])) } };
-  return v256_from_64(
-      (uint64_t)vreinterpret_s64_u8(
-          vtbl4_u8(p, vreinterpret_u8_s64(vget_high_s64(pattern.val[1])))),
-      (uint64_t)vreinterpret_s64_u8(
-          vtbl4_u8(p, vreinterpret_u8_s64(vget_low_s64(pattern.val[1])))),
-      (uint64_t)vreinterpret_s64_u8(
-          vtbl4_u8(p, vreinterpret_u8_s64(vget_high_s64(pattern.val[0])))),
-      (uint64_t)vreinterpret_s64_u8(
-          vtbl4_u8(p, vreinterpret_u8_s64(vget_low_s64(pattern.val[0])))));
+  uint8x8_t shuffle1_hi =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_high_s64(pattern.val[1])));
+  uint8x8_t shuffle1_lo =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_low_s64(pattern.val[1])));
+  uint8x8_t shuffle0_hi =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_high_s64(pattern.val[0])));
+  uint8x8_t shuffle0_lo =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_low_s64(pattern.val[0])));
+  return v256_from_64(vget_lane_u64(vreinterpret_u64_u8(shuffle1_hi), 0),
+                      vget_lane_u64(vreinterpret_u64_u8(shuffle1_lo), 0),
+                      vget_lane_u64(vreinterpret_u64_u8(shuffle0_hi), 0),
+                      vget_lane_u64(vreinterpret_u64_u8(shuffle0_lo), 0));
 #endif
 #else
   v128 c16 = v128_dup_8(16);
@@ -672,24 +675,26 @@
                       vget_high_u8(vreinterpretq_u8_s64(y.val[0])),
                       vget_low_u8(vreinterpretq_u8_s64(y.val[1])),
                       vget_high_u8(vreinterpretq_u8_s64(y.val[1])) } };
-  v256 r1 =
-      v256_from_64((uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       p, vreinterpret_u8_s64(vget_high_s64(p32.val[1])))),
-                   (uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       p, vreinterpret_u8_s64(vget_low_s64(p32.val[1])))),
-                   (uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       p, vreinterpret_u8_s64(vget_high_s64(p32.val[0])))),
-                   (uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       p, vreinterpret_u8_s64(vget_low_s64(p32.val[0])))));
-  v256 r2 =
-      v256_from_64((uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       q, vreinterpret_u8_s64(vget_high_s64(pattern.val[1])))),
-                   (uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       q, vreinterpret_u8_s64(vget_low_s64(pattern.val[1])))),
-                   (uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       q, vreinterpret_u8_s64(vget_high_s64(pattern.val[0])))),
-                   (uint64_t)vreinterpret_s64_u8(vtbl4_u8(
-                       q, vreinterpret_u8_s64(vget_low_s64(pattern.val[0])))));
+  uint8x8_t shuffle1_hi =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_high_s64(p32.val[1])));
+  uint8x8_t shuffle1_lo =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_low_s64(p32.val[1])));
+  uint8x8_t shuffle0_hi =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_high_s64(p32.val[0])));
+  uint8x8_t shuffle0_lo =
+      vtbl4_u8(p, vreinterpret_u8_s64(vget_low_s64(p32.val[0])));
+  v256 r1 = v256_from_64(vget_lane_u64(vreinterpret_u64_u8(shuffle1_hi), 0),
+                         vget_lane_u64(vreinterpret_u64_u8(shuffle1_lo), 0),
+                         vget_lane_u64(vreinterpret_u64_u8(shuffle0_hi), 0),
+                         vget_lane_u64(vreinterpret_u64_u8(shuffle0_lo), 0));
+  shuffle1_hi = vtbl4_u8(q, vreinterpret_u8_s64(vget_high_s64(pattern.val[1])));
+  shuffle1_lo = vtbl4_u8(q, vreinterpret_u8_s64(vget_low_s64(pattern.val[1])));
+  shuffle0_hi = vtbl4_u8(q, vreinterpret_u8_s64(vget_high_s64(pattern.val[0])));
+  shuffle0_lo = vtbl4_u8(q, vreinterpret_u8_s64(vget_low_s64(pattern.val[0])));
+  v256 r2 = v256_from_64(vget_lane_u64(vreinterpret_u64_u8(shuffle1_hi), 0),
+                         vget_lane_u64(vreinterpret_u64_u8(shuffle1_lo), 0),
+                         vget_lane_u64(vreinterpret_u64_u8(shuffle0_hi), 0),
+                         vget_lane_u64(vreinterpret_u64_u8(shuffle0_lo), 0));
   return v256_blend_8(r1, r2, v256_cmplt_s8(pattern, c32));
 #endif
 #else
diff --git a/aom_dsp/simd/v64_intrinsics_arm.h b/aom_dsp/simd/v64_intrinsics_arm.h
index a4ecdf4..265ebed 100644
--- a/aom_dsp/simd/v64_intrinsics_arm.h
+++ b/aom_dsp/simd/v64_intrinsics_arm.h
@@ -13,6 +13,7 @@
 #define AOM_AOM_DSP_SIMD_V64_INTRINSICS_ARM_H_
 
 #include <arm_neon.h>
+#include <string.h>
 
 #include "aom_dsp/simd/v64_intrinsics_arm.h"
 #include "aom_ports/arm.h"
@@ -50,7 +51,7 @@
 
 SIMD_INLINE v64 v64_from_64(uint64_t x) { return vcreate_s64(x); }
 
-SIMD_INLINE uint64_t v64_u64(v64 x) { return (uint64_t)x; }
+SIMD_INLINE uint64_t v64_u64(v64 x) { return (uint64_t)vget_lane_s64(x, 0); }
 
 SIMD_INLINE uint32_t u32_load_aligned(const void *p) {
   return *((uint32_t *)p);
@@ -77,8 +78,7 @@
   } __attribute__((__packed__));
   ((struct Unaligned32Struct *)p)->value = a;
 #else
-  vst1_lane_u32((uint32_t *)p, vreinterpret_u32_s64((uint64x1_t)(uint64_t)a),
-                0);
+  memcpy(p, &a, 4);
 #endif
 }
 
@@ -106,7 +106,8 @@
                  vext_s8(vreinterpret_s8_s64(b), vreinterpret_s8_s64(a), c))
            : b;
 #else
-  return c ? v64_from_64(((uint64_t)b >> c * 8) | ((uint64_t)a << (8 - c) * 8))
+  return c ? v64_from_64(((uint64_t)vget_lane_s64(b, 0) >> c * 8) |
+                         ((uint64_t)vget_lane_s64(a, 0) << (8 - c) * 8))
            : b;
 #endif
 }
@@ -133,7 +134,7 @@
   return vaddlvq_s16(t);
 #else
   int64x2_t r = vpaddlq_s32(vpaddlq_s16(t));
-  return (int64_t)vadd_s64(vget_high_s64(r), vget_low_s64(r));
+  return vget_lane_s64(vadd_s64(vget_high_s64(r), vget_low_s64(r)), 0);
 #endif
 }
 
@@ -144,7 +145,7 @@
 #else
   int64x2_t r =
       vpaddlq_s32(vmull_s16(vreinterpret_s16_s64(x), vreinterpret_s16_s64(y)));
-  return (int64_t)(vget_high_s64(r) + vget_low_s64(r));
+  return vget_lane_s64(vadd_s64(vget_high_s64(r), vget_low_s64(r)), 0);
 #endif
 }
 
@@ -152,12 +153,13 @@
 #if defined(__aarch64__)
   return vaddlv_u8(vreinterpret_u8_s64(x));
 #else
-  return (uint64_t)vpaddl_u32(vpaddl_u16(vpaddl_u8(vreinterpret_u8_s64(x))));
+  return vget_lane_u64(
+      vpaddl_u32(vpaddl_u16(vpaddl_u8(vreinterpret_u8_s64(x)))), 0);
 #endif
 }
 
 SIMD_INLINE int64_t v64_hadd_s16(v64 a) {
-  return (int64_t)vpaddl_s32(vpaddl_s16(vreinterpret_s16_s64(a)));
+  return vget_lane_s64(vpaddl_s32(vpaddl_s16(vreinterpret_s16_s64(a))), 0);
 }
 
 typedef uint16x8_t sad64_internal;
@@ -175,7 +177,8 @@
   return vaddlvq_u16(s);
 #else
   uint64x2_t r = vpaddlq_u32(vpaddlq_u16(s));
-  return (uint32_t)(uint64_t)(vget_high_u64(r) + vget_low_u64(r));
+  return (uint32_t)vget_lane_u64(vadd_u64(vget_high_u64(r), vget_low_u64(r)),
+                                 0);
 #endif
 }
 
@@ -556,43 +559,48 @@
 }
 
 SIMD_INLINE v64 v64_shl_8(v64 a, unsigned int c) {
-  return vreinterpret_s64_u8(vshl_u8(vreinterpret_u8_s64(a), vdup_n_s8(c)));
+  return vreinterpret_s64_u8(
+      vshl_u8(vreinterpret_u8_s64(a), vdup_n_s8((int8_t)c)));
 }
 
 SIMD_INLINE v64 v64_shr_u8(v64 a, unsigned int c) {
-  return vreinterpret_s64_u8(vshl_u8(vreinterpret_u8_s64(a), vdup_n_s8(-c)));
+  return vreinterpret_s64_u8(
+      vshl_u8(vreinterpret_u8_s64(a), vdup_n_s8(-(int8_t)c)));
 }
 
 SIMD_INLINE v64 v64_shr_s8(v64 a, unsigned int c) {
-  return vreinterpret_s64_s8(vshl_s8(vreinterpret_s8_s64(a), vdup_n_s8(-c)));
+  return vreinterpret_s64_s8(
+      vshl_s8(vreinterpret_s8_s64(a), vdup_n_s8(-(int8_t)c)));
 }
 
 SIMD_INLINE v64 v64_shl_16(v64 a, unsigned int c) {
-  return vreinterpret_s64_u16(vshl_u16(vreinterpret_u16_s64(a), vdup_n_s16(c)));
+  return vreinterpret_s64_u16(
+      vshl_u16(vreinterpret_u16_s64(a), vdup_n_s16((int16_t)c)));
 }
 
 SIMD_INLINE v64 v64_shr_u16(v64 a, unsigned int c) {
   return vreinterpret_s64_u16(
-      vshl_u16(vreinterpret_u16_s64(a), vdup_n_s16(-(int)c)));
+      vshl_u16(vreinterpret_u16_s64(a), vdup_n_s16(-(int16_t)c)));
 }
 
 SIMD_INLINE v64 v64_shr_s16(v64 a, unsigned int c) {
   return vreinterpret_s64_s16(
-      vshl_s16(vreinterpret_s16_s64(a), vdup_n_s16(-(int)c)));
+      vshl_s16(vreinterpret_s16_s64(a), vdup_n_s16(-(int16_t)c)));
 }
 
 SIMD_INLINE v64 v64_shl_32(v64 a, unsigned int c) {
-  return vreinterpret_s64_u32(vshl_u32(vreinterpret_u32_s64(a), vdup_n_s32(c)));
+  return vreinterpret_s64_u32(
+      vshl_u32(vreinterpret_u32_s64(a), vdup_n_s32((int32_t)c)));
 }
 
 SIMD_INLINE v64 v64_shr_u32(v64 a, unsigned int c) {
   return vreinterpret_s64_u32(
-      vshl_u32(vreinterpret_u32_s64(a), vdup_n_s32(-(int)c)));
+      vshl_u32(vreinterpret_u32_s64(a), vdup_n_s32(-(int32_t)c)));
 }
 
 SIMD_INLINE v64 v64_shr_s32(v64 a, unsigned int c) {
   return vreinterpret_s64_s32(
-      vshl_s32(vreinterpret_s32_s64(a), vdup_n_s32(-(int)c)));
+      vshl_s32(vreinterpret_s32_s64(a), vdup_n_s32(-(int32_t)c)));
 }
 
 // The following functions require an immediate.
diff --git a/aom_dsp/variance.c b/aom_dsp/variance.c
index f72feea..63c1e5f 100644
--- a/aom_dsp/variance.c
+++ b/aom_dsp/variance.c
@@ -25,24 +25,6 @@
 #include "av1/common/filter.h"
 #include "av1/common/reconinter.h"
 
-uint32_t aom_get4x4sse_cs_c(const uint8_t *a, int a_stride, const uint8_t *b,
-                            int b_stride) {
-  int distortion = 0;
-  int r, c;
-
-  for (r = 0; r < 4; ++r) {
-    for (c = 0; c < 4; ++c) {
-      int diff = a[c] - b[c];
-      distortion += diff * diff;
-    }
-
-    a += a_stride;
-    b += b_stride;
-  }
-
-  return distortion;
-}
-
 uint32_t aom_get_mb_ss_c(const int16_t *a) {
   unsigned int i, sum = 0;
 
@@ -198,17 +180,6 @@
     return aom_variance##W##x##H(temp3, W, b, b_stride, sse);                  \
   }
 
-/* Identical to the variance call except it takes an additional parameter, sum,
- * and returns that value using pass-by-reference instead of returning
- * sse - sum^2 / w*h
- */
-#define GET_VAR(W, H)                                                         \
-  void aom_get##W##x##H##var_c(const uint8_t *a, int a_stride,                \
-                               const uint8_t *b, int b_stride, uint32_t *sse, \
-                               int *sum) {                                    \
-    variance(a, a_stride, b, b_stride, W, H, sse, sum);                       \
-  }
-
 void aom_get_var_sse_sum_8x8_quad_c(const uint8_t *a, int a_stride,
                                     const uint8_t *b, int b_stride,
                                     uint32_t *sse8x8, int *sum8x8,
@@ -231,7 +202,7 @@
                                       const uint8_t *ref_ptr, int ref_stride,
                                       uint32_t *sse16x16, unsigned int *tot_sse,
                                       int *tot_sum, uint32_t *var16x16) {
-  int sum16x16[64] = { 0 };
+  int sum16x16[2] = { 0 };
   // Loop over two consecutive 16x16 blocks and process as one 16x32 block.
   for (int k = 0; k < 2; k++) {
     variance(src_ptr + (k * 16), source_stride, ref_ptr + (k * 16), ref_stride,
@@ -281,9 +252,6 @@
 VARIANCES(8, 4)
 VARIANCES(4, 8)
 VARIANCES(4, 4)
-VARIANCES(4, 2)
-VARIANCES(2, 4)
-VARIANCES(2, 2)
 
 // Realtime mode doesn't use rectangular blocks.
 #if !CONFIG_REALTIME_ONLY
@@ -295,9 +263,6 @@
 VARIANCES(64, 16)
 #endif
 
-GET_VAR(16, 16)
-GET_VAR(8, 8)
-
 MSE(16, 16)
 MSE(16, 8)
 MSE(8, 16)
@@ -428,25 +393,6 @@
     return (var >= 0) ? (uint32_t)var : 0;                                     \
   }
 
-#define HIGHBD_GET_VAR(S)                                                    \
-  void aom_highbd_8_get##S##x##S##var_c(const uint8_t *src, int src_stride,  \
-                                        const uint8_t *ref, int ref_stride,  \
-                                        uint32_t *sse, int *sum) {           \
-    highbd_8_variance(src, src_stride, ref, ref_stride, S, S, sse, sum);     \
-  }                                                                          \
-                                                                             \
-  void aom_highbd_10_get##S##x##S##var_c(const uint8_t *src, int src_stride, \
-                                         const uint8_t *ref, int ref_stride, \
-                                         uint32_t *sse, int *sum) {          \
-    highbd_10_variance(src, src_stride, ref, ref_stride, S, S, sse, sum);    \
-  }                                                                          \
-                                                                             \
-  void aom_highbd_12_get##S##x##S##var_c(const uint8_t *src, int src_stride, \
-                                         const uint8_t *ref, int ref_stride, \
-                                         uint32_t *sse, int *sum) {          \
-    highbd_12_variance(src, src_stride, ref, ref_stride, S, S, sse, sum);    \
-  }
-
 #define HIGHBD_MSE(W, H)                                                      \
   uint32_t aom_highbd_8_mse##W##x##H##_c(const uint8_t *src, int src_stride,  \
                                          const uint8_t *ref, int ref_stride,  \
@@ -706,9 +652,6 @@
 HIGHBD_VARIANCES(8, 4)
 HIGHBD_VARIANCES(4, 8)
 HIGHBD_VARIANCES(4, 4)
-HIGHBD_VARIANCES(4, 2)
-HIGHBD_VARIANCES(2, 4)
-HIGHBD_VARIANCES(2, 2)
 
 // Realtime mode doesn't use 4x rectangular blocks.
 #if !CONFIG_REALTIME_ONLY
@@ -720,9 +663,6 @@
 HIGHBD_VARIANCES(64, 16)
 #endif
 
-HIGHBD_GET_VAR(8)
-HIGHBD_GET_VAR(16)
-
 HIGHBD_MSE(16, 16)
 HIGHBD_MSE(16, 8)
 HIGHBD_MSE(8, 16)
diff --git a/aom_dsp/x86/avg_intrin_sse2.c b/aom_dsp/x86/avg_intrin_sse2.c
index 71e7028..2d94f0e 100644
--- a/aom_dsp/x86/avg_intrin_sse2.c
+++ b/aom_dsp/x86/avg_intrin_sse2.c
@@ -344,56 +344,6 @@
   hadamard_8x8_sse2(src_diff, src_stride, coeff, 1);
 }
 
-void aom_pixel_scale_sse2(const int16_t *src_diff, ptrdiff_t src_stride,
-                          int16_t *coeff, int log_scale, int h8, int w8) {
-  __m128i src[8];
-  const int16_t *org_src_diff = src_diff;
-  int16_t *org_coeff = coeff;
-  int coeff_stride = w8 << 3;
-  for (int idy = 0; idy < h8; ++idy) {
-    for (int idx = 0; idx < w8; ++idx) {
-      src_diff = org_src_diff + (idx << 3);
-      coeff = org_coeff + (idx << 3);
-
-      src[0] = _mm_load_si128((const __m128i *)src_diff);
-      src[1] = _mm_load_si128((const __m128i *)(src_diff += src_stride));
-      src[2] = _mm_load_si128((const __m128i *)(src_diff += src_stride));
-      src[3] = _mm_load_si128((const __m128i *)(src_diff += src_stride));
-      src[4] = _mm_load_si128((const __m128i *)(src_diff += src_stride));
-      src[5] = _mm_load_si128((const __m128i *)(src_diff += src_stride));
-      src[6] = _mm_load_si128((const __m128i *)(src_diff += src_stride));
-      src[7] = _mm_load_si128((const __m128i *)(src_diff + src_stride));
-
-      src[0] = _mm_slli_epi16(src[0], log_scale);
-      src[1] = _mm_slli_epi16(src[1], log_scale);
-      src[2] = _mm_slli_epi16(src[2], log_scale);
-      src[3] = _mm_slli_epi16(src[3], log_scale);
-      src[4] = _mm_slli_epi16(src[4], log_scale);
-      src[5] = _mm_slli_epi16(src[5], log_scale);
-      src[6] = _mm_slli_epi16(src[6], log_scale);
-      src[7] = _mm_slli_epi16(src[7], log_scale);
-
-      _mm_store_si128((__m128i *)coeff, src[0]);
-      coeff += coeff_stride;
-      _mm_store_si128((__m128i *)coeff, src[1]);
-      coeff += coeff_stride;
-      _mm_store_si128((__m128i *)coeff, src[2]);
-      coeff += coeff_stride;
-      _mm_store_si128((__m128i *)coeff, src[3]);
-      coeff += coeff_stride;
-      _mm_store_si128((__m128i *)coeff, src[4]);
-      coeff += coeff_stride;
-      _mm_store_si128((__m128i *)coeff, src[5]);
-      coeff += coeff_stride;
-      _mm_store_si128((__m128i *)coeff, src[6]);
-      coeff += coeff_stride;
-      _mm_store_si128((__m128i *)coeff, src[7]);
-    }
-    org_src_diff += (src_stride << 3);
-    org_coeff += (coeff_stride << 3);
-  }
-}
-
 static INLINE void hadamard_lp_8x8_sse2(const int16_t *src_diff,
                                         ptrdiff_t src_stride, int16_t *coeff) {
   __m128i src[8];
diff --git a/aom_dsp/x86/convolve_avx2.h b/aom_dsp/x86/convolve_avx2.h
index a709008..f5a382c 100644
--- a/aom_dsp/x86/convolve_avx2.h
+++ b/aom_dsp/x86/convolve_avx2.h
@@ -329,20 +329,20 @@
           _mm256_castsi128_si256(_mm_loadu_si128(                              \
               (__m128i *)(&src_ptr[i * src_stride + src_stride + j]))),        \
           0x20);                                                               \
-      const __m256i s_16l = _mm256_unpacklo_epi8(data, v_zero);                \
-      const __m256i s_16h = _mm256_unpackhi_epi8(data, v_zero);                \
-      const __m256i s_ll = _mm256_unpacklo_epi16(s_16l, s_16l);                \
-      const __m256i s_lh = _mm256_unpackhi_epi16(s_16l, s_16l);                \
+      const __m256i s_16lo = _mm256_unpacklo_epi8(data, v_zero);               \
+      const __m256i s_16hi = _mm256_unpackhi_epi8(data, v_zero);               \
+      const __m256i s_lolo = _mm256_unpacklo_epi16(s_16lo, s_16lo);            \
+      const __m256i s_lohi = _mm256_unpackhi_epi16(s_16lo, s_16lo);            \
                                                                                \
-      const __m256i s_hl = _mm256_unpacklo_epi16(s_16h, s_16h);                \
-      const __m256i s_hh = _mm256_unpackhi_epi16(s_16h, s_16h);                \
+      const __m256i s_hilo = _mm256_unpacklo_epi16(s_16hi, s_16hi);            \
+      const __m256i s_hihi = _mm256_unpackhi_epi16(s_16hi, s_16hi);            \
                                                                                \
-      s[0] = _mm256_alignr_epi8(s_lh, s_ll, 2);                                \
-      s[1] = _mm256_alignr_epi8(s_lh, s_ll, 10);                               \
-      s[2] = _mm256_alignr_epi8(s_hl, s_lh, 2);                                \
-      s[3] = _mm256_alignr_epi8(s_hl, s_lh, 10);                               \
-      s[4] = _mm256_alignr_epi8(s_hh, s_hl, 2);                                \
-      s[5] = _mm256_alignr_epi8(s_hh, s_hl, 10);                               \
+      s[0] = _mm256_alignr_epi8(s_lohi, s_lolo, 2);                            \
+      s[1] = _mm256_alignr_epi8(s_lohi, s_lolo, 10);                           \
+      s[2] = _mm256_alignr_epi8(s_hilo, s_lohi, 2);                            \
+      s[3] = _mm256_alignr_epi8(s_hilo, s_lohi, 10);                           \
+      s[4] = _mm256_alignr_epi8(s_hihi, s_hilo, 2);                            \
+      s[5] = _mm256_alignr_epi8(s_hihi, s_hilo, 10);                           \
                                                                                \
       const __m256i res_lo = convolve_12taps(s, coeffs_h);                     \
                                                                                \
@@ -373,21 +373,21 @@
           _mm256_castsi128_si256(                                              \
               _mm_loadu_si128((__m128i *)(&src_ptr[i * src_stride + j + 4]))), \
           0x20);                                                               \
-      const __m256i s_16l = _mm256_unpacklo_epi8(data, v_zero);                \
-      const __m256i s_16h = _mm256_unpackhi_epi8(data, v_zero);                \
+      const __m256i s_16lo = _mm256_unpacklo_epi8(data, v_zero);               \
+      const __m256i s_16hi = _mm256_unpackhi_epi8(data, v_zero);               \
                                                                                \
-      const __m256i s_ll = _mm256_unpacklo_epi16(s_16l, s_16l);                \
-      const __m256i s_lh = _mm256_unpackhi_epi16(s_16l, s_16l);                \
+      const __m256i s_lolo = _mm256_unpacklo_epi16(s_16lo, s_16lo);            \
+      const __m256i s_lohi = _mm256_unpackhi_epi16(s_16lo, s_16lo);            \
                                                                                \
-      const __m256i s_hl = _mm256_unpacklo_epi16(s_16h, s_16h);                \
-      const __m256i s_hh = _mm256_unpackhi_epi16(s_16h, s_16h);                \
+      const __m256i s_hilo = _mm256_unpacklo_epi16(s_16hi, s_16hi);            \
+      const __m256i s_hihi = _mm256_unpackhi_epi16(s_16hi, s_16hi);            \
                                                                                \
-      s[0] = _mm256_alignr_epi8(s_lh, s_ll, 2);                                \
-      s[1] = _mm256_alignr_epi8(s_lh, s_ll, 10);                               \
-      s[2] = _mm256_alignr_epi8(s_hl, s_lh, 2);                                \
-      s[3] = _mm256_alignr_epi8(s_hl, s_lh, 10);                               \
-      s[4] = _mm256_alignr_epi8(s_hh, s_hl, 2);                                \
-      s[5] = _mm256_alignr_epi8(s_hh, s_hl, 10);                               \
+      s[0] = _mm256_alignr_epi8(s_lohi, s_lolo, 2);                            \
+      s[1] = _mm256_alignr_epi8(s_lohi, s_lolo, 10);                           \
+      s[2] = _mm256_alignr_epi8(s_hilo, s_lohi, 2);                            \
+      s[3] = _mm256_alignr_epi8(s_hilo, s_lohi, 10);                           \
+      s[4] = _mm256_alignr_epi8(s_hihi, s_hilo, 2);                            \
+      s[5] = _mm256_alignr_epi8(s_hihi, s_hilo, 10);                           \
                                                                                \
       const __m256i res_lo = convolve_12taps(s, coeffs_h);                     \
                                                                                \
diff --git a/aom_dsp/x86/fwd_txfm_impl_sse2.h b/aom_dsp/x86/fwd_txfm_impl_sse2.h
index 89fe189..7ee8ba3 100644
--- a/aom_dsp/x86/fwd_txfm_impl_sse2.h
+++ b/aom_dsp/x86/fwd_txfm_impl_sse2.h
@@ -180,25 +180,8 @@
       const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS2);
       const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS2);
       const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS2);
-      // w0 = [o0 o4 o8 oC]
-      // w1 = [o2 o6 oA oE]
-      // w2 = [o1 o5 o9 oD]
-      // w3 = [o3 o7 oB oF]
-      // remember the o's are numbered according to the correct output location
-      const __m128i x0 = _mm_packs_epi32(w0, w1);
-      const __m128i x1 = _mm_packs_epi32(w2, w3);
-      {
-        // x0 = [o0 o4 o8 oC o2 o6 oA oE]
-        // x1 = [o1 o5 o9 oD o3 o7 oB oF]
-        const __m128i y0 = _mm_unpacklo_epi16(x0, x1);
-        const __m128i y1 = _mm_unpackhi_epi16(x0, x1);
-        // y0 = [o0 o1 o4 o5 o8 o9 oC oD]
-        // y1 = [o2 o3 o6 o7 oA oB oE oF]
-        *in0 = _mm_unpacklo_epi32(y0, y1);
-        // in0 = [o0 o1 o2 o3 o4 o5 o6 o7]
-        *in1 = _mm_unpackhi_epi32(y0, y1);
-        // in1 = [o8 o9 oA oB oC oD oE oF]
-      }
+      *in0 = _mm_packs_epi32(w0, w2);
+      *in1 = _mm_packs_epi32(w1, w3);
     }
   }
 }
@@ -230,6 +213,7 @@
   _mm_storeu_si128((__m128i *)(output + 2 * 4), in1);
 }
 
+#if CONFIG_INTERNAL_STATS
 void FDCT8x8_2D(const int16_t *input, tran_low_t *output, int stride) {
   int pass;
   // Constants
@@ -539,6 +523,7 @@
     store_output(&in7, (output + 7 * 8));
   }
 }
+#endif  // CONFIG_INTERNAL_STATS
 
 #undef ADD_EPI16
 #undef SUB_EPI16
diff --git a/aom_dsp/x86/highbd_variance_sse2.c b/aom_dsp/x86/highbd_variance_sse2.c
index d45885c..e897aab 100644
--- a/aom_dsp/x86/highbd_variance_sse2.c
+++ b/aom_dsp/x86/highbd_variance_sse2.c
@@ -98,43 +98,6 @@
   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);
 }
 
-#define HIGH_GET_VAR(S)                                                       \
-  void aom_highbd_get##S##x##S##var_sse2(const uint8_t *src8, int src_stride, \
-                                         const uint8_t *ref8, int ref_stride, \
-                                         uint32_t *sse, int *sum) {           \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
-    aom_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
-                                       sum);                                  \
-  }                                                                           \
-                                                                              \
-  void aom_highbd_10_get##S##x##S##var_sse2(                                  \
-      const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
-      int ref_stride, uint32_t *sse, int *sum) {                              \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
-    aom_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
-                                       sum);                                  \
-    *sum = ROUND_POWER_OF_TWO(*sum, 2);                                       \
-    *sse = ROUND_POWER_OF_TWO(*sse, 4);                                       \
-  }                                                                           \
-                                                                              \
-  void aom_highbd_12_get##S##x##S##var_sse2(                                  \
-      const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
-      int ref_stride, uint32_t *sse, int *sum) {                              \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
-    aom_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
-                                       sum);                                  \
-    *sum = ROUND_POWER_OF_TWO(*sum, 4);                                       \
-    *sse = ROUND_POWER_OF_TWO(*sse, 8);                                       \
-  }
-
-HIGH_GET_VAR(16)
-HIGH_GET_VAR(8)
-
-#undef HIGH_GET_VAR
-
 #define VAR_FN(w, h, block_size, shift)                                    \
   uint32_t aom_highbd_8_variance##w##x##h##_sse2(                          \
       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
diff --git a/aom_dsp/x86/intrapred_sse4.c b/aom_dsp/x86/intrapred_sse4.c
index 3f72dc4..fb30420 100644
--- a/aom_dsp/x86/intrapred_sse4.c
+++ b/aom_dsp/x86/intrapred_sse4.c
@@ -602,7 +602,7 @@
   const __m128i c1234 = _mm_setr_epi16(1, 2, 3, 4, 5, 6, 7, 8);
 
   for (int r = 0; r < N; r++) {
-    __m128i b, res, res1, shift, shifty;
+    __m128i b, res, res1, shift;
     __m128i resx, resy, resxy, r6, ydx;
 
     int y = r + 1;
@@ -620,11 +620,7 @@
     }
 
     if (base_shift > 7) {
-      a0_x = _mm_setzero_si128();
-      a1_x = _mm_setzero_si128();
-      a0_y = _mm_setzero_si128();
-      a1_y = _mm_setzero_si128();
-      shift = _mm_setzero_si128();
+      resx = _mm_setzero_si128();
     } else {
       a0_above = _mm_loadu_si128((__m128i *)(above + base_x + base_shift));
       ydx = _mm_set1_epi16(y * dx);
@@ -649,9 +645,15 @@
       }
       a0_x = _mm_cvtepu8_epi16(a0_above);
       a1_x = _mm_cvtepu8_epi16(a1_above);
-      a0_y = _mm_setzero_si128();
-      a1_y = _mm_setzero_si128();
-      shifty = shift;
+
+      diff = _mm_sub_epi16(a1_x, a0_x);  // a[x+1] - a[x]
+      a32 = _mm_slli_epi16(a0_x, 5);     // a[x] * 32
+      a32 = _mm_add_epi16(a32, a16);     // a[x] * 32 + 16
+
+      b = _mm_mullo_epi16(diff, shift);
+      res = _mm_add_epi16(a32, b);
+      res = _mm_srli_epi16(res, 5);
+      resx = _mm_packus_epi16(res, res);
     }
 
     // y calc
@@ -678,34 +680,27 @@
                             left[base_y_c[6]], left[base_y_c[7]]);
 
       if (upsample_left) {
-        shifty = _mm_srli_epi16(
+        shift = _mm_srli_epi16(
             _mm_and_si128(_mm_slli_epi16(y_c, upsample_left), c3f), 1);
       } else {
-        shifty = _mm_srli_epi16(_mm_and_si128(y_c, c3f), 1);
+        shift = _mm_srli_epi16(_mm_and_si128(y_c, c3f), 1);
       }
+
+      diff = _mm_sub_epi16(a1_y, a0_y);  // a[x+1] - a[x]
+      a32 = _mm_slli_epi16(a0_y, 5);     // a[x] * 32
+      a32 = _mm_add_epi16(a32, a16);     // a[x] * 32 + 16
+
+      b = _mm_mullo_epi16(diff, shift);
+      res1 = _mm_add_epi16(a32, b);
+      res1 = _mm_srli_epi16(res1, 5);
+
+      resy = _mm_packus_epi16(res1, res1);
+      resxy = _mm_blendv_epi8(resx, resy, *(__m128i *)Mask[0][base_min_diff]);
+      _mm_storel_epi64((__m128i *)dst, resxy);
+    } else {
+      _mm_storel_epi64((__m128i *)dst, resx);
     }
 
-    diff = _mm_sub_epi16(a1_x, a0_x);  // a[x+1] - a[x]
-    a32 = _mm_slli_epi16(a0_x, 5);     // a[x] * 32
-    a32 = _mm_add_epi16(a32, a16);     // a[x] * 32 + 16
-
-    b = _mm_mullo_epi16(diff, shift);
-    res = _mm_add_epi16(a32, b);
-    res = _mm_srli_epi16(res, 5);
-
-    diff = _mm_sub_epi16(a1_y, a0_y);  // a[x+1] - a[x]
-    a32 = _mm_slli_epi16(a0_y, 5);     // a[x] * 32
-    a32 = _mm_add_epi16(a32, a16);     // a[x] * 32 + 16
-
-    b = _mm_mullo_epi16(diff, shifty);
-    res1 = _mm_add_epi16(a32, b);
-    res1 = _mm_srli_epi16(res1, 5);
-
-    resx = _mm_packus_epi16(res, res);
-    resy = _mm_packus_epi16(res1, res1);
-
-    resxy = _mm_blendv_epi8(resx, resy, *(__m128i *)Mask[0][base_min_diff]);
-    _mm_storel_epi64((__m128i *)(dst), resxy);
     dst += stride;
   }
 }
diff --git a/aom_dsp/x86/jnt_sad_ssse3.c b/aom_dsp/x86/jnt_sad_sse2.c
similarity index 66%
rename from aom_dsp/x86/jnt_sad_ssse3.c
rename to aom_dsp/x86/jnt_sad_sse2.c
index 357f70a..16d2f4b 100644
--- a/aom_dsp/x86/jnt_sad_ssse3.c
+++ b/aom_dsp/x86/jnt_sad_sse2.c
@@ -10,16 +10,16 @@
  */
 
 #include <assert.h>
-#include <emmintrin.h>  // SSE2
-#include <tmmintrin.h>
+#include <emmintrin.h>
 
 #include "config/aom_config.h"
 #include "config/aom_dsp_rtcd.h"
 
 #include "aom_dsp/x86/synonyms.h"
 
-unsigned int aom_sad4xh_sse2(const uint8_t *a, int a_stride, const uint8_t *b,
-                             int b_stride, int width, int height) {
+static unsigned int sad4xh_sse2(const uint8_t *a, int a_stride,
+                                const uint8_t *b, int b_stride, int width,
+                                int height) {
   int i;
   assert(width == 4);
   (void)width;
@@ -59,8 +59,9 @@
   return res;
 }
 
-unsigned int aom_sad8xh_sse2(const uint8_t *a, int a_stride, const uint8_t *b,
-                             int b_stride, int width, int height) {
+static unsigned int sad8xh_sse2(const uint8_t *a, int a_stride,
+                                const uint8_t *b, int b_stride, int width,
+                                int height) {
   int i;
   assert(width == 8);
   (void)width;
@@ -91,8 +92,9 @@
   return res;
 }
 
-unsigned int aom_sad16xh_sse2(const uint8_t *a, int a_stride, const uint8_t *b,
-                              int b_stride, int width, int height) {
+static unsigned int sad16xh_sse2(const uint8_t *a, int a_stride,
+                                 const uint8_t *b, int b_stride, int width,
+                                 int height) {
   int i;
   assert(width == 16);
   (void)width;
@@ -116,8 +118,9 @@
   return res;
 }
 
-unsigned int aom_sad32xh_sse2(const uint8_t *a, int a_stride, const uint8_t *b,
-                              int b_stride, int width, int height) {
+static unsigned int sad32xh_sse2(const uint8_t *a, int a_stride,
+                                 const uint8_t *b, int b_stride, int width,
+                                 int height) {
   int i, j;
   assert(width == 32);
   (void)width;
@@ -143,8 +146,9 @@
   return res;
 }
 
-unsigned int aom_sad64xh_sse2(const uint8_t *a, int a_stride, const uint8_t *b,
-                              int b_stride, int width, int height) {
+static unsigned int sad64xh_sse2(const uint8_t *a, int a_stride,
+                                 const uint8_t *b, int b_stride, int width,
+                                 int height) {
   int i, j;
   assert(width == 64);
   (void)width;
@@ -170,8 +174,9 @@
   return res;
 }
 
-unsigned int aom_sad128xh_sse2(const uint8_t *a, int a_stride, const uint8_t *b,
-                               int b_stride, int width, int height) {
+static unsigned int sad128xh_sse2(const uint8_t *a, int a_stride,
+                                  const uint8_t *b, int b_stride, int width,
+                                  int height) {
   int i, j;
   assert(width == 128);
   (void)width;
@@ -197,47 +202,37 @@
   return res;
 }
 
-#define dist_wtd_sadMxN_sse2(m, n)                                            \
-  unsigned int aom_dist_wtd_sad##m##x##n##_avg_ssse3(                         \
+#define DIST_WTD_SADMXN_SSE2(m, n)                                            \
+  unsigned int aom_dist_wtd_sad##m##x##n##_avg_sse2(                          \
       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
       const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) {    \
     uint8_t comp_pred[m * n];                                                 \
     aom_dist_wtd_comp_avg_pred(comp_pred, second_pred, m, n, ref, ref_stride, \
                                jcp_param);                                    \
-    return aom_sad##m##xh_sse2(src, src_stride, comp_pred, m, m, n);          \
+    return sad##m##xh_sse2(src, src_stride, comp_pred, m, m, n);              \
   }
 
-#define dist_wtd_sadMxN_avx2(m, n)                                            \
-  unsigned int aom_dist_wtd_sad##m##x##n##_avg_avx2(                          \
-      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
-      const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) {    \
-    uint8_t comp_pred[m * n];                                                 \
-    aom_dist_wtd_comp_avg_pred(comp_pred, second_pred, m, n, ref, ref_stride, \
-                               jcp_param);                                    \
-    return aom_sad##m##xh_avx2(src, src_stride, comp_pred, m, m, n);          \
-  }
-
-/* clang-format off */
-dist_wtd_sadMxN_sse2(128, 128)
-dist_wtd_sadMxN_sse2(128, 64)
-dist_wtd_sadMxN_sse2(64, 128)
-dist_wtd_sadMxN_sse2(64, 64)
-dist_wtd_sadMxN_sse2(64, 32)
-dist_wtd_sadMxN_sse2(32, 64)
-dist_wtd_sadMxN_sse2(32, 32)
-dist_wtd_sadMxN_sse2(32, 16)
-dist_wtd_sadMxN_sse2(16, 32)
-dist_wtd_sadMxN_sse2(16, 16)
-dist_wtd_sadMxN_sse2(16, 8)
-dist_wtd_sadMxN_sse2(8, 16)
-dist_wtd_sadMxN_sse2(8, 8)
-dist_wtd_sadMxN_sse2(8, 4)
-dist_wtd_sadMxN_sse2(4, 8)
-dist_wtd_sadMxN_sse2(4, 4)
-dist_wtd_sadMxN_sse2(4, 16)
-dist_wtd_sadMxN_sse2(16, 4)
-dist_wtd_sadMxN_sse2(8, 32)
-dist_wtd_sadMxN_sse2(32, 8)
-dist_wtd_sadMxN_sse2(16, 64)
-dist_wtd_sadMxN_sse2(64, 16)
-    /* clang-format on */
+DIST_WTD_SADMXN_SSE2(128, 128)
+DIST_WTD_SADMXN_SSE2(128, 64)
+DIST_WTD_SADMXN_SSE2(64, 128)
+DIST_WTD_SADMXN_SSE2(64, 64)
+DIST_WTD_SADMXN_SSE2(64, 32)
+DIST_WTD_SADMXN_SSE2(32, 64)
+DIST_WTD_SADMXN_SSE2(32, 32)
+DIST_WTD_SADMXN_SSE2(32, 16)
+DIST_WTD_SADMXN_SSE2(16, 32)
+DIST_WTD_SADMXN_SSE2(16, 16)
+DIST_WTD_SADMXN_SSE2(16, 8)
+DIST_WTD_SADMXN_SSE2(8, 16)
+DIST_WTD_SADMXN_SSE2(8, 8)
+DIST_WTD_SADMXN_SSE2(8, 4)
+DIST_WTD_SADMXN_SSE2(4, 8)
+DIST_WTD_SADMXN_SSE2(4, 4)
+#if !CONFIG_REALTIME_ONLY
+DIST_WTD_SADMXN_SSE2(4, 16)
+DIST_WTD_SADMXN_SSE2(16, 4)
+DIST_WTD_SADMXN_SSE2(8, 32)
+DIST_WTD_SADMXN_SSE2(32, 8)
+DIST_WTD_SADMXN_SSE2(16, 64)
+DIST_WTD_SADMXN_SSE2(64, 16)
+#endif
diff --git a/aom_dsp/x86/sad4d_sse2.asm b/aom_dsp/x86/sad4d_sse2.asm
index 6de708b..6696c40 100644
--- a/aom_dsp/x86/sad4d_sse2.asm
+++ b/aom_dsp/x86/sad4d_sse2.asm
@@ -15,13 +15,6 @@
 
 SECTION .text
 
-%macro AVG_4x2x4 2
-  movh                  m2, [second_predq]
-  movlhps               m2, m2
-  pavgb                 %1, m2
-  pavgb                 %2, m2
-  lea                   second_predq, [second_predq+8]
-%endmacro
 ; 'spill_src_stride' affect a lot how the code works.
 ;
 ; When 'spill_src_stride' is false, the 'src_strideq' resides in
@@ -64,8 +57,8 @@
   lea                ref4q, [ref4q+ref_strideq*2]
 %endmacro
 
-; PROCESS_4x2x4 first, do_avg
-%macro PROCESS_4x2x4 2
+; PROCESS_4x2x4 first
+%macro PROCESS_4x2x4 1
   movd                  m0, [srcq]
   HANDLE_SECOND_OFFSET
 %if %1 == 1
@@ -87,9 +80,6 @@
   movlhps               m0, m0
   movlhps               m6, m4
   movlhps               m7, m5
-%if %2 == 1
-  AVG_4x2x4             m6, m7
-%endif
   psadbw                m6, m0
   psadbw                m7, m0
 %else
@@ -110,9 +100,6 @@
   movlhps               m0, m0
   movlhps               m1, m2
   movlhps               m3, m4
-%if %2 == 1
-  AVG_4x2x4             m1, m3
-%endif
   psadbw                m1, m0
   psadbw                m3, m0
   paddd                 m6, m1
@@ -120,8 +107,8 @@
 %endif
 %endmacro
 
-; PROCESS_8x2x4 first, do_avg
-%macro PROCESS_8x2x4 2
+; PROCESS_8x2x4 first
+%macro PROCESS_8x2x4 1
   movh                  m0, [srcq]
   HANDLE_SECOND_OFFSET
 %if %1 == 1
@@ -134,14 +121,6 @@
   movhps                m5, [ref2q+ref_strideq]
   movhps                m6, [ref3q+ref_strideq]
   movhps                m7, [ref4q+ref_strideq]
-%if %2 == 1
-  movu                  m3, [second_predq]
-  pavgb                 m4, m3
-  pavgb                 m5, m3
-  pavgb                 m6, m3
-  pavgb                 m7, m3
-  lea                   second_predq, [second_predq+mmsize]
-%endif
   psadbw                m4, m0
   psadbw                m5, m0
   psadbw                m6, m0
@@ -152,11 +131,6 @@
   movhps                m0, [srcq + second_offset]
   movhps                m1, [ref1q+ref_strideq]
   movhps                m2, [ref2q+ref_strideq]
-%if %2 == 1
-  movu                  m3, [second_predq]
-  pavgb                 m1, m3
-  pavgb                 m2, m3
-%endif
   psadbw                m1, m0
   psadbw                m2, m0
   paddd                 m4, m1
@@ -166,11 +140,6 @@
   movhps                m1, [ref3q+ref_strideq]
   movh                  m2, [ref4q]
   movhps                m2, [ref4q+ref_strideq]
-%if %2 == 1
-  pavgb                 m1, m3
-  pavgb                 m2, m3
-  lea                   second_predq, [second_predq+mmsize]
-%endif
   psadbw                m1, m0
   psadbw                m2, m0
   paddd                 m6, m1
@@ -178,37 +147,24 @@
 %endif
 %endmacro
 
-; PROCESS_FIRST_MMSIZE do_avg
-%macro PROCESS_FIRST_MMSIZE 1
+; PROCESS_FIRST_MMSIZE
+%macro PROCESS_FIRST_MMSIZE 0
   mova                  m0, [srcq]
   movu                  m4, [ref1q]
   movu                  m5, [ref2q]
   movu                  m6, [ref3q]
   movu                  m7, [ref4q]
-%if %1 == 1
-  movu                  m3, [second_predq]
-  pavgb                 m4, m3
-  pavgb                 m5, m3
-  pavgb                 m6, m3
-  pavgb                 m7, m3
-  lea                   second_predq, [second_predq+mmsize]
-%endif
   psadbw                m4, m0
   psadbw                m5, m0
   psadbw                m6, m0
   psadbw                m7, m0
 %endmacro
 
-; PROCESS_16x1x4 offset, do_avg
-%macro PROCESS_16x1x4 2
+; PROCESS_16x1x4 offset
+%macro PROCESS_16x1x4 1
   mova                  m0, [srcq + %1]
   movu                  m1, [ref1q + ref_offsetq + %1]
   movu                  m2, [ref2q + ref_offsetq + %1]
-%if %2 == 1
-  movu                  m3, [second_predq]
-  pavgb                 m1, m3
-  pavgb                 m2, m3
-%endif
   psadbw                m1, m0
   psadbw                m2, m0
   paddd                 m4, m1
@@ -216,11 +172,6 @@
 
   movu                  m1, [ref3q + ref_offsetq + %1]
   movu                  m2, [ref4q + ref_offsetq + %1]
-%if %2 == 1
-  pavgb                 m1, m3
-  pavgb                 m2, m3
-  lea                   second_predq, [second_predq+mmsize]
-%endif
   psadbw                m1, m0
   psadbw                m2, m0
   paddd                 m6, m1
@@ -233,9 +184,8 @@
 ; Macro Arguments:
 ;   1: Width
 ;   2: Height
-;   3: If 0, then normal sad, else avg
-;   4: If 0, then normal sad, else skip rows
-%macro SADNXN4D 2-4 0,0
+;   3: If 0, then normal sad, else skip rows
+%macro SADNXN4D 2-3 0
 
 %define spill_src_stride 0
 %define spill_ref_stride 0
@@ -249,7 +199,7 @@
 ; Remove loops in the 4x4 and 8x4 case
 %define use_loop (use_ref_offset || %2 > 4)
 
-%if %4 == 1  ; skip rows
+%if %3 == 1  ; skip rows
 %if ARCH_X86_64
 %if use_ref_offset
 cglobal sad_skip_%1x%2x4d, 5, 10, 8, src, src_stride, ref1, ref_stride, res, \
@@ -276,7 +226,7 @@
                                     ref3, ref4
 %endif
 %endif
-%elif %3 == 0  ; normal sad
+%else ; normal sad
 %if ARCH_X86_64
 %if use_ref_offset
 cglobal sad%1x%2x4d, 5, 10, 8, src, src_stride, ref1, ref_stride, res, ref2, \
@@ -301,34 +251,6 @@
                               ref4
 %endif
 %endif
-%else ; avg
-%if ARCH_X86_64
-%if use_ref_offset
-cglobal sad%1x%2x4d_avg, 6, 11, 8, src, src_stride, ref1, ref_stride, \
-                                   second_pred, res, ref2, ref3, ref4, cnt, \
-                                   ref_offset
-%elif use_loop
-cglobal sad%1x%2x4d_avg, 6, 10, 8, src, src_stride, ref1, ref_stride, \
-                                   second_pred, res, ref2, ref3, ref4, cnt
-%else
-cglobal sad%1x%2x4d_avg, 6, 9, 8, src, src_stride, ref1, ref_stride, \
-                                   second_pred, res, ref2, ref3, ref4
-%endif
-%else
-%if use_ref_offset
-cglobal sad%1x%2x4d_avg, 5, 7, 8, src, ref4, ref1, ref_offset, second_pred, ref2, ref3
-  %define spill_src_stride 1
-  %define spill_ref_stride 1
-  %define spill_cnt 1
-%elif use_loop
-cglobal sad%1x%2x4d_avg, 5, 7, 8, src, ref4, ref1, ref_stride, second_pred, ref2, ref3
-  %define spill_src_stride 1
-  %define spill_cnt 1
-%else
-cglobal sad%1x%2x4d_avg, 5, 7, 8, src, ref4, ref1, ref_stride, second_pred, ref2, ref3
-  %define spill_src_stride 1
-%endif
-%endif
 %endif
 
 %if spill_src_stride
@@ -345,7 +267,7 @@
   %define cntd word [rsp]
 %endif
 
-%if %4 == 1
+%if %3 == 1
   sal          src_strided, 1
   sal          ref_strided, 1
 %endif
@@ -362,14 +284,12 @@
 %define external_loop (use_ref_offset && %1 > mmsize && %1 != %2)
 
 %if use_ref_offset
-  PROCESS_FIRST_MMSIZE %3
+  PROCESS_FIRST_MMSIZE
 %if %1 > mmsize
   mov          ref_offsetq, 0
-  mov                 cntd, %2 >> %4
+  mov                 cntd, %2 >> %3
 ; Jump part way into the loop for the square version of this width
 %if %3 == 1
-  jmp mangle(private_prefix %+ _sad%1x%1x4d_avg %+ SUFFIX).midloop
-%elif %4 == 1
   jmp mangle(private_prefix %+ _sad_skip_%1x%1x4d %+ SUFFIX).midloop
 %else
   jmp mangle(private_prefix %+ _sad%1x%1x4d %+ SUFFIX).midloop
@@ -377,14 +297,14 @@
 %else
   mov          ref_offsetq, ref_strideq
   add                 srcq, src_strideq
-  mov                 cntd, (%2 >> %4) - 1
+  mov                 cntd, (%2 >> %3) - 1
 %endif
 %if external_loop == 0
 .loop:
 ; Unrolled horizontal loop
 %assign h_offset 0
 %rep %1/mmsize
-  PROCESS_16x1x4 h_offset, %3
+  PROCESS_16x1x4 h_offset
 %if h_offset == 0
 ; The first row of the first column is done outside the loop and jumps here
 .midloop:
@@ -398,13 +318,13 @@
   jnz .loop
 %endif
 %else
-  PROCESS_%1x2x4 1, %3
+  PROCESS_%1x2x4 1
   ADVANCE_END_OF_TWO_LINES
 %if use_loop
-  mov                 cntd, (%2/2 >> %4) - 1
+  mov                 cntd, (%2/2 >> %3) - 1
 .loop:
 %endif
-  PROCESS_%1x2x4 0, %3
+  PROCESS_%1x2x4 0
 %if use_loop
   ADVANCE_END_OF_TWO_LINES
   sub                 cntd, 1
@@ -421,13 +341,10 @@
 %if %3 == 0
   %define resultq r4
   %define resultmp r4mp
-%else
-  %define resultq r5
-  %define resultmp r5mp
 %endif
 
 ; Undo modifications on parameters on the stack
-%if %4 == 1
+%if %3 == 1
 %if spill_src_stride
   shr          src_strided, 1
 %endif
@@ -446,7 +363,7 @@
   punpcklqdq            m4, m6
   punpckhqdq            m5, m7
   paddd                 m4, m5
-%if %4 == 1
+%if %3 == 1
   pslld                 m4, 1
 %endif
   movifnidn             resultq, resultmp
@@ -455,7 +372,7 @@
 %else
   pshufd            m6, m6, 0x08
   pshufd            m7, m7, 0x08
-%if %4 == 1
+%if %3 == 1
   pslld                 m6, 1
   pslld                 m7, 1
 %endif
@@ -492,7 +409,6 @@
 SADNXN4D  16,  64
 SADNXN4D  64,  16
 %endif
-%if CONFIG_REALTIME_ONLY==0
 SADNXN4D 128, 128, 1
 SADNXN4D 128,  64, 1
 SADNXN4D  64, 128, 1
@@ -506,39 +422,16 @@
 SADNXN4D  16,   8, 1
 SADNXN4D   8,  16, 1
 SADNXN4D   8,   8, 1
-SADNXN4D   8,   4, 1
 SADNXN4D   4,   8, 1
-SADNXN4D   4,   4, 1
+%if CONFIG_REALTIME_ONLY==0
 SADNXN4D   4,  16, 1
-SADNXN4D  16,   4, 1
 SADNXN4D   8,  32, 1
 SADNXN4D  32,   8, 1
 SADNXN4D  16,  64, 1
 SADNXN4D  64,  16, 1
 %endif
-SADNXN4D 128, 128, 0, 1
-SADNXN4D 128,  64, 0, 1
-SADNXN4D  64, 128, 0, 1
-SADNXN4D  64,  64, 0, 1
-SADNXN4D  64,  32, 0, 1
-SADNXN4D  32,  64, 0, 1
-SADNXN4D  32,  32, 0, 1
-SADNXN4D  32,  16, 0, 1
-SADNXN4D  16,  32, 0, 1
-SADNXN4D  16,  16, 0, 1
-SADNXN4D  16,   8, 0, 1
-SADNXN4D   8,  16, 0, 1
-SADNXN4D   8,   8, 0, 1
-SADNXN4D   4,   8, 0, 1
-%if CONFIG_REALTIME_ONLY==0
-SADNXN4D   4,  16, 0, 1
-SADNXN4D   8,  32, 0, 1
-SADNXN4D  32,   8, 0, 1
-SADNXN4D  16,  64, 0, 1
-SADNXN4D  64,  16, 0, 1
-%endif
 
 ; Different assembly is needed when the height gets subsampled to 2
-; SADNXN4D 16,  4, 0, 1
-; SADNXN4D  8,  4, 0, 1
-; SADNXN4D  4,  4, 0, 1
+; SADNXN4D 16,  4, 1
+; SADNXN4D  8,  4, 1
+; SADNXN4D  4,  4, 1
diff --git a/aom_dsp/x86/synonyms.h b/aom_dsp/x86/synonyms.h
index d538015..6744ec5 100644
--- a/aom_dsp/x86/synonyms.h
+++ b/aom_dsp/x86/synonyms.h
@@ -85,6 +85,16 @@
 #endif
 }
 
+// Fill an SSE register using an interleaved pair of values, ie. set the
+// 8 channels to {a, b, a, b, a, b, a, b}, using the same channel ordering
+// as when a register is stored to / loaded from memory.
+//
+// This is useful for rearranging filter kernels for use with the _mm_madd_epi16
+// instruction
+static INLINE __m128i xx_set2_epi16(int16_t a, int16_t b) {
+  return _mm_setr_epi16(a, b, a, b, a, b, a, b);
+}
+
 static INLINE __m128i xx_round_epu16(__m128i v_val_w) {
   return _mm_avg_epu16(v_val_w, _mm_setzero_si128());
 }
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index a475fb7..046d6f1 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -269,6 +269,95 @@
   _mm256_storeu_si256((__m256i *)(comp_pred), roundA);
 }
 
+void aom_comp_avg_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
+                            int height, const uint8_t *ref, int ref_stride) {
+  int row = 0;
+  if (width == 8) {
+    do {<