blob: abcf84498a433f04ebff2374bff9a358c22342fd [file] [log] [blame]
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use crate::{CodingUnitKind, Frame, Plane, PlaneType, ProtoEnumMapping, Spatial};
#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub enum FrameStatistic {
LumaModes,
ChromaModes,
BlockSizes,
Symbols,
PartitionSplit,
}
#[derive(Default, Debug, Deserialize, Serialize, PartialEq)]
pub struct StatsFilter {
pub include: Vec<String>,
pub exclude: Vec<String>,
}
impl StatsFilter {
fn from_comma_separated(include: &str, exclude: &str) -> Self {
Self {
include: include
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
exclude: exclude
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub enum StatSortMode {
Unsorted,
ByName,
ByValue,
}
impl StatSortMode {
pub fn name(&self) -> &'static str {
match self {
Self::Unsorted => "Unsorted",
Self::ByName => "By name",
Self::ByValue => "By value",
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct StatsSettings {
pub sort_by: StatSortMode,
// Using separate bool + value fields rather than an Option to make the UI design a bit more intuitive (e.g. checkbox + disabled number input).
pub apply_limit_count: bool,
pub limit_count: usize,
pub apply_limit_frac: bool,
pub limit_frac: f32,
pub include_filter: String,
pub exclude_filter: String,
pub include_filter_exact_match: bool,
pub exclude_filter_exact_match: bool,
pub show_relative_total: bool,
// Comma separated list of block sizes to include for partition split stats.
pub partition_split_block_sizes: String,
}
impl Default for StatsSettings {
fn default() -> Self {
Self {
sort_by: StatSortMode::ByValue,
apply_limit_count: false,
limit_count: 20,
apply_limit_frac: false,
limit_frac: 0.01,
include_filter: "".into(),
exclude_filter: "".into(),
include_filter_exact_match: false,
exclude_filter_exact_match: false,
show_relative_total: false,
partition_split_block_sizes: "".into(),
}
}
}
#[derive(Debug)]
pub struct Sample {
pub name: String,
pub value: f64,
}
impl Sample {
pub fn new(name: String, value: f64) -> Self {
Self { name, value }
}
}
impl FrameStatistic {
fn luma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
let modes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
let cu = ctx.coding_unit;
let prediction_mode = cu.prediction_mode.as_ref().unwrap();
frame
.enum_lookup(ProtoEnumMapping::PredictionMode, prediction_mode.mode)
.unwrap_or("UNKNOWN".into())
});
let mut modes_map: HashMap<String, f64> = HashMap::new();
for mode in modes {
*modes_map.entry(mode).or_default() += 1.0;
}
modes_map
}
fn chroma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
let kind = frame.coding_unit_kind(PlaneType::Planar(Plane::U));
let modes = frame.iter_coding_units(kind).map(|ctx| {
let cu = ctx.coding_unit;
let prediction_mode = cu.prediction_mode.as_ref().unwrap();
frame
.enum_lookup(ProtoEnumMapping::UvPredictionMode, prediction_mode.uv_mode)
.unwrap_or("UNKNOWN".into())
});
let mut modes_map: HashMap<String, f64> = HashMap::new();
for mode in modes {
*modes_map.entry(mode).or_default() += 1.0;
}
modes_map
}
fn block_sizes(&self, frame: &Frame) -> HashMap<String, f64> {
let sizes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
let cu = ctx.coding_unit;
let w = cu.width();
let h = cu.height();
format!("{w}x{h}")
});
let mut sizes_map: HashMap<String, f64> = HashMap::new();
for size in sizes {
*sizes_map.entry(size).or_default() += 1.0;
}
sizes_map
}
fn partition_split(&self, frame: &Frame, settings: &StatsSettings) -> HashMap<String, f64> {
// TODO(comc): Add settings option for partition kind.
let filter = StatsFilter::from_comma_separated(&settings.partition_split_block_sizes, "");
let splits = frame.iter_partitions(CodingUnitKind::Shared).filter_map(|ctx| {
let partition = ctx.partition;
let size = partition.size_name();
if !filter.include.is_empty() && !filter.include.iter().any(|incl| &size == incl) {
return None;
}
let partition_type = frame
.enum_lookup(ProtoEnumMapping::PartitionType, partition.partition_type)
.unwrap_or("UNKNOWN".into());
Some(partition_type)
});
let mut splits_map: HashMap<String, f64> = HashMap::new();
for split in splits {
*splits_map.entry(split).or_default() += 1.0;
}
splits_map
}
fn symbols(&self, frame: &Frame) -> HashMap<String, f64> {
let mut symbols: HashMap<String, f64> = HashMap::new();
// TODO(comc): Use iter_symbols. Add iter_symbols method for partition blocks as well.
let sbs = frame.iter_superblocks().map(|sb_ctx| {
let sb = sb_ctx.superblock;
let mut symbols_sb: HashMap<String, f64> = HashMap::new();
for symbol in sb.symbols.iter() {
let info = symbol.info_id;
let info = &frame.symbol_info[&info];
let name = info.source_function.clone();
let bits = symbol.bits;
*symbols_sb.entry(name.clone()).or_default() += bits as f64;
}
symbols_sb
});
for symbols_sb in sbs {
for (name, bits) in symbols_sb {
*symbols.entry(name.clone()).or_default() += bits;
}
}
symbols
}
fn apply_settings(&self, mapping: HashMap<String, f64>, settings: &StatsSettings) -> Vec<Sample> {
let mut samples: Vec<_> = mapping
.into_iter()
.map(|(name, value)| Sample::new(name, value))
.collect();
let filter: StatsFilter = StatsFilter::from_comma_separated(&settings.include_filter, &settings.exclude_filter);
let total: f64 = samples.iter().map(|sample| sample.value).sum();
let mut other = 0.0;
samples.retain(|Sample { name, value }| {
let mut keep = true;
// TODO(comc): Make this a method of StatsFilter.
if !filter.include.is_empty() {
if settings.include_filter_exact_match {
if !filter.include.contains(name) {
keep = false;
}
} else if !filter.include.iter().any(|incl| name.contains(incl)) {
keep = false;
}
}
if settings.exclude_filter_exact_match {
if filter.exclude.contains(name) {
keep = false
}
} else if filter.exclude.iter().any(|excl| name.contains(excl)) {
keep = false;
}
if !keep {
other += value;
}
keep
});
let filtered_total: f64 = samples.iter().map(|sample| sample.value).sum();
let top_n = if settings.apply_limit_count {
let top_n: HashSet<_> = samples
.iter()
.sorted_by_key(|sample| (OrderedFloat(sample.value), &sample.name)) // name used as a tie-breaker.
.rev()
.map(|sample| sample.name.clone())
.take(settings.limit_count)
.collect();
Some(top_n)
} else {
None
};
samples.retain(|Sample { name, value }| {
let mut keep = true;
if settings.apply_limit_frac {
let frac = if settings.show_relative_total {
*value / filtered_total
} else {
*value / total
};
if frac < settings.limit_frac as f64 {
keep = false;
}
}
if let Some(top_n) = &top_n {
if !top_n.contains(name) {
keep = false;
}
}
if !keep {
other += value;
}
keep
});
if !settings.show_relative_total && other > 0.0 {
samples.push(Sample::new("Other".into(), other));
}
match settings.sort_by {
StatSortMode::ByName => samples
.into_iter()
.sorted_by_key(|sample| sample.name.clone())
.collect(),
StatSortMode::ByValue => samples
.into_iter()
.sorted_by_key(|sample| (OrderedFloat(sample.value), sample.name.clone())) // name used as a tie-breaker.
.collect(),
StatSortMode::Unsorted => samples,
}
}
pub fn calculate(&self, frame: &Frame, settings: &StatsSettings) -> Vec<Sample> {
let mapping = match self {
FrameStatistic::LumaModes => self.luma_modes(frame),
FrameStatistic::ChromaModes => self.chroma_modes(frame),
FrameStatistic::BlockSizes => self.block_sizes(frame),
FrameStatistic::Symbols => self.symbols(frame),
FrameStatistic::PartitionSplit => self.partition_split(frame, settings),
};
self.apply_settings(mapping, settings)
}
pub fn name(&self) -> &'static str {
match self {
FrameStatistic::LumaModes => "Luma modes",
FrameStatistic::ChromaModes => "Chroma modes",
FrameStatistic::BlockSizes => "Block sizes",
FrameStatistic::Symbols => "Symbols",
FrameStatistic::PartitionSplit => "Partition split",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exclude_filter() {
let mapping: HashMap<String, f64> = [("ABC", 1.0), ("AAA", 2.0)]
.iter()
.map(|(k, v)| (k.to_string(), *v))
.collect();
let settings = StatsSettings {
exclude_filter: "A".into(),
exclude_filter_exact_match: true,
show_relative_total: true,
..Default::default()
};
let samples = FrameStatistic::Symbols.apply_settings(mapping, &settings);
assert_eq!(samples.len(), 2);
}
}