diff --git a/crates/voltex_renderer/src/rt_accel.rs b/crates/voltex_renderer/src/rt_accel.rs new file mode 100644 index 0000000..7e63552 --- /dev/null +++ b/crates/voltex_renderer/src/rt_accel.rs @@ -0,0 +1,195 @@ +use crate::vertex::MeshVertex; + +/// Data needed to build a BLAS for one mesh. +pub struct BlasMeshData<'a> { + pub vertex_buffer: &'a wgpu::Buffer, + pub index_buffer: &'a wgpu::Buffer, + pub vertex_count: u32, + pub index_count: u32, +} + +/// One instance transform fed to the TLAS: a world transform and a BLAS index. +pub struct RtInstance { + /// Column-major 4x4 transform matrix. + pub transform: [f32; 16], + /// Index into `RtAccel::blas_list`. + pub blas_index: usize, +} + +/// Bottom + Top Level Acceleration Structures for a scene. +pub struct RtAccel { + pub blas_list: Vec, + pub tlas: wgpu::Tlas, +} + +impl RtAccel { + /// Create BLAS for each mesh and build a TLAS with the given instances. + /// + /// The encoder must be submitted after this call so the GPU builds fire. + pub fn new( + device: &wgpu::Device, + encoder: &mut wgpu::CommandEncoder, + meshes: &[BlasMeshData<'_>], + instances: &[RtInstance], + ) -> Self { + let vertex_stride = std::mem::size_of::() as u64; + + // ── Build one BLAS per mesh ─────────────────────────────────────────── + let size_descs: Vec = meshes + .iter() + .map(|m| wgpu::BlasTriangleGeometrySizeDescriptor { + vertex_format: wgpu::VertexFormat::Float32x3, + vertex_count: m.vertex_count, + index_format: Some(wgpu::IndexFormat::Uint32), + index_count: Some(m.index_count), + flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE, + }) + .collect(); + + let blas_list: Vec = meshes + .iter() + .zip(size_descs.iter()) + .map(|(_mesh, size_desc)| { + device.create_blas( + &wgpu::CreateBlasDescriptor { + label: Some("Mesh BLAS"), + flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE, + update_mode: wgpu::AccelerationStructureUpdateMode::Build, + }, + wgpu::BlasGeometrySizeDescriptors::Triangles { + descriptors: vec![size_desc.clone()], + }, + ) + }) + .collect(); + + // ── Create TLAS ─────────────────────────────────────────────────────── + let max_instances = instances.len().max(1) as u32; + let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor { + label: Some("Scene TLAS"), + max_instances, + flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE, + update_mode: wgpu::AccelerationStructureUpdateMode::Build, + }); + + // ── Populate TLAS instances ─────────────────────────────────────────── + for (i, inst) in instances.iter().enumerate() { + let blas = &blas_list[inst.blas_index]; + tlas[i] = Some(wgpu::TlasInstance::new( + blas, + mat4_to_tlas_transform(&inst.transform), + 0, + 0xFF, + )); + } + + // ── Build entries ───────────────────────────────────────────────────── + let blas_entries: Vec> = meshes + .iter() + .zip(blas_list.iter()) + .zip(size_descs.iter()) + .map(|((mesh, blas), size_desc)| wgpu::BlasBuildEntry { + blas, + geometry: wgpu::BlasGeometries::TriangleGeometries(vec![ + wgpu::BlasTriangleGeometry { + size: size_desc, + vertex_buffer: mesh.vertex_buffer, + first_vertex: 0, + vertex_stride, + index_buffer: Some(mesh.index_buffer), + first_index: Some(0), + transform_buffer: None, + transform_buffer_offset: None, + }, + ]), + }) + .collect(); + + encoder.build_acceleration_structures(blas_entries.iter(), std::iter::once(&tlas)); + + Self { blas_list, tlas } + } + + /// Update TLAS instance transforms and rebuild. + pub fn update_instances( + &mut self, + encoder: &mut wgpu::CommandEncoder, + instances: &[RtInstance], + ) { + for (i, inst) in instances.iter().enumerate() { + if i >= self.tlas.get().len() { + break; + } + let blas = &self.blas_list[inst.blas_index]; + self.tlas[i] = Some(wgpu::TlasInstance::new( + blas, + mat4_to_tlas_transform(&inst.transform), + 0, + 0xFF, + )); + } + encoder.build_acceleration_structures(std::iter::empty(), std::iter::once(&self.tlas)); + } +} + +/// Convert a column-major 4×4 matrix to a row-major 3×4 affine transform. +/// +/// The TLAS expects `[f32; 12]` in row-major order (3 rows × 4 columns). +/// A standard column-major mat4 stores: col0=[m0,m1,m2,m3], col1=[m4,m5,m6,m7], ... +/// Row 0: m[0], m[4], m[8], m[12] +/// Row 1: m[1], m[5], m[9], m[13] +/// Row 2: m[2], m[6], m[10], m[14] +pub fn mat4_to_tlas_transform(m: &[f32; 16]) -> [f32; 12] { + [ + m[0], m[4], m[8], m[12], + m[1], m[5], m[9], m[13], + m[2], m[6], m[10], m[14], + ] +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mat4_to_tlas_transform_identity() { + #[rustfmt::skip] + let identity: [f32; 16] = [ + 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, + ]; + let result = mat4_to_tlas_transform(&identity); + // Row 0: [1,0,0,0], Row 1: [0,1,0,0], Row 2: [0,0,1,0] + let expected: [f32; 12] = [ + 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + ]; + assert_eq!(result, expected); + } + + #[test] + fn test_mat4_to_tlas_transform_translation() { + // Column-major: identity rotation + translation (tx=1, ty=2, tz=3) + // Column 3 = [tx, ty, tz, 1] stored at indices [12,13,14,15] + #[rustfmt::skip] + let mat: [f32; 16] = [ + 1.0, 0.0, 0.0, 0.0, // col 0 + 0.0, 1.0, 0.0, 0.0, // col 1 + 0.0, 0.0, 1.0, 0.0, // col 2 + 1.0, 2.0, 3.0, 1.0, // col 3 (translation) + ]; + let result = mat4_to_tlas_transform(&mat); + // Row-major 3x4: rotation part identity, translation in last column + let expected: [f32; 12] = [ + 1.0, 0.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 2.0, + 0.0, 0.0, 1.0, 3.0, + ]; + assert_eq!(result, expected); + } +}