Python stats aggregation improvements.

Added better examples and updated README.
diff --git a/tools/py_stats/README.md b/tools/py_stats/README.md
index aa6ad85..0bccb86 100644
--- a/tools/py_stats/README.md
+++ b/tools/py_stats/README.md
@@ -197,30 +197,105 @@
 Bits for read_intra_luma_mode: 760.7987823486328 (9.55%)
 ```
 
-### Prediction mode / block size aggregation
-All previous examples focused on visualizing a single frame from a single stream. This example shows how to compute aggregated stats across an arbitrary number of AVM streams.
-A typical use case might be to compare how bits are spent in CTC frames under different test conditions.
+### Statistics aggregation
+All previous examples focused on visualizing a single frame from a single stream. The following examples show how to compute aggregated stats across an arbitrary number of AVM streams. A typical use case might be to compare how bits are spent in CTC frames under different test conditions.
 
-This example script takes one or more glob paths and will extract all frames from all streams found:
+Each of the following examples uses the `aggregate_from_extractor.py` script to plot different types of data from the streams, in addition to dumping that data as a .csv file.
+
+### Block size distribution (A5, QP=235, 10 frames); Plot Intra vs Inter
 ```bash
-python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_prediction_modes.py \
-    --stream_glob "/path/to/some/interesting/streams/*.ivf" \
-    --stream_glob "/path/to/some/other/streams/*.ivf" \
-    --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
-    --output_csv /tmp/aggregate_dump.csv
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+  --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+  --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+  --threads 32 \
+  --frame_limit 10 \
+  --extractor "partition_type" \
+  --group_by block_size,partition_type,is_intra_frame \
+  --output_csv /tmp/aggregate_block_sizes.csv \
+  --plot 'title:"Intra Block Sizes", field:block_size, filter:"is_intra_frame"' \
+  --plot 'title:"Inter Block Sizes", field:block_size, filter:"not is_intra_frame"'
 ```
+![](img/aggregation_block_sizes.png "Sample output from aggregate_from_extractor.py with the partition_type extractor using block sizes.")
 
-The script will produce three different plots:
-1. The overall distribution of block sizes, comparing intra and inter frames:
-![](img/block_size_distribution.png "Sample output from aggregate_prediction_modes.py showing block size distribution")
 
-2. The overall distribution of prediction modes, comparing intra and inter frames:
-![](img/prediction_mode_distribution.png "Sample output from aggregate_prediction_modes.py showing prediction mode distribution")
+### Prediction mode distribution (A5, QP=235, 10 frames); Plot Intra vs Inter
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+  --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+  --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+  --threads 32 \
+  --frame_limit 10 \
+  --extractor "prediction_mode" \
+  --group_by mode,is_intra_frame \
+  --output_csv /tmp/prediction_modes.csv \
+  --plot 'title:"Intra Prediction Modes", field:mode, filter:"is_intra_frame"' \
+  --plot 'title:"Inter Prediction Modes", field:mode, filter:"not is_intra_frame"'
+```
+![](img/aggregation_prediction_modes.png "Sample output from aggregate_from_extractor.py with the prediction_mode extractor counting each mode.")
 
-3. Prediction mode distribution, weighted by the number of bits used to code each mode, comparing intra and inter frames:
-![](img/prediction_mode_distribution_by_bits.png "Sample output from aggregate_prediction_modes.py showing prediction mode distribution weighted by bits")
 
