feat(renderer): add TAA with Halton jitter and neighborhood clamping
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,7 @@ pub mod auto_exposure;
|
|||||||
pub mod instancing;
|
pub mod instancing;
|
||||||
pub mod bilateral_blur;
|
pub mod bilateral_blur;
|
||||||
pub mod temporal_accum;
|
pub mod temporal_accum;
|
||||||
|
pub mod taa;
|
||||||
|
|
||||||
pub use gpu::{GpuContext, DEPTH_FORMAT};
|
pub use gpu::{GpuContext, DEPTH_FORMAT};
|
||||||
pub use light::{CameraUniform, LightUniform, LightData, LightsUniform, MAX_LIGHTS, LIGHT_DIRECTIONAL, LIGHT_POINT, LIGHT_SPOT};
|
pub use light::{CameraUniform, LightUniform, LightData, LightsUniform, MAX_LIGHTS, LIGHT_DIRECTIONAL, LIGHT_POINT, LIGHT_SPOT};
|
||||||
@@ -77,6 +78,7 @@ pub use auto_exposure::AutoExposure;
|
|||||||
pub use instancing::{InstanceData, InstanceBuffer, create_instanced_pipeline};
|
pub use instancing::{InstanceData, InstanceBuffer, create_instanced_pipeline};
|
||||||
pub use bilateral_blur::BilateralBlur;
|
pub use bilateral_blur::BilateralBlur;
|
||||||
pub use temporal_accum::TemporalAccumulation;
|
pub use temporal_accum::TemporalAccumulation;
|
||||||
|
pub use taa::Taa;
|
||||||
pub use png::parse_png;
|
pub use png::parse_png;
|
||||||
pub use jpg::parse_jpg;
|
pub use jpg::parse_jpg;
|
||||||
pub use gltf::{parse_gltf, GltfData, GltfMesh, GltfMaterial};
|
pub use gltf::{parse_gltf, GltfData, GltfMesh, GltfMaterial};
|
||||||
|
|||||||
179
crates/voltex_renderer/src/taa.rs
Normal file
179
crates/voltex_renderer/src/taa.rs
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
use bytemuck::{Pod, Zeroable};
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Copy, Clone, Pod, Zeroable)]
|
||||||
|
struct TaaParams {
|
||||||
|
jitter: [f32; 2],
|
||||||
|
prev_jitter: [f32; 2],
|
||||||
|
blend_factor: f32,
|
||||||
|
_pad: [f32; 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Taa {
|
||||||
|
pipeline: wgpu::ComputePipeline,
|
||||||
|
bind_group_layout: wgpu::BindGroupLayout,
|
||||||
|
params_buffer: wgpu::Buffer,
|
||||||
|
jitter_index: usize,
|
||||||
|
prev_jitter: [f32; 2],
|
||||||
|
pub blend_factor: f32,
|
||||||
|
pub enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Halton sequence for jitter offsets (base 2 and 3).
|
||||||
|
pub fn halton(index: usize, base: usize) -> f32 {
|
||||||
|
let mut result = 0.0_f32;
|
||||||
|
let mut f = 1.0 / base as f32;
|
||||||
|
let mut i = index;
|
||||||
|
while i > 0 {
|
||||||
|
result += f * (i % base) as f32;
|
||||||
|
i /= base;
|
||||||
|
f /= base as f32;
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate jitter offset for frame N (in pixel units, centered around 0).
|
||||||
|
pub fn jitter_offset(frame: usize) -> [f32; 2] {
|
||||||
|
let idx = (frame % 16) + 1; // avoid index 0
|
||||||
|
[halton(idx, 2) - 0.5, halton(idx, 3) - 0.5]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Taa {
|
||||||
|
pub fn new(device: &wgpu::Device) -> Self {
|
||||||
|
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||||
|
label: Some("TAA Compute"),
|
||||||
|
source: wgpu::ShaderSource::Wgsl(include_str!("taa.wgsl").into()),
|
||||||
|
});
|
||||||
|
|
||||||
|
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||||
|
label: Some("TAA BGL"),
|
||||||
|
entries: &[
|
||||||
|
wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 0, visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
|
||||||
|
count: None,
|
||||||
|
},
|
||||||
|
wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 1, visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
|
||||||
|
count: None,
|
||||||
|
},
|
||||||
|
wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 2, visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::StorageTexture { access: wgpu::StorageTextureAccess::WriteOnly, format: wgpu::TextureFormat::Rgba16Float, view_dimension: wgpu::TextureViewDimension::D2 },
|
||||||
|
count: None,
|
||||||
|
},
|
||||||
|
wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 3, visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None },
|
||||||
|
count: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||||
|
label: Some("TAA PL"), bind_group_layouts: &[&bind_group_layout], immediate_size: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||||
|
label: Some("TAA Pipeline"), layout: Some(&pipeline_layout),
|
||||||
|
module: &shader, entry_point: Some("main"),
|
||||||
|
compilation_options: wgpu::PipelineCompilationOptions::default(), cache: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: Some("TAA Params"), size: std::mem::size_of::<TaaParams>() as u64,
|
||||||
|
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
Taa {
|
||||||
|
pipeline, bind_group_layout, params_buffer,
|
||||||
|
jitter_index: 0, prev_jitter: [0.0; 2],
|
||||||
|
blend_factor: 0.1, enabled: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current jitter and advance frame index.
|
||||||
|
pub fn next_jitter(&mut self) -> [f32; 2] {
|
||||||
|
self.prev_jitter = jitter_offset(self.jitter_index);
|
||||||
|
self.jitter_index += 1;
|
||||||
|
let current = jitter_offset(self.jitter_index);
|
||||||
|
current
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dispatch(
|
||||||
|
&self,
|
||||||
|
device: &wgpu::Device,
|
||||||
|
queue: &wgpu::Queue,
|
||||||
|
encoder: &mut wgpu::CommandEncoder,
|
||||||
|
current_view: &wgpu::TextureView,
|
||||||
|
history_view: &wgpu::TextureView,
|
||||||
|
output_view: &wgpu::TextureView,
|
||||||
|
jitter: [f32; 2],
|
||||||
|
width: u32,
|
||||||
|
height: u32,
|
||||||
|
) {
|
||||||
|
let params = TaaParams {
|
||||||
|
jitter: [jitter[0] / width as f32, jitter[1] / height as f32],
|
||||||
|
prev_jitter: [self.prev_jitter[0] / width as f32, self.prev_jitter[1] / height as f32],
|
||||||
|
blend_factor: self.blend_factor,
|
||||||
|
_pad: [0.0; 3],
|
||||||
|
};
|
||||||
|
queue.write_buffer(&self.params_buffer, 0, bytemuck::cast_slice(&[params]));
|
||||||
|
|
||||||
|
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||||
|
label: Some("TAA BG"), layout: &self.bind_group_layout,
|
||||||
|
entries: &[
|
||||||
|
wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(current_view) },
|
||||||
|
wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::TextureView(history_view) },
|
||||||
|
wgpu::BindGroupEntry { binding: 2, resource: wgpu::BindingResource::TextureView(output_view) },
|
||||||
|
wgpu::BindGroupEntry { binding: 3, resource: self.params_buffer.as_entire_binding() },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("TAA Pass"), timestamp_writes: None });
|
||||||
|
cpass.set_pipeline(&self.pipeline);
|
||||||
|
cpass.set_bind_group(0, &bind_group, &[]);
|
||||||
|
cpass.dispatch_workgroups((width + 15) / 16, (height + 15) / 16, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_halton_base2() {
|
||||||
|
assert!((halton(1, 2) - 0.5).abs() < 1e-6);
|
||||||
|
assert!((halton(2, 2) - 0.25).abs() < 1e-6);
|
||||||
|
assert!((halton(3, 2) - 0.75).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_halton_base3() {
|
||||||
|
assert!((halton(1, 3) - 1.0/3.0).abs() < 1e-6);
|
||||||
|
assert!((halton(2, 3) - 2.0/3.0).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jitter_offset_centered() {
|
||||||
|
for i in 0..16 {
|
||||||
|
let j = jitter_offset(i);
|
||||||
|
assert!(j[0] >= -0.5 && j[0] <= 0.5, "jitter x out of range: {}", j[0]);
|
||||||
|
assert!(j[1] >= -0.5 && j[1] <= 0.5, "jitter y out of range: {}", j[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jitter_varies() {
|
||||||
|
let j0 = jitter_offset(0);
|
||||||
|
let j1 = jitter_offset(1);
|
||||||
|
assert!(j0 != j1, "consecutive jitters should differ");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_halton_zero() {
|
||||||
|
assert!((halton(0, 2) - 0.0).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
52
crates/voltex_renderer/src/taa.wgsl
Normal file
52
crates/voltex_renderer/src/taa.wgsl
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
struct TaaParams {
|
||||||
|
jitter: vec2<f32>, // current frame jitter offset
|
||||||
|
prev_jitter: vec2<f32>, // previous frame jitter offset
|
||||||
|
blend_factor: f32, // 0.1 (10% new, 90% history)
|
||||||
|
_pad: vec3<f32>,
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0) var current_tex: texture_2d<f32>;
|
||||||
|
@group(0) @binding(1) var history_tex: texture_2d<f32>;
|
||||||
|
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba16float, write>;
|
||||||
|
@group(0) @binding(3) var<uniform> params: TaaParams;
|
||||||
|
|
||||||
|
@compute @workgroup_size(16, 16)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
let dims = textureDimensions(current_tex);
|
||||||
|
if (gid.x >= dims.x || gid.y >= dims.y) { return; }
|
||||||
|
|
||||||
|
let pos = vec2<i32>(gid.xy);
|
||||||
|
let current = textureLoad(current_tex, pos, 0);
|
||||||
|
|
||||||
|
// Neighborhood color clamping (3x3 AABB)
|
||||||
|
var color_min = vec4<f32>(999.0);
|
||||||
|
var color_max = vec4<f32>(-999.0);
|
||||||
|
for (var dy = -1; dy <= 1; dy++) {
|
||||||
|
for (var dx = -1; dx <= 1; dx++) {
|
||||||
|
let np = pos + vec2<i32>(dx, dy);
|
||||||
|
let nc = textureLoad(current_tex, clamp(np, vec2<i32>(0), vec2<i32>(dims) - 1), 0);
|
||||||
|
color_min = min(color_min, nc);
|
||||||
|
color_max = max(color_max, nc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reprojection: unjitter to find history position
|
||||||
|
let uv = (vec2<f32>(gid.xy) + 0.5) / vec2<f32>(dims);
|
||||||
|
let history_uv = uv + params.prev_jitter - params.jitter;
|
||||||
|
let history_pos = vec2<i32>(history_uv * vec2<f32>(dims));
|
||||||
|
|
||||||
|
var history: vec4<f32>;
|
||||||
|
if (history_pos.x >= 0 && history_pos.y >= 0 &&
|
||||||
|
history_pos.x < i32(dims.x) && history_pos.y < i32(dims.y)) {
|
||||||
|
history = textureLoad(history_tex, history_pos, 0);
|
||||||
|
} else {
|
||||||
|
history = current;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clamp history to neighborhood
|
||||||
|
let clamped_history = clamp(history, color_min, color_max);
|
||||||
|
|
||||||
|
// Blend
|
||||||
|
let result = mix(clamped_history, current, params.blend_factor);
|
||||||
|
textureStore(output_tex, pos, result);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user