Merge "vp9: sync threads after a longjmp"
diff --git a/test/vp9_avg_test.cc b/test/vp9_avg_test.cc
index c2e472b..fa04528 100644
--- a/test/vp9_avg_test.cc
+++ b/test/vp9_avg_test.cc
@@ -57,7 +57,7 @@
   }
 
   // Sum Pixels
-  unsigned int ReferenceAverage(const uint8_t* source, int pitch ) {
+  unsigned int ReferenceAverage8x8(const uint8_t* source, int pitch ) {
     unsigned int average = 0;
     for (int h = 0; h < 8; ++h)
       for (int w = 0; w < 8; ++w)
@@ -65,6 +65,14 @@
     return ((average + 32) >> 6);
   }
 
+  unsigned int ReferenceAverage4x4(const uint8_t* source, int pitch ) {
+    unsigned int average = 0;
+    for (int h = 0; h < 4; ++h)
+      for (int w = 0; w < 4; ++w)
+        average += source[h * source_stride_ + w];
+    return ((average + 8) >> 4);
+  }
+
   void FillConstant(uint8_t fill_constant) {
     for (int i = 0; i < width_ * height_; ++i) {
         source_data_[i] = fill_constant;
@@ -85,7 +93,7 @@
 };
 typedef unsigned int (*AverageFunction)(const uint8_t* s, int pitch);
 
-typedef std::tr1::tuple<int, int, int, AverageFunction> AvgFunc;
+typedef std::tr1::tuple<int, int, int, int, AverageFunction> AvgFunc;
 
 class AverageTest
     : public AverageTestBase,
@@ -95,12 +103,18 @@
 
  protected:
   void CheckAverages() {
-    unsigned int expected = ReferenceAverage(source_data_+ GET_PARAM(2),
-                                             source_stride_);
+    unsigned int expected = 0;
+    if (GET_PARAM(3) == 8) {
+      expected = ReferenceAverage8x8(source_data_+ GET_PARAM(2),
+                                     source_stride_);
+    } else  if (GET_PARAM(3) == 4) {
+      expected = ReferenceAverage4x4(source_data_+ GET_PARAM(2),
+                                     source_stride_);
+    }
 
-    ASM_REGISTER_STATE_CHECK(GET_PARAM(3)(source_data_+ GET_PARAM(2),
+    ASM_REGISTER_STATE_CHECK(GET_PARAM(4)(source_data_+ GET_PARAM(2),
                                           source_stride_));
-    unsigned int actual = GET_PARAM(3)(source_data_+ GET_PARAM(2),
+    unsigned int actual = GET_PARAM(4)(source_data_+ GET_PARAM(2),
                                        source_stride_);
 
     EXPECT_EQ(expected, actual);
@@ -134,16 +148,20 @@
 INSTANTIATE_TEST_CASE_P(
     C, AverageTest,
     ::testing::Values(
-        make_tuple(16, 16, 1, &vp9_avg_8x8_c)));
+        make_tuple(16, 16, 1, 8, &vp9_avg_8x8_c),
+        make_tuple(16, 16, 1, 4, &vp9_avg_4x4_c)));
 
 
 #if HAVE_SSE2
 INSTANTIATE_TEST_CASE_P(
     SSE2, AverageTest,
     ::testing::Values(
-        make_tuple(16, 16, 0, &vp9_avg_8x8_sse2),
-        make_tuple(16, 16, 5, &vp9_avg_8x8_sse2),
-        make_tuple(32, 32, 15, &vp9_avg_8x8_sse2)));
+        make_tuple(16, 16, 0, 8, &vp9_avg_8x8_sse2),
+        make_tuple(16, 16, 5, 8, &vp9_avg_8x8_sse2),
+        make_tuple(32, 32, 15, 8, &vp9_avg_8x8_sse2),
+        make_tuple(16, 16, 0, 4, &vp9_avg_4x4_sse2),
+        make_tuple(16, 16, 5, 4, &vp9_avg_4x4_sse2),
+        make_tuple(32, 32, 15, 4, &vp9_avg_4x4_sse2)));
 
 #endif
 
diff --git a/vp9/common/vp9_rtcd_defs.pl b/vp9/common/vp9_rtcd_defs.pl
index ae12808..281dcbd 100644
--- a/vp9/common/vp9_rtcd_defs.pl
+++ b/vp9/common/vp9_rtcd_defs.pl
@@ -1135,9 +1135,14 @@
 add_proto qw/unsigned int vp9_avg_8x8/, "const uint8_t *, int p";
 specialize qw/vp9_avg_8x8 sse2/;
 
+add_proto qw/unsigned int vp9_avg_4x4/, "const uint8_t *, int p";
+specialize qw/vp9_avg_4x4 sse2/;
+
 if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
   add_proto qw/unsigned int vp9_highbd_avg_8x8/, "const uint8_t *, int p";
   specialize qw/vp9_highbd_avg_8x8/;
+  add_proto qw/unsigned int vp9_highbd_avg_4x4/, "const uint8_t *, int p";
+  specialize qw/vp9_highbd_avg_4x4/;
 }
 
 # ENCODEMB INVOKE
diff --git a/vp9/encoder/vp9_avg.c b/vp9/encoder/vp9_avg.c
index e9810c8..f8fa7d2 100644
--- a/vp9/encoder/vp9_avg.c
+++ b/vp9/encoder/vp9_avg.c
@@ -19,6 +19,15 @@
   return (sum + 32) >> 6;
 }
 
