Rework loop filter tx size selection

Update and capture the effective transform block size per color
plane.

Change-Id: Ib6e0e7abb3973db6b8d511ee7c9948aaab048788
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index c1b596b..80f6b09 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -125,6 +125,11 @@
 #if CONFIG_VAR_TX
   aom_free(cm->above_txfm_context);
   cm->above_txfm_context = NULL;
+
+  for (i = 0; i < MAX_MB_PLANE; ++i) {
+    aom_free(cm->top_txfm_context[i]);
+    cm->top_txfm_context[i] = NULL;
+  }
 #endif
 }
 
@@ -170,6 +175,14 @@
     cm->above_txfm_context = (TXFM_CONTEXT *)aom_calloc(
         aligned_mi_cols << TX_UNIT_WIDE_LOG2, sizeof(*cm->above_txfm_context));
     if (!cm->above_txfm_context) goto fail;
+
+    for (i = 0; i < MAX_MB_PLANE; ++i) {
+      aom_free(cm->top_txfm_context[i]);
+      cm->top_txfm_context[i] =
+          (TXFM_CONTEXT *)aom_calloc(aligned_mi_cols << TX_UNIT_WIDE_LOG2,
+                                     sizeof(*cm->top_txfm_context[0]));
+      if (!cm->top_txfm_context[i]) goto fail;
+    }
 #endif
 
     cm->above_context_alloc_cols = aligned_mi_cols;
