diff --git a/av1/common/enums.h b/av1/common/enums.h
index 8b3d00d..0c09a1b 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -291,8 +291,6 @@
   EXT_TX_SET_TYPES
 } UENUM1BYTE(TxSetType);
 
-#define IS_2D_TRANSFORM(tx_type) (tx_type < IDTX)
-
 #define EXT_TX_SIZES 4       // number of sizes that use extended transforms
 #define EXT_TX_SETS_INTER 4  // Sets of transform selections for INTER
 #define EXT_TX_SETS_INTRA 3  // Sets of transform selections for INTRA
diff --git a/av1/common/quant_common.c b/av1/common/quant_common.c
index 804eb6a..9b967fe 100644
--- a/av1/common/quant_common.c
+++ b/av1/common/quant_common.c
@@ -244,6 +244,36 @@
   return quant_params->gqmatrix[qmlevel][plane][tx_size];
 }
 
+// Returns true if the tx_type corresponds to non-identity transform in both
+// horizontal and vertical directions.
+static INLINE bool is_2d_transform(TX_TYPE tx_type) { return (tx_type < IDTX); }
+
+const qm_val_t *av1_get_iqmatrix(const CommonQuantParams *const quant_params,
+                                 const MACROBLOCKD *const xd, int plane,
+                                 TX_SIZE tx_size, TX_TYPE tx_type) {
+  const struct macroblockd_plane *const pd = &xd->plane[plane];
+  const MB_MODE_INFO *const mbmi = xd->mi[0];
+  const int seg_id = mbmi->segment_id;
+  const TX_SIZE qm_tx_size = av1_get_adjusted_tx_size(tx_size);
+  // Use a flat matrix (i.e. no weighting) for 1D and Identity transforms
+  return is_2d_transform(tx_type)
+             ? pd->seg_iqmatrix[seg_id][qm_tx_size]
+             : quant_params->giqmatrix[NUM_QM_LEVELS - 1][0][qm_tx_size];
+}
+
+const qm_val_t *av1_get_qmatrix(const CommonQuantParams *const quant_params,
+                                const MACROBLOCKD *const xd, int plane,
+                                TX_SIZE tx_size, TX_TYPE tx_type) {
+  const struct macroblockd_plane *const pd = &xd->plane[plane];
+  const MB_MODE_INFO *const mbmi = xd->mi[0];
+  const int seg_id = mbmi->segment_id;
+  const TX_SIZE qm_tx_size = av1_get_adjusted_tx_size(tx_size);
+  // Use a flat matrix (i.e. no weighting) for 1D and Identity transforms
+  return is_2d_transform(tx_type)
+             ? pd->seg_qmatrix[seg_id][qm_tx_size]
+             : quant_params->gqmatrix[NUM_QM_LEVELS - 1][0][qm_tx_size];
+}
+
 #define QM_TOTAL_SIZE 3344
 // We only use wt_matrix_ref[q] and iwt_matrix_ref[q]
 // for q = 0, ..., NUM_QM_LEVELS - 2.
diff --git a/av1/common/quant_common.h b/av1/common/quant_common.h
index 8dc0419..fb27326 100644
--- a/av1/common/quant_common.h
+++ b/av1/common/quant_common.h
@@ -60,11 +60,22 @@
 // Initialize all global quant/dequant matrices.
 void av1_qm_init(struct CommonQuantParams *quant_params, int num_planes);
 
+// Get global dequant matrix.
 const qm_val_t *av1_iqmatrix(const struct CommonQuantParams *quant_params,
                              int qmlevel, int plane, TX_SIZE tx_size);
+// Get global quant matrix.
 const qm_val_t *av1_qmatrix(const struct CommonQuantParams *quant_params,
                             int qmlevel, int plane, TX_SIZE tx_size);
 