+unsigned int vp9_avg_4x4_c(const uint8_t *s, int p) {
+  int i, j;
+  int sum = 0;
+  for (i = 0; i < 4; ++i, s+=p)
+    for (j = 0; j < 4; sum += s[j], ++j) {}
+
+  return (sum + 8) >> 4;
+}
+
 #if CONFIG_VP9_HIGHBITDEPTH
 unsigned int vp9_highbd_avg_8x8_c(const uint8_t *s8, int p) {
   int i, j;
@@ -29,5 +38,16 @@
 
   return (sum + 32) >> 6;
 }
+
+unsigned int vp9_highbd_avg_4x4_c(const uint8_t *s8, int p) {
+  int i, j;
+  int sum = 0;
+  const uint16_t* s = CONVERT_TO_SHORTPTR(s8);
+  for (i = 0; i < 4; ++i, s+=p)
+    for (j = 0; j < 4; sum += s[j], ++j) {}
+
+  return (sum + 8) >> 4;
+}
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 
+
diff --git a/vp9/encoder/vp9_bitstream.c b/vp9/encoder/vp9_bitstream.c
index cad3109..4d88fb5 100644
--- a/vp9/encoder/vp9_bitstream.c
+++ b/vp9/encoder/vp9_bitstream.c
@@ -535,6 +535,8 @@
   const vp9_prob upd = DIFF_UPDATE_PROB;
   const int entropy_nodes_update = UNCONSTRAINED_NODES;
   int i, j, k, l, t;
+  int stepsize = cpi->sf.coeff_prob_appx_step;
+
   switch (cpi->sf.use_fast_coef_updates) {
     case TWO_LOOP: {
       /* dry run to see if there is any update at all needed */
@@ -552,7 +554,7 @@
                 if (t == PIVOT_NODE)
                   s = vp9_prob_diff_update_savings_search_model(
                       frame_branch_ct[i][j][k][l][0],
-                      old_coef_probs[i][j][k][l], &newp, upd);
+                      old_coef_probs[i][j][k][l], &newp, upd, stepsize);
                 else
                   s = vp9_prob_diff_update_savings_search(
                       frame_branch_ct[i][j][k][l][t], oldp, &newp, upd);
@@ -590,7 +592,7 @@
                 if (t == PIVOT_NODE)
                   s = vp9_prob_diff_update_savings_search_model(
                       frame_branch_ct[i][j][k][l][0],
-                      old_coef_probs[i][j][k][l], &newp, upd);
+                      old_coef_probs[i][j][k][l], &newp, upd, stepsize);
                 else
                   s = vp9_prob_diff_update_savings_search(
                       frame_branch_ct[i][j][k][l][t],
@@ -611,16 +613,15 @@
       return;
     }
 
