[tcq] Simplify get_rate*_c functions.
Split luma/chroma versions of get_rate_dist*
diff --git a/av1/encoder/x86/trellis_quant_avx2.c b/av1/encoder/x86/trellis_quant_avx2.c
index 64b35b5..209efa9 100644
--- a/av1/encoder/x86/trellis_quant_avx2.c
+++ b/av1/encoder/x86/trellis_quant_avx2.c
@@ -550,10 +550,8 @@
struct tcq_lf_ctx_t *lf_ctx) {
__m256i upd_last_a;
__m256i upd_last_b;
-#if MORESTATES
__m256i upd_last_c;
__m256i upd_last_d;
-#endif
for (int st = 0; st < TOTALSTATES; st += 2) {
int absLevel0 = decision[st].absLevel;
int prevId0 = decision[st].prevId;
@@ -573,10 +571,8 @@
upd1 = _mm_insert_epi8(upd1, AOMMIN(absLevel1, INT8_MAX), 0);
__m256i upd01 = _mm256_castsi128_si256(upd0);
upd01 = _mm256_inserti128_si256(upd01, upd1, 1);
-#if MORESTATES
upd_last_d = upd_last_c;
upd_last_c = upd_last_b;
-#endif
upd_last_b = upd_last_a;
upd_last_a = upd01;
}
@@ -586,12 +582,126 @@
_mm256_storeu_si256((__m256i *)lf_ctx[4].last, upd_last_b);
_mm256_storeu_si256((__m256i *)lf_ctx[6].last, upd_last_a);
#else
+ (void)upd_last_d;
+ (void)upd_last_c;
_mm256_storeu_si256((__m256i *)lf_ctx[0].last, upd_last_b);
_mm256_storeu_si256((__m256i *)lf_ctx[2].last, upd_last_a);
#endif
}
-void av1_get_rate_dist_lf_avx2(
+void av1_get_rate_dist_lf_luma_avx2(
+ const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
+ const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx,
+ int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int blk_pos,
+ int coeff_sign, int32_t rate_zero[TOTALSTATES],
+ int32_t rate[2 * TOTALSTATES], int64_t dist[2 * TOTALSTATES]) {
+#define Z -1
+ static const int8_t kShuf[2][32] = {
+ { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
+ 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 },
+ { 0, 8, Z, Z, 1, 9, Z, Z, 2, 10, Z, Z, 3, 11, Z, Z,
+ 4, 12, Z, Z, 5, 13, Z, Z, 6, 14, Z, Z, 7, 15, Z, Z }
+ };
+ const uint16_t(*cost_low)[LF_BASE_SYMBOLS][LF_SIG_COEF_CONTEXTS] =
+ txb_costs->base_lf_cost_low;
+ const uint16_t(*cost_low_tbl)[LF_SIG_COEF_CONTEXTS][DQ_CTXS][2] =
+ txb_costs->base_lf_cost_low_tbl;
+ const tran_low_t *absLevel = pq->absLevel;
+ const int64_t *deltaDist = pq->deltaDist;
+ const int plane = 0;
+
+ // Copy distortion stats.
+ __m256i delta_dist = _mm256_lddqu_si256((__m256i *)deltaDist);
+ __m256i dist02 = _mm256_permute4x64_epi64(delta_dist, 0x88);
+ __m256i dist13 = _mm256_permute4x64_epi64(delta_dist, 0xDD);
+ _mm256_storeu_si256((__m256i *)&dist[0], dist02);
+ _mm256_storeu_si256((__m256i *)&dist[4], dist13);
+#if MORESTATES
+ _mm256_storeu_si256((__m256i *)&dist[8], dist02);
+ _mm256_storeu_si256((__m256i *)&dist[12], dist13);
+#endif
+
+ // Calc zero coeff costs.
+ __m256i cost_zero_dq0 =
+ _mm256_lddqu_si256((__m256i *)&cost_low[0][0][diag_ctx]);
+ __m256i cost_zero_dq1 =
+ _mm256_lddqu_si256((__m256i *)&cost_low[1][0][diag_ctx]);
+ __m256i shuf = _mm256_lddqu_si256((__m256i *)kShuf[0]);
+ cost_zero_dq0 = _mm256_shuffle_epi8(cost_zero_dq0, shuf);
+ cost_zero_dq1 = _mm256_shuffle_epi8(cost_zero_dq1, shuf);
+ __m256i cost_dq0 = _mm256_permute4x64_epi64(cost_zero_dq0, 0xD8);
+ __m256i cost_dq1 = _mm256_permute4x64_epi64(cost_zero_dq1, 0xD8);
+ __m256i ctx = _mm256_castsi128_si256(_mm_loadu_si64(coeff_ctx));
+ __m256i fifteen = _mm256_set1_epi8(15);
+ __m256i base_ctx = _mm256_and_si256(ctx, fifteen);
+ base_ctx = _mm256_permute4x64_epi64(base_ctx, 0);
+ __m256i ratez_dq0 = _mm256_shuffle_epi8(cost_dq0, base_ctx);
+ __m256i ratez_dq1 = _mm256_shuffle_epi8(cost_dq1, base_ctx);
+ __m256i ratez = _mm256_blend_epi16(ratez_dq0, ratez_dq1, 0xAA);
+ ratez = _mm256_permute4x64_epi64(ratez, 0x88);
+ __m256i shuf1 = _mm256_lddqu_si256((__m256i *)kShuf[1]);
+ ratez = _mm256_shuffle_epi8(ratez, shuf1);
+#if MORESTATES
+ _mm256_storeu_si256((__m256i *)&rate_zero[0], ratez);
+#else
+ _mm_storeu_si128((__m128i *)&rate_zero[0], _mm256_castsi256_si128(ratez));
+#endif
+
+ // Calc coeff_base rate.
+ int idx = AOMMIN(pq->qIdx - 1, 8);
+ for (int i = 0; i < TOTALSTATES / 4; i++) {
+ int j = 4 * i;
+ int ctx0 = diag_ctx + (coeff_ctx[j + 0] & 15);
+ int ctx1 = diag_ctx + (coeff_ctx[j + 1] & 15);
+ int ctx2 = diag_ctx + (coeff_ctx[j + 2] & 15);
+ int ctx3 = diag_ctx + (coeff_ctx[j + 3] & 15);
+ __m128i rate_01 = _mm_loadu_si64(&cost_low_tbl[idx][ctx0][0]);
+ __m128i rate_23 = _mm_loadu_si64(&cost_low_tbl[idx][ctx1][0]);
+ __m128i rate_45 = _mm_loadu_si64(&cost_low_tbl[idx][ctx2][1]);
+ __m128i rate_67 = _mm_loadu_si64(&cost_low_tbl[idx][ctx3][1]);
+ __m128i rate_0123 = _mm_unpacklo_epi32(rate_01, rate_23);
+ __m128i rate_4567 = _mm_unpacklo_epi32(rate_45, rate_67);
+ __m128i c_zero = _mm_setzero_si128();
+ rate_0123 = _mm_unpacklo_epi16(rate_0123, c_zero);
+ rate_4567 = _mm_unpacklo_epi16(rate_4567, c_zero);
+ _mm_storeu_si128((__m128i *)&rate[8 * i], rate_0123);
+ _mm_storeu_si128((__m128i *)&rate[8 * i + 4], rate_4567);
+ }
+
+ const int row = blk_pos >> bwl;
+ const int col = blk_pos - (row << bwl);
+ const bool dc_2dtx = (blk_pos == 0);
+ const bool dc_hor = (col == 0) && tx_class == TX_CLASS_HORIZ;
+ const bool dc_ver = (row == 0) && tx_class == TX_CLASS_VERT;
+ const bool is_dc_coeff = dc_2dtx || dc_hor || dc_ver;
+ if (is_dc_coeff) {
+ for (int i = 0; i < TOTALSTATES; i++) {
+ int a0 = i & 2 ? 1 : 0;
+ int a1 = a0 + 2;
+ int mid_cost0 =
+ get_mid_cost_lf_dc(blk_pos, absLevel[a0], coeff_sign, coeff_ctx[i],
+ dc_sign_ctx, txb_costs, tmp_sign, plane);
+ int mid_cost1 =
+ get_mid_cost_lf_dc(blk_pos, absLevel[a1], coeff_sign, coeff_ctx[i],
+ dc_sign_ctx, txb_costs, tmp_sign, plane);
+ rate[2 * i] += mid_cost0;
+ rate[2 * i + 1] += mid_cost1;
+ }
+ } else if (idx > 4) {
+ for (int i = 0; i < TOTALSTATES; i++) {
+ int a0 = i & 2 ? 1 : 0;
+ int a1 = a0 + 2;
+ int mid_cost0 =
+ get_mid_cost_lf(absLevel[a0], coeff_ctx[i], txb_costs, plane);
+ int mid_cost1 =
+ get_mid_cost_lf(absLevel[a1], coeff_ctx[i], txb_costs, plane);
+ rate[2 * i] += mid_cost0;
+ rate[2 * i + 1] += mid_cost1;
+ }
+ }
+}
+
+void av1_get_rate_dist_lf_chroma_avx2(
const struct LV_MAP_COEFF_COST *txb_costs, const struct prequant_t *pq,
const uint8_t coeff_ctx[TOTALSTATES + 4], int diag_ctx, int dc_sign_ctx,
int32_t *tmp_sign, int bwl, TX_CLASS tx_class, int plane, int blk_pos,