blob: abcf84498a433f04ebff2374bff9a358c22342fd [file] [log] [blame]
Conor McCullough1b651f32024-01-25 02:50:55 -08001use std::collections::{HashMap, HashSet};
2
3use itertools::Itertools;
4
5use ordered_float::OrderedFloat;
6use serde::{Deserialize, Serialize};
7
8use crate::{CodingUnitKind, Frame, Plane, PlaneType, ProtoEnumMapping, Spatial};
9
10#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
11pub enum FrameStatistic {
12 LumaModes,
13 ChromaModes,
14 BlockSizes,
15 Symbols,
16 PartitionSplit,
17}
18
19#[derive(Default, Debug, Deserialize, Serialize, PartialEq)]
20pub struct StatsFilter {
21 pub include: Vec<String>,
22 pub exclude: Vec<String>,
23}
24
25impl StatsFilter {
26 fn from_comma_separated(include: &str, exclude: &str) -> Self {
27 Self {
28 include: include
29 .split(',')
30 .map(|s| s.trim().to_string())
31 .filter(|s| !s.is_empty())
32 .collect(),
33 exclude: exclude
34 .split(',')
35 .map(|s| s.trim().to_string())
36 .filter(|s| !s.is_empty())
37 .collect(),
38 }
39 }
40}
41
42#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
43pub enum StatSortMode {
44 Unsorted,
45 ByName,
46 ByValue,
47}
48impl StatSortMode {
49 pub fn name(&self) -> &'static str {
50 match self {
51 Self::Unsorted => "Unsorted",
52 Self::ByName => "By name",
53 Self::ByValue => "By value",
54 }
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
59pub struct StatsSettings {
60 pub sort_by: StatSortMode,
61 // Using separate bool + value fields rather than an Option to make the UI design a bit more intuitive (e.g. checkbox + disabled number input).
62 pub apply_limit_count: bool,
63 pub limit_count: usize,
64 pub apply_limit_frac: bool,
65 pub limit_frac: f32,
66 pub include_filter: String,
67 pub exclude_filter: String,
68 pub include_filter_exact_match: bool,
69 pub exclude_filter_exact_match: bool,
70 pub show_relative_total: bool,
71 // Comma separated list of block sizes to include for partition split stats.
72 pub partition_split_block_sizes: String,
73}
74
75impl Default for StatsSettings {
76 fn default() -> Self {
77 Self {
78 sort_by: StatSortMode::ByValue,
79 apply_limit_count: false,
80 limit_count: 20,
81 apply_limit_frac: false,
82 limit_frac: 0.01,
83 include_filter: "".into(),
84 exclude_filter: "".into(),
85 include_filter_exact_match: false,
86 exclude_filter_exact_match: false,
87 show_relative_total: false,
88 partition_split_block_sizes: "".into(),
89 }
90 }
91}
92
93#[derive(Debug)]
94pub struct Sample {
95 pub name: String,
96 pub value: f64,
97}
98impl Sample {
99 pub fn new(name: String, value: f64) -> Self {
100 Self { name, value }
101 }
102}
103
104impl FrameStatistic {
105 fn luma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
106 let modes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
107 let cu = ctx.coding_unit;
108 let prediction_mode = cu.prediction_mode.as_ref().unwrap();
109 frame
110 .enum_lookup(ProtoEnumMapping::PredictionMode, prediction_mode.mode)
111 .unwrap_or("UNKNOWN".into())
112 });
113 let mut modes_map: HashMap<String, f64> = HashMap::new();
114
115 for mode in modes {
116 *modes_map.entry(mode).or_default() += 1.0;
117 }
118
119 modes_map
120 }
121
122 fn chroma_modes(&self, frame: &Frame) -> HashMap<String, f64> {
123 let kind = frame.coding_unit_kind(PlaneType::Planar(Plane::U));
124 let modes = frame.iter_coding_units(kind).map(|ctx| {
125 let cu = ctx.coding_unit;
126 let prediction_mode = cu.prediction_mode.as_ref().unwrap();
127 frame
128 .enum_lookup(ProtoEnumMapping::UvPredictionMode, prediction_mode.uv_mode)
129 .unwrap_or("UNKNOWN".into())
130 });
131 let mut modes_map: HashMap<String, f64> = HashMap::new();
132
133 for mode in modes {
134 *modes_map.entry(mode).or_default() += 1.0;
135 }
136
137 modes_map
138 }
139
140 fn block_sizes(&self, frame: &Frame) -> HashMap<String, f64> {
141 let sizes = frame.iter_coding_units(CodingUnitKind::Shared).map(|ctx| {
142 let cu = ctx.coding_unit;
143 let w = cu.width();
144 let h = cu.height();
145 format!("{w}x{h}")
146 });
147 let mut sizes_map: HashMap<String, f64> = HashMap::new();
148
149 for size in sizes {
150 *sizes_map.entry(size).or_default() += 1.0;
151 }
152 sizes_map
153 }
154
155 fn partition_split(&self, frame: &Frame, settings: &StatsSettings) -> HashMap<String, f64> {
156 // TODO(comc): Add settings option for partition kind.
157 let filter = StatsFilter::from_comma_separated(&settings.partition_split_block_sizes, "");
158 let splits = frame.iter_partitions(CodingUnitKind::Shared).filter_map(|ctx| {
159 let partition = ctx.partition;
160 let size = partition.size_name();
161 if !filter.include.is_empty() && !filter.include.iter().any(|incl| &size == incl) {
162 return None;
163 }
164 let partition_type = frame
165 .enum_lookup(ProtoEnumMapping::PartitionType, partition.partition_type)
166 .unwrap_or("UNKNOWN".into());
167 Some(partition_type)
168 });
169 let mut splits_map: HashMap<String, f64> = HashMap::new();
170
171 for split in splits {
172 *splits_map.entry(split).or_default() += 1.0;
173 }
174 splits_map
175 }
176
177 fn symbols(&self, frame: &Frame) -> HashMap<String, f64> {
178 let mut symbols: HashMap<String, f64> = HashMap::new();
179 // TODO(comc): Use iter_symbols. Add iter_symbols method for partition blocks as well.
180 let sbs = frame.iter_superblocks().map(|sb_ctx| {
181 let sb = sb_ctx.superblock;
182 let mut symbols_sb: HashMap<String, f64> = HashMap::new();
183 for symbol in sb.symbols.iter() {
184 let info = symbol.info_id;
185 let info = &frame.symbol_info[&info];
186 let name = info.source_function.clone();
187 let bits = symbol.bits;
188 *symbols_sb.entry(name.clone()).or_default() += bits as f64;
189 }
190 symbols_sb
191 });
192 for symbols_sb in sbs {
193 for (name, bits) in symbols_sb {
194 *symbols.entry(name.clone()).or_default() += bits;
195 }
196 }
197 symbols
198 }
199
200 fn apply_settings(&self, mapping: HashMap<String, f64>, settings: &StatsSettings) -> Vec<Sample> {
201 let mut samples: Vec<_> = mapping
202 .into_iter()
203 .map(|(name, value)| Sample::new(name, value))
204 .collect();
205 let filter: StatsFilter = StatsFilter::from_comma_separated(&settings.include_filter, &settings.exclude_filter);
206 let total: f64 = samples.iter().map(|sample| sample.value).sum();
207 let mut other = 0.0;
208 samples.retain(|Sample { name, value }| {
209 let mut keep = true;
210 // TODO(comc): Make this a method of StatsFilter.
211 if !filter.include.is_empty() {
212 if settings.include_filter_exact_match {
213 if !filter.include.contains(name) {
214 keep = false;
215 }
216 } else if !filter.include.iter().any(|incl| name.contains(incl)) {
217 keep = false;
218 }
219 }
220 if settings.exclude_filter_exact_match {
221 if filter.exclude.contains(name) {
222 keep = false
223 }
224 } else if filter.exclude.iter().any(|excl| name.contains(excl)) {
225 keep = false;
226 }
227 if !keep {
228 other += value;
229 }
230 keep
231 });
232
233 let filtered_total: f64 = samples.iter().map(|sample| sample.value).sum();
234
235 let top_n = if settings.apply_limit_count {
236 let top_n: HashSet<_> = samples
237 .iter()
238 .sorted_by_key(|sample| (OrderedFloat(sample.value), &sample.name)) // name used as a tie-breaker.
239 .rev()
240 .map(|sample| sample.name.clone())
241 .take(settings.limit_count)
242 .collect();
243 Some(top_n)
244 } else {
245 None
246 };
247
248 samples.retain(|Sample { name, value }| {
249 let mut keep = true;
250 if settings.apply_limit_frac {
251 let frac = if settings.show_relative_total {
252 *value / filtered_total
253 } else {
254 *value / total
255 };
256 if frac < settings.limit_frac as f64 {
257 keep = false;
258 }
259 }
260
261 if let Some(top_n) = &top_n {
262 if !top_n.contains(name) {
263 keep = false;
264 }
265 }
266
267 if !keep {
268 other += value;
269 }
270 keep
271 });
272
273 if !settings.show_relative_total && other > 0.0 {
274 samples.push(Sample::new("Other".into(), other));
275 }
276 match settings.sort_by {
277 StatSortMode::ByName => samples
278 .into_iter()
279 .sorted_by_key(|sample| sample.name.clone())
280 .collect(),
281 StatSortMode::ByValue => samples
282 .into_iter()
283 .sorted_by_key(|sample| (OrderedFloat(sample.value), sample.name.clone())) // name used as a tie-breaker.
284 .collect(),
285 StatSortMode::Unsorted => samples,
286 }
287 }
288
289 pub fn calculate(&self, frame: &Frame, settings: &StatsSettings) -> Vec<Sample> {
290 let mapping = match self {
291 FrameStatistic::LumaModes => self.luma_modes(frame),
292 FrameStatistic::ChromaModes => self.chroma_modes(frame),
293 FrameStatistic::BlockSizes => self.block_sizes(frame),
294 FrameStatistic::Symbols => self.symbols(frame),
295 FrameStatistic::PartitionSplit => self.partition_split(frame, settings),
296 };
297 self.apply_settings(mapping, settings)
298 }
299
300 pub fn name(&self) -> &'static str {
301 match self {
302 FrameStatistic::LumaModes => "Luma modes",
303 FrameStatistic::ChromaModes => "Chroma modes",
304 FrameStatistic::BlockSizes => "Block sizes",
305 FrameStatistic::Symbols => "Symbols",
306 FrameStatistic::PartitionSplit => "Partition split",
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313
314 use super::*;
315
316 #[test]
317 fn test_exclude_filter() {
318 let mapping: HashMap<String, f64> = [("ABC", 1.0), ("AAA", 2.0)]
319 .iter()
320 .map(|(k, v)| (k.to_string(), *v))
321 .collect();
322 let settings = StatsSettings {
323 exclude_filter: "A".into(),
324 exclude_filter_exact_match: true,
325 show_relative_total: true,
326 ..Default::default()
327 };
328 let samples = FrameStatistic::Symbols.apply_settings(mapping, &settings);
329 assert_eq!(samples.len(), 2);
330 }
331}