-The script also takes an optional `--output_csv` argument that will dump the final Pandas dataframe to a CSV file.
+### Prediction modes weighted by bits used (A5, QP=235, 10 frames); Plot Intra vs Inter
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+  --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+  --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+  --threads 32 \
+  --frame_limit 10 \
+  --extractor "prediction_mode" \
+  --aggregated_field mode_bits \
+  --group_by mode,is_intra_frame \
+  --output_csv /tmp/prediction_mode_bits.csv \
+  --plot 'title:"Intra Prediction Modes by bits", field:mode, filter:"is_intra_frame"' \
+  --plot 'title:"Inter Prediction Modes by bits", field:mode, filter:"not is_intra_frame"'
+```
+![](img/aggregation_prediction_modes_by_bits.png "Sample output from aggregate_from_extractor.py with the prediction_mode extractor weighting by number of bits.")
+
+### TX types (A5, QP=235, 10 frames); Plot Intra vs Inter
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+  --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+  --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_235.bin" \
+  --threads 32 \
+  --frame_limit 10 \
+  --extractor "tx_type" \
+  --group_by tx_type,is_intra_frame \
+  --output_csv /tmp/aggregate_tx_types.csv \
+  --plot 'title:"Intra TX Types", field:tx_type, filter:"is_intra_frame"' \
+  --plot 'title:"Inter TX Types", field:tx_type, filter:"not is_intra_frame"'
+```
+![](img/aggregation_tx_types.png "Sample output from aggregate_from_extractor.py with the tx_type extractor.")
+
+### Intra partition types (A5, QP=210, 10 frames); Block sizes 64x64 vs 32x32
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+  --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+  --stream_glob "${CTC_STREAMS}/A5_Natural_270p/AllIntra/*_QP_210.bin" \
+  --threads 32 \
+  --frame_limit 10 \
+  --extractor "partition_type" \
+  --group_by partition_type,block_size \
+  --output_csv /tmp/aggregate_partition_types.csv \
+  --plot 'title:"Partition types (64x64)", field:partition_type, filter:block_size=="64x64"' \
+  --plot 'title:"Partition types (32x32)", field:partition_type, filter:block_size=="32x32"'
+```
+![](img/aggregation_partition_types.png "Sample output from aggregate_from_extractor.py with the partition_type extractor.")
+
+### Top 10 symbols types weighted by bits used (A5, QP=235, 10 frames); Intra vs Inter frames
+```bash
+python3 ${AOM_ROOT}/tools/py_stats/examples/aggregate_from_extractor.py \
+  --extract_proto_bin ${AOM_BUILD_DIR}/extract_proto \
+  --stream_glob "${CTC_STREAMS}/A5_Natural_270p/LowDelay/*_QP_210.bin" \
+  --threads 32 \
+  --frame_limit 10 \
+  --extractor "symbol_bits" \
+  --aggregated_field bits \
+  --group_by symbol_name,is_intra_frame \
+  --output_csv /tmp/aggregate_symbols.csv \
+  --plot 'title:"Symbols (Intra frames)", field:symbol_name, filter:"is_intra_frame", limit:10' \
+  --plot 'title:"Symbols (Inter frames)", field:symbol_name, filter:"not is_intra_frame", limit:10'
+```
+![](img/aggregation_symbols.png "Sample output from aggregate_from_extractor.py with the symbol_bits extractor.")
+
 
 ### Launch an example Jupyter notebook
 This demonstrates many of the same visualizations listed above, but in a single Jupyter notebook.
diff --git a/tools/py_stats/avm_stats/proto_helpers.py b/tools/py_stats/avm_stats/proto_helpers.py
index b012ec7..b6c861a 100644
--- a/tools/py_stats/avm_stats/proto_helpers.py
+++ b/tools/py_stats/avm_stats/proto_helpers.py
@@ -164,7 +164,7 @@
         (pixels_height, pixels_width)
     )[:sb_height_clipped, :sb_width_clipped]
 
-    samples[sb_y : sb_y + sb_height_clipped, sb_x : sb_x + sb_width_clipped] = (
+    samples[sb_y: sb_y + sb_height_clipped, sb_x: sb_x + sb_width_clipped] = (
         superblock_samples
     )
 
@@ -541,23 +541,10 @@
     self.superblocks = [
         Superblock(self, sb_proto) for sb_proto in proto.superblocks
     ]
-    self.pixels = [
-        _create_plane_buffer(self, p) for p in (Plane.Y, Plane.U, Plane.V)
-    ]
-
-    if self.pixels[0].original is None:
-      self.original_rgb = None
-    else:
-      self.original_rgb = yuv_tools.yuv_to_rgb(
-          self.pixels[0].original,
-          yuv_tools.upscale(self.pixels[1].original, 2),
-          yuv_tools.upscale(self.pixels[2].original, 2),
-      )
-    self.reconstruction_rgb = yuv_tools.yuv_to_rgb(
-        self.pixels[0].reconstruction,
-        yuv_tools.upscale(self.pixels[1].reconstruction, 2),
-        yuv_tools.upscale(self.pixels[2].reconstruction, 2),
-    )
+    # Pixel data is created lazily.
+    self._pixels = None
+    self._original_rgb = None
+    self._reconstruction_rgb = None
 
   @property
   def frame_id(self) -> int:
@@ -575,6 +562,38 @@
   def bit_depth(self) -> int:
     return self.proto.frame_params.bit_depth
 
+  @property
+  def is_intra_frame(self) -> bool:
+    return self.proto.frame_params.frame_type == 0
+
+  @property
+  def pixels(self) -> list[PlaneBuffer]:
+    if self._pixels is None:
+      self.pixels = [
+          _create_plane_buffer(self, p) for p in (Plane.Y, Plane.U, Plane.V)
+      ]
+    return self._pixels
+
+  @property
+  def original_rgb(self) -> np.ndarray:
+    if self._original_rgb is None and self.pixels[0].original is not None:
+      self._original_rgb = yuv_tools.yuv_to_rgb(
+          self.pixels[0].original,
+          yuv_tools.upscale(self.pixels[1].original, 2),
+          yuv_tools.upscale(self.pixels[2].original, 2),
+      )
+    return self._original_rgb
+
+  @property
+  def reconstruction_rgb(self) -> np.ndarray:
+    if self._reconstruction_rgb is None:
+      self._reconstruction_rgb = yuv_tools.yuv_to_rgb(
+          self.pixels[0].reconstruction,
+          yuv_tools.upscale(self.pixels[1].reconstruction, 2),
+          yuv_tools.upscale(self.pixels[2].reconstruction, 2),
+      )
+    return self._reconstruction_rgb
+
   def clip_rect(self, rect: Rectangle) -> Rectangle:
     """Clips a rectangle to be contained with the frame boundaries.
 
diff --git a/tools/py_stats/examples/aggregate_from_extractor.py b/tools/py_stats/examples/aggregate_from_extractor.py
new file mode 100755
index 0000000..2ae1747
--- /dev/null
+++ b/tools/py_stats/examples/aggregate_from_extractor.py
@@ -0,0 +1,232 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+from __future__ import annotations
+
+from functools import partial
+import gc
+import glob
+import pathlib
+import sys
+import tempfile
+from typing import Sequence, Type
+
+from absl import app
+from absl import flags
+from absl import logging
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+import matplotlib.pyplot as plt
+import multiprocessing
+import pandas as pd
+
+from extractors.partition_similarity_extractor import PartitionSimilarityExtractor
+from extractors.partition_type_extractor import PartitionTypeExtractor
+from extractors.prediction_mode_extractor import PredictionModeExtractor
+from extractors.symbol_bits_extractor import SymbolBitsExtractor
+from extractors.tx_type_extractor import TxTypeExtractor
+
+_EXTRACTORS = {
+    "partition_similarity": PartitionSimilarityExtractor,
+    "partition_type": PartitionTypeExtractor,
+    "prediction_mode": PredictionModeExtractor,
+    "symbol_bits": SymbolBitsExtractor,
+    "tx_type": TxTypeExtractor,
+}
+
+_STREAM_GLOB = flags.DEFINE_multi_string(
+    "stream_glob", None, "Path to AVM streams."
+)
+flags.mark_flag_as_required("stream_glob")
+
+_EXTRACT_PROTO_BIN = flags.DEFINE_string(
+    "extract_proto_bin", None, "Path to extract_proto binary."
+)
+flags.mark_flag_as_required("extract_proto_bin")
+
+_OUTPUT_CSV = flags.DEFINE_string(
+    "output_csv", None, "Path to output CSV file (optional)."
+)
+
+_AGGREGATED_FIELD = flags.DEFINE_string(
+    "aggregated_field", "count", "Field to aggregate on. By default will use the count."
+)
+
+_GROUP_BY = flags.DEFINE_string(
+    "group_by", None, "Group by these fields (comma separated)."
+)
+
+_EXTRACTOR = flags.DEFINE_enum(
+    "extractor", None, _EXTRACTORS.keys(), "Data extractor to use."
+)
+flags.mark_flag_as_required("extractor")
+
+_THREADS = flags.DEFINE_integer(
+    "threads", 1, "Number of parallel workers to spawn."
+)
+
+_FRAME_LIMIT = flags.DEFINE_integer(
+    "frame_limit", None, "Use at most this many frames from each stream."
+)
+
+_PLOT = flags.DEFINE_multi_string(
+    "plot", None, """Plot command args, e.g.: --plot 'title:"Block Sizes", field:"block_size", filter:"not is_intra_frame", limit:10'"""
+)
+
+_TMP_DIR = flags.DEFINE_string(
+    "tmp_dir", None, "Temp working dir."
+)
+
+@dataclasses.dataclass
+class PlotArgs:
+  title: str = ""
+  field: str = ""
+  filter: str = ""
+  limit: str = ""
+
+  def __init__(self, args_str: str):
+    def kv_pair(kv: str) -> tuple[str, str]:
+      k, v = [i.strip() for i in kv.split(":")]
+      if v.startswith('"') and v.endswith('"'):
+        v = v[1:-1]
+      return (k, v)
+
+    for k, v in [kv_pair(kv) for kv in args_str.split(",")]:
+      if not hasattr(self, k):
+        raise ValueError(f"Unknown plot arg: {k}")
+      setattr(self, k, v)
+
+
+def filter_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count", filt: str = "", limit: int | None = None):
+  if filt:
+    filtered_df = df.query(filt)[group_by + [aggregated_field]].groupby(group_by, as_index=False)
+  else:
+    filtered_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False)
+  filtered_df = filtered_df.sum()
+  filtered_df["percent"] = filtered_df[aggregated_field].transform(lambda x: x / x.sum() * 100)
+  filtered_df = filtered_df.sort_values(by=[aggregated_field], ascending=False)
+  if limit is not None:
+    filtered_df_top_n = filtered_df.iloc[:limit, :]
+    other_total = filtered_df.iloc[limit:, :][aggregated_field].sum()
+    other_percent = filtered_df.iloc[limit:, :]["percent"].sum()
+    others = pd.DataFrame({group_by[0]: ["Other"], aggregated_field: other_total, "percent": other_percent})
+    filtered_df = pd.concat([filtered_df_top_n, others])
+  return filtered_df
+
+
+def sum_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count"):
+  aggregated_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False)
+  aggregated_df = aggregated_df.sum()
+  return aggregated_df
+
+
+def create_plot(
+    df: pd.DataFrame,
+    plot_args: PlotArgs,
+    ax: plt.Axes,
+    legend_ax: plt.Axes,
+    aggregated_field: str,
+):
+  limit = int(plot_args.limit) if plot_args.limit else None
+  aggregated_df = filter_dataframe(
+      df, group_by=[plot_args.field], aggregated_field=aggregated_field, filt=plot_args.filter, limit=limit)
+  plot_title = plot_args.title or plot_args.field
+  ax.set_title(f"{plot_title}")
+  patches, _ = ax.pie(
+      x=aggregated_df[aggregated_field], labels=aggregated_df[plot_args.field], autopct=None)
+  labels = [f"{n} - {p:.1f}% ({v:.1f})" for n, p, v in zip(aggregated_df[plot_args.field],
+                                                           aggregated_df["percent"], aggregated_df[aggregated_field])]
+  legend_ax.legend(patches, labels, loc="best", ncol=2)
+  legend_ax.axis("off")
+
+
+def extract_to_temp_dir(stream_path: pathlib.Path, frame_limit: int | None = None) -> Iterator[proto_helpers.Frame]:
+  with tempfile.TemporaryDirectory(dir=_TMP_DIR.value) as tmp_dir:
+    tmp_path = pathlib.Path(tmp_dir)
+    extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value)
+    stream_name = stream_path.stem
+    output_path = tmp_path / stream_name
+    try:
+      output_path.mkdir()
+    except FileExistsError:
+      logging.fatal(f"Duplicate stream name: {stream_name}")
+    yield from extract_and_load_protos(
+        extract_proto_path=extract_proto_path,
+        stream_path=stream_path,
+        output_path=output_path,
+        frame_limit=frame_limit
+    )
+
+
+def process_stream(stream_path: pathlib.Path, *, extractor_class: Type[Extractor]) -> pd.DataFrame:
+  group_by = _GROUP_BY.value.split(",")
+  frames = extract_to_temp_dir(stream_path, _FRAME_LIMIT.value)
+  df = None
+  for frame in frames:
+    frame_df = aggregate_to_dataframe([frame], extractor_class())
+    frame_df["count"] = 1
+    if df is None:
+      df = frame_df
+    else:
+      df = pd.concat([df, frame_df])
+    if len(df):
+      df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
+    gc.collect()
+  if len(df):
+    df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
+  gc.collect()
+  with open(f"/opt/tmp/progress/{stream_path.stem}.txt", "w") as f:
+    f.write("Done")
+  return df
+
+
+def main(argv: Sequence[str]) -> None:
+  if len(argv) > 1:
+    raise app.UsageError("Too many command-line arguments.")
+
+  extractor_name = _EXTRACTOR.value
+  try:
+    extractor_class = _EXTRACTORS[extractor_name]
+  except KeyError:
+    print(f"Unknown extractor: {extractor_name}", file=sys.stderr)
+    sys.exit(1)
+
+  stream_paths = []
+  for stream_glob in _STREAM_GLOB.value:
+    for stream in glob.glob(stream_glob, recursive=True):
+      path = pathlib.Path(stream)
+      stream_paths.append(path)
+
+  stream_paths = sorted(stream_paths)
+  with multiprocessing.Pool(_THREADS.value) as pool:
+    dfs = pool.map(
+        partial(process_stream, extractor_class=extractor_class), stream_paths)
+
+  df = pd.concat(dfs)
+  group_by = _GROUP_BY.value.split(",")
+  aggregated_df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
+
+  if _OUTPUT_CSV.value:
+    aggregated_df.to_csv(_OUTPUT_CSV.value)
+
+  plots = _PLOT.value
+  if plots:
+    _, axes = plt.subplots(2, len(plots), squeeze=False, figsize=(10, 5))
+    for i, plot in enumerate(plots):
+      args = PlotArgs(plot)
+      create_plot(df, args, axes[0, i], axes[1, i], aggregated_field=_AGGREGATED_FIELD.value)
+    plt.tight_layout()
+    plt.show()
+
+
+if __name__ == "__main__":
+  app.run(main)
diff --git a/tools/py_stats/examples/aggregate_prediction_modes.py b/tools/py_stats/examples/aggregate_prediction_modes.py
deleted file mode 100644
index 9d73fd3..0000000
--- a/tools/py_stats/examples/aggregate_prediction_modes.py
+++ /dev/null
@@ -1,154 +0,0 @@
-#!/usr/bin/env python3
-## Copyright (c) 2023, Alliance for Open Media. All rights reserved
-##
-## This source code is subject to the terms of the BSD 3-Clause Clear License and the
-## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
-## not distributed with this source code in the LICENSE file, you can obtain it
-## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
-## License 1.0 was not distributed with this source code in the PATENTS file, you
-## can obtain it at aomedia.org/license/patent-license/.
-##
-from __future__ import annotations
-
-import collections
-import glob
-import itertools
-import pathlib
-import tempfile
-from typing import Sequence
-
-from absl import app
-from absl import flags
-from absl import logging
-from avm_stats.extract_proto import *
-from avm_stats.frame_visualizations import *
-from avm_stats.proto_helpers import *
-from avm_stats.stats_aggregation import *
-import matplotlib.pyplot as plt
-import pandas as pd
-
-_STREAM_GLOB = flags.DEFINE_multi_string(
-    "stream_glob", None, "Path to AVM stream."
-)
-flags.mark_flag_as_required("stream_glob")
-
-_EXTRACT_PROTO_BIN = flags.DEFINE_string(
-    "extract_proto_bin", None, "Path to extract_proto binary."
-)
-flags.mark_flag_as_required("extract_proto_bin")
-
-_OUTPUT_CSV = flags.DEFINE_string(
-    "output_csv", None, "Path to output CSV file (optional)."
-)
-
-
-def prediction_mode_symbol_filter(symbol: Symbol):
-  return symbol.source_function in (
-      "read_inter_mode",
-      "read_intra_luma_mode",
-      "read_drl_index",
-      "read_inter_compound_mode",
-  )
-
-
-class PredictionModeExtractor(CodingUnitExtractor):
-  PredictionMode = collections.namedtuple(
-      "PredictionMode",
-      ["width", "height", "mode", "mode_bits", "is_intra_frame"],
-  )
-
-  def sample(self, coding_unit: CodingUnit):
-    width = coding_unit.rect.width
-    height = coding_unit.rect.height
-    mode = coding_unit.get_prediction_mode()
-    mode_bits = sum(
-        sym.bits
-        for sym in coding_unit.get_symbols(prediction_mode_symbol_filter)
-    )
-    is_intra_frame = coding_unit.frame.proto.frame_params.frame_type == 0
-    yield self.PredictionMode(width, height, mode, mode_bits, is_intra_frame)
-
-
-def compare_intra_inter(
-    df: pd.DataFrame,
-    column: str,
-    *,
-    aggregated_field: str = "count",
-    aggregator: str = "count",
-    plot_title: str | None = None,
-):
-  df_intra = df.query("is_intra_frame")[[column, aggregated_field]].groupby(
-      column, as_index=False
-  )
-  df_intra = getattr(df_intra, aggregator)()
-  df_inter = df.query("not is_intra_frame")[[column, aggregated_field]].groupby(
-      column, as_index=False
-  )
-  df_inter = getattr(df_inter, aggregator)()
-  _, axes = plt.subplots(1, 2)
-  plot_title = plot_title or column
-  axes[0].pie(
-      x=df_intra[aggregated_field], labels=df_intra[column], autopct="%.1f%%"
-  )
-  axes[0].set_title(f"{plot_title} (Intra frames)")
-  axes[1].pie(
-      x=df_inter[aggregated_field], labels=df_inter[column], autopct="%.1f%%"
-  )
-  axes[1].set_title(f"{plot_title} (Inter frames)")
-  plt.show()
-
-
-def main(argv: Sequence[str]) -> None:
-  if len(argv) > 1:
-    raise app.UsageError("Too many command-line arguments.")
-
-  stream_paths = []
-  for stream_glob in _STREAM_GLOB.value:
-    for stream in glob.glob(stream_glob, recursive=True):
-      path = pathlib.Path(stream)
-      stream_paths.append(path)
-
-  with tempfile.TemporaryDirectory() as tmp_dir:
-    tmp_path = pathlib.Path(tmp_dir)
-    extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value)
-
-    def extract_to_temp_dir(stream_path: pathlib.Path) -> Frame:
-      stream_name = stream_path.stem
-      output_path = tmp_path / stream_name
-      try:
-        output_path.mkdir()
-      except FileExistsError:
-        logging.fatal(f"Duplicate stream name: {stream_name}")
-      yield from extract_and_load_protos(
-          extract_proto_path=extract_proto_path,
-          stream_path=stream_path,
-          output_path=output_path,
-      )
-
-    all_frames = itertools.chain.from_iterable(
-        map(extract_to_temp_dir, stream_paths)
-    )
-    df = aggregate_to_dataframe(all_frames, PredictionModeExtractor())
-    df["count"] = 1
-    df["width_height"] = df.apply(
-        lambda row: f"{row.width}x{row.height}", axis=1
-    )
-
-    compare_intra_inter(
-        df, "width_height", plot_title="Block size distribution"
-    )
-    compare_intra_inter(df, "mode", plot_title="Prediction mode distribution")
-    compare_intra_inter(
-        df,
-        "mode",
-        aggregated_field="mode_bits",
-        aggregator="sum",
-        plot_title="Prediction modes by bits spent",
-    )
-
-    if _OUTPUT_CSV.value:
-      df.to_csv(_OUTPUT_CSV.value)
-
-
-if __name__ == "__main__":
-  app.run(main)
diff --git a/tools/py_stats/examples/extractors/partition_type_extractor.py b/tools/py_stats/examples/extractors/partition_type_extractor.py
new file mode 100644
index 0000000..c972d5f
--- /dev/null
+++ b/tools/py_stats/examples/extractors/partition_type_extractor.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+def iter_partitions(partition_block):
+  yield partition_block
+  if not partition_block.is_leaf_node:
+    for child in partition_block.children:
+      yield from iter_partitions(child)
+
+class PartitionTypeExtractor(SuperblockExtractor):
+  PartitionType = collections.namedtuple(
+      "PartitionType",
+      ["width", "height", "block_size", "partition_type", "is_intra_frame", "stream_name"],
+  )
+
+  def sample(self, superblock: Superblock):
+    stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin")
+    is_intra_frame = superblock.frame.is_intra_frame
+    for partition_block in iter_partitions(superblock.proto.luma_partition_tree):
+      width = partition_block.size.width
+      height = partition_block.size.height
+      block_size = f"{width}x{height}"
+      partition_type = partition_block.partition_type
+      partition_type_name = superblock.frame.proto.enum_mappings.partition_type_mapping[partition_type]
+      yield self.PartitionType(width, height, block_size, partition_type_name, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/extractors/prediction_mode_extractor.py b/tools/py_stats/examples/extractors/prediction_mode_extractor.py
new file mode 100644
index 0000000..6333989
--- /dev/null
+++ b/tools/py_stats/examples/extractors/prediction_mode_extractor.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+def prediction_mode_symbol_filter(symbol: Symbol):
+  return symbol.source_function in (
+      "read_inter_mode",
+      "read_intra_luma_mode",
+      "read_drl_index",
+      "read_inter_compound_mode",
+  )
+
+class PredictionModeExtractor(CodingUnitExtractor):
+  PredictionMode = collections.namedtuple(
+      "PredictionMode",
+      ["width", "height", "block_size", "mode", "mode_bits", "is_intra_frame", "stream_name"],
+  )
+
+  def sample(self, coding_unit: CodingUnit):
+    stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin")
+    width = coding_unit.rect.width
+    height = coding_unit.rect.height
+    block_size = f"{width}x{height}"
+    mode = coding_unit.get_prediction_mode()
+    mode_bits = sum(
+        sym.bits
+        for sym in coding_unit.get_symbols(prediction_mode_symbol_filter)
+    )
+    is_intra_frame = coding_unit.frame.is_intra_frame
+    yield self.PredictionMode(width, height, block_size, mode, mode_bits, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/extractors/symbol_bits_extractor.py b/tools/py_stats/examples/extractors/symbol_bits_extractor.py
new file mode 100644
index 0000000..2a70bd5
--- /dev/null
+++ b/tools/py_stats/examples/extractors/symbol_bits_extractor.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+CTC_CLASSES = [
+  2160,
+  1080,
+  720,
+  360,
+  270,
+]
+
+class SymbolBitsExtractor(SuperblockExtractor):
+  SymbolBits = collections.namedtuple(
+      "SymbolBits",
+      ["symbol_name", "symbol_tags", "bits", "is_intra_frame", "stream_name", "qp", "ctc_class", "ctc_config"],
+  )
+
+  def sample(self, superblock: Superblock):
+    stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin")
+    qp = stream_name.split("_")[-1]
+    ctc_config = stream_name.split("_")[-3]
+    ctc_class = None
+    for i, c in enumerate(CTC_CLASSES):
+      if str(c) in stream_name:
+        assert ctc_class is None
+        ctc_class = f"A{i+1}"
+    assert ctc_class is not None
+    is_intra_frame = superblock.frame.is_intra_frame
+    for symbol in superblock.proto.symbols:
+      symbol_info = superblock.frame.proto.symbol_info[symbol.info_id]
+      symbol_tags = "/".join(symbol_info.tags)
+      yield self.SymbolBits(symbol_info.source_function, symbol_tags, symbol.bits, is_intra_frame, stream_name, qp, ctc_class, ctc_config)
diff --git a/tools/py_stats/examples/extractors/tx_type_extractor.py b/tools/py_stats/examples/extractors/tx_type_extractor.py
new file mode 100644
index 0000000..bf9ae93
--- /dev/null
+++ b/tools/py_stats/examples/extractors/tx_type_extractor.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
+##
+import collections
+
+from avm_stats.extract_proto import *
+from avm_stats.frame_visualizations import *
+from avm_stats.proto_helpers import *
+from avm_stats.stats_aggregation import *
+
+
+class TxTypeExtractor(CodingUnitExtractor):
+  TxSize = collections.namedtuple(
+      "TxSize",
+      ["width", "height", "tx_size", "tx_type", "is_intra_frame", "stream_name"],
+  )
+
+  def sample(self, coding_unit: CodingUnit):
+    stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin")
+    # Luma only
+    is_chroma = len(coding_unit.proto.transform_planes) == 2
+    if is_chroma:
+      return
+    is_intra_frame = coding_unit.frame.is_intra_frame
+    for transform_unit in coding_unit.proto.transform_planes[0].transform_units:
+      width = transform_unit.size.width
+      height = transform_unit.size.height
+      tx_size = f"{width}x{height}"
+      # TODO(comc): Add method on transform_unit for this.
+      tx_type = transform_unit.tx_type & 0xF
+      tx_type_name = coding_unit.frame.proto.enum_mappings.transform_type_mapping[tx_type]
+      yield self.TxSize(width, height, tx_size, tx_type_name, is_intra_frame, stream_name)
diff --git a/tools/py_stats/examples/partition_tree.py b/tools/py_stats/examples/partition_tree.py
index 3a26160..d2c32ad 100755
--- a/tools/py_stats/examples/partition_tree.py
+++ b/tools/py_stats/examples/partition_tree.py
@@ -1,12 +1,12 @@
 #!/usr/bin/env python3
-# Copyright (c) 2023, Alliance for Open Media. All rights reserved
+## Copyright (c) 2023, Alliance for Open Media. All rights reserved
 ##
-# This source code is subject to the terms of the BSD 3-Clause Clear License and the
-# Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
-# not distributed with this source code in the LICENSE file, you can obtain it
-# at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
-# License 1.0 was not distributed with this source code in the PATENTS file, you
-# can obtain it at aomedia.org/license/patent-license/.
+## This source code is subject to the terms of the BSD 3-Clause Clear License and the
+## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
+## not distributed with this source code in the LICENSE file, you can obtain it
+## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent
+## License 1.0 was not distributed with this source code in the PATENTS file, you
+## can obtain it at aomedia.org/license/patent-license/.
 ##
 from __future__ import annotations
 
diff --git a/tools/py_stats/img/aggregation_block_sizes.png b/tools/py_stats/img/aggregation_block_sizes.png
new file mode 100644
index 0000000..8064457
--- /dev/null
+++ b/tools/py_stats/img/aggregation_block_sizes.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_partition_types.png b/tools/py_stats/img/aggregation_partition_types.png
new file mode 100644
index 0000000..f632dea
--- /dev/null
+++ b/tools/py_stats/img/aggregation_partition_types.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_prediction_modes.png b/tools/py_stats/img/aggregation_prediction_modes.png
new file mode 100644
index 0000000..81dafce
--- /dev/null
+++ b/tools/py_stats/img/aggregation_prediction_modes.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_prediction_modes_by_bits.png b/tools/py_stats/img/aggregation_prediction_modes_by_bits.png
new file mode 100644
index 0000000..e87f9d9
--- /dev/null
+++ b/tools/py_stats/img/aggregation_prediction_modes_by_bits.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_symbols.png b/tools/py_stats/img/aggregation_symbols.png
new file mode 100644
index 0000000..cc6e226
--- /dev/null
+++ b/tools/py_stats/img/aggregation_symbols.png
Binary files differ
diff --git a/tools/py_stats/img/aggregation_tx_types.png b/tools/py_stats/img/aggregation_tx_types.png
new file mode 100644
index 0000000..6c63310
--- /dev/null
+++ b/tools/py_stats/img/aggregation_tx_types.png
Binary files differ
diff --git a/tools/py_stats/img/bits_heatmap.png b/tools/py_stats/img/bits_heatmap.png
new file mode 100644
index 0000000..bf7e6e5
--- /dev/null
+++ b/tools/py_stats/img/bits_heatmap.png
Binary files differ
diff --git a/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png b/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png
new file mode 100644
index 0000000..4db7259
--- /dev/null
+++ b/tools/py_stats/img/bits_heatmap_read_intra_luma_mode.png
Binary files differ
diff --git a/tools/py_stats/img/partition_tree_luma_inter.png b/tools/py_stats/img/partition_tree_luma_inter.png
new file mode 100644
index 0000000..e4ee744
--- /dev/null
+++ b/tools/py_stats/img/partition_tree_luma_inter.png
Binary files differ
diff --git a/tools/py_stats/img/partition_tree_luma_intra.png b/tools/py_stats/img/partition_tree_luma_intra.png
new file mode 100644
index 0000000..169dffc
--- /dev/null
+++ b/tools/py_stats/img/partition_tree_luma_intra.png
Binary files differ
diff --git a/tools/py_stats/img/pixel_pipeline.png b/tools/py_stats/img/pixel_pipeline.png
new file mode 100644
index 0000000..7fcac2d
--- /dev/null
+++ b/tools/py_stats/img/pixel_pipeline.png
Binary files differ
diff --git a/tools/py_stats/img/prediction_modes_luma_intra.png b/tools/py_stats/img/prediction_modes_luma_intra.png
new file mode 100644
index 0000000..d27c7b8
--- /dev/null
+++ b/tools/py_stats/img/prediction_modes_luma_intra.png
Binary files differ
diff --git a/tools/py_stats/pyproject.toml b/tools/py_stats/pyproject.toml
index 469ba6a..6d94bab 100644
--- a/tools/py_stats/pyproject.toml
+++ b/tools/py_stats/pyproject.toml
@@ -7,7 +7,7 @@
 license = "BSD-3-Clause"
 readme = "README.md"
 # By default Poetry will use .gitignore, which excludes the generated protobuf bindings.
-include = ["*.py"]
+include = ["**/*.py"]
 
 [tool.poetry.dependencies]
 absl-py = "^2.0.0"