Python stats aggregation improvements. Added better examples and updated README.
diff --git a/tools/py_stats/README.md b/tools/py_stats/README.md index aa6ad85..0bccb86 100644 --- a/tools/py_stats/README.md +++ b/tools/py_stats/README.md
@@ -197,30 +197,105 @@ Bits for read_intra_luma_mode: 760.7987823486328 (9.55%) ``` -### Prediction mode / block size aggregation -All previous examples focused on visualizing a single frame from a single stream. This example shows how to compute aggregated stats across an arbitrary number of AVM streams. -A typical use case might be to compare how bits are spent in CTC frames under different test conditions. +### Statistics aggregation +All previous examples focused on visualizing a single frame from a single stream. The following examples show how to compute aggregated stats across an arbitrary number of AVM streams. A typical use case might be to compare how bits are spent in CTC frames under different test conditions. -This example script takes one or more glob paths and will extract all frames from all streams found: +Each of the following examples uses the `aggregate_from_extractor.py` script to plot different types of data from the streams, in addition to dumping that data as a .csv file. + +### Block size distribution (A5, QP=235, 10 frames); Plot Intra vs Inter ```bash -python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_prediction_modes.py \ - --stream_glob "/path/to/some/interesting/streams/*.ivf" \ - --stream_glob "/path/to/some/other/streams/*.ivf" \ - --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \ - --output_csv /tmp/aggregate_dump.csv +python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \ + --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \ + --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \ + --threads 32 \ + --frame_limit 10 \ + --extractor "partition_type" \ + --group_by block_size,partition_type,is_intra_frame \ + --output_csv /tmp/aggregate_block_sizes.csv \ + --plot 'title:"Intra Block Sizes", field:block_size, filter:"is_intra_frame"' \ + --plot 'title:"Inter Block Sizes", field:block_size, filter:"not is_intra_frame"' ``` + -The script will produce three different plots: -1. The overall distribution of block sizes, comparing intra and inter frames: - -2. The overall distribution of prediction modes, comparing intra and inter frames: - +### Prediction mode distribution (A5, QP=235, 10 frames); Plot Intra vs Inter +```bash +python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \ + --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \ + --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \ + --threads 32 \ + --frame_limit 10 \ + --extractor "prediction_mode" \ + --group_by mode,is_intra_frame \ + --output_csv /tmp/prediction_modes.csv \ + --plot 'title:"Intra Prediction Modes", field:mode, filter:"is_intra_frame"' \ + --plot 'title:"Inter Prediction Modes", field:mode, filter:"not is_intra_frame"' +``` + -3. Prediction mode distribution, weighted by the number of bits used to code each mode, comparing intra and inter frames: - -The script also takes an optional `--output_csv` argument that will dump the final Pandas dataframe to a CSV file. +### Prediction modes weighted by bits used (A5, QP=235, 10 frames); Plot Intra vs Inter +```bash +python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \ + --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \ + --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \ + --threads 32 \ + --frame_limit 10 \ + --extractor "prediction_mode" \ + --aggregated_field mode_bits \ + --group_by mode,is_intra_frame \ + --output_csv /tmp/prediction_mode_bits.csv \ + --plot 'title:"Intra Prediction Modes by bits", field:mode, filter:"is_intra_frame"' \ + --plot 'title:"Inter Prediction Modes by bits", field:mode, filter:"not is_intra_frame"' +``` + + +### TX types (A5, QP=235, 10 frames); Plot Intra vs Inter +```bash +python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \ + --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \ + --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \ + --threads 32 \ + --frame_limit 10 \ + --extractor "tx_type" \ + --group_by tx_type,is_intra_frame \ + --output_csv /tmp/aggregate_tx_types.csv \ + --plot 'title:"Intra TX Types", field:tx_type, filter:"is_intra_frame"' \ + --plot 'title:"Inter TX Types", field:tx_type, filter:"not is_intra_frame"' +``` + + +### Intra partition types (A5, QP=210, 10 frames); Block sizes 64x64 vs 32x32 +```bash +python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \ + --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \ + --stream_glob "${CTC_STREAMS}/A5_Natural_270p/AllIntra/*_QP_210.bin" \ + --threads 32 \ + --frame_limit 10 \ + --extractor "partition_type" \ + --group_by partition_type,block_size \ + --output_csv /tmp/aggregate_partition_types.csv \ + --plot 'title:"Partition types (64x64)", field:partition_type, filter:block_size=="64x64"' \ + --plot 'title:"Partition types (32x32)", field:partition_type, filter:block_size=="32x32"' +``` + + +### Top 10 symbols types weighted by bits used (A5, QP=235, 10 frames); Intra vs Inter frames +```bash +python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \ + --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \ + --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_210.bin" \ + --threads 32 \ + --frame_limit 10 \ + --extractor "symbol_bits" \ + --aggregated_field bits \ + --group_by symbol_name,is_intra_frame \ + --output_csv /tmp/aggregate_symbols.csv \ + --plot 'title:"Symbols (Intra frames)", field:symbol_name, filter:"is_intra_frame", limit:10' \ + --plot 'title:"Symbols (Inter frames)", field:symbol_name, filter:"not is_intra_frame", limit:10' +``` + + ### Launch an example Jupyter notebook This demonstrates many of the same visualizations listed above, but in a single Jupyter notebook.
diff --git a/tools/py_stats/avm_stats/proto_helpers.py b/tools/py_stats/avm_stats/proto_helpers.py index b012ec7..b6c861a 100644 --- a/tools/py_stats/avm_stats/proto_helpers.py +++ b/tools/py_stats/avm_stats/proto_helpers.py
@@ -164,7 +164,7 @@ (pixels_height, pixels_width) )[:sb_height_clipped, :sb_width_clipped] - samples[sb_y : sb_y + sb_height_clipped, sb_x : sb_x + sb_width_clipped] = ( + samples[sb_y: sb_y + sb_height_clipped, sb_x: sb_x + sb_width_clipped] = ( superblock_samples ) @@ -541,23 +541,10 @@ self.superblocks = [ Superblock(self, sb_proto) for sb_proto in proto.superblocks ] - self.pixels = [ - _create_plane_buffer(self, p) for p in (Plane.Y, Plane.U, Plane.V) - ] - - if self.pixels[0].original is None: - self.original_rgb = None - else: - self.original_rgb = yuv_tools.yuv_to_rgb( - self.pixels[0].original, - yuv_tools.upscale(self.pixels[1].original, 2), - yuv_tools.upscale(self.pixels[2].original, 2), - ) - self.reconstruction_rgb = yuv_tools.yuv_to_rgb( - self.pixels[0].reconstruction, - yuv_tools.upscale(self.pixels[1].reconstruction, 2), - yuv_tools.upscale(self.pixels[2].reconstruction, 2), - ) + # Pixel data is created lazily. + self._pixels = None + self._original_rgb = None + self._reconstruction_rgb = None @property def frame_id(self) -> int: @@ -575,6 +562,38 @@ def bit_depth(self) -> int: return self.proto.frame_params.bit_depth + @property + def is_intra_frame(self) -> bool: + return self.proto.frame_params.frame_type == 0 + + @property + def pixels(self) -> list[PlaneBuffer]: + if self._pixels is None: + self.pixels = [ + _create_plane_buffer(self, p) for p in (Plane.Y, Plane.U, Plane.V) + ] + return self._pixels + + @property + def original_rgb(self) -> np.ndarray: + if self._original_rgb is None and self.pixels[0].original is not None: + self._original_rgb = yuv_tools.yuv_to_rgb( + self.pixels[0].original, + yuv_tools.upscale(self.pixels[1].original, 2), + yuv_tools.upscale(self.pixels[2].original, 2), + ) + return self._original_rgb + + @property + def reconstruction_rgb(self) -> np.ndarray: + if self._reconstruction_rgb is None: + self._reconstruction_rgb = yuv_tools.yuv_to_rgb( + self.pixels[0].reconstruction, + yuv_tools.upscale(self.pixels[1].reconstruction, 2), + yuv_tools.upscale(self.pixels[2].reconstruction, 2), + ) + return self._reconstruction_rgb + def clip_rect(self, rect: Rectangle) -> Rectangle: """Clips a rectangle to be contained with the frame boundaries.
diff --git a/tools/py_stats/examples/aggregate_from_extractor.py b/tools/py_stats/examples/aggregate_from_extractor.py new file mode 100755 index 0000000..2ae1747 --- /dev/null +++ b/tools/py_stats/examples/aggregate_from_extractor.py
@@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +## Copyright (c) 2023, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 3-Clause Clear License and the +## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was +## not distributed with this source code in the LICENSE file, you can obtain it +## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent +## License 1.0 was not distributed with this source code in the PATENTS file, you +## can obtain it at aomedia.org/license/patent-license/. +## +from __future__ import annotations + +from functools import partial +import gc +import glob +import pathlib +import sys +import tempfile +from typing import Sequence, Type + +from absl import app +from absl import flags +from absl import logging +from avm_stats.extract_proto import * +from avm_stats.frame_visualizations import * +from avm_stats.proto_helpers import * +from avm_stats.stats_aggregation import * +import matplotlib.pyplot as plt +import multiprocessing +import pandas as pd + +from extractors.partition_similarity_extractor import PartitionSimilarityExtractor +from extractors.partition_type_extractor import PartitionTypeExtractor +from extractors.prediction_mode_extractor import PredictionModeExtractor +from extractors.symbol_bits_extractor import SymbolBitsExtractor +from extractors.tx_type_extractor import TxTypeExtractor + +_EXTRACTORS = { + "partition_similarity": PartitionSimilarityExtractor, + "partition_type": PartitionTypeExtractor, + "prediction_mode": PredictionModeExtractor, + "symbol_bits": SymbolBitsExtractor, + "tx_type": TxTypeExtractor, +} + +_STREAM_GLOB = flags.DEFINE_multi_string( + "stream_glob", None, "Path to AVM streams." +) +flags.mark_flag_as_required("stream_glob") + +_EXTRACT_PROTO_BIN = flags.DEFINE_string( + "extract_proto_bin", None, "Path to extract_proto binary." +) +flags.mark_flag_as_required("extract_proto_bin") + +_OUTPUT_CSV = flags.DEFINE_string( + "output_csv", None, "Path to output CSV file (optional)." +) + +_AGGREGATED_FIELD = flags.DEFINE_string( + "aggregated_field", "count", "Field to aggregate on. By default will use the count." +) + +_GROUP_BY = flags.DEFINE_string( + "group_by", None, "Group by these fields (comma separated)." +) + +_EXTRACTOR = flags.DEFINE_enum( + "extractor", None, _EXTRACTORS.keys(), "Data extractor to use." +) +flags.mark_flag_as_required("extractor") + +_THREADS = flags.DEFINE_integer( + "threads", 1, "Number of parallel workers to spawn." +) + +_FRAME_LIMIT = flags.DEFINE_integer( + "frame_limit", None, "Use at most this many frames from each stream." +) + +_PLOT = flags.DEFINE_multi_string( + "plot", None, """Plot command args, e.g.: --plot 'title:"Block Sizes", field:"block_size", filter:"not is_intra_frame", limit:10'""" +) + +_TMP_DIR = flags.DEFINE_string( + "tmp_dir", None, "Temp working dir." +) + +@dataclasses.dataclass +class PlotArgs: + title: str = "" + field: str = "" + filter: str = "" + limit: str = "" + + def __init__(self, args_str: str): + def kv_pair(kv: str) -> tuple[str, str]: + k, v = [i.strip() for i in kv.split(":")] + if v.startswith('"') and v.endswith('"'): + v = v[1:-1] + return (k, v) + + for k, v in [kv_pair(kv) for kv in args_str.split(",")]: + if not hasattr(self, k): + raise ValueError(f"Unknown plot arg: {k}") + setattr(self, k, v) + + +def filter_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count", filt: str = "", limit: int | None = None): + if filt: + filtered_df = df.query(filt)[group_by + [aggregated_field]].groupby(group_by, as_index=False) + else: + filtered_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False) + filtered_df = filtered_df.sum() + filtered_df["percent"] = filtered_df[aggregated_field].transform(lambda x: x / x.sum() * 100) + filtered_df = filtered_df.sort_values(by=[aggregated_field], ascending=False) + if limit is not None: + filtered_df_top_n = filtered_df.iloc[:limit, :] + other_total = filtered_df.iloc[limit:, :][aggregated_field].sum() + other_percent = filtered_df.iloc[limit:, :]["percent"].sum() + others = pd.DataFrame({group_by[0]: ["Other"], aggregated_field: other_total, "percent": other_percent}) + filtered_df = pd.concat([filtered_df_top_n, others]) + return filtered_df + + +def sum_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count"): + aggregated_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False) + aggregated_df = aggregated_df.sum() + return aggregated_df + + +def create_plot( + df: pd.DataFrame, + plot_args: PlotArgs, + ax: plt.Axes, + legend_ax: plt.Axes, + aggregated_field: str, +): + limit = int(plot_args.limit) if plot_args.limit else None + aggregated_df = filter_dataframe( + df, group_by=[plot_args.field], aggregated_field=aggregated_field, filt=plot_args.filter, limit=limit) + plot_title = plot_args.title or plot_args.field + ax.set_title(f"{plot_title}") + patches, _ = ax.pie( + x=aggregated_df[aggregated_field], labels=aggregated_df[plot_args.field], autopct=None) + labels = [f"{n} - {p:.1f}% ({v:.1f})" for n, p, v in zip(aggregated_df[plot_args.field], + aggregated_df["percent"], aggregated_df[aggregated_field])] + legend_ax.legend(patches, labels, loc="best", ncol=2) + legend_ax.axis("off") + + +def extract_to_temp_dir(stream_path: pathlib.Path, frame_limit: int | None = None) -> Iterator[proto_helpers.Frame]: + with tempfile.TemporaryDirectory(dir=_TMP_DIR.value) as tmp_dir: + tmp_path = pathlib.Path(tmp_dir) + extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value) + stream_name = stream_path.stem + output_path = tmp_path / stream_name + try: + output_path.mkdir() + except FileExistsError: + logging.fatal(f"Duplicate stream name: {stream_name}") + yield from extract_and_load_protos( + extract_proto_path=extract_proto_path, + stream_path=stream_path, + output_path=output_path, + frame_limit=frame_limit + ) + + +def process_stream(stream_path: pathlib.Path, *, extractor_class: Type[Extractor]) -> pd.DataFrame: + group_by = _GROUP_BY.value.split(",") + frames = extract_to_temp_dir(stream_path, _FRAME_LIMIT.value) + df = None + for frame in frames: + frame_df = aggregate_to_dataframe([frame], extractor_class()) + frame_df["count"] = 1 + if df is None: + df = frame_df + else: + df = pd.concat([df, frame_df]) + if len(df): + df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value) + gc.collect() + if len(df): + df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value) + gc.collect() + with open(f"/opt/tmp/progress/{stream_path.stem}.txt", "w") as f: + f.write("Done") + return df + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + extractor_name = _EXTRACTOR.value + try: + extractor_class = _EXTRACTORS[extractor_name] + except KeyError: + print(f"Unknown extractor: {extractor_name}", file=sys.stderr) + sys.exit(1) + + stream_paths = [] + for stream_glob in _STREAM_GLOB.value: + for stream in glob.glob(stream_glob, recursive=True): + path = pathlib.Path(stream) + stream_paths.append(path) + + stream_paths = sorted(stream_paths) + with multiprocessing.Pool(_THREADS.value) as pool: + dfs = pool.map( + partial(process_stream, extractor_class=extractor_class), stream_paths) + + df = pd.concat(dfs) + group_by = _GROUP_BY.value.split(",") + aggregated_df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value) + + if _OUTPUT_CSV.value: + aggregated_df.to_csv(_OUTPUT_CSV.value) + + plots = _PLOT.value + if plots: + _, axes = plt.subplots(2, len(plots), squeeze=False, figsize=(10, 5)) + for i, plot in enumerate(plots): + args = PlotArgs(plot) + create_plot(df, args, axes[0, i], axes[1, i], aggregated_field=_AGGREGATED_FIELD.value) + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + app.run(main)
diff --git a/tools/py_stats/examples/aggregate_prediction_modes.py b/tools/py_stats/examples/aggregate_prediction_modes.py deleted file mode 100644 index 9d73fd3..0000000 --- a/tools/py_stats/examples/aggregate_prediction_modes.py +++ /dev/null
@@ -1,154 +0,0 @@ -#!/usr/bin/env python3 -## Copyright (c) 2023, Alliance for Open Media. All rights reserved -## -## This source code is subject to the terms of the BSD 3-Clause Clear License and the -## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was -## not distributed with this source code in the LICENSE file, you can obtain it -## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent -## License 1.0 was not distributed with this source code in the PATENTS file, you -## can obtain it at aomedia.org/license/patent-license/. -## -from __future__ import annotations - -import collections -import glob -import itertools -import pathlib -import tempfile -from typing import Sequence - -from absl import app -from absl import flags -from absl import logging -from avm_stats.extract_proto import * -from avm_stats.frame_visualizations import * -from avm_stats.proto_helpers import * -from avm_stats.stats_aggregation import * -import matplotlib.pyplot as plt -import pandas as pd - -_STREAM_GLOB = flags.DEFINE_multi_string( - "stream_glob", None, "Path to AVM stream." -) -flags.mark_flag_as_required("stream_glob") - -_EXTRACT_PROTO_BIN = flags.DEFINE_string( - "extract_proto_bin", None, "Path to extract_proto binary." -) -flags.mark_flag_as_required("extract_proto_bin") - -_OUTPUT_CSV = flags.DEFINE_string( - "output_csv", None, "Path to output CSV file (optional)." -) - - -def prediction_mode_symbol_filter(symbol: Symbol): - return symbol.source_function in ( - "read_inter_mode", - "read_intra_luma_mode", - "read_drl_index", - "read_inter_compound_mode", - ) - - -class PredictionModeExtractor(CodingUnitExtractor): - PredictionMode = collections.namedtuple( - "PredictionMode", - ["width", "height", "mode", "mode_bits", "is_intra_frame"], - ) - - def sample(self, coding_unit: CodingUnit): - width = coding_unit.rect.width - height = coding_unit.rect.height - mode = coding_unit.get_prediction_mode() - mode_bits = sum( - sym.bits - for sym in coding_unit.get_symbols(prediction_mode_symbol_filter) - ) - is_intra_frame = coding_unit.frame.proto.frame_params.frame_type == 0 - yield self.PredictionMode(width, height, mode, mode_bits, is_intra_frame) - - -def compare_intra_inter( - df: pd.DataFrame, - column: str, - *, - aggregated_field: str = "count", - aggregator: str = "count", - plot_title: str | None = None, -): - df_intra = df.query("is_intra_frame")[[column, aggregated_field]].groupby( - column, as_index=False - ) - df_intra = getattr(df_intra, aggregator)() - df_inter = df.query("not is_intra_frame")[[column, aggregated_field]].groupby( - column, as_index=False - ) - df_inter = getattr(df_inter, aggregator)() - _, axes = plt.subplots(1, 2) - plot_title = plot_title or column - axes[0].pie( - x=df_intra[aggregated_field], labels=df_intra[column], autopct="%.1f%%" - ) - axes[0].set_title(f"{plot_title} (Intra frames)") - axes[1].pie( - x=df_inter[aggregated_field], labels=df_inter[column], autopct="%.1f%%" - ) - axes[1].set_title(f"{plot_title} (Inter frames)") - plt.show() - - -def main(argv: Sequence[str]) -> None: - if len(argv) > 1: - raise app.UsageError("Too many command-line arguments.") - - stream_paths = [] - for stream_glob in _STREAM_GLOB.value: - for stream in glob.glob(stream_glob, recursive=True): - path = pathlib.Path(stream) - stream_paths.append(path) - - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = pathlib.Path(tmp_dir) - extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value) - - def extract_to_temp_dir(stream_path: pathlib.Path) -> Frame: - stream_name = stream_path.stem - output_path = tmp_path / stream_name - try: - output_path.mkdir() - except FileExistsError: - logging.fatal(f"Duplicate stream name: {stream_name}") - yield from extract_and_load_protos( - extract_proto_path=extract_proto_path, - stream_path=stream_path, - output_path=output_path, - ) - - all_frames = itertools.chain.from_iterable( - map(extract_to_temp_dir, stream_paths) - ) - df = aggregate_to_dataframe(all_frames, PredictionModeExtractor()) - df["count"] = 1 - df["width_height"] = df.apply( - lambda row: f"{row.width}x{row.height}", axis=1 - ) - - compare_intra_inter( - df, "width_height", plot_title="Block size distribution" - ) - compare_intra_inter(df, "mode", plot_title="Prediction mode distribution") - compare_intra_inter( - df, - "mode", - aggregated_field="mode_bits", - aggregator="sum", - plot_title="Prediction modes by bits spent", - ) - - if _OUTPUT_CSV.value: - df.to_csv(_OUTPUT_CSV.value) - - -if __name__ == "__main__": - app.run(main)
diff --git a/tools/py_stats/examples/extractors/partition_type_extractor.py b/tools/py_stats/examples/extractors/partition_type_extractor.py new file mode 100644 index 0000000..c972d5f --- /dev/null +++ b/tools/py_stats/examples/extractors/partition_type_extractor.py
@@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +## Copyright (c) 2023, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 3-Clause Clear License and the +## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was +## not distributed with this source code in the LICENSE file, you can obtain it +## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent +## License 1.0 was not distributed with this source code in the PATENTS file, you +## can obtain it at aomedia.org/license/patent-license/. +## +import collections + +from avm_stats.extract_proto import * +from avm_stats.frame_visualizations import * +from avm_stats.proto_helpers import * +from avm_stats.stats_aggregation import * + +def iter_partitions(partition_block): + yield partition_block + if not partition_block.is_leaf_node: + for child in partition_block.children: + yield from iter_partitions(child) + +class PartitionTypeExtractor(SuperblockExtractor): + PartitionType = collections.namedtuple( + "PartitionType", + ["width", "height", "block_size", "partition_type", "is_intra_frame", "stream_name"], + ) + + def sample(self, superblock: Superblock): + stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin") + is_intra_frame = superblock.frame.is_intra_frame + for partition_block in iter_partitions(superblock.proto.luma_partition_tree): + width = partition_block.size.width + height = partition_block.size.height + block_size = f"{width}x{height}" + partition_type = partition_block.partition_type + partition_type_name = superblock.frame.proto.enum_mappings.partition_type_mapping[partition_type] + yield self.PartitionType(width, height, block_size, partition_type_name, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/extractors/prediction_mode_extractor.py b/tools/py_stats/examples/extractors/prediction_mode_extractor.py new file mode 100644 index 0000000..6333989 --- /dev/null +++ b/tools/py_stats/examples/extractors/prediction_mode_extractor.py
@@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +## Copyright (c) 2023, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 3-Clause Clear License and the +## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was +## not distributed with this source code in the LICENSE file, you can obtain it +## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent +## License 1.0 was not distributed with this source code in the PATENTS file, you +## can obtain it at aomedia.org/license/patent-license/. +## +import collections + +from avm_stats.extract_proto import * +from avm_stats.frame_visualizations import * +from avm_stats.proto_helpers import * +from avm_stats.stats_aggregation import * + +def prediction_mode_symbol_filter(symbol: Symbol): + return symbol.source_function in ( + "read_inter_mode", + "read_intra_luma_mode", + "read_drl_index", + "read_inter_compound_mode", + ) + +class PredictionModeExtractor(CodingUnitExtractor): + PredictionMode = collections.namedtuple( + "PredictionMode", + ["width", "height", "block_size", "mode", "mode_bits", "is_intra_frame", "stream_name"], + ) + + def sample(self, coding_unit: CodingUnit): + stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin") + width = coding_unit.rect.width + height = coding_unit.rect.height + block_size = f"{width}x{height}" + mode = coding_unit.get_prediction_mode() + mode_bits = sum( + sym.bits + for sym in coding_unit.get_symbols(prediction_mode_symbol_filter) + ) + is_intra_frame = coding_unit.frame.is_intra_frame + yield self.PredictionMode(width, height, block_size, mode, mode_bits, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/extractors/symbol_bits_extractor.py b/tools/py_stats/examples/extractors/symbol_bits_extractor.py new file mode 100644 index 0000000..2a70bd5 --- /dev/null +++ b/tools/py_stats/examples/extractors/symbol_bits_extractor.py
@@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +## Copyright (c) 2023, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 3-Clause Clear License and the +## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was +## not distributed with this source code in the LICENSE file, you can obtain it +## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent +## License 1.0 was not distributed with this source code in the PATENTS file, you +## can obtain it at aomedia.org/license/patent-license/. +## +import collections + +from avm_stats.extract_proto import * +from avm_stats.frame_visualizations import * +from avm_stats.proto_helpers import * +from avm_stats.stats_aggregation import * + +CTC_CLASSES = [ + 2160, + 1080, + 720, + 360, + 270, +] + +class SymbolBitsExtractor(SuperblockExtractor): + SymbolBits = collections.namedtuple( + "SymbolBits", + ["symbol_name", "symbol_tags", "bits", "is_intra_frame", "stream_name", "qp", "ctc_class", "ctc_config"], + ) + + def sample(self, superblock: Superblock): + stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin") + qp = stream_name.split("_")[-1] + ctc_config = stream_name.split("_")[-3] + ctc_class = None + for i, c in enumerate(CTC_CLASSES): + if str(c) in stream_name: + assert ctc_class is None + ctc_class = f"A{i+1}" + assert ctc_class is not None + is_intra_frame = superblock.frame.is_intra_frame + for symbol in superblock.proto.symbols: + symbol_info = superblock.frame.proto.symbol_info[symbol.info_id] + symbol_tags = "/".join(symbol_info.tags) + yield self.SymbolBits(symbol_info.source_function, symbol_tags, symbol.bits, is_intra_frame, stream_name, qp, ctc_class, ctc_config)
diff --git a/tools/py_stats/examples/extractors/tx_type_extractor.py b/tools/py_stats/examples/extractors/tx_type_extractor.py new file mode 100644 index 0000000..bf9ae93 --- /dev/null +++ b/tools/py_stats/examples/extractors/tx_type_extractor.py
@@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +## Copyright (c) 2023, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 3-Clause Clear License and the +## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was +## not distributed with this source code in the LICENSE file, you can obtain it +## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent +## License 1.0 was not distributed with this source code in the PATENTS file, you +## can obtain it at aomedia.org/license/patent-license/. +## +import collections + +from avm_stats.extract_proto import * +from avm_stats.frame_visualizations import * +from avm_stats.proto_helpers import * +from avm_stats.stats_aggregation import * + + +class TxTypeExtractor(CodingUnitExtractor): + TxSize = collections.namedtuple( + "TxSize", + ["width", "height", "tx_size", "tx_type", "is_intra_frame", "stream_name"], + ) + + def sample(self, coding_unit: CodingUnit): + stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin") + # Luma only + is_chroma = len(coding_unit.proto.transform_planes) == 2 + if is_chroma: + return + is_intra_frame = coding_unit.frame.is_intra_frame + for transform_unit in coding_unit.proto.transform_planes[0].transform_units: + width = transform_unit.size.width + height = transform_unit.size.height + tx_size = f"{width}x{height}" + # TODO(comc): Add method on transform_unit for this. + tx_type = transform_unit.tx_type & 0xF + tx_type_name = coding_unit.frame.proto.enum_mappings.transform_type_mapping[tx_type] + yield self.TxSize(width, height, tx_size, tx_type_name, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/partition_tree.py b/tools/py_stats/examples/partition_tree.py index 3a26160..d2c32ad 100755 --- a/tools/py_stats/examples/partition_tree.py +++ b/tools/py_stats/examples/partition_tree.py
@@ -1,12 +1,12 @@ #!/usr/bin/env python3 -# Copyright (c) 2023, Alliance for Open Media. All rights reserved +## Copyright (c) 2023, Alliance for Open Media. All rights reserved ## -# This source code is subject to the terms of the BSD 3-Clause Clear License and the -# Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was -# not distributed with this source code in the LICENSE file, you can obtain it -# at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent -# License 1.0 was not distributed with this source code in the PATENTS file, you -# can obtain it at aomedia.org/license/patent-license/. +## This source code is subject to the terms of the BSD 3-Clause Clear License and the +## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was +## not distributed with this source code in the LICENSE file, you can obtain it +## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent +## License 1.0 was not distributed with this source code in the PATENTS file, you +## can obtain it at aomedia.org/license/patent-license/. ## from __future__ import annotations
diff --git a/tools/py_stats/img/aggregation_block_sizes.png b/tools/py_stats/img/aggregation_block_sizes.png new file mode 100644 index 0000000..8064457 --- /dev/null +++ b/tools/py_stats/img/aggregation_block_sizes.png Binary files differ
diff --git a/tools/py_stats/img/aggregation_partition_types.png b/tools/py_stats/img/aggregation_partition_types.png new file mode 100644 index 0000000..f632dea --- /dev/null +++ b/tools/py_stats/img/aggregation_partition_types.png Binary files differ
diff --git a/tools/py_stats/img/aggregation_prediction_modes.png b/tools/py_stats/img/aggregation_prediction_modes.png new file mode 100644 index 0000000..81dafce --- /dev/null +++ b/tools/py_stats/img/aggregation_prediction_modes.png Binary files differ
diff --git a/tools/py_stats/img/aggregation_prediction_modes_by_bits.png b/tools/py_stats/img/aggregation_prediction_modes_by_bits.png new file mode 100644 index 0000000..e87f9d9 --- /dev/null +++ b/tools/py_stats/img/aggregation_prediction_modes_by_bits.png Binary files differ
diff --git a/tools/py_stats/img/aggregation_symbols.png b/tools/py_stats/img/aggregation_symbols.png new file mode 100644 index 0000000..cc6e226 --- /dev/null +++ b/tools/py_stats/img/aggregation_symbols.png Binary files differ
diff --git a/tools/py_stats/img/aggregation_tx_types.png b/tools/py_stats/img/aggregation_tx_types.png new file mode 100644 index 0000000..6c63310 --- /dev/null +++ b/tools/py_stats/img/aggregation_tx_types.png Binary files differ
diff --git a/tools/py_stats/img/bits_heatmap.png b/tools/py_stats/img/bits_heatmap.png new file mode 100644 index 0000000..bf7e6e5 --- /dev/null +++ b/tools/py_stats/img/bits_heatmap.png Binary files differ
diff --git a/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png b/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png new file mode 100644 index 0000000..4db7259 --- /dev/null +++ b/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png Binary files differ
diff --git a/tools/py_stats/img/partition_tree_luma_inter.png b/tools/py_stats/img/partition_tree_luma_inter.png new file mode 100644 index 0000000..e4ee744 --- /dev/null +++ b/tools/py_stats/img/partition_tree_luma_inter.png Binary files differ
diff --git a/tools/py_stats/img/partition_tree_luma_intra.png b/tools/py_stats/img/partition_tree_luma_intra.png new file mode 100644 index 0000000..169dffc --- /dev/null +++ b/tools/py_stats/img/partition_tree_luma_intra.png Binary files differ
diff --git a/tools/py_stats/img/pixel_pipeline.png b/tools/py_stats/img/pixel_pipeline.png new file mode 100644 index 0000000..7fcac2d --- /dev/null +++ b/tools/py_stats/img/pixel_pipeline.png Binary files differ
diff --git a/tools/py_stats/img/prediction_modes_luma_intra.png b/tools/py_stats/img/prediction_modes_luma_intra.png new file mode 100644 index 0000000..d27c7b8 --- /dev/null +++ b/tools/py_stats/img/prediction_modes_luma_intra.png Binary files differ
diff --git a/tools/py_stats/pyproject.toml b/tools/py_stats/pyproject.toml index 469ba6a..6d94bab 100644 --- a/tools/py_stats/pyproject.toml +++ b/tools/py_stats/pyproject.toml
@@ -7,7 +7,7 @@ license = "BSD-3-Clause" readme = "README.md" # By default Poetry will use .gitignore, which excludes the generated protobuf bindings. -include = ["*.py"] +include = ["**/*.py"] [tool.poetry.dependencies] absl-py = "^2.0.0"