blob: eb268c5482109235949663f6ab1a8ed1f6797212 [file] [log] [blame] [edit]
/*
* Copyright (c) 2021, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 3-Clause Clear License
* and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
* License was not distributed with this source code in the LICENSE file, you
* can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
* Alliance for Open Media Patent License 1.0 was not distributed with this
* source code in the PATENTS file, you can obtain it at
* aomedia.org/license/patent-license/.
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <string.h>
#include "tools/txfm_analyzer/txfm_graph.h"
typedef enum CODE_TYPE {
CODE_TYPE_C,
CODE_TYPE_SSE2,
CODE_TYPE_SSE4_1
} CODE_TYPE;
int get_cos_idx(double value, int mod) {
return round(acos(fabs(value)) / PI * mod);
}
char *cos_text_arr(double value, int mod, char *text, int size) {
int num = get_cos_idx(value, mod);
if (value < 0) {
snprintf(text, size, "-cospi[%2d]", num);
} else {
snprintf(text, size, " cospi[%2d]", num);
}
if (num == 0)
printf("v: %f -> %d/%d v==-1 is %d\n", value, num, mod, value == -1);
return text;
}
char *cos_text_sse2(double w0, double w1, int mod, char *text, int size) {
int idx0 = get_cos_idx(w0, mod);
int idx1 = get_cos_idx(w1, mod);
char p[] = "p";
char n[] = "m";
char *sgn0 = w0 < 0 ? n : p;
char *sgn1 = w1 < 0 ? n : p;
snprintf(text, size, "cospi_%s%02d_%s%02d", sgn0, idx0, sgn1, idx1);
return text;
}
char *cos_text_sse4_1(double w, int mod, char *text, int size) {
int idx = get_cos_idx(w, mod);
char p[] = "p";
char n[] = "m";
char *sgn = w < 0 ? n : p;
snprintf(text, size, "cospi_%s%02d", sgn, idx);
return text;
}
void node_to_code_c(Node *node, const char *buf0, const char *buf1) {
int cnt = 0;
for (int i = 0; i < 2; i++) {
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
}
if (cnt == 2) {
int cnt2 = 0;
printf(" %s[%d] =", buf1, node->nodeIdx);
for (int i = 0; i < 2; i++) {
if (fabs(node->inWeight[i]) == 1) {
cnt2++;
}
}
if (cnt2 == 2) {
printf(" apply_value(");
}
int cnt1 = 0;
for (int i = 0; i < 2; i++) {
if (node->inWeight[i] == 1) {
if (cnt1 > 0)
printf(" + %s[%d]", buf0, node->inNodeIdx[i]);
else
printf(" %s[%d]", buf0, node->inNodeIdx[i]);
cnt1++;
} else if (node->inWeight[i] == -1) {
if (cnt1 > 0)
printf(" - %s[%d]", buf0, node->inNodeIdx[i]);
else
printf("-%s[%d]", buf0, node->inNodeIdx[i]);
cnt1++;
}
}
if (cnt2 == 2) {
printf(", stage_range[stage])");
}
printf(";\n");
} else {
char w0[100];
char w1[100];
printf(
" %s[%d] = half_btf(%s, %s[%d], %s, %s[%d], "
"cos_bit);\n",
buf1, node->nodeIdx, cos_text_arr(node->inWeight[0], COS_MOD, w0, 100),
buf0, node->inNodeIdx[0],
cos_text_arr(node->inWeight[1], COS_MOD, w1, 100), buf0,
node->inNodeIdx[1]);
}
}
void gen_code_c(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
char *fun_name = new char[100];
get_fun_name(fun_name, 100, type, node_num);
printf("\n");
printf(
"void av1_%s(const int32_t *input, int32_t *output, int8_t cos_bit, "
"const int8_t* stage_range) "
"{\n",
fun_name);
printf(" assert(output != input);\n");
printf(" const int32_t size = %d;\n", node_num);
printf(" const int32_t *cospi = cospi_arr(cos_bit);\n");
printf("\n");
printf(" int32_t stage = 0;\n");
printf(" int32_t *bf0, *bf1;\n");
printf(" int32_t step[%d];\n", node_num);
const char *buf0 = "bf0";
const char *buf1 = "bf1";
const char *input = "input";
int si = 0;
printf("\n");
printf(" // stage %d;\n", si);
printf(" apply_range(stage, input, %s, size, stage_range[stage]);\n", input);
si = 1;
printf("\n");
printf(" // stage %d;\n", si);
printf(" stage++;\n");
if (si % 2 == (stage_num - 1) % 2) {
printf(" %s = output;\n", buf1);
} else {
printf(" %s = step;\n", buf1);
}
for (int ni = 0; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_c(node + idx, input, buf1);
}
printf(" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
for (int si = 2; si < stage_num; si++) {
printf("\n");
printf(" // stage %d\n", si);
printf(" stage++;\n");
if (si % 2 == (stage_num - 1) % 2) {
printf(" %s = step;\n", buf0);
printf(" %s = output;\n", buf1);
} else {
printf(" %s = output;\n", buf0);
printf(" %s = step;\n", buf1);
}
// computation code
for (int ni = 0; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_c(node + idx, buf0, buf1);
}
if (si != stage_num - 1) {
printf(
" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
}
}
printf(" apply_range(stage, input, output, size, stage_range[stage]);\n");
printf("}\n");
}
void single_node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
printf(" %s[%2d] =", buf1, node->nodeIdx);
if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
printf(" _mm_adds_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
node->inNodeIdx[1]);
} else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
node->inNodeIdx[1]);
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
node->inNodeIdx[0]);
} else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
printf(" %s[%d]", buf0, node->inNodeIdx[0]);
} else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
printf(" %s[%d]", buf0, node->inNodeIdx[1]);
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
} else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
}
printf(";\n");
}
void pair_node_to_code_sse2(Node *node, Node *partnerNode, const char *buf0,
const char *buf1) {
char temp0[100];
char temp1[100];
// btf_16_sse2_type0(w0, w1, in0, in1, out0, out1)
if (node->inNodeIdx[0] != partnerNode->inNodeIdx[0])
printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
100),
cos_text_sse2(partnerNode->inWeight[1], partnerNode->inWeight[0],
COS_MOD, temp1, 100),
buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
node->nodeIdx, buf1, partnerNode->nodeIdx);
else
printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
100),
cos_text_sse2(partnerNode->inWeight[0], partnerNode->inWeight[1],
COS_MOD, temp1, 100),
buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
node->nodeIdx, buf1, partnerNode->nodeIdx);
}
Node *get_partner_node(Node *node) {
int diff = node->inNode[1]->nodeIdx - node->nodeIdx;
return node + diff;
}
void node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
int cnt = 0;
int cnt1 = 0;
if (node->visited == 0) {
node->visited = 1;
for (int i = 0; i < 2; i++) {
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
if (fabs(node->inWeight[i]) == 1) cnt1++;
}
if (cnt == 2) {
if (cnt1 == 2) {
// has a partner
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1;
single_node_to_code_sse2(node, buf0, buf1);
single_node_to_code_sse2(partnerNode, buf0, buf1);
} else {
single_node_to_code_sse2(node, buf0, buf1);
}
} else {
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1;
pair_node_to_code_sse2(node, partnerNode, buf0, buf1);
}
}
}
void gen_cospi_list_sse2(Node *node, int stage_num, int node_num) {
int visited[65][65][2][2];
memset(visited, 0, sizeof(visited));
char text[100];
char text1[100];
char text2[100];
int size = 100;
printf("\n");
for (int si = 1; si < stage_num; si++) {
for (int ni = 0; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
int cnt = 0;
Node *node0 = node + idx;
if (node0->visited == 0) {
node0->visited = 1;
for (int i = 0; i < 2; i++) {
if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
cnt++;
}
if (cnt != 2) {
{
double w0 = node0->inWeight[0];
double w1 = node0->inWeight[1];
int idx0 = get_cos_idx(w0, COS_MOD);
int idx1 = get_cos_idx(w1, COS_MOD);
int sgn0 = w0 < 0 ? 1 : 0;
int sgn1 = w1 < 0 ? 1 : 0;
if (!visited[idx0][idx1][sgn0][sgn1]) {
visited[idx0][idx1][sgn0][sgn1] = 1;
printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
cos_text_sse2(w0, w1, COS_MOD, text, size),
cos_text_arr(w0, COS_MOD, text1, size),
cos_text_arr(w1, COS_MOD, text2, size));
}
}
Node *node1 = get_partner_node(node0);
node1->visited = 1;
if (node1->inNode[0]->nodeIdx != node0->inNode[0]->nodeIdx) {
double w0 = node1->inWeight[0];
double w1 = node1->inWeight[1];
int idx0 = get_cos_idx(w0, COS_MOD);
int idx1 = get_cos_idx(w1, COS_MOD);
int sgn0 = w0 < 0 ? 1 : 0;
int sgn1 = w1 < 0 ? 1 : 0;
if (!visited[idx1][idx0][sgn1][sgn0]) {
visited[idx1][idx0][sgn1][sgn0] = 1;
printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
cos_text_sse2(w1, w0, COS_MOD, text, size),
cos_text_arr(w1, COS_MOD, text1, size),
cos_text_arr(w0, COS_MOD, text2, size));
}
} else {
double w0 = node1->inWeight[0];
double w1 = node1->inWeight[1];
int idx0 = get_cos_idx(w0, COS_MOD);
int idx1 = get_cos_idx(w1, COS_MOD);
int sgn0 = w0 < 0 ? 1 : 0;
int sgn1 = w1 < 0 ? 1 : 0;
if (!visited[idx0][idx1][sgn0][sgn1]) {
visited[idx0][idx1][sgn0][sgn1] = 1;
printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
cos_text_sse2(w0, w1, COS_MOD, text, size),
cos_text_arr(w0, COS_MOD, text1, size),
cos_text_arr(w1, COS_MOD, text2, size));
}
}
}
}
}
}
}
void gen_code_sse2(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
char *fun_name = new char[100];
get_fun_name(fun_name, 100, type, node_num);
printf("\n");
printf(
"void %s_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) "
"{\n",
fun_name);
printf(" const int32_t* cospi = cospi_arr(cos_bit);\n");
printf(" const __m128i __zero = _mm_setzero_si128();\n");
printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
graph_reset_visited(node, stage_num, node_num);
gen_cospi_list_sse2(node, stage_num, node_num);
graph_reset_visited(node, stage_num, node_num);
for (int si = 1; si < stage_num; si++) {
char in[100];
char out[100];
printf("\n");
printf(" // stage %d\n", si);
if (si == 1)
snprintf(in, 100, "%s", "input");
else
snprintf(in, 100, "x%d", si - 1);
if (si == stage_num - 1) {
snprintf(out, 100, "%s", "output");
} else {
snprintf(out, 100, "x%d", si);
printf(" __m128i %s[%d];\n", out, node_num);
}
// computation code
for (int ni = 0; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_sse2(node + idx, in, out);
}
}
printf("}\n");
}
void gen_cospi_list_sse4_1(Node *node, int stage_num, int node_num) {
int visited[65][2];
memset(visited, 0, sizeof(visited));
char text[100];
char text1[100];
int size = 100;
printf("\n");
for (int si = 1; si < stage_num; si++) {
for (int ni = 0; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
Node *node0 = node + idx;
if (node0->visited == 0) {
int cnt = 0;
node0->visited = 1;
for (int i = 0; i < 2; i++) {
if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
cnt++;
}
if (cnt != 2) {
for (int i = 0; i < 2; i++) {
if (fabs(node0->inWeight[i]) != 1 &&
fabs(node0->inWeight[i]) != 0) {
double w = node0->inWeight[i];
int idx = get_cos_idx(w, COS_MOD);
int sgn = w < 0 ? 1 : 0;
if (!visited[idx][sgn]) {
visited[idx][sgn] = 1;
printf(" __m128i %s = _mm_set1_epi32(%s);\n",
cos_text_sse4_1(w, COS_MOD, text, size),
cos_text_arr(w, COS_MOD, text1, size));
}
}
}
Node *node1 = get_partner_node(node0);
node1->visited = 1;
}
}
}
}
}
void single_node_to_code_sse4_1(Node *node, const char *buf0,
const char *buf1) {
printf(" %s[%2d] =", buf1, node->nodeIdx);
if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
printf(" _mm_add_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
node->inNodeIdx[1]);
} else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
node->inNodeIdx[1]);
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
node->inNodeIdx[0]);
} else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
printf(" %s[%d]", buf0, node->inNodeIdx[0]);
} else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
printf(" %s[%d]", buf0, node->inNodeIdx[1]);
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
} else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
}
printf(";\n");
}
void pair_node_to_code_sse4_1(Node *node, Node *partnerNode, const char *buf0,
const char *buf1) {
char temp0[100];
char temp1[100];
if (node->inWeight[0] * partnerNode->inWeight[0] < 0) {
/* type0
* cos sin
* sin -cos
*/
// btf_32_sse2_type0(w0, w1, in0, in1, out0, out1)
// out0 = w0*in0 + w1*in1
// out1 = -w0*in1 + w1*in0
printf(
" btf_32_type0_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
"__rounding, cos_bit);\n",
cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
partnerNode->nodeIdx);
} else {
/* type1
* cos sin
* -sin cos
*/
// btf_32_sse2_type1(w0, w1, in0, in1, out0, out1)
// out0 = w0*in0 + w1*in1
// out1 = w0*in1 - w1*in0
printf(
" btf_32_type1_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
"__rounding, cos_bit);\n",
cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
partnerNode->nodeIdx);
}
}
void node_to_code_sse4_1(Node *node, const char *buf0, const char *buf1) {
int cnt = 0;
int cnt1 = 0;
if (node->visited == 0) {
node->visited = 1;
for (int i = 0; i < 2; i++) {
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
if (fabs(node->inWeight[i]) == 1) cnt1++;
}
if (cnt == 2) {
if (cnt1 == 2) {
// has a partner
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1;
single_node_to_code_sse4_1(node, buf0, buf1);
single_node_to_code_sse4_1(partnerNode, buf0, buf1);
} else {
single_node_to_code_sse2(node, buf0, buf1);
}
} else {
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1;
pair_node_to_code_sse4_1(node, partnerNode, buf0, buf1);
}
}
}
void gen_code_sse4_1(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
char *fun_name = new char[100];
get_fun_name(fun_name, 100, type, node_num);
printf("\n");
printf(
"void %s_sse4_1(const __m128i *input, __m128i *output, int8_t cos_bit) "
"{\n",
fun_name);
printf(" const int32_t* cospi = cospi_arr(cos_bit);\n");
printf(" const __m128i __zero = _mm_setzero_si128();\n");
printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
graph_reset_visited(node, stage_num, node_num);
gen_cospi_list_sse4_1(node, stage_num, node_num);
graph_reset_visited(node, stage_num, node_num);
for (int si = 1; si < stage_num; si++) {
char in[100];
char out[100];
printf("\n");
printf(" // stage %d\n", si);
if (si == 1)
snprintf(in, 100, "%s", "input");
else
snprintf(in, 100, "x%d", si - 1);
if (si == stage_num - 1) {
snprintf(out, 100, "%s", "output");
} else {
snprintf(out, 100, "x%d", si);
printf(" __m128i %s[%d];\n", out, node_num);
}
// computation code
for (int ni = 0; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_sse4_1(node + idx, in, out);
}
}
printf("}\n");
}
void gen_hybrid_code(CODE_TYPE code_type, TYPE_TXFM txfm_type, int node_num) {
int stage_num = get_hybrid_stage_num(txfm_type, node_num);
Node *node = new Node[node_num * stage_num];
init_graph(node, stage_num, node_num);
gen_hybrid_graph_1d(node, stage_num, node_num, 0, 0, node_num, txfm_type);
switch (code_type) {
case CODE_TYPE_C: gen_code_c(node, stage_num, node_num, txfm_type); break;
case CODE_TYPE_SSE2:
gen_code_sse2(node, stage_num, node_num, txfm_type);
break;
case CODE_TYPE_SSE4_1:
gen_code_sse4_1(node, stage_num, node_num, txfm_type);
break;
}
delete[] node;
}
int main(int argc, char **argv) {
CODE_TYPE code_type = CODE_TYPE_SSE4_1;
for (int txfm_type = TYPE_DCT; txfm_type < TYPE_LAST; txfm_type++) {
for (int node_num = 4; node_num <= 64; node_num *= 2) {
gen_hybrid_code(code_type, (TYPE_TXFM)txfm_type, node_num);
}
}
return 0;
}