blob: 792b700c289aa22c0db9bd23505d2d703a63c9ea [file] [log] [blame] [edit]
"""
Copyright (c) 2024, Alliance for Open Media. All rights reserved
This source code is subject to the terms of the BSD 3-Clause Clear License
and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
License was not distributed with this source code in the LICENSE file, you
can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
Alliance for Open Media Patent License 1.0 was not distributed with this
source code in the PATENTS file, you can obtain it at
aomedia.org/license/patent-license/.
"""
import os
from termcolor import cprint
import parakit.tasks.collect_results as collect
import parakit.tasks.decoding as decode
import parakit.tasks.generate_tables as generate_table
import parakit.tasks.training as train
PATH_BITSTREAM = "./unit_test/bitstreams"
PATH_CTXDATA = "./unit_test/data"
PATH_TABLE = "./unit_test"
CFG_FILE = "./unit_test/parameters_unit_test.yaml"
TABLE_FILE = "Context-Table_Combined_Result_Unit-Test.h"
def main():
cprint(
"-------------------- UNIT TEST ----------------------", "black", attrs=["bold"]
)
# Step 1: run decoder to collect data
decode.run(
path_bitstream=PATH_BITSTREAM,
path_ctx_data=PATH_CTXDATA,
user_config_file=CFG_FILE,
)
# Step 2: run training and create result report
train.run(path_ctxdata=PATH_CTXDATA, user_config_file=CFG_FILE)
# Step 3: collect results
collect.run(path_ctxdata=PATH_CTXDATA, user_config_file=CFG_FILE)
# Step 4: generate context tables
generate_table.run(
path_ctxdata=PATH_CTXDATA, path_table=PATH_TABLE, user_config_file=CFG_FILE
)
# Check if TABLE_FILE exists after unit test
table_file = f"{PATH_TABLE}/{TABLE_FILE}"
check_file = os.path.exists(table_file)
if check_file:
cprint("Unit test successful!", "green", attrs=["bold"])
else:
cprint(
f"Unit test failed: {TABLE_FILE} cannot be found.", "red", attrs=["bold"]
)
cprint(
"-----------------------------------------------------", "black", attrs=["bold"]
)
if __name__ == "__main__":
main()