feat(renderer): add screen space reflections compute shader

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 15:41:29 +09:00
parent 41c7f9607e
commit 764ee96ec1
3 changed files with 261 additions and 0 deletions

View File

@@ -38,6 +38,7 @@ pub mod instancing;
pub mod bilateral_blur; pub mod bilateral_blur;
pub mod temporal_accum; pub mod temporal_accum;
pub mod taa; pub mod taa;
pub mod ssr;
pub use gpu::{GpuContext, DEPTH_FORMAT}; pub use gpu::{GpuContext, DEPTH_FORMAT};
pub use light::{CameraUniform, LightUniform, LightData, LightsUniform, MAX_LIGHTS, LIGHT_DIRECTIONAL, LIGHT_POINT, LIGHT_SPOT}; pub use light::{CameraUniform, LightUniform, LightData, LightsUniform, MAX_LIGHTS, LIGHT_DIRECTIONAL, LIGHT_POINT, LIGHT_SPOT};
@@ -79,6 +80,7 @@ pub use instancing::{InstanceData, InstanceBuffer, create_instanced_pipeline};
pub use bilateral_blur::BilateralBlur; pub use bilateral_blur::BilateralBlur;
pub use temporal_accum::TemporalAccumulation; pub use temporal_accum::TemporalAccumulation;
pub use taa::Taa; pub use taa::Taa;
pub use ssr::Ssr;
pub use png::parse_png; pub use png::parse_png;
pub use jpg::parse_jpg; pub use jpg::parse_jpg;
pub use gltf::{parse_gltf, GltfData, GltfMesh, GltfMaterial}; pub use gltf::{parse_gltf, GltfData, GltfMesh, GltfMaterial};

View File