diff --git a/av1/common/av1_loopfilter.c b/av1/common/av1_loopfilter.c
index f40de33..1f0cb3b 100644
--- a/av1/common/av1_loopfilter.c
+++ b/av1/common/av1_loopfilter.c
@@ -1363,7 +1363,7 @@
 // the non420 case).
 // Note: 'row_masks_ptr' and/or 'col_masks_ptr' can be passed NULL.
 static void get_filter_level_and_masks_non420(
-    AV1_COMMON *const cm, const struct macroblockd_plane *const plane,
+    AV1_COMMON *const cm, const struct macroblockd_plane *const plane, int pl,
     MODE_INFO **mib, int mi_row, int mi_col, int idx_r, uint8_t *const lfl_r,
     unsigned int *const mask_4x4_int_r_ptr,
     unsigned int *const mask_4x4_int_c_ptr, FilterMasks *const row_masks_ptr,
@@ -1400,10 +1400,6 @@
         (num_4x4_blocks_high_lookup[sb_type] > 1) ? !blk_row : 1;
     const int skip_this_r = skip_this && !block_edge_above;
 
-#if CONFIG_VAR_TX
-    const TX_SIZE mb_tx_size = mbmi->inter_tx_size[blk_row][blk_col];
-#endif
-
     TX_SIZE tx_size = (plane->plane_type == PLANE_TYPE_UV)
                           ? get_uv_tx_size(mbmi, plane)
                           : mbmi->tx_size;
@@ -1420,8 +1416,15 @@
 
 #if CONFIG_VAR_TX
     if (is_inter_block(mbmi) && !mbmi->skip) {
+      const int tx_row_idx =
+          (blk_row * mi_size_high[BLOCK_8X8] << TX_UNIT_HIGH_LOG2) >> 1;
+      const int tx_col_idx =
+          (blk_col * mi_size_wide[BLOCK_8X8] << TX_UNIT_WIDE_LOG2) >> 1;
+      const BLOCK_SIZE bsize =
+          AOMMAX(BLOCK_4X4, get_plane_block_size(mbmi->sb_type, plane));
+      const TX_SIZE mb_tx_size = mbmi->inter_tx_size[tx_row_idx][tx_col_idx];
       tx_size = (plane->plane_type == PLANE_TYPE_UV)
-                    ? uv_txsize_lookup[sb_type][mb_tx_size][ss_x][ss_y]
+                    ? uv_txsize_lookup[bsize][mb_tx_size][0][0]
                     : mb_tx_size;
     }
 #endif
@@ -1434,18 +1437,29 @@
 #endif
 
 #if CONFIG_VAR_TX
-    TX_SIZE tx_size_r = AOMMIN(
-        tx_size, cm->above_txfm_context[(mi_col + c) << TX_UNIT_WIDE_LOG2]);
-    TX_SIZE tx_size_c =
-        AOMMIN(tx_size, cm->left_txfm_context[((mi_row + r) & MAX_MIB_MASK)
-                                              << TX_UNIT_HIGH_LOG2]);
+    TX_SIZE tx_size_r, tx_size_c;
 
-    cm->above_txfm_context[(mi_col + c) << TX_UNIT_WIDE_LOG2] = tx_size;
-    cm->left_txfm_context[((mi_row + r) & MAX_MIB_MASK) << TX_UNIT_HIGH_LOG2] =
-        tx_size;
+    const int tx_wide =
+        AOMMIN(tx_size_wide[tx_size],
+               tx_size_wide[cm->top_txfm_context[pl][(mi_col + idx_c)
+                                                     << TX_UNIT_WIDE_LOG2]]);
+    const int tx_high = AOMMIN(
+        tx_size_high[tx_size],
+        tx_size_high[cm->left_txfm_context[pl][((mi_row + idx_r) & MAX_MIB_MASK)
+                                               << TX_UNIT_HIGH_LOG2]]);
+
+    tx_size_c = get_sqr_tx_size(tx_wide);
+    tx_size_r = get_sqr_tx_size(tx_high);
+
+    memset(cm->top_txfm_context[pl] + ((mi_col + idx_c) << TX_UNIT_WIDE_LOG2),
+           tx_size, mi_size_wide[BLOCK_8X8] << TX_UNIT_WIDE_LOG2);
+    memset(cm->left_txfm_context[pl] +
+               (((mi_row + idx_r) & MAX_MIB_MASK) << TX_UNIT_HIGH_LOG2),
+           tx_size, mi_size_high[BLOCK_8X8] << TX_UNIT_HIGH_LOG2);
 #else
     TX_SIZE tx_size_c = txsize_horz_map[tx_size];
     TX_SIZE tx_size_r = txsize_vert_map[tx_size];
+    (void)pl;
 #endif  // CONFIG_VAR_TX
 
     if (tx_size_c == TX_32X32)
@@ -1530,8 +1544,8 @@
 
 void av1_filter_block_plane_non420_ver(AV1_COMMON *const cm,
                                        struct macroblockd_plane *plane,
-                                       MODE_INFO **mib, int mi_row,
-                                       int mi_col) {
+                                       MODE_INFO **mib, int mi_row, int mi_col,
+                                       int pl) {
   const int ss_y = plane->subsampling_y;
   const int row_step = mi_size_high[BLOCK_8X8] << ss_y;
   struct buf_2d *const dst = &plane->dst;
@@ -1544,7 +1558,7 @@
     unsigned int mask_4x4_int;
     FilterMasks col_masks;
     const int r = idx_r >> mi_height_log2_lookup[BLOCK_8X8];
-    get_filter_level_and_masks_non420(cm, plane, mib, mi_row, mi_col, idx_r,
+    get_filter_level_and_masks_non420(cm, plane, pl, mib, mi_row, mi_col, idx_r,
                                       &lfl[r][0], NULL, &mask_4x4_int, NULL,
                                       &col_masks);
 
@@ -1579,8 +1593,8 @@
 
 void av1_filter_block_plane_non420_hor(AV1_COMMON *const cm,
                                        struct macroblockd_plane *plane,
-                                       MODE_INFO **mib, int mi_row,
-                                       int mi_col) {
+                                       MODE_INFO **mib, int mi_row, int mi_col,
+                                       int pl) {
   const int ss_y = plane->subsampling_y;
   const int row_step = mi_size_high[BLOCK_8X8] << ss_y;
   struct buf_2d *const dst = &plane->dst;
@@ -1592,7 +1606,7 @@
   for (idx_r = 0; idx_r < cm->mib_size && mi_row + idx_r < cm->mi_rows;
        idx_r += row_step) {
     const int r = idx_r >> mi_height_log2_lookup[BLOCK_8X8];
-    get_filter_level_and_masks_non420(cm, plane, mib, mi_row, mi_col, idx_r,
+    get_filter_level_and_masks_non420(cm, plane, pl, mib, mi_row, mi_col, idx_r,
                                       &lfl[r][0], mask_4x4_int + r, NULL,
                                       row_masks_array + r, NULL);
   }
@@ -2192,12 +2206,15 @@
   int mi_row, mi_col;
 
 #if CONFIG_VAR_TX
-  memset(cm->above_txfm_context, TX_SIZES, cm->mi_cols << TX_UNIT_WIDE_LOG2);
+  for (int i = 0; i < MAX_MB_PLANE; ++i)
+    memset(cm->top_txfm_context[i], TX_32X32, cm->mi_cols << TX_UNIT_WIDE_LOG2);
 #endif  // CONFIG_VAR_TX
   for (mi_row = start; mi_row < stop; mi_row += cm->mib_size) {
     MODE_INFO **mi = cm->mi_grid_visible + mi_row * cm->mi_stride;
 #if CONFIG_VAR_TX
-    memset(cm->left_txfm_context, TX_SIZES, MAX_MIB_SIZE << TX_UNIT_WIDE_LOG2);
+    for (int i = 0; i < MAX_MB_PLANE; ++i)
+      memset(cm->left_txfm_context[i], TX_32X32, MAX_MIB_SIZE
+                                                     << TX_UNIT_WIDE_LOG2);
 #endif  // CONFIG_VAR_TX
     for (mi_col = 0; mi_col < cm->mi_cols; mi_col += cm->mib_size) {
       int plane;
@@ -2206,9 +2223,9 @@
 
       for (plane = 0; plane < num_planes; ++plane) {
         av1_filter_block_plane_non420_ver(cm, &planes[plane], mi + mi_col,
-                                          mi_row, mi_col);
+                                          mi_row, mi_col, plane);
         av1_filter_block_plane_non420_hor(cm, &planes[plane], mi + mi_col,
-                                          mi_row, mi_col);
+                                          mi_row, mi_col, plane);
       }
     }
   }
@@ -2282,9 +2299,9 @@
             break;
           case LF_PATH_SLOW:
             av1_filter_block_plane_non420_ver(cm, &planes[plane], mi + mi_col,
-                                              mi_row, mi_col);
+                                              mi_row, mi_col, plane);
             av1_filter_block_plane_non420_hor(cm, &planes[plane], mi + mi_col,
-                                              mi_row, mi_col);
+                                              mi_row, mi_col, plane);
 
             break;
         }
