Merge "Refactor transform type-size search function" into nextgenv2
diff --git a/vp10/encoder/rdopt.c b/vp10/encoder/rdopt.c
index a71278b..87c9d0f 100644
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -1311,6 +1311,179 @@
 }
 #endif  // CONFIG_SUPERTX
 
+static int64_t txfm_yrd(VP10_COMP *cpi, MACROBLOCK *x,
+                        int *r, int64_t *d, int *s, int64_t *sse,
+                        int64_t ref_best_rd,
+                        BLOCK_SIZE bs, TX_TYPE tx_type, int tx_size) {
+  VP10_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  int64_t rd = INT64_MAX;
+  vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
+  int s0, s1;
+  const TX_SIZE max_tx_size = max_txsize_lookup[bs];
+  const int tx_select = cm->tx_mode == TX_MODE_SELECT;
+  const int is_inter = is_inter_block(mbmi);
+  const int r_tx_size =
+      cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)][tx_size];
+#if CONFIG_EXT_TX
+  int ext_tx_set;
+#endif  // CONFIG_EXT_TX
+
+  assert(skip_prob > 0);
+  s0 = vp10_cost_bit(skip_prob, 0);
+  s1 = vp10_cost_bit(skip_prob, 1);
+
+  mbmi->tx_type = tx_type;
+  mbmi->tx_size = tx_size;
+  txfm_rd_in_plane(x,
+                   cpi,
+                   r, d, s,
+                   sse, ref_best_rd, 0, bs, tx_size,
+                   cpi->sf.use_fast_coef_costing);
+  if (*r == INT_MAX)
+    return INT64_MAX;
+#if CONFIG_EXT_TX
+  ext_tx_set = get_ext_tx_set(tx_size, bs, is_inter);
+  if (get_ext_tx_types(tx_size, bs, is_inter) > 1 &&
+      !xd->lossless[xd->mi[0]->mbmi.segment_id]) {
+    if (is_inter) {
+      if (ext_tx_set > 0)
+        *r += cpi->inter_tx_type_costs[ext_tx_set]
+                                      [mbmi->tx_size][mbmi->tx_type];
+    } else {
+      if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
+        *r += cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size]
+                                      [mbmi->mode][mbmi->tx_type];
+    }
+  }
+
+#else
+  if (tx_size < TX_32X32 &&
+      !xd->lossless[xd->mi[0]->mbmi.segment_id] && !FIXED_TX_TYPE) {
+    if (is_inter) {
+      *r += cpi->inter_tx_type_costs[mbmi->tx_size][mbmi->tx_type];
+    } else {
+      *r += cpi->intra_tx_type_costs[mbmi->tx_size]
+           [intra_mode_to_tx_type_context[mbmi->mode]]
+           [mbmi->tx_type];
+    }
+  }
+#endif  // CONFIG_EXT_TX
+
+  if (*s) {
+    if (is_inter) {
+      rd = RDCOST(x->rdmult, x->rddiv, s1, *sse);
+    } else {
+      rd =  RDCOST(x->rdmult, x->rddiv, s1 + r_tx_size * tx_select, *sse);
+    }
+  } else {
+    rd = RDCOST(x->rdmult, x->rddiv, *r + s0 + r_tx_size * tx_select, *d);
+  }
+
+  if (tx_select && !(*s && is_inter))
+    *r += r_tx_size;
+
+  if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !(*s))
+    rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, *sse));
+
+  return rd;
+}
+
+static int64_t choose_tx_size_fix_type(VP10_COMP *cpi, MACROBLOCK *x,
+                                       int *rate,
+                                       int64_t *distortion,
+                                       int *skip,
+                                       int64_t *psse,
+                                       int64_t ref_best_rd,
+                                       BLOCK_SIZE bs, TX_TYPE tx_type,
+                                       int prune) {
+  VP10_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  int r, s;
+  int64_t d, sse;
+  int64_t rd = INT64_MAX;
+  int n;
+  int start_tx, end_tx;
+  int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
+  const TX_SIZE max_tx_size = max_txsize_lookup[bs];
+  TX_SIZE best_tx = max_tx_size;
+  const int tx_select = cm->tx_mode == TX_MODE_SELECT;
+  const int is_inter = is_inter_block(mbmi);
+#if CONFIG_EXT_TX
+  int ext_tx_set;
+#endif  // CONFIG_EXT_TX
+
+  if (tx_select) {
+    start_tx = max_tx_size;
+    end_tx = 0;
+  } else {
+    const TX_SIZE chosen_tx_size =
+        VPXMIN(max_tx_size, tx_mode_to_biggest_tx_size[cm->tx_mode]);
+    start_tx = chosen_tx_size;
+    end_tx = chosen_tx_size;
+  }
+
+  *distortion = INT64_MAX;
+  *rate       = INT_MAX;
+  *skip       = 0;
+  *psse       = INT64_MAX;
+
+  mbmi->tx_type = tx_type;
+  last_rd = INT64_MAX;
+  for (n = start_tx; n >= end_tx; --n) {
+    if (FIXED_TX_TYPE && tx_type != get_default_tx_type(0, xd, 0, n))
+        continue;
+#if CONFIG_EXT_TX
+    ext_tx_set = get_ext_tx_set(n, bs, is_inter);
+    if (is_inter) {
+      if (!ext_tx_used_inter[ext_tx_set][tx_type])
+        continue;
+      if (cpi->sf.tx_type_search > 0) {
+        if (!do_tx_type_search(tx_type, prune))
+          continue;
+      }
+    } else {
+      if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
+        if (tx_type != intra_mode_to_tx_type_context[mbmi->mode])
+          continue;
+      }
+      if (!ext_tx_used_intra[ext_tx_set][tx_type])
+        continue;
+    }
+#else  // CONFIG_EXT_TX
+    if (n >= TX_32X32 && tx_type != DCT_DCT)
+      continue;
+    if (is_inter && cpi->sf.tx_type_search > 0 &&
+        !do_tx_type_search(tx_type, prune))
+        continue;
+#endif  // CONFIG_EXT_TX
+
+    rd = txfm_yrd(cpi, x, &r, &d, &s, &sse, ref_best_rd, bs, tx_type, n);
+
+    // Early termination in transform size search.
+    if (cpi->sf.tx_size_search_breakout &&
+        (rd == INT64_MAX ||
+         (s == 1 && tx_type != DCT_DCT && n < start_tx) ||
+         (n < (int) max_tx_size && rd > last_rd)))
+      break;
+
+    last_rd = rd;
+    if (rd < best_rd) {
+      best_tx = n;
+      best_rd = rd;
+      *distortion = d;
+      *rate       = r;
+      *skip       = s;
+      *psse       = sse;
+    }
+  }
+  mbmi->tx_size = best_tx;
+
+  return best_rd;
+}
+
 static void choose_largest_tx_size(VP10_COMP *cpi, MACROBLOCK *x,
                                    int *rate, int64_t *distortion,
                                    int *skip, int64_t *sse,
@@ -1464,155 +1637,36 @@
                                    int64_t *psse,
                                    int64_t ref_best_rd,
                                    BLOCK_SIZE bs) {
-  const TX_SIZE max_tx_size = max_txsize_lookup[bs];
-  VP10_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
-  vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
   int r, s;
   int64_t d, sse;
   int64_t rd = INT64_MAX;
-  int n;
-  int s0, s1;
-  int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
-  TX_SIZE best_tx = max_tx_size;
-  int start_tx, end_tx;
-  const int tx_select = cm->tx_mode == TX_MODE_SELECT;
+  int64_t best_rd = INT64_MAX;
+  TX_SIZE best_tx = max_txsize_lookup[bs];
   const int is_inter = is_inter_block(mbmi);
   TX_TYPE tx_type, best_tx_type = DCT_DCT;
   int prune = 0;
-#if CONFIG_EXT_TX
-  int ext_tx_set;
-#endif  // CONFIG_EXT_TX
 
   if (is_inter && cpi->sf.tx_type_search > 0)
     prune = prune_tx_types(cpi, bs, x, xd);
 
-  assert(skip_prob > 0);
-  s0 = vp10_cost_bit(skip_prob, 0);
-  s1 = vp10_cost_bit(skip_prob, 1);
-
-  if (tx_select) {
-    start_tx = max_tx_size;
-    end_tx = 0;
-  } else {
-    const TX_SIZE chosen_tx_size =
-        VPXMIN(max_tx_size, tx_mode_to_biggest_tx_size[cm->tx_mode]);
-    start_tx = chosen_tx_size;
-    end_tx = chosen_tx_size;
-  }
-
   *distortion = INT64_MAX;
   *rate       = INT_MAX;
   *skip       = 0;
   *psse       = INT64_MAX;
 
   for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
-    last_rd = INT64_MAX;
-    for (n = start_tx; n >= end_tx; --n) {
-      const int r_tx_size =
-          cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)][n];
-      if (FIXED_TX_TYPE && tx_type != get_default_tx_type(0, xd, 0, n))
-          continue;
-#if CONFIG_EXT_TX
-      ext_tx_set = get_ext_tx_set(n, bs, is_inter);
-      if (is_inter) {
-        if (!ext_tx_used_inter[ext_tx_set][tx_type])
-          continue;
-        if (cpi->sf.tx_type_search > 0) {
-          if (!do_tx_type_search(tx_type, prune))
-            continue;
-        }
-      } else {
-        if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
-          if (tx_type != intra_mode_to_tx_type_context[mbmi->mode])
-            continue;
-        }
-        if (!ext_tx_used_intra[ext_tx_set][tx_type])
-          continue;
-      }
-      mbmi->tx_type = tx_type;
-      txfm_rd_in_plane(x,
-                       cpi,
-                       &r, &d, &s,
-                       &sse, ref_best_rd, 0, bs, n,
-                       cpi->sf.use_fast_coef_costing);
-      if (get_ext_tx_types(n, bs, is_inter) > 1 &&
-          !xd->lossless[xd->mi[0]->mbmi.segment_id] &&
-          r != INT_MAX) {
-        if (is_inter) {
-          if (ext_tx_set > 0)
-            r += cpi->inter_tx_type_costs[ext_tx_set]
-                                         [mbmi->tx_size][mbmi->tx_type];
-        } else {
-          if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
-            r += cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size]
-                                         [mbmi->mode][mbmi->tx_type];
-        }
-      }
-#else  // CONFIG_EXT_TX
-      if (n >= TX_32X32 && tx_type != DCT_DCT) {
-        continue;
-      }
-      mbmi->tx_type = tx_type;
-      txfm_rd_in_plane(x,
-                       cpi,
-                       &r, &d, &s,
-                       &sse, ref_best_rd, 0, bs, n,
-                       cpi->sf.use_fast_coef_costing);
-      if (n < TX_32X32 &&
-          !xd->lossless[xd->mi[0]->mbmi.segment_id] &&
-          r != INT_MAX && !FIXED_TX_TYPE) {
-        if (is_inter) {
-          r += cpi->inter_tx_type_costs[mbmi->tx_size][mbmi->tx_type];
-          if (cpi->sf.tx_type_search > 0 && !do_tx_type_search(tx_type, prune))
-              continue;
-        } else {
-          r += cpi->intra_tx_type_costs[mbmi->tx_size]
-              [intra_mode_to_tx_type_context[mbmi->mode]]
-              [mbmi->tx_type];
-        }
-      }
-#endif  // CONFIG_EXT_TX
-
-      if (r == INT_MAX)
-        continue;
-
-      if (s) {
-        if (is_inter) {
-          rd = RDCOST(x->rdmult, x->rddiv, s1, sse);
-        } else {
-          rd =  RDCOST(x->rdmult, x->rddiv, s1 + r_tx_size * tx_select, sse);
-        }
-      } else {
-        rd = RDCOST(x->rdmult, x->rddiv, r + s0 + r_tx_size * tx_select, d);
-      }
-
-      if (tx_select && !(s && is_inter))
-        r += r_tx_size;
-
-      if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !s)
-        rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, sse));
-
-      // Early termination in transform size search.
-      if (cpi->sf.tx_size_search_breakout &&
-          (rd == INT64_MAX ||
-           (s == 1 && tx_type != DCT_DCT && n < start_tx) ||
-           (n < (int) max_tx_size && rd > last_rd)))
-        break;
-
-      last_rd = rd;
-      if (rd <
-          (is_inter && best_tx_type == DCT_DCT ? ext_tx_th : 1) *
-          best_rd) {
-        best_tx = n;
-        best_rd = rd;
-        *distortion = d;
-        *rate       = r;
-        *skip       = s;
-        *psse       = sse;
-        best_tx_type = mbmi->tx_type;
-      }
+    rd = choose_tx_size_fix_type(cpi, x, &r, &d, &s, &sse, ref_best_rd, bs,
+                                 tx_type, prune);
+    if (rd < (is_inter && best_tx_type == DCT_DCT ? ext_tx_th : 1) * best_rd) {
+      best_rd = rd;
+      *distortion = d;
+      *rate       = r;
+      *skip       = s;
+      *psse       = sse;
+      best_tx_type = tx_type;
+      best_tx = mbmi->tx_size;
     }
   }
 
