Filling in speed feature functions for ext tx search

Filled in prune one and prune two. Prune three is still
being experimented with.

Change-Id: Ic07f828c448e86cacb0369aa3a9a0feb2edae054
diff --git a/vp10/encoder/rdopt.c b/vp10/encoder/rdopt.c
index 1dbac3d..c1a5abc 100644
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -87,6 +87,8 @@
 const double ext_tx_th = 0.99;
 #endif
 
+const double ADST_FLIP_SVM[8] = {-7.3283, -3.0450, -3.2450, 3.6403,  // vert
+                                 -9.4204, -3.1821, -4.6851, 4.1469};  // horz
 
 typedef struct {
   PREDICTION_MODE mode;
@@ -350,7 +352,12 @@
   }
 }
 
-#if CONFIG_EXT_TX
+// constants for prune 1 and prune 2 decision boundaries
+#define FAST_EXT_TX_CORR_MID 0.0
+#define FAST_EXT_TX_EDST_MID 0.1
+#define FAST_EXT_TX_CORR_MARGIN 0.5
+#define FAST_EXT_TX_EDST_MARGIN 0.05
+
 typedef enum {
   DCT_1D = 0,
   ADST_1D = 1,
@@ -359,15 +366,222 @@
   TX_TYPES_1D = 4,
 } TX_TYPE_1D;
 