diff --git a/av1/common/av1_loopfilter.h b/av1/common/av1_loopfilter.h
index 8ac5d99..424df24 100644
--- a/av1/common/av1_loopfilter.h
+++ b/av1/common/av1_loopfilter.h
@@ -115,11 +115,11 @@
 void av1_filter_block_plane_non420_ver(struct AV1Common *const cm,
                                        struct macroblockd_plane *plane,
                                        MODE_INFO **mi_8x8, int mi_row,
-                                       int mi_col);
+                                       int mi_col, int pl);
 void av1_filter_block_plane_non420_hor(struct AV1Common *const cm,
                                        struct macroblockd_plane *plane,
                                        MODE_INFO **mi_8x8, int mi_row,
-                                       int mi_col);
+                                       int mi_col, int pl);
 
 void av1_loop_filter_init(struct AV1Common *cm);
 
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 795af30..38750c4 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -386,7 +386,8 @@
   ENTROPY_CONTEXT *above_context[MAX_MB_PLANE];
 #if CONFIG_VAR_TX
   TXFM_CONTEXT *above_txfm_context;
-  TXFM_CONTEXT left_txfm_context[2 * MAX_MIB_SIZE];
+  TXFM_CONTEXT *top_txfm_context[MAX_MB_PLANE];
+  TXFM_CONTEXT left_txfm_context[MAX_MB_PLANE][2 * MAX_MIB_SIZE];
 #endif
   int above_context_alloc_cols;
 
@@ -928,6 +929,21 @@
   for (i = 0; i < bw; ++i) above_ctx[i] = txw;
 }
 