@@ -3102,21 +3156,75 @@
   }
 }
 
+static int64_t select_tx_size_fix_type(const VP10_COMP *cpi, MACROBLOCK *x,
+                                       int *rate, int64_t *dist,
+                                       int *skippable,
+                                       int64_t *sse, BLOCK_SIZE bsize,
+                                       int64_t ref_best_rd, TX_TYPE tx_type) {
+  const VP10_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
+  const int is_inter = is_inter_block(mbmi);
+#if CONFIG_EXT_TX
+  int ext_tx_set = get_ext_tx_set(max_tx_size, bsize, is_inter);
+#endif  // CONFIG_EXT_TX
+  vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
+  int s0 = vp10_cost_bit(skip_prob, 0);
+  int s1 = vp10_cost_bit(skip_prob, 1);
+  int64_t rd;
+
+  mbmi->tx_type = tx_type;
+  inter_block_yrd(cpi, x, rate, dist, skippable, sse, bsize, ref_best_rd);
+
+  if (*rate == INT_MAX)
+    return INT64_MAX;
+
+#if CONFIG_EXT_TX
+  if (get_ext_tx_types(max_tx_size, bsize, is_inter) > 1 &&
+      !xd->lossless[xd->mi[0]->mbmi.segment_id]) {
+    if (is_inter) {
+      if (ext_tx_set > 0)
+        *rate += cpi->inter_tx_type_costs[ext_tx_set]
+                                         [max_tx_size][mbmi->tx_type];
+    } else {
+      if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
+        *rate += cpi->intra_tx_type_costs[ext_tx_set][max_tx_size]
+                                         [mbmi->mode][mbmi->tx_type];
+    }
+  }
+#else  // CONFIG_EXT_TX
+  if (max_tx_size < TX_32X32 && !xd->lossless[xd->mi[0]->mbmi.segment_id]) {
+    if (is_inter)
+      *rate += cpi->inter_tx_type_costs[max_tx_size][mbmi->tx_type];
+    else
+      *rate += cpi->intra_tx_type_costs[max_tx_size]
+                 [intra_mode_to_tx_type_context[mbmi->mode]][mbmi->tx_type];
+  }
+#endif  // CONFIG_EXT_TX
+
+  if (*skippable)
+    rd = RDCOST(x->rdmult, x->rddiv, s1, *sse);
+  else
+    rd = RDCOST(x->rdmult, x->rddiv, *rate + s0, *dist);
+
+  if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !(*skippable))
+    rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, *sse));
+
+  return rd;
+}
+
 static void select_tx_type_yrd(const VP10_COMP *cpi, MACROBLOCK *x,
                                int *rate, int64_t *distortion, int *skippable,
                                int64_t *sse, BLOCK_SIZE bsize,
                                int64_t ref_best_rd) {
   const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
-  const VP10_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   int64_t rd = INT64_MAX;
   int64_t best_rd = INT64_MAX;
   TX_TYPE tx_type, best_tx_type = DCT_DCT;
   const int is_inter = is_inter_block(mbmi);
-  vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
-  int s0 = vp10_cost_bit(skip_prob, 0);
-  int s1 = vp10_cost_bit(skip_prob, 1);
   TX_SIZE best_tx_size[MI_BLOCK_SIZE][MI_BLOCK_SIZE];
   TX_SIZE best_tx = TX_SIZES;
   uint8_t best_blk_skip[256];
@@ -3156,59 +3264,15 @@
       if (!ext_tx_used_intra[ext_tx_set][tx_type])
         continue;
     }
