blob: 8959e2a5d6c8ede738513fd6439e61380f3deacf [file] [log] [blame]
use crate::{
CodingUnit, Frame, FrameError, PartitionContext, Plane, PredictionParams, ProtoEnumMapping, SuperblockContext,
SuperblockLocator, SymbolContext, SymbolRange, TransformUnitContext, TransformUnitLocator,
};
use serde::{Deserialize, Serialize};
use FrameError::BadCodingUnit;
impl CodingUnit {
pub fn plane_index(&self, plane: Plane) -> Result<usize, FrameError> {
match plane {
Plane::Y => Ok(0),
Plane::U | Plane::V => {
match self.transform_planes.len() {
// Split luma and chroma partition trees
2 => Ok(plane.to_usize() - 1),
// Unified luma and chroma partition tree
3 => Ok(plane.to_usize()),
_ => Err(BadCodingUnit(format!(
"Unexpected number of transform planes: got {}, expected 2 or 3 for plane {plane:?}",
self.transform_planes.len()
))),
}
}
}
}
pub fn has_chroma(&self) -> Result<bool, FrameError> {
let num_transform_planes = self.transform_planes.len();
match num_transform_planes {
2 | 3 => Ok(true),
1 => Ok(false),
_ => Err(BadCodingUnit(format!(
"Unexpected number of transform planes: {num_transform_planes}"
))),
}
}
pub fn has_luma(&self) -> Result<bool, FrameError> {
let num_transform_planes = self.transform_planes.len();
match num_transform_planes {
1 | 3 => Ok(true),
2 => Ok(false),
_ => Err(BadCodingUnit(format!(
"Unexpected number of transform planes: {num_transform_planes}"
))),
}
}
pub fn get_prediction_mode(&self) -> Result<&PredictionParams, FrameError> {
self.prediction_mode
.as_ref()
.ok_or(BadCodingUnit("Missing prediction mode.".into()))
}
pub fn get_symbol_range(&self) -> Result<&SymbolRange, FrameError> {
self.symbol_range
.as_ref()
.ok_or(BadCodingUnit("Missing symbol range.".into()))
}
pub fn lookup_mode_name(&self, frame: &Frame) -> Result<String, FrameError> {
let mode = self.get_prediction_mode()?;
frame.enum_lookup(ProtoEnumMapping::PredictionMode, mode.mode)
}
pub fn luma_mode_angle_delta(&self, frame: &Frame) -> Option<i32> {
if let Ok(mode) = self.lookup_mode_name(frame) {
if mode.ends_with("_PRED") {
return self.prediction_mode.as_ref().map(|mode| mode.angle_delta);
}
}
None
}
pub fn lookup_uv_mode_name(&self, frame: &Frame) -> Result<String, FrameError> {
let mode = self.get_prediction_mode()?;
frame.enum_lookup(ProtoEnumMapping::UvPredictionMode, mode.uv_mode)
}
pub fn chroma_mode_angle_delta(&self, frame: &Frame) -> Option<i32> {
if let Ok(mode) = self.lookup_uv_mode_name(frame) {
if mode.ends_with("_PRED") {
return self.prediction_mode.as_ref().map(|mode| mode.uv_angle_delta);
}
}
None
}
pub fn lookup_motion_vector_precision_name(&self, frame: &Frame) -> Result<String, FrameError> {
let mode = self.get_prediction_mode()?;
frame.enum_lookup(ProtoEnumMapping::MotionVectorPrecision, mode.motion_vector_precision)
}
}
/// Which planes this coding unit contains.
#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub enum CodingUnitKind {
/// Coding unit contains all three planes. Equivalent to the shared partition tree type.
Shared,
/// Coding unit contains only luma.
LumaOnly,
/// Coding unit contains only chroma.
ChromaOnly,
}
/// Index of a coding unit within its parent superblock.
#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct CodingUnitLocator {
/// Index of parent superblock within the frame.
pub superblock: SuperblockLocator,
/// Either Shared or ChromaOnly. Note that LumaOnly refers to the same underlying buffer as Shared.
pub kind: CodingUnitKind,
/// Index of coding unit with its parent superblock.
pub index: usize,
}
impl CodingUnitLocator {
pub fn new(superblock: SuperblockLocator, kind: CodingUnitKind, index: usize) -> Self {
Self {
superblock,
kind,
index,
}
}
/// Convert this index into a `CodingUnitContext`.
pub fn try_resolve<'a>(&self, frame: &'a Frame) -> Option<CodingUnitContext<'a>> {
let superblock_context = self.superblock.try_resolve(frame)?;
let coding_unit = match self.kind {
CodingUnitKind::Shared => superblock_context.superblock.coding_units_shared.get(self.index),
CodingUnitKind::LumaOnly => superblock_context.superblock.coding_units_shared.get(self.index),
CodingUnitKind::ChromaOnly => superblock_context.superblock.coding_units_chroma.get(self.index),
};
coding_unit.map(|coding_unit| CodingUnitContext {
coding_unit,
superblock_context,
locator: *self,
})
}
pub fn resolve<'a>(&self, frame: &'a Frame) -> CodingUnitContext<'a> {
self.try_resolve(frame).unwrap()
}
}
/// Context about a coding unit during iteration.
#[derive(Copy, Clone)]
pub struct CodingUnitContext<'a> {
/// Coding unit being iterated over.
pub coding_unit: &'a CodingUnit,
/// Superblock that owns this coding unit.
pub superblock_context: SuperblockContext<'a>,
/// The index of this coding unit within its parent superblock.
pub locator: CodingUnitLocator,
}
impl<'a> CodingUnitContext<'a> {
pub fn iter_symbols(&self) -> impl Iterator<Item = SymbolContext<'a>> {
let symbol_range = self.coding_unit.symbol_range.clone().unwrap_or_default();
self.superblock_context.iter_symbols(Some(symbol_range))
}
pub fn total_bits(&self) -> f32 {
self.iter_symbols().map(|sym| sym.symbol.bits).sum()
}
pub fn iter_transform_units(self, plane: Plane) -> impl Iterator<Item = TransformUnitContext<'a>> {
let plane_index = self.coding_unit.plane_index(plane);
self.coding_unit.transform_planes[plane_index.unwrap()]
.transform_units
.iter()
.enumerate()
.map(move |(index, transform_unit)| TransformUnitContext {
transform_unit,
coding_unit_context: self,
locator: TransformUnitLocator::new(self.locator, plane, index),
})
}
pub fn find_parent_partition(&self) -> Option<PartitionContext<'a>> {
self.superblock_context
.root_partition(self.locator.kind)
.and_then(|root| root.find_coding_unit_parent(self.coding_unit))
}
}