+static INLINE TX_SIZE get_sqr_tx_size(int tx_dim) {
+  TX_SIZE tx_size;
+  switch (tx_dim) {
+#if CONFIG_EXT_PARTITION
+    case 128:
+#endif
+    case 64:
+    case 32: tx_size = TX_32X32; break;
+    case 16: tx_size = TX_16X16; break;
+    case 8: tx_size = TX_8X8; break;
+    default: tx_size = TX_4X4;
+  }
+  return tx_size;
+}
+
 static INLINE int txfm_partition_context(TXFM_CONTEXT *above_ctx,
                                          TXFM_CONTEXT *left_ctx,
                                          BLOCK_SIZE bsize, TX_SIZE tx_size) {
@@ -935,22 +951,13 @@
   const uint8_t txh = tx_size_high[tx_size];
   const int above = *above_ctx < txw;
   const int left = *left_ctx < txh;
-  TX_SIZE max_tx_size = max_txsize_lookup[bsize];
   int category = TXFM_PARTITION_CONTEXTS - 1;
 
   // dummy return, not used by others.
   if (tx_size <= TX_4X4) return 0;
 
-  switch (AOMMAX(block_size_wide[bsize], block_size_high[bsize])) {
-#if CONFIG_EXT_PARTITION
-    case 128:
-#endif
-    case 64:
-    case 32: max_tx_size = TX_32X32; break;
-    case 16: max_tx_size = TX_16X16; break;
-    case 8: max_tx_size = TX_8X8; break;
-    default: assert(max_tx_size == max_txsize_lookup[bsize]); assert(0);
-  }
+  TX_SIZE max_tx_size =
+      get_sqr_tx_size(AOMMAX(block_size_wide[bsize], block_size_high[bsize]));
 
   if (max_tx_size >= TX_8X8) {
     category = (tx_size != max_tx_size && max_tx_size > TX_8X8) +
diff --git a/av1/common/thread_common.c b/av1/common/thread_common.c
index ca8b1b3..d96a71a 100644
--- a/av1/common/thread_common.c
+++ b/av1/common/thread_common.c
@@ -113,7 +113,7 @@
         break;
       case LF_PATH_SLOW:
         av1_filter_block_plane_non420_ver(cm, &planes[plane], mi, mi_row,
-                                          mi_col);
+                                          mi_col, plane);
         break;
     }
   }
@@ -135,7 +135,7 @@
         break;
       case LF_PATH_SLOW:
         av1_filter_block_plane_non420_hor(cm, &planes[plane], mi, mi_row,
-                                          mi_col);
+                                          mi_col, plane);
         break;
     }
   }
@@ -168,7 +168,7 @@
 #if CONFIG_EXT_PARTITION_TYPES
       for (plane = 0; plane < num_planes; ++plane)
         av1_filter_block_plane_non420_ver(lf_data->cm, &lf_data->planes[plane],
-                                          mi + mi_col, mi_row, mi_col);
+                                          mi + mi_col, mi_row, mi_col, plane);
 #else
 
       for (plane = 0; plane < num_planes; ++plane)
@@ -213,7 +213,7 @@
 #if CONFIG_EXT_PARTITION_TYPES
       for (plane = 0; plane < num_planes; ++plane)
         av1_filter_block_plane_non420_hor(lf_data->cm, &lf_data->planes[plane],
-                                          mi + mi_col, mi_row, mi_col);
+                                          mi + mi_col, mi_row, mi_col, plane);
 #else
       for (plane = 0; plane < num_planes; ++plane)
         loop_filter_block_plane_hor(lf_data->cm, lf_data->planes, plane,
@@ -263,9 +263,9 @@
 #if CONFIG_EXT_PARTITION_TYPES
       for (plane = 0; plane < num_planes; ++plane) {
         av1_filter_block_plane_non420_ver(lf_data->cm, &lf_data->planes[plane],
-                                          mi + mi_col, mi_row, mi_col);
+                                          mi + mi_col, mi_row, mi_col, plane);
         av1_filter_block_plane_non420_hor(lf_data->cm, &lf_data->planes[plane],
-                                          mi + mi_col, mi_row, mi_col);
+                                          mi + mi_col, mi_row, mi_col, plane);
       }
 #else
       av1_setup_mask(lf_data->cm, mi_row, mi_col, mi + mi_col,