+// Get either local / global dequant matrix as appropriate.
+const qm_val_t *av1_get_iqmatrix(const struct CommonQuantParams *quant_params,
+                                 const struct macroblockd *xd, int plane,
+                                 TX_SIZE tx_size, TX_TYPE tx_type);
+// Get either local / global quant matrix as appropriate.
+const qm_val_t *av1_get_qmatrix(const struct CommonQuantParams *quant_params,
+                                const struct macroblockd *xd, int plane,
+                                TX_SIZE tx_size, TX_TYPE tx_type);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index aa08022..24c6031 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -161,11 +161,8 @@
       av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
                       cm->features.reduced_tx_set_used);
   const TX_CLASS tx_class = tx_type_to_class[tx_type];
-  const TX_SIZE qm_tx_size = av1_get_adjusted_tx_size(tx_size);
   const qm_val_t *iqmatrix =
-      IS_2D_TRANSFORM(tx_type)
-          ? pd->seg_iqmatrix[mbmi->segment_id][qm_tx_size]
-          : cm->quant_params.giqmatrix[NUM_QM_LEVELS - 1][0][qm_tx_size];
+      av1_get_iqmatrix(&cm->quant_params, xd, plane, tx_size, tx_type);
   const SCAN_ORDER *const scan_order = get_scan(tx_size, tx_type);
   const int16_t *const scan = scan_order->scan;
   int eob_extra = 0;
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 95b4e03..e3915bf 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -339,24 +339,11 @@
   qparam->iqmatrix = NULL;
 }
 void av1_setup_qmatrix(const CommonQuantParams *quant_params,
-                       const MACROBLOCK *x, int plane, TX_SIZE tx_size,
+                       const MACROBLOCKD *const xd, int plane, TX_SIZE tx_size,
                        TX_TYPE tx_type, QUANT_PARAM *qparam) {
-  const MACROBLOCKD *const xd = &x->e_mbd;
-  const struct macroblockd_plane *const pd = &xd->plane[plane];
-  const MB_MODE_INFO *const mbmi = xd->mi[0];
-  const int seg_id = mbmi->segment_id;
-  const TX_SIZE qm_tx_size = av1_get_adjusted_tx_size(tx_size);
-  // Use a flat matrix (i.e. no weighting) for 1D and Identity transforms
-  const qm_val_t *qmatrix =
-      IS_2D_TRANSFORM(tx_type)
-          ? pd->seg_qmatrix[seg_id][qm_tx_size]
-          : quant_params->gqmatrix[NUM_QM_LEVELS - 1][0][qm_tx_size];
-  const qm_val_t *iqmatrix =
-      IS_2D_TRANSFORM(tx_type)
-          ? pd->seg_iqmatrix[seg_id][qm_tx_size]
-          : quant_params->giqmatrix[NUM_QM_LEVELS - 1][0][qm_tx_size];
-  qparam->qmatrix = qmatrix;
-  qparam->iqmatrix = iqmatrix;
+  qparam->qmatrix = av1_get_qmatrix(quant_params, xd, plane, tx_size, tx_type);
+  qparam->iqmatrix =
+      av1_get_iqmatrix(quant_params, xd, plane, tx_size, tx_type);
 }
 
 static void encode_block(int plane, int block, int blk_row, int blk_col,
@@ -398,7 +385,7 @@
     av1_setup_xform(cm, x, tx_size, tx_type, &txfm_param);
     av1_setup_quant(tx_size, use_trellis, quant_idx, cpi->use_quant_b_adapt,
                     &quant_param);
-    av1_setup_qmatrix(&cm->quant_params, x, plane, tx_size, tx_type,
+    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
                       &quant_param);
     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
                     &quant_param);
@@ -589,7 +576,7 @@
   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
   av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, cpi->use_quant_b_adapt,
                   &quant_param);
-  av1_setup_qmatrix(&cm->quant_params, x, plane, tx_size, DCT_DCT,
+  av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, DCT_DCT,
                     &quant_param);
 
   av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
