CWG-E120-v1: Implementation ParaKit software on AVM research-v7.0.0
Implementation of the ParaKit training software based on CWG-E120-v1.
The training software is integrated into AVM software under ParaKit/ directory.
For usage and detailed instructions, the developers are referred to ParaKit/README.md file.
diff --git a/.gitignore b/.gitignore
index a488609..3ca43d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,4 +5,11 @@
.vscode
.gitignore
TAGS
-__pycache__
\ No newline at end of file
+__pycache__
+ParaKit/venv/
+ParaKit/bitstreams/*
+ParaKit/results/*
+ParaKit/unit_test/data/*
+ParaKit/unit_test/*.h
+ParaKit/binaries/*
+parakit.egg-info/
\ No newline at end of file
diff --git a/ParaKit/LICENSE b/ParaKit/LICENSE
new file mode 100644
index 0000000..d1a8710
--- /dev/null
+++ b/ParaKit/LICENSE
@@ -0,0 +1,30 @@
+BSD 3-Clause Clear License The Clear BSD License
+
+Copyright (c) 2024, Alliance for Open Media
+
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the distribution.
+
+3. Neither the name of the Alliance for Open Media nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+
+NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
+EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
+OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
+THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
+OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/ParaKit/PATENTS b/ParaKit/PATENTS
new file mode 100644
index 0000000..acb5f97
--- /dev/null
+++ b/ParaKit/PATENTS
@@ -0,0 +1,107 @@
+Alliance for Open Media Patent License 1.0
+
+1. License Terms.
+
+1.1. Patent License. Subject to the terms and conditions of this License, each
+ Licensor, on behalf of itself and successors in interest and assigns,
+ grants Licensee a non-sublicensable, perpetual, worldwide, non-exclusive,
+ no-charge, royalty-free, irrevocable (except as expressly stated in this
+ License) patent license to its Necessary Claims to make, use, sell, offer
+ for sale, import or distribute any Implementation.
+
+1.2. Conditions.
+
+1.2.1. Availability. As a condition to the grant of rights to Licensee to make,
+ sell, offer for sale, import or distribute an Implementation under
+ Section 1.1, Licensee must make its Necessary Claims available under
+ this License, and must reproduce this License with any Implementation
+ as follows:
+
+ a. For distribution in source code, by including this License in the
+ root directory of the source code with its Implementation.
+
+ b. For distribution in any other form (including binary, object form,
+ and/or hardware description code (e.g., HDL, RTL, Gate Level Netlist,
+ GDSII, etc.)), by including this License in the documentation, legal
+ notices, and/or other written materials provided with the
+ Implementation.
+
+1.2.2. Additional Conditions. This license is directly from Licensor to
+ Licensee. Licensee acknowledges as a condition of benefiting from it
+ that no rights from Licensor are received from suppliers, distributors,
+ or otherwise in connection with this License.
+
+1.3. Defensive Termination. If any Licensee, its Affiliates, or its agents
+ initiates patent litigation or files, maintains, or voluntarily
+ participates in a lawsuit against another entity or any person asserting
+ that any Implementation infringes Necessary Claims, any patent licenses
+ granted under this License directly to the Licensee are immediately
+ terminated as of the date of the initiation of action unless 1) that suit
+ was in response to a corresponding suit regarding an Implementation first
+ brought against an initiating entity, or 2) that suit was brought to
+ enforce the terms of this License (including intervention in a third-party
+ action by a Licensee).
+
+1.4. Disclaimers. The Reference Implementation and Specification are provided
+ "AS IS" and without warranty. The entire risk as to implementing or
+ otherwise using the Reference Implementation or Specification is assumed
+ by the implementer and user. Licensor expressly disclaims any warranties
+ (express, implied, or otherwise), including implied warranties of
+ merchantability, non-infringement, fitness for a particular purpose, or
+ title, related to the material. IN NO EVENT WILL LICENSOR BE LIABLE TO
+ ANY OTHER PARTY FOR LOST PROFITS OR ANY FORM OF INDIRECT, SPECIAL,
+ INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER FROM ANY CAUSES OF
+ ACTION OF ANY KIND WITH RESPECT TO THIS LICENSE, WHETHER BASED ON BREACH
+ OF CONTRACT, TORT (INCLUDING NEGLIGENCE), OR OTHERWISE, AND WHETHER OR
+ NOT THE OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+2. Definitions.
+
+2.1. Affiliate. "Affiliate" means an entity that directly or indirectly
+ Controls, is Controlled by, or is under common Control of that party.
+
+2.2. Control. "Control" means direct or indirect control of more than 50% of
+ the voting power to elect directors of that corporation, or for any other
+ entity, the power to direct management of such entity.
+
+2.3. Decoder. "Decoder" means any decoder that conforms fully with all
+ non-optional portions of the Specification.
+
+2.4. Encoder. "Encoder" means any encoder that produces a bitstream that can
+ be decoded by a Decoder only to the extent it produces such a bitstream.
+
+2.5. Final Deliverable. "Final Deliverable" means the final version of a
+ deliverable approved by the Alliance for Open Media as a Final
+ Deliverable.
+
+2.6. Implementation. "Implementation" means any implementation, including the
+ Reference Implementation, that is an Encoder and/or a Decoder. An
+ Implementation also includes components of an Implementation only to the
+ extent they are used as part of an Implementation.
+
+2.7. License. "License" means this license.
+
+2.8. Licensee. "Licensee" means any person or entity who exercises patent
+ rights granted under this License.
+
+2.9. Licensor. "Licensor" means (i) any Licensee that makes, sells, offers
+ for sale, imports or distributes any Implementation, or (ii) a person
+ or entity that has a licensing obligation to the Implementation as a
+ result of its membership and/or participation in the Alliance for Open
+ Media working group that developed the Specification.
+
+2.10. Necessary Claims. "Necessary Claims" means all claims of patents or
+ patent applications, (a) that currently or at any time in the future,
+ are owned or controlled by the Licensor, and (b) (i) would be an
+ Essential Claim as defined by the W3C Policy as of February 5, 2004
+ (https://www.w3.org/Consortium/Patent-Policy-20040205/#def-essential)
+ as if the Specification was a W3C Recommendation; or (ii) are infringed
+ by the Reference Implementation.
+
+2.11. Reference Implementation. "Reference Implementation" means an Encoder
+ and/or Decoder released by the Alliance for Open Media as a Final
+ Deliverable.
+
+2.12. Specification. "Specification" means the specification designated by
+ the Alliance for Open Media as a Final Deliverable for which this
+ License was issued.
\ No newline at end of file
diff --git a/ParaKit/README.md b/ParaKit/README.md
new file mode 100644
index 0000000..a45e1a6
--- /dev/null
+++ b/ParaKit/README.md
@@ -0,0 +1,125 @@
+<p align="center">
+<img src="./logo.png" width=320>
+</p>
+
+# ParaKit
+ParaKit is a Python toolkit for training <u>P</u>robability and <u>A</u>daptation <u>R</u>ate <u>A</u>djustment (PARA) parameters used to define context model initializations for the AV2 video coding standard, currently under development by the Alliance for Open Media (AOM).
+
+ParaKit is named after Apple's CWG-D115 proposal to the AOM's Coding Working Group, entitled "<u>PARA</u>: Probability Adaptation Rate Adjustment for Entropy Coding", where the "<u>Kit</u>" comes from the word toolkit, often referring to a collection of software tools.
+
+---
+
+## 1. Requirements
+ParaKit is built on top of the AV2 reference software (AVM), so the requirements include:
+
+- For the compilation of AVM software, it is recommended to install a recent version of `cmake` (e.g., version 3.29 and after).
+- For setting up necessary Python packages, it is required to install `Homebrew` (e.g., version 4.3.3 and after).
+
+ParaKit's training is data-driven, so it requires collecting data from AVM coded bitstreams. For this purpose, a separate branch (`research-v7.0.0-parakit`) in AVM is created as a reference implementation that allows developers to collect data for a selection of contexts.
+[Section 5](#5-data-collection-guidelines-for-modifying-avm) below provides some instructions on how to modify the ParaKit-AVM codebase to collect data for the context(s) of interest.
+
+After making necessary modifications to the AVM for data collection, ParaKit has the following two requirements to be able to run training:
+
+1. a binary `aomdec`*, compiled from a version of AVM for data collection, and
+2. compatible AVM bitstreams, from which the data will be collected using `aomdec`.
+
+*Note: `aomdec` needs to be compiled on the same platform that the developer will run ParaKit.
+
+---
+
+## 2. Installation
+<b>Step 1:</b> clone AVM, change directory and switch to `research-v7.0.0-parakit` branch.
+```
+git clone https://gitlab.com/AOMediaCodec/avm.git ParaKit-AVM
+cd ParaKit-AVM
+git checkout research-v7.0.0-parakit
+```
+
+<b>Step 2:</b> change directory to `ParaKit` and run the setup.sh script, which creates a python virtual environment `venv` and installs necessary python packages within `venv`.
+```
+cd ParaKit
+source setup.sh
+```
+
+<b>Step 3:</b> compile the AVM decoder by running the following command which will build `aomdec` under ./binaries directory.
+```
+source setup_decoder.sh
+```
+Note that `setup_decoder.sh` script is provided for convenience. The user can always copy a valid `aomdec` binary compiled from the modified AVM codebase (e.g., `research-v7.0.0-parakit` branch ).
+
+Also, make sure that `aomdec` binary under `binaries/` is executable, if not, run the following command.
+```
+sudo chmod +x ./binaries/aomdec
+```
+<b>Important note:</b> The sample AVM implementation in `research-v7.0.0-parakit` branch can collect data only for `eob_flag_cdf16` and `eob_flag_cdf32` contexts. To support other contexts, the developer needs to replace the binary compiled with the necessary changes to AVM. Please refer to [Section 5](#5-data-collection-guidelines-for-modifying-avm) for more details on modifying the AVM codebase.
+
+<b>Installation complete:</b> After the steps above, we are ready to use ParaKit and train for `eob_flag_cdf16` and `eob_flag_cdf32` contexts. You may now run the unit test as the next step.
+
+<b>Unit test (optional, but recommended):</b> run the following unit test script to further check if the installation is complete without any issues.
+```
+python run_unit_test.py
+```
+The unit test uses the two sample bitstreams under `unit_test/bitstreams` compatible with `research-v7.0.0-parakit`, as discussed in [Section 1](#1-requirements).
+
+## 3. Usage: running training via ParaKit
+<b>Step 1:</b> replace `aomdec` under `binaries/` with a new decoder binary (based on to collect data for desired contexts). The `setup_decoder.sh` script can be used to compile new binaries from modified AVM. Make sure that `aomdec` is executable (see Step 3 in [Section 2](#2-installation)).
+
+<b>Step 2:</b> copy compatible AVM bitstreams under `bitstreams/` directory. The only requirement is that each bitstream's filename should start with `Bin_`.
+
+<b>Step 3:</b> check and modify `parameters.yaml` file to set necessary user defined configurations. See [Section 4](#4-details-of-configuring-parametersyaml) for more details.
+
+<b>Step 4:</b> run the `run.py` python script.
+```
+python run.py
+```
+This step will run the whole training pipeline that:
+
+1. collects the data in csv format by decoding all the bitstreams in the `bitstreams/` directory. The csv files will be generated under `results/data/` directory,
+2. runs the training for each csv data under `results/data/` and generates a result report file in json format,
+3. collects and combines the results in json files, and
+4. generates the context initialization tables in a single `Context-Table_*.h` file under `results/`.
+
+<b>Step 5:</b> use the generated tables from the `Context-Table_*.h` file under `results/` by copying them into the AVM codebase for testing.
+
+---
+
+## 4. Details of configuring parameters.yaml
+ParaKit requires the `parameters.yaml` file present in the main directory.
+The sample `./parameters.yaml` provided in the repository is configured to train for `eob_flag_cdf16` and `eob_flag_cdf32` contexts as follows:
+```
+TEST_OUTPUT_TAG: "Test-Tag"
+BITSTREAM_EXTENSION: "av1"
+DESIRED_CTX_LIST:
+ - eob_flag_cdf16
+ - eob_flag_cdf32
+eob_flag_cdf16: "av1_default_eob_multi16_cdfs"
+eob_flag_cdf32: "av1_default_eob_multi32_cdfs"
+```
+where the mandatory fields are:
+
+- `TEST_OUTPUT_TAG` is the tag used to identify a test (this tag appears in the resulting generated context table `results/Context-Table_*.h`),
+- `BITSTREAM_EXTENSION` specifies the extension of bitstreams copied under `bitstreams/`,
+- `DESIRED_CTX_LIST` specifies the context(s) to be trained,
+- `eob_flag_cdf16: "av1_default_eob_multi16_cdfs"` and `eob_flag_cdf32: "av1_default_eob_multi32_cdfs"` define the context name to context table mapping.
+
+<b>Important note:</b> The developer needs to make sure that the context names (e.g., `eob_flag_cdf16` or `eob_flag_cdf32`) follow the same convention in the ParaKit's AVM decoder (`aomdec`).
+The csv data files obtained from `aomdec` are in `Stat_context_name_*.csv` format, where in the above example, `context_name` is replaced by `eob_flag_cdf16` or `eob_flag_cdf32`.
+
+---
+
+## 5. Data collection: guidelines for modifying AVM
+The data collection requires some modifications to AVM decoder implementation. For this purpose, `research-v7.0.0-parakit` branch is created as a reference implementation based on AVM.
+
+In the `research-v7.0.0-parakit` branch, the basic data collection module is implemented in `aom_read_symbol_probdata` function by extending the existing `aom_read_symbol` function in AVM. All the changes related to data collection are implemented under the `CONFIG_PARAKIT_COLLECT_DATA` macro. The comments including `@ParaKit` text provides additional information to guide developers on how to extend data collection for different contexts.
+
+The `research-v7.0.0-parakit` branch implements the necessary changes on top `research-v7.0.0` tag to collect data specifically for `eob_flag_cdf16` and `eob_flag_cdf32` context groups.
+Developers can extend this to add support for new (or any other) contexts on by following the changes under `CONFIG_PARAKIT_COLLECT_DATA` macro and instructions in the comments by searching the text `@ParaKit` on their local AVM version.
+
+---
+
+## Contact
+Please contact Hilmi Egilmez for any questions regarding the use of ParaKit.
+
+E-mail: h_egilmez@apple.com
+
+---
diff --git a/ParaKit/logo.png b/ParaKit/logo.png
new file mode 100644
index 0000000..b869b40
--- /dev/null
+++ b/ParaKit/logo.png
Binary files differ
diff --git a/ParaKit/parameters.yaml b/ParaKit/parameters.yaml
new file mode 100644
index 0000000..f70c292
--- /dev/null
+++ b/ParaKit/parameters.yaml
@@ -0,0 +1,7 @@
+TEST_OUTPUT_TAG: "Test-Tag"
+BITSTREAM_EXTENSION: "av1"
+DESIRED_CTX_LIST:
+ - eob_flag_cdf16
+ - eob_flag_cdf32
+eob_flag_cdf16: "av1_default_eob_multi16_cdfs"
+eob_flag_cdf32: "av1_default_eob_multi32_cdfs"
diff --git a/ParaKit/pyproject.toml b/ParaKit/pyproject.toml
new file mode 100644
index 0000000..7e58f13
--- /dev/null
+++ b/ParaKit/pyproject.toml
@@ -0,0 +1,13 @@
+[build-system]
+requires = ["setuptools>=64", "setuptools_scm>=8"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "ParaKit"
+version = "1.0.0"
+authors = [
+ { name="Hilmi E. Egilmez", email="hegilmez@apple.com" },
+]
+description = "ParaKit is a Python toolkit for training Probability and Adaptation Rate Adjustment (PARA)."
+readme = "README.md"
+requires-python = ">=3.8"
diff --git a/ParaKit/requirements.txt b/ParaKit/requirements.txt
new file mode 100644
index 0000000..b6c9221
--- /dev/null
+++ b/ParaKit/requirements.txt
@@ -0,0 +1,9 @@
+numpy==1.24.4
+pandas==1.2.4
+pandas-stubs==2.2.2.240603
+python-dateutil==2.8.2
+pytz==2024.1
+PyYAML==6.0.1
+types-PyYAML==6.0.12.20240311
+six==1.16.0
+termcolor==2.4.0
diff --git a/ParaKit/run.py b/ParaKit/run.py
new file mode 100644
index 0000000..4b784a0
--- /dev/null
+++ b/ParaKit/run.py
@@ -0,0 +1,56 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+from datetime import datetime
+
+from termcolor import cprint
+
+import parakit.tasks.collect_results as collect
+import parakit.tasks.decoding as decode
+import parakit.tasks.generate_tables as generate_table
+import parakit.tasks.training as train
+
+
+def print_elapsed_time(start, end):
+ total_time = end - start
+ hours = total_time.seconds // 3600
+ mins = total_time.seconds // 60 % 60
+ sec = total_time.seconds % 60
+ print(f"Elapsed time: {total_time.days}d {hours}h:{mins}m:{sec}s")
+
+
+def main():
+ cprint(
+ "------------------ RUN TRAINING ---------------------", "black", attrs=["bold"]
+ )
+ start_time = datetime.now()
+
+ # Step 1: run decoder to collect data
+ decode.run()
+
+ # Step 2: run training and create result report
+ train.run()
+
+ # Step 3: collect results
+ collect.run()
+
+ # Step 4: generate context tables
+ generate_table.run()
+
+ end_time = datetime.now()
+ print_elapsed_time(start_time, end_time)
+ cprint(
+ "-----------------------------------------------------", "black", attrs=["bold"]
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ParaKit/run_unit_test.py b/ParaKit/run_unit_test.py
new file mode 100644
index 0000000..792b700
--- /dev/null
+++ b/ParaKit/run_unit_test.py
@@ -0,0 +1,65 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import os
+
+from termcolor import cprint
+
+import parakit.tasks.collect_results as collect
+import parakit.tasks.decoding as decode
+import parakit.tasks.generate_tables as generate_table
+import parakit.tasks.training as train
+
+PATH_BITSTREAM = "./unit_test/bitstreams"
+PATH_CTXDATA = "./unit_test/data"
+PATH_TABLE = "./unit_test"
+CFG_FILE = "./unit_test/parameters_unit_test.yaml"
+TABLE_FILE = "Context-Table_Combined_Result_Unit-Test.h"
+
+
+def main():
+ cprint(
+ "-------------------- UNIT TEST ----------------------", "black", attrs=["bold"]
+ )
+ # Step 1: run decoder to collect data
+ decode.run(
+ path_bitstream=PATH_BITSTREAM,
+ path_ctx_data=PATH_CTXDATA,
+ user_config_file=CFG_FILE,
+ )
+
+ # Step 2: run training and create result report
+ train.run(path_ctxdata=PATH_CTXDATA, user_config_file=CFG_FILE)
+
+ # Step 3: collect results
+ collect.run(path_ctxdata=PATH_CTXDATA, user_config_file=CFG_FILE)
+
+ # Step 4: generate context tables
+ generate_table.run(
+ path_ctxdata=PATH_CTXDATA, path_table=PATH_TABLE, user_config_file=CFG_FILE
+ )
+
+ # Check if TABLE_FILE exists after unit test
+ table_file = f"{PATH_TABLE}/{TABLE_FILE}"
+ check_file = os.path.exists(table_file)
+ if check_file:
+ cprint("Unit test successful!", "green", attrs=["bold"])
+ else:
+ cprint(
+ f"Unit test failed: {TABLE_FILE} cannot be found.", "red", attrs=["bold"]
+ )
+ cprint(
+ "-----------------------------------------------------", "black", attrs=["bold"]
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ParaKit/setup.sh b/ParaKit/setup.sh
new file mode 100644
index 0000000..704e639
--- /dev/null
+++ b/ParaKit/setup.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+echo "--- Installing ParaKit ---"
+if [ ! -d venv ]; then
+ pipdir=`which pip3`
+ if [ $? -ne 0 ]; then
+ echo "Installing pip3 with the latest version of python3"
+ brew install python3
+ fi
+ python3 -m venv venv
+ source venv/bin/activate
+ # update pip
+ pip install --upgrade pip
+ # install required packages
+ pip install -r requirements.txt
+ # install package locally
+ pip install -e .
+else
+ echo "venv exists: activating"
+ source venv/bin/activate
+fi
+mkdir -p binaries
+mkdir -p bitstreams
+mkdir -p results/
+mkdir -p results/data
+mkdir -p unit_test/data
+
+export PYTHONPATH=$(pwd)
+echo "Setup Complete!"
diff --git a/ParaKit/setup_decoder.sh b/ParaKit/setup_decoder.sh
new file mode 100644
index 0000000..a97b837
--- /dev/null
+++ b/ParaKit/setup_decoder.sh
@@ -0,0 +1,53 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+#create directories
+mkdir -p binaries
+
+if [ ! -f binaries/aomdec ]; then
+
+ #check if AVM directory is available based on LICENSE file
+ is_parent_avm=$(cat ../LICENSE | grep "Alliance for Open Media")
+ if [ -n "$is_parent_avm" ]; then
+ echo "AVM software is available in parent directory."
+ else
+ echo "Error: AVM is not available in parent directory."
+ exit 1
+ fi
+ # create build directory
+ mkdir -p binaries/build
+
+ # check Makefile
+ if [ ! -f binaries/build/Makefile ]; then
+ cmake -S ../ -B ./binaries/build -DCONFIG_PARAKIT_COLLECT_DATA=1
+ else
+ echo "Makefile exists: building aomdec..."
+ fi
+
+ # Makefile should exist
+ if [ -f binaries/build/Makefile ]; then
+ make aomdec -C ./binaries/build
+ else
+ echo "Error: Makefile does not exist cannot compile aomdec"
+ exit 1
+ fi
+
+ # copy aomdec under binaries
+ if [ -f binaries/build/aomdec ]; then
+ cp ./binaries/build/aomdec ./binaries/aomdec
+ else
+ echo "Error: aomdec does not exist under ./binaries/build/"
+ exit 1
+ fi
+
+ #clear build if aomdec is under binaries
+ if [ -f binaries/aomdec ]; then
+ rm -rf ./binaries/build/
+ echo "Compilation complete!"
+ else
+ echo "Error: aomdec does not exist under ./binaries/"
+ exit 1
+ fi
+else
+ echo "Compilation skipped, because ./binaries/aomdec exists (delete aomdec and rerun this script to recompile from parent directory)."
+fi
diff --git a/ParaKit/src/parakit/__init__.py b/ParaKit/src/parakit/__init__.py
new file mode 100644
index 0000000..fe2df03
--- /dev/null
+++ b/ParaKit/src/parakit/__init__.py
@@ -0,0 +1,11 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
diff --git a/ParaKit/src/parakit/config/__init__.py b/ParaKit/src/parakit/config/__init__.py
new file mode 100644
index 0000000..fe2df03
--- /dev/null
+++ b/ParaKit/src/parakit/config/__init__.py
@@ -0,0 +1,11 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
diff --git a/ParaKit/src/parakit/config/table.py b/ParaKit/src/parakit/config/table.py
new file mode 100644
index 0000000..c16baa4
--- /dev/null
+++ b/ParaKit/src/parakit/config/table.py
@@ -0,0 +1,21 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+DEFAULT_PROB_INITIALIZER = False
+DEFAULT_RATE_PARAMETER = False
+ZERO_RATE_PARAMETER = False
+MIN_NUM_DATA_SAMPLES_NEEDED = 10
+DESIRED_RATE_IDX = "best_rate_idx"
+
+REGULAR_CTX_GROUP_MAPPING = {
+ "eob_flag_cdf16": "av1_default_eob_multi16_cdfs",
+ "eob_flag_cdf32": "av1_default_eob_multi32_cdfs",
+}
diff --git a/ParaKit/src/parakit/config/training.py b/ParaKit/src/parakit/config/training.py
new file mode 100644
index 0000000..be6e812
--- /dev/null
+++ b/ParaKit/src/parakit/config/training.py
@@ -0,0 +1,140 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+RATE_LIST = [
+ (0, 0, 0),
+ (0, 0, -1),
+ (0, 0, -2),
+ (0, 0, 1),
+ (0, 0, 2),
+ (0, -1, 0),
+ (0, -1, -1),
+ (0, -1, -2),
+ (0, -1, 1),
+ (0, -1, 2),
+ (0, -2, 0),
+ (0, -2, -1),
+ (0, -2, -2),
+ (0, -2, 1),
+ (0, -2, 2),
+ (0, 1, 0),
+ (0, 1, -1),
+ (0, 1, -2),
+ (0, 1, 1),
+ (0, 1, 2),
+ (0, 2, 0),
+ (0, 2, -1),
+ (0, 2, -2),
+ (0, 2, 1),
+ (0, 2, 2),
+ (-1, 0, 0),
+ (-1, 0, -1),
+ (-1, 0, -2),
+ (-1, 0, 1),
+ (-1, 0, 2),
+ (-1, -1, 0),
+ (-1, -1, -1),
+ (-1, -1, -2),
+ (-1, -1, 1),
+ (-1, -1, 2),
+ (-1, -2, 0),
+ (-1, -2, -1),
+ (-1, -2, -2),
+ (-1, -2, 1),
+ (-1, -2, 2),
+ (-1, 1, 0),
+ (-1, 1, -1),
+ (-1, 1, -2),
+ (-1, 1, 1),
+ (-1, 1, 2),
+ (-1, 2, 0),
+ (-1, 2, -1),
+ (-1, 2, -2),
+ (-1, 2, 1),
+ (-1, 2, 2),
+ (-2, 0, 0),
+ (-2, 0, -1),
+ (-2, 0, -2),
+ (-2, 0, 1),
+ (-2, 0, 2),
+ (-2, -1, 0),
+ (-2, -1, -1),
+ (-2, -1, -2),
+ (-2, -1, 1),
+ (-2, -1, 2),
+ (-2, -2, 0),
+ (-2, -2, -1),
+ (-2, -2, -2),
+ (-2, -2, 1),
+ (-2, -2, 2),
+ (-2, 1, 0),
+ (-2, 1, -1),
+ (-2, 1, -2),
+ (-2, 1, 1),
+ (-2, 1, 2),
+ (-2, 2, 0),
+ (-2, 2, -1),
+ (-2, 2, -2),
+ (-2, 2, 1),
+ (-2, 2, 2),
+ (1, 0, 0),
+ (1, 0, -1),
+ (1, 0, -2),
+ (1, 0, 1),
+ (1, 0, 2),
+ (1, -1, 0),
+ (1, -1, -1),
+ (1, -1, -2),
+ (1, -1, 1),
+ (1, -1, 2),
+ (1, -2, 0),
+ (1, -2, -1),
+ (1, -2, -2),
+ (1, -2, 1),
+ (1, -2, 2),
+ (1, 1, 0),
+ (1, 1, -1),
+ (1, 1, -2),
+ (1, 1, 1),
+ (1, 1, 2),
+ (1, 2, 0),
+ (1, 2, -1),
+ (1, 2, -2),
+ (1, 2, 1),
+ (1, 2, 2),
+ (2, 0, 0),
+ (2, 0, -1),
+ (2, 0, -2),
+ (2, 0, 1),
+ (2, 0, 2),
+ (2, -1, 0),
+ (2, -1, -1),
+ (2, -1, -2),
+ (2, -1, 1),
+ (2, -1, 2),
+ (2, -2, 0),
+ (2, -2, -1),
+ (2, -2, -2),
+ (2, -2, 1),
+ (2, -2, 2),
+ (2, 1, 0),
+ (2, 1, -1),
+ (2, 1, -2),
+ (2, 1, 1),
+ (2, 1, 2),
+ (2, 2, 0),
+ (2, 2, -1),
+ (2, 2, -2),
+ (2, 2, 1),
+ (2, 2, 2),
+]
+CHANGE_INITIALIZERS = False
+MIN_NUM_DATA_SAMPLES_NEEDED = 10
diff --git a/ParaKit/src/parakit/config/user.py b/ParaKit/src/parakit/config/user.py
new file mode 100644
index 0000000..06c9ec7
--- /dev/null
+++ b/ParaKit/src/parakit/config/user.py
@@ -0,0 +1,30 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import yaml
+
+
+def read_config_data(cfg_file):
+ with open(cfg_file) as f:
+ data = yaml.load(f, Loader=yaml.SafeLoader)
+ return (data["TEST_OUTPUT_TAG"], data["DESIRED_CTX_LIST"])
+
+
+def read_config_decode(cfg_file):
+ with open(cfg_file) as f:
+ data = yaml.load(f, Loader=yaml.SafeLoader)
+ return data["BITSTREAM_EXTENSION"].replace(".", "")
+
+
+def read_config_context(cfg_file, context_name):
+ with open(cfg_file) as f:
+ data = yaml.load(f, Loader=yaml.SafeLoader)
+ return data[context_name]
diff --git a/ParaKit/src/parakit/entropy/__init__.py b/ParaKit/src/parakit/entropy/__init__.py
new file mode 100644
index 0000000..fe2df03
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/__init__.py
@@ -0,0 +1,11 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
diff --git a/ParaKit/src/parakit/entropy/codec_cdf_functions.py b/ParaKit/src/parakit/entropy/codec_cdf_functions.py
new file mode 100644
index 0000000..b4d0b1c
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/codec_cdf_functions.py
@@ -0,0 +1,203 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import numpy as np
+
+from parakit.entropy.codec_default_cdf import (
+ AV1_PROB_COST,
+ AV1_PROB_COST_SHIFT,
+ CDF_INIT_TOP,
+ CDF_PROB_BITS,
+ CDF_PROB_TOP,
+)
+
+
+def clog2(x):
+ """Ceiling of log2"""
+ if x <= 0:
+ raise ValueError("clog2 input error")
+ return (x - 1).bit_length()
+
+
+def flog2(x):
+ """Floor of log2"""
+ if x <= 0:
+ raise ValueError("flog2 input error")
+ return x.bit_length() - 1
+
+
+def update_cdfinv_av1(cdf, val, counter, nsymb, roffset=0):
+ """Python implementation of the following C code from AVM codec:
+ --------------------------------------------------------------
+ static INLINE void update_cdf(aom_cdf_prob *cdf, int8_t val, int nsymbs) {
+ int rate;
+ int i, tmp;
+ static const int nsymbs2speed[17] = { 0, 0, 1, 1, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2 };
+ assert(nsymbs < 17);
+ rate = 3 + (cdf[nsymbs] > 15) + (cdf[nsymbs] > 31) +
+ nsymbs2speed[nsymbs]; // + get_msb(nsymbs);
+ tmp = AOM_ICDF(0);
+ // Single loop (faster)
+ for (i = 0; i < nsymbs - 1; ++i) {
+ tmp = (i == val) ? 0 : tmp;
+ if (tmp < cdf[i]) {
+ cdf[i] -= ((cdf[i] - tmp) >> rate);
+ } else {
+ cdf[i] += ((tmp - cdf[i]) >> rate);
+ }
+ }
+ cdf[nsymbs] += (cdf[nsymbs] < 32);
+ }
+ --------------------------------------------------------------
+ """
+ nsymbs2speed = [0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
+ rate = 3 + nsymbs2speed[nsymb]
+ if counter > 15:
+ rate = rate + 1
+ if counter > 31:
+ rate = rate + 1
+ rate = rate + roffset
+ tmp = CDF_INIT_TOP
+ for i in range(nsymb - 1):
+ if i == val:
+ tmp = 0
+ if tmp < cdf[i]:
+ cdf[i] -= (cdf[i] - tmp) // (2**rate)
+ else:
+ cdf[i] += (tmp - cdf[i]) // (2**rate)
+ return cdf
+
+
+def get_prob(num, den):
+ """Python implementation of the following C code from AVM codec:
+ --------------------------------------------------------------
+ static INLINE uint8_t get_prob(unsigned int num, unsigned int den) {
+ assert(den != 0);
+ {
+ const int p = (int)(((uint64_t)num * 256 + (den >> 1)) / den);
+ // (p > 255) ? 255 : (p < 1) ? 1 : p;
+ const int clipped_prob = p | ((255 - p) >> 23) | (p == 0);
+ return (uint8_t)clipped_prob;
+ }
+ }
+ --------------------------------------------------------------
+ """
+ p = int(((num * 256) + (den // 2)) / den)
+ if p > 255:
+ p = 255
+ if p < 1:
+ p = 1
+ return p
+
+
+def cost_literal_av1(n):
+ """Python implementation of the following C code from AVM codec:
+ --------------------------------------------------------------
+ define av1_cost_literal(n) ((n) * (1 << AV1_PROB_COST_SHIFT))
+ --------------------------------------------------------------
+ """
+ return n * (2**AV1_PROB_COST_SHIFT)
+
+
+def cost_symbol_av1(p15):
+ """Python implementation of the following C code from AVM codec:
+ --------------------------------------------------------------
+ static INLINE int av1_cost_symbol(aom_cdf_prob p15) {
+ // p15 can be out of range [1, CDF_PROB_TOP - 1]. Clamping it, so that the
+ // following cost calculation works correctly. Otherwise, if p15 =
+ // CDF_PROB_TOP, shift would be -1, and "p15 << shift" would be wrong.
+ p15 = (aom_cdf_prob)clamp(p15, 1, CDF_PROB_TOP - 1);
+ assert(0 < p15 && p15 < CDF_PROB_TOP);
+ const int shift = CDF_PROB_BITS - 1 - get_msb(p15);
+ const int prob = get_prob(p15 << shift, CDF_PROB_TOP);
+ assert(prob >= 128);
+ return av1_prob_cost[prob - 128] + av1_cost_literal(shift);
+ }
+ --------------------------------------------------------------
+ """
+ if p15 > (CDF_PROB_TOP - 1):
+ p15 = CDF_PROB_TOP - 1
+ if p15 < 1:
+ p15 = 1
+ msb = flog2(int(p15)) # or int(math.floor(math.log2(p15))) using math
+ shift = CDF_PROB_BITS - 1 - msb
+ prob_scaled = p15 * (2**shift)
+ prob = get_prob(prob_scaled, CDF_PROB_TOP)
+ if prob < 128:
+ raise ValueError(
+ f"Normalized probability value is less than 128 (prob={prob},msb={msb},prob_scaled={prob_scaled})"
+ )
+ return AV1_PROB_COST[prob - 128] + cost_literal_av1(shift)
+
+
+def pmf2cdfinv_av1(pmf):
+ """converts pmf to cdf-inverse"""
+ cdf = CDF_INIT_TOP - np.cumsum(pmf)
+ return cdf
+
+
+def cdfinv2pmf_av1(cdf_inv):
+ """converts cdf-inverse to pmf"""
+ cdf = np.insert(cdf_inv, 0, CDF_INIT_TOP)
+ pmf = np.diff(CDF_INIT_TOP - cdf)
+ return pmf
+
+
+def pmf2cdf_av1(pmf):
+ """converts pmf to cdf"""
+ return CDF_INIT_TOP - pmf2cdfinv_av1(pmf)
+
+
+def cdf2pmf_av1(cdf):
+ """converts cdf to pmf"""
+ cdf_ext = np.insert(cdf, 0, 0)
+ pmf = np.diff(cdf_ext)
+ return pmf
+
+
+def count2cdf_av1(value_count):
+ """Python implementation of the following C code from AVM codec:
+ --------------------------------------------------------------
+ static void counts_to_cdf(const aom_count_type *counts, aom_cdf_prob *cdf, int modes) {
+ int64_t csum[CDF_MAX_SIZE];
+ assert(modes <= CDF_MAX_SIZE);
+
+ csum[0] = counts[0] + 1;
+ for (int i = 1; i < modes; ++i) csum[i] = counts[i] + 1 + csum[i - 1];
+
+ for (int i = 0; i < modes; ++i) fprintf(logfile, "%d ", counts[i]);
+ fprintf(logfile, "\n");
+
+ int64_t sum = csum[modes - 1];
+ const int64_t round_shift = sum >> 1;
+ for (int i = 0; i < modes; ++i) {
+ cdf[i] = (csum[i] * CDF_PROB_TOP + round_shift) / sum;
+ cdf[i] = AOMMIN(cdf[i], CDF_PROB_TOP - (modes - 1 + i) * 4);
+ cdf[i] = (i == 0) ? AOMMAX(cdf[i], 4) : AOMMAX(cdf[i], cdf[i - 1] + 4);
+ }
+ }
+ --------------------------------------------------------------
+ """
+ value_count = value_count + 1
+ cdf_count = np.cumsum(value_count)
+ total_count = value_count.sum()
+ round_shift = total_count // 2
+ nsymb = len(cdf_count)
+ cdf = np.zeros(nsymb).astype(int)
+ for i in range(nsymb):
+ cdf[i] = int((cdf_count[i] * CDF_INIT_TOP + round_shift) / total_count)
+ cdf[i] = min(cdf[i], CDF_INIT_TOP - (nsymb - 1 + i) * 4)
+ if i == 0:
+ cdf[i] = max(cdf[i], 4)
+ else:
+ cdf[i] = max(cdf[i], cdf[i - 1] + 4)
+ return cdf
diff --git a/ParaKit/src/parakit/entropy/codec_default_cdf.py b/ParaKit/src/parakit/entropy/codec_default_cdf.py
new file mode 100644
index 0000000..e9ba146
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/codec_default_cdf.py
@@ -0,0 +1,204 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import numpy as np
+
+CDF_PROB_BITS = 15
+CDF_INIT_TOP = 32768
+CDF_PROB_TOP = 2**CDF_PROB_BITS
+AV1_PROB_COST_SHIFT = 9
+
+MAX_CTX_DIM = 4 # maximum dimension in context tables
+
+AV1_PROB_COST = (
+ 512,
+ 506,
+ 501,
+ 495,
+ 489,
+ 484,
+ 478,
+ 473,
+ 467,
+ 462,
+ 456,
+ 451,
+ 446,
+ 441,
+ 435,
+ 430,
+ 425,
+ 420,
+ 415,
+ 410,
+ 405,
+ 400,
+ 395,
+ 390,
+ 385,
+ 380,
+ 375,
+ 371,
+ 366,
+ 361,
+ 356,
+ 352,
+ 347,
+ 343,
+ 338,
+ 333,
+ 329,
+ 324,
+ 320,
+ 316,
+ 311,
+ 307,
+ 302,
+ 298,
+ 294,
+ 289,
+ 285,
+ 281,
+ 277,
+ 273,
+ 268,
+ 264,
+ 260,
+ 256,
+ 252,
+ 248,
+ 244,
+ 240,
+ 236,
+ 232,
+ 228,
+ 224,
+ 220,
+ 216,
+ 212,
+ 209,
+ 205,
+ 201,
+ 197,
+ 194,
+ 190,
+ 186,
+ 182,
+ 179,
+ 175,
+ 171,
+ 168,
+ 164,
+ 161,
+ 157,
+ 153,
+ 150,
+ 146,
+ 143,
+ 139,
+ 136,
+ 132,
+ 129,
+ 125,
+ 122,
+ 119,
+ 115,
+ 112,
+ 109,
+ 105,
+ 102,
+ 99,
+ 95,
+ 92,
+ 89,
+ 86,
+ 82,
+ 79,
+ 76,
+ 73,
+ 70,
+ 66,
+ 63,
+ 60,
+ 57,
+ 54,
+ 51,
+ 48,
+ 45,
+ 42,
+ 38,
+ 35,
+ 32,
+ 29,
+ 26,
+ 23,
+ 20,
+ 18,
+ 15,
+ 12,
+ 9,
+ 6,
+ 3,
+)
+
+
+def av1_default_cdf_parameters(n_taps):
+ arr = np.arange(1, n_taps)
+ cdf = (2**15) / n_taps * arr
+ cdf = cdf.round().astype(int)
+ return cdf
+
+
+def av1_default_pmf(n_taps):
+ cdf = av1_default_cdf_parameters(n_taps)
+ cdf = np.append(cdf, CDF_INIT_TOP)
+ pmf = np.diff(cdf)
+ return pmf
+
+
+def print_default_cdf_parameters(n_taps):
+ print(get_default_aom_cdf_string(n_taps))
+
+
+def get_aom_cdf_entry(n_taps, cdf):
+ str_cdf = f"AOM_CDF{n_taps}("
+ for i, p in enumerate(cdf):
+ if i < n_taps - 2:
+ str_cdf += str(p).rjust(5) + ", "
+ else:
+ str_cdf += str(p).rjust(5) + ")"
+ return str_cdf
+
+
+def get_default_aom_cdf_string(n_taps):
+ cdf = av1_default_cdf_parameters(n_taps)
+ return get_aom_cdf_entry(n_taps, cdf)
+
+
+def get_aom_cdf_string(cdf):
+ n_taps = len(cdf) + 1
+ return get_aom_cdf_entry(n_taps, cdf)
+
+
+if __name__ == "__main__":
+ print_default_cdf_parameters(n_taps=2)
+ print_default_cdf_parameters(n_taps=3)
+ print_default_cdf_parameters(n_taps=4)
+ print_default_cdf_parameters(n_taps=5)
+ print_default_cdf_parameters(n_taps=6)
+ print_default_cdf_parameters(n_taps=7)
+ print_default_cdf_parameters(n_taps=8)
+ print_default_cdf_parameters(n_taps=9)
+ print_default_cdf_parameters(n_taps=10)
+ print_default_cdf_parameters(n_taps=11)
+ print_default_cdf_parameters(n_taps=12)
+ print_default_cdf_parameters(n_taps=15)
+ print_default_cdf_parameters(n_taps=16)
diff --git a/ParaKit/src/parakit/entropy/data_collector.py b/ParaKit/src/parakit/entropy/data_collector.py
new file mode 100644
index 0000000..1a45e2c
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/data_collector.py
@@ -0,0 +1,141 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import pandas as pd
+
+import parakit.entropy.model as model
+from parakit.entropy.codec_default_cdf import CDF_INIT_TOP, MAX_CTX_DIM
+
+
+class DataCollector:
+ def __init__(self, csv_filename):
+ self.csv_filename = csv_filename
+ self._checkfile()
+
+ self._context_model = None
+ self._num_rows_header = None
+ self._parse_header_information()
+
+ def get_context_model(self):
+ return self._context_model
+
+ def collect_dataframe(self, max_rows=None):
+ csv_filename = self.csv_filename
+ file_path = "./" + csv_filename
+ num_lineskips = self._num_rows_header + 1
+ df = pd.read_csv(file_path, header=num_lineskips, nrows=max_rows)
+ return df
+
+ def _checkfile(self):
+ if not self.csv_filename.endswith(".csv"):
+ raise ValueError("File should have .csv extension")
+
+ def _parse_header_information(self):
+ csv_filename = self.csv_filename
+ file_path = "./" + csv_filename
+ with open(file_path) as f:
+ header_line = f.readline()
+ # create tokens
+ tokens = header_line.split(",")
+ num_tokens = len(tokens)
+ # parse
+ header_str, ctx_group_name = tokens[0].split(":")
+ num_symb = int(tokens[1])
+ num_dims = int(tokens[2])
+ size_list = []
+ if num_dims > 0:
+ size_list = [int(tokens[3 + i]) for i in range(num_dims)]
+
+ # check if first line if header includes 'Header'
+ if header_str != "Header":
+ raise Warning(
+ "First line of header does not match in file: " + csv_filename
+ )
+ # check size of tokens
+ expected_num_tokens = num_dims + 3
+ if num_tokens != expected_num_tokens:
+ raise Warning(
+ f"Expected and actual number of tokens in top header ({num_tokens} and {expected_num_tokens}) does not match for file "
+ + csv_filename
+ )
+
+ if ctx_group_name not in csv_filename:
+ raise Warning(
+ f"Context group {ctx_group_name} does not appear in {csv_filename}"
+ )
+
+ num_context_groups = 1
+ for size in size_list:
+ num_context_groups *= size
+
+ default_prob_inits = {}
+ for i in range(num_context_groups):
+ init_line = f.readline()
+ init_tokens = init_line.split(",")
+ expected_num_tokens = (
+ 1 + MAX_CTX_DIM + num_symb + 2
+ ) # ctx_idx, {4 dims}, {num_symb}, counter, rate
+
+ if len(init_tokens) != expected_num_tokens:
+ raise Warning(
+ f"Expected and actual number of tokens ({num_tokens} and {expected_num_tokens}) for context group at {i} does not match for file "
+ + csv_filename
+ )
+
+ # check header line-by-line
+ ctx_idx = int(init_tokens[0])
+ if ctx_idx != i:
+ raise Warning(
+ f"Expected context group index {i} does not match with {ctx_idx} in file "
+ + csv_filename
+ )
+ index_list = []
+ if num_dims > 0:
+ index_list = [
+ int(token) for token in init_tokens[1 : (MAX_CTX_DIM + 1)]
+ ]
+ index_list = index_list[(MAX_CTX_DIM - num_dims) :]
+ # create dictionary string
+ dict_str = ctx_group_name
+ for index in index_list:
+ dict_str = dict_str + "[" + str(index) + "]"
+
+ index_init_counter = MAX_CTX_DIM + 1 + num_symb
+ init_cdfs = [
+ int(token)
+ for token in init_tokens[MAX_CTX_DIM + 1 : index_init_counter]
+ ]
+ init_counter = int(init_tokens[index_init_counter])
+ index_init_rateidx = index_init_counter + 1
+ init_rateidx = int(init_tokens[index_init_rateidx])
+ # checks & assertions
+ if init_cdfs[-1] != CDF_INIT_TOP:
+ raise Warning(
+ f"At context group index {i} CDF entry does not match with maximum CDF value {CDF_INIT_TOP} in file "
+ + csv_filename
+ )
+ if init_counter != 0:
+ raise Warning(
+ f"At context group index {i} initial counter does not {0} in file "
+ + csv_filename
+ )
+
+ default_prob_inits[dict_str] = {
+ "initializer": init_cdfs,
+ "index": index_list,
+ "init_rateidx": init_rateidx,
+ "init_counter": init_counter,
+ }
+
+ self._context_model = model.EntropyContext(
+ ctx_group_name, num_symb, num_dims, size_list, default_prob_inits
+ )
+ self._num_rows_header = len(default_prob_inits)
diff --git a/ParaKit/src/parakit/entropy/file_collector.py b/ParaKit/src/parakit/entropy/file_collector.py
new file mode 100644
index 0000000..0c7ddc6
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/file_collector.py
@@ -0,0 +1,78 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import os
+
+
+class FileCollector:
+ def __init__(
+ self,
+ data_path,
+ file_extension,
+ ctx_group="",
+ coding_config="",
+ desired_qp=(),
+ subtext="",
+ starttext="",
+ ):
+ # non-public
+ self._data_path = data_path
+ self._file_extension = "." + file_extension # e.g., 'csv' and 'json'
+ self._ctx_group = ctx_group
+ self._coding_config = coding_config
+ self._desired_qp = desired_qp
+ self._subtext = subtext
+ self._starttext = starttext
+ # public
+ self.all_files = self._get_all_files()
+ self.files = self._filter_files()
+
+ # public
+ def get_files(self):
+ return self.files
+
+ # non-public
+ def _get_all_files(self):
+ files_all = os.listdir(self._data_path)
+ files = [f for f in files_all if f.endswith(self._file_extension)]
+ return sorted(files)
+
+ def _filter_files(self):
+ filtered_files = self.all_files
+ # ctx_group
+ ctx_group = self._ctx_group
+ if ctx_group != "":
+ filtered_files = [
+ f for f in filtered_files if "Stat_" + ctx_group + "_Bin_" in f
+ ]
+ # coding config (AI, RA, LD, etc..)
+ cfg = self._coding_config
+ if cfg != "":
+ filtered_files = [f for f in filtered_files if "_" + cfg + "_" in f]
+ # filter based on desired QP list
+ desired_qp = self._desired_qp
+ if len(desired_qp) > 0:
+ all_filt_files = filtered_files.copy()
+ filtered_files = []
+ for qp in desired_qp:
+ new_files = [f for f in all_filt_files if "_QP_" + str(qp) + "_" in f]
+ for f in new_files:
+ filtered_files.append(f)
+ # any subtext between underscores in '_{subtext}_' format
+ subtext = self._subtext
+ if subtext != "":
+ filtered_files = [f for f in filtered_files if subtext in f]
+
+ starttext = self._starttext
+ if starttext != "":
+ filtered_files = [f for f in filtered_files if f.startswith(starttext)]
+
+ return filtered_files
diff --git a/ParaKit/src/parakit/entropy/model.py b/ParaKit/src/parakit/entropy/model.py
new file mode 100644
index 0000000..2845fd6
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/model.py
@@ -0,0 +1,271 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import parakit.config.user as user
+from parakit.entropy.codec_default_cdf import get_aom_cdf_string
+
+ADD_COMMENTS = False
+
+
+class EntropyContext:
+ def __init__(
+ self,
+ ctx_group_name,
+ num_symb,
+ num_dims,
+ size_list,
+ model_dict,
+ user_cfg_file=(),
+ ):
+ self.ctx_group_name = ctx_group_name
+ self.num_symb = num_symb
+ self.num_dims = num_dims
+ self.size_list = size_list
+ self.model_dict = model_dict
+ self.user_cfg = user_cfg_file
+
+ def get_model_dictionary(self, index_list):
+ key = self._get_key(index_list)
+ return self.model_dict[key]
+
+ def get_model_string(self, index_list):
+ return self.get_cdf_string(index_list) + ", " + self.get_rate_string(index_list)
+
+ def get_rate_string(self, index_list):
+ model = self.get_model_dictionary(index_list)
+ rate = model["init_rate"]
+ rate_str = str(rate).rjust(3)
+ return rate_str
+
+ def get_cdf_string(self, index_list):
+ model = self.get_model_dictionary(index_list)
+ cdf_list = model["initializer"]
+ num_symb = len(cdf_list)
+ cdf = cdf_list[: num_symb - 1]
+ cdf_str = get_aom_cdf_string(cdf)
+ return cdf_str
+
+ def get_frequency_string(self, index_list):
+ model = self.get_model_dictionary(index_list)
+ num_samples = model["num_samples"]
+ sample_str = str(num_samples).rjust(8)
+ return sample_str
+
+ def get_complete_model_string(self, is_avm_style=True):
+ model_str = []
+ if is_avm_style:
+ model_str = self._get_avm_variable_name() + " = "
+ else:
+ model_str = self._get_arrayname() + " = "
+ num_dim = self.num_dims
+ if num_dim == 0:
+ model_str = self._model_string_dim0(model_str)
+ elif num_dim == 1:
+ model_str = self._model_string_dim1(model_str)
+ elif num_dim == 2:
+ model_str = self._model_string_dim2(model_str)
+ elif num_dim == 3:
+ model_str = self._model_string_dim3(model_str)
+ elif num_dim == 4:
+ model_str = self._model_string_dim4(model_str)
+ else:
+ model_str = "Number of dimensions needs to be between 0 and 4."
+
+ # last characters
+ model_str += "\n"
+
+ return model_str
+
+ def _get_num_ctx_groups(self):
+ num_context_groups = 1
+ for size in self.size_list:
+ num_context_groups *= size
+ return num_context_groups
+
+ def _get_key(self, index_list):
+ key = self.ctx_group_name
+ for index in index_list:
+ key = key + "[" + str(index) + "]"
+ return key
+
+ def _get_arrayname(self):
+ array_name = self.ctx_group_name
+ for size in self.size_list:
+ array_name = array_name + "[" + str(size) + "]"
+ array_name += "[CDF_SIZE(" + str(self.num_symb) + ")]"
+ return array_name
+
+ def _get_avm_variable_name(self):
+ array_name = "static const aom_cdf_prob " + user.read_config_context(
+ self.user_cfg, self.ctx_group_name
+ )
+ for size in self.size_list:
+ array_name = array_name + "[" + str(size) + "]"
+ array_name += "[CDF_SIZE(" + str(self.num_symb) + ")]"
+ return array_name
+
+ def _model_string_dim0(self, model_str, is_commented=ADD_COMMENTS):
+ end_char = ";"
+ index_list = []
+ comment_str = (
+ " // "
+ + self._get_key(index_list)
+ + " "
+ + self.get_frequency_string(index_list)
+ if is_commented
+ else ""
+ )
+ model_str = (
+ model_str
+ + "{ "
+ + self.get_model_string(index_list)
+ + " }"
+ + end_char
+ + comment_str
+ )
+ return model_str
+
+ def _model_string_dim1_special_nobracket(
+ self, model_str, is_commented=ADD_COMMENTS
+ ):
+ for i0 in range(self.size_list[0]):
+ index_list = [i0]
+ comment_str = (
+ " // "
+ + self._get_key(index_list)
+ + " "
+ + self.get_frequency_string(index_list)
+ if is_commented
+ else ""
+ )
+ model_str = (
+ model_str
+ + "{ "
+ + self.get_model_string(index_list)
+ + " },"
+ + comment_str
+ + "\n"
+ )
+ return model_str
+
+ def _model_string_dim1(self, model_str, is_commented=ADD_COMMENTS):
+ end_char = ";"
+ model_str += "{\n"
+ for i0 in range(self.size_list[0]):
+ index_list = [i0]
+ comment_str = (
+ " // "
+ + self._get_key(index_list)
+ + " "
+ + self.get_frequency_string(index_list)
+ if is_commented
+ else ""
+ )
+ model_str = (
+ model_str
+ + " { "
+ + self.get_model_string(index_list)
+ + " },"
+ + comment_str
+ + "\n"
+ )
+ model_str += "}" + end_char
+ return model_str
+
+ def _model_string_dim2(self, model_str, is_commented=ADD_COMMENTS):
+ end_char = ";"
+ model_str += "{\n"
+ for i0 in range(self.size_list[0]):
+ model_str += " {\n"
+ for i1 in range(self.size_list[1]):
+ index_list = [i0, i1]
+ comment_str = (
+ " // "
+ + self._get_key(index_list)
+ + " "
+ + self.get_frequency_string(index_list)
+ if is_commented
+ else ""
+ )
+ model_str = (
+ model_str
+ + " { "
+ + self.get_model_string(index_list)
+ + " },"
+ + comment_str
+ + "\n"
+ )
+ model_str += " },\n"
+ model_str += "}" + end_char # + '\n'
+ return model_str
+
+ def _model_string_dim3(self, model_str, is_commented=ADD_COMMENTS):
+ end_char = ";"
+ model_str += "{\n"
+ for i0 in range(self.size_list[0]):
+ model_str += " {\n"
+ for i1 in range(self.size_list[1]):
+ model_str += " {\n"
+ for i2 in range(self.size_list[2]):
+ index_list = [i0, i1, i2]
+ comment_str = (
+ " // "
+ + self._get_key(index_list)
+ + " "
+ + self.get_frequency_string(index_list)
+ if is_commented
+ else ""
+ )
+ model_str = (
+ model_str
+ + " { "
+ + self.get_model_string(index_list)
+ + " },"
+ + comment_str
+ + "\n"
+ )
+ model_str += " },\n"
+ model_str += " },\n"
+ model_str += "}" + end_char # + '\n'
+ return model_str
+
+ def _model_string_dim4(self, model_str, is_commented=ADD_COMMENTS):
+ end_char = ";"
+ model_str += "{\n"
+ for i0 in range(self.size_list[0]):
+ model_str += " {\n"
+ for i1 in range(self.size_list[1]):
+ model_str += " {\n"
+ for i2 in range(self.size_list[2]):
+ model_str += " {\n"
+ for i3 in range(self.size_list[3]):
+ index_list = [i0, i1, i2, i3]
+ comment_str = (
+ " // "
+ + self._get_key(index_list)
+ + " "
+ + self.get_frequency_string(index_list)
+ if is_commented
+ else ""
+ )
+ model_str = (
+ model_str
+ + " { "
+ + self.get_model_string(index_list)
+ + " },"
+ + comment_str
+ + "\n"
+ )
+ model_str += " },\n"
+ model_str += " },\n"
+ model_str += " },\n"
+ model_str += "}" + end_char # + '\n'
+ return model_str
diff --git a/ParaKit/src/parakit/entropy/result_collector.py b/ParaKit/src/parakit/entropy/result_collector.py
new file mode 100644
index 0000000..7aacdc9
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/result_collector.py
@@ -0,0 +1,147 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import json
+
+import numpy as np
+
+from parakit.entropy.codec_cdf_functions import count2cdf_av1
+from parakit.entropy.codec_default_cdf import CDF_INIT_TOP, av1_default_cdf_parameters
+
+
+class ResultCollector:
+ def __init__(self, json_filename):
+ self.json_filename = json_filename
+ self._checkfile()
+
+ def _checkfile(self):
+ if not self.json_filename.endswith(".json"):
+ raise ValueError("File should have .json extension")
+
+ def _checkdata(self, data, key, isIgnored, default_rateidx=0):
+ data_cost = data[key]["current_cost"]
+ cost_default = int(data_cost[default_rateidx])
+ if cost_default != 0:
+ print(
+ f"Warning: Default cost is {cost_default} for {key} in file {self.json_filename}",
+ end=" -- ",
+ )
+ if isIgnored:
+ print("Ignored!")
+ return False
+ else:
+ print("Not ignored!")
+ return True
+
+ def parse_json_file(self, isIgnored=True):
+ filepath = self.json_filename
+ data = {}
+ with open(filepath) as json_file:
+ data = json.load(json_file)
+ keys = list(data.keys())
+ keys.remove("information")
+ for key in keys:
+ init_rateidx = data[key]["init_rateidx"]
+ if self._checkdata(data, key, isIgnored, default_rateidx=init_rateidx):
+ data[key]["current_cost"] = np.array(data[key]["current_cost"])
+ data[key]["value_count"] = np.array(data[key]["value_count"])
+ data[key]["value_cost"] = np.array(data[key]["value_cost"])
+ else:
+ data[key]["current_cost"] = np.zeros(
+ len(data[key]["current_cost"]), dtype=int
+ )
+ data[key]["value_count"] = np.zeros(
+ len(data[key]["value_count"]), dtype=int
+ )
+ data[key]["value_cost"] = np.zeros(
+ len(data[key]["value_cost"]), dtype=int
+ )
+ return data
+
+ def calculate_percent_reduction(self, data_in_key):
+ actual_cost = data_in_key["initial_cost"]
+ if actual_cost == 0:
+ percent_reduction = np.zeros(len(data_in_key["current_cost"]))
+ else:
+ percent_reduction = 100 * (data_in_key["current_cost"] / actual_cost)
+ percent_reduction = np.nan_to_num(percent_reduction)
+ return percent_reduction
+
+ def combine_data(self, combined_data, data):
+ keys = list(data.keys())
+ keys.remove("information") # skip information key
+ if len(combined_data) == 0:
+ combined_data = data.copy()
+ # one can add new fields below, if needed
+ for key in keys:
+ percent_reduction = self.calculate_percent_reduction(data[key])
+ combined_data[key][
+ "percent_cost_total"
+ ] = percent_reduction # add percent reduction field
+ combined_data[key][
+ "percent_cost_min"
+ ] = percent_reduction # add percent reduction field
+ combined_data[key][
+ "percent_cost_max"
+ ] = percent_reduction # add percent reduction field
+ else:
+ for key in keys:
+ combined_data[key]["current_cost"] += data[key]["current_cost"]
+ combined_data[key]["num_samples"] += data[key]["num_samples"]
+ combined_data[key]["initial_cost"] += data[key]["initial_cost"]
+ combined_data[key]["codec_cost"] += data[key]["codec_cost"]
+ combined_data[key]["upper_cost"] += data[key]["upper_cost"]
+ combined_data[key]["value_count"] += data[key]["value_count"]
+ combined_data[key]["value_cost"] += data[key]["value_cost"]
+ # update percent reduction fields
+ percent_reduction = self.calculate_percent_reduction(data[key])
+ combined_data[key]["percent_cost_total"] += percent_reduction
+ combined_data[key]["percent_cost_min"] = np.minimum(
+ percent_reduction, combined_data[key]["percent_cost_min"]
+ )
+ combined_data[key]["percent_cost_max"] = np.maximum(
+ percent_reduction, combined_data[key]["percent_cost_max"]
+ )
+ return combined_data
+
+ def update_probability_initialzer(self, data):
+ keys = list(data.keys())
+ keys.remove("information") # skip information key
+ for key in keys:
+ value_count = data[key]["value_count"]
+ total_count = value_count.sum()
+ if total_count > 0:
+ pmf = value_count / total_count
+ cdf = np.cumsum(pmf)
+ scaled_cdf = np.round(CDF_INIT_TOP * cdf).astype(int)
+ if scaled_cdf[-1] > CDF_INIT_TOP:
+ scaled_cdf[-1] = CDF_INIT_TOP
+ data[key]["initializer"] = scaled_cdf.tolist()
+ return data
+
+ def update_probability_initialzer_av1_style(self, data):
+ keys = list(data.keys())
+ keys.remove("information") # skip information key
+ for key in keys:
+ value_count = data[key]["value_count"]
+ scaled_cdf = count2cdf_av1(value_count)
+ data[key]["initializer"] = scaled_cdf.tolist()
+ return data
+
+ def update_default_probability_initializer(self, data):
+ keys = list(data.keys())
+ keys.remove("information") # skip information key
+ for key in keys:
+ num_symb = data[key]["header"]["num_symb"]
+ cdf_list = av1_default_cdf_parameters(num_symb).tolist()
+ cdf_list.append(CDF_INIT_TOP)
+ data[key]["initializer"] = cdf_list
+ return data
diff --git a/ParaKit/src/parakit/entropy/trainer.py b/ParaKit/src/parakit/entropy/trainer.py
new file mode 100644
index 0000000..087fef1
--- /dev/null
+++ b/ParaKit/src/parakit/entropy/trainer.py
@@ -0,0 +1,342 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import json
+import os
+
+import numpy as np
+import pandas as pd
+
+import parakit.entropy.model as model
+from parakit.entropy.codec_cdf_functions import (
+ cdf2pmf_av1,
+ cdfinv2pmf_av1,
+ cost_symbol_av1,
+ pmf2cdfinv_av1,
+ update_cdfinv_av1,
+)
+from parakit.entropy.data_collector import DataCollector
+
+COST_REGULARIZATION = 4 # modifies maximum cost
+MIN_SAMPLE_REQUIRED = 10 # minimum number of data points needed to run training
+MAX_NUMBER_ROWS = 5000000
+
+
+class Trainer:
+ def __init__(
+ self,
+ filename_data,
+ rate_search_list,
+ cost_regularization=COST_REGULARIZATION,
+ min_sample=MIN_SAMPLE_REQUIRED,
+ max_rows=MAX_NUMBER_ROWS,
+ ):
+ # training parameters
+ self._filename_data = filename_data # data file
+ self._cost_regularization = cost_regularization
+ self._min_sample = min_sample
+ self._max_rows = max_rows
+ self._rate_search_list = rate_search_list
+
+ # internal data and models
+ self.dataframe = None
+ self.initial_model = None
+ self.resulting_model = None
+ self._max_cost_regularization = None
+ self._output_filename = None # resulting json file
+
+ self.initialize()
+
+ def initialize(self):
+ dc = DataCollector(self._filename_data)
+ self._initial_model = dc.get_context_model()
+ self.dataframe = dc.collect_dataframe(max_rows=self._max_rows)
+ self._output_filename = self._get_output_filename()
+ self._max_cost_regularization = cost_symbol_av1(self._cost_regularization)
+
+ def prepare_dataframe(self, df, ctx_idx_list):
+ dims = self._initial_model.num_dims
+ df_filtered = df
+ if dims > 0:
+ mask = df["Dim0"] == ctx_idx_list[0]
+ for i in range(1, dims):
+ mask &= df[f"Dim{i}"] == ctx_idx_list[i]
+ df_filtered = df.loc[mask]
+ # add "CalcCost" column calculating cost
+ new_cols = ["CalcCost"]
+ df_filtered = df_filtered.reindex(
+ columns=[*df_filtered.columns.tolist(), *new_cols], fill_value=0
+ )
+ df_filtered.reset_index(inplace=True)
+
+ # return empty df if very few (< _min_sample) data points available
+ if len(df_filtered) < self._min_sample:
+ return pd.DataFrame()
+
+ # convert NaN values to -1, if any
+ df_filtered = df_filtered.fillna(-1)
+ df_filtered = df_filtered[df_filtered.columns].astype(
+ int
+ ) # all columns are integer
+
+ return df_filtered
+
+ def map_rate_from_counter(self, rates_tuple, counter):
+ if counter <= 15:
+ return rates_tuple[0]
+ elif counter <= 31:
+ return rates_tuple[1]
+ else:
+ return rates_tuple[2]
+
+ def update_pmf_per_rate(self, pmf_list, cost_list, val, counter):
+ rate_offset_list = self._rate_search_list
+ nsymb = self._initial_model.num_symb
+ for r, rates_tuple in enumerate(rate_offset_list):
+ roffset = self.map_rate_from_counter(rates_tuple, counter)
+ # cost
+ cost = cost_symbol_av1(pmf_list[r][val])
+ if cost > self._max_cost_regularization:
+ cost = self._max_cost_regularization
+ cost_list[r] += cost
+ # update
+ cdf_inv = pmf2cdfinv_av1(pmf_list[r])
+ cdf_inv_updated = update_cdfinv_av1(cdf_inv, val, counter, nsymb, roffset)
+ pmf_list[r] = cdfinv2pmf_av1(cdf_inv_updated)
+ return (pmf_list, cost_list)
+
+ def estimate_cost_multiple(self, df_filtered, pmf_init, rateidx_init=0):
+ nsymb = self._initial_model.num_symb
+ rate_offset_list = self._rate_search_list
+ num_rates = len(rate_offset_list)
+ cost_list = np.zeros(num_rates, dtype=int)
+
+ # initialization
+ pmf_init_list = [pmf_init] * num_rates
+ prev_pmf_buffer = [pmf_init_list]
+ cdf_cols = []
+ for i in range(nsymb):
+ if f"cdf{i}" in df_filtered.keys():
+ cdf_cols.append(f"cdf{i}")
+
+ mask_index = (df_filtered["Counter"] == 0) | (df_filtered["isBeginFrame"] == 1)
+ index_list = df_filtered[mask_index].index.to_list()
+ index_list.append(len(df_filtered))
+ index_pair_list = [
+ (index_list[ind], index_list[ind + 1]) for ind in range(len(index_list) - 1)
+ ]
+
+ for index_pair in index_pair_list:
+ start_idx, end_idx = index_pair
+ df_sub = df_filtered.iloc[start_idx:end_idx].copy()
+
+ num_data_samples = len(df_sub)
+ success = False
+ if len(cdf_cols) == nsymb:
+ cdf_init = np.array(df_sub.iloc[0][cdf_cols], dtype=int)
+ prev_pmf_buffer = [[cdf2pmf_av1(cdf_init)] * num_rates]
+ for _, prev_pmf in enumerate(prev_pmf_buffer):
+ rateidx_init = df_sub.iloc[0]["rate"]
+ pmf_default = prev_pmf[rateidx_init] # default one
+ # initial values
+ val_init = df_sub.iloc[0]["Value"]
+ cost = df_sub.iloc[0]["Cost"]
+ calculated_cost = cost_symbol_av1(pmf_default[val_init])
+ if calculated_cost > self._max_cost_regularization:
+ calculated_cost = self._max_cost_regularization
+
+ cost_list_persub = np.zeros(num_rates, dtype=int)
+
+ if calculated_cost != cost:
+ continue
+
+ pmf_list = [pmf_default] * num_rates
+
+ # try best cases
+ for i in range(num_data_samples):
+ val = df_sub.iloc[i]["Value"]
+ counter = df_sub.iloc[i]["Counter"]
+ cost = df_sub.iloc[i]["Cost"]
+ calculated_cost = cost_symbol_av1(pmf_list[rateidx_init][val])
+ if calculated_cost > self._max_cost_regularization:
+ calculated_cost = self._max_cost_regularization
+ df_sub.iloc[i]["CalcCost"] = calculated_cost
+
+ if cost != calculated_cost:
+ break
+
+ # update variables
+ pmf_list, cost_list_persub = self.update_pmf_per_rate(
+ pmf_list, cost_list_persub, val, counter
+ )
+
+ # success
+ if i == num_data_samples - 1:
+ success = True
+ # update cost
+ cost_list += cost_list_persub
+ # save latest pmf before updating
+ prev_pmf_buffer.insert(1, pmf_list)
+ if success:
+ break
+
+ return (cost_list, 0)
+
+ def get_cost_per_value(self, df):
+ num_symb = self._initial_model.num_symb
+ cost_pervalue_list = [0] * num_symb
+ if df.empty:
+ return cost_pervalue_list
+ for i in range(num_symb):
+ cost = df[df["Value"] == i]["Cost"].sum()
+ cost_pervalue_list[i] = int(cost)
+ return cost_pervalue_list
+
+ def get_value_count(self, df):
+ num_symb = self._initial_model.num_symb
+ value_count_list = [0] * num_symb
+ if df.empty:
+ return value_count_list
+ count_series = df["Value"].value_counts()
+ for i in count_series.index:
+ count = count_series[i]
+ value_count_list[i] = int(count)
+ return value_count_list
+
+ def run_rate_training_on_file(self):
+ # parameters
+ rate_parameters = self._rate_search_list
+ size_list = self._initial_model.size_list
+ num_dims = self._initial_model.num_dims
+ num_symb = self._initial_model.num_symb
+ ctx_group_name = self._initial_model.ctx_group_name
+ prob_dict = self._initial_model.model_dict
+
+ ctx_name_list = list(prob_dict.keys())
+ result_dict = prob_dict.copy()
+ for ctx_name in ctx_name_list:
+ ctx_idx_interest = prob_dict[ctx_name]["index"]
+ print(f"Context {ctx_name}:", end=" ")
+ df_filtered = self.prepare_dataframe(self.dataframe, ctx_idx_interest)
+ value_count_list = self.get_value_count(df_filtered)
+ value_cost_list = self.get_cost_per_value(df_filtered)
+ # check empty df
+ if df_filtered.empty:
+ print("Skip training - insufficient data.")
+ result_dict[ctx_name] = {
+ "current_cost": np.zeros(len(rate_parameters), dtype=int),
+ "adapt_rate": rate_parameters,
+ "initial_cost": 0,
+ "codec_cost": 0,
+ "upper_cost": 0,
+ "num_samples": 0,
+ "initializer": prob_dict[ctx_name]["initializer"],
+ "init_rateidx": prob_dict[ctx_name]["init_rateidx"],
+ "value_count": value_count_list,
+ "value_cost": value_cost_list,
+ }
+ continue
+
+ # actual cost
+ codec_cost = df_filtered["Cost"].sum()
+
+ # estimate cost
+ default_cdf = np.array(prob_dict[ctx_name]["initializer"])
+ default_pmf = cdf2pmf_av1(default_cdf)
+ default_rateidx = int(prob_dict[ctx_name]["init_rateidx"])
+ # default_rate = RATE_LIST[default_rateidx]
+
+ curr_cost_list, upper_cost = self.estimate_cost_multiple(
+ df_filtered, pmf_init=default_pmf, rateidx_init=default_rateidx
+ )
+ init_cost = curr_cost_list[default_rateidx]
+ curr_cost_list = curr_cost_list - init_cost
+ result_dict[ctx_name] = {
+ "current_cost": curr_cost_list,
+ "adapt_rate": rate_parameters,
+ "initial_cost": init_cost,
+ "codec_cost": codec_cost,
+ "upper_cost": upper_cost,
+ "num_samples": len(df_filtered),
+ "initializer": prob_dict[ctx_name]["initializer"],
+ "init_rateidx": prob_dict[ctx_name]["init_rateidx"],
+ "value_count": value_count_list,
+ "value_cost": value_cost_list,
+ }
+ print("Training...")
+
+ self.resulting_model = model.EntropyContext(
+ ctx_group_name, num_symb, num_dims, size_list, result_dict
+ )
+ self.write_results()
+
+ return self.resulting_model
+
+ def _get_output_filename(self):
+ # get path and filename
+ fullpath_data = self._filename_data
+ filename = os.path.basename(fullpath_data)
+ dirname = os.path.dirname(fullpath_data)
+ if len(dirname) == 0:
+ dirname = "."
+ base_filename = filename.split(".")[0]
+ fullpath_result_json = dirname + "/" + "Result_" + base_filename + ".json"
+ return fullpath_result_json
+
+ def get_searched_rate_parameter_list(self):
+ return self._rate_search_list
+
+ def write_results(self):
+ # information
+ ctx_group_name = self.resulting_model.ctx_group_name
+ num_symb = self.resulting_model.num_symb
+ num_dims = self.resulting_model.num_dims
+ size_list = self.resulting_model.size_list
+ result_dict = self.resulting_model.model_dict
+
+ # collect keys
+ key_list = ["information"]
+ key_list.extend(list(result_dict.keys()))
+ json_dict = dict.fromkeys(key_list)
+ for key in key_list:
+ if key == "information":
+ json_dict[key] = {
+ "header": {
+ "ctx_group_name": ctx_group_name,
+ "num_symb": num_symb,
+ "num_dims": num_dims,
+ "size_list": size_list,
+ },
+ "training": {
+ "rate_search_list": self._rate_search_list,
+ "cost_regularization": self._cost_regularization,
+ "min_sample": self._min_sample,
+ },
+ }
+ else:
+ cost_list = result_dict[key]["current_cost"].copy()
+ min_idx = np.argmin(cost_list)
+ json_dict[key] = {
+ "best_rate": self._rate_search_list[min_idx],
+ "best_rate_idx": int(min_idx),
+ "current_cost": cost_list.astype(int).tolist(),
+ "initial_cost": int(result_dict[key]["initial_cost"]),
+ "codec_cost": int(result_dict[key]["codec_cost"]),
+ "upper_cost": int(result_dict[key]["upper_cost"]),
+ "num_samples": int(result_dict[key]["num_samples"]),
+ "initializer": result_dict[key]["initializer"],
+ "init_rateidx": result_dict[key]["init_rateidx"],
+ "value_count": result_dict[key]["value_count"],
+ "value_cost": result_dict[key]["value_cost"],
+ }
+ # write to json
+ with open(self._output_filename, "w") as jsonfile:
+ json.dump(json_dict, jsonfile, indent=4)
diff --git a/ParaKit/src/parakit/tasks/__init__.py b/ParaKit/src/parakit/tasks/__init__.py
new file mode 100644
index 0000000..fe2df03
--- /dev/null
+++ b/ParaKit/src/parakit/tasks/__init__.py
@@ -0,0 +1,11 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
diff --git a/ParaKit/src/parakit/tasks/collect_results.py b/ParaKit/src/parakit/tasks/collect_results.py
new file mode 100644
index 0000000..a5c05a9
--- /dev/null
+++ b/ParaKit/src/parakit/tasks/collect_results.py
@@ -0,0 +1,107 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import json
+
+import numpy as np
+from termcolor import cprint
+
+import parakit.config.user as user
+from parakit.config.training import (
+ CHANGE_INITIALIZERS,
+ MIN_NUM_DATA_SAMPLES_NEEDED,
+ RATE_LIST,
+)
+from parakit.entropy.file_collector import FileCollector
+from parakit.entropy.result_collector import ResultCollector
+
+
+def run(path_ctxdata="./results/data", user_config_file="parameters.yaml"):
+ path_result = path_ctxdata
+ test_output_tag, desired_ctx_list = user.read_config_data(user_config_file)
+ if CHANGE_INITIALIZERS:
+ test_output_tag += "_newinit"
+ final_dict = dict.fromkeys(desired_ctx_list)
+ cprint(
+ f"Combining results for {len(desired_ctx_list)} sets of probability models:",
+ attrs=["bold"],
+ )
+ for desired_ctx in desired_ctx_list:
+ fc = FileCollector(
+ path_result, "json", desired_ctx, subtext=f"_{test_output_tag}."
+ )
+ json_files = fc.get_files()
+ num_json_files = len(json_files)
+ print(f"Working on {desired_ctx} to combine {num_json_files} json files")
+ combined_results = {}
+ for file_idx in range(num_json_files):
+ filename = json_files[file_idx]
+ print(f"Combining: {filename}")
+ filepath = path_result + "/" + filename
+ rc = ResultCollector(filepath)
+ result_dict = rc.parse_json_file()
+ combined_results = rc.combine_data(combined_results, result_dict)
+
+ if CHANGE_INITIALIZERS: # change initializer
+ combined_results = rc.update_probability_initialzer_av1_style(
+ combined_results
+ )
+
+ if len(combined_results) == 0:
+ final_dict[desired_ctx] = {}
+ else:
+ key_list = list(combined_results.keys())
+ temp_dict = dict.fromkeys(key_list)
+ for key in key_list:
+ if key == "information":
+ temp_dict[key] = combined_results["information"]
+ continue
+ cost_list = combined_results[key]["current_cost"]
+ actual_cost = combined_results[key]["initial_cost"]
+ if actual_cost == 0:
+ overall_percent_reduction = np.zeros(len(cost_list))
+ else:
+ overall_percent_reduction = 100 * (cost_list / actual_cost)
+ overall_percent_reduction = np.nan_to_num(overall_percent_reduction)
+ num_samples = combined_results[key]["num_samples"]
+ combined_percent_reduction = combined_results[key]["percent_cost_total"]
+ MAX_RATE_SEARCH = len(RATE_LIST)
+ min_idx = np.argmin(cost_list[0:MAX_RATE_SEARCH])
+ min_idx_perc = np.argmin(combined_percent_reduction[0:MAX_RATE_SEARCH])
+ if num_samples < MIN_NUM_DATA_SAMPLES_NEEDED:
+ min_idx = 0
+ min_idx_perc = 0
+ init_rateidx = int(combined_results[key]["init_rateidx"])
+ temp_dict[key] = {
+ "best_rate": RATE_LIST[min_idx],
+ "best_rate_idx": int(min_idx),
+ "current_cost": cost_list.astype(int).tolist(),
+ "initial_cost": int(actual_cost),
+ "best_rate_perc": RATE_LIST[min_idx_perc],
+ "best_rate_idx_perc": int(min_idx_perc),
+ "percent_cost": combined_percent_reduction.astype(float).tolist(),
+ "num_samples": int(num_samples),
+ "overall_percent_reduction": overall_percent_reduction.astype(
+ float
+ ).tolist(),
+ "initializer": combined_results[key]["initializer"],
+ "init_rate": RATE_LIST[init_rateidx],
+ "init_rateidx": int(init_rateidx),
+ }
+ final_dict[desired_ctx] = temp_dict
+
+ with open(f"{path_result}/Combined_Result_{test_output_tag}.json", "w") as outfile:
+ json.dump(final_dict, outfile, indent=4)
+ cprint("Done combining results!\n", "green", attrs=["bold"])
+
+
+if __name__ == "__main__":
+ run()
diff --git a/ParaKit/src/parakit/tasks/decoding.py b/ParaKit/src/parakit/tasks/decoding.py
new file mode 100644
index 0000000..0b20a15
--- /dev/null
+++ b/ParaKit/src/parakit/tasks/decoding.py
@@ -0,0 +1,71 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import multiprocessing
+import os
+import sys
+
+from termcolor import cprint
+
+import parakit.config.user as user
+from parakit.entropy.file_collector import FileCollector
+
+
+def decode_task(decode_info):
+ os.system(decode_info[0])
+ print(f"Decoded: {decode_info[2]}", flush=True)
+
+
+def run(
+ path_bitstream="./bitstreams",
+ path_ctx_data="./results/data",
+ user_config_file="parameters.yaml",
+):
+ test_output_tag, _ = user.read_config_data(user_config_file)
+ bitstream_extension = user.read_config_decode(user_config_file)
+ fc = FileCollector(path_bitstream, bitstream_extension)
+ bitstreams = fc.get_files()
+ num_bitstreams = len(bitstreams)
+ if num_bitstreams == 0:
+ cprint(
+ f"No bistream files with extension .{bitstream_extension} under {path_bitstream}",
+ "red",
+ attrs=["bold"],
+ file=sys.stderr,
+ )
+ print(
+ f"Usage: (i) add files under {path_bitstream} path and (ii) choose the correct extension in parameters.yaml (BITSTREAM_EXTENSION field)."
+ )
+ sys.exit()
+ cprint(
+ f"Decoding {num_bitstreams} bitstreams to collect data under {path_ctx_data}:",
+ attrs=["bold"],
+ )
+ # prepare decoding task information
+ decode_info = []
+ for idx, bitstream in enumerate(bitstreams):
+ suffix = os.path.splitext(bitstream)[0] + "_" + test_output_tag
+ decode_info.append(
+ (
+ f"binaries/aomdec {path_bitstream}/{bitstream} --path-ctxdata={path_ctx_data} --suffix-ctxdata={suffix} -o /dev/null",
+ idx,
+ bitstream,
+ )
+ )
+ # run using all available cores
+ num_cpu = os.cpu_count()
+ with multiprocessing.Pool(num_cpu) as pool:
+ pool.map(decode_task, decode_info)
+ cprint("Decoding complete!\n", "green", attrs=["bold"])
+
+
+if __name__ == "__main__":
+ run()
diff --git a/ParaKit/src/parakit/tasks/generate_tables.py b/ParaKit/src/parakit/tasks/generate_tables.py
new file mode 100644
index 0000000..c4f7510
--- /dev/null
+++ b/ParaKit/src/parakit/tasks/generate_tables.py
@@ -0,0 +1,91 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import json
+
+from termcolor import cprint
+
+import parakit.config.user as user
+import parakit.entropy.model as model
+from parakit.entropy.codec_default_cdf import CDF_INIT_TOP, av1_default_cdf_parameters
+
+DEFAULT_PROB_INITIALIZER = False
+DEFAULT_RATE_PARAMETER = False
+ZERO_RATE_PARAMETER = False
+
+
+def run(
+ path_ctxdata="./results/data",
+ path_table="./results",
+ user_config_file="parameters.yaml",
+):
+ results = {}
+ test_output_tag, desired_ctx_list = user.read_config_data(user_config_file)
+ combined_file = f"Combined_Result_{test_output_tag}"
+ with open(f"{path_ctxdata}/{combined_file}.json") as json_file:
+ results = json.load(json_file)
+ table_filename = f"{path_table}/Context-Table_{combined_file}.h"
+ cprint(f"Generating context tables in file: {table_filename}", attrs=["bold"])
+ num_symbols = 0
+ num_symbol_groups = 0
+ table_string = "\n"
+ for desired_ctx in desired_ctx_list:
+ result_ctx_group = results[desired_ctx]
+ print(f"Generating table for context: {desired_ctx}")
+ result_info = result_ctx_group["information"]
+ ctx_group_name = result_info["header"]["ctx_group_name"]
+ num_symb = result_info["header"]["num_symb"]
+ num_dims = result_info["header"]["num_dims"]
+ size_list = result_info["header"]["size_list"]
+
+ sym_grp = 1
+ for d in range(num_dims):
+ sym_grp = sym_grp * size_list[d]
+ num_symbol_groups += sym_grp
+ num_symbols = num_symbols + (sym_grp * num_symb)
+
+ key_list = list(result_ctx_group.keys())
+ key_list.remove("information")
+ for key in key_list:
+ if ZERO_RATE_PARAMETER:
+ result_ctx_group[key]["init_rate"] = 0
+ elif DEFAULT_RATE_PARAMETER:
+ result_ctx_group[key]["init_rate"] = result_ctx_group[key][
+ "init_rateidx"
+ ]
+ else:
+ result_ctx_group[key]["init_rate"] = result_ctx_group[key][
+ "best_rate_idx"
+ ]
+
+ if DEFAULT_PROB_INITIALIZER:
+ cdf_list = av1_default_cdf_parameters(num_symb).tolist()
+ cdf_list.append(CDF_INIT_TOP)
+ result_ctx_group[key]["initializer"] = cdf_list
+
+ resulting_model = model.EntropyContext(
+ ctx_group_name,
+ num_symb,
+ num_dims,
+ size_list,
+ result_ctx_group,
+ user_config_file,
+ )
+ table_string += resulting_model.get_complete_model_string(is_avm_style=True)
+ table_string += "\n"
+
+ with open(table_filename, "w") as table_file:
+ table_file.write(table_string)
+ cprint("Done generating context tables!\n", "green", attrs=["bold"])
+
+
+if __name__ == "__main__":
+ run()
diff --git a/ParaKit/src/parakit/tasks/training.py b/ParaKit/src/parakit/tasks/training.py
new file mode 100644
index 0000000..36898ed
--- /dev/null
+++ b/ParaKit/src/parakit/tasks/training.py
@@ -0,0 +1,53 @@
+"""
+Copyright (c) 2024, Alliance for Open Media. All rights reserved
+
+This source code is subject to the terms of the BSD 3-Clause Clear License
+and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+License was not distributed with this source code in the LICENSE file, you
+can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+Alliance for Open Media Patent License 1.0 was not distributed with this
+source code in the PATENTS file, you can obtain it at
+aomedia.org/license/patent-license/.
+"""
+import multiprocessing
+import os
+
+from termcolor import cprint
+
+import parakit.config.user as user
+from parakit.config.training import RATE_LIST
+from parakit.entropy.file_collector import FileCollector
+from parakit.entropy.trainer import Trainer
+
+
+def train_task(train_info):
+ file_fullpath = train_info[0]
+ trainer = Trainer(file_fullpath, RATE_LIST)
+ trainer.run_rate_training_on_file()
+ print(f"Trained: {train_info[2]}", flush=True)
+
+
+def run(path_ctxdata="./results/data", user_config_file="parameters.yaml"):
+ test_output_tag, _ = user.read_config_data(user_config_file)
+ fc = FileCollector(path_ctxdata, "csv", subtext=f"_{test_output_tag}.")
+ data_files = fc.get_files()
+ num_data_files = len(data_files)
+ cprint(
+ f"Training based on {num_data_files} csv files under {path_ctxdata}:",
+ attrs=["bold"],
+ )
+ # prepare decoding task information
+ train_info = []
+ for idx, file_csv in enumerate(data_files):
+ file_fullpath = f"{path_ctxdata}/{file_csv}"
+ # print(f'[{idx}]: {file_csv}')
+ train_info.append((file_fullpath, idx, file_csv))
+ # run training in paralel using all available cores
+ num_cpu = os.cpu_count()
+ with multiprocessing.Pool(num_cpu) as pool:
+ pool.map(train_task, train_info)
+ cprint("Training complete!\n", "green", attrs=["bold"])
+
+
+if __name__ == "__main__":
+ run()
diff --git a/ParaKit/unit_test/bitstreams/Bin_AVM-v7_File1.avm b/ParaKit/unit_test/bitstreams/Bin_AVM-v7_File1.avm
new file mode 100644
index 0000000..398aab0
--- /dev/null
+++ b/ParaKit/unit_test/bitstreams/Bin_AVM-v7_File1.avm
Binary files differ
diff --git a/ParaKit/unit_test/bitstreams/Bin_AVM-v7_File2.avm b/ParaKit/unit_test/bitstreams/Bin_AVM-v7_File2.avm
new file mode 100644
index 0000000..4ec8ba6
--- /dev/null
+++ b/ParaKit/unit_test/bitstreams/Bin_AVM-v7_File2.avm
Binary files differ
diff --git a/ParaKit/unit_test/parameters_unit_test.yaml b/ParaKit/unit_test/parameters_unit_test.yaml
new file mode 100644
index 0000000..8ce3d70
--- /dev/null
+++ b/ParaKit/unit_test/parameters_unit_test.yaml
@@ -0,0 +1,7 @@
+TEST_OUTPUT_TAG: "Unit-Test"
+BITSTREAM_EXTENSION: "avm"
+DESIRED_CTX_LIST:
+ - eob_flag_cdf16
+ - eob_flag_cdf32
+eob_flag_cdf16: "av1_default_eob_multi16_cdfs"
+eob_flag_cdf32: "av1_default_eob_multi32_cdfs"
diff --git a/aom/aom_decoder.h b/aom/aom_decoder.h
index 46846a9..2d4a60e 100644
--- a/aom/aom_decoder.h
+++ b/aom/aom_decoder.h
@@ -93,6 +93,8 @@
unsigned int threads; /**< Maximum number of threads to use, default 1 */
unsigned int w; /**< Width */
unsigned int h; /**< Height */
+ char *path_parakit; /**< ParaKit data path */
+ char *suffix_parakit; /**< ParaKit data suffix */
} aom_codec_dec_cfg_t; /**< alias for struct aom_codec_dec_cfg */
/*!\brief Initialize a decoder instance
diff --git a/aom_dsp/bitreader.h b/aom_dsp/bitreader.h
index 0e5cd29..d6c0ea6 100644
--- a/aom_dsp/bitreader.h
+++ b/aom_dsp/bitreader.h
@@ -27,6 +27,12 @@
#include "aom_dsp/recenter.h"
#endif // ENABLE_LR_4PART_CODE
+#if CONFIG_PARAKIT_COLLECT_DATA
+#include "av1/common/cost.h"
+#include "av1/common/av1_common_int.h"
+#include "av1/common/entropy_sideinfo.h"
+#endif
+
#if CONFIG_BITSTREAM_DEBUG
#include "aom_util/debug_util.h"
#endif // CONFIG_BITSTREAM_DEBUG
@@ -387,6 +393,78 @@
return ret;
}
+#if CONFIG_PARAKIT_COLLECT_DATA
+static INLINE int aom_read_cdf_probdata(aom_reader *r, const aom_cdf_prob *cdf,
+ int nsymbs) {
+ int symb;
+ assert(cdf != NULL);
+ symb = od_ec_decode_cdf_q15(&r->ec, cdf, nsymbs);
+ return symb;
+}
+
+// @ParaKit: use aom_read_symbol_probdata function for decoding to collect data
+// make sure that "const AV1_COMMON *const cm" pointer that has
+// prob_info information
+static INLINE int aom_read_symbol_probdata(aom_reader *r, aom_cdf_prob *cdf,
+ const int *indexlist,
+ ProbModelInfo prob_info) {
+ FILE *filedata = prob_info.fDataCollect;
+ const int symLength = prob_info.num_symb;
+ // Estimated probability and counter information
+ const int counter_engine = (int)cdf[symLength];
+ for (int i = 0; i < prob_info.num_dim; i++) {
+ fprintf(filedata, "%d,", *(indexlist + i));
+ }
+
+ const int frameNumber = prob_info.frameNumber;
+ fprintf(filedata, "%d,", frameNumber);
+ const int frameType = prob_info.frameType;
+ fprintf(filedata, "%d,", frameType);
+ int begin_idx[4] = { 0, 0, 0, 0 };
+ for (int i = 0; i < prob_info.num_dim; i++) {
+ const int offset = 4 - prob_info.num_dim;
+ assert(offset >= 0);
+ begin_idx[i + offset] = indexlist[i];
+ }
+ assert(begin_idx[0] >= 0 && begin_idx[0] < MAX_DIMS_CONTEXT3);
+ assert(begin_idx[1] >= 0 && begin_idx[1] < MAX_DIMS_CONTEXT2);
+ assert(begin_idx[2] >= 0 && begin_idx[2] < MAX_DIMS_CONTEXT1);
+ assert(begin_idx[3] >= 0 && begin_idx[3] < MAX_DIMS_CONTEXT0);
+ const int beginFrameFlag =
+ beginningFrameFlag[prob_info.model_idx][begin_idx[0]][begin_idx[1]]
+ [begin_idx[2]][begin_idx[3]];
+ fprintf(filedata, "%d,", beginFrameFlag);
+
+ int cdf_list[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
+ for (int sym = 0; sym < symLength; sym++) {
+ cdf_list[sym] = CDF_INIT_TOP - cdf[sym];
+ }
+ int cost_list[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
+ av1_cost_tokens_from_cdf(cost_list, cdf, NULL);
+
+ int ret;
+ ret = aom_read_cdf_probdata(r, cdf, symLength);
+ if (r->allow_update_cdf) update_cdf(cdf, ret, symLength);
+
+ const int cost = cost_list[ret];
+ fprintf(filedata, "%d,%d,%d", counter_engine, ret, cost);
+
+ if (beginningFrameFlag[prob_info.model_idx][begin_idx[0]][begin_idx[1]]
+ [begin_idx[2]][begin_idx[3]] ||
+ counter_engine == 0) {
+ for (int sym = 0; sym < symLength; sym++) {
+ fprintf(filedata, ",%d", cdf_list[sym]);
+ }
+ fprintf(filedata, ",%d", (int)cdf[symLength + 1]);
+ }
+ beginningFrameFlag[prob_info.model_idx][begin_idx[0]][begin_idx[1]]
+ [begin_idx[2]][begin_idx[3]] = 0;
+ fprintf(filedata, "\n");
+
+ return ret;
+}
+#endif
+
#if ENABLE_LR_4PART_CODE
// Implements a code where a symbol with an alphabet size a power of 2 with
// nsymb_bits bits (with nsymb_bits >= 3), is coded by decomposing the symbol
diff --git a/aom_dsp/bitwriter.h b/aom_dsp/bitwriter.h
index c73e3ff..746b8a9 100644
--- a/aom_dsp/bitwriter.h
+++ b/aom_dsp/bitwriter.h
@@ -26,7 +26,7 @@
#if CONFIG_RD_DEBUG
#include "av1/common/blockd.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#endif
#if CONFIG_BITSTREAM_DEBUG
diff --git a/apps/aomdec.c b/apps/aomdec.c
index f8de253..8c9778d 100644
--- a/apps/aomdec.c
+++ b/apps/aomdec.c
@@ -50,6 +50,10 @@
static const char *exec_name;
+#if CONFIG_PARAKIT_COLLECT_DATA
+#include "av1/common/entropy_sideinfo.h"
+#endif
+
struct AvxDecInputContext {
struct AvxInputContext *aom_input_ctx;
struct ObuDecInputContext *obu_ctx;
@@ -81,6 +85,14 @@
ARG_DEF(NULL, "summary", 0, "Show timing summary");
static const arg_def_t outputfile =
ARG_DEF("o", "output", 1, "Output file name pattern (see below)");
+#if CONFIG_PARAKIT_COLLECT_DATA
+static const arg_def_t datafilesuffix =
+ ARG_DEF(NULL, "suffix-ctxdata", 1,
+ "Filename prefix for collecting probability data");
+static const arg_def_t datafilepath =
+ ARG_DEF(NULL, "path-ctxdata", 1,
+ "Path for the file used to collect probability data");
+#endif
static const arg_def_t threadsarg =
ARG_DEF("t", "threads", 1, "Max threads to use");
static const arg_def_t verbosearg =
@@ -111,12 +123,15 @@
ARG_DEF(NULL, "skip-film-grain", 0, "Skip film grain application");
static const arg_def_t *all_args[] = {
- &help, &codecarg, &use_yv12, &use_i420,
- &flipuvarg, &rawvideo, &noblitarg, &progressarg,
- &limitarg, &skiparg, &summaryarg, &outputfile,
- &threadsarg, &verbosearg, &scalearg, &fb_arg,
- &md5arg, &verifyarg, &framestatsarg, &continuearg,
- &outbitdeptharg, &isannexb, &oppointarg, &outallarg,
+ &help, &codecarg, &use_yv12, &use_i420,
+ &flipuvarg, &rawvideo, &noblitarg, &progressarg,
+ &limitarg, &skiparg, &summaryarg, &outputfile,
+#if CONFIG_PARAKIT_COLLECT_DATA
+ &datafilesuffix, &datafilepath,
+#endif
+ &threadsarg, &verbosearg, &scalearg, &fb_arg,
+ &md5arg, &verifyarg, &framestatsarg, &continuearg,
+ &outbitdeptharg, &isannexb, &oppointarg, &outallarg,
&skipfilmgrain, NULL
};
@@ -536,7 +551,7 @@
int opt_yv12 = 0;
int opt_i420 = 0;
int opt_raw = 0;
- aom_codec_dec_cfg_t cfg = { 0, 0, 0 };
+ aom_codec_dec_cfg_t cfg = { 0, 0, 0, NULL, NULL };
unsigned int fixed_output_bit_depth = 0;
unsigned int is_annexb = 0;
int frames_corrupted = 0;
@@ -560,6 +575,11 @@
MD5Context md5_ctx;
unsigned char md5_digest[16];
+#if CONFIG_PARAKIT_COLLECT_DATA
+ char *datafilename_path = NULL;
+ char *datafilename_suffix = NULL;
+#endif
+
struct AvxDecInputContext input = { NULL, NULL, NULL };
struct AvxInputContext aom_input_ctx;
memset(&aom_input_ctx, 0, sizeof(aom_input_ctx));
@@ -595,6 +615,12 @@
// no-op
} else if (arg_match(&arg, &outputfile, argi)) {
outfile_pattern = arg.val;
+#if CONFIG_PARAKIT_COLLECT_DATA
+ } else if (arg_match(&arg, &datafilesuffix, argi)) {
+ datafilename_suffix = (char *)arg.val;
+ } else if (arg_match(&arg, &datafilepath, argi)) {
+ datafilename_path = (char *)arg.val;
+#endif
} else if (arg_match(&arg, &use_yv12, argi)) {
use_y4m = 0;
flipuv = 1;
@@ -771,6 +797,10 @@
if (!interface) interface = get_aom_decoder_by_index(0);
+#if CONFIG_PARAKIT_COLLECT_DATA
+ cfg.path_parakit = datafilename_path;
+ cfg.suffix_parakit = datafilename_suffix;
+#endif
dec_flags = 0;
if (aom_codec_dec_init(&decoder, interface, &cfg, dec_flags)) {
fprintf(stderr, "Failed to initialize decoder: %s\n",
diff --git a/apps/aomenc.c b/apps/aomenc.c
index 5e5a200..8061fe5 100644
--- a/apps/aomenc.c
+++ b/apps/aomenc.c
@@ -1827,7 +1827,7 @@
if (global->test_decode != TEST_DECODE_OFF) {
aom_codec_iface_t *decoder = get_aom_decoder_by_short_name(
get_short_name_by_aom_encoder(global->codec));
- aom_codec_dec_cfg_t cfg = { 0, 0, 0 };
+ aom_codec_dec_cfg_t cfg = { 0, 0, 0, NULL, NULL };
aom_codec_dec_init(&stream->decoder, decoder, &cfg, 0);
if (strcmp(get_short_name_by_aom_encoder(global->codec), "av1") == 0) {
diff --git a/av1/arg_defs.c b/av1/arg_defs.c
index a149f7e..bdaf826 100644
--- a/av1/arg_defs.c
+++ b/av1/arg_defs.c
@@ -147,6 +147,12 @@
.debugmode =
ARG_DEF("D", "debug", 0, "Debug mode (makes output deterministic)"),
.outputfile = ARG_DEF("o", "output", 1, "Output filename"),
+#if CONFIG_PARAKIT_COLLECT_DATA
+ .datafilesuffix = ARG_DEF(NULL, "suffix-ctxdata", 1,
+ "Prefix for filename used for data collection"),
+ .datafilepath = ARG_DEF(NULL, "path-ctxdata", 1,
+ "Path for file used for data collection"),
+#endif
.reconfile = ARG_DEF(NULL, "recon", 1, "Recon filename"),
.use_yv12 = ARG_DEF(NULL, "yv12", 0, "Input file is YV12 "),
.use_i420 = ARG_DEF(NULL, "i420", 0, "Input file is I420 (default)"),
diff --git a/av1/arg_defs.h b/av1/arg_defs.h
index 6ca7e17..413bebe 100644
--- a/av1/arg_defs.h
+++ b/av1/arg_defs.h
@@ -33,6 +33,10 @@
arg_def_t help;
arg_def_t debugmode;
arg_def_t outputfile;
+#if CONFIG_PARAKIT_COLLECT_DATA
+ arg_def_t datafilesuffix;
+ arg_def_t datafilepath;
+#endif
arg_def_t reconfile;
arg_def_t use_yv12;
arg_def_t use_i420;
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 1b4472a..5429c7a 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -94,7 +94,13 @@
"${AOM_ROOT}/av1/common/txb_common.c"
"${AOM_ROOT}/av1/common/txb_common.h"
"${AOM_ROOT}/av1/common/warped_motion.c"
- "${AOM_ROOT}/av1/common/warped_motion.h")
+ "${AOM_ROOT}/av1/common/warped_motion.h"
+ "${AOM_ROOT}/av1/common/cost.c"
+ "${AOM_ROOT}/av1/common/cost.h"
+ "${AOM_ROOT}/av1/common/entropy_inits_coeffs.h"
+ "${AOM_ROOT}/av1/common/entropy_inits_modes.h"
+ "${AOM_ROOT}/av1/common/entropy_inits_mv.h"
+ "${AOM_ROOT}/av1/common/entropy_sideinfo.h")
list(APPEND AOM_AV1_COMMON_SOURCES "${AOM_ROOT}/av1/common/pef.h")
list(APPEND AOM_AV1_COMMON_SOURCES "${AOM_ROOT}/av1/common/pef.c")
@@ -166,8 +172,6 @@
"${AOM_ROOT}/av1/encoder/compound_type.h"
"${AOM_ROOT}/av1/encoder/context_tree.c"
"${AOM_ROOT}/av1/encoder/context_tree.h"
- "${AOM_ROOT}/av1/encoder/cost.c"
- "${AOM_ROOT}/av1/encoder/cost.h"
"${AOM_ROOT}/av1/encoder/encodeframe.c"
"${AOM_ROOT}/av1/encoder/encodeframe.h"
"${AOM_ROOT}/av1/encoder/encodeframe_utils.c"
diff --git a/av1/av1_dx_iface.c b/av1/av1_dx_iface.c
index 36b4f40..cfa5418 100644
--- a/av1/av1_dx_iface.c
+++ b/av1/av1_dx_iface.c
@@ -471,7 +471,12 @@
return AOM_CODEC_MEM_ERROR;
}
frame_worker_data = (FrameWorkerData *)worker->data1;
+#if CONFIG_PARAKIT_COLLECT_DATA
+ frame_worker_data->pbi = av1_decoder_create(
+ ctx->buffer_pool, ctx->cfg.path_parakit, ctx->cfg.suffix_parakit);
+#else
frame_worker_data->pbi = av1_decoder_create(ctx->buffer_pool);
+#endif
if (frame_worker_data->pbi == NULL) {
set_error_detail(ctx, "Failed to allocate frame_worker_data");
return AOM_CODEC_MEM_ERROR;
diff --git a/av1/common/av1_common_int.h b/av1/common/av1_common_int.h
index c69f9ad..38d9f8d 100644
--- a/av1/common/av1_common_int.h
+++ b/av1/common/av1_common_int.h
@@ -157,6 +157,22 @@
/*!\cond */
+#if CONFIG_PARAKIT_COLLECT_DATA
+#define MAX_CTX_DIM 4
+typedef struct ProbModelInfo {
+ char *ctx_group_name;
+ aom_cdf_prob *prob;
+ int cdf_stride;
+ int num_symb;
+ int num_dim;
+ int num_idx[MAX_CTX_DIM];
+ FILE *fDataCollect;
+ int frameNumber;
+ int frameType;
+ int model_idx;
+} ProbModelInfo;
+#endif
+
enum {
SINGLE_REFERENCE = 0,
COMPOUND_REFERENCE = 1,
@@ -1785,6 +1801,10 @@
FILE *fDecCoeffLog;
#endif
+#if CONFIG_PARAKIT_COLLECT_DATA
+ ProbModelInfo prob_models[MAX_NUM_CTX_GROUPS];
+#endif
+
/*!
* Flag to indicate if current frame has backward ref frame
*/
diff --git a/av1/encoder/cost.c b/av1/common/cost.c
similarity index 98%
rename from av1/encoder/cost.c
rename to av1/common/cost.c
index 347789c..568cb63 100644
--- a/av1/encoder/cost.c
+++ b/av1/common/cost.c
@@ -11,7 +11,7 @@
*/
#include <assert.h>
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/common/entropy.h"
// round(-log2(i/256.) * (1 << AV1_PROB_COST_SHIFT)); i = 128~255.
diff --git a/av1/encoder/cost.h b/av1/common/cost.h
similarity index 94%
rename from av1/encoder/cost.h
rename to av1/common/cost.h
index 8ce12af..ae2bf03 100644
--- a/av1/encoder/cost.h
+++ b/av1/common/cost.h
@@ -10,8 +10,8 @@
* aomedia.org/license/patent-license/.
*/
-#ifndef AOM_AV1_ENCODER_COST_H_
-#define AOM_AV1_ENCODER_COST_H_
+#ifndef AOM_AV1_COMMON_COST_H_
+#define AOM_AV1_COMMON_COST_H_
#include "aom_dsp/prob.h"
#include "aom/aom_integer.h"
@@ -49,4 +49,4 @@
} // extern "C"
#endif
-#endif // AOM_AV1_ENCODER_COST_H_
+#endif // AOM_AV1_COMMON_COST_H_
diff --git a/av1/common/entropy_inits_coeffs.h b/av1/common/entropy_inits_coeffs.h
new file mode 100644
index 0000000..f840df4
--- /dev/null
+++ b/av1/common/entropy_inits_coeffs.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 3-Clause Clear License
+ * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+ * License was not distributed with this source code in the LICENSE file, you
+ * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+ * Alliance for Open Media Patent License 1.0 was not distributed with this
+ * source code in the PATENTS file, you can obtain it at
+ * aomedia.org/license/patent-license/.
+ */
+
+#ifndef AOM_AV1_COMMON_ENTROPY_INITS_COEFFS_H_
+#define AOM_AV1_COMMON_ENTROPY_INITS_COEFFS_H_
+
+#include "config/aom_config.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+static const aom_cdf_prob
+ av1_default_eob_multi16_cdfs[TOKEN_CDF_Q_CTXS][EOB_PLANE_CTXS][CDF_SIZE(
+ 5)] = {
+ {
+ { AOM_CDF5(1413, 1933, 3768, 9455), 5 },
+ { AOM_CDF5(1954, 2400, 4205, 7753), 5 },
+ { AOM_CDF5(9359, 11741, 16061, 22179), 31 },
+ },
+ {
+ { AOM_CDF5(2832, 4201, 8578, 17754), 30 },
+ { AOM_CDF5(4563, 5208, 7444, 11962), 5 },
+ { AOM_CDF5(10524, 13197, 18032, 24922), 0 },
+ },
+ {
+ { AOM_CDF5(4390, 6907, 13987, 24674), 55 },
+ { AOM_CDF5(1870, 2463, 3813, 9299), 30 },
+ { AOM_CDF5(15137, 18012, 23056, 29705), 6 },
+ },
+ {
+ { AOM_CDF5(5508, 11837, 26327, 32095), 56 },
+ { AOM_CDF5(6554, 8738, 10923, 24030), 0 },
+ { AOM_CDF5(28607, 29647, 32421, 32595), 50 },
+ },
+ };
+
+static const aom_cdf_prob
+ av1_default_eob_multi32_cdfs[TOKEN_CDF_Q_CTXS][EOB_PLANE_CTXS][CDF_SIZE(
+ 6)] = {
+ {
+ { AOM_CDF6(1183, 1539, 2981, 7359, 12851), 31 },
+ { AOM_CDF6(1847, 2098, 2631, 4422, 9368), 5 },
+ { AOM_CDF6(14803, 16649, 20616, 25021, 29117), 6 },
+ },
+ {
+ { AOM_CDF6(2170, 3095, 6309, 12580, 18493), 31 },
+ { AOM_CDF6(1194, 1592, 2551, 4712, 9835), 6 },
+ { AOM_CDF6(12842, 15056, 19310, 24033, 29143), 6 },
+ },
+ {
+ { AOM_CDF6(3673, 5100, 10624, 18431, 23892), 31 },
+ { AOM_CDF6(1891, 2179, 3130, 6874, 14672), 0 },
+ { AOM_CDF6(17990, 20534, 24659, 28946, 31883), 6 },
+ },
+ {
+ { AOM_CDF6(6158, 10781, 23027, 30726, 32275), 36 },
+ { AOM_CDF6(971, 6554, 17719, 26336, 29370), 32 },
+ { AOM_CDF6(15245, 19009, 26979, 32073, 32710), 5 },
+ },
+ };
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // AOM_AV1_COMMON_ENTROPY_INITS_COEFFS_H_
diff --git a/av1/common/entropy_inits_modes.h b/av1/common/entropy_inits_modes.h
new file mode 100644
index 0000000..e230df8
--- /dev/null
+++ b/av1/common/entropy_inits_modes.h
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 3-Clause Clear License
+ * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+ * License was not distributed with this source code in the LICENSE file, you
+ * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+ * Alliance for Open Media Patent License 1.0 was not distributed with this
+ * source code in the PATENTS file, you can obtain it at
+ * aomedia.org/license/patent-license/.
+ */
+
+#ifndef AOM_AV1_COMMON_ENTROPY_INITS_MODES_H_
+#define AOM_AV1_COMMON_ENTROPY_INITS_MODES_H_
+
+#include "config/aom_config.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // AOM_AV1_COMMON_ENTROPY_INITS_MODES_H_
diff --git a/av1/common/entropy_inits_mv.h b/av1/common/entropy_inits_mv.h
new file mode 100644
index 0000000..7920826
--- /dev/null
+++ b/av1/common/entropy_inits_mv.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 3-Clause Clear License
+ * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+ * License was not distributed with this source code in the LICENSE file, you
+ * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+ * Alliance for Open Media Patent License 1.0 was not distributed with this
+ * source code in the PATENTS file, you can obtain it at
+ * aomedia.org/license/patent-license/.
+ */
+
+#ifndef AOM_AV1_COMMON_ENTROPY_INITS_MV_H_
+#define AOM_AV1_COMMON_ENTROPY_INITS_MV_H_
+
+#include "config/aom_config.h"
+#include "av1/common/entropymv.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // AOM_AV1_COMMON_ENTROPY_INITS_MV_H_
diff --git a/av1/common/entropy_sideinfo.h b/av1/common/entropy_sideinfo.h
new file mode 100644
index 0000000..6e15287
--- /dev/null
+++ b/av1/common/entropy_sideinfo.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 3-Clause Clear License
+ * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
+ * License was not distributed with this source code in the LICENSE file, you
+ * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/. If the
+ * Alliance for Open Media Patent License 1.0 was not distributed with this
+ * source code in the PATENTS file, you can obtain it at
+ * aomedia.org/license/patent-license/.
+ */
+
+#ifndef AOM_AV1_COMMON_SIDEINFO_H_
+#define AOM_AV1_COMMON_SIDEINFO_H_
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define MAX_NUMBER_CONTEXTS 160 // relaxed upper bound
+#define MAX_DIMS_CONTEXT0 100 // relaxed upper bound
+#define MAX_DIMS_CONTEXT1 10 // relaxed upper bound
+#define MAX_DIMS_CONTEXT2 10 // relaxed upper bound
+#define MAX_DIMS_CONTEXT3 5 // relaxed upper bound
+
+extern int beginningFrameFlag[MAX_NUMBER_CONTEXTS][MAX_DIMS_CONTEXT3]
+ [MAX_DIMS_CONTEXT2][MAX_DIMS_CONTEXT1]
+ [MAX_DIMS_CONTEXT0];
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // AOM_AV1_COMMON_SIDEINFO_H_
diff --git a/av1/common/enums.h b/av1/common/enums.h
index a959da6..5cd42d4 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -42,6 +42,11 @@
#define DEFAULT_IMP_MSK_WT 0 // default implict masked blending weight
#endif // CONFIG_D071_IMP_MSK_BLD
+#if CONFIG_PARAKIT_COLLECT_DATA
+// @ParaKit: add enum variables to indiciate context groups
+enum { EOB_FLAG_CDF16, EOB_FLAG_CDF32, MAX_NUM_CTX_GROUPS };
+#endif
+
#if CONFIG_WEDGE_MOD_EXT
/*WEDGE_0 is defined in the three o'clock direciton, the angles are defined in
* the anticlockwise.*/
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 0ffce05..127a315 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -7252,6 +7252,18 @@
sframe_info->is_s_frame = 0;
sframe_info->is_s_frame_at_altref = 0;
+#if CONFIG_PARAKIT_COLLECT_DATA
+ for (int i = 0; i < MAX_NUM_CTX_GROUPS; i++) {
+ cm->prob_models[i].frameNumber = current_frame->frame_number;
+ cm->prob_models[i].frameType = current_frame->frame_type;
+ for (int j = 0; j < MAX_DIMS_CONTEXT3; j++)
+ for (int k = 0; k < MAX_DIMS_CONTEXT2; k++)
+ for (int l = 0; l < MAX_DIMS_CONTEXT1; l++)
+ for (int h = 0; h < MAX_DIMS_CONTEXT0; h++)
+ beginningFrameFlag[i][j][k][l][h] = 1;
+ }
+#endif
+
if (!pbi->sequence_header_ready) {
aom_internal_error(&cm->error, AOM_CODEC_CORRUPT_FRAME,
"No sequence header");
@@ -8809,5 +8821,13 @@
// Non frame parallel update frame context here.
if (!tiles->large_scale) {
cm->cur_frame->frame_context = *cm->fc;
+#if CONFIG_PARAKIT_COLLECT_DATA
+ for (int i = 0; i < MAX_NUM_CTX_GROUPS; i++)
+ for (int j = 0; j < MAX_DIMS_CONTEXT3; j++)
+ for (int k = 0; k < MAX_DIMS_CONTEXT2; k++)
+ for (int l = 0; l < MAX_DIMS_CONTEXT1; l++)
+ for (int h = 0; h < MAX_DIMS_CONTEXT0; h++)
+ beginningFrameFlag[i][j][k][l][h] = 1;
+#endif
}
}
diff --git a/av1/decoder/decoder.c b/av1/decoder/decoder.c
index 1189b76..733e208 100644
--- a/av1/decoder/decoder.c
+++ b/av1/decoder/decoder.c
@@ -42,6 +42,12 @@
#include "av1/decoder/detokenize.h"
#include "av1/decoder/obu.h"
+#if CONFIG_PARAKIT_COLLECT_DATA
+#include "av1/common/entropy_sideinfo.h"
+int beginningFrameFlag[MAX_NUMBER_CONTEXTS][MAX_DIMS_CONTEXT3]
+ [MAX_DIMS_CONTEXT2][MAX_DIMS_CONTEXT1][MAX_DIMS_CONTEXT0];
+#endif
+
static void initialize_dec(void) {
av1_rtcd();
aom_dsp_rtcd();
@@ -183,7 +189,12 @@
}
#endif // CONFIG_OPTFLOW_ON_TIP
+#if CONFIG_PARAKIT_COLLECT_DATA
+AV1Decoder *av1_decoder_create(BufferPool *const pool, const char *path,
+ const char *suffix) {
+#else
AV1Decoder *av1_decoder_create(BufferPool *const pool) {
+#endif
AV1Decoder *volatile const pbi = aom_memalign(32, sizeof(*pbi));
if (!pbi) return NULL;
av1_zero(*pbi);
@@ -250,6 +261,103 @@
cm->fDecCoeffLog = fopen("DecCoeffLog.txt", "wt");
#endif
+#if CONFIG_PARAKIT_COLLECT_DATA
+#include "av1/common/entropy_inits_coeffs.h"
+#include "av1/common/entropy_inits_modes.h"
+#include "av1/common/entropy_inits_mv.h"
+
+ // @ParaKit: add side information needed in array of prob_models structure to
+ // be used in collecting data
+ cm->prob_models[EOB_FLAG_CDF16] =
+ (ProbModelInfo){ .ctx_group_name = "eob_flag_cdf16",
+ .prob = (aom_cdf_prob *)av1_default_eob_multi16_cdfs,
+ .cdf_stride = 0,
+ .num_symb = 5,
+ .num_dim = 2,
+ .num_idx = { 0, 0, 4, 3 } };
+ cm->prob_models[EOB_FLAG_CDF32] =
+ (ProbModelInfo){ .ctx_group_name = "eob_flag_cdf32",
+ .prob = (aom_cdf_prob *)av1_default_eob_multi32_cdfs,
+ .cdf_stride = 0,
+ .num_symb = 6,
+ .num_dim = 2,
+ .num_idx = { 0, 0, 4, 3 } };
+
+ for (int i = 0; i < MAX_NUM_CTX_GROUPS; i++) {
+ for (int j = 0; j < MAX_DIMS_CONTEXT3; j++)
+ for (int k = 0; k < MAX_DIMS_CONTEXT2; k++)
+ for (int l = 0; l < MAX_DIMS_CONTEXT1; l++)
+ for (int h = 0; h < MAX_DIMS_CONTEXT0; h++)
+ beginningFrameFlag[i][j][k][l][h] = 0;
+ }
+
+ for (int f = 0; f < MAX_NUM_CTX_GROUPS; f++) {
+ cm->prob_models[f].model_idx = f;
+ const int fixed_stride = cm->prob_models[f].cdf_stride;
+ const int num_sym = cm->prob_models[f].num_symb;
+ const int num_dims = cm->prob_models[f].num_dim;
+ const int num_idx0 = cm->prob_models[f].num_idx[0];
+ const int num_idx1 = cm->prob_models[f].num_idx[1];
+ const int num_idx2 = cm->prob_models[f].num_idx[2];
+ const int num_idx3 = cm->prob_models[f].num_idx[3];
+ const char *str_ctx = cm->prob_models[f].ctx_group_name;
+ const char *str_path = path ? path : ".";
+ const char *str_suffix = suffix ? suffix : "data";
+ char filename[2048];
+ sprintf(filename, "%s/Stat_%s_%s.csv", str_path, str_ctx, str_suffix);
+ FILE *fData = fopen(filename, "wt");
+ cm->prob_models[f].fDataCollect = fData;
+
+ fprintf(fData, "Header:%s,%d,%d", str_ctx, num_sym, num_dims);
+ const int dim_offset = MAX_CTX_DIM - num_dims;
+ for (int i = 0; i < num_dims; i++) {
+ fprintf(fData, ",%d", cm->prob_models[f].num_idx[i + dim_offset]);
+ }
+ fprintf(fData, "\n");
+
+ aom_cdf_prob *prob_ptr;
+ prob_ptr = cm->prob_models[f].prob;
+ int ctx_group_counter = 0;
+ for (int d0 = 0; d0 < (num_idx0 == 0 ? 1 : num_idx0); d0++)
+ for (int d1 = 0; d1 < (num_idx1 == 0 ? 1 : num_idx1); d1++)
+ for (int d2 = 0; d2 < (num_idx2 == 0 ? 1 : num_idx2); d2++)
+ for (int d3 = 0; d3 < (num_idx3 == 0 ? 1 : num_idx3); d3++) {
+ // indexing according to MAX_CTX_DIM
+ fprintf(fData, "%d,%d,%d,%d,%d,", ctx_group_counter, d0, d1, d2,
+ d3);
+ ctx_group_counter++;
+ for (int sym = 0; sym < CDF_SIZE(num_sym); sym++) {
+ int cdf_stride = (fixed_stride == 0) ? num_sym : fixed_stride;
+ int offset =
+ (d0 * num_idx3 * num_idx2 * num_idx1 * CDF_SIZE(cdf_stride)) +
+ (d1 * num_idx3 * num_idx2 * CDF_SIZE(cdf_stride)) +
+ (d2 * num_idx3 * CDF_SIZE(cdf_stride)) +
+ (d3 * CDF_SIZE(cdf_stride)) + sym;
+ if (sym < num_sym)
+ fprintf(fData, "%d", (int)AOM_ICDF(*(prob_ptr + offset)));
+ else
+ fprintf(fData, "%d", (int)*(prob_ptr + offset));
+ if (sym < CDF_SIZE(num_sym - 1)) {
+ fprintf(fData, ",");
+ } else {
+ fprintf(fData, "\n");
+ }
+ }
+ }
+ // main header
+ for (int i = 0; i < num_dims; i++) {
+ fprintf(fData, "Dim%d,", i);
+ }
+
+ fprintf(fData, "FrameNum,FrameType,isBeginFrame,Counter,Value,Cost");
+
+ for (int sym = 0; sym < num_sym; sym++) {
+ fprintf(fData, ",cdf%d", sym);
+ }
+ fprintf(fData, ",rate");
+ fprintf(fData, "\n");
+ }
+#endif
return pbi;
}
@@ -341,6 +449,14 @@
}
#endif
+#if CONFIG_PARAKIT_COLLECT_DATA
+ for (int f = 0; f < MAX_NUM_CTX_GROUPS; f++) {
+ if (pbi->common.prob_models[f].fDataCollect != NULL) {
+ fclose(pbi->common.prob_models[f].fDataCollect);
+ }
+ }
+#endif
+
aom_free(pbi);
}
diff --git a/av1/decoder/decoder.h b/av1/decoder/decoder.h
index b100eb2..7f4ef89 100644
--- a/av1/decoder/decoder.h
+++ b/av1/decoder/decoder.h
@@ -404,7 +404,12 @@
YV12_BUFFER_CONFIG *new_frame,
YV12_BUFFER_CONFIG *sd);
+#if CONFIG_PARAKIT_COLLECT_DATA
+struct AV1Decoder *av1_decoder_create(BufferPool *const pool, const char *path,
+ const char *suffix);
+#else
struct AV1Decoder *av1_decoder_create(BufferPool *const pool);
+#endif
void av1_decoder_remove(struct AV1Decoder *pbi);
void av1_dealloc_dec_jobs(struct AV1DecTileMTData *tile_mt_info);
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index dd6373f..20e449f 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -20,6 +20,19 @@
#include "av1/common/reconintra.h"
#include "av1/decoder/decodemv.h"
+#if CONFIG_PARAKIT_COLLECT_DATA
+#include "av1/common/cost.h"
+#endif
+
+#if CONFIG_PARAKIT_COLLECT_DATA
+static int get_q_ctx(int q) {
+ if (q <= 90) return 0;
+ if (q <= 140) return 1;
+ if (q <= 190) return 2;
+ return 3;
+}
+#endif
+
static int read_golomb(MACROBLOCKD *xd, aom_reader *r) {
int x = 1;
int length = 0;
@@ -347,7 +360,12 @@
// Decode the end-of-block syntax.
static INLINE void decode_eob(DecoderCodingBlock *dcb, aom_reader *const r,
- const int plane, const TX_SIZE tx_size) {
+ const int plane, const TX_SIZE tx_size
+#if CONFIG_PARAKIT_COLLECT_DATA
+ ,
+ const AV1_COMMON *const cm
+#endif
+) {
MACROBLOCKD *const xd = &dcb->xd;
const PLANE_TYPE plane_type = get_plane_type(plane);
FRAME_CONTEXT *const ec_ctx = xd->tile_ctx;
@@ -367,19 +385,47 @@
int eob_extra = 0;
int eob_pt = 1;
const int eob_multi_size = txsize_log2_minus4[tx_size];
+#if CONFIG_PARAKIT_COLLECT_DATA
+ const int qp_index = get_q_ctx(cm->quant_params.base_qindex);
+ int idxlist[MAX_CTX_DIM];
+ idxlist[0] = qp_index;
+ idxlist[1] = pl_ctx;
+ idxlist[2] = -1;
+ idxlist[3] = -1;
+#endif
switch (eob_multi_size) {
case 0:
+#if CONFIG_PARAKIT_COLLECT_DATA
+ {
+ eob_pt =
+ aom_read_symbol_probdata(r, ec_ctx->eob_flag_cdf16[pl_ctx], idxlist,
+ cm->prob_models[EOB_FLAG_CDF16]) +
+ 1;
+ break;
+ }
+#else
eob_pt =
aom_read_symbol(r, ec_ctx->eob_flag_cdf16[pl_ctx], EOB_MAX_SYMS - 6,
ACCT_INFO("eob_pt", "eob_multi_size:0")) +
1;
break;
+#endif
case 1:
+#if CONFIG_PARAKIT_COLLECT_DATA
+ {
+ eob_pt =
+ aom_read_symbol_probdata(r, ec_ctx->eob_flag_cdf32[pl_ctx], idxlist,
+ cm->prob_models[EOB_FLAG_CDF32]) +
+ 1;
+ break;
+ }
+#else
eob_pt =
aom_read_symbol(r, ec_ctx->eob_flag_cdf32[pl_ctx], EOB_MAX_SYMS - 5,
ACCT_INFO("eob_pt", "eob_multi_size:1")) +
1;
break;
+#endif
case 2:
eob_pt =
aom_read_symbol(r, ec_ctx->eob_flag_cdf64[pl_ctx], EOB_MAX_SYMS - 4,
@@ -538,7 +584,12 @@
}
return 0;
}
- decode_eob(dcb, r, plane, tx_size);
+ decode_eob(dcb, r, plane, tx_size
+#if CONFIG_PARAKIT_COLLECT_DATA
+ ,
+ cm
+#endif
+ );
av1_read_tx_type(cm, xd, blk_row, blk_col, tx_size, r, plane, *eob,
is_inter ? 0 : *eob);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index cbe70a9..b5481b7 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -49,7 +49,7 @@
#include "av1/common/tile_common.h"
#include "av1/encoder/bitstream.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/encodemv.h"
#include "av1/encoder/encodetxb.h"
#include "av1/encoder/mcomp.h"
diff --git a/av1/encoder/encodemv.c b/av1/encoder/encodemv.c
index 6684c2a..e8c2679 100644
--- a/av1/encoder/encodemv.c
+++ b/av1/encoder/encodemv.c
@@ -15,7 +15,7 @@
#include "av1/common/common.h"
#include "av1/common/entropymode.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/encodemv.h"
#include "aom_dsp/aom_dsp_common.h"
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index ab7fcc6..7e35b65 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -19,7 +19,7 @@
#include "av1/common/scan.h"
#include "av1/common/reconintra.h"
#include "av1/encoder/bitstream.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/encodeframe.h"
#include "av1/encoder/hash.h"
#include "av1/encoder/rdopt.h"
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index 7a493d8..664c109 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -27,7 +27,7 @@
#include "av1/common/mvref_common.h"
#include "av1/common/reconinter.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/encoder.h"
#include "av1/encoder/encodemv.h"
#include "av1/encoder/mcomp.h"
diff --git a/av1/encoder/palette.c b/av1/encoder/palette.c
index 421078e..d04dde1 100644
--- a/av1/encoder/palette.c
+++ b/av1/encoder/palette.c
@@ -16,7 +16,7 @@
#include "av1/common/pred_common.h"
#include "av1/encoder/block.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/encoder.h"
#include "av1/encoder/intra_mode_search.h"
#include "av1/encoder/intra_mode_search_utils.h"
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 0e71e22..f873a18 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -35,7 +35,7 @@
#include "av1/common/seg_common.h"
#include "av1/encoder/av1_quantize.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/encodemb.h"
#include "av1/encoder/encodemv.h"
#include "av1/encoder/encoder.h"
diff --git a/av1/encoder/rd.h b/av1/encoder/rd.h
index f10c2c2..0c5fd93 100644
--- a/av1/encoder/rd.h
+++ b/av1/encoder/rd.h
@@ -19,7 +19,7 @@
#include "av1/encoder/block.h"
#include "av1/encoder/context_tree.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#ifdef __cplusplus
extern "C" {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index e5d4d3e..96def2a 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -46,7 +46,7 @@
#include "av1/encoder/aq_variance.h"
#include "av1/encoder/av1_quantize.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/compound_type.h"
#include "av1/encoder/encodemb.h"
#include "av1/encoder/encodemv.h"
diff --git a/av1/encoder/segmentation.c b/av1/encoder/segmentation.c
index 31d1ffc..4363500 100644
--- a/av1/encoder/segmentation.c
+++ b/av1/encoder/segmentation.c
@@ -17,7 +17,7 @@
#include "av1/common/pred_common.h"
#include "av1/common/tile_common.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/segmentation.h"
void av1_enable_segmentation(struct segmentation *seg) {
diff --git a/av1/encoder/tokenize.c b/av1/encoder/tokenize.c
index d22f466..68e15c4 100644
--- a/av1/encoder/tokenize.c
+++ b/av1/encoder/tokenize.c
@@ -22,7 +22,7 @@
#include "av1/common/scan.h"
#include "av1/common/seg_common.h"
-#include "av1/encoder/cost.h"
+#include "av1/common/cost.h"
#include "av1/encoder/encoder.h"
#include "av1/encoder/encodetxb.h"
#include "av1/encoder/rdopt.h"
diff --git a/build/cmake/aom_config_defaults.cmake b/build/cmake/aom_config_defaults.cmake
index 8ced654..af6c50b 100644
--- a/build/cmake/aom_config_defaults.cmake
+++ b/build/cmake/aom_config_defaults.cmake
@@ -77,7 +77,7 @@
set_aom_config_var(CONFIG_GPROF 0 "Enable gprof support.")
set_aom_config_var(CONFIG_LIBYUV 1 "Enables libyuv scaling/conversion support.")
-set_aom_config_var(CONFIG_MULTITHREAD 1 "Multithread support.")
+set_aom_config_var(CONFIG_MULTITHREAD 0 "Multithread support.")
set_aom_config_var(CONFIG_OS_SUPPORT 0 "Internal flag.")
set_aom_config_var(CONFIG_PIC 0 "Build with PIC enabled.")
set_aom_config_var(CONFIG_RUNTIME_CPU_DETECT 1 "Runtime CPU detection support.")
@@ -139,6 +139,9 @@
set_aom_config_var(CONFIG_ZERO_OFFSET_BITUPSHIFT 1
"Use zero offset for non-normative bit upshift")
+set_aom_config_var(CONFIG_PARAKIT_COLLECT_DATA 1
+ "enables data collection for ParaKit training.")
+
# AV2 experiment flags.
set_aom_config_var(CONFIG_IMPROVEIDTX_CTXS 1
"AV2 enable improved identity transform coding 1/2.")
diff --git a/build/cmake/aom_experiment_deps.cmake b/build/cmake/aom_experiment_deps.cmake
index a6b2334..ffe9aec 100644
--- a/build/cmake/aom_experiment_deps.cmake
+++ b/build/cmake/aom_experiment_deps.cmake
@@ -34,6 +34,11 @@
change_config_and_warn(CONFIG_TENSORFLOW_LITE 1 CONFIG_EXT_RECUR_PARTITIONS)
endif()
+ # CONFIG_MULTITHREAD is dependent on CONFIG_PARAKIT_COLLECT_DATA.
+ if(CONFIG_PARAKIT_COLLECT_DATA AND CONFIG_MULTITHREAD)
+ change_config_and_warn(CONFIG_MULTITHREAD 0 CONFIG_PARAKIT_COLLECT_DATA)
+ endif()
+
# CONFIG_THROUGHPUT_ANALYSIS requires CONFIG_ACCOUNTING. If CONFIG_ACCOUNTING
# is off, we also turn off CONFIG_THROUGHPUT_ANALYSIS.
if(NOT CONFIG_ACCOUNTING AND CONFIG_THROUGHPUT_ANALYSIS)
diff --git a/test/av1_key_value_api_test.cc b/test/av1_key_value_api_test.cc
index c55c8e3..930dcc8 100644
--- a/test/av1_key_value_api_test.cc
+++ b/test/av1_key_value_api_test.cc
@@ -37,7 +37,7 @@
#endif
#if CONFIG_AV1_DECODER
aom_codec_iface_t *iface_dx = aom_codec_av1_dx();
- aom_codec_dec_cfg_t dec_cfg = { 0, 0, 0 };
+ aom_codec_dec_cfg_t dec_cfg = { 0, 0, 0, NULL, NULL };
EXPECT_EQ(AOM_CODEC_OK, aom_codec_dec_init(&dec_, iface_dx, &dec_cfg, 0));
#endif
diff --git a/test/invalid_file_test.cc b/test/invalid_file_test.cc
index 2ae87aa..3bbde06 100644
--- a/test/invalid_file_test.cc
+++ b/test/invalid_file_test.cc
@@ -102,7 +102,7 @@
void RunTest() {
const DecodeParam input = GET_PARAM(1);
- aom_codec_dec_cfg_t cfg = { 0, 0, 0 };
+ aom_codec_dec_cfg_t cfg = { 0, 0, 0, NULL, NULL };
cfg.threads = input.threads;
const std::string filename = input.filename;
libaom_test::IVFVideoSource decode_video(filename);