+static void get_energy_distribution_fine(const VP10_COMP *cpi,
+                                         BLOCK_SIZE bsize,
+                                         uint8_t *src, int src_stride,
+                                         uint8_t *dst, int dst_stride,
+                                         double *hordist, double *verdist) {
+  int bw = 4 << (b_width_log2_lookup[bsize]);
+  int bh = 4 << (b_height_log2_lookup[bsize]);
+  unsigned int esq[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+  unsigned int var[16];
+  double total = 0;
+  const int f_index = bsize - 6;
+  if (f_index < 0) {
+    int i, j, index;
+    int w_shift = bw == 8 ? 1 : 2;
+    int h_shift = bh == 8 ? 1 : 2;
+    for (i = 0; i < bh; ++i)
+      for (j = 0; j < bw; ++j) {
+        index = (j >> w_shift) + ((i >> h_shift) << 2);
+        esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
+                      (src[j + i * src_stride] - dst[j + i * dst_stride]);
+      }
+  } else {
+    var[0] = cpi->fn_ptr[f_index].vf(src, src_stride,
+                                     dst, dst_stride, &esq[0]);
+    var[1] = cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride,
+                                     dst + bw / 4, dst_stride, &esq[1]);
+    var[2] = cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride,
+                                     dst + bw / 2, dst_stride, &esq[2]);
+    var[3] = cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride,
+                                     dst + 3 * bw / 4, dst_stride, &esq[3]);
+    src += bh / 4 * src_stride;
+    dst += bh / 4 * dst_stride;
+
+    var[4] = cpi->fn_ptr[f_index].vf(src, src_stride,
+                                     dst, dst_stride, &esq[4]);
+    var[5] = cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride,
+                                     dst + bw / 4, dst_stride, &esq[5]);
+    var[6] = cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride,
+                                     dst + bw / 2, dst_stride, &esq[6]);
+    var[7] = cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride,
+                                     dst + 3 * bw / 4, dst_stride, &esq[7]);
+    src += bh / 4 * src_stride;
+    dst += bh / 4 * dst_stride;
+
+    var[8] = cpi->fn_ptr[f_index].vf(src, src_stride,
+                                     dst, dst_stride, &esq[8]);
+    var[9] = cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride,
+                                     dst + bw / 4, dst_stride, &esq[9]);
+    var[10] = cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride,
+                                      dst + bw / 2, dst_stride, &esq[10]);
+    var[11] = cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride,
+                                      dst + 3 * bw / 4, dst_stride, &esq[11]);
+    src += bh / 4 * src_stride;
+    dst += bh / 4 * dst_stride;
+
+    var[12] = cpi->fn_ptr[f_index].vf(src, src_stride,
+                                      dst, dst_stride, &esq[12]);
+    var[13] = cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride,
+                                      dst + bw / 4, dst_stride, &esq[13]);
+    var[14] = cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride,
+                                      dst + bw / 2, dst_stride, &esq[14]);
+    var[15] = cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride,
+                                      dst + 3 * bw / 4, dst_stride, &esq[15]);
+  }
+
+  total = esq[0] + esq[1] + esq[2] + esq[3] +
+          esq[4] + esq[5] + esq[6] + esq[7] +
+          esq[8] + esq[9] + esq[10] + esq[11] +
+          esq[12] + esq[13] + esq[14] + esq[15];
+  if (total > 0) {
+    const double e_recip = 1.0 / total;
+    hordist[0] = ((double)esq[0] + (double)esq[4] + (double)esq[8] +
+                  (double)esq[12]) * e_recip;
+    hordist[1] = ((double)esq[1] + (double)esq[5] + (double)esq[9] +
+                  (double)esq[13]) * e_recip;
+    hordist[2] = ((double)esq[2] + (double)esq[6] + (double)esq[10] +
+                  (double)esq[14]) * e_recip;
+    verdist[0] = ((double)esq[0] + (double)esq[1] + (double)esq[2] +
+                  (double)esq[3]) * e_recip;
+    verdist[1] = ((double)esq[4] + (double)esq[5] + (double)esq[6] +
+                  (double)esq[7]) * e_recip;
+    verdist[2] = ((double)esq[8] + (double)esq[9] + (double)esq[10] +
+                  (double)esq[11]) * e_recip;
+  } else {
+    hordist[0] = verdist[0] = 0.25;
+    hordist[1] = verdist[1] = 0.25;
+    hordist[2] = verdist[2] = 0.25;
+  }
+  (void) var[0];
+  (void) var[1];
+  (void) var[2];
+  (void) var[3];
+  (void) var[4];
+  (void) var[5];
+  (void) var[6];
+  (void) var[7];
+  (void) var[8];
+  (void) var[9];
+  (void) var[10];
+  (void) var[11];
+  (void) var[12];
+  (void) var[13];
+  (void) var[14];
+  (void) var[15];
+}
+
+int adst_vs_flipadst(const VP10_COMP *cpi,
+                     BLOCK_SIZE bsize,
+                     uint8_t *src, int src_stride,
+                     uint8_t *dst, int dst_stride,
+                     double *hdist, double *vdist) {
+  int prune_bitmask = 0;
+  double svm_proj_h = 0, svm_proj_v = 0;
+  get_energy_distribution_fine(cpi, bsize, src, src_stride,
+                               dst, dst_stride, hdist, vdist);
+
+
+
+  svm_proj_v = vdist[0] * ADST_FLIP_SVM[0] +
+               vdist[1] * ADST_FLIP_SVM[1] +
+               vdist[2] * ADST_FLIP_SVM[2] + ADST_FLIP_SVM[3];
+  svm_proj_h = hdist[0] * ADST_FLIP_SVM[4] +
+               hdist[1] * ADST_FLIP_SVM[5] +
+               hdist[2] * ADST_FLIP_SVM[6] + ADST_FLIP_SVM[7];
+  if (svm_proj_v > FAST_EXT_TX_EDST_MID + FAST_EXT_TX_EDST_MARGIN)
+    prune_bitmask |= 1 << FLIPADST_1D;
+  else if (svm_proj_v < FAST_EXT_TX_EDST_MID - FAST_EXT_TX_EDST_MARGIN)
+    prune_bitmask |= 1 << ADST_1D;
+
+  if (svm_proj_h > FAST_EXT_TX_EDST_MID + FAST_EXT_TX_EDST_MARGIN)
+    prune_bitmask |= 1 << (FLIPADST_1D + 8);
+  else if (svm_proj_h < FAST_EXT_TX_EDST_MID - FAST_EXT_TX_EDST_MARGIN)
+    prune_bitmask |= 1 << (ADST_1D + 8);
+
+  return prune_bitmask;
+}
+
+#if CONFIG_EXT_TX
+static void get_horver_correlation(int16_t *diff, int stride,
+                                   int w, int h,
+                                   double *hcorr, double *vcorr) {
+  // Returns hor/ver correlation coefficient
+  const int num = (h - 1) * (w - 1);
+  double num_r;
+  int i, j;
+  int64_t xy_sum = 0, xz_sum = 0;
+  int64_t x_sum = 0, y_sum = 0, z_sum = 0;
+  int64_t x2_sum = 0, y2_sum = 0, z2_sum = 0;
+  double x_var_n, y_var_n, z_var_n, xy_var_n, xz_var_n;
+  *hcorr = *vcorr = 1;
+
+  assert(num > 0);
+  num_r = 1.0 / num;
+  for (i = 1; i < h; ++i) {
+    for (j = 1; j < w; ++j) {
+      const int16_t x = diff[i * stride + j];
+      const int16_t y = diff[i * stride + j - 1];
+      const int16_t z = diff[(i - 1) * stride + j];
+      xy_sum += x * y;
+      xz_sum += x * z;
+      x_sum += x;
+      y_sum += y;
+      z_sum += z;
+      x2_sum += x * x;
+      y2_sum += y * y;
+      z2_sum += z * z;
+    }
+  }
+  x_var_n =  x2_sum - (x_sum * x_sum) * num_r;
+  y_var_n =  y2_sum - (y_sum * y_sum) * num_r;
+  z_var_n =  z2_sum - (z_sum * z_sum) * num_r;
+  xy_var_n = xy_sum - (x_sum * y_sum) * num_r;
+  xz_var_n = xz_sum - (x_sum * z_sum) * num_r;
+  if (x_var_n > 0 && y_var_n > 0) {
+    *hcorr = xy_var_n / sqrt(x_var_n * y_var_n);
+    *hcorr = *hcorr < 0 ? 0 : *hcorr;
+  }
+  if (x_var_n > 0 && z_var_n > 0) {
+    *vcorr = xz_var_n / sqrt(x_var_n * z_var_n);
+    *vcorr = *vcorr < 0 ? 0 : *vcorr;
+  }
+}
+
+int dct_vs_dst(int16_t *diff, int stride, int w, int h,
+               double *hcorr, double *vcorr) {
+  int prune_bitmask = 0;
+  get_horver_correlation(diff, stride, w, h, hcorr, vcorr);
+
+  if (*vcorr > FAST_EXT_TX_CORR_MID + FAST_EXT_TX_CORR_MARGIN)
+    prune_bitmask |= 1 << DST_1D;
+  else if (*vcorr < FAST_EXT_TX_CORR_MID - FAST_EXT_TX_CORR_MARGIN)
+    prune_bitmask |= 1 << DCT_1D;
+
+  if (*hcorr > FAST_EXT_TX_CORR_MID + FAST_EXT_TX_CORR_MARGIN)
+    prune_bitmask |= 1 << (DST_1D + 8);
+  else if (*hcorr < FAST_EXT_TX_CORR_MID - FAST_EXT_TX_CORR_MARGIN)
+    prune_bitmask |= 1 << (DCT_1D + 8);
+  return prune_bitmask;
+}
+
+// Performance drop: 0.5%, Speed improvement: 24%
 static int prune_two_for_sby(const VP10_COMP *cpi,
                              BLOCK_SIZE bsize,
                              MACROBLOCK *x,
                              MACROBLOCKD *xd) {
-  (void) cpi;
-  (void) bsize;
-  (void) x;
-  (void) xd;
-  return 3;
+  struct macroblock_plane *const p = &x->plane[0];
+  struct macroblockd_plane *const pd = &xd->plane[0];
+  const BLOCK_SIZE bs = get_plane_block_size(bsize, pd);
+  const int bw = 4 << (b_width_log2_lookup[bs]);
+  const int bh = 4 << (b_height_log2_lookup[bs]);
+  double hdist[3] = {0, 0, 0}, vdist[3] = {0, 0, 0};
+  double hcorr, vcorr;
+  vp10_subtract_plane(x, bsize, 0);
+  return adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride, pd->dst.buf,
+                          pd->dst.stride, hdist, vdist) |
+         dct_vs_dst(p->src_diff, bw, bw, bh, &hcorr, &vcorr);
 }
 
 static int prune_three_for_sby(const VP10_COMP *cpi,
@@ -378,20 +592,22 @@
   (void) bsize;
   (void) x;
   (void) xd;
-  return 7;
+  return 0;
 }
 
 #endif  // CONFIG_EXT_TX
 