@@ -731,7 +718,7 @@
     av1_setup_xform(cm, x, tx_size, tx_type, &txfm_param);
     av1_setup_quant(tx_size, use_trellis, quant_idx, cpi->use_quant_b_adapt,
                     &quant_param);
-    av1_setup_qmatrix(&cm->quant_params, x, plane, tx_size, tx_type,
+    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
                       &quant_param);
 
     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
diff --git a/av1/encoder/encodemb.h b/av1/encoder/encodemb.h
index e05f4ce..a337c83 100644
--- a/av1/encoder/encodemb.h
+++ b/av1/encoder/encodemb.h
@@ -69,7 +69,7 @@
 void av1_setup_quant(TX_SIZE tx_size, int use_optimize_b, int xform_quant_idx,
                      int use_quant_b_adapt, QUANT_PARAM *qparam);
 void av1_setup_qmatrix(const CommonQuantParams *quant_params,
-                       const MACROBLOCK *x, int plane, TX_SIZE tx_size,
+                       const MACROBLOCKD *xd, int plane, TX_SIZE tx_size,
                        TX_TYPE tx_type, QUANT_PARAM *qparam);
 
 void av1_xform_quant(MACROBLOCK *x, int plane, int block, int blk_row,
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 697192e..046d66e 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -1750,12 +1750,8 @@
   const int shift = av1_get_tx_scale(tx_size);
   int eob = p->eobs[block];
   const int16_t *dequant = p->dequant_QTX;
-  const TX_SIZE qm_tx_size = av1_get_adjusted_tx_size(tx_size);
   const qm_val_t *iqmatrix =
-      IS_2D_TRANSFORM(tx_type)
-          ? pd->seg_iqmatrix[xd->mi[0]->segment_id][qm_tx_size]
-          : cpi->common.quant_params
-                .giqmatrix[NUM_QM_LEVELS - 1][0][qm_tx_size];
+      av1_get_iqmatrix(&cpi->common.quant_params, xd, plane, tx_size, tx_type);
   const int block_offset = BLOCK_OFFSET(block);
   tran_low_t *qcoeff = p->qcoeff + block_offset;
   tran_low_t *dqcoeff = pd->dqcoeff + block_offset;
@@ -1948,11 +1944,8 @@
       2;
   uint8_t levels_buf[TX_PAD_2D];
   uint8_t *const levels = set_levels(levels_buf, width);
-  const TX_SIZE qm_tx_size = av1_get_adjusted_tx_size(tx_size);
   const qm_val_t *iqmatrix =
-      IS_2D_TRANSFORM(tx_type)
-          ? pd->seg_iqmatrix[mbmi->segment_id][qm_tx_size]
-          : cm->quant_params.giqmatrix[NUM_QM_LEVELS - 1][0][qm_tx_size];
+      av1_get_iqmatrix(&cpi->common.quant_params, xd, plane, tx_size, tx_type);
   assert(width == (1 << bwl));
   const int tx_type_cost =
       get_tx_type_cost(x, xd, plane, tx_size, tx_type, reduced_tx_set_used);
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 48cc47b..3129b3a 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -1102,7 +1102,7 @@
                                                     : AV1_XFORM_QUANT_FP)
                           : AV1_XFORM_QUANT_FP,
                       cpi->use_quant_b_adapt, &quant_param_intra);
-      av1_setup_qmatrix(&cm->quant_params, x, plane, tx_size, best_tx_type,
+      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
                         &quant_param_intra);
       av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
                       &txfm_param_intra, &quant_param_intra);
@@ -2208,7 +2208,7 @@
     if (!(allowed_tx_mask & (1 << tx_type))) continue;
     txfm_param.tx_type = tx_type;
     if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
-      av1_setup_qmatrix(&cm->quant_params, x, plane, tx_size, tx_type,
+      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
                         &quant_param);
     }
     if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
