blob: 0f06e454ec479a2a2ef73ad83e8b836ab549cf5e [file] [log] [blame]
#!/usr/bin/env python3
## Copyright (c) 2023, Alliance for Open Media. All rights reserved
##
## This source code is subject to the terms of the BSD 3-Clause Clear License and the
## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
## not distributed with this source code in the LICENSE file, you can obtain it
## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
## License 1.0 was not distributed with this source code in the PATENTS file, you
## can obtain it at aomedia.org/license/patent-license/.
##
from __future__ import annotations
from functools import partial
import pathlib
import tempfile
from absl import app
from absl import flags
from absl import logging
from avm_stats.extract_proto import *
from avm_stats.frame_visualizations import *
from avm_stats.proto_helpers import *
from avm_stats.yuv_tools import *
import matplotlib.pyplot as plt
_STREAM = flags.DEFINE_string("stream", None, "Path to AVM stream.")
flags.mark_flag_as_required("stream")
_SOURCE = flags.DEFINE_string(
"source", None, "Path to source YUV/Y4M before encoding (optional)."
)
_EXTRACT_PROTO_BIN = flags.DEFINE_string(
"extract_proto_bin", None, "Path to extract_proto binary."
)
flags.mark_flag_as_required("extract_proto_bin")
_FRAME = flags.DEFINE_integer(
"frame", 0, "Frame number to visualize (defaults to the first frame)."
)
_PLANE = flags.DEFINE_enum(
"plane", "y", ["y", "u", "v"], "Plane to visualize (defaults to luma: 'y')"
)
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = pathlib.Path(tmp_dir)
stream_path = pathlib.Path(_STREAM.value)
extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value)
yuv_path = _SOURCE.value or None
frames = extract_and_load_protos(
extract_proto_path=extract_proto_path,
stream_path=stream_path,
output_path=tmp_path,
skip_if_output_already_exists=False,
yuv_path=yuv_path,
)
seq = list(frames)
num_frames = len(seq)
logging.info(f"Loaded {num_frames} frame protos.")
if yuv_path:
visualizations = [
OriginalYuvLayer,
PredictionYuvLayer,
partial(ResidualYuvLayer, show_relative=False),
partial(ResidualYuvLayer, show_relative=True),
PrefilteredYuvLayer,
partial(FilterDeltaYuvLayer, show_relative=False),
partial(FilterDeltaYuvLayer, show_relative=True),
ReconstructionYuvLayer,
partial(DistortionYuvLayer, show_relative=False),
partial(DistortionYuvLayer, show_relative=True),
]
# If original YUV is not available, skip the visualizations that depend on it.
else:
visualizations = [
PredictionYuvLayer,
partial(ResidualYuvLayer, show_relative=False),
partial(ResidualYuvLayer, show_relative=True),
PrefilteredYuvLayer,
partial(FilterDeltaYuvLayer, show_relative=False),
partial(FilterDeltaYuvLayer, show_relative=True),
ReconstructionYuvLayer,
]
subplot_cols = 5
subplot_rows = 2
fig, axes = plt.subplots(subplot_rows, subplot_cols)
plane = Plane[_PLANE.value.upper()]
for i in range(subplot_cols * subplot_rows):
axes_row = i // subplot_cols
axes_col = i % subplot_cols
ax = axes[axes_row][axes_col]
if i < len(visualizations):
visualizations[i](plane=plane).show(seq[_FRAME.value], ax)
else:
fig.delaxes(ax)
plt.show()
if __name__ == "__main__":
app.run(main)