| """ |
| Copyright (c) 2024, Alliance for Open Media. All rights reserved |
| |
| This source code is subject to the terms of the BSD 3-Clause Clear License |
| and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear |
| License was not distributed with this source code in the LICENSE file, you |
| can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the |
| Alliance for Open Media Patent License 1.0 was not distributed with this |
| source code in the PATENTS file, you can obtain it at |
| aomedia.org/license/patent-license/. |
| """ |
| import multiprocessing |
| import os |
| |
| from termcolor import cprint |
| |
| import parakit.config.user as user |
| from parakit.config.training import RATE_LIST |
| from parakit.entropy.file_collector import FileCollector |
| from parakit.entropy.trainer import Trainer |
| |
| |
| def train_task(train_info): |
| file_fullpath = train_info[0] |
| trainer = Trainer(file_fullpath, RATE_LIST) |
| trainer.run_rate_training_on_file() |
| print(f"Trained: {train_info[2]}", flush=True) |
| |
| |
| def run(path_ctxdata="./results/data", user_config_file="parameters.yaml"): |
| test_output_tag, _ = user.read_config_data(user_config_file) |
| fc = FileCollector(path_ctxdata, "csv", subtext=f"_{test_output_tag}.") |
| data_files = fc.get_files() |
| num_data_files = len(data_files) |
| cprint( |
| f"Training based on {num_data_files} csv files under {path_ctxdata}:", |
| attrs=["bold"], |
| ) |
| # prepare decoding task information |
| train_info = [] |
| for idx, file_csv in enumerate(data_files): |
| file_fullpath = f"{path_ctxdata}/{file_csv}" |
| # print(f'[{idx}]: {file_csv}') |
| train_info.append((file_fullpath, idx, file_csv)) |
| # run training in paralel using all available cores |
| num_cpu = os.cpu_count() |
| with multiprocessing.Pool(num_cpu) as pool: |
| pool.map(train_task, train_info) |
| cprint("Training complete!\n", "green", attrs=["bold"]) |
| |
| |
| if __name__ == "__main__": |
| run() |