lgt-from-pred: transforms based on prediction
In this experiment, sharp image discontinuity in the predicted
block is detected. Based on this discontinuity, we choose
particular LGTs as row and column transforms.
Bitstream syntax, entropy coding, and RD search for LGT are added.
One binary symbol is used to signal whether LGT is used. This
experiment can work independently with the lgt experiment.
lowres: -0.414% for key frames, -0.151% overall
midres: -0.413% for key frames, -0.161% overall
Change-Id: Iaa2f2c2839c34ca4134fa55e77870dc3f1fa879f
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index f5bf46b..bfe040b 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -456,6 +456,10 @@
add_proto qw/void av1_quantize_b/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan, int log_scale";
}
+ if (aom_config("CONFIG_LGT_FROM_PRED") eq "yes") {
+ add_proto qw/void flgt2d_from_pred/, "const int16_t *input, tran_low_t *output, int stride, struct txfm_param *param";
+ }
+
if (aom_config("CONFIG_HIGHBITDEPTH") eq "yes") {
# ENCODEMB INVOKE
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index ffd3680..30ad337 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -385,6 +385,9 @@
#if CONFIG_TXK_SEL
TX_TYPE txk_type[MAX_SB_SQUARE / (TX_SIZE_W_MIN * TX_SIZE_H_MIN)];
#endif
+#if CONFIG_LGT_FROM_PRED
+ int use_lgt;
+#endif
#if CONFIG_FILTER_INTRA
FILTER_INTRA_MODE_INFO filter_intra_mode_info;
@@ -1053,6 +1056,36 @@
return av1_num_ext_tx_set[set_type];
}
+#if CONFIG_LGT_FROM_PRED
+static INLINE int is_lgt_allowed(PREDICTION_MODE mode, TX_SIZE tx_size) {
+ if (!LGT_FROM_PRED_INTRA && !is_inter_mode(mode)) return 0;
+ if (!LGT_FROM_PRED_INTER && is_inter_mode(mode)) return 0;
+
+ switch (mode) {
+ case D45_PRED:
+ case D63_PRED:
+ case D117_PRED:
+ case V_PRED:
+#if CONFIG_SMOOTH_HV
+ case SMOOTH_V_PRED:
+#endif
+ return tx_size_wide[tx_size] <= 8;
+ case D135_PRED:
+ case D153_PRED:
+ case D207_PRED:
+ case H_PRED:
+#if CONFIG_SMOOTH_HV
+ case SMOOTH_H_PRED:
+#endif
+ return tx_size_high[tx_size] <= 8;
+ case DC_PRED:
+ case SMOOTH_PRED: return 0;
+ case TM_PRED:
+ default: return tx_size_wide[tx_size] <= 8 || tx_size_high[tx_size] <= 8;
+ }
+}
+#endif // CONFIG_LGT_FROM_PRED
+
#if CONFIG_RECT_TX
static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) {
static const char LUT[BLOCK_SIZES_ALL] = {
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 600d693..207f1e2 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -2653,6 +2653,23 @@
};
#endif
+#if CONFIG_LGT_FROM_PRED
+static const aom_prob default_intra_lgt_prob[LGT_SIZES][INTRA_MODES] = {
+ { 255, 208, 208, 180, 230, 208, 194, 214, 220, 255,
+#if CONFIG_SMOOTH_HV
+ 220, 220,
+#endif
+ 230 },
+ { 255, 192, 216, 180, 180, 180, 180, 200, 200, 255,
+#if CONFIG_SMOOTH_HV
+ 220, 220,
+#endif
+ 222 },
+};
+
+static const aom_prob default_inter_lgt_prob[LGT_SIZES] = { 230, 230 };
+#endif // CONFIG_LGT_FROM_PRED
+
#if CONFIG_EXT_INTRA && CONFIG_INTRA_INTERP
static const aom_prob
default_intra_filter_probs[INTRA_FILTERS + 1][INTRA_FILTERS - 1] = {
@@ -5798,6 +5815,10 @@
#if CONFIG_FILTER_INTRA
av1_copy(fc->filter_intra_probs, default_filter_intra_probs);
#endif // CONFIG_FILTER_INTRA
+#if CONFIG_LGT_FROM_PRED
+ av1_copy(fc->intra_lgt_prob, default_intra_lgt_prob);
+ av1_copy(fc->inter_lgt_prob, default_inter_lgt_prob);
+#endif // CONFIG_LGT_FROM_PRED
#if CONFIG_LOOP_RESTORATION
av1_copy(fc->switchable_restore_prob, default_switchable_restore_prob);
#endif // CONFIG_LOOP_RESTORATION
@@ -6005,6 +6026,23 @@
fc->skip_probs[i] =
av1_mode_mv_merge_probs(pre_fc->skip_probs[i], counts->skip[i]);
+#if CONFIG_LGT_FROM_PRED
+ int j;
+ if (LGT_FROM_PRED_INTRA) {
+ for (i = TX_4X4; i < LGT_SIZES; ++i) {
+ for (j = 0; j < INTRA_MODES; ++j)
+ fc->intra_lgt_prob[i][j] = av1_mode_mv_merge_probs(
+ pre_fc->intra_lgt_prob[i][j], counts->intra_lgt[i][j]);
+ }
+ }
+ if (LGT_FROM_PRED_INTER) {
+ for (i = TX_4X4; i < LGT_SIZES; ++i) {
+ fc->inter_lgt_prob[i] = av1_mode_mv_merge_probs(pre_fc->inter_lgt_prob[i],
+ counts->inter_lgt[i]);
+ }
+ }
+#endif // CONFIG_LGT_FROM_PRED
+
if (cm->seg.temporal_update) {
for (i = 0; i < PREDICTION_PROBS; i++)
fc->seg.pred_probs[i] = av1_mode_mv_merge_probs(pre_fc->seg.pred_probs[i],
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index e5c8f9d..3452241 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -386,6 +386,10 @@
aom_cdf_prob intra_ext_tx_cdf[EXT_TX_SIZES][TX_TYPES][CDF_SIZE(TX_TYPES)];
aom_cdf_prob inter_ext_tx_cdf[EXT_TX_SIZES][CDF_SIZE(TX_TYPES)];
#endif // CONFIG_EXT_TX
+#if CONFIG_LGT_FROM_PRED
+ aom_prob intra_lgt_prob[LGT_SIZES][INTRA_MODES];
+ aom_prob inter_lgt_prob[LGT_SIZES];
+#endif // CONFIG_LGT_FROM_PRED
#if CONFIG_EXT_INTRA && CONFIG_INTRA_INTERP
aom_cdf_prob intra_filter_cdf[INTRA_FILTERS + 1][CDF_SIZE(INTRA_FILTERS)];
#endif // CONFIG_EXT_INTRA && CONFIG_INTRA_INTERP
@@ -528,6 +532,10 @@
unsigned int intrabc[2];
nmv_context_counts dv;
#endif
+#if CONFIG_LGT_FROM_PRED
+ unsigned int intra_lgt[LGT_SIZES][INTRA_MODES][2];
+ unsigned int inter_lgt[LGT_SIZES][2];
+#endif // CONFIG_LGT_FROM_PRED
unsigned int delta_q[DELTA_Q_PROBS][2];
#if CONFIG_EXT_DELTA_Q
#if CONFIG_LOOPFILTER_LEVEL
diff --git a/av1/common/enums.h b/av1/common/enums.h
index b60e0d8..e8c4003 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -771,6 +771,15 @@
} OBU_TYPE;
#endif
+#if CONFIG_LGT_FROM_PRED
+#define LGT_SIZES 2
+// Note: at least one of LGT_FROM_PRED_INTRA and LGT_FROM_PRED_INTER must be 1
+#define LGT_FROM_PRED_INTRA 1
+#define LGT_FROM_PRED_INTER 1
+// LGT_SL_INTRA: LGTs with a mode-dependent first self-loop and a break point
+#define LGT_SL_INTRA 0
+#endif // CONFIG_LGT_FROM_PRED
+
#ifdef __cplusplus
} // extern "C"
#endif
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 56019cc..53c2ba1 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -205,10 +205,21 @@
#endif // CONFIG_EXT_TX && CONFIG_TX64X64
#endif // CONFIG_HIGHBITDEPTH
-#if CONFIG_LGT
+#if CONFIG_LGT || CONFIG_LGT_FROM_PRED
void ilgt4(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
if (!lgtmtx) assert(0);
+#if CONFIG_LGT_FROM_PRED
+ // For DCT/ADST, use butterfly implementations
+ if (lgtmtx[0] == DCT4) {
+ aom_idct4_c(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST4) {
+ aom_iadst4_c(input, output);
+ return;
+ }
+#endif // CONFIG_LGT_FROM_PRED
+
// evaluate s[j] = sum of all lgtmtx[j]*input[i] over i=1,...,4
tran_high_t s[4] = { 0 };
for (int i = 0; i < 4; ++i)
@@ -220,6 +231,17 @@
void ilgt8(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
if (!lgtmtx) assert(0);
+#if CONFIG_LGT_FROM_PRED
+ // For DCT/ADST, use butterfly implementations
+ if (lgtmtx[0] == DCT8) {
+ aom_idct8_c(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST8) {
+ aom_iadst8_c(input, output);
+ return;
+ }
+#endif // CONFIG_LGT_FROM_PRED
+
// evaluate s[j] = sum of all lgtmtx[j]*input[i] over i=1,...,8
tran_high_t s[8] = { 0 };
for (int i = 0; i < 8; ++i)
@@ -227,7 +249,9 @@
for (int i = 0; i < 8; ++i) output[i] = WRAPLOW(dct_const_round_shift(s[i]));
}
+#endif // CONFIG_LGT || CONFIG_LGT_FROM_PRED
+#if CONFIG_LGT
// get_lgt4 and get_lgt8 return 1 and pick a lgt matrix if LGT is chosen to
// apply. Otherwise they return 0
int get_lgt4(const TxfmParam *txfm_param, int is_col,
@@ -261,6 +285,427 @@
}
#endif // CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
+void ilgt16up(const tran_low_t *input, tran_low_t *output,
+ const tran_high_t *lgtmtx) {
+ if (lgtmtx[0] == DCT16) {
+ aom_idct16_c(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST16) {
+ aom_iadst16_c(input, output);
+ return;
+ } else if (lgtmtx[0] == DCT32) {
+ aom_idct32_c(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST32) {
+ ihalfright32_c(input, output);
+ return;
+ } else {
+ assert(0);
+ }
+}
+
+void get_discontinuity_1d(uint8_t *arr, int n, int *idx_max_diff) {
+ *idx_max_diff = -1;
+
+ int temp = 0, max_diff = 0, min_diff = INT_MAX;
+ for (int i = 1; i < n; ++i) {
+ temp = abs(arr[i] - arr[i - 1]);
+ if (temp > max_diff) {
+ max_diff = temp;
+ *idx_max_diff = i;
+ }
+ if (temp < min_diff) min_diff = temp;
+ }
+}
+
+void get_discontinuity_2d(uint8_t *dst, int stride, int n, int is_col,
+ int *idx_max_diff, int ntx) {
+ *idx_max_diff = -1;
+
+ int diff = 0, temp = 0, max_diff = 0, min_diff = INT_MAX;
+ for (int i = 1; i < n; ++i) {
+ temp = 0;
+ for (int j = 0; j < ntx; ++j) {
+ if (is_col) // vertical diff
+ diff = dst[i * stride + j] - dst[(i - 1) * stride + j];
+ else // horizontal diff
+ diff = dst[j * stride + i] - dst[j * stride + i - 1];
+ temp += diff * diff;
+ }
+ // temp/w is the i-th avg square diff
+ if (temp > max_diff) {
+ max_diff = temp;
+ *idx_max_diff = i;
+ }
+ if (temp < min_diff) min_diff = temp;
+ }
+}
+
+int idx_selfloop_wrt_mode(PREDICTION_MODE mode, int is_col) {
+ // 0: no self-loop
+ // 1: small self-loop
+ // 2: medium self-loop
+ // 3: large self-loop
+ switch (mode) {
+ case DC_PRED:
+ case SMOOTH_PRED:
+ // predition is good for both directions: large SLs for row and col
+ return 3;
+ case TM_PRED: return 0;
+#if CONFIG_SMOOTH_HV
+ case SMOOTH_H_PRED:
+#endif
+ case H_PRED:
+ // prediction is good for H direction: large SL for row only
+ return is_col ? 0 : 3;
+#if CONFIG_SMOOTH_HV
+ case SMOOTH_V_PRED:
+#endif
+ case V_PRED:
+ // prediction is good for V direction: large SL for col only
+ return is_col ? 3 : 0;
+#if LGT_SL_INTRA
+ // directional mode: choose SL based on the direction
+ case D45_PRED: return is_col ? 2 : 0;
+ case D63_PRED: return is_col ? 3 : 0;
+ case D117_PRED: return is_col ? 3 : 1;
+ case D135_PRED: return 2;
+ case D153_PRED: return is_col ? 1 : 3;
+ case D207_PRED: return is_col ? 0 : 3;
+#else
+ case D45_PRED:
+ case D63_PRED:
+ case D117_PRED: return is_col ? 3 : 0;
+ case D135_PRED:
+ case D153_PRED:
+ case D207_PRED: return is_col ? 0 : 3;
+#endif
+ // inter: no SL
+ default: return 0;
+ }
+}
+
+void get_lgt4_from_pred(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t **lgtmtx, int ntx) {
+ PREDICTION_MODE mode = txfm_param->mode;
+ int stride = txfm_param->stride;
+ uint8_t *dst = txfm_param->dst;
+ int bp = -1;
+ uint8_t arr[4];
+
+ // Each lgt4mtx_arr[k][i] corresponds to a line graph with a self-loop on
+ // the first node, and possibly a weak edge within the line graph. i is
+ // the index of the weak edge (between the i-th and (i+1)-th pixels, i=0
+ // means no weak edge). k corresponds to the first self-loop's weight
+ const tran_high_t *lgt4mtx_arr[4][4] = {
+ { &lgt4_000[0][0], &lgt4_000w1[0][0], &lgt4_000w2[0][0],
+ &lgt4_000w3[0][0] },
+ { &lgt4_060[0][0], &lgt4_060_000w1[0][0], &lgt4_060_000w2[0][0],
+ &lgt4_060_000w3[0][0] },
+ { &lgt4_100[0][0], &lgt4_100_000w1[0][0], &lgt4_100_000w2[0][0],
+ &lgt4_100_000w3[0][0] },
+ { &lgt4_150[0][0], &lgt4_150_000w1[0][0], &lgt4_150_000w2[0][0],
+ &lgt4_150_000w3[0][0] },
+ };
+
+ // initialize to DCT or some LGTs, and then change later if necessary
+ int idx_sl = idx_selfloop_wrt_mode(mode, is_col);
+ lgtmtx[0] = lgt4mtx_arr[idx_sl][0];
+
+ // find the break point and replace the line graph by the one with a
+ // break point
+ if (mode == DC_PRED || mode == SMOOTH_PRED) {
+ // Do not use break point, since 1) is_left_available and is_top_available
+ // in DC_PRED are not known by txfm_param for now, so accessing
+ // both boundaries anyway may cause a mismatch 2) DC prediciton
+ // typically yields very smooth residues so having the break point
+ // does not usually improve the RD result.
+ return;
+ } else if (mode == TM_PRED) {
+ // TM_PRED: use both 1D top boundary and 1D left boundary
+ if (is_col)
+ for (int i = 0; i < 4; ++i) arr[i] = dst[i * stride];
+ else
+ for (int i = 0; i < 4; ++i) arr[i] = dst[i];
+ get_discontinuity_1d(&arr[0], 4, &bp);
+ } else if (mode == V_PRED) {
+ // V_PRED: use 1D top boundary only
+ if (is_col) return;
+ for (int i = 0; i < 4; ++i) arr[i] = dst[i];
+ get_discontinuity_1d(&arr[0], 4, &bp);
+ } else if (mode == H_PRED) {
+ // H_PRED: use 1D left boundary only
+ if (!is_col) return;
+ for (int i = 0; i < 4; ++i) arr[i] = dst[i * stride];
+ get_discontinuity_1d(&arr[0], 4, &bp);
+#if CONFIG_SMOOTH_HV
+ } else if (mode == SMOOTH_V_PRED) {
+ if (is_col) return;
+ for (int i = 0; i < 4; ++i) arr[i] = dst[-stride + i];
+ get_discontinuity_1d(&arr[0], 4, &bp);
+ } else if (mode == SMOOTH_H_PRED) {
+ if (!is_col) return;
+ for (int i = 0; i < 4; ++i) arr[i] = dst[i * stride - 1];
+ get_discontinuity_1d(&arr[0], 4, &bp);
+#endif
+ } else if (mode == D45_PRED || mode == D63_PRED || mode == D117_PRED) {
+ // directional modes closer to vertical (maybe include D135 later)
+ if (!is_col) get_discontinuity_2d(dst, stride, 4, 0, &bp, ntx);
+ } else if (mode == D135_PRED || mode == D153_PRED || mode == D207_PRED) {
+ // directional modes closer to horizontal
+ if (is_col) get_discontinuity_2d(dst, stride, 4, 1, &bp, ntx);
+ } else if (mode > TM_PRED) {
+ // inter
+ get_discontinuity_2d(dst, stride, 4, is_col, &bp, ntx);
+ }
+
+#if LGT_SL_INTRA
+ if (bp != -1) lgtmtx[0] = lgt4mtx_arr[idx_sl][bp];
+#else
+ if (bp != -1) lgtmtx[0] = lgt4mtx_arr[0][bp];
+#endif
+}
+
+void get_lgt8_from_pred(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t **lgtmtx, int ntx) {
+ PREDICTION_MODE mode = txfm_param->mode;
+ int stride = txfm_param->stride;
+ uint8_t *dst = txfm_param->dst;
+ int bp = -1;
+ uint8_t arr[8];
+
+ const tran_high_t *lgt8mtx_arr[4][8] = {
+ { &lgt8_000[0][0], &lgt8_000w1[0][0], &lgt8_000w2[0][0], &lgt8_000w3[0][0],
+ &lgt8_000w4[0][0], &lgt8_000w5[0][0], &lgt8_000w6[0][0],
+ &lgt8_000w7[0][0] },
+ { &lgt8_060[0][0], &lgt8_060_000w1[0][0], &lgt8_060_000w2[0][0],
+ &lgt8_060_000w3[0][0], &lgt8_060_000w4[0][0], &lgt8_060_000w5[0][0],
+ &lgt8_060_000w6[0][0], &lgt8_060_000w7[0][0] },
+ { &lgt8_100[0][0], &lgt8_100_000w1[0][0], &lgt8_100_000w2[0][0],
+ &lgt8_100_000w3[0][0], &lgt8_100_000w4[0][0], &lgt8_100_000w5[0][0],
+ &lgt8_100_000w6[0][0], &lgt8_100_000w7[0][0] },
+ { &lgt8_150[0][0], &lgt8_150_000w1[0][0], &lgt8_150_000w2[0][0],
+ &lgt8_150_000w3[0][0], &lgt8_150_000w4[0][0], &lgt8_150_000w5[0][0],
+ &lgt8_150_000w6[0][0], &lgt8_150_000w7[0][0] },
+ };
+
+ int idx_sl = idx_selfloop_wrt_mode(mode, is_col);
+ lgtmtx[0] = lgt8mtx_arr[idx_sl][0];
+
+ if (mode == DC_PRED || mode == SMOOTH_PRED) {
+ return;
+ } else if (mode == TM_PRED) {
+ if (is_col)
+ for (int i = 0; i < 8; ++i) arr[i] = dst[i * stride];
+ else
+ for (int i = 0; i < 8; ++i) arr[i] = dst[i];
+ get_discontinuity_1d(&arr[0], 8, &bp);
+ } else if (mode == V_PRED) {
+ if (is_col) return;
+ for (int i = 0; i < 8; ++i) arr[i] = dst[i];
+ get_discontinuity_1d(&arr[0], 8, &bp);
+ } else if (mode == H_PRED) {
+ if (!is_col) return;
+ for (int i = 0; i < 8; ++i) arr[i] = dst[i * stride];
+ get_discontinuity_1d(&arr[0], 8, &bp);
+#if CONFIG_SMOOTH_HV
+ } else if (mode == SMOOTH_V_PRED) {
+ if (is_col) return;
+ for (int i = 0; i < 8; ++i) arr[i] = dst[-stride + i];
+ get_discontinuity_1d(&arr[0], 8, &bp);
+ } else if (mode == SMOOTH_H_PRED) {
+ if (!is_col) return;
+ for (int i = 0; i < 8; ++i) arr[i] = dst[i * stride - 1];
+ get_discontinuity_1d(&arr[0], 8, &bp);
+#endif
+ } else if (mode == D45_PRED || mode == D63_PRED || mode == D117_PRED) {
+ if (!is_col) get_discontinuity_2d(dst, stride, 8, 0, &bp, ntx);
+ } else if (mode == D135_PRED || mode == D153_PRED || mode == D207_PRED) {
+ if (is_col) get_discontinuity_2d(dst, stride, 8, 1, &bp, ntx);
+ } else if (mode > TM_PRED) {
+ get_discontinuity_2d(dst, stride, 8, is_col, &bp, ntx);
+ }
+
+#if LGT_SL_INTRA
+ if (bp != -1) lgtmtx[0] = lgt8mtx_arr[idx_sl][bp];
+#else
+ if (bp != -1) lgtmtx[0] = lgt8mtx_arr[0][bp];
+#endif
+}
+
+// Since LGTs with length >8 are not implemented now, the following function
+// will just call DCT or ADST
+void get_lgt16up_from_pred(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t **lgtmtx, int ntx) {
+ int tx_length = is_col ? tx_size_high[txfm_param->tx_size]
+ : tx_size_wide[txfm_param->tx_size];
+ assert(tx_length == 16 || tx_length == 32);
+ PREDICTION_MODE mode = txfm_param->mode;
+
+ (void)ntx;
+ const tran_high_t *dctmtx =
+ tx_length == 16 ? &lgt16_000[0][0] : &lgt32_000[0][0];
+ const tran_high_t *adstmtx =
+ tx_length == 16 ? &lgt16_200[0][0] : &lgt32_200[0][0];
+
+ switch (mode) {
+ case DC_PRED:
+ case TM_PRED:
+ case SMOOTH_PRED:
+ // prediction from both top and left -> ADST
+ lgtmtx[0] = adstmtx;
+ break;
+ case V_PRED:
+ case D45_PRED:
+ case D63_PRED:
+ case D117_PRED:
+#if CONFIG_SMOOTH_HV
+ case SMOOTH_V_PRED:
+#endif
+ // prediction from the top more than from the left -> ADST
+ lgtmtx[0] = is_col ? adstmtx : dctmtx;
+ break;
+ case H_PRED:
+ case D135_PRED:
+ case D153_PRED:
+ case D207_PRED:
+#if CONFIG_SMOOTH_HV
+ case SMOOTH_H_PRED:
+#endif
+ // prediction from the left more than from the top -> DCT
+ lgtmtx[0] = is_col ? dctmtx : adstmtx;
+ break;
+ default: lgtmtx[0] = dctmtx; break;
+ }
+}
+
+typedef void (*IlgtFunc)(const tran_low_t *input, tran_low_t *output,
+ const tran_high_t *lgtmtx);
+
+static IlgtFunc ilgt_func[4] = { ilgt4, ilgt8, ilgt16up, ilgt16up };
+
+typedef void (*GetLgtFunc)(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t **lgtmtx, int ntx);
+
+static GetLgtFunc get_lgt_func[4] = { get_lgt4_from_pred, get_lgt8_from_pred,
+ get_lgt16up_from_pred,
+ get_lgt16up_from_pred };
+
+// this inline function corresponds to the up scaling before the transpose
+// operation in the av1_iht* functions
+static INLINE tran_low_t inv_upscale_wrt_txsize(const tran_high_t val,
+ const TX_SIZE tx_size) {
+ switch (tx_size) {
+ case TX_4X4:
+ case TX_8X8:
+ case TX_4X16:
+ case TX_16X4:
+ case TX_8X32:
+ case TX_32X8: return (tran_low_t)val;
+ case TX_4X8:
+ case TX_8X4:
+ case TX_8X16:
+ case TX_16X8: return (tran_low_t)dct_const_round_shift(val * Sqrt2);
+ default: assert(0); break;
+ }
+ return 0;
+}
+
+// This inline function corresponds to the bit shift before summing with the
+// destination in the av1_iht* functions
+static INLINE tran_low_t inv_downscale_wrt_txsize(const tran_low_t val,
+ const TX_SIZE tx_size) {
+ switch (tx_size) {
+ case TX_4X4: return ROUND_POWER_OF_TWO(val, 4);
+ case TX_4X8:
+ case TX_8X4:
+ case TX_8X8:
+ case TX_4X16:
+ case TX_16X4: return ROUND_POWER_OF_TWO(val, 5);
+ case TX_8X16:
+ case TX_16X8:
+ case TX_8X32:
+ case TX_32X8: return ROUND_POWER_OF_TWO(val, 6);
+ default: assert(0); break;
+ }
+ return 0;
+}
+
+void ilgt2d_from_pred_add(const tran_low_t *input, uint8_t *dest, int stride,
+ const TxfmParam *txfm_param) {
+ const TX_SIZE tx_size = txfm_param->tx_size;
+ const int w = tx_size_wide[tx_size];
+ const int h = tx_size_high[tx_size];
+ const int wlog2 = tx_size_wide_log2[tx_size];
+ const int hlog2 = tx_size_high_log2[tx_size];
+ assert(w <= 8 || h <= 8);
+
+ int i, j;
+ // largest 1D size allowed for LGT: 32
+ // largest 2D size allowed for LGT: 8x32=256
+ tran_low_t tmp[256], out[256], temp1d[32];
+ const tran_high_t *lgtmtx_col[1];
+ const tran_high_t *lgtmtx_row[1];
+ get_lgt_func[hlog2 - 2](txfm_param, 1, lgtmtx_col, w);
+ get_lgt_func[wlog2 - 2](txfm_param, 0, lgtmtx_row, h);
+
+// for inverse transform, to be consistent with av1_iht functions, we always
+// apply row transforms first and column transforms second, but both
+// row-first and column-first versions are implemented here for future
+// tests (use different lgtmtx_col[i], and choose row or column tx first
+// depending on transforms).
+#if 1
+ // inverse column transforms
+ for (i = 0; i < w; ++i) {
+ // transpose
+ for (j = 0; j < h; ++j) tmp[i * h + j] = input[j * w + i];
+ ilgt_func[hlog2 - 2](&tmp[i * h], temp1d, lgtmtx_col[0]);
+ // upscale, and store in place
+ for (j = 0; j < h; ++j)
+ tmp[i * h + j] = inv_upscale_wrt_txsize(temp1d[j], tx_size);
+ }
+ // inverse row transforms
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) temp1d[j] = tmp[j * h + i];
+ ilgt_func[wlog2 - 2](temp1d, &out[i * w], lgtmtx_row[0]);
+ }
+ // downscale + sum with the destination
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) {
+ int d = i * stride + j;
+ int s = i * w + j;
+ dest[d] =
+ clip_pixel_add(dest[d], inv_downscale_wrt_txsize(out[s], tx_size));
+ }
+ }
+#else
+ // inverse row transforms
+ for (i = 0; i < h; ++i) {
+ ilgt_func[wlog2 - 2](input, temp1d, lgtmtx_row[0]);
+ // upscale and transpose (tmp[j*h+i] <--> tmp[j][i])
+ for (j = 0; j < w; ++j)
+ tmp[j * h + i] = inv_upscale_wrt_txsize(temp1d[j], tx_size);
+ input += w;
+ }
+ // inverse column transforms
+ for (i = 0; i < w; ++i)
+ ilgt_func[hlog2 - 2](&tmp[i * h], &out[i * h], lgtmtx_col[0]);
+ // here, out[] is the transpose of 2D block of transform coefficients
+
+ // downscale + transform + sum with dest
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) {
+ int d = i * stride + j;
+ int s = j * h + i;
+ dest[d] =
+ clip_pixel_add(dest[d], inv_downscale_wrt_txsize(out[s], tx_size));
+ }
+ }
+#endif
+}
+#endif // CONFIG_LGT_FROM_PRED
+
void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
const TX_TYPE tx_type = txfm_param->tx_type;
@@ -2453,6 +2898,13 @@
void av1_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
TxfmParam *txfm_param) {
const TX_SIZE tx_size = txfm_param->tx_size;
+#if CONFIG_LGT_FROM_PRED
+ if (txfm_param->use_lgt) {
+ assert(is_lgt_allowed(txfm_param->mode, tx_size));
+ ilgt2d_from_pred_add(input, dest, stride, txfm_param);
+ return;
+ }
+#endif // CONFIG_LGT_FROM_PRED
switch (tx_size) {
#if CONFIG_TX64X64
case TX_64X64: inv_txfm_add_64x64(input, dest, stride, txfm_param); break;
@@ -2499,6 +2951,9 @@
#if CONFIG_LGT
txfm_param->is_inter = is_inter_block(&xd->mi[0]->mbmi);
#endif
+#if CONFIG_LGT_FROM_PRED
+ txfm_param->use_lgt = xd->mi[0]->mbmi.use_lgt;
+#endif
#if CONFIG_ADAPT_SCAN
txfm_param->eob_threshold =
(const int16_t *)&xd->eob_threshold_md[tx_size][tx_type][0];
@@ -2515,7 +2970,7 @@
void av1_inverse_transform_block(const MACROBLOCKD *xd,
const tran_low_t *dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
PREDICTION_MODE mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -2541,15 +2996,17 @@
init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
#if CONFIG_LGT || CONFIG_MRC_TX
txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
- txfm_param.dst = dst;
- txfm_param.stride = stride;
+#endif // CONFIG_LGT || CONFIG_MRC_TX
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
txfm_param.mask = mrc_mask;
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED || CONFIG_MRC_TX
+ txfm_param.dst = dst;
+ txfm_param.stride = stride;
+#if CONFIG_LGT_FROM_PRED
txfm_param.mode = mode;
-#endif // CONFIG_LGT
-#endif // CONFIG_LGT || CONFIG_MRC_TX
+#endif // CONFIG_LGT_FROM_PRED
+#endif // CONFIG_LGT_FROM_PRED || CONFIG_MRC_TX
const int is_hbd = get_bitdepth_data_path_index(xd);
#if CONFIG_TXMG
@@ -2595,9 +3052,9 @@
uint8_t *dst =
&pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
av1_inverse_transform_block(xd, dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
xd->mi[0]->mbmi.mode,
-#endif // CONFIG_LGT
+#endif // CONFIG_LGT_FROM_PRED
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
diff --git a/av1/common/idct.h b/av1/common/idct.h
index 0859a75..e4e4ad6 100644
--- a/av1/common/idct.h
+++ b/av1/common/idct.h
@@ -39,6 +39,15 @@
const tran_high_t **lgtmtx);
#endif // CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
+void get_lgt4_from_pred(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t **lgtmtx, int ntx);
+void get_lgt8_from_pred(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t **lgtmtx, int ntx);
+void get_lgt16up_from_pred(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t **lgtmtx, int ntx);
+#endif // CONFIG_LGT_FROM_PRED
+
#if CONFIG_HIGHBITDEPTH
typedef void (*highbd_transform_1d)(const tran_low_t *, tran_low_t *, int bd);
@@ -59,7 +68,7 @@
TxfmParam *txfm_param);
void av1_inverse_transform_block(const MACROBLOCKD *xd,
const tran_low_t *dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
PREDICTION_MODE mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
diff --git a/av1/common/scan.h b/av1/common/scan.h
index 4c8dd50..82d2e91 100644
--- a/av1/common/scan.h
+++ b/av1/common/scan.h
@@ -109,6 +109,9 @@
// use the DCT_DCT scan order for MRC_DCT for now
if (tx_type == MRC_DCT) tx_type = DCT_DCT;
#endif // CONFIG_MRC_TX
+#if CONFIG_LGT_FROM_PRED
+ if (mbmi->use_lgt) tx_type = DCT_DCT;
+#endif
const int is_inter = is_inter_block(mbmi);
#if CONFIG_ADAPT_SCAN
(void)mbmi;
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 1b047e9..4f14a28 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -268,7 +268,7 @@
#endif
static void inverse_transform_block(MACROBLOCKD *xd, int plane,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
PREDICTION_MODE mode,
#endif
const TX_TYPE tx_type,
@@ -277,7 +277,7 @@
struct macroblockd_plane *const pd = &xd->plane[plane];
tran_low_t *const dqcoeff = pd->dqcoeff;
av1_inverse_transform_block(xd, dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -510,7 +510,7 @@
uint8_t *dst =
&pd->dst.buf[(row * pd->dst.stride + col) << tx_size_wide_log2[0]];
inverse_transform_block(xd, plane,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mbmi->mode,
#endif
tx_type, tx_size, dst, pd->dst.stride,
@@ -568,7 +568,7 @@
&max_scan_line, r, mbmi->segment_id);
#endif // CONFIG_LV_MAP
inverse_transform_block(xd, plane,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mbmi->mode,
#endif
tx_type, plane_tx_size,
@@ -656,7 +656,7 @@
&pd->dst.buf[(row * pd->dst.stride + col) << tx_size_wide_log2[0]];
if (eob)
inverse_transform_block(xd, plane,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
xd->mi[0]->mbmi.mode,
#endif
tx_type, tx_size, dst, pd->dst.stride,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 96c06b6..1215aa2 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -994,6 +994,9 @@
(void)block;
TX_TYPE *tx_type = &mbmi->txk_type[(blk_row << 4) + blk_col];
#endif
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = 0;
+#endif
if (!FIXED_TX_TYPE) {
#if CONFIG_EXT_TX
@@ -1014,6 +1017,8 @@
// eset == 0 should correspond to a set with only DCT_DCT and
// there is no need to read the tx_type
assert(eset != 0);
+
+#if !CONFIG_LGT_FROM_PRED
if (inter_block) {
*tx_type = av1_ext_tx_inv[tx_set_type][aom_read_symbol(
r, ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
@@ -1023,10 +1028,73 @@
r, ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
av1_num_ext_tx_set[tx_set_type], ACCT_STR)];
}
+#else
+ // only signal tx_type when lgt is not allowed or not selected
+ if (inter_block) {
+ if (LGT_FROM_PRED_INTER) {
+ if (is_lgt_allowed(mbmi->mode, tx_size) && !cm->reduced_tx_set_used) {
+ mbmi->use_lgt =
+ aom_read(r, ec_ctx->inter_lgt_prob[square_tx_size], ACCT_STR);
+#if CONFIG_ENTROPY_STATS
+ if (counts) ++counts->inter_lgt[square_tx_size][mbmi->use_lgt];
+#endif // CONFIG_ENTROPY_STATS
+ }
+ if (!mbmi->use_lgt) {
+ *tx_type = av1_ext_tx_inv[tx_set_type][aom_read_symbol(
+ r, ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
+ av1_num_ext_tx_set[tx_set_type], ACCT_STR)];
+#if CONFIG_ENTROPY_STATS
+ if (counts) ++counts->inter_ext_tx[eset][square_tx_size][*tx_type];
+#endif // CONFIG_ENTROPY_STATS
+ } else {
+ *tx_type = DCT_DCT; // assign a dummy tx_type
+ }
+ } else {
+ *tx_type = av1_ext_tx_inv[tx_set_type][aom_read_symbol(
+ r, ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
+ av1_num_ext_tx_set[tx_set_type], ACCT_STR)];
+#if CONFIG_ENTROPY_STATS
+ if (counts) ++counts->inter_ext_tx[eset][square_tx_size][*tx_type];
+#endif // CONFIG_ENTROPY_STATS
+ }
+ } else if (ALLOW_INTRA_EXT_TX) {
+ if (LGT_FROM_PRED_INTRA) {
+ if (is_lgt_allowed(mbmi->mode, tx_size) && !cm->reduced_tx_set_used) {
+ mbmi->use_lgt =
+ aom_read(r, ec_ctx->intra_lgt_prob[square_tx_size][mbmi->mode],
+ ACCT_STR);
+#if CONFIG_ENTROPY_STATS
+ if (counts)
+ ++counts->intra_lgt[square_tx_size][mbmi->mode][mbmi->use_lgt];
+#endif // CONFIG_ENTROPY_STATS
+ }
+ if (!mbmi->use_lgt) {
+ *tx_type = av1_ext_tx_inv[tx_set_type][aom_read_symbol(
+ r, ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
+ av1_num_ext_tx_set[tx_set_type], ACCT_STR)];
+#if CONFIG_ENTROPY_STATS
+ if (counts)
+ ++counts
+ ->intra_ext_tx[eset][square_tx_size][mbmi->mode][*tx_type];
+#endif // CONFIG_ENTROPY_STATS
+ } else {
+ *tx_type = DCT_DCT; // assign a dummy tx_type
+ }
+ } else {
+ *tx_type = av1_ext_tx_inv[tx_set_type][aom_read_symbol(
+ r, ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
+ av1_num_ext_tx_set[tx_set_type], ACCT_STR)];
+#if CONFIG_ENTROPY_STATS
+ if (counts)
+ ++counts->intra_ext_tx[eset][square_tx_size][mbmi->mode][*tx_type];
+#endif // CONFIG_ENTROPY_STATS
+ }
+ }
+#endif // CONFIG_LGT_FROM_PRED
} else {
*tx_type = DCT_DCT;
}
-#else
+#else // CONFIG_EXT_TX
if (tx_size < TX_32X32 &&
((!cm->seg.enabled && cm->base_qindex > 0) ||
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 33db1ae..90dca1e 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1581,6 +1581,7 @@
// is no need to send the tx_type
assert(eset > 0);
assert(av1_ext_tx_used[tx_set_type][tx_type]);
+#if !CONFIG_LGT_FROM_PRED
if (is_inter) {
aom_write_symbol(w, av1_ext_tx_ind[tx_set_type][tx_type],
ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
@@ -1591,8 +1592,41 @@
ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
av1_num_ext_tx_set[tx_set_type]);
}
- }
#else
+ // only signal tx_type when lgt is not allowed or not selected
+ if (is_inter) {
+ if (LGT_FROM_PRED_INTER) {
+ if (is_lgt_allowed(mbmi->mode, tx_size) && !cm->reduced_tx_set_used)
+ aom_write(w, mbmi->use_lgt, ec_ctx->inter_lgt_prob[square_tx_size]);
+ if (!mbmi->use_lgt)
+ aom_write_symbol(w, av1_ext_tx_ind[tx_set_type][tx_type],
+ ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
+ av1_num_ext_tx_set[tx_set_type]);
+ } else {
+ aom_write_symbol(w, av1_ext_tx_ind[tx_set_type][tx_type],
+ ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
+ av1_num_ext_tx_set[tx_set_type]);
+ }
+ } else if (ALLOW_INTRA_EXT_TX) {
+ if (LGT_FROM_PRED_INTRA) {
+ if (is_lgt_allowed(mbmi->mode, tx_size) && !cm->reduced_tx_set_used)
+ aom_write(w, mbmi->use_lgt,
+ ec_ctx->intra_lgt_prob[square_tx_size][mbmi->mode]);
+ if (!mbmi->use_lgt)
+ aom_write_symbol(
+ w, av1_ext_tx_ind[tx_set_type][tx_type],
+ ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
+ av1_num_ext_tx_set[tx_set_type]);
+ } else {
+ aom_write_symbol(
+ w, av1_ext_tx_ind[tx_set_type][tx_type],
+ ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
+ av1_num_ext_tx_set[tx_set_type]);
+ }
+ }
+#endif // CONFIG_LGT_FROM_PRED
+ }
+#else // CONFIG_EXT_TX
if (tx_size < TX_32X32 &&
((!cm->seg.enabled && cm->base_qindex > 0) ||
(cm->seg.enabled && xd->qindex[mbmi->segment_id] > 0)) &&
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 3b1672a..8b66278 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -283,6 +283,10 @@
#endif // CONFIG_CFL
int tx_size_cost[TX_SIZES - 1][TX_SIZE_CONTEXTS][TX_SIZES];
#if CONFIG_EXT_TX
+#if CONFIG_LGT_FROM_PRED
+ int intra_lgt_cost[LGT_SIZES][INTRA_MODES][2];
+ int inter_lgt_cost[LGT_SIZES][2];
+#endif
int inter_tx_type_costs[EXT_TX_SETS_INTER][EXT_TX_SIZES][TX_TYPES];
int intra_tx_type_costs[EXT_TX_SETS_INTRA][EXT_TX_SIZES][INTRA_MODES]
[TX_TYPES];
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index c91e289..a04d46b 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1070,10 +1070,20 @@
}
#endif // CONFIG_MRC_TX
-#if CONFIG_LGT
+#if CONFIG_LGT || CONFIG_LGT_FROM_PRED
static void flgt4(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
if (!lgtmtx) assert(0);
+#if CONFIG_LGT_FROM_PRED
+ // For DCT/ADST, use butterfly implementations
+ if (lgtmtx[0] == DCT4) {
+ fdct4(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST4) {
+ fadst4(input, output);
+ return;
+ }
+#endif // CONFIG_LGT_FROM_PRED
// evaluate s[j] = sum of all lgtmtx[j][i]*input[i] over i=1,...,4
tran_high_t s[4] = { 0 };
@@ -1086,6 +1096,16 @@
static void flgt8(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
if (!lgtmtx) assert(0);
+#if CONFIG_LGT_FROM_PRED
+ // For DCT/ADST, use butterfly implementations
+ if (lgtmtx[0] == DCT8) {
+ fdct8(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST8) {
+ fadst8(input, output);
+ return;
+ }
+#endif // CONFIG_LGT_FROM_PRED
// evaluate s[j] = sum of all lgtmtx[j][i]*input[i] over i=1,...,8
tran_high_t s[8] = { 0 };
@@ -1094,7 +1114,140 @@
for (int i = 0; i < 8; ++i) output[i] = (tran_low_t)fdct_round_shift(s[i]);
}
-#endif // CONFIG_LGT
+#endif // CONFIG_LGT || CONFIG_LGT_FROM_PRED
+
+#if CONFIG_LGT_FROM_PRED
+static void flgt16up(const tran_low_t *input, tran_low_t *output,
+ const tran_high_t *lgtmtx) {
+ if (lgtmtx[0] == DCT16) {
+ fdct16(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST16) {
+ fadst16(input, output);
+ return;
+ } else if (lgtmtx[0] == DCT32) {
+ fdct32(input, output);
+ return;
+ } else if (lgtmtx[0] == ADST32) {
+ fhalfright32(input, output);
+ return;
+ } else {
+ assert(0);
+ }
+}
+
+typedef void (*FlgtFunc)(const tran_low_t *input, tran_low_t *output,
+ const tran_high_t *lgtmtx);
+
+static FlgtFunc flgt_func[4] = { flgt4, flgt8, flgt16up, flgt16up };
+
+typedef void (*GetLgtFunc)(const TxfmParam *txfm_param, int is_col,
+ const tran_high_t *lgtmtx[], int ntx);
+
+static GetLgtFunc get_lgt_func[4] = { get_lgt4_from_pred, get_lgt8_from_pred,
+ get_lgt16up_from_pred,
+ get_lgt16up_from_pred };
+
+// this inline function corresponds to the up scaling before the first
+// transform in the av1_fht* functions
+static INLINE tran_low_t fwd_upscale_wrt_txsize(const tran_high_t val,
+ const TX_SIZE tx_size) {
+ switch (tx_size) {
+ case TX_4X4: return (tran_low_t)val << 4;
+ case TX_8X8:
+ case TX_4X16:
+ case TX_16X4:
+ case TX_8X32:
+ case TX_32X8: return (tran_low_t)val << 2;
+ case TX_4X8:
+ case TX_8X4:
+ case TX_8X16:
+ case TX_16X8: return (tran_low_t)fdct_round_shift(val * 4 * Sqrt2);
+ default: assert(0); break;
+ }
+ return 0;
+}
+
+// This inline function corresponds to the bit shift after the second
+// transform in the av1_fht* functions
+static INLINE tran_low_t fwd_downscale_wrt_txsize(const tran_low_t val,
+ const TX_SIZE tx_size) {
+ switch (tx_size) {
+ case TX_4X4: return (val + 1) >> 2;
+ case TX_4X8:
+ case TX_8X4:
+ case TX_8X8:
+ case TX_4X16:
+ case TX_16X4: return (val + (val < 0)) >> 1;
+ case TX_8X16:
+ case TX_16X8: return val;
+ case TX_8X32:
+ case TX_32X8: return ROUND_POWER_OF_TWO_SIGNED(val, 2);
+ default: assert(0); break;
+ }
+ return 0;
+}
+
+void flgt2d_from_pred_c(const int16_t *input, tran_low_t *output, int stride,
+ TxfmParam *txfm_param) {
+ const TX_SIZE tx_size = txfm_param->tx_size;
+ const int w = tx_size_wide[tx_size];
+ const int h = tx_size_high[tx_size];
+ const int wlog2 = tx_size_wide_log2[tx_size];
+ const int hlog2 = tx_size_high_log2[tx_size];
+ assert(w <= 8 || h <= 8);
+
+ int i, j;
+ tran_low_t out[256]; // max size: 8x32 and 32x8
+ tran_low_t temp_in[32], temp_out[32];
+ const tran_high_t *lgtmtx_col[1];
+ const tran_high_t *lgtmtx_row[1];
+ get_lgt_func[hlog2 - 2](txfm_param, 1, lgtmtx_col, w);
+ get_lgt_func[wlog2 - 2](txfm_param, 0, lgtmtx_row, h);
+
+ // For forward transforms, to be consistent with av1_fht functions, we apply
+ // short transform first and long transform second.
+ if (w < h) {
+ // Row transforms
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j)
+ temp_in[j] = fwd_upscale_wrt_txsize(input[i * stride + j], tx_size);
+ flgt_func[wlog2 - 2](temp_in, temp_out, lgtmtx_row[0]);
+ // right shift of 2 bits here in fht8x16 and fht16x8
+ for (j = 0; j < w; ++j)
+ out[j * h + i] = (tx_size == TX_16X8 || tx_size == TX_8X16)
+ ? ROUND_POWER_OF_TWO_SIGNED(temp_out[j], 2)
+ : temp_out[j];
+ }
+ // Column transforms
+ for (i = 0; i < w; ++i) {
+ for (j = 0; j < h; ++j) temp_in[j] = out[j + i * h];
+ flgt_func[hlog2 - 2](temp_in, temp_out, lgtmtx_col[0]);
+ for (j = 0; j < h; ++j)
+ output[j * w + i] = fwd_downscale_wrt_txsize(temp_out[j], tx_size);
+ }
+ } else {
+ // Column transforms
+ for (i = 0; i < w; ++i) {
+ for (j = 0; j < h; ++j)
+ temp_in[j] = fwd_upscale_wrt_txsize(input[j * stride + i], tx_size);
+ flgt_func[hlog2 - 2](temp_in, temp_out, lgtmtx_col[0]);
+ // fht8x16 and fht16x8 have right shift of 2 bits here
+ for (j = 0; j < h; ++j)
+ out[j * w + i] = (tx_size == TX_16X8 || tx_size == TX_8X16)
+ ? ROUND_POWER_OF_TWO_SIGNED(temp_out[j], 2)
+ : temp_out[j];
+ }
+ // Row transforms
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) temp_in[j] = out[j + i * w];
+ flgt_func[wlog2 - 2](temp_in, temp_out, lgtmtx_row[0]);
+ for (j = 0; j < w; ++j)
+ output[j + i * w] = fwd_downscale_wrt_txsize(temp_out[j], tx_size);
+ }
+ }
+}
+#endif // CONFIG_LGT_FROM_PRED
#if CONFIG_EXT_TX
// TODO(sarahparker) these functions will be removed once the highbitdepth
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 094f2d9..f79a678 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -6186,6 +6186,7 @@
const int eset =
get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
if (eset > 0) {
+#if !CONFIG_LGT_FROM_PRED
const TxSetType tx_set_type = get_ext_tx_set_type(
tx_size, bsize, is_inter, cm->reduced_tx_set_used);
if (is_inter) {
@@ -6205,6 +6206,44 @@
av1_ext_tx_ind[tx_set_type][tx_type],
av1_num_ext_tx_set[tx_set_type]);
}
+#else
+ (void)tx_type;
+ (void)fc;
+ if (is_inter) {
+ if (LGT_FROM_PRED_INTER) {
+ if (is_lgt_allowed(mbmi->mode, tx_size) && !cm->reduced_tx_set_used)
+ ++counts->inter_lgt[txsize_sqr_map[tx_size]][mbmi->use_lgt];
+#if CONFIG_ENTROPY_STATS
+ if (!mbmi->use_lgt)
+ ++counts->inter_ext_tx[eset][txsize_sqr_map[tx_size]][tx_type];
+ else
+#endif // CONFIG_ENTROPY_STATS
+ mbmi->tx_type = DCT_DCT;
+ } else {
+#if CONFIG_ENTROPY_STATS
+ ++counts->inter_ext_tx[eset][txsize_sqr_map[tx_size]][tx_type];
+#endif // CONFIG_ENTROPY_STATS
+ }
+ } else {
+ if (LGT_FROM_PRED_INTRA) {
+ if (is_lgt_allowed(mbmi->mode, tx_size) && !cm->reduced_tx_set_used)
+ ++counts->intra_lgt[txsize_sqr_map[tx_size]][mbmi->mode]
+ [mbmi->use_lgt];
+#if CONFIG_ENTROPY_STATS
+ if (!mbmi->use_lgt)
+ ++counts->intra_ext_tx[eset][txsize_sqr_map[tx_size]][mbmi->mode]
+ [tx_type];
+ else
+#endif // CONFIG_ENTROPY_STATS
+ mbmi->tx_type = DCT_DCT;
+ } else {
+#if CONFIG_ENTROPY_STATS
+ ++counts->intra_ext_tx[eset][txsize_sqr_map[tx_size]][mbmi->mode]
+ [tx_type];
+#endif // CONFIG_ENTROPY_STATS
+ }
+ }
+#endif // CONFIG_LGT_FROM_PRED
}
}
#else
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index a700b3a..f35ce8a 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -498,7 +498,7 @@
TxfmParam txfm_param;
-#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
+#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT_FROM_PRED || CONFIG_MRC_TX
uint8_t *dst;
const int dst_stride = pd->dst.stride;
#if CONFIG_PVQ || CONFIG_DIST_8X8
@@ -561,9 +561,10 @@
#endif // CONFIG_HIGHBITDEPTH
#endif
-#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
+#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT_FROM_PRED || CONFIG_MRC_TX
dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
-#endif // CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
+#endif // CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT_FROM_PRED ||
+ // CONFIG_MRC_TX
#if CONFIG_PVQ || CONFIG_DIST_8X8
if (CONFIG_PVQ
@@ -599,6 +600,8 @@
txfm_param.lossless = xd->lossless[mbmi->segment_id];
#if CONFIG_MRC_TX || CONFIG_LGT
txfm_param.is_inter = is_inter_block(mbmi);
+#endif
+#if CONFIG_MRC_TX || CONFIG_LGT_FROM_PRED
txfm_param.dst = dst;
txfm_param.stride = dst_stride;
#if CONFIG_MRC_TX
@@ -607,10 +610,11 @@
txfm_param.mask = BLOCK_OFFSET(xd->mrc_mask, block);
#endif // SIGNAL_ANY_MRC_MASK
#endif // CONFIG_MRC_TX
-#endif // CONFIG_MRC_TX || CONFIG_LGT
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
txfm_param.mode = mbmi->mode;
-#endif // CONFIG_LGT
+ txfm_param.use_lgt = mbmi->use_lgt;
+#endif // CONFIG_LGT_FROM_PRED
+#endif // CONFIG_MRC_TX || CONFIG_LGT_FROM_PRED
#if !CONFIG_PVQ
txfm_param.bd = xd->bd;
@@ -740,15 +744,15 @@
if (!x->pvq_skip[plane])
#endif
{
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
PREDICTION_MODE mode = xd->mi[0]->mbmi.mode;
-#endif // CONFIG_LGT
+#endif // CONFIG_LGT_FROM_PRED
TX_TYPE tx_type =
av1_get_tx_type(pd->plane_type, xd, blk_row, blk_col, block, tx_size);
av1_inverse_transform_block(xd, dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mode,
-#endif // CONFIG_LGT
+#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -1095,7 +1099,7 @@
if (x->pvq_skip[plane]) return;
#endif // CONFIG_PVQ
av1_inverse_transform_block(xd, dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
xd->mi[0]->mbmi.mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index c36b177..6ddeb2b 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -555,6 +555,14 @@
void av1_fwd_txfm(const int16_t *src_diff, tran_low_t *coeff, int diff_stride,
TxfmParam *txfm_param) {
const TX_SIZE tx_size = txfm_param->tx_size;
+#if CONFIG_LGT_FROM_PRED
+ if (txfm_param->use_lgt) {
+ // if use_lgt is 1, it will override tx_type
+ assert(is_lgt_allowed(txfm_param->mode, tx_size));
+ flgt2d_from_pred_c(src_diff, coeff, diff_stride, txfm_param);
+ return;
+ }
+#endif // CONFIG_LGT_FROM_PRED
switch (tx_size) {
#if CONFIG_TX64X64
case TX_64X64:
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index cd8e323..5dd4853 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -210,6 +210,22 @@
NULL);
#if CONFIG_EXT_TX
+#if CONFIG_LGT_FROM_PRED
+ if (LGT_FROM_PRED_INTRA) {
+ for (i = 0; i < LGT_SIZES; ++i) {
+ for (j = 0; j < INTRA_MODES; ++j) {
+ x->intra_lgt_cost[i][j][0] = av1_cost_bit(fc->intra_lgt_prob[i][j], 0);
+ x->intra_lgt_cost[i][j][1] = av1_cost_bit(fc->intra_lgt_prob[i][j], 1);
+ }
+ }
+ }
+ if (LGT_FROM_PRED_INTER) {
+ for (i = 0; i < LGT_SIZES; ++i) {
+ x->inter_lgt_cost[i][0] = av1_cost_bit(fc->inter_lgt_prob[i], 0);
+ x->inter_lgt_cost[i][1] = av1_cost_bit(fc->inter_lgt_prob[i], 1);
+ }
+ }
+#endif // CONFIG_LGT_FROM_PRED
for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
int s;
for (s = 1; s < EXT_TX_SETS_INTER; ++s) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 01fa606..2360459 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1843,7 +1843,7 @@
TX_TYPE tx_type =
av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
av1_inverse_transform_block(xd, dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
xd->mi[0]->mbmi.mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -2237,12 +2237,38 @@
}
}
+#if CONFIG_LGT_FROM_PRED
+int av1_lgt_cost(const AV1_COMMON *cm, const MACROBLOCK *x,
+ const MACROBLOCKD *xd, BLOCK_SIZE bsize, int plane,
+ TX_SIZE tx_size, int use_lgt) {
+ if (plane > 0) return 0;
+ const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+ const int is_inter = is_inter_block(mbmi);
+
+ assert(is_lgt_allowed(mbmi->mode, tx_size));
+ if (get_ext_tx_types(tx_size, bsize, is_inter, cm->reduced_tx_set_used) > 1 &&
+ !xd->lossless[xd->mi[0]->mbmi.segment_id]) {
+ const int ext_tx_set =
+ get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
+ if (LGT_FROM_PRED_INTRA && !is_inter && ext_tx_set > 0 &&
+ ALLOW_INTRA_EXT_TX)
+ return x->intra_lgt_cost[txsize_sqr_map[tx_size]][mbmi->mode][use_lgt];
+ if (LGT_FROM_PRED_INTRA && is_inter && ext_tx_set > 0)
+ return x->inter_lgt_cost[txsize_sqr_map[tx_size]][use_lgt];
+ }
+ return 0;
+}
+#endif // CONFIG_LGT_FROM_PRED
+
// TODO(angiebird): use this function whenever it's possible
int av1_tx_type_cost(const AV1_COMMON *cm, const MACROBLOCK *x,
const MACROBLOCKD *xd, BLOCK_SIZE bsize, int plane,
TX_SIZE tx_size, TX_TYPE tx_type) {
if (plane > 0) return 0;
+#if CONFIG_LGT_FROM_PRED
+ assert(!xd->mi[0]->mbmi.use_lgt);
+#endif
#if CONFIG_VAR_TX
tx_size = get_min_tx_size(tx_size);
#endif
@@ -2313,7 +2339,15 @@
if (rd_stats->rate == INT_MAX) return INT64_MAX;
#if !CONFIG_TXK_SEL
int plane = 0;
+#if CONFIG_LGT_FROM_PRED
+ if (is_lgt_allowed(mbmi->mode, tx_size))
+ rd_stats->rate +=
+ av1_lgt_cost(cm, x, xd, bs, plane, tx_size, mbmi->use_lgt);
+ if (!mbmi->use_lgt)
+ rd_stats->rate += av1_tx_type_cost(cm, x, xd, bs, plane, tx_size, tx_type);
+#else
rd_stats->rate += av1_tx_type_cost(cm, x, xd, bs, plane, tx_size, tx_type);
+#endif // CONFIG_LGT_FROM_PRED
#endif
if (rd_stats->skip) {
@@ -2356,6 +2390,9 @@
tx_size != TX_32X32))
return 1;
#endif // CONFIG_MRC_TX
+#if CONFIG_LGT_FROM_PRED
+ if (mbmi->use_lgt && mbmi->ref_mv_idx > 0) return 1;
+#endif // CONFIG_LGT_FROM_PRED
if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) return 1;
if (FIXED_TX_TYPE && tx_type != get_default_tx_type(0, xd, 0, tx_size))
return 1;
@@ -2418,6 +2455,14 @@
const int is_inter = is_inter_block(mbmi);
int prune = 0;
const int plane = 0;
+#if CONFIG_LGT_FROM_PRED
+ int is_lgt_best = 0;
+ int search_lgt = is_inter
+ ? LGT_FROM_PRED_INTER && !x->use_default_inter_tx_type &&
+ !cpi->sf.tx_type_search.prune_mode > NO_PRUNE
+ : LGT_FROM_PRED_INTRA && !x->use_default_intra_tx_type &&
+ ALLOW_INTRA_EXT_TX;
+#endif // CONFIG_LGT_FROM_PRED
av1_invalid_rd_stats(rd_stats);
mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
@@ -2498,6 +2543,33 @@
#if CONFIG_PVQ
od_encode_rollback(&x->daala_enc, &post_buf);
#endif // CONFIG_PVQ
+#if CONFIG_LGT_FROM_PRED
+ // search LGT
+ if (search_lgt && is_lgt_allowed(mbmi->mode, mbmi->tx_size) &&
+ !cm->reduced_tx_set_used) {
+ RD_STATS this_rd_stats;
+ mbmi->use_lgt = 1;
+ txfm_rd_in_plane(x, cpi, &this_rd_stats, ref_best_rd, 0, bs,
+ mbmi->tx_size, cpi->sf.use_fast_coef_costing);
+ if (this_rd_stats.rate != INT_MAX) {
+ av1_lgt_cost(cm, x, xd, bs, plane, mbmi->tx_size, 1);
+ if (this_rd_stats.skip)
+ this_rd = RDCOST(x->rdmult, s1, this_rd_stats.sse);
+ else
+ this_rd =
+ RDCOST(x->rdmult, this_rd_stats.rate + s0, this_rd_stats.dist);
+ if (is_inter_block(mbmi) && !xd->lossless[mbmi->segment_id] &&
+ !this_rd_stats.skip)
+ this_rd = AOMMIN(this_rd, RDCOST(x->rdmult, s1, this_rd_stats.sse));
+ if (this_rd < best_rd) {
+ best_rd = this_rd;
+ is_lgt_best = 1;
+ *rd_stats = this_rd_stats;
+ }
+ }
+ mbmi->use_lgt = 0;
+ }
+#endif // CONFIG_LGT_FROM_PRED
} else {
mbmi->tx_type = DCT_DCT;
txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, bs, mbmi->tx_size,
@@ -2545,6 +2617,9 @@
}
#endif // CONFIG_EXT_TX
mbmi->tx_type = best_tx_type;
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = is_lgt_best;
+#endif // CONFIG_LGT_FROM_PRED
}
static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
@@ -2583,6 +2658,11 @@
const TX_SIZE max_tx_size = max_txsize_lookup[bs];
TX_SIZE best_tx_size = max_tx_size;
TX_TYPE best_tx_type = DCT_DCT;
+#if CONFIG_LGT_FROM_PRED
+ int breakout = 0;
+ int is_lgt_best = 0;
+ mbmi->use_lgt = 0;
+#endif // CONFIG_LGT_FROM_PRED
#if CONFIG_TXK_SEL
TX_TYPE best_txk_type[MAX_SB_SQUARE / (TX_SIZE_W_MIN * TX_SIZE_H_MIN)];
#endif // CONFIG_TXK_SEL
@@ -2639,6 +2719,21 @@
if (mbmi->sb_type < BLOCK_8X8 && is_inter) break;
#endif // CONFIG_CB4X4 && !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
}
+#if CONFIG_LGT_FROM_PRED
+ const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
+ if (is_lgt_allowed(mbmi->mode, rect_tx_size) && !cm->reduced_tx_set_used) {
+ RD_STATS this_rd_stats;
+ mbmi->use_lgt = 1;
+ rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, 0, rect_tx_size);
+ if (rd < best_rd) {
+ is_lgt_best = 1;
+ best_tx_size = rect_tx_size;
+ best_rd = rd;
+ *rd_stats = this_rd_stats;
+ }
+ mbmi->use_lgt = 0;
+ }
+#endif // CONFIG_LGT_FROM_PRED
}
#if CONFIG_RECT_TX_EXT
@@ -2677,6 +2772,9 @@
sizeof(best_txk_type[0]) * num_blk);
#endif
best_tx_type = tx_type;
+#if CONFIG_LGT_FROM_PRED
+ is_lgt_best = 0;
+#endif
best_tx_size = tx_size;
best_rd = rd;
*rd_stats = this_rd_stats;
@@ -2687,6 +2785,21 @@
if (mbmi->sb_type < BLOCK_8X8 && is_inter) break;
#endif // CONFIG_CB4X4 && !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
}
+#if CONFIG_LGT_FROM_PRED
+ if (is_lgt_allowed(mbmi->mode, tx_size) && !cm->reduced_tx_set_used) {
+ const TX_SIZE tx_size = quarter_txsize_lookup[bs];
+ RD_STATS this_rd_stats;
+ mbmi->use_lgt = 1;
+ rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, 0, tx_size);
+ if (rd < best_rd) {
+ is_lgt_best = 1;
+ best_tx_size = tx_size;
+ best_rd = rd;
+ *rd_stats = this_rd_stats;
+ }
+ mbmi->use_lgt = 0;
+ }
+#endif // CONFIG_LGT_FROM_PRED
}
#endif // CONFIG_RECT_TX_EXT
#endif // CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -2725,8 +2838,12 @@
if (cpi->sf.tx_size_search_breakout &&
(rd == INT64_MAX ||
(this_rd_stats.skip == 1 && tx_type != DCT_DCT && n < start_tx) ||
- (n < (int)max_tx_size && rd > last_rd)))
+ (n < (int)max_tx_size && rd > last_rd))) {
+#if CONFIG_LGT_FROM_PRED
+ breakout = 1;
+#endif
break;
+ }
last_rd = rd;
ref_best_rd = AOMMIN(rd, ref_best_rd);
@@ -2735,6 +2852,9 @@
memcpy(best_txk_type, mbmi->txk_type, sizeof(best_txk_type[0]) * 256);
#endif
best_tx_type = tx_type;
+#if CONFIG_LGT_FROM_PRED
+ is_lgt_best = 0;
+#endif
best_tx_size = n;
best_rd = rd;
*rd_stats = this_rd_stats;
@@ -2744,9 +2864,28 @@
if (mbmi->sb_type < BLOCK_8X8 && is_inter) break;
#endif // CONFIG_CB4X4 && !USE_TXTYPE_SEARCH_FOR_SUB8X8_IN_CB4X4
}
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = 1;
+ if (is_lgt_allowed(mbmi->mode, n) && !skip_txfm_search(cpi, x, bs, 0, n) &&
+ !breakout) {
+ RD_STATS this_rd_stats;
+ rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, 0, n);
+ if (rd < best_rd) {
+ is_lgt_best = 1;
+ best_tx_size = n;
+ best_rd = rd;
+ *rd_stats = this_rd_stats;
+ }
+ }
+ mbmi->use_lgt = 0;
+#endif // CONFIG_LGT_FROM_PRED
}
mbmi->tx_size = best_tx_size;
mbmi->tx_type = best_tx_type;
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = is_lgt_best;
+ assert(!is_lgt_best || is_lgt_allowed(mbmi->mode, mbmi->tx_size));
+#endif // CONFIG_LGT_FROM_PRED
#if CONFIG_TXK_SEL
memcpy(mbmi->txk_type, best_txk_type, sizeof(best_txk_type[0]) * 256);
#endif
@@ -3241,7 +3380,7 @@
if (!skip)
#endif
av1_inverse_transform_block(xd, BLOCK_OFFSET(pd->dqcoeff, block),
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -3297,7 +3436,7 @@
if (!skip)
#endif
av1_inverse_transform_block(xd, BLOCK_OFFSET(pd->dqcoeff, block),
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -3486,7 +3625,7 @@
if (!skip)
#endif // CONFIG_PVQ
av1_inverse_transform_block(xd, BLOCK_OFFSET(pd->dqcoeff, block),
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -3507,7 +3646,7 @@
if (!skip)
#endif // CONFIG_PVQ
av1_inverse_transform_block(xd, BLOCK_OFFSET(pd->dqcoeff, block),
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -3598,6 +3737,9 @@
// expense of speed.
mbmi->tx_type = DCT_DCT;
mbmi->tx_size = tx_size;
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = 0;
+#endif
if (y_skip) *y_skip = 1;
@@ -3672,8 +3814,14 @@
1) {
const int eset =
get_ext_tx_set(tx_size, bsize, 0, cpi->common.reduced_tx_set_used);
- rate_tx_type = mb->intra_tx_type_costs[eset][txsize_sqr_map[tx_size]]
- [mbmi->mode][mbmi->tx_type];
+#if CONFIG_LGT_FROM_PRED
+ if (LGT_FROM_PRED_INTRA && is_lgt_allowed(mbmi->mode, tx_size))
+ rate_tx_type += mb->intra_lgt_cost[txsize_sqr_map[tx_size]][mbmi->mode]
+ [mbmi->use_lgt];
+ if (!LGT_FROM_PRED_INTRA || !mbmi->use_lgt)
+#endif // CONFIG_LGT_FROM_PRED
+ rate_tx_type += mb->intra_tx_type_costs[eset][txsize_sqr_map[tx_size]]
+ [mbmi->mode][mbmi->tx_type];
}
#else
rate_tx_type =
@@ -3709,6 +3857,9 @@
TX_SIZE best_tx_size = TX_4X4;
FILTER_INTRA_MODE_INFO filter_intra_mode_info;
TX_TYPE best_tx_type;
+#if CONFIG_LGT_FROM_PRED
+ int use_lgt_when_selected;
+#endif
av1_zero(filter_intra_mode_info);
mbmi->filter_intra_mode_info.use_filter_intra_mode[0] = 1;
@@ -3738,6 +3889,9 @@
best_tx_size = mic->mbmi.tx_size;
filter_intra_mode_info = mbmi->filter_intra_mode_info;
best_tx_type = mic->mbmi.tx_type;
+#if CONFIG_LGT_FROM_PRED
+ use_lgt_when_selected = mic->mbmi.use_lgt;
+#endif
*rate = this_rate;
*rate_tokenonly = tokenonly_rd_stats.rate;
*distortion = tokenonly_rd_stats.dist;
@@ -3749,6 +3903,9 @@
if (filter_intra_selected_flag) {
mbmi->mode = DC_PRED;
mbmi->tx_size = best_tx_size;
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = use_lgt_when_selected;
+#endif
mbmi->filter_intra_mode_info.use_filter_intra_mode[0] =
filter_intra_mode_info.use_filter_intra_mode[0];
mbmi->filter_intra_mode_info.filter_intra_mode[0] =
@@ -3769,6 +3926,9 @@
int64_t best_rd_in, int8_t angle_delta, int max_angle_delta, int *rate,
RD_STATS *rd_stats, int *best_angle_delta, TX_SIZE *best_tx_size,
TX_TYPE *best_tx_type,
+#if CONFIG_LGT_FROM_PRED
+ int *use_lgt_when_selected,
+#endif
#if CONFIG_INTRA_INTERP
INTRA_FILTER *best_filter,
#endif // CONFIG_INTRA_INTERP
@@ -3801,6 +3961,9 @@
*best_filter = mbmi->intra_filter;
#endif // CONFIG_INTRA_INTERP
*best_tx_type = mbmi->tx_type;
+#if CONFIG_LGT_FROM_PRED
+ *use_lgt_when_selected = mbmi->use_lgt;
+#endif
*rate = this_rate;
rd_stats->rate = tokenonly_rd_stats.rate;
rd_stats->dist = tokenonly_rd_stats.dist;
@@ -3830,6 +3993,9 @@
int64_t this_rd, best_rd_in, rd_cost[2 * (MAX_ANGLE_DELTA + 2)];
TX_SIZE best_tx_size = mic->mbmi.tx_size;
TX_TYPE best_tx_type = mbmi->tx_type;
+#if CONFIG_LGT_FROM_PRED
+ int use_lgt_when_selected = mbmi->use_lgt;
+#endif
for (i = 0; i < 2 * (MAX_ANGLE_DELTA + 2); ++i) rd_cost[i] = INT64_MAX;
@@ -3852,6 +4018,9 @@
#endif // CONFIG_INTRA_INTERP
best_rd_in, (1 - 2 * i) * angle_delta, MAX_ANGLE_DELTA, rate,
rd_stats, &best_angle_delta, &best_tx_size, &best_tx_type,
+#if CONFIG_LGT_FROM_PRED
+ &use_lgt_when_selected,
+#endif
#if CONFIG_INTRA_INTERP
&best_filter,
#endif // CONFIG_INTRA_INTERP
@@ -3893,6 +4062,9 @@
#endif // CONFIG_INTRA_INTERP
best_rd, (1 - 2 * i) * angle_delta, MAX_ANGLE_DELTA, rate,
rd_stats, &best_angle_delta, &best_tx_size, &best_tx_type,
+#if CONFIG_LGT_FROM_PRED
+ &use_lgt_when_selected,
+#endif
#if CONFIG_INTRA_INTERP
&best_filter,
#endif // CONFIG_INTRA_INTERP
@@ -3914,8 +4086,11 @@
cpi, x, bsize,
mode_cost + x->intra_filter_cost[intra_filter_ctx][filter], best_rd,
best_angle_delta, MAX_ANGLE_DELTA, rate, rd_stats,
- &best_angle_delta, &best_tx_size, &best_tx_type, &best_filter,
- &best_rd, best_model_rd);
+ &best_angle_delta, &best_tx_size, &best_tx_type,
+#if CONFIG_LGT_FROM_PRED
+ &use_lgt_when_selected,
+#endif
+ &best_filter, &best_rd, best_model_rd);
}
}
}
@@ -3927,6 +4102,9 @@
mic->mbmi.intra_filter = best_filter;
#endif // CONFIG_INTRA_INTERP
mbmi->tx_type = best_tx_type;
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = use_lgt_when_selected;
+#endif
return best_rd;
}
@@ -4478,7 +4656,7 @@
const int eob = p->eobs[block];
av1_inverse_transform_block(xd, dqcoeff,
-#if CONFIG_LGT
+#if CONFIG_LGT_FROM_PRED
xd->mi[0]->mbmi.mode,
#endif
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
@@ -5065,18 +5243,34 @@
!xd->lossless[xd->mi[0]->mbmi.segment_id]) {
const int ext_tx_set = get_ext_tx_set(mbmi->min_tx_size, bsize, is_inter,
cm->reduced_tx_set_used);
- if (is_inter) {
- if (ext_tx_set > 0)
+#if CONFIG_LGT_FROM_PRED
+ if (is_lgt_allowed(mbmi->mode, mbmi->min_tx_size)) {
+ if (LGT_FROM_PRED_INTRA && !is_inter && ext_tx_set > 0 &&
+ ALLOW_INTRA_EXT_TX)
+ rd_stats->rate += x->intra_lgt_cost[txsize_sqr_map[mbmi->min_tx_size]]
+ [mbmi->mode][mbmi->use_lgt];
+ if (LGT_FROM_PRED_INTER && is_inter && ext_tx_set > 0)
rd_stats->rate +=
- x->inter_tx_type_costs[ext_tx_set]
- [txsize_sqr_map[mbmi->min_tx_size]]
- [mbmi->tx_type];
- } else {
- if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
- rd_stats->rate += x->intra_tx_type_costs[ext_tx_set][mbmi->min_tx_size]
- [mbmi->mode][mbmi->tx_type];
+ x->inter_lgt_cost[txsize_sqr_map[mbmi->min_tx_size]][mbmi->use_lgt];
}
+ if (!mbmi->use_lgt) {
+#endif // CONFIG_LGT_FROM_PRED
+ if (is_inter) {
+ if (ext_tx_set > 0)
+ rd_stats->rate +=
+ x->inter_tx_type_costs[ext_tx_set]
+ [txsize_sqr_map[mbmi->min_tx_size]]
+ [mbmi->tx_type];
+ } else {
+ if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
+ rd_stats->rate +=
+ x->intra_tx_type_costs[ext_tx_set][mbmi->min_tx_size][mbmi->mode]
+ [mbmi->tx_type];
+ }
+ }
+#if CONFIG_LGT_FROM_PRED
}
+#endif
#else
if (mbmi->min_tx_size < TX_32X32 && !xd->lossless[xd->mi[0]->mbmi.segment_id])
rd_stats->rate += x->inter_tx_type_costs[mbmi->min_tx_size][mbmi->tx_type];
@@ -5284,6 +5478,14 @@
av1_invalid_rd_stats(rd_stats);
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = 0;
+ int search_lgt = is_inter
+ ? LGT_FROM_PRED_INTER &&
+ (!cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
+ : LGT_FROM_PRED_INTRA && ALLOW_INTRA_EXT_TX;
+#endif // CONFIG_LGT_FROM_PRED
+
const uint32_t hash = get_block_residue_hash(x, bsize);
TX_RD_RECORD *tx_rd_record = &x->tx_rd_record;
@@ -5379,6 +5581,26 @@
assert(IMPLIES(!found, ref_best_rd != INT64_MAX));
if (!found) return;
+#if CONFIG_LGT_FROM_PRED
+ if (search_lgt && is_lgt_allowed(mbmi->mode, max_tx_size) &&
+ !cm->reduced_tx_set_used) {
+ RD_STATS this_rd_stats;
+ mbmi->use_lgt = 1;
+ rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, ref_best_rd, 0);
+ if (rd < best_rd) {
+ best_rd = rd;
+ *rd_stats = this_rd_stats;
+ best_tx = mbmi->tx_size;
+ best_min_tx_size = mbmi->min_tx_size;
+ memcpy(best_blk_skip, x->blk_skip[0], sizeof(best_blk_skip[0]) * n4);
+ for (idy = 0; idy < xd->n8_h; ++idy)
+ for (idx = 0; idx < xd->n8_w; ++idx)
+ best_tx_size[idy][idx] = mbmi->inter_tx_size[idy][idx];
+ } else {
+ mbmi->use_lgt = 0;
+ }
+ }
+#endif // CONFIG_LGT_FROM_PRED
// We found a candidate transform to use. Copy our results from the "best"
// array into mbmi.
mbmi->tx_type = best_tx_type;
@@ -8988,6 +9210,9 @@
int compmode_interinter_cost = 0;
mbmi->interinter_compound_type = COMPOUND_AVERAGE;
#endif
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = 0;
+#endif
#if CONFIG_INTERINTRA
if (!cm->allow_interintra_compound && is_comp_interintra_pred)
@@ -9785,6 +10010,9 @@
mbmi->use_intrabc = 0;
mbmi->mv[0].as_int = 0;
#endif // CONFIG_INTRABC
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = 0;
+#endif
const int64_t intra_yrd =
(bsize >= BLOCK_8X8 || unify_bsize)
@@ -11564,6 +11792,9 @@
#endif // CONFIG_VAR_TX
best_mbmode.tx_type = mbmi->tx_type;
best_mbmode.tx_size = mbmi->tx_size;
+#if CONFIG_LGT_FROM_PRED
+ best_mbmode.use_lgt = mbmi->use_lgt;
+#endif
#if CONFIG_VAR_TX
for (idy = 0; idy < xd->n8_h; ++idy)
for (idx = 0; idx < xd->n8_w; ++idx)
@@ -11989,6 +12220,9 @@
mbmi->ref_mv_idx = 0;
mbmi->pred_mv[0].as_int = 0;
+#if CONFIG_LGT_FROM_PRED
+ mbmi->use_lgt = 0;
+#endif
mbmi->motion_mode = SIMPLE_TRANSLATION;
#if CONFIG_MOTION_VAR