Python stats aggregation improvements.
Added better examples and updated README.
diff --git a/tools/py_stats/README.md b/tools/py_stats/README.md
index aa6ad85..0bccb86 100644
--- a/tools/py_stats/README.md
+++ b/tools/py_stats/README.md
@@ -197,30 +197,105 @@
Bits for read_intra_luma_mode: 760.7987823486328 (9.55%)
```
-### Prediction mode / block size aggregation
-All previous examples focused on visualizing a single frame from a single stream. This example shows how to compute aggregated stats across an arbitrary number of AVM streams.
-A typical use case might be to compare how bits are spent in CTC frames under different test conditions.
+### Statistics aggregation
+All previous examples focused on visualizing a single frame from a single stream. The following examples show how to compute aggregated stats across an arbitrary number of AVM streams. A typical use case might be to compare how bits are spent in CTC frames under different test conditions.
-This example script takes one or more glob paths and will extract all frames from all streams found:
+Each of the following examples uses the `aggregate_from_extractor.py` script to plot different types of data from the streams, in addition to dumping that data as a .csv file.
+
+### Block size distribution (A5, QP=235, 10 frames); Plot Intra vs Inter
```bash
-python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_prediction_modes.py \
- --stream_glob "/path/to/some/interesting/streams/*.ivf" \
- --stream_glob "/path/to/some/other/streams/*.ivf" \
- --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
- --output_csv /tmp/aggregate_dump.csv
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+ --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+ --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+ --threads 32 \
+ --frame_limit 10 \
+ --extractor "partition_type" \
+ --group_by block_size,partition_type,is_intra_frame \
+ --output_csv /tmp/aggregate_block_sizes.csv \
+ --plot 'title:"Intra Block Sizes", field:block_size, filter:"is_intra_frame"' \
+ --plot 'title:"Inter Block Sizes", field:block_size, filter:"not is_intra_frame"'
```
+
-The script will produce three different plots:
-1. The overall distribution of block sizes, comparing intra and inter frames:
-
-2. The overall distribution of prediction modes, comparing intra and inter frames:
-
+### Prediction mode distribution (A5, QP=235, 10 frames); Plot Intra vs Inter
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+ --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+ --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+ --threads 32 \
+ --frame_limit 10 \
+ --extractor "prediction_mode" \
+ --group_by mode,is_intra_frame \
+ --output_csv /tmp/prediction_modes.csv \
+ --plot 'title:"Intra Prediction Modes", field:mode, filter:"is_intra_frame"' \
+ --plot 'title:"Inter Prediction Modes", field:mode, filter:"not is_intra_frame"'
+```
+
-3. Prediction mode distribution, weighted by the number of bits used to code each mode, comparing intra and inter frames:
-
-The script also takes an optional `--output_csv` argument that will dump the final Pandas dataframe to a CSV file.
+### Prediction modes weighted by bits used (A5, QP=235, 10 frames); Plot Intra vs Inter
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+ --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+ --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+ --threads 32 \
+ --frame_limit 10 \
+ --extractor "prediction_mode" \
+ --aggregated_field mode_bits \
+ --group_by mode,is_intra_frame \
+ --output_csv /tmp/prediction_mode_bits.csv \
+ --plot 'title:"Intra Prediction Modes by bits", field:mode, filter:"is_intra_frame"' \
+ --plot 'title:"Inter Prediction Modes by bits", field:mode, filter:"not is_intra_frame"'
+```
+
+
+### TX types (A5, QP=235, 10 frames); Plot Intra vs Inter
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+ --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+ --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+ --threads 32 \
+ --frame_limit 10 \
+ --extractor "tx_type" \
+ --group_by tx_type,is_intra_frame \
+ --output_csv /tmp/aggregate_tx_types.csv \
+ --plot 'title:"Intra TX Types", field:tx_type, filter:"is_intra_frame"' \
+ --plot 'title:"Inter TX Types", field:tx_type, filter:"not is_intra_frame"'
+```
+
+
+### Intra partition types (A5, QP=210, 10 frames); Block sizes 64x64 vs 32x32
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+ --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+ --stream_glob "${CTC_STREAMS}/A5_Natural_270p/AllIntra/*_QP_210.bin" \
+ --threads 32 \
+ --frame_limit 10 \
+ --extractor "partition_type" \
+ --group_by partition_type,block_size \
+ --output_csv /tmp/aggregate_partition_types.csv \
+ --plot 'title:"Partition types (64x64)", field:partition_type, filter:block_size=="64x64"' \
+ --plot 'title:"Partition types (32x32)", field:partition_type, filter:block_size=="32x32"'
+```
+
+
+### Top 10 symbols types weighted by bits used (A5, QP=235, 10 frames); Intra vs Inter frames
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+ --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+ --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_210.bin" \
+ --threads 32 \
+ --frame_limit 10 \
+ --extractor "symbol_bits" \
+ --aggregated_field bits \
+ --group_by symbol_name,is_intra_frame \
+ --output_csv /tmp/aggregate_symbols.csv \
+ --plot 'title:"Symbols (Intra frames)", field:symbol_name, filter:"is_intra_frame", limit:10' \
+ --plot 'title:"Symbols (Inter frames)", field:symbol_name, filter:"not is_intra_frame", limit:10'
+```
+
+
### Launch an example Jupyter notebook
This demonstrates many of the same visualizations listed above, but in a single Jupyter notebook.
diff --git a/tools/py_stats/avm_stats/proto_helpers.py b/tools/py_stats/avm_stats/proto_helpers.py
index b012ec7..b6c861a 100644
--- a/tools/py_stats/avm_stats/proto_helpers.py
+++ b/tools/py_stats/avm_stats/proto_helpers.py
@@ -164,7 +164,7 @@
(pixels_height, pixels_width)
)[:sb_height_clipped, :sb_width_clipped]
- samples[sb_y : sb_y + sb_height_clipped, sb_x : sb_x + sb_width_clipped] = (
+ samples[sb_y: sb_y + sb_height_clipped, sb_x: sb_x + sb_width_clipped] = (
superblock_samples
)
@@ -541,23 +541,10 @@
self.superblocks = [
Superblock(self, sb_proto) for sb_proto in proto.superblocks
]
- self.pixels = [
- _create_plane_buffer(self, p) for p in (Plane.Y, Plane.U, Plane.V)
- ]
-
- if self.pixels[0].original is None:
- self.original_rgb = None
- else:
- self.original_rgb = yuv_tools.yuv_to_rgb(
- self.pixels[0].original,
- yuv_tools.upscale(self.pixels[1].original, 2),
- yuv_tools.upscale(self.pixels[2].original, 2),
- )
- self.reconstruction_rgb = yuv_tools.yuv_to_rgb(
- self.pixels[0].reconstruction,
- yuv_tools.upscale(self.pixels[1].reconstruction, 2),
- yuv_tools.upscale(self.pixels[2].reconstruction, 2),
- )
+ # Pixel data is created lazily.
+ self._pixels = None
+ self._original_rgb = None
+ self._reconstruction_rgb = None
@property
def frame_id(self) -> int:
@@ -575,6 +562,38 @@
def bit_depth(self) -> int:
return self.proto.frame_params.bit_depth
+ @property
+ def is_intra_frame(self) -> bool:
+ return self.proto.frame_params.frame_type == 0
+
+ @property
+ def pixels(self) -> list[PlaneBuffer]:
+ if self._pixels is None:
+ self.pixels = [
+ _create_plane_buffer(self, p) for p in (Plane.Y, Plane.U, Plane.V)
+ ]
+ return self._pixels
+
+ @property
+ def original_rgb(self) -> np.ndarray:
+ if self._original_rgb is None and self.pixels[0].original is not None:
+ self._original_rgb = yuv_tools.yuv_to_rgb(
+ self.pixels[0].original,
+ yuv_tools.upscale(self.pixels[1].original, 2),
+ yuv_tools.upscale(self.pixels[2].original, 2),
+ )
+ return self._original_rgb
+
+ @property
+ def reconstruction_rgb(self) -> np.ndarray:
+ if self._reconstruction_rgb is None:
+ self._reconstruction_rgb = yuv_tools.yuv_to_rgb(
+ self.pixels[0].reconstruction,
+ yuv_tools.upscale(self.pixels[1].reconstruction, 2),
+ yuv_tools.upscale(self.pixels[2].reconstruction, 2),
+ )
+ return self._reconstruction_rgb
+
def clip_rect(self, rect: Rectangle) -> Rectangle:
"""Clips a rectangle to be contained with the frame boundaries.
diff --git a/tools/py_stats/examples/aggregate_from_extractor.py b/tools/py_stats/examples/aggregate_from_extractor.py
new file mode 100755
index 0000000..2ae1747
--- /dev/null
+++ b/tools/py_stats/examples/aggregate_from_extractor.py
@@ -0,0 +1,232 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+from __future__ import annotations
+
+from functools import partial
+import gc
+import glob
+import pathlib
+import sys
+import tempfile
+from typing import Sequence, Type
+
+from absl import app
+from absl import flags
+from absl import logging
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+import matplotlib.pyplot as plt
+import multiprocessing
+import pandas as pd
+
+from extractors.partition_similarity_extractor import PartitionSimilarityExtractor
+from extractors.partition_type_extractor import PartitionTypeExtractor
+from extractors.prediction_mode_extractor import PredictionModeExtractor
+from extractors.symbol_bits_extractor import SymbolBitsExtractor
+from extractors.tx_type_extractor import TxTypeExtractor
+
+_EXTRACTORS = {
+ "partition_similarity": PartitionSimilarityExtractor,
+ "partition_type": PartitionTypeExtractor,
+ "prediction_mode": PredictionModeExtractor,
+ "symbol_bits": SymbolBitsExtractor,
+ "tx_type": TxTypeExtractor,
+}
+
+_STREAM_GLOB = flags.DEFINE_multi_string(
+ "stream_glob", None, "Path to AVM streams."
+)
+flags.mark_flag_as_required("stream_glob")
+
+_EXTRACT_PROTO_BIN = flags.DEFINE_string(
+ "extract_proto_bin", None, "Path to extract_proto binary."
+)
+flags.mark_flag_as_required("extract_proto_bin")
+
+_OUTPUT_CSV = flags.DEFINE_string(
+ "output_csv", None, "Path to output CSV file (optional)."
+)
+
+_AGGREGATED_FIELD = flags.DEFINE_string(
+ "aggregated_field", "count", "Field to aggregate on. By default will use the count."
+)
+
+_GROUP_BY = flags.DEFINE_string(
+ "group_by", None, "Group by these fields (comma separated)."
+)
+
+_EXTRACTOR = flags.DEFINE_enum(
+ "extractor", None, _EXTRACTORS.keys(), "Data extractor to use."
+)
+flags.mark_flag_as_required("extractor")
+
+_THREADS = flags.DEFINE_integer(
+ "threads", 1, "Number of parallel workers to spawn."
+)
+
+_FRAME_LIMIT = flags.DEFINE_integer(
+ "frame_limit", None, "Use at most this many frames from each stream."
+)
+
+_PLOT = flags.DEFINE_multi_string(
+ "plot", None, """Plot command args, e.g.: --plot 'title:"Block Sizes", field:"block_size", filter:"not is_intra_frame", limit:10'"""
+)
+
+_TMP_DIR = flags.DEFINE_string(
+ "tmp_dir", None, "Temp working dir."
+)
+
+@dataclasses.dataclass
+class PlotArgs:
+ title: str = ""
+ field: str = ""
+ filter: str = ""
+ limit: str = ""
+
+ def __init__(self, args_str: str):
+ def kv_pair(kv: str) -> tuple[str, str]:
+ k, v = [i.strip() for i in kv.split(":")]
+ if v.startswith('"') and v.endswith('"'):
+ v = v[1:-1]
+ return (k, v)
+
+ for k, v in [kv_pair(kv) for kv in args_str.split(",")]:
+ if not hasattr(self, k):
+ raise ValueError(f"Unknown plot arg: {k}")
+ setattr(self, k, v)
+
+
+def filter_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count", filt: str = "", limit: int | None = None):
+ if filt:
+ filtered_df = df.query(filt)[group_by + [aggregated_field]].groupby(group_by, as_index=False)
+ else:
+ filtered_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False)
+ filtered_df = filtered_df.sum()
+ filtered_df["percent"] = filtered_df[aggregated_field].transform(lambda x: x / x.sum() * 100)
+ filtered_df = filtered_df.sort_values(by=[aggregated_field], ascending=False)
+ if limit is not None:
+ filtered_df_top_n = filtered_df.iloc[:limit, :]
+ other_total = filtered_df.iloc[limit:, :][aggregated_field].sum()
+ other_percent = filtered_df.iloc[limit:, :]["percent"].sum()
+ others = pd.DataFrame({group_by[0]: ["Other"], aggregated_field: other_total, "percent": other_percent})
+ filtered_df = pd.concat([filtered_df_top_n, others])
+ return filtered_df
+
+
+def sum_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count"):
+ aggregated_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False)
+ aggregated_df = aggregated_df.sum()
+ return aggregated_df
+
+
+def create_plot(
+ df: pd.DataFrame,
+ plot_args: PlotArgs,
+ ax: plt.Axes,
+ legend_ax: plt.Axes,
+ aggregated_field: str,
+):
+ limit = int(plot_args.limit) if plot_args.limit else None
+ aggregated_df = filter_dataframe(
+ df, group_by=[plot_args.field], aggregated_field=aggregated_field, filt=plot_args.filter, limit=limit)
+ plot_title = plot_args.title or plot_args.field
+ ax.set_title(f"{plot_title}")
+ patches, _ = ax.pie(
+ x=aggregated_df[aggregated_field], labels=aggregated_df[plot_args.field], autopct=None)
+ labels = [f"{n} - {p:.1f}% ({v:.1f})" for n, p, v in zip(aggregated_df[plot_args.field],
+ aggregated_df["percent"], aggregated_df[aggregated_field])]
+ legend_ax.legend(patches, labels, loc="best", ncol=2)
+ legend_ax.axis("off")
+
+
+def extract_to_temp_dir(stream_path: pathlib.Path, frame_limit: int | None = None) -> Iterator[proto_helpers.Frame]:
+ with tempfile.TemporaryDirectory(dir=_TMP_DIR.value) as tmp_dir:
+ tmp_path = pathlib.Path(tmp_dir)
+ extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value)
+ stream_name = stream_path.stem
+ output_path = tmp_path / stream_name
+ try:
+ output_path.mkdir()
+ except FileExistsError:
+ logging.fatal(f"Duplicate stream name: {stream_name}")
+ yield from extract_and_load_protos(
+ extract_proto_path=extract_proto_path,
+ stream_path=stream_path,
+ output_path=output_path,
+ frame_limit=frame_limit
+ )
+
+
+def process_stream(stream_path: pathlib.Path, *, extractor_class: Type[Extractor]) -> pd.DataFrame:
+ group_by = _GROUP_BY.value.split(",")
+ frames = extract_to_temp_dir(stream_path, _FRAME_LIMIT.value)
+ df = None
+ for frame in frames:
+ frame_df = aggregate_to_dataframe([frame], extractor_class())
+ frame_df["count"] = 1
+ if df is None:
+ df = frame_df
+ else:
+ df = pd.concat([df, frame_df])
+ if len(df):
+ df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
+ gc.collect()
+ if len(df):
+ df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
+ gc.collect()
+ with open(f"/opt/tmp/progress/{stream_path.stem}.txt", "w") as f:
+ f.write("Done")
+ return df
+
+
+def main(argv: Sequence[str]) -> None:
+ if len(argv) > 1:
+ raise app.UsageError("Too many command-line arguments.")
+
+ extractor_name = _EXTRACTOR.value
+ try:
+ extractor_class = _EXTRACTORS[extractor_name]
+ except KeyError:
+ print(f"Unknown extractor: {extractor_name}", file=sys.stderr)
+ sys.exit(1)
+
+ stream_paths = []
+ for stream_glob in _STREAM_GLOB.value:
+ for stream in glob.glob(stream_glob, recursive=True):
+ path = pathlib.Path(stream)
+ stream_paths.append(path)
+
+ stream_paths = sorted(stream_paths)
+ with multiprocessing.Pool(_THREADS.value) as pool:
+ dfs = pool.map(
+ partial(process_stream, extractor_class=extractor_class), stream_paths)
+
+ df = pd.concat(dfs)
+ group_by = _GROUP_BY.value.split(",")
+ aggregated_df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
+
+ if _OUTPUT_CSV.value:
+ aggregated_df.to_csv(_OUTPUT_CSV.value)
+
+ plots = _PLOT.value
+ if plots:
+ _, axes = plt.subplots(2, len(plots), squeeze=False, figsize=(10, 5))
+ for i, plot in enumerate(plots):
+ args = PlotArgs(plot)
+ create_plot(df, args, axes[0, i], axes[1, i], aggregated_field=_AGGREGATED_FIELD.value)
+ plt.tight_layout()
+ plt.show()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/tools/py_stats/examples/aggregate_prediction_modes.py b/tools/py_stats/examples/aggregate_prediction_modes.py
deleted file mode 100644
index 9d73fd3..0000000
--- a/tools/py_stats/examples/aggregate_prediction_modes.py
+++ /dev/null
@@ -1,154 +0,0 @@
-#!/usr/bin/env python3
-## Copyright (c) 2023, Alliance for Open Media. All rights reserved
-##
-## This source code is subject to the terms of the BSD 3-Clause Clear License and the
-## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
-## not distributed with this source code in the LICENSE file, you can obtain it
-## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
-## License 1.0 was not distributed with this source code in the PATENTS file, you
-## can obtain it at aomedia.org/license/patent-license/.
-##
-from __future__ import annotations
-
-import collections
-import glob
-import itertools
-import pathlib
-import tempfile
-from typing import Sequence
-
-from absl import app
-from absl import flags
-from absl import logging
-from avm_stats.extract_proto import *
-from avm_stats.frame_visualizations import *
-from avm_stats.proto_helpers import *
-from avm_stats.stats_aggregation import *
-import matplotlib.pyplot as plt
-import pandas as pd
-
-_STREAM_GLOB = flags.DEFINE_multi_string(
- "stream_glob", None, "Path to AVM stream."
-)
-flags.mark_flag_as_required("stream_glob")
-
-_EXTRACT_PROTO_BIN = flags.DEFINE_string(
- "extract_proto_bin", None, "Path to extract_proto binary."
-)
-flags.mark_flag_as_required("extract_proto_bin")
-
-_OUTPUT_CSV = flags.DEFINE_string(
- "output_csv", None, "Path to output CSV file (optional)."
-)
-
-
-def prediction_mode_symbol_filter(symbol: Symbol):
- return symbol.source_function in (
- "read_inter_mode",
- "read_intra_luma_mode",
- "read_drl_index",
- "read_inter_compound_mode",
- )
-
-
-class PredictionModeExtractor(CodingUnitExtractor):
- PredictionMode = collections.namedtuple(
- "PredictionMode",
- ["width", "height", "mode", "mode_bits", "is_intra_frame"],
- )
-
- def sample(self, coding_unit: CodingUnit):
- width = coding_unit.rect.width
- height = coding_unit.rect.height
- mode = coding_unit.get_prediction_mode()
- mode_bits = sum(
- sym.bits
- for sym in coding_unit.get_symbols(prediction_mode_symbol_filter)
- )
- is_intra_frame = coding_unit.frame.proto.frame_params.frame_type == 0
- yield self.PredictionMode(width, height, mode, mode_bits, is_intra_frame)
-
-
-def compare_intra_inter(
- df: pd.DataFrame,
- column: str,
- *,
- aggregated_field: str = "count",
- aggregator: str = "count",
- plot_title: str | None = None,
-):
- df_intra = df.query("is_intra_frame")[[column, aggregated_field]].groupby(
- column, as_index=False
- )
- df_intra = getattr(df_intra, aggregator)()
- df_inter = df.query("not is_intra_frame")[[column, aggregated_field]].groupby(
- column, as_index=False
- )
- df_inter = getattr(df_inter, aggregator)()
- _, axes = plt.subplots(1, 2)
- plot_title = plot_title or column
- axes[0].pie(
- x=df_intra[aggregated_field], labels=df_intra[column], autopct="%.1f%%"
- )
- axes[0].set_title(f"{plot_title} (Intra frames)")
- axes[1].pie(
- x=df_inter[aggregated_field], labels=df_inter[column], autopct="%.1f%%"
- )
- axes[1].set_title(f"{plot_title} (Inter frames)")
- plt.show()
-
-
-def main(argv: Sequence[str]) -> None:
- if len(argv) > 1:
- raise app.UsageError("Too many command-line arguments.")
-
- stream_paths = []
- for stream_glob in _STREAM_GLOB.value:
- for stream in glob.glob(stream_glob, recursive=True):
- path = pathlib.Path(stream)
- stream_paths.append(path)
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- tmp_path = pathlib.Path(tmp_dir)
- extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value)
-
- def extract_to_temp_dir(stream_path: pathlib.Path) -> Frame:
- stream_name = stream_path.stem
- output_path = tmp_path / stream_name
- try:
- output_path.mkdir()
- except FileExistsError:
- logging.fatal(f"Duplicate stream name: {stream_name}")
- yield from extract_and_load_protos(
- extract_proto_path=extract_proto_path,
- stream_path=stream_path,
- output_path=output_path,
- )
-
- all_frames = itertools.chain.from_iterable(
- map(extract_to_temp_dir, stream_paths)
- )
- df = aggregate_to_dataframe(all_frames, PredictionModeExtractor())
- df["count"] = 1
- df["width_height"] = df.apply(
- lambda row: f"{row.width}x{row.height}", axis=1
- )
-
- compare_intra_inter(
- df, "width_height", plot_title="Block size distribution"
- )
- compare_intra_inter(df, "mode", plot_title="Prediction mode distribution")
- compare_intra_inter(
- df,
- "mode",
- aggregated_field="mode_bits",
- aggregator="sum",
- plot_title="Prediction modes by bits spent",
- )
-
- if _OUTPUT_CSV.value:
- df.to_csv(_OUTPUT_CSV.value)
-
-
-if __name__ == "__main__":
- app.run(main)
diff --git a/tools/py_stats/examples/extractors/partition_type_extractor.py b/tools/py_stats/examples/extractors/partition_type_extractor.py
new file mode 100644
index 0000000..c972d5f
--- /dev/null
+++ b/tools/py_stats/examples/extractors/partition_type_extractor.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+def iter_partitions(partition_block):
+ yield partition_block
+ if not partition_block.is_leaf_node:
+ for child in partition_block.children:
+ yield from iter_partitions(child)
+
+class PartitionTypeExtractor(SuperblockExtractor):
+ PartitionType = collections.namedtuple(
+ "PartitionType",
+ ["width", "height", "block_size", "partition_type", "is_intra_frame", "stream_name"],
+ )
+
+ def sample(self, superblock: Superblock):
+ stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin")
+ is_intra_frame = superblock.frame.is_intra_frame
+ for partition_block in iter_partitions(superblock.proto.luma_partition_tree):
+ width = partition_block.size.width
+ height = partition_block.size.height
+ block_size = f"{width}x{height}"
+ partition_type = partition_block.partition_type
+ partition_type_name = superblock.frame.proto.enum_mappings.partition_type_mapping[partition_type]
+ yield self.PartitionType(width, height, block_size, partition_type_name, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/extractors/prediction_mode_extractor.py b/tools/py_stats/examples/extractors/prediction_mode_extractor.py
new file mode 100644
index 0000000..6333989
--- /dev/null
+++ b/tools/py_stats/examples/extractors/prediction_mode_extractor.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+def prediction_mode_symbol_filter(symbol: Symbol):
+ return symbol.source_function in (
+ "read_inter_mode",
+ "read_intra_luma_mode",
+ "read_drl_index",
+ "read_inter_compound_mode",
+ )
+
+class PredictionModeExtractor(CodingUnitExtractor):
+ PredictionMode = collections.namedtuple(
+ "PredictionMode",
+ ["width", "height", "block_size", "mode", "mode_bits", "is_intra_frame", "stream_name"],
+ )
+
+ def sample(self, coding_unit: CodingUnit):
+ stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin")
+ width = coding_unit.rect.width
+ height = coding_unit.rect.height
+ block_size = f"{width}x{height}"
+ mode = coding_unit.get_prediction_mode()
+ mode_bits = sum(
+ sym.bits
+ for sym in coding_unit.get_symbols(prediction_mode_symbol_filter)
+ )
+ is_intra_frame = coding_unit.frame.is_intra_frame
+ yield self.PredictionMode(width, height, block_size, mode, mode_bits, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/extractors/symbol_bits_extractor.py b/tools/py_stats/examples/extractors/symbol_bits_extractor.py
new file mode 100644
index 0000000..2a70bd5
--- /dev/null
+++ b/tools/py_stats/examples/extractors/symbol_bits_extractor.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+CTC_CLASSES = [
+ 2160,
+ 1080,
+ 720,
+ 360,
+ 270,
+]
+
+class SymbolBitsExtractor(SuperblockExtractor):
+ SymbolBits = collections.namedtuple(
+ "SymbolBits",
+ ["symbol_name", "symbol_tags", "bits", "is_intra_frame", "stream_name", "qp", "ctc_class", "ctc_config"],
+ )
+
+ def sample(self, superblock: Superblock):
+ stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin")
+ qp = stream_name.split("_")[-1]
+ ctc_config = stream_name.split("_")[-3]
+ ctc_class = None
+ for i, c in enumerate(CTC_CLASSES):
+ if str(c) in stream_name:
+ assert ctc_class is None
+ ctc_class = f"A{i+1}"
+ assert ctc_class is not None
+ is_intra_frame = superblock.frame.is_intra_frame
+ for symbol in superblock.proto.symbols:
+ symbol_info = superblock.frame.proto.symbol_info[symbol.info_id]
+ symbol_tags = "/".join(symbol_info.tags)
+ yield self.SymbolBits(symbol_info.source_function, symbol_tags, symbol.bits, is_intra_frame, stream_name, qp, ctc_class, ctc_config)
diff --git a/tools/py_stats/examples/extractors/tx_type_extractor.py b/tools/py_stats/examples/extractors/tx_type_extractor.py
new file mode 100644
index 0000000..bf9ae93
--- /dev/null
+++ b/tools/py_stats/examples/extractors/tx_type_extractor.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+
+class TxTypeExtractor(CodingUnitExtractor):
+ TxSize = collections.namedtuple(
+ "TxSize",
+ ["width", "height", "tx_size", "tx_type", "is_intra_frame", "stream_name"],
+ )
+
+ def sample(self, coding_unit: CodingUnit):
+ stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin")
+ # Luma only
+ is_chroma = len(coding_unit.proto.transform_planes) == 2
+ if is_chroma:
+ return
+ is_intra_frame = coding_unit.frame.is_intra_frame
+ for transform_unit in coding_unit.proto.transform_planes[0].transform_units:
+ width = transform_unit.size.width
+ height = transform_unit.size.height
+ tx_size = f"{width}x{height}"
+ # TODO(comc): Add method on transform_unit for this.
+ tx_type = transform_unit.tx_type & 0xF
+ tx_type_name = coding_unit.frame.proto.enum_mappings.transform_type_mapping[tx_type]
+ yield self.TxSize(width, height, tx_size, tx_type_name, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/partition_tree.py b/tools/py_stats/examples/partition_tree.py
index 3a26160..d2c32ad 100755
--- a/tools/py_stats/examples/partition_tree.py
+++ b/tools/py_stats/examples/partition_tree.py
@@ -1,12 +1,12 @@
#!/usr/bin/env python3
-# Copyright (c) 2023, Alliance for Open Media. All rights reserved
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
##
-# This source code is subject to the terms of the BSD 3-Clause Clear License and the
-# Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
-# not distributed with this source code in the LICENSE file, you can obtain it
-# at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
-# License 1.0 was not distributed with this source code in the PATENTS file, you
-# can obtain it at aomedia.org/license/patent-license/.
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
##
from __future__ import annotations
diff --git a/tools/py_stats/img/aggregation_block_sizes.png b/tools/py_stats/img/aggregation_block_sizes.png
new file mode 100644
index 0000000..8064457
--- /dev/null
+++ b/tools/py_stats/img/aggregation_block_sizes.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_partition_types.png b/tools/py_stats/img/aggregation_partition_types.png
new file mode 100644
index 0000000..f632dea
--- /dev/null
+++ b/tools/py_stats/img/aggregation_partition_types.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_prediction_modes.png b/tools/py_stats/img/aggregation_prediction_modes.png
new file mode 100644
index 0000000..81dafce
--- /dev/null
+++ b/tools/py_stats/img/aggregation_prediction_modes.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_prediction_modes_by_bits.png b/tools/py_stats/img/aggregation_prediction_modes_by_bits.png
new file mode 100644
index 0000000..e87f9d9
--- /dev/null
+++ b/tools/py_stats/img/aggregation_prediction_modes_by_bits.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_symbols.png b/tools/py_stats/img/aggregation_symbols.png
new file mode 100644
index 0000000..cc6e226
--- /dev/null
+++ b/tools/py_stats/img/aggregation_symbols.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_tx_types.png b/tools/py_stats/img/aggregation_tx_types.png
new file mode 100644
index 0000000..6c63310
--- /dev/null
+++ b/tools/py_stats/img/aggregation_tx_types.png
Binary files differ
diff --git a/tools/py_stats/img/bits_heatmap.png b/tools/py_stats/img/bits_heatmap.png
new file mode 100644
index 0000000..bf7e6e5
--- /dev/null
+++ b/tools/py_stats/img/bits_heatmap.png
Binary files differ
diff --git a/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png b/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png
new file mode 100644
index 0000000..4db7259
--- /dev/null
+++ b/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png
Binary files differ
diff --git a/tools/py_stats/img/partition_tree_luma_inter.png b/tools/py_stats/img/partition_tree_luma_inter.png
new file mode 100644
index 0000000..e4ee744
--- /dev/null
+++ b/tools/py_stats/img/partition_tree_luma_inter.png
Binary files differ
diff --git a/tools/py_stats/img/partition_tree_luma_intra.png b/tools/py_stats/img/partition_tree_luma_intra.png
new file mode 100644
index 0000000..169dffc
--- /dev/null
+++ b/tools/py_stats/img/partition_tree_luma_intra.png
Binary files differ
diff --git a/tools/py_stats/img/pixel_pipeline.png b/tools/py_stats/img/pixel_pipeline.png
new file mode 100644
index 0000000..7fcac2d
--- /dev/null
+++ b/tools/py_stats/img/pixel_pipeline.png
Binary files differ
diff --git a/tools/py_stats/img/prediction_modes_luma_intra.png b/tools/py_stats/img/prediction_modes_luma_intra.png
new file mode 100644
index 0000000..d27c7b8
--- /dev/null
+++ b/tools/py_stats/img/prediction_modes_luma_intra.png
Binary files differ
diff --git a/tools/py_stats/pyproject.toml b/tools/py_stats/pyproject.toml
index 469ba6a..6d94bab 100644
--- a/tools/py_stats/pyproject.toml
+++ b/tools/py_stats/pyproject.toml
@@ -7,7 +7,7 @@
license = "BSD-3-Clause"
readme = "README.md"
# By default Poetry will use .gitignore, which excludes the generated protobuf bindings.
-include = ["*.py"]
+include = ["**/*.py"]
[tool.poetry.dependencies]
absl-py = "^2.0.0"