Refactor lgt

Change get_lgt in order to integrate a later experiment
lgt_from_pred with lgt. There are two main changes.

The main purpose for this change is to unify get_fwd_lgt and
get_inv_lgt functions into a get_lgt function so the lgt basis
functions can always be selected through the same function in
both forward and inverse transform paths. The structure of those
functions will also be consistent with the get_lgt_from_pred
functions that will be added in the lgt-from-pred experiment.

These changes have no impact on the bitstream.

Change-Id: Ifd3dfc1a9e1a250495830ddbf42c201e80aa913e
diff --git a/aom_dsp/txfm_common.h b/aom_dsp/txfm_common.h
index b07c9b0..fa96aca 100644
--- a/aom_dsp/txfm_common.h
+++ b/aom_dsp/txfm_common.h
@@ -31,13 +31,13 @@
   int is_inter;
   int stride;
   uint8_t *dst;
+#if CONFIG_LGT
+  int mode;
+#endif
 #if CONFIG_MRC_TX
   int *valid_mask;
 #endif  // CONFIG_MRC_TX
 #endif  // CONFIG_MRC_TX || CONFIG_LGT
-#if CONFIG_LGT
-  int mode;
-#endif
 // for inverse transforms only
 #if CONFIG_ADAPT_SCAN
   const int16_t *eob_threshold;
@@ -97,9 +97,10 @@
 }
 
 #if CONFIG_LGT
-// The Line Graph Transforms (LGTs) matrices are written as follows.
-// Each 2D array is 16384 times an LGT matrix, which is the matrix of
-// eigenvectors of the graph Laplacian matrices for the line graph.
+/* The Line Graph Transforms (LGTs) matrices are defined as follows.
+ * Each 2D array is sqrt(2)*16384 times an LGT matrix, which is the
+ * matrix of eigenvectors of the graph Laplacian matrix of the associated
+ * line graph. */
 
 // LGT4 name: lgt4_140
 // Self loops: 1.400, 0.000, 0.000, 0.000
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 30607f0..024404e 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1235,19 +1235,6 @@
   return (tx_size == TX_4X4) ? raster_order : (raster_order > 0) ? 2 : 0;
 }
 
