#!/usr/bin/env python3
from __future__ import annotations
from functools import partial
import gc
import glob
import pathlib
import sys
import tempfile
from typing import Sequence, Type
from absl import app
from absl import flags
from absl import logging
from avm_stats.extract_proto import *
from avm_stats.frame_visualizations import *
from avm_stats.proto_helpers import *
from avm_stats.stats_aggregation import *
import matplotlib.pyplot as plt
import multiprocessing
import pandas as pd
from extractors.partition_type_extractor import PartitionTypeExtractor
from extractors.prediction_mode_extractor import PredictionModeExtractor
from extractors.symbol_bits_extractor import SymbolBitsExtractor
from extractors.tx_type_extractor import TxTypeExtractor
"partition_type": PartitionTypeExtractor,
"prediction_mode": PredictionModeExtractor,
"symbol_bits": SymbolBitsExtractor,
"tx_type": TxTypeExtractor,
_STREAM_GLOB = flags.DEFINE_multi_string(
"stream_glob", None, "Path to AVM streams."
"extract_proto_bin", None, "Path to extract_proto binary."
_OUTPUT_CSV = flags.DEFINE_string(
"output_csv", None, "Path to output CSV file (optional)."
"aggregated_field", "count", "Field to aggregate on. By default will use the count."
_GROUP_BY = flags.DEFINE_string(
"group_by", None, "Group by these fields (comma separated)."
_EXTRACTOR = flags.DEFINE_enum(
"extractor", None, _EXTRACTORS.keys(), "Data extractor to use."
_THREADS = flags.DEFINE_integer(
"threads", 1, "Number of parallel workers to spawn."
_FRAME_LIMIT = flags.DEFINE_integer(
"frame_limit", None, "Use at most this many frames from each stream."
_PLOT = flags.DEFINE_multi_string(
"plot", None, """Plot command args, e.g.: --plot 'title:"Block Sizes", field:"block_size", filter:"not is_intra_frame", limit:10'"""
_TMP_DIR = flags.DEFINE_string(
"tmp_dir", None, "Temp working dir."
class PlotArgs:
title: str = ""
field: str = ""
filter: str = ""
limit: str = ""
def __init__(self, args_str: str):
def kv_pair(kv: str) -> tuple[str, str]:
k, v = [i.strip() for i in kv.split(":")]
if v.startswith('"') and v.endswith('"'):
v = v[1:-1]
return (k, v)
for k, v in [kv_pair(kv) for kv in args_str.split(",")]:
if not hasattr(self, k):
raise ValueError(f"Unknown plot arg: {k}")
setattr(self, k, v)
def filter_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count", filt: str = "", limit: int | None = None):
if filt:
filtered_df = df.query(filt)[group_by + [aggregated_field]].groupby(group_by, as_index=False)
filtered_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False)
filtered_df = filtered_df.sum()
filtered_df["percent"] = filtered_df[aggregated_field].transform(lambda x: x / x.sum() * 100)
filtered_df = filtered_df.sort_values(by=[aggregated_field], ascending=False)
if limit is not None:
filtered_df_top_n = filtered_df.iloc[:limit, :]
other_total = filtered_df.iloc[limit:, :][aggregated_field].sum()
other_percent = filtered_df.iloc[limit:, :]["percent"].sum()
others = pd.DataFrame({group_by[0]: ["Other"], aggregated_field: other_total, "percent": other_percent})
filtered_df = pd.concat([filtered_df_top_n, others])
return filtered_df
def sum_dataframe(df: pd.DataFrame, *, group_by: list[str], aggregated_field: str = "count"):
aggregated_df = df[group_by + [aggregated_field]].groupby(group_by, as_index=False)
aggregated_df = aggregated_df.sum()
return aggregated_df
def create_plot(
df: pd.DataFrame,
plot_args: PlotArgs,
ax: plt.Axes,
legend_ax: plt.Axes,
aggregated_field: str,
limit = int(plot_args.limit) if plot_args.limit else None
aggregated_df = filter_dataframe(
df, group_by=[plot_args.field], aggregated_field=aggregated_field, filt=plot_args.filter, limit=limit)
plot_title = plot_args.title or plot_args.field
patches, _ = ax.pie(
x=aggregated_df[aggregated_field], labels=aggregated_df[plot_args.field], autopct=None)
labels = [f"{n} - {p:.1f}% ({v:.1f})" for n, p, v in zip(aggregated_df[plot_args.field],
aggregated_df["percent"], aggregated_df[aggregated_field])]
legend_ax.legend(patches, labels, loc="best", ncol=2)
def extract_to_temp_dir(stream_path: pathlib.Path, frame_limit: int | None = None) -> Iterator[proto_helpers.Frame]:
with tempfile.TemporaryDirectory(dir=_TMP_DIR.value) as tmp_dir:
tmp_path = pathlib.Path(tmp_dir)
extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value)
stream_name = stream_path.stem
output_path = tmp_path / stream_name
except FileExistsError:
logging.fatal(f"Duplicate stream name: {stream_name}")
yield from extract_and_load_protos(
# Preserve the entire stream path so that the CTC config and class can be retrieved if available from the path.
extra_args = ["--preserve_stream_path_depth", "-1"]
def process_stream(stream_path: pathlib.Path, *, extractor_class: Type[Extractor]) -> pd.DataFrame:
group_by = _GROUP_BY.value.split(",")
frames = extract_to_temp_dir(stream_path, _FRAME_LIMIT.value)
df = None
for frame in frames:
frame_df = aggregate_to_dataframe([frame], extractor_class())
frame_df["count"] = 1
if df is None:
df = frame_df
df = pd.concat([df, frame_df])
if len(df):
df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
if len(df):
df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
with open(f"/opt/tmp/progress/{stream_path.stem}.txt", "w") as f:
return df
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
extractor_name = _EXTRACTOR.value
extractor_class = _EXTRACTORS[extractor_name]
except KeyError:
print(f"Unknown extractor: {extractor_name}", file=sys.stderr)
stream_paths = []
for stream_glob in _STREAM_GLOB.value:
for stream in glob.glob(stream_glob, recursive=True):
path = pathlib.Path(stream)
stream_paths = sorted(stream_paths)
with multiprocessing.Pool(_THREADS.value) as pool:
dfs =
partial(process_stream, extractor_class=extractor_class), stream_paths)
df = pd.concat(dfs)
group_by = _GROUP_BY.value.split(",")
aggregated_df = sum_dataframe(df, group_by=group_by, aggregated_field=_AGGREGATED_FIELD.value)
if _OUTPUT_CSV.value:
plots = _PLOT.value
if plots:
_, axes = plt.subplots(2, len(plots), squeeze=False, figsize=(10, 5))
for i, plot in enumerate(plots):
args = PlotArgs(plot)
create_plot(df, args, axes[0, i], axes[1, i], aggregated_field=_AGGREGATED_FIELD.value)
if __name__ == "__main__":