Fix 10-bit handling in py_stats.
diff --git a/tools/py_stats/avm_stats/extract_proto.py b/tools/py_stats/avm_stats/extract_proto.py
index 0a70284..3b57bf8 100644
--- a/tools/py_stats/avm_stats/extract_proto.py
+++ b/tools/py_stats/avm_stats/extract_proto.py
@@ -42,7 +42,8 @@
output_path: pathlib.Path,
skip_if_output_already_exists: bool = False,
yuv_path: pathlib.Path | None = None,
- frame_limit: int | None = None
+ frame_limit: int | None = None,
+ extra_args: list[str] | None = None
) -> ExtractProtoResult:
if (
skip_if_output_already_exists
@@ -75,6 +76,8 @@
"--limit",
str(frame_limit),
])
+ if extra_args is not None:
+ extract_proto_args.extend(extra_args)
logging.info("Running:\n %s", " ".join(extract_proto_args))
p = subprocess.run(extract_proto_args, capture_output=True, check=True)
@@ -107,7 +110,8 @@
output_path: pathlib.Path,
skip_if_output_already_exists: bool = False,
yuv_path: pathlib.Path | None = None,
- frame_limit: int | None = None
+ frame_limit: int | None = None,
+ extra_args: list[str] | None = None
) -> Iterator[proto_helpers.Frame]:
result = extract_proto(
extract_proto_path=extract_proto_path,
@@ -115,6 +119,7 @@
output_path=output_path,
skip_if_output_already_exists=skip_if_output_already_exists,
yuv_path=yuv_path,
- frame_limit=frame_limit
+ frame_limit=frame_limit,
+ extra_args=extra_args,
)
yield from load_protos(result.output_path)
diff --git a/tools/py_stats/avm_stats/frame_visualizations.py b/tools/py_stats/avm_stats/frame_visualizations.py
index 2dfbfd1..ce42e23 100644
--- a/tools/py_stats/avm_stats/frame_visualizations.py
+++ b/tools/py_stats/avm_stats/frame_visualizations.py
@@ -108,8 +108,9 @@
pixels = getattr(frame.pixels[self.plane], self.pixels_attribute)
width = frame.width
height = frame.height
+ vmax = 2 ** frame.bit_depth - 1
ax.imshow(
- pixels, cmap="gray", vmin=0, vmax=255, extent=[0, width, height, 0]
+ pixels, cmap="gray", vmin=0, vmax=vmax, extent=[0, width, height, 0]
)
@@ -160,7 +161,8 @@
pixels = self.get_pixels(frame)
pixel_min = np.min(pixels)
pixel_max = np.max(pixels)
- vmin, vmax = -255, 255
+ vmax = 2 ** frame.bit_depth - 1
+ vmin = -vmax
annotation = "Relative" if self.show_relative else "Absolute"
if self.show_relative:
vmin, vmax = pixel_min, pixel_max
diff --git a/tools/py_stats/avm_stats/proto_helpers.py b/tools/py_stats/avm_stats/proto_helpers.py
index b6c861a..0bd57a5 100644
--- a/tools/py_stats/avm_stats/proto_helpers.py
+++ b/tools/py_stats/avm_stats/proto_helpers.py
@@ -160,9 +160,13 @@
pixels_width = superblock_plane.width
pixels_height = superblock_plane.height
- superblock_samples = np.array(superblock_plane.pixels).reshape(
+ superblock_samples = np.array(superblock_plane.pixels, dtype=dtype).reshape(
(pixels_height, pixels_width)
)[:sb_height_clipped, :sb_width_clipped]
+ if superblock_plane.bit_depth > frame.bit_depth:
+ superblock_samples //= int(2 ** (superblock_plane.bit_depth - frame.bit_depth))
+ elif superblock_plane.bit_depth < frame.bit_depth:
+ superblock_samples *= int(2 ** (frame.bit_depth - superblock_plane.bit_depth))
samples[sb_y: sb_y + sb_height_clipped, sb_x: sb_x + sb_width_clipped] = (
superblock_samples
@@ -194,11 +198,11 @@
assert reconstruction is not None
# Even for 8-bit frames, 16 bits are needed for the deltas, since the range is
# [-255, 255].
- residual = pre_filtered.astype(np.int16) - prediction
- filter_delta = reconstruction.astype(np.int16) - pre_filtered
+ residual = pre_filtered.astype(np.int16) - prediction.astype(np.int16)
+ filter_delta = reconstruction.astype(np.int16) - pre_filtered.astype(np.int16)
# Since distortion is computed from the original, this might also be None.
distortion = (
- original.astype(np.int16) - reconstruction
+ original.astype(np.int16) - reconstruction.astype(np.int16)
if original is not None
else None
)
@@ -563,13 +567,22 @@
return self.proto.frame_params.bit_depth
@property
+ def pixel_scale(self) -> int:
+ if self.bit_depth == 8:
+ return 1
+ elif self.bit_depth == 10:
+ return 4
+ else:
+ raise RuntimeError(f"Unsupported bit depth: {self.bit_depth}")
+
+ @property
def is_intra_frame(self) -> bool:
return self.proto.frame_params.frame_type == 0
@property
def pixels(self) -> list[PlaneBuffer]:
if self._pixels is None:
- self.pixels = [
+ self._pixels = [
_create_plane_buffer(self, p) for p in (Plane.Y, Plane.U, Plane.V)
]
return self._pixels
@@ -578,9 +591,9 @@
def original_rgb(self) -> np.ndarray:
if self._original_rgb is None and self.pixels[0].original is not None:
self._original_rgb = yuv_tools.yuv_to_rgb(
- self.pixels[0].original,
- yuv_tools.upscale(self.pixels[1].original, 2),
- yuv_tools.upscale(self.pixels[2].original, 2),
+ self.pixels[0].original // self.pixel_scale,
+ yuv_tools.upscale(self.pixels[1].original, 2) // self.pixel_scale,
+ yuv_tools.upscale(self.pixels[2].original, 2) // self.pixel_scale,
)
return self._original_rgb
@@ -588,9 +601,11 @@
def reconstruction_rgb(self) -> np.ndarray:
if self._reconstruction_rgb is None:
self._reconstruction_rgb = yuv_tools.yuv_to_rgb(
- self.pixels[0].reconstruction,
- yuv_tools.upscale(self.pixels[1].reconstruction, 2),
- yuv_tools.upscale(self.pixels[2].reconstruction, 2),
+ self.pixels[0].reconstruction // self.pixel_scale,
+ yuv_tools.upscale(
+ self.pixels[1].reconstruction, 2) // self.pixel_scale,
+ yuv_tools.upscale(
+ self.pixels[2].reconstruction, 2) // self.pixel_scale,
)
return self._reconstruction_rgb
diff --git a/tools/py_stats/examples/aggregate_from_extractor.py b/tools/py_stats/examples/aggregate_from_extractor.py
index 2ae1747..f99e531 100755
--- a/tools/py_stats/examples/aggregate_from_extractor.py
+++ b/tools/py_stats/examples/aggregate_from_extractor.py
@@ -29,14 +29,12 @@
import multiprocessing
import pandas as pd
-from extractors.partition_similarity_extractor import PartitionSimilarityExtractor
from extractors.partition_type_extractor import PartitionTypeExtractor
from extractors.prediction_mode_extractor import PredictionModeExtractor
from extractors.symbol_bits_extractor import SymbolBitsExtractor
from extractors.tx_type_extractor import TxTypeExtractor
_EXTRACTORS = {
- "partition_similarity": PartitionSimilarityExtractor,
"partition_type": PartitionTypeExtractor,
"prediction_mode": PredictionModeExtractor,
"symbol_bits": SymbolBitsExtractor,
@@ -163,7 +161,9 @@
extract_proto_path=extract_proto_path,
stream_path=stream_path,
output_path=output_path,
- frame_limit=frame_limit
+ frame_limit=frame_limit,
+ # Preserve the entire stream path so that the CTC config and class can be retrieved if available from the path.
+ extra_args = ["--preserve_stream_path_depth", "-1"]
)
diff --git a/tools/py_stats/examples/extractors/partition_type_extractor.py b/tools/py_stats/examples/extractors/partition_type_extractor.py
index c972d5f..bd7f8c76 100644
--- a/tools/py_stats/examples/extractors/partition_type_extractor.py
+++ b/tools/py_stats/examples/extractors/partition_type_extractor.py
@@ -24,11 +24,11 @@
class PartitionTypeExtractor(SuperblockExtractor):
PartitionType = collections.namedtuple(
"PartitionType",
- ["width", "height", "block_size", "partition_type", "is_intra_frame", "stream_name"],
+ ["width", "height", "block_size", "partition_type", "is_intra_frame", "stream_path"],
)
def sample(self, superblock: Superblock):
- stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin")
+ stream_path = superblock.frame.proto.stream_params.stream_path
is_intra_frame = superblock.frame.is_intra_frame
for partition_block in iter_partitions(superblock.proto.luma_partition_tree):
width = partition_block.size.width
@@ -36,4 +36,4 @@
block_size = f"{width}x{height}"
partition_type = partition_block.partition_type
partition_type_name = superblock.frame.proto.enum_mappings.partition_type_mapping[partition_type]
- yield self.PartitionType(width, height, block_size, partition_type_name, is_intra_frame, stream_name)
+ yield self.PartitionType(width, height, block_size, partition_type_name, is_intra_frame, stream_path)
diff --git a/tools/py_stats/examples/extractors/prediction_mode_extractor.py b/tools/py_stats/examples/extractors/prediction_mode_extractor.py
index 6333989..43eec36 100644
--- a/tools/py_stats/examples/extractors/prediction_mode_extractor.py
+++ b/tools/py_stats/examples/extractors/prediction_mode_extractor.py
@@ -26,11 +26,11 @@
class PredictionModeExtractor(CodingUnitExtractor):
PredictionMode = collections.namedtuple(
"PredictionMode",
- ["width", "height", "block_size", "mode", "mode_bits", "is_intra_frame", "stream_name"],
+ ["width", "height", "block_size", "mode", "mode_bits", "is_intra_frame", "stream_path"],
)
def sample(self, coding_unit: CodingUnit):
- stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin")
+ stream_path = coding_unit.frame.proto.stream_params.stream_path
width = coding_unit.rect.width
height = coding_unit.rect.height
block_size = f"{width}x{height}"
@@ -40,4 +40,4 @@
for sym in coding_unit.get_symbols(prediction_mode_symbol_filter)
)
is_intra_frame = coding_unit.frame.is_intra_frame
- yield self.PredictionMode(width, height, block_size, mode, mode_bits, is_intra_frame, stream_name)
+ yield self.PredictionMode(width, height, block_size, mode, mode_bits, is_intra_frame, stream_path)
diff --git a/tools/py_stats/examples/extractors/symbol_bits_extractor.py b/tools/py_stats/examples/extractors/symbol_bits_extractor.py
index 2a70bd5..b889f6f 100644
--- a/tools/py_stats/examples/extractors/symbol_bits_extractor.py
+++ b/tools/py_stats/examples/extractors/symbol_bits_extractor.py
@@ -15,32 +15,16 @@
from avm_stats.proto_helpers import *
from avm_stats.stats_aggregation import *
-CTC_CLASSES = [
- 2160,
- 1080,
- 720,
- 360,
- 270,
-]
-
class SymbolBitsExtractor(SuperblockExtractor):
SymbolBits = collections.namedtuple(
"SymbolBits",
- ["symbol_name", "symbol_tags", "bits", "is_intra_frame", "stream_name", "qp", "ctc_class", "ctc_config"],
+ ["symbol_name", "symbol_tags", "bits", "is_intra_frame", "stream_path"],
)
def sample(self, superblock: Superblock):
- stream_name = superblock.frame.proto.stream_params.stream_name.removesuffix(".bin")
- qp = stream_name.split("_")[-1]
- ctc_config = stream_name.split("_")[-3]
- ctc_class = None
- for i, c in enumerate(CTC_CLASSES):
- if str(c) in stream_name:
- assert ctc_class is None
- ctc_class = f"A{i+1}"
- assert ctc_class is not None
+ stream_path = superblock.frame.proto.stream_params.stream_path
is_intra_frame = superblock.frame.is_intra_frame
for symbol in superblock.proto.symbols:
symbol_info = superblock.frame.proto.symbol_info[symbol.info_id]
symbol_tags = "/".join(symbol_info.tags)
- yield self.SymbolBits(symbol_info.source_function, symbol_tags, symbol.bits, is_intra_frame, stream_name, qp, ctc_class, ctc_config)
+ yield self.SymbolBits(symbol_info.source_function, symbol_tags, symbol.bits, is_intra_frame, stream_path)
diff --git a/tools/py_stats/examples/extractors/tx_type_extractor.py b/tools/py_stats/examples/extractors/tx_type_extractor.py
index bf9ae93..a07697d 100644
--- a/tools/py_stats/examples/extractors/tx_type_extractor.py
+++ b/tools/py_stats/examples/extractors/tx_type_extractor.py
@@ -19,11 +19,11 @@
class TxTypeExtractor(CodingUnitExtractor):
TxSize = collections.namedtuple(
"TxSize",
- ["width", "height", "tx_size", "tx_type", "is_intra_frame", "stream_name"],
+ ["width", "height", "tx_size", "tx_type", "is_intra_frame", "stream_path"],
)
def sample(self, coding_unit: CodingUnit):
- stream_name = coding_unit.frame.proto.stream_params.stream_name.removesuffix(".bin")
+ stream_path = coding_unit.frame.proto.stream_params.stream_path
# Luma only
is_chroma = len(coding_unit.proto.transform_planes) == 2
if is_chroma:
@@ -36,4 +36,4 @@
# TODO(comc): Add method on transform_unit for this.
tx_type = transform_unit.tx_type & 0xF
tx_type_name = coding_unit.frame.proto.enum_mappings.transform_type_mapping[tx_type]
- yield self.TxSize(width, height, tx_size, tx_type_name, is_intra_frame, stream_name)
+ yield self.TxSize(width, height, tx_size, tx_type_name, is_intra_frame, stream_path)
diff --git a/tools/py_stats/examples/pixel_pipeline.py b/tools/py_stats/examples/pixel_pipeline.py
index 0f06e45..fe26b21 100755
--- a/tools/py_stats/examples/pixel_pipeline.py
+++ b/tools/py_stats/examples/pixel_pipeline.py
@@ -57,6 +57,7 @@
output_path=tmp_path,
skip_if_output_already_exists=False,
yuv_path=yuv_path,
+ frame_limit=_FRAME.value,
)
seq = list(frames)
num_frames = len(seq)