-    case ONE_LOOP:
     case ONE_LOOP_REDUCED: {
-      const int prev_coef_contexts_to_update =
-          cpi->sf.use_fast_coef_updates == ONE_LOOP_REDUCED ?
-              COEFF_CONTEXTS >> 1 : COEFF_CONTEXTS;
-      const int coef_band_to_update =
-          cpi->sf.use_fast_coef_updates == ONE_LOOP_REDUCED ?
-              COEF_BANDS >> 1 : COEF_BANDS;
       int updates = 0;
       int noupdates_before_first = 0;
+
+      if (tx_size >= TX_16X16 && cpi->sf.tx_size_search_method == USE_TX_8X8) {
+        vp9_write_bit(bc, 0);
+        return;
+      }
+
       for (i = 0; i < PLANE_TYPES; ++i) {
         for (j = 0; j < REF_TYPES; ++j) {
           for (k = 0; k < COEF_BANDS; ++k) {
@@ -631,21 +632,19 @@
                 vp9_prob *oldp = old_coef_probs[i][j][k][l] + t;
                 int s;
                 int u = 0;
-                if (l >= prev_coef_contexts_to_update ||
-                    k >= coef_band_to_update) {
-                  u = 0;
+
+                if (t == PIVOT_NODE) {
+                  s = vp9_prob_diff_update_savings_search_model(
+                      frame_branch_ct[i][j][k][l][0],
+                      old_coef_probs[i][j][k][l], &newp, upd, stepsize);
                 } else {
-                  if (t == PIVOT_NODE)
-                    s = vp9_prob_diff_update_savings_search_model(
-                        frame_branch_ct[i][j][k][l][0],
-                        old_coef_probs[i][j][k][l], &newp, upd);
-                  else
-                    s = vp9_prob_diff_update_savings_search(
-                        frame_branch_ct[i][j][k][l][t],
-                        *oldp, &newp, upd);
-                  if (s > 0 && newp != *oldp)
-                    u = 1;
+                  s = vp9_prob_diff_update_savings_search(
+                      frame_branch_ct[i][j][k][l][t],
+                      *oldp, &newp, upd);
                 }
+
+                if (s > 0 && newp != *oldp)
+                  u = 1;
                 updates += u;
                 if (u == 0 && updates == 0) {
                   noupdates_before_first++;
diff --git a/vp9/encoder/vp9_encodeframe.c b/vp9/encoder/vp9_encodeframe.c
index d5122d0..7788e50 100644
--- a/vp9/encoder/vp9_encodeframe.c
+++ b/vp9/encoder/vp9_encodeframe.c
@@ -291,6 +291,11 @@
 typedef struct {
   partition_variance part_variances;
   var split[4];
+} v4x4;
+
+typedef struct {
+  partition_variance part_variances;
+  v4x4 split[4];
 } v8x8;
 
 typedef struct {
@@ -349,6 +354,13 @@
       v8x8 *vt = (v8x8 *) data;
       node->part_variances = &vt->part_variances;
       for (i = 0; i < 4; i++)
+        node->split[i] = &vt->split[i].part_variances.none;
+      break;
+    }
+    case BLOCK_4X4: {
+      v4x4 *vt = (v4x4 *) data;
+      node->part_variances = &vt->part_variances;
+      for (i = 0; i < 4; i++)
         node->split[i] = &vt->split[i];
       break;
     }
@@ -398,64 +410,76 @@
   variance_node vt;
   const int block_width = num_8x8_blocks_wide_lookup[bsize];
   const int block_height = num_8x8_blocks_high_lookup[bsize];
-  // TODO(debargha): Choose this more intelligently.
-  const int threshold_multiplier = cm->frame_type == KEY_FRAME ? 64 : 4;
+  // TODO(marpan): Adjust/tune these thresholds.
+  const int threshold_multiplier = cm->frame_type == KEY_FRAME ? 80 : 4;
   int64_t threshold =
       (int64_t)(threshold_multiplier *
                 vp9_convert_qindex_to_q(cm->base_qindex, cm->bit_depth));
+  int64_t threshold_bsize_ref = threshold << 6;
+  int64_t threshold_low = threshold;
+  BLOCK_SIZE bsize_ref = BLOCK_16X16;
+
   assert(block_height == block_width);
   tree_to_node(data, bsize, &vt);
 
-  // Split none is available only if we have more than half a block size
-  // in width and height inside the visible image.
-  if (mi_col + block_width / 2 < cm->mi_cols &&
-      mi_row + block_height / 2 < cm->mi_rows &&
-      vt.part_variances->none.variance < threshold) {
-    set_block_size(cpi, xd, mi_row, mi_col, bsize);
-    return 1;
+  if (cm->frame_type == KEY_FRAME) {
+    bsize_ref = BLOCK_8X8;
+    // Choose lower thresholds for key frame variance to favor split.
+    threshold_bsize_ref = threshold >> 1;
+    threshold_low = threshold >> 2;
   }
 
-  // Only allow split for blocks above 16x16.
-  if (bsize > BLOCK_16X16) {
-    // Vertical split is available on all but the bottom border.
+  // For bsize=bsize_ref (16x16/8x8 for 8x8/4x4 downsampling), select if
+  // variance is below threshold, otherwise split will be selected.
+  // No check for vert/horiz split as too few samples for variance.
+  if (bsize == bsize_ref) {
+    if (mi_col + block_width / 2 < cm->mi_cols &&
+        mi_row + block_height / 2 < cm->mi_rows &&
+        vt.part_variances->none.variance < threshold_bsize_ref) {
+      set_block_size(cpi, xd, mi_row, mi_col, bsize);
+      return 1;
+    }
+    return 0;
+  } else if (bsize > bsize_ref) {
+    // For key frame, for bsize above 32X32, or very high variance, take split.
+    if (cm->frame_type == KEY_FRAME &&
+        (bsize > BLOCK_32X32 ||
+        vt.part_variances->none.variance > (threshold << 2))) {
+      return 0;
+    }
+    // If variance is low, take the bsize (no split).
+    if (mi_col + block_width / 2 < cm->mi_cols &&
+        mi_row + block_height / 2 < cm->mi_rows &&
+        vt.part_variances->none.variance < threshold_low) {
+      set_block_size(cpi, xd, mi_row, mi_col, bsize);
+      return 1;
+    }
+    // Check vertical split.
     if (mi_row + block_height / 2 < cm->mi_rows &&
-        vt.part_variances->vert[0].variance < threshold &&
-        vt.part_variances->vert[1].variance < threshold) {
+        vt.part_variances->vert[0].variance < threshold_low &&
+        vt.part_variances->vert[1].variance < threshold_low) {
       BLOCK_SIZE subsize = get_subsize(bsize, PARTITION_VERT);
       set_block_size(cpi, xd, mi_row, mi_col, subsize);
       set_block_size(cpi, xd, mi_row, mi_col + block_width / 2, subsize);
       return 1;
     }
-
-    // Horizontal split is available on all but the right border.
+    // Check horizontal split.
     if (mi_col + block_width / 2 < cm->mi_cols &&
-        vt.part_variances->horz[0].variance < threshold &&
-        vt.part_variances->horz[1].variance < threshold) {
+        vt.part_variances->horz[0].variance < threshold_low &&
+        vt.part_variances->horz[1].variance < threshold_low) {
       BLOCK_SIZE subsize = get_subsize(bsize, PARTITION_HORZ);
       set_block_size(cpi, xd, mi_row, mi_col, subsize);
       set_block_size(cpi, xd, mi_row + block_height / 2, mi_col, subsize);
       return 1;
     }
-  }
-
-  // This will only allow 8x8 if the 16x16 variance is very large.
-  if (bsize == BLOCK_16X16) {
-    if (mi_col + block_width / 2 < cm->mi_cols &&
-        mi_row + block_height / 2 < cm->mi_rows &&
-        vt.part_variances->none.variance < (threshold << 6)) {
-      set_block_size(cpi, xd, mi_row, mi_col, bsize);
-      return 1;
-    }
+    return 0;
   }
   return 0;
 }
 
-// This function chooses partitioning based on the variance
-// between source and reconstructed last, where variance is
-// computed for 8x8 downsampled inputs. Some things to check:
-// using the last source rather than reconstructed last, and
-// allowing for small downsampling (4x4 or 2x2) for selection
-// of smaller block sizes (i.e., < 16x16).
+// This function chooses partitioning based on the variance between source and
+// reconstructed last, where variance is computed for downsampled inputs.
+// Currently 8x8 downsampling is used for delta frames, 4x4 for key frames.
 static void choose_partitioning(VP9_COMP *cpi,
                                 const TileInfo *const tile,
                                 MACROBLOCK *x,
@@ -463,7 +487,7 @@
   VP9_COMMON * const cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
 
-  int i, j, k;
+  int i, j, k, m;
   v64x64 vt;
   uint8_t *s;
   const uint8_t *d;
@@ -525,38 +549,63 @@
       const int y16_idx = y32_idx + ((j >> 1) << 4);
       v16x16 *vst = &vt.split[i].split[j];
       for (k = 0; k < 4; k++) {
-        int x_idx = x16_idx + ((k & 1) << 3);
-        int y_idx = y16_idx + ((k >> 1) << 3);
-        unsigned int sse = 0;
-        int sum = 0;
-
-        if (x_idx < pixels_wide && y_idx < pixels_high) {
-          int s_avg, d_avg;
+        int x8_idx = x16_idx + ((k & 1) << 3);
+        int y8_idx = y16_idx + ((k >> 1) << 3);
+        if (cm->frame_type != KEY_FRAME) {
+          unsigned int sse = 0;
+          int sum = 0;
+          if (x8_idx < pixels_wide && y8_idx < pixels_high) {
+            int s_avg, d_avg;
 #if CONFIG_VP9_HIGHBITDEPTH
-          if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
-            s_avg = vp9_highbd_avg_8x8(s + y_idx * sp + x_idx, sp);
-            d_avg = vp9_highbd_avg_8x8(d + y_idx * dp + x_idx, dp);
-          } else {
-            s_avg = vp9_avg_8x8(s + y_idx * sp + x_idx, sp);
-            d_avg = vp9_avg_8x8(d + y_idx * dp + x_idx, dp);
-          }
+            if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+              s_avg = vp9_highbd_avg_8x8(s + y8_idx * sp + x8_idx, sp);
+              d_avg = vp9_highbd_avg_8x8(d + y8_idx * dp + x8_idx, dp);
+            } else {
+              s_avg = vp9_avg_8x8(s + y8_idx * sp + x8_idx, sp);
+              d_avg = vp9_avg_8x8(d + y8_idx * dp + x8_idx, dp);
+           }
 #else
-          s_avg = vp9_avg_8x8(s + y_idx * sp + x_idx, sp);
-          d_avg = vp9_avg_8x8(d + y_idx * dp + x_idx, dp);
+            s_avg = vp9_avg_8x8(s + y8_idx * sp + x8_idx, sp);
+            d_avg = vp9_avg_8x8(d + y8_idx * dp + x8_idx, dp);
 #endif
-          sum = s_avg - d_avg;
-          sse = sum * sum;
+            sum = s_avg - d_avg;
+            sse = sum * sum;
+          }
+          // If variance is based on 8x8 downsampling, we stop here and have
+          // one sample for 8x8 block (so use 1 for count in fill_variance),
+          // which of course means variance = 0 for 8x8 block.
+          fill_variance(sse, sum, 1, &vst->split[k].part_variances.none);
+        } else {
+          // For key frame, go down to 4x4.
+          v8x8 *vst2 = &vst->split[k];
+          for (m = 0; m < 4; m++) {
+            int x4_idx = x8_idx + ((m & 1) << 2);
+            int y4_idx = y8_idx + ((m >> 1) << 2);
+            unsigned int sse = 0;
+            int sum = 0;
+            if (x4_idx < pixels_wide && y4_idx < pixels_high) {
+              int s_avg = vp9_avg_4x4(s + y4_idx * sp + x4_idx, sp);
+              // For key frame, reference is set to 128.
+              sum = s_avg - 128;
+              sse = sum * sum;
+            }
+            // If variance is based on 4x4 downsampling, we stop here and have
+            // one sample for 4x4 block (so use 1 for count in fill_variance),
+            // which of course means variance = 0 for 4x4 block.
+           fill_variance(sse, sum, 1, &vst2->split[m].part_variances.none);
+          }
         }
-        // For an 8x8 block we have just one value the average of all 64
-        // pixels,  so use 1.   This means of course that there is no variance
-        // in an 8x8 block.
-        fill_variance(sse, sum, 1, &vst->split[k].part_variances.none);
       }
     }
   }
   // Fill the rest of the variance tree by summing split partition values.
   for (i = 0; i < 4; i++) {
     for (j = 0; j < 4; j++) {
+      if (cm->frame_type == KEY_FRAME) {
+        for (m = 0; m < 4; m++) {
+          fill_variance_tree(&vt.split[i].split[j].split[m], BLOCK_8X8);
+        }
+      }
       fill_variance_tree(&vt.split[i].split[j], BLOCK_16X16);
     }
     fill_variance_tree(&vt.split[i], BLOCK_32X32);
@@ -564,8 +613,7 @@
   fill_variance_tree(&vt, BLOCK_64X64);
 
   // Now go through the entire structure,  splitting every block size until
-  // we get to one that's got a variance lower than our threshold,  or we
-  // hit 8x8.
+  // we get to one that's got a variance lower than our threshold.
   if ( mi_col + 8 > cm->mi_cols || mi_row + 8 > cm->mi_rows ||
       !set_vt_partitioning(cpi, xd, &vt, BLOCK_64X64, mi_row, mi_col)) {
     for (i = 0; i < 4; ++i) {
@@ -576,11 +624,13 @@
         for (j = 0; j < 4; ++j) {
           const int x16_idx = ((j & 1) << 1);
           const int y16_idx = ((j >> 1) << 1);
-          // NOTE: Since this uses 8x8 downsampling for variance calculation
-          // we cannot really select block size 8x8 (or even 8x16/16x8),
-          // since we do not sufficient samples for variance.
-          // For now, 8x8 partition is only set if the variance of the 16x16
-          // block is very high. This is controlled in set_vt_partitioning.
+          // Note: If 8x8 downsampling is used for variance calculation we
+          // cannot really select block size 8x8 (or even 8x16/16x8), since we
+          // don't have sufficient samples for variance. So on delta frames,
+          // 8x8 partition is only set if variance of the 16x16 block is very
+          // high. For key frames, 4x4 downsampling is used, so we can better
+          // select 8x16/16x8 and 8x8. 4x4 partition can potentially be set
+          // used here too, but for now 4x4 is not allowed.
           if (!set_vt_partitioning(cpi, xd, &vt.split[i].split[j],
                                    BLOCK_16X16,
                                    mi_row + y32_idx + y16_idx,
@@ -588,10 +638,26 @@
             for (k = 0; k < 4; ++k) {
               const int x8_idx = (k & 1);
               const int y8_idx = (k >> 1);
-              set_block_size(cpi, xd,
-                             (mi_row + y32_idx + y16_idx + y8_idx),
-                             (mi_col + x32_idx + x16_idx + x8_idx),
-                             BLOCK_8X8);
+              // TODO(marpan): Allow for setting 4x4 partition on key frame.
+              /*
+              if (cm->frame_type == KEY_FRAME) {
+                if (!set_vt_partitioning(cpi, xd,
+                                         &vt.split[i].split[j].split[k],
+                                         BLOCK_8X8,
+                                         mi_row + y32_idx + y16_idx + y8_idx,
+                                         mi_col + x32_idx + x16_idx + x8_idx)) {
+                    set_block_size(cpi, xd,
+                                  (mi_row + y32_idx + y16_idx + y8_idx),
+                                  (mi_col + x32_idx + x16_idx + x8_idx),
+                                   BLOCK_4X4);
+                }
+              } else {
+              */
+                set_block_size(cpi, xd,
+                               (mi_row + y32_idx + y16_idx + y8_idx),
+                               (mi_col + x32_idx + x16_idx + x8_idx),
+                               BLOCK_8X8);
+              // }
             }
           }
         }
@@ -2511,7 +2577,7 @@
       rd_use_partition(cpi, td, tile_data, mi, tp, mi_row, mi_col,
                        BLOCK_64X64, &dummy_rate, &dummy_dist, 1, td->pc_root);
     } else if (sf->partition_search_type == VAR_BASED_PARTITION &&
-               cm->frame_type != KEY_FRAME ) {
+               cm->frame_type != KEY_FRAME) {
       choose_partitioning(cpi, tile_info, x, mi_row, mi_col);
       rd_use_partition(cpi, td, tile_data, mi, tp, mi_row, mi_col,
                        BLOCK_64X64, &dummy_rate, &dummy_dist, 1, td->pc_root);
@@ -3532,6 +3598,11 @@
                  cm->uv_ac_delta_q == 0;
 
   cm->tx_mode = select_tx_mode(cpi, xd);
+  if (cm->frame_type == KEY_FRAME &&
+      cpi->sf.use_nonrd_pick_mode &&
+      cpi->sf.partition_search_type == VAR_BASED_PARTITION) {
+    cm->tx_mode = ALLOW_16X16;
+  }
 
 #if CONFIG_VP9_HIGHBITDEPTH
   if (cm->use_highbitdepth)
diff --git a/vp9/encoder/vp9_speed_features.c b/vp9/encoder/vp9_speed_features.c
index 5c70b4e..4a0c797 100644
--- a/vp9/encoder/vp9_speed_features.c
+++ b/vp9/encoder/vp9_speed_features.c
@@ -249,7 +249,6 @@
     sf->use_uv_intra_rd_estimate = 1;
     sf->skip_encode_sb = 1;
     sf->mv.subpel_iters_per_step = 1;
-    sf->use_fast_coef_updates = ONE_LOOP_REDUCED;
     sf->adaptive_rd_thresh = 4;
     sf->mode_skip_start = 6;
     sf->allow_skip_recode = 0;
@@ -304,6 +303,9 @@
     // This feature is only enabled when partition search is disabled.
     sf->reuse_inter_pred_sby = 1;
     sf->partition_search_breakout_rate_thr = 200;
+    sf->coeff_prob_appx_step = 4;
+    sf->use_fast_coef_updates = is_keyframe ? TWO_LOOP : ONE_LOOP_REDUCED;
+
     if (!is_keyframe) {
       int i;
       if (content == VP9E_CONTENT_SCREEN) {
@@ -321,7 +323,7 @@
     sf->partition_search_type = VAR_BASED_PARTITION;
 
     // Turn on this to use non-RD key frame coding mode.
-    // sf->use_nonrd_pick_mode = 1;
+    sf->use_nonrd_pick_mode = 1;
     sf->mv.search_method = NSTEP;
     sf->tx_size_search_method = is_keyframe ? USE_LARGESTALL : USE_TX_8X8;
     sf->mv.reduce_first_step_size = 1;
@@ -394,6 +396,7 @@
   sf->mv.subpel_force_stop = 0;
   sf->optimize_coefficients = !is_lossless_requested(&cpi->oxcf);
   sf->mv.reduce_first_step_size = 0;
+  sf->coeff_prob_appx_step = 1;
   sf->mv.auto_mv_step_size = 0;
   sf->mv.fullpel_search_step_param = 6;
   sf->comp_inter_joint_search_thresh = BLOCK_4X4;
diff --git a/vp9/encoder/vp9_speed_features.h b/vp9/encoder/vp9_speed_features.h
index efea503..c2cfd62 100644
--- a/vp9/encoder/vp9_speed_features.h
+++ b/vp9/encoder/vp9_speed_features.h
@@ -163,12 +163,9 @@
   // before the final run.
   TWO_LOOP = 0,
 
-  // No dry run conducted.
-  ONE_LOOP = 1,
-
   // No dry run, also only half the coef contexts and bands are updated.
   // The rest are not updated at all.
-  ONE_LOOP_REDUCED = 2
+  ONE_LOOP_REDUCED = 1
 } FAST_COEFF_UPDATE;
 
 typedef struct MV_SPEED_FEATURES {
@@ -236,6 +233,9 @@
   // level within a frame.
   int allow_skip_recode;
 
+  // Coefficient probability model approximation step size
+  int coeff_prob_appx_step;
+
   // The threshold is to determine how slow the motino is, it is used when
   // use_lastframe_partitioning is set to LAST_FRAME_PARTITION_LOW_MOTION
   MOTION_THRESHOLD lf_motion_threshold;
diff --git a/vp9/encoder/vp9_subexp.c b/vp9/encoder/vp9_subexp.c
index 530b592..180dadd 100644
--- a/vp9/encoder/vp9_subexp.c
+++ b/vp9/encoder/vp9_subexp.c
@@ -140,7 +140,8 @@
 int vp9_prob_diff_update_savings_search_model(const unsigned int *ct,
                                               const vp9_prob *oldp,
                                               vp9_prob *bestp,
-                                              vp9_prob upd) {
+                                              vp9_prob upd,
+                                              int stepsize) {
   int i, old_b, new_b, update_b, savings, bestsavings, step;
   int newp;
   vp9_prob bestnewp, newplist[ENTROPY_NODES], oldplist[ENTROPY_NODES];
@@ -153,24 +154,44 @@
   bestsavings = 0;
   bestnewp = oldp[PIVOT_NODE];
 
-  step = (*bestp > oldp[PIVOT_NODE] ? -1 : 1);
-
-  for (newp = *bestp; newp != oldp[PIVOT_NODE]; newp += step) {
-    if (newp < 1 || newp > 255)
-      continue;
-    newplist[PIVOT_NODE] = newp;
-    vp9_model_to_full_probs(newplist, newplist);
-    for (i = UNCONSTRAINED_NODES, new_b = 0; i < ENTROPY_NODES; ++i)
-      new_b += cost_branch256(ct + 2 * i, newplist[i]);
-    new_b += cost_branch256(ct + 2 * PIVOT_NODE, newplist[PIVOT_NODE]);
-    update_b = prob_diff_update_cost(newp, oldp[PIVOT_NODE]) +
-        vp9_cost_upd256;
-    savings = old_b - new_b - update_b;
-    if (savings > bestsavings) {
-      bestsavings = savings;
-      bestnewp = newp;
+  if (*bestp > oldp[PIVOT_NODE]) {
+    step = -stepsize;
+    for (newp = *bestp; newp > oldp[PIVOT_NODE]; newp += step) {
+      if (newp < 1 || newp > 255)
+        continue;
+      newplist[PIVOT_NODE] = newp;
+      vp9_model_to_full_probs(newplist, newplist);
+      for (i = UNCONSTRAINED_NODES, new_b = 0; i < ENTROPY_NODES; ++i)
+        new_b += cost_branch256(ct + 2 * i, newplist[i]);
+      new_b += cost_branch256(ct + 2 * PIVOT_NODE, newplist[PIVOT_NODE]);
+      update_b = prob_diff_update_cost(newp, oldp[PIVOT_NODE]) +
+          vp9_cost_upd256;
+      savings = old_b - new_b - update_b;
+      if (savings > bestsavings) {
+        bestsavings = savings;
+        bestnewp = newp;
+      }
+    }
+  } else {
+    step = stepsize;
+    for (newp = *bestp; newp < oldp[PIVOT_NODE]; newp += step) {
+      if (newp < 1 || newp > 255)
+        continue;
+      newplist[PIVOT_NODE] = newp;
+      vp9_model_to_full_probs(newplist, newplist);
+      for (i = UNCONSTRAINED_NODES, new_b = 0; i < ENTROPY_NODES; ++i)
+        new_b += cost_branch256(ct + 2 * i, newplist[i]);
+      new_b += cost_branch256(ct + 2 * PIVOT_NODE, newplist[PIVOT_NODE]);
+      update_b = prob_diff_update_cost(newp, oldp[PIVOT_NODE]) +
+          vp9_cost_upd256;
+      savings = old_b - new_b - update_b;
+      if (savings > bestsavings) {
+        bestsavings = savings;
+        bestnewp = newp;
+      }
     }
   }
+
   *bestp = bestnewp;
   return bestsavings;
 }
diff --git a/vp9/encoder/vp9_subexp.h b/vp9/encoder/vp9_subexp.h
index 8e02a1d..ac54893 100644
--- a/vp9/encoder/vp9_subexp.h
+++ b/vp9/encoder/vp9_subexp.h
@@ -30,7 +30,8 @@
 int vp9_prob_diff_update_savings_search_model(const unsigned int *ct,
                                               const vp9_prob *oldp,
                                               vp9_prob *bestp,
-                                              vp9_prob upd);
+                                              vp9_prob upd,
+                                              int stepsize);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/vp9/encoder/x86/vp9_avg_intrin_sse2.c b/vp9/encoder/x86/vp9_avg_intrin_sse2.c
index ca6cf1a..4c3495b 100644
--- a/vp9/encoder/x86/vp9_avg_intrin_sse2.c
+++ b/vp9/encoder/x86/vp9_avg_intrin_sse2.c
@@ -38,3 +38,21 @@
   avg = _mm_extract_epi16(s0, 0);
   return (avg + 32) >> 6;
 }
+
+unsigned int vp9_avg_4x4_sse2(const uint8_t *s, int p) {
+  __m128i s0, s1, u0;
+  unsigned int avg = 0;
+  u0  = _mm_setzero_si128();
+  s0 = _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)(s)), u0);
+  s1 = _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)(s + p)), u0);
+  s0 = _mm_adds_epu16(s0, s1);
+  s1 = _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)(s + 2 * p)), u0);
+  s0 = _mm_adds_epu16(s0, s1);
+  s1 = _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)(s + 3 * p)), u0);
+  s0 = _mm_adds_epu16(s0, s1);
+
+  s0 = _mm_adds_epu16(s0, _mm_srli_si128(s0, 4));
+  s0 = _mm_adds_epu16(s0, _mm_srli_epi64(s0, 16));
+  avg = _mm_extract_epi16(s0, 0);
+  return (avg + 8) >> 4;
+}