-#if CONFIG_LGT
-static INLINE PREDICTION_MODE get_prediction_mode(const MODE_INFO *mi,
-                                                  int plane, TX_SIZE tx_size,
-                                                  int block_idx) {
-  const MB_MODE_INFO *const mbmi = &mi->mbmi;
-  if (is_inter_block(mbmi)) return mbmi->mode;
-
-  int block_raster_idx = av1_block_index_to_raster_order(tx_size, block_idx);
-  return (plane == PLANE_TYPE_Y) ? get_y_mode(mi, block_raster_idx)
-                                 : get_uv_mode(mbmi->uv_mode);
-}
-#endif  // CONFIG_LGT
-
 static INLINE TX_TYPE get_default_tx_type(PLANE_TYPE plane_type,
                                           const MACROBLOCKD *xd, int block_idx,
                                           TX_SIZE tx_size) {
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 737a597..3a3705a 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -265,12 +265,8 @@
 #if CONFIG_LGT
 void ilgt4(const tran_low_t *input, tran_low_t *output,
            const tran_high_t *lgtmtx) {
-  if (!(input[0] | input[1] | input[2] | input[3])) {
-    output[0] = output[1] = output[2] = output[3] = 0;
-    return;
-  }
-
-  // evaluate s[j] = sum of all lgtmtx[i][j]*input[i] over i=1,...,4
+  if (!lgtmtx) assert(0);
+  // evaluate s[j] = sum of all lgtmtx[j]*input[i] over i=1,...,4
   tran_high_t s[4] = { 0 };
   for (int i = 0; i < 4; ++i)
     for (int j = 0; j < 4; ++j) s[j] += lgtmtx[i * 4 + j] * input[i];
@@ -280,7 +276,8 @@
 
 void ilgt8(const tran_low_t *input, tran_low_t *output,
            const tran_high_t *lgtmtx) {
-  // evaluate s[j] = sum of all lgtmtx[i][j]*input[i] over i=1,...,8
+  if (!lgtmtx) assert(0);
+  // evaluate s[j] = sum of all lgtmtx[j]*input[i] over i=1,...,8
   tran_high_t s[8] = { 0 };
   for (int i = 0; i < 8; ++i)
     for (int j = 0; j < 8; ++j) s[j] += lgtmtx[i * 8 + j] * input[i];
@@ -288,26 +285,35 @@
   for (int i = 0; i < 8; ++i) output[i] = WRAPLOW(dct_const_round_shift(s[i]));
 }
 
-// The get_inv_lgt functions return 1 if LGT is chosen to apply, and 0 otherwise
-int get_inv_lgt4(transform_1d tx_orig, const TxfmParam *txfm_param,
-                 const tran_high_t *lgtmtx[], int ntx) {
-  // inter/intra split
-  if (tx_orig == &aom_iadst4_c) {
-    for (int i = 0; i < ntx; ++i)
-      lgtmtx[i] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
+// get_lgt4 and get_lgt8 return 1 and pick a lgt matrix if LGT is chosen to
+// apply. Otherwise they return 0
+int get_lgt4(const TxfmParam *txfm_param, int is_col,
+             const tran_high_t **lgtmtx) {
+  if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
+                 vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
+    lgtmtx[0] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
+    return 1;
+  } else if (!is_col && (htx_tab[txfm_param->tx_type] == ADST_1D ||
+                         htx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
+    lgtmtx[0] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
     return 1;
   }
+  lgtmtx[0] = NULL;
   return 0;
 }
 
-int get_inv_lgt8(transform_1d tx_orig, const TxfmParam *txfm_param,
-                 const tran_high_t *lgtmtx[], int ntx) {
-  // inter/intra split
-  if (tx_orig == &aom_iadst8_c) {
-    for (int i = 0; i < ntx; ++i)
-      lgtmtx[i] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
+int get_lgt8(const TxfmParam *txfm_param, int is_col,
+             const tran_high_t **lgtmtx) {
+  if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
+                 vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
+    lgtmtx[0] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
+    return 1;
+  } else if (!is_col && (htx_tab[txfm_param->tx_type] == ADST_1D ||
+                         htx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
+    lgtmtx[0] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
     return 1;
   }
+  lgtmtx[0] = NULL;
   return 0;
 }
 #endif  // CONFIG_LGT
@@ -356,12 +362,10 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[4];
-  const tran_high_t *lgtmtx_row[4];
-  int use_lgt_col =
-      get_inv_lgt4(IHT_4[tx_type].cols, txfm_param, lgtmtx_col, 4);
-  int use_lgt_row =
-      get_inv_lgt4(IHT_4[tx_type].rows, txfm_param, lgtmtx_row, 4);
+  const tran_high_t *lgtmtx_col[1];
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
+  int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
 #endif
 
   // inverse transform row vectors
@@ -373,7 +377,7 @@
 #else
 #if CONFIG_LGT
     if (use_lgt_row)
-      ilgt4(input, out[i], lgtmtx_row[i]);
+      ilgt4(input, out[i], lgtmtx_row[0]);
     else
 #endif
       IHT_4[tx_type].rows(input, out[i]);
@@ -392,7 +396,7 @@
   for (i = 0; i < 4; ++i) {
 #if CONFIG_LGT
     if (use_lgt_col)
-      ilgt4(tmp[i], out[i], lgtmtx_col[i]);
+      ilgt4(tmp[i], out[i], lgtmtx_col[0]);
     else
 #endif
       IHT_4[tx_type].cols(tmp[i], out[i]);
@@ -454,19 +458,17 @@
   int outstride = n2;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[4];
-  const tran_high_t *lgtmtx_row[8];
-  int use_lgt_col =
-      get_inv_lgt8(IHT_4x8[tx_type].cols, txfm_param, lgtmtx_col, 4);
-  int use_lgt_row =
-      get_inv_lgt4(IHT_4x8[tx_type].rows, txfm_param, lgtmtx_row, 8);
+  const tran_high_t *lgtmtx_col[1];
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
+  int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
 #endif
 
   // inverse transform row vectors and transpose
   for (i = 0; i < n2; ++i) {
 #if CONFIG_LGT
     if (use_lgt_row)
-      ilgt4(input, outtmp, lgtmtx_row[i]);
+      ilgt4(input, outtmp, lgtmtx_row[0]);
     else
 #endif
       IHT_4x8[tx_type].rows(input, outtmp);
@@ -479,7 +481,7 @@
   for (i = 0; i < n; ++i) {
 #if CONFIG_LGT
     if (use_lgt_col)
-      ilgt8(tmp[i], out[i], lgtmtx_col[i]);
+      ilgt8(tmp[i], out[i], lgtmtx_col[0]);
     else
 #endif
       IHT_4x8[tx_type].cols(tmp[i], out[i]);
@@ -538,19 +540,17 @@
   int outstride = n;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[8];
-  const tran_high_t *lgtmtx_row[4];
-  int use_lgt_col =
-      get_inv_lgt4(IHT_8x4[tx_type].cols, txfm_param, lgtmtx_col, 8);
-  int use_lgt_row =
-      get_inv_lgt8(IHT_8x4[tx_type].rows, txfm_param, lgtmtx_row, 4);
+  const tran_high_t *lgtmtx_col[1];
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
+  int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
   // inverse transform row vectors and transpose
   for (i = 0; i < n; ++i) {
 #if CONFIG_LGT
     if (use_lgt_row)
-      ilgt8(input, outtmp, lgtmtx_row[i]);
+      ilgt8(input, outtmp, lgtmtx_row[0]);
     else
 #endif
       IHT_8x4[tx_type].rows(input, outtmp);
@@ -563,7 +563,7 @@
   for (i = 0; i < n2; ++i) {
 #if CONFIG_LGT
     if (use_lgt_col)
-      ilgt4(tmp[i], out[i], lgtmtx_col[i]);
+      ilgt4(tmp[i], out[i], lgtmtx_col[0]);
     else
 #endif
       IHT_8x4[tx_type].cols(tmp[i], out[i]);
@@ -621,16 +621,15 @@
   int outstride = n4;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_row[16];
-  int use_lgt_row =
-      get_inv_lgt4(IHT_4x16[tx_type].rows, txfm_param, lgtmtx_row, 16);
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
 #endif
 
   // inverse transform row vectors and transpose
   for (i = 0; i < n4; ++i) {
 #if CONFIG_LGT
     if (use_lgt_row)
-      ilgt4(input, outtmp, lgtmtx_row[i]);
+      ilgt4(input, outtmp, lgtmtx_row[0]);
     else
 #endif
       IHT_4x16[tx_type].rows(input, outtmp);
@@ -696,9 +695,8 @@
   int outstride = n;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[16];
-  int use_lgt_col =
-      get_inv_lgt4(IHT_16x4[tx_type].cols, txfm_param, lgtmtx_col, 16);
+  const tran_high_t *lgtmtx_col[1];
+  int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
 #endif
 
   // inverse transform row vectors and transpose
@@ -712,7 +710,7 @@
   for (i = 0; i < n4; ++i) {
 #if CONFIG_LGT
     if (use_lgt_col)
-      ilgt4(tmp[i], out[i], lgtmtx_col[i]);
+      ilgt4(tmp[i], out[i], lgtmtx_col[0]);
     else
 #endif
       IHT_16x4[tx_type].cols(tmp[i], out[i]);
@@ -770,16 +768,15 @@
   int outstride = n2;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_row[16];
-  int use_lgt_row =
-      get_inv_lgt8(IHT_8x16[tx_type].rows, txfm_param, lgtmtx_row, 16);
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
   // inverse transform row vectors and transpose
   for (i = 0; i < n2; ++i) {
 #if CONFIG_LGT
     if (use_lgt_row)
-      ilgt8(input, outtmp, lgtmtx_row[i]);
+      ilgt8(input, outtmp, lgtmtx_row[0]);
     else
 #endif
       IHT_8x16[tx_type].rows(input, outtmp);
@@ -846,9 +843,8 @@
   int outstride = n;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[16];
-  int use_lgt_col =
-      get_inv_lgt8(IHT_16x8[tx_type].cols, txfm_param, lgtmtx_col, 16);
+  const tran_high_t *lgtmtx_col[1];
+  int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
 #endif
 
   // inverse transform row vectors and transpose
@@ -863,7 +859,7 @@
   for (i = 0; i < n2; ++i) {
 #if CONFIG_LGT
     if (use_lgt_col)
-      ilgt8(tmp[i], out[i], lgtmtx_col[i]);
+      ilgt8(tmp[i], out[i], lgtmtx_col[0]);
     else
 #endif
       IHT_16x8[tx_type].cols(tmp[i], out[i]);
@@ -921,16 +917,15 @@
   int outstride = n4;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_row[32];
-  int use_lgt_row =
-      get_inv_lgt8(IHT_8x32[tx_type].rows, txfm_param, lgtmtx_row, 32);
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
   // inverse transform row vectors and transpose
   for (i = 0; i < n4; ++i) {
 #if CONFIG_LGT
     if (use_lgt_row)
-      ilgt8(input, outtmp, lgtmtx_row[i]);
+      ilgt8(input, outtmp, lgtmtx_row[0]);
     else
 #endif
       IHT_8x32[tx_type].rows(input, outtmp);
@@ -996,9 +991,8 @@
   int outstride = n;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[32];
-  int use_lgt_col =
-      get_inv_lgt4(IHT_32x8[tx_type].cols, txfm_param, lgtmtx_col, 32);
+  const tran_high_t *lgtmtx_col[1];
+  int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
 #endif
 
   // inverse transform row vectors and transpose
@@ -1012,7 +1006,7 @@
   for (i = 0; i < n4; ++i) {
 #if CONFIG_LGT
     if (use_lgt_col)
-      ilgt8(tmp[i], out[i], lgtmtx_col[i]);
+      ilgt8(tmp[i], out[i], lgtmtx_col[0]);
     else
 #endif
       IHT_32x8[tx_type].cols(tmp[i], out[i]);
@@ -1193,12 +1187,10 @@
   int outstride = 8;
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[8];
-  const tran_high_t *lgtmtx_row[8];
-  int use_lgt_col =
-      get_inv_lgt8(IHT_8[tx_type].cols, txfm_param, lgtmtx_col, 8);
-  int use_lgt_row =
-      get_inv_lgt8(IHT_8[tx_type].rows, txfm_param, lgtmtx_row, 8);
+  const tran_high_t *lgtmtx_col[1];
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
+  int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
   // inverse transform row vectors
@@ -1210,7 +1202,7 @@
 #else
 #if CONFIG_LGT
     if (use_lgt_row)
-      ilgt8(input, out[i], lgtmtx_row[i]);
+      ilgt8(input, out[i], lgtmtx_row[0]);
     else
 #endif
       IHT_8[tx_type].rows(input, out[i]);
@@ -1229,7 +1221,7 @@
   for (i = 0; i < 8; ++i) {
 #if CONFIG_LGT
     if (use_lgt_col)
-      ilgt8(tmp[i], out[i], lgtmtx_col[i]);
+      ilgt8(tmp[i], out[i], lgtmtx_col[0]);
     else
 #endif
       IHT_8[tx_type].cols(tmp[i], out[i]);
@@ -2294,9 +2286,6 @@
                                         av1_highbd_inv_txfm_add };
 #endif
 
-// TODO(kslu) Change input arguments to TxfmParam, which contains mode,
-// tx_type, tx_size, dst, stride, eob. Thus, the additional argument when LGT
-// is on will no longer be needed.
 void av1_inverse_transform_block(const MACROBLOCKD *xd,
                                  const tran_low_t *dqcoeff,
 #if CONFIG_LGT
@@ -2321,13 +2310,13 @@
   TxfmParam txfm_param;
   init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
 #if CONFIG_LGT || CONFIG_MRC_TX
+  txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
   txfm_param.dst = dst;
   txfm_param.stride = stride;
-  txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
-#endif  // CONFIG_LGT || CONFIG_MRC_TX
 #if CONFIG_LGT
   txfm_param.mode = mode;
-#endif
+#endif  // CONFIG_LGT
+#endif  // CONFIG_LGT || CONFIG_MRC_TX
 
   const int is_hbd = get_bitdepth_data_path_index(xd);
 #if CONFIG_TXMG
@@ -2369,14 +2358,11 @@
   const int dst_stride = pd->dst.stride;
   uint8_t *dst =
       &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
+  av1_inverse_transform_block(xd, dqcoeff,
 #if CONFIG_LGT
-  PREDICTION_MODE mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
-  av1_inverse_transform_block(xd, dqcoeff, mode, tx_type, tx_size, dst,
-                              dst_stride, eob);
-#else
-  av1_inverse_transform_block(xd, dqcoeff, tx_type, tx_size, dst, dst_stride,
-                              eob);
+                              xd->mi[0]->mbmi.mode,
 #endif  // CONFIG_LGT
+                              tx_type, tx_size, dst, dst_stride, eob);
 }
 
 void av1_highbd_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
diff --git a/av1/common/idct.h b/av1/common/idct.h
index db3238c..f8a9e91 100644
--- a/av1/common/idct.h
+++ b/av1/common/idct.h
@@ -26,13 +26,19 @@
 extern "C" {
 #endif
 
-// TODO(kslu) move the common stuff in idct.h to av1_txfm.h or txfm_common.h
 typedef void (*transform_1d)(const tran_low_t *, tran_low_t *);
 
 typedef struct {
   transform_1d cols, rows;  // vertical and horizontal
 } transform_2d;
 
+#if CONFIG_LGT
+int get_lgt4(const TxfmParam *txfm_param, int is_col,
+             const tran_high_t **lgtmtx);
+int get_lgt8(const TxfmParam *txfm_param, int is_col,
+             const tran_high_t **lgtmtx);
+#endif  // CONFIG_LGT
+
 #if CONFIG_HIGHBITDEPTH
 typedef void (*highbd_transform_1d)(const tran_low_t *, tran_low_t *, int bd);
 
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 6216aa2..ef2bdb2 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -504,13 +504,9 @@
     if (eob) {
       uint8_t *dst =
           &pd->dst.buf[(row * pd->dst.stride + col) << tx_size_wide_log2[0]];
-#if CONFIG_LGT
-      const PREDICTION_MODE mode =
-          get_prediction_mode(xd->mi[0], plane, tx_size, block_idx);
-#endif  // CONFIG_LGT
       inverse_transform_block(xd, plane,
 #if CONFIG_LGT
-                              mode,
+                              mbmi->mode,
 #endif
                               tx_type, tx_size, dst, pd->dst.stride,
                               max_scan_line, eob);
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index cda6aba..3b47fe9 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1179,10 +1179,7 @@
 #if CONFIG_LGT
 static void flgt4(const tran_low_t *input, tran_low_t *output,
                   const tran_high_t *lgtmtx) {
-  if (!(input[0] | input[1] | input[2] | input[3])) {
-    output[0] = output[1] = output[2] = output[3] = 0;
-    return;
-  }
+  if (!lgtmtx) assert(0);
 
   // evaluate s[j] = sum of all lgtmtx[j][i]*input[i] over i=1,...,4
   tran_high_t s[4] = { 0 };
@@ -1194,6 +1191,8 @@
 
 static void flgt8(const tran_low_t *input, tran_low_t *output,
                   const tran_high_t *lgtmtx) {
+  if (!lgtmtx) assert(0);
+
   // evaluate s[j] = sum of all lgtmtx[j][i]*input[i] over i=1,...,8
   tran_high_t s[8] = { 0 };
   for (int i = 0; i < 8; ++i)
@@ -1201,29 +1200,6 @@
 
   for (int i = 0; i < 8; ++i) output[i] = (tran_low_t)fdct_round_shift(s[i]);
 }
-
-// The get_fwd_lgt functions return 1 if LGT is chosen to apply, and 0 otherwise
-int get_fwd_lgt4(transform_1d tx_orig, TxfmParam *txfm_param,
-                 const tran_high_t *lgtmtx[], int ntx) {
-  // inter/intra split
-  if (tx_orig == &fadst4) {
-    for (int i = 0; i < ntx; ++i)
-      lgtmtx[i] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
-    return 1;
-  }
-  return 0;
-}
-
-int get_fwd_lgt8(transform_1d tx_orig, TxfmParam *txfm_param,
-                 const tran_high_t *lgtmtx[], int ntx) {
-  // inter/intra split
-  if (tx_orig == &fadst8) {
-    for (int i = 0; i < ntx; ++i)
-      lgtmtx[i] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
-    return 1;
-  }
-  return 0;
-}
 #endif  // CONFIG_LGT
 
 #if CONFIG_EXT_TX
@@ -1422,10 +1398,10 @@
 #if CONFIG_LGT
     // Choose LGT adaptive to the prediction. We may apply different LGTs for
     // different rows/columns, indicated by the pointers to 2D arrays
-    const tran_high_t *lgtmtx_col[4];
-    const tran_high_t *lgtmtx_row[4];
-    int use_lgt_col = get_fwd_lgt4(ht.cols, txfm_param, lgtmtx_col, 4);
-    int use_lgt_row = get_fwd_lgt4(ht.rows, txfm_param, lgtmtx_row, 4);
+    const tran_high_t *lgtmtx_col[1];
+    const tran_high_t *lgtmtx_row[1];
+    int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
+    int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
 #endif
 
     // Columns
@@ -1437,7 +1413,7 @@
 #endif
 #if CONFIG_LGT
       if (use_lgt_col)
-        flgt4(temp_in, temp_out, lgtmtx_col[i]);
+        flgt4(temp_in, temp_out, lgtmtx_col[0]);
       else
 #endif
         ht.cols(temp_in, temp_out);
@@ -1449,7 +1425,7 @@
       for (j = 0; j < 4; ++j) temp_in[j] = out[j + i * 4];
 #if CONFIG_LGT
       if (use_lgt_row)
-        flgt4(temp_in, temp_out, lgtmtx_row[i]);
+        flgt4(temp_in, temp_out, lgtmtx_row[0]);
       else
 #endif
         ht.rows(temp_in, temp_out);
@@ -1505,10 +1481,10 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[4];
-  const tran_high_t *lgtmtx_row[8];
-  int use_lgt_col = get_fwd_lgt8(ht.cols, txfm_param, lgtmtx_col, 4);
-  int use_lgt_row = get_fwd_lgt4(ht.rows, txfm_param, lgtmtx_row, 8);
+  const tran_high_t *lgtmtx_col[1];
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
+  int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
 #endif
 
   // Rows
@@ -1518,7 +1494,7 @@
           (tran_low_t)fdct_round_shift(input[i * stride + j] * 4 * Sqrt2);
 #if CONFIG_LGT
     if (use_lgt_row)
-      flgt4(temp_in, temp_out, lgtmtx_row[i]);
+      flgt4(temp_in, temp_out, lgtmtx_row[0]);
     else
 #endif
       ht.rows(temp_in, temp_out);
@@ -1530,7 +1506,7 @@
     for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
 #if CONFIG_LGT
     if (use_lgt_col)
-      flgt8(temp_in, temp_out, lgtmtx_col[i]);
+      flgt8(temp_in, temp_out, lgtmtx_col[0]);
     else
 #endif
       ht.cols(temp_in, temp_out);
@@ -1581,10 +1557,10 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[8];
-  const tran_high_t *lgtmtx_row[4];
-  int use_lgt_col = get_fwd_lgt4(ht.cols, txfm_param, lgtmtx_col, 8);
-  int use_lgt_row = get_fwd_lgt8(ht.rows, txfm_param, lgtmtx_row, 4);
+  const tran_high_t *lgtmtx_col[1];
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
+  int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
   // Columns
@@ -1594,7 +1570,7 @@
           (tran_low_t)fdct_round_shift(input[j * stride + i] * 4 * Sqrt2);
 #if CONFIG_LGT
     if (use_lgt_col)
-      flgt4(temp_in, temp_out, lgtmtx_col[i]);
+      flgt4(temp_in, temp_out, lgtmtx_col[0]);
     else
 #endif
       ht.cols(temp_in, temp_out);
@@ -1606,7 +1582,7 @@
     for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
 #if CONFIG_LGT
     if (use_lgt_row)
-      flgt8(temp_in, temp_out, lgtmtx_row[i]);
+      flgt8(temp_in, temp_out, lgtmtx_row[0]);
     else
 #endif
       ht.rows(temp_in, temp_out);
@@ -1657,8 +1633,8 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_row[16];
-  int use_lgt_row = get_fwd_lgt4(ht.rows, txfm_param, lgtmtx_row, 16);
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
 #endif
 
   // Rows
@@ -1666,7 +1642,7 @@
     for (j = 0; j < n; ++j) temp_in[j] = input[i * stride + j] * 4;
 #if CONFIG_LGT
     if (use_lgt_row)
-      flgt4(temp_in, temp_out, lgtmtx_row[i]);
+      flgt4(temp_in, temp_out, lgtmtx_row[0]);
     else
 #endif
       ht.rows(temp_in, temp_out);
@@ -1724,8 +1700,8 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[16];
-  int use_lgt_col = get_fwd_lgt4(ht.cols, txfm_param, lgtmtx_col, 16);
+  const tran_high_t *lgtmtx_col[1];
+  int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
 #endif
 
   // Columns
@@ -1733,7 +1709,7 @@
     for (j = 0; j < n; ++j) temp_in[j] = input[j * stride + i] * 4;
 #if CONFIG_LGT
     if (use_lgt_col)
-      flgt4(temp_in, temp_out, lgtmtx_col[i]);
+      flgt4(temp_in, temp_out, lgtmtx_col[0]);
     else
 #endif
       ht.cols(temp_in, temp_out);
@@ -1791,8 +1767,8 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_row[16];
-  int use_lgt_row = get_fwd_lgt8(ht.rows, txfm_param, lgtmtx_row, 16);
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
   // Rows
@@ -1802,7 +1778,7 @@
           (tran_low_t)fdct_round_shift(input[i * stride + j] * 4 * Sqrt2);
 #if CONFIG_LGT
     if (use_lgt_row)
-      flgt8(temp_in, temp_out, lgtmtx_row[i]);
+      flgt8(temp_in, temp_out, lgtmtx_row[0]);
     else
 #endif
       ht.rows(temp_in, temp_out);
@@ -1860,8 +1836,8 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[16];
-  int use_lgt_col = get_fwd_lgt8(ht.cols, txfm_param, lgtmtx_col, 16);
+  const tran_high_t *lgtmtx_col[1];
+  int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
 #endif
 
   // Columns
@@ -1871,7 +1847,7 @@
           (tran_low_t)fdct_round_shift(input[j * stride + i] * 4 * Sqrt2);
 #if CONFIG_LGT
     if (use_lgt_col)
-      flgt8(temp_in, temp_out, lgtmtx_col[i]);
+      flgt8(temp_in, temp_out, lgtmtx_col[0]);
     else
 #endif
       ht.cols(temp_in, temp_out);
@@ -1929,8 +1905,8 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_row[32];
-  int use_lgt_row = get_fwd_lgt8(ht.rows, txfm_param, lgtmtx_row, 32);
+  const tran_high_t *lgtmtx_row[1];
+  int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
   // Rows
@@ -1938,7 +1914,7 @@
     for (j = 0; j < n; ++j) temp_in[j] = input[i * stride + j] * 4;
 #if CONFIG_LGT
     if (use_lgt_row)
-      flgt8(temp_in, temp_out, lgtmtx_row[i]);
+      flgt8(temp_in, temp_out, lgtmtx_row[0]);
     else
 #endif
       ht.rows(temp_in, temp_out);
@@ -1996,8 +1972,8 @@
 #endif
 
 #if CONFIG_LGT
-  const tran_high_t *lgtmtx_col[32];
-  int use_lgt_col = get_fwd_lgt8(ht.cols, txfm_param, lgtmtx_col, 32);
+  const tran_high_t *lgtmtx_col[1];
+  int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
 #endif
 
   // Columns
@@ -2005,7 +1981,7 @@
     for (j = 0; j < n; ++j) temp_in[j] = input[j * stride + i] * 4;
 #if CONFIG_LGT
     if (use_lgt_col)
-      flgt8(temp_in, temp_out, lgtmtx_col[i]);
+      flgt8(temp_in, temp_out, lgtmtx_col[0]);
     else
 #endif
       ht.cols(temp_in, temp_out);
@@ -2300,10 +2276,10 @@
 #endif
 
 #if CONFIG_LGT
-    const tran_high_t *lgtmtx_col[8];
-    const tran_high_t *lgtmtx_row[8];
-    int use_lgt_col = get_fwd_lgt8(ht.cols, txfm_param, lgtmtx_col, 8);
-    int use_lgt_row = get_fwd_lgt8(ht.rows, txfm_param, lgtmtx_row, 8);
+    const tran_high_t *lgtmtx_col[1];
+    const tran_high_t *lgtmtx_row[1];
+    int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
+    int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
 #endif
 
     // Columns
@@ -2315,7 +2291,7 @@
 #endif
 #if CONFIG_LGT
       if (use_lgt_col)
-        flgt8(temp_in, temp_out, lgtmtx_col[i]);
+        flgt8(temp_in, temp_out, lgtmtx_col[0]);
       else
 #endif
         ht.cols(temp_in, temp_out);
@@ -2327,7 +2303,7 @@
       for (j = 0; j < 8; ++j) temp_in[j] = out[j + i * 8];
 #if CONFIG_LGT
       if (use_lgt_row)
-        flgt8(temp_in, temp_out, lgtmtx_row[i]);
+        flgt8(temp_in, temp_out, lgtmtx_row[0]);
       else
 #endif
         ht.rows(temp_in, temp_out);
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 13ac4d4..9ecf210 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -615,8 +615,8 @@
 #endif  // CONFIG_MRC_TX
 #endif  // CONFIG_MRC_TX || CONFIG_LGT
 #if CONFIG_LGT
-  txfm_param.mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
-#endif
+  txfm_param.mode = mbmi->mode;
+#endif  // CONFIG_LGT
 
 #if !CONFIG_PVQ
   txfm_param.bd = xd->bd;
@@ -746,7 +746,7 @@
   TX_TYPE tx_type =
       av1_get_tx_type(pd->plane_type, xd, blk_row, blk_col, block, tx_size);
 #if CONFIG_LGT
-  PREDICTION_MODE mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
+  PREDICTION_MODE mode = xd->mi[0]->mbmi.mode;
   av1_inverse_transform_block(xd, dqcoeff, mode, tx_type, tx_size, dst,
                               pd->dst.stride, p->eobs[block]);
 #else
@@ -1063,11 +1063,6 @@
 
   av1_predict_intra_block_facade(xd, plane, block, blk_col, blk_row, tx_size);
 
-#if CONFIG_LGT
-  const PREDICTION_MODE mode =
-      get_prediction_mode(xd->mi[0], plane, tx_size, block);
-#endif  // CONFIG_LGT
-
   av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
 
   const ENTROPY_CONTEXT *a = &args->ta[blk_col];
@@ -1091,7 +1086,7 @@
 #endif  // CONFIG_PVQ
   av1_inverse_transform_block(xd, dqcoeff,
 #if CONFIG_LGT
-                              mode,
+                              xd->mi[0]->mbmi.mode,
 #endif
                               tx_type, tx_size, dst, dst_stride, *eob);
 #if !CONFIG_PVQ
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 54465b8..27171e8 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4560,14 +4560,11 @@
 
   const int eob = p->eobs[block];
 
+  av1_inverse_transform_block(xd, dqcoeff,
 #if CONFIG_LGT
-  PREDICTION_MODE mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
-  av1_inverse_transform_block(xd, dqcoeff, mode, tx_type, tx_size, rec_buffer,
-                              MAX_TX_SIZE, eob);
-#else
-  av1_inverse_transform_block(xd, dqcoeff, tx_type, tx_size, rec_buffer,
-                              MAX_TX_SIZE, eob);
+                              xd->mi[0]->mbmi.mode,
 #endif
+                              tx_type, tx_size, rec_buffer, MAX_TX_SIZE, eob);
   if (eob > 0) {
 #if CONFIG_DIST_8X8
     if (x->using_dist_8x8 && plane == 0 && (bw < 8 && bh < 8)) {