feat(renderer): add BLAS/TLAS acceleration structure management for RT

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-25 13:14:06 +09:00
parent 71045d8603
commit e2424bf8c9

View File

@@ -0,0 +1,195 @@
use crate::vertex::MeshVertex;
/// Data needed to build a BLAS for one mesh.
pub struct BlasMeshData<'a> {
pub vertex_buffer: &'a wgpu::Buffer,
pub index_buffer: &'a wgpu::Buffer,
pub vertex_count: u32,
pub index_count: u32,
}
/// One instance transform fed to the TLAS: a world transform and a BLAS index.
pub struct RtInstance {
/// Column-major 4x4 transform matrix.
pub transform: [f32; 16],
/// Index into `RtAccel::blas_list`.
pub blas_index: usize,
}
/// Bottom + Top Level Acceleration Structures for a scene.
pub struct RtAccel {
pub blas_list: Vec<wgpu::Blas>,
pub tlas: wgpu::Tlas,
}
impl RtAccel {
/// Create BLAS for each mesh and build a TLAS with the given instances.
///
/// The encoder must be submitted after this call so the GPU builds fire.
pub fn new(
device: &wgpu::Device,
encoder: &mut wgpu::CommandEncoder,
meshes: &[BlasMeshData<'_>],
instances: &[RtInstance],
) -> Self {
let vertex_stride = std::mem::size_of::<MeshVertex>() as u64;
// ── Build one BLAS per mesh ───────────────────────────────────────────
let size_descs: Vec<wgpu::BlasTriangleGeometrySizeDescriptor> = meshes
.iter()
.map(|m| wgpu::BlasTriangleGeometrySizeDescriptor {
vertex_format: wgpu::VertexFormat::Float32x3,
vertex_count: m.vertex_count,
index_format: Some(wgpu::IndexFormat::Uint32),
index_count: Some(m.index_count),
flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
})
.collect();
let blas_list: Vec<wgpu::Blas> = meshes
.iter()
.zip(size_descs.iter())
.map(|(_mesh, size_desc)| {
device.create_blas(
&wgpu::CreateBlasDescriptor {
label: Some("Mesh BLAS"),
flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: wgpu::AccelerationStructureUpdateMode::Build,
},
wgpu::BlasGeometrySizeDescriptors::Triangles {
descriptors: vec![size_desc.clone()],
},
)
})
.collect();
// ── Create TLAS ───────────────────────────────────────────────────────
let max_instances = instances.len().max(1) as u32;
let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
label: Some("Scene TLAS"),
max_instances,
flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: wgpu::AccelerationStructureUpdateMode::Build,
});
// ── Populate TLAS instances ───────────────────────────────────────────
for (i, inst) in instances.iter().enumerate() {
let blas = &blas_list[inst.blas_index];
tlas[i] = Some(wgpu::TlasInstance::new(
blas,
mat4_to_tlas_transform(&inst.transform),
0,
0xFF,
));
}
// ── Build entries ─────────────────────────────────────────────────────
let blas_entries: Vec<wgpu::BlasBuildEntry<'_>> = meshes
.iter()
.zip(blas_list.iter())
.zip(size_descs.iter())
.map(|((mesh, blas), size_desc)| wgpu::BlasBuildEntry {
blas,
geometry: wgpu::BlasGeometries::TriangleGeometries(vec![
wgpu::BlasTriangleGeometry {
size: size_desc,
vertex_buffer: mesh.vertex_buffer,
first_vertex: 0,
vertex_stride,
index_buffer: Some(mesh.index_buffer),
first_index: Some(0),
transform_buffer: None,
transform_buffer_offset: None,
},
]),
})
.collect();
encoder.build_acceleration_structures(blas_entries.iter(), std::iter::once(&tlas));
Self { blas_list, tlas }
}
/// Update TLAS instance transforms and rebuild.
pub fn update_instances(
&mut self,
encoder: &mut wgpu::CommandEncoder,
instances: &[RtInstance],
) {
for (i, inst) in instances.iter().enumerate() {
if i >= self.tlas.get().len() {
break;
}
let blas = &self.blas_list[inst.blas_index];
self.tlas[i] = Some(wgpu::TlasInstance::new(
blas,
mat4_to_tlas_transform(&inst.transform),
0,
0xFF,
));
}
encoder.build_acceleration_structures(std::iter::empty(), std::iter::once(&self.tlas));
}
}
/// Convert a column-major 4×4 matrix to a row-major 3×4 affine transform.
///
/// The TLAS expects `[f32; 12]` in row-major order (3 rows × 4 columns).
/// A standard column-major mat4 stores: col0=[m0,m1,m2,m3], col1=[m4,m5,m6,m7], ...
/// Row 0: m[0], m[4], m[8], m[12]
/// Row 1: m[1], m[5], m[9], m[13]
/// Row 2: m[2], m[6], m[10], m[14]
pub fn mat4_to_tlas_transform(m: &[f32; 16]) -> [f32; 12] {
[
m[0], m[4], m[8], m[12],
m[1], m[5], m[9], m[13],
m[2], m[6], m[10], m[14],
]
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mat4_to_tlas_transform_identity() {
#[rustfmt::skip]
let identity: [f32; 16] = [
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
];
let result = mat4_to_tlas_transform(&identity);
// Row 0: [1,0,0,0], Row 1: [0,1,0,0], Row 2: [0,0,1,0]
let expected: [f32; 12] = [
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
];
assert_eq!(result, expected);
}
#[test]
fn test_mat4_to_tlas_transform_translation() {
// Column-major: identity rotation + translation (tx=1, ty=2, tz=3)
// Column 3 = [tx, ty, tz, 1] stored at indices [12,13,14,15]
#[rustfmt::skip]
let mat: [f32; 16] = [
1.0, 0.0, 0.0, 0.0, // col 0
0.0, 1.0, 0.0, 0.0, // col 1
0.0, 0.0, 1.0, 0.0, // col 2
1.0, 2.0, 3.0, 1.0, // col 3 (translation)
];
let result = mat4_to_tlas_transform(&mat);
// Row-major 3x4: rotation part identity, translation in last column
let expected: [f32; 12] = [
1.0, 0.0, 0.0, 1.0,
0.0, 1.0, 0.0, 2.0,
0.0, 0.0, 1.0, 3.0,
];
assert_eq!(result, expected);
}
}