AVM Analyzer initial commit.
diff --git a/tools/avm_analyzer/avm_stats/src/stats.rs b/tools/avm_analyzer/avm_stats/src/stats.rs
new file mode 100644
index 0000000..abcf844
--- /dev/null
+++ b/tools/avm_analyzer/avm_stats/src/stats.rs
@@ -0,0 +1,331 @@
+use std::collections::{HashMap, HashSet};
+
+use itertools::Itertools;
+
+use ordered_float::OrderedFloat;
+use serde::{Deserialize, Serialize};
+
+use crate::{CodingUnitKind, Frame, Plane, PlaneType, ProtoEnumMapping, Spatial};
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
+pub enum FrameStatistic {
+ LumaModes,
+ ChromaModes,
+ BlockSizes,
+ Symbols,
+ PartitionSplit,
+}
+
+#[derive(Default, Debug, Deserialize, Serialize, PartialEq)]
+pub struct StatsFilter {
+ pub include: Vec<String>,
+ pub exclude: Vec<String>,
+}
+
+impl StatsFilter {
+ fn from_comma_separated(include: &str, exclude: &str) -> Self {
+ Self {
+ include: include
+ .split(',')
+ .map(|s| s.trim().to_string())
+ .filter(|s| !s.is_empty())
+ .collect(),
+ exclude: exclude
+ .split(',')
+ .map(|s| s.trim().to_string())
+ .filter(|s| !s.is_empty())
+ .collect(),
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
+pub enum StatSortMode {
+ Unsorted,
+ ByName,
+ ByValue,
+}
+impl StatSortMode {
+ pub fn name(&self) -> &'static str {
+ match self {
+ Self::Unsorted => "Unsorted",
+ Self::ByName => "By name",
+ Self::ByValue => "By value",
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
+pub struct StatsSettings {
+ pub sort_by: StatSortMode,
+ // Using separate bool + value fields rather than an Option to make the UI design a bit more intuitive (e.g. checkbox + disabled number input).
+ pub apply_limit_count: bool,
+ pub limit_count: usize,
+ pub apply_limit_frac: bool,
+ pub limit_frac: f32,
+ pub include_filter: String,
+ pub exclude_filter: String,
+ pub include_filter_exact_match: bool,
+ pub exclude_filter_exact_match: bool,
+ pub show_relative_total: bool,
+ // Comma separated list of block sizes to include for partition split stats.
+ pub partition_split_block_sizes: String,
+}
+
+impl Default for StatsSettings {
+ fn default() -> Self {
+ Self {
+ sort_by: StatSortMode::ByValue,
+ apply_limit_count: false,
+ limit_count: 20,
+ apply_limit_frac: false,
+ limit_frac: 0.01,
+ include_filter: "".into(),
+ exclude_filter: "".into(),
+ include_filter_exact_match: false,
+ exclude_filter_exact_match: false,
+ show_relative_total: false,
+ partition_split_block_sizes: "".into(),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Sample {
+ pub name: String,
+ pub value: f64,
+}
+impl Sample {
+ pub fn new(name: String, value: f64) -> Self {
+ Self { name, value }
+ }
+}
+
+impl FrameStatistic {
+ fn luma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
+ let modes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
+ let cu = ctx.coding_unit;
+ let prediction_mode = cu.prediction_mode.as_ref().unwrap();
+ frame
+ .enum_lookup(ProtoEnumMapping::PredictionMode, prediction_mode.mode)
+ .unwrap_or("UNKNOWN".into())
+ });
+ let mut modes_map: HashMap<String, f64> = HashMap::new();
+
+ for mode in modes {
+ *modes_map.entry(mode).or_default() += 1.0;
+ }
+
+ modes_map
+ }
+
+ fn chroma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
+ let kind = frame.coding_unit_kind(PlaneType::Planar(Plane::U));
+ let modes = frame.iter_coding_units(kind).map(|ctx| {
+ let cu = ctx.coding_unit;
+ let prediction_mode = cu.prediction_mode.as_ref().unwrap();
+ frame
+ .enum_lookup(ProtoEnumMapping::UvPredictionMode, prediction_mode.uv_mode)
+ .unwrap_or("UNKNOWN".into())
+ });
+ let mut modes_map: HashMap<String, f64> = HashMap::new();
+
+ for mode in modes {
+ *modes_map.entry(mode).or_default() += 1.0;
+ }
+
+ modes_map
+ }
+
+ fn block_sizes(&self, frame: &Frame) -> HashMap<String, f64> {
+ let sizes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
+ let cu = ctx.coding_unit;
+ let w = cu.width();
+ let h = cu.height();
+ format!("{w}x{h}")
+ });
+ let mut sizes_map: HashMap<String, f64> = HashMap::new();
+
+ for size in sizes {
+ *sizes_map.entry(size).or_default() += 1.0;
+ }
+ sizes_map
+ }
+
+ fn partition_split(&self, frame: &Frame, settings: &StatsSettings) -> HashMap<String, f64> {
+ // TODO(comc): Add settings option for partition kind.
+ let filter = StatsFilter::from_comma_separated(&settings.partition_split_block_sizes, "");
+ let splits = frame.iter_partitions(CodingUnitKind::Shared).filter_map(|ctx| {
+ let partition = ctx.partition;
+ let size = partition.size_name();
+ if !filter.include.is_empty() && !filter.include.iter().any(|incl| &size == incl) {
+ return None;
+ }
+ let partition_type = frame
+ .enum_lookup(ProtoEnumMapping::PartitionType, partition.partition_type)
+ .unwrap_or("UNKNOWN".into());
+ Some(partition_type)
+ });
+ let mut splits_map: HashMap<String, f64> = HashMap::new();
+
+ for split in splits {
+ *splits_map.entry(split).or_default() += 1.0;
+ }
+ splits_map
+ }
+
+ fn symbols(&self, frame: &Frame) -> HashMap<String, f64> {
+ let mut symbols: HashMap<String, f64> = HashMap::new();
+ // TODO(comc): Use iter_symbols. Add iter_symbols method for partition blocks as well.
+ let sbs = frame.iter_superblocks().map(|sb_ctx| {
+ let sb = sb_ctx.superblock;
+ let mut symbols_sb: HashMap<String, f64> = HashMap::new();
+ for symbol in sb.symbols.iter() {
+ let info = symbol.info_id;
+ let info = &frame.symbol_info[&info];
+ let name = info.source_function.clone();
+ let bits = symbol.bits;
+ *symbols_sb.entry(name.clone()).or_default() += bits as f64;
+ }
+ symbols_sb
+ });
+ for symbols_sb in sbs {
+ for (name, bits) in symbols_sb {
+ *symbols.entry(name.clone()).or_default() += bits;
+ }
+ }
+ symbols
+ }
+
+ fn apply_settings(&self, mapping: HashMap<String, f64>, settings: &StatsSettings) -> Vec<Sample> {
+ let mut samples: Vec<_> = mapping
+ .into_iter()
+ .map(|(name, value)| Sample::new(name, value))
+ .collect();
+ let filter: StatsFilter = StatsFilter::from_comma_separated(&settings.include_filter, &settings.exclude_filter);
+ let total: f64 = samples.iter().map(|sample| sample.value).sum();
+ let mut other = 0.0;
+ samples.retain(|Sample { name, value }| {
+ let mut keep = true;
+ // TODO(comc): Make this a method of StatsFilter.
+ if !filter.include.is_empty() {
+ if settings.include_filter_exact_match {
+ if !filter.include.contains(name) {
+ keep = false;
+ }
+ } else if !filter.include.iter().any(|incl| name.contains(incl)) {
+ keep = false;
+ }
+ }
+ if settings.exclude_filter_exact_match {
+ if filter.exclude.contains(name) {
+ keep = false
+ }
+ } else if filter.exclude.iter().any(|excl| name.contains(excl)) {
+ keep = false;
+ }
+ if !keep {
+ other += value;
+ }
+ keep
+ });
+
+ let filtered_total: f64 = samples.iter().map(|sample| sample.value).sum();
+
+ let top_n = if settings.apply_limit_count {
+ let top_n: HashSet<_> = samples
+ .iter()
+ .sorted_by_key(|sample| (OrderedFloat(sample.value), &sample.name)) // name used as a tie-breaker.
+ .rev()
+ .map(|sample| sample.name.clone())
+ .take(settings.limit_count)
+ .collect();
+ Some(top_n)
+ } else {
+ None
+ };
+
+ samples.retain(|Sample { name, value }| {
+ let mut keep = true;
+ if settings.apply_limit_frac {
+ let frac = if settings.show_relative_total {
+ *value / filtered_total
+ } else {
+ *value / total
+ };
+ if frac < settings.limit_frac as f64 {
+ keep = false;
+ }
+ }
+
+ if let Some(top_n) = &top_n {
+ if !top_n.contains(name) {
+ keep = false;
+ }
+ }
+
+ if !keep {
+ other += value;
+ }
+ keep
+ });
+
+ if !settings.show_relative_total && other > 0.0 {
+ samples.push(Sample::new("Other".into(), other));
+ }
+ match settings.sort_by {
+ StatSortMode::ByName => samples
+ .into_iter()
+ .sorted_by_key(|sample| sample.name.clone())
+ .collect(),
+ StatSortMode::ByValue => samples
+ .into_iter()
+ .sorted_by_key(|sample| (OrderedFloat(sample.value), sample.name.clone())) // name used as a tie-breaker.
+ .collect(),
+ StatSortMode::Unsorted => samples,
+ }
+ }
+
+ pub fn calculate(&self, frame: &Frame, settings: &StatsSettings) -> Vec<Sample> {
+ let mapping = match self {
+ FrameStatistic::LumaModes => self.luma_modes(frame),
+ FrameStatistic::ChromaModes => self.chroma_modes(frame),
+ FrameStatistic::BlockSizes => self.block_sizes(frame),
+ FrameStatistic::Symbols => self.symbols(frame),
+ FrameStatistic::PartitionSplit => self.partition_split(frame, settings),
+ };
+ self.apply_settings(mapping, settings)
+ }
+
+ pub fn name(&self) -> &'static str {
+ match self {
+ FrameStatistic::LumaModes => "Luma modes",
+ FrameStatistic::ChromaModes => "Chroma modes",
+ FrameStatistic::BlockSizes => "Block sizes",
+ FrameStatistic::Symbols => "Symbols",
+ FrameStatistic::PartitionSplit => "Partition split",
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+
+ #[test]
+ fn test_exclude_filter() {
+ let mapping: HashMap<String, f64> = [("ABC", 1.0), ("AAA", 2.0)]
+ .iter()
+ .map(|(k, v)| (k.to_string(), *v))
+ .collect();
+ let settings = StatsSettings {
+ exclude_filter: "A".into(),
+ exclude_filter_exact_match: true,
+ show_relative_total: true,
+ ..Default::default()
+ };
+ let samples = FrameStatistic::Symbols.apply_settings(mapping, &settings);
+ assert_eq!(samples.len(), 2);
+ }
+}