feat(renderer): add TAA with Halton jitter and neighborhood clamping

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 15:40:53 +09:00
parent d321c0695c
commit 41c7f9607e
3 changed files with 233 additions and 0 deletions

View File

@@ -37,6 +37,7 @@ pub mod auto_exposure;
pub mod instancing;
pub mod bilateral_blur;
pub mod temporal_accum;
pub mod taa;
pub use gpu::{GpuContext, DEPTH_FORMAT};
pub use light::{CameraUniform, LightUniform, LightData, LightsUniform, MAX_LIGHTS, LIGHT_DIRECTIONAL, LIGHT_POINT, LIGHT_SPOT};
@@ -77,6 +78,7 @@ pub use auto_exposure::AutoExposure;
pub use instancing::{InstanceData, InstanceBuffer, create_instanced_pipeline};
pub use bilateral_blur::BilateralBlur;
pub use temporal_accum::TemporalAccumulation;
pub use taa::Taa;
pub use png::parse_png;
pub use jpg::parse_jpg;
pub use gltf::{parse_gltf, GltfData, GltfMesh, GltfMaterial};

View File

@@ -0,0 +1,179 @@
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
struct TaaParams {
jitter: [f32; 2],
prev_jitter: [f32; 2],
blend_factor: f32,
_pad: [f32; 3],
}
pub struct Taa {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
params_buffer: wgpu::Buffer,
jitter_index: usize,
prev_jitter: [f32; 2],
pub blend_factor: f32,
pub enabled: bool,
}
/// Halton sequence for jitter offsets (base 2 and 3).
pub fn halton(index: usize, base: usize) -> f32 {
let mut result = 0.0_f32;
let mut f = 1.0 / base as f32;
let mut i = index;
while i > 0 {
result += f * (i % base) as f32;
i /= base;
f /= base as f32;
}
result
}
/// Generate jitter offset for frame N (in pixel units, centered around 0).
pub fn jitter_offset(frame: usize) -> [f32; 2] {
let idx = (frame % 16) + 1; // avoid index 0
[halton(idx, 2) - 0.5, halton(idx, 3) - 0.5]
}
impl Taa {
pub fn new(device: &wgpu::Device) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("TAA Compute"),
source: wgpu::ShaderSource::Wgsl(include_str!("taa.wgsl").into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("TAA BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture { access: wgpu::StorageTextureAccess::WriteOnly, format: wgpu::TextureFormat::Rgba16Float, view_dimension: wgpu::TextureViewDimension::D2 },
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None },
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("TAA PL"), bind_group_layouts: &[&bind_group_layout], immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("TAA Pipeline"), layout: Some(&pipeline_layout),
module: &shader, entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(), cache: None,
});
let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("TAA Params"), size: std::mem::size_of::<TaaParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Taa {
pipeline, bind_group_layout, params_buffer,
jitter_index: 0, prev_jitter: [0.0; 2],
blend_factor: 0.1, enabled: true,
}
}
/// Get current jitter and advance frame index.
pub fn next_jitter(&mut self) -> [f32; 2] {
self.prev_jitter = jitter_offset(self.jitter_index);
self.jitter_index += 1;
let current = jitter_offset(self.jitter_index);
current
}
pub fn dispatch(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
encoder: &mut wgpu::CommandEncoder,
current_view: &wgpu::TextureView,
history_view: &wgpu::TextureView,
output_view: &wgpu::TextureView,
jitter: [f32; 2],
width: u32,
height: u32,
) {
let params = TaaParams {
jitter: [jitter[0] / width as f32, jitter[1] / height as f32],
prev_jitter: [self.prev_jitter[0] / width as f32, self.prev_jitter[1] / height as f32],
blend_factor: self.blend_factor,
_pad: [0.0; 3],
};
queue.write_buffer(&self.params_buffer, 0, bytemuck::cast_slice(&[params]));
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("TAA BG"), layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(current_view) },
wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::TextureView(history_view) },
wgpu::BindGroupEntry { binding: 2, resource: wgpu::BindingResource::TextureView(output_view) },
wgpu::BindGroupEntry { binding: 3, resource: self.params_buffer.as_entire_binding() },
],
});
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("TAA Pass"), timestamp_writes: None });
cpass.set_pipeline(&self.pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups((width + 15) / 16, (height + 15) / 16, 1);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_halton_base2() {
assert!((halton(1, 2) - 0.5).abs() < 1e-6);
assert!((halton(2, 2) - 0.25).abs() < 1e-6);
assert!((halton(3, 2) - 0.75).abs() < 1e-6);
}
#[test]
fn test_halton_base3() {
assert!((halton(1, 3) - 1.0/3.0).abs() < 1e-6);
assert!((halton(2, 3) - 2.0/3.0).abs() < 1e-6);
}
#[test]
fn test_jitter_offset_centered() {
for i in 0..16 {
let j = jitter_offset(i);
assert!(j[0] >= -0.5 && j[0] <= 0.5, "jitter x out of range: {}", j[0]);
assert!(j[1] >= -0.5 && j[1] <= 0.5, "jitter y out of range: {}", j[1]);
}
}
#[test]
fn test_jitter_varies() {
let j0 = jitter_offset(0);
let j1 = jitter_offset(1);
assert!(j0 != j1, "consecutive jitters should differ");
}
#[test]
fn test_halton_zero() {
assert!((halton(0, 2) - 0.0).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,52 @@
struct TaaParams {
jitter: vec2<f32>, // current frame jitter offset
prev_jitter: vec2<f32>, // previous frame jitter offset
blend_factor: f32, // 0.1 (10% new, 90% history)
_pad: vec3<f32>,
};
@group(0) @binding(0) var current_tex: texture_2d<f32>;
@group(0) @binding(1) var history_tex: texture_2d<f32>;
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba16float, write>;
@group(0) @binding(3) var<uniform> params: TaaParams;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dims = textureDimensions(current_tex);
if (gid.x >= dims.x || gid.y >= dims.y) { return; }
let pos = vec2<i32>(gid.xy);
let current = textureLoad(current_tex, pos, 0);
// Neighborhood color clamping (3x3 AABB)
var color_min = vec4<f32>(999.0);
var color_max = vec4<f32>(-999.0);
for (var dy = -1; dy <= 1; dy++) {
for (var dx = -1; dx <= 1; dx++) {
let np = pos + vec2<i32>(dx, dy);
let nc = textureLoad(current_tex, clamp(np, vec2<i32>(0), vec2<i32>(dims) - 1), 0);
color_min = min(color_min, nc);
color_max = max(color_max, nc);
}
}
// Reprojection: unjitter to find history position
let uv = (vec2<f32>(gid.xy) + 0.5) / vec2<f32>(dims);
let history_uv = uv + params.prev_jitter - params.jitter;
let history_pos = vec2<i32>(history_uv * vec2<f32>(dims));
var history: vec4<f32>;
if (history_pos.x >= 0 && history_pos.y >= 0 &&
history_pos.x < i32(dims.x) && history_pos.y < i32(dims.y)) {
history = textureLoad(history_tex, history_pos, 0);
} else {
history = current;
}
// Clamp history to neighborhood
let clamped_history = clamp(history, color_min, color_max);
// Blend
let result = mix(clamped_history, current, params.blend_factor);
textureStore(output_tex, pos, result);
}