blob: 42ee7d05861ec6dbf4b618bdc3addcddd9e91c89 [file] [log] [blame]
#!/usr/bin/env python3
## Copyright (c) 2023, Alliance for Open Media. All rights reserved
##
## This source code is subject to the terms of the BSD 3-Clause Clear License and the
## Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear License was
## not distributed with this source code in the LICENSE file, you can obtain it
## at aomedia.org/license/software-license/bsd-3-c-c/. If the Alliance for Open Media Patent
## License 1.0 was not distributed with this source code in the PATENTS file, you
## can obtain it at aomedia.org/license/patent-license/.
##
"""Helper functions for working with YUV files from Codec ML Colab notebooks."""
import dataclasses
import numpy as np
# This byte sequence marks the end of the header and the start of frame data
# within a .y4m file.
_Y4M_FRAME_MARKER = b"FRAME\n"
# TODO(comc): More sophisticated colorspace handling.
def yuv_to_rgb(
y: np.ndarray, u: np.ndarray, v: np.ndarray, bit_depth: int = 8
) -> np.ndarray:
"""Converts individual Y, U, V planes into a single RGB array.
matplotlib can't show YUVs directly, so this helper function lets us visualize
all three planes at once with pyplot.imshow(). Note that this function does
not compensate for chroma subsampling. If necessary, the chroma planes should
be upscaled beforehand.
Args:
y: (H, W) shaped luma plane.
u: (H, W) shaped chroma U plane.
v: (H, W) shaped chroma V plane.
bit_depth: image bit-depth (8 or 10 bits per pixel).
Returns:
(H, W, 3) shaped RGB image. dtype is uint8 for bit_depth=8, and uint16
otherwise.
"""
if bit_depth not in (8, 10):
raise ValueError(
f"Only 8-bit and 10-bit YUVs are supported: bit_depth={bit_depth}."
)
y_max = (1 << bit_depth) - 1
# Scale luma to [0, 1] and chroma to [-0.5, 0.5]
y = y.astype(np.float32) / y_max
u = u.astype(np.float32) / y_max - 0.5
v = v.astype(np.float32) / y_max - 0.5
# RGB conversion matrix from:
# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.470
# It is possible to end up with values outside the range [0.0, 1.0] after this
# transformation, so clamp them.
r = np.clip(y + 1.13983 * v, 0, 1)
g = np.clip(y - 0.39465 * u - 0.58060 * v, 0, 1)
b = np.clip(y + 2.03211 * u, 0, 1)
rgb = np.stack((r, g, b), axis=2) * y_max
dtype = np.uint8 if bit_depth == 8 else np.uint16
return rgb.astype(dtype)
def upscale(plane: np.ndarray, factor: int = 2) -> np.ndarray:
"""Upscale a 2d array, e.g. to compensate for chroma subsampling.
For example, `upscale(np.array([[1, 2], [3, 4]]), 2)` returns:
```
array([[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4]])
```
Args:
plane: (H, W) shaped ndarray.
factor: subsampling factor to compensate for.
Returns:
Upsampled array.
"""
return plane.repeat(factor, axis=0).repeat(factor, axis=1)
def _plane_size_420(width: int, height: int, *, is_chroma: bool) -> int:
"""Calculates how many samples make up a single plane.
Note that this is specifically for 4:2:0 chroma subsampled planes, i.e. the U
and V chroma planes will have 1/4th the samples of the luma plane.
Args:
width: Width of the plane in pixels.
height: Height of the plane in pixels.
is_chroma: If True, scale dimensions by a factor of 2 to compensate for
chroma subsampling.
Returns:
Number of samples that make up this plane.
"""
if is_chroma:
chroma_width = (width + 1) // 2
chroma_height = (height + 1) // 2
return chroma_width * chroma_height
return width * height
def _frame_size_420(width: int, height: int) -> int:
"""Calculates how many samples make up one frame.
Note that this is specifically for 4:2:0 chroma subsampled images.
Args:
width: Width of the image in pixels.
height: Height of the image in pixels.
Returns:
Number of samples that make up this image (i.e. the sum of the luma
and chroma planes).
"""
luma_size = _plane_size_420(width, height, is_chroma=False)
chroma_size = _plane_size_420(width, height, is_chroma=True)
return luma_size + 2 * chroma_size
class Yuv420:
"""Represents a single YUV 4:2:0 frame.
Attributes:
width: Frame width in pixels.
height: Frame height in pixels.
bit_depth: Bits per sample, either 8 or 10.
y: numpy array storing luma plane data.
u: numpy array storing chroma U plane data.
v: numpy array storing chroma V plane data.
rgb: numpy array of this frame converted to RGB.
"""
def __init__(
self,
raw: np.ndarray,
width: int,
height: int,
bit_depth: int = 8,
offset: int = 0,
):
self.width = width
self.height = height
if bit_depth not in (8, 10):
raise ValueError(
f"Only 8-bit and 10-bit YUVs are supported: bit_depth={bit_depth}."
)
self.bit_depth = bit_depth
chroma_width = (width + 1) // 2
chroma_height = (height + 1) // 2
luma_size = _plane_size_420(width, height, is_chroma=False)
chroma_size = _plane_size_420(width, height, is_chroma=True)
u_offset = offset + luma_size
v_offset = u_offset + chroma_size
self.y = raw[offset : offset + luma_size].reshape((height, width))
self.u = raw[u_offset : u_offset + chroma_size].reshape(
(chroma_height, chroma_width)
)
self.v = raw[v_offset : v_offset + chroma_size].reshape(
(chroma_height, chroma_width)
)
uu = upscale(self.u, 2)
vv = upscale(self.v, 2)
self.rgb = yuv_to_rgb(self.y, uu, vv, bit_depth)
@dataclasses.dataclass
class YuvSequence:
"""Represents a sequence of YUV frames."""
yuvs: list[Yuv420]
def parse_raw_yuv(
yuv_path: str, width: int, height: int, num_frames: int, bit_depth: int = 8
) -> YuvSequence:
"""Parse a raw YUV file.
Args:
yuv_path: Local file path to .yuv file.
width: Width of YUV in pixels.
height: Height of YUV in pixels.
num_frames: Number of frames to read from YUV file.
bit_depth: Bits per sample, either 8 or 10.
Returns:
YuvSequence of length num_frames.
"""
yuvs = []
with open(yuv_path, "rb") as f:
frame_size = _frame_size_420(width, height)
bytes_per_pixel = 1 if bit_depth == 8 else 2
frame_size_bytes = frame_size * bytes_per_pixel
for _ in range(num_frames):
dtype = np.uint8 if bit_depth == 8 else np.uint16
raw_bytes = f.read(frame_size_bytes)
raw = np.frombuffer(raw_bytes, dtype=dtype)
yuv = Yuv420(raw, width, height, bit_depth)
yuvs.append(yuv)
return YuvSequence(yuvs)
def parse_y4m(
y4m_path: str, width: int, height: int, num_frames: int, bit_depth: int = 8
) -> YuvSequence:
"""Parse a .y4m file.
Args:
y4m_path: Local file path to .y4m file.
width: Width of YUV in pixels.
height: Height of YUV in pixels.
num_frames: Number of frames to read from YUV file.
bit_depth: Bits per sample, either 8 or 10.
Returns:
YuvSequence of length num_frames.
"""
with open(y4m_path, "rb") as f:
# TODO(comc): Read one frame at a time.
raw_bytes = f.read()
yuvs = []
frame_size = _frame_size_420(width, height)
offset = 0
for _ in range(num_frames):
offset = raw_bytes.find(_Y4M_FRAME_MARKER, offset)
offset += len(_Y4M_FRAME_MARKER)
dtype = np.uint8 if bit_depth == 8 else np.uint16
raw = np.frombuffer(raw_bytes, dtype=dtype, count=frame_size, offset=offset)
yuv = Yuv420(raw, width, height, bit_depth)
yuvs.append(yuv)
return YuvSequence(yuvs)