-
-    mbmi->tx_type = tx_type;
-
-    inter_block_yrd(cpi, x, &this_rate, &this_dist, &this_skip, &this_sse,
-                    bsize, ref_best_rd);
-
-    if (get_ext_tx_types(max_tx_size, bsize, is_inter) > 1 &&
-        !xd->lossless[xd->mi[0]->mbmi.segment_id] &&
-        this_rate != INT_MAX) {
-      if (is_inter) {
-        if (ext_tx_set > 0)
-          this_rate += cpi->inter_tx_type_costs[ext_tx_set]
-                                       [max_tx_size][mbmi->tx_type];
-      } else {
-        if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
-          this_rate += cpi->intra_tx_type_costs[ext_tx_set][max_tx_size]
-                                               [mbmi->mode][mbmi->tx_type];
-      }
-    }
 #else  // CONFIG_EXT_TX
-      if (max_tx_size >= TX_32X32 && tx_type != DCT_DCT)
-        continue;
-
-      mbmi->tx_type = tx_type;
-
-      inter_block_yrd(cpi, x, &this_rate, &this_dist, &this_skip, &this_sse,
-                      bsize, ref_best_rd);
-
-      if (max_tx_size < TX_32X32 &&
-          !xd->lossless[xd->mi[0]->mbmi.segment_id] &&
-          this_rate != INT_MAX) {
-        if (is_inter) {
-          this_rate += cpi->inter_tx_type_costs[max_tx_size][mbmi->tx_type];
-          if (cpi->sf.tx_type_search > 0 && !do_tx_type_search(tx_type, prune))
-              continue;
-        } else {
-          this_rate += cpi->intra_tx_type_costs[max_tx_size]
-              [intra_mode_to_tx_type_context[mbmi->mode]]
-              [mbmi->tx_type];
-        }
-      }
-#endif  // CONFIG_EXT_TX
-
-    if (this_rate == INT_MAX)
+    if (max_tx_size >= TX_32X32 && tx_type != DCT_DCT)
       continue;
-
-    if (this_skip)
-      rd = RDCOST(x->rdmult, x->rddiv, s1, this_sse);
-    else
-      rd = RDCOST(x->rdmult, x->rddiv, this_rate + s0, this_dist);
-
-    if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !this_skip)
-      rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, this_sse));
+    if (is_inter && cpi->sf.tx_type_search > 0 &&
+        !do_tx_type_search(tx_type, prune))
+      continue;
+#endif  // CONFIG_EXT_TX
+    rd = select_tx_size_fix_type(cpi, x, &this_rate, &this_dist, &this_skip,
+                                 &this_sse, bsize, ref_best_rd, tx_type);
 
     if (rd < (is_inter && best_tx_type == DCT_DCT ? ext_tx_th : 1) * best_rd) {
       best_rd = rd;