AVM Analyzer initial commit.
diff --git a/tools/avm_analyzer/avm_stats/src/avm_proto.rs b/tools/avm_analyzer/avm_stats/src/avm_proto.rs
new file mode 100644
index 0000000..6733548
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/avm_proto.rs
@@ -0,0 +1 @@
+include!(concat!(env!("OUT_DIR"), "/avm.tools.rs"));
diff --git a/tools/avm_analyzer/avm_stats/src/coding_unit.rs b/tools/avm_analyzer/avm_stats/src/coding_unit.rs
new file mode 100644
index 0000000..8959e2a
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/coding_unit.rs
@@ -0,0 +1,185 @@
+use crate::{
+ CodingUnit, Frame, FrameError, PartitionContext, Plane, PredictionParams, ProtoEnumMapping, SuperblockContext,
+ SuperblockLocator, SymbolContext, SymbolRange, TransformUnitContext, TransformUnitLocator,
+};
+
+use serde::{Deserialize, Serialize};
+use FrameError::BadCodingUnit;
+impl CodingUnit {
+ pub fn plane_index(&self, plane: Plane) -> Result<usize, FrameError> {
+ match plane {
+ Plane::Y => Ok(0),
+ Plane::U | Plane::V => {
+ match self.transform_planes.len() {
+ // Split luma and chroma partition trees
+ 2 => Ok(plane.to_usize() - 1),
+ // Unified luma and chroma partition tree
+ 3 => Ok(plane.to_usize()),
+ _ => Err(BadCodingUnit(format!(
+ "Unexpected number of transform planes: got {}, expected 2 or 3 for plane {plane:?}",
+ self.transform_planes.len()
+ ))),
+ }
+ }
+ }
+ }
+
+ pub fn has_chroma(&self) -> Result<bool, FrameError> {
+ let num_transform_planes = self.transform_planes.len();
+ match num_transform_planes {
+ 2 | 3 => Ok(true),
+ 1 => Ok(false),
+ _ => Err(BadCodingUnit(format!(
+ "Unexpected number of transform planes: {num_transform_planes}"
+ ))),
+ }
+ }
+
+ pub fn has_luma(&self) -> Result<bool, FrameError> {
+ let num_transform_planes = self.transform_planes.len();
+ match num_transform_planes {
+ 1 | 3 => Ok(true),
+ 2 => Ok(false),
+ _ => Err(BadCodingUnit(format!(
+ "Unexpected number of transform planes: {num_transform_planes}"
+ ))),
+ }
+ }
+
+ pub fn get_prediction_mode(&self) -> Result<&PredictionParams, FrameError> {
+ self.prediction_mode
+ .as_ref()
+ .ok_or(BadCodingUnit("Missing prediction mode.".into()))
+ }
+
+ pub fn get_symbol_range(&self) -> Result<&SymbolRange, FrameError> {
+ self.symbol_range
+ .as_ref()
+ .ok_or(BadCodingUnit("Missing symbol range.".into()))
+ }
+
+ pub fn lookup_mode_name(&self, frame: &Frame) -> Result<String, FrameError> {
+ let mode = self.get_prediction_mode()?;
+ frame.enum_lookup(ProtoEnumMapping::PredictionMode, mode.mode)
+ }
+
+ pub fn luma_mode_angle_delta(&self, frame: &Frame) -> Option<i32> {
+ if let Ok(mode) = self.lookup_mode_name(frame) {
+ if mode.ends_with("_PRED") {
+ return self.prediction_mode.as_ref().map(|mode| mode.angle_delta);
+ }
+ }
+ None
+ }
+
+ pub fn lookup_uv_mode_name(&self, frame: &Frame) -> Result<String, FrameError> {
+ let mode = self.get_prediction_mode()?;
+ frame.enum_lookup(ProtoEnumMapping::UvPredictionMode, mode.uv_mode)
+ }
+
+ pub fn chroma_mode_angle_delta(&self, frame: &Frame) -> Option<i32> {
+ if let Ok(mode) = self.lookup_uv_mode_name(frame) {
+ if mode.ends_with("_PRED") {
+ return self.prediction_mode.as_ref().map(|mode| mode.uv_angle_delta);
+ }
+ }
+ None
+ }
+
+ pub fn lookup_motion_vector_precision_name(&self, frame: &Frame) -> Result<String, FrameError> {
+ let mode = self.get_prediction_mode()?;
+ frame.enum_lookup(ProtoEnumMapping::MotionVectorPrecision, mode.motion_vector_precision)
+ }
+}
+
+/// Which planes this coding unit contains.
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
+pub enum CodingUnitKind {
+ /// Coding unit contains all three planes. Equivalent to the shared partition tree type.
+ Shared,
+ /// Coding unit contains only luma.
+ LumaOnly,
+ /// Coding unit contains only chroma.
+ ChromaOnly,
+}
+
+/// Index of a coding unit within its parent superblock.
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
+pub struct CodingUnitLocator {
+ /// Index of parent superblock within the frame.
+ pub superblock: SuperblockLocator,
+ /// Either Shared or ChromaOnly. Note that LumaOnly refers to the same underlying buffer as Shared.
+ pub kind: CodingUnitKind,
+ /// Index of coding unit with its parent superblock.
+ pub index: usize,
+}
+
+impl CodingUnitLocator {
+ pub fn new(superblock: SuperblockLocator, kind: CodingUnitKind, index: usize) -> Self {
+ Self {
+ superblock,
+ kind,
+ index,
+ }
+ }
+ /// Convert this index into a `CodingUnitContext`.
+ pub fn try_resolve<'a>(&self, frame: &'a Frame) -> Option<CodingUnitContext<'a>> {
+ let superblock_context = self.superblock.try_resolve(frame)?;
+ let coding_unit = match self.kind {
+ CodingUnitKind::Shared => superblock_context.superblock.coding_units_shared.get(self.index),
+ CodingUnitKind::LumaOnly => superblock_context.superblock.coding_units_shared.get(self.index),
+ CodingUnitKind::ChromaOnly => superblock_context.superblock.coding_units_chroma.get(self.index),
+ };
+
+ coding_unit.map(|coding_unit| CodingUnitContext {
+ coding_unit,
+ superblock_context,
+ locator: *self,
+ })
+ }
+
+ pub fn resolve<'a>(&self, frame: &'a Frame) -> CodingUnitContext<'a> {
+ self.try_resolve(frame).unwrap()
+ }
+}
+
+/// Context about a coding unit during iteration.
+#[derive(Copy, Clone)]
+pub struct CodingUnitContext<'a> {
+ /// Coding unit being iterated over.
+ pub coding_unit: &'a CodingUnit,
+ /// Superblock that owns this coding unit.
+ pub superblock_context: SuperblockContext<'a>,
+ /// The index of this coding unit within its parent superblock.
+ pub locator: CodingUnitLocator,
+}
+
+impl<'a> CodingUnitContext<'a> {
+ pub fn iter_symbols(&self) -> impl Iterator<Item = SymbolContext<'a>> {
+ let symbol_range = self.coding_unit.symbol_range.clone().unwrap_or_default();
+ self.superblock_context.iter_symbols(Some(symbol_range))
+ }
+
+ pub fn total_bits(&self) -> f32 {
+ self.iter_symbols().map(|sym| sym.symbol.bits).sum()
+ }
+
+ pub fn iter_transform_units(self, plane: Plane) -> impl Iterator<Item = TransformUnitContext<'a>> {
+ let plane_index = self.coding_unit.plane_index(plane);
+ self.coding_unit.transform_planes[plane_index.unwrap()]
+ .transform_units
+ .iter()
+ .enumerate()
+ .map(move |(index, transform_unit)| TransformUnitContext {
+ transform_unit,
+ coding_unit_context: self,
+ locator: TransformUnitLocator::new(self.locator, plane, index),
+ })
+ }
+
+ pub fn find_parent_partition(&self) -> Option<PartitionContext<'a>> {
+ self.superblock_context
+ .root_partition(self.locator.kind)
+ .and_then(|root| root.find_coding_unit_parent(self.coding_unit))
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/constants.rs b/tools/avm_analyzer/avm_stats/src/constants.rs
new file mode 100644
index 0000000..970f22b
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/constants.rs
@@ -0,0 +1,2 @@
+// TODO(comc): This can vary by frame.
+pub const MOTION_VECTOR_PRECISION: f32 = 8.0;
diff --git a/tools/avm_analyzer/avm_stats/src/frame.rs b/tools/avm_analyzer/avm_stats/src/frame.rs
new file mode 100644
index 0000000..969bfd8
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/frame.rs
@@ -0,0 +1,195 @@
+use crate::{
+ CodingUnitContext, CodingUnitKind, EnumMappings, Frame, FrameError, PartitionContext, Plane, PlaneType, Spatial,
+ SuperblockContext, SuperblockLocator, SymbolContext, TransformUnitContext,
+};
+
+pub enum ProtoEnumMapping {
+ TransformType,
+ EntropyCodingMode,
+ InterpolationFilter,
+ PredictionMode,
+ UvPredictionMode,
+ MotionMode,
+ TransformSize,
+ BlockSize,
+ PartitionType,
+ FrameType,
+ TipMode,
+ MotionVectorPrecision,
+}
+
+impl Frame {
+ pub fn iter_coding_units(&self, kind: CodingUnitKind) -> impl Iterator<Item = CodingUnitContext> + '_ {
+ self.iter_superblocks().flat_map(move |ctx| ctx.iter_coding_units(kind))
+ }
+
+ /// Whether this frame has separate luma and chroma partition trees (i.e. semi-decoupled partitioning - SDP).
+ ///
+ /// This is stored at the superblock level, but each superblock is assumed to have the same SDP setting.
+ pub fn has_separate_chroma_partition_tree(&self) -> bool {
+ if let Some(sb) = self.superblocks.first() {
+ sb.has_separate_chroma_partition_tree
+ } else {
+ false
+ }
+ }
+
+ pub fn coding_unit_kind(&self, plane_type: PlaneType) -> CodingUnitKind {
+ match plane_type {
+ PlaneType::Rgb => CodingUnitKind::Shared,
+ PlaneType::Planar(Plane::Y) => {
+ if self.has_separate_chroma_partition_tree() {
+ CodingUnitKind::LumaOnly
+ } else {
+ CodingUnitKind::Shared
+ }
+ }
+ PlaneType::Planar(Plane::U | Plane::V) => {
+ if self.has_separate_chroma_partition_tree() {
+ CodingUnitKind::ChromaOnly
+ } else {
+ CodingUnitKind::Shared
+ }
+ }
+ }
+ }
+
+ pub fn iter_coding_unit_rects(&self, kind: CodingUnitKind) -> impl Iterator<Item = emath::Rect> + '_ {
+ self.iter_coding_units(kind).map(|ctx| ctx.coding_unit.rect())
+ }
+
+ pub fn iter_transform_units(&self, plane: Plane) -> impl Iterator<Item = TransformUnitContext> {
+ let kind = self.coding_unit_kind(PlaneType::Planar(plane));
+ self.iter_coding_units(kind)
+ .flat_map(move |ctx| ctx.iter_transform_units(plane))
+ }
+
+ pub fn iter_transform_rects(&self, plane: Plane) -> impl Iterator<Item = emath::Rect> + '_ {
+ self.iter_transform_units(plane).map(|ctx| ctx.transform_unit.rect())
+ }
+
+ pub fn iter_superblocks(&self) -> impl Iterator<Item = SuperblockContext> {
+ self.superblocks
+ .iter()
+ .enumerate()
+ .map(|(i, superblock)| SuperblockContext {
+ superblock,
+ frame: self,
+ locator: SuperblockLocator::new(i),
+ })
+ }
+
+ pub fn iter_partitions(&self, kind: CodingUnitKind) -> impl Iterator<Item = PartitionContext> {
+ self.iter_superblocks()
+ .flat_map(move |superblock_context| superblock_context.iter_partitions(kind))
+ }
+
+ fn get_enum_mappings(&self) -> Result<&EnumMappings, FrameError> {
+ self.enum_mappings
+ .as_ref()
+ .ok_or(FrameError::BadFrame("Missing enum mappings.".into()))
+ }
+ pub fn enum_lookup(&self, enum_type: ProtoEnumMapping, value: i32) -> Result<String, FrameError> {
+ use FrameError::*;
+ let enum_mappings = self.get_enum_mappings()?;
+ match enum_type {
+ ProtoEnumMapping::TransformType => enum_mappings
+ .transform_type_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing transform type value: {value}"))),
+ ProtoEnumMapping::EntropyCodingMode => enum_mappings
+ .entropy_coding_mode_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing entropy coding mode value: {value}"))),
+ ProtoEnumMapping::InterpolationFilter => enum_mappings
+ .interpolation_filter_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing interpolation filter value: {value}"))),
+ ProtoEnumMapping::PredictionMode => enum_mappings
+ .prediction_mode_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing prediction mode value: {value}"))),
+ ProtoEnumMapping::UvPredictionMode => enum_mappings
+ .uv_prediction_mode_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing UV prediction mode value: {value}"))),
+ ProtoEnumMapping::MotionMode => enum_mappings
+ .motion_mode_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing motion mode value: {value}"))),
+ ProtoEnumMapping::TransformSize => enum_mappings
+ .transform_size_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing transform size value: {value}"))),
+ ProtoEnumMapping::BlockSize => enum_mappings
+ .block_size_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing block size value: {value}"))),
+ ProtoEnumMapping::PartitionType => enum_mappings
+ .partition_type_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing partition type value: {value}"))),
+ ProtoEnumMapping::FrameType => enum_mappings
+ .frame_type_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing frame type value: {value}"))),
+ ProtoEnumMapping::TipMode => enum_mappings
+ .tip_mode_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing TIP mode value: {value}"))),
+ ProtoEnumMapping::MotionVectorPrecision => enum_mappings
+ .motion_vector_precision_mapping
+ .get(&value)
+ .ok_or(BadFrame(format!("Missing MV precision value: {value}"))),
+ }
+ .cloned()
+ }
+
+ pub fn iter_superblock_rects(&self) -> impl Iterator<Item = emath::Rect> + '_ {
+ self.iter_superblocks()
+ .map(|superblock_context| superblock_context.superblock.rect())
+ }
+
+ pub fn iter_symbols(&self) -> impl Iterator<Item = SymbolContext> {
+ self.iter_superblocks()
+ .flat_map(move |superblock_context| superblock_context.iter_symbols(None))
+ }
+
+ pub fn bit_depth(&self) -> u8 {
+ self.frame_params
+ .as_ref()
+ .map_or(0, |frame_params| frame_params.bit_depth as u8)
+ }
+
+ pub fn decode_index(&self) -> usize {
+ self.frame_params
+ .as_ref()
+ .map_or(0, |frame_params| frame_params.decode_index as usize)
+ }
+
+ pub fn display_index(&self) -> usize {
+ self.frame_params
+ .as_ref()
+ .map_or(0, |frame_params| frame_params.display_index as usize)
+ }
+
+ pub fn frame_type_name(&self) -> String {
+ if let Some(frame_params) = self.frame_params.as_ref() {
+ let frame_type = frame_params.frame_type;
+ if let Ok(name) = self.enum_lookup(ProtoEnumMapping::FrameType, frame_type) {
+ return name;
+ }
+ }
+ "UNKNOWN".into()
+ }
+
+ pub fn tip_mode_name(&self) -> String {
+ if let Some(tip_frame_params) = self.tip_frame_params.as_ref() {
+ let tip_mode = tip_frame_params.tip_mode;
+ if let Ok(name) = self.enum_lookup(ProtoEnumMapping::TipMode, tip_mode) {
+ return name;
+ }
+ }
+ "UNKNOWN".into()
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/frame_error.rs b/tools/avm_analyzer/avm_stats/src/frame_error.rs
new file mode 100644
index 0000000..259b648
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/frame_error.rs
@@ -0,0 +1,23 @@
+use thiserror::Error;
+
+#[derive(Error, Debug, Clone)]
+pub enum FrameError {
+ #[error("Badly formed frame: {0}")]
+ BadFrame(String),
+ #[error("Badly formed superblock: {0}")]
+ BadSuperblock(String),
+ #[error("Badly formed coding unit: {0}")]
+ BadCodingUnit(String),
+ #[error("Badly formed transform unit: {0}")]
+ BadTransformUnit(String),
+ #[error("Badly formed symbol: {0}")]
+ BadSymbol(String),
+ #[error("Badly formed pixel buffer: {0}")]
+ BadPixelBuffer(String),
+ #[error("Missing pixel buffer: {0}")]
+ MissingPixelBuffer(String),
+ #[error("Internal error: {0}")]
+ Internal(String),
+ #[error("Unknown frame error: {0}")]
+ Unknown(String),
+}
diff --git a/tools/avm_analyzer/avm_stats/src/heatmap.rs b/tools/avm_analyzer/avm_stats/src/heatmap.rs
new file mode 100644
index 0000000..dad3e80
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/heatmap.rs
@@ -0,0 +1,99 @@
+use itertools::{Itertools, MinMaxResult};
+use serde::{Deserialize, Serialize};
+
+use crate::{CodingUnitKind, Frame, FrameError, Spatial};
+// TODO(comc): Allow filtering by symbol type.
+// TODO(comc): Consider some way of handling this for TIP frames, e.g. weighted average of the two reference frames?
+pub const DEFAULT_HISTROGRAM_BUCKETS: usize = 32;
+
+#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
+pub struct HeatmapSettings {
+ pub symbol_filter: String,
+ pub histogram_buckets: usize,
+ pub coding_unit_kind: CodingUnitKind,
+}
+
+impl Default for HeatmapSettings {
+ fn default() -> Self {
+ Self {
+ symbol_filter: "".to_string(),
+ histogram_buckets: DEFAULT_HISTROGRAM_BUCKETS,
+ coding_unit_kind: CodingUnitKind::Shared,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct Heatmap {
+ pub width: usize,
+ pub height: usize,
+ pub data: Vec<u8>,
+ pub min_value: f32,
+ pub max_value: f32,
+ pub bucket_width: f32,
+ pub histogram: Vec<f32>,
+}
+
+pub fn calculate_heatmap(frame: &Frame, settings: &HeatmapSettings) -> Result<Heatmap, FrameError> {
+ let width = frame.width() as usize;
+ let height = frame.height() as usize;
+ let mut heatmap = vec![0.0; width * height];
+ // TODO(comc): Option to iterate over both luma and chroma symbols and add them up (for SDP frames).
+ let bit_rects = frame.iter_coding_units(settings.coding_unit_kind).map(|ctx| {
+ let cu = ctx.coding_unit;
+ let bits = ctx.iter_symbols().filter_map(|sym| {
+ if settings.symbol_filter.is_empty() || sym.info.unwrap().source_function.contains(&settings.symbol_filter)
+ {
+ Some(sym.symbol.bits)
+ } else {
+ None
+ }
+ });
+ let sum: f32 = bits.sum();
+ let y0 = cu.y() as usize;
+ let y1 = y0 + cu.height() as usize;
+ let x0 = cu.x() as usize;
+ let x1 = x0 + cu.width() as usize;
+
+ Ok::<_, FrameError>((y0.min(height), y1.min(height), x0.min(width), x1.min(width), sum))
+ });
+ for bit_rect in bit_rects.flatten() {
+ let (y0, y1, x0, x1, bits) = bit_rect;
+ let area = ((y1 - y0) * (x1 - x0)) as f32;
+ for y in y0..y1 {
+ for x in x0..x1 {
+ let index = y * width + x;
+ heatmap[index] = bits / area;
+ }
+ }
+ }
+ let mut min = 0.0;
+ let mut max = 255.0;
+ match heatmap.iter().minmax() {
+ MinMaxResult::NoElements | MinMaxResult::OneElement(_) => {}
+ MinMaxResult::MinMax(&min_v, &max_v) => {
+ min = min_v;
+ max = max_v;
+ }
+ };
+ let mut histogram = vec![0.0; settings.histogram_buckets];
+ heatmap.iter().for_each(|&x| {
+ let frac = (x - min) / (max - min);
+ let bucket = (frac * settings.histogram_buckets as f32) as usize;
+ let bucket = bucket.min(settings.histogram_buckets - 1);
+ histogram[bucket] += 1.0;
+ });
+ let heatmap: Vec<u8> = heatmap
+ .iter()
+ .map(|&x| (255.0 * (x - min) / (max - min)) as u8)
+ .collect();
+ Ok(Heatmap {
+ width,
+ height,
+ data: heatmap,
+ min_value: min,
+ max_value: max,
+ bucket_width: (max - min) / settings.histogram_buckets as f32,
+ histogram,
+ })
+}
diff --git a/tools/avm_analyzer/avm_stats/src/lib.rs b/tools/avm_analyzer/avm_stats/src/lib.rs
new file mode 100644
index 0000000..1fc8a92
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/lib.rs
@@ -0,0 +1,29 @@
+pub mod avm_proto;
+pub mod coding_unit;
+pub mod constants;
+pub mod frame;
+pub mod frame_error;
+pub mod heatmap;
+pub mod partition;
+pub mod pixels;
+pub mod plane;
+pub mod spatial;
+pub mod stats;
+pub mod superblock;
+pub mod symbol;
+pub mod transform_unit;
+
+pub use avm_proto::*;
+pub use coding_unit::*;
+pub use constants::*;
+pub use frame::*;
+pub use frame_error::*;
+pub use heatmap::*;
+pub use partition::*;
+pub use pixels::*;
+pub use plane::*;
+pub use spatial::*;
+pub use stats::*;
+pub use superblock::*;
+pub use symbol::*;
+pub use transform_unit::*;
diff --git a/tools/avm_analyzer/avm_stats/src/main.rs b/tools/avm_analyzer/avm_stats/src/main.rs
new file mode 100644
index 0000000..89fdd3a
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/main.rs
@@ -0,0 +1,19 @@
+use avm_stats::Frame;
+use clap::Parser;
+use prost::Message;
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Path to protobuf file.
+ #[arg(short, long)]
+ proto: String,
+}
+
+fn main() -> Result<(), anyhow::Error> {
+ let args = Args::parse();
+ let frame = std::fs::read(args.proto).unwrap();
+ let frame = Frame::decode(frame.as_slice()).unwrap();
+ println!("{:?}", frame.superblocks.len());
+ Ok(())
+}
diff --git a/tools/avm_analyzer/avm_stats/src/partition.rs b/tools/avm_analyzer/avm_stats/src/partition.rs
new file mode 100644
index 0000000..97b1d8e
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/partition.rs
@@ -0,0 +1,151 @@
+use serde::{Deserialize, Serialize};
+
+use crate::{CodingUnit, Partition, Spatial, SuperblockContext};
+
+use crate::{CodingUnitKind, Frame, SuperblockLocator, SymbolContext};
+
+pub struct PartitionIterator<'a> {
+ pub stack: Vec<(PartitionContext<'a>, usize)>,
+ pub max_depth: Option<usize>,
+}
+
+impl<'a> PartitionIterator<'a> {
+ fn new(root: PartitionContext<'a>) -> Self {
+ Self {
+ stack: vec![(root, 0)],
+ max_depth: None,
+ }
+ }
+
+ fn with_max_depth(root: PartitionContext<'a>, max_depth: usize) -> Self {
+ Self {
+ stack: vec![(root, 0)],
+ max_depth: Some(max_depth),
+ }
+ }
+}
+
+impl<'a> Iterator for PartitionIterator<'a> {
+ type Item = PartitionContext<'a>;
+ fn next(&mut self) -> Option<Self::Item> {
+ let (current, depth) = self.stack.pop()?;
+ let max_depth = self.max_depth.unwrap_or(usize::MAX);
+ let child_depth = depth + 1;
+ if child_depth <= max_depth {
+ self.stack
+ .extend(current.partition.children.iter().enumerate().rev().map(|(i, child)| {
+ let mut child_context = current.clone();
+ child_context.partition = child;
+ child_context.locator.path_indices.push(i);
+ (child_context, child_depth)
+ }));
+ }
+ Some(current)
+ }
+}
+
+// TODO(comc): Handle shared vs luma differently at the shared level of the partition tree.
+#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
+pub struct PartitionLocator {
+ pub path_indices: Vec<usize>,
+ pub kind: CodingUnitKind,
+ pub superblock: SuperblockLocator,
+}
+
+impl PartitionLocator {
+ pub fn new(path_indices: Vec<usize>, kind: CodingUnitKind, superblock: SuperblockLocator) -> Self {
+ Self {
+ path_indices,
+ kind,
+ superblock,
+ }
+ }
+
+ pub fn try_resolve<'a>(&self, frame: &'a Frame) -> Option<PartitionContext<'a>> {
+ let superblock_context = self.superblock.try_resolve(frame)?;
+ let mut current = match self.kind {
+ CodingUnitKind::Shared | CodingUnitKind::LumaOnly => {
+ superblock_context.superblock.luma_partition_tree.as_ref()?
+ }
+ CodingUnitKind::ChromaOnly => superblock_context.superblock.chroma_partition_tree.as_ref()?,
+ };
+ for index in self.path_indices.iter() {
+ if let Some(child) = current.children.get(*index) {
+ current = child;
+ } else {
+ return None;
+ }
+ }
+ Some(PartitionContext {
+ partition: current,
+ superblock_context,
+ locator: self.clone(),
+ })
+ }
+
+ pub fn resolve(self, frame: &Frame) -> PartitionContext {
+ self.try_resolve(frame).unwrap()
+ }
+
+ pub fn is_root(&self) -> bool {
+ self.path_indices.is_empty()
+ }
+
+ pub fn parent(&self) -> Option<PartitionLocator> {
+ if self.is_root() {
+ None
+ } else {
+ let mut parent = self.clone();
+ parent.path_indices.pop();
+ Some(parent)
+ }
+ }
+}
+
+/// Context about a partition block during iteration.
+#[derive(Clone)]
+pub struct PartitionContext<'a> {
+ /// Partition block being iterated over.
+ pub partition: &'a Partition,
+ /// Superblock that owns this partition block.
+ pub superblock_context: SuperblockContext<'a>,
+ /// The index of this partition block within its parent superblock.
+ pub locator: PartitionLocator,
+}
+
+impl<'a> PartitionContext<'a> {
+ // Note: Also yields self.
+ pub fn iter(&self) -> impl Iterator<Item = PartitionContext<'a>> {
+ PartitionIterator::new(self.clone())
+ }
+
+ // Note: Also yields self.
+ pub fn iter_with_max_depth(&self, max_depth: usize) -> impl Iterator<Item = PartitionContext<'a>> {
+ PartitionIterator::with_max_depth(self.clone(), max_depth)
+ }
+
+ pub fn iter_direct_children(&self) -> impl Iterator<Item = PartitionContext<'a>> {
+ self.iter_with_max_depth(1).skip(1)
+ }
+
+ pub fn iter_symbols(&self) -> impl Iterator<Item = SymbolContext<'a>> {
+ let symbol_range = self.partition.symbol_range.clone().unwrap_or_default();
+ self.superblock_context.iter_symbols(Some(symbol_range))
+ }
+
+ pub fn find_coding_unit_parent(&self, coding_unit: &'a CodingUnit) -> Option<PartitionContext<'a>> {
+ if self.partition.rect() == coding_unit.rect() {
+ return Some(self.clone());
+ }
+ for child in self.iter_direct_children() {
+ if child.partition.rect().contains_rect(coding_unit.rect()) {
+ return child.find_coding_unit_parent(coding_unit);
+ }
+ }
+ None
+ }
+
+ pub fn is_root(&self) -> bool {
+ self.locator.is_root()
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/pixels.rs b/tools/avm_analyzer/avm_stats/src/pixels.rs
new file mode 100644
index 0000000..4f218ff
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/pixels.rs
@@ -0,0 +1,359 @@
+use crate::FrameError;
+use crate::Plane;
+use crate::{Frame, PixelBuffer, Spatial, Superblock};
+use std::fmt;
+
+/// Where in the codec pipeline a pixel buffer was sampled from.
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
+pub enum PixelType {
+ /// Pre-encode pixels (may not always be available).
+ Original,
+ /// Intra or inter predicted pixels.
+ Prediction,
+ /// Reconstructed pixels BEFORE filtering.
+ PreFiltered,
+ /// Final reconstructed pixels AFTER filtering.
+ Reconstruction,
+ /// Residual, i.e. (PreFiltered - Prediction).
+ Residual,
+ /// The effect of the in-loop filtering, i.e. (Reconstruction - PreFiltered).
+ FilterDelta,
+ /// (Original - Reconstruction) - depends on Original pixels being available.
+ Distortion,
+}
+
+impl PixelType {
+ /// Whether this pixel type represents a difference between two other pixel types.
+ pub fn is_delta(&self) -> bool {
+ match self {
+ Self::Original | Self::Prediction | Self::PreFiltered | Self::Reconstruction => false,
+ Self::Residual | Self::FilterDelta | Self::Distortion => true,
+ }
+ }
+}
+
+impl PixelBuffer {
+ /// Retrieves a pixel from the buffer, and compensates for bit_depth adjustment if necessary.
+ /// `desired_bit_depth` will typically be the bit_depth of the stream itself. The underlying
+ /// buffer may have a different bit_depth in the case of original YUV pixels.
+ pub fn get_pixel(&self, x: i32, y: i32, desired_bit_depth: u8) -> Result<i16, FrameError> {
+ use FrameError::*;
+ let stride = self.width;
+ let index = (y * stride + x) as usize;
+ let mut pixel = *self.pixels.get(index).ok_or_else(|| {
+ BadPixelBuffer(format!(
+ "Out of bounds access (x={x}, y={y}) on pixel buffer (width={}, height={}).",
+ self.width, self.height
+ ))
+ })?;
+ if (self.bit_depth as u8) < desired_bit_depth {
+ let left_shift = desired_bit_depth - self.bit_depth as u8;
+ pixel <<= left_shift;
+ }
+ else if (self.bit_depth as u8) > desired_bit_depth {
+ let right_shift = self.bit_depth as u8 - desired_bit_depth;
+ pixel >>= right_shift;
+ }
+ Ok(pixel as i16)
+ }
+}
+
+/// Reference to a pixel buffer, or two pixel buffers in the case of a delta pixel type.
+#[derive(Debug, Clone)]
+pub enum PixelBufferRef<'a> {
+ Single(&'a PixelBuffer),
+ Delta(&'a PixelBuffer, &'a PixelBuffer),
+}
+
+impl<'a> PixelBufferRef<'a> {
+ pub fn new_single(buf: &'a PixelBuffer) -> Self {
+ Self::Single(buf)
+ }
+ pub fn new_delta(buf_1: &'a PixelBuffer, buf_2: &'a PixelBuffer) -> Self {
+ Self::Delta(buf_1, buf_2)
+ }
+
+ /// Assumes both underlying buffers have the same width.
+ pub fn width(&self) -> i32 {
+ match self {
+ Self::Single(buf) => buf.width,
+ Self::Delta(buf_1, _) => buf_1.width,
+ }
+ }
+
+ /// Assumes both underlying buffers have the same height.
+ pub fn height(&self) -> i32 {
+ match self {
+ Self::Single(buf) => buf.height,
+ Self::Delta(buf_1, _) => buf_1.height,
+ }
+ }
+
+ /// Get a pixel from the underlying buffer(s), or a `FrameError` if OoB access occurs.
+ pub fn get_pixel(&self, x: i32, y: i32, desired_bit_depth: u8) -> Result<i16, FrameError> {
+ match self {
+ Self::Single(buf) => {
+ buf.get_pixel(x, y, desired_bit_depth)
+ }
+ Self::Delta(buf_1, buf_2) => {
+ let pixel_1 = buf_1.get_pixel(x, y, desired_bit_depth)?;
+ let pixel_2 = buf_2.get_pixel(x, y, desired_bit_depth)?;
+ Ok(pixel_1 - pixel_2)
+ }
+ }
+ }
+}
+
+impl fmt::Display for PixelType {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let text = match self {
+ PixelType::Original => "Original YUV",
+ PixelType::Prediction => "Prediction",
+ PixelType::PreFiltered => "Prefiltered",
+ PixelType::Reconstruction => "Reconstruction",
+ PixelType::Residual => "Residual",
+ PixelType::FilterDelta => "Filter Delta",
+ PixelType::Distortion => "Distortion",
+ };
+ write!(f, "{text}")
+ }
+}
+
+/// Pixel data for a single plane (Y, U or V) and single pixel type.
+pub struct PixelPlane {
+ pub bit_depth: u8,
+ pub width: i32,
+ pub height: i32,
+ pub pixels: Vec<i16>,
+ pub plane: Plane,
+ pub pixel_type: PixelType,
+}
+
+impl PixelPlane {
+ fn create_from_tip_frame(frame: &Frame, plane: Plane, pixel_type: PixelType) -> Result<Self, FrameError> {
+ use FrameError::*;
+ let tip_params = frame.tip_frame_params.as_ref().unwrap();
+ let width = plane.subsampled(frame.width());
+ let height = plane.subsampled(frame.height());
+ let bit_depth = frame.bit_depth();
+ let mut pixels = vec![0; (width * height) as usize];
+
+ let pixel_data = tip_params
+ .pixel_data
+ .get(plane.to_i32() as usize)
+ .ok_or(BadFrame("Missing pixel data in tip frame.".into()))?;
+ let pixel_buffer = match pixel_type {
+ PixelType::Original => {
+ PixelBufferRef::new_single(pixel_data.original.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Original pixel data for plane {} not present",
+ plane.to_usize()
+ )))?)
+ }
+
+ PixelType::Reconstruction => {
+ PixelBufferRef::new_single(pixel_data.reconstruction.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Reconstruction pixel data for plane {} not present",
+ plane.to_usize()
+ )))?)
+ }
+
+ PixelType::Distortion => {
+ let original = pixel_data.original.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Original pixel data for plane {} not present",
+ plane.to_usize()
+ )))?;
+ let reconstruction = pixel_data.reconstruction.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Reconstruction pixel data for plane {} not present",
+ plane.to_usize()
+ )))?;
+ PixelBufferRef::new_delta(original, reconstruction)
+ }
+
+ _ => {
+ return Err(FrameError::Internal(format!(
+ "Tried to retrieve invalid single pixel buffer type ({pixel_type:?}) from TIP params."
+ )))
+ }
+ };
+
+ for y in 0..height {
+ for x in 0..width {
+ let index = (y * width + x) as usize;
+ pixels[index] = pixel_buffer.get_pixel(x, y, bit_depth)?;
+ }
+ }
+ Ok(Self {
+ bit_depth,
+ width,
+ height,
+ pixels,
+ plane,
+ pixel_type,
+ })
+ }
+
+ fn create_from_superblocks(frame: &Frame, plane: Plane, pixel_type: PixelType) -> Result<Self, FrameError> {
+ use FrameError::*;
+ let width = plane.subsampled(frame.width());
+ let height = plane.subsampled(frame.height());
+ let bit_depth = frame.bit_depth();
+ let mut pixels = vec![0; (width * height) as usize];
+
+ for sb_ctx in frame.iter_superblocks() {
+ let sb = sb_ctx.superblock;
+ let sb_width = plane.subsampled(sb.width());
+ let sb_height = plane.subsampled(sb.height());
+
+ if sb_width <= 0 || sb_height <= 0 {
+ return Err(BadSuperblock(format!("Invalid dimensions: {sb_width}x{sb_height}")));
+ }
+
+ let sb_x = plane.subsampled(sb.x());
+ let sb_y = plane.subsampled(sb.y());
+ if sb_x < 0 || sb_x >= width || sb_y < 0 || sb_y >= height {
+ return Err(BadSuperblock(format!("Outside frame bounds: x={sb_x}, y={sb_y}")));
+ }
+
+ let remaining_width = width - sb_x;
+ let remaining_height = height - sb_y;
+ let cropped_sb_width = sb_width.min(remaining_width);
+ let cropped_sb_height = sb_height.min(remaining_height);
+
+ let pixel_buffer = sb.get_pixels(plane, pixel_type)?;
+
+ if cropped_sb_width > pixel_buffer.width() || cropped_sb_height > pixel_buffer.height() {
+ return Err(BadPixelBuffer(format!(
+ "Expected pixel buffer shape: ({}x{}), Actual: ({}x{})",
+ cropped_sb_width,
+ cropped_sb_height,
+ pixel_buffer.width(),
+ pixel_buffer.height(),
+ )));
+ }
+
+ for rel_y in 0..sb_height {
+ let abs_y = sb_y + rel_y;
+ // Clip on frame bottom edge if frame height isn't a multiple of superblock size.
+ if abs_y >= height {
+ break;
+ }
+ for rel_x in 0..sb_width {
+ let abs_x = sb_x + rel_x;
+ // Clip on frame right edge if frame width isn't a multiple of superblock size.
+ if abs_x >= width {
+ break;
+ }
+ let dest_index = (abs_y * width + abs_x) as usize;
+ pixels[dest_index] = pixel_buffer.get_pixel(rel_x, rel_y, bit_depth)?;
+ }
+ }
+ }
+ Ok(Self {
+ bit_depth,
+ width,
+ height,
+ pixels,
+ plane,
+ pixel_type,
+ })
+ }
+
+ pub fn create_from_frame(frame: &Frame, plane: Plane, pixel_type: PixelType) -> Result<Self, FrameError> {
+ if let Some(tip_params) = &frame.tip_frame_params {
+ // TODO(comc): Const for this 2.
+ if tip_params.tip_mode == 2 {
+ return Self::create_from_tip_frame(frame, plane, pixel_type);
+ }
+ }
+ Self::create_from_superblocks(frame, plane, pixel_type)
+ }
+}
+
+impl Superblock {
+ /// Retrieves a single `PixelBuffer` from this superblock.
+ pub fn get_single_pixel_buffer(&self, plane: Plane, pixel_type: PixelType) -> Result<&PixelBuffer, FrameError> {
+ use FrameError::*;
+ let pixel_data = self.pixel_data.get(plane.to_usize()).ok_or(MissingPixelBuffer(format!(
+ "Pixel data for plane {} not present ({} total)",
+ plane.to_usize(),
+ self.pixel_data.len()
+ )))?;
+
+ let pixels = match pixel_type {
+ PixelType::Original => pixel_data.original.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Original pixel data for plane {} not present",
+ plane.to_usize()
+ )))?,
+
+ PixelType::Prediction => pixel_data.prediction.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Prediction pixel data for plane {} not present",
+ plane.to_usize()
+ )))?,
+
+ PixelType::PreFiltered => pixel_data.pre_filtered.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Pre-filtered pixel data for plane {} not present",
+ plane.to_usize()
+ )))?,
+
+ PixelType::Reconstruction => pixel_data.reconstruction.as_ref().ok_or(MissingPixelBuffer(format!(
+ "Reconstruction pixel data for plane {} not present",
+ plane.to_usize()
+ )))?,
+
+ _ => {
+ return Err(FrameError::Internal(format!(
+ "Tried to retrieve invalid single pixel buffer type ({pixel_type:?}) from protobuf superblock."
+ )))
+ }
+ };
+ let width = pixels.width;
+ let height = pixels.height;
+ let num_pixels = width * height;
+ let actual_pixels = pixels.pixels.len() as i32;
+ if num_pixels != actual_pixels {
+ return Err(FrameError::BadPixelBuffer(format!(
+ "Pixel buffer contains {actual_pixels} pixels, but dimensions require {num_pixels} pixels ({}x{})",
+ width, height
+ )));
+ }
+ Ok(pixels)
+ }
+
+ pub fn get_pixels(&self, plane: Plane, pixel_type: PixelType) -> Result<PixelBufferRef, FrameError> {
+ if pixel_type.is_delta() {
+ let (buf_1, buf_2) = match pixel_type {
+ PixelType::Residual => {
+ let pre_filtered = self.get_single_pixel_buffer(plane, PixelType::PreFiltered)?;
+ let prediction = self.get_single_pixel_buffer(plane, PixelType::Prediction)?;
+ (pre_filtered, prediction)
+ }
+ PixelType::FilterDelta => {
+ let reconstruction = self.get_single_pixel_buffer(plane, PixelType::Reconstruction)?;
+ let pre_filtered = self.get_single_pixel_buffer(plane, PixelType::PreFiltered)?;
+ (reconstruction, pre_filtered)
+ }
+ PixelType::Distortion => {
+ let original = self.get_single_pixel_buffer(plane, PixelType::Original)?;
+ let reconstruction = self.get_single_pixel_buffer(plane, PixelType::Reconstruction)?;
+ (original, reconstruction)
+ }
+ _ => {
+ return Err(FrameError::Internal(format!(
+ "Tried to retrieve invalid pixel delta type: {pixel_type:?}"
+ )));
+ }
+ };
+
+ if buf_1.width != buf_2.width || buf_1.height != buf_2.height {
+ return Err(FrameError::BadPixelBuffer(format!(
+ "Mismatched dimensions: {}x{} vs {}x{}",
+ buf_1.width, buf_1.height, buf_2.width, buf_2.height
+ )));
+ }
+ Ok(PixelBufferRef::new_delta(buf_1, buf_2))
+ } else {
+ let buf = self.get_single_pixel_buffer(plane, pixel_type)?;
+ Ok(PixelBufferRef::new_single(buf))
+ }
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/plane.rs b/tools/avm_analyzer/avm_stats/src/plane.rs
new file mode 100644
index 0000000..025dd5f
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/plane.rs
@@ -0,0 +1,91 @@
+use std::fmt;
+
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize)]
+pub enum Plane {
+ Y,
+ U,
+ V,
+}
+impl Plane {
+ pub fn as_str(&self) -> &str {
+ match self {
+ Plane::Y => "Y plane",
+ Plane::U => "U plane",
+ Plane::V => "V plane",
+ }
+ }
+
+ pub fn from_i32(i: i32) -> Self {
+ match i {
+ 0 => Plane::Y,
+ 1 => Plane::U,
+ 2 => Plane::V,
+ _ => panic!("Bad plane id: {i}"),
+ }
+ }
+
+ pub fn to_i32(&self) -> i32 {
+ match self {
+ Plane::Y => 0,
+ Plane::U => 1,
+ Plane::V => 2,
+ }
+ }
+
+ pub fn to_usize(&self) -> usize {
+ self.to_i32() as usize
+ }
+
+ pub fn is_chroma(&self) -> bool {
+ match self {
+ Plane::Y => false,
+ Plane::U | Plane::V => true,
+ }
+ }
+
+ pub fn subsampled(&self, dimension: i32) -> i32 {
+ if self.is_chroma() {
+ (dimension + 1) / 2
+ } else {
+ dimension
+ }
+ }
+}
+
+#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
+pub enum PlaneType {
+ Planar(Plane),
+ #[default]
+ Rgb,
+}
+
+impl PlaneType {
+ // For partition tree selection
+ pub fn use_chroma(&self) -> bool {
+ match self {
+ PlaneType::Rgb | PlaneType::Planar(Plane::Y) => false,
+ PlaneType::Planar(Plane::U) | PlaneType::Planar(Plane::V) => true,
+ }
+ }
+
+ pub fn to_plane(&self) -> Plane {
+ match self {
+ PlaneType::Rgb => Plane::Y,
+ PlaneType::Planar(plane) => *plane,
+ }
+ }
+}
+
+impl fmt::Display for PlaneType {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let text = match self {
+ PlaneType::Planar(Plane::Y) => "Y",
+ PlaneType::Planar(Plane::U) => "U",
+ PlaneType::Planar(Plane::V) => "V",
+ PlaneType::Rgb => "YUV",
+ };
+ write!(f, "{text}")
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/spatial.rs b/tools/avm_analyzer/avm_stats/src/spatial.rs
new file mode 100644
index 0000000..d0dd00c
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/spatial.rs
@@ -0,0 +1,109 @@
+use crate::{CodingUnit, Frame, Partition, Superblock, TransformUnit};
+
+pub trait Spatial {
+ fn width(&self) -> i32;
+ fn height(&self) -> i32;
+ fn x(&self) -> i32;
+ fn y(&self) -> i32;
+ fn rect(&self) -> emath::Rect {
+ emath::Rect::from_min_size(
+ emath::pos2(self.x() as f32, self.y() as f32),
+ emath::vec2(self.width() as f32, self.height() as f32),
+ )
+ }
+ fn size_name(&self) -> String {
+ format!("{}x{}", self.width(), self.height())
+ }
+}
+
+// TODO(comc): Could use a macro to implement each of these.
+impl Spatial for TransformUnit {
+ fn width(&self) -> i32 {
+ // TODO(comc): This is very messy. Add derive Default to prost build script.
+ self.size.as_ref().map_or(0, |size| size.width)
+ }
+
+ fn height(&self) -> i32 {
+ self.size.as_ref().map_or(0, |size| size.height)
+ }
+
+ fn x(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.x)
+ }
+
+ fn y(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.y)
+ }
+}
+
+impl Spatial for CodingUnit {
+ fn width(&self) -> i32 {
+ self.size.as_ref().map_or(0, |size| size.width)
+ }
+
+ fn height(&self) -> i32 {
+ self.size.as_ref().map_or(0, |size| size.height)
+ }
+
+ fn x(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.x)
+ }
+
+ fn y(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.y)
+ }
+}
+
+impl Spatial for Partition {
+ fn width(&self) -> i32 {
+ self.size.as_ref().map_or(0, |size| size.width)
+ }
+
+ fn height(&self) -> i32 {
+ self.size.as_ref().map_or(0, |size| size.height)
+ }
+
+ fn x(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.x)
+ }
+
+ fn y(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.y)
+ }
+}
+
+impl Spatial for Superblock {
+ fn width(&self) -> i32 {
+ self.size.as_ref().map_or(0, |size| size.width)
+ }
+
+ fn height(&self) -> i32 {
+ self.size.as_ref().map_or(0, |size| size.height)
+ }
+
+ fn x(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.x)
+ }
+
+ fn y(&self) -> i32 {
+ self.position.as_ref().map_or(0, |position| position.y)
+ }
+}
+
+impl Spatial for Frame {
+ fn width(&self) -> i32 {
+ self.frame_params.as_ref().map_or(0, |frame_params| frame_params.width)
+ }
+
+ fn height(&self) -> i32 {
+ self.frame_params.as_ref().map_or(0, |frame_params| frame_params.height)
+ }
+
+ fn x(&self) -> i32 {
+ 0
+ }
+
+ fn y(&self) -> i32 {
+ 0
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/stats.rs b/tools/avm_analyzer/avm_stats/src/stats.rs
new file mode 100644
index 0000000..abcf844
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/stats.rs
@@ -0,0 +1,331 @@
+use std::collections::{HashMap, HashSet};
+
+use itertools::Itertools;
+
+use ordered_float::OrderedFloat;
+use serde::{Deserialize, Serialize};
+
+use crate::{CodingUnitKind, Frame, Plane, PlaneType, ProtoEnumMapping, Spatial};
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
+pub enum FrameStatistic {
+ LumaModes,
+ ChromaModes,
+ BlockSizes,
+ Symbols,
+ PartitionSplit,
+}
+
+#[derive(Default, Debug, Deserialize, Serialize, PartialEq)]
+pub struct StatsFilter {
+ pub include: Vec<String>,
+ pub exclude: Vec<String>,
+}
+
+impl StatsFilter {
+ fn from_comma_separated(include: &str, exclude: &str) -> Self {
+ Self {
+ include: include
+ .split(',')
+ .map(|s| s.trim().to_string())
+ .filter(|s| !s.is_empty())
+ .collect(),
+ exclude: exclude
+ .split(',')
+ .map(|s| s.trim().to_string())
+ .filter(|s| !s.is_empty())
+ .collect(),
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
+pub enum StatSortMode {
+ Unsorted,
+ ByName,
+ ByValue,
+}
+impl StatSortMode {
+ pub fn name(&self) -> &'static str {
+ match self {
+ Self::Unsorted => "Unsorted",
+ Self::ByName => "By name",
+ Self::ByValue => "By value",
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
+pub struct StatsSettings {
+ pub sort_by: StatSortMode,
+ // Using separate bool + value fields rather than an Option to make the UI design a bit more intuitive (e.g. checkbox + disabled number input).
+ pub apply_limit_count: bool,
+ pub limit_count: usize,
+ pub apply_limit_frac: bool,
+ pub limit_frac: f32,
+ pub include_filter: String,
+ pub exclude_filter: String,
+ pub include_filter_exact_match: bool,
+ pub exclude_filter_exact_match: bool,
+ pub show_relative_total: bool,
+ // Comma separated list of block sizes to include for partition split stats.
+ pub partition_split_block_sizes: String,
+}
+
+impl Default for StatsSettings {
+ fn default() -> Self {
+ Self {
+ sort_by: StatSortMode::ByValue,
+ apply_limit_count: false,
+ limit_count: 20,
+ apply_limit_frac: false,
+ limit_frac: 0.01,
+ include_filter: "".into(),
+ exclude_filter: "".into(),
+ include_filter_exact_match: false,
+ exclude_filter_exact_match: false,
+ show_relative_total: false,
+ partition_split_block_sizes: "".into(),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Sample {
+ pub name: String,
+ pub value: f64,
+}
+impl Sample {
+ pub fn new(name: String, value: f64) -> Self {
+ Self { name, value }
+ }
+}
+
+impl FrameStatistic {
+ fn luma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
+ let modes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
+ let cu = ctx.coding_unit;
+ let prediction_mode = cu.prediction_mode.as_ref().unwrap();
+ frame
+ .enum_lookup(ProtoEnumMapping::PredictionMode, prediction_mode.mode)
+ .unwrap_or("UNKNOWN".into())
+ });
+ let mut modes_map: HashMap<String, f64> = HashMap::new();
+
+ for mode in modes {
+ *modes_map.entry(mode).or_default() += 1.0;
+ }
+
+ modes_map
+ }
+
+ fn chroma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
+ let kind = frame.coding_unit_kind(PlaneType::Planar(Plane::U));
+ let modes = frame.iter_coding_units(kind).map(|ctx| {
+ let cu = ctx.coding_unit;
+ let prediction_mode = cu.prediction_mode.as_ref().unwrap();
+ frame
+ .enum_lookup(ProtoEnumMapping::UvPredictionMode, prediction_mode.uv_mode)
+ .unwrap_or("UNKNOWN".into())
+ });
+ let mut modes_map: HashMap<String, f64> = HashMap::new();
+
+ for mode in modes {
+ *modes_map.entry(mode).or_default() += 1.0;
+ }
+
+ modes_map
+ }
+
+ fn block_sizes(&self, frame: &Frame) -> HashMap<String, f64> {
+ let sizes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
+ let cu = ctx.coding_unit;
+ let w = cu.width();
+ let h = cu.height();
+ format!("{w}x{h}")
+ });
+ let mut sizes_map: HashMap<String, f64> = HashMap::new();
+
+ for size in sizes {
+ *sizes_map.entry(size).or_default() += 1.0;
+ }
+ sizes_map
+ }
+
+ fn partition_split(&self, frame: &Frame, settings: &StatsSettings) -> HashMap<String, f64> {
+ // TODO(comc): Add settings option for partition kind.
+ let filter = StatsFilter::from_comma_separated(&settings.partition_split_block_sizes, "");
+ let splits = frame.iter_partitions(CodingUnitKind::Shared).filter_map(|ctx| {
+ let partition = ctx.partition;
+ let size = partition.size_name();
+ if !filter.include.is_empty() && !filter.include.iter().any(|incl| &size == incl) {
+ return None;
+ }
+ let partition_type = frame
+ .enum_lookup(ProtoEnumMapping::PartitionType, partition.partition_type)
+ .unwrap_or("UNKNOWN".into());
+ Some(partition_type)
+ });
+ let mut splits_map: HashMap<String, f64> = HashMap::new();
+
+ for split in splits {
+ *splits_map.entry(split).or_default() += 1.0;
+ }
+ splits_map
+ }
+
+ fn symbols(&self, frame: &Frame) -> HashMap<String, f64> {
+ let mut symbols: HashMap<String, f64> = HashMap::new();
+ // TODO(comc): Use iter_symbols. Add iter_symbols method for partition blocks as well.
+ let sbs = frame.iter_superblocks().map(|sb_ctx| {
+ let sb = sb_ctx.superblock;
+ let mut symbols_sb: HashMap<String, f64> = HashMap::new();
+ for symbol in sb.symbols.iter() {
+ let info = symbol.info_id;
+ let info = &frame.symbol_info[&info];
+ let name = info.source_function.clone();
+ let bits = symbol.bits;
+ *symbols_sb.entry(name.clone()).or_default() += bits as f64;
+ }
+ symbols_sb
+ });
+ for symbols_sb in sbs {
+ for (name, bits) in symbols_sb {
+ *symbols.entry(name.clone()).or_default() += bits;
+ }
+ }
+ symbols
+ }
+
+ fn apply_settings(&self, mapping: HashMap<String, f64>, settings: &StatsSettings) -> Vec<Sample> {
+ let mut samples: Vec<_> = mapping
+ .into_iter()
+ .map(|(name, value)| Sample::new(name, value))
+ .collect();
+ let filter: StatsFilter = StatsFilter::from_comma_separated(&settings.include_filter, &settings.exclude_filter);
+ let total: f64 = samples.iter().map(|sample| sample.value).sum();
+ let mut other = 0.0;
+ samples.retain(|Sample { name, value }| {
+ let mut keep = true;
+ // TODO(comc): Make this a method of StatsFilter.
+ if !filter.include.is_empty() {
+ if settings.include_filter_exact_match {
+ if !filter.include.contains(name) {
+ keep = false;
+ }
+ } else if !filter.include.iter().any(|incl| name.contains(incl)) {
+ keep = false;
+ }
+ }
+ if settings.exclude_filter_exact_match {
+ if filter.exclude.contains(name) {
+ keep = false
+ }
+ } else if filter.exclude.iter().any(|excl| name.contains(excl)) {
+ keep = false;
+ }
+ if !keep {
+ other += value;
+ }
+ keep
+ });
+
+ let filtered_total: f64 = samples.iter().map(|sample| sample.value).sum();
+
+ let top_n = if settings.apply_limit_count {
+ let top_n: HashSet<_> = samples
+ .iter()
+ .sorted_by_key(|sample| (OrderedFloat(sample.value), &sample.name)) // name used as a tie-breaker.
+ .rev()
+ .map(|sample| sample.name.clone())
+ .take(settings.limit_count)
+ .collect();
+ Some(top_n)
+ } else {
+ None
+ };
+
+ samples.retain(|Sample { name, value }| {
+ let mut keep = true;
+ if settings.apply_limit_frac {
+ let frac = if settings.show_relative_total {
+ *value / filtered_total
+ } else {
+ *value / total
+ };
+ if frac < settings.limit_frac as f64 {
+ keep = false;
+ }
+ }
+
+ if let Some(top_n) = &top_n {
+ if !top_n.contains(name) {
+ keep = false;
+ }
+ }
+
+ if !keep {
+ other += value;
+ }
+ keep
+ });
+
+ if !settings.show_relative_total && other > 0.0 {
+ samples.push(Sample::new("Other".into(), other));
+ }
+ match settings.sort_by {
+ StatSortMode::ByName => samples
+ .into_iter()
+ .sorted_by_key(|sample| sample.name.clone())
+ .collect(),
+ StatSortMode::ByValue => samples
+ .into_iter()
+ .sorted_by_key(|sample| (OrderedFloat(sample.value), sample.name.clone())) // name used as a tie-breaker.
+ .collect(),
+ StatSortMode::Unsorted => samples,
+ }
+ }
+
+ pub fn calculate(&self, frame: &Frame, settings: &StatsSettings) -> Vec<Sample> {
+ let mapping = match self {
+ FrameStatistic::LumaModes => self.luma_modes(frame),
+ FrameStatistic::ChromaModes => self.chroma_modes(frame),
+ FrameStatistic::BlockSizes => self.block_sizes(frame),
+ FrameStatistic::Symbols => self.symbols(frame),
+ FrameStatistic::PartitionSplit => self.partition_split(frame, settings),
+ };
+ self.apply_settings(mapping, settings)
+ }
+
+ pub fn name(&self) -> &'static str {
+ match self {
+ FrameStatistic::LumaModes => "Luma modes",
+ FrameStatistic::ChromaModes => "Chroma modes",
+ FrameStatistic::BlockSizes => "Block sizes",
+ FrameStatistic::Symbols => "Symbols",
+ FrameStatistic::PartitionSplit => "Partition split",
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+
+ #[test]
+ fn test_exclude_filter() {
+ let mapping: HashMap<String, f64> = [("ABC", 1.0), ("AAA", 2.0)]
+ .iter()
+ .map(|(k, v)| (k.to_string(), *v))
+ .collect();
+ let settings = StatsSettings {
+ exclude_filter: "A".into(),
+ exclude_filter_exact_match: true,
+ show_relative_total: true,
+ ..Default::default()
+ };
+ let samples = FrameStatistic::Symbols.apply_settings(mapping, &settings);
+ assert_eq!(samples.len(), 2);
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/superblock.rs b/tools/avm_analyzer/avm_stats/src/superblock.rs
new file mode 100644
index 0000000..66f933b
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/superblock.rs
@@ -0,0 +1,87 @@
+use serde::{Deserialize, Serialize};
+
+use crate::{
+ CodingUnitContext, CodingUnitKind, CodingUnitLocator, Frame, PartitionContext, PartitionIterator, PartitionLocator,
+ Superblock, SymbolContext, SymbolRange,
+};
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
+pub struct SuperblockLocator {
+ pub index: usize,
+}
+
+impl SuperblockLocator {
+ pub fn new(index: usize) -> Self {
+ Self { index }
+ }
+ pub fn try_resolve<'a>(&self, frame: &'a Frame) -> Option<SuperblockContext<'a>> {
+ frame.superblocks.get(self.index).map(|superblock| SuperblockContext {
+ superblock,
+ frame,
+ locator: *self,
+ })
+ }
+ pub fn resolve<'a>(&self, frame: &'a Frame) -> SuperblockContext<'a> {
+ self.try_resolve(frame).unwrap()
+ }
+}
+
+#[derive(Copy, Clone, Debug)]
+pub struct SuperblockContext<'a> {
+ pub superblock: &'a Superblock,
+ pub frame: &'a Frame,
+ pub locator: SuperblockLocator,
+}
+
+impl<'a> SuperblockContext<'a> {
+ pub fn iter_partitions(&self, kind: CodingUnitKind) -> impl Iterator<Item = PartitionContext<'a>> {
+ let root = match kind {
+ CodingUnitKind::Shared | CodingUnitKind::LumaOnly => self.superblock.luma_partition_tree.as_ref().unwrap(),
+ CodingUnitKind::ChromaOnly => self.superblock.chroma_partition_tree.as_ref().unwrap(),
+ };
+
+ let root_locator = PartitionLocator::new(Vec::new(), kind, self.locator);
+ let root_context = PartitionContext {
+ partition: root,
+ superblock_context: *self,
+ locator: root_locator,
+ };
+ PartitionIterator {
+ stack: vec![(root_context, 0)],
+ max_depth: None,
+ }
+ }
+
+ pub fn root_partition(&self, kind: CodingUnitKind) -> Option<PartitionContext<'a>> {
+ self.iter_partitions(kind).next()
+ }
+
+ // Consuming self simplifies lifetime management in caller.
+ pub fn iter_coding_units(self, kind: CodingUnitKind) -> impl Iterator<Item = CodingUnitContext<'a>> {
+ let coding_units = match kind {
+ CodingUnitKind::Shared | CodingUnitKind::LumaOnly => self.superblock.coding_units_shared.iter(),
+ CodingUnitKind::ChromaOnly => self.superblock.coding_units_chroma.iter(),
+ };
+ coding_units
+ .enumerate()
+ .map(move |(index, coding_unit)| CodingUnitContext {
+ coding_unit,
+ superblock_context: self,
+ locator: CodingUnitLocator::new(self.locator, kind, index),
+ })
+ }
+
+ pub fn iter_symbols(&self, range: Option<SymbolRange>) -> impl Iterator<Item = SymbolContext<'a>> {
+ let range = range.unwrap_or(SymbolRange {
+ start: 0,
+ end: self.superblock.symbols.len() as u32,
+ });
+ self.superblock.symbols[range.start as usize..range.end as usize]
+ .iter()
+ .map(|sym| SymbolContext {
+ symbol: sym,
+ info: self.frame.symbol_info.get(&sym.info_id),
+ superblock: self.superblock,
+ })
+ }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/symbol.rs b/tools/avm_analyzer/avm_stats/src/symbol.rs
new file mode 100644
index 0000000..886c8c6
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/symbol.rs
@@ -0,0 +1,34 @@
+use crate::{Superblock, Symbol, SymbolInfo};
+use once_cell::sync::Lazy;
+
+pub static MISSING_SYMBOL_INFO: Lazy<SymbolInfo> = Lazy::new(|| SymbolInfo {
+ id: -1,
+ source_file: "UNKNOWN".into(),
+ source_line: -1,
+ source_function: "UNKNOWN".into(),
+ tags: Vec::new(),
+});
+
+#[derive(Copy, Clone)]
+pub struct SymbolContext<'a> {
+ pub symbol: &'a Symbol,
+ pub info: Option<&'a SymbolInfo>,
+ pub superblock: &'a Superblock,
+}
+
+impl<'a> SymbolContext<'a> {
+ // pub fn from_coding_unit_context(
+ // transform_unit: &'a TransformUnit,
+ // plane: Plane,
+ // transform_unit_index: usize,
+ // coding_unit_context: CodingUnitContext<'a>,
+ // ) -> Self {
+ // let index = TransformUnitIndex::new(coding_unit_context.index, plane, transform_unit_index);
+ // Self {
+ // transform_unit,
+ // coding_unit: coding_unit_context.coding_unit,
+ // superblock: coding_unit_context.superblock,
+ // index,
+ // }
+ // }
+}
diff --git a/tools/avm_analyzer/avm_stats/src/transform_unit.rs b/tools/avm_analyzer/avm_stats/src/transform_unit.rs
new file mode 100644
index 0000000..310ab6f
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/transform_unit.rs
@@ -0,0 +1,69 @@
+use serde::{Deserialize, Serialize};
+
+use crate::{CodingUnitContext, CodingUnitLocator, Frame, Plane, ProtoEnumMapping, TransformUnit};
+
+// TX blocks larger than 32x32 have all coefficients other than the top-left 32x32 set to 0.
+pub const MAX_COEFFS_SIZE: usize = 32;
+
+impl TransformUnit {
+ pub fn primary_tx_type_or_skip(&self, frame: &Frame) -> String {
+ let tx_type = self.tx_type;
+ // Only lower 4-bits used for primary transform. Upper bits are IST.
+ let tx_type = tx_type & 0xF;
+ if self.skip == 1 {
+ "SKIP".to_owned()
+ } else {
+ frame
+ .enum_lookup(ProtoEnumMapping::TransformType, tx_type)
+ .unwrap_or(format!("UNKNOWN ({tx_type})"))
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
+pub struct TransformUnitLocator {
+ pub coding_unit: CodingUnitLocator,
+ pub plane: Plane,
+ /// Index of this tranform unit with its parent.
+ pub index: usize,
+}
+
+// Note: Converting plane to a usize does not automatically get the correct index into a coding unit's transform planes.
+// e.g., in SDP mode, the chroma coding units will have two planes, but with plane IDs (1, 2), not (0, 1).
+
+impl TransformUnitLocator {
+ pub fn new(coding_unit: CodingUnitLocator, plane: Plane, index: usize) -> Self {
+ Self {
+ coding_unit,
+ plane,
+ index,
+ }
+ }
+
+ pub fn try_resolve<'a>(&self, frame: &'a Frame) -> Option<TransformUnitContext<'a>> {
+ let coding_unit_context = self.coding_unit.try_resolve(frame)?;
+ let plane_index = coding_unit_context.coding_unit.plane_index(self.plane).ok()?;
+ let transform_unit = coding_unit_context
+ .coding_unit
+ .transform_planes
+ .get(plane_index)?
+ .transform_units
+ .get(self.index)?;
+ Some(TransformUnitContext {
+ transform_unit,
+ coding_unit_context,
+ locator: *self,
+ })
+ }
+
+ pub fn resolve<'a>(&self, frame: &'a Frame) -> TransformUnitContext<'a> {
+ self.try_resolve(frame).unwrap()
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct TransformUnitContext<'a> {
+ pub transform_unit: &'a TransformUnit,
+ pub coding_unit_context: CodingUnitContext<'a>,
+ pub locator: TransformUnitLocator,
+}