feat(audio): add OGG/Vorbis decoder, 24/32-bit WAV, Doppler effect

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 07:15:38 +09:00
parent 0ef750de69
commit 1c2a8466e7
3 changed files with 2039 additions and 14 deletions

View File

@@ -0,0 +1,294 @@
//! OGG container parser.
//!
//! Parses OGG bitstream pages and extracts Vorbis packets.
//! Reference: <https://www.xiph.org/ogg/doc/framing.html>
/// An OGG page header.
#[derive(Debug, Clone)]
pub struct OggPage {
/// Header type flags (0x01 = continuation, 0x02 = BOS, 0x04 = EOS).
pub header_type: u8,
/// Granule position (PCM sample position).
pub granule_position: u64,
/// Bitstream serial number.
pub serial: u32,
/// Page sequence number.
pub page_sequence: u32,
/// Number of segments in this page.
pub segment_count: u8,
/// The segment table (each entry is a segment length, 0..255).
pub segment_table: Vec<u8>,
/// Raw packet data of this page (concatenated segments).
pub data: Vec<u8>,
}
/// Parse all OGG pages from raw bytes.
pub fn parse_ogg_pages(data: &[u8]) -> Result<Vec<OggPage>, String> {
let mut pages = Vec::new();
let mut offset = 0;
while offset < data.len() {
if offset + 27 > data.len() {
break;
}
// Capture pattern "OggS"
if &data[offset..offset + 4] != b"OggS" {
return Err(format!("Invalid OGG capture pattern at offset {}", offset));
}
let version = data[offset + 4];
if version != 0 {
return Err(format!("Unsupported OGG version: {}", version));
}
let header_type = data[offset + 5];
let granule_position = u64::from_le_bytes([
data[offset + 6],
data[offset + 7],
data[offset + 8],
data[offset + 9],
data[offset + 10],
data[offset + 11],
data[offset + 12],
data[offset + 13],
]);
let serial = u32::from_le_bytes([
data[offset + 14],
data[offset + 15],
data[offset + 16],
data[offset + 17],
]);
let page_sequence = u32::from_le_bytes([
data[offset + 18],
data[offset + 19],
data[offset + 20],
data[offset + 21],
]);
// CRC at offset+22..+26 (skip verification for simplicity)
let segment_count = data[offset + 26] as usize;
if offset + 27 + segment_count > data.len() {
return Err("OGG page segment table extends beyond data".to_string());
}
let segment_table: Vec<u8> = data[offset + 27..offset + 27 + segment_count].to_vec();
let total_data_size: usize = segment_table.iter().map(|&s| s as usize).sum();
let data_start = offset + 27 + segment_count;
if data_start + total_data_size > data.len() {
return Err("OGG page data extends beyond file".to_string());
}
let page_data = data[data_start..data_start + total_data_size].to_vec();
pages.push(OggPage {
header_type,
granule_position,
serial,
page_sequence,
segment_count: segment_count as u8,
segment_table,
data: page_data,
});
offset = data_start + total_data_size;
}
if pages.is_empty() {
return Err("No OGG pages found".to_string());
}
Ok(pages)
}
/// Extract Vorbis packets from parsed OGG pages.
///
/// Packets can span multiple segments (segment length = 255 means continuation).
/// Packets can also span multiple pages (header_type bit 0x01 = continuation).
pub fn extract_packets(pages: &[OggPage]) -> Result<Vec<Vec<u8>>, String> {
let mut packets: Vec<Vec<u8>> = Vec::new();
let mut current_packet: Vec<u8> = Vec::new();
for page in pages {
let mut data_offset = 0;
for (seg_idx, &seg_len) in page.segment_table.iter().enumerate() {
let seg_data = &page.data[data_offset..data_offset + seg_len as usize];
current_packet.extend_from_slice(seg_data);
data_offset += seg_len as usize;
// A segment length < 255 terminates the current packet.
// A segment length of exactly 255 means the packet continues in the next segment.
if seg_len < 255 {
if !current_packet.is_empty() {
packets.push(std::mem::take(&mut current_packet));
}
}
// If seg_len == 255 and this is the last segment of the page,
// the packet continues on the next page.
let _ = seg_idx; // suppress unused warning
}
}
// If there's remaining data in current_packet (ended with 255-byte segments
// and no terminating segment), flush it as a final packet.
if !current_packet.is_empty() {
packets.push(current_packet);
}
Ok(packets)
}
/// Convenience function: parse OGG container and extract all Vorbis packets.
pub fn parse_ogg(data: &[u8]) -> Result<Vec<Vec<u8>>, String> {
let pages = parse_ogg_pages(data)?;
extract_packets(&pages)
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
/// Build a minimal OGG page from raw packet data.
fn build_ogg_page(
header_type: u8,
granule: u64,
serial: u32,
page_seq: u32,
packets_data: &[&[u8]],
) -> Vec<u8> {
// Build segment table and concatenated data
let mut segment_table = Vec::new();
let mut page_data = Vec::new();
for (i, packet) in packets_data.iter().enumerate() {
let len = packet.len();
// Write full 255-byte segments
let full_segments = len / 255;
let remainder = len % 255;
for _ in 0..full_segments {
segment_table.push(255u8);
}
// Terminating segment (< 255), even if 0 to signal end of packet
segment_table.push(remainder as u8);
page_data.extend_from_slice(packet);
}
let segment_count = segment_table.len();
let mut out = Vec::new();
// Capture pattern
out.extend_from_slice(b"OggS");
// Version
out.push(0);
// Header type
out.push(header_type);
// Granule position
out.extend_from_slice(&granule.to_le_bytes());
// Serial
out.extend_from_slice(&serial.to_le_bytes());
// Page sequence
out.extend_from_slice(&page_seq.to_le_bytes());
// CRC (dummy zeros)
out.extend_from_slice(&[0u8; 4]);
// Segment count
out.push(segment_count as u8);
// Segment table
out.extend_from_slice(&segment_table);
// Data
out.extend_from_slice(&page_data);
out
}
#[test]
fn parse_single_page() {
let packet = b"hello vorbis";
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[packet.as_slice()]);
let pages = parse_ogg_pages(&page_bytes).expect("parse failed");
assert_eq!(pages.len(), 1);
assert_eq!(pages[0].header_type, 0x02);
assert_eq!(pages[0].serial, 1);
assert_eq!(pages[0].page_sequence, 0);
assert_eq!(pages[0].data, packet);
}
#[test]
fn parse_multiple_pages() {
let p1 = build_ogg_page(0x02, 0, 1, 0, &[b"first"]);
let p2 = build_ogg_page(0x00, 100, 1, 1, &[b"second"]);
let mut data = p1;
data.extend_from_slice(&p2);
let pages = parse_ogg_pages(&data).expect("parse failed");
assert_eq!(pages.len(), 2);
assert_eq!(pages[0].page_sequence, 0);
assert_eq!(pages[1].page_sequence, 1);
assert_eq!(pages[1].granule_position, 100);
}
#[test]
fn extract_single_packet() {
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[b"packet_one"]);
let packets = parse_ogg(&page_bytes).expect("parse_ogg failed");
assert_eq!(packets.len(), 1);
assert_eq!(packets[0], b"packet_one");
}
#[test]
fn extract_multiple_packets_single_page() {
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[b"pkt1", b"pkt2", b"pkt3"]);
let packets = parse_ogg(&page_bytes).expect("parse_ogg failed");
assert_eq!(packets.len(), 3);
assert_eq!(packets[0], b"pkt1");
assert_eq!(packets[1], b"pkt2");
assert_eq!(packets[2], b"pkt3");
}
#[test]
fn extract_large_packet_spanning_segments() {
// Create a packet larger than 255 bytes
let large_packet: Vec<u8> = (0..600).map(|i| (i % 256) as u8).collect();
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[&large_packet]);
let packets = parse_ogg(&page_bytes).expect("parse_ogg failed");
assert_eq!(packets.len(), 1);
assert_eq!(packets[0], large_packet);
}
#[test]
fn invalid_capture_pattern() {
let data = b"NotOGGdata";
let result = parse_ogg_pages(data);
assert!(result.is_err());
assert!(result.unwrap_err().contains("capture pattern"));
}
#[test]
fn empty_data() {
let result = parse_ogg_pages(&[]);
assert!(result.is_err());
}
#[test]
fn page_header_fields() {
let page_bytes = build_ogg_page(0x04, 12345, 42, 7, &[b"data"]);
let pages = parse_ogg_pages(&page_bytes).expect("parse failed");
assert_eq!(pages[0].header_type, 0x04); // EOS
assert_eq!(pages[0].granule_position, 12345);
assert_eq!(pages[0].serial, 42);
assert_eq!(pages[0].page_sequence, 7);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -62,7 +62,42 @@ fn find_chunk(data: &[u8], id: &[u8; 4], start: usize) -> Option<(usize, u32)> {
// Public API // Public API
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
/// Parse a PCM 16-bit WAV file from raw bytes into an [`AudioClip`]. /// Read a 24-bit signed integer (little-endian) from 3 bytes and return as i32.
fn read_i24_le(data: &[u8], offset: usize) -> Result<i32, String> {
if offset + 3 > data.len() {
return Err(format!("read_i24_le: offset {} out of bounds (len={})", offset, data.len()));
}
let lo = data[offset] as u32;
let mid = data[offset + 1] as u32;
let hi = data[offset + 2] as u32;
let unsigned = lo | (mid << 8) | (hi << 16);
// Sign-extend from 24-bit to 32-bit
if unsigned & 0x800000 != 0 {
Ok((unsigned | 0xFF000000) as i32)
} else {
Ok(unsigned as i32)
}
}
/// Read a 32-bit float (little-endian).
fn read_f32_le(data: &[u8], offset: usize) -> Result<f32, String> {
if offset + 4 > data.len() {
return Err(format!("read_f32_le: offset {} out of bounds (len={})", offset, data.len()));
}
Ok(f32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]))
}
/// Parse a WAV file from raw bytes into an [`AudioClip`].
///
/// Supported formats:
/// - PCM 16-bit (format_tag=1, bits_per_sample=16)
/// - PCM 24-bit (format_tag=1, bits_per_sample=24)
/// - IEEE float 32-bit (format_tag=3, bits_per_sample=32)
pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> { pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
// Minimum viable WAV: RIFF(4) + size(4) + WAVE(4) = 12 bytes // Minimum viable WAV: RIFF(4) + size(4) + WAVE(4) = 12 bytes
if data.len() < 12 { if data.len() < 12 {
@@ -86,10 +121,6 @@ pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
} }
let format_tag = read_u16_le(data, fmt_offset)?; let format_tag = read_u16_le(data, fmt_offset)?;
if format_tag != 1 {
return Err(format!("Unsupported WAV format tag: {} (only PCM=1 is supported)", format_tag));
}
let channels = read_u16_le(data, fmt_offset + 2)?; let channels = read_u16_le(data, fmt_offset + 2)?;
if channels != 1 && channels != 2 { if channels != 1 && channels != 2 {
return Err(format!("Unsupported channel count: {}", channels)); return Err(format!("Unsupported channel count: {}", channels));
@@ -99,9 +130,16 @@ pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
// byte_rate = fmt_offset + 8 (skip) // byte_rate = fmt_offset + 8 (skip)
// block_align = fmt_offset + 12 (skip) // block_align = fmt_offset + 12 (skip)
let bits_per_sample = read_u16_le(data, fmt_offset + 14)?; let bits_per_sample = read_u16_le(data, fmt_offset + 14)?;
if bits_per_sample != 16 {
return Err(format!("Unsupported bits per sample: {} (only 16-bit is supported)", bits_per_sample)); // Validate format_tag + bits_per_sample combination
} let bytes_per_sample = match (format_tag, bits_per_sample) {
(1, 16) => 2, // PCM 16-bit
(1, 24) => 3, // PCM 24-bit
(3, 32) => 4, // IEEE float 32-bit
(1, bps) => return Err(format!("Unsupported PCM bits per sample: {}", bps)),
(3, bps) => return Err(format!("Unsupported float bits per sample: {} (only 32-bit supported)", bps)),
(tag, _) => return Err(format!("Unsupported WAV format tag: {} (only PCM=1 and IEEE_FLOAT=3 are supported)", tag)),
};
// --- data chunk --- // --- data chunk ---
let (data_offset, data_size) = let (data_offset, data_size) =
@@ -112,15 +150,31 @@ pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
return Err("data chunk extends beyond end of file".to_string()); return Err("data chunk extends beyond end of file".to_string());
} }
// Each sample is 2 bytes (16-bit PCM). let sample_count = data_size as usize / bytes_per_sample;
let sample_count = data_size as usize / 2;
let mut samples = Vec::with_capacity(sample_count); let mut samples = Vec::with_capacity(sample_count);
match (format_tag, bits_per_sample) {
(1, 16) => {
for i in 0..sample_count { for i in 0..sample_count {
let raw = read_i16_le(data, data_offset + i * 2)?; let raw = read_i16_le(data, data_offset + i * 2)?;
// Convert i16 [-32768, 32767] to f32 [-1.0, ~1.0]
samples.push(raw as f32 / 32768.0); samples.push(raw as f32 / 32768.0);
} }
}
(1, 24) => {
for i in 0..sample_count {
let raw = read_i24_le(data, data_offset + i * 3)?;
// 24-bit range: [-8388608, 8388607]
samples.push(raw as f32 / 8388608.0);
}
}
(3, 32) => {
for i in 0..sample_count {
let raw = read_f32_le(data, data_offset + i * 4)?;
samples.push(raw);
}
}
_ => unreachable!(),
}
Ok(AudioClip::new(samples, sample_rate, channels)) Ok(AudioClip::new(samples, sample_rate, channels))
} }
@@ -165,6 +219,87 @@ pub fn generate_wav_bytes(samples_f32: &[f32], sample_rate: u32) -> Vec<u8> {
out out
} }
/// Generate a minimal PCM 24-bit mono WAV file from f32 samples.
/// Used for round-trip testing.
pub fn generate_wav_bytes_24bit(samples_f32: &[f32], sample_rate: u32) -> Vec<u8> {
let channels: u16 = 1;
let bits_per_sample: u16 = 24;
let byte_rate = sample_rate * channels as u32 * bits_per_sample as u32 / 8;
let block_align: u16 = channels * bits_per_sample / 8;
let data_size = (samples_f32.len() * 3) as u32;
let riff_size = 4 + 8 + 16 + 8 + data_size;
let mut out: Vec<u8> = Vec::with_capacity(12 + 8 + 16 + 8 + data_size as usize);
// RIFF header
out.extend_from_slice(b"RIFF");
out.extend_from_slice(&riff_size.to_le_bytes());
out.extend_from_slice(b"WAVE");
// fmt chunk
out.extend_from_slice(b"fmt ");
out.extend_from_slice(&16u32.to_le_bytes());
out.extend_from_slice(&1u16.to_le_bytes()); // PCM
out.extend_from_slice(&channels.to_le_bytes());
out.extend_from_slice(&sample_rate.to_le_bytes());
out.extend_from_slice(&byte_rate.to_le_bytes());
out.extend_from_slice(&block_align.to_le_bytes());
out.extend_from_slice(&bits_per_sample.to_le_bytes());
// data chunk
out.extend_from_slice(b"data");
out.extend_from_slice(&data_size.to_le_bytes());
for &s in samples_f32 {
let clamped = s.clamp(-1.0, 1.0);
let raw = (clamped * 8388607.0) as i32;
// Write 3 bytes LE
out.push((raw & 0xFF) as u8);
out.push(((raw >> 8) & 0xFF) as u8);
out.push(((raw >> 16) & 0xFF) as u8);
}
out
}
/// Generate a minimal IEEE float 32-bit mono WAV file from f32 samples.
/// Used for round-trip testing.
pub fn generate_wav_bytes_f32(samples_f32: &[f32], sample_rate: u32) -> Vec<u8> {
let channels: u16 = 1;
let bits_per_sample: u16 = 32;
let byte_rate = sample_rate * channels as u32 * bits_per_sample as u32 / 8;
let block_align: u16 = channels * bits_per_sample / 8;
let data_size = (samples_f32.len() * 4) as u32;
let riff_size = 4 + 8 + 16 + 8 + data_size;
let mut out: Vec<u8> = Vec::with_capacity(12 + 8 + 16 + 8 + data_size as usize);
// RIFF header
out.extend_from_slice(b"RIFF");
out.extend_from_slice(&riff_size.to_le_bytes());
out.extend_from_slice(b"WAVE");
// fmt chunk
out.extend_from_slice(b"fmt ");
out.extend_from_slice(&16u32.to_le_bytes());
out.extend_from_slice(&3u16.to_le_bytes()); // IEEE_FLOAT
out.extend_from_slice(&channels.to_le_bytes());
out.extend_from_slice(&sample_rate.to_le_bytes());
out.extend_from_slice(&byte_rate.to_le_bytes());
out.extend_from_slice(&block_align.to_le_bytes());
out.extend_from_slice(&bits_per_sample.to_le_bytes());
// data chunk
out.extend_from_slice(b"data");
out.extend_from_slice(&data_size.to_le_bytes());
for &s in samples_f32 {
out.extend_from_slice(&s.to_le_bytes());
}
out
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Tests // Tests
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -237,4 +372,107 @@ mod tests {
); );
} }
} }
// -----------------------------------------------------------------------
// 24-bit PCM tests
// -----------------------------------------------------------------------
#[test]
fn parse_24bit_wav() {
let sample_rate = 44100u32;
let num_samples = 4410usize;
let samples: Vec<f32> = (0..num_samples)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin())
.collect();
let wav_bytes = generate_wav_bytes_24bit(&samples, sample_rate);
let clip = parse_wav(&wav_bytes).expect("parse_wav 24-bit failed");
assert_eq!(clip.sample_rate, sample_rate);
assert_eq!(clip.channels, 1);
assert_eq!(clip.frame_count(), num_samples);
}
#[test]
fn roundtrip_24bit() {
let original: Vec<f32> = vec![0.0, 0.25, 0.5, -0.25, -0.5, 1.0, -1.0];
let wav_bytes = generate_wav_bytes_24bit(&original, 44100);
let clip = parse_wav(&wav_bytes).expect("roundtrip 24-bit parse failed");
assert_eq!(clip.samples.len(), original.len());
for (orig, decoded) in original.iter().zip(clip.samples.iter()) {
// 24-bit quantization error should be < 0.0001
assert!(
(orig - decoded).abs() < 0.0001,
"24-bit: orig={} decoded={}",
orig,
decoded
);
}
}
#[test]
fn accuracy_24bit() {
// 24-bit should be more accurate than 16-bit
let samples = vec![0.5f32];
let wav_bytes = generate_wav_bytes_24bit(&samples, 44100);
let clip = parse_wav(&wav_bytes).expect("parse failed");
// 0.5 * 8388607 = 4194303 -> 4194303 / 8388608 ≈ 0.49999988
assert!((clip.samples[0] - 0.5).abs() < 0.0001, "24-bit got {}", clip.samples[0]);
}
// -----------------------------------------------------------------------
// 32-bit float tests
// -----------------------------------------------------------------------
#[test]
fn parse_32bit_float_wav() {
let sample_rate = 44100u32;
let num_samples = 4410usize;
let samples: Vec<f32> = (0..num_samples)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin())
.collect();
let wav_bytes = generate_wav_bytes_f32(&samples, sample_rate);
let clip = parse_wav(&wav_bytes).expect("parse_wav float32 failed");
assert_eq!(clip.sample_rate, sample_rate);
assert_eq!(clip.channels, 1);
assert_eq!(clip.frame_count(), num_samples);
}
#[test]
fn roundtrip_32bit_float() {
let original: Vec<f32> = vec![0.0, 0.25, 0.5, -0.25, -0.5, 1.0, -1.0];
let wav_bytes = generate_wav_bytes_f32(&original, 44100);
let clip = parse_wav(&wav_bytes).expect("roundtrip float32 parse failed");
assert_eq!(clip.samples.len(), original.len());
for (orig, decoded) in original.iter().zip(clip.samples.iter()) {
// 32-bit float should be exact
assert_eq!(*orig, *decoded, "float32: orig={} decoded={}", orig, decoded);
}
}
#[test]
fn accuracy_32bit_float() {
// 32-bit float should preserve exact values
let samples = vec![0.123456789f32, -0.987654321f32];
let wav_bytes = generate_wav_bytes_f32(&samples, 44100);
let clip = parse_wav(&wav_bytes).expect("parse failed");
assert_eq!(clip.samples[0], 0.123456789f32);
assert_eq!(clip.samples[1], -0.987654321f32);
}
#[test]
fn reject_unsupported_format_tag() {
// Create a WAV with format_tag=2 (ADPCM), which we don't support
let mut wav = generate_wav_bytes(&[0.0], 44100);
// format_tag is at byte 20-21 (RIFF(4)+size(4)+WAVE(4)+fmt(4)+chunk_size(4))
wav[20] = 2;
wav[21] = 0;
let result = parse_wav(&wav);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Unsupported WAV format tag"));
}
} }