|  | #!/usr/bin/env python3 | 
|  | ## Copyright (c) 2023, Alliance for Open Media. All rights reserved | 
|  | ## | 
|  | ## This source code is subject to the terms of the BSD 3-Clause Clear License and the | 
|  | ## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was | 
|  | ## not distributed with this source code in the LICENSE file, you can obtain it | 
|  | ## at aomedia.org/license/software-license/bsd-3-c-c/.  If the Alliance for Open Media Patent | 
|  | ## License 1.0 was not distributed with this source code in the PATENTS file, you | 
|  | ## can obtain it at aomedia.org/license/patent-license/. | 
|  | ## | 
|  | from __future__ import annotations | 
|  |  | 
|  | from functools import partial | 
|  | import pathlib | 
|  | import tempfile | 
|  |  | 
|  | from absl import app | 
|  | from absl import flags | 
|  | from absl import logging | 
|  | from avm_stats.extract_proto import * | 
|  | from avm_stats.frame_visualizations import * | 
|  | from avm_stats.proto_helpers import * | 
|  | from avm_stats.yuv_tools import * | 
|  | import matplotlib.pyplot as plt | 
|  |  | 
|  | _STREAM = flags.DEFINE_string("stream", None, "Path to AVM stream.") | 
|  | flags.mark_flag_as_required("stream") | 
|  |  | 
|  | _SOURCE = flags.DEFINE_string( | 
|  | "source", None, "Path to source YUV/Y4M before encoding (optional)." | 
|  | ) | 
|  |  | 
|  | _EXTRACT_PROTO_BIN = flags.DEFINE_string( | 
|  | "extract_proto_bin", None, "Path to extract_proto binary." | 
|  | ) | 
|  | flags.mark_flag_as_required("extract_proto_bin") | 
|  |  | 
|  | _FRAME = flags.DEFINE_integer( | 
|  | "frame", 0, "Frame number to visualize (defaults to the first frame)." | 
|  | ) | 
|  | _PLANE = flags.DEFINE_enum( | 
|  | "plane", "y", ["y", "u", "v"], "Plane to visualize (defaults to luma: 'y')" | 
|  | ) | 
|  |  | 
|  |  | 
|  | def main(argv: Sequence[str]) -> None: | 
|  | if len(argv) > 1: | 
|  | raise app.UsageError("Too many command-line arguments.") | 
|  | with tempfile.TemporaryDirectory() as tmp_dir: | 
|  | tmp_path = pathlib.Path(tmp_dir) | 
|  | stream_path = pathlib.Path(_STREAM.value) | 
|  | extract_proto_path = pathlib.Path(_EXTRACT_PROTO_BIN.value) | 
|  | yuv_path = _SOURCE.value or None | 
|  | frames = extract_and_load_protos( | 
|  | extract_proto_path=extract_proto_path, | 
|  | stream_path=stream_path, | 
|  | output_path=tmp_path, | 
|  | skip_if_output_already_exists=False, | 
|  | yuv_path=yuv_path, | 
|  | frame_limit=_FRAME.value, | 
|  | ) | 
|  | seq = list(frames) | 
|  | num_frames = len(seq) | 
|  | logging.info(f"Loaded {num_frames} frame protos.") | 
|  |  | 
|  | if yuv_path: | 
|  | visualizations = [ | 
|  | OriginalYuvLayer, | 
|  | PredictionYuvLayer, | 
|  | partial(ResidualYuvLayer, show_relative=False), | 
|  | partial(ResidualYuvLayer, show_relative=True), | 
|  | PrefilteredYuvLayer, | 
|  | partial(FilterDeltaYuvLayer, show_relative=False), | 
|  | partial(FilterDeltaYuvLayer, show_relative=True), | 
|  | ReconstructionYuvLayer, | 
|  | partial(DistortionYuvLayer, show_relative=False), | 
|  | partial(DistortionYuvLayer, show_relative=True), | 
|  | ] | 
|  | # If original YUV is not available, skip the visualizations that depend on it. | 
|  | else: | 
|  | visualizations = [ | 
|  | PredictionYuvLayer, | 
|  | partial(ResidualYuvLayer, show_relative=False), | 
|  | partial(ResidualYuvLayer, show_relative=True), | 
|  | PrefilteredYuvLayer, | 
|  | partial(FilterDeltaYuvLayer, show_relative=False), | 
|  | partial(FilterDeltaYuvLayer, show_relative=True), | 
|  | ReconstructionYuvLayer, | 
|  | ] | 
|  |  | 
|  | subplot_cols = 5 | 
|  | subplot_rows = 2 | 
|  | fig, axes = plt.subplots(subplot_rows, subplot_cols) | 
|  |  | 
|  | plane = Plane[_PLANE.value.upper()] | 
|  | for i in range(subplot_cols * subplot_rows): | 
|  | axes_row = i // subplot_cols | 
|  | axes_col = i % subplot_cols | 
|  | ax = axes[axes_row][axes_col] | 
|  | if i < len(visualizations): | 
|  | visualizations[i](plane=plane).show(seq[_FRAME.value], ax) | 
|  | else: | 
|  | fig.delaxes(ax) | 
|  | plt.show() | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | app.run(main) |