@@ -0,0 +1,177 @@
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
pub struct SsrParams {
pub view_proj: [[f32; 4]; 4],
pub inv_view_proj: [[f32; 4]; 4],
pub camera_pos: [f32; 3],
pub max_steps: u32,
pub step_size: f32,
pub thickness: f32,
pub _pad: [f32; 2],
}
impl SsrParams {
pub fn new() -> Self {
SsrParams {
view_proj: [[0.0; 4]; 4],
inv_view_proj: [[0.0; 4]; 4],
camera_pos: [0.0; 3],
max_steps: 64,
step_size: 0.1,
thickness: 0.05,
_pad: [0.0; 2],
}
}
}
pub struct Ssr {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
params_buffer: wgpu::Buffer,
pub enabled: bool,
}
impl Ssr {
pub fn new(device: &wgpu::Device) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("SSR Compute"),
source: wgpu::ShaderSource::Wgsl(include_str!("ssr.wgsl").into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("SSR BGL"),
entries: &[
// binding 0: position texture
wgpu::BindGroupLayoutEntry {
binding: 0, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
count: None,
},
// binding 1: normal texture
wgpu::BindGroupLayoutEntry {
binding: 1, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
count: None,
},
// binding 2: color texture (lit scene)
wgpu::BindGroupLayoutEntry {
binding: 2, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
count: None,
},
// binding 3: output
wgpu::BindGroupLayoutEntry {
binding: 3, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture { access: wgpu::StorageTextureAccess::WriteOnly, format: wgpu::TextureFormat::Rgba16Float, view_dimension: wgpu::TextureViewDimension::D2 },
count: None,
},
// binding 4: params uniform
wgpu::BindGroupLayoutEntry {
binding: 4, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None },
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("SSR PL"), bind_group_layouts: &[&bind_group_layout], immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("SSR Pipeline"), layout: Some(&pipeline_layout),
module: &shader, entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(), cache: None,
});
let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("SSR Params"),
size: std::mem::size_of::<SsrParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Ssr { pipeline, bind_group_layout, params_buffer, enabled: true }
}
pub fn dispatch(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
encoder: &mut wgpu::CommandEncoder,
position_view: &wgpu::TextureView,
normal_view: &wgpu::TextureView,
color_view: &wgpu::TextureView,
output_view: &wgpu::TextureView,
params: &SsrParams,
width: u32,
height: u32,
) {
queue.write_buffer(&self.params_buffer, 0, bytemuck::cast_slice(&[*params]));
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("SSR BG"), layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(position_view) },
wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::TextureView(normal_view) },
wgpu::BindGroupEntry { binding: 2, resource: wgpu::BindingResource::TextureView(color_view) },
wgpu::BindGroupEntry { binding: 3, resource: wgpu::BindingResource::TextureView(output_view) },
wgpu::BindGroupEntry { binding: 4, resource: self.params_buffer.as_entire_binding() },
],
});
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("SSR Pass"), timestamp_writes: None });
cpass.set_pipeline(&self.pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups((width + 15) / 16, (height + 15) / 16, 1);
}
}
/// Reflect a vector about a normal (CPU version for testing).
pub fn reflect(incident: [f32; 3], normal: [f32; 3]) -> [f32; 3] {
let d = 2.0 * (incident[0]*normal[0] + incident[1]*normal[1] + incident[2]*normal[2]);
[
incident[0] - d * normal[0],
incident[1] - d * normal[1],
incident[2] - d * normal[2],
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reflect_horizontal() {
// Ray going down, reflected off horizontal surface (normal up)
let r = reflect([0.0, -1.0, 0.0], [0.0, 1.0, 0.0]);
assert!((r[0] - 0.0).abs() < 1e-6);
assert!((r[1] - 1.0).abs() < 1e-6);
assert!((r[2] - 0.0).abs() < 1e-6);
}
#[test]
fn test_reflect_45_degrees() {
let s = std::f32::consts::FRAC_1_SQRT_2;
let r = reflect([s, -s, 0.0], [0.0, 1.0, 0.0]);
assert!((r[0] - s).abs() < 1e-5);
assert!((r[1] - s).abs() < 1e-5);
}
#[test]
fn test_ssr_params_default() {
let p = SsrParams::new();
assert_eq!(p.max_steps, 64);
assert!((p.step_size - 0.1).abs() < 1e-6);
assert!((p.thickness - 0.05).abs() < 1e-6);
}
#[test]
fn test_ssr_params_size() {
// Must be aligned for uniform buffer
let size = std::mem::size_of::<SsrParams>();
assert_eq!(size % 16, 0, "SsrParams size must be 16-byte aligned, got {}", size);
}
}

View File

@@ -0,0 +1,82 @@
struct SsrParams {
view_proj: mat4x4<f32>,
inv_view_proj: mat4x4<f32>,
camera_pos: vec3<f32>,
max_steps: u32,
step_size: f32,
thickness: f32,
_pad: vec2<f32>,
};
@group(0) @binding(0) var position_tex: texture_2d<f32>; // G-Buffer world position
@group(0) @binding(1) var normal_tex: texture_2d<f32>; // G-Buffer world normal
@group(0) @binding(2) var color_tex: texture_2d<f32>; // Lit HDR color
@group(0) @binding(3) var output_tex: texture_storage_2d<rgba16float, write>;
@group(0) @binding(4) var<uniform> params: SsrParams;
fn world_to_screen(world_pos: vec3<f32>) -> vec3<f32> {
let clip = params.view_proj * vec4<f32>(world_pos, 1.0);
let ndc = clip.xyz / clip.w;
let dims = vec2<f32>(textureDimensions(color_tex));
return vec3<f32>(
(ndc.x * 0.5 + 0.5) * dims.x,
(1.0 - (ndc.y * 0.5 + 0.5)) * dims.y,
ndc.z,
);
}
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dims = textureDimensions(position_tex);
if (gid.x >= dims.x || gid.y >= dims.y) { return; }
let pos = vec2<i32>(gid.xy);
let world_pos = textureLoad(position_tex, pos, 0).xyz;
let normal = normalize(textureLoad(normal_tex, pos, 0).xyz);
// Skip pixels with no geometry (position = 0)
if (dot(world_pos, world_pos) < 0.001) {
textureStore(output_tex, pos, vec4<f32>(0.0));
return;
}
// Reflect view direction
let view_dir = normalize(world_pos - params.camera_pos);
let reflect_dir = reflect(view_dir, normal);
// Ray march in world space, project to screen each step
var ray_pos = world_pos + reflect_dir * params.step_size;
var hit_color = vec4<f32>(0.0);
for (var i = 0u; i < params.max_steps; i++) {
let screen = world_to_screen(ray_pos);
let sx = i32(screen.x);
let sy = i32(screen.y);
// Bounds check
if (sx < 0 || sy < 0 || sx >= i32(dims.x) || sy >= i32(dims.y) || screen.z < 0.0 || screen.z > 1.0) {
break;
}
// Compare depth
let sample_pos = textureLoad(position_tex, vec2<i32>(sx, sy), 0).xyz;
let sample_screen = world_to_screen(sample_pos);
let depth_diff = screen.z - sample_screen.z;
if (depth_diff > 0.0 && depth_diff < params.thickness) {
// Hit! Sample the color
hit_color = textureLoad(color_tex, vec2<i32>(sx, sy), 0);
// Fade at edges
let edge_fade = 1.0 - max(
abs(f32(sx) / f32(dims.x) * 2.0 - 1.0),
abs(f32(sy) / f32(dims.y) * 2.0 - 1.0),
);
hit_color = hit_color * clamp(edge_fade, 0.0, 1.0);
break;
}
ray_pos += reflect_dir * params.step_size;
}
textureStore(output_tex, pos, hit_color);
}