+// Performance drop: 0.3%, Speed improvement: 5%
 static int prune_one_for_sby(const VP10_COMP *cpi,
                              BLOCK_SIZE bsize,
                              MACROBLOCK *x,
                              MACROBLOCKD *xd) {
-  (void) cpi;
-  (void) bsize;
-  (void) x;
-  (void) xd;
-  return 1;
+  struct macroblock_plane *const p = &x->plane[0];
+  struct macroblockd_plane *const pd = &xd->plane[0];
+  double hdist[3] = {0, 0, 0}, vdist[3] = {0, 0, 0};
+  vp10_subtract_plane(x, bsize, 0);
+  return adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride, pd->dst.buf,
+                          pd->dst.stride, hdist, vdist);
 }
 
 static int prune_tx_types(const VP10_COMP *cpi,
@@ -458,10 +674,10 @@
     DST_1D,
     DST_1D,
   };
-  if (tx_type == IDTX)
+  if (tx_type >= IDTX)
     return 1;
   return !(((prune >> vtx_tab[tx_type]) & 1) |
-         ((prune >> (htx_tab[tx_type] + TX_TYPES_1D)) & 1));
+         ((prune >> (htx_tab[tx_type] + 8)) & 1));
 #else
   // temporary to avoid compiler warnings
   (void) tx_type;