use crate::vertex::MeshVertex; /// Data needed to build a BLAS for one mesh. pub struct BlasMeshData<'a> { pub vertex_buffer: &'a wgpu::Buffer, pub index_buffer: &'a wgpu::Buffer, pub vertex_count: u32, pub index_count: u32, } /// One instance transform fed to the TLAS: a world transform and a BLAS index. pub struct RtInstance { /// Column-major 4x4 transform matrix. pub transform: [f32; 16], /// Index into `RtAccel::blas_list`. pub blas_index: usize, } /// Bottom + Top Level Acceleration Structures for a scene. pub struct RtAccel { pub blas_list: Vec, pub tlas: wgpu::Tlas, } impl RtAccel { /// Create BLAS for each mesh and build a TLAS with the given instances. /// /// The encoder must be submitted after this call so the GPU builds fire. pub fn new( device: &wgpu::Device, encoder: &mut wgpu::CommandEncoder, meshes: &[BlasMeshData<'_>], instances: &[RtInstance], ) -> Self { let vertex_stride = std::mem::size_of::() as u64; // ── Build one BLAS per mesh ─────────────────────────────────────────── let size_descs: Vec = meshes .iter() .map(|m| wgpu::BlasTriangleGeometrySizeDescriptor { vertex_format: wgpu::VertexFormat::Float32x3, vertex_count: m.vertex_count, index_format: Some(wgpu::IndexFormat::Uint32), index_count: Some(m.index_count), flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE, }) .collect(); let blas_list: Vec = meshes .iter() .zip(size_descs.iter()) .map(|(_mesh, size_desc)| { device.create_blas( &wgpu::CreateBlasDescriptor { label: Some("Mesh BLAS"), flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE, update_mode: wgpu::AccelerationStructureUpdateMode::Build, }, wgpu::BlasGeometrySizeDescriptors::Triangles { descriptors: vec![size_desc.clone()], }, ) }) .collect(); // ── Create TLAS ─────────────────────────────────────────────────────── let max_instances = instances.len().max(1) as u32; let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor { label: Some("Scene TLAS"), max_instances, flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE, update_mode: wgpu::AccelerationStructureUpdateMode::Build, }); // ── Populate TLAS instances ─────────────────────────────────────────── for (i, inst) in instances.iter().enumerate() { let blas = &blas_list[inst.blas_index]; tlas[i] = Some(wgpu::TlasInstance::new( blas, mat4_to_tlas_transform(&inst.transform), 0, 0xFF, )); } // ── Build entries ───────────────────────────────────────────────────── let blas_entries: Vec> = meshes .iter() .zip(blas_list.iter()) .zip(size_descs.iter()) .map(|((mesh, blas), size_desc)| wgpu::BlasBuildEntry { blas, geometry: wgpu::BlasGeometries::TriangleGeometries(vec![ wgpu::BlasTriangleGeometry { size: size_desc, vertex_buffer: mesh.vertex_buffer, first_vertex: 0, vertex_stride, index_buffer: Some(mesh.index_buffer), first_index: Some(0), transform_buffer: None, transform_buffer_offset: None, }, ]), }) .collect(); encoder.build_acceleration_structures(blas_entries.iter(), std::iter::once(&tlas)); Self { blas_list, tlas } } /// Update TLAS instance transforms and rebuild. pub fn update_instances( &mut self, encoder: &mut wgpu::CommandEncoder, instances: &[RtInstance], ) { for (i, inst) in instances.iter().enumerate() { if i >= self.tlas.get().len() { break; } let blas = &self.blas_list[inst.blas_index]; self.tlas[i] = Some(wgpu::TlasInstance::new( blas, mat4_to_tlas_transform(&inst.transform), 0, 0xFF, )); } encoder.build_acceleration_structures(std::iter::empty(), std::iter::once(&self.tlas)); } } /// Convert a column-major 4×4 matrix to a row-major 3×4 affine transform. /// /// The TLAS expects `[f32; 12]` in row-major order (3 rows × 4 columns). /// A standard column-major mat4 stores: col0=[m0,m1,m2,m3], col1=[m4,m5,m6,m7], ... /// Row 0: m[0], m[4], m[8], m[12] /// Row 1: m[1], m[5], m[9], m[13] /// Row 2: m[2], m[6], m[10], m[14] pub fn mat4_to_tlas_transform(m: &[f32; 16]) -> [f32; 12] { [ m[0], m[4], m[8], m[12], m[1], m[5], m[9], m[13], m[2], m[6], m[10], m[14], ] } // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] mod tests { use super::*; #[test] fn test_mat4_to_tlas_transform_identity() { #[rustfmt::skip] let identity: [f32; 16] = [ 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ]; let result = mat4_to_tlas_transform(&identity); // Row 0: [1,0,0,0], Row 1: [0,1,0,0], Row 2: [0,0,1,0] let expected: [f32; 12] = [ 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ]; assert_eq!(result, expected); } #[test] fn test_mat4_to_tlas_transform_translation() { // Column-major: identity rotation + translation (tx=1, ty=2, tz=3) // Column 3 = [tx, ty, tz, 1] stored at indices [12,13,14,15] #[rustfmt::skip] let mat: [f32; 16] = [ 1.0, 0.0, 0.0, 0.0, // col 0 0.0, 1.0, 0.0, 0.0, // col 1 0.0, 0.0, 1.0, 0.0, // col 2 1.0, 2.0, 3.0, 1.0, // col 3 (translation) ]; let result = mat4_to_tlas_transform(&mat); // Row-major 3x4: rotation part identity, translation in last column let expected: [f32; 12] = [ 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0, 3.0, ]; assert_eq!(result, expected); } }