diff --git a/crates/voltex_renderer/src/lib.rs b/crates/voltex_renderer/src/lib.rs index b3e8228..1cdfbd8 100644 --- a/crates/voltex_renderer/src/lib.rs +++ b/crates/voltex_renderer/src/lib.rs @@ -37,6 +37,7 @@ pub mod auto_exposure; pub mod instancing; pub mod bilateral_blur; pub mod temporal_accum; +pub mod taa; pub use gpu::{GpuContext, DEPTH_FORMAT}; pub use light::{CameraUniform, LightUniform, LightData, LightsUniform, MAX_LIGHTS, LIGHT_DIRECTIONAL, LIGHT_POINT, LIGHT_SPOT}; @@ -77,6 +78,7 @@ pub use auto_exposure::AutoExposure; pub use instancing::{InstanceData, InstanceBuffer, create_instanced_pipeline}; pub use bilateral_blur::BilateralBlur; pub use temporal_accum::TemporalAccumulation; +pub use taa::Taa; pub use png::parse_png; pub use jpg::parse_jpg; pub use gltf::{parse_gltf, GltfData, GltfMesh, GltfMaterial}; diff --git a/crates/voltex_renderer/src/taa.rs b/crates/voltex_renderer/src/taa.rs new file mode 100644 index 0000000..ffeba05 --- /dev/null +++ b/crates/voltex_renderer/src/taa.rs @@ -0,0 +1,179 @@ +use bytemuck::{Pod, Zeroable}; + +#[repr(C)] +#[derive(Copy, Clone, Pod, Zeroable)] +struct TaaParams { + jitter: [f32; 2], + prev_jitter: [f32; 2], + blend_factor: f32, + _pad: [f32; 3], +} + +pub struct Taa { + pipeline: wgpu::ComputePipeline, + bind_group_layout: wgpu::BindGroupLayout, + params_buffer: wgpu::Buffer, + jitter_index: usize, + prev_jitter: [f32; 2], + pub blend_factor: f32, + pub enabled: bool, +} + +/// Halton sequence for jitter offsets (base 2 and 3). +pub fn halton(index: usize, base: usize) -> f32 { + let mut result = 0.0_f32; + let mut f = 1.0 / base as f32; + let mut i = index; + while i > 0 { + result += f * (i % base) as f32; + i /= base; + f /= base as f32; + } + result +} + +/// Generate jitter offset for frame N (in pixel units, centered around 0). +pub fn jitter_offset(frame: usize) -> [f32; 2] { + let idx = (frame % 16) + 1; // avoid index 0 + [halton(idx, 2) - 0.5, halton(idx, 3) - 0.5] +} + +impl Taa { + pub fn new(device: &wgpu::Device) -> Self { + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("TAA Compute"), + source: wgpu::ShaderSource::Wgsl(include_str!("taa.wgsl").into()), + }); + + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("TAA BGL"), + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 2, visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::StorageTexture { access: wgpu::StorageTextureAccess::WriteOnly, format: wgpu::TextureFormat::Rgba16Float, view_dimension: wgpu::TextureViewDimension::D2 }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 3, visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None }, + count: None, + }, + ], + }); + + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("TAA PL"), bind_group_layouts: &[&bind_group_layout], immediate_size: 0, + }); + + let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("TAA Pipeline"), layout: Some(&pipeline_layout), + module: &shader, entry_point: Some("main"), + compilation_options: wgpu::PipelineCompilationOptions::default(), cache: None, + }); + + let params_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("TAA Params"), size: std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Taa { + pipeline, bind_group_layout, params_buffer, + jitter_index: 0, prev_jitter: [0.0; 2], + blend_factor: 0.1, enabled: true, + } + } + + /// Get current jitter and advance frame index. + pub fn next_jitter(&mut self) -> [f32; 2] { + self.prev_jitter = jitter_offset(self.jitter_index); + self.jitter_index += 1; + let current = jitter_offset(self.jitter_index); + current + } + + pub fn dispatch( + &self, + device: &wgpu::Device, + queue: &wgpu::Queue, + encoder: &mut wgpu::CommandEncoder, + current_view: &wgpu::TextureView, + history_view: &wgpu::TextureView, + output_view: &wgpu::TextureView, + jitter: [f32; 2], + width: u32, + height: u32, + ) { + let params = TaaParams { + jitter: [jitter[0] / width as f32, jitter[1] / height as f32], + prev_jitter: [self.prev_jitter[0] / width as f32, self.prev_jitter[1] / height as f32], + blend_factor: self.blend_factor, + _pad: [0.0; 3], + }; + queue.write_buffer(&self.params_buffer, 0, bytemuck::cast_slice(&[params])); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("TAA BG"), layout: &self.bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(current_view) }, + wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::TextureView(history_view) }, + wgpu::BindGroupEntry { binding: 2, resource: wgpu::BindingResource::TextureView(output_view) }, + wgpu::BindGroupEntry { binding: 3, resource: self.params_buffer.as_entire_binding() }, + ], + }); + + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("TAA Pass"), timestamp_writes: None }); + cpass.set_pipeline(&self.pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups((width + 15) / 16, (height + 15) / 16, 1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_halton_base2() { + assert!((halton(1, 2) - 0.5).abs() < 1e-6); + assert!((halton(2, 2) - 0.25).abs() < 1e-6); + assert!((halton(3, 2) - 0.75).abs() < 1e-6); + } + + #[test] + fn test_halton_base3() { + assert!((halton(1, 3) - 1.0/3.0).abs() < 1e-6); + assert!((halton(2, 3) - 2.0/3.0).abs() < 1e-6); + } + + #[test] + fn test_jitter_offset_centered() { + for i in 0..16 { + let j = jitter_offset(i); + assert!(j[0] >= -0.5 && j[0] <= 0.5, "jitter x out of range: {}", j[0]); + assert!(j[1] >= -0.5 && j[1] <= 0.5, "jitter y out of range: {}", j[1]); + } + } + + #[test] + fn test_jitter_varies() { + let j0 = jitter_offset(0); + let j1 = jitter_offset(1); + assert!(j0 != j1, "consecutive jitters should differ"); + } + + #[test] + fn test_halton_zero() { + assert!((halton(0, 2) - 0.0).abs() < 1e-6); + } +} diff --git a/crates/voltex_renderer/src/taa.wgsl b/crates/voltex_renderer/src/taa.wgsl new file mode 100644 index 0000000..96e78aa --- /dev/null +++ b/crates/voltex_renderer/src/taa.wgsl @@ -0,0 +1,52 @@ +struct TaaParams { + jitter: vec2, // current frame jitter offset + prev_jitter: vec2, // previous frame jitter offset + blend_factor: f32, // 0.1 (10% new, 90% history) + _pad: vec3, +}; + +@group(0) @binding(0) var current_tex: texture_2d; +@group(0) @binding(1) var history_tex: texture_2d; +@group(0) @binding(2) var output_tex: texture_storage_2d; +@group(0) @binding(3) var params: TaaParams; + +@compute @workgroup_size(16, 16) +fn main(@builtin(global_invocation_id) gid: vec3) { + let dims = textureDimensions(current_tex); + if (gid.x >= dims.x || gid.y >= dims.y) { return; } + + let pos = vec2(gid.xy); + let current = textureLoad(current_tex, pos, 0); + + // Neighborhood color clamping (3x3 AABB) + var color_min = vec4(999.0); + var color_max = vec4(-999.0); + for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + let np = pos + vec2(dx, dy); + let nc = textureLoad(current_tex, clamp(np, vec2(0), vec2(dims) - 1), 0); + color_min = min(color_min, nc); + color_max = max(color_max, nc); + } + } + + // Reprojection: unjitter to find history position + let uv = (vec2(gid.xy) + 0.5) / vec2(dims); + let history_uv = uv + params.prev_jitter - params.jitter; + let history_pos = vec2(history_uv * vec2(dims)); + + var history: vec4; + if (history_pos.x >= 0 && history_pos.y >= 0 && + history_pos.x < i32(dims.x) && history_pos.y < i32(dims.y)) { + history = textureLoad(history_tex, history_pos, 0); + } else { + history = current; + } + + // Clamp history to neighborhood + let clamped_history = clamp(history, color_min, color_max); + + // Blend + let result = mix(clamped_history, current, params.blend_factor); + textureStore(output_tex, pos, result); +}