Refactor rate-distortion optimization of recursive transform partition

Support rectangular transform block in the rate-distortion cost
estimator.

Change-Id: I99201fcae797c1ed2f2184021a215867eac0288f
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c11f357..bd5bffe 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2918,9 +2918,12 @@
   TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, is_inter_block(&xd->mi[0]->mbmi));
-
   BLOCK_SIZE txm_bsize = txsize_to_bsize[tx_size];
-  int bh = 4 * num_4x4_blocks_wide_lookup[txm_bsize];
+  int bh = block_size_high[txm_bsize];
+  int bw = block_size_wide[txm_bsize];
+  int txb_h = tx_size_high_unit[tx_size];
+  int txb_w = tx_size_wide_unit[tx_size];
+
   int src_stride = p->src.stride;
   uint8_t *src = &p->src.buf[4 * blk_row * src_stride + 4 * blk_col];
   uint8_t *dst = &pd->dst.buf[4 * blk_row * pd->dst.stride + 4 * blk_col];
@@ -2930,20 +2933,21 @@
 #else
   DECLARE_ALIGNED(16, uint8_t, rec_buffer[MAX_TX_SQUARE]);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-  const int diff_stride = 4 * num_4x4_blocks_wide_lookup[plane_bsize];
+  int max_blocks_high = block_size_high[plane_bsize];
+  int max_blocks_wide = block_size_wide[plane_bsize];
+  const int diff_stride = max_blocks_wide;
   const int16_t *diff = &p->src_diff[4 * (blk_row * diff_stride + blk_col)];
-
-  int max_blocks_high = num_4x4_blocks_high_lookup[plane_bsize];
-  int max_blocks_wide = num_4x4_blocks_wide_lookup[plane_bsize];
-
 #if CONFIG_EXT_TX
   assert(tx_size < TX_SIZES);
 #endif  // CONFIG_EXT_TX
 
   if (xd->mb_to_bottom_edge < 0)
-    max_blocks_high += xd->mb_to_bottom_edge >> (5 + pd->subsampling_y);
+    max_blocks_high += xd->mb_to_bottom_edge >> (3 + pd->subsampling_y);
   if (xd->mb_to_right_edge < 0)
-    max_blocks_wide += xd->mb_to_right_edge >> (5 + pd->subsampling_x);
+    max_blocks_wide += xd->mb_to_right_edge >> (3 + pd->subsampling_x);
+
+  max_blocks_high >>= tx_size_wide_log2[0];
+  max_blocks_wide >>= tx_size_wide_log2[0];
 
 #if CONFIG_NEW_QUANT
   av1_xform_quant_fp_nuq(cm, x, plane, block, blk_row, blk_col, plane_bsize,
@@ -2960,22 +2964,21 @@
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
     rec_buffer = CONVERT_TO_BYTEPTR(rec_buffer16);
     aom_highbd_convolve_copy(dst, pd->dst.stride, rec_buffer, MAX_TX_SIZE, NULL,
-                             0, NULL, 0, bh, bh, xd->bd);
+                             0, NULL, 0, bw, bh, xd->bd);
   } else {
     rec_buffer = (uint8_t *)rec_buffer16;
     aom_convolve_copy(dst, pd->dst.stride, rec_buffer, MAX_TX_SIZE, NULL, 0,
-                      NULL, 0, bh, bh);
+                      NULL, 0, bw, bh);
   }
 #else
   aom_convolve_copy(dst, pd->dst.stride, rec_buffer, MAX_TX_SIZE, NULL, 0, NULL,
-                    0, bh, bh);
+                    0, bw, bh);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
-  if (blk_row + (bh >> 2) > max_blocks_high ||
-      blk_col + (bh >> 2) > max_blocks_wide) {
+  if (blk_row + txb_h > max_blocks_high || blk_col + txb_w > max_blocks_wide) {
     int idx, idy;
-    int blocks_height = AOMMIN(bh >> 2, max_blocks_high - blk_row);
-    int blocks_width = AOMMIN(bh >> 2, max_blocks_wide - blk_col);
+    int blocks_height = AOMMIN(txb_h, max_blocks_high - blk_row);
+    int blocks_width = AOMMIN(txb_w, max_blocks_wide - blk_col);
     tmp = 0;
     for (idy = 0; idy < blocks_height; idy += 2) {
       for (idx = 0; idx < blocks_width; idx += 2) {
@@ -2984,7 +2987,7 @@
       }
     }
   } else {
-    tmp = aom_sum_squares_2d_i16(diff, diff_stride, bh);
+    tmp = sum_squares_2d(diff, diff_stride, tx_size);
   }
 
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -3010,12 +3013,12 @@
     inv_txfm_add(dqcoeff, rec_buffer, MAX_TX_SIZE, &inv_txfm_param);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
-    if ((bh >> 2) + blk_col > max_blocks_wide ||
-        (bh >> 2) + blk_row > max_blocks_high) {
+    if (txb_h + blk_col > max_blocks_wide ||
+        txb_w + blk_row > max_blocks_high) {
       int idx, idy;
       unsigned int this_dist;
-      int blocks_height = AOMMIN(bh >> 2, max_blocks_high - blk_row);
-      int blocks_width = AOMMIN(bh >> 2, max_blocks_wide - blk_col);
+      int blocks_height = AOMMIN(txb_h, max_blocks_high - blk_row);
+      int blocks_width = AOMMIN(txb_w, max_blocks_wide - blk_col);
       tmp = 0;
       for (idy = 0; idy < blocks_height; idy += 2) {
         for (idx = 0; idx < blocks_width; idx += 2) {