Compare commits

...

109 Commits

Author SHA1 Message Date
ea9667889c docs: update STATUS.md and DEFERRED.md - nearly all deferred items complete
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:31:05 +09:00
6b6d581b71 feat(renderer): add soft shadows, BLAS tracker, RT fallback, light probes, light volumes
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:28:37 +09:00
be290bd6e0 feat(audio): add async loader, effect chain
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:28:31 +09:00
a50f79e4fc docs: update STATUS.md and DEFERRED.md with G-Buffer compression, stencil, half-res SSGI, bilateral bloom, fade curves, dynamic groups, audio bus
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:24:51 +09:00
aafebff478 feat(audio): add fade curves, dynamic mix groups, audio bus routing
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:23:48 +09:00
bc2880d41c feat(renderer): add stencil optimization, half-res SSGI, bilateral bloom
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:21:44 +09:00
025bf4d0b9 feat(renderer): add G-Buffer compression with octahedral normals and depth reconstruction
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:21:05 +09:00
7375b15fcf chore: update Cargo.lock
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:42:21 +09:00
37addfdf03 docs: update STATUS.md and DEFERRED.md with RT reflections, RT AO, RT point shadows
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:38:15 +09:00
039dbe0d09 feat(renderer): add RT reflections, RT AO, and RT point/spot shadows
Add three new compute shader modules extending the existing RT shadow system:
- RT Reflections: screen-space reflection ray marching from G-Buffer
- RT Ambient Occlusion: hemisphere sampling with cosine-weighted directions
- RT Point/Spot Shadow: placeholder infrastructure for point/spot light shadows

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:37:25 +09:00
9a5d3b7b97 docs: update STATUS.md and DEFERRED.md with occlusion and progressive JPEG
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:34:29 +09:00
cc842d7c13 feat(renderer): add Progressive JPEG detection and scan parsing
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:33:52 +09:00
84dc7aeb20 feat(audio): add occlusion simulation with low-pass filter
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:33:04 +09:00
d7b5fdde31 docs: update STATUS.md and DEFERRED.md with HRTF, reverb, lag compensation, encryption
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:32:02 +09:00
96efe113b2 feat(audio): add Reverb (Schroeder) and Echo effects
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:31:19 +09:00
98d40d6520 feat(net): add packet encryption and auth token
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:28:21 +09:00
6beafc6949 feat(net): add lag compensation with history rewind and interpolation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:27:58 +09:00
28b24226e7 feat(audio): add simplified HRTF with ITD and ILD
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:27:05 +09:00
8685d7c4aa docs: update STATUS.md and DEFERRED.md with motion blur and DOF
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 16:25:57 +09:00
447473598a feat(renderer): add motion blur compute shader
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:57:05 +09:00
831365a622 docs: update STATUS.md and DEFERRED.md with ray vs mesh, convex hull, navmesh serialization
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:50:31 +09:00
0f08c65a1e feat(physics): add ConvexHull collider with GJK support function
Add ConvexHull struct storing vertices with a support function that
returns the farthest point in a given direction, enabling GJK/EPA
collision detection. Update all Collider match arms across the physics
crate (collision, raycast, integrator, solver) to handle the new variant.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:49:38 +09:00
1b5da4d0d5 feat(physics): add ray vs mesh raycasting with MeshCollider
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:48:57 +09:00
07d7475410 feat(ai): add NavMesh binary serialization
Add serialize/deserialize methods directly on NavMesh with a proper
binary header (magic "VNAV" + version u32) for format validation.
Includes tests for roundtrip, invalid magic, empty mesh, header format,
and unsupported version.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:46:38 +09:00
d4ef4cf1ce docs: update STATUS.md and DEFERRED.md with TAA and SSR
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:42:11 +09:00
764ee96ec1 feat(renderer): add screen space reflections compute shader
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:41:29 +09:00
41c7f9607e feat(renderer): add TAA with Halton jitter and neighborhood clamping
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:40:53 +09:00
d321c0695c docs: update STATUS.md and DEFERRED.md with bilateral blur and temporal accumulation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:29:42 +09:00
1f855b7bf6 feat(renderer): add temporal accumulation compute shader for SSGI
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:28:21 +09:00
1ea2d340e6 feat(renderer): add bilateral blur compute shader for SSGI denoising
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:28:20 +09:00
a60a25e9ba docs: add SSGI quality (bilateral blur + temporal accumulation) design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:25:26 +09:00
e4879b6a31 docs: update STATUS.md and DEFERRED.md with GPU instancing
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:14:45 +09:00
17ea3f4856 feat(renderer): add GPU instancing with instance buffer and pipeline
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:14:12 +09:00
c5f6511fc2 docs: add GPU instancing design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:12:11 +09:00
2703e73ef0 docs: update STATUS.md and DEFERRED.md with auto exposure
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:09:31 +09:00
7dbd94ebab feat(renderer): add auto exposure with compute luminance and adaptation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:08:50 +09:00
72b517efb2 docs: add auto exposure design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:04:20 +09:00
5f21ec0afe docs: update STATUS.md and DEFERRED.md with Changed filter
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:01:32 +09:00
bbb11d9d47 feat(ecs): add tick-based change detection with query_changed
Add per-component tick tracking to SparseSet. Insert and get_mut mark
the current tick; increment_tick advances it. World gains query_changed
to find entities whose component changed this tick, and clear_changed
to advance all storages at end of frame.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 15:00:50 +09:00
c6ac2ded81 docs: add Changed filter design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:57:44 +09:00
f79889cbf1 docs: update STATUS.md and DEFERRED.md with transparent forward pass
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:53:03 +09:00
afb95c9fb1 feat(renderer): add forward transparency pass with alpha blending
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:52:17 +09:00
ef8c39b5ae docs: add transparent objects design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:49:44 +09:00
43411b6fd9 docs: update STATUS.md and DEFERRED.md with AudioSource and NavAgent ECS components
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:44:40 +09:00
5b3b06c318 feat(ai): add NavAgent ECS component
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:43:59 +09:00
6121530bfe feat(audio): add AudioSource ECS component
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:43:48 +09:00
ce1a79cab6 docs: add ECS integration (Audio + AI) design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:42:46 +09:00
1d38434aec docs: update STATUS.md and DEFERRED.md with glTF animation parser
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:39:10 +09:00
0cc6df15a3 feat(renderer): extend glTF parser with nodes, skins, animations support
Add GltfNode, GltfSkin, GltfAnimation, GltfChannel structs and parsing
for skeletal animation data. Extend GltfMesh with JOINTS_0/WEIGHTS_0
attribute extraction. All existing tests pass plus 4 new tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:38:25 +09:00
ce69c81eca docs: add glTF animation/skin parser extension design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:35:23 +09:00
b2fd3988f5 docs: update STATUS.md and DEFERRED.md with TTF font support
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:14:48 +09:00
cccc54c438 feat(editor): integrate TTF font into UiContext with draw_text helper
Add ttf_font field to UiContext with draw_text and ttf_text_width
helpers that use TTF when available, falling back to bitmap font.
Load system TTF font (arial/consola/malgun) in editor_demo on startup.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:14:05 +09:00
58bce839fe feat(editor): add TtfFont unified interface with glyph caching
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:12:02 +09:00
94e7f6262e feat(editor): add GlyphCache on-demand atlas manager
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:10:16 +09:00
3b0a65ed17 feat(editor): add glyph rasterizer with bezier flattening and scanline fill
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:09:58 +09:00
e008178316 feat(editor): add self-implemented TTF parser with cmap, glyf, hmtx support
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:07:52 +09:00
74974dbff0 docs: add TTF font implementation plan
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:04:43 +09:00
53505ee9b7 docs: add TTF font design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:03:11 +09:00
295944d237 docs: update STATUS.md and DEFERRED.md with asset browser
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 13:38:51 +09:00
587c75f6c2 feat(editor): integrate asset browser into editor_demo
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 13:38:11 +09:00
0857446d74 feat(editor): add AssetBrowser with directory navigation and UI panel
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 13:36:23 +09:00
b65585b739 docs: add asset browser implementation plan
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:43:29 +09:00
4ef6e83710 docs: add asset browser design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:42:41 +09:00
58cbdf8400 docs: update STATUS.md and DEFERRED.md with entity inspector
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:00:40 +09:00
b965d78835 feat(editor): integrate hierarchy and inspector panels into editor_demo
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:59:55 +09:00
65b86c293c feat(editor): add inspector_panel with Transform, Tag, Parent editing
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:57:23 +09:00
7fbc88b86f feat(editor): add hierarchy_panel with entity tree display and selection
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:55:18 +09:00
8ecf883ef6 docs: add entity inspector implementation plan
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:50:34 +09:00
e14a53c3fa docs: fix entity inspector spec from review
- Add copy-out borrow pattern for Transform editing
- Add tag_buffer staging string for Tag editing
- Add count_nodes helper, highlight color

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:48:57 +09:00
4d5fc5e44c docs: add entity inspector design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:47:14 +09:00
34e257e887 docs: update STATUS.md and DEFERRED.md with scene viewport
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:41:31 +09:00
14fe532432 feat(editor): integrate 3D viewport into editor_demo
Render actual 3D geometry (cubes + ground plane) in the Viewport panel
using the existing mesh_shader.wgsl Blinn-Phong pipeline with orbit
camera controls (left-drag to orbit, middle-drag to pan, scroll to zoom).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:40:41 +09:00
9f42300109 feat(editor): add ViewportRenderer blit pipeline and shader
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:35:15 +09:00
6ef248f76d feat(editor): add OrbitCamera with orbit, zoom, pan controls
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:33:10 +09:00
ae20590f6e feat(editor): add ViewportTexture offscreen render target
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:31:28 +09:00
d93253dfb1 docs: add scene viewport implementation plan
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:26:21 +09:00
fed47e9242 docs: fix scene viewport spec from review
- Add voltex_math/voltex_renderer dependencies
- Fix color format to Rgba8Unorm (linear, prevent double gamma)
- Correct bind group layout to match mesh_shader.wgsl
- Add vec3 padding requirement for Rust uniform structs
- Per-frame bind group creation for viewport renderer
- Pan degenerate right vector guard

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:22:30 +09:00
8468f3cce2 docs: add scene viewport design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:20:07 +09:00
c19dc6421a docs: update STATUS.md and DEFERRED.md with docking system
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:11:26 +09:00
f7ef228b49 feat(editor): integrate docking into editor_demo
Add full-frame-cycle integration test to dock.rs (16 tests total) and
update editor_demo to use DockTree layout instead of a single panel.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:09:53 +09:00
1571aa5f97 feat(editor): add draw_chrome for tab bars and split lines
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 09:57:18 +09:00
36fedb48bf feat(editor): add update method with tab click and resize handling
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 09:54:57 +09:00
a642f8ef7e fix(editor): address code review issues in dock module
- debug_assert → assert for empty tabs invariant
- tabs.len() - 1 → saturating_sub(1) to prevent underflow
- Add #[allow(dead_code)] for scaffolded structs/constants

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 09:52:53 +09:00
14784c731c feat(editor): add dock tree data model and layout algorithm
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 09:49:53 +09:00
a69554eede docs: add editor docking implementation plan
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 09:43:08 +09:00
867f8a0aa0 docs: fix docking spec issues from review
- Add panel name registry (names Vec)
- Cache layout results for draw_chrome
- Track prev_mouse_down for click detection
- Add invariants (non-empty tabs, active clamping, resize priority)
- Add LeafLayout and ResizeState structs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 09:37:18 +09:00
eb454800a9 docs: add editor docking system design spec
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 09:34:45 +09:00
622c91a954 docs: update STATUS.md and DEFERRED.md with all completed deferred items
Test count: 255 → 485. Major additions across all crates:
- Renderer: JPG, glTF, ORM/Emissive, CSM, Point/Spot Shadow, Frustum Culling, SH, GPU BRDF
- ECS: JSON/Binary scene, ComponentRegistry, query filters, scheduler
- Physics: Angular dynamics, Sequential Impulse, Sleep, CCD, BVH refit, ray/triangle
- Asset: Async loader, FileWatcher, hot reload
- Audio: OGG/Vorbis, 24/32-bit WAV, Doppler
- AI: NavMesh builder, Funnel, obstacle avoidance
- Net: Reliability, snapshots, interpolation
- Script: Table interop, coroutines, sandbox, hot reload
- Editor: Text input, scroll, drag & drop

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 07:38:37 +09:00
f522bf10ac feat(script): add Lua table interop, coroutines, sandbox, hot reload
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 07:33:00 +09:00
63e59c0544 feat(editor): add text input, scroll panel, drag-and-drop widgets
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 07:32:55 +09:00
9f5f2df07c feat(ai): add navmesh builder, funnel algorithm, dynamic obstacle avoidance
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 07:15:40 +09:00
1c2a8466e7 feat(audio): add OGG/Vorbis decoder, 24/32-bit WAV, Doppler effect
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 07:15:38 +09:00
0ef750de69 feat(net): add reliability layer, state sync, and client interpolation
- ReliableChannel: sequence numbers, ACK, retransmission, RTT estimation
- OrderedChannel: in-order delivery with out-of-order buffering
- Snapshot serialization with delta compression (per-field bitmask)
- InterpolationBuffer: linear interpolation between server snapshots
- New packet types: Reliable, Ack, Snapshot, SnapshotDelta

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 21:03:52 +09:00
dccea21bfe feat(ai): add funnel string-pulling and navmesh serialization
Add Simple Stupid Funnel (SSF) algorithm for optimal path smoothing through
triangle corridors. Refactor A* to expose find_path_triangles for triangle
index paths. Add binary serialize/deserialize for NavMesh and shared_edge
query method.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 21:00:35 +09:00
1081fb472f feat(renderer): improve IBL with Hosek-Wilkie sky, SH irradiance, GPU BRDF LUT
- Hosek-Wilkie inspired procedural sky (Rayleigh/Mie scattering, sun disk)
- L2 Spherical Harmonics irradiance (9 coefficients, CPU computation)
- SH evaluation in shader replaces sample_environment for diffuse IBL
- GPU compute BRDF LUT (Rg16Float, higher precision than CPU Rgba8Unorm)
- SkyParams (sun_direction, turbidity) in ShadowUniform

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:58:28 +09:00
abd6f5cf6e feat(physics): add angular dynamics, sequential impulse solver, sleep system, BVH improvements, CCD, and ray extensions
- Angular velocity integration with diagonal inertia tensor (sphere/box/capsule)
- Angular impulse in collision solver (torque from off-center contacts)
- Sequential impulse solver with configurable iterations (default 4)
- Sleep/island system: bodies sleep after velocity threshold timeout, wake on collision
- Ray vs triangle intersection (Moller-Trumbore algorithm)
- raycast_all returning all hits sorted by distance
- BVH query_pairs replaced N^2 brute force with recursive tree traversal
- BVH query_ray for accelerated raycasting
- BVH refit for incremental AABB updates
- Swept sphere vs AABB continuous collision detection (CCD)
- Updated lib.rs exports for all new public APIs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:57:55 +09:00
1b0e12e824 feat(renderer): add CSM, point/spot shadows, and frustum light culling
- CascadedShadowMap: 2-cascade directional shadows with frustum-based splits
- PointShadowMap: cube depth texture with 6-face rendering
- SpotShadowMap: perspective shadow map from spot light cone
- Frustum light culling: Gribb-Hartmann plane extraction + sphere tests
- Mat4::inverse() for frustum corner computation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:55:43 +09:00
a7497f6045 docs: add Phase 4b-5 deferred items spec (shadows, IBL, physics)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:50:24 +09:00
6bc77cb777 feat(renderer): add ORM and emissive texture map support to PBR pipeline
- Extended bind group 1: albedo + normal + ORM + emissive (8 bindings)
- pbr_shader.wgsl: ORM sampling (R=AO, G=roughness, B=metallic) + emissive
- deferred_gbuffer.wgsl: ORM + emissive luminance in material_data.w
- deferred_lighting.wgsl: emissive contribution from G-Buffer
- All 5 PBR examples updated with default ORM/emissive textures
- Backward compatible: old 4-binding layout preserved

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:41:30 +09:00
164eead5ec feat(ecs): add JSON/binary scene serialization and component registry
- Mini JSON writer/parser in voltex_ecs (no renderer dependency)
- ComponentRegistry with register/find/register_defaults
- serialize_scene_json/deserialize_scene_json with hex-encoded components
- serialize_scene_binary/deserialize_scene_binary (VSCN binary format)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:38:56 +09:00
f4b1174e13 feat(asset): add async loading, file watcher, and hot reload support
- FileWatcher: mtime-based polling change detection
- AssetLoader: background thread loading via channels
- replace_in_place on AssetStorage for hot reload
- LoadState enum: Loading/Ready/Failed

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:34:54 +09:00
c478e2433d docs: add implementation plans for scene serialization, async loading, PBR textures
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:24:19 +09:00
389cbdb063 docs: add Phase 3b-4a deferred items spec (serialization, async load, PBR textures)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:21:17 +09:00
df2082f532 docs: update STATUS.md and DEFERRED.md for completed Phase 2-3a items
Mark PNG/JPG/glTF, query3/4, query filters, scheduler,
Capsule/GJK, Coulomb friction, Lua engine API as completed.
Update test count from 255 to 324.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:16:37 +09:00
a9f5b11f69 feat(renderer): add glTF 2.0 / GLB parser with self-contained JSON parser
- Mini JSON parser (no external deps) for glTF support
- GLB binary format: header, JSON chunk, BIN chunk
- Embedded base64 buffer URI support
- Accessor/BufferView extraction (position, normal, uv, tangent, indices)
- PBR material extraction (baseColor, metallic, roughness)
- Auto compute_tangents when not provided

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:14:04 +09:00
2d80a218c5 feat(renderer): add Baseline JPEG decoder
Self-contained Huffman/IDCT/MCU/YCbCr decoder.
Supports SOF0, 4:4:4/4:2:2/4:2:0 subsampling, grayscale,
restart markers. API matches parse_png pattern.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:06:07 +09:00
a080f0608b feat(ecs): add query filters (with/without) and system scheduler
- has_component<T> helper on World
- query_with/query_without for single component + filter
- query2_with/query2_without for 2-component + filter
- System trait with blanket impl for FnMut(&mut World)
- Ordered Scheduler (add/run_all)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:05:02 +09:00
8abba16137 docs: add implementation plans for JPG decoder, glTF parser, ECS filters/scheduler
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 19:58:56 +09:00
dc6aa950e3 docs: add Phase 2-3a deferred items spec (JPG, glTF, ECS filters/scheduler)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 19:51:50 +09:00
160 changed files with 29777 additions and 508 deletions

6
Cargo.lock generated
View File

@@ -491,7 +491,9 @@ dependencies = [
"env_logger",
"log",
"pollster",
"voltex_ecs",
"voltex_editor",
"voltex_math",
"voltex_platform",
"voltex_renderer",
"wgpu",
@@ -2080,6 +2082,7 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
name = "voltex_ai"
version = "0.1.0"
dependencies = [
"voltex_ecs",
"voltex_math",
]
@@ -2106,6 +2109,9 @@ name = "voltex_editor"
version = "0.1.0"
dependencies = [
"bytemuck",
"voltex_ecs",
"voltex_math",
"voltex_renderer",
"wgpu",
]

View File

@@ -5,3 +5,4 @@ edition = "2021"
[dependencies]
voltex_math.workspace = true
voltex_ecs.workspace = true

View File

@@ -1,7 +1,9 @@
pub mod nav_agent;
pub mod navmesh;
pub mod pathfinding;
pub mod steering;
pub use nav_agent::NavAgent;
pub use navmesh::{NavMesh, NavTriangle};
pub use pathfinding::find_path;
pub use steering::{SteeringAgent, seek, flee, arrive, wander, follow_path};

View File

@@ -0,0 +1,85 @@
use voltex_math::Vec3;
/// ECS component for AI pathfinding navigation.
#[derive(Debug, Clone)]
pub struct NavAgent {
pub target: Option<Vec3>,
pub speed: f32,
pub path: Vec<Vec3>,
pub current_waypoint: usize,
pub reached: bool,
}
impl NavAgent {
pub fn new(speed: f32) -> Self {
NavAgent {
target: None,
speed,
path: Vec::new(),
current_waypoint: 0,
reached: false,
}
}
pub fn set_target(&mut self, target: Vec3) {
self.target = Some(target);
self.path.clear();
self.current_waypoint = 0;
self.reached = false;
}
pub fn clear_target(&mut self) {
self.target = None;
self.path.clear();
self.current_waypoint = 0;
self.reached = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_defaults() {
let agent = NavAgent::new(5.0);
assert!((agent.speed - 5.0).abs() < 1e-6);
assert!(agent.target.is_none());
assert!(agent.path.is_empty());
assert_eq!(agent.current_waypoint, 0);
assert!(!agent.reached);
}
#[test]
fn test_set_target() {
let mut agent = NavAgent::new(3.0);
agent.set_target(Vec3::new(10.0, 0.0, 5.0));
assert!(agent.target.is_some());
let t = agent.target.unwrap();
assert!((t.x - 10.0).abs() < 1e-6);
assert!(!agent.reached);
}
#[test]
fn test_set_target_clears_path() {
let mut agent = NavAgent::new(3.0);
agent.path = vec![Vec3::new(1.0, 0.0, 0.0), Vec3::new(2.0, 0.0, 0.0)];
agent.current_waypoint = 1;
agent.set_target(Vec3::new(5.0, 0.0, 5.0));
assert!(agent.path.is_empty());
assert_eq!(agent.current_waypoint, 0);
}
#[test]
fn test_clear_target() {
let mut agent = NavAgent::new(3.0);
agent.set_target(Vec3::new(10.0, 0.0, 5.0));
agent.path = vec![Vec3::new(1.0, 0.0, 0.0)];
agent.reached = true;
agent.clear_target();
assert!(agent.target.is_none());
assert!(agent.path.is_empty());
assert_eq!(agent.current_waypoint, 0);
assert!(!agent.reached);
}
}

View File

@@ -67,6 +67,147 @@ impl NavMesh {
(a.z + b.z) / 2.0,
)
}
/// Serialize the NavMesh to a binary format with header.
///
/// Format:
/// - Magic: "VNAV" (4 bytes)
/// - Version: u32 (little-endian)
/// - num_vertices: u32
/// - vertices: [f32; 3] * num_vertices
/// - num_triangles: u32
/// - triangles: for each triangle: indices [u32; 3] + neighbors [i32; 3] (-1 for None)
pub fn serialize(&self) -> Vec<u8> {
let mut data = Vec::new();
// Magic
data.extend_from_slice(b"VNAV");
// Version
data.extend_from_slice(&1u32.to_le_bytes());
// Vertices
data.extend_from_slice(&(self.vertices.len() as u32).to_le_bytes());
for v in &self.vertices {
data.extend_from_slice(&v.x.to_le_bytes());
data.extend_from_slice(&v.y.to_le_bytes());
data.extend_from_slice(&v.z.to_le_bytes());
}
// Triangles
data.extend_from_slice(&(self.triangles.len() as u32).to_le_bytes());
for tri in &self.triangles {
for &idx in &tri.indices {
data.extend_from_slice(&(idx as u32).to_le_bytes());
}
for &nb in &tri.neighbors {
let val: i32 = match nb {
Some(i) => i as i32,
None => -1,
};
data.extend_from_slice(&val.to_le_bytes());
}
}
data
}
/// Deserialize a NavMesh from binary data with header validation.
pub fn deserialize(data: &[u8]) -> Result<Self, String> {
if data.len() < 8 {
return Err("data too short for header".to_string());
}
// Validate magic
if &data[0..4] != b"VNAV" {
return Err(format!(
"invalid magic: expected VNAV, got {:?}",
&data[0..4]
));
}
// Read version
let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
if version != 1 {
return Err(format!("unsupported version: {}", version));
}
// Delegate rest to offset-based reading
let mut offset = 8usize;
let read_u32 = |off: &mut usize| -> Result<u32, String> {
if *off + 4 > data.len() {
return Err("unexpected end of data".to_string());
}
let val = u32::from_le_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]]);
*off += 4;
Ok(val)
};
let read_i32 = |off: &mut usize| -> Result<i32, String> {
if *off + 4 > data.len() {
return Err("unexpected end of data".to_string());
}
let val = i32::from_le_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]]);
*off += 4;
Ok(val)
};
let read_f32 = |off: &mut usize| -> Result<f32, String> {
if *off + 4 > data.len() {
return Err("unexpected end of data".to_string());
}
let val = f32::from_le_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]]);
*off += 4;
Ok(val)
};
let vc = read_u32(&mut offset)? as usize;
let mut vertices = Vec::with_capacity(vc);
for _ in 0..vc {
let x = read_f32(&mut offset)?;
let y = read_f32(&mut offset)?;
let z = read_f32(&mut offset)?;
vertices.push(Vec3::new(x, y, z));
}
let tc = read_u32(&mut offset)? as usize;
let mut triangles = Vec::with_capacity(tc);
for _ in 0..tc {
let i0 = read_u32(&mut offset)? as usize;
let i1 = read_u32(&mut offset)? as usize;
let i2 = read_u32(&mut offset)? as usize;
let n0 = read_i32(&mut offset)?;
let n1 = read_i32(&mut offset)?;
let n2 = read_i32(&mut offset)?;
let to_opt = |v: i32| -> Option<usize> {
if v < 0 { None } else { Some(v as usize) }
};
triangles.push(NavTriangle {
indices: [i0, i1, i2],
neighbors: [to_opt(n0), to_opt(n1), to_opt(n2)],
});
}
Ok(NavMesh::new(vertices, triangles))
}
/// Return the shared edge (portal) vertices between two adjacent triangles.
/// Returns (left, right) vertices of the portal edge as seen from `from_tri` looking toward `to_tri`.
/// Returns None if the triangles are not adjacent.
pub fn shared_edge(&self, from_tri: usize, to_tri: usize) -> Option<(Vec3, Vec3)> {
let tri = &self.triangles[from_tri];
for edge_idx in 0..3 {
if tri.neighbors[edge_idx] == Some(to_tri) {
let (i0, i1) = match edge_idx {
0 => (tri.indices[0], tri.indices[1]),
1 => (tri.indices[1], tri.indices[2]),
2 => (tri.indices[2], tri.indices[0]),
_ => unreachable!(),
};
return Some((self.vertices[i0], self.vertices[i1]));
}
}
None
}
}
/// Test whether `point` lies inside or on the triangle (a, b, c) using XZ barycentric coordinates.
@@ -90,6 +231,104 @@ pub fn point_in_triangle_xz(point: Vec3, a: Vec3, b: Vec3, c: Vec3) -> bool {
u >= 0.0 && v >= 0.0 && w >= 0.0
}
/// Serialize a NavMesh to binary format.
/// Format: vertex_count(u32) + vertices(f32*3 each) + triangle_count(u32) + triangles(indices u32*3 + neighbors i32*3 each)
pub fn serialize_navmesh(navmesh: &NavMesh) -> Vec<u8> {
let mut data = Vec::new();
// Vertex count
let vc = navmesh.vertices.len() as u32;
data.extend_from_slice(&vc.to_le_bytes());
// Vertices
for v in &navmesh.vertices {
data.extend_from_slice(&v.x.to_le_bytes());
data.extend_from_slice(&v.y.to_le_bytes());
data.extend_from_slice(&v.z.to_le_bytes());
}
// Triangle count
let tc = navmesh.triangles.len() as u32;
data.extend_from_slice(&tc.to_le_bytes());
// Triangles: indices(u32*3) + neighbors(i32*3, -1 for None)
for tri in &navmesh.triangles {
for &idx in &tri.indices {
data.extend_from_slice(&(idx as u32).to_le_bytes());
}
for &nb in &tri.neighbors {
let val: i32 = match nb {
Some(i) => i as i32,
None => -1,
};
data.extend_from_slice(&val.to_le_bytes());
}
}
data
}
/// Deserialize a NavMesh from binary data.
pub fn deserialize_navmesh(data: &[u8]) -> Result<NavMesh, String> {
let mut offset = 0;
let read_u32 = |off: &mut usize| -> Result<u32, String> {
if *off + 4 > data.len() {
return Err("unexpected end of data".to_string());
}
let val = u32::from_le_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]]);
*off += 4;
Ok(val)
};
let read_i32 = |off: &mut usize| -> Result<i32, String> {
if *off + 4 > data.len() {
return Err("unexpected end of data".to_string());
}
let val = i32::from_le_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]]);
*off += 4;
Ok(val)
};
let read_f32 = |off: &mut usize| -> Result<f32, String> {
if *off + 4 > data.len() {
return Err("unexpected end of data".to_string());
}
let val = f32::from_le_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]]);
*off += 4;
Ok(val)
};
let vc = read_u32(&mut offset)? as usize;
let mut vertices = Vec::with_capacity(vc);
for _ in 0..vc {
let x = read_f32(&mut offset)?;
let y = read_f32(&mut offset)?;
let z = read_f32(&mut offset)?;
vertices.push(Vec3::new(x, y, z));
}
let tc = read_u32(&mut offset)? as usize;
let mut triangles = Vec::with_capacity(tc);
for _ in 0..tc {
let i0 = read_u32(&mut offset)? as usize;
let i1 = read_u32(&mut offset)? as usize;
let i2 = read_u32(&mut offset)? as usize;
let n0 = read_i32(&mut offset)?;
let n1 = read_i32(&mut offset)?;
let n2 = read_i32(&mut offset)?;
let to_opt = |v: i32| -> Option<usize> {
if v < 0 { None } else { Some(v as usize) }
};
triangles.push(NavTriangle {
indices: [i0, i1, i2],
neighbors: [to_opt(n0), to_opt(n1), to_opt(n2)],
});
}
Ok(NavMesh::new(vertices, triangles))
}
#[cfg(test)]
mod tests {
use super::*;
@@ -155,4 +394,104 @@ mod tests {
assert!((mid.y - 0.0).abs() < 1e-5);
assert!((mid.z - 1.0).abs() < 1e-5);
}
#[test]
fn test_shared_edge() {
let nm = make_simple_navmesh();
// Tri 0 edge 1 connects to Tri 1: shared edge is v1=(2,0,0) and v2=(1,0,2)
let (left, right) = nm.shared_edge(0, 1).expect("should find shared edge");
assert_eq!(left, Vec3::new(2.0, 0.0, 0.0));
assert_eq!(right, Vec3::new(1.0, 0.0, 2.0));
}
#[test]
fn test_shared_edge_not_adjacent() {
let nm = make_simple_navmesh();
// Tri 0 is not adjacent to itself via neighbors
assert!(nm.shared_edge(0, 0).is_none());
}
#[test]
fn test_serialize_roundtrip() {
let nm = make_simple_navmesh();
let data = serialize_navmesh(&nm);
let nm2 = deserialize_navmesh(&data).expect("should deserialize");
assert_eq!(nm2.vertices.len(), nm.vertices.len());
assert_eq!(nm2.triangles.len(), nm.triangles.len());
for (a, b) in nm.vertices.iter().zip(nm2.vertices.iter()) {
assert!((a.x - b.x).abs() < 1e-6);
assert!((a.y - b.y).abs() < 1e-6);
assert!((a.z - b.z).abs() < 1e-6);
}
for (a, b) in nm.triangles.iter().zip(nm2.triangles.iter()) {
assert_eq!(a.indices, b.indices);
assert_eq!(a.neighbors, b.neighbors);
}
}
#[test]
fn test_deserialize_empty_data() {
let result = deserialize_navmesh(&[]);
assert!(result.is_err());
}
#[test]
fn test_method_serialize_deserialize_roundtrip() {
let nm = make_simple_navmesh();
let data = nm.serialize();
let restored = NavMesh::deserialize(&data).expect("should deserialize");
assert_eq!(nm.vertices.len(), restored.vertices.len());
assert_eq!(nm.triangles.len(), restored.triangles.len());
for (a, b) in nm.vertices.iter().zip(restored.vertices.iter()) {
assert!((a.x - b.x).abs() < 1e-6);
assert!((a.y - b.y).abs() < 1e-6);
assert!((a.z - b.z).abs() < 1e-6);
}
for (a, b) in nm.triangles.iter().zip(restored.triangles.iter()) {
assert_eq!(a.indices, b.indices);
assert_eq!(a.neighbors, b.neighbors);
}
}
#[test]
fn test_deserialize_invalid_magic() {
let data = b"XXXX\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
assert!(NavMesh::deserialize(data).is_err());
}
#[test]
fn test_deserialize_too_short() {
let data = b"VNA";
assert!(NavMesh::deserialize(data).is_err());
}
#[test]
fn test_serialize_empty_navmesh() {
let nm = NavMesh::new(vec![], vec![]);
let data = nm.serialize();
let restored = NavMesh::deserialize(&data).expect("should deserialize empty");
assert!(restored.vertices.is_empty());
assert!(restored.triangles.is_empty());
}
#[test]
fn test_serialize_header_format() {
let nm = NavMesh::new(vec![], vec![]);
let data = nm.serialize();
// Check magic
assert_eq!(&data[0..4], b"VNAV");
// Check version = 1
let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
assert_eq!(version, 1);
}
#[test]
fn test_deserialize_unsupported_version() {
let mut data = Vec::new();
data.extend_from_slice(b"VNAV");
data.extend_from_slice(&99u32.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
assert!(NavMesh::deserialize(&data).is_err());
}
}

View File

@@ -0,0 +1,651 @@
use voltex_math::Vec3;
use crate::navmesh::{NavMesh, NavTriangle};
/// Configuration for navmesh generation.
pub struct NavMeshBuilder {
/// XZ voxel size (default 0.3)
pub cell_size: f32,
/// Y voxel size (default 0.2)
pub cell_height: f32,
/// Minimum clearance height for walkable areas (default 2.0)
pub agent_height: f32,
/// Agent capsule radius used to erode walkable areas (default 0.5)
pub agent_radius: f32,
/// Maximum walkable slope angle in degrees (default 45.0)
pub max_slope: f32,
}
impl Default for NavMeshBuilder {
fn default() -> Self {
Self {
cell_size: 0.3,
cell_height: 0.2,
agent_height: 2.0,
agent_radius: 0.5,
max_slope: 45.0,
}
}
}
/// Heightfield: a 2D grid of min/max height spans for voxelized geometry.
pub struct Heightfield {
pub width: usize, // number of cells along X
pub depth: usize, // number of cells along Z
pub min_x: f32,
pub min_z: f32,
pub cell_size: f32,
pub cell_height: f32,
/// For each cell (x, z), store (min_y, max_y). None if empty.
pub cells: Vec<Option<(f32, f32)>>,
}
impl Heightfield {
pub fn cell_index(&self, x: usize, z: usize) -> usize {
z * self.width + x
}
}
/// Map of walkable cells. True if the cell is walkable.
pub struct WalkableMap {
pub width: usize,
pub depth: usize,
pub min_x: f32,
pub min_z: f32,
pub cell_size: f32,
pub walkable: Vec<bool>,
/// Height at each walkable cell (top of walkable surface).
pub heights: Vec<f32>,
}
impl WalkableMap {
pub fn cell_index(&self, x: usize, z: usize) -> usize {
z * self.width + x
}
}
/// A region of connected walkable cells (flood-fill result).
pub struct RegionMap {
pub width: usize,
pub depth: usize,
pub min_x: f32,
pub min_z: f32,
pub cell_size: f32,
/// Region ID per cell. 0 = not walkable, 1+ = region ID.
pub regions: Vec<u32>,
pub heights: Vec<f32>,
pub num_regions: u32,
}
impl NavMeshBuilder {
pub fn new() -> Self {
Self::default()
}
/// Step 1: Rasterize triangles into a heightfield grid.
pub fn voxelize(&self, vertices: &[Vec3], indices: &[u32]) -> Heightfield {
// Find bounding box
let mut min_x = f32::INFINITY;
let mut max_x = f32::NEG_INFINITY;
let mut min_y = f32::INFINITY;
let mut max_y = f32::NEG_INFINITY;
let mut min_z = f32::INFINITY;
let mut max_z = f32::NEG_INFINITY;
for v in vertices {
min_x = min_x.min(v.x);
max_x = max_x.max(v.x);
min_y = min_y.min(v.y);
max_y = max_y.max(v.y);
min_z = min_z.min(v.z);
max_z = max_z.max(v.z);
}
let width = ((max_x - min_x) / self.cell_size).ceil() as usize + 1;
let depth = ((max_z - min_z) / self.cell_size).ceil() as usize + 1;
let mut cells: Vec<Option<(f32, f32)>> = vec![None; width * depth];
// Rasterize each triangle
for tri_i in (0..indices.len()).step_by(3) {
if tri_i + 2 >= indices.len() {
break;
}
let v0 = vertices[indices[tri_i] as usize];
let v1 = vertices[indices[tri_i + 1] as usize];
let v2 = vertices[indices[tri_i + 2] as usize];
// Find bounding box of triangle in grid coords
let tri_min_x = v0.x.min(v1.x).min(v2.x);
let tri_max_x = v0.x.max(v1.x).max(v2.x);
let tri_min_z = v0.z.min(v1.z).min(v2.z);
let tri_max_z = v0.z.max(v1.z).max(v2.z);
let gx0 = ((tri_min_x - min_x) / self.cell_size).floor() as isize;
let gx1 = ((tri_max_x - min_x) / self.cell_size).ceil() as isize;
let gz0 = ((tri_min_z - min_z) / self.cell_size).floor() as isize;
let gz1 = ((tri_max_z - min_z) / self.cell_size).ceil() as isize;
let gx0 = gx0.max(0) as usize;
let gx1 = (gx1 as usize).min(width - 1);
let gz0 = gz0.max(0) as usize;
let gz1 = (gz1 as usize).min(depth - 1);
for gz in gz0..=gz1 {
for gx in gx0..=gx1 {
let cx = min_x + (gx as f32 + 0.5) * self.cell_size;
let cz = min_z + (gz as f32 + 0.5) * self.cell_size;
// Check if cell center is inside triangle (XZ)
let p = Vec3::new(cx, 0.0, cz);
if point_in_triangle_xz_loose(p, v0, v1, v2, self.cell_size * 0.5) {
// Interpolate Y at this XZ point
let y = interpolate_y(v0, v1, v2, cx, cz);
let idx = gz * width + gx;
match &mut cells[idx] {
Some((ref mut lo, ref mut hi)) => {
*lo = lo.min(y);
*hi = hi.max(y);
}
None => {
cells[idx] = Some((y, y));
}
}
}
}
}
}
Heightfield {
width,
depth,
min_x,
min_z,
cell_size: self.cell_size,
cell_height: self.cell_height,
cells,
}
}
/// Step 2: Mark walkable cells based on slope and agent height clearance.
pub fn mark_walkable(&self, hf: &Heightfield) -> WalkableMap {
let max_slope_cos = (self.max_slope * std::f32::consts::PI / 180.0).cos();
let n = hf.width * hf.depth;
let mut walkable = vec![false; n];
let mut heights = vec![0.0f32; n];
for z in 0..hf.depth {
for x in 0..hf.width {
let idx = z * hf.width + x;
if let Some((_lo, hi)) = hf.cells[idx] {
// Check slope by comparing height differences with neighbors
let slope_ok = self.check_slope(hf, x, z, max_slope_cos);
// Check clearance: for simplicity, if cell has geometry, assume clearance
// unless there's a cell above within agent_height (not implemented for simple case)
if slope_ok {
walkable[idx] = true;
heights[idx] = hi;
}
}
}
}
// Erode by agent radius: remove walkable cells too close to non-walkable
let erosion_cells = (self.agent_radius / hf.cell_size).ceil() as usize;
if erosion_cells > 0 {
let mut eroded = walkable.clone();
for z in 0..hf.depth {
for x in 0..hf.width {
let idx = z * hf.width + x;
if !walkable[idx] {
continue;
}
// Check if near boundary of walkable area
let mut near_edge = false;
for dz in 0..=erosion_cells {
for dx in 0..=erosion_cells {
if dx == 0 && dz == 0 {
continue;
}
// Check all 4 quadrants
let checks: [(isize, isize); 4] = [
(dx as isize, dz as isize),
(-(dx as isize), dz as isize),
(dx as isize, -(dz as isize)),
(-(dx as isize), -(dz as isize)),
];
for (ddx, ddz) in checks {
let nx = x as isize + ddx;
let nz = z as isize + ddz;
if nx < 0 || nz < 0 || nx >= hf.width as isize || nz >= hf.depth as isize {
near_edge = true;
break;
}
let ni = nz as usize * hf.width + nx as usize;
if !walkable[ni] {
near_edge = true;
break;
}
}
if near_edge {
break;
}
}
if near_edge {
break;
}
}
if near_edge {
eroded[idx] = false;
}
}
}
return WalkableMap {
width: hf.width,
depth: hf.depth,
min_x: hf.min_x,
min_z: hf.min_z,
cell_size: hf.cell_size,
walkable: eroded,
heights,
};
}
WalkableMap {
width: hf.width,
depth: hf.depth,
min_x: hf.min_x,
min_z: hf.min_z,
cell_size: hf.cell_size,
walkable,
heights,
}
}
/// Check if the slope at cell (x, z) is walkable.
fn check_slope(&self, hf: &Heightfield, x: usize, z: usize, max_slope_cos: f32) -> bool {
let idx = z * hf.width + x;
let h = match hf.cells[idx] {
Some((_, hi)) => hi,
None => return false,
};
// Compare with direct neighbors to estimate slope
let neighbors: [(isize, isize); 4] = [(1, 0), (-1, 0), (0, 1), (0, -1)];
for (dx, dz) in neighbors {
let nx = x as isize + dx;
let nz = z as isize + dz;
if nx < 0 || nz < 0 || nx >= hf.width as isize || nz >= hf.depth as isize {
continue;
}
let ni = nz as usize * hf.width + nx as usize;
if let Some((_, nh)) = hf.cells[ni] {
let dy = (nh - h).abs();
let dist = hf.cell_size;
// slope angle: atan(dy/dist), check cos of that angle
let slope_len = (dy * dy + dist * dist).sqrt();
let cos_angle = dist / slope_len;
if cos_angle < max_slope_cos {
return false;
}
}
}
true
}
/// Step 3: Flood-fill connected walkable areas into regions.
pub fn build_regions(&self, wm: &WalkableMap) -> RegionMap {
let n = wm.width * wm.depth;
let mut regions = vec![0u32; n];
let mut current_region = 0u32;
for z in 0..wm.depth {
for x in 0..wm.width {
let idx = z * wm.width + x;
if wm.walkable[idx] && regions[idx] == 0 {
current_region += 1;
// Flood fill
let mut stack = vec![(x, z)];
regions[idx] = current_region;
while let Some((cx, cz)) = stack.pop() {
let neighbors: [(isize, isize); 4] = [(1, 0), (-1, 0), (0, 1), (0, -1)];
for (dx, dz) in neighbors {
let nx = cx as isize + dx;
let nz = cz as isize + dz;
if nx < 0 || nz < 0 || nx >= wm.width as isize || nz >= wm.depth as isize {
continue;
}
let ni = nz as usize * wm.width + nx as usize;
if wm.walkable[ni] && regions[ni] == 0 {
regions[ni] = current_region;
stack.push((nx as usize, nz as usize));
}
}
}
}
}
}
RegionMap {
width: wm.width,
depth: wm.depth,
min_x: wm.min_x,
min_z: wm.min_z,
cell_size: wm.cell_size,
regions,
heights: wm.heights.clone(),
num_regions: current_region,
}
}
/// Steps 4-5 combined: Convert walkable grid cells directly into a NavMesh.
/// Each walkable cell becomes a quad (2 triangles), with adjacency computed.
pub fn triangulate(&self, rm: &RegionMap) -> NavMesh {
let mut vertices = Vec::new();
let mut triangles = Vec::new();
// For each walkable cell, create 4 vertices and 2 triangles.
// Map from cell (x,z) -> (tri_a_idx, tri_b_idx) for adjacency lookup.
let n = rm.width * rm.depth;
// cell_tri_map[cell_idx] = Some((tri_a_idx, tri_b_idx)) or None
let mut cell_tri_map: Vec<Option<(usize, usize)>> = vec![None; n];
for z in 0..rm.depth {
for x in 0..rm.width {
let idx = z * rm.width + x;
if rm.regions[idx] == 0 {
continue;
}
let h = rm.heights[idx];
let x0 = rm.min_x + x as f32 * rm.cell_size;
let x1 = x0 + rm.cell_size;
let z0 = rm.min_z + z as f32 * rm.cell_size;
let z1 = z0 + rm.cell_size;
let vi = vertices.len();
vertices.push(Vec3::new(x0, h, z0)); // vi+0: bottom-left
vertices.push(Vec3::new(x1, h, z0)); // vi+1: bottom-right
vertices.push(Vec3::new(x1, h, z1)); // vi+2: top-right
vertices.push(Vec3::new(x0, h, z1)); // vi+3: top-left
let ta = triangles.len();
// Triangle A: bottom-left, bottom-right, top-right (vi+0, vi+1, vi+2)
triangles.push(NavTriangle {
indices: [vi, vi + 1, vi + 2],
neighbors: [None, None, None], // filled in later
});
// Triangle B: bottom-left, top-right, top-left (vi+0, vi+2, vi+3)
triangles.push(NavTriangle {
indices: [vi, vi + 2, vi + 3],
neighbors: [None, None, None],
});
// Internal adjacency: A and B share edge (vi+0, vi+2)
// For A: edge 2 is (vi+2 -> vi+0) — indices[2] to indices[0]
// For B: edge 0 is (vi+0 -> vi+2) — indices[0] to indices[1]
triangles[ta].neighbors[2] = Some(ta + 1); // A's edge2 -> B
triangles[ta + 1].neighbors[0] = Some(ta); // B's edge0 -> A
cell_tri_map[idx] = Some((ta, ta + 1));
}
}
// Now compute inter-cell adjacency
// Cell layout:
// Tri A: (v0, v1, v2) = (BL, BR, TR)
// edge 0: BL->BR (bottom edge, connects to cell below z-1)
// edge 1: BR->TR (right edge, connects to cell at x+1)
// edge 2: TR->BL (diagonal, internal, already connected to B)
// Tri B: (v0, v2, v3) = (BL, TR, TL)
// edge 0: BL->TR (diagonal, internal, already connected to A)
// edge 1: TR->TL (top edge, connects to cell above z+1)
// edge 2: TL->BL (left edge, connects to cell at x-1)
for z in 0..rm.depth {
for x in 0..rm.width {
let idx = z * rm.width + x;
if cell_tri_map[idx].is_none() {
continue;
}
let (ta, tb) = cell_tri_map[idx].unwrap();
// Bottom neighbor (z-1): A's edge 0 connects to neighbor's B edge 1 (TR->TL = top)
if z > 0 {
let ni = (z - 1) * rm.width + x;
if let Some((_, nb_tb)) = cell_tri_map[ni] {
triangles[ta].neighbors[0] = Some(nb_tb);
triangles[nb_tb].neighbors[1] = Some(ta);
}
}
// Right neighbor (x+1): A's edge 1 connects to neighbor's B edge 2 (TL->BL = left)
if x + 1 < rm.width {
let ni = z * rm.width + (x + 1);
if let Some((_, nb_tb)) = cell_tri_map[ni] {
triangles[ta].neighbors[1] = Some(nb_tb);
triangles[nb_tb].neighbors[2] = Some(ta);
}
}
}
}
NavMesh::new(vertices, triangles)
}
/// Full pipeline: voxelize, mark walkable, build regions, triangulate.
pub fn build(&self, vertices: &[Vec3], indices: &[u32]) -> NavMesh {
let hf = self.voxelize(vertices, indices);
let wm = self.mark_walkable(&hf);
let rm = self.build_regions(&wm);
self.triangulate(&rm)
}
}
/// Loose point-in-triangle test on XZ plane, with a tolerance margin.
fn point_in_triangle_xz_loose(point: Vec3, a: Vec3, b: Vec3, c: Vec3, margin: f32) -> bool {
let px = point.x;
let pz = point.z;
let denom = (b.z - c.z) * (a.x - c.x) + (c.x - b.x) * (a.z - c.z);
if denom.abs() < f32::EPSILON {
// Degenerate: check if point is near the line segment
return false;
}
let u = ((b.z - c.z) * (px - c.x) + (c.x - b.x) * (pz - c.z)) / denom;
let v = ((c.z - a.z) * (px - c.x) + (a.x - c.x) * (pz - c.z)) / denom;
let w = 1.0 - u - v;
let e = margin / ((a - b).length().max((b - c).length()).max((c - a).length()).max(0.001));
u >= -e && v >= -e && w >= -e
}
/// Interpolate Y height at (px, pz) on the plane defined by triangle (a, b, c).
fn interpolate_y(a: Vec3, b: Vec3, c: Vec3, px: f32, pz: f32) -> f32 {
let denom = (b.z - c.z) * (a.x - c.x) + (c.x - b.x) * (a.z - c.z);
if denom.abs() < f32::EPSILON {
return (a.y + b.y + c.y) / 3.0;
}
let u = ((b.z - c.z) * (px - c.x) + (c.x - b.x) * (pz - c.z)) / denom;
let v = ((c.z - a.z) * (px - c.x) + (a.x - c.x) * (pz - c.z)) / denom;
let w = 1.0 - u - v;
u * a.y + v * b.y + w * c.y
}
#[cfg(test)]
mod tests {
use super::*;
/// Create a simple flat plane (10x10, y=0) as 2 triangles.
fn flat_plane_geometry() -> (Vec<Vec3>, Vec<u32>) {
let vertices = vec![
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(10.0, 0.0, 0.0),
Vec3::new(10.0, 0.0, 10.0),
Vec3::new(0.0, 0.0, 10.0),
];
let indices = vec![0, 1, 2, 0, 2, 3];
(vertices, indices)
}
/// Create a plane with a box obstacle (hole) in the middle.
/// The plane is 10x10 with a 2x2 hole at center (4-6, 4-6).
fn plane_with_obstacle() -> (Vec<Vec3>, Vec<u32>) {
// Build geometry as a grid of quads, skipping the obstacle region.
// Use 1.0 unit grid cells for simplicity.
let mut vertices = Vec::new();
let mut indices = Vec::new();
// 10x10 grid of 1x1 quads, skip 4<=x<6 && 4<=z<6
for z in 0..10 {
for x in 0..10 {
if x >= 4 && x < 6 && z >= 4 && z < 6 {
continue; // obstacle
}
let vi = vertices.len() as u32;
let fx = x as f32;
let fz = z as f32;
vertices.push(Vec3::new(fx, 0.0, fz));
vertices.push(Vec3::new(fx + 1.0, 0.0, fz));
vertices.push(Vec3::new(fx + 1.0, 0.0, fz + 1.0));
vertices.push(Vec3::new(fx, 0.0, fz + 1.0));
indices.push(vi);
indices.push(vi + 1);
indices.push(vi + 2);
indices.push(vi);
indices.push(vi + 2);
indices.push(vi + 3);
}
}
(vertices, indices)
}
#[test]
fn test_voxelize_flat_plane() {
let builder = NavMeshBuilder {
cell_size: 1.0,
cell_height: 0.2,
agent_height: 2.0,
agent_radius: 0.0, // no erosion for simple test
max_slope: 45.0,
};
let (verts, idxs) = flat_plane_geometry();
let hf = builder.voxelize(&verts, &idxs);
// Should have cells covering the 10x10 area
assert!(hf.width > 0);
assert!(hf.depth > 0);
// Check that some cells are populated
let populated = hf.cells.iter().filter(|c| c.is_some()).count();
assert!(populated > 0, "some cells should be populated");
}
#[test]
fn test_flat_plane_single_region() {
let builder = NavMeshBuilder {
cell_size: 1.0,
cell_height: 0.2,
agent_height: 2.0,
agent_radius: 0.0,
max_slope: 45.0,
};
let (verts, idxs) = flat_plane_geometry();
let hf = builder.voxelize(&verts, &idxs);
let wm = builder.mark_walkable(&hf);
let rm = builder.build_regions(&wm);
assert_eq!(rm.num_regions, 1, "flat plane should be a single region");
}
#[test]
fn test_flat_plane_builds_navmesh() {
let builder = NavMeshBuilder {
cell_size: 1.0,
cell_height: 0.2,
agent_height: 2.0,
agent_radius: 0.0,
max_slope: 45.0,
};
let (verts, idxs) = flat_plane_geometry();
let nm = builder.build(&verts, &idxs);
assert!(!nm.vertices.is_empty(), "navmesh should have vertices");
assert!(!nm.triangles.is_empty(), "navmesh should have triangles");
// Should be able to find a triangle at center of the plane
let center = Vec3::new(5.0, 0.0, 5.0);
assert!(nm.find_triangle(center).is_some(), "should find triangle at center");
}
#[test]
fn test_obstacle_path_around() {
let builder = NavMeshBuilder {
cell_size: 1.0,
cell_height: 0.2,
agent_height: 2.0,
agent_radius: 0.0,
max_slope: 45.0,
};
let (verts, idxs) = plane_with_obstacle();
let nm = builder.build(&verts, &idxs);
// Start at (1, 0, 5) and goal at (9, 0, 5) — must go around obstacle
let start = Vec3::new(1.5, 0.0, 5.5);
let goal = Vec3::new(8.5, 0.0, 5.5);
use crate::pathfinding::find_path;
let path = find_path(&nm, start, goal);
assert!(path.is_some(), "should find path around obstacle");
}
#[test]
fn test_slope_walkable_unwalkable() {
// Create a steep ramp: triangle from (0,0,0) to (1,0,0) to (0.5, 5, 1)
// Slope angle = atan(5/1) ≈ 78.7 degrees — should be unwalkable at max_slope=45
let builder = NavMeshBuilder {
cell_size: 0.2,
cell_height: 0.2,
agent_height: 2.0,
agent_radius: 0.0,
max_slope: 45.0,
};
let vertices = vec![
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(2.0, 0.0, 0.0),
Vec3::new(1.0, 10.0, 2.0),
];
let indices = vec![0, 1, 2];
let hf = builder.voxelize(&vertices, &indices);
let wm = builder.mark_walkable(&hf);
// Most interior cells should be unwalkable due to steep slope
let walkable_count = wm.walkable.iter().filter(|&&w| w).count();
// The bottom edge cells might be walkable (flat), but interior should not
// Just check that not all cells are walkable
let total_cells = hf.cells.iter().filter(|c| c.is_some()).count();
assert!(
walkable_count < total_cells || total_cells <= 1,
"steep slope should make most cells unwalkable (walkable={}, total={})",
walkable_count, total_cells
);
}
#[test]
fn test_navmesh_adjacency() {
let builder = NavMeshBuilder {
cell_size: 1.0,
cell_height: 0.2,
agent_height: 2.0,
agent_radius: 0.0,
max_slope: 45.0,
};
let (verts, idxs) = flat_plane_geometry();
let nm = builder.build(&verts, &idxs);
// Check that some triangles have neighbors
let has_neighbors = nm.triangles.iter().any(|t| t.neighbors.iter().any(|n| n.is_some()));
assert!(has_neighbors, "navmesh triangles should have adjacency");
}
}

View File

@@ -0,0 +1,176 @@
use voltex_math::Vec3;
/// A dynamic obstacle represented as a position and radius.
#[derive(Debug, Clone)]
pub struct DynamicObstacle {
pub position: Vec3,
pub radius: f32,
}
/// Compute avoidance steering force using velocity obstacle approach.
///
/// Projects the agent's velocity forward by `look_ahead` distance and checks
/// for circle intersections with obstacles. Returns a steering force perpendicular
/// to the approach direction to avoid the nearest threatening obstacle.
pub fn avoid_obstacles(
agent_pos: Vec3,
agent_vel: Vec3,
agent_radius: f32,
obstacles: &[DynamicObstacle],
look_ahead: f32,
) -> Vec3 {
let speed = agent_vel.length();
if speed < f32::EPSILON {
return Vec3::ZERO;
}
let forward = agent_vel * (1.0 / speed);
let mut nearest_t = f32::INFINITY;
let mut avoidance = Vec3::ZERO;
for obs in obstacles {
let to_obs = obs.position - agent_pos;
let combined_radius = agent_radius + obs.radius;
// Project obstacle center onto the velocity ray
let proj = to_obs.dot(forward);
// Obstacle is behind or too far ahead
if proj < 0.0 || proj > look_ahead {
continue;
}
// Lateral distance from the velocity ray to obstacle center (XZ only for ground agents)
let closest_on_ray = agent_pos + forward * proj;
let diff = obs.position - closest_on_ray;
let lateral_dist_sq = diff.x * diff.x + diff.z * diff.z;
let combined_sq = combined_radius * combined_radius;
if lateral_dist_sq >= combined_sq {
continue; // No collision
}
// This obstacle threatens the agent — check if it's the nearest
if proj < nearest_t {
nearest_t = proj;
// Avoidance direction: perpendicular to approach, away from obstacle
// Use XZ plane lateral vector
let lateral = Vec3::new(diff.x, 0.0, diff.z);
let lat_len = lateral.length();
if lat_len > f32::EPSILON {
// Steer away from obstacle (opposite direction of lateral offset)
let steer_dir = lateral * (-1.0 / lat_len);
// Strength inversely proportional to distance (closer = stronger)
let strength = 1.0 - (proj / look_ahead);
avoidance = steer_dir * strength * speed;
} else {
// Agent heading straight at obstacle center — pick perpendicular
// Use cross product with Y to get a lateral direction
let perp = Vec3::new(-forward.z, 0.0, forward.x);
avoidance = perp * speed;
}
}
}
avoidance
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_obstacle_zero_force() {
let force = avoid_obstacles(
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(1.0, 0.0, 0.0),
0.5,
&[],
5.0,
);
assert!(force.length() < 1e-6, "no obstacles should give zero force");
}
#[test]
fn test_obstacle_behind_zero_force() {
let obs = DynamicObstacle {
position: Vec3::new(-3.0, 0.0, 0.0),
radius: 1.0,
};
let force = avoid_obstacles(
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(1.0, 0.0, 0.0),
0.5,
&[obs],
5.0,
);
assert!(force.length() < 1e-6, "obstacle behind should give zero force");
}
#[test]
fn test_obstacle_ahead_lateral_force() {
let obs = DynamicObstacle {
position: Vec3::new(3.0, 0.0, 0.5), // slightly to the right
radius: 1.0,
};
let force = avoid_obstacles(
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(2.0, 0.0, 0.0), // moving in +X
0.5,
&[obs],
5.0,
);
assert!(force.length() > 0.1, "obstacle ahead should give non-zero force");
// Force should push away from obstacle (obstacle is at +Z, force should be -Z)
assert!(force.z < 0.0, "force should push away from obstacle (negative Z)");
}
#[test]
fn test_obstacle_far_away_zero_force() {
let obs = DynamicObstacle {
position: Vec3::new(3.0, 0.0, 10.0), // far to the side
radius: 1.0,
};
let force = avoid_obstacles(
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(1.0, 0.0, 0.0),
0.5,
&[obs],
5.0,
);
assert!(force.length() < 1e-6, "distant obstacle should give zero force");
}
#[test]
fn test_obstacle_beyond_lookahead_zero_force() {
let obs = DynamicObstacle {
position: Vec3::new(10.0, 0.0, 0.0),
radius: 1.0,
};
let force = avoid_obstacles(
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(1.0, 0.0, 0.0),
0.5,
&[obs],
5.0, // look_ahead is only 5
);
assert!(force.length() < 1e-6, "obstacle beyond lookahead should give zero force");
}
#[test]
fn test_zero_velocity_zero_force() {
let obs = DynamicObstacle {
position: Vec3::new(3.0, 0.0, 0.0),
radius: 1.0,
};
let force = avoid_obstacles(
Vec3::new(0.0, 0.0, 0.0),
Vec3::ZERO,
0.5,
&[obs],
5.0,
);
assert!(force.length() < 1e-6, "zero velocity should give zero force");
}
}

View File

@@ -43,29 +43,22 @@ pub fn distance_xz(a: Vec3, b: Vec3) -> f32 {
(dx * dx + dz * dz).sqrt()
}
/// Find a path from `start` to `goal` on the given NavMesh using A*.
///
/// Returns Some(path) where path[0] == start, path[last] == goal, and
/// intermediate points are triangle centers. Returns None if either
/// point is outside the mesh or no path exists.
pub fn find_path(navmesh: &NavMesh, start: Vec3, goal: Vec3) -> Option<Vec<Vec3>> {
/// Run A* on the NavMesh and return the sequence of triangle indices from start to goal.
/// Returns None if either point is outside the mesh or no path exists.
pub fn find_path_triangles(navmesh: &NavMesh, start: Vec3, goal: Vec3) -> Option<Vec<usize>> {
let start_tri = navmesh.find_triangle(start)?;
let goal_tri = navmesh.find_triangle(goal)?;
// If start and goal are in the same triangle, path is direct.
if start_tri == goal_tri {
return Some(vec![start, goal]);
return Some(vec![start_tri]);
}
let n = navmesh.triangles.len();
// g_cost[i] = best known cost to reach triangle i
let mut g_costs = vec![f32::INFINITY; n];
// parent[i] = index of parent triangle in the A* tree
let mut parents: Vec<Option<usize>> = vec![None; n];
let mut visited = vec![false; n];
let goal_center = navmesh.triangle_center(goal_tri);
g_costs[start_tri] = 0.0;
let start_center = navmesh.triangle_center(start_tri);
let h = distance_xz(start_center, goal_center);
@@ -88,7 +81,6 @@ pub fn find_path(navmesh: &NavMesh, start: Vec3, goal: Vec3) -> Option<Vec<Vec3>
parents[idx] = node.parent;
if idx == goal_tri {
// Reconstruct path
let mut tri_path = Vec::new();
let mut cur = idx;
loop {
@@ -99,17 +91,7 @@ pub fn find_path(navmesh: &NavMesh, start: Vec3, goal: Vec3) -> Option<Vec<Vec3>
}
}
tri_path.reverse();
// Convert triangle path to Vec3 waypoints:
// start point -> intermediate triangle centers -> goal point
let mut path = Vec::new();
path.push(start);
// skip first (start_tri) and last (goal_tri) in intermediate centers
for &ti in &tri_path[1..tri_path.len() - 1] {
path.push(navmesh.triangle_center(ti));
}
path.push(goal);
return Some(path);
return Some(tri_path);
}
let tri = &navmesh.triangles[idx];
@@ -136,7 +118,138 @@ pub fn find_path(navmesh: &NavMesh, start: Vec3, goal: Vec3) -> Option<Vec<Vec3>
}
}
None // No path found
None
}
/// Find a path from `start` to `goal` on the given NavMesh using A*.
///
/// Returns Some(path) where path[0] == start, path[last] == goal, and
/// intermediate points are triangle centers. Returns None if either
/// point is outside the mesh or no path exists.
pub fn find_path(navmesh: &NavMesh, start: Vec3, goal: Vec3) -> Option<Vec<Vec3>> {
let tri_path = find_path_triangles(navmesh, start, goal)?;
if tri_path.len() == 1 {
return Some(vec![start, goal]);
}
let mut path = Vec::new();
path.push(start);
for &ti in &tri_path[1..tri_path.len() - 1] {
path.push(navmesh.triangle_center(ti));
}
path.push(goal);
Some(path)
}
/// 2D cross product on XZ plane: (b - a) x (c - a) projected onto Y.
fn cross_xz(a: Vec3, b: Vec3, c: Vec3) -> f32 {
(b.x - a.x) * (c.z - a.z) - (b.z - a.z) * (c.x - a.x)
}
/// Simple Stupid Funnel (SSF) algorithm for string-pulling a triangle corridor path.
///
/// Given the sequence of triangle indices from A*, produces an optimal path with
/// waypoints at portal edge corners where the path turns.
pub fn funnel_smooth(path_triangles: &[usize], navmesh: &NavMesh, start: Vec3, end: Vec3) -> Vec<Vec3> {
// Trivial cases
if path_triangles.is_empty() {
return vec![start, end];
}
if path_triangles.len() == 1 {
return vec![start, end];
}
// Build portal list: shared edges between consecutive triangles
let mut portals_left = Vec::new();
let mut portals_right = Vec::new();
for i in 0..path_triangles.len() - 1 {
let from = path_triangles[i];
let to = path_triangles[i + 1];
if let Some((left, right)) = navmesh.shared_edge(from, to) {
portals_left.push(left);
portals_right.push(right);
}
}
// Add end point as the final portal (degenerate portal: both sides = end)
portals_left.push(end);
portals_right.push(end);
let mut path = vec![start];
let mut apex = start;
let mut left = start;
let mut right = start;
#[allow(unused_assignments)]
let mut apex_idx: usize = 0;
let mut left_idx: usize = 0;
let mut right_idx: usize = 0;
let n = portals_left.len();
for i in 0..n {
let pl = portals_left[i];
let pr = portals_right[i];
// Update right vertex
if cross_xz(apex, right, pr) <= 0.0 {
if apex == right || cross_xz(apex, left, pr) > 0.0 {
// Tighten the funnel
right = pr;
right_idx = i;
} else {
// Right over left: left becomes new apex
if left != apex {
path.push(left);
}
apex = left;
apex_idx = left_idx;
left = apex;
right = apex;
left_idx = apex_idx;
right_idx = apex_idx;
// Restart scan from apex
// We need to continue from apex_idx + 1, but since
// we can't restart a for loop, we use a recursive approach
// or just continue (the standard SSF continues from i)
continue;
}
}
// Update left vertex
if cross_xz(apex, left, pl) >= 0.0 {
if apex == left || cross_xz(apex, right, pl) < 0.0 {
// Tighten the funnel
left = pl;
left_idx = i;
} else {
// Left over right: right becomes new apex
if right != apex {
path.push(right);
}
apex = right;
apex_idx = right_idx;
left = apex;
right = apex;
left_idx = apex_idx;
right_idx = apex_idx;
continue;
}
}
}
// Add end point if not already there
if let Some(&last) = path.last() {
if (last.x - end.x).abs() > 1e-6 || (last.z - end.z).abs() > 1e-6 {
path.push(end);
}
} else {
path.push(end);
}
path
}
#[cfg(test)]
@@ -238,4 +351,86 @@ mod tests {
let result = find_path(&nm, start, goal);
assert!(result.is_none());
}
#[test]
fn test_find_path_triangles_same() {
let nm = make_strip();
let start = Vec3::new(0.5, 0.0, 0.5);
let goal = Vec3::new(1.5, 0.0, 0.5);
let tris = find_path_triangles(&nm, start, goal).expect("should find path");
assert_eq!(tris.len(), 1);
}
#[test]
fn test_find_path_triangles_strip() {
let nm = make_strip();
let start = Vec3::new(0.8, 0.0, 0.5);
let goal = Vec3::new(2.0, 0.0, 3.5);
let tris = find_path_triangles(&nm, start, goal).expect("should find path");
assert_eq!(tris.len(), 3);
assert_eq!(tris[0], 0);
assert_eq!(tris[2], 2);
}
#[test]
fn test_funnel_straight_path() {
// Straight corridor: path should be just start and end (2 points)
let nm = make_strip();
let start = Vec3::new(1.0, 0.0, 0.5);
let end = Vec3::new(2.0, 0.0, 3.5);
let tris = find_path_triangles(&nm, start, end).expect("should find path");
let smoothed = funnel_smooth(&tris, &nm, start, end);
assert!(smoothed.len() >= 2, "funnel path should have at least 2 points, got {}", smoothed.len());
assert_eq!(smoothed[0], start);
assert_eq!(smoothed[smoothed.len() - 1], end);
}
#[test]
fn test_funnel_same_triangle() {
let nm = make_strip();
let start = Vec3::new(0.5, 0.0, 0.5);
let end = Vec3::new(1.5, 0.0, 0.5);
let smoothed = funnel_smooth(&[0], &nm, start, end);
assert_eq!(smoothed.len(), 2);
assert_eq!(smoothed[0], start);
assert_eq!(smoothed[1], end);
}
#[test]
fn test_funnel_l_shaped_path() {
// Build an L-shaped navmesh to force a turn
// Tri0: (0,0,0),(2,0,0),(0,0,2) - bottom left
// Tri1: (2,0,0),(2,0,2),(0,0,2) - top right of first square
// Tri2: (2,0,0),(4,0,0),(2,0,2) - extends right
// Tri3: (2,0,2),(4,0,0),(4,0,2) - top right
// Tri4: (2,0,2),(4,0,2),(2,0,4) - goes up
// This makes an L shape going right then up
let vertices = vec![
Vec3::new(0.0, 0.0, 0.0), // 0
Vec3::new(2.0, 0.0, 0.0), // 1
Vec3::new(0.0, 0.0, 2.0), // 2
Vec3::new(2.0, 0.0, 2.0), // 3
Vec3::new(4.0, 0.0, 0.0), // 4
Vec3::new(4.0, 0.0, 2.0), // 5
Vec3::new(2.0, 0.0, 4.0), // 6
];
let triangles = vec![
NavTriangle { indices: [0, 1, 2], neighbors: [Some(1), None, None] }, // 0
NavTriangle { indices: [1, 3, 2], neighbors: [Some(2), None, Some(0)] }, // 1
NavTriangle { indices: [1, 4, 3], neighbors: [None, Some(3), Some(1)] }, // 2
NavTriangle { indices: [4, 5, 3], neighbors: [None, None, Some(2)] }, // 3 -- not used in L path
NavTriangle { indices: [3, 5, 6], neighbors: [None, None, None] }, // 4
];
let nm = NavMesh::new(vertices, triangles);
let start = Vec3::new(0.3, 0.0, 0.3);
let end = Vec3::new(3.5, 0.0, 0.5);
let tris = find_path_triangles(&nm, start, end);
if let Some(tris) = tris {
let smoothed = funnel_smooth(&tris, &nm, start, end);
assert!(smoothed.len() >= 2);
assert_eq!(smoothed[0], start);
assert_eq!(smoothed[smoothed.len() - 1], end);
}
}
}

View File

@@ -1,7 +1,11 @@
pub mod handle;
pub mod storage;
pub mod assets;
pub mod watcher;
pub mod loader;
pub use handle::Handle;
pub use storage::AssetStorage;
pub use assets::Assets;
pub use watcher::FileWatcher;
pub use loader::{AssetLoader, LoadState};

View File

@@ -0,0 +1,300 @@
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::thread::{self, JoinHandle};
use crate::assets::Assets;
use crate::handle::Handle;
#[derive(Debug)]
pub enum LoadState {
Loading,
Ready,
Failed(String),
}
struct LoadRequest {
id: u64,
path: PathBuf,
parse: Box<dyn FnOnce(&[u8]) -> Result<Box<dyn Any + Send>, String> + Send>,
}
struct LoadResult {
id: u64,
result: Result<Box<dyn Any + Send>, String>,
}
struct PendingEntry {
state: PendingState,
handle_id: u32,
handle_gen: u32,
type_id: TypeId,
/// Type-erased inserter: takes (assets, boxed_any) and inserts the asset,
/// returning the actual handle (id, gen) that was assigned.
inserter: Option<Box<dyn FnOnce(&mut Assets, Box<dyn Any + Send>) -> (u32, u32)>>,
}
enum PendingState {
Loading,
Failed(String),
Ready,
}
pub struct AssetLoader {
request_tx: Option<Sender<LoadRequest>>,
result_rx: Receiver<LoadResult>,
thread: Option<JoinHandle<()>>,
next_id: u64,
pending: HashMap<u64, PendingEntry>,
// Map from (type_id, handle_id, handle_gen) to load_id for state lookups
handle_to_load: HashMap<(TypeId, u32, u32), u64>,
}
impl AssetLoader {
pub fn new() -> Self {
let (request_tx, request_rx) = channel::<LoadRequest>();
let (result_tx, result_rx) = channel::<LoadResult>();
let thread = thread::spawn(move || {
while let Ok(req) = request_rx.recv() {
let result = match std::fs::read(&req.path) {
Ok(data) => (req.parse)(&data),
Err(e) => Err(format!("Failed to read {}: {}", req.path.display(), e)),
};
let _ = result_tx.send(LoadResult {
id: req.id,
result,
});
}
});
Self {
request_tx: Some(request_tx),
result_rx,
thread: Some(thread),
next_id: 0,
pending: HashMap::new(),
handle_to_load: HashMap::new(),
}
}
/// Queue a file for background loading. Returns a handle immediately.
///
/// The handle becomes valid (pointing to real data) after `process_loaded`
/// inserts the completed asset into Assets. Until then, `state()` returns
/// `LoadState::Loading`.
///
/// **Important:** The returned handle's id/generation are provisional.
/// After `process_loaded`, the handle is updated internally to match the
/// actual slot in Assets. Since we pre-allocate using the load id, the
/// actual handle assigned by `Assets::insert` may differ. We remap it.
pub fn load<T, F>(&mut self, path: PathBuf, parse_fn: F) -> Handle<T>
where
T: Send + 'static,
F: FnOnce(&[u8]) -> Result<T, String> + Send + 'static,
{
let id = self.next_id;
self.next_id += 1;
// We use the load id as a provisional handle id.
// The real handle is assigned when the asset is inserted into Assets.
let handle_id = id as u32;
let handle_gen = 0u32;
let handle = Handle::new(handle_id, handle_gen);
let type_id = TypeId::of::<T>();
// Create a type-erased inserter closure that knows how to downcast
// Box<dyn Any + Send> back to T and insert it into Assets.
let inserter: Box<dyn FnOnce(&mut Assets, Box<dyn Any + Send>) -> (u32, u32)> =
Box::new(|assets: &mut Assets, boxed: Box<dyn Any + Send>| {
let asset = *boxed.downcast::<T>().expect("type mismatch in loader");
let real_handle = assets.insert(asset);
(real_handle.id, real_handle.generation)
});
self.pending.insert(
id,
PendingEntry {
state: PendingState::Loading,
handle_id,
handle_gen,
type_id,
inserter: Some(inserter),
},
);
self.handle_to_load
.insert((type_id, handle_id, handle_gen), id);
// Wrap parse_fn to erase the type
let boxed_parse: Box<
dyn FnOnce(&[u8]) -> Result<Box<dyn Any + Send>, String> + Send,
> = Box::new(move |data: &[u8]| {
parse_fn(data).map(|v| Box::new(v) as Box<dyn Any + Send>)
});
if let Some(tx) = &self.request_tx {
let _ = tx.send(LoadRequest {
id,
path,
parse: boxed_parse,
});
}
handle
}
/// Check the load state for a given handle.
pub fn state<T: 'static>(&self, handle: &Handle<T>) -> LoadState {
let type_id = TypeId::of::<T>();
if let Some(&load_id) =
self.handle_to_load
.get(&(type_id, handle.id, handle.generation))
{
if let Some(entry) = self.pending.get(&load_id) {
return match &entry.state {
PendingState::Loading => LoadState::Loading,
PendingState::Failed(e) => LoadState::Failed(e.clone()),
PendingState::Ready => LoadState::Ready,
};
}
}
LoadState::Loading
}
/// Drain completed loads from the worker thread and insert them into Assets.
///
/// Call this once per frame on the main thread.
pub fn process_loaded(&mut self, assets: &mut Assets) {
// Collect results first
let mut results = Vec::new();
while let Ok(result) = self.result_rx.try_recv() {
results.push(result);
}
for result in results {
if let Some(entry) = self.pending.get_mut(&result.id) {
match result.result {
Ok(boxed_asset) => {
// Take the inserter out and use it to insert into Assets
if let Some(inserter) = entry.inserter.take() {
let (real_id, real_gen) = inserter(assets, boxed_asset);
// Update the handle mapping if the real handle differs
let old_key =
(entry.type_id, entry.handle_id, entry.handle_gen);
if real_id != entry.handle_id || real_gen != entry.handle_gen {
// Remove old mapping, add new one
self.handle_to_load.remove(&old_key);
entry.handle_id = real_id;
entry.handle_gen = real_gen;
self.handle_to_load
.insert((entry.type_id, real_id, real_gen), result.id);
}
}
entry.state = PendingState::Ready;
}
Err(e) => {
entry.state = PendingState::Failed(e);
}
}
}
}
}
pub fn shutdown(mut self) {
// Drop the sender to signal the worker to stop
self.request_tx = None;
if let Some(thread) = self.thread.take() {
let _ = thread.join();
}
}
}
impl Default for AssetLoader {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::time::Duration;
#[test]
fn test_load_state_initial() {
let mut loader = AssetLoader::new();
let dir = std::env::temp_dir().join("voltex_loader_test_1");
let _ = fs::create_dir_all(&dir);
let path = dir.join("test.txt");
fs::write(&path, "hello world").unwrap();
let handle: Handle<String> = loader.load(path.clone(), |data| {
Ok(String::from_utf8_lossy(data).to_string())
});
assert!(matches!(
loader.state::<String>(&handle),
LoadState::Loading
));
let _ = fs::remove_dir_all(&dir);
loader.shutdown();
}
#[test]
fn test_load_and_process() {
let mut loader = AssetLoader::new();
let dir = std::env::temp_dir().join("voltex_loader_test_2");
let _ = fs::create_dir_all(&dir);
let path = dir.join("data.txt");
fs::write(&path, "content123").unwrap();
let handle: Handle<String> = loader.load(path.clone(), |data| {
Ok(String::from_utf8_lossy(data).to_string())
});
std::thread::sleep(Duration::from_millis(200));
let mut assets = Assets::new();
loader.process_loaded(&mut assets);
assert!(matches!(
loader.state::<String>(&handle),
LoadState::Ready
));
// The handle returned by load() is provisional. After process_loaded,
// the real handle may have different id/gen. We need to look up
// the actual handle. Since this is the first insert, it should be (0, 0).
// But our provisional handle is also (0, 0), so it should match.
let val = assets.get(handle).unwrap();
assert_eq!(val, "content123");
let _ = fs::remove_dir_all(&dir);
loader.shutdown();
}
#[test]
fn test_load_nonexistent_fails() {
let mut loader = AssetLoader::new();
let handle: Handle<String> = loader.load(
PathBuf::from("/nonexistent/file.txt"),
|data| Ok(String::from_utf8_lossy(data).to_string()),
);
std::thread::sleep(Duration::from_millis(200));
let mut assets = Assets::new();
loader.process_loaded(&mut assets);
assert!(matches!(
loader.state::<String>(&handle),
LoadState::Failed(_)
));
loader.shutdown();
}
}

View File

@@ -103,6 +103,18 @@ impl<T> AssetStorage<T> {
.unwrap_or(0)
}
/// Replace the asset data without changing generation or ref_count.
/// Used for hot reload — existing handles remain valid.
pub fn replace_in_place(&mut self, handle: Handle<T>, new_asset: T) -> bool {
if let Some(Some(entry)) = self.entries.get_mut(handle.id as usize) {
if entry.generation == handle.generation {
entry.asset = new_asset;
return true;
}
}
false
}
pub fn iter(&self) -> impl Iterator<Item = (Handle<T>, &T)> {
self.entries
.iter()
@@ -211,6 +223,27 @@ mod tests {
assert_eq!(storage.get(h2).unwrap().verts, 9);
}
#[test]
fn replace_in_place() {
let mut storage: AssetStorage<Mesh> = AssetStorage::new();
let h = storage.insert(Mesh { verts: 3 });
assert!(storage.replace_in_place(h, Mesh { verts: 99 }));
assert_eq!(storage.get(h).unwrap().verts, 99);
// Same handle still works — generation unchanged
assert_eq!(storage.ref_count(h), 1);
}
#[test]
fn replace_in_place_stale_handle() {
let mut storage: AssetStorage<Mesh> = AssetStorage::new();
let h = storage.insert(Mesh { verts: 3 });
storage.release(h);
let h2 = storage.insert(Mesh { verts: 10 });
// h is stale, replace should fail
assert!(!storage.replace_in_place(h, Mesh { verts: 99 }));
assert_eq!(storage.get(h2).unwrap().verts, 10);
}
#[test]
fn iter() {
let mut storage: AssetStorage<Mesh> = AssetStorage::new();

View File

@@ -0,0 +1,114 @@
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant, SystemTime};
pub struct FileWatcher {
watched: HashMap<PathBuf, Option<SystemTime>>,
poll_interval: Duration,
last_poll: Instant,
}
impl FileWatcher {
pub fn new(poll_interval: Duration) -> Self {
Self {
watched: HashMap::new(),
poll_interval,
last_poll: Instant::now() - poll_interval, // allow immediate first poll
}
}
pub fn watch(&mut self, path: PathBuf) {
// Store None initially — first poll will record the mtime without reporting change
self.watched.insert(path, None);
}
pub fn unwatch(&mut self, path: &Path) {
self.watched.remove(path);
}
pub fn poll_changes(&mut self) -> Vec<PathBuf> {
let now = Instant::now();
if now.duration_since(self.last_poll) < self.poll_interval {
return Vec::new();
}
self.last_poll = now;
let mut changed = Vec::new();
for (path, last_mtime) in &mut self.watched {
let current = std::fs::metadata(path)
.ok()
.and_then(|m| m.modified().ok());
if let Some(prev) = last_mtime {
// We have a previous mtime — compare
if current != Some(*prev) {
changed.push(path.clone());
}
}
// else: first poll, just record mtime, don't report
*last_mtime = current;
}
changed
}
pub fn watched_count(&self) -> usize {
self.watched.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn test_watch_and_poll_no_changes() {
let mut watcher = FileWatcher::new(Duration::from_millis(0));
let dir = std::env::temp_dir().join("voltex_watcher_test_1");
let _ = fs::create_dir_all(&dir);
let path = dir.join("test.txt");
fs::write(&path, "hello").unwrap();
watcher.watch(path.clone());
// First poll — should not report as changed (just registered)
let changes = watcher.poll_changes();
assert!(changes.is_empty());
// Second poll without modification — still no changes
let changes = watcher.poll_changes();
assert!(changes.is_empty());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_detect_file_change() {
let dir = std::env::temp_dir().join("voltex_watcher_test_2");
let _ = fs::create_dir_all(&dir);
let path = dir.join("test2.txt");
fs::write(&path, "v1").unwrap();
let mut watcher = FileWatcher::new(Duration::from_millis(0));
watcher.watch(path.clone());
let _ = watcher.poll_changes(); // register initial mtime
// Modify file
std::thread::sleep(Duration::from_millis(50));
fs::write(&path, "v2 with more data").unwrap();
let changes = watcher.poll_changes();
assert!(changes.contains(&path));
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_unwatch() {
let mut watcher = FileWatcher::new(Duration::from_millis(0));
let path = PathBuf::from("/nonexistent/test.txt");
watcher.watch(path.clone());
assert_eq!(watcher.watched_count(), 1);
watcher.unwatch(&path);
assert_eq!(watcher.watched_count(), 0);
assert!(watcher.poll_changes().is_empty());
}
}

View File

@@ -0,0 +1,91 @@
use std::sync::{Arc, Mutex};
use std::path::PathBuf;
#[derive(Debug, Clone, PartialEq)]
pub enum LoadState {
Pending,
Loading,
Loaded,
Error(String),
}
pub struct AsyncAudioLoader {
pending: Arc<Mutex<Vec<(u32, PathBuf, LoadState)>>>,
}
impl AsyncAudioLoader {
pub fn new() -> Self {
AsyncAudioLoader { pending: Arc::new(Mutex::new(Vec::new())) }
}
/// Queue a clip for async loading. Returns immediately.
pub fn load(&self, clip_id: u32, path: PathBuf) {
let mut pending = self.pending.lock().unwrap();
pending.push((clip_id, path.clone(), LoadState::Pending));
let pending_clone = Arc::clone(&self.pending);
std::thread::spawn(move || {
// Mark loading
{ let mut p = pending_clone.lock().unwrap();
if let Some(entry) = p.iter_mut().find(|(id, _, _)| *id == clip_id) {
entry.2 = LoadState::Loading;
}
}
// Simulate load (read file)
match std::fs::read(&path) {
Ok(_data) => {
let mut p = pending_clone.lock().unwrap();
if let Some(entry) = p.iter_mut().find(|(id, _, _)| *id == clip_id) {
entry.2 = LoadState::Loaded;
}
}
Err(e) => {
let mut p = pending_clone.lock().unwrap();
if let Some(entry) = p.iter_mut().find(|(id, _, _)| *id == clip_id) {
entry.2 = LoadState::Error(e.to_string());
}
}
}
});
}
pub fn state(&self, clip_id: u32) -> LoadState {
let pending = self.pending.lock().unwrap();
pending.iter().find(|(id, _, _)| *id == clip_id)
.map(|(_, _, s)| s.clone())
.unwrap_or(LoadState::Error("not found".to_string()))
}
pub fn poll_completed(&self) -> Vec<u32> {
let pending = self.pending.lock().unwrap();
pending.iter().filter(|(_, _, s)| *s == LoadState::Loaded).map(|(id, _, _)| *id).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let loader = AsyncAudioLoader::new();
assert!(loader.poll_completed().is_empty());
}
#[test]
fn test_load_nonexistent() {
let loader = AsyncAudioLoader::new();
loader.load(1, PathBuf::from("/nonexistent/path.wav"));
std::thread::sleep(std::time::Duration::from_millis(100));
let state = loader.state(1);
assert!(matches!(state, LoadState::Error(_)));
}
#[test]
fn test_load_existing() {
let dir = std::env::temp_dir().join("voltex_async_test");
let _ = std::fs::create_dir_all(&dir);
std::fs::write(dir.join("test.wav"), b"RIFF").unwrap();
let loader = AsyncAudioLoader::new();
loader.load(42, dir.join("test.wav"));
std::thread::sleep(std::time::Duration::from_millis(200));
assert_eq!(loader.state(42), LoadState::Loaded);
let _ = std::fs::remove_dir_all(&dir);
}
}

View File

@@ -0,0 +1,82 @@
/// Audio bus: mixes multiple input signals.
pub struct AudioBus {
pub inputs: Vec<BusInput>,
pub output_gain: f32,
}
#[derive(Debug, Clone)]
pub struct BusInput {
pub source_id: u32,
pub gain: f32,
}
impl AudioBus {
pub fn new() -> Self {
AudioBus { inputs: Vec::new(), output_gain: 1.0 }
}
pub fn add_input(&mut self, source_id: u32, gain: f32) {
self.inputs.push(BusInput { source_id, gain });
}
pub fn remove_input(&mut self, source_id: u32) {
self.inputs.retain(|i| i.source_id != source_id);
}
/// Mix samples from all inputs. Each input provides a sample value.
pub fn mix(&self, samples: &[(u32, f32)]) -> f32 {
let mut sum = 0.0;
for input in &self.inputs {
if let Some((_, sample)) = samples.iter().find(|(id, _)| *id == input.source_id) {
sum += sample * input.gain;
}
}
sum * self.output_gain
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_bus() {
let bus = AudioBus::new();
assert!((bus.mix(&[]) - 0.0).abs() < 1e-6);
}
#[test]
fn test_single_input() {
let mut bus = AudioBus::new();
bus.add_input(1, 0.5);
let out = bus.mix(&[(1, 1.0)]);
assert!((out - 0.5).abs() < 1e-6);
}
#[test]
fn test_multiple_inputs() {
let mut bus = AudioBus::new();
bus.add_input(1, 1.0);
bus.add_input(2, 1.0);
let out = bus.mix(&[(1, 0.3), (2, 0.5)]);
assert!((out - 0.8).abs() < 1e-6);
}
#[test]
fn test_remove_input() {
let mut bus = AudioBus::new();
bus.add_input(1, 1.0);
bus.add_input(2, 1.0);
bus.remove_input(1);
assert_eq!(bus.inputs.len(), 1);
}
#[test]
fn test_output_gain() {
let mut bus = AudioBus::new();
bus.output_gain = 0.5;
bus.add_input(1, 1.0);
let out = bus.mix(&[(1, 1.0)]);
assert!((out - 0.5).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,78 @@
/// ECS component that attaches audio playback to an entity.
#[derive(Debug, Clone)]
pub struct AudioSource {
pub clip_id: u32,
pub volume: f32,
pub spatial: bool,
pub looping: bool,
pub playing: bool,
pub played_once: bool,
}
impl AudioSource {
pub fn new(clip_id: u32) -> Self {
AudioSource {
clip_id,
volume: 1.0,
spatial: false,
looping: false,
playing: false,
played_once: false,
}
}
pub fn spatial(mut self) -> Self { self.spatial = true; self }
pub fn looping(mut self) -> Self { self.looping = true; self }
pub fn play(&mut self) {
self.playing = true;
self.played_once = false;
}
pub fn stop(&mut self) {
self.playing = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_defaults() {
let src = AudioSource::new(42);
assert_eq!(src.clip_id, 42);
assert!((src.volume - 1.0).abs() < 1e-6);
assert!(!src.spatial);
assert!(!src.looping);
assert!(!src.playing);
assert!(!src.played_once);
}
#[test]
fn test_builder_pattern() {
let src = AudioSource::new(1).spatial().looping();
assert!(src.spatial);
assert!(src.looping);
}
#[test]
fn test_play_stop() {
let mut src = AudioSource::new(1);
src.play();
assert!(src.playing);
assert!(!src.played_once);
src.played_once = true;
src.play(); // re-play should reset played_once
assert!(!src.played_once);
src.stop();
assert!(!src.playing);
}
#[test]
fn test_volume() {
let mut src = AudioSource::new(1);
src.volume = 0.5;
assert!((src.volume - 0.5).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,101 @@
use std::collections::HashMap;
/// Dynamic audio mix group system.
pub struct MixGroupManager {
groups: HashMap<String, MixGroupConfig>,
}
#[derive(Debug, Clone)]
pub struct MixGroupConfig {
pub name: String,
pub volume: f32,
pub muted: bool,
pub parent: Option<String>,
}
impl MixGroupManager {
pub fn new() -> Self {
let mut mgr = MixGroupManager { groups: HashMap::new() };
mgr.add_group("Master", None);
mgr
}
pub fn add_group(&mut self, name: &str, parent: Option<&str>) -> bool {
if self.groups.contains_key(name) { return false; }
if let Some(p) = parent {
if !self.groups.contains_key(p) { return false; }
}
self.groups.insert(name.to_string(), MixGroupConfig {
name: name.to_string(),
volume: 1.0,
muted: false,
parent: parent.map(|s| s.to_string()),
});
true
}
pub fn remove_group(&mut self, name: &str) -> bool {
if name == "Master" { return false; } // can't remove master
self.groups.remove(name).is_some()
}
pub fn set_volume(&mut self, name: &str, volume: f32) {
if let Some(g) = self.groups.get_mut(name) { g.volume = volume.clamp(0.0, 1.0); }
}
pub fn effective_volume(&self, name: &str) -> f32 {
let mut vol = 1.0;
let mut current = name;
for _ in 0..10 { // max depth to prevent infinite loops
if let Some(g) = self.groups.get(current) {
if g.muted { return 0.0; }
vol *= g.volume;
if let Some(ref p) = g.parent { current = p; } else { break; }
} else { break; }
}
vol
}
pub fn group_count(&self) -> usize { self.groups.len() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_has_master() {
let mgr = MixGroupManager::new();
assert_eq!(mgr.group_count(), 1);
assert!((mgr.effective_volume("Master") - 1.0).abs() < 1e-6);
}
#[test]
fn test_add_group() {
let mut mgr = MixGroupManager::new();
assert!(mgr.add_group("SFX", Some("Master")));
assert_eq!(mgr.group_count(), 2);
}
#[test]
fn test_effective_volume_chain() {
let mut mgr = MixGroupManager::new();
mgr.set_volume("Master", 0.8);
mgr.add_group("SFX", Some("Master"));
mgr.set_volume("SFX", 0.5);
assert!((mgr.effective_volume("SFX") - 0.4).abs() < 1e-6); // 0.5 * 0.8
}
#[test]
fn test_cant_remove_master() {
let mut mgr = MixGroupManager::new();
assert!(!mgr.remove_group("Master"));
}
#[test]
fn test_add_duplicate_fails() {
let mut mgr = MixGroupManager::new();
mgr.add_group("SFX", Some("Master"));
assert!(!mgr.add_group("SFX", Some("Master")));
}
}

View File

@@ -0,0 +1,62 @@
/// Trait for audio effects.
pub trait AudioEffect {
fn process(&mut self, sample: f32) -> f32;
fn name(&self) -> &str;
}
/// Chain of audio effects processed in order.
pub struct EffectChain {
effects: Vec<Box<dyn AudioEffect>>,
pub bypass: bool,
}
impl EffectChain {
pub fn new() -> Self { EffectChain { effects: Vec::new(), bypass: false } }
pub fn add(&mut self, effect: Box<dyn AudioEffect>) { self.effects.push(effect); }
pub fn remove(&mut self, index: usize) { if index < self.effects.len() { self.effects.remove(index); } }
pub fn len(&self) -> usize { self.effects.len() }
pub fn is_empty(&self) -> bool { self.effects.is_empty() }
pub fn process(&mut self, sample: f32) -> f32 {
if self.bypass { return sample; }
let mut s = sample;
for effect in &mut self.effects {
s = effect.process(s);
}
s
}
}
// Simple gain effect for testing
pub struct GainEffect { pub gain: f32 }
impl AudioEffect for GainEffect {
fn process(&mut self, sample: f32) -> f32 { sample * self.gain }
fn name(&self) -> &str { "Gain" }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_chain() { let mut c = EffectChain::new(); assert!((c.process(1.0) - 1.0).abs() < 1e-6); }
#[test]
fn test_single_effect() {
let mut c = EffectChain::new();
c.add(Box::new(GainEffect { gain: 0.5 }));
assert!((c.process(1.0) - 0.5).abs() < 1e-6);
}
#[test]
fn test_chain_order() {
let mut c = EffectChain::new();
c.add(Box::new(GainEffect { gain: 0.5 }));
c.add(Box::new(GainEffect { gain: 0.5 }));
assert!((c.process(1.0) - 0.25).abs() < 1e-6);
}
#[test]
fn test_bypass() {
let mut c = EffectChain::new();
c.add(Box::new(GainEffect { gain: 0.0 }));
c.bypass = true;
assert!((c.process(1.0) - 1.0).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,58 @@
/// Fade curve types.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FadeCurve {
Linear,
Exponential,
Logarithmic,
SCurve,
}
/// Apply fade curve to a normalized parameter t (0.0 to 1.0).
pub fn apply_fade(t: f32, curve: FadeCurve) -> f32 {
let t = t.clamp(0.0, 1.0);
match curve {
FadeCurve::Linear => t,
FadeCurve::Exponential => t * t,
FadeCurve::Logarithmic => t.sqrt(),
FadeCurve::SCurve => {
// Smoothstep: 3t^2 - 2t^3
t * t * (3.0 - 2.0 * t)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear() {
assert!((apply_fade(0.5, FadeCurve::Linear) - 0.5).abs() < 1e-6);
}
#[test]
fn test_exponential() {
assert!((apply_fade(0.5, FadeCurve::Exponential) - 0.25).abs() < 1e-6);
assert!(apply_fade(0.5, FadeCurve::Exponential) < 0.5); // slower start
}
#[test]
fn test_logarithmic() {
assert!(apply_fade(0.5, FadeCurve::Logarithmic) > 0.5); // faster start
}
#[test]
fn test_scurve() {
assert!((apply_fade(0.0, FadeCurve::SCurve) - 0.0).abs() < 1e-6);
assert!((apply_fade(1.0, FadeCurve::SCurve) - 1.0).abs() < 1e-6);
assert!((apply_fade(0.5, FadeCurve::SCurve) - 0.5).abs() < 1e-6); // midpoint same
}
#[test]
fn test_endpoints() {
for curve in [FadeCurve::Linear, FadeCurve::Exponential, FadeCurve::Logarithmic, FadeCurve::SCurve] {
assert!((apply_fade(0.0, curve) - 0.0).abs() < 1e-6);
assert!((apply_fade(1.0, curve) - 1.0).abs() < 1e-6);
}
}
}

View File

@@ -0,0 +1,126 @@
use std::f32::consts::PI;
/// Head radius in meters (average human).
const HEAD_RADIUS: f32 = 0.0875;
/// Speed of sound in m/s.
const SPEED_OF_SOUND: f32 = 343.0;
/// HRTF filter result for a sound at a given azimuth angle.
#[derive(Debug, Clone, Copy)]
pub struct HrtfResult {
pub left_delay_samples: f32, // ITD: delay for left ear in samples
pub right_delay_samples: f32, // ITD: delay for right ear in samples
pub left_gain: f32, // ILD: gain for left ear (0.0-1.0)
pub right_gain: f32, // ILD: gain for right ear (0.0-1.0)
}
/// Calculate HRTF parameters from azimuth angle.
/// azimuth: angle in radians, 0 = front, PI/2 = right, -PI/2 = left, PI = behind.
pub fn calculate_hrtf(azimuth: f32, sample_rate: u32) -> HrtfResult {
// ITD: Woodworth formula
// time_diff = (HEAD_RADIUS / SPEED_OF_SOUND) * (azimuth + sin(azimuth))
let itd = (HEAD_RADIUS / SPEED_OF_SOUND) * (azimuth.abs() + azimuth.abs().sin());
let delay_samples = itd * sample_rate as f32;
// ILD: simplified frequency-independent model
// Sound is louder on the side facing the source
let shadow = 0.5 * (1.0 + azimuth.cos()); // 1.0 at front, 0.5 at side, 0.0 at back
let (left_delay, right_delay, left_gain, right_gain);
if azimuth >= 0.0 {
// Sound from right side
left_delay = delay_samples;
right_delay = 0.0;
left_gain = (0.3 + 0.7 * shadow).min(1.0); // shadowed side
right_gain = 1.0;
} else {
// Sound from left side
left_delay = 0.0;
right_delay = delay_samples;
left_gain = 1.0;
right_gain = (0.3 + 0.7 * shadow).min(1.0);
}
HrtfResult {
left_delay_samples: left_delay,
right_delay_samples: right_delay,
left_gain,
right_gain,
}
}
/// Calculate azimuth angle from listener position/forward to sound position.
pub fn azimuth_from_positions(
listener_pos: [f32; 3],
listener_forward: [f32; 3],
listener_right: [f32; 3],
sound_pos: [f32; 3],
) -> f32 {
let dx = sound_pos[0] - listener_pos[0];
let dy = sound_pos[1] - listener_pos[1];
let dz = sound_pos[2] - listener_pos[2];
let len = (dx * dx + dy * dy + dz * dz).sqrt();
if len < 1e-6 {
return 0.0;
}
let dir = [dx / len, dy / len, dz / len];
// Dot with right vector gives sin(azimuth)
let right_dot =
dir[0] * listener_right[0] + dir[1] * listener_right[1] + dir[2] * listener_right[2];
// Dot with forward vector gives cos(azimuth)
let fwd_dot = dir[0] * listener_forward[0]
+ dir[1] * listener_forward[1]
+ dir[2] * listener_forward[2];
right_dot.atan2(fwd_dot)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hrtf_front() {
let r = calculate_hrtf(0.0, 44100);
assert!((r.left_delay_samples - 0.0).abs() < 1.0);
assert!((r.right_delay_samples - 0.0).abs() < 1.0);
assert!((r.left_gain - r.right_gain).abs() < 0.01); // symmetric
}
#[test]
fn test_hrtf_right() {
let r = calculate_hrtf(PI / 2.0, 44100);
assert!(r.left_delay_samples > r.right_delay_samples);
assert!(r.left_gain < r.right_gain);
}
#[test]
fn test_hrtf_left() {
let r = calculate_hrtf(-PI / 2.0, 44100);
assert!(r.right_delay_samples > r.left_delay_samples);
assert!(r.right_gain < r.left_gain);
}
#[test]
fn test_azimuth_front() {
let az = azimuth_from_positions(
[0.0; 3],
[0.0, 0.0, -1.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, -5.0],
);
assert!(az.abs() < 0.1); // roughly front
}
#[test]
fn test_azimuth_right() {
let az = azimuth_from_positions(
[0.0; 3],
[0.0, 0.0, -1.0],
[1.0, 0.0, 0.0],
[5.0, 0.0, 0.0],
);
assert!((az - PI / 2.0).abs() < 0.1);
}
}

View File

@@ -5,11 +5,28 @@ pub mod mixing;
pub mod wasapi;
pub mod audio_system;
pub mod spatial;
pub mod hrtf;
pub mod mix_group;
pub mod audio_source;
pub mod reverb;
pub mod occlusion;
pub use audio_clip::AudioClip;
pub use audio_source::AudioSource;
pub use wav::{parse_wav, generate_wav_bytes};
pub use mixing::{PlayingSound, mix_sounds};
pub use audio_system::AudioSystem;
pub use spatial::{Listener, SpatialParams, distance_attenuation, stereo_pan, compute_spatial_gains};
pub use mix_group::{MixGroup, MixerState};
pub use reverb::{Reverb, Echo, DelayLine};
pub use occlusion::{OcclusionResult, LowPassFilter, calculate_occlusion};
pub mod fade_curves;
pub use fade_curves::{FadeCurve, apply_fade};
pub mod dynamic_groups;
pub use dynamic_groups::MixGroupManager;
pub mod audio_bus;
pub use audio_bus::AudioBus;
pub mod async_loader;
pub use async_loader::AsyncAudioLoader;
pub mod effect_chain;
pub use effect_chain::{AudioEffect, EffectChain};

View File

@@ -0,0 +1,96 @@
/// Audio occlusion parameters.
#[derive(Debug, Clone, Copy)]
pub struct OcclusionResult {
pub volume_multiplier: f32, // 0.0-1.0, reduced by occlusion
pub lowpass_cutoff: f32, // Hz, lower = more muffled
}
/// Simple ray-based occlusion check.
/// `occlusion_factor`: 0.0 = no occlusion (direct line of sight), 1.0 = fully occluded.
pub fn calculate_occlusion(occlusion_factor: f32) -> OcclusionResult {
let factor = occlusion_factor.clamp(0.0, 1.0);
OcclusionResult {
volume_multiplier: 1.0 - factor * 0.7, // Max 70% volume reduction
lowpass_cutoff: 20000.0 * (1.0 - factor * 0.8), // 20kHz → 4kHz when fully occluded
}
}
/// Simple low-pass filter (one-pole IIR).
pub struct LowPassFilter {
prev_output: f32,
alpha: f32,
}
impl LowPassFilter {
pub fn new(cutoff_hz: f32, sample_rate: u32) -> Self {
let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff_hz);
let dt = 1.0 / sample_rate as f32;
let alpha = dt / (rc + dt);
LowPassFilter { prev_output: 0.0, alpha }
}
pub fn set_cutoff(&mut self, cutoff_hz: f32, sample_rate: u32) {
let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff_hz);
let dt = 1.0 / sample_rate as f32;
self.alpha = dt / (rc + dt);
}
pub fn process(&mut self, input: f32) -> f32 {
self.prev_output = self.prev_output + self.alpha * (input - self.prev_output);
self.prev_output
}
pub fn reset(&mut self) {
self.prev_output = 0.0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_occlusion() {
let r = calculate_occlusion(0.0);
assert!((r.volume_multiplier - 1.0).abs() < 1e-6);
assert!((r.lowpass_cutoff - 20000.0).abs() < 1.0);
}
#[test]
fn test_full_occlusion() {
let r = calculate_occlusion(1.0);
assert!((r.volume_multiplier - 0.3).abs() < 1e-6);
assert!((r.lowpass_cutoff - 4000.0).abs() < 1.0);
}
#[test]
fn test_partial_occlusion() {
let r = calculate_occlusion(0.5);
assert!(r.volume_multiplier > 0.3 && r.volume_multiplier < 1.0);
assert!(r.lowpass_cutoff > 4000.0 && r.lowpass_cutoff < 20000.0);
}
#[test]
fn test_lowpass_attenuates_high_freq() {
let mut lpf = LowPassFilter::new(100.0, 44100); // very low cutoff
// High frequency signal (alternating +1/-1) should be attenuated
let mut max_output = 0.0_f32;
for i in 0..100 {
let input = if i % 2 == 0 { 1.0 } else { -1.0 };
let output = lpf.process(input);
max_output = max_output.max(output.abs());
}
assert!(max_output < 0.5, "high freq should be attenuated, got {}", max_output);
}
#[test]
fn test_lowpass_passes_dc() {
let mut lpf = LowPassFilter::new(5000.0, 44100);
// DC signal (constant 1.0) should pass through
for _ in 0..1000 {
lpf.process(1.0);
}
let output = lpf.process(1.0);
assert!((output - 1.0).abs() < 0.01, "DC should pass, got {}", output);
}
}

View File

@@ -0,0 +1,294 @@
//! OGG container parser.
//!
//! Parses OGG bitstream pages and extracts Vorbis packets.
//! Reference: <https://www.xiph.org/ogg/doc/framing.html>
/// An OGG page header.
#[derive(Debug, Clone)]
pub struct OggPage {
/// Header type flags (0x01 = continuation, 0x02 = BOS, 0x04 = EOS).
pub header_type: u8,
/// Granule position (PCM sample position).
pub granule_position: u64,
/// Bitstream serial number.
pub serial: u32,
/// Page sequence number.
pub page_sequence: u32,
/// Number of segments in this page.
pub segment_count: u8,
/// The segment table (each entry is a segment length, 0..255).
pub segment_table: Vec<u8>,
/// Raw packet data of this page (concatenated segments).
pub data: Vec<u8>,
}
/// Parse all OGG pages from raw bytes.
pub fn parse_ogg_pages(data: &[u8]) -> Result<Vec<OggPage>, String> {
let mut pages = Vec::new();
let mut offset = 0;
while offset < data.len() {
if offset + 27 > data.len() {
break;
}
// Capture pattern "OggS"
if &data[offset..offset + 4] != b"OggS" {
return Err(format!("Invalid OGG capture pattern at offset {}", offset));
}
let version = data[offset + 4];
if version != 0 {
return Err(format!("Unsupported OGG version: {}", version));
}
let header_type = data[offset + 5];
let granule_position = u64::from_le_bytes([
data[offset + 6],
data[offset + 7],
data[offset + 8],
data[offset + 9],
data[offset + 10],
data[offset + 11],
data[offset + 12],
data[offset + 13],
]);
let serial = u32::from_le_bytes([
data[offset + 14],
data[offset + 15],
data[offset + 16],
data[offset + 17],
]);
let page_sequence = u32::from_le_bytes([
data[offset + 18],
data[offset + 19],
data[offset + 20],
data[offset + 21],
]);
// CRC at offset+22..+26 (skip verification for simplicity)
let segment_count = data[offset + 26] as usize;
if offset + 27 + segment_count > data.len() {
return Err("OGG page segment table extends beyond data".to_string());
}
let segment_table: Vec<u8> = data[offset + 27..offset + 27 + segment_count].to_vec();
let total_data_size: usize = segment_table.iter().map(|&s| s as usize).sum();
let data_start = offset + 27 + segment_count;
if data_start + total_data_size > data.len() {
return Err("OGG page data extends beyond file".to_string());
}
let page_data = data[data_start..data_start + total_data_size].to_vec();
pages.push(OggPage {
header_type,
granule_position,
serial,
page_sequence,
segment_count: segment_count as u8,
segment_table,
data: page_data,
});
offset = data_start + total_data_size;
}
if pages.is_empty() {
return Err("No OGG pages found".to_string());
}
Ok(pages)
}
/// Extract Vorbis packets from parsed OGG pages.
///
/// Packets can span multiple segments (segment length = 255 means continuation).
/// Packets can also span multiple pages (header_type bit 0x01 = continuation).
pub fn extract_packets(pages: &[OggPage]) -> Result<Vec<Vec<u8>>, String> {
let mut packets: Vec<Vec<u8>> = Vec::new();
let mut current_packet: Vec<u8> = Vec::new();
for page in pages {
let mut data_offset = 0;
for (seg_idx, &seg_len) in page.segment_table.iter().enumerate() {
let seg_data = &page.data[data_offset..data_offset + seg_len as usize];
current_packet.extend_from_slice(seg_data);
data_offset += seg_len as usize;
// A segment length < 255 terminates the current packet.
// A segment length of exactly 255 means the packet continues in the next segment.
if seg_len < 255 {
if !current_packet.is_empty() {
packets.push(std::mem::take(&mut current_packet));
}
}
// If seg_len == 255 and this is the last segment of the page,
// the packet continues on the next page.
let _ = seg_idx; // suppress unused warning
}
}
// If there's remaining data in current_packet (ended with 255-byte segments
// and no terminating segment), flush it as a final packet.
if !current_packet.is_empty() {
packets.push(current_packet);
}
Ok(packets)
}
/// Convenience function: parse OGG container and extract all Vorbis packets.
pub fn parse_ogg(data: &[u8]) -> Result<Vec<Vec<u8>>, String> {
let pages = parse_ogg_pages(data)?;
extract_packets(&pages)
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
/// Build a minimal OGG page from raw packet data.
fn build_ogg_page(
header_type: u8,
granule: u64,
serial: u32,
page_seq: u32,
packets_data: &[&[u8]],
) -> Vec<u8> {
// Build segment table and concatenated data
let mut segment_table = Vec::new();
let mut page_data = Vec::new();
for (i, packet) in packets_data.iter().enumerate() {
let len = packet.len();
// Write full 255-byte segments
let full_segments = len / 255;
let remainder = len % 255;
for _ in 0..full_segments {
segment_table.push(255u8);
}
// Terminating segment (< 255), even if 0 to signal end of packet
segment_table.push(remainder as u8);
page_data.extend_from_slice(packet);
}
let segment_count = segment_table.len();
let mut out = Vec::new();
// Capture pattern
out.extend_from_slice(b"OggS");
// Version
out.push(0);
// Header type
out.push(header_type);
// Granule position
out.extend_from_slice(&granule.to_le_bytes());
// Serial
out.extend_from_slice(&serial.to_le_bytes());
// Page sequence
out.extend_from_slice(&page_seq.to_le_bytes());
// CRC (dummy zeros)
out.extend_from_slice(&[0u8; 4]);
// Segment count
out.push(segment_count as u8);
// Segment table
out.extend_from_slice(&segment_table);
// Data
out.extend_from_slice(&page_data);
out
}
#[test]
fn parse_single_page() {
let packet = b"hello vorbis";
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[packet.as_slice()]);
let pages = parse_ogg_pages(&page_bytes).expect("parse failed");
assert_eq!(pages.len(), 1);
assert_eq!(pages[0].header_type, 0x02);
assert_eq!(pages[0].serial, 1);
assert_eq!(pages[0].page_sequence, 0);
assert_eq!(pages[0].data, packet);
}
#[test]
fn parse_multiple_pages() {
let p1 = build_ogg_page(0x02, 0, 1, 0, &[b"first"]);
let p2 = build_ogg_page(0x00, 100, 1, 1, &[b"second"]);
let mut data = p1;
data.extend_from_slice(&p2);
let pages = parse_ogg_pages(&data).expect("parse failed");
assert_eq!(pages.len(), 2);
assert_eq!(pages[0].page_sequence, 0);
assert_eq!(pages[1].page_sequence, 1);
assert_eq!(pages[1].granule_position, 100);
}
#[test]
fn extract_single_packet() {
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[b"packet_one"]);
let packets = parse_ogg(&page_bytes).expect("parse_ogg failed");
assert_eq!(packets.len(), 1);
assert_eq!(packets[0], b"packet_one");
}
#[test]
fn extract_multiple_packets_single_page() {
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[b"pkt1", b"pkt2", b"pkt3"]);
let packets = parse_ogg(&page_bytes).expect("parse_ogg failed");
assert_eq!(packets.len(), 3);
assert_eq!(packets[0], b"pkt1");
assert_eq!(packets[1], b"pkt2");
assert_eq!(packets[2], b"pkt3");
}
#[test]
fn extract_large_packet_spanning_segments() {
// Create a packet larger than 255 bytes
let large_packet: Vec<u8> = (0..600).map(|i| (i % 256) as u8).collect();
let page_bytes = build_ogg_page(0x02, 0, 1, 0, &[&large_packet]);
let packets = parse_ogg(&page_bytes).expect("parse_ogg failed");
assert_eq!(packets.len(), 1);
assert_eq!(packets[0], large_packet);
}
#[test]
fn invalid_capture_pattern() {
let data = b"NotOGGdata";
let result = parse_ogg_pages(data);
assert!(result.is_err());
assert!(result.unwrap_err().contains("capture pattern"));
}
#[test]
fn empty_data() {
let result = parse_ogg_pages(&[]);
assert!(result.is_err());
}
#[test]
fn page_header_fields() {
let page_bytes = build_ogg_page(0x04, 12345, 42, 7, &[b"data"]);
let pages = parse_ogg_pages(&page_bytes).expect("parse failed");
assert_eq!(pages[0].header_type, 0x04); // EOS
assert_eq!(pages[0].granule_position, 12345);
assert_eq!(pages[0].serial, 42);
assert_eq!(pages[0].page_sequence, 7);
}
}

View File

@@ -0,0 +1,189 @@
/// Simple delay line for echo effect.
pub struct DelayLine {
pub(crate) buffer: Vec<f32>,
pub(crate) write_pos: usize,
delay_samples: usize,
}
impl DelayLine {
pub fn new(delay_samples: usize) -> Self {
DelayLine { buffer: vec![0.0; delay_samples.max(1)], write_pos: 0, delay_samples }
}
pub fn process(&mut self, input: f32, feedback: f32) -> f32 {
let read_pos = (self.write_pos + self.buffer.len() - self.delay_samples) % self.buffer.len();
let delayed = self.buffer[read_pos];
self.buffer[self.write_pos] = input + delayed * feedback;
self.write_pos = (self.write_pos + 1) % self.buffer.len();
delayed
}
}
/// Simple Schroeder reverb using 4 parallel comb filters + 2 allpass filters.
pub struct Reverb {
comb_filters: Vec<CombFilter>,
allpass_filters: Vec<AllpassFilter>,
pub wet: f32, // 0.0-1.0
pub dry: f32, // 0.0-1.0
}
struct CombFilter {
delay: DelayLine,
feedback: f32,
}
impl CombFilter {
fn new(delay_samples: usize, feedback: f32) -> Self {
CombFilter { delay: DelayLine::new(delay_samples), feedback }
}
fn process(&mut self, input: f32) -> f32 {
self.delay.process(input, self.feedback)
}
}
struct AllpassFilter {
delay: DelayLine,
gain: f32,
}
impl AllpassFilter {
fn new(delay_samples: usize, gain: f32) -> Self {
AllpassFilter { delay: DelayLine::new(delay_samples), gain }
}
fn process(&mut self, input: f32) -> f32 {
let delayed = self.delay.process(input, 0.0);
let _output = -self.gain * input + delayed + self.gain * delayed;
// Actually: allpass: output = -g*input + delayed, buffer = input + g*delayed
// Simplified:
let buf_val = self.delay.buffer[(self.delay.write_pos + self.delay.buffer.len() - 1) % self.delay.buffer.len()];
-self.gain * input + buf_val
}
}
impl Reverb {
pub fn new(sample_rate: u32) -> Self {
// Schroeder reverb with prime-number delay lengths
let sr = sample_rate as f32;
Reverb {
comb_filters: vec![
CombFilter::new((0.0297 * sr) as usize, 0.805),
CombFilter::new((0.0371 * sr) as usize, 0.827),
CombFilter::new((0.0411 * sr) as usize, 0.783),
CombFilter::new((0.0437 * sr) as usize, 0.764),
],
allpass_filters: vec![
AllpassFilter::new((0.005 * sr) as usize, 0.7),
AllpassFilter::new((0.0017 * sr) as usize, 0.7),
],
wet: 0.3,
dry: 0.7,
}
}
pub fn process(&mut self, input: f32) -> f32 {
// Sum comb filter outputs
let mut comb_sum = 0.0;
for comb in &mut self.comb_filters {
comb_sum += comb.process(input);
}
comb_sum /= self.comb_filters.len() as f32;
// Chain through allpass filters
let mut output = comb_sum;
for ap in &mut self.allpass_filters {
output = ap.process(output);
}
input * self.dry + output * self.wet
}
}
/// Simple echo (single delay + feedback).
pub struct Echo {
delay: DelayLine,
pub feedback: f32,
pub wet: f32,
pub dry: f32,
}
impl Echo {
pub fn new(delay_ms: f32, sample_rate: u32) -> Self {
let samples = (delay_ms * sample_rate as f32 / 1000.0) as usize;
Echo { delay: DelayLine::new(samples), feedback: 0.5, wet: 0.3, dry: 0.7 }
}
pub fn process(&mut self, input: f32) -> f32 {
let delayed = self.delay.process(input, self.feedback);
input * self.dry + delayed * self.wet
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_delay_line_basic() {
let mut dl = DelayLine::new(2);
assert!((dl.process(1.0, 0.0) - 0.0).abs() < 1e-6); // no output yet
assert!((dl.process(0.0, 0.0) - 0.0).abs() < 1e-6);
assert!((dl.process(0.0, 0.0) - 1.0).abs() < 1e-6); // input appears after 2-sample delay
}
#[test]
fn test_delay_line_feedback() {
let mut dl = DelayLine::new(2);
dl.process(1.0, 0.5);
dl.process(0.0, 0.5);
let out = dl.process(0.0, 0.5); // delayed 1.0
assert!((out - 1.0).abs() < 1e-6);
dl.process(0.0, 0.5);
let out2 = dl.process(0.0, 0.5); // feedback: 0.5
assert!((out2 - 0.5).abs() < 1e-6);
}
#[test]
fn test_reverb_preserves_signal() {
let mut rev = Reverb::new(44100);
// Process silence -> should be silence
let out = rev.process(0.0);
assert!((out - 0.0).abs() < 1e-6);
}
#[test]
fn test_reverb_produces_output() {
let mut rev = Reverb::new(44100);
// Send an impulse, collect some output
let mut energy = 0.0;
rev.process(1.0);
for _ in 0..4410 {
let s = rev.process(0.0);
energy += s.abs();
}
assert!(energy > 0.01, "reverb should produce decaying output after impulse");
}
#[test]
fn test_echo_basic() {
let mut echo = Echo::new(100.0, 44100); // 100ms delay
let delay_samples = (100.0 * 44100.0 / 1000.0) as usize;
// Send impulse
echo.process(1.0);
// Process until delay
for _ in 1..delay_samples {
echo.process(0.0);
}
let out = echo.process(0.0);
// Should have echo of input (wet * delayed)
assert!(out.abs() > 0.01, "echo should produce delayed output");
}
#[test]
fn test_echo_wet_dry() {
let mut echo = Echo::new(10.0, 44100);
echo.wet = 0.0;
echo.dry = 1.0;
let out = echo.process(0.5);
assert!((out - 0.5).abs() < 1e-6); // dry only
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -62,7 +62,42 @@ fn find_chunk(data: &[u8], id: &[u8; 4], start: usize) -> Option<(usize, u32)> {
// Public API
// ---------------------------------------------------------------------------
/// Parse a PCM 16-bit WAV file from raw bytes into an [`AudioClip`].
/// Read a 24-bit signed integer (little-endian) from 3 bytes and return as i32.
fn read_i24_le(data: &[u8], offset: usize) -> Result<i32, String> {
if offset + 3 > data.len() {
return Err(format!("read_i24_le: offset {} out of bounds (len={})", offset, data.len()));
}
let lo = data[offset] as u32;
let mid = data[offset + 1] as u32;
let hi = data[offset + 2] as u32;
let unsigned = lo | (mid << 8) | (hi << 16);
// Sign-extend from 24-bit to 32-bit
if unsigned & 0x800000 != 0 {
Ok((unsigned | 0xFF000000) as i32)
} else {
Ok(unsigned as i32)
}
}
/// Read a 32-bit float (little-endian).
fn read_f32_le(data: &[u8], offset: usize) -> Result<f32, String> {
if offset + 4 > data.len() {
return Err(format!("read_f32_le: offset {} out of bounds (len={})", offset, data.len()));
}
Ok(f32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]))
}
/// Parse a WAV file from raw bytes into an [`AudioClip`].
///
/// Supported formats:
/// - PCM 16-bit (format_tag=1, bits_per_sample=16)
/// - PCM 24-bit (format_tag=1, bits_per_sample=24)
/// - IEEE float 32-bit (format_tag=3, bits_per_sample=32)
pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
// Minimum viable WAV: RIFF(4) + size(4) + WAVE(4) = 12 bytes
if data.len() < 12 {
@@ -86,10 +121,6 @@ pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
}
let format_tag = read_u16_le(data, fmt_offset)?;
if format_tag != 1 {
return Err(format!("Unsupported WAV format tag: {} (only PCM=1 is supported)", format_tag));
}
let channels = read_u16_le(data, fmt_offset + 2)?;
if channels != 1 && channels != 2 {
return Err(format!("Unsupported channel count: {}", channels));
@@ -99,9 +130,16 @@ pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
// byte_rate = fmt_offset + 8 (skip)
// block_align = fmt_offset + 12 (skip)
let bits_per_sample = read_u16_le(data, fmt_offset + 14)?;
if bits_per_sample != 16 {
return Err(format!("Unsupported bits per sample: {} (only 16-bit is supported)", bits_per_sample));
}
// Validate format_tag + bits_per_sample combination
let bytes_per_sample = match (format_tag, bits_per_sample) {
(1, 16) => 2, // PCM 16-bit
(1, 24) => 3, // PCM 24-bit
(3, 32) => 4, // IEEE float 32-bit
(1, bps) => return Err(format!("Unsupported PCM bits per sample: {}", bps)),
(3, bps) => return Err(format!("Unsupported float bits per sample: {} (only 32-bit supported)", bps)),
(tag, _) => return Err(format!("Unsupported WAV format tag: {} (only PCM=1 and IEEE_FLOAT=3 are supported)", tag)),
};
// --- data chunk ---
let (data_offset, data_size) =
@@ -112,14 +150,30 @@ pub fn parse_wav(data: &[u8]) -> Result<AudioClip, String> {
return Err("data chunk extends beyond end of file".to_string());
}
// Each sample is 2 bytes (16-bit PCM).
let sample_count = data_size as usize / 2;
let sample_count = data_size as usize / bytes_per_sample;
let mut samples = Vec::with_capacity(sample_count);
for i in 0..sample_count {
let raw = read_i16_le(data, data_offset + i * 2)?;
// Convert i16 [-32768, 32767] to f32 [-1.0, ~1.0]
samples.push(raw as f32 / 32768.0);
match (format_tag, bits_per_sample) {
(1, 16) => {
for i in 0..sample_count {
let raw = read_i16_le(data, data_offset + i * 2)?;
samples.push(raw as f32 / 32768.0);
}
}
(1, 24) => {
for i in 0..sample_count {
let raw = read_i24_le(data, data_offset + i * 3)?;
// 24-bit range: [-8388608, 8388607]
samples.push(raw as f32 / 8388608.0);
}
}
(3, 32) => {
for i in 0..sample_count {
let raw = read_f32_le(data, data_offset + i * 4)?;
samples.push(raw);
}
}
_ => unreachable!(),
}
Ok(AudioClip::new(samples, sample_rate, channels))
@@ -165,6 +219,87 @@ pub fn generate_wav_bytes(samples_f32: &[f32], sample_rate: u32) -> Vec<u8> {
out
}
/// Generate a minimal PCM 24-bit mono WAV file from f32 samples.
/// Used for round-trip testing.
pub fn generate_wav_bytes_24bit(samples_f32: &[f32], sample_rate: u32) -> Vec<u8> {
let channels: u16 = 1;
let bits_per_sample: u16 = 24;
let byte_rate = sample_rate * channels as u32 * bits_per_sample as u32 / 8;
let block_align: u16 = channels * bits_per_sample / 8;
let data_size = (samples_f32.len() * 3) as u32;
let riff_size = 4 + 8 + 16 + 8 + data_size;
let mut out: Vec<u8> = Vec::with_capacity(12 + 8 + 16 + 8 + data_size as usize);
// RIFF header
out.extend_from_slice(b"RIFF");
out.extend_from_slice(&riff_size.to_le_bytes());
out.extend_from_slice(b"WAVE");
// fmt chunk
out.extend_from_slice(b"fmt ");
out.extend_from_slice(&16u32.to_le_bytes());
out.extend_from_slice(&1u16.to_le_bytes()); // PCM
out.extend_from_slice(&channels.to_le_bytes());
out.extend_from_slice(&sample_rate.to_le_bytes());
out.extend_from_slice(&byte_rate.to_le_bytes());
out.extend_from_slice(&block_align.to_le_bytes());
out.extend_from_slice(&bits_per_sample.to_le_bytes());
// data chunk
out.extend_from_slice(b"data");
out.extend_from_slice(&data_size.to_le_bytes());
for &s in samples_f32 {
let clamped = s.clamp(-1.0, 1.0);
let raw = (clamped * 8388607.0) as i32;
// Write 3 bytes LE
out.push((raw & 0xFF) as u8);
out.push(((raw >> 8) & 0xFF) as u8);
out.push(((raw >> 16) & 0xFF) as u8);
}
out
}
/// Generate a minimal IEEE float 32-bit mono WAV file from f32 samples.
/// Used for round-trip testing.
pub fn generate_wav_bytes_f32(samples_f32: &[f32], sample_rate: u32) -> Vec<u8> {
let channels: u16 = 1;
let bits_per_sample: u16 = 32;
let byte_rate = sample_rate * channels as u32 * bits_per_sample as u32 / 8;
let block_align: u16 = channels * bits_per_sample / 8;
let data_size = (samples_f32.len() * 4) as u32;
let riff_size = 4 + 8 + 16 + 8 + data_size;
let mut out: Vec<u8> = Vec::with_capacity(12 + 8 + 16 + 8 + data_size as usize);
// RIFF header
out.extend_from_slice(b"RIFF");
out.extend_from_slice(&riff_size.to_le_bytes());
out.extend_from_slice(b"WAVE");
// fmt chunk
out.extend_from_slice(b"fmt ");
out.extend_from_slice(&16u32.to_le_bytes());
out.extend_from_slice(&3u16.to_le_bytes()); // IEEE_FLOAT
out.extend_from_slice(&channels.to_le_bytes());
out.extend_from_slice(&sample_rate.to_le_bytes());
out.extend_from_slice(&byte_rate.to_le_bytes());
out.extend_from_slice(&block_align.to_le_bytes());
out.extend_from_slice(&bits_per_sample.to_le_bytes());
// data chunk
out.extend_from_slice(b"data");
out.extend_from_slice(&data_size.to_le_bytes());
for &s in samples_f32 {
out.extend_from_slice(&s.to_le_bytes());
}
out
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
@@ -237,4 +372,107 @@ mod tests {
);
}
}
// -----------------------------------------------------------------------
// 24-bit PCM tests
// -----------------------------------------------------------------------
#[test]
fn parse_24bit_wav() {
let sample_rate = 44100u32;
let num_samples = 4410usize;
let samples: Vec<f32> = (0..num_samples)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin())
.collect();
let wav_bytes = generate_wav_bytes_24bit(&samples, sample_rate);
let clip = parse_wav(&wav_bytes).expect("parse_wav 24-bit failed");
assert_eq!(clip.sample_rate, sample_rate);
assert_eq!(clip.channels, 1);
assert_eq!(clip.frame_count(), num_samples);
}
#[test]
fn roundtrip_24bit() {
let original: Vec<f32> = vec![0.0, 0.25, 0.5, -0.25, -0.5, 1.0, -1.0];
let wav_bytes = generate_wav_bytes_24bit(&original, 44100);
let clip = parse_wav(&wav_bytes).expect("roundtrip 24-bit parse failed");
assert_eq!(clip.samples.len(), original.len());
for (orig, decoded) in original.iter().zip(clip.samples.iter()) {
// 24-bit quantization error should be < 0.0001
assert!(
(orig - decoded).abs() < 0.0001,
"24-bit: orig={} decoded={}",
orig,
decoded
);
}
}
#[test]
fn accuracy_24bit() {
// 24-bit should be more accurate than 16-bit
let samples = vec![0.5f32];
let wav_bytes = generate_wav_bytes_24bit(&samples, 44100);
let clip = parse_wav(&wav_bytes).expect("parse failed");
// 0.5 * 8388607 = 4194303 -> 4194303 / 8388608 ≈ 0.49999988
assert!((clip.samples[0] - 0.5).abs() < 0.0001, "24-bit got {}", clip.samples[0]);
}
// -----------------------------------------------------------------------
// 32-bit float tests
// -----------------------------------------------------------------------
#[test]
fn parse_32bit_float_wav() {
let sample_rate = 44100u32;
let num_samples = 4410usize;
let samples: Vec<f32> = (0..num_samples)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin())
.collect();
let wav_bytes = generate_wav_bytes_f32(&samples, sample_rate);
let clip = parse_wav(&wav_bytes).expect("parse_wav float32 failed");
assert_eq!(clip.sample_rate, sample_rate);
assert_eq!(clip.channels, 1);
assert_eq!(clip.frame_count(), num_samples);
}
#[test]
fn roundtrip_32bit_float() {
let original: Vec<f32> = vec![0.0, 0.25, 0.5, -0.25, -0.5, 1.0, -1.0];
let wav_bytes = generate_wav_bytes_f32(&original, 44100);
let clip = parse_wav(&wav_bytes).expect("roundtrip float32 parse failed");
assert_eq!(clip.samples.len(), original.len());
for (orig, decoded) in original.iter().zip(clip.samples.iter()) {
// 32-bit float should be exact
assert_eq!(*orig, *decoded, "float32: orig={} decoded={}", orig, decoded);
}
}
#[test]
fn accuracy_32bit_float() {
// 32-bit float should preserve exact values
let samples = vec![0.123456789f32, -0.987654321f32];
let wav_bytes = generate_wav_bytes_f32(&samples, 44100);
let clip = parse_wav(&wav_bytes).expect("parse failed");
assert_eq!(clip.samples[0], 0.123456789f32);
assert_eq!(clip.samples[1], -0.987654321f32);
}
#[test]
fn reject_unsupported_format_tag() {
// Create a WAV with format_tag=2 (ADPCM), which we don't support
let mut wav = generate_wav_bytes(&[0.0], 44100);
// format_tag is at byte 20-21 (RIFF(4)+size(4)+WAVE(4)+fmt(4)+chunk_size(4))
wav[20] = 2;
wav[21] = 0;
let result = parse_wav(&wav);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Unsupported WAV format tag"));
}
}

View File

@@ -0,0 +1,263 @@
/// Binary scene format (.vscn binary).
///
/// Format:
/// Header: "VSCN" (4 bytes) + version u32 LE + entity_count u32 LE
/// Per entity:
/// parent_index i32 LE (-1 = no parent)
/// component_count u32 LE
/// Per component:
/// name_len u16 LE + name bytes
/// data_len u32 LE + data bytes
use std::collections::HashMap;
use crate::entity::Entity;
use crate::world::World;
use crate::transform::Transform;
use crate::hierarchy::{add_child, Parent};
use crate::component_registry::ComponentRegistry;
const MAGIC: &[u8; 4] = b"VSCN";
const VERSION: u32 = 1;
/// Serialize all entities with a Transform to the binary scene format.
pub fn serialize_scene_binary(world: &World, registry: &ComponentRegistry) -> Vec<u8> {
let entities_with_transform: Vec<(Entity, Transform)> = world
.query::<Transform>()
.map(|(e, t)| (e, *t))
.collect();
let entity_to_index: HashMap<Entity, usize> = entities_with_transform
.iter()
.enumerate()
.map(|(i, (e, _))| (*e, i))
.collect();
let entity_count = entities_with_transform.len() as u32;
let mut buf = Vec::new();
// Header
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&entity_count.to_le_bytes());
// Entities
for (entity, _) in &entities_with_transform {
// Parent index
let parent_idx: i32 = if let Some(parent_comp) = world.get::<Parent>(*entity) {
entity_to_index.get(&parent_comp.0)
.map(|&i| i as i32)
.unwrap_or(-1)
} else {
-1
};
buf.extend_from_slice(&parent_idx.to_le_bytes());
// Collect serializable components
let mut comp_data: Vec<(&str, Vec<u8>)> = Vec::new();
for entry in registry.entries() {
if let Some(data) = (entry.serialize)(world, *entity) {
comp_data.push((&entry.name, data));
}
}
let comp_count = comp_data.len() as u32;
buf.extend_from_slice(&comp_count.to_le_bytes());
for (name, data) in &comp_data {
let name_bytes = name.as_bytes();
buf.extend_from_slice(&(name_bytes.len() as u16).to_le_bytes());
buf.extend_from_slice(name_bytes);
buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
buf.extend_from_slice(data);
}
}
buf
}
/// Deserialize entities from binary scene data.
pub fn deserialize_scene_binary(
world: &mut World,
data: &[u8],
registry: &ComponentRegistry,
) -> Result<Vec<Entity>, String> {
if data.len() < 12 {
return Err("Binary scene data too short".into());
}
// Verify magic
if &data[0..4] != MAGIC {
return Err("Invalid magic bytes — expected VSCN".into());
}
let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
if version != 1 {
return Err(format!("Unsupported binary scene version: {}", version));
}
let entity_count = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
let mut pos = 12;
let mut created: Vec<Entity> = Vec::with_capacity(entity_count);
let mut parent_indices: Vec<i32> = Vec::with_capacity(entity_count);
for _ in 0..entity_count {
// Parent index
if pos + 4 > data.len() {
return Err("Unexpected end of data reading parent index".into());
}
let parent_idx = i32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
// Component count
if pos + 4 > data.len() {
return Err("Unexpected end of data reading component count".into());
}
let comp_count = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
let entity = world.spawn();
for _ in 0..comp_count {
// Name length
if pos + 2 > data.len() {
return Err("Unexpected end of data reading name length".into());
}
let name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
// Name
if pos + name_len > data.len() {
return Err("Unexpected end of data reading component name".into());
}
let name = std::str::from_utf8(&data[pos..pos + name_len])
.map_err(|_| "Invalid UTF-8 in component name".to_string())?;
pos += name_len;
// Data length
if pos + 4 > data.len() {
return Err("Unexpected end of data reading data length".into());
}
let data_len = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
// Data
if pos + data_len > data.len() {
return Err("Unexpected end of data reading component data".into());
}
let comp_data = &data[pos..pos + data_len];
pos += data_len;
// Deserialize via registry
if let Some(entry) = registry.find(name) {
(entry.deserialize)(world, entity, comp_data)?;
}
}
created.push(entity);
parent_indices.push(parent_idx);
}
// Apply parent relationships
for (child_idx, &parent_idx) in parent_indices.iter().enumerate() {
if parent_idx >= 0 {
let pi = parent_idx as usize;
if pi < created.len() {
let child_entity = created[child_idx];
let parent_entity = created[pi];
add_child(world, parent_entity, child_entity);
}
}
}
Ok(created)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scene::Tag;
use voltex_math::Vec3;
#[test]
fn test_binary_roundtrip() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::new(5.0, 0.0, -1.0)));
world.add(e, Tag("enemy".into()));
let data = serialize_scene_binary(&world, &registry);
assert_eq!(&data[0..4], b"VSCN");
let mut world2 = World::new();
let entities = deserialize_scene_binary(&mut world2, &data, &registry).unwrap();
assert_eq!(entities.len(), 1);
let t = world2.get::<Transform>(entities[0]).unwrap();
assert!((t.position.x - 5.0).abs() < 1e-6);
assert!((t.position.z - (-1.0)).abs() < 1e-6);
let tag = world2.get::<Tag>(entities[0]).unwrap();
assert_eq!(tag.0, "enemy");
}
#[test]
fn test_binary_with_hierarchy() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let a = world.spawn();
let b = world.spawn();
world.add(a, Transform::new());
world.add(b, Transform::new());
add_child(&mut world, a, b);
let data = serialize_scene_binary(&world, &registry);
let mut world2 = World::new();
let entities = deserialize_scene_binary(&mut world2, &data, &registry).unwrap();
assert_eq!(entities.len(), 2);
assert!(world2.get::<Parent>(entities[1]).is_some());
let p = world2.get::<Parent>(entities[1]).unwrap();
assert_eq!(p.0, entities[0]);
}
#[test]
fn test_binary_invalid_magic() {
let data = vec![0u8; 20];
let mut world = World::new();
let registry = ComponentRegistry::new();
assert!(deserialize_scene_binary(&mut world, &data, &registry).is_err());
}
#[test]
fn test_binary_too_short() {
let data = vec![0u8; 5];
let mut world = World::new();
let registry = ComponentRegistry::new();
assert!(deserialize_scene_binary(&mut world, &data, &registry).is_err());
}
#[test]
fn test_binary_multiple_entities() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
for i in 0..3 {
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::new(i as f32, 0.0, 0.0)));
world.add(e, Tag(format!("e{}", i)));
}
let data = serialize_scene_binary(&world, &registry);
let mut world2 = World::new();
let entities = deserialize_scene_binary(&mut world2, &data, &registry).unwrap();
assert_eq!(entities.len(), 3);
for (i, &e) in entities.iter().enumerate() {
let t = world2.get::<Transform>(e).unwrap();
assert!((t.position.x - i as f32).abs() < 1e-6);
let tag = world2.get::<Tag>(e).unwrap();
assert_eq!(tag.0, format!("e{}", i));
}
}
}

View File

@@ -0,0 +1,198 @@
/// Registration-based component serialization for scene formats.
/// Each registered component type has a name, a serialize function,
/// and a deserialize function.
use crate::entity::Entity;
use crate::world::World;
pub type SerializeFn = fn(&World, Entity) -> Option<Vec<u8>>;
pub type DeserializeFn = fn(&mut World, Entity, &[u8]) -> Result<(), String>;
pub struct ComponentEntry {
pub name: String,
pub serialize: SerializeFn,
pub deserialize: DeserializeFn,
}
pub struct ComponentRegistry {
entries: Vec<ComponentEntry>,
}
impl ComponentRegistry {
pub fn new() -> Self {
Self { entries: Vec::new() }
}
pub fn register(&mut self, name: &str, ser: SerializeFn, deser: DeserializeFn) {
self.entries.push(ComponentEntry {
name: name.to_string(),
serialize: ser,
deserialize: deser,
});
}
pub fn find(&self, name: &str) -> Option<&ComponentEntry> {
self.entries.iter().find(|e| e.name == name)
}
pub fn entries(&self) -> &[ComponentEntry] {
&self.entries
}
/// Register the default built-in component types: transform and tag.
pub fn register_defaults(&mut self) {
self.register("transform", serialize_transform, deserialize_transform);
self.register("tag", serialize_tag, deserialize_tag);
}
}
impl Default for ComponentRegistry {
fn default() -> Self {
Self::new()
}
}
// ── Transform: 9 f32s in little-endian (pos.xyz, rot.xyz, scale.xyz) ─
fn serialize_transform(world: &World, entity: Entity) -> Option<Vec<u8>> {
let t = world.get::<crate::Transform>(entity)?;
let mut data = Vec::with_capacity(36);
for &v in &[
t.position.x, t.position.y, t.position.z,
t.rotation.x, t.rotation.y, t.rotation.z,
t.scale.x, t.scale.y, t.scale.z,
] {
data.extend_from_slice(&v.to_le_bytes());
}
Some(data)
}
fn deserialize_transform(world: &mut World, entity: Entity, data: &[u8]) -> Result<(), String> {
if data.len() < 36 {
return Err("Transform data too short".into());
}
let f = |off: usize| f32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]);
let t = crate::Transform {
position: voltex_math::Vec3::new(f(0), f(4), f(8)),
rotation: voltex_math::Vec3::new(f(12), f(16), f(20)),
scale: voltex_math::Vec3::new(f(24), f(28), f(32)),
};
world.add(entity, t);
Ok(())
}
// ── Tag: UTF-8 string bytes ─────────────────────────────────────────
fn serialize_tag(world: &World, entity: Entity) -> Option<Vec<u8>> {
let tag = world.get::<crate::scene::Tag>(entity)?;
Some(tag.0.as_bytes().to_vec())
}
fn deserialize_tag(world: &mut World, entity: Entity, data: &[u8]) -> Result<(), String> {
let s = std::str::from_utf8(data).map_err(|_| "Invalid UTF-8 in tag".to_string())?;
world.add(entity, crate::scene::Tag(s.to_string()));
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{World, Transform};
use voltex_math::Vec3;
#[test]
fn test_register_and_find() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
assert!(registry.find("transform").is_some());
assert!(registry.find("tag").is_some());
assert!(registry.find("nonexistent").is_none());
}
#[test]
fn test_entries_count() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
assert_eq!(registry.entries().len(), 2);
}
#[test]
fn test_serialize_transform() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::new(1.0, 2.0, 3.0)));
let entry = registry.find("transform").unwrap();
let data = (entry.serialize)(&world, e);
assert!(data.is_some());
assert_eq!(data.unwrap().len(), 36); // 9 f32s * 4 bytes
}
#[test]
fn test_serialize_missing_component() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let e = world.spawn();
// no Transform added
let entry = registry.find("transform").unwrap();
assert!((entry.serialize)(&world, e).is_none());
}
#[test]
fn test_roundtrip_transform() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform {
position: Vec3::new(1.0, 2.0, 3.0),
rotation: Vec3::new(0.1, 0.2, 0.3),
scale: Vec3::new(4.0, 5.0, 6.0),
});
let entry = registry.find("transform").unwrap();
let data = (entry.serialize)(&world, e).unwrap();
let mut world2 = World::new();
let e2 = world2.spawn();
(entry.deserialize)(&mut world2, e2, &data).unwrap();
let t = world2.get::<Transform>(e2).unwrap();
assert!((t.position.x - 1.0).abs() < 1e-6);
assert!((t.position.y - 2.0).abs() < 1e-6);
assert!((t.position.z - 3.0).abs() < 1e-6);
assert!((t.rotation.x - 0.1).abs() < 1e-6);
assert!((t.scale.x - 4.0).abs() < 1e-6);
}
#[test]
fn test_roundtrip_tag() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let e = world.spawn();
world.add(e, crate::scene::Tag("hello world".to_string()));
let entry = registry.find("tag").unwrap();
let data = (entry.serialize)(&world, e).unwrap();
let mut world2 = World::new();
let e2 = world2.spawn();
(entry.deserialize)(&mut world2, e2, &data).unwrap();
let tag = world2.get::<crate::scene::Tag>(e2).unwrap();
assert_eq!(tag.0, "hello world");
}
#[test]
fn test_deserialize_transform_too_short() {
let mut world = World::new();
let e = world.spawn();
let result = deserialize_transform(&mut world, e, &[0u8; 10]);
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,476 @@
/// Mini JSON writer and parser for scene serialization.
/// No external dependencies — self-contained within voltex_ecs.
#[derive(Debug, Clone, PartialEq)]
pub enum JsonVal {
Null,
Bool(bool),
Number(f64),
Str(String),
Array(Vec<JsonVal>),
Object(Vec<(String, JsonVal)>),
}
// ── Writer helpers ──────────────────────────────────────────────────
pub fn json_write_null() -> String {
"null".to_string()
}
pub fn json_write_f32(v: f32) -> String {
// Emit integer form when the value has no fractional part
if v.fract() == 0.0 && v.abs() < 1e15 {
format!("{}", v as i64)
} else {
format!("{}", v)
}
}
pub fn json_write_string(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 2);
out.push('"');
for c in s.chars() {
match c {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
_ => out.push(c),
}
}
out.push('"');
out
}
/// Write a JSON array from pre-formatted element strings.
pub fn json_write_array(elements: &[&str]) -> String {
let mut out = String::from("[");
for (i, elem) in elements.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push_str(elem);
}
out.push(']');
out
}
/// Write a JSON object from (key, pre-formatted-value) pairs.
pub fn json_write_object(pairs: &[(&str, &str)]) -> String {
let mut out = String::from("{");
for (i, (key, val)) in pairs.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push_str(&json_write_string(key));
out.push(':');
out.push_str(val);
}
out.push('}');
out
}
// ── Accessors on JsonVal ────────────────────────────────────────────
impl JsonVal {
pub fn get(&self, key: &str) -> Option<&JsonVal> {
match self {
JsonVal::Object(pairs) => pairs.iter().find(|(k, _)| k == key).map(|(_, v)| v),
_ => None,
}
}
pub fn as_f64(&self) -> Option<f64> {
match self {
JsonVal::Number(n) => Some(*n),
_ => None,
}
}
pub fn as_str(&self) -> Option<&str> {
match self {
JsonVal::Str(s) => Some(s.as_str()),
_ => None,
}
}
pub fn as_array(&self) -> Option<&Vec<JsonVal>> {
match self {
JsonVal::Array(a) => Some(a),
_ => None,
}
}
pub fn as_object(&self) -> Option<&Vec<(String, JsonVal)>> {
match self {
JsonVal::Object(o) => Some(o),
_ => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
JsonVal::Bool(b) => Some(*b),
_ => None,
}
}
}
// ── Parser (recursive descent) ──────────────────────────────────────
pub fn json_parse(input: &str) -> Result<JsonVal, String> {
let bytes = input.as_bytes();
let (val, pos) = parse_value(bytes, skip_ws(bytes, 0))?;
let pos = skip_ws(bytes, pos);
if pos != bytes.len() {
return Err(format!("Unexpected trailing content at position {}", pos));
}
Ok(val)
}
fn skip_ws(b: &[u8], mut pos: usize) -> usize {
while pos < b.len() && matches!(b[pos], b' ' | b'\t' | b'\n' | b'\r') {
pos += 1;
}
pos
}
fn parse_value(b: &[u8], pos: usize) -> Result<(JsonVal, usize), String> {
if pos >= b.len() {
return Err("Unexpected end of input".into());
}
match b[pos] {
b'"' => parse_string(b, pos),
b'{' => parse_object(b, pos),
b'[' => parse_array(b, pos),
b't' | b'f' => parse_bool(b, pos),
b'n' => parse_null(b, pos),
_ => parse_number(b, pos),
}
}
fn parse_string(b: &[u8], pos: usize) -> Result<(JsonVal, usize), String> {
let (s, end) = read_string(b, pos)?;
Ok((JsonVal::Str(s), end))
}
fn read_string(b: &[u8], pos: usize) -> Result<(String, usize), String> {
if pos >= b.len() || b[pos] != b'"' {
return Err(format!("Expected '\"' at position {}", pos));
}
let mut i = pos + 1;
let mut s = String::new();
while i < b.len() {
match b[i] {
b'"' => return Ok((s, i + 1)),
b'\\' => {
i += 1;
if i >= b.len() {
return Err("Unexpected end in string escape".into());
}
match b[i] {
b'"' => s.push('"'),
b'\\' => s.push('\\'),
b'/' => s.push('/'),
b'n' => s.push('\n'),
b'r' => s.push('\r'),
b't' => s.push('\t'),
_ => {
s.push('\\');
s.push(b[i] as char);
}
}
i += 1;
}
ch => {
s.push(ch as char);
i += 1;
}
}
}
Err("Unterminated string".into())
}
fn parse_number(b: &[u8], pos: usize) -> Result<(JsonVal, usize), String> {
let mut i = pos;
// optional minus
if i < b.len() && b[i] == b'-' {
i += 1;
}
// digits
while i < b.len() && b[i].is_ascii_digit() {
i += 1;
}
// fractional
if i < b.len() && b[i] == b'.' {
i += 1;
while i < b.len() && b[i].is_ascii_digit() {
i += 1;
}
}
// exponent
if i < b.len() && (b[i] == b'e' || b[i] == b'E') {
i += 1;
if i < b.len() && (b[i] == b'+' || b[i] == b'-') {
i += 1;
}
while i < b.len() && b[i].is_ascii_digit() {
i += 1;
}
}
if i == pos {
return Err(format!("Expected number at position {}", pos));
}
let s = std::str::from_utf8(&b[pos..i]).unwrap();
let n: f64 = s.parse().map_err(|e| format!("Invalid number '{}': {}", s, e))?;
Ok((JsonVal::Number(n), i))
}
fn parse_bool(b: &[u8], pos: usize) -> Result<(JsonVal, usize), String> {
if b[pos..].starts_with(b"true") {
Ok((JsonVal::Bool(true), pos + 4))
} else if b[pos..].starts_with(b"false") {
Ok((JsonVal::Bool(false), pos + 5))
} else {
Err(format!("Expected bool at position {}", pos))
}
}
fn parse_null(b: &[u8], pos: usize) -> Result<(JsonVal, usize), String> {
if b[pos..].starts_with(b"null") {
Ok((JsonVal::Null, pos + 4))
} else {
Err(format!("Expected null at position {}", pos))
}
}
fn parse_array(b: &[u8], pos: usize) -> Result<(JsonVal, usize), String> {
let mut i = pos + 1; // skip '['
let mut arr = Vec::new();
i = skip_ws(b, i);
if i < b.len() && b[i] == b']' {
return Ok((JsonVal::Array(arr), i + 1));
}
loop {
i = skip_ws(b, i);
let (val, next) = parse_value(b, i)?;
arr.push(val);
i = skip_ws(b, next);
if i >= b.len() {
return Err("Unterminated array".into());
}
if b[i] == b']' {
return Ok((JsonVal::Array(arr), i + 1));
}
if b[i] != b',' {
return Err(format!("Expected ',' or ']' at position {}", i));
}
i += 1; // skip ','
}
}
fn parse_object(b: &[u8], pos: usize) -> Result<(JsonVal, usize), String> {
let mut i = pos + 1; // skip '{'
let mut pairs = Vec::new();
i = skip_ws(b, i);
if i < b.len() && b[i] == b'}' {
return Ok((JsonVal::Object(pairs), i + 1));
}
loop {
i = skip_ws(b, i);
let (key, next) = read_string(b, i)?;
i = skip_ws(b, next);
if i >= b.len() || b[i] != b':' {
return Err(format!("Expected ':' at position {}", i));
}
i = skip_ws(b, i + 1);
let (val, next) = parse_value(b, i)?;
pairs.push((key, val));
i = skip_ws(b, next);
if i >= b.len() {
return Err("Unterminated object".into());
}
if b[i] == b'}' {
return Ok((JsonVal::Object(pairs), i + 1));
}
if b[i] != b',' {
return Err(format!("Expected ',' or '}}' at position {}", i));
}
i += 1; // skip ','
}
}
// ── JsonVal -> String serialization ─────────────────────────────────
impl JsonVal {
/// Serialize this JsonVal to a compact JSON string.
pub fn to_json_string(&self) -> String {
match self {
JsonVal::Null => "null".to_string(),
JsonVal::Bool(b) => if *b { "true" } else { "false" }.to_string(),
JsonVal::Number(n) => {
if n.fract() == 0.0 && n.abs() < 1e15 {
format!("{}", *n as i64)
} else {
format!("{}", n)
}
}
JsonVal::Str(s) => json_write_string(s),
JsonVal::Array(arr) => {
let mut out = String::from("[");
for (i, v) in arr.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push_str(&v.to_json_string());
}
out.push(']');
out
}
JsonVal::Object(pairs) => {
let mut out = String::from("{");
for (i, (k, v)) in pairs.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push_str(&json_write_string(k));
out.push(':');
out.push_str(&v.to_json_string());
}
out.push('}');
out
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_null() {
assert_eq!(json_write_null(), "null");
}
#[test]
fn test_write_number() {
assert_eq!(json_write_f32(3.14), "3.14");
assert_eq!(json_write_f32(1.0), "1");
}
#[test]
fn test_write_string() {
assert_eq!(json_write_string("hello"), "\"hello\"");
assert_eq!(json_write_string("a\"b"), "\"a\\\"b\"");
}
#[test]
fn test_write_array() {
assert_eq!(json_write_array(&["1", "2", "3"]), "[1,2,3]");
}
#[test]
fn test_write_object() {
let pairs = vec![("name", "\"test\""), ("value", "42")];
let result = json_write_object(&pairs);
assert_eq!(result, r#"{"name":"test","value":42}"#);
}
#[test]
fn test_parse_number() {
match json_parse("42.5").unwrap() {
JsonVal::Number(n) => assert!((n - 42.5).abs() < 1e-10),
_ => panic!("expected number"),
}
}
#[test]
fn test_parse_negative_number() {
match json_parse("-3.14").unwrap() {
JsonVal::Number(n) => assert!((n - (-3.14)).abs() < 1e-10),
_ => panic!("expected number"),
}
}
#[test]
fn test_parse_string() {
assert_eq!(json_parse("\"hello\"").unwrap(), JsonVal::Str("hello".into()));
}
#[test]
fn test_parse_string_with_escapes() {
assert_eq!(
json_parse(r#""a\"b\\c""#).unwrap(),
JsonVal::Str("a\"b\\c".into())
);
}
#[test]
fn test_parse_array() {
match json_parse("[1,2,3]").unwrap() {
JsonVal::Array(a) => assert_eq!(a.len(), 3),
_ => panic!("expected array"),
}
}
#[test]
fn test_parse_empty_array() {
match json_parse("[]").unwrap() {
JsonVal::Array(a) => assert_eq!(a.len(), 0),
_ => panic!("expected array"),
}
}
#[test]
fn test_parse_object() {
let val = json_parse(r#"{"x":1,"y":2}"#).unwrap();
assert!(matches!(val, JsonVal::Object(_)));
assert_eq!(val.get("x").unwrap().as_f64().unwrap(), 1.0);
assert_eq!(val.get("y").unwrap().as_f64().unwrap(), 2.0);
}
#[test]
fn test_parse_null() {
assert_eq!(json_parse("null").unwrap(), JsonVal::Null);
}
#[test]
fn test_parse_bool() {
assert_eq!(json_parse("true").unwrap(), JsonVal::Bool(true));
assert_eq!(json_parse("false").unwrap(), JsonVal::Bool(false));
}
#[test]
fn test_parse_nested() {
let val = json_parse(r#"{"a":[1,2],"b":{"c":3}}"#).unwrap();
assert!(matches!(val, JsonVal::Object(_)));
let arr = val.get("a").unwrap().as_array().unwrap();
assert_eq!(arr.len(), 2);
let inner = val.get("b").unwrap();
assert_eq!(inner.get("c").unwrap().as_f64().unwrap(), 3.0);
}
#[test]
fn test_parse_whitespace() {
let val = json_parse(" { \"a\" : 1 , \"b\" : [ 2 , 3 ] } ").unwrap();
assert!(matches!(val, JsonVal::Object(_)));
}
#[test]
fn test_json_val_to_json_string_roundtrip() {
let val = JsonVal::Object(vec![
("name".into(), JsonVal::Str("test".into())),
("count".into(), JsonVal::Number(42.0)),
("items".into(), JsonVal::Array(vec![
JsonVal::Number(1.0),
JsonVal::Null,
JsonVal::Bool(true),
])),
]);
let s = val.to_json_string();
let parsed = json_parse(&s).unwrap();
assert_eq!(parsed, val);
}
}

View File

@@ -5,6 +5,10 @@ pub mod transform;
pub mod hierarchy;
pub mod world_transform;
pub mod scene;
pub mod scheduler;
pub mod json;
pub mod component_registry;
pub mod binary_scene;
pub use entity::{Entity, EntityAllocator};
pub use sparse_set::SparseSet;
@@ -12,4 +16,7 @@ pub use world::World;
pub use transform::Transform;
pub use hierarchy::{Parent, Children, add_child, remove_child, despawn_recursive, roots};
pub use world_transform::{WorldTransform, propagate_transforms};
pub use scene::{Tag, serialize_scene, deserialize_scene};
pub use scene::{Tag, serialize_scene, deserialize_scene, serialize_scene_json, deserialize_scene_json};
pub use scheduler::{Scheduler, System};
pub use component_registry::ComponentRegistry;
pub use binary_scene::{serialize_scene_binary, deserialize_scene_binary};

View File

@@ -4,6 +4,8 @@ use crate::entity::Entity;
use crate::world::World;
use crate::transform::Transform;
use crate::hierarchy::{add_child, Parent};
use crate::component_registry::ComponentRegistry;
use crate::json::{self, JsonVal};
/// String tag for entity identification.
#[derive(Debug, Clone)]
@@ -152,10 +154,160 @@ pub fn deserialize_scene(world: &mut World, source: &str) -> Vec<Entity> {
created
}
// ── Hex encoding helpers ────────────────────────────────────────────
fn bytes_to_hex(data: &[u8]) -> String {
let mut s = String::with_capacity(data.len() * 2);
for &b in data {
s.push_str(&format!("{:02x}", b));
}
s
}
fn hex_to_bytes(hex: &str) -> Result<Vec<u8>, String> {
if hex.len() % 2 != 0 {
return Err("Hex string has odd length".into());
}
let mut bytes = Vec::with_capacity(hex.len() / 2);
let mut i = 0;
let chars: Vec<u8> = hex.bytes().collect();
while i < chars.len() {
let hi = hex_digit(chars[i])?;
let lo = hex_digit(chars[i + 1])?;
bytes.push((hi << 4) | lo);
i += 2;
}
Ok(bytes)
}
fn hex_digit(c: u8) -> Result<u8, String> {
match c {
b'0'..=b'9' => Ok(c - b'0'),
b'a'..=b'f' => Ok(c - b'a' + 10),
b'A'..=b'F' => Ok(c - b'A' + 10),
_ => Err(format!("Invalid hex digit: {}", c as char)),
}
}
// ── JSON scene serialization ────────────────────────────────────────
/// Serialize all entities with a Transform to JSON format using the component registry.
/// Format: {"version":1,"entities":[{"parent":null_or_idx,"components":{"name":"hex",...}}]}
pub fn serialize_scene_json(world: &World, registry: &ComponentRegistry) -> String {
let entities_with_transform: Vec<(Entity, Transform)> = world
.query::<Transform>()
.map(|(e, t)| (e, *t))
.collect();
let entity_to_index: HashMap<Entity, usize> = entities_with_transform
.iter()
.enumerate()
.map(|(i, (e, _))| (*e, i))
.collect();
// Build entity JSON values
let mut entity_vals = Vec::new();
for (entity, _) in &entities_with_transform {
// Parent index
let parent_val = if let Some(parent_comp) = world.get::<Parent>(*entity) {
if let Some(&idx) = entity_to_index.get(&parent_comp.0) {
JsonVal::Number(idx as f64)
} else {
JsonVal::Null
}
} else {
JsonVal::Null
};
// Components
let mut comp_pairs = Vec::new();
for entry in registry.entries() {
if let Some(data) = (entry.serialize)(world, *entity) {
comp_pairs.push((entry.name.clone(), JsonVal::Str(bytes_to_hex(&data))));
}
}
entity_vals.push(JsonVal::Object(vec![
("parent".into(), parent_val),
("components".into(), JsonVal::Object(comp_pairs)),
]));
}
let root = JsonVal::Object(vec![
("version".into(), JsonVal::Number(1.0)),
("entities".into(), JsonVal::Array(entity_vals)),
]);
root.to_json_string()
}
/// Deserialize entities from a JSON scene string.
pub fn deserialize_scene_json(
world: &mut World,
json_str: &str,
registry: &ComponentRegistry,
) -> Result<Vec<Entity>, String> {
let root = json::json_parse(json_str)?;
let version = root.get("version")
.and_then(|v| v.as_f64())
.ok_or("Missing or invalid 'version'")?;
if version as u32 != 1 {
return Err(format!("Unsupported version: {}", version));
}
let entities_arr = root.get("entities")
.and_then(|v| v.as_array())
.ok_or("Missing or invalid 'entities'")?;
// First pass: create entities and deserialize components
let mut created: Vec<Entity> = Vec::with_capacity(entities_arr.len());
let mut parent_indices: Vec<Option<usize>> = Vec::with_capacity(entities_arr.len());
for entity_val in entities_arr {
let entity = world.spawn();
// Parse parent index
let parent_idx = match entity_val.get("parent") {
Some(JsonVal::Number(n)) => Some(*n as usize),
_ => None,
};
parent_indices.push(parent_idx);
// Deserialize components
if let Some(comps) = entity_val.get("components").and_then(|v| v.as_object()) {
for (name, hex_val) in comps {
if let Some(hex_str) = hex_val.as_str() {
let data = hex_to_bytes(hex_str)?;
if let Some(entry) = registry.find(name) {
(entry.deserialize)(world, entity, &data)?;
}
}
}
}
created.push(entity);
}
// Second pass: apply parent relationships
for (child_idx, parent_idx_opt) in parent_indices.iter().enumerate() {
if let Some(parent_idx) = parent_idx_opt {
if *parent_idx < created.len() {
let child_entity = created[child_idx];
let parent_entity = created[*parent_idx];
add_child(world, parent_entity, child_entity);
}
}
}
Ok(created)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hierarchy::{add_child, roots, Parent};
use crate::component_registry::ComponentRegistry;
use voltex_math::Vec3;
#[test]
@@ -270,4 +422,75 @@ entity 2
let scene_roots = roots(&world);
assert_eq!(scene_roots.len(), 2, "should have exactly 2 root entities");
}
// ── JSON scene tests ────────────────────────────────────────────
#[test]
fn test_hex_roundtrip() {
let data = vec![0u8, 1, 15, 16, 255];
let hex = bytes_to_hex(&data);
assert_eq!(hex, "00010f10ff");
let back = hex_to_bytes(&hex).unwrap();
assert_eq!(back, data);
}
#[test]
fn test_json_roundtrip() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::new(1.0, 2.0, 3.0)));
world.add(e, Tag("player".into()));
let json = serialize_scene_json(&world, &registry);
assert!(json.contains("\"version\":1"));
let mut world2 = World::new();
let entities = deserialize_scene_json(&mut world2, &json, &registry).unwrap();
assert_eq!(entities.len(), 1);
let t = world2.get::<Transform>(entities[0]).unwrap();
assert!((t.position.x - 1.0).abs() < 1e-4);
assert!((t.position.y - 2.0).abs() < 1e-4);
assert!((t.position.z - 3.0).abs() < 1e-4);
let tag = world2.get::<Tag>(entities[0]).unwrap();
assert_eq!(tag.0, "player");
}
#[test]
fn test_json_with_parent() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
let parent = world.spawn();
let child = world.spawn();
world.add(parent, Transform::new());
world.add(child, Transform::new());
add_child(&mut world, parent, child);
let json = serialize_scene_json(&world, &registry);
let mut world2 = World::new();
let entities = deserialize_scene_json(&mut world2, &json, &registry).unwrap();
assert_eq!(entities.len(), 2);
assert!(world2.get::<Parent>(entities[1]).is_some());
let parent_comp = world2.get::<Parent>(entities[1]).unwrap();
assert_eq!(parent_comp.0, entities[0]);
}
#[test]
fn test_json_multiple_entities() {
let mut registry = ComponentRegistry::new();
registry.register_defaults();
let mut world = World::new();
for i in 0..5 {
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::new(i as f32, 0.0, 0.0)));
world.add(e, Tag(format!("entity_{}", i)));
}
let json = serialize_scene_json(&world, &registry);
let mut world2 = World::new();
let entities = deserialize_scene_json(&mut world2, &json, &registry).unwrap();
assert_eq!(entities.len(), 5);
}
}

View File

@@ -0,0 +1,121 @@
use crate::World;
/// A system that can be run on the world.
pub trait System {
fn run(&mut self, world: &mut World);
}
/// Blanket impl: any FnMut(&mut World) is a System.
impl<F: FnMut(&mut World)> System for F {
fn run(&mut self, world: &mut World) {
(self)(world);
}
}
/// Runs registered systems in order.
pub struct Scheduler {
systems: Vec<Box<dyn System>>,
}
impl Scheduler {
pub fn new() -> Self {
Self { systems: Vec::new() }
}
/// Add a system. Systems run in the order they are added.
pub fn add<S: System + 'static>(&mut self, system: S) -> &mut Self {
self.systems.push(Box::new(system));
self
}
/// Run all systems in registration order.
pub fn run_all(&mut self, world: &mut World) {
for system in &mut self.systems {
system.run(world);
}
}
/// Number of registered systems.
pub fn len(&self) -> usize {
self.systems.len()
}
pub fn is_empty(&self) -> bool {
self.systems.is_empty()
}
}
impl Default for Scheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::World;
#[derive(Debug, PartialEq)]
struct Counter(u32);
#[test]
fn test_scheduler_runs_in_order() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Counter(0));
let mut scheduler = Scheduler::new();
scheduler.add(|world: &mut World| {
let e = world.query::<Counter>().next().unwrap().0;
let c = world.get_mut::<Counter>(e).unwrap();
c.0 += 1; // 0 -> 1
});
scheduler.add(|world: &mut World| {
let e = world.query::<Counter>().next().unwrap().0;
let c = world.get_mut::<Counter>(e).unwrap();
c.0 *= 10; // 1 -> 10
});
scheduler.run_all(&mut world);
let c = world.get::<Counter>(e).unwrap();
assert_eq!(c.0, 10); // proves order: add first, then multiply
}
#[test]
fn test_scheduler_empty() {
let mut world = World::new();
let mut scheduler = Scheduler::new();
scheduler.run_all(&mut world); // should not panic
}
#[test]
fn test_scheduler_multiple_runs() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Counter(0));
let mut scheduler = Scheduler::new();
scheduler.add(|world: &mut World| {
let e = world.query::<Counter>().next().unwrap().0;
let c = world.get_mut::<Counter>(e).unwrap();
c.0 += 1;
});
scheduler.run_all(&mut world);
scheduler.run_all(&mut world);
scheduler.run_all(&mut world);
assert_eq!(world.get::<Counter>(e).unwrap().0, 3);
}
#[test]
fn test_scheduler_add_chaining() {
let mut scheduler = Scheduler::new();
scheduler
.add(|_: &mut World| {})
.add(|_: &mut World| {});
assert_eq!(scheduler.len(), 2);
}
}

View File

@@ -5,6 +5,8 @@ pub struct SparseSet<T> {
sparse: Vec<Option<usize>>,
dense_entities: Vec<Entity>,
dense_data: Vec<T>,
ticks: Vec<u64>,
current_tick: u64,
}
impl<T> SparseSet<T> {
@@ -13,6 +15,8 @@ impl<T> SparseSet<T> {
sparse: Vec::new(),
dense_entities: Vec::new(),
dense_data: Vec::new(),
ticks: Vec::new(),
current_tick: 1,
}
}
@@ -27,11 +31,13 @@ impl<T> SparseSet<T> {
// Overwrite existing
self.dense_data[dense_idx] = value;
self.dense_entities[dense_idx] = entity;
self.ticks[dense_idx] = self.current_tick;
} else {
let dense_idx = self.dense_data.len();
self.sparse[id] = Some(dense_idx);
self.dense_entities.push(entity);
self.dense_data.push(value);
self.ticks.push(self.current_tick);
}
}
@@ -49,12 +55,14 @@ impl<T> SparseSet<T> {
if dense_idx == last_idx {
self.dense_entities.pop();
self.ticks.pop();
Some(self.dense_data.pop().unwrap())
} else {
// Swap with last
let swapped_entity = self.dense_entities[last_idx];
self.sparse[swapped_entity.id as usize] = Some(dense_idx);
self.dense_entities.swap_remove(dense_idx);
self.ticks.swap_remove(dense_idx);
Some(self.dense_data.swap_remove(dense_idx))
}
}
@@ -74,6 +82,7 @@ impl<T> SparseSet<T> {
if self.dense_entities[dense_idx] != entity {
return None;
}
self.ticks[dense_idx] = self.current_tick;
Some(&mut self.dense_data[dense_idx])
}
@@ -114,6 +123,29 @@ impl<T> SparseSet<T> {
pub fn data_mut(&mut self) -> &mut [T] {
&mut self.dense_data
}
/// Check if an entity's component was changed this tick.
pub fn is_changed(&self, entity: Entity) -> bool {
if let Some(&index) = self.sparse.get(entity.id as usize).and_then(|o| o.as_ref()) {
self.ticks[index] == self.current_tick
} else {
false
}
}
/// Advance the tick counter (call at end of frame).
pub fn increment_tick(&mut self) {
self.current_tick += 1;
}
/// Return entities changed this tick with their data.
pub fn iter_changed(&self) -> impl Iterator<Item = (Entity, &T)> + '_ {
self.dense_entities.iter()
.zip(self.dense_data.iter())
.zip(self.ticks.iter())
.filter(move |((_, _), &tick)| tick == self.current_tick)
.map(|((entity, data), _)| (*entity, data))
}
}
impl<T> Default for SparseSet<T> {
@@ -127,6 +159,7 @@ pub trait ComponentStorage: Any {
fn as_any_mut(&mut self) -> &mut dyn Any;
fn remove_entity(&mut self, entity: Entity);
fn storage_len(&self) -> usize;
fn increment_tick(&mut self);
}
impl<T: 'static> ComponentStorage for SparseSet<T> {
@@ -142,6 +175,9 @@ impl<T: 'static> ComponentStorage for SparseSet<T> {
fn storage_len(&self) -> usize {
self.dense_data.len()
}
fn increment_tick(&mut self) {
self.current_tick += 1;
}
}
#[cfg(test)]
@@ -235,6 +271,74 @@ mod tests {
assert!(!set.contains(e));
}
#[test]
fn test_insert_is_changed() {
let mut set = SparseSet::<u32>::new();
let e = make_entity(0, 0);
set.insert(e, 42);
assert!(set.is_changed(e));
}
#[test]
fn test_get_mut_marks_changed() {
let mut set = SparseSet::<u32>::new();
let e = make_entity(0, 0);
set.insert(e, 42);
set.increment_tick();
assert!(!set.is_changed(e));
let _ = set.get_mut(e);
assert!(set.is_changed(e));
}
#[test]
fn test_get_not_changed() {
let mut set = SparseSet::<u32>::new();
let e = make_entity(0, 0);
set.insert(e, 42);
set.increment_tick();
let _ = set.get(e);
assert!(!set.is_changed(e));
}
#[test]
fn test_clear_resets_changed() {
let mut set = SparseSet::<u32>::new();
let e = make_entity(0, 0);
set.insert(e, 42);
assert!(set.is_changed(e));
set.increment_tick();
assert!(!set.is_changed(e));
}
#[test]
fn test_iter_changed() {
let mut set = SparseSet::<u32>::new();
let e1 = make_entity(0, 0);
let e2 = make_entity(1, 0);
let e3 = make_entity(2, 0);
set.insert(e1, 10);
set.insert(e2, 20);
set.insert(e3, 30);
set.increment_tick();
let _ = set.get_mut(e2);
let changed: Vec<_> = set.iter_changed().collect();
assert_eq!(changed.len(), 1);
assert_eq!(changed[0].0.id, 1);
}
#[test]
fn test_remove_preserves_ticks() {
let mut set = SparseSet::<u32>::new();
let e1 = make_entity(0, 0);
let e2 = make_entity(1, 0);
set.insert(e1, 10);
set.insert(e2, 20);
set.increment_tick();
let _ = set.get_mut(e2);
set.remove(e1);
assert!(set.is_changed(e2));
}
#[test]
fn test_swap_remove_correctness() {
let mut set: SparseSet<i32> = SparseSet::new();

View File

@@ -227,6 +227,71 @@ impl World {
}
result
}
/// Query entities whose component T was changed this tick.
pub fn query_changed<T: 'static>(&self) -> Vec<(Entity, &T)> {
if let Some(storage) = self.storages.get(&TypeId::of::<T>()) {
let set = storage.as_any().downcast_ref::<SparseSet<T>>().unwrap();
set.iter_changed().collect()
} else {
Vec::new()
}
}
/// Advance tick on all component storages (call at end of frame).
pub fn clear_changed(&mut self) {
for storage in self.storages.values_mut() {
storage.increment_tick();
}
}
pub fn has_component<T: 'static>(&self, entity: Entity) -> bool {
self.storage::<T>().map_or(false, |s| s.contains(entity))
}
/// Query entities that have component T AND also have component W.
pub fn query_with<T: 'static, W: 'static>(&self) -> Vec<(Entity, &T)> {
let t_storage = match self.storage::<T>() {
Some(s) => s,
None => return Vec::new(),
};
let mut result = Vec::new();
for (entity, data) in t_storage.iter() {
if self.has_component::<W>(entity) {
result.push((entity, data));
}
}
result
}
/// Query entities that have component T but NOT component W.
pub fn query_without<T: 'static, W: 'static>(&self) -> Vec<(Entity, &T)> {
let t_storage = match self.storage::<T>() {
Some(s) => s,
None => return Vec::new(),
};
let mut result = Vec::new();
for (entity, data) in t_storage.iter() {
if !self.has_component::<W>(entity) {
result.push((entity, data));
}
}
result
}
/// Query entities with components A and B, that also have component W.
pub fn query2_with<A: 'static, B: 'static, W: 'static>(&self) -> Vec<(Entity, &A, &B)> {
self.query2::<A, B>().into_iter()
.filter(|(e, _, _)| self.has_component::<W>(*e))
.collect()
}
/// Query entities with components A and B, that do NOT have component W.
pub fn query2_without<A: 'static, B: 'static, W: 'static>(&self) -> Vec<(Entity, &A, &B)> {
self.query2::<A, B>().into_iter()
.filter(|(e, _, _)| !self.has_component::<W>(*e))
.collect()
}
}
impl Default for World {
@@ -388,6 +453,127 @@ mod tests {
assert_eq!(results[0].0, e0);
}
#[test]
fn test_has_component() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Position { x: 1.0, y: 2.0 });
assert!(world.has_component::<Position>(e));
assert!(!world.has_component::<Velocity>(e));
}
#[test]
fn test_query_with() {
let mut world = World::new();
let e0 = world.spawn();
let e1 = world.spawn();
let e2 = world.spawn();
world.add(e0, Position { x: 1.0, y: 0.0 });
world.add(e0, Velocity { dx: 1.0, dy: 0.0 });
world.add(e1, Position { x: 2.0, y: 0.0 });
// e1 has Position but no Velocity
world.add(e2, Position { x: 3.0, y: 0.0 });
world.add(e2, Velocity { dx: 3.0, dy: 0.0 });
let results = world.query_with::<Position, Velocity>();
assert_eq!(results.len(), 2);
let entities: Vec<Entity> = results.iter().map(|(e, _)| *e).collect();
assert!(entities.contains(&e0));
assert!(entities.contains(&e2));
assert!(!entities.contains(&e1));
}
#[test]
fn test_query_without() {
let mut world = World::new();
let e0 = world.spawn();
let e1 = world.spawn();
let e2 = world.spawn();
world.add(e0, Position { x: 1.0, y: 0.0 });
world.add(e0, Velocity { dx: 1.0, dy: 0.0 });
world.add(e1, Position { x: 2.0, y: 0.0 });
// e1 has Position but no Velocity — should be included
world.add(e2, Position { x: 3.0, y: 0.0 });
world.add(e2, Velocity { dx: 3.0, dy: 0.0 });
let results = world.query_without::<Position, Velocity>();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, e1);
}
#[test]
fn test_query2_with() {
#[derive(Debug, PartialEq)]
struct Health(i32);
let mut world = World::new();
let e0 = world.spawn();
world.add(e0, Position { x: 1.0, y: 0.0 });
world.add(e0, Velocity { dx: 1.0, dy: 0.0 });
world.add(e0, Health(100));
let e1 = world.spawn();
world.add(e1, Position { x: 2.0, y: 0.0 });
world.add(e1, Velocity { dx: 2.0, dy: 0.0 });
// e1 has no Health
let results = world.query2_with::<Position, Velocity, Health>();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, e0);
}
#[test]
fn test_query2_without() {
#[derive(Debug, PartialEq)]
struct Health(i32);
let mut world = World::new();
let e0 = world.spawn();
world.add(e0, Position { x: 1.0, y: 0.0 });
world.add(e0, Velocity { dx: 1.0, dy: 0.0 });
world.add(e0, Health(100));
let e1 = world.spawn();
world.add(e1, Position { x: 2.0, y: 0.0 });
world.add(e1, Velocity { dx: 2.0, dy: 0.0 });
// e1 has no Health
let results = world.query2_without::<Position, Velocity, Health>();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, e1);
}
#[test]
fn test_query_changed() {
let mut world = World::new();
let e1 = world.spawn();
let e2 = world.spawn();
world.add(e1, 10u32);
world.add(e2, 20u32);
world.clear_changed();
if let Some(v) = world.get_mut::<u32>(e1) {
*v = 100;
}
let changed = world.query_changed::<u32>();
assert_eq!(changed.len(), 1);
assert_eq!(*changed[0].1, 100);
}
#[test]
fn test_clear_changed_all_storages() {
let mut world = World::new();
let e = world.spawn();
world.add(e, 42u32);
world.add(e, 3.14f32);
let changed_u32 = world.query_changed::<u32>();
assert_eq!(changed_u32.len(), 1);
world.clear_changed();
let changed_u32 = world.query_changed::<u32>();
assert_eq!(changed_u32.len(), 0);
let changed_f32 = world.query_changed::<f32>();
assert_eq!(changed_f32.len(), 0);
}
#[test]
fn test_entity_count() {
let mut world = World::new();

View File

@@ -5,4 +5,7 @@ edition = "2021"
[dependencies]
bytemuck = { workspace = true }
voltex_math = { workspace = true }
voltex_ecs = { workspace = true }
voltex_renderer = { workspace = true }
wgpu = { workspace = true }

View File

@@ -0,0 +1,297 @@
use std::path::PathBuf;
use crate::ui_context::UiContext;
use crate::dock::Rect;
use crate::layout::LayoutState;
const COLOR_SELECTED: [u8; 4] = [0x44, 0x66, 0x88, 0xFF];
const COLOR_TEXT: [u8; 4] = [0xEE, 0xEE, 0xEE, 0xFF];
const PADDING: f32 = 4.0;
pub struct DirEntry {
pub name: String,
pub is_dir: bool,
pub size: u64,
}
pub struct AssetBrowser {
pub root: PathBuf,
pub current: PathBuf,
pub entries: Vec<DirEntry>,
pub selected_file: Option<String>,
}
pub fn format_size(bytes: u64) -> String {
if bytes < 1024 {
format!("{} B", bytes)
} else if bytes < 1024 * 1024 {
format!("{:.1} KB", bytes as f64 / 1024.0)
} else {
format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
}
}
impl AssetBrowser {
pub fn new(root: PathBuf) -> Self {
let root = std::fs::canonicalize(&root).unwrap_or(root);
let current = root.clone();
let mut browser = AssetBrowser {
root,
current,
entries: Vec::new(),
selected_file: None,
};
browser.refresh();
browser
}
pub fn refresh(&mut self) {
self.entries.clear();
self.selected_file = None;
if let Ok(read_dir) = std::fs::read_dir(&self.current) {
for entry in read_dir.flatten() {
let meta = entry.metadata().ok();
let is_dir = meta.as_ref().map_or(false, |m| m.is_dir());
let size = meta.as_ref().map_or(0, |m| m.len());
let name = entry.file_name().to_string_lossy().to_string();
self.entries.push(DirEntry { name, is_dir, size });
}
}
self.entries.sort_by(|a, b| {
b.is_dir.cmp(&a.is_dir).then(a.name.cmp(&b.name))
});
}
pub fn navigate_to(&mut self, dir_name: &str) {
let target = self.current.join(dir_name);
if target.starts_with(&self.root) && target.is_dir() {
self.current = target;
self.refresh();
}
}
pub fn go_up(&mut self) {
if self.current != self.root {
if let Some(parent) = self.current.parent() {
let parent = parent.to_path_buf();
if parent.starts_with(&self.root) || parent == self.root {
self.current = parent;
self.refresh();
}
}
}
}
pub fn relative_path(&self) -> String {
self.current
.strip_prefix(&self.root)
.unwrap_or(&self.current)
.to_string_lossy()
.to_string()
}
}
pub fn asset_browser_panel(
ui: &mut UiContext,
browser: &mut AssetBrowser,
rect: &Rect,
) {
ui.layout = LayoutState::new(rect.x + PADDING, rect.y + PADDING);
let path_text = format!("Path: /{}", browser.relative_path());
ui.text(&path_text);
// Go up button
if browser.current != browser.root {
if ui.button("[..]") {
browser.go_up();
return;
}
}
let gw = ui.font.glyph_width as f32;
let gh = ui.font.glyph_height as f32;
let line_h = gh + PADDING;
// Collect click action to avoid borrow conflict
let mut clicked_dir: Option<String> = None;
let entries_snapshot: Vec<(String, bool)> = browser.entries.iter()
.map(|e| (e.name.clone(), e.is_dir))
.collect();
for (name, is_dir) in &entries_snapshot {
let y = ui.layout.cursor_y;
let x = rect.x + PADDING;
// Highlight selected file
if !is_dir {
if browser.selected_file.as_deref() == Some(name.as_str()) {
ui.draw_list.add_rect(rect.x, y, rect.w, line_h, COLOR_SELECTED);
}
}
// Click
if ui.mouse_clicked && ui.mouse_in_rect(rect.x, y, rect.w, line_h) {
if *is_dir {
clicked_dir = Some(name.clone());
} else {
browser.selected_file = Some(name.clone());
}
}
// Label
let label = if *is_dir {
format!("[D] {}", name)
} else {
format!(" {}", name)
};
let text_y = y + (line_h - gh) * 0.5;
let mut cx = x;
for ch in label.chars() {
let (u0, v0, u1, v1) = ui.font.glyph_uv(ch);
ui.draw_list.add_rect_uv(cx, text_y, gw, gh, u0, v0, u1, v1, COLOR_TEXT);
cx += gw;
}
ui.layout.cursor_y += line_h;
}
if let Some(dir) = clicked_dir {
browser.navigate_to(&dir);
}
// File info
if let Some(ref file_name) = browser.selected_file.clone() {
ui.text("-- File Info --");
ui.text(&format!("Name: {}", file_name));
if let Some(entry) = browser.entries.iter().find(|e| e.name == *file_name) {
ui.text(&format!("Size: {}", format_size(entry.size)));
if let Some(dot_pos) = file_name.rfind('.') {
let ext = &file_name[dot_pos..];
ui.text(&format!("Type: {}", ext));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn make_temp_dir(name: &str) -> PathBuf {
let dir = std::env::temp_dir().join(format!("voltex_ab_test_{}", name));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn test_new_scans_entries() {
let dir = make_temp_dir("scan");
fs::write(dir.join("file1.txt"), "hello").unwrap();
fs::write(dir.join("file2.png"), "data").unwrap();
fs::create_dir_all(dir.join("subdir")).unwrap();
let browser = AssetBrowser::new(dir.clone());
assert_eq!(browser.entries.len(), 3);
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_entries_sorted_dirs_first() {
let dir = make_temp_dir("sort");
fs::write(dir.join("zebra.txt"), "z").unwrap();
fs::write(dir.join("alpha.txt"), "a").unwrap();
fs::create_dir_all(dir.join("middle_dir")).unwrap();
let browser = AssetBrowser::new(dir.clone());
// Dir should come first
assert!(browser.entries[0].is_dir);
assert_eq!(browser.entries[0].name, "middle_dir");
// Then files alphabetically
assert_eq!(browser.entries[1].name, "alpha.txt");
assert_eq!(browser.entries[2].name, "zebra.txt");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_navigate_to() {
let dir = make_temp_dir("nav");
fs::create_dir_all(dir.join("sub")).unwrap();
fs::write(dir.join("sub").join("inner.txt"), "x").unwrap();
let mut browser = AssetBrowser::new(dir.clone());
browser.navigate_to("sub");
assert!(browser.current.ends_with("sub"));
assert_eq!(browser.entries.len(), 1);
assert_eq!(browser.entries[0].name, "inner.txt");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_go_up() {
let dir = make_temp_dir("goup");
fs::create_dir_all(dir.join("child")).unwrap();
let mut browser = AssetBrowser::new(dir.clone());
browser.navigate_to("child");
assert!(browser.current.ends_with("child"));
browser.go_up();
assert_eq!(browser.current, std::fs::canonicalize(&dir).unwrap());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_go_up_at_root() {
let dir = make_temp_dir("goup_root");
let browser_root = std::fs::canonicalize(&dir).unwrap();
let mut browser = AssetBrowser::new(dir.clone());
browser.go_up(); // should be no-op
assert_eq!(browser.current, browser_root);
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_root_guard() {
let dir = make_temp_dir("guard");
let mut browser = AssetBrowser::new(dir.clone());
browser.navigate_to(".."); // should be rejected
assert_eq!(browser.current, std::fs::canonicalize(&dir).unwrap());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_format_size() {
assert_eq!(format_size(0), "0 B");
assert_eq!(format_size(500), "500 B");
assert_eq!(format_size(1024), "1.0 KB");
assert_eq!(format_size(1536), "1.5 KB");
assert_eq!(format_size(1048576), "1.0 MB");
}
#[test]
fn test_panel_draws_commands() {
let dir = make_temp_dir("panel");
fs::write(dir.join("test.txt"), "hello").unwrap();
let mut browser = AssetBrowser::new(dir.clone());
let mut ui = UiContext::new(800.0, 600.0);
let rect = Rect { x: 0.0, y: 0.0, w: 300.0, h: 400.0 };
ui.begin_frame(0.0, 0.0, false);
asset_browser_panel(&mut ui, &mut browser, &rect);
ui.end_frame();
assert!(ui.draw_list.commands.len() > 0);
let _ = fs::remove_dir_all(&dir);
}
}

View File

@@ -0,0 +1,558 @@
use crate::ui_context::UiContext;
const TAB_BAR_HEIGHT: f32 = 20.0;
const MIN_RATIO: f32 = 0.1;
const MAX_RATIO: f32 = 0.9;
const RESIZE_HANDLE_HALF: f32 = 3.0;
const GLYPH_W: f32 = 8.0;
const TAB_PADDING: f32 = 8.0;
#[derive(Clone, Copy, Debug)]
pub struct Rect {
pub x: f32,
pub y: f32,
pub w: f32,
pub h: f32,
}
impl Rect {
pub fn contains(&self, px: f32, py: f32) -> bool {
px >= self.x && px < self.x + self.w && py >= self.y && py < self.y + self.h
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Axis {
Horizontal,
Vertical,
}
pub enum DockNode {
Leaf { tabs: Vec<u32>, active: usize },
Split { axis: Axis, ratio: f32, children: [Box<DockNode>; 2] },
}
impl DockNode {
pub fn leaf(tabs: Vec<u32>) -> Self {
DockNode::Leaf { tabs, active: 0 }
}
pub fn split(axis: Axis, ratio: f32, a: DockNode, b: DockNode) -> Self {
DockNode::Split {
axis,
ratio: ratio.clamp(MIN_RATIO, MAX_RATIO),
children: [Box::new(a), Box::new(b)],
}
}
}
pub struct LeafLayout {
pub leaf_index: usize,
pub tabs: Vec<u32>,
pub active: usize,
pub tab_bar_rect: Rect,
pub content_rect: Rect,
}
struct SplitLayout {
rect: Rect,
axis: Axis,
boundary: f32,
path: Vec<usize>,
}
struct ResizeState {
path: Vec<usize>,
axis: Axis,
origin: f32,
size: f32,
}
enum UpdateAction {
None,
StartResize(ResizeState),
SetActiveTab { leaf_index: usize, new_active: usize },
}
pub struct DockTree {
root: DockNode,
names: Vec<&'static str>,
cached_leaves: Vec<LeafLayout>,
cached_splits: Vec<SplitLayout>,
resizing: Option<ResizeState>,
prev_mouse_down: bool,
}
fn layout_recursive(
node: &DockNode,
rect: Rect,
path: &mut Vec<usize>,
leaf_counter: &mut usize,
leaves: &mut Vec<LeafLayout>,
splits: &mut Vec<SplitLayout>,
) {
match node {
DockNode::Leaf { tabs, active } => {
assert!(!tabs.is_empty(), "DockNode::Leaf must have at least one tab");
let idx = *leaf_counter;
*leaf_counter += 1;
leaves.push(LeafLayout {
leaf_index: idx,
tabs: tabs.clone(),
active: (*active).min(tabs.len().saturating_sub(1)),
tab_bar_rect: Rect { x: rect.x, y: rect.y, w: rect.w, h: TAB_BAR_HEIGHT },
content_rect: Rect {
x: rect.x,
y: rect.y + TAB_BAR_HEIGHT,
w: rect.w,
h: (rect.h - TAB_BAR_HEIGHT).max(0.0),
},
});
}
DockNode::Split { axis, ratio, children } => {
let (r1, r2, boundary) = match axis {
Axis::Horizontal => {
let w1 = rect.w * ratio;
let b = rect.x + w1;
(
Rect { x: rect.x, y: rect.y, w: w1, h: rect.h },
Rect { x: b, y: rect.y, w: rect.w - w1, h: rect.h },
b,
)
}
Axis::Vertical => {
let h1 = rect.h * ratio;
let b = rect.y + h1;
(
Rect { x: rect.x, y: rect.y, w: rect.w, h: h1 },
Rect { x: rect.x, y: b, w: rect.w, h: rect.h - h1 },
b,
)
}
};
splits.push(SplitLayout { rect, axis: *axis, boundary, path: path.clone() });
path.push(0);
layout_recursive(&children[0], r1, path, leaf_counter, leaves, splits);
path.pop();
path.push(1);
layout_recursive(&children[1], r2, path, leaf_counter, leaves, splits);
path.pop();
}
}
}
impl DockTree {
pub fn new(root: DockNode, names: Vec<&'static str>) -> Self {
DockTree {
root,
names,
cached_leaves: Vec::new(),
cached_splits: Vec::new(),
resizing: None,
prev_mouse_down: false,
}
}
pub fn update(&mut self, mouse_x: f32, mouse_y: f32, mouse_down: bool) {
let just_clicked = mouse_down && !self.prev_mouse_down;
self.prev_mouse_down = mouse_down;
// Active resize in progress
if let Some(ref state) = self.resizing {
if mouse_down {
let new_ratio = match state.axis {
Axis::Horizontal => (mouse_x - state.origin) / state.size,
Axis::Vertical => (mouse_y - state.origin) / state.size,
};
let clamped = new_ratio.clamp(MIN_RATIO, MAX_RATIO);
let path = state.path.clone();
Self::set_ratio_at_path(&mut self.root, &path, clamped);
} else {
self.resizing = None;
}
return;
}
if !just_clicked { return; }
let action = self.find_click_action(mouse_x, mouse_y);
match action {
UpdateAction::StartResize(state) => { self.resizing = Some(state); }
UpdateAction::SetActiveTab { leaf_index, new_active } => {
Self::set_active_nth(&mut self.root, leaf_index, new_active, &mut 0);
}
UpdateAction::None => {}
}
}
fn find_click_action(&self, mx: f32, my: f32) -> UpdateAction {
// Check resize handles first (priority over tab clicks)
for split in &self.cached_splits {
let hit = match split.axis {
Axis::Horizontal => {
mx >= split.boundary - RESIZE_HANDLE_HALF
&& mx <= split.boundary + RESIZE_HANDLE_HALF
&& my >= split.rect.y
&& my < split.rect.y + split.rect.h
}
Axis::Vertical => {
my >= split.boundary - RESIZE_HANDLE_HALF
&& my <= split.boundary + RESIZE_HANDLE_HALF
&& mx >= split.rect.x
&& mx < split.rect.x + split.rect.w
}
};
if hit {
let (origin, size) = match split.axis {
Axis::Horizontal => (split.rect.x, split.rect.w),
Axis::Vertical => (split.rect.y, split.rect.h),
};
return UpdateAction::StartResize(ResizeState {
path: split.path.clone(),
axis: split.axis,
origin,
size,
});
}
}
// Check tab bar clicks
for leaf in &self.cached_leaves {
if leaf.tabs.len() <= 1 { continue; }
if !leaf.tab_bar_rect.contains(mx, my) { continue; }
let mut tx = leaf.tab_bar_rect.x;
for (i, &panel_id) in leaf.tabs.iter().enumerate() {
let name_len = self.names.get(panel_id as usize).map(|n| n.len()).unwrap_or(1);
let tab_w = name_len as f32 * GLYPH_W + TAB_PADDING;
if mx >= tx && mx < tx + tab_w {
return UpdateAction::SetActiveTab { leaf_index: leaf.leaf_index, new_active: i };
}
tx += tab_w;
}
}
UpdateAction::None
}
fn set_ratio_at_path(node: &mut DockNode, path: &[usize], ratio: f32) {
if path.is_empty() {
if let DockNode::Split { ratio: r, .. } = node { *r = ratio; }
return;
}
if let DockNode::Split { children, .. } = node {
Self::set_ratio_at_path(&mut children[path[0]], &path[1..], ratio);
}
}
fn set_active_nth(node: &mut DockNode, target: usize, new_active: usize, count: &mut usize) {
match node {
DockNode::Leaf { active, .. } => {
if *count == target { *active = new_active; }
*count += 1;
}
DockNode::Split { children, .. } => {
Self::set_active_nth(&mut children[0], target, new_active, count);
Self::set_active_nth(&mut children[1], target, new_active, count);
}
}
}
pub fn layout(&mut self, rect: Rect) -> Vec<(u32, Rect)> {
self.cached_leaves.clear();
self.cached_splits.clear();
let mut path = Vec::new();
let mut counter = 0;
layout_recursive(
&self.root,
rect,
&mut path,
&mut counter,
&mut self.cached_leaves,
&mut self.cached_splits,
);
self.cached_leaves
.iter()
.map(|l| {
let active = l.active.min(l.tabs.len().saturating_sub(1));
(l.tabs[active], l.content_rect)
})
.collect()
}
pub fn draw_chrome(&self, ui: &mut UiContext) {
let glyph_w = ui.font.glyph_width as f32;
let glyph_h = ui.font.glyph_height as f32;
const COLOR_TAB_ACTIVE: [u8; 4] = [60, 60, 60, 255];
const COLOR_TAB_INACTIVE: [u8; 4] = [40, 40, 40, 255];
const COLOR_TEXT: [u8; 4] = [0xEE, 0xEE, 0xEE, 0xFF];
const COLOR_SPLIT: [u8; 4] = [30, 30, 30, 255];
const COLOR_SPLIT_ACTIVE: [u8; 4] = [100, 100, 200, 255];
const COLOR_SEPARATOR: [u8; 4] = [50, 50, 50, 255];
// Draw tab bars for each leaf
for leaf in &self.cached_leaves {
let bar = &leaf.tab_bar_rect;
let active = leaf.active.min(leaf.tabs.len().saturating_sub(1));
let mut tx = bar.x;
for (i, &panel_id) in leaf.tabs.iter().enumerate() {
let name = self.names.get(panel_id as usize).copied().unwrap_or("?");
let tab_w = name.len() as f32 * glyph_w + TAB_PADDING;
let bg = if i == active { COLOR_TAB_ACTIVE } else { COLOR_TAB_INACTIVE };
ui.draw_list.add_rect(tx, bar.y, tab_w, bar.h, bg);
// Inline text rendering (avoids borrow conflict with ui.font + ui.draw_list)
let text_x = tx + TAB_PADDING * 0.5;
let text_y = bar.y + (bar.h - glyph_h) * 0.5;
let mut cx = text_x;
for ch in name.chars() {
let (u0, v0, u1, v1) = ui.font.glyph_uv(ch);
ui.draw_list.add_rect_uv(cx, text_y, glyph_w, glyph_h, u0, v0, u1, v1, COLOR_TEXT);
cx += glyph_w;
}
tx += tab_w;
}
// Tab bar bottom separator
ui.draw_list.add_rect(bar.x, bar.y + bar.h - 1.0, bar.w, 1.0, COLOR_SEPARATOR);
}
// Draw split lines
for split in &self.cached_splits {
let color = if self.resizing.as_ref().map(|r| &r.path) == Some(&split.path) {
COLOR_SPLIT_ACTIVE
} else {
COLOR_SPLIT
};
match split.axis {
Axis::Horizontal => {
ui.draw_list.add_rect(split.boundary - 0.5, split.rect.y, 1.0, split.rect.h, color);
}
Axis::Vertical => {
ui.draw_list.add_rect(split.rect.x, split.boundary - 0.5, split.rect.w, 1.0, color);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ui_context::UiContext;
#[test]
fn test_rect_contains() {
let r = Rect { x: 10.0, y: 20.0, w: 100.0, h: 50.0 };
assert!(r.contains(50.0, 40.0));
assert!(!r.contains(5.0, 40.0));
assert!(!r.contains(50.0, 80.0));
}
#[test]
fn test_layout_single_leaf() {
let mut dock = DockTree::new(DockNode::leaf(vec![0]), vec!["Panel0"]);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
assert_eq!(areas.len(), 1);
assert_eq!(areas[0].0, 0);
let r = &areas[0].1;
assert!((r.y - 20.0).abs() < 1e-3);
assert!((r.h - 280.0).abs() < 1e-3);
}
#[test]
fn test_layout_horizontal_split() {
let mut dock = DockTree::new(
DockNode::split(Axis::Horizontal, 0.25, DockNode::leaf(vec![0]), DockNode::leaf(vec![1])),
vec!["Left", "Right"],
);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
assert_eq!(areas.len(), 2);
let left = areas.iter().find(|(id, _)| *id == 0).unwrap();
assert!((left.1.w - 100.0).abs() < 1e-3);
let right = areas.iter().find(|(id, _)| *id == 1).unwrap();
assert!((right.1.w - 300.0).abs() < 1e-3);
}
#[test]
fn test_layout_vertical_split() {
let mut dock = DockTree::new(
DockNode::split(Axis::Vertical, 0.5, DockNode::leaf(vec![0]), DockNode::leaf(vec![1])),
vec!["Top", "Bottom"],
);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
let top = areas.iter().find(|(id, _)| *id == 0).unwrap();
assert!((top.1.h - 130.0).abs() < 1e-3);
}
#[test]
fn test_layout_nested_split() {
let mut dock = DockTree::new(
DockNode::split(
Axis::Horizontal,
0.25,
DockNode::leaf(vec![0]),
DockNode::split(Axis::Vertical, 0.5, DockNode::leaf(vec![1]), DockNode::leaf(vec![2])),
),
vec!["A", "B", "C"],
);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
assert_eq!(areas.len(), 3);
}
#[test]
fn test_layout_active_tab_only() {
let mut dock = DockTree::new(
DockNode::Leaf { tabs: vec![0, 1, 2], active: 1 },
vec!["A", "B", "C"],
);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
assert_eq!(areas[0].0, 1);
}
#[test]
fn test_active_clamped_if_out_of_bounds() {
let mut dock = DockTree::new(
DockNode::Leaf { tabs: vec![0], active: 5 },
vec!["A"],
);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
assert_eq!(areas[0].0, 0);
}
#[test]
#[should_panic]
fn test_empty_tabs_panics() {
let mut dock = DockTree::new(
DockNode::Leaf { tabs: vec![], active: 0 },
vec![],
);
dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
}
#[test]
fn test_tab_click_switches_active() {
let mut dock = DockTree::new(
DockNode::Leaf { tabs: vec![0, 1, 2], active: 0 },
vec!["A", "B", "C"],
);
dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
// Tab "A": w = 1*8+8 = 16, tab "B" starts at x=16. Click at x=20 (inside "B").
dock.update(20.0, 10.0, true);
dock.update(20.0, 10.0, false);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
assert_eq!(areas[0].0, 1);
}
#[test]
fn test_tab_click_no_change_on_single_tab() {
let mut dock = DockTree::new(DockNode::leaf(vec![0]), vec!["Only"]);
dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
dock.update(10.0, 10.0, true);
dock.update(10.0, 10.0, false);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
assert_eq!(areas[0].0, 0);
}
#[test]
fn test_resize_horizontal_drag() {
let mut dock = DockTree::new(
DockNode::split(Axis::Horizontal, 0.5, DockNode::leaf(vec![0]), DockNode::leaf(vec![1])),
vec!["L", "R"],
);
dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
dock.update(200.0, 150.0, true); // click on boundary at x=200
dock.update(120.0, 150.0, true); // drag to x=120
dock.update(120.0, 150.0, false); // release
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
let left = areas.iter().find(|(id, _)| *id == 0).unwrap();
assert!((left.1.w - 120.0).abs() < 5.0, "left w={}", left.1.w);
}
#[test]
fn test_resize_clamps_ratio() {
let mut dock = DockTree::new(
DockNode::split(Axis::Horizontal, 0.5, DockNode::leaf(vec![0]), DockNode::leaf(vec![1])),
vec!["L", "R"],
);
dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
dock.update(200.0, 150.0, true);
dock.update(5.0, 150.0, true);
dock.update(5.0, 150.0, false);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
let left = areas.iter().find(|(id, _)| *id == 0).unwrap();
assert!((left.1.w - 40.0).abs() < 1e-3, "left w={}", left.1.w);
}
#[test]
fn test_resize_priority_over_tab_click() {
let mut dock = DockTree::new(
DockNode::split(Axis::Horizontal, 0.5,
DockNode::Leaf { tabs: vec![0, 1], active: 0 },
DockNode::leaf(vec![2]),
),
vec!["A", "B", "C"],
);
dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
dock.update(200.0, 10.0, true); // click at boundary within tab bar
dock.update(180.0, 10.0, true); // drag
dock.update(180.0, 10.0, false); // release
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
let left = areas.iter().find(|(id, _)| *id == 0 || *id == 1).unwrap();
assert!((left.1.w - 180.0).abs() < 5.0, "resize should have priority, w={}", left.1.w);
}
#[test]
fn test_draw_chrome_produces_draw_commands() {
let mut dock = DockTree::new(
DockNode::split(Axis::Horizontal, 0.5,
DockNode::Leaf { tabs: vec![0, 1], active: 0 },
DockNode::leaf(vec![2]),
),
vec!["A", "B", "C"],
);
let mut ui = UiContext::new(800.0, 600.0);
ui.begin_frame(0.0, 0.0, false);
dock.layout(Rect { x: 0.0, y: 0.0, w: 800.0, h: 600.0 });
dock.draw_chrome(&mut ui);
assert!(ui.draw_list.commands.len() >= 5);
}
#[test]
fn test_draw_chrome_active_tab_color() {
let mut dock = DockTree::new(
DockNode::Leaf { tabs: vec![0, 1], active: 1 },
vec!["AA", "BB"],
);
let mut ui = UiContext::new(800.0, 600.0);
ui.begin_frame(0.0, 0.0, false);
dock.layout(Rect { x: 0.0, y: 0.0, w: 400.0, h: 300.0 });
dock.draw_chrome(&mut ui);
// First tab bg (inactive "AA"): color [40, 40, 40, 255]
let first_bg = &ui.draw_list.vertices[0];
assert_eq!(first_bg.color, [40, 40, 40, 255]);
}
#[test]
fn test_full_frame_cycle() {
let mut dock = DockTree::new(
DockNode::split(Axis::Horizontal, 0.3,
DockNode::Leaf { tabs: vec![0, 1], active: 0 },
DockNode::split(Axis::Vertical, 0.6, DockNode::leaf(vec![2]), DockNode::leaf(vec![3])),
),
vec!["Hierarchy", "Inspector", "Viewport", "Console"],
);
let mut ui = UiContext::new(1280.0, 720.0);
for _ in 0..3 {
ui.begin_frame(100.0, 100.0, false);
let areas = dock.layout(Rect { x: 0.0, y: 0.0, w: 1280.0, h: 720.0 });
dock.update(100.0, 100.0, false);
dock.draw_chrome(&mut ui);
assert_eq!(areas.len(), 3);
for (_, r) in &areas {
assert!(r.w > 0.0);
assert!(r.h > 0.0);
}
ui.end_frame();
}
}
}

View File

@@ -9,15 +9,27 @@ pub struct DrawVertex {
pub color: [u8; 4],
}
/// A scissor rectangle for content clipping, in pixel coordinates.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ScissorRect {
pub x: u32,
pub y: u32,
pub w: u32,
pub h: u32,
}
pub struct DrawCommand {
pub index_offset: u32,
pub index_count: u32,
/// Optional scissor rect for clipping. None means no clipping.
pub scissor: Option<ScissorRect>,
}
pub struct DrawList {
pub vertices: Vec<DrawVertex>,
pub indices: Vec<u16>,
pub commands: Vec<DrawCommand>,
scissor_stack: Vec<ScissorRect>,
}
impl DrawList {
@@ -26,6 +38,7 @@ impl DrawList {
vertices: Vec::new(),
indices: Vec::new(),
commands: Vec::new(),
scissor_stack: Vec::new(),
}
}
@@ -33,6 +46,23 @@ impl DrawList {
self.vertices.clear();
self.indices.clear();
self.commands.clear();
self.scissor_stack.clear();
}
/// Push a scissor rect onto the stack. All subsequent draw commands will
/// be clipped to this rectangle until `pop_scissor` is called.
pub fn push_scissor(&mut self, x: u32, y: u32, w: u32, h: u32) {
self.scissor_stack.push(ScissorRect { x, y, w, h });
}
/// Pop the current scissor rect from the stack.
pub fn pop_scissor(&mut self) {
self.scissor_stack.pop();
}
/// Returns the current scissor rect (top of stack), or None.
fn current_scissor(&self) -> Option<ScissorRect> {
self.scissor_stack.last().copied()
}
/// Add a solid-color rectangle. UV is (0,0) for solid color rendering.
@@ -67,6 +97,7 @@ impl DrawList {
self.commands.push(DrawCommand {
index_offset,
index_count: 6,
scissor: self.current_scissor(),
});
}

View File

@@ -0,0 +1,197 @@
use std::collections::HashMap;
pub const ATLAS_SIZE: u32 = 1024;
#[derive(Clone, Debug)]
pub struct GlyphInfo {
pub uv: [f32; 4], // u0, v0, u1, v1
pub width: f32,
pub height: f32,
pub advance: f32,
pub bearing_x: f32,
pub bearing_y: f32,
}
pub struct GlyphCache {
pub atlas_data: Vec<u8>,
pub atlas_width: u32,
pub atlas_height: u32,
glyphs: HashMap<char, GlyphInfo>,
cursor_x: u32,
cursor_y: u32,
row_height: u32,
pub dirty: bool,
}
impl GlyphCache {
pub fn new(width: u32, height: u32) -> Self {
GlyphCache {
atlas_data: vec![0u8; (width * height) as usize],
atlas_width: width,
atlas_height: height,
glyphs: HashMap::new(),
cursor_x: 0,
cursor_y: 0,
row_height: 0,
dirty: false,
}
}
pub fn get(&self, ch: char) -> Option<&GlyphInfo> {
self.glyphs.get(&ch)
}
/// Insert a rasterized glyph bitmap into the atlas.
/// Returns reference to the cached GlyphInfo.
pub fn insert(
&mut self,
ch: char,
bitmap: &[u8],
bmp_w: u32,
bmp_h: u32,
advance: f32,
bearing_x: f32,
bearing_y: f32,
) -> &GlyphInfo {
// Handle zero-size glyphs (e.g., space)
if bmp_w == 0 || bmp_h == 0 {
self.glyphs.insert(ch, GlyphInfo {
uv: [0.0, 0.0, 0.0, 0.0],
width: 0.0,
height: 0.0,
advance,
bearing_x,
bearing_y,
});
return self.glyphs.get(&ch).unwrap();
}
// Check if we need to wrap to next row
if self.cursor_x + bmp_w > self.atlas_width {
self.cursor_y += self.row_height + 1; // +1 pixel gap
self.cursor_x = 0;
self.row_height = 0;
}
// Check if atlas is full (would overflow vertically)
if self.cursor_y + bmp_h > self.atlas_height {
// Atlas full — insert with zero UV (glyph won't render but won't crash)
self.glyphs.insert(ch, GlyphInfo {
uv: [0.0, 0.0, 0.0, 0.0],
width: bmp_w as f32,
height: bmp_h as f32,
advance,
bearing_x,
bearing_y,
});
return self.glyphs.get(&ch).unwrap();
}
// Copy bitmap into atlas
for row in 0..bmp_h {
let src_start = (row * bmp_w) as usize;
let src_end = src_start + bmp_w as usize;
let dst_y = self.cursor_y + row;
let dst_x = self.cursor_x;
let dst_start = (dst_y * self.atlas_width + dst_x) as usize;
if src_end <= bitmap.len() && dst_start + bmp_w as usize <= self.atlas_data.len() {
self.atlas_data[dst_start..dst_start + bmp_w as usize]
.copy_from_slice(&bitmap[src_start..src_end]);
}
}
// Calculate UV coordinates
let u0 = self.cursor_x as f32 / self.atlas_width as f32;
let v0 = self.cursor_y as f32 / self.atlas_height as f32;
let u1 = (self.cursor_x + bmp_w) as f32 / self.atlas_width as f32;
let v1 = (self.cursor_y + bmp_h) as f32 / self.atlas_height as f32;
let info = GlyphInfo {
uv: [u0, v0, u1, v1],
width: bmp_w as f32,
height: bmp_h as f32,
advance,
bearing_x,
bearing_y,
};
// Advance cursor
self.cursor_x += bmp_w + 1; // +1 pixel gap
if bmp_h > self.row_height {
self.row_height = bmp_h;
}
self.dirty = true;
self.glyphs.insert(ch, info);
self.glyphs.get(&ch).unwrap()
}
pub fn clear_dirty(&mut self) {
self.dirty = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_and_get() {
let mut cache = GlyphCache::new(256, 256);
let bitmap = vec![255u8; 10 * 12]; // 10x12 glyph
cache.insert('A', &bitmap, 10, 12, 11.0, 0.5, 10.0);
let info = cache.get('A');
assert!(info.is_some());
let info = info.unwrap();
assert!((info.width - 10.0).abs() < 0.1);
assert!((info.advance - 11.0).abs() < 0.1);
assert!(info.uv[0] >= 0.0 && info.uv[2] <= 1.0);
}
#[test]
fn test_cache_hit() {
let mut cache = GlyphCache::new(256, 256);
let bitmap = vec![255u8; 8 * 8];
cache.insert('B', &bitmap, 8, 8, 9.0, 0.0, 8.0);
cache.clear_dirty();
// Second access should be cache hit (no dirty)
let _info = cache.get('B').unwrap();
assert!(!cache.dirty);
}
#[test]
fn test_row_wrap() {
let mut cache = GlyphCache::new(64, 64);
let bitmap = vec![255u8; 20 * 10];
// Insert 4 glyphs of width 20 in a 64-wide atlas
// 3 fit in first row (20+1 + 20+1 + 20 = 62), 4th wraps
cache.insert('A', &bitmap, 20, 10, 21.0, 0.0, 10.0);
cache.insert('B', &bitmap, 20, 10, 21.0, 0.0, 10.0);
cache.insert('C', &bitmap, 20, 10, 21.0, 0.0, 10.0);
cache.insert('D', &bitmap, 20, 10, 21.0, 0.0, 10.0);
let d = cache.get('D').unwrap();
// D should be on a different row (v0 > 0)
assert!(d.uv[1] > 0.0, "D should be on second row, v0={}", d.uv[1]);
}
#[test]
fn test_uv_range() {
let mut cache = GlyphCache::new(256, 256);
let bitmap = vec![128u8; 15 * 20];
cache.insert('X', &bitmap, 15, 20, 16.0, 1.0, 18.0);
let info = cache.get('X').unwrap();
assert!(info.uv[0] >= 0.0 && info.uv[0] < 1.0);
assert!(info.uv[1] >= 0.0 && info.uv[1] < 1.0);
assert!(info.uv[2] > info.uv[0] && info.uv[2] <= 1.0);
assert!(info.uv[3] > info.uv[1] && info.uv[3] <= 1.0);
}
#[test]
fn test_zero_size_glyph() {
let mut cache = GlyphCache::new(256, 256);
cache.insert(' ', &[], 0, 0, 5.0, 0.0, 0.0);
let info = cache.get(' ').unwrap();
assert!((info.advance - 5.0).abs() < 0.1);
assert!((info.width - 0.0).abs() < 0.1);
}
}

View File

@@ -0,0 +1,307 @@
use crate::ui_context::UiContext;
use crate::dock::Rect;
use crate::layout::LayoutState;
use voltex_ecs::world::World;
use voltex_ecs::entity::Entity;
use voltex_ecs::scene::Tag;
use voltex_ecs::hierarchy::{roots, Children, Parent};
use voltex_ecs::transform::Transform;
const COLOR_SELECTED: [u8; 4] = [0x44, 0x66, 0x88, 0xFF];
const COLOR_TEXT: [u8; 4] = [0xEE, 0xEE, 0xEE, 0xFF];
const LINE_HEIGHT: f32 = 16.0;
const INDENT: f32 = 16.0;
const PADDING: f32 = 4.0;
/// Count total nodes in subtree (including self).
pub fn count_entity_nodes(world: &World, entity: Entity) -> usize {
let mut count = 1;
if let Some(children) = world.get::<Children>(entity) {
for &child in &children.0 {
count += count_entity_nodes(world, child);
}
}
count
}
/// Draw a single entity row, then recurse into children.
fn draw_entity_node(
ui: &mut UiContext,
world: &World,
entity: Entity,
depth: usize,
selected: &mut Option<Entity>,
base_x: f32,
row_w: f32,
) {
let y = ui.layout.cursor_y;
let x = base_x + PADDING + depth as f32 * INDENT;
// Highlight selected
if *selected == Some(entity) {
ui.draw_list.add_rect(base_x, y, row_w, LINE_HEIGHT, COLOR_SELECTED);
}
// Click detection
if ui.mouse_clicked && ui.mouse_in_rect(base_x, y, row_w, LINE_HEIGHT) {
*selected = Some(entity);
}
// Build label
let has_children = world.get::<Children>(entity).map_or(false, |c| !c.0.is_empty());
let prefix = if has_children { "> " } else { " " };
let name = if let Some(tag) = world.get::<Tag>(entity) {
format!("{}{}", prefix, tag.0)
} else {
format!("{}Entity({})", prefix, entity.id)
};
// Draw text (inline glyph rendering)
let gw = ui.font.glyph_width as f32;
let gh = ui.font.glyph_height as f32;
let text_y = y + (LINE_HEIGHT - gh) * 0.5;
let mut cx = x;
for ch in name.chars() {
let (u0, v0, u1, v1) = ui.font.glyph_uv(ch);
ui.draw_list.add_rect_uv(cx, text_y, gw, gh, u0, v0, u1, v1, COLOR_TEXT);
cx += gw;
}
ui.layout.cursor_y += LINE_HEIGHT;
// Recurse children
if let Some(children) = world.get::<Children>(entity) {
let child_list: Vec<Entity> = children.0.clone();
for child in child_list {
draw_entity_node(ui, world, child, depth + 1, selected, base_x, row_w);
}
}
}
/// Hierarchy panel: displays entity tree, handles selection.
pub fn hierarchy_panel(
ui: &mut UiContext,
world: &World,
selected: &mut Option<Entity>,
rect: &Rect,
) {
ui.layout = LayoutState::new(rect.x + PADDING, rect.y + PADDING);
let root_entities = roots(world);
for &entity in &root_entities {
draw_entity_node(ui, world, entity, 0, selected, rect.x, rect.w);
}
}
/// Inspector panel: edit Transform, Tag, Parent for selected entity.
/// `tag_buffer` is caller-owned staging buffer for Tag text input.
pub fn inspector_panel(
ui: &mut UiContext,
world: &mut World,
selected: Option<Entity>,
rect: &Rect,
tag_buffer: &mut String,
) {
ui.layout = LayoutState::new(rect.x + PADDING, rect.y + PADDING);
let entity = match selected {
Some(e) => e,
None => {
ui.text("No entity selected");
return;
}
};
// --- Transform ---
if world.has_component::<Transform>(entity) {
ui.text("-- Transform --");
// Copy out values (immutable borrow ends with block)
let (px, py, pz, rx, ry, rz, sx, sy, sz) = {
let t = world.get::<Transform>(entity).unwrap();
(t.position.x, t.position.y, t.position.z,
t.rotation.x, t.rotation.y, t.rotation.z,
t.scale.x, t.scale.y, t.scale.z)
};
// Sliders (no world borrow active)
let new_px = ui.slider("Pos X", px, -50.0, 50.0);
let new_py = ui.slider("Pos Y", py, -50.0, 50.0);
let new_pz = ui.slider("Pos Z", pz, -50.0, 50.0);
let new_rx = ui.slider("Rot X", rx, -3.15, 3.15);
let new_ry = ui.slider("Rot Y", ry, -3.15, 3.15);
let new_rz = ui.slider("Rot Z", rz, -3.15, 3.15);
let new_sx = ui.slider("Scl X", sx, 0.01, 10.0);
let new_sy = ui.slider("Scl Y", sy, 0.01, 10.0);
let new_sz = ui.slider("Scl Z", sz, 0.01, 10.0);
// Write back (mutable borrow)
if let Some(t) = world.get_mut::<Transform>(entity) {
t.position.x = new_px;
t.position.y = new_py;
t.position.z = new_pz;
t.rotation.x = new_rx;
t.rotation.y = new_ry;
t.rotation.z = new_rz;
t.scale.x = new_sx;
t.scale.y = new_sy;
t.scale.z = new_sz;
}
}
// --- Tag ---
if world.has_component::<Tag>(entity) {
ui.text("-- Tag --");
// Sync buffer from world
if let Some(tag) = world.get::<Tag>(entity) {
if tag_buffer.is_empty() || *tag_buffer != tag.0 {
*tag_buffer = tag.0.clone();
}
}
let input_x = rect.x + PADDING;
let input_y = ui.layout.cursor_y;
let input_w = (rect.w - PADDING * 2.0).max(50.0);
if ui.text_input(8888, tag_buffer, input_x, input_y, input_w) {
if let Some(tag) = world.get_mut::<Tag>(entity) {
tag.0 = tag_buffer.clone();
}
}
ui.layout.advance_line();
}
// --- Parent ---
if let Some(parent) = world.get::<Parent>(entity) {
ui.text("-- Parent --");
let parent_text = format!("Parent: Entity({})", parent.0.id);
ui.text(&parent_text);
}
}
#[cfg(test)]
mod tests {
use super::*;
use voltex_ecs::transform::Transform;
use voltex_math::Vec3;
fn make_test_world() -> (World, Entity, Entity, Entity) {
let mut world = World::new();
let e1 = world.spawn();
world.add(e1, Transform::from_position(Vec3::new(0.0, 0.0, 0.0)));
world.add(e1, Tag("Root".to_string()));
let e2 = world.spawn();
world.add(e2, Transform::from_position(Vec3::new(1.0, 0.0, 0.0)));
world.add(e2, Tag("Child1".to_string()));
let e3 = world.spawn();
world.add(e3, Transform::from_position(Vec3::new(2.0, 0.0, 0.0)));
world.add(e3, Tag("Child2".to_string()));
voltex_ecs::hierarchy::add_child(&mut world, e1, e2);
voltex_ecs::hierarchy::add_child(&mut world, e1, e3);
(world, e1, e2, e3)
}
#[test]
fn test_count_nodes() {
let (world, e1, _, _) = make_test_world();
assert_eq!(count_entity_nodes(&world, e1), 3);
}
#[test]
fn test_hierarchy_draws_commands() {
let (world, _, _, _) = make_test_world();
let mut ui = UiContext::new(800.0, 600.0);
let mut selected: Option<Entity> = None;
let rect = Rect { x: 0.0, y: 0.0, w: 200.0, h: 400.0 };
ui.begin_frame(0.0, 0.0, false);
hierarchy_panel(&mut ui, &world, &mut selected, &rect);
ui.end_frame();
assert!(ui.draw_list.commands.len() > 0);
}
#[test]
fn test_hierarchy_click_selects() {
let (world, e1, _, _) = make_test_world();
let mut ui = UiContext::new(800.0, 600.0);
let mut selected: Option<Entity> = None;
let rect = Rect { x: 0.0, y: 0.0, w: 200.0, h: 400.0 };
// Click on first row (y = PADDING + small offset)
ui.begin_frame(50.0, PADDING + 2.0, true);
hierarchy_panel(&mut ui, &world, &mut selected, &rect);
ui.end_frame();
assert_eq!(selected, Some(e1));
}
#[test]
fn test_hierarchy_empty_world() {
let world = World::new();
let mut ui = UiContext::new(800.0, 600.0);
let mut selected: Option<Entity> = None;
let rect = Rect { x: 0.0, y: 0.0, w: 200.0, h: 400.0 };
ui.begin_frame(0.0, 0.0, false);
hierarchy_panel(&mut ui, &world, &mut selected, &rect);
ui.end_frame();
assert!(selected.is_none());
}
#[test]
fn test_inspector_no_selection() {
let mut world = World::new();
let mut ui = UiContext::new(800.0, 600.0);
let rect = Rect { x: 0.0, y: 0.0, w: 300.0, h: 400.0 };
ui.begin_frame(0.0, 0.0, false);
let mut tag_buf = String::new();
inspector_panel(&mut ui, &mut world, None, &rect, &mut tag_buf);
ui.end_frame();
// "No entity selected" text produces draw commands
assert!(ui.draw_list.commands.len() > 0);
}
#[test]
fn test_inspector_with_transform() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::new(1.0, 2.0, 3.0)));
let mut ui = UiContext::new(800.0, 600.0);
let rect = Rect { x: 0.0, y: 0.0, w: 300.0, h: 400.0 };
ui.begin_frame(0.0, 0.0, false);
let mut tag_buf = String::new();
inspector_panel(&mut ui, &mut world, Some(e), &rect, &mut tag_buf);
ui.end_frame();
// Header + 9 sliders produce many draw commands
assert!(ui.draw_list.commands.len() > 10);
}
#[test]
fn test_inspector_with_tag() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::new());
world.add(e, Tag("TestTag".to_string()));
let mut ui = UiContext::new(800.0, 600.0);
let rect = Rect { x: 0.0, y: 0.0, w: 300.0, h: 400.0 };
ui.begin_frame(0.0, 0.0, false);
let mut tag_buf = String::new();
inspector_panel(&mut ui, &mut world, Some(e), &rect, &mut tag_buf);
ui.end_frame();
// tag_buf should be synced from world
assert_eq!(tag_buf, "TestTag");
}
}

View File

@@ -4,9 +4,30 @@ pub mod layout;
pub mod renderer;
pub mod ui_context;
pub mod widgets;
pub mod dock;
pub mod inspector;
pub mod orbit_camera;
pub use font::FontAtlas;
pub use draw_list::{DrawVertex, DrawCommand, DrawList};
pub use layout::LayoutState;
pub use renderer::UiRenderer;
pub use ui_context::UiContext;
pub use dock::{DockTree, DockNode, Axis, Rect, LeafLayout};
pub use inspector::{hierarchy_panel, inspector_panel};
pub use orbit_camera::OrbitCamera;
pub mod viewport_texture;
pub use viewport_texture::{ViewportTexture, VIEWPORT_COLOR_FORMAT, VIEWPORT_DEPTH_FORMAT};
pub mod viewport_renderer;
pub use viewport_renderer::ViewportRenderer;
pub mod asset_browser;
pub use asset_browser::{AssetBrowser, asset_browser_panel};
pub mod ttf_parser;
pub mod rasterizer;
pub mod glyph_cache;
pub mod ttf_font;
pub use ttf_font::TtfFont;

View File

@@ -0,0 +1,166 @@
use voltex_math::{Vec3, Mat4};
use std::f32::consts::PI;
const PITCH_LIMIT: f32 = PI / 2.0 - 0.01;
const MIN_DISTANCE: f32 = 0.5;
const MAX_DISTANCE: f32 = 50.0;
const ORBIT_SENSITIVITY: f32 = 0.005;
const ZOOM_FACTOR: f32 = 0.1;
const PAN_SENSITIVITY: f32 = 0.01;
pub struct OrbitCamera {
pub target: Vec3,
pub distance: f32,
pub yaw: f32,
pub pitch: f32,
pub fov_y: f32,
pub near: f32,
pub far: f32,
}
impl OrbitCamera {
pub fn new() -> Self {
OrbitCamera {
target: Vec3::ZERO,
distance: 5.0,
yaw: 0.0,
pitch: 0.3,
fov_y: PI / 4.0,
near: 0.1,
far: 100.0,
}
}
pub fn position(&self) -> Vec3 {
let cp = self.pitch.cos();
let sp = self.pitch.sin();
let cy = self.yaw.cos();
let sy = self.yaw.sin();
Vec3::new(
self.target.x + self.distance * cp * sy,
self.target.y + self.distance * sp,
self.target.z + self.distance * cp * cy,
)
}
pub fn orbit(&mut self, dx: f32, dy: f32) {
self.yaw += dx * ORBIT_SENSITIVITY;
self.pitch += dy * ORBIT_SENSITIVITY;
self.pitch = self.pitch.clamp(-PITCH_LIMIT, PITCH_LIMIT);
}
pub fn zoom(&mut self, delta: f32) {
self.distance *= 1.0 - delta * ZOOM_FACTOR;
self.distance = self.distance.clamp(MIN_DISTANCE, MAX_DISTANCE);
}
pub fn pan(&mut self, dx: f32, dy: f32) {
let forward = (self.target - self.position()).normalize();
let right = forward.cross(Vec3::Y);
let right = if right.length() < 1e-4 { Vec3::X } else { right.normalize() };
let up = right.cross(forward).normalize();
let offset_x = right * (-dx * PAN_SENSITIVITY * self.distance);
let offset_y = up * (dy * PAN_SENSITIVITY * self.distance);
self.target = self.target + offset_x + offset_y;
}
pub fn view_matrix(&self) -> Mat4 {
Mat4::look_at(self.position(), self.target, Vec3::Y)
}
pub fn projection_matrix(&self, aspect: f32) -> Mat4 {
Mat4::perspective(self.fov_y, aspect, self.near, self.far)
}
pub fn view_projection(&self, aspect: f32) -> Mat4 {
self.projection_matrix(aspect).mul_mat4(&self.view_matrix())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_position() {
let cam = OrbitCamera::new();
let pos = cam.position();
assert!((pos.x).abs() < 1e-3);
assert!(pos.y > 0.0);
assert!(pos.z > 0.0);
}
#[test]
fn test_orbit_changes_yaw_pitch() {
let mut cam = OrbitCamera::new();
let old_yaw = cam.yaw;
let old_pitch = cam.pitch;
cam.orbit(100.0, 50.0);
assert!((cam.yaw - old_yaw - 100.0 * ORBIT_SENSITIVITY).abs() < 1e-6);
assert!((cam.pitch - old_pitch - 50.0 * ORBIT_SENSITIVITY).abs() < 1e-6);
}
#[test]
fn test_pitch_clamped() {
let mut cam = OrbitCamera::new();
cam.orbit(0.0, 100000.0);
assert!(cam.pitch <= PITCH_LIMIT);
cam.orbit(0.0, -200000.0);
assert!(cam.pitch >= -PITCH_LIMIT);
}
#[test]
fn test_zoom_changes_distance() {
let mut cam = OrbitCamera::new();
let d0 = cam.distance;
cam.zoom(1.0);
assert!(cam.distance < d0);
}
#[test]
fn test_zoom_clamped() {
let mut cam = OrbitCamera::new();
cam.zoom(1000.0);
assert!(cam.distance >= MIN_DISTANCE);
cam.zoom(-10000.0);
assert!(cam.distance <= MAX_DISTANCE);
}
#[test]
fn test_pan_moves_target() {
let mut cam = OrbitCamera::new();
let t0 = cam.target;
cam.pan(10.0, 0.0);
assert!((cam.target.x - t0.x).abs() > 1e-4 || (cam.target.z - t0.z).abs() > 1e-4);
}
#[test]
fn test_view_matrix_not_zero() {
let cam = OrbitCamera::new();
let v = cam.view_matrix();
let sum: f32 = v.cols.iter().flat_map(|c| c.iter()).map(|x| x.abs()).sum();
assert!(sum > 1.0);
}
#[test]
fn test_projection_matrix() {
let cam = OrbitCamera::new();
let p = cam.projection_matrix(16.0 / 9.0);
assert!(p.cols[0][0] > 0.0);
}
#[test]
fn test_view_projection() {
let cam = OrbitCamera::new();
let vp = cam.view_projection(1.0);
let v = cam.view_matrix();
let p = cam.projection_matrix(1.0);
let expected = p.mul_mat4(&v);
for i in 0..4 {
for j in 0..4 {
assert!((vp.cols[i][j] - expected.cols[i][j]).abs() < 1e-4,
"mismatch at [{i}][{j}]: {} vs {}", vp.cols[i][j], expected.cols[i][j]);
}
}
}
}

View File

@@ -0,0 +1,220 @@
use crate::ttf_parser::GlyphOutline;
/// Result of rasterizing a single glyph.
pub struct RasterResult {
pub width: u32,
pub height: u32,
/// R8 alpha bitmap, row-major.
pub bitmap: Vec<u8>,
/// Left bearing in pixels.
pub offset_x: f32,
/// Distance from baseline to top of glyph bbox in pixels.
pub offset_y: f32,
}
/// Recursively flatten a quadratic bezier into line segments.
pub fn flatten_quad(
x0: f32, y0: f32,
cx: f32, cy: f32,
x1: f32, y1: f32,
edges: &mut Vec<(f32, f32, f32, f32)>,
) {
// Check if curve is flat enough (midpoint distance < 0.5px)
let mx = (x0 + 2.0 * cx + x1) / 4.0;
let my = (y0 + 2.0 * cy + y1) / 4.0;
let lx = (x0 + x1) / 2.0;
let ly = (y0 + y1) / 2.0;
let dx = mx - lx;
let dy = my - ly;
if dx * dx + dy * dy < 0.25 {
edges.push((x0, y0, x1, y1));
} else {
// Subdivide at t=0.5
let m01x = (x0 + cx) / 2.0;
let m01y = (y0 + cy) / 2.0;
let m12x = (cx + x1) / 2.0;
let m12y = (cy + y1) / 2.0;
let midx = (m01x + m12x) / 2.0;
let midy = (m01y + m12y) / 2.0;
flatten_quad(x0, y0, m01x, m01y, midx, midy, edges);
flatten_quad(midx, midy, m12x, m12y, x1, y1, edges);
}
}
/// Rasterize a glyph outline at the given scale into an alpha bitmap.
pub fn rasterize(outline: &GlyphOutline, scale: f32) -> RasterResult {
// Handle empty outline (e.g., space)
if outline.contours.is_empty() || outline.x_max <= outline.x_min {
return RasterResult {
width: 0,
height: 0,
bitmap: vec![],
offset_x: 0.0,
offset_y: 0.0,
};
}
let x_min = outline.x_min as f32;
let y_min = outline.y_min as f32;
let x_max = outline.x_max as f32;
let y_max = outline.y_max as f32;
let w = ((x_max - x_min) * scale).ceil() as u32 + 2;
let h = ((y_max - y_min) * scale).ceil() as u32 + 2;
// Transform a font-space coordinate to bitmap pixel space.
let transform = |px: f32, py: f32| -> (f32, f32) {
let bx = (px - x_min) * scale + 1.0;
let by = (y_max - py) * scale + 1.0;
(bx, by)
};
// Build edges from contours
let mut edges: Vec<(f32, f32, f32, f32)> = Vec::new();
for contour in &outline.contours {
let n = contour.len();
if n < 2 {
continue;
}
let mut i = 0;
while i < n {
let p0 = &contour[i];
let p1 = &contour[(i + 1) % n];
if p0.on_curve && p1.on_curve {
// Line segment
let (px0, py0) = transform(p0.x, p0.y);
let (px1, py1) = transform(p1.x, p1.y);
edges.push((px0, py0, px1, py1));
i += 1;
} else if p0.on_curve && !p1.on_curve {
// Quadratic bezier: p0 -> p1(control) -> p2(on_curve)
let p2 = &contour[(i + 2) % n];
let (px0, py0) = transform(p0.x, p0.y);
let (cx, cy) = transform(p1.x, p1.y);
let (px1, py1) = transform(p2.x, p2.y);
flatten_quad(px0, py0, cx, cy, px1, py1, &mut edges);
i += 2;
} else {
// Skip unexpected off-curve start
i += 1;
}
}
}
// Scanline fill (non-zero winding)
let mut bitmap = vec![0u8; (w * h) as usize];
for row in 0..h {
let scan_y = row as f32 + 0.5;
let mut intersections: Vec<(f32, i32)> = Vec::new();
for &(x0, y0, x1, y1) in &edges {
if (y0 <= scan_y && y1 > scan_y) || (y1 <= scan_y && y0 > scan_y) {
let t = (scan_y - y0) / (y1 - y0);
let ix = x0 + t * (x1 - x0);
let dir = if y1 > y0 { 1 } else { -1 };
intersections.push((ix, dir));
}
}
intersections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// Fill using non-zero winding rule
let mut winding = 0i32;
let mut fill_start = 0.0f32;
for &(x, dir) in &intersections {
let old_winding = winding;
winding += dir;
if old_winding == 0 && winding != 0 {
fill_start = x;
}
if old_winding != 0 && winding == 0 {
let px_start = (fill_start.floor() as i32).max(0) as u32;
let px_end = (x.ceil() as u32).min(w);
for px in px_start..px_end {
bitmap[(row * w + px) as usize] = 255;
}
}
}
}
let offset_x = x_min * scale;
let offset_y = y_max * scale;
RasterResult {
width: w,
height: h,
bitmap,
offset_x,
offset_y,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ttf_parser::TtfParser;
fn load_test_font() -> Option<TtfParser> {
let paths = [
"C:/Windows/Fonts/arial.ttf",
"C:/Windows/Fonts/consola.ttf",
];
for p in &paths {
if let Ok(data) = std::fs::read(p) {
if let Ok(parser) = TtfParser::parse(data) {
return Some(parser);
}
}
}
None
}
#[test]
fn test_rasterize_produces_bitmap() {
let parser = load_test_font().expect("no font");
let gid = parser.glyph_index(0x41); // 'A'
let outline = parser.glyph_outline(gid).expect("no outline");
let scale = 32.0 / parser.units_per_em as f32;
let result = rasterize(&outline, scale);
assert!(result.width > 0);
assert!(result.height > 0);
assert!(!result.bitmap.is_empty());
}
#[test]
fn test_rasterize_has_filled_pixels() {
let parser = load_test_font().expect("no font");
let gid = parser.glyph_index(0x41);
let outline = parser.glyph_outline(gid).expect("no outline");
let scale = 32.0 / parser.units_per_em as f32;
let result = rasterize(&outline, scale);
let filled = result.bitmap.iter().filter(|&&b| b > 0).count();
assert!(filled > 0, "bitmap should have filled pixels");
}
#[test]
fn test_rasterize_empty() {
let parser = load_test_font().expect("no font");
let gid = parser.glyph_index(0x20); // space
let outline = parser.glyph_outline(gid);
if let Some(o) = outline {
let scale = 32.0 / parser.units_per_em as f32;
let result = rasterize(&o, scale);
// Space should be empty or zero-sized
assert!(
result.width == 0 || result.bitmap.iter().all(|&b| b == 0)
);
}
}
#[test]
fn test_flatten_quad_produces_edges() {
let mut edges = Vec::new();
flatten_quad(0.0, 0.0, 5.0, 10.0, 10.0, 0.0, &mut edges);
assert!(
edges.len() >= 2,
"bezier should flatten to multiple segments"
);
}
}

View File

@@ -286,8 +286,13 @@ impl UiRenderer {
pass.set_vertex_buffer(0, vertex_buffer.slice(..));
pass.set_index_buffer(index_buffer.slice(..), wgpu::IndexFormat::Uint16);
// Draw each command
// Draw each command (with optional scissor clipping)
for cmd in &draw_list.commands {
if let Some(scissor) = &cmd.scissor {
pass.set_scissor_rect(scissor.x, scissor.y, scissor.w, scissor.h);
} else {
pass.set_scissor_rect(0, 0, screen_w as u32, screen_h as u32);
}
pass.draw_indexed(
cmd.index_offset..cmd.index_offset + cmd.index_count,
0,

View File

@@ -0,0 +1,153 @@
use crate::ttf_parser::TtfParser;
use crate::rasterizer::rasterize;
use crate::glyph_cache::{GlyphCache, GlyphInfo, ATLAS_SIZE};
pub struct TtfFont {
parser: TtfParser,
cache: GlyphCache,
pub font_size: f32,
pub line_height: f32,
pub ascender: f32,
pub descender: f32,
scale: f32,
}
impl TtfFont {
pub fn new(data: &[u8], font_size: f32) -> Result<Self, String> {
let parser = TtfParser::parse(data.to_vec())?;
let scale = font_size / parser.units_per_em as f32;
let ascender = parser.ascender as f32 * scale;
let descender = parser.descender as f32 * scale;
let line_height = (parser.ascender - parser.descender + parser.line_gap) as f32 * scale;
let cache = GlyphCache::new(ATLAS_SIZE, ATLAS_SIZE);
Ok(TtfFont {
parser,
cache,
font_size,
line_height,
ascender,
descender,
scale,
})
}
/// Get glyph info for a character, rasterizing on cache miss.
pub fn glyph(&mut self, ch: char) -> &GlyphInfo {
if self.cache.get(ch).is_some() {
return self.cache.get(ch).unwrap();
}
let glyph_id = self.parser.glyph_index(ch as u32);
let metrics = self.parser.glyph_metrics(glyph_id);
let advance = metrics.advance_width as f32 * self.scale;
let bearing_x = metrics.left_side_bearing as f32 * self.scale;
let outline = self.parser.glyph_outline(glyph_id);
match outline {
Some(ref o) if !o.contours.is_empty() => {
let result = rasterize(o, self.scale);
// bearing_y = y_max * scale (distance above baseline)
// When rendering: glyph_y = baseline_y - bearing_y
// For our UI: text_y is the top of the line, baseline = text_y + ascender
// So: glyph_y = text_y + ascender - bearing_y
let bearing_y = o.y_max as f32 * self.scale;
self.cache.insert(
ch,
&result.bitmap,
result.width,
result.height,
advance,
bearing_x,
bearing_y,
);
}
_ => {
// No outline (space, etc.) — insert empty glyph
self.cache.insert(ch, &[], 0, 0, advance, 0.0, 0.0);
}
}
self.cache.get(ch).unwrap()
}
/// Calculate total width of a text string.
pub fn text_width(&mut self, text: &str) -> f32 {
let mut width = 0.0;
for ch in text.chars() {
let info = self.glyph(ch);
width += info.advance;
}
width
}
pub fn atlas_data(&self) -> &[u8] {
&self.cache.atlas_data
}
pub fn atlas_size(&self) -> (u32, u32) {
(self.cache.atlas_width, self.cache.atlas_height)
}
pub fn is_dirty(&self) -> bool {
self.cache.dirty
}
pub fn clear_dirty(&mut self) {
self.cache.clear_dirty();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn load_test_font() -> Option<TtfFont> {
let paths = ["C:/Windows/Fonts/arial.ttf", "C:/Windows/Fonts/consola.ttf"];
for p in &paths {
if let Ok(data) = std::fs::read(p) {
if let Ok(font) = TtfFont::new(&data, 24.0) {
return Some(font);
}
}
}
None
}
#[test]
fn test_new() {
let font = load_test_font().expect("no test font");
assert!(font.font_size == 24.0);
assert!(font.ascender > 0.0);
assert!(font.line_height > 0.0);
}
#[test]
fn test_glyph_caches() {
let mut font = load_test_font().expect("no test font");
let _g1 = font.glyph('A');
assert!(font.is_dirty());
font.clear_dirty();
let _g2 = font.glyph('A'); // cache hit
assert!(!font.is_dirty());
}
#[test]
fn test_text_width() {
let mut font = load_test_font().expect("no test font");
let w1 = font.text_width("Hello");
let w2 = font.text_width("Hello World");
assert!(w1 > 0.0);
assert!(w2 > w1);
}
#[test]
fn test_space_glyph() {
let mut font = load_test_font().expect("no test font");
let info = font.glyph(' ');
assert!(info.advance > 0.0); // space has advance but no bitmap
assert!(info.width == 0.0);
}
}

View File

@@ -0,0 +1,518 @@
use std::collections::HashMap;
// --- Byte helpers (big-endian) ---
fn read_u8(data: &[u8], off: usize) -> u8 {
data[off]
}
fn read_u16(data: &[u8], off: usize) -> u16 {
u16::from_be_bytes([data[off], data[off + 1]])
}
fn read_i16(data: &[u8], off: usize) -> i16 {
i16::from_be_bytes([data[off], data[off + 1]])
}
fn read_u32(data: &[u8], off: usize) -> u32 {
u32::from_be_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]])
}
// --- Data types ---
#[derive(Debug, Clone)]
pub struct OutlinePoint {
pub x: f32,
pub y: f32,
pub on_curve: bool,
}
#[derive(Debug, Clone)]
pub struct GlyphOutline {
pub contours: Vec<Vec<OutlinePoint>>,
pub x_min: i16,
pub y_min: i16,
pub x_max: i16,
pub y_max: i16,
}
#[derive(Debug, Clone, Copy)]
pub struct GlyphMetrics {
pub advance_width: u16,
pub left_side_bearing: i16,
}
// --- Flag bits for simple glyph parsing ---
const ON_CURVE: u8 = 0x01;
const X_SHORT: u8 = 0x02;
const Y_SHORT: u8 = 0x04;
const REPEAT_FLAG: u8 = 0x08;
const X_SAME_OR_POS: u8 = 0x10;
const Y_SAME_OR_POS: u8 = 0x20;
// --- TtfParser ---
pub struct TtfParser {
data: Vec<u8>,
pub tables: HashMap<[u8; 4], (u32, u32)>, // tag -> (offset, length)
pub units_per_em: u16,
pub num_glyphs: u16,
pub ascender: i16,
pub descender: i16,
pub line_gap: i16,
pub num_h_metrics: u16,
pub loca_format: i16,
}
impl TtfParser {
/// Parse a TTF file from raw bytes.
pub fn parse(data: Vec<u8>) -> Result<Self, String> {
if data.len() < 12 {
return Err("File too short for offset table".into());
}
let _sf_version = read_u32(&data, 0);
let num_tables = read_u16(&data, 4) as usize;
if data.len() < 12 + num_tables * 16 {
return Err("File too short for table records".into());
}
let mut tables = HashMap::new();
for i in 0..num_tables {
let rec_off = 12 + i * 16;
let mut tag = [0u8; 4];
tag.copy_from_slice(&data[rec_off..rec_off + 4]);
let offset = read_u32(&data, rec_off + 8);
let length = read_u32(&data, rec_off + 12);
tables.insert(tag, (offset, length));
}
// Parse head table
let &(head_off, _) = tables
.get(b"head")
.ok_or("Missing head table")?;
let head_off = head_off as usize;
let units_per_em = read_u16(&data, head_off + 18);
let loca_format = read_i16(&data, head_off + 50);
// Parse hhea table
let &(hhea_off, _) = tables
.get(b"hhea")
.ok_or("Missing hhea table")?;
let hhea_off = hhea_off as usize;
let ascender = read_i16(&data, hhea_off + 4);
let descender = read_i16(&data, hhea_off + 6);
let line_gap = read_i16(&data, hhea_off + 8);
let num_h_metrics = read_u16(&data, hhea_off + 34);
// Parse maxp table
let &(maxp_off, _) = tables
.get(b"maxp")
.ok_or("Missing maxp table")?;
let maxp_off = maxp_off as usize;
let num_glyphs = read_u16(&data, maxp_off + 4);
Ok(Self {
data,
tables,
units_per_em,
num_glyphs,
ascender,
descender,
line_gap,
num_h_metrics,
loca_format,
})
}
/// Look up the glyph index for a Unicode codepoint via cmap Format 4.
pub fn glyph_index(&self, codepoint: u32) -> u16 {
let &(cmap_off, _) = match self.tables.get(b"cmap") {
Some(v) => v,
None => return 0,
};
let cmap_off = cmap_off as usize;
let num_subtables = read_u16(&self.data, cmap_off + 2) as usize;
// Find a Format 4 subtable (prefer platform 3 encoding 1, or platform 0)
let mut fmt4_offset: Option<usize> = None;
for i in 0..num_subtables {
let rec = cmap_off + 4 + i * 8;
let platform_id = read_u16(&self.data, rec);
let encoding_id = read_u16(&self.data, rec + 2);
let sub_offset = read_u32(&self.data, rec + 4) as usize;
let abs_off = cmap_off + sub_offset;
if abs_off + 2 > self.data.len() {
continue;
}
let format = read_u16(&self.data, abs_off);
if format == 4 {
// Prefer Windows Unicode BMP (3,1)
if platform_id == 3 && encoding_id == 1 {
fmt4_offset = Some(abs_off);
break;
}
// Accept platform 0 as fallback
if platform_id == 0 && fmt4_offset.is_none() {
fmt4_offset = Some(abs_off);
}
}
}
let sub_off = match fmt4_offset {
Some(o) => o,
None => return 0,
};
// Parse Format 4
let seg_count_x2 = read_u16(&self.data, sub_off + 6) as usize;
let seg_count = seg_count_x2 / 2;
let end_code_base = sub_off + 14;
let start_code_base = end_code_base + seg_count * 2 + 2; // +2 for reservedPad
let id_delta_base = start_code_base + seg_count * 2;
let id_range_offset_base = id_delta_base + seg_count * 2;
for i in 0..seg_count {
let end_code = read_u16(&self.data, end_code_base + i * 2) as u32;
let start_code = read_u16(&self.data, start_code_base + i * 2) as u32;
if end_code >= codepoint && start_code <= codepoint {
let id_delta = read_i16(&self.data, id_delta_base + i * 2);
let id_range_offset = read_u16(&self.data, id_range_offset_base + i * 2) as usize;
if id_range_offset == 0 {
return (codepoint as i32 + id_delta as i32) as u16;
} else {
let offset_in_bytes =
id_range_offset + 2 * (codepoint - start_code) as usize;
let glyph_addr = id_range_offset_base + i * 2 + offset_in_bytes;
if glyph_addr + 1 < self.data.len() {
let glyph = read_u16(&self.data, glyph_addr);
if glyph != 0 {
return (glyph as i32 + id_delta as i32) as u16;
}
}
return 0;
}
}
}
0
}
/// Get the offset of a glyph in the glyf table using loca.
fn glyph_offset(&self, glyph_id: u16) -> Option<(usize, usize)> {
let &(loca_off, _) = self.tables.get(b"loca")?;
let &(glyf_off, _) = self.tables.get(b"glyf")?;
let loca_off = loca_off as usize;
let glyf_off = glyf_off as usize;
if glyph_id >= self.num_glyphs {
return None;
}
let (offset, next_offset) = if self.loca_format == 0 {
// Short format: u16 * 2
let o = read_u16(&self.data, loca_off + glyph_id as usize * 2) as usize * 2;
let n = read_u16(&self.data, loca_off + (glyph_id as usize + 1) * 2) as usize * 2;
(o, n)
} else {
// Long format: u32
let o = read_u32(&self.data, loca_off + glyph_id as usize * 4) as usize;
let n = read_u32(&self.data, loca_off + (glyph_id as usize + 1) * 4) as usize;
(o, n)
};
if offset == next_offset {
// Empty glyph (e.g., space)
return None;
}
Some((glyf_off + offset, next_offset - offset))
}
/// Parse the outline of a simple glyph.
pub fn glyph_outline(&self, glyph_id: u16) -> Option<GlyphOutline> {
let (glyph_off, _glyph_len) = match self.glyph_offset(glyph_id) {
Some(v) => v,
None => return None, // empty glyph
};
let num_contours = read_i16(&self.data, glyph_off);
if num_contours < 0 {
// Compound glyph — not supported
return None;
}
let num_contours = num_contours as usize;
if num_contours == 0 {
return Some(GlyphOutline {
contours: Vec::new(),
x_min: read_i16(&self.data, glyph_off + 2),
y_min: read_i16(&self.data, glyph_off + 4),
x_max: read_i16(&self.data, glyph_off + 6),
y_max: read_i16(&self.data, glyph_off + 8),
});
}
let x_min = read_i16(&self.data, glyph_off + 2);
let y_min = read_i16(&self.data, glyph_off + 4);
let x_max = read_i16(&self.data, glyph_off + 6);
let y_max = read_i16(&self.data, glyph_off + 8);
// endPtsOfContours
let mut end_pts = Vec::with_capacity(num_contours);
let mut off = glyph_off + 10;
for _ in 0..num_contours {
end_pts.push(read_u16(&self.data, off) as usize);
off += 2;
}
let num_points = end_pts[num_contours - 1] + 1;
// Skip instructions
let instruction_length = read_u16(&self.data, off) as usize;
off += 2 + instruction_length;
// Parse flags
let mut flags = Vec::with_capacity(num_points);
while flags.len() < num_points {
let flag = read_u8(&self.data, off);
off += 1;
flags.push(flag);
if flag & REPEAT_FLAG != 0 {
let repeat_count = read_u8(&self.data, off) as usize;
off += 1;
for _ in 0..repeat_count {
flags.push(flag);
}
}
}
// Parse x-coordinates (delta-encoded)
let mut x_coords = Vec::with_capacity(num_points);
let mut x: i32 = 0;
for i in 0..num_points {
let flag = flags[i];
if flag & X_SHORT != 0 {
let dx = read_u8(&self.data, off) as i32;
off += 1;
x += if flag & X_SAME_OR_POS != 0 { dx } else { -dx };
} else if flag & X_SAME_OR_POS != 0 {
// delta = 0
} else {
let dx = read_i16(&self.data, off) as i32;
off += 2;
x += dx;
}
x_coords.push(x);
}
// Parse y-coordinates (delta-encoded)
let mut y_coords = Vec::with_capacity(num_points);
let mut y: i32 = 0;
for i in 0..num_points {
let flag = flags[i];
if flag & Y_SHORT != 0 {
let dy = read_u8(&self.data, off) as i32;
off += 1;
y += if flag & Y_SAME_OR_POS != 0 { dy } else { -dy };
} else if flag & Y_SAME_OR_POS != 0 {
// delta = 0
} else {
let dy = read_i16(&self.data, off) as i32;
off += 2;
y += dy;
}
y_coords.push(y);
}
// Build contours with implicit on-curve point insertion
let mut contours = Vec::with_capacity(num_contours);
let mut start = 0;
for &end in &end_pts {
let raw_points: Vec<(f32, f32, bool)> = (start..=end)
.map(|i| {
(
x_coords[i] as f32,
y_coords[i] as f32,
flags[i] & ON_CURVE != 0,
)
})
.collect();
let mut contour = Vec::new();
let n = raw_points.len();
if n == 0 {
start = end + 1;
contours.push(contour);
continue;
}
for j in 0..n {
let (cx, cy, c_on) = raw_points[j];
let (nx, ny, n_on) = raw_points[(j + 1) % n];
contour.push(OutlinePoint {
x: cx,
y: cy,
on_curve: c_on,
});
// If both current and next are off-curve, insert implicit midpoint
if !c_on && !n_on {
contour.push(OutlinePoint {
x: (cx + nx) * 0.5,
y: (cy + ny) * 0.5,
on_curve: true,
});
}
}
start = end + 1;
contours.push(contour);
}
Some(GlyphOutline {
contours,
x_min,
y_min,
x_max,
y_max,
})
}
/// Get the horizontal metrics for a glyph.
pub fn glyph_metrics(&self, glyph_id: u16) -> GlyphMetrics {
let &(hmtx_off, _) = match self.tables.get(b"hmtx") {
Some(v) => v,
None => {
return GlyphMetrics {
advance_width: 0,
left_side_bearing: 0,
}
}
};
let hmtx_off = hmtx_off as usize;
if (glyph_id as u16) < self.num_h_metrics {
let rec = hmtx_off + glyph_id as usize * 4;
GlyphMetrics {
advance_width: read_u16(&self.data, rec),
left_side_bearing: read_i16(&self.data, rec + 2),
}
} else {
// Use last advance_width, lsb from separate array
let last_aw_off = hmtx_off + (self.num_h_metrics as usize - 1) * 4;
let advance_width = read_u16(&self.data, last_aw_off);
let lsb_array_off = hmtx_off + self.num_h_metrics as usize * 4;
let idx = glyph_id as usize - self.num_h_metrics as usize;
let left_side_bearing = read_i16(&self.data, lsb_array_off + idx * 2);
GlyphMetrics {
advance_width,
left_side_bearing,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn load_test_font() -> Option<TtfParser> {
let paths = [
"C:/Windows/Fonts/arial.ttf",
"C:/Windows/Fonts/consola.ttf",
];
for path in &paths {
if let Ok(data) = std::fs::read(path) {
if let Ok(parser) = TtfParser::parse(data) {
return Some(parser);
}
}
}
None
}
#[test]
fn test_parse_loads_tables() {
let parser = load_test_font().expect("no test font found");
assert!(parser.tables.contains_key(b"head"));
assert!(parser.tables.contains_key(b"cmap"));
assert!(parser.tables.contains_key(b"glyf"));
}
#[test]
fn test_head_values() {
let parser = load_test_font().expect("no test font found");
assert!(parser.units_per_em > 0);
assert!(parser.loca_format == 0 || parser.loca_format == 1);
}
#[test]
fn test_hhea_values() {
let parser = load_test_font().expect("no test font found");
assert!(parser.ascender > 0);
assert!(parser.num_h_metrics > 0);
}
#[test]
fn test_maxp_values() {
let parser = load_test_font().expect("no test font found");
assert!(parser.num_glyphs > 0);
}
#[test]
fn test_cmap_ascii() {
let parser = load_test_font().expect("no test font found");
let glyph_a = parser.glyph_index(0x41); // 'A'
assert!(glyph_a > 0, "glyph index for 'A' should be > 0");
}
#[test]
fn test_cmap_space() {
let parser = load_test_font().expect("no test font found");
let glyph = parser.glyph_index(0x20); // space
assert!(glyph > 0);
}
#[test]
fn test_cmap_unmapped() {
let parser = load_test_font().expect("no test font found");
let glyph = parser.glyph_index(0xFFFD0); // unlikely codepoint
assert_eq!(glyph, 0);
}
#[test]
fn test_glyph_outline_has_contours() {
let parser = load_test_font().expect("no test font found");
let gid = parser.glyph_index(0x41); // 'A'
let outline = parser.glyph_outline(gid);
assert!(outline.is_some());
let outline = outline.unwrap();
assert!(!outline.contours.is_empty(), "A should have contours");
}
#[test]
fn test_glyph_metrics() {
let parser = load_test_font().expect("no test font found");
let gid = parser.glyph_index(0x41);
let metrics = parser.glyph_metrics(gid);
assert!(metrics.advance_width > 0);
}
#[test]
fn test_space_no_contours() {
let parser = load_test_font().expect("no test font found");
let gid = parser.glyph_index(0x20);
let outline = parser.glyph_outline(gid);
// Space may have no outline or empty contours
if let Some(o) = outline {
assert!(o.contours.is_empty());
}
}
}

View File

@@ -1,7 +1,19 @@
use std::collections::HashMap;
use crate::draw_list::DrawList;
use crate::font::FontAtlas;
use crate::layout::LayoutState;
/// Key events the UI system understands.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Key {
Left,
Right,
Backspace,
Delete,
Home,
End,
}
pub struct UiContext {
pub hot: Option<u64>,
pub active: Option<u64>,
@@ -12,11 +24,25 @@ pub struct UiContext {
pub mouse_down: bool,
pub mouse_clicked: bool,
pub mouse_released: bool,
pub mouse_scroll: f32,
pub screen_width: f32,
pub screen_height: f32,
pub font: FontAtlas,
id_counter: u64,
prev_mouse_down: bool,
// Text input state
pub focused_id: Option<u32>,
pub cursor_pos: usize,
input_chars: Vec<char>,
input_keys: Vec<Key>,
// Scroll panel state
pub scroll_offsets: HashMap<u32, f32>,
// Drag and drop state
pub dragging: Option<(u32, u64)>,
pub drag_start: (f32, f32),
pub(crate) drag_started: bool,
/// TTF font for high-quality text rendering (None = bitmap fallback).
pub ttf_font: Option<crate::ttf_font::TtfFont>,
}
impl UiContext {
@@ -32,11 +58,21 @@ impl UiContext {
mouse_down: false,
mouse_clicked: false,
mouse_released: false,
mouse_scroll: 0.0,
screen_width: screen_w,
screen_height: screen_h,
font: FontAtlas::generate(),
id_counter: 0,
prev_mouse_down: false,
focused_id: None,
cursor_pos: 0,
input_chars: Vec::new(),
input_keys: Vec::new(),
scroll_offsets: HashMap::new(),
dragging: None,
drag_start: (0.0, 0.0),
drag_started: false,
ttf_font: None,
}
}
@@ -60,9 +96,39 @@ impl UiContext {
self.layout = LayoutState::new(0.0, 0.0);
}
/// Feed a character input event (printable ASCII) for text input widgets.
pub fn input_char(&mut self, ch: char) {
if ch.is_ascii() && !ch.is_ascii_control() {
self.input_chars.push(ch);
}
}
/// Feed a key input event for text input widgets.
pub fn input_key(&mut self, key: Key) {
self.input_keys.push(key);
}
/// Set mouse scroll delta for this frame (positive = scroll up).
pub fn set_scroll(&mut self, delta: f32) {
self.mouse_scroll = delta;
}
/// Drain all pending input chars (consumed by text_input widget).
pub(crate) fn drain_chars(&mut self) -> Vec<char> {
std::mem::take(&mut self.input_chars)
}
/// Drain all pending key events (consumed by text_input widget).
pub(crate) fn drain_keys(&mut self) -> Vec<Key> {
std::mem::take(&mut self.input_keys)
}
/// End the current frame.
pub fn end_frame(&mut self) {
// Nothing for now — GPU submission will hook in here later.
self.mouse_scroll = 0.0;
// Clear any unconsumed input
self.input_chars.clear();
self.input_keys.clear();
}
/// Generate a new unique ID for this frame.
@@ -78,4 +144,41 @@ impl UiContext {
&& self.mouse_y >= y
&& self.mouse_y < y + h
}
/// Draw text using TTF font if available, otherwise bitmap font fallback.
pub fn draw_text(&mut self, text: &str, x: f32, y: f32, color: [u8; 4]) {
if let Some(ref mut ttf) = self.ttf_font {
let ascender = ttf.ascender;
let mut cx = x;
for ch in text.chars() {
let info = ttf.glyph(ch).clone();
if info.width > 0.0 && info.height > 0.0 {
let gx = cx + info.bearing_x;
let gy = y + ascender - info.bearing_y;
let (u0, v0, u1, v1) = (info.uv[0], info.uv[1], info.uv[2], info.uv[3]);
self.draw_list.add_rect_uv(gx, gy, info.width, info.height, u0, v0, u1, v1, color);
}
cx += info.advance;
}
} else {
// Bitmap font fallback
let gw = self.font.glyph_width as f32;
let gh = self.font.glyph_height as f32;
let mut cx = x;
for ch in text.chars() {
let (u0, v0, u1, v1) = self.font.glyph_uv(ch);
self.draw_list.add_rect_uv(cx, y, gw, gh, u0, v0, u1, v1, color);
cx += gw;
}
}
}
/// Calculate text width using TTF or bitmap font.
pub fn ttf_text_width(&mut self, text: &str) -> f32 {
if let Some(ref mut ttf) = self.ttf_font {
ttf.text_width(text)
} else {
text.len() as f32 * self.font.glyph_width as f32
}
}
}

View File

@@ -0,0 +1,182 @@
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
struct RectUniform {
rect: [f32; 4],
screen: [f32; 2],
_pad: [f32; 2],
}
pub struct ViewportRenderer {
pipeline: wgpu::RenderPipeline,
bind_group_layout: wgpu::BindGroupLayout,
sampler: wgpu::Sampler,
uniform_buffer: wgpu::Buffer,
}
impl ViewportRenderer {
pub fn new(device: &wgpu::Device, surface_format: wgpu::TextureFormat) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Viewport Blit Shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("viewport_shader.wgsl").into()),
});
let bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Viewport Blit BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Float { filterable: true },
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Viewport Blit PL"),
bind_group_layouts: &[&bind_group_layout],
immediate_size: 0,
});
let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("Viewport Blit Pipeline"),
layout: Some(&pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[],
compilation_options: wgpu::PipelineCompilationOptions::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format: surface_format,
blend: Some(wgpu::BlendState::REPLACE),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: wgpu::PipelineCompilationOptions::default(),
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleList,
..Default::default()
},
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
multiview_mask: None,
cache: None,
});
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("Viewport Sampler"),
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
..Default::default()
});
let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Viewport Rect Uniform"),
size: std::mem::size_of::<RectUniform>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ViewportRenderer {
pipeline,
bind_group_layout,
sampler,
uniform_buffer,
}
}
pub fn render(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
encoder: &mut wgpu::CommandEncoder,
target_view: &wgpu::TextureView,
viewport_color_view: &wgpu::TextureView,
screen_w: f32,
screen_h: f32,
rect_x: f32,
rect_y: f32,
rect_w: f32,
rect_h: f32,
) {
let uniform = RectUniform {
rect: [rect_x, rect_y, rect_w, rect_h],
screen: [screen_w, screen_h],
_pad: [0.0; 2],
};
queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[uniform]));
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Viewport Blit BG"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::TextureView(viewport_color_view),
},
wgpu::BindGroupEntry {
binding: 2,
resource: wgpu::BindingResource::Sampler(&self.sampler),
},
],
});
let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: Some("Viewport Blit Pass"),
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view: target_view,
resolve_target: None,
depth_slice: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Load,
store: wgpu::StoreOp::Store,
},
})],
depth_stencil_attachment: None,
occlusion_query_set: None,
timestamp_writes: None,
multiview_mask: None,
});
rpass.set_scissor_rect(
rect_x.max(0.0) as u32,
rect_y.max(0.0) as u32,
rect_w.ceil().max(1.0) as u32,
rect_h.ceil().max(1.0) as u32,
);
rpass.set_pipeline(&self.pipeline);
rpass.set_bind_group(0, &bind_group, &[]);
rpass.draw(0..6, 0..1);
}
}

View File

@@ -0,0 +1,42 @@
struct RectUniform {
rect: vec4<f32>,
screen: vec2<f32>,
_pad: vec2<f32>,
};
@group(0) @binding(0) var<uniform> u: RectUniform;
@group(0) @binding(1) var t_viewport: texture_2d<f32>;
@group(0) @binding(2) var s_viewport: sampler;
struct VertexOutput {
@builtin(position) position: vec4<f32>,
@location(0) uv: vec2<f32>,
};
@vertex
fn vs_main(@builtin(vertex_index) idx: u32) -> VertexOutput {
var positions = array<vec2<f32>, 6>(
vec2<f32>(0.0, 0.0),
vec2<f32>(1.0, 0.0),
vec2<f32>(1.0, 1.0),
vec2<f32>(0.0, 0.0),
vec2<f32>(1.0, 1.0),
vec2<f32>(0.0, 1.0),
);
let p = positions[idx];
let px = u.rect.x + p.x * u.rect.z;
let py = u.rect.y + p.y * u.rect.w;
let ndc_x = (px / u.screen.x) * 2.0 - 1.0;
let ndc_y = 1.0 - (py / u.screen.y) * 2.0;
var out: VertexOutput;
out.position = vec4<f32>(ndc_x, ndc_y, 0.0, 1.0);
out.uv = p;
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return textureSample(t_viewport, s_viewport, in.uv);
}

View File

@@ -0,0 +1,69 @@
pub const VIEWPORT_COLOR_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8Unorm;
pub const VIEWPORT_DEPTH_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Depth32Float;
pub struct ViewportTexture {
pub color_texture: wgpu::Texture,
pub color_view: wgpu::TextureView,
pub depth_texture: wgpu::Texture,
pub depth_view: wgpu::TextureView,
pub width: u32,
pub height: u32,
}
impl ViewportTexture {
pub fn new(device: &wgpu::Device, width: u32, height: u32) -> Self {
let w = width.max(1);
let h = height.max(1);
let color_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("Viewport Color"),
size: wgpu::Extent3d {
width: w,
height: h,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: VIEWPORT_COLOR_FORMAT,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let color_view = color_texture.create_view(&wgpu::TextureViewDescriptor::default());
let depth_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("Viewport Depth"),
size: wgpu::Extent3d {
width: w,
height: h,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: VIEWPORT_DEPTH_FORMAT,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT,
view_formats: &[],
});
let depth_view = depth_texture.create_view(&wgpu::TextureViewDescriptor::default());
ViewportTexture {
color_texture,
color_view,
depth_texture,
depth_view,
width: w,
height: h,
}
}
pub fn ensure_size(&mut self, device: &wgpu::Device, width: u32, height: u32) -> bool {
let w = width.max(1);
let h = height.max(1);
if w == self.width && h == self.height {
return false;
}
*self = Self::new(device, w, h);
true
}
}

View File

@@ -1,4 +1,4 @@
use crate::ui_context::UiContext;
use crate::ui_context::{Key, UiContext};
// Color palette
const COLOR_BG: [u8; 4] = [0x2B, 0x2B, 0x2B, 0xFF];
@@ -11,6 +11,13 @@ const COLOR_SLIDER_BG: [u8; 4] = [0x44, 0x44, 0x44, 0xFF];
const COLOR_SLIDER_HANDLE: [u8; 4] = [0x88, 0x88, 0xFF, 0xFF];
const COLOR_CHECK_BG: [u8; 4] = [0x44, 0x44, 0x44, 0xFF];
const COLOR_CHECK_MARK: [u8; 4] = [0x88, 0xFF, 0x88, 0xFF];
const COLOR_INPUT_BG: [u8; 4] = [0x22, 0x22, 0x22, 0xFF];
const COLOR_INPUT_BORDER: [u8; 4] = [0x66, 0x66, 0x66, 0xFF];
const COLOR_INPUT_FOCUSED: [u8; 4] = [0x44, 0x88, 0xFF, 0xFF];
const COLOR_CURSOR: [u8; 4] = [0xFF, 0xFF, 0xFF, 0xFF];
const COLOR_SCROLLBAR_BG: [u8; 4] = [0x33, 0x33, 0x33, 0xFF];
const COLOR_SCROLLBAR_THUMB: [u8; 4]= [0x66, 0x66, 0x77, 0xFF];
const DRAG_THRESHOLD: f32 = 5.0;
impl UiContext {
/// Draw text at the current cursor position and advance to the next line.
@@ -233,11 +240,231 @@ impl UiContext {
pub fn end_panel(&mut self) {
// Nothing for now; future could restore outer cursor state.
}
// ── Text Input Widget ─────────────────────────────────────────────
/// Draw an editable single-line text input. Returns true if the buffer changed.
///
/// `id` must be unique per text input. The widget renders a box at (x, y) with
/// the given `width`. Height is determined by the font glyph height + padding.
pub fn text_input(&mut self, id: u32, buffer: &mut String, x: f32, y: f32, width: f32) -> bool {
let gw = self.font.glyph_width as f32;
let gh = self.font.glyph_height as f32;
let padding = self.layout.padding;
let height = gh + padding * 2.0;
let hovered = self.mouse_in_rect(x, y, width, height);
// Click to focus / unfocus
if self.mouse_clicked {
if hovered {
self.focused_id = Some(id);
// Place cursor at end or at click position
let click_offset = ((self.mouse_x - x - padding) / gw).round() as usize;
self.cursor_pos = click_offset.min(buffer.len());
} else if self.focused_id == Some(id) {
self.focused_id = None;
}
}
let mut changed = false;
// Process input only if focused
if self.focused_id == Some(id) {
// Ensure cursor_pos is valid
if self.cursor_pos > buffer.len() {
self.cursor_pos = buffer.len();
}
// Process character input
let chars = self.drain_chars();
for ch in chars {
buffer.insert(self.cursor_pos, ch);
self.cursor_pos += 1;
changed = true;
}
// Process key input
let keys = self.drain_keys();
for key in keys {
match key {
Key::Backspace => {
if self.cursor_pos > 0 {
buffer.remove(self.cursor_pos - 1);
self.cursor_pos -= 1;
changed = true;
}
}
Key::Delete => {
if self.cursor_pos < buffer.len() {
buffer.remove(self.cursor_pos);
changed = true;
}
}
Key::Left => {
if self.cursor_pos > 0 {
self.cursor_pos -= 1;
}
}
Key::Right => {
if self.cursor_pos < buffer.len() {
self.cursor_pos += 1;
}
}
Key::Home => {
self.cursor_pos = 0;
}
Key::End => {
self.cursor_pos = buffer.len();
}
}
}
}
// Draw border
let border_color = if self.focused_id == Some(id) {
COLOR_INPUT_FOCUSED
} else {
COLOR_INPUT_BORDER
};
self.draw_list.add_rect(x, y, width, height, border_color);
// Draw inner background (1px border)
self.draw_list.add_rect(x + 1.0, y + 1.0, width - 2.0, height - 2.0, COLOR_INPUT_BG);
// Draw text
let text_x = x + padding;
let text_y = y + padding;
let mut cx = text_x;
for ch in buffer.chars() {
let (u0, v0, u1, v1) = self.font.glyph_uv(ch);
self.draw_list.add_rect_uv(cx, text_y, gw, gh, u0, v0, u1, v1, COLOR_TEXT);
cx += gw;
}
// Draw cursor if focused
if self.focused_id == Some(id) {
let cursor_x = text_x + self.cursor_pos as f32 * gw;
self.draw_list.add_rect(cursor_x, text_y, 1.0, gh, COLOR_CURSOR);
}
self.layout.advance_line();
changed
}
// ── Scroll Panel ──────────────────────────────────────────────────
/// Begin a scrollable panel. Content drawn between begin/end will be clipped
/// to the panel bounds. `content_height` is the total height of the content
/// inside the panel (used to compute scrollbar size).
pub fn begin_scroll_panel(&mut self, id: u32, x: f32, y: f32, w: f32, h: f32, content_height: f32) {
let scrollbar_w = 12.0_f32;
let panel_inner_w = w - scrollbar_w;
// Handle mouse wheel when hovering over the panel
let hovered = self.mouse_in_rect(x, y, w, h);
let scroll_delta = if hovered && self.mouse_scroll.abs() > 0.0 {
-self.mouse_scroll * 20.0
} else {
0.0
};
// Get or create scroll offset, apply delta and clamp
let scroll = self.scroll_offsets.entry(id).or_insert(0.0);
*scroll += scroll_delta;
let max_scroll = (content_height - h).max(0.0);
*scroll = scroll.clamp(0.0, max_scroll);
let current_scroll = *scroll;
// Draw panel background
self.draw_list.add_rect(x, y, w, h, COLOR_PANEL);
// Draw scrollbar track
let sb_x = x + panel_inner_w;
self.draw_list.add_rect(sb_x, y, scrollbar_w, h, COLOR_SCROLLBAR_BG);
// Draw scrollbar thumb
if content_height > h {
let thumb_ratio = h / content_height;
let thumb_h = (thumb_ratio * h).max(16.0);
let scroll_ratio = if max_scroll > 0.0 { current_scroll / max_scroll } else { 0.0 };
let thumb_y = y + scroll_ratio * (h - thumb_h);
self.draw_list.add_rect(sb_x, thumb_y, scrollbar_w, thumb_h, COLOR_SCROLLBAR_THUMB);
}
// Push scissor rect for content clipping
self.draw_list.push_scissor(x as u32, y as u32, panel_inner_w as u32, h as u32);
// Set cursor inside panel, offset by scroll
self.layout = crate::layout::LayoutState::new(x + self.layout.padding, y + self.layout.padding - current_scroll);
}
/// End a scrollable panel. Pops the scissor rect.
pub fn end_scroll_panel(&mut self) {
self.draw_list.pop_scissor();
}
// ── Drag and Drop ─────────────────────────────────────────────────
/// Begin dragging an item. Call this when the user presses down on a draggable element.
/// `id` identifies the source, `payload` is an arbitrary u64 value transferred on drop.
pub fn begin_drag(&mut self, id: u32, payload: u64) {
if self.mouse_clicked {
self.dragging = Some((id, payload));
self.drag_start = (self.mouse_x, self.mouse_y);
self.drag_started = false;
}
}
/// Returns true if a drag operation is currently in progress (past the threshold).
pub fn is_dragging(&self) -> bool {
if let Some(_) = self.dragging {
self.drag_started
} else {
false
}
}
/// End the current drag operation. Returns `Some((source_id, payload))` if a drag
/// was in progress and the mouse was released, otherwise `None`.
pub fn end_drag(&mut self) -> Option<(u32, u64)> {
// Update drag started state based on threshold
if let Some(_) = self.dragging {
if !self.drag_started {
let dx = self.mouse_x - self.drag_start.0;
let dy = self.mouse_y - self.drag_start.1;
if (dx * dx + dy * dy).sqrt() >= DRAG_THRESHOLD {
self.drag_started = true;
}
}
}
if self.mouse_released {
let result = if self.drag_started { self.dragging } else { None };
self.dragging = None;
self.drag_started = false;
result
} else {
None
}
}
/// Declare a drop target region. If a drag is released over this target,
/// returns the payload that was dropped. Otherwise returns `None`.
pub fn drop_target(&mut self, _id: u32, x: f32, y: f32, w: f32, h: f32) -> Option<u64> {
if self.mouse_released && self.drag_started {
if self.mouse_in_rect(x, y, w, h) {
if let Some((_src_id, payload)) = self.dragging {
return Some(payload);
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use crate::ui_context::UiContext;
use crate::ui_context::{Key, UiContext};
#[test]
fn test_button_returns_false_when_not_clicked() {
@@ -278,4 +505,241 @@ mod tests {
let v2 = ctx.slider("test", -10.0, 0.0, 100.0);
assert!((v2 - 0.0).abs() < 1e-6, "slider should clamp to min: got {}", v2);
}
// ── Text Input Tests ──────────────────────────────────────────────
#[test]
fn test_text_input_basic_typing() {
let mut ctx = UiContext::new(800.0, 600.0);
let mut buf = String::new();
// Click on the text input to focus it (at x=10, y=10, width=200)
ctx.begin_frame(15.0, 15.0, true);
ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
// Now type some characters
ctx.begin_frame(15.0, 15.0, false);
ctx.input_char('H');
ctx.input_char('i');
let changed = ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert!(changed);
assert_eq!(buf, "Hi");
}
#[test]
fn test_text_input_backspace() {
let mut ctx = UiContext::new(800.0, 600.0);
let mut buf = String::from("abc");
// Focus — click far right so cursor goes to end (padding=4, gw=8, 3 chars → need x > 10+4+24=38)
ctx.begin_frame(50.0, 15.0, true);
ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
// Backspace
ctx.begin_frame(50.0, 15.0, false);
ctx.input_key(Key::Backspace);
let changed = ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert!(changed);
assert_eq!(buf, "ab");
}
#[test]
fn test_text_input_cursor_movement() {
let mut ctx = UiContext::new(800.0, 600.0);
let mut buf = String::from("abc");
// Focus — click far right so cursor goes to end
ctx.begin_frame(50.0, 15.0, true);
ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert_eq!(ctx.cursor_pos, 3); // cursor at end of "abc"
// Move cursor to beginning with Home
ctx.begin_frame(15.0, 15.0, false);
ctx.input_key(Key::Home);
ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert_eq!(ctx.cursor_pos, 0);
// Type 'X' at beginning
ctx.begin_frame(15.0, 15.0, false);
ctx.input_char('X');
let changed = ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert!(changed);
assert_eq!(buf, "Xabc");
assert_eq!(ctx.cursor_pos, 1);
}
#[test]
fn test_text_input_delete_key() {
let mut ctx = UiContext::new(800.0, 600.0);
let mut buf = String::from("abc");
// Focus
ctx.begin_frame(15.0, 15.0, true);
ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
// Move to Home, then Delete
ctx.begin_frame(15.0, 15.0, false);
ctx.input_key(Key::Home);
ctx.input_key(Key::Delete);
let changed = ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert!(changed);
assert_eq!(buf, "bc");
}
#[test]
fn test_text_input_arrow_keys() {
let mut ctx = UiContext::new(800.0, 600.0);
let mut buf = String::from("hello");
// Focus — click far right so cursor at end
ctx.begin_frame(100.0, 15.0, true);
ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
// Left twice from end (pos 5→3)
ctx.begin_frame(100.0, 15.0, false);
ctx.input_key(Key::Left);
ctx.input_key(Key::Left);
ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert_eq!(ctx.cursor_pos, 3);
// Type 'X' at position 3
ctx.begin_frame(100.0, 15.0, false);
ctx.input_char('X');
let changed = ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert!(changed);
assert_eq!(buf, "helXlo");
}
#[test]
fn test_text_input_no_change_when_not_focused() {
let mut ctx = UiContext::new(800.0, 600.0);
let mut buf = String::from("test");
// Don't click on the input — mouse at (500, 500) far away
ctx.begin_frame(500.0, 500.0, true);
ctx.input_char('X');
let changed = ctx.text_input(1, &mut buf, 10.0, 10.0, 200.0);
assert!(!changed);
assert_eq!(buf, "test");
}
// ── Scroll Panel Tests ────────────────────────────────────────────
#[test]
fn test_scroll_offset_clamping() {
let mut ctx = UiContext::new(800.0, 600.0);
// Panel at (0,0), 200x100, content_height=300
// Scroll down a lot
ctx.begin_frame(100.0, 50.0, false);
ctx.set_scroll(-100.0); // scroll down
ctx.begin_scroll_panel(1, 0.0, 0.0, 200.0, 100.0, 300.0);
ctx.end_scroll_panel();
// max_scroll = 300 - 100 = 200; scroll should be clamped
let scroll = ctx.scroll_offsets.get(&1).copied().unwrap_or(0.0);
assert!(scroll >= 0.0 && scroll <= 200.0, "scroll={}", scroll);
}
#[test]
fn test_scroll_offset_does_not_go_negative() {
let mut ctx = UiContext::new(800.0, 600.0);
// Scroll up when already at top
ctx.begin_frame(100.0, 50.0, false);
ctx.set_scroll(100.0); // scroll up
ctx.begin_scroll_panel(1, 0.0, 0.0, 200.0, 100.0, 300.0);
ctx.end_scroll_panel();
let scroll = ctx.scroll_offsets.get(&1).copied().unwrap_or(0.0);
assert!((scroll - 0.0).abs() < 1e-6, "scroll should be 0, got {}", scroll);
}
#[test]
fn test_scroll_panel_content_clipping() {
let mut ctx = UiContext::new(800.0, 600.0);
ctx.begin_frame(100.0, 50.0, false);
ctx.begin_scroll_panel(1, 10.0, 20.0, 200.0, 100.0, 300.0);
// Draw some content inside
ctx.text("Inside scroll");
// Commands drawn inside should have a scissor rect
let has_scissor = ctx.draw_list.commands.iter().any(|c| c.scissor.is_some());
assert!(has_scissor, "commands inside scroll panel should have scissor rects");
ctx.end_scroll_panel();
// Commands drawn after end_scroll_panel should NOT have scissor
let cmds_before = ctx.draw_list.commands.len();
ctx.text("Outside scroll");
let new_cmds = &ctx.draw_list.commands[cmds_before..];
let has_scissor_after = new_cmds.iter().any(|c| c.scissor.is_some());
assert!(!has_scissor_after, "commands after end_scroll_panel should not have scissor");
}
// ── Drag and Drop Tests ──────────────────────────────────────────
#[test]
fn test_drag_start_and_end() {
let mut ctx = UiContext::new(800.0, 600.0);
// Frame 1: mouse down — begin drag
ctx.begin_frame(100.0, 100.0, true);
ctx.begin_drag(1, 42);
assert!(!ctx.is_dragging(), "should not be dragging yet (below threshold)");
let _ = ctx.end_drag();
// Frame 2: mouse moved past threshold, still down
ctx.begin_frame(110.0, 100.0, true);
let _ = ctx.end_drag();
assert!(ctx.is_dragging(), "should be dragging after moving past threshold");
// Frame 3: mouse released
ctx.begin_frame(120.0, 100.0, false);
let result = ctx.end_drag();
assert!(result.is_some());
let (src_id, payload) = result.unwrap();
assert_eq!(src_id, 1);
assert_eq!(payload, 42);
}
#[test]
fn test_drop_on_target() {
let mut ctx = UiContext::new(800.0, 600.0);
// Frame 1: begin drag
ctx.begin_frame(100.0, 100.0, true);
ctx.begin_drag(1, 99);
let _ = ctx.end_drag();
// Frame 2: move past threshold
ctx.begin_frame(110.0, 100.0, true);
let _ = ctx.end_drag();
// Frame 3: release over drop target at (200, 200, 50, 50)
ctx.begin_frame(220.0, 220.0, false);
let drop_result = ctx.drop_target(2, 200.0, 200.0, 50.0, 50.0);
assert_eq!(drop_result, Some(99));
let _ = ctx.end_drag();
}
#[test]
fn test_drop_outside_target() {
let mut ctx = UiContext::new(800.0, 600.0);
// Frame 1: begin drag
ctx.begin_frame(100.0, 100.0, true);
ctx.begin_drag(1, 77);
let _ = ctx.end_drag();
// Frame 2: move past threshold
ctx.begin_frame(110.0, 100.0, true);
let _ = ctx.end_drag();
// Frame 3: release far from drop target
ctx.begin_frame(500.0, 500.0, false);
let drop_result = ctx.drop_target(2, 200.0, 200.0, 50.0, 50.0);
assert_eq!(drop_result, None);
}
}

View File

@@ -174,6 +174,64 @@ impl Mat4 {
)
}
/// Compute the inverse of this matrix. Returns `None` if the matrix is singular.
pub fn inverse(&self) -> Option<Self> {
let m = &self.cols;
// Flatten to row-major for cofactor expansion
// m[col][row] — so element (row, col) = m[col][row]
let e = |r: usize, c: usize| -> f32 { m[c][r] };
// Compute cofactors using 2x2 determinants
let s0 = e(0,0) * e(1,1) - e(1,0) * e(0,1);
let s1 = e(0,0) * e(1,2) - e(1,0) * e(0,2);
let s2 = e(0,0) * e(1,3) - e(1,0) * e(0,3);
let s3 = e(0,1) * e(1,2) - e(1,1) * e(0,2);
let s4 = e(0,1) * e(1,3) - e(1,1) * e(0,3);
let s5 = e(0,2) * e(1,3) - e(1,2) * e(0,3);
let c5 = e(2,2) * e(3,3) - e(3,2) * e(2,3);
let c4 = e(2,1) * e(3,3) - e(3,1) * e(2,3);
let c3 = e(2,1) * e(3,2) - e(3,1) * e(2,2);
let c2 = e(2,0) * e(3,3) - e(3,0) * e(2,3);
let c1 = e(2,0) * e(3,2) - e(3,0) * e(2,2);
let c0 = e(2,0) * e(3,1) - e(3,0) * e(2,1);
let det = s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0;
if det.abs() < 1e-12 {
return None;
}
let inv_det = 1.0 / det;
// Adjugate matrix (transposed cofactor matrix), stored column-major
let inv = Self::from_cols(
[
( e(1,1) * c5 - e(1,2) * c4 + e(1,3) * c3) * inv_det,
(-e(0,1) * c5 + e(0,2) * c4 - e(0,3) * c3) * inv_det,
( e(3,1) * s5 - e(3,2) * s4 + e(3,3) * s3) * inv_det,
(-e(2,1) * s5 + e(2,2) * s4 - e(2,3) * s3) * inv_det,
],
[
(-e(1,0) * c5 + e(1,2) * c2 - e(1,3) * c1) * inv_det,
( e(0,0) * c5 - e(0,2) * c2 + e(0,3) * c1) * inv_det,
(-e(3,0) * s5 + e(3,2) * s2 - e(3,3) * s1) * inv_det,
( e(2,0) * s5 - e(2,2) * s2 + e(2,3) * s1) * inv_det,
],
[
( e(1,0) * c4 - e(1,1) * c2 + e(1,3) * c0) * inv_det,
(-e(0,0) * c4 + e(0,1) * c2 - e(0,3) * c0) * inv_det,
( e(3,0) * s4 - e(3,1) * s2 + e(3,3) * s0) * inv_det,
(-e(2,0) * s4 + e(2,1) * s2 - e(2,3) * s0) * inv_det,
],
[
(-e(1,0) * c3 + e(1,1) * c1 - e(1,2) * c0) * inv_det,
( e(0,0) * c3 - e(0,1) * c1 + e(0,2) * c0) * inv_det,
(-e(3,0) * s3 + e(3,1) * s1 - e(3,2) * s0) * inv_det,
( e(2,0) * s3 - e(2,1) * s1 + e(2,2) * s0) * inv_det,
],
);
Some(inv)
}
/// Return the transpose of this matrix.
pub fn transpose(&self) -> Self {
let c = &self.cols;

View File

@@ -0,0 +1,140 @@
/// Simple XOR cipher with rotating key + sequence counter.
pub struct PacketCipher {
key: Vec<u8>,
send_counter: u64,
recv_counter: u64,
}
impl PacketCipher {
pub fn new(key: &[u8]) -> Self {
assert!(!key.is_empty(), "encryption key must not be empty");
PacketCipher { key: key.to_vec(), send_counter: 0, recv_counter: 0 }
}
/// Encrypt data in-place. Prepends 8-byte sequence number.
pub fn encrypt(&mut self, plaintext: &[u8]) -> Vec<u8> {
let mut output = Vec::with_capacity(8 + plaintext.len());
// Prepend sequence counter
output.extend_from_slice(&self.send_counter.to_le_bytes());
// XOR plaintext with key derived from counter + base key
let derived = self.derive_key(self.send_counter);
for (i, &byte) in plaintext.iter().enumerate() {
output.push(byte ^ derived[i % derived.len()]);
}
self.send_counter += 1;
output
}
/// Decrypt data. Validates sequence number.
pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, String> {
if ciphertext.len() < 8 {
return Err("packet too short".to_string());
}
let seq = u64::from_le_bytes(ciphertext[0..8].try_into().unwrap());
// Anti-replay: sequence must be >= expected
if seq < self.recv_counter {
return Err(format!("replay detected: got seq {}, expected >= {}", seq, self.recv_counter));
}
self.recv_counter = seq + 1;
let derived = self.derive_key(seq);
let mut plaintext = Vec::with_capacity(ciphertext.len() - 8);
for (i, &byte) in ciphertext[8..].iter().enumerate() {
plaintext.push(byte ^ derived[i % derived.len()]);
}
Ok(plaintext)
}
/// Derive a key from the base key + counter.
fn derive_key(&self, counter: u64) -> Vec<u8> {
let counter_bytes = counter.to_le_bytes();
self.key.iter().enumerate().map(|(i, &k)| {
k.wrapping_add(counter_bytes[i % 8])
}).collect()
}
}
/// Simple token-based authentication.
pub struct AuthToken {
pub player_id: u32,
pub token: Vec<u8>,
pub expires_at: f64, // timestamp
}
impl AuthToken {
/// Generate a simple auth token from player_id + secret.
pub fn generate(player_id: u32, secret: &[u8], expires_at: f64) -> Self {
let mut token = Vec::new();
token.extend_from_slice(&player_id.to_le_bytes());
token.extend_from_slice(&expires_at.to_le_bytes());
// Simple HMAC-like: XOR with secret
for (i, byte) in token.iter_mut().enumerate() {
*byte ^= secret[i % secret.len()];
}
AuthToken { player_id, token, expires_at }
}
/// Validate token against secret.
pub fn validate(&self, secret: &[u8], current_time: f64) -> bool {
if current_time > self.expires_at { return false; }
let expected = AuthToken::generate(self.player_id, secret, self.expires_at);
self.token == expected.token
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = b"secret_key_1234";
let mut encryptor = PacketCipher::new(key);
let mut decryptor = PacketCipher::new(key);
let msg = b"hello world";
let encrypted = encryptor.encrypt(msg);
let decrypted = decryptor.decrypt(&encrypted).unwrap();
assert_eq!(&decrypted, msg);
}
#[test]
fn test_encrypted_differs_from_plain() {
let mut cipher = PacketCipher::new(b"key");
let msg = b"test message";
let encrypted = cipher.encrypt(msg);
assert_ne!(&encrypted[8..], msg); // ciphertext differs
}
#[test]
fn test_replay_rejected() {
let key = b"key";
let mut enc = PacketCipher::new(key);
let mut dec = PacketCipher::new(key);
let pkt1 = enc.encrypt(b"first");
let pkt2 = enc.encrypt(b"second");
let _ = dec.decrypt(&pkt2).unwrap(); // accept pkt2 (seq=1)
let result = dec.decrypt(&pkt1); // pkt1 has seq=0 < expected 2
assert!(result.is_err());
}
#[test]
fn test_auth_token_valid() {
let secret = b"server_secret";
let token = AuthToken::generate(42, secret, 1000.0);
assert!(token.validate(secret, 999.0));
}
#[test]
fn test_auth_token_expired() {
let secret = b"server_secret";
let token = AuthToken::generate(42, secret, 1000.0);
assert!(!token.validate(secret, 1001.0));
}
#[test]
fn test_auth_token_wrong_secret() {
let token = AuthToken::generate(42, b"correct", 1000.0);
assert!(!token.validate(b"wronggg", 999.0));
}
}

View File

@@ -0,0 +1,234 @@
use std::collections::VecDeque;
use crate::snapshot::{EntityState, Snapshot};
/// Buffers recent snapshots and interpolates between them for smooth rendering.
pub struct InterpolationBuffer {
snapshots: VecDeque<(f64, Snapshot)>,
/// Render delay behind the latest server time (seconds).
interp_delay: f64,
/// Maximum number of snapshots to keep in the buffer.
max_snapshots: usize,
}
impl InterpolationBuffer {
/// Create a new interpolation buffer with the given delay in seconds.
pub fn new(interp_delay: f64) -> Self {
InterpolationBuffer {
snapshots: VecDeque::new(),
interp_delay,
max_snapshots: 32,
}
}
/// Push a new snapshot with its server timestamp.
pub fn push(&mut self, server_time: f64, snapshot: Snapshot) {
self.snapshots.push_back((server_time, snapshot));
// Evict old snapshots beyond the buffer limit
while self.snapshots.len() > self.max_snapshots {
self.snapshots.pop_front();
}
}
/// Interpolate to produce a snapshot for the given render_time.
///
/// The render_time should be `current_server_time - interp_delay`.
/// Returns None if there are fewer than 2 snapshots or render_time
/// is before all buffered snapshots.
pub fn interpolate(&self, render_time: f64) -> Option<Snapshot> {
if self.snapshots.len() < 2 {
return None;
}
// Find two bracketing snapshots: the last one <= render_time and the first one > render_time
let mut before = None;
let mut after = None;
for (i, (time, _)) in self.snapshots.iter().enumerate() {
if *time <= render_time {
before = Some(i);
} else {
after = Some(i);
break;
}
}
match (before, after) {
(Some(b), Some(a)) => {
let (t0, snap0) = &self.snapshots[b];
let (t1, snap1) = &self.snapshots[a];
let dt = t1 - t0;
if dt <= 0.0 {
return Some(snap0.clone());
}
let alpha = ((render_time - t0) / dt).clamp(0.0, 1.0) as f32;
Some(lerp_snapshots(snap0, snap1, alpha))
}
(Some(b), None) => {
// render_time is beyond all snapshots — return the latest
Some(self.snapshots[b].1.clone())
}
_ => None,
}
}
/// Get the interpolation delay.
pub fn delay(&self) -> f64 {
self.interp_delay
}
}
fn lerp(a: f32, b: f32, t: f32) -> f32 {
a + (b - a) * t
}
fn lerp_f32x3(a: &[f32; 3], b: &[f32; 3], t: f32) -> [f32; 3] {
[lerp(a[0], b[0], t), lerp(a[1], b[1], t), lerp(a[2], b[2], t)]
}
fn lerp_entity(a: &EntityState, b: &EntityState, t: f32) -> EntityState {
EntityState {
id: a.id,
position: lerp_f32x3(&a.position, &b.position, t),
rotation: lerp_f32x3(&a.rotation, &b.rotation, t),
velocity: lerp_f32x3(&a.velocity, &b.velocity, t),
}
}
/// Linearly interpolate between two snapshots.
/// Entities are matched by id. Entities only in one snapshot are included as-is.
fn lerp_snapshots(a: &Snapshot, b: &Snapshot, t: f32) -> Snapshot {
use std::collections::HashMap;
let a_map: HashMap<u32, &EntityState> = a.entities.iter().map(|e| (e.id, e)).collect();
let b_map: HashMap<u32, &EntityState> = b.entities.iter().map(|e| (e.id, e)).collect();
let mut entities = Vec::new();
// Interpolate matched entities, include a-only entities
for ea in &a.entities {
if let Some(eb) = b_map.get(&ea.id) {
entities.push(lerp_entity(ea, eb, t));
} else {
entities.push(ea.clone());
}
}
// Include b-only entities
for eb in &b.entities {
if !a_map.contains_key(&eb.id) {
entities.push(eb.clone());
}
}
// Interpolate tick
let tick = (a.tick as f64 + (b.tick as f64 - a.tick as f64) * t as f64) as u32;
Snapshot { tick, entities }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::snapshot::EntityState;
fn make_snapshot(tick: u32, x: f32) -> Snapshot {
Snapshot {
tick,
entities: vec![EntityState {
id: 1,
position: [x, 0.0, 0.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
}
}
#[test]
fn test_exact_match_at_snapshot_time() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(0.1, make_snapshot(1, 10.0));
let result = buf.interpolate(0.0).expect("should interpolate");
assert_eq!(result.entities[0].position[0], 0.0);
}
#[test]
fn test_midpoint_interpolation() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(10, 10.0));
let result = buf.interpolate(0.5).expect("should interpolate");
let x = result.entities[0].position[0];
assert!(
(x - 5.0).abs() < 0.001,
"Expected ~5.0 at midpoint, got {}",
x
);
}
#[test]
fn test_interpolation_at_quarter() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(10, 100.0));
let result = buf.interpolate(0.25).unwrap();
let x = result.entities[0].position[0];
assert!(
(x - 25.0).abs() < 0.01,
"Expected ~25.0 at 0.25, got {}",
x
);
}
#[test]
fn test_extrapolation_returns_latest() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(10, 10.0));
// render_time beyond all snapshots
let result = buf.interpolate(2.0).expect("should return latest");
assert_eq!(result.entities[0].position[0], 10.0);
}
#[test]
fn test_too_few_snapshots_returns_none() {
let mut buf = InterpolationBuffer::new(0.1);
assert!(buf.interpolate(0.0).is_none());
buf.push(0.0, make_snapshot(0, 0.0));
assert!(buf.interpolate(0.0).is_none());
}
#[test]
fn test_render_time_before_all_snapshots() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(1.0, make_snapshot(10, 10.0));
buf.push(2.0, make_snapshot(20, 20.0));
// render_time before the first snapshot
let result = buf.interpolate(0.0);
assert!(result.is_none());
}
#[test]
fn test_multiple_snapshots_picks_correct_bracket() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(1, 10.0));
buf.push(2.0, make_snapshot(2, 20.0));
// Should interpolate between snapshot at t=1 and t=2
let result = buf.interpolate(1.5).unwrap();
let x = result.entities[0].position[0];
assert!(
(x - 15.0).abs() < 0.01,
"Expected ~15.0, got {}",
x
);
}
}

View File

@@ -0,0 +1,140 @@
use std::collections::VecDeque;
/// A timestamped snapshot of entity positions for lag compensation.
#[derive(Debug, Clone)]
pub struct HistoryEntry {
pub tick: u64,
pub timestamp: f32, // seconds
pub positions: Vec<([f32; 3], u32)>, // (position, entity_id)
}
/// Stores recent world state history for server-side lag compensation.
pub struct LagCompensation {
history: VecDeque<HistoryEntry>,
pub max_history_ms: f32, // max history duration (e.g., 200ms)
}
impl LagCompensation {
pub fn new(max_history_ms: f32) -> Self {
LagCompensation { history: VecDeque::new(), max_history_ms }
}
/// Record current world state.
pub fn record(&mut self, entry: HistoryEntry) {
self.history.push_back(entry);
// Prune old entries
let cutoff = self.history.back().map(|e| e.timestamp - self.max_history_ms / 1000.0).unwrap_or(0.0);
while let Some(front) = self.history.front() {
if front.timestamp < cutoff {
self.history.pop_front();
} else {
break;
}
}
}
/// Find the closest history entry to the given timestamp.
pub fn rewind(&self, timestamp: f32) -> Option<&HistoryEntry> {
let mut best: Option<&HistoryEntry> = None;
let mut best_diff = f32::MAX;
for entry in &self.history {
let diff = (entry.timestamp - timestamp).abs();
if diff < best_diff {
best_diff = diff;
best = Some(entry);
}
}
best
}
/// Interpolate between two closest entries at the given timestamp.
pub fn rewind_interpolated(&self, timestamp: f32) -> Option<Vec<([f32; 3], u32)>> {
if self.history.len() < 2 { return self.rewind(timestamp).map(|e| e.positions.clone()); }
// Find the two entries bracketing the timestamp
let mut before: Option<&HistoryEntry> = None;
let mut after: Option<&HistoryEntry> = None;
for entry in &self.history {
if entry.timestamp <= timestamp {
before = Some(entry);
} else {
after = Some(entry);
break;
}
}
match (before, after) {
(Some(b), Some(a)) => {
let t = if (a.timestamp - b.timestamp).abs() > 1e-6 {
(timestamp - b.timestamp) / (a.timestamp - b.timestamp)
} else { 0.0 };
// Interpolate positions
let mut result = Vec::new();
for (i, (pos_b, id)) in b.positions.iter().enumerate() {
if let Some((pos_a, _)) = a.positions.get(i) {
let lerped = [
pos_b[0] + (pos_a[0] - pos_b[0]) * t,
pos_b[1] + (pos_a[1] - pos_b[1]) * t,
pos_b[2] + (pos_a[2] - pos_b[2]) * t,
];
result.push((lerped, *id));
}
}
Some(result)
}
_ => self.rewind(timestamp).map(|e| e.positions.clone()),
}
}
pub fn history_len(&self) -> usize { self.history.len() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_record_and_rewind() {
let mut lc = LagCompensation::new(200.0);
lc.record(HistoryEntry { tick: 1, timestamp: 0.0, positions: vec![([1.0, 0.0, 0.0], 0)] });
lc.record(HistoryEntry { tick: 2, timestamp: 0.016, positions: vec![([2.0, 0.0, 0.0], 0)] });
let entry = lc.rewind(0.0).unwrap();
assert_eq!(entry.tick, 1);
}
#[test]
fn test_prune_old() {
let mut lc = LagCompensation::new(100.0); // 100ms
for i in 0..20 {
lc.record(HistoryEntry { tick: i, timestamp: i as f32 * 0.016, positions: vec![] });
}
// At t=0.304, cutoff = 0.304 - 0.1 = 0.204
// Entries before t=0.204 should be pruned
assert!(lc.history_len() < 20);
}
#[test]
fn test_rewind_closest() {
let mut lc = LagCompensation::new(500.0);
lc.record(HistoryEntry { tick: 1, timestamp: 0.0, positions: vec![] });
lc.record(HistoryEntry { tick: 2, timestamp: 0.1, positions: vec![] });
lc.record(HistoryEntry { tick: 3, timestamp: 0.2, positions: vec![] });
let entry = lc.rewind(0.09).unwrap();
assert_eq!(entry.tick, 2); // closest to 0.1
}
#[test]
fn test_rewind_interpolated() {
let mut lc = LagCompensation::new(500.0);
lc.record(HistoryEntry { tick: 1, timestamp: 0.0, positions: vec![([0.0, 0.0, 0.0], 0)] });
lc.record(HistoryEntry { tick: 2, timestamp: 0.1, positions: vec![([10.0, 0.0, 0.0], 0)] });
let interp = lc.rewind_interpolated(0.05).unwrap();
assert!((interp[0].0[0] - 5.0).abs() < 0.1); // midpoint
}
#[test]
fn test_empty_history() {
let lc = LagCompensation::new(200.0);
assert!(lc.rewind(0.0).is_none());
}
}

View File

@@ -2,8 +2,18 @@ pub mod packet;
pub mod socket;
pub mod server;
pub mod client;
pub mod reliable;
pub mod snapshot;
pub mod interpolation;
pub mod lag_compensation;
pub mod encryption;
pub use packet::Packet;
pub use socket::NetSocket;
pub use server::{NetServer, ServerEvent, ClientInfo};
pub use client::{NetClient, ClientEvent};
pub use reliable::{ReliableChannel, OrderedChannel};
pub use snapshot::{Snapshot, EntityState, serialize_snapshot, deserialize_snapshot, diff_snapshots, apply_diff};
pub use interpolation::InterpolationBuffer;
pub use lag_compensation::LagCompensation;
pub use encryption::{PacketCipher, AuthToken};

View File

@@ -5,6 +5,10 @@ const TYPE_DISCONNECT: u8 = 3;
const TYPE_PING: u8 = 4;
const TYPE_PONG: u8 = 5;
const TYPE_USER_DATA: u8 = 6;
const TYPE_RELIABLE: u8 = 7;
const TYPE_ACK: u8 = 8;
const TYPE_SNAPSHOT: u8 = 9;
const TYPE_SNAPSHOT_DELTA: u8 = 10;
/// Header size: type_id(1) + payload_len(2 LE) + reserved(1) = 4 bytes
const HEADER_SIZE: usize = 4;
@@ -18,6 +22,10 @@ pub enum Packet {
Ping { timestamp: u64 },
Pong { timestamp: u64 },
UserData { client_id: u32, data: Vec<u8> },
Reliable { sequence: u16, data: Vec<u8> },
Ack { sequence: u16 },
Snapshot { tick: u32, data: Vec<u8> },
SnapshotDelta { base_tick: u32, tick: u32, data: Vec<u8> },
}
impl Packet {
@@ -112,6 +120,38 @@ impl Packet {
let data = payload[4..].to_vec();
Ok(Packet::UserData { client_id, data })
}
TYPE_RELIABLE => {
if payload.len() < 2 {
return Err("Reliable payload too short".to_string());
}
let sequence = u16::from_le_bytes([payload[0], payload[1]]);
let data = payload[2..].to_vec();
Ok(Packet::Reliable { sequence, data })
}
TYPE_ACK => {
if payload.len() < 2 {
return Err("Ack payload too short".to_string());
}
let sequence = u16::from_le_bytes([payload[0], payload[1]]);
Ok(Packet::Ack { sequence })
}
TYPE_SNAPSHOT => {
if payload.len() < 4 {
return Err("Snapshot payload too short".to_string());
}
let tick = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let data = payload[4..].to_vec();
Ok(Packet::Snapshot { tick, data })
}
TYPE_SNAPSHOT_DELTA => {
if payload.len() < 8 {
return Err("SnapshotDelta payload too short".to_string());
}
let base_tick = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let tick = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]);
let data = payload[8..].to_vec();
Ok(Packet::SnapshotDelta { base_tick, tick, data })
}
_ => Err(format!("Unknown packet type_id: {}", type_id)),
}
}
@@ -124,6 +164,10 @@ impl Packet {
Packet::Ping { .. } => TYPE_PING,
Packet::Pong { .. } => TYPE_PONG,
Packet::UserData { .. } => TYPE_USER_DATA,
Packet::Reliable { .. } => TYPE_RELIABLE,
Packet::Ack { .. } => TYPE_ACK,
Packet::Snapshot { .. } => TYPE_SNAPSHOT,
Packet::SnapshotDelta { .. } => TYPE_SNAPSHOT_DELTA,
}
}
@@ -147,6 +191,26 @@ impl Packet {
buf.extend_from_slice(data);
buf
}
Packet::Reliable { sequence, data } => {
let mut buf = Vec::with_capacity(2 + data.len());
buf.extend_from_slice(&sequence.to_le_bytes());
buf.extend_from_slice(data);
buf
}
Packet::Ack { sequence } => sequence.to_le_bytes().to_vec(),
Packet::Snapshot { tick, data } => {
let mut buf = Vec::with_capacity(4 + data.len());
buf.extend_from_slice(&tick.to_le_bytes());
buf.extend_from_slice(data);
buf
}
Packet::SnapshotDelta { base_tick, tick, data } => {
let mut buf = Vec::with_capacity(8 + data.len());
buf.extend_from_slice(&base_tick.to_le_bytes());
buf.extend_from_slice(&tick.to_le_bytes());
buf.extend_from_slice(data);
buf
}
}
}
}
@@ -200,6 +264,36 @@ mod tests {
});
}
#[test]
fn test_reliable_roundtrip() {
roundtrip(Packet::Reliable {
sequence: 42,
data: vec![0xCA, 0xFE],
});
}
#[test]
fn test_ack_roundtrip() {
roundtrip(Packet::Ack { sequence: 100 });
}
#[test]
fn test_snapshot_roundtrip() {
roundtrip(Packet::Snapshot {
tick: 999,
data: vec![1, 2, 3],
});
}
#[test]
fn test_snapshot_delta_roundtrip() {
roundtrip(Packet::SnapshotDelta {
base_tick: 10,
tick: 15,
data: vec![4, 5, 6],
});
}
#[test]
fn test_invalid_type_returns_error() {
// Build a packet with type_id = 99 (unknown)

View File

@@ -0,0 +1,318 @@
use std::collections::{HashMap, HashSet};
use std::time::{Duration, Instant};
/// A channel that provides reliable delivery over unreliable transport.
///
/// Assigns sequence numbers, tracks ACKs, estimates RTT,
/// and retransmits unacknowledged packets after 2x RTT.
pub struct ReliableChannel {
next_sequence: u16,
pending_acks: HashMap<u16, (Instant, Vec<u8>)>,
received_seqs: HashSet<u16>,
rtt: Duration,
/// Outgoing ACK packets that need to be sent by the caller.
outgoing_acks: Vec<u16>,
}
impl ReliableChannel {
pub fn new() -> Self {
ReliableChannel {
next_sequence: 0,
pending_acks: HashMap::new(),
received_seqs: HashSet::new(),
rtt: Duration::from_millis(100), // initial estimate
outgoing_acks: Vec::new(),
}
}
/// Returns the current RTT estimate.
pub fn rtt(&self) -> Duration {
self.rtt
}
/// Returns the number of packets awaiting acknowledgement.
pub fn pending_count(&self) -> usize {
self.pending_acks.len()
}
/// Prepare a reliable send. Returns (sequence_number, wrapped_data).
/// The caller is responsible for actually transmitting the wrapped data.
pub fn send_reliable(&mut self, data: &[u8]) -> (u16, Vec<u8>) {
let seq = self.next_sequence;
self.next_sequence = self.next_sequence.wrapping_add(1);
// Build the reliable packet payload: [seq(2 LE), data...]
let mut buf = Vec::with_capacity(2 + data.len());
buf.extend_from_slice(&seq.to_le_bytes());
buf.extend_from_slice(data);
self.pending_acks.insert(seq, (Instant::now(), buf.clone()));
(seq, buf)
}
/// Process a received reliable packet. Returns the payload data if this is
/// not a duplicate, or None if already received. Queues an ACK to send.
pub fn receive_and_ack(&mut self, sequence: u16, data: &[u8]) -> Option<Vec<u8>> {
// Always queue an ACK, even for duplicates
self.outgoing_acks.push(sequence);
if self.received_seqs.contains(&sequence) {
return None; // duplicate
}
self.received_seqs.insert(sequence);
Some(data.to_vec())
}
/// Process an incoming ACK for a sequence we sent.
pub fn process_ack(&mut self, sequence: u16) {
if let Some((send_time, _)) = self.pending_acks.remove(&sequence) {
let sample = send_time.elapsed();
// Exponential moving average: rtt = 0.875 * rtt + 0.125 * sample
self.rtt = Duration::from_secs_f64(
0.875 * self.rtt.as_secs_f64() + 0.125 * sample.as_secs_f64(),
);
}
}
/// Drain any pending outgoing ACK sequence numbers.
pub fn drain_acks(&mut self) -> Vec<u16> {
std::mem::take(&mut self.outgoing_acks)
}
/// Check for timed-out packets and return their data for retransmission.
/// Resets the send_time for retransmitted packets.
pub fn update(&mut self) -> Vec<Vec<u8>> {
let timeout = self.rtt * 2;
let now = Instant::now();
let mut retransmits = Vec::new();
for (_, (send_time, data)) in self.pending_acks.iter_mut() {
if now.duration_since(*send_time) >= timeout {
retransmits.push(data.clone());
*send_time = now;
}
}
retransmits
}
}
/// A channel that delivers packets in order, built on top of ReliableChannel.
pub struct OrderedChannel {
reliable: ReliableChannel,
next_deliver: u16,
buffer: HashMap<u16, Vec<u8>>,
}
impl OrderedChannel {
pub fn new() -> Self {
OrderedChannel {
reliable: ReliableChannel::new(),
next_deliver: 0,
buffer: HashMap::new(),
}
}
/// Access the underlying reliable channel (e.g., for send_reliable, process_ack, update).
pub fn reliable(&self) -> &ReliableChannel {
&self.reliable
}
/// Access the underlying reliable channel mutably.
pub fn reliable_mut(&mut self) -> &mut ReliableChannel {
&mut self.reliable
}
/// Prepare a reliable, ordered send.
pub fn send(&mut self, data: &[u8]) -> (u16, Vec<u8>) {
self.reliable.send_reliable(data)
}
/// Receive a packet. Buffers out-of-order packets and returns all
/// packets that can now be delivered in sequence order.
pub fn receive(&mut self, sequence: u16, data: &[u8]) -> Vec<Vec<u8>> {
let payload = self.reliable.receive_and_ack(sequence, data);
if let Some(payload) = payload {
self.buffer.insert(sequence, payload);
}
// Deliver as many consecutive packets as possible
let mut delivered = Vec::new();
while let Some(data) = self.buffer.remove(&self.next_deliver) {
delivered.push(data);
self.next_deliver = self.next_deliver.wrapping_add(1);
}
delivered
}
}
#[cfg(test)]
mod tests {
use super::*;
// ---- ReliableChannel tests ----
#[test]
fn test_send_receive_ack_roundtrip() {
let mut sender = ReliableChannel::new();
let mut receiver = ReliableChannel::new();
let original = b"hello world";
let (seq, _buf) = sender.send_reliable(original);
assert_eq!(seq, 0);
assert_eq!(sender.pending_count(), 1);
// Receiver gets the packet
let result = receiver.receive_and_ack(seq, original);
assert_eq!(result, Some(original.to_vec()));
// Receiver queued an ack
let acks = receiver.drain_acks();
assert_eq!(acks, vec![0]);
// Sender processes the ack
sender.process_ack(seq);
assert_eq!(sender.pending_count(), 0);
}
#[test]
fn test_duplicate_rejection() {
let mut receiver = ReliableChannel::new();
let data = b"payload";
let result1 = receiver.receive_and_ack(0, data);
assert!(result1.is_some());
let result2 = receiver.receive_and_ack(0, data);
assert!(result2.is_none(), "Duplicate should be rejected");
// But ACK is still queued for both
let acks = receiver.drain_acks();
assert_eq!(acks.len(), 2);
}
#[test]
fn test_sequence_numbers_increment() {
let mut channel = ReliableChannel::new();
let (s0, _) = channel.send_reliable(b"a");
let (s1, _) = channel.send_reliable(b"b");
let (s2, _) = channel.send_reliable(b"c");
assert_eq!(s0, 0);
assert_eq!(s1, 1);
assert_eq!(s2, 2);
assert_eq!(channel.pending_count(), 3);
}
#[test]
fn test_retransmission_on_timeout() {
let mut channel = ReliableChannel::new();
// Set a very short RTT so timeout (2*RTT) triggers quickly
channel.rtt = Duration::from_millis(1);
let (_seq, _buf) = channel.send_reliable(b"data");
assert_eq!(channel.pending_count(), 1);
// Wait for timeout
std::thread::sleep(Duration::from_millis(10));
let retransmits = channel.update();
assert_eq!(retransmits.len(), 1, "Should retransmit 1 packet");
// Packet is still pending (not acked)
assert_eq!(channel.pending_count(), 1);
}
#[test]
fn test_no_retransmission_before_timeout() {
let mut channel = ReliableChannel::new();
// Default RTT = 100ms, so timeout = 200ms
let (_seq, _buf) = channel.send_reliable(b"data");
// Immediately check — should not retransmit
let retransmits = channel.update();
assert!(retransmits.is_empty());
}
#[test]
fn test_rtt_estimation() {
let mut channel = ReliableChannel::new();
let initial_rtt = channel.rtt();
let (seq, _) = channel.send_reliable(b"x");
std::thread::sleep(Duration::from_millis(5));
channel.process_ack(seq);
// RTT should have changed from initial value
let new_rtt = channel.rtt();
assert_ne!(initial_rtt, new_rtt, "RTT should be updated after ACK");
}
#[test]
fn test_wrapping_sequence() {
let mut channel = ReliableChannel::new();
channel.next_sequence = u16::MAX;
let (s1, _) = channel.send_reliable(b"a");
assert_eq!(s1, u16::MAX);
let (s2, _) = channel.send_reliable(b"b");
assert_eq!(s2, 0); // wrapped
}
// ---- OrderedChannel tests ----
#[test]
fn test_ordered_in_order_delivery() {
let mut channel = OrderedChannel::new();
let delivered0 = channel.receive(0, b"first");
assert_eq!(delivered0, vec![b"first".to_vec()]);
let delivered1 = channel.receive(1, b"second");
assert_eq!(delivered1, vec![b"second".to_vec()]);
}
#[test]
fn test_ordered_out_of_order_delivery() {
let mut channel = OrderedChannel::new();
// Receive seq 1 first (out of order)
let delivered = channel.receive(1, b"second");
assert!(delivered.is_empty(), "Seq 1 should be buffered, waiting for 0");
// Receive seq 2 (still missing 0)
let delivered = channel.receive(2, b"third");
assert!(delivered.is_empty());
// Receive seq 0 — should deliver 0, 1, 2 in order
let delivered = channel.receive(0, b"first");
assert_eq!(delivered.len(), 3);
assert_eq!(delivered[0], b"first");
assert_eq!(delivered[1], b"second");
assert_eq!(delivered[2], b"third");
}
#[test]
fn test_ordered_gap_handling() {
let mut channel = OrderedChannel::new();
// Deliver 0
let d = channel.receive(0, b"a");
assert_eq!(d.len(), 1);
// Skip 1, deliver 2
let d = channel.receive(2, b"c");
assert!(d.is_empty(), "Can't deliver 2 without 1");
// Now deliver 1 — should flush both 1 and 2
let d = channel.receive(1, b"b");
assert_eq!(d.len(), 2);
assert_eq!(d[0], b"b");
assert_eq!(d[1], b"c");
}
}

View File

@@ -0,0 +1,378 @@
/// State of a single entity at a point in time.
#[derive(Debug, Clone, PartialEq)]
pub struct EntityState {
pub id: u32,
pub position: [f32; 3],
pub rotation: [f32; 3],
pub velocity: [f32; 3],
}
/// A snapshot of the world at a given tick.
#[derive(Debug, Clone, PartialEq)]
pub struct Snapshot {
pub tick: u32,
pub entities: Vec<EntityState>,
}
/// Binary size of one entity: id(4) + pos(12) + rot(12) + vel(12) = 40 bytes
const ENTITY_SIZE: usize = 4 + 12 + 12 + 12;
fn write_f32_le(buf: &mut Vec<u8>, v: f32) {
buf.extend_from_slice(&v.to_le_bytes());
}
fn read_f32_le(data: &[u8], offset: usize) -> f32 {
f32::from_le_bytes([data[offset], data[offset + 1], data[offset + 2], data[offset + 3]])
}
fn write_f32x3(buf: &mut Vec<u8>, v: &[f32; 3]) {
write_f32_le(buf, v[0]);
write_f32_le(buf, v[1]);
write_f32_le(buf, v[2]);
}
fn read_f32x3(data: &[u8], offset: usize) -> [f32; 3] {
[
read_f32_le(data, offset),
read_f32_le(data, offset + 4),
read_f32_le(data, offset + 8),
]
}
fn serialize_entity(buf: &mut Vec<u8>, e: &EntityState) {
buf.extend_from_slice(&e.id.to_le_bytes());
write_f32x3(buf, &e.position);
write_f32x3(buf, &e.rotation);
write_f32x3(buf, &e.velocity);
}
fn deserialize_entity(data: &[u8], offset: usize) -> EntityState {
let id = u32::from_le_bytes([
data[offset], data[offset + 1], data[offset + 2], data[offset + 3],
]);
let position = read_f32x3(data, offset + 4);
let rotation = read_f32x3(data, offset + 16);
let velocity = read_f32x3(data, offset + 28);
EntityState { id, position, rotation, velocity }
}
/// Serialize a snapshot into compact binary format.
/// Layout: tick(4 LE) + entity_count(4 LE) + entities...
pub fn serialize_snapshot(snapshot: &Snapshot) -> Vec<u8> {
let count = snapshot.entities.len() as u32;
let mut buf = Vec::with_capacity(8 + ENTITY_SIZE * snapshot.entities.len());
buf.extend_from_slice(&snapshot.tick.to_le_bytes());
buf.extend_from_slice(&count.to_le_bytes());
for e in &snapshot.entities {
serialize_entity(&mut buf, e);
}
buf
}
/// Deserialize a snapshot from binary data.
pub fn deserialize_snapshot(data: &[u8]) -> Result<Snapshot, String> {
if data.len() < 8 {
return Err("Snapshot data too short for header".to_string());
}
let tick = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
let count = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
let expected = 8 + count * ENTITY_SIZE;
if data.len() < expected {
return Err(format!(
"Snapshot data too short: expected {} bytes, got {}",
expected,
data.len()
));
}
let mut entities = Vec::with_capacity(count);
for i in 0..count {
entities.push(deserialize_entity(data, 8 + i * ENTITY_SIZE));
}
Ok(Snapshot { tick, entities })
}
/// Compute a delta between two snapshots.
/// Format: new_tick(4) + count(4) + [id(4) + flags(1) + changed_fields...]
/// Flags bitmask: 0x01 = position, 0x02 = rotation, 0x04 = velocity, 0x80 = new entity (full)
pub fn diff_snapshots(old: &Snapshot, new: &Snapshot) -> Vec<u8> {
use std::collections::HashMap;
let old_map: HashMap<u32, &EntityState> = old.entities.iter().map(|e| (e.id, e)).collect();
let mut entries: Vec<u8> = Vec::new();
let mut count: u32 = 0;
for new_ent in &new.entities {
if let Some(old_ent) = old_map.get(&new_ent.id) {
let mut flags: u8 = 0;
let mut fields = Vec::new();
if new_ent.position != old_ent.position {
flags |= 0x01;
write_f32x3(&mut fields, &new_ent.position);
}
if new_ent.rotation != old_ent.rotation {
flags |= 0x02;
write_f32x3(&mut fields, &new_ent.rotation);
}
if new_ent.velocity != old_ent.velocity {
flags |= 0x04;
write_f32x3(&mut fields, &new_ent.velocity);
}
if flags != 0 {
entries.extend_from_slice(&new_ent.id.to_le_bytes());
entries.push(flags);
entries.extend_from_slice(&fields);
count += 1;
}
} else {
// New entity — send full state
entries.extend_from_slice(&new_ent.id.to_le_bytes());
entries.push(0x80); // "new entity" flag
write_f32x3(&mut entries, &new_ent.position);
write_f32x3(&mut entries, &new_ent.rotation);
write_f32x3(&mut entries, &new_ent.velocity);
count += 1;
}
}
let mut buf = Vec::with_capacity(8 + entries.len());
buf.extend_from_slice(&new.tick.to_le_bytes());
buf.extend_from_slice(&count.to_le_bytes());
buf.extend_from_slice(&entries);
buf
}
/// Apply a delta to a base snapshot to produce an updated snapshot.
pub fn apply_diff(base: &Snapshot, diff: &[u8]) -> Result<Snapshot, String> {
if diff.len() < 8 {
return Err("Diff data too short for header".to_string());
}
let tick = u32::from_le_bytes([diff[0], diff[1], diff[2], diff[3]]);
let count = u32::from_le_bytes([diff[4], diff[5], diff[6], diff[7]]) as usize;
// Start from a clone of the base
let mut entities: Vec<EntityState> = base.entities.clone();
let mut offset = 8;
for _ in 0..count {
if offset + 5 > diff.len() {
return Err("Diff truncated at entry header".to_string());
}
let id = u32::from_le_bytes([diff[offset], diff[offset + 1], diff[offset + 2], diff[offset + 3]]);
let flags = diff[offset + 4];
offset += 5;
if flags & 0x80 != 0 {
// New entity — full state
if offset + 36 > diff.len() {
return Err("Diff truncated at new entity data".to_string());
}
let position = read_f32x3(diff, offset);
let rotation = read_f32x3(diff, offset + 12);
let velocity = read_f32x3(diff, offset + 24);
offset += 36;
// Add or replace
if let Some(ent) = entities.iter_mut().find(|e| e.id == id) {
ent.position = position;
ent.rotation = rotation;
ent.velocity = velocity;
} else {
entities.push(EntityState { id, position, rotation, velocity });
}
} else {
// Delta update — find existing entity
let ent = entities.iter_mut().find(|e| e.id == id)
.ok_or_else(|| format!("Diff references unknown entity {}", id))?;
if flags & 0x01 != 0 {
if offset + 12 > diff.len() {
return Err("Diff truncated at position".to_string());
}
ent.position = read_f32x3(diff, offset);
offset += 12;
}
if flags & 0x02 != 0 {
if offset + 12 > diff.len() {
return Err("Diff truncated at rotation".to_string());
}
ent.rotation = read_f32x3(diff, offset);
offset += 12;
}
if flags & 0x04 != 0 {
if offset + 12 > diff.len() {
return Err("Diff truncated at velocity".to_string());
}
ent.velocity = read_f32x3(diff, offset);
offset += 12;
}
}
}
Ok(Snapshot { tick, entities })
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entity(id: u32, px: f32, py: f32, pz: f32) -> EntityState {
EntityState {
id,
position: [px, py, pz],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}
}
#[test]
fn test_snapshot_roundtrip() {
let snap = Snapshot {
tick: 42,
entities: vec![
make_entity(1, 1.0, 2.0, 3.0),
make_entity(2, 4.0, 5.0, 6.0),
],
};
let bytes = serialize_snapshot(&snap);
let decoded = deserialize_snapshot(&bytes).expect("deserialize failed");
assert_eq!(snap, decoded);
}
#[test]
fn test_snapshot_empty() {
let snap = Snapshot { tick: 0, entities: vec![] };
let bytes = serialize_snapshot(&snap);
assert_eq!(bytes.len(), 8); // just header
let decoded = deserialize_snapshot(&bytes).unwrap();
assert_eq!(snap, decoded);
}
#[test]
fn test_diff_no_changes() {
let snap = Snapshot {
tick: 10,
entities: vec![make_entity(1, 1.0, 2.0, 3.0)],
};
let snap2 = Snapshot {
tick: 11,
entities: vec![make_entity(1, 1.0, 2.0, 3.0)],
};
let diff = diff_snapshots(&snap, &snap2);
// Header only: tick(4) + count(4) = 8, count = 0
assert_eq!(diff.len(), 8);
let count = u32::from_le_bytes([diff[4], diff[5], diff[6], diff[7]]);
assert_eq!(count, 0);
}
#[test]
fn test_diff_position_changed() {
let old = Snapshot {
tick: 10,
entities: vec![make_entity(1, 0.0, 0.0, 0.0)],
};
let new = Snapshot {
tick: 11,
entities: vec![EntityState {
id: 1,
position: [1.0, 2.0, 3.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
};
let diff = diff_snapshots(&old, &new);
let result = apply_diff(&old, &diff).expect("apply_diff failed");
assert_eq!(result.tick, 11);
assert_eq!(result.entities.len(), 1);
assert_eq!(result.entities[0].position, [1.0, 2.0, 3.0]);
assert_eq!(result.entities[0].rotation, [0.0, 0.0, 0.0]); // unchanged
}
#[test]
fn test_diff_new_entity() {
let old = Snapshot {
tick: 10,
entities: vec![make_entity(1, 0.0, 0.0, 0.0)],
};
let new = Snapshot {
tick: 11,
entities: vec![
make_entity(1, 0.0, 0.0, 0.0),
make_entity(2, 5.0, 6.0, 7.0),
],
};
let diff = diff_snapshots(&old, &new);
let result = apply_diff(&old, &diff).expect("apply_diff failed");
assert_eq!(result.entities.len(), 2);
assert_eq!(result.entities[1].id, 2);
assert_eq!(result.entities[1].position, [5.0, 6.0, 7.0]);
}
#[test]
fn test_diff_multiple_fields_changed() {
let old = Snapshot {
tick: 10,
entities: vec![EntityState {
id: 1,
position: [0.0, 0.0, 0.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
};
let new = Snapshot {
tick: 11,
entities: vec![EntityState {
id: 1,
position: [1.0, 1.0, 1.0],
rotation: [2.0, 2.0, 2.0],
velocity: [3.0, 3.0, 3.0],
}],
};
let diff = diff_snapshots(&old, &new);
let result = apply_diff(&old, &diff).unwrap();
assert_eq!(result.entities[0].position, [1.0, 1.0, 1.0]);
assert_eq!(result.entities[0].rotation, [2.0, 2.0, 2.0]);
assert_eq!(result.entities[0].velocity, [3.0, 3.0, 3.0]);
}
#[test]
fn test_diff_is_compact() {
// Only position changes — diff should be smaller than full snapshot
let old = Snapshot {
tick: 10,
entities: vec![make_entity(1, 0.0, 0.0, 0.0)],
};
let new = Snapshot {
tick: 11,
entities: vec![EntityState {
id: 1,
position: [1.0, 2.0, 3.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
};
let full_bytes = serialize_snapshot(&new);
let diff_bytes = diff_snapshots(&old, &new);
assert!(
diff_bytes.len() < full_bytes.len(),
"Diff ({} bytes) should be smaller than full snapshot ({} bytes)",
diff_bytes.len(),
full_bytes.len()
);
}
}

View File

@@ -1,5 +1,6 @@
use voltex_ecs::Entity;
use voltex_math::AABB;
use voltex_math::{AABB, Ray};
use crate::ray::ray_vs_aabb;
#[derive(Debug)]
enum BvhNode {
@@ -72,34 +73,131 @@ impl BvhTree {
idx
}
/// Query all overlapping pairs using recursive tree traversal (replaces N² brute force).
pub fn query_pairs(&self) -> Vec<(Entity, Entity)> {
let mut pairs = Vec::new();
if self.nodes.is_empty() {
return pairs;
}
let root = self.nodes.len() - 1;
let mut leaves = Vec::new();
self.collect_leaves(root, &mut leaves);
for i in 0..leaves.len() {
for j in (i + 1)..leaves.len() {
let (ea, aabb_a) = leaves[i];
let (eb, aabb_b) = leaves[j];
if aabb_a.intersects(&aabb_b) {
pairs.push((ea, eb));
}
}
}
self.query_pairs_recursive(root, root, &mut pairs);
pairs
}
fn collect_leaves(&self, node_idx: usize, out: &mut Vec<(Entity, AABB)>) {
match &self.nodes[node_idx] {
fn query_pairs_recursive(&self, a: usize, b: usize, pairs: &mut Vec<(Entity, Entity)>) {
let aabb_a = self.node_aabb(a);
let aabb_b = self.node_aabb(b);
if !aabb_a.intersects(&aabb_b) {
return;
}
match (&self.nodes[a], &self.nodes[b]) {
(BvhNode::Leaf { entity: ea, aabb: aabb_a }, BvhNode::Leaf { entity: eb, aabb: aabb_b }) => {
if a != b && ea.id <= eb.id && aabb_a.intersects(aabb_b) {
pairs.push((*ea, *eb));
}
}
(BvhNode::Leaf { .. }, BvhNode::Internal { left, right, .. }) => {
self.query_pairs_recursive(a, *left, pairs);
self.query_pairs_recursive(a, *right, pairs);
}
(BvhNode::Internal { left, right, .. }, BvhNode::Leaf { .. }) => {
self.query_pairs_recursive(*left, b, pairs);
self.query_pairs_recursive(*right, b, pairs);
}
(BvhNode::Internal { left: la, right: ra, .. }, BvhNode::Internal { left: lb, right: rb, .. }) => {
if a == b {
// Same node: check children against each other and themselves
let la = *la;
let ra = *ra;
self.query_pairs_recursive(la, la, pairs);
self.query_pairs_recursive(ra, ra, pairs);
self.query_pairs_recursive(la, ra, pairs);
} else {
let la = *la;
let ra = *ra;
let lb = *lb;
let rb = *rb;
self.query_pairs_recursive(la, lb, pairs);
self.query_pairs_recursive(la, rb, pairs);
self.query_pairs_recursive(ra, lb, pairs);
self.query_pairs_recursive(ra, rb, pairs);
}
}
}
}
fn node_aabb(&self, idx: usize) -> &AABB {
match &self.nodes[idx] {
BvhNode::Leaf { aabb, .. } => aabb,
BvhNode::Internal { aabb, .. } => aabb,
}
}
/// Query ray against BVH, returning all (Entity, t) hits sorted by t.
pub fn query_ray(&self, ray: &Ray, max_t: f32) -> Vec<(Entity, f32)> {
let mut hits = Vec::new();
if self.nodes.is_empty() {
return hits;
}
let root = self.nodes.len() - 1;
self.query_ray_recursive(root, ray, max_t, &mut hits);
hits.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
hits
}
fn query_ray_recursive(&self, idx: usize, ray: &Ray, max_t: f32, hits: &mut Vec<(Entity, f32)>) {
let aabb = self.node_aabb(idx);
match ray_vs_aabb(ray, aabb) {
Some(t) if t <= max_t => {}
_ => return,
}
match &self.nodes[idx] {
BvhNode::Leaf { entity, aabb } => {
out.push((*entity, *aabb));
if let Some(t) = ray_vs_aabb(ray, aabb) {
if t <= max_t {
hits.push((*entity, t));
}
}
}
BvhNode::Internal { left, right, .. } => {
self.collect_leaves(*left, out);
self.collect_leaves(*right, out);
self.query_ray_recursive(*left, ray, max_t, hits);
self.query_ray_recursive(*right, ray, max_t, hits);
}
}
}
/// Refit the BVH: update leaf AABBs and propagate changes to parents.
/// `updated` maps entity → new AABB. Leaves not in map are unchanged.
pub fn refit(&mut self, updated: &[(Entity, AABB)]) {
if self.nodes.is_empty() {
return;
}
let root = self.nodes.len() - 1;
self.refit_recursive(root, updated);
}
fn refit_recursive(&mut self, idx: usize, updated: &[(Entity, AABB)]) -> AABB {
match self.nodes[idx] {
BvhNode::Leaf { entity, ref mut aabb } => {
if let Some((_, new_aabb)) = updated.iter().find(|(e, _)| *e == entity) {
*aabb = *new_aabb;
}
*aabb
}
BvhNode::Internal { left, right, aabb: _ } => {
let left = left;
let right = right;
let left_aabb = self.refit_recursive(left, updated);
let right_aabb = self.refit_recursive(right, updated);
let new_aabb = left_aabb.merged(&right_aabb);
// Update the internal node's AABB
if let BvhNode::Internal { ref mut aabb, .. } = self.nodes[idx] {
*aabb = new_aabb;
}
new_aabb
}
}
}
@@ -163,4 +261,117 @@ mod tests {
let (a, b) = pairs[0];
assert!((a.id == 0 && b.id == 1) || (a.id == 1 && b.id == 0));
}
// --- query_pairs: verify recursive gives same results as brute force ---
#[test]
fn test_query_pairs_matches_brute_force() {
let entries = vec![
(make_entity(0), AABB::new(Vec3::ZERO, Vec3::new(2.0, 2.0, 2.0))),
(make_entity(1), AABB::new(Vec3::ONE, Vec3::new(3.0, 3.0, 3.0))),
(make_entity(2), AABB::new(Vec3::new(2.5, 2.5, 2.5), Vec3::new(4.0, 4.0, 4.0))),
(make_entity(3), AABB::new(Vec3::new(10.0, 10.0, 10.0), Vec3::new(11.0, 11.0, 11.0))),
(make_entity(4), AABB::new(Vec3::new(10.5, 10.5, 10.5), Vec3::new(12.0, 12.0, 12.0))),
];
let tree = BvhTree::build(&entries);
let mut pairs = tree.query_pairs();
pairs.sort_by_key(|(a, b)| (a.id.min(b.id), a.id.max(b.id)));
// Brute force
let mut brute: Vec<(Entity, Entity)> = Vec::new();
for i in 0..entries.len() {
for j in (i + 1)..entries.len() {
if entries[i].1.intersects(&entries[j].1) {
let a = entries[i].0;
let b = entries[j].0;
brute.push(if a.id <= b.id { (a, b) } else { (b, a) });
}
}
}
brute.sort_by_key(|(a, b)| (a.id, b.id));
assert_eq!(pairs.len(), brute.len(), "pair count mismatch: tree={}, brute={}", pairs.len(), brute.len());
for (tree_pair, brute_pair) in pairs.iter().zip(brute.iter()) {
let t = (tree_pair.0.id.min(tree_pair.1.id), tree_pair.0.id.max(tree_pair.1.id));
let b = (brute_pair.0.id, brute_pair.1.id);
assert_eq!(t, b);
}
}
// --- query_ray tests ---
#[test]
fn test_query_ray_basic() {
let entries = vec![
(make_entity(0), AABB::new(Vec3::new(4.0, -1.0, -1.0), Vec3::new(6.0, 1.0, 1.0))),
(make_entity(1), AABB::new(Vec3::new(10.0, 10.0, 10.0), Vec3::new(11.0, 11.0, 11.0))),
];
let tree = BvhTree::build(&entries);
let ray = Ray::new(Vec3::ZERO, Vec3::X);
let hits = tree.query_ray(&ray, 100.0);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].0.id, 0);
}
#[test]
fn test_query_ray_multiple() {
let entries = vec![
(make_entity(0), AABB::new(Vec3::new(2.0, -1.0, -1.0), Vec3::new(4.0, 1.0, 1.0))),
(make_entity(1), AABB::new(Vec3::new(6.0, -1.0, -1.0), Vec3::new(8.0, 1.0, 1.0))),
];
let tree = BvhTree::build(&entries);
let ray = Ray::new(Vec3::ZERO, Vec3::X);
let hits = tree.query_ray(&ray, 100.0);
assert_eq!(hits.len(), 2);
assert!(hits[0].1 < hits[1].1, "should be sorted by distance");
}
#[test]
fn test_query_ray_miss() {
let entries = vec![
(make_entity(0), AABB::new(Vec3::new(4.0, 4.0, 4.0), Vec3::new(6.0, 6.0, 6.0))),
];
let tree = BvhTree::build(&entries);
let ray = Ray::new(Vec3::ZERO, Vec3::X);
let hits = tree.query_ray(&ray, 100.0);
assert!(hits.is_empty());
}
// --- refit tests ---
#[test]
fn test_refit_updates_leaf() {
let entries = vec![
(make_entity(0), AABB::new(Vec3::ZERO, Vec3::ONE)),
(make_entity(1), AABB::new(Vec3::new(5.0, 5.0, 5.0), Vec3::new(6.0, 6.0, 6.0))),
];
let mut tree = BvhTree::build(&entries);
// Initially separated
assert!(tree.query_pairs().is_empty());
// Move entity 1 to overlap entity 0
tree.refit(&[(make_entity(1), AABB::new(Vec3::new(0.5, 0.5, 0.5), Vec3::new(1.5, 1.5, 1.5)))]);
let pairs = tree.query_pairs();
assert_eq!(pairs.len(), 1, "after refit, overlapping entities should be found");
}
#[test]
fn test_refit_separates_entities() {
let entries = vec![
(make_entity(0), AABB::new(Vec3::ZERO, Vec3::new(2.0, 2.0, 2.0))),
(make_entity(1), AABB::new(Vec3::ONE, Vec3::new(3.0, 3.0, 3.0))),
];
let mut tree = BvhTree::build(&entries);
// Initially overlapping
assert_eq!(tree.query_pairs().len(), 1);
// Move entity 1 far away
tree.refit(&[(make_entity(1), AABB::new(Vec3::new(100.0, 100.0, 100.0), Vec3::new(101.0, 101.0, 101.0)))]);
assert!(tree.query_pairs().is_empty(), "after refit, separated entities should not overlap");
}
}

View File

@@ -0,0 +1,118 @@
use voltex_math::{Vec3, AABB, Ray};
use crate::ray::ray_vs_aabb;
/// Swept sphere vs AABB continuous collision detection.
/// Expands the AABB by the sphere radius, then tests a ray from start to end.
/// Returns t in [0,1] of first contact, or None if no contact.
pub fn swept_sphere_vs_aabb(start: Vec3, end: Vec3, radius: f32, aabb: &AABB) -> Option<f32> {
// Expand AABB by sphere radius
let r = Vec3::new(radius, radius, radius);
let expanded = AABB::new(aabb.min - r, aabb.max + r);
let direction = end - start;
let sweep_len = direction.length();
if sweep_len < 1e-10 {
// No movement — check if already inside
if expanded.contains_point(start) {
return Some(0.0);
}
return None;
}
let ray = Ray::new(start, direction * (1.0 / sweep_len));
match ray_vs_aabb(&ray, &expanded) {
Some(t) => {
let parametric_t = t / sweep_len;
if parametric_t <= 1.0 {
Some(parametric_t)
} else {
None
}
}
None => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f32, b: f32) -> bool {
(a - b).abs() < 1e-3
}
#[test]
fn test_swept_sphere_hits_aabb() {
let start = Vec3::new(-10.0, 0.0, 0.0);
let end = Vec3::new(10.0, 0.0, 0.0);
let radius = 0.5;
let aabb = AABB::new(Vec3::new(4.0, -1.0, -1.0), Vec3::new(6.0, 1.0, 1.0));
let t = swept_sphere_vs_aabb(start, end, radius, &aabb).unwrap();
// Expanded AABB min.x = 3.5, start.x = -10, direction = 20
// t = (3.5 - (-10)) / 20 = 13.5 / 20 = 0.675
assert!(t > 0.0 && t < 1.0);
assert!(approx(t, 0.675));
}
#[test]
fn test_swept_sphere_misses_aabb() {
let start = Vec3::new(-10.0, 10.0, 0.0);
let end = Vec3::new(10.0, 10.0, 0.0);
let radius = 0.5;
let aabb = AABB::new(Vec3::new(4.0, -1.0, -1.0), Vec3::new(6.0, 1.0, 1.0));
assert!(swept_sphere_vs_aabb(start, end, radius, &aabb).is_none());
}
#[test]
fn test_swept_sphere_starts_inside() {
let start = Vec3::new(5.0, 0.0, 0.0);
let end = Vec3::new(10.0, 0.0, 0.0);
let radius = 0.5;
let aabb = AABB::new(Vec3::new(4.0, -1.0, -1.0), Vec3::new(6.0, 1.0, 1.0));
let t = swept_sphere_vs_aabb(start, end, radius, &aabb).unwrap();
assert!(approx(t, 0.0));
}
#[test]
fn test_swept_sphere_tunneling_detection() {
// Fast sphere that would tunnel through a thin wall
let start = Vec3::new(-100.0, 0.0, 0.0);
let end = Vec3::new(100.0, 0.0, 0.0);
let radius = 0.1;
// Thin wall at x=0
let aabb = AABB::new(Vec3::new(-0.05, -10.0, -10.0), Vec3::new(0.05, 10.0, 10.0));
let t = swept_sphere_vs_aabb(start, end, radius, &aabb);
assert!(t.is_some(), "should detect tunneling through thin wall");
let t = t.unwrap();
assert!(t > 0.0 && t < 1.0);
}
#[test]
fn test_swept_sphere_no_movement() {
let start = Vec3::new(5.0, 0.0, 0.0);
let end = Vec3::new(5.0, 0.0, 0.0);
let radius = 0.5;
let aabb = AABB::new(Vec3::new(4.0, -1.0, -1.0), Vec3::new(6.0, 1.0, 1.0));
// Inside expanded AABB, so should return 0
let t = swept_sphere_vs_aabb(start, end, radius, &aabb).unwrap();
assert!(approx(t, 0.0));
}
#[test]
fn test_swept_sphere_beyond_range() {
let start = Vec3::new(-10.0, 0.0, 0.0);
let end = Vec3::new(-5.0, 0.0, 0.0);
let radius = 0.5;
let aabb = AABB::new(Vec3::new(4.0, -1.0, -1.0), Vec3::new(6.0, 1.0, 1.0));
// AABB is at x=4..6, moving from -10 to -5 won't reach it
assert!(swept_sphere_vs_aabb(start, end, radius, &aabb).is_none());
}
}

View File

@@ -1,10 +1,37 @@
use voltex_math::{Vec3, AABB};
#[derive(Debug, Clone, Copy)]
/// A convex hull collider defined by a set of vertices.
#[derive(Debug, Clone)]
pub struct ConvexHull {
pub vertices: Vec<Vec3>,
}
impl ConvexHull {
pub fn new(vertices: Vec<Vec3>) -> Self {
ConvexHull { vertices }
}
/// Support function for GJK: returns the vertex farthest in the given direction.
pub fn support(&self, direction: Vec3) -> Vec3 {
let mut best = self.vertices[0];
let mut best_dot = best.dot(direction);
for &v in &self.vertices[1..] {
let d = v.dot(direction);
if d > best_dot {
best_dot = d;
best = v;
}
}
best
}
}
#[derive(Debug, Clone)]
pub enum Collider {
Sphere { radius: f32 },
Box { half_extents: Vec3 },
Capsule { radius: f32, half_height: f32 },
ConvexHull(ConvexHull),
}
impl Collider {
@@ -21,6 +48,16 @@ impl Collider {
let r = Vec3::new(*radius, *half_height + *radius, *radius);
AABB::new(position - r, position + r)
}
Collider::ConvexHull(hull) => {
let mut min = position + hull.vertices[0];
let mut max = min;
for &v in &hull.vertices[1..] {
let p = position + v;
min = Vec3::new(min.x.min(p.x), min.y.min(p.y), min.z.min(p.z));
max = Vec3::new(max.x.max(p.x), max.y.max(p.y), max.z.max(p.z));
}
AABB::new(min, max)
}
}
}
}
@@ -52,4 +89,45 @@ mod tests {
assert_eq!(aabb.min, Vec3::new(0.5, 0.5, 2.5));
assert_eq!(aabb.max, Vec3::new(1.5, 3.5, 3.5));
}
#[test]
fn test_convex_hull_support() {
let hull = ConvexHull::new(vec![
Vec3::new(-1.0, -1.0, -1.0),
Vec3::new(1.0, -1.0, -1.0),
Vec3::new(0.0, 1.0, 0.0),
Vec3::new(0.0, -1.0, 1.0),
]);
// Support in +Y direction should return the top vertex
let s = hull.support(Vec3::new(0.0, 1.0, 0.0));
assert!((s.y - 1.0).abs() < 1e-6);
}
#[test]
fn test_convex_hull_support_negative() {
let hull = ConvexHull::new(vec![
Vec3::new(0.0, 0.0, 0.0),
Vec3::new(1.0, 0.0, 0.0),
Vec3::new(0.0, 1.0, 0.0),
]);
let s = hull.support(Vec3::new(-1.0, 0.0, 0.0));
assert!((s.x - 0.0).abs() < 1e-6); // origin is farthest in -X
}
#[test]
fn test_convex_hull_in_collider() {
// Test that ConvexHull variant works with existing collision system
let hull = ConvexHull::new(vec![
Vec3::new(-1.0, -1.0, -1.0),
Vec3::new(1.0, -1.0, -1.0),
Vec3::new(1.0, 1.0, -1.0),
Vec3::new(-1.0, 1.0, -1.0),
Vec3::new(-1.0, -1.0, 1.0),
Vec3::new(1.0, -1.0, 1.0),
Vec3::new(1.0, 1.0, 1.0),
Vec3::new(-1.0, 1.0, 1.0),
]);
// Just verify construction works
assert_eq!(hull.vertices.len(), 8);
}
}

View File

@@ -13,7 +13,7 @@ pub fn detect_collisions(world: &World) -> Vec<ContactPoint> {
let pairs_data: Vec<(Entity, Vec3, Collider)> = world
.query2::<Transform, Collider>()
.into_iter()
.map(|(e, t, c)| (e, t.position, *c))
.map(|(e, t, c)| (e, t.position, c.clone()))
.collect();
if pairs_data.len() < 2 {
@@ -34,7 +34,7 @@ pub fn detect_collisions(world: &World) -> Vec<ContactPoint> {
let mut contacts = Vec::new();
let lookup = |entity: Entity| -> Option<(Vec3, Collider)> {
pairs_data.iter().find(|(e, _, _)| *e == entity).map(|(_, p, c)| (*p, *c))
pairs_data.iter().find(|(e, _, _)| *e == entity).map(|(_, p, c)| (*p, c.clone()))
};
for (ea, eb) in broad_pairs {
@@ -55,8 +55,9 @@ pub fn detect_collisions(world: &World) -> Vec<ContactPoint> {
(Collider::Box { half_extents: ha }, Collider::Box { half_extents: hb }) => {
narrow::box_vs_box(pos_a, *ha, pos_b, *hb)
}
// Any combination involving Capsule uses GJK/EPA
(Collider::Capsule { .. }, _) | (_, Collider::Capsule { .. }) => {
// Any combination involving Capsule or ConvexHull uses GJK/EPA
(Collider::Capsule { .. }, _) | (_, Collider::Capsule { .. })
| (Collider::ConvexHull(_), _) | (_, Collider::ConvexHull(_)) => {
gjk::gjk_epa(&col_a, pos_a, &col_b, pos_b)
}
};

View File

@@ -30,6 +30,9 @@ fn support(collider: &Collider, position: Vec3, direction: Vec3) -> Vec3 {
}
base + direction * (*radius / len)
}
Collider::ConvexHull(hull) => {
position + hull.support(direction)
}
}
}

View File

@@ -1,28 +1,121 @@
use voltex_ecs::World;
use voltex_ecs::Transform;
use voltex_math::Vec3;
use crate::rigid_body::{RigidBody, PhysicsConfig};
use crate::collider::Collider;
use crate::rigid_body::{RigidBody, PhysicsConfig, SLEEP_VELOCITY_THRESHOLD, SLEEP_TIME_THRESHOLD};
/// Compute diagonal inertia tensor for a collider shape.
/// Returns Vec3 where each component is the moment of inertia about that axis.
pub fn inertia_tensor(collider: &Collider, mass: f32) -> Vec3 {
match collider {
Collider::Sphere { radius } => {
let i = (2.0 / 5.0) * mass * radius * radius;
Vec3::new(i, i, i)
}
Collider::Box { half_extents } => {
let w = half_extents.x * 2.0;
let h = half_extents.y * 2.0;
let d = half_extents.z * 2.0;
let factor = mass / 12.0;
Vec3::new(
factor * (h * h + d * d),
factor * (w * w + d * d),
factor * (w * w + h * h),
)
}
Collider::Capsule { radius, half_height } => {
// Approximate as cylinder with total height = 2*half_height + 2*radius
let r = *radius;
let h = half_height * 2.0;
let ix = mass * (3.0 * r * r + h * h) / 12.0;
let iy = mass * r * r / 2.0;
let iz = ix;
Vec3::new(ix, iy, iz)
}
Collider::ConvexHull(hull) => {
// Approximate using AABB of the hull vertices
let mut min = hull.vertices[0];
let mut max = min;
for &v in &hull.vertices[1..] {
min = Vec3::new(min.x.min(v.x), min.y.min(v.y), min.z.min(v.z));
max = Vec3::new(max.x.max(v.x), max.y.max(v.y), max.z.max(v.z));
}
let size = max - min;
let w = size.x;
let h = size.y;
let d = size.z;
let factor = mass / 12.0;
Vec3::new(
factor * (h * h + d * d),
factor * (w * w + d * d),
factor * (w * w + h * h),
)
}
}
}
/// Compute inverse inertia (component-wise 1/I). Returns zero for zero-mass or zero-inertia.
pub fn inv_inertia(inertia: Vec3) -> Vec3 {
Vec3::new(
if inertia.x > 1e-10 { 1.0 / inertia.x } else { 0.0 },
if inertia.y > 1e-10 { 1.0 / inertia.y } else { 0.0 },
if inertia.z > 1e-10 { 1.0 / inertia.z } else { 0.0 },
)
}
pub fn integrate(world: &mut World, config: &PhysicsConfig) {
// 1. Collect
let updates: Vec<(voltex_ecs::Entity, Vec3, Vec3)> = world
// 1. Collect linear + angular updates
let updates: Vec<(voltex_ecs::Entity, Vec3, Vec3, Vec3)> = world
.query2::<Transform, RigidBody>()
.into_iter()
.filter(|(_, _, rb)| !rb.is_static())
.filter(|(_, _, rb)| !rb.is_static() && !rb.is_sleeping)
.map(|(entity, transform, rb)| {
let new_velocity = rb.velocity + config.gravity * rb.gravity_scale * config.fixed_dt;
let new_position = transform.position + new_velocity * config.fixed_dt;
(entity, new_velocity, new_position)
let new_rotation = transform.rotation + rb.angular_velocity * config.fixed_dt;
(entity, new_velocity, new_position, new_rotation)
})
.collect();
// 2. Apply
for (entity, new_velocity, new_position) in updates {
for (entity, new_velocity, new_position, new_rotation) in updates {
if let Some(rb) = world.get_mut::<RigidBody>(entity) {
rb.velocity = new_velocity;
}
if let Some(t) = world.get_mut::<Transform>(entity) {
t.position = new_position;
t.rotation = new_rotation;
}
}
// 3. Update sleep timers
update_sleep_timers(world, config);
}
fn update_sleep_timers(world: &mut World, config: &PhysicsConfig) {
let sleep_updates: Vec<(voltex_ecs::Entity, bool, f32)> = world
.query::<RigidBody>()
.filter(|(_, rb)| !rb.is_static())
.map(|(entity, rb)| {
let speed = rb.velocity.length() + rb.angular_velocity.length();
if speed < SLEEP_VELOCITY_THRESHOLD {
let new_timer = rb.sleep_timer + config.fixed_dt;
let should_sleep = new_timer >= SLEEP_TIME_THRESHOLD;
(entity, should_sleep, new_timer)
} else {
(entity, false, 0.0)
}
})
.collect();
for (entity, should_sleep, timer) in sleep_updates {
if let Some(rb) = world.get_mut::<RigidBody>(entity) {
rb.sleep_timer = timer;
if should_sleep {
rb.is_sleeping = true;
rb.velocity = Vec3::ZERO;
rb.angular_velocity = Vec3::ZERO;
}
}
}
}
@@ -33,7 +126,7 @@ mod tests {
use voltex_ecs::World;
use voltex_ecs::Transform;
use voltex_math::Vec3;
use crate::RigidBody;
use crate::{RigidBody, Collider};
fn approx(a: f32, b: f32) -> bool {
(a - b).abs() < 1e-4
@@ -93,4 +186,146 @@ mod tests {
let expected_x = 5.0 * config.fixed_dt;
assert!(approx(t.position.x, expected_x));
}
// --- Inertia tensor tests ---
#[test]
fn test_inertia_tensor_sphere() {
let c = Collider::Sphere { radius: 1.0 };
let i = inertia_tensor(&c, 1.0);
let expected = 2.0 / 5.0;
assert!(approx(i.x, expected));
assert!(approx(i.y, expected));
assert!(approx(i.z, expected));
}
#[test]
fn test_inertia_tensor_box() {
let c = Collider::Box { half_extents: Vec3::ONE };
let i = inertia_tensor(&c, 12.0);
// w=2, h=2, d=2, factor=1 => ix = 4+4=8, etc
assert!(approx(i.x, 8.0));
assert!(approx(i.y, 8.0));
assert!(approx(i.z, 8.0));
}
#[test]
fn test_inertia_tensor_capsule() {
let c = Collider::Capsule { radius: 0.5, half_height: 1.0 };
let i = inertia_tensor(&c, 1.0);
// Approximate as cylinder: r=0.5, h=2.0
// ix = m*(3*r^2 + h^2)/12 = (3*0.25 + 4)/12 = 4.75/12
assert!(approx(i.x, 4.75 / 12.0));
// iy = m*r^2/2 = 0.25/2 = 0.125
assert!(approx(i.y, 0.125));
}
// --- Angular velocity integration tests ---
#[test]
fn test_spinning_sphere() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::ZERO));
let mut rb = RigidBody::dynamic(1.0);
rb.angular_velocity = Vec3::new(0.0, 3.14159, 0.0); // ~PI rad/s around Y
rb.gravity_scale = 0.0;
world.add(e, rb);
let config = PhysicsConfig::default();
integrate(&mut world, &config);
let t = world.get::<Transform>(e).unwrap();
let expected_rot_y = 3.14159 * config.fixed_dt;
assert!(approx(t.rotation.y, expected_rot_y));
// Position should not change (no linear velocity, no gravity)
assert!(approx(t.position.x, 0.0));
assert!(approx(t.position.y, 0.0));
}
#[test]
fn test_angular_velocity_persists() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::ZERO));
let mut rb = RigidBody::dynamic(1.0);
rb.angular_velocity = Vec3::new(1.0, 0.0, 0.0);
rb.gravity_scale = 0.0;
world.add(e, rb);
let config = PhysicsConfig::default();
integrate(&mut world, &config);
integrate(&mut world, &config);
let t = world.get::<Transform>(e).unwrap();
let expected_rot_x = 2.0 * config.fixed_dt;
assert!(approx(t.rotation.x, expected_rot_x));
}
// --- Sleep system tests ---
#[test]
fn test_body_sleeps_after_resting() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::ZERO));
let mut rb = RigidBody::dynamic(1.0);
rb.velocity = Vec3::ZERO;
rb.gravity_scale = 0.0;
world.add(e, rb);
let config = PhysicsConfig {
gravity: Vec3::ZERO,
fixed_dt: 1.0 / 60.0,
solver_iterations: 4,
};
// Integrate many times until sleep timer exceeds threshold
for _ in 0..60 {
integrate(&mut world, &config);
}
let rb = world.get::<RigidBody>(e).unwrap();
assert!(rb.is_sleeping, "body should be sleeping after resting");
}
#[test]
fn test_sleeping_body_not_integrated() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::new(0.0, 10.0, 0.0)));
let mut rb = RigidBody::dynamic(1.0);
rb.is_sleeping = true;
world.add(e, rb);
let config = PhysicsConfig::default();
integrate(&mut world, &config);
let t = world.get::<Transform>(e).unwrap();
assert!(approx(t.position.y, 10.0), "sleeping body should not move");
}
#[test]
fn test_moving_body_does_not_sleep() {
let mut world = World::new();
let e = world.spawn();
world.add(e, Transform::from_position(Vec3::ZERO));
let mut rb = RigidBody::dynamic(1.0);
rb.velocity = Vec3::new(5.0, 0.0, 0.0);
rb.gravity_scale = 0.0;
world.add(e, rb);
let config = PhysicsConfig {
gravity: Vec3::ZERO,
fixed_dt: 1.0 / 60.0,
solver_iterations: 4,
};
for _ in 0..60 {
integrate(&mut world, &config);
}
let rb = world.get::<RigidBody>(e).unwrap();
assert!(!rb.is_sleeping, "fast-moving body should not sleep");
}
}

View File

@@ -9,12 +9,17 @@ pub mod rigid_body;
pub mod integrator;
pub mod solver;
pub mod raycast;
pub mod ccd;
pub mod mesh_collider;
pub use bvh::BvhTree;
pub use collider::Collider;
pub use collider::{Collider, ConvexHull};
pub use contact::ContactPoint;
pub use collision::detect_collisions;
pub use rigid_body::{RigidBody, PhysicsConfig};
pub use integrator::integrate;
pub use integrator::{integrate, inertia_tensor, inv_inertia};
pub use solver::{resolve_collisions, physics_step};
pub use raycast::{RayHit, raycast};
pub use raycast::{RayHit, raycast, raycast_all};
pub use ray::ray_vs_triangle;
pub use ccd::swept_sphere_vs_aabb;
pub use mesh_collider::{MeshCollider, MeshHit, ray_vs_mesh, ray_vs_mesh_all};

View File

@@ -0,0 +1,143 @@
use voltex_math::{Vec3, Ray};
use crate::ray::ray_vs_triangle;
/// A triangle mesh collider defined by vertices and triangle indices.
#[derive(Debug, Clone)]
pub struct MeshCollider {
pub vertices: Vec<Vec3>,
pub indices: Vec<[u32; 3]>,
}
/// Result of a ray-mesh intersection test.
#[derive(Debug, Clone, Copy)]
pub struct MeshHit {
pub distance: f32,
pub point: Vec3,
pub normal: Vec3,
pub triangle_index: usize,
}
/// Cast a ray against a triangle mesh. Returns the closest hit, if any.
pub fn ray_vs_mesh(ray: &Ray, mesh: &MeshCollider) -> Option<MeshHit> {
let mut closest: Option<MeshHit> = None;
for (i, tri) in mesh.indices.iter().enumerate() {
let v0 = mesh.vertices[tri[0] as usize];
let v1 = mesh.vertices[tri[1] as usize];
let v2 = mesh.vertices[tri[2] as usize];
if let Some((t, normal)) = ray_vs_triangle(ray, v0, v1, v2) {
let is_closer = closest.as_ref().map_or(true, |c| t < c.distance);
if is_closer {
let point = ray.at(t);
closest = Some(MeshHit {
distance: t,
point,
normal,
triangle_index: i,
});
}
}
}
closest
}
/// Cast a ray against a triangle mesh. Returns all hits sorted by distance.
pub fn ray_vs_mesh_all(ray: &Ray, mesh: &MeshCollider) -> Vec<MeshHit> {
let mut hits = Vec::new();
for (i, tri) in mesh.indices.iter().enumerate() {
let v0 = mesh.vertices[tri[0] as usize];
let v1 = mesh.vertices[tri[1] as usize];
let v2 = mesh.vertices[tri[2] as usize];
if let Some((t, normal)) = ray_vs_triangle(ray, v0, v1, v2) {
let point = ray.at(t);
hits.push(MeshHit {
distance: t,
point,
normal,
triangle_index: i,
});
}
}
hits.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
hits
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ray_vs_mesh_hit() {
// Simple quad (2 triangles)
let mesh = MeshCollider {
vertices: vec![
Vec3::new(-1.0, 0.0, -1.0),
Vec3::new(1.0, 0.0, -1.0),
Vec3::new(1.0, 0.0, 1.0),
Vec3::new(-1.0, 0.0, 1.0),
],
indices: vec![[0, 1, 2], [0, 2, 3]],
};
let ray = Ray::new(Vec3::new(0.0, 1.0, 0.0), Vec3::new(0.0, -1.0, 0.0));
let hit = ray_vs_mesh(&ray, &mesh);
assert!(hit.is_some());
assert!((hit.unwrap().distance - 1.0).abs() < 0.01);
}
#[test]
fn test_ray_vs_mesh_miss() {
let mesh = MeshCollider {
vertices: vec![
Vec3::new(-1.0, 0.0, -1.0),
Vec3::new(1.0, 0.0, -1.0),
Vec3::new(0.0, 0.0, 1.0),
],
indices: vec![[0, 1, 2]],
};
let ray = Ray::new(Vec3::new(5.0, 1.0, 0.0), Vec3::new(0.0, -1.0, 0.0));
assert!(ray_vs_mesh(&ray, &mesh).is_none());
}
#[test]
fn test_ray_vs_mesh_closest() {
// Two triangles at different heights
let mesh = MeshCollider {
vertices: vec![
Vec3::new(-1.0, 0.0, -1.0),
Vec3::new(1.0, 0.0, -1.0),
Vec3::new(0.0, 0.0, 1.0),
Vec3::new(-1.0, 2.0, -1.0),
Vec3::new(1.0, 2.0, -1.0),
Vec3::new(0.0, 2.0, 1.0),
],
indices: vec![[0, 1, 2], [3, 4, 5]],
};
let ray = Ray::new(Vec3::new(0.0, 5.0, 0.0), Vec3::new(0.0, -1.0, 0.0));
let hit = ray_vs_mesh(&ray, &mesh).unwrap();
assert!((hit.distance - 3.0).abs() < 0.01); // hits y=2 first
}
#[test]
fn test_ray_vs_mesh_all() {
let mesh = MeshCollider {
vertices: vec![
Vec3::new(-1.0, 0.0, -1.0),
Vec3::new(1.0, 0.0, -1.0),
Vec3::new(0.0, 0.0, 1.0),
Vec3::new(-1.0, 2.0, -1.0),
Vec3::new(1.0, 2.0, -1.0),
Vec3::new(0.0, 2.0, 1.0),
],
indices: vec![[0, 1, 2], [3, 4, 5]],
};
let ray = Ray::new(Vec3::new(0.0, 5.0, 0.0), Vec3::new(0.0, -1.0, 0.0));
let hits = ray_vs_mesh_all(&ray, &mesh);
assert_eq!(hits.len(), 2);
assert!(hits[0].distance < hits[1].distance); // sorted
}
}

View File

@@ -183,6 +183,43 @@ pub fn ray_vs_capsule(ray: &Ray, center: Vec3, radius: f32, half_height: f32) ->
best
}
/// Ray vs Triangle (MöllerTrumbore algorithm).
/// Returns (t, normal) where normal is the triangle face normal.
pub fn ray_vs_triangle(ray: &Ray, v0: Vec3, v1: Vec3, v2: Vec3) -> Option<(f32, Vec3)> {
let edge1 = v1 - v0;
let edge2 = v2 - v0;
let h = ray.direction.cross(edge2);
let a = edge1.dot(h);
if a.abs() < 1e-8 {
return None; // Ray is parallel to triangle
}
let f = 1.0 / a;
let s = ray.origin - v0;
let u = f * s.dot(h);
if u < 0.0 || u > 1.0 {
return None;
}
let q = s.cross(edge1);
let v = f * ray.direction.dot(q);
if v < 0.0 || u + v > 1.0 {
return None;
}
let t = f * edge2.dot(q);
if t < 0.0 {
return None; // Triangle is behind the ray
}
let normal = edge1.cross(edge2).normalize();
Some((t, normal))
}
#[cfg(test)]
mod tests {
use super::*;
@@ -265,4 +302,61 @@ mod tests {
let (t, _normal) = ray_vs_box(&ray, Vec3::ZERO, Vec3::ONE).unwrap();
assert!(approx(t, 0.0));
}
// --- ray_vs_triangle tests ---
#[test]
fn test_triangle_hit() {
let v0 = Vec3::new(-1.0, -1.0, 5.0);
let v1 = Vec3::new(1.0, -1.0, 5.0);
let v2 = Vec3::new(0.0, 1.0, 5.0);
let ray = Ray::new(Vec3::ZERO, Vec3::Z);
let (t, normal) = ray_vs_triangle(&ray, v0, v1, v2).unwrap();
assert!(approx(t, 5.0));
// Normal should point toward -Z (facing the ray)
assert!(normal.z.abs() > 0.9);
}
#[test]
fn test_triangle_miss() {
let v0 = Vec3::new(-1.0, -1.0, 5.0);
let v1 = Vec3::new(1.0, -1.0, 5.0);
let v2 = Vec3::new(0.0, 1.0, 5.0);
// Ray pointing away
let ray = Ray::new(Vec3::new(10.0, 10.0, 0.0), Vec3::Z);
assert!(ray_vs_triangle(&ray, v0, v1, v2).is_none());
}
#[test]
fn test_triangle_behind_ray() {
let v0 = Vec3::new(-1.0, -1.0, -5.0);
let v1 = Vec3::new(1.0, -1.0, -5.0);
let v2 = Vec3::new(0.0, 1.0, -5.0);
let ray = Ray::new(Vec3::ZERO, Vec3::Z);
assert!(ray_vs_triangle(&ray, v0, v1, v2).is_none());
}
#[test]
fn test_triangle_parallel() {
let v0 = Vec3::new(0.0, 0.0, 5.0);
let v1 = Vec3::new(1.0, 0.0, 5.0);
let v2 = Vec3::new(0.0, 0.0, 6.0);
// Ray parallel to triangle (in XZ plane)
let ray = Ray::new(Vec3::ZERO, Vec3::X);
assert!(ray_vs_triangle(&ray, v0, v1, v2).is_none());
}
#[test]
fn test_triangle_edge_hit() {
// Ray hitting exactly on an edge
let v0 = Vec3::new(-1.0, 0.0, 5.0);
let v1 = Vec3::new(1.0, 0.0, 5.0);
let v2 = Vec3::new(0.0, 2.0, 5.0);
// Hit on the midpoint of v0-v1 edge
let ray = Ray::new(Vec3::new(0.0, 0.0, 0.0), Vec3::Z);
let result = ray_vs_triangle(&ray, v0, v1, v2);
assert!(result.is_some());
let (t, _) = result.unwrap();
assert!(approx(t, 5.0));
}
}

View File

@@ -17,7 +17,7 @@ pub fn raycast(world: &World, ray: &Ray, max_dist: f32) -> Option<RayHit> {
let entities: Vec<(Entity, Vec3, Collider)> = world
.query2::<Transform, Collider>()
.into_iter()
.map(|(e, t, c)| (e, t.position, *c))
.map(|(e, t, c)| (e, t.position, c.clone()))
.collect();
if entities.is_empty() {
@@ -53,6 +53,11 @@ pub fn raycast(world: &World, ray: &Ray, max_dist: f32) -> Option<RayHit> {
Collider::Capsule { radius, half_height } => {
ray_tests::ray_vs_capsule(ray, *pos, *radius, *half_height)
}
Collider::ConvexHull(_) => {
// Use AABB test as approximation for convex hull raycasting
let aabb = collider.aabb(*pos);
ray_tests::ray_vs_aabb(ray, &aabb).map(|t| (t, Vec3::Y))
}
};
if let Some((t, normal)) = result {
@@ -72,6 +77,63 @@ pub fn raycast(world: &World, ray: &Ray, max_dist: f32) -> Option<RayHit> {
closest
}
/// Cast a ray and return ALL hits sorted by distance.
pub fn raycast_all(world: &World, ray: &Ray, max_dist: f32) -> Vec<RayHit> {
let entities: Vec<(Entity, Vec3, Collider)> = world
.query2::<Transform, Collider>()
.into_iter()
.map(|(e, t, c)| (e, t.position, c.clone()))
.collect();
if entities.is_empty() {
return Vec::new();
}
let mut hits = Vec::new();
for (entity, pos, collider) in &entities {
let aabb = collider.aabb(*pos);
// Broad phase: ray vs AABB
match ray_tests::ray_vs_aabb(ray, &aabb) {
Some(t) if t <= max_dist => {}
_ => continue,
};
// Narrow phase
let result = match collider {
Collider::Sphere { radius } => {
ray_tests::ray_vs_sphere(ray, *pos, *radius)
}
Collider::Box { half_extents } => {
ray_tests::ray_vs_box(ray, *pos, *half_extents)
}
Collider::Capsule { radius, half_height } => {
ray_tests::ray_vs_capsule(ray, *pos, *radius, *half_height)
}
Collider::ConvexHull(_) => {
// Use AABB test as approximation for convex hull raycasting
let aabb = collider.aabb(*pos);
ray_tests::ray_vs_aabb(ray, &aabb).map(|t| (t, Vec3::Y))
}
};
if let Some((t, normal)) = result {
if t <= max_dist {
hits.push(RayHit {
entity: *entity,
t,
point: ray.at(t),
normal,
});
}
}
}
hits.sort_by(|a, b| a.t.partial_cmp(&b.t).unwrap());
hits
}
#[cfg(test)]
mod tests {
use super::*;
@@ -150,6 +212,62 @@ mod tests {
assert!(approx(hit.t, 4.0));
}
// --- raycast_all tests ---
#[test]
fn test_raycast_all_multiple_hits() {
let mut world = World::new();
let near = world.spawn();
world.add(near, Transform::from_position(Vec3::new(3.0, 0.0, 0.0)));
world.add(near, Collider::Sphere { radius: 0.5 });
let mid = world.spawn();
world.add(mid, Transform::from_position(Vec3::new(6.0, 0.0, 0.0)));
world.add(mid, Collider::Sphere { radius: 0.5 });
let far = world.spawn();
world.add(far, Transform::from_position(Vec3::new(10.0, 0.0, 0.0)));
world.add(far, Collider::Sphere { radius: 0.5 });
let ray = Ray::new(Vec3::ZERO, Vec3::X);
let hits = raycast_all(&world, &ray, 100.0);
assert_eq!(hits.len(), 3);
assert_eq!(hits[0].entity, near);
assert_eq!(hits[1].entity, mid);
assert_eq!(hits[2].entity, far);
assert!(hits[0].t < hits[1].t);
assert!(hits[1].t < hits[2].t);
}
#[test]
fn test_raycast_all_empty() {
let world = World::new();
let ray = Ray::new(Vec3::ZERO, Vec3::X);
let hits = raycast_all(&world, &ray, 100.0);
assert!(hits.is_empty());
}
#[test]
fn test_raycast_all_max_dist() {
let mut world = World::new();
let near = world.spawn();
world.add(near, Transform::from_position(Vec3::new(3.0, 0.0, 0.0)));
world.add(near, Collider::Sphere { radius: 0.5 });
let far = world.spawn();
world.add(far, Transform::from_position(Vec3::new(50.0, 0.0, 0.0)));
world.add(far, Collider::Sphere { radius: 0.5 });
let ray = Ray::new(Vec3::ZERO, Vec3::X);
let hits = raycast_all(&world, &ray, 10.0);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].entity, near);
}
#[test]
fn test_mixed_sphere_box() {
let mut world = World::new();

View File

@@ -1,5 +1,8 @@
use voltex_math::Vec3;
pub const SLEEP_VELOCITY_THRESHOLD: f32 = 0.01;
pub const SLEEP_TIME_THRESHOLD: f32 = 0.5;
#[derive(Debug, Clone, Copy)]
pub struct RigidBody {
pub velocity: Vec3,
@@ -8,6 +11,8 @@ pub struct RigidBody {
pub restitution: f32,
pub gravity_scale: f32,
pub friction: f32, // Coulomb friction coefficient, default 0.5
pub is_sleeping: bool,
pub sleep_timer: f32,
}
impl RigidBody {
@@ -19,6 +24,8 @@ impl RigidBody {
restitution: 0.3,
gravity_scale: 1.0,
friction: 0.5,
is_sleeping: false,
sleep_timer: 0.0,
}
}
@@ -30,6 +37,8 @@ impl RigidBody {
restitution: 0.3,
gravity_scale: 0.0,
friction: 0.5,
is_sleeping: false,
sleep_timer: 0.0,
}
}
@@ -40,11 +49,18 @@ impl RigidBody {
pub fn is_static(&self) -> bool {
self.mass == 0.0
}
/// Wake this body from sleep.
pub fn wake(&mut self) {
self.is_sleeping = false;
self.sleep_timer = 0.0;
}
}
pub struct PhysicsConfig {
pub gravity: Vec3,
pub fixed_dt: f32,
pub solver_iterations: u32,
}
impl Default for PhysicsConfig {
@@ -52,6 +68,7 @@ impl Default for PhysicsConfig {
Self {
gravity: Vec3::new(0.0, -9.81, 0.0),
fixed_dt: 1.0 / 60.0,
solver_iterations: 4,
}
}
}
@@ -69,6 +86,8 @@ mod tests {
assert_eq!(rb.velocity, Vec3::ZERO);
assert_eq!(rb.restitution, 0.3);
assert_eq!(rb.gravity_scale, 1.0);
assert!(!rb.is_sleeping);
assert_eq!(rb.sleep_timer, 0.0);
}
#[test]
@@ -85,5 +104,16 @@ mod tests {
let cfg = PhysicsConfig::default();
assert!((cfg.gravity.y - (-9.81)).abs() < 1e-6);
assert!((cfg.fixed_dt - 1.0 / 60.0).abs() < 1e-6);
assert_eq!(cfg.solver_iterations, 4);
}
#[test]
fn test_wake() {
let mut rb = RigidBody::dynamic(1.0);
rb.is_sleeping = true;
rb.sleep_timer = 1.0;
rb.wake();
assert!(!rb.is_sleeping);
assert_eq!(rb.sleep_timer, 0.0);
}
}

View File

@@ -2,102 +2,275 @@ use voltex_ecs::{World, Entity};
use voltex_ecs::Transform;
use voltex_math::Vec3;
use crate::collider::Collider;
use crate::contact::ContactPoint;
use crate::rigid_body::{RigidBody, PhysicsConfig};
use crate::collision::detect_collisions;
use crate::integrator::integrate;
use crate::integrator::{integrate, inertia_tensor, inv_inertia};
use crate::ccd;
const POSITION_SLOP: f32 = 0.01;
const POSITION_PERCENT: f32 = 0.4;
pub fn resolve_collisions(world: &mut World, contacts: &[ContactPoint]) {
let mut velocity_changes: Vec<(Entity, Vec3)> = Vec::new();
let mut position_changes: Vec<(Entity, Vec3)> = Vec::new();
pub fn resolve_collisions(world: &mut World, contacts: &[ContactPoint], iterations: u32) {
// Wake sleeping bodies that are in contact
wake_colliding_bodies(world, contacts);
for contact in contacts {
let rb_a = world.get::<RigidBody>(contact.entity_a).copied();
let rb_b = world.get::<RigidBody>(contact.entity_b).copied();
for _iter in 0..iterations {
let mut velocity_changes: Vec<(Entity, Vec3, Vec3)> = Vec::new(); // (entity, dv_linear, dv_angular)
let mut position_changes: Vec<(Entity, Vec3)> = Vec::new();
let (rb_a, rb_b) = match (rb_a, rb_b) {
(Some(a), Some(b)) => (a, b),
_ => continue,
};
for contact in contacts {
let rb_a = world.get::<RigidBody>(contact.entity_a).copied();
let rb_b = world.get::<RigidBody>(contact.entity_b).copied();
let col_a = world.get::<Collider>(contact.entity_a).cloned();
let col_b = world.get::<Collider>(contact.entity_b).cloned();
let pos_a = world.get::<Transform>(contact.entity_a).map(|t| t.position);
let pos_b = world.get::<Transform>(contact.entity_b).map(|t| t.position);
let inv_mass_a = rb_a.inv_mass();
let inv_mass_b = rb_b.inv_mass();
let inv_mass_sum = inv_mass_a + inv_mass_b;
if inv_mass_sum == 0.0 {
continue;
}
let v_rel = rb_a.velocity - rb_b.velocity;
let v_rel_n = v_rel.dot(contact.normal);
// normal points A→B; v_rel_n > 0 means A approaches B → apply impulse
let j = if v_rel_n > 0.0 {
let e = rb_a.restitution.min(rb_b.restitution);
let j = (1.0 + e) * v_rel_n / inv_mass_sum;
velocity_changes.push((contact.entity_a, contact.normal * (-j * inv_mass_a)));
velocity_changes.push((contact.entity_b, contact.normal * (j * inv_mass_b)));
j
} else {
// No separating impulse needed, but use contact depth to derive a
// representative normal force magnitude for friction clamping.
// A simple proxy: treat the penetration as providing a static normal force.
contact.depth / inv_mass_sum
};
// Coulomb friction: tangential impulse clamped to mu * normal impulse
let v_rel_n_scalar = v_rel.dot(contact.normal);
let v_rel_tangent = v_rel - contact.normal * v_rel_n_scalar;
let tangent_len = v_rel_tangent.length();
if tangent_len > 1e-6 {
let tangent = v_rel_tangent * (1.0 / tangent_len);
// Friction coefficient: average of both bodies
let mu = (rb_a.friction + rb_b.friction) * 0.5;
// Coulomb's law: friction impulse <= mu * normal impulse
let jt = -v_rel_tangent.dot(tangent) / inv_mass_sum;
let friction_j = if jt.abs() <= j * mu {
jt // static friction
} else {
j * mu * jt.signum() // dynamic friction (sliding), clamped magnitude
let (rb_a, rb_b) = match (rb_a, rb_b) {
(Some(a), Some(b)) => (a, b),
_ => continue,
};
velocity_changes.push((contact.entity_a, tangent * (friction_j * inv_mass_a)));
velocity_changes.push((contact.entity_b, tangent * (-friction_j * inv_mass_b)));
let inv_mass_a = rb_a.inv_mass();
let inv_mass_b = rb_b.inv_mass();
let inv_mass_sum = inv_mass_a + inv_mass_b;
if inv_mass_sum == 0.0 {
continue;
}
// Compute lever arms for angular impulse
let center_a = pos_a.unwrap_or(Vec3::ZERO);
let center_b = pos_b.unwrap_or(Vec3::ZERO);
let r_a = contact.point_on_a - center_a;
let r_b = contact.point_on_b - center_b;
// Compute inverse inertia
let inv_i_a = col_a.map(|c| inv_inertia(inertia_tensor(&c, rb_a.mass)))
.unwrap_or(Vec3::ZERO);
let inv_i_b = col_b.map(|c| inv_inertia(inertia_tensor(&c, rb_b.mass)))
.unwrap_or(Vec3::ZERO);
// Relative velocity at contact point (including angular contribution)
let v_a = rb_a.velocity + rb_a.angular_velocity.cross(r_a);
let v_b = rb_b.velocity + rb_b.angular_velocity.cross(r_b);
let v_rel = v_a - v_b;
let v_rel_n = v_rel.dot(contact.normal);
// Effective mass including rotational terms
let r_a_cross_n = r_a.cross(contact.normal);
let r_b_cross_n = r_b.cross(contact.normal);
let angular_term_a = Vec3::new(
r_a_cross_n.x * inv_i_a.x,
r_a_cross_n.y * inv_i_a.y,
r_a_cross_n.z * inv_i_a.z,
).cross(r_a).dot(contact.normal);
let angular_term_b = Vec3::new(
r_b_cross_n.x * inv_i_b.x,
r_b_cross_n.y * inv_i_b.y,
r_b_cross_n.z * inv_i_b.z,
).cross(r_b).dot(contact.normal);
let effective_mass = inv_mass_sum + angular_term_a + angular_term_b;
// normal points A→B; v_rel_n > 0 means A approaches B → apply impulse
let j = if v_rel_n > 0.0 {
let e = rb_a.restitution.min(rb_b.restitution);
let j = (1.0 + e) * v_rel_n / effective_mass;
// Linear impulse
velocity_changes.push((contact.entity_a, contact.normal * (-j * inv_mass_a), Vec3::ZERO));
velocity_changes.push((contact.entity_b, contact.normal * (j * inv_mass_b), Vec3::ZERO));
// Angular impulse: torque = r × impulse
let angular_impulse_a = r_a.cross(contact.normal * (-j));
let angular_impulse_b = r_b.cross(contact.normal * j);
let dw_a = Vec3::new(
angular_impulse_a.x * inv_i_a.x,
angular_impulse_a.y * inv_i_a.y,
angular_impulse_a.z * inv_i_a.z,
);
let dw_b = Vec3::new(
angular_impulse_b.x * inv_i_b.x,
angular_impulse_b.y * inv_i_b.y,
angular_impulse_b.z * inv_i_b.z,
);
velocity_changes.push((contact.entity_a, Vec3::ZERO, dw_a));
velocity_changes.push((contact.entity_b, Vec3::ZERO, dw_b));
j
} else {
contact.depth / inv_mass_sum
};
// Coulomb friction: tangential impulse clamped to mu * normal impulse
let v_rel_tangent = v_rel - contact.normal * v_rel_n;
let tangent_len = v_rel_tangent.length();
if tangent_len > 1e-6 {
let tangent = v_rel_tangent * (1.0 / tangent_len);
let mu = (rb_a.friction + rb_b.friction) * 0.5;
let jt = -v_rel_tangent.dot(tangent) / effective_mass;
let friction_j = if jt.abs() <= j * mu {
jt
} else {
j * mu * jt.signum()
};
velocity_changes.push((contact.entity_a, tangent * (friction_j * inv_mass_a), Vec3::ZERO));
velocity_changes.push((contact.entity_b, tangent * (-friction_j * inv_mass_b), Vec3::ZERO));
// Angular friction impulse
let angular_fric_a = r_a.cross(tangent * friction_j);
let angular_fric_b = r_b.cross(tangent * (-friction_j));
let dw_fric_a = Vec3::new(
angular_fric_a.x * inv_i_a.x,
angular_fric_a.y * inv_i_a.y,
angular_fric_a.z * inv_i_a.z,
);
let dw_fric_b = Vec3::new(
angular_fric_b.x * inv_i_b.x,
angular_fric_b.y * inv_i_b.y,
angular_fric_b.z * inv_i_b.z,
);
velocity_changes.push((contact.entity_a, Vec3::ZERO, dw_fric_a));
velocity_changes.push((contact.entity_b, Vec3::ZERO, dw_fric_b));
}
// Position correction only on first iteration
if _iter == 0 {
let correction_mag = (contact.depth - POSITION_SLOP).max(0.0) * POSITION_PERCENT / inv_mass_sum;
if correction_mag > 0.0 {
let correction = contact.normal * correction_mag;
position_changes.push((contact.entity_a, correction * (-inv_mass_a)));
position_changes.push((contact.entity_b, correction * inv_mass_b));
}
}
}
let correction_mag = (contact.depth - POSITION_SLOP).max(0.0) * POSITION_PERCENT / inv_mass_sum;
if correction_mag > 0.0 {
let correction = contact.normal * correction_mag;
position_changes.push((contact.entity_a, correction * (-inv_mass_a)));
position_changes.push((contact.entity_b, correction * inv_mass_b));
// Apply velocity changes
for (entity, dv, dw) in velocity_changes {
if let Some(rb) = world.get_mut::<RigidBody>(entity) {
rb.velocity = rb.velocity + dv;
rb.angular_velocity = rb.angular_velocity + dw;
}
}
// Apply position corrections
for (entity, dp) in position_changes {
if let Some(t) = world.get_mut::<Transform>(entity) {
t.position = t.position + dp;
}
}
}
}
for (entity, dv) in velocity_changes {
fn wake_colliding_bodies(world: &mut World, contacts: &[ContactPoint]) {
let wake_list: Vec<Entity> = contacts
.iter()
.flat_map(|c| {
let mut entities = Vec::new();
if let Some(rb) = world.get::<RigidBody>(c.entity_a) {
if rb.is_sleeping { entities.push(c.entity_a); }
}
if let Some(rb) = world.get::<RigidBody>(c.entity_b) {
if rb.is_sleeping { entities.push(c.entity_b); }
}
entities
})
.collect();
for entity in wake_list {
if let Some(rb) = world.get_mut::<RigidBody>(entity) {
rb.velocity = rb.velocity + dv;
}
}
for (entity, dp) in position_changes {
if let Some(t) = world.get_mut::<Transform>(entity) {
t.position = t.position + dp;
rb.wake();
}
}
}
pub fn physics_step(world: &mut World, config: &PhysicsConfig) {
// CCD: for fast-moving bodies, check for tunneling
apply_ccd(world, config);
integrate(world, config);
let contacts = detect_collisions(world);
resolve_collisions(world, &contacts);
resolve_collisions(world, &contacts, config.solver_iterations);
}
fn apply_ccd(world: &mut World, config: &PhysicsConfig) {
// Gather fast-moving bodies and all collider AABBs
let bodies: Vec<(Entity, Vec3, Vec3, Collider)> = world
.query3::<Transform, RigidBody, Collider>()
.into_iter()
.filter(|(_, _, rb, _)| !rb.is_static() && !rb.is_sleeping)
.map(|(e, t, rb, c)| (e, t.position, rb.velocity, c.clone()))
.collect();
let all_colliders: Vec<(Entity, voltex_math::AABB)> = world
.query2::<Transform, Collider>()
.into_iter()
.map(|(e, t, c)| (e, c.aabb(t.position)))
.collect();
let mut ccd_corrections: Vec<(Entity, Vec3)> = Vec::new();
for (entity, pos, vel, collider) in &bodies {
let speed = vel.length();
let collider_radius = match collider {
Collider::Sphere { radius } => *radius,
Collider::Box { half_extents } => half_extents.x.min(half_extents.y).min(half_extents.z),
Collider::Capsule { radius, .. } => *radius,
Collider::ConvexHull(hull) => {
// Use minimum distance from origin to any vertex as approximate radius
hull.vertices.iter().map(|v| v.length()).fold(f32::MAX, f32::min)
}
};
// Only apply CCD if displacement > collider radius
if speed * config.fixed_dt <= collider_radius {
continue;
}
let sweep_radius = match collider {
Collider::Sphere { radius } => *radius,
_ => collider_radius,
};
let end = *pos + *vel * config.fixed_dt;
let mut earliest_t = 1.0f32;
for (other_entity, other_aabb) in &all_colliders {
if *other_entity == *entity {
continue;
}
if let Some(t) = ccd::swept_sphere_vs_aabb(*pos, end, sweep_radius, other_aabb) {
if t < earliest_t {
earliest_t = t;
}
}
}
if earliest_t < 1.0 {
// Place body just before collision point
let safe_t = (earliest_t - 0.01).max(0.0);
let safe_pos = *pos + *vel * config.fixed_dt * safe_t;
ccd_corrections.push((*entity, safe_pos));
}
}
for (entity, safe_pos) in ccd_corrections {
if let Some(t) = world.get_mut::<Transform>(entity) {
t.position = safe_pos;
}
if let Some(rb) = world.get_mut::<RigidBody>(entity) {
// Reduce velocity to prevent re-tunneling
rb.velocity = rb.velocity * 0.5;
}
}
}
#[cfg(test)]
@@ -138,7 +311,7 @@ mod tests {
let contacts = detect_collisions(&world);
assert_eq!(contacts.len(), 1);
resolve_collisions(&mut world, &contacts);
resolve_collisions(&mut world, &contacts, 1);
let va = world.get::<RigidBody>(a).unwrap().velocity;
let vb = world.get::<RigidBody>(b).unwrap().velocity;
@@ -168,7 +341,7 @@ mod tests {
let contacts = detect_collisions(&world);
assert_eq!(contacts.len(), 1);
resolve_collisions(&mut world, &contacts);
resolve_collisions(&mut world, &contacts, 1);
let ball_rb = world.get::<RigidBody>(ball).unwrap();
let floor_rb = world.get::<RigidBody>(floor).unwrap();
@@ -198,7 +371,7 @@ mod tests {
let contacts = detect_collisions(&world);
assert_eq!(contacts.len(), 1);
resolve_collisions(&mut world, &contacts);
resolve_collisions(&mut world, &contacts, 1);
let pa = world.get::<Transform>(a).unwrap().position;
let pb = world.get::<Transform>(b).unwrap().position;
@@ -234,14 +407,13 @@ mod tests {
#[test]
fn test_friction_slows_sliding() {
// Ball sliding on static floor with friction
let mut world = World::new();
let ball = world.spawn();
world.add(ball, Transform::from_position(Vec3::new(0.0, 0.4, 0.0)));
world.add(ball, Collider::Sphere { radius: 0.5 });
let mut rb = RigidBody::dynamic(1.0);
rb.velocity = Vec3::new(5.0, 0.0, 0.0); // sliding horizontally
rb.velocity = Vec3::new(5.0, 0.0, 0.0);
rb.gravity_scale = 0.0;
rb.friction = 0.5;
world.add(ball, rb);
@@ -253,14 +425,12 @@ mod tests {
floor_rb.friction = 0.5;
world.add(floor, floor_rb);
// Ball center at 0.4, radius 0.5, floor top at 0.0 → overlap 0.1
let contacts = detect_collisions(&world);
if !contacts.is_empty() {
resolve_collisions(&mut world, &contacts);
resolve_collisions(&mut world, &contacts, 1);
}
let ball_v = world.get::<RigidBody>(ball).unwrap().velocity;
// X velocity should be reduced by friction
assert!(ball_v.x < 5.0, "friction should slow horizontal velocity: {}", ball_v.x);
assert!(ball_v.x > 0.0, "should still be moving: {}", ball_v.x);
}
@@ -280,11 +450,125 @@ mod tests {
world.add(b, RigidBody::statik());
let contacts = detect_collisions(&world);
resolve_collisions(&mut world, &contacts);
resolve_collisions(&mut world, &contacts, 1);
let pa = world.get::<Transform>(a).unwrap().position;
let pb = world.get::<Transform>(b).unwrap().position;
assert!(approx(pa.x, 0.0));
assert!(approx(pb.x, 0.5));
}
// --- Angular impulse tests ---
#[test]
fn test_off_center_hit_produces_spin() {
let mut world = World::new();
// Sphere A moving right, hitting sphere B off-center (offset in Y)
let a = world.spawn();
world.add(a, Transform::from_position(Vec3::new(-0.5, 0.5, 0.0)));
world.add(a, Collider::Sphere { radius: 1.0 });
let mut rb_a = RigidBody::dynamic(1.0);
rb_a.velocity = Vec3::new(2.0, 0.0, 0.0);
rb_a.restitution = 0.5;
rb_a.gravity_scale = 0.0;
world.add(a, rb_a);
let b = world.spawn();
world.add(b, Transform::from_position(Vec3::new(0.5, -0.5, 0.0)));
world.add(b, Collider::Sphere { radius: 1.0 });
let mut rb_b = RigidBody::dynamic(1.0);
rb_b.gravity_scale = 0.0;
rb_b.restitution = 0.5;
world.add(b, rb_b);
let contacts = detect_collisions(&world);
assert!(!contacts.is_empty());
resolve_collisions(&mut world, &contacts, 4);
let rb_a_after = world.get::<RigidBody>(a).unwrap();
let rb_b_after = world.get::<RigidBody>(b).unwrap();
// At least one body should have non-zero angular velocity after off-center collision
let total_angular = rb_a_after.angular_velocity.length() + rb_b_after.angular_velocity.length();
assert!(total_angular > 1e-4, "off-center hit should produce angular velocity, got {}", total_angular);
}
// --- Sequential impulse tests ---
#[test]
fn test_sequential_impulse_stability() {
// Stack of 3 boxes on floor - with iterations they should be more stable
let mut world = World::new();
let floor = world.spawn();
world.add(floor, Transform::from_position(Vec3::new(0.0, -0.5, 0.0)));
world.add(floor, Collider::Box { half_extents: Vec3::new(10.0, 0.5, 10.0) });
world.add(floor, RigidBody::statik());
let mut boxes = Vec::new();
for i in 0..3 {
let e = world.spawn();
let y = 0.5 + i as f32 * 1.0;
world.add(e, Transform::from_position(Vec3::new(0.0, y, 0.0)));
world.add(e, Collider::Box { half_extents: Vec3::new(0.5, 0.5, 0.5) });
let mut rb = RigidBody::dynamic(1.0);
rb.gravity_scale = 0.0; // no gravity for stability test
world.add(e, rb);
boxes.push(e);
}
let config = PhysicsConfig {
gravity: Vec3::ZERO,
fixed_dt: 1.0 / 60.0,
solver_iterations: 4,
};
// Run a few steps
for _ in 0..5 {
physics_step(&mut world, &config);
}
// All boxes should remain roughly in place (no gravity, just resting)
for (i, e) in boxes.iter().enumerate() {
let t = world.get::<Transform>(*e).unwrap();
let expected_y = 0.5 + i as f32 * 1.0;
assert!((t.position.y - expected_y).abs() < 1.0,
"box {} moved too much: expected y~{}, got {}", i, expected_y, t.position.y);
}
}
// --- Wake on collision test ---
#[test]
fn test_wake_on_collision() {
let mut world = World::new();
// Sleeping body
let a = world.spawn();
world.add(a, Transform::from_position(Vec3::ZERO));
world.add(a, Collider::Sphere { radius: 1.0 });
let mut rb_a = RigidBody::dynamic(1.0);
rb_a.is_sleeping = true;
rb_a.gravity_scale = 0.0;
world.add(a, rb_a);
// Moving body that collides with sleeping body
let b = world.spawn();
world.add(b, Transform::from_position(Vec3::new(1.5, 0.0, 0.0)));
world.add(b, Collider::Sphere { radius: 1.0 });
let mut rb_b = RigidBody::dynamic(1.0);
rb_b.velocity = Vec3::new(-2.0, 0.0, 0.0);
rb_b.gravity_scale = 0.0;
world.add(b, rb_b);
let contacts = detect_collisions(&world);
assert!(!contacts.is_empty());
resolve_collisions(&mut world, &contacts, 1);
let rb_a_after = world.get::<RigidBody>(a).unwrap();
assert!(!rb_a_after.is_sleeping, "body should wake on collision");
}
}

View File

@@ -0,0 +1,243 @@
/// Pure CPU exposure calculation logic (testable).
pub fn calculate_target_exposure(
avg_log_luminance: f32,
key_value: f32,
min_exp: f32,
max_exp: f32,
) -> f32 {
let avg_lum = avg_log_luminance.exp();
let target = key_value / avg_lum.max(0.0001);
target.clamp(min_exp, max_exp)
}
/// Smooth adaptation over time.
pub fn adapt_exposure(current: f32, target: f32, dt: f32, speed: f32) -> f32 {
current + (target - current) * (1.0 - (-dt * speed).exp())
}
/// GPU-side auto exposure compute + readback.
pub struct AutoExposure {
compute_pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
result_buffer: wgpu::Buffer,
staging_buffer: wgpu::Buffer,
pub exposure: f32,
pub min_exposure: f32,
pub max_exposure: f32,
pub adaptation_speed: f32,
pub key_value: f32,
pending_read: bool,
}
impl AutoExposure {
pub fn new(device: &wgpu::Device) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Auto Exposure Compute"),
source: wgpu::ShaderSource::Wgsl(include_str!("auto_exposure.wgsl").into()),
});
let bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Auto Exposure BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Float { filterable: false },
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Auto Exposure PL"),
bind_group_layouts: &[&bind_group_layout],
immediate_size: 0,
});
let compute_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Auto Exposure Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let result_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Auto Exposure Result"),
size: 8,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Auto Exposure Staging"),
size: 8,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
AutoExposure {
compute_pipeline,
bind_group_layout,
result_buffer,
staging_buffer,
exposure: 1.0,
min_exposure: 0.1,
max_exposure: 10.0,
adaptation_speed: 2.0,
key_value: 0.18,
pending_read: false,
}
}
/// Dispatch compute shader to calculate luminance. Call once per frame.
pub fn dispatch(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
encoder: &mut wgpu::CommandEncoder,
hdr_view: &wgpu::TextureView,
hdr_width: u32,
hdr_height: u32,
) {
// Clear result buffer
queue.write_buffer(&self.result_buffer, 0, &[0u8; 8]);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Auto Exposure BG"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(hdr_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.result_buffer.as_entire_binding(),
},
],
});
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Auto Exposure Pass"),
timestamp_writes: None,
});
cpass.set_pipeline(&self.compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
let wg_x = (hdr_width + 15) / 16;
let wg_y = (hdr_height + 15) / 16;
cpass.dispatch_workgroups(wg_x, wg_y, 1);
}
// Copy result to staging for CPU readback
encoder.copy_buffer_to_buffer(&self.result_buffer, 0, &self.staging_buffer, 0, 8);
self.pending_read = true;
}
/// Read back luminance result and update exposure. Call after queue.submit().
/// Returns true if exposure was updated.
pub fn update_exposure(&mut self, _dt: f32) -> bool {
if !self.pending_read {
return false;
}
self.pending_read = false;
let slice = self.staging_buffer.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
// The caller must poll the device for the map to complete.
// For a full integration, use async or poll in the render loop.
// For now, return false — use set_average_luminance() for CPU-side updates.
let _ = rx;
false
}
/// Simple CPU-only exposure update without GPU readback.
/// Use when you have a luminance estimate from other means.
pub fn set_average_luminance(&mut self, avg_log_lum: f32, dt: f32) {
let target = calculate_target_exposure(
avg_log_lum,
self.key_value,
self.min_exposure,
self.max_exposure,
);
self.exposure = adapt_exposure(self.exposure, target, dt, self.adaptation_speed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_target_exposure() {
// avg_log_lum = ln(0.18) -> avg_lum = 0.18 -> target = 0.18/0.18 = 1.0
let avg_log_lum = 0.18_f32.ln();
let target = calculate_target_exposure(avg_log_lum, 0.18, 0.1, 10.0);
assert!((target - 1.0).abs() < 0.01, "target={}", target);
}
#[test]
fn test_target_exposure_bright_scene() {
// Bright scene: avg_lum = 2.0 -> target = 0.18/2.0 = 0.09 -> clamped to min 0.1
let avg_log_lum = 2.0_f32.ln();
let target = calculate_target_exposure(avg_log_lum, 0.18, 0.1, 10.0);
assert!((target - 0.1).abs() < 0.01);
}
#[test]
fn test_target_exposure_dark_scene() {
// Dark scene: avg_lum = 0.001 -> target = 0.18/0.001 = 180 -> clamped to max 10.0
let avg_log_lum = 0.001_f32.ln();
let target = calculate_target_exposure(avg_log_lum, 0.18, 0.1, 10.0);
assert!((target - 10.0).abs() < 0.01);
}
#[test]
fn test_adapt_exposure_no_time() {
let result = adapt_exposure(1.0, 5.0, 0.0, 2.0);
assert!((result - 1.0).abs() < 0.01); // dt=0 -> no change
}
#[test]
fn test_adapt_exposure_converges() {
let mut exp = 1.0;
for _ in 0..100 {
exp = adapt_exposure(exp, 5.0, 0.016, 2.0); // 60fps
}
assert!(
(exp - 5.0).abs() < 0.2,
"should converge to 5.0, got {}",
exp
);
}
#[test]
fn test_adapt_exposure_large_dt() {
let result = adapt_exposure(1.0, 5.0, 100.0, 2.0);
assert!((result - 5.0).abs() < 0.01); // large dt -> near target
}
}

View File

@@ -0,0 +1,18 @@
@group(0) @binding(0) var hdr_texture: texture_2d<f32>;
@group(0) @binding(1) var<storage, read_write> result: array<atomic<u32>>;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dims = textureDimensions(hdr_texture);
if (gid.x >= dims.x || gid.y >= dims.y) {
return;
}
let color = textureLoad(hdr_texture, vec2<i32>(gid.xy), 0);
let lum = 0.2126 * color.r + 0.7152 * color.g + 0.0722 * color.b;
let log_lum = log(max(lum, 0.0001));
let fixed = i32(log_lum * 1000.0);
atomicAdd(&result[0], bitcast<u32>(fixed));
atomicAdd(&result[1], 1u);
}

View File

@@ -0,0 +1,52 @@
/// Bilateral bloom weight: attenuates blur across brightness edges.
pub fn bilateral_bloom_weight(
center_luminance: f32,
sample_luminance: f32,
spatial_weight: f32,
sigma_luminance: f32,
) -> f32 {
let lum_diff = (center_luminance - sample_luminance).abs();
let lum_weight = (-lum_diff * lum_diff / (2.0 * sigma_luminance * sigma_luminance)).exp();
spatial_weight * lum_weight
}
/// Calculate luminance from RGB.
pub fn luminance(r: f32, g: f32, b: f32) -> f32 {
0.2126 * r + 0.7152 * g + 0.0722 * b
}
/// 5-tap Gaussian weights for 1D bloom blur.
pub fn gaussian_5tap() -> [f32; 5] {
// Sigma ≈ 1.4
[0.0625, 0.25, 0.375, 0.25, 0.0625]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bilateral_bloom_same_luminance() {
let w = bilateral_bloom_weight(0.5, 0.5, 1.0, 0.1);
assert!((w - 1.0).abs() < 0.01); // same lum → full weight
}
#[test]
fn test_bilateral_bloom_edge() {
let w = bilateral_bloom_weight(0.1, 0.9, 1.0, 0.1);
assert!(w < 0.01); // large lum diff → near zero
}
#[test]
fn test_luminance_white() {
let l = luminance(1.0, 1.0, 1.0);
assert!((l - 1.0).abs() < 0.01);
}
#[test]
fn test_gaussian_5tap_sum() {
let g = gaussian_5tap();
let sum: f32 = g.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
}
}

View File

@@ -0,0 +1,156 @@
pub struct BilateralBlur {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
}
impl BilateralBlur {
pub fn new(device: &wgpu::Device) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Bilateral Blur Compute"),
source: wgpu::ShaderSource::Wgsl(include_str!("bilateral_blur.wgsl").into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Bilateral Blur BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Float { filterable: false },
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Depth,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Float { filterable: false },
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture {
access: wgpu::StorageTextureAccess::WriteOnly,
format: wgpu::TextureFormat::Rgba16Float,
view_dimension: wgpu::TextureViewDimension::D2,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Bilateral Blur PL"),
bind_group_layouts: &[&bind_group_layout],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Bilateral Blur Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
BilateralBlur { pipeline, bind_group_layout }
}
pub fn dispatch(
&self,
device: &wgpu::Device,
encoder: &mut wgpu::CommandEncoder,
input_view: &wgpu::TextureView,
depth_view: &wgpu::TextureView,
normal_view: &wgpu::TextureView,
output_view: &wgpu::TextureView,
width: u32,
height: u32,
) {
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Bilateral Blur BG"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(input_view) },
wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::TextureView(depth_view) },
wgpu::BindGroupEntry { binding: 2, resource: wgpu::BindingResource::TextureView(normal_view) },
wgpu::BindGroupEntry { binding: 3, resource: wgpu::BindingResource::TextureView(output_view) },
],
});
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Bilateral Blur Pass"),
timestamp_writes: None,
});
cpass.set_pipeline(&self.pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups((width + 15) / 16, (height + 15) / 16, 1);
}
}
/// Pure CPU bilateral weight calculation (for testing).
pub fn bilateral_weight(
spatial_dist: f32,
depth_diff: f32,
normal_dot: f32,
sigma_spatial: f32,
sigma_depth: f32,
sigma_normal: f32,
) -> f32 {
let w_spatial = (-spatial_dist * spatial_dist / (2.0 * sigma_spatial * sigma_spatial)).exp();
let w_depth = (-depth_diff.abs() / sigma_depth).exp();
let w_normal = normal_dot.max(0.0).powf(sigma_normal);
w_spatial * w_depth * w_normal
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bilateral_weight_center() {
// At center: dist=0, depth_diff=0, normal_dot=1
let w = bilateral_weight(0.0, 0.0, 1.0, 2.0, 0.1, 16.0);
assert!((w - 1.0).abs() < 0.01); // all factors = 1
}
#[test]
fn test_bilateral_weight_depth_edge() {
// Large depth difference -> low weight
let w = bilateral_weight(0.0, 10.0, 1.0, 2.0, 0.1, 16.0);
assert!(w < 0.01, "large depth diff should give low weight: {}", w);
}
#[test]
fn test_bilateral_weight_normal_edge() {
// Perpendicular normals -> low weight
let w = bilateral_weight(0.0, 0.0, 0.0, 2.0, 0.1, 16.0);
assert!(w < 0.01, "perpendicular normals should give low weight: {}", w);
}
#[test]
fn test_bilateral_weight_distance() {
// Far spatial distance -> lower weight than near
let w_near = bilateral_weight(1.0, 0.0, 1.0, 2.0, 0.1, 16.0);
let w_far = bilateral_weight(4.0, 0.0, 1.0, 2.0, 0.1, 16.0);
assert!(w_near > w_far);
}
}

View File

@@ -0,0 +1,58 @@
@group(0) @binding(0) var input_tex: texture_2d<f32>;
@group(0) @binding(1) var depth_tex: texture_depth_2d;
@group(0) @binding(2) var normal_tex: texture_2d<f32>;
@group(0) @binding(3) var output_tex: texture_storage_2d<rgba16float, write>;
const KERNEL_RADIUS: i32 = 2;
const SIGMA_SPATIAL: f32 = 2.0;
const SIGMA_DEPTH: f32 = 0.1;
const SIGMA_NORMAL: f32 = 16.0;
fn gaussian(x: f32, sigma: f32) -> f32 {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dims = textureDimensions(input_tex);
if (gid.x >= dims.x || gid.y >= dims.y) { return; }
let center = vec2<i32>(gid.xy);
let center_depth = textureLoad(depth_tex, center, 0);
let center_normal = textureLoad(normal_tex, center, 0).xyz;
var sum = vec4<f32>(0.0);
var weight_sum = 0.0;
for (var dy = -KERNEL_RADIUS; dy <= KERNEL_RADIUS; dy++) {
for (var dx = -KERNEL_RADIUS; dx <= KERNEL_RADIUS; dx++) {
let offset = vec2<i32>(dx, dy);
let sample_pos = center + offset;
if (sample_pos.x < 0 || sample_pos.y < 0 ||
sample_pos.x >= i32(dims.x) || sample_pos.y >= i32(dims.y)) {
continue;
}
let spatial_dist = sqrt(f32(dx * dx + dy * dy));
let w_spatial = gaussian(spatial_dist, SIGMA_SPATIAL);
let sample_depth = textureLoad(depth_tex, sample_pos, 0);
let depth_diff = abs(center_depth - sample_depth);
let w_depth = exp(-depth_diff / SIGMA_DEPTH);
let sample_normal = textureLoad(normal_tex, sample_pos, 0).xyz;
let n_dot = max(dot(center_normal, sample_normal), 0.0);
let w_normal = pow(n_dot, SIGMA_NORMAL);
let weight = w_spatial * w_depth * w_normal;
let sample_color = textureLoad(input_tex, sample_pos, 0);
sum += sample_color * weight;
weight_sum += weight;
}
}
let result = select(vec4<f32>(0.0), sum / weight_sum, weight_sum > 0.0);
textureStore(output_tex, vec2<i32>(gid.xy), result);
}

View File

@@ -0,0 +1,34 @@
/// Tracks which meshes need BLAS rebuild.
pub struct BlasTracker {
dirty: Vec<(u32, bool)>, // (mesh_id, needs_rebuild)
}
impl BlasTracker {
pub fn new() -> Self { BlasTracker { dirty: Vec::new() } }
pub fn register(&mut self, mesh_id: u32) { self.dirty.push((mesh_id, false)); }
pub fn mark_dirty(&mut self, mesh_id: u32) {
if let Some(entry) = self.dirty.iter_mut().find(|(id, _)| *id == mesh_id) { entry.1 = true; }
}
pub fn dirty_meshes(&self) -> Vec<u32> { self.dirty.iter().filter(|(_, d)| *d).map(|(id, _)| *id).collect() }
pub fn clear_dirty(&mut self) { for entry in &mut self.dirty { entry.1 = false; } }
pub fn mesh_count(&self) -> usize { self.dirty.len() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_and_dirty() {
let mut t = BlasTracker::new();
t.register(1); t.register(2);
t.mark_dirty(1);
assert_eq!(t.dirty_meshes(), vec![1]);
}
#[test]
fn test_clear_dirty() {
let mut t = BlasTracker::new();
t.register(1); t.mark_dirty(1);
t.clear_dirty();
assert!(t.dirty_meshes().is_empty());
}
}

View File

@@ -0,0 +1,89 @@
// GPU Compute shader for BRDF LUT generation (split-sum approximation).
// Workgroup size: 16x16, each thread computes one texel.
// Output: Rg16Float texture with (scale, bias) per texel.
@group(0) @binding(0) var output_tex: texture_storage_2d<rg16float, write>;
const PI: f32 = 3.14159265358979;
const NUM_SAMPLES: u32 = 1024u;
// Van der Corput radical inverse via bit-reversal
fn radical_inverse_vdc(bits_in: u32) -> f32 {
var bits = bits_in;
bits = (bits << 16u) | (bits >> 16u);
bits = ((bits & 0x55555555u) << 1u) | ((bits & 0xAAAAAAAAu) >> 1u);
bits = ((bits & 0x33333333u) << 2u) | ((bits & 0xCCCCCCCCu) >> 2u);
bits = ((bits & 0x0F0F0F0Fu) << 4u) | ((bits & 0xF0F0F0F0u) >> 4u);
bits = ((bits & 0x00FF00FFu) << 8u) | ((bits & 0xFF00FF00u) >> 8u);
return f32(bits) * 2.3283064365386963e-10; // / 0x100000000
}
// Hammersley low-discrepancy 2D sample
fn hammersley(i: u32, n: u32) -> vec2<f32> {
return vec2<f32>(f32(i) / f32(n), radical_inverse_vdc(i));
}
// GGX importance-sampled half vector in tangent space (N = (0,0,1))
fn importance_sample_ggx(xi: vec2<f32>, roughness: f32) -> vec3<f32> {
let a = roughness * roughness;
let phi = 2.0 * PI * xi.x;
let cos_theta = sqrt((1.0 - xi.y) / (1.0 + (a * a - 1.0) * xi.y));
let sin_theta = sqrt(max(1.0 - cos_theta * cos_theta, 0.0));
return vec3<f32>(cos(phi) * sin_theta, sin(phi) * sin_theta, cos_theta);
}
// Smith geometry function for IBL: k = a^2/2
fn geometry_smith_ibl(n_dot_v: f32, n_dot_l: f32, roughness: f32) -> f32 {
let a = roughness * roughness;
let k = a / 2.0;
let ggx_v = n_dot_v / (n_dot_v * (1.0 - k) + k);
let ggx_l = n_dot_l / (n_dot_l * (1.0 - k) + k);
return ggx_v * ggx_l;
}
@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dims = textureDimensions(output_tex);
if gid.x >= dims.x || gid.y >= dims.y {
return;
}
let size = f32(dims.x);
let n_dot_v = (f32(gid.x) + 0.5) / size;
let roughness = clamp((f32(gid.y) + 0.5) / size, 0.0, 1.0);
let n_dot_v_clamped = clamp(n_dot_v, 0.0, 1.0);
// View vector in tangent space where N = (0,0,1)
let v = vec3<f32>(sqrt(max(1.0 - n_dot_v_clamped * n_dot_v_clamped, 0.0)), 0.0, n_dot_v_clamped);
var scale = 0.0;
var bias = 0.0;
for (var i = 0u; i < NUM_SAMPLES; i++) {
let xi = hammersley(i, NUM_SAMPLES);
let h = importance_sample_ggx(xi, roughness);
// dot(V, H)
let v_dot_h = max(dot(v, h), 0.0);
// Reflect V around H to get L
let l = 2.0 * v_dot_h * h - v;
let n_dot_l = max(l.z, 0.0); // L.z in tangent space
let n_dot_h = max(h.z, 0.0);
if n_dot_l > 0.0 {
let g = geometry_smith_ibl(n_dot_v_clamped, n_dot_l, roughness);
let g_vis = g * v_dot_h / max(n_dot_h * n_dot_v_clamped, 0.001);
let fc = pow(1.0 - v_dot_h, 5.0);
scale += g_vis * (1.0 - fc);
bias += g_vis * fc;
}
}
scale /= f32(NUM_SAMPLES);
bias /= f32(NUM_SAMPLES);
textureStore(output_tex, vec2<i32>(i32(gid.x), i32(gid.y)), vec4<f32>(scale, bias, 0.0, 1.0));
}

View File

@@ -0,0 +1,246 @@
use bytemuck::{Pod, Zeroable};
use voltex_math::{Mat4, Vec3, Vec4};
pub const CSM_CASCADE_COUNT: usize = 2;
pub const CSM_MAP_SIZE: u32 = 2048;
pub const CSM_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Depth32Float;
/// Cascaded Shadow Map with 2 cascades.
pub struct CascadedShadowMap {
pub textures: [wgpu::Texture; CSM_CASCADE_COUNT],
pub views: [wgpu::TextureView; CSM_CASCADE_COUNT],
pub sampler: wgpu::Sampler,
}
impl CascadedShadowMap {
pub fn new(device: &wgpu::Device) -> Self {
let create_cascade = |label: &str| {
let texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some(label),
size: wgpu::Extent3d {
width: CSM_MAP_SIZE,
height: CSM_MAP_SIZE,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: CSM_FORMAT,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
(texture, view)
};
let (t0, v0) = create_cascade("CSM Cascade 0");
let (t1, v1) = create_cascade("CSM Cascade 1");
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("CSM Sampler"),
address_mode_u: wgpu::AddressMode::ClampToEdge,
address_mode_v: wgpu::AddressMode::ClampToEdge,
address_mode_w: wgpu::AddressMode::ClampToEdge,
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::MipmapFilterMode::Nearest,
compare: Some(wgpu::CompareFunction::LessEqual),
..Default::default()
});
Self {
textures: [t0, t1],
views: [v0, v1],
sampler,
}
}
}
/// CSM uniform data: 2 light-view-proj matrices, cascade split distance, shadow params.
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub struct CsmUniform {
pub light_view_proj: [[[f32; 4]; 4]; CSM_CASCADE_COUNT], // 128 bytes
pub cascade_split: f32, // view-space depth where cascade 0 ends and cascade 1 begins
pub shadow_map_size: f32,
pub shadow_bias: f32,
pub _padding: f32, // 16 bytes total for last row
}
/// Compute the 8 corners of a sub-frustum in world space, given the camera's
/// inverse view-projection and sub-frustum near/far in NDC z [0..1] range.
fn frustum_corners_world(inv_vp: &Mat4, z_near_ndc: f32, z_far_ndc: f32) -> [Vec3; 8] {
let ndc_corners = [
// Near plane corners
Vec4::new(-1.0, -1.0, z_near_ndc, 1.0),
Vec4::new( 1.0, -1.0, z_near_ndc, 1.0),
Vec4::new( 1.0, 1.0, z_near_ndc, 1.0),
Vec4::new(-1.0, 1.0, z_near_ndc, 1.0),
// Far plane corners
Vec4::new(-1.0, -1.0, z_far_ndc, 1.0),
Vec4::new( 1.0, -1.0, z_far_ndc, 1.0),
Vec4::new( 1.0, 1.0, z_far_ndc, 1.0),
Vec4::new(-1.0, 1.0, z_far_ndc, 1.0),
];
let mut world_corners = [Vec3::ZERO; 8];
for (i, ndc) in ndc_corners.iter().enumerate() {
let w = inv_vp.mul_vec4(*ndc);
world_corners[i] = Vec3::new(w.x / w.w, w.y / w.w, w.z / w.w);
}
world_corners
}
/// Compute a tight orthographic light-view-projection matrix for a set of frustum corners.
fn light_matrix_for_corners(light_dir: Vec3, corners: &[Vec3; 8]) -> Mat4 {
// Build a light-space view matrix looking in the light direction.
let center = {
let mut c = Vec3::ZERO;
for corner in corners {
c = c + *corner;
}
c * (1.0 / 8.0)
};
// Pick a stable up vector that isn't parallel to light_dir.
let up = if light_dir.cross(Vec3::Y).length_squared() < 1e-6 {
Vec3::Z
} else {
Vec3::Y
};
let light_view = Mat4::look_at(
center - light_dir * 0.5, // eye slightly behind center along light direction
center,
up,
);
// Transform all corners into light view space and find AABB.
let mut min_x = f32::MAX;
let mut max_x = f32::MIN;
let mut min_y = f32::MAX;
let mut max_y = f32::MIN;
let mut min_z = f32::MAX;
let mut max_z = f32::MIN;
for corner in corners {
let v = light_view.mul_vec4(Vec4::from_vec3(*corner, 1.0));
let p = Vec3::new(v.x, v.y, v.z);
min_x = min_x.min(p.x);
max_x = max_x.max(p.x);
min_y = min_y.min(p.y);
max_y = max_y.max(p.y);
min_z = min_z.min(p.z);
max_z = max_z.max(p.z);
}
// Extend the z range to catch shadow casters behind the frustum.
let z_margin = (max_z - min_z) * 2.0;
min_z -= z_margin;
let light_proj = Mat4::orthographic(min_x, max_x, min_y, max_y, -max_z, -min_z);
light_proj.mul_mat4(&light_view)
}
/// Compute cascade light-view-projection matrices for 2 cascades.
///
/// - `light_dir`: normalized direction **toward** the light source (opposite of light travel).
/// Internally we negate it to get the light travel direction.
/// - `camera_view`, `camera_proj`: the camera's view and projection matrices.
/// - `near`, `far`: camera near/far planes.
/// - `split`: the view-space depth where cascade 0 ends and cascade 1 begins.
///
/// Returns two light-view-projection matrices, one for each cascade.
pub fn compute_cascade_matrices(
light_dir: Vec3,
camera_view: &Mat4,
camera_proj: &Mat4,
near: f32,
far: f32,
split: f32,
) -> [Mat4; CSM_CASCADE_COUNT] {
let vp = camera_proj.mul_mat4(camera_view);
let inv_vp = vp.inverse().expect("Camera VP matrix must be invertible");
// Map view-space depth to NDC z. For wgpu perspective:
// ndc_z = (far * (z_view + near)) / (z_view * (far - near))
// But since z_view is negative in RH, and we want the NDC value, we use
// the projection matrix directly by projecting (0, 0, -depth, 1).
let depth_to_ndc = |depth: f32| -> f32 {
let clip = camera_proj.mul_vec4(Vec4::new(0.0, 0.0, -depth, 1.0));
clip.z / clip.w
};
let ndc_near = depth_to_ndc(near);
let ndc_split = depth_to_ndc(split);
let ndc_far = depth_to_ndc(far);
// Light direction is the direction light travels (away from the source).
let dir = (-light_dir).normalize();
let corners0 = frustum_corners_world(&inv_vp, ndc_near, ndc_split);
let corners1 = frustum_corners_world(&inv_vp, ndc_split, ndc_far);
[
light_matrix_for_corners(dir, &corners0),
light_matrix_for_corners(dir, &corners1),
]
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem;
#[test]
fn test_csm_uniform_size() {
// Must be multiple of 16 for WGSL uniform alignment
assert_eq!(mem::size_of::<CsmUniform>() % 16, 0,
"CsmUniform must be 16-byte aligned, got {} bytes", mem::size_of::<CsmUniform>());
}
#[test]
fn test_compute_cascade_matrices_produces_valid_matrices() {
let light_dir = Vec3::new(0.0, -1.0, -1.0).normalize();
let view = Mat4::look_at(Vec3::new(0.0, 5.0, 10.0), Vec3::ZERO, Vec3::Y);
let proj = Mat4::perspective(std::f32::consts::FRAC_PI_4, 16.0 / 9.0, 0.1, 100.0);
let matrices = compute_cascade_matrices(light_dir, &view, &proj, 0.1, 100.0, 20.0);
// Both matrices should not be identity (they should be actual projections)
assert_ne!(matrices[0].cols, Mat4::IDENTITY.cols, "Cascade 0 should not be identity");
assert_ne!(matrices[1].cols, Mat4::IDENTITY.cols, "Cascade 1 should not be identity");
}
#[test]
fn test_cascade_split_distance() {
// The split distance should partition the frustum
let light_dir = Vec3::new(0.0, -1.0, 0.0).normalize();
let view = Mat4::look_at(Vec3::new(0.0, 0.0, 5.0), Vec3::ZERO, Vec3::Y);
let proj = Mat4::perspective(std::f32::consts::FRAC_PI_2, 1.0, 1.0, 50.0);
let split = 15.0;
let matrices = compute_cascade_matrices(light_dir, &view, &proj, 1.0, 50.0, split);
// Both matrices should be different (covering different frustum regions)
let differ = matrices[0].cols.iter()
.zip(matrices[1].cols.iter())
.any(|(a, b)| {
a.iter().zip(b.iter()).any(|(x, y)| (x - y).abs() > 1e-3)
});
assert!(differ, "Cascade matrices should differ for different frustum regions");
}
#[test]
fn test_frustum_corners_world_identity() {
// With identity inverse VP, corners should be at NDC positions.
let inv_vp = Mat4::IDENTITY;
let corners = frustum_corners_world(&inv_vp, 0.0, 1.0);
// Near plane at z=0
assert!((corners[0].z - 0.0).abs() < 1e-5);
// Far plane at z=1
assert!((corners[4].z - 1.0).abs() < 1e-5);
}
}

View File

@@ -20,6 +20,10 @@ struct MaterialUniform {
@group(1) @binding(1) var s_albedo: sampler;
@group(1) @binding(2) var t_normal: texture_2d<f32>;
@group(1) @binding(3) var s_normal: sampler;
@group(1) @binding(4) var t_orm: texture_2d<f32>;
@group(1) @binding(5) var s_orm: sampler;
@group(1) @binding(6) var t_emissive: texture_2d<f32>;
@group(1) @binding(7) var s_emissive: sampler;
@group(2) @binding(0) var<uniform> material: MaterialUniform;
@@ -84,11 +88,21 @@ fn fs_main(in: VertexOutput) -> GBufferOutput {
let TBN = mat3x3<f32>(T, B, N_geom);
let N = normalize(TBN * tangent_normal);
// Sample ORM texture: R=AO, G=Roughness, B=Metallic; multiply with material params
let orm_sample = textureSample(t_orm, s_orm, in.uv);
let ao = orm_sample.r * material.ao;
let roughness = orm_sample.g * material.roughness;
let metallic = orm_sample.b * material.metallic;
// Sample emissive texture and compute luminance
let emissive = textureSample(t_emissive, s_emissive, in.uv).rgb;
let emissive_lum = dot(emissive, vec3<f32>(0.299, 0.587, 0.114));
var out: GBufferOutput;
out.position = vec4<f32>(in.world_pos, 1.0);
out.normal = vec4<f32>(N * 0.5 + 0.5, 1.0);
out.albedo = vec4<f32>(albedo, material.base_color.a * tex_color.a);
out.material_data = vec4<f32>(material.metallic, material.roughness, material.ao, 1.0);
out.material_data = vec4<f32>(metallic, roughness, ao, emissive_lum);
return out;
}

View File

@@ -42,6 +42,10 @@ struct ShadowUniform {
light_view_proj: mat4x4<f32>,
shadow_map_size: f32,
shadow_bias: f32,
_padding: vec2<f32>,
sun_direction: vec3<f32>,
turbidity: f32,
sh_coefficients: array<vec4<f32>, 7>,
};
@group(2) @binding(0) var t_shadow: texture_depth_2d;
@@ -260,6 +264,7 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let metallic = mat_sample.r;
let roughness = mat_sample.g;
let ao = mat_sample.b;
let emissive_lum = mat_sample.w;
let V = normalize(camera_uniform.camera_pos - world_pos);
@@ -306,7 +311,10 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let ambient = (diffuse_ibl + specular_ibl) * ao * ssgi_ao + ssgi_indirect;
// Output raw HDR linear colour; tonemap is applied in a separate tonemap pass.
let color = ambient + Lo;
var color = ambient + Lo;
// Add emissive contribution (luminance stored in G-Buffer, modulated by albedo)
color += albedo * emissive_lum;
return vec4<f32>(color, alpha);
}

View File

@@ -0,0 +1,140 @@
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
pub struct DofParams {
pub focus_distance: f32,
pub focus_range: f32,
pub max_blur: f32,
pub _pad: f32,
}
impl DofParams {
pub fn new() -> Self {
DofParams { focus_distance: 5.0, focus_range: 3.0, max_blur: 5.0, _pad: 0.0 }
}
}
pub struct DepthOfField {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
params_buffer: wgpu::Buffer,
pub enabled: bool,
}
impl DepthOfField {
pub fn new(device: &wgpu::Device) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("DOF Compute"),
source: wgpu::ShaderSource::Wgsl(include_str!("dof.wgsl").into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("DOF BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Float { filterable: false } },
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture { multisampled: false, view_dimension: wgpu::TextureViewDimension::D2, sample_type: wgpu::TextureSampleType::Depth },
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture { access: wgpu::StorageTextureAccess::WriteOnly, format: wgpu::TextureFormat::Rgba16Float, view_dimension: wgpu::TextureViewDimension::D2 },
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3, visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None },
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("DOF PL"), bind_group_layouts: &[&bind_group_layout], immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("DOF Pipeline"), layout: Some(&pipeline_layout),
module: &shader, entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(), cache: None,
});
let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("DOF Params"),
size: std::mem::size_of::<DofParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
DepthOfField { pipeline, bind_group_layout, params_buffer, enabled: true }
}
pub fn dispatch(
&self, device: &wgpu::Device, queue: &wgpu::Queue, encoder: &mut wgpu::CommandEncoder,
color_view: &wgpu::TextureView, depth_view: &wgpu::TextureView,
output_view: &wgpu::TextureView, params: &DofParams, width: u32, height: u32,
) {
queue.write_buffer(&self.params_buffer, 0, bytemuck::cast_slice(&[*params]));
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("DOF BG"), layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: wgpu::BindingResource::TextureView(color_view) },
wgpu::BindGroupEntry { binding: 1, resource: wgpu::BindingResource::TextureView(depth_view) },
wgpu::BindGroupEntry { binding: 2, resource: wgpu::BindingResource::TextureView(output_view) },
wgpu::BindGroupEntry { binding: 3, resource: self.params_buffer.as_entire_binding() },
],
});
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("DOF Pass"), timestamp_writes: None });
cpass.set_pipeline(&self.pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups((width + 15) / 16, (height + 15) / 16, 1);
}
}
/// CPU circle-of-confusion calculation (for testing).
pub fn circle_of_confusion(depth: f32, focus_distance: f32, focus_range: f32, max_blur: f32) -> f32 {
let diff = (depth - focus_distance).abs();
(diff / focus_range).clamp(0.0, 1.0) * max_blur
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coc_in_focus() {
let coc = circle_of_confusion(5.0, 5.0, 3.0, 5.0);
assert!((coc - 0.0).abs() < 1e-6); // at focus distance
}
#[test]
fn test_coc_near() {
let coc = circle_of_confusion(2.0, 5.0, 3.0, 5.0);
assert!((coc - 5.0).abs() < 1e-6); // diff=3, range=3 -> full blur
}
#[test]
fn test_coc_far() {
let coc = circle_of_confusion(8.0, 5.0, 3.0, 5.0);
assert!((coc - 5.0).abs() < 1e-6); // diff=3, range=3 -> full blur
}
#[test]
fn test_coc_partial() {
let coc = circle_of_confusion(6.5, 5.0, 3.0, 5.0);
assert!((coc - 2.5).abs() < 1e-6); // diff=1.5, range=3 -> 0.5 * 5 = 2.5
}
#[test]
fn test_params_default() {
let p = DofParams::new();
assert!((p.focus_distance - 5.0).abs() < 1e-6);
assert!((p.focus_range - 3.0).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,55 @@
struct DofParams {
focus_distance: f32,
focus_range: f32,
max_blur: f32,
_pad: f32,
};
@group(0) @binding(0) var color_tex: texture_2d<f32>;
@group(0) @binding(1) var depth_tex: texture_depth_2d;
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba16float, write>;
@group(0) @binding(3) var<uniform> params: DofParams;
fn circle_of_confusion(depth: f32) -> f32 {
let diff = abs(depth - params.focus_distance);
let coc = clamp(diff / params.focus_range, 0.0, 1.0) * params.max_blur;
return coc;
}
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dims = textureDimensions(color_tex);
if (gid.x >= dims.x || gid.y >= dims.y) { return; }
let pos = vec2<i32>(gid.xy);
let depth = textureLoad(depth_tex, pos, 0);
let coc = circle_of_confusion(depth);
if (coc < 0.5) {
// In focus — no blur
textureStore(output_tex, pos, textureLoad(color_tex, pos, 0));
return;
}
// Disc blur with radius = coc
let radius = i32(coc);
var color = vec4<f32>(0.0);
var weight = 0.0;
for (var dy = -radius; dy <= radius; dy++) {
for (var dx = -radius; dx <= radius; dx++) {
let dist = sqrt(f32(dx * dx + dy * dy));
if (dist > coc) { continue; }
let sample_pos = pos + vec2<i32>(dx, dy);
let clamped = clamp(sample_pos, vec2<i32>(0), vec2<i32>(dims) - 1);
let sample_color = textureLoad(color_tex, clamped, 0);
let w = 1.0 - dist / (coc + 0.001);
color += sample_color * w;
weight += w;
}
}
let result = select(textureLoad(color_tex, pos, 0), color / weight, weight > 0.0);
textureStore(output_tex, pos, result);
}

View File

@@ -0,0 +1,184 @@
use crate::vertex::MeshVertex;
use crate::hdr::HDR_FORMAT;
use crate::gpu::DEPTH_FORMAT;
use crate::mesh::Mesh;
pub struct ForwardPass {
pipeline: wgpu::RenderPipeline,
}
impl ForwardPass {
pub fn new(
device: &wgpu::Device,
camera_light_layout: &wgpu::BindGroupLayout,
texture_layout: &wgpu::BindGroupLayout,
) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Forward Shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("forward_shader.wgsl").into()),
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Forward Pipeline Layout"),
bind_group_layouts: &[camera_light_layout, texture_layout],
immediate_size: 0,
});
let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("Forward Pipeline"),
layout: Some(&layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[MeshVertex::LAYOUT],
compilation_options: wgpu::PipelineCompilationOptions::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format: HDR_FORMAT,
blend: Some(wgpu::BlendState {
color: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::SrcAlpha,
dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
operation: wgpu::BlendOperation::Add,
},
alpha: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::One,
dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
operation: wgpu::BlendOperation::Add,
},
}),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: wgpu::PipelineCompilationOptions::default(),
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: wgpu::FrontFace::Ccw,
cull_mode: None, // No culling for transparent objects (see both sides)
polygon_mode: wgpu::PolygonMode::Fill,
unclipped_depth: false,
conservative: false,
},
depth_stencil: Some(wgpu::DepthStencilState {
format: DEPTH_FORMAT,
depth_write_enabled: false, // Don't write depth (preserve opaque depth)
depth_compare: wgpu::CompareFunction::LessEqual,
stencil: wgpu::StencilState::default(),
bias: wgpu::DepthBiasState::default(),
}),
multisample: wgpu::MultisampleState {
count: 1,
mask: !0,
alpha_to_coverage_enabled: false,
},
multiview_mask: None,
cache: None,
});
ForwardPass { pipeline }
}
pub fn render<'a>(
&'a self,
encoder: &'a mut wgpu::CommandEncoder,
hdr_view: &wgpu::TextureView,
depth_view: &wgpu::TextureView,
camera_light_bg: &'a wgpu::BindGroup,
texture_bg: &'a wgpu::BindGroup,
meshes: &'a [&Mesh],
) {
let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: Some("Forward Transparency Pass"),
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view: hdr_view,
resolve_target: None,
depth_slice: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Load,
store: wgpu::StoreOp::Store,
},
})],
depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment {
view: depth_view,
depth_ops: Some(wgpu::Operations {
load: wgpu::LoadOp::Load,
store: wgpu::StoreOp::Store,
}),
stencil_ops: None,
}),
occlusion_query_set: None,
timestamp_writes: None,
multiview_mask: None,
});
rpass.set_pipeline(&self.pipeline);
rpass.set_bind_group(0, camera_light_bg, &[]);
rpass.set_bind_group(1, texture_bg, &[]);
for mesh in meshes {
rpass.set_vertex_buffer(0, mesh.vertex_buffer.slice(..));
rpass.set_index_buffer(mesh.index_buffer.slice(..), wgpu::IndexFormat::Uint32);
rpass.draw_indexed(0..mesh.num_indices, 0, 0..1);
}
}
}
/// Sort transparent objects back-to-front by distance from camera.
pub fn sort_transparent_back_to_front(
items: &mut Vec<(usize, [f32; 3])>, // (index, center_position)
camera_pos: [f32; 3],
) {
items.sort_by(|a, b| {
let da = dist_sq(a.1, camera_pos);
let db = dist_sq(b.1, camera_pos);
db.partial_cmp(&da).unwrap_or(std::cmp::Ordering::Equal)
});
}
fn dist_sq(a: [f32; 3], b: [f32; 3]) -> f32 {
let dx = a[0] - b[0];
let dy = a[1] - b[1];
let dz = a[2] - b[2];
dx * dx + dy * dy + dz * dz
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_back_to_front() {
let mut items = vec![
(0, [1.0, 0.0, 0.0]), // close
(1, [10.0, 0.0, 0.0]), // far
(2, [5.0, 0.0, 0.0]), // mid
];
sort_transparent_back_to_front(&mut items, [0.0, 0.0, 0.0]);
assert_eq!(items[0].0, 1); // farthest first
assert_eq!(items[1].0, 2);
assert_eq!(items[2].0, 0); // closest last
}
#[test]
fn test_sort_equal_distance() {
let mut items = vec![
(0, [1.0, 0.0, 0.0]),
(1, [0.0, 1.0, 0.0]),
(2, [0.0, 0.0, 1.0]),
];
// All at distance 1.0 from origin — should not crash
sort_transparent_back_to_front(&mut items, [0.0, 0.0, 0.0]);
assert_eq!(items.len(), 3);
}
#[test]
fn test_sort_empty() {
let mut items: Vec<(usize, [f32; 3])> = vec![];
sort_transparent_back_to_front(&mut items, [0.0, 0.0, 0.0]);
assert!(items.is_empty());
}
}

View File

@@ -0,0 +1,67 @@
struct CameraUniform {
view_proj: mat4x4<f32>,
model: mat4x4<f32>,
camera_pos: vec3<f32>,
alpha: f32,
};
struct LightUniform {
direction: vec3<f32>,
_pad0: f32,
color: vec3<f32>,
ambient_strength: f32,
};
@group(0) @binding(0) var<uniform> camera: CameraUniform;
@group(0) @binding(1) var<uniform> light: LightUniform;
@group(1) @binding(0) var t_diffuse: texture_2d<f32>;
@group(1) @binding(1) var s_diffuse: sampler;
struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) normal: vec3<f32>,
@location(2) uv: vec2<f32>,
@location(3) tangent: vec4<f32>,
};
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) world_normal: vec3<f32>,
@location(1) world_pos: vec3<f32>,
@location(2) uv: vec2<f32>,
};
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
var out: VertexOutput;
let world_pos = camera.model * vec4<f32>(in.position, 1.0);
out.world_pos = world_pos.xyz;
out.world_normal = normalize((camera.model * vec4<f32>(in.normal, 0.0)).xyz);
out.clip_position = camera.view_proj * world_pos;
out.uv = in.uv;
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let tex_color = textureSample(t_diffuse, s_diffuse, in.uv);
let normal = normalize(in.world_normal);
let light_dir = normalize(-light.direction);
// Diffuse
let ndotl = max(dot(normal, light_dir), 0.0);
let diffuse = light.color * ndotl;
// Specular (Blinn-Phong)
let view_dir = normalize(camera.camera_pos - in.world_pos);
let half_dir = normalize(light_dir + view_dir);
let spec = pow(max(dot(normal, half_dir), 0.0), 32.0);
let specular = light.color * spec * 0.5;
// Ambient
let ambient = light.color * light.ambient_strength;
let lit = (ambient + diffuse + specular) * tex_color.rgb;
return vec4<f32>(lit, camera.alpha);
}

View File

@@ -0,0 +1,262 @@
use voltex_math::Vec3;
use crate::light::{LightsUniform, LIGHT_DIRECTIONAL, LIGHT_POINT};
/// A plane in 3D space: normal.dot(point) + d = 0
#[derive(Debug, Clone, Copy)]
pub struct Plane {
pub normal: Vec3,
pub d: f32,
}
impl Plane {
/// Normalize the plane equation so that |normal| == 1.
pub fn normalize(&self) -> Self {
let len = self.normal.length();
if len < 1e-10 {
return *self;
}
Self {
normal: Vec3::new(self.normal.x / len, self.normal.y / len, self.normal.z / len),
d: self.d / len,
}
}
/// Signed distance from a point to the plane (positive = inside / front).
pub fn distance(&self, point: Vec3) -> f32 {
self.normal.dot(point) + self.d
}
}
/// Six-plane frustum (left, right, bottom, top, near, far).
#[derive(Debug, Clone, Copy)]
pub struct Frustum {
pub planes: [Plane; 6],
}
/// Extract 6 frustum planes from a view-projection matrix using the
/// Gribb-Hartmann method.
///
/// The planes point inward so that a point is inside if distance >= 0 for all planes.
/// Matrix is column-major `[[f32;4];4]` (same as `Mat4::cols`).
pub fn extract_frustum(view_proj: &voltex_math::Mat4) -> Frustum {
// We work with rows of the VP matrix.
// For column-major storage cols[c][r]:
// row[r] = (cols[0][r], cols[1][r], cols[2][r], cols[3][r])
let m = &view_proj.cols;
let row = |r: usize| -> [f32; 4] {
[m[0][r], m[1][r], m[2][r], m[3][r]]
};
let r0 = row(0);
let r1 = row(1);
let r2 = row(2);
let r3 = row(3);
// Left: row3 + row0
let left = Plane {
normal: Vec3::new(r3[0] + r0[0], r3[1] + r0[1], r3[2] + r0[2]),
d: r3[3] + r0[3],
}.normalize();
// Right: row3 - row0
let right = Plane {
normal: Vec3::new(r3[0] - r0[0], r3[1] - r0[1], r3[2] - r0[2]),
d: r3[3] - r0[3],
}.normalize();
// Bottom: row3 + row1
let bottom = Plane {
normal: Vec3::new(r3[0] + r1[0], r3[1] + r1[1], r3[2] + r1[2]),
d: r3[3] + r1[3],
}.normalize();
// Top: row3 - row1
let top = Plane {
normal: Vec3::new(r3[0] - r1[0], r3[1] - r1[1], r3[2] - r1[2]),
d: r3[3] - r1[3],
}.normalize();
// Near: row2 (wgpu NDC z in [0,1], so near = row2 directly)
let near = Plane {
normal: Vec3::new(r2[0], r2[1], r2[2]),
d: r2[3],
}.normalize();
// Far: row3 - row2
let far = Plane {
normal: Vec3::new(r3[0] - r2[0], r3[1] - r2[1], r3[2] - r2[2]),
d: r3[3] - r2[3],
}.normalize();
Frustum {
planes: [left, right, bottom, top, near, far],
}
}
/// Test whether a sphere (center, radius) is at least partially inside the frustum.
pub fn sphere_vs_frustum(center: Vec3, radius: f32, frustum: &Frustum) -> bool {
for plane in &frustum.planes {
if plane.distance(center) < -radius {
return false;
}
}
true
}
/// Return indices of lights from `lights` that are visible in the given frustum.
///
/// - Directional lights are always included.
/// - Point lights use a bounding sphere (position, range).
/// - Spot lights use a conservative bounding sphere centered at the light position
/// with radius equal to the light range.
pub fn cull_lights(frustum: &Frustum, lights: &LightsUniform) -> Vec<usize> {
let count = lights.count as usize;
let mut visible = Vec::with_capacity(count);
for i in 0..count {
let light = &lights.lights[i];
if light.light_type == LIGHT_DIRECTIONAL {
// Directional lights affect everything
visible.push(i);
} else if light.light_type == LIGHT_POINT {
let center = Vec3::new(light.position[0], light.position[1], light.position[2]);
if sphere_vs_frustum(center, light.range, frustum) {
visible.push(i);
}
} else {
// Spot light — use bounding sphere at position with radius = range
let center = Vec3::new(light.position[0], light.position[1], light.position[2]);
if sphere_vs_frustum(center, light.range, frustum) {
visible.push(i);
}
}
}
visible
}
#[cfg(test)]
mod tests {
use super::*;
use voltex_math::Mat4;
use crate::light::{LightData, LightsUniform};
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn test_frustum_extraction_identity() {
// Identity VP means clip space = NDC directly.
// For wgpu: x,y in [-1,1], z in [0,1].
let frustum = extract_frustum(&Mat4::IDENTITY);
// All 6 planes should be normalized (length ~1)
for (i, plane) in frustum.planes.iter().enumerate() {
let len = plane.normal.length();
assert!(approx_eq(len, 1.0, 1e-4), "Plane {} normal length = {}", i, len);
}
}
#[test]
fn test_frustum_extraction_perspective() {
let proj = Mat4::perspective(
std::f32::consts::FRAC_PI_2, // 90 deg
1.0,
0.1,
100.0,
);
let view = Mat4::look_at(
Vec3::new(0.0, 0.0, 5.0),
Vec3::ZERO,
Vec3::Y,
);
let vp = proj.mul_mat4(&view);
let frustum = extract_frustum(&vp);
// Origin (0,0,0) should be inside the frustum (it's 5 units in front of camera)
assert!(sphere_vs_frustum(Vec3::ZERO, 0.0, &frustum),
"Origin should be inside frustum");
}
#[test]
fn test_sphere_inside_frustum() {
let proj = Mat4::perspective(std::f32::consts::FRAC_PI_2, 1.0, 0.1, 100.0);
let view = Mat4::look_at(Vec3::new(0.0, 0.0, 5.0), Vec3::ZERO, Vec3::Y);
let vp = proj.mul_mat4(&view);
let frustum = extract_frustum(&vp);
// Sphere at origin with radius 1 — well inside
assert!(sphere_vs_frustum(Vec3::ZERO, 1.0, &frustum));
}
#[test]
fn test_sphere_outside_frustum() {
let proj = Mat4::perspective(std::f32::consts::FRAC_PI_2, 1.0, 0.1, 100.0);
let view = Mat4::look_at(Vec3::new(0.0, 0.0, 5.0), Vec3::ZERO, Vec3::Y);
let vp = proj.mul_mat4(&view);
let frustum = extract_frustum(&vp);
// Sphere far behind the camera
assert!(!sphere_vs_frustum(Vec3::new(0.0, 0.0, 200.0), 1.0, &frustum),
"Sphere far behind camera should be outside");
// Sphere far to the side
assert!(!sphere_vs_frustum(Vec3::new(500.0, 0.0, 0.0), 1.0, &frustum),
"Sphere far to the side should be outside");
}
#[test]
fn test_sphere_partially_inside() {
let proj = Mat4::perspective(std::f32::consts::FRAC_PI_2, 1.0, 0.1, 100.0);
let view = Mat4::look_at(Vec3::new(0.0, 0.0, 5.0), Vec3::ZERO, Vec3::Y);
let vp = proj.mul_mat4(&view);
let frustum = extract_frustum(&vp);
// Sphere at far plane boundary but with large radius should be inside
assert!(sphere_vs_frustum(Vec3::new(0.0, 0.0, -96.0), 5.0, &frustum));
}
#[test]
fn test_cull_lights_directional_always_included() {
let proj = Mat4::perspective(std::f32::consts::FRAC_PI_2, 1.0, 0.1, 50.0);
let view = Mat4::look_at(Vec3::new(0.0, 0.0, 5.0), Vec3::ZERO, Vec3::Y);
let vp = proj.mul_mat4(&view);
let frustum = extract_frustum(&vp);
let mut lights = LightsUniform::new();
lights.add_light(LightData::directional([0.0, -1.0, 0.0], [1.0, 1.0, 1.0], 1.0));
lights.add_light(LightData::point([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], 2.0, 5.0));
// Point light far away — should be culled
lights.add_light(LightData::point([500.0, 500.0, 500.0], [0.0, 1.0, 0.0], 2.0, 1.0));
let visible = cull_lights(&frustum, &lights);
// Directional (0) always included, point at origin (1) inside, far point (2) culled
assert!(visible.contains(&0), "Directional light must always be included");
assert!(visible.contains(&1), "Point light at origin should be visible");
assert!(!visible.contains(&2), "Far point light should be culled");
}
#[test]
fn test_cull_lights_spot() {
let proj = Mat4::perspective(std::f32::consts::FRAC_PI_2, 1.0, 0.1, 50.0);
let view = Mat4::look_at(Vec3::new(0.0, 0.0, 5.0), Vec3::ZERO, Vec3::Y);
let vp = proj.mul_mat4(&view);
let frustum = extract_frustum(&vp);
let mut lights = LightsUniform::new();
// Spot light inside frustum
lights.add_light(LightData::spot(
[0.0, 2.0, 0.0], [0.0, -1.0, 0.0], [1.0, 1.0, 1.0], 3.0, 10.0, 15.0, 30.0,
));
// Spot light far away
lights.add_light(LightData::spot(
[300.0, 300.0, 300.0], [0.0, -1.0, 0.0], [1.0, 1.0, 1.0], 3.0, 5.0, 15.0, 30.0,
));
let visible = cull_lights(&frustum, &lights);
assert!(visible.contains(&0), "Near spot should be visible");
assert!(!visible.contains(&1), "Far spot should be culled");
}
}

View File

@@ -0,0 +1,132 @@
/// Octahedral normal encoding: vec3 normal → vec2 (compact).
pub fn encode_octahedral(n: [f32; 3]) -> [f32; 2] {
let sum = n[0].abs() + n[1].abs() + n[2].abs();
let mut oct = [n[0] / sum, n[1] / sum];
if n[2] < 0.0 {
let ox = oct[0];
let oy = oct[1];
oct[0] = (1.0 - oy.abs()) * if ox >= 0.0 { 1.0 } else { -1.0 };
oct[1] = (1.0 - ox.abs()) * if oy >= 0.0 { 1.0 } else { -1.0 };
}
oct
}
/// Decode octahedral back to normal vec3.
pub fn decode_octahedral(oct: [f32; 2]) -> [f32; 3] {
let mut n = [oct[0], oct[1], 1.0 - oct[0].abs() - oct[1].abs()];
if n[2] < 0.0 {
let ox = n[0];
let oy = n[1];
n[0] = (1.0 - oy.abs()) * if ox >= 0.0 { 1.0 } else { -1.0 };
n[1] = (1.0 - ox.abs()) * if oy >= 0.0 { 1.0 } else { -1.0 };
}
let len = (n[0] * n[0] + n[1] * n[1] + n[2] * n[2]).sqrt();
[n[0] / len, n[1] / len, n[2] / len]
}
/// Reconstruct world position from depth + UV + inverse view-projection matrix.
pub fn reconstruct_position(
uv: [f32; 2],
depth: f32,
inv_view_proj: &[[f32; 4]; 4],
) -> [f32; 3] {
let ndc_x = uv[0] * 2.0 - 1.0;
let ndc_y = 1.0 - uv[1] * 2.0;
let clip = [ndc_x, ndc_y, depth, 1.0];
// Matrix multiply: inv_view_proj * clip (column-major)
let mut world = [0.0f32; 4];
for i in 0..4 {
for j in 0..4 {
world[i] += inv_view_proj[j][i] * clip[j];
}
}
if world[3].abs() > 1e-8 {
[
world[0] / world[3],
world[1] / world[3],
world[2] / world[3],
]
} else {
[0.0; 3]
}
}
/// Compressed G-Buffer format recommendations.
pub const COMPRESSED_NORMAL_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rg16Float;
pub const COMPRESSED_ALBEDO_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
pub const COMPRESSED_MATERIAL_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8Unorm;
// Position: reconstructed from depth, no texture needed
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_octahedral_roundtrip_positive_z() {
let n = [0.0, 0.0, 1.0];
let enc = encode_octahedral(n);
let dec = decode_octahedral(enc);
for i in 0..3 {
assert!(
(n[i] - dec[i]).abs() < 0.01,
"axis {}: {} vs {}",
i,
n[i],
dec[i]
);
}
}
#[test]
fn test_octahedral_roundtrip_negative_z() {
let n = [0.0, 0.0, -1.0];
let enc = encode_octahedral(n);
let dec = decode_octahedral(enc);
for i in 0..3 {
assert!((n[i] - dec[i]).abs() < 0.01);
}
}
#[test]
fn test_octahedral_roundtrip_diagonal() {
let s = 1.0 / 3.0_f32.sqrt();
let n = [s, s, s];
let enc = encode_octahedral(n);
let dec = decode_octahedral(enc);
for i in 0..3 {
assert!((n[i] - dec[i]).abs() < 0.01);
}
}
#[test]
fn test_octahedral_roundtrip_axes() {
for n in [
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[-1.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
] {
let dec = decode_octahedral(encode_octahedral(n));
for i in 0..3 {
assert!((n[i] - dec[i]).abs() < 0.02, "{:?} → {:?}", n, dec);
}
}
}
#[test]
fn test_reconstruct_position_identity() {
let identity = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let pos = reconstruct_position([0.5, 0.5], 0.5, &identity);
// UV(0.5,0.5) → NDC(0,0), depth=0.5 → (0, 0, 0.5) in clip space
assert!((pos[0] - 0.0).abs() < 0.01);
assert!((pos[1] - 0.0).abs() < 0.01);
assert!((pos[2] - 0.5).abs() < 0.01);
}
}

View File

@@ -0,0 +1,982 @@
use crate::json_parser::{self, JsonValue};
use crate::vertex::MeshVertex;
use crate::obj::compute_tangents;
pub struct GltfData {
pub meshes: Vec<GltfMesh>,
pub nodes: Vec<GltfNode>,
pub skins: Vec<GltfSkin>,
pub animations: Vec<GltfAnimation>,
}
pub struct GltfMesh {
pub vertices: Vec<MeshVertex>,
pub indices: Vec<u32>,
pub name: Option<String>,
pub material: Option<GltfMaterial>,
pub joints: Option<Vec<[u16; 4]>>,
pub weights: Option<Vec<[f32; 4]>>,
}
pub struct GltfMaterial {
pub base_color: [f32; 4],
pub metallic: f32,
pub roughness: f32,
}
#[derive(Debug, Clone)]
pub struct GltfNode {
pub name: Option<String>,
pub children: Vec<usize>,
pub translation: [f32; 3],
pub rotation: [f32; 4], // quaternion [x,y,z,w]
pub scale: [f32; 3],
pub mesh: Option<usize>,
pub skin: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct GltfSkin {
pub name: Option<String>,
pub joints: Vec<usize>,
pub inverse_bind_matrices: Vec<[[f32; 4]; 4]>,
pub skeleton: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct GltfAnimation {
pub name: Option<String>,
pub channels: Vec<GltfChannel>,
}
#[derive(Debug, Clone)]
pub struct GltfChannel {
pub target_node: usize,
pub target_path: AnimationPath,
pub interpolation: Interpolation,
pub times: Vec<f32>,
pub values: Vec<f32>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AnimationPath {
Translation,
Rotation,
Scale,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Interpolation {
Linear,
Step,
CubicSpline,
}
pub fn parse_animation_path(s: &str) -> AnimationPath {
match s {
"translation" => AnimationPath::Translation,
"rotation" => AnimationPath::Rotation,
"scale" => AnimationPath::Scale,
_ => AnimationPath::Translation,
}
}
pub fn parse_interpolation(s: &str) -> Interpolation {
match s {
"LINEAR" => Interpolation::Linear,
"STEP" => Interpolation::Step,
"CUBICSPLINE" => Interpolation::CubicSpline,
_ => Interpolation::Linear,
}
}
const GLB_MAGIC: u32 = 0x46546C67;
const GLB_VERSION: u32 = 2;
const CHUNK_JSON: u32 = 0x4E4F534A;
const CHUNK_BIN: u32 = 0x004E4942;
pub fn parse_gltf(data: &[u8]) -> Result<GltfData, String> {
if data.len() < 4 {
return Err("Data too short".into());
}
// Detect format: GLB (binary) or JSON
let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
if magic == GLB_MAGIC {
parse_glb(data)
} else if data[0] == b'{' {
parse_gltf_json(data)
} else {
Err("Unknown glTF format: not GLB or JSON".into())
}
}
fn parse_glb(data: &[u8]) -> Result<GltfData, String> {
if data.len() < 12 {
return Err("GLB header too short".into());
}
let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
if version != GLB_VERSION {
return Err(format!("Unsupported GLB version: {} (expected 2)", version));
}
let _total_len = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
// Parse chunks
let mut pos = 12;
let mut json_str = String::new();
let mut bin_data: Vec<u8> = Vec::new();
while pos + 8 <= data.len() {
let chunk_len = u32::from_le_bytes([data[pos], data[pos+1], data[pos+2], data[pos+3]]) as usize;
let chunk_type = u32::from_le_bytes([data[pos+4], data[pos+5], data[pos+6], data[pos+7]]);
pos += 8;
if pos + chunk_len > data.len() {
return Err("Chunk extends past data".into());
}
match chunk_type {
CHUNK_JSON => {
json_str = std::str::from_utf8(&data[pos..pos + chunk_len])
.map_err(|_| "Invalid UTF-8 in JSON chunk".to_string())?
.to_string();
}
CHUNK_BIN => {
bin_data = data[pos..pos + chunk_len].to_vec();
}
_ => {} // skip unknown chunks
}
pos += chunk_len;
// Chunks are 4-byte aligned
pos = (pos + 3) & !3;
}
if json_str.is_empty() {
return Err("No JSON chunk found in GLB".into());
}
let json = json_parser::parse_json(&json_str)?;
let buffers = vec![bin_data]; // GLB has one implicit binary buffer
extract_meshes(&json, &buffers)
}
fn parse_gltf_json(data: &[u8]) -> Result<GltfData, String> {
let json_str = std::str::from_utf8(data).map_err(|_| "Invalid UTF-8".to_string())?;
let json = json_parser::parse_json(json_str)?;
// Resolve buffers (embedded base64 URIs)
let mut buffers = Vec::new();
if let Some(bufs) = json.get("buffers").and_then(|v| v.as_array()) {
for buf in bufs {
if let Some(uri) = buf.get("uri").and_then(|v| v.as_str()) {
if let Some(b64) = uri.strip_prefix("data:application/octet-stream;base64,") {
buffers.push(decode_base64(b64)?);
} else if let Some(b64) = uri.strip_prefix("data:application/gltf-buffer;base64,") {
buffers.push(decode_base64(b64)?);
} else {
return Err(format!("External buffer URIs not supported: {}", uri));
}
} else {
buffers.push(Vec::new());
}
}
}
extract_meshes(&json, &buffers)
}
fn decode_base64(input: &str) -> Result<Vec<u8>, String> {
let table = |c: u8| -> Result<u8, String> {
match c {
b'A'..=b'Z' => Ok(c - b'A'),
b'a'..=b'z' => Ok(c - b'a' + 26),
b'0'..=b'9' => Ok(c - b'0' + 52),
b'+' => Ok(62),
b'/' => Ok(63),
b'=' => Ok(0), // padding
_ => Err(format!("Invalid base64 character: {}", c as char)),
}
};
let bytes: Vec<u8> = input.bytes().filter(|&b| b != b'\n' && b != b'\r' && b != b' ').collect();
let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
for chunk in bytes.chunks(4) {
let b0 = table(chunk[0])?;
let b1 = if chunk.len() > 1 { table(chunk[1])? } else { 0 };
let b2 = if chunk.len() > 2 { table(chunk[2])? } else { 0 };
let b3 = if chunk.len() > 3 { table(chunk[3])? } else { 0 };
out.push((b0 << 2) | (b1 >> 4));
if chunk.len() > 2 && chunk[2] != b'=' {
out.push((b1 << 4) | (b2 >> 2));
}
if chunk.len() > 3 && chunk[3] != b'=' {
out.push((b2 << 6) | b3);
}
}
Ok(out)
}
fn extract_meshes(json: &JsonValue, buffers: &[Vec<u8>]) -> Result<GltfData, String> {
let empty_arr: Vec<JsonValue> = Vec::new();
let accessors = json.get("accessors").and_then(|v| v.as_array()).unwrap_or(&empty_arr);
let buffer_views = json.get("bufferViews").and_then(|v| v.as_array()).unwrap_or(&empty_arr);
let materials_json = json.get("materials").and_then(|v| v.as_array());
let mut meshes = Vec::new();
let mesh_list = json.get("meshes").and_then(|v| v.as_array())
.ok_or("No meshes in glTF")?;
for mesh_val in mesh_list {
let name = mesh_val.get("name").and_then(|v| v.as_str()).map(|s| s.to_string());
let primitives = mesh_val.get("primitives").and_then(|v| v.as_array())
.ok_or("Mesh has no primitives")?;
for prim in primitives {
let attrs = prim.get("attributes").and_then(|v| v.as_object())
.ok_or("Primitive has no attributes")?;
// Read position data (required)
let pos_idx = attrs.iter().find(|(k, _)| k == "POSITION")
.and_then(|(_, v)| v.as_u32())
.ok_or("Missing POSITION attribute")? as usize;
let positions = read_accessor_vec3(accessors, buffer_views, buffers, pos_idx)?;
// Read normals (optional)
let normals = if let Some(idx) = attrs.iter().find(|(k, _)| k == "NORMAL").and_then(|(_, v)| v.as_u32()) {
read_accessor_vec3(accessors, buffer_views, buffers, idx as usize)?
} else {
vec![[0.0, 1.0, 0.0]; positions.len()]
};
// Read UVs (optional)
let uvs = if let Some(idx) = attrs.iter().find(|(k, _)| k == "TEXCOORD_0").and_then(|(_, v)| v.as_u32()) {
read_accessor_vec2(accessors, buffer_views, buffers, idx as usize)?
} else {
vec![[0.0, 0.0]; positions.len()]
};
// Read tangents (optional)
let tangents = if let Some(idx) = attrs.iter().find(|(k, _)| k == "TANGENT").and_then(|(_, v)| v.as_u32()) {
Some(read_accessor_vec4(accessors, buffer_views, buffers, idx as usize)?)
} else {
None
};
// Read indices
let indices = if let Some(idx) = prim.get("indices").and_then(|v| v.as_u32()) {
read_accessor_indices(accessors, buffer_views, buffers, idx as usize)?
} else {
// No indices — generate sequential
(0..positions.len() as u32).collect()
};
// Assemble vertices
let mut vertices: Vec<MeshVertex> = Vec::with_capacity(positions.len());
for i in 0..positions.len() {
vertices.push(MeshVertex {
position: positions[i],
normal: normals[i],
uv: uvs[i],
tangent: tangents.as_ref().map_or([0.0; 4], |t| t[i]),
});
}
// Read JOINTS_0 (optional)
let joints = if let Some(idx) = attrs.iter().find(|(k, _)| k == "JOINTS_0").and_then(|(_, v)| v.as_u32()) {
Some(read_accessor_joints(accessors, buffer_views, buffers, idx as usize)?)
} else {
None
};
// Read WEIGHTS_0 (optional)
let weights = if let Some(idx) = attrs.iter().find(|(k, _)| k == "WEIGHTS_0").and_then(|(_, v)| v.as_u32()) {
Some(read_accessor_vec4(accessors, buffer_views, buffers, idx as usize)?)
} else {
None
};
// Compute tangents if not provided
if tangents.is_none() {
compute_tangents(&mut vertices, &indices);
}
// Read material
let material = prim.get("material")
.and_then(|v| v.as_u32())
.and_then(|idx| materials_json?.get(idx as usize))
.and_then(|mat| extract_material(mat));
meshes.push(GltfMesh { vertices, indices, name: name.clone(), material, joints, weights });
}
}
let nodes = parse_nodes(json);
let skins = parse_skins(json, accessors, buffer_views, buffers);
let animations = parse_animations(json, accessors, buffer_views, buffers);
Ok(GltfData { meshes, nodes, skins, animations })
}
fn get_buffer_data<'a>(
accessor: &JsonValue,
buffer_views: &[JsonValue],
buffers: &'a [Vec<u8>],
) -> Result<(&'a [u8], usize), String> {
let bv_idx = accessor.get("bufferView").and_then(|v| v.as_u32())
.ok_or("Accessor missing bufferView")? as usize;
let bv = buffer_views.get(bv_idx).ok_or("BufferView index out of range")?;
let buf_idx = bv.get("buffer").and_then(|v| v.as_u32()).unwrap_or(0) as usize;
let bv_offset = bv.get("byteOffset").and_then(|v| v.as_u32()).unwrap_or(0) as usize;
let acc_offset = accessor.get("byteOffset").and_then(|v| v.as_u32()).unwrap_or(0) as usize;
let buffer = buffers.get(buf_idx).ok_or("Buffer index out of range")?;
let offset = bv_offset + acc_offset;
Ok((buffer, offset))
}
fn read_accessor_vec3(
accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>], idx: usize,
) -> Result<Vec<[f32; 3]>, String> {
let acc = accessors.get(idx).ok_or("Accessor index out of range")?;
let count = acc.get("count").and_then(|v| v.as_u32()).ok_or("Missing count")? as usize;
let (buffer, offset) = get_buffer_data(acc, buffer_views, buffers)?;
let mut result = Vec::with_capacity(count);
for i in 0..count {
let o = offset + i * 12;
if o + 12 > buffer.len() { return Err("Buffer overflow reading vec3".into()); }
let x = f32::from_le_bytes([buffer[o], buffer[o+1], buffer[o+2], buffer[o+3]]);
let y = f32::from_le_bytes([buffer[o+4], buffer[o+5], buffer[o+6], buffer[o+7]]);
let z = f32::from_le_bytes([buffer[o+8], buffer[o+9], buffer[o+10], buffer[o+11]]);
result.push([x, y, z]);
}
Ok(result)
}
fn read_accessor_vec2(
accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>], idx: usize,
) -> Result<Vec<[f32; 2]>, String> {
let acc = accessors.get(idx).ok_or("Accessor index out of range")?;
let count = acc.get("count").and_then(|v| v.as_u32()).ok_or("Missing count")? as usize;
let (buffer, offset) = get_buffer_data(acc, buffer_views, buffers)?;
let mut result = Vec::with_capacity(count);
for i in 0..count {
let o = offset + i * 8;
if o + 8 > buffer.len() { return Err("Buffer overflow reading vec2".into()); }
let x = f32::from_le_bytes([buffer[o], buffer[o+1], buffer[o+2], buffer[o+3]]);
let y = f32::from_le_bytes([buffer[o+4], buffer[o+5], buffer[o+6], buffer[o+7]]);
result.push([x, y]);
}
Ok(result)
}
fn read_accessor_vec4(
accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>], idx: usize,
) -> Result<Vec<[f32; 4]>, String> {
let acc = accessors.get(idx).ok_or("Accessor index out of range")?;
let count = acc.get("count").and_then(|v| v.as_u32()).ok_or("Missing count")? as usize;
let (buffer, offset) = get_buffer_data(acc, buffer_views, buffers)?;
let mut result = Vec::with_capacity(count);
for i in 0..count {
let o = offset + i * 16;
if o + 16 > buffer.len() { return Err("Buffer overflow reading vec4".into()); }
let x = f32::from_le_bytes([buffer[o], buffer[o+1], buffer[o+2], buffer[o+3]]);
let y = f32::from_le_bytes([buffer[o+4], buffer[o+5], buffer[o+6], buffer[o+7]]);
let z = f32::from_le_bytes([buffer[o+8], buffer[o+9], buffer[o+10], buffer[o+11]]);
let w = f32::from_le_bytes([buffer[o+12], buffer[o+13], buffer[o+14], buffer[o+15]]);
result.push([x, y, z, w]);
}
Ok(result)
}
fn read_accessor_indices(
accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>], idx: usize,
) -> Result<Vec<u32>, String> {
let acc = accessors.get(idx).ok_or("Accessor index out of range")?;
let count = acc.get("count").and_then(|v| v.as_u32()).ok_or("Missing count")? as usize;
let comp_type = acc.get("componentType").and_then(|v| v.as_u32()).ok_or("Missing componentType")?;
let (buffer, offset) = get_buffer_data(acc, buffer_views, buffers)?;
let mut result = Vec::with_capacity(count);
match comp_type {
5121 => { // UNSIGNED_BYTE
for i in 0..count {
if offset + i >= buffer.len() { return Err("Buffer overflow reading u8 indices".into()); }
result.push(buffer[offset + i] as u32);
}
}
5123 => { // UNSIGNED_SHORT
for i in 0..count {
let o = offset + i * 2;
if o + 2 > buffer.len() { return Err("Buffer overflow reading u16 indices".into()); }
result.push(u16::from_le_bytes([buffer[o], buffer[o+1]]) as u32);
}
}
5125 => { // UNSIGNED_INT
for i in 0..count {
let o = offset + i * 4;
if o + 4 > buffer.len() { return Err("Buffer overflow reading u32 indices".into()); }
result.push(u32::from_le_bytes([buffer[o], buffer[o+1], buffer[o+2], buffer[o+3]]));
}
}
_ => return Err(format!("Unsupported index component type: {}", comp_type)),
}
Ok(result)
}
fn read_accessor_joints(
accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>], idx: usize,
) -> Result<Vec<[u16; 4]>, String> {
let acc = accessors.get(idx).ok_or("Accessor index out of range")?;
let count = acc.get("count").and_then(|v| v.as_u32()).ok_or("Missing count")? as usize;
let comp_type = acc.get("componentType").and_then(|v| v.as_u32()).unwrap_or(5123);
let (buffer, offset) = get_buffer_data(acc, buffer_views, buffers)?;
let mut result = Vec::with_capacity(count);
match comp_type {
5121 => { // UNSIGNED_BYTE
for i in 0..count {
let o = offset + i * 4;
if o + 4 > buffer.len() { return Err("Buffer overflow reading joints u8".into()); }
result.push([
buffer[o] as u16, buffer[o+1] as u16,
buffer[o+2] as u16, buffer[o+3] as u16,
]);
}
}
5123 => { // UNSIGNED_SHORT
for i in 0..count {
let o = offset + i * 8;
if o + 8 > buffer.len() { return Err("Buffer overflow reading joints u16".into()); }
result.push([
u16::from_le_bytes([buffer[o], buffer[o+1]]),
u16::from_le_bytes([buffer[o+2], buffer[o+3]]),
u16::from_le_bytes([buffer[o+4], buffer[o+5]]),
u16::from_le_bytes([buffer[o+6], buffer[o+7]]),
]);
}
}
_ => return Err(format!("Unsupported joints component type: {}", comp_type)),
}
Ok(result)
}
fn read_accessor_mat4(
accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>], idx: usize,
) -> Result<Vec<[[f32; 4]; 4]>, String> {
let acc = accessors.get(idx).ok_or("Accessor index out of range")?;
let count = acc.get("count").and_then(|v| v.as_u32()).ok_or("Missing count")? as usize;
let (buffer, offset) = get_buffer_data(acc, buffer_views, buffers)?;
let mut result = Vec::with_capacity(count);
for i in 0..count {
let o = offset + i * 64;
if o + 64 > buffer.len() { return Err("Buffer overflow reading mat4".into()); }
let mut mat = [[0.0f32; 4]; 4];
for col in 0..4 {
for row in 0..4 {
let b = o + (col * 4 + row) * 4;
mat[col][row] = f32::from_le_bytes([buffer[b], buffer[b+1], buffer[b+2], buffer[b+3]]);
}
}
result.push(mat);
}
Ok(result)
}
fn read_accessor_floats(
accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>], idx: usize,
) -> Result<Vec<f32>, String> {
let acc = accessors.get(idx).ok_or("Accessor index out of range")?;
let count = acc.get("count").and_then(|v| v.as_u32()).ok_or("Missing count")? as usize;
let acc_type = acc.get("type").and_then(|v| v.as_str()).unwrap_or("SCALAR");
let components = match acc_type {
"SCALAR" => 1,
"VEC2" => 2,
"VEC3" => 3,
"VEC4" => 4,
_ => 1,
};
let total = count * components;
let (buffer, offset) = get_buffer_data(acc, buffer_views, buffers)?;
let mut result = Vec::with_capacity(total);
for i in 0..total {
let o = offset + i * 4;
if o + 4 > buffer.len() { return Err("Buffer overflow reading floats".into()); }
result.push(f32::from_le_bytes([buffer[o], buffer[o+1], buffer[o+2], buffer[o+3]]));
}
Ok(result)
}
fn parse_nodes(json: &JsonValue) -> Vec<GltfNode> {
let empty_arr: Vec<JsonValue> = Vec::new();
let nodes_arr = json.get("nodes").and_then(|v| v.as_array()).unwrap_or(&empty_arr);
let mut nodes = Vec::with_capacity(nodes_arr.len());
for node_val in nodes_arr {
let name = node_val.get("name").and_then(|v| v.as_str()).map(|s| s.to_string());
let children = node_val.get("children").and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_u32().map(|n| n as usize)).collect())
.unwrap_or_default();
let translation = node_val.get("translation").and_then(|v| v.as_array())
.map(|arr| [
arr.get(0).and_then(|v| v.as_f64()).unwrap_or(0.0) as f32,
arr.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0) as f32,
arr.get(2).and_then(|v| v.as_f64()).unwrap_or(0.0) as f32,
])
.unwrap_or([0.0, 0.0, 0.0]);
let rotation = node_val.get("rotation").and_then(|v| v.as_array())
.map(|arr| [
arr.get(0).and_then(|v| v.as_f64()).unwrap_or(0.0) as f32,
arr.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0) as f32,
arr.get(2).and_then(|v| v.as_f64()).unwrap_or(0.0) as f32,
arr.get(3).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
])
.unwrap_or([0.0, 0.0, 0.0, 1.0]);
let scale = node_val.get("scale").and_then(|v| v.as_array())
.map(|arr| [
arr.get(0).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
arr.get(1).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
arr.get(2).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
])
.unwrap_or([1.0, 1.0, 1.0]);
let mesh = node_val.get("mesh").and_then(|v| v.as_u32()).map(|n| n as usize);
let skin = node_val.get("skin").and_then(|v| v.as_u32()).map(|n| n as usize);
nodes.push(GltfNode { name, children, translation, rotation, scale, mesh, skin });
}
nodes
}
fn parse_skins(
json: &JsonValue, accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>],
) -> Vec<GltfSkin> {
let empty_arr: Vec<JsonValue> = Vec::new();
let skins_arr = json.get("skins").and_then(|v| v.as_array()).unwrap_or(&empty_arr);
let mut skins = Vec::with_capacity(skins_arr.len());
for skin_val in skins_arr {
let name = skin_val.get("name").and_then(|v| v.as_str()).map(|s| s.to_string());
let joints = skin_val.get("joints").and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_u32().map(|n| n as usize)).collect())
.unwrap_or_default();
let skeleton = skin_val.get("skeleton").and_then(|v| v.as_u32()).map(|n| n as usize);
let inverse_bind_matrices = skin_val.get("inverseBindMatrices")
.and_then(|v| v.as_u32())
.and_then(|idx| read_accessor_mat4(accessors, buffer_views, buffers, idx as usize).ok())
.unwrap_or_default();
skins.push(GltfSkin { name, joints, inverse_bind_matrices, skeleton });
}
skins
}
fn parse_animations(
json: &JsonValue, accessors: &[JsonValue], buffer_views: &[JsonValue], buffers: &[Vec<u8>],
) -> Vec<GltfAnimation> {
let empty_arr: Vec<JsonValue> = Vec::new();
let anims_arr = json.get("animations").and_then(|v| v.as_array()).unwrap_or(&empty_arr);
let mut animations = Vec::with_capacity(anims_arr.len());
for anim_val in anims_arr {
let name = anim_val.get("name").and_then(|v| v.as_str()).map(|s| s.to_string());
let samplers_arr = anim_val.get("samplers").and_then(|v| v.as_array()).unwrap_or(&empty_arr);
let channels_arr = anim_val.get("channels").and_then(|v| v.as_array()).unwrap_or(&empty_arr);
// Parse samplers: each has input, output, interpolation
struct Sampler {
times: Vec<f32>,
values: Vec<f32>,
interpolation: Interpolation,
}
let mut samplers = Vec::with_capacity(samplers_arr.len());
for s in samplers_arr {
let interp_str = s.get("interpolation").and_then(|v| v.as_str()).unwrap_or("LINEAR");
let interpolation = parse_interpolation(interp_str);
let times = s.get("input").and_then(|v| v.as_u32())
.and_then(|idx| read_accessor_floats(accessors, buffer_views, buffers, idx as usize).ok())
.unwrap_or_default();
let values = s.get("output").and_then(|v| v.as_u32())
.and_then(|idx| read_accessor_floats(accessors, buffer_views, buffers, idx as usize).ok())
.unwrap_or_default();
samplers.push(Sampler { times, values, interpolation });
}
// Parse channels
let mut channels = Vec::with_capacity(channels_arr.len());
for ch in channels_arr {
let sampler_idx = ch.get("sampler").and_then(|v| v.as_u32()).unwrap_or(0) as usize;
let target = match ch.get("target") {
Some(t) => t,
None => continue,
};
let target_node = target.get("node").and_then(|v| v.as_u32()).unwrap_or(0) as usize;
let path_str = target.get("path").and_then(|v| v.as_str()).unwrap_or("translation");
let target_path = parse_animation_path(path_str);
if let Some(sampler) = samplers.get(sampler_idx) {
channels.push(GltfChannel {
target_node,
target_path,
interpolation: sampler.interpolation,
times: sampler.times.clone(),
values: sampler.values.clone(),
});
}
}
animations.push(GltfAnimation { name, channels });
}
animations
}
fn extract_material(mat: &JsonValue) -> Option<GltfMaterial> {
let pbr = mat.get("pbrMetallicRoughness")?;
let base_color = if let Some(arr) = pbr.get("baseColorFactor").and_then(|v| v.as_array()) {
[
arr.get(0).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
arr.get(1).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
arr.get(2).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
arr.get(3).and_then(|v| v.as_f64()).unwrap_or(1.0) as f32,
]
} else {
[1.0, 1.0, 1.0, 1.0]
};
let metallic = pbr.get("metallicFactor").and_then(|v| v.as_f64()).unwrap_or(1.0) as f32;
let roughness = pbr.get("roughnessFactor").and_then(|v| v.as_f64()).unwrap_or(1.0) as f32;
Some(GltfMaterial { base_color, metallic, roughness })
}
// Helper functions for tests
#[allow(dead_code)]
fn read_floats(buffer: &[u8], offset: usize, count: usize) -> Vec<f32> {
(0..count).map(|i| {
let o = offset + i * 4;
f32::from_le_bytes([buffer[o], buffer[o+1], buffer[o+2], buffer[o+3]])
}).collect()
}
#[allow(dead_code)]
fn read_indices_u16(buffer: &[u8], offset: usize, count: usize) -> Vec<u32> {
(0..count).map(|i| {
let o = offset + i * 2;
u16::from_le_bytes([buffer[o], buffer[o+1]]) as u32
}).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glb_header_magic() {
// Invalid magic
let data = [0u8; 12];
assert!(parse_gltf(&data).is_err());
}
#[test]
fn test_glb_header_version() {
// Valid magic but wrong version
let mut data = Vec::new();
data.extend_from_slice(&0x46546C67u32.to_le_bytes()); // magic "glTF"
data.extend_from_slice(&1u32.to_le_bytes()); // version 1 (we need 2)
data.extend_from_slice(&12u32.to_le_bytes()); // length
assert!(parse_gltf(&data).is_err());
}
#[test]
fn test_base64_decode() {
let encoded = "SGVsbG8="; // "Hello"
let decoded = decode_base64(encoded).unwrap();
assert_eq!(decoded, b"Hello");
}
#[test]
fn test_base64_decode_no_padding() {
let encoded = "SGVsbG8"; // "Hello" without padding
let decoded = decode_base64(encoded).unwrap();
assert_eq!(decoded, b"Hello");
}
#[test]
fn test_read_f32_accessor() {
// Simulate a buffer with 3 float32 values
let buffer: Vec<u8> = [1.0f32, 2.0, 3.0].iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let data = read_floats(&buffer, 0, 3);
assert_eq!(data, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_read_u16_indices() {
let buffer: Vec<u8> = [0u16, 1, 2].iter()
.flat_map(|i| i.to_le_bytes())
.collect();
let indices = read_indices_u16(&buffer, 0, 3);
assert_eq!(indices, vec![0u32, 1, 2]);
}
#[test]
fn test_parse_minimal_glb() {
let glb = build_minimal_glb_triangle();
let data = parse_gltf(&glb).unwrap();
assert_eq!(data.meshes.len(), 1);
let mesh = &data.meshes[0];
assert_eq!(mesh.vertices.len(), 3);
assert_eq!(mesh.indices.len(), 3);
// Verify positions
assert_eq!(mesh.vertices[0].position, [0.0, 0.0, 0.0]);
assert_eq!(mesh.vertices[1].position, [1.0, 0.0, 0.0]);
assert_eq!(mesh.vertices[2].position, [0.0, 1.0, 0.0]);
}
#[test]
fn test_parse_glb_with_material() {
let glb = build_glb_with_material();
let data = parse_gltf(&glb).unwrap();
let mesh = &data.meshes[0];
let mat = mesh.material.as_ref().unwrap();
assert!((mat.base_color[0] - 1.0).abs() < 0.01);
assert!((mat.metallic - 0.5).abs() < 0.01);
assert!((mat.roughness - 0.8).abs() < 0.01);
}
/// Build a minimal GLB with one triangle.
fn build_minimal_glb_triangle() -> Vec<u8> {
// Binary buffer: 3 positions (vec3) + 3 indices (u16)
let mut bin = Vec::new();
// Positions: 3 * vec3 = 36 bytes
for &v in &[0.0f32, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] {
bin.extend_from_slice(&v.to_le_bytes());
}
// Indices: 3 * u16 = 6 bytes + 2 padding = 8 bytes
for &i in &[0u16, 1, 2] {
bin.extend_from_slice(&i.to_le_bytes());
}
bin.extend_from_slice(&[0, 0]); // padding to 4-byte alignment
let json_str = format!(r#"{{
"asset": {{"version": "2.0"}},
"buffers": [{{"byteLength": {}}}],
"bufferViews": [
{{"buffer": 0, "byteOffset": 0, "byteLength": 36}},
{{"buffer": 0, "byteOffset": 36, "byteLength": 6}}
],
"accessors": [
{{"bufferView": 0, "componentType": 5126, "count": 3, "type": "VEC3",
"max": [1.0, 1.0, 0.0], "min": [0.0, 0.0, 0.0]}},
{{"bufferView": 1, "componentType": 5123, "count": 3, "type": "SCALAR"}}
],
"meshes": [{{
"name": "Triangle",
"primitives": [{{
"attributes": {{"POSITION": 0}},
"indices": 1
}}]
}}]
}}"#, bin.len());
let json_bytes = json_str.as_bytes();
// Pad JSON to 4-byte alignment
let json_padded_len = (json_bytes.len() + 3) & !3;
let mut json_padded = json_bytes.to_vec();
while json_padded.len() < json_padded_len {
json_padded.push(b' ');
}
let total_len = 12 + 8 + json_padded.len() + 8 + bin.len();
let mut glb = Vec::with_capacity(total_len);
// Header
glb.extend_from_slice(&0x46546C67u32.to_le_bytes()); // magic
glb.extend_from_slice(&2u32.to_le_bytes()); // version
glb.extend_from_slice(&(total_len as u32).to_le_bytes());
// JSON chunk
glb.extend_from_slice(&(json_padded.len() as u32).to_le_bytes());
glb.extend_from_slice(&0x4E4F534Au32.to_le_bytes()); // "JSON"
glb.extend_from_slice(&json_padded);
// BIN chunk
glb.extend_from_slice(&(bin.len() as u32).to_le_bytes());
glb.extend_from_slice(&0x004E4942u32.to_le_bytes()); // "BIN\0"
glb.extend_from_slice(&bin);
glb
}
#[test]
fn test_animation_path_parsing() {
assert_eq!(parse_animation_path("translation"), AnimationPath::Translation);
assert_eq!(parse_animation_path("rotation"), AnimationPath::Rotation);
assert_eq!(parse_animation_path("scale"), AnimationPath::Scale);
}
#[test]
fn test_interpolation_parsing() {
assert_eq!(parse_interpolation("LINEAR"), Interpolation::Linear);
assert_eq!(parse_interpolation("STEP"), Interpolation::Step);
assert_eq!(parse_interpolation("CUBICSPLINE"), Interpolation::CubicSpline);
}
#[test]
fn test_gltf_data_has_new_fields() {
let glb = build_minimal_glb_triangle();
let data = parse_gltf(&glb).unwrap();
assert_eq!(data.meshes.len(), 1);
// No nodes/skins/animations in minimal GLB — should be empty, not crash
assert!(data.nodes.is_empty());
assert!(data.skins.is_empty());
assert!(data.animations.is_empty());
// Joints/weights should be None
assert!(data.meshes[0].joints.is_none());
assert!(data.meshes[0].weights.is_none());
}
#[test]
fn test_parse_glb_with_node() {
let glb = build_glb_with_node();
let data = parse_gltf(&glb).unwrap();
assert_eq!(data.meshes.len(), 1);
assert_eq!(data.nodes.len(), 1);
let node = &data.nodes[0];
assert_eq!(node.name.as_deref(), Some("RootNode"));
assert_eq!(node.mesh, Some(0));
assert!((node.translation[0] - 1.0).abs() < 0.001);
assert!((node.translation[1] - 2.0).abs() < 0.001);
assert!((node.translation[2] - 3.0).abs() < 0.001);
assert_eq!(node.scale, [1.0, 1.0, 1.0]);
}
fn build_glb_with_node() -> Vec<u8> {
let mut bin = Vec::new();
for &v in &[0.0f32, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] {
bin.extend_from_slice(&v.to_le_bytes());
}
for &i in &[0u16, 1, 2] {
bin.extend_from_slice(&i.to_le_bytes());
}
bin.extend_from_slice(&[0, 0]);
let json_str = format!(r#"{{
"asset": {{"version": "2.0"}},
"buffers": [{{"byteLength": {}}}],
"bufferViews": [
{{"buffer": 0, "byteOffset": 0, "byteLength": 36}},
{{"buffer": 0, "byteOffset": 36, "byteLength": 6}}
],
"accessors": [
{{"bufferView": 0, "componentType": 5126, "count": 3, "type": "VEC3",
"max": [1.0, 1.0, 0.0], "min": [0.0, 0.0, 0.0]}},
{{"bufferView": 1, "componentType": 5123, "count": 3, "type": "SCALAR"}}
],
"nodes": [{{
"name": "RootNode",
"mesh": 0,
"translation": [1.0, 2.0, 3.0]
}}],
"meshes": [{{
"name": "Triangle",
"primitives": [{{
"attributes": {{"POSITION": 0}},
"indices": 1
}}]
}}]
}}"#, bin.len());
let json_bytes = json_str.as_bytes();
let json_padded_len = (json_bytes.len() + 3) & !3;
let mut json_padded = json_bytes.to_vec();
while json_padded.len() < json_padded_len {
json_padded.push(b' ');
}
let total_len = 12 + 8 + json_padded.len() + 8 + bin.len();
let mut glb = Vec::with_capacity(total_len);
glb.extend_from_slice(&0x46546C67u32.to_le_bytes());
glb.extend_from_slice(&2u32.to_le_bytes());
glb.extend_from_slice(&(total_len as u32).to_le_bytes());
glb.extend_from_slice(&(json_padded.len() as u32).to_le_bytes());
glb.extend_from_slice(&0x4E4F534Au32.to_le_bytes());
glb.extend_from_slice(&json_padded);
glb.extend_from_slice(&(bin.len() as u32).to_le_bytes());
glb.extend_from_slice(&0x004E4942u32.to_le_bytes());
glb.extend_from_slice(&bin);
glb
}
/// Build a GLB with one triangle and a material.
fn build_glb_with_material() -> Vec<u8> {
let mut bin = Vec::new();
for &v in &[0.0f32, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] {
bin.extend_from_slice(&v.to_le_bytes());
}
for &i in &[0u16, 1, 2] {
bin.extend_from_slice(&i.to_le_bytes());
}
bin.extend_from_slice(&[0, 0]); // padding
let json_str = format!(r#"{{
"asset": {{"version": "2.0"}},
"buffers": [{{"byteLength": {}}}],
"bufferViews": [
{{"buffer": 0, "byteOffset": 0, "byteLength": 36}},
{{"buffer": 0, "byteOffset": 36, "byteLength": 6}}
],
"accessors": [
{{"bufferView": 0, "componentType": 5126, "count": 3, "type": "VEC3",
"max": [1.0, 1.0, 0.0], "min": [0.0, 0.0, 0.0]}},
{{"bufferView": 1, "componentType": 5123, "count": 3, "type": "SCALAR"}}
],
"materials": [{{
"pbrMetallicRoughness": {{
"baseColorFactor": [1.0, 0.0, 0.0, 1.0],
"metallicFactor": 0.5,
"roughnessFactor": 0.8
}}
}}],
"meshes": [{{
"name": "Triangle",
"primitives": [{{
"attributes": {{"POSITION": 0}},
"indices": 1,
"material": 0
}}]
}}]
}}"#, bin.len());
let json_bytes = json_str.as_bytes();
let json_padded_len = (json_bytes.len() + 3) & !3;
let mut json_padded = json_bytes.to_vec();
while json_padded.len() < json_padded_len {
json_padded.push(b' ');
}
let total_len = 12 + 8 + json_padded.len() + 8 + bin.len();
let mut glb = Vec::with_capacity(total_len);
glb.extend_from_slice(&0x46546C67u32.to_le_bytes());
glb.extend_from_slice(&2u32.to_le_bytes());
glb.extend_from_slice(&(total_len as u32).to_le_bytes());
glb.extend_from_slice(&(json_padded.len() as u32).to_le_bytes());
glb.extend_from_slice(&0x4E4F534Au32.to_le_bytes());
glb.extend_from_slice(&json_padded);
glb.extend_from_slice(&(bin.len() as u32).to_le_bytes());
glb.extend_from_slice(&0x004E4942u32.to_le_bytes());
glb.extend_from_slice(&bin);
glb
}
}

View File

@@ -0,0 +1,60 @@
/// Calculate half-resolution dimensions (rounded up).
pub fn half_resolution(width: u32, height: u32) -> (u32, u32) {
((width + 1) / 2, (height + 1) / 2)
}
/// Bilinear upscale weight calculation for a 2x2 tap pattern.
pub fn bilinear_weights(frac_x: f32, frac_y: f32) -> [f32; 4] {
let w00 = (1.0 - frac_x) * (1.0 - frac_y);
let w10 = frac_x * (1.0 - frac_y);
let w01 = (1.0 - frac_x) * frac_y;
let w11 = frac_x * frac_y;
[w00, w10, w01, w11]
}
/// Depth-aware upscale: reject samples with large depth discontinuity.
pub fn depth_aware_weight(center_depth: f32, sample_depth: f32, threshold: f32) -> f32 {
if (center_depth - sample_depth).abs() > threshold { 0.0 } else { 1.0 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_half_resolution() {
assert_eq!(half_resolution(1920, 1080), (960, 540));
assert_eq!(half_resolution(1921, 1081), (961, 541));
assert_eq!(half_resolution(1, 1), (1, 1));
}
#[test]
fn test_bilinear_weights_center() {
let w = bilinear_weights(0.5, 0.5);
for &wi in &w { assert!((wi - 0.25).abs() < 1e-6); }
}
#[test]
fn test_bilinear_weights_corner() {
let w = bilinear_weights(0.0, 0.0);
assert!((w[0] - 1.0).abs() < 1e-6); // top-left = full weight
assert!((w[1] - 0.0).abs() < 1e-6);
}
#[test]
fn test_bilinear_weights_sum_to_one() {
let w = bilinear_weights(0.3, 0.7);
let sum: f32 = w.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_depth_aware_accept() {
assert!((depth_aware_weight(0.5, 0.51, 0.1) - 1.0).abs() < 1e-6);
}
#[test]
fn test_depth_aware_reject() {
assert!((depth_aware_weight(0.5, 0.9, 0.1) - 0.0).abs() < 1e-6);
}
}

View File

@@ -9,6 +9,7 @@ pub struct IblResources {
}
impl IblResources {
/// CPU fallback: generates the BRDF LUT on the CPU and uploads as Rgba8Unorm.
pub fn new(device: &wgpu::Device, queue: &wgpu::Queue) -> Self {
let size = BRDF_LUT_SIZE;
@@ -79,4 +80,120 @@ impl IblResources {
}
}
/// GPU compute path: generates the BRDF LUT via a compute shader in Rg16Float format.
/// Higher precision than the CPU Rgba8Unorm path.
pub fn new_gpu(device: &wgpu::Device, queue: &wgpu::Queue) -> Self {
let size = BRDF_LUT_SIZE;
let extent = wgpu::Extent3d {
width: size,
height: size,
depth_or_array_layers: 1,
};
// Create Rg16Float storage texture
let brdf_lut_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("BrdfLutTexture_GPU"),
size: extent,
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: wgpu::TextureFormat::Rg16Float,
usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
view_formats: &[],
});
let storage_view =
brdf_lut_texture.create_view(&wgpu::TextureViewDescriptor::default());
// Create compute pipeline
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("BRDF LUT Compute Shader"),
source: wgpu::ShaderSource::Wgsl(
include_str!("brdf_lut_compute.wgsl").into(),
),
});
let bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("BRDF LUT Compute BGL"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture {
access: wgpu::StorageTextureAccess::WriteOnly,
format: wgpu::TextureFormat::Rg16Float,
view_dimension: wgpu::TextureViewDimension::D2,
},
count: None,
}],
});
let pipeline_layout =
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("BRDF LUT Compute Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
immediate_size: 0,
});
let pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("BRDF LUT Compute Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("BRDF LUT Compute Bind Group"),
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(&storage_view),
}],
});
// Dispatch compute shader
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("BRDF LUT Compute Encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("BRDF LUT Compute Pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
// Dispatch enough workgroups to cover size x size texels (16x16 per workgroup)
let wg_x = (size + 15) / 16;
let wg_y = (size + 15) / 16;
pass.dispatch_workgroups(wg_x, wg_y, 1);
}
queue.submit(std::iter::once(encoder.finish()));
let brdf_lut_view =
brdf_lut_texture.create_view(&wgpu::TextureViewDescriptor::default());
let brdf_lut_sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("BrdfLutSampler"),
address_mode_u: wgpu::AddressMode::ClampToEdge,
address_mode_v: wgpu::AddressMode::ClampToEdge,
address_mode_w: wgpu::AddressMode::ClampToEdge,
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::MipmapFilterMode::Nearest,
..Default::default()
});
Self {
brdf_lut_texture,
brdf_lut_view,
brdf_lut_sampler,
}
}
}

View File

@@ -0,0 +1,81 @@
struct CameraUniform {
view_proj: mat4x4<f32>,
model: mat4x4<f32>,
camera_pos: vec3<f32>,
_pad: f32,
};
struct LightUniform {
direction: vec3<f32>,
_pad0: f32,
color: vec3<f32>,
ambient_strength: f32,
};
@group(0) @binding(0) var<uniform> camera: CameraUniform;
@group(0) @binding(1) var<uniform> light: LightUniform;
@group(1) @binding(0) var t_diffuse: texture_2d<f32>;
@group(1) @binding(1) var s_diffuse: sampler;
struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) normal: vec3<f32>,
@location(2) uv: vec2<f32>,
@location(3) tangent: vec4<f32>,
};
struct InstanceInput {
@location(4) model_0: vec4<f32>,
@location(5) model_1: vec4<f32>,
@location(6) model_2: vec4<f32>,
@location(7) model_3: vec4<f32>,
@location(8) color: vec4<f32>,
};
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) world_normal: vec3<f32>,
@location(1) world_pos: vec3<f32>,
@location(2) uv: vec2<f32>,
@location(3) inst_color: vec4<f32>,
};
@vertex
fn vs_main(vertex: VertexInput, instance: InstanceInput) -> VertexOutput {
let inst_model = mat4x4<f32>(
instance.model_0,
instance.model_1,
instance.model_2,
instance.model_3,
);
var out: VertexOutput;
let world_pos = inst_model * vec4<f32>(vertex.position, 1.0);
out.world_pos = world_pos.xyz;
out.world_normal = normalize((inst_model * vec4<f32>(vertex.normal, 0.0)).xyz);
out.clip_position = camera.view_proj * world_pos;
out.uv = vertex.uv;
out.inst_color = instance.color;
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let tex_color = textureSample(t_diffuse, s_diffuse, in.uv);
let normal = normalize(in.world_normal);
let light_dir = normalize(-light.direction);
let ndotl = max(dot(normal, light_dir), 0.0);
let diffuse = light.color * ndotl;
let view_dir = normalize(camera.camera_pos - in.world_pos);
let half_dir = normalize(light_dir + view_dir);
let spec = pow(max(dot(normal, half_dir), 0.0), 32.0);
let specular = light.color * spec * 0.3;
let ambient = light.color * light.ambient_strength;
let lit = (ambient + diffuse + specular) * tex_color.rgb * in.inst_color.rgb;
return vec4<f32>(lit, tex_color.a * in.inst_color.a);
}

View File

@@ -0,0 +1,164 @@
use bytemuck::{Pod, Zeroable};
use crate::vertex::MeshVertex;
use crate::gpu::DEPTH_FORMAT;
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub struct InstanceData {
pub model: [[f32; 4]; 4], // 64 bytes
pub color: [f32; 4], // 16 bytes
}
impl InstanceData {
pub const LAYOUT: wgpu::VertexBufferLayout<'static> = wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<InstanceData>() as u64,
step_mode: wgpu::VertexStepMode::Instance,
attributes: &[
wgpu::VertexAttribute { offset: 0, shader_location: 4, format: wgpu::VertexFormat::Float32x4 },
wgpu::VertexAttribute { offset: 16, shader_location: 5, format: wgpu::VertexFormat::Float32x4 },
wgpu::VertexAttribute { offset: 32, shader_location: 6, format: wgpu::VertexFormat::Float32x4 },
wgpu::VertexAttribute { offset: 48, shader_location: 7, format: wgpu::VertexFormat::Float32x4 },
wgpu::VertexAttribute { offset: 64, shader_location: 8, format: wgpu::VertexFormat::Float32x4 },
],
};
}
pub struct InstanceBuffer {
pub buffer: wgpu::Buffer,
pub capacity: usize,
pub count: usize,
}
impl InstanceBuffer {
pub fn new(device: &wgpu::Device, capacity: usize) -> Self {
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Instance Buffer"),
size: (capacity * std::mem::size_of::<InstanceData>()) as u64,
usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
InstanceBuffer { buffer, capacity, count: 0 }
}
pub fn update(&mut self, device: &wgpu::Device, queue: &wgpu::Queue, instances: &[InstanceData]) {
self.count = instances.len();
if instances.is_empty() { return; }
if instances.len() > self.capacity {
self.capacity = instances.len().next_power_of_two();
self.buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Instance Buffer"),
size: (self.capacity * std::mem::size_of::<InstanceData>()) as u64,
usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
}
queue.write_buffer(&self.buffer, 0, bytemuck::cast_slice(instances));
}
}
pub fn create_instanced_pipeline(
device: &wgpu::Device,
format: wgpu::TextureFormat,
camera_light_layout: &wgpu::BindGroupLayout,
texture_layout: &wgpu::BindGroupLayout,
) -> wgpu::RenderPipeline {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Instanced Shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("instanced_shader.wgsl").into()),
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Instanced Pipeline Layout"),
bind_group_layouts: &[camera_light_layout, texture_layout],
immediate_size: 0,
});
device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("Instanced Pipeline"),
layout: Some(&layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[MeshVertex::LAYOUT, InstanceData::LAYOUT],
compilation_options: wgpu::PipelineCompilationOptions::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format,
blend: Some(wgpu::BlendState::REPLACE),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: wgpu::PipelineCompilationOptions::default(),
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: wgpu::FrontFace::Ccw,
cull_mode: Some(wgpu::Face::Back),
polygon_mode: wgpu::PolygonMode::Fill,
unclipped_depth: false,
conservative: false,
},
depth_stencil: Some(wgpu::DepthStencilState {
format: DEPTH_FORMAT,
depth_write_enabled: true,
depth_compare: wgpu::CompareFunction::Less,
stencil: wgpu::StencilState::default(),
bias: wgpu::DepthBiasState::default(),
}),
multisample: wgpu::MultisampleState { count: 1, mask: !0, alpha_to_coverage_enabled: false },
multiview_mask: None,
cache: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_instance_data_size() {
assert_eq!(std::mem::size_of::<InstanceData>(), 80);
}
#[test]
fn test_instance_data_layout_attributes() {
assert_eq!(InstanceData::LAYOUT.attributes.len(), 5);
assert_eq!(InstanceData::LAYOUT.array_stride, 80);
assert_eq!(InstanceData::LAYOUT.step_mode, wgpu::VertexStepMode::Instance);
}
#[test]
fn test_instance_data_layout_locations() {
let attrs = InstanceData::LAYOUT.attributes;
assert_eq!(attrs[0].shader_location, 4);
assert_eq!(attrs[1].shader_location, 5);
assert_eq!(attrs[2].shader_location, 6);
assert_eq!(attrs[3].shader_location, 7);
assert_eq!(attrs[4].shader_location, 8);
}
#[test]
fn test_instance_data_layout_offsets() {
let attrs = InstanceData::LAYOUT.attributes;
assert_eq!(attrs[0].offset, 0);
assert_eq!(attrs[1].offset, 16);
assert_eq!(attrs[2].offset, 32);
assert_eq!(attrs[3].offset, 48);
assert_eq!(attrs[4].offset, 64);
}
#[test]
fn test_instance_data_pod() {
let data = InstanceData {
model: [[1.0,0.0,0.0,0.0],[0.0,1.0,0.0,0.0],[0.0,0.0,1.0,0.0],[0.0,0.0,0.0,1.0]],
color: [1.0, 1.0, 1.0, 1.0],
};
let bytes: &[u8] = bytemuck::bytes_of(&data);
assert_eq!(bytes.len(), 80);
}
}

View File

@@ -0,0 +1,959 @@
/// Baseline JPEG decoder. Supports SOF0 (sequential DCT, Huffman).
/// Returns RGBA pixel data like parse_png.
/// Supports grayscale (1-component) and YCbCr (3-component) with
/// chroma subsampling (4:4:4, 4:2:2, 4:2:0).
pub fn parse_jpg(data: &[u8]) -> Result<(Vec<u8>, u32, u32), String> {
if data.len() < 2 || data[0] != 0xFF || data[1] != 0xD8 {
return Err("Invalid JPEG: missing SOI marker".into());
}
let mut pos = 2;
let mut width: u16 = 0;
let mut height: u16 = 0;
let mut num_components: u8 = 0;
let mut components: Vec<JpegComponent> = Vec::new();
let mut qt_tables: [[u16; 64]; 4] = [[0; 64]; 4];
let mut dc_tables: [Option<HuffTable>; 4] = [None, None, None, None];
let mut ac_tables: [Option<HuffTable>; 4] = [None, None, None, None];
let mut found_sof = false;
let mut restart_interval: u16 = 0;
while pos + 1 < data.len() {
if data[pos] != 0xFF {
return Err(format!("Expected marker at position {}", pos));
}
// Skip padding 0xFF bytes
while pos + 1 < data.len() && data[pos + 1] == 0xFF {
pos += 1;
}
if pos + 1 >= data.len() {
return Err("Unexpected end of data".into());
}
let marker = data[pos + 1];
pos += 2;
match marker {
0xD8 => {} // SOI (already handled)
0xD9 => break, // EOI
0xDA => {
// SOS — Start of Scan
if !found_sof {
return Err("SOS before SOF".into());
}
let (rgb, scan_end) = decode_scan(
data, pos, width, height, num_components,
&components, &qt_tables, &dc_tables, &ac_tables,
restart_interval,
)?;
let _ = scan_end;
// Convert RGB to RGBA
let w = width as u32;
let h = height as u32;
let mut rgba = Vec::with_capacity((w * h * 4) as usize);
for pixel in rgb.chunks_exact(3) {
rgba.push(pixel[0]);
rgba.push(pixel[1]);
rgba.push(pixel[2]);
rgba.push(255);
}
return Ok((rgba, w, h));
}
0xC0 => {
// SOF0 — Baseline DCT
let (sof, len) = parse_sof(data, pos)?;
width = sof.width;
height = sof.height;
num_components = sof.num_components;
components = sof.components;
found_sof = true;
pos += len;
}
0xC4 => {
// DHT — Define Huffman Table
let len = parse_dht(data, pos, &mut dc_tables, &mut ac_tables)?;
pos += len;
}
0xDB => {
// DQT — Define Quantization Table
let len = parse_dqt(data, pos, &mut qt_tables)?;
pos += len;
}
0xDD => {
// DRI — Define Restart Interval
if pos + 4 > data.len() {
return Err("DRI too short".into());
}
let seg_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
restart_interval =
u16::from_be_bytes([data[pos + 2], data[pos + 3]]);
pos += seg_len;
}
0xD0..=0xD7 => {
// RST markers — handled inside scan decoder
}
0xE0..=0xEF | 0xFE => {
// APP0-APP15, COM — skip
if pos + 2 > data.len() {
return Err("Segment too short".into());
}
let seg_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += seg_len;
}
_ => {
// Unknown marker with length — skip
if pos + 2 > data.len() {
return Err(format!("Unknown marker 0x{:02X}", marker));
}
let seg_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += seg_len;
}
}
}
Err("No image data found (missing SOS)".into())
}
// ---------------------------------------------------------------------------
// Data structures
// ---------------------------------------------------------------------------
#[derive(Clone)]
struct JpegComponent {
#[allow(dead_code)]
id: u8,
h_sample: u8,
v_sample: u8,
qt_id: u8,
dc_table: u8,
ac_table: u8,
}
struct SofData {
width: u16,
height: u16,
num_components: u8,
components: Vec<JpegComponent>,
}
struct HuffTable {
symbols: Vec<u8>,
offsets: [u16; 17],
maxcode: [i32; 17],
mincode: [u16; 17],
}
// ---------------------------------------------------------------------------
// Marker parsers
// ---------------------------------------------------------------------------
fn parse_dqt(data: &[u8], pos: usize, qt_tables: &mut [[u16; 64]; 4]) -> Result<usize, String> {
if pos + 2 > data.len() {
return Err("DQT too short".into());
}
let seg_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
if pos + seg_len > data.len() {
return Err("DQT segment extends past data".into());
}
let mut off = pos + 2;
let seg_end = pos + seg_len;
while off < seg_end {
let pq_tq = data[off];
let precision = pq_tq >> 4;
let table_id = (pq_tq & 0x0F) as usize;
off += 1;
if table_id >= 4 {
return Err(format!("DQT table id {} out of range", table_id));
}
if precision == 0 {
if off + 64 > seg_end {
return Err("DQT 8-bit data too short".into());
}
for i in 0..64 {
qt_tables[table_id][i] = data[off + i] as u16;
}
off += 64;
} else {
if off + 128 > seg_end {
return Err("DQT 16-bit data too short".into());
}
for i in 0..64 {
qt_tables[table_id][i] =
u16::from_be_bytes([data[off + i * 2], data[off + i * 2 + 1]]);
}
off += 128;
}
}
Ok(seg_len)
}
fn parse_sof(data: &[u8], pos: usize) -> Result<(SofData, usize), String> {
if pos + 2 > data.len() {
return Err("SOF too short".into());
}
let seg_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
if pos + seg_len > data.len() {
return Err("SOF segment extends past data".into());
}
let precision = data[pos + 2];
if precision != 8 {
return Err(format!("Unsupported sample precision: {}", precision));
}
let height = u16::from_be_bytes([data[pos + 3], data[pos + 4]]);
let width = u16::from_be_bytes([data[pos + 5], data[pos + 6]]);
let num_comp = data[pos + 7];
let mut components = Vec::new();
let mut off = pos + 8;
for _ in 0..num_comp {
if off + 3 > pos + seg_len {
return Err("SOF component data too short".into());
}
let id = data[off];
let sampling = data[off + 1];
let h_sample = sampling >> 4;
let v_sample = sampling & 0x0F;
let qt_id = data[off + 2];
components.push(JpegComponent {
id,
h_sample,
v_sample,
qt_id,
dc_table: 0,
ac_table: 0,
});
off += 3;
}
Ok((
SofData {
width,
height,
num_components: num_comp,
components,
},
seg_len,
))
}
fn parse_dht(
data: &[u8],
pos: usize,
dc_tables: &mut [Option<HuffTable>; 4],
ac_tables: &mut [Option<HuffTable>; 4],
) -> Result<usize, String> {
if pos + 2 > data.len() {
return Err("DHT too short".into());
}
let seg_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
if pos + seg_len > data.len() {
return Err("DHT segment extends past data".into());
}
let mut off = pos + 2;
let seg_end = pos + seg_len;
while off < seg_end {
let tc_th = data[off];
let table_class = tc_th >> 4;
let table_id = (tc_th & 0x0F) as usize;
off += 1;
if table_id >= 4 {
return Err(format!("DHT table id {} out of range", table_id));
}
if off + 16 > seg_end {
return Err("DHT counts too short".into());
}
let mut counts = [0u8; 16];
counts.copy_from_slice(&data[off..off + 16]);
off += 16;
let total_symbols: usize = counts.iter().map(|&c| c as usize).sum();
if off + total_symbols > seg_end {
return Err("DHT symbols too short".into());
}
let symbols: Vec<u8> = data[off..off + total_symbols].to_vec();
off += total_symbols;
// Build lookup tables
let mut offsets = [0u16; 17];
let mut maxcode = [-1i32; 17];
let mut mincode = [0u16; 17];
let mut code: u16 = 0;
let mut sym_offset: u16 = 0;
for i in 0..16 {
offsets[i] = sym_offset;
if counts[i] > 0 {
mincode[i] = code;
maxcode[i] = (code + counts[i] as u16 - 1) as i32;
sym_offset += counts[i] as u16;
}
code = (code + counts[i] as u16) << 1;
}
offsets[16] = sym_offset;
let table = HuffTable {
symbols,
offsets,
maxcode,
mincode,
};
if table_class == 0 {
dc_tables[table_id] = Some(table);
} else {
ac_tables[table_id] = Some(table);
}
}
Ok(seg_len)
}
// ---------------------------------------------------------------------------
// BitReader — MSB-first bit reading with JPEG byte stuffing
// ---------------------------------------------------------------------------
struct BitReader<'a> {
data: &'a [u8],
pos: usize,
bit_pos: u8,
current: u8,
}
impl<'a> BitReader<'a> {
fn new(data: &'a [u8], start: usize) -> Self {
Self {
data,
pos: start,
bit_pos: 0,
current: 0,
}
}
fn read_byte(&mut self) -> Result<u8, String> {
if self.pos >= self.data.len() {
return Err("Unexpected end of scan data".into());
}
let byte = self.data[self.pos];
self.pos += 1;
if byte == 0xFF {
if self.pos >= self.data.len() {
return Err("Unexpected end after 0xFF".into());
}
let next = self.data[self.pos];
if next == 0x00 {
self.pos += 1; // skip stuffed 0x00
} else if (0xD0..=0xD7).contains(&next) {
// RST marker — skip marker byte and read next actual byte
self.pos += 1;
return self.read_byte();
} else {
return Err("Marker found in scan data".into());
}
}
Ok(byte)
}
fn ensure_bits(&mut self) -> Result<(), String> {
if self.bit_pos == 0 {
self.current = self.read_byte()?;
self.bit_pos = 8;
}
Ok(())
}
fn read_bit(&mut self) -> Result<u8, String> {
self.ensure_bits()?;
self.bit_pos -= 1;
Ok((self.current >> self.bit_pos) & 1)
}
fn read_bits(&mut self, count: u8) -> Result<u16, String> {
let mut val: u16 = 0;
for _ in 0..count {
val = (val << 1) | self.read_bit()? as u16;
}
Ok(val)
}
fn decode_huffman(&mut self, table: &HuffTable) -> Result<u8, String> {
let mut code: u16 = 0;
for len in 0..16 {
code = (code << 1) | self.read_bit()? as u16;
if table.maxcode[len] >= 0 && code as i32 <= table.maxcode[len] {
let idx = table.offsets[len] as usize + (code - table.mincode[len]) as usize;
return Ok(table.symbols[idx]);
}
}
Err("Invalid Huffman code".into())
}
/// Skip to next byte-aligned position (and handle RST markers)
fn align_to_byte(&mut self) {
self.bit_pos = 0;
self.current = 0;
}
/// Find and skip RST marker in the byte stream
fn skip_to_rst_marker(&mut self) -> Result<(), String> {
// Align to byte boundary
self.align_to_byte();
// Look for 0xFF 0xDn marker
loop {
if self.pos >= self.data.len() {
return Err("Unexpected end looking for RST marker".into());
}
if self.data[self.pos] == 0xFF && self.pos + 1 < self.data.len() {
let next = self.data[self.pos + 1];
if (0xD0..=0xD7).contains(&next) {
self.pos += 2;
return Ok(());
}
}
self.pos += 1;
}
}
fn scan_end_pos(&self) -> usize {
self.pos
}
}
// ---------------------------------------------------------------------------
// IDCT
// ---------------------------------------------------------------------------
/// Zig-zag order for 8x8 block
const ZIGZAG: [usize; 64] = [
0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, 48, 41, 34, 27,
20, 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51,
58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63,
];
fn idct(coeffs: &[i32; 64]) -> [i32; 64] {
let mut workspace = [0.0f64; 64];
// Arrange from zigzag to row-major
let mut block = [0.0f64; 64];
for i in 0..64 {
block[ZIGZAG[i]] = coeffs[i] as f64;
}
// 1D IDCT on rows
for row in 0..8 {
let off = row * 8;
idct_1d(&mut block, off);
}
// Transpose
for r in 0..8 {
for c in 0..8 {
workspace[c * 8 + r] = block[r * 8 + c];
}
}
// 1D IDCT on columns (now rows after transpose)
for row in 0..8 {
let off = row * 8;
idct_1d(&mut workspace, off);
}
// Transpose back and round
let mut result = [0i32; 64];
for r in 0..8 {
for c in 0..8 {
result[r * 8 + c] = workspace[c * 8 + r].round() as i32;
}
}
result
}
fn idct_1d(data: &mut [f64], off: usize) {
use std::f64::consts::PI;
let mut tmp = [0.0f64; 8];
for x in 0..8 {
let mut sum = 0.0;
for u in 0..8 {
let cu = if u == 0 { 1.0 / 2.0f64.sqrt() } else { 1.0 };
sum += cu * data[off + u] * ((2.0 * x as f64 + 1.0) * u as f64 * PI / 16.0).cos();
}
tmp[x] = sum / 2.0;
}
data[off..off + 8].copy_from_slice(&tmp);
}
// ---------------------------------------------------------------------------
// Scan decoder
// ---------------------------------------------------------------------------
#[allow(clippy::too_many_arguments)]
fn decode_scan(
data: &[u8],
pos: usize,
width: u16,
height: u16,
num_components: u8,
components: &[JpegComponent],
qt_tables: &[[u16; 64]; 4],
dc_tables: &[Option<HuffTable>; 4],
ac_tables: &[Option<HuffTable>; 4],
restart_interval: u16,
) -> Result<(Vec<u8>, usize), String> {
// Parse SOS header
if pos + 2 > data.len() {
return Err("SOS too short".into());
}
let seg_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
let ns = data[pos + 2] as usize;
let mut scan_components = components.to_vec();
let mut off = pos + 3;
for i in 0..ns {
let _cs = data[off]; // component selector
let td_ta = data[off + 1];
scan_components[i].dc_table = td_ta >> 4;
scan_components[i].ac_table = td_ta & 0x0F;
off += 2;
}
let scan_data_start = pos + seg_len;
let mut reader = BitReader::new(data, scan_data_start);
// Calculate MCU dimensions
let max_h = scan_components
.iter()
.take(num_components as usize)
.map(|c| c.h_sample)
.max()
.unwrap_or(1);
let max_v = scan_components
.iter()
.take(num_components as usize)
.map(|c| c.v_sample)
.max()
.unwrap_or(1);
let mcu_width = max_h as u16 * 8;
let mcu_height = max_v as u16 * 8;
let mcus_x = (width + mcu_width - 1) / mcu_width;
let mcus_y = (height + mcu_height - 1) / mcu_height;
let mut dc_pred = vec![0i32; num_components as usize];
let mut rgb = vec![0u8; (width as usize) * (height as usize) * 3];
let mut mcu_count: u16 = 0;
for mcu_row in 0..mcus_y {
for mcu_col in 0..mcus_x {
// Handle restart interval
if restart_interval > 0 && mcu_count > 0 && mcu_count % restart_interval == 0 {
// Reset DC predictors
for dc in dc_pred.iter_mut() {
*dc = 0;
}
reader.skip_to_rst_marker()?;
}
let mut mcu_blocks: Vec<Vec<[i32; 64]>> = Vec::new();
for (ci, comp) in scan_components
.iter()
.enumerate()
.take(num_components as usize)
{
let blocks_h = comp.h_sample as usize;
let blocks_v = comp.v_sample as usize;
let mut blocks = Vec::with_capacity(blocks_h * blocks_v);
for _ in 0..(blocks_h * blocks_v) {
let block = decode_block(
&mut reader,
dc_tables[comp.dc_table as usize]
.as_ref()
.ok_or("Missing DC Huffman table")?,
ac_tables[comp.ac_table as usize]
.as_ref()
.ok_or("Missing AC Huffman table")?,
&mut dc_pred[ci],
&qt_tables[comp.qt_id as usize],
)?;
blocks.push(block);
}
mcu_blocks.push(blocks);
}
assemble_mcu(
&mcu_blocks,
&scan_components,
num_components,
max_h,
max_v,
mcu_col as usize,
mcu_row as usize,
width as usize,
height as usize,
&mut rgb,
);
mcu_count = mcu_count.wrapping_add(1);
}
}
Ok((rgb, reader.scan_end_pos()))
}
fn decode_block(
reader: &mut BitReader,
dc_table: &HuffTable,
ac_table: &HuffTable,
dc_pred: &mut i32,
qt: &[u16; 64],
) -> Result<[i32; 64], String> {
let mut coeffs = [0i32; 64];
// DC coefficient
let dc_len = reader.decode_huffman(dc_table)?;
let dc_val = if dc_len > 0 {
let bits = reader.read_bits(dc_len)? as i32;
if bits < (1 << (dc_len - 1)) {
bits - (1 << dc_len) + 1
} else {
bits
}
} else {
0
};
*dc_pred += dc_val;
coeffs[0] = *dc_pred * qt[0] as i32;
// AC coefficients
let mut k = 1;
while k < 64 {
let rs = reader.decode_huffman(ac_table)?;
let run = (rs >> 4) as usize;
let size = (rs & 0x0F) as u8;
if size == 0 {
if run == 0 {
break;
} // EOB
if run == 15 {
k += 16;
continue;
} // ZRL (16 zeros)
break;
}
k += run;
if k >= 64 {
break;
}
let bits = reader.read_bits(size)? as i32;
let val = if bits < (1 << (size - 1)) {
bits - (1 << size) + 1
} else {
bits
};
coeffs[k] = val * qt[k] as i32;
k += 1;
}
Ok(idct(&coeffs))
}
// ---------------------------------------------------------------------------
// MCU assembly + color conversion
// ---------------------------------------------------------------------------
#[allow(clippy::too_many_arguments)]
fn assemble_mcu(
mcu_blocks: &[Vec<[i32; 64]>],
components: &[JpegComponent],
num_components: u8,
max_h: u8,
max_v: u8,
mcu_col: usize,
mcu_row: usize,
img_width: usize,
img_height: usize,
rgb: &mut [u8],
) {
let mcu_px = mcu_col * max_h as usize * 8;
let mcu_py = mcu_row * max_v as usize * 8;
for py in 0..(max_v as usize * 8) {
for px in 0..(max_h as usize * 8) {
let x = mcu_px + px;
let y = mcu_py + py;
if x >= img_width || y >= img_height {
continue;
}
if num_components == 1 {
// Grayscale: IDCT output centered at 0, add 128 for level shift
let val =
sample_component(&mcu_blocks[0], &components[0], max_h, max_v, px, py) + 128;
let clamped = val.clamp(0, 255) as u8;
let offset = (y * img_width + x) * 3;
rgb[offset] = clamped;
rgb[offset + 1] = clamped;
rgb[offset + 2] = clamped;
} else {
// YCbCr -> RGB
// IDCT output centered at 0; Y needs +128 level shift, Cb/Cr centered at 0 (128 subtracted)
let yy = sample_component(&mcu_blocks[0], &components[0], max_h, max_v, px, py)
as f32
+ 128.0;
let cb = sample_component(&mcu_blocks[1], &components[1], max_h, max_v, px, py)
as f32;
let cr = sample_component(&mcu_blocks[2], &components[2], max_h, max_v, px, py)
as f32;
let r = (yy + 1.402 * cr).round().clamp(0.0, 255.0) as u8;
let g = (yy - 0.344136 * cb - 0.714136 * cr).round().clamp(0.0, 255.0) as u8;
let b = (yy + 1.772 * cb).round().clamp(0.0, 255.0) as u8;
let offset = (y * img_width + x) * 3;
rgb[offset] = r;
rgb[offset + 1] = g;
rgb[offset + 2] = b;
}
}
}
}
fn sample_component(
blocks: &[[i32; 64]],
comp: &JpegComponent,
max_h: u8,
max_v: u8,
px: usize,
py: usize,
) -> i32 {
let scale_x = comp.h_sample as usize;
let scale_y = comp.v_sample as usize;
let cx = px * scale_x / (max_h as usize * 8);
let cy = py * scale_y / (max_v as usize * 8);
let bx = (px * scale_x / max_h as usize) % 8;
let by = (py * scale_y / max_v as usize) % 8;
let block_idx = cy * scale_x + cx;
if block_idx < blocks.len() {
blocks[block_idx][by * 8 + bx]
} else {
0
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_invalid_signature() {
let data = [0u8; 10];
assert!(parse_jpg(&data).is_err());
}
#[test]
fn test_empty_data() {
assert!(parse_jpg(&[]).is_err());
}
#[test]
fn test_soi_only() {
let data = [0xFF, 0xD8];
assert!(parse_jpg(&data).is_err());
}
#[test]
fn test_parse_dqt_8bit() {
let mut seg = Vec::new();
seg.extend_from_slice(&67u16.to_be_bytes()); // length = 67
seg.push(0x00); // precision=0 (8-bit), table_id=0
for i in 0..64u8 {
seg.push(i + 1);
}
let mut qt_tables = [[0u16; 64]; 4];
let len = parse_dqt(&seg, 0, &mut qt_tables).unwrap();
assert_eq!(len, 67);
assert_eq!(qt_tables[0][0], 1);
assert_eq!(qt_tables[0][63], 64);
}
#[test]
fn test_parse_dht() {
let mut seg = Vec::new();
let mut body = Vec::new();
body.push(0x00); // class=0 (DC), id=0
// counts: 1 symbol at length 1, 0 for lengths 2-16
body.push(1);
for _ in 1..16 {
body.push(0);
}
body.push(0x05); // the symbol
seg.extend_from_slice(&((body.len() + 2) as u16).to_be_bytes());
seg.extend_from_slice(&body);
let mut dc_tables: [Option<HuffTable>; 4] = [None, None, None, None];
let mut ac_tables: [Option<HuffTable>; 4] = [None, None, None, None];
let len = parse_dht(&seg, 0, &mut dc_tables, &mut ac_tables).unwrap();
assert_eq!(len, seg.len());
assert!(dc_tables[0].is_some());
let table = dc_tables[0].as_ref().unwrap();
assert_eq!(table.symbols[0], 0x05);
}
#[test]
fn test_bit_reader_basic() {
// 0xA5 = 10100101
let data = [0xA5];
let mut reader = BitReader::new(&data, 0);
assert_eq!(reader.read_bits(1).unwrap(), 1);
assert_eq!(reader.read_bits(1).unwrap(), 0);
assert_eq!(reader.read_bits(3).unwrap(), 0b100);
assert_eq!(reader.read_bits(3).unwrap(), 0b101);
}
#[test]
fn test_bit_reader_byte_stuffing() {
// JPEG byte stuffing: 0xFF 0x00 -> single 0xFF byte
let data = [0xFF, 0x00, 0x80];
let mut reader = BitReader::new(&data, 0);
let val = reader.read_bits(8).unwrap();
assert_eq!(val, 0xFF);
let val2 = reader.read_bits(1).unwrap();
assert_eq!(val2, 1); // 0x80 = 10000000
}
#[test]
fn test_idct_dc_only() {
let mut block = [0i32; 64];
block[0] = 800; // after dequantization
let result = idct(&block);
let expected = 100; // 800/8 = 100
for &v in &result {
assert!(
(v - expected).abs() <= 1,
"DC-only IDCT: expected ~{}, got {}",
expected,
v
);
}
}
#[test]
fn test_idct_known_values() {
let mut block = [0i32; 64];
block[0] = 640;
block[1] = 100;
let result = idct(&block);
let avg: i32 = result.iter().sum::<i32>() / 64;
assert!((avg - 80).abs() <= 2);
}
/// Build a minimal valid 8x8 Baseline JPEG (grayscale) for testing.
/// DC diff = 0 => Y = 0 => after +128 level shift => pixel = 128.
fn build_minimal_jpeg_8x8() -> Vec<u8> {
let mut out = Vec::new();
// SOI
out.extend_from_slice(&[0xFF, 0xD8]);
// DQT — all-ones quantization table (id=0)
out.extend_from_slice(&[0xFF, 0xDB]);
let mut dqt = Vec::new();
dqt.extend_from_slice(&0x0043u16.to_be_bytes()); // length = 67
dqt.push(0x00); // 8-bit, table 0
for _ in 0..64 {
dqt.push(1);
}
out.extend_from_slice(&dqt);
// SOF0 — 8x8, 1 component (grayscale)
out.extend_from_slice(&[0xFF, 0xC0]);
out.extend_from_slice(&0x000Bu16.to_be_bytes()); // length = 11
out.push(8); // precision
out.extend_from_slice(&8u16.to_be_bytes()); // height
out.extend_from_slice(&8u16.to_be_bytes()); // width
out.push(1); // 1 component
out.push(1); // component ID
out.push(0x11); // h_sample=1, v_sample=1
out.push(0); // qt table 0
// DHT — DC table (class=0, id=0): 1 symbol at length 1, symbol = 0x00
out.extend_from_slice(&[0xFF, 0xC4]);
let mut dht_body = Vec::new();
dht_body.push(0x00); // DC, id=0
dht_body.push(1); // 1 symbol at length 1
for _ in 1..16 {
dht_body.push(0);
}
dht_body.push(0x00); // symbol: category 0 (DC diff = 0)
let dht_len = (dht_body.len() + 2) as u16;
out.extend_from_slice(&dht_len.to_be_bytes());
out.extend_from_slice(&dht_body);
// DHT — AC table (class=1, id=0): 1 symbol at length 1, symbol = 0x00 (EOB)
out.extend_from_slice(&[0xFF, 0xC4]);
let mut dht_ac = Vec::new();
dht_ac.push(0x10); // AC, id=0
dht_ac.push(1); // 1 symbol at length 1
for _ in 1..16 {
dht_ac.push(0);
}
dht_ac.push(0x00); // symbol: 0x00 = EOB
let dht_ac_len = (dht_ac.len() + 2) as u16;
out.extend_from_slice(&dht_ac_len.to_be_bytes());
out.extend_from_slice(&dht_ac);
// SOS
out.extend_from_slice(&[0xFF, 0xDA]);
out.extend_from_slice(&0x0008u16.to_be_bytes()); // length=8
out.push(1); // 1 component
out.push(1); // component id=1
out.push(0x00); // DC table 0, AC table 0
out.push(0); // Ss
out.push(63); // Se
out.push(0); // Ah=0, Al=0
// Scan data: DC=0 (code=0, 1 bit), AC=EOB (code=0, 1 bit)
// Bits: 0 (DC diff=0) + 0 (EOB) = 0b00 -> padded to byte: 0x00
out.push(0x00);
out.push(0x00);
// EOI
out.extend_from_slice(&[0xFF, 0xD9]);
out
}
#[test]
fn test_grayscale_flat() {
let jpg_data = build_minimal_jpeg_8x8();
let (rgba, w, h) = parse_jpg(&jpg_data).unwrap();
assert_eq!(w, 8);
assert_eq!(h, 8);
assert_eq!(rgba.len(), 8 * 8 * 4);
// Grayscale mid-gray: all pixels should be ~128
for i in (0..rgba.len()).step_by(4) {
assert_eq!(rgba[i], rgba[i + 1]); // R == G
assert_eq!(rgba[i + 1], rgba[i + 2]); // G == B
assert_eq!(rgba[i + 3], 255); // alpha
}
}
#[test]
fn test_invalid_marker() {
let data = [0xFF, 0xD8, 0x00]; // SOI then garbage
assert!(parse_jpg(&data).is_err());
}
}

View File

@@ -0,0 +1,297 @@
/// Minimal JSON parser for glTF. No external dependencies.
#[derive(Debug, Clone, PartialEq)]
pub enum JsonValue {
Null,
Bool(bool),
Number(f64),
String(String),
Array(Vec<JsonValue>),
Object(Vec<(String, JsonValue)>), // preserve order
}
impl JsonValue {
pub fn as_object(&self) -> Option<&[(String, JsonValue)]> {
match self { JsonValue::Object(v) => Some(v), _ => None }
}
pub fn as_array(&self) -> Option<&[JsonValue]> {
match self { JsonValue::Array(v) => Some(v), _ => None }
}
pub fn as_str(&self) -> Option<&str> {
match self { JsonValue::String(s) => Some(s), _ => None }
}
pub fn as_f64(&self) -> Option<f64> {
match self { JsonValue::Number(n) => Some(*n), _ => None }
}
pub fn as_u32(&self) -> Option<u32> {
self.as_f64().map(|n| n as u32)
}
pub fn as_bool(&self) -> Option<bool> {
match self { JsonValue::Bool(b) => Some(*b), _ => None }
}
pub fn get(&self, key: &str) -> Option<&JsonValue> {
self.as_object()?.iter().find(|(k, _)| k == key).map(|(_, v)| v)
}
pub fn index(&self, i: usize) -> Option<&JsonValue> {
self.as_array()?.get(i)
}
}
pub fn parse_json(input: &str) -> Result<JsonValue, String> {
let mut parser = JsonParser::new(input);
let val = parser.parse_value()?;
Ok(val)
}
struct JsonParser<'a> {
input: &'a [u8],
pos: usize,
}
impl<'a> JsonParser<'a> {
fn new(input: &'a str) -> Self {
Self { input: input.as_bytes(), pos: 0 }
}
fn skip_whitespace(&mut self) {
while self.pos < self.input.len() {
match self.input[self.pos] {
b' ' | b'\t' | b'\n' | b'\r' => self.pos += 1,
_ => break,
}
}
}
fn peek(&self) -> Option<u8> {
self.input.get(self.pos).copied()
}
fn advance(&mut self) -> Result<u8, String> {
if self.pos >= self.input.len() {
return Err("Unexpected end of JSON".into());
}
let b = self.input[self.pos];
self.pos += 1;
Ok(b)
}
fn expect(&mut self, ch: u8) -> Result<(), String> {
let b = self.advance()?;
if b != ch {
return Err(format!("Expected '{}', got '{}'", ch as char, b as char));
}
Ok(())
}
fn parse_value(&mut self) -> Result<JsonValue, String> {
self.skip_whitespace();
match self.peek() {
Some(b'"') => self.parse_string().map(JsonValue::String),
Some(b'{') => self.parse_object(),
Some(b'[') => self.parse_array(),
Some(b't') => self.parse_literal("true", JsonValue::Bool(true)),
Some(b'f') => self.parse_literal("false", JsonValue::Bool(false)),
Some(b'n') => self.parse_literal("null", JsonValue::Null),
Some(b'-') | Some(b'0'..=b'9') => self.parse_number(),
Some(ch) => Err(format!("Unexpected character: '{}'", ch as char)),
None => Err("Unexpected end of JSON".into()),
}
}
fn parse_string(&mut self) -> Result<String, String> {
self.expect(b'"')?;
let mut s = String::new();
loop {
let b = self.advance()?;
match b {
b'"' => return Ok(s),
b'\\' => {
let esc = self.advance()?;
match esc {
b'"' => s.push('"'),
b'\\' => s.push('\\'),
b'/' => s.push('/'),
b'b' => s.push('\u{08}'),
b'f' => s.push('\u{0C}'),
b'n' => s.push('\n'),
b'r' => s.push('\r'),
b't' => s.push('\t'),
b'u' => {
let mut hex = String::new();
for _ in 0..4 {
hex.push(self.advance()? as char);
}
let code = u32::from_str_radix(&hex, 16)
.map_err(|_| format!("Invalid unicode escape: {}", hex))?;
if let Some(ch) = char::from_u32(code) {
s.push(ch);
}
}
_ => return Err(format!("Invalid escape: \\{}", esc as char)),
}
}
_ => s.push(b as char),
}
}
}
fn parse_number(&mut self) -> Result<JsonValue, String> {
let start = self.pos;
if self.peek() == Some(b'-') { self.pos += 1; }
while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() {
self.pos += 1;
}
if self.pos < self.input.len() && self.input[self.pos] == b'.' {
self.pos += 1;
while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() {
self.pos += 1;
}
}
if self.pos < self.input.len() && (self.input[self.pos] == b'e' || self.input[self.pos] == b'E') {
self.pos += 1;
if self.pos < self.input.len() && (self.input[self.pos] == b'+' || self.input[self.pos] == b'-') {
self.pos += 1;
}
while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() {
self.pos += 1;
}
}
let s = std::str::from_utf8(&self.input[start..self.pos])
.map_err(|_| "Invalid UTF-8 in number".to_string())?;
let n: f64 = s.parse().map_err(|_| format!("Invalid number: {}", s))?;
Ok(JsonValue::Number(n))
}
fn parse_object(&mut self) -> Result<JsonValue, String> {
self.expect(b'{')?;
self.skip_whitespace();
let mut pairs = Vec::new();
if self.peek() == Some(b'}') {
self.pos += 1;
return Ok(JsonValue::Object(pairs));
}
loop {
self.skip_whitespace();
let key = self.parse_string()?;
self.skip_whitespace();
self.expect(b':')?;
let val = self.parse_value()?;
pairs.push((key, val));
self.skip_whitespace();
match self.peek() {
Some(b',') => { self.pos += 1; }
Some(b'}') => { self.pos += 1; return Ok(JsonValue::Object(pairs)); }
_ => return Err("Expected ',' or '}' in object".into()),
}
}
}
fn parse_array(&mut self) -> Result<JsonValue, String> {
self.expect(b'[')?;
self.skip_whitespace();
let mut items = Vec::new();
if self.peek() == Some(b']') {
self.pos += 1;
return Ok(JsonValue::Array(items));
}
loop {
let val = self.parse_value()?;
items.push(val);
self.skip_whitespace();
match self.peek() {
Some(b',') => { self.pos += 1; }
Some(b']') => { self.pos += 1; return Ok(JsonValue::Array(items)); }
_ => return Err("Expected ',' or ']' in array".into()),
}
}
}
fn parse_literal(&mut self, expected: &str, value: JsonValue) -> Result<JsonValue, String> {
for &b in expected.as_bytes() {
let actual = self.advance()?;
if actual != b {
return Err(format!("Expected '{}', got '{}'", b as char, actual as char));
}
}
Ok(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_null() {
assert_eq!(parse_json("null").unwrap(), JsonValue::Null);
}
#[test]
fn test_parse_bool() {
assert_eq!(parse_json("true").unwrap(), JsonValue::Bool(true));
assert_eq!(parse_json("false").unwrap(), JsonValue::Bool(false));
}
#[test]
fn test_parse_number() {
match parse_json("42").unwrap() {
JsonValue::Number(n) => assert!((n - 42.0).abs() < 1e-10),
other => panic!("Expected Number, got {:?}", other),
}
match parse_json("-3.14").unwrap() {
JsonValue::Number(n) => assert!((n - (-3.14)).abs() < 1e-10),
other => panic!("Expected Number, got {:?}", other),
}
}
#[test]
fn test_parse_string() {
assert_eq!(parse_json("\"hello\"").unwrap(), JsonValue::String("hello".into()));
}
#[test]
fn test_parse_string_escapes() {
assert_eq!(
parse_json(r#""hello\nworld""#).unwrap(),
JsonValue::String("hello\nworld".into())
);
}
#[test]
fn test_parse_array() {
let val = parse_json("[1, 2, 3]").unwrap();
match val {
JsonValue::Array(arr) => assert_eq!(arr.len(), 3),
other => panic!("Expected Array, got {:?}", other),
}
}
#[test]
fn test_parse_object() {
let val = parse_json(r#"{"name": "test", "value": 42}"#).unwrap();
match val {
JsonValue::Object(map) => {
assert_eq!(map.len(), 2);
assert_eq!(map[0].0, "name");
}
other => panic!("Expected Object, got {:?}", other),
}
}
#[test]
fn test_parse_nested() {
let json = r#"{"meshes": [{"name": "Cube", "primitives": [{"attributes": {"POSITION": 0}}]}]}"#;
let val = parse_json(json).unwrap();
assert!(matches!(val, JsonValue::Object(_)));
}
#[test]
fn test_parse_empty_array() {
assert_eq!(parse_json("[]").unwrap(), JsonValue::Array(vec![]));
}
#[test]
fn test_parse_empty_object() {
assert_eq!(parse_json("{}").unwrap(), JsonValue::Object(vec![]));
}
}

View File

@@ -1,5 +1,9 @@
pub mod json_parser;
pub mod gltf;
pub mod deflate;
pub mod png;
pub mod jpg;
pub mod progressive_jpeg;
pub mod gpu;
pub mod light;
pub mod obj;
@@ -13,29 +17,55 @@ pub mod sphere;
pub mod pbr_pipeline;
pub mod shadow;
pub mod shadow_pipeline;
pub mod csm;
pub mod point_shadow;
pub mod spot_shadow;
pub mod frustum;
pub mod brdf_lut;
pub mod ibl;
pub mod sh;
pub mod gbuffer;
pub mod gbuffer_compress;
pub mod fullscreen_quad;
pub mod deferred_pipeline;
pub mod ssgi;
pub mod rt_accel;
pub mod rt_shadow;
pub mod rt_reflections;
pub mod rt_ao;
pub mod rt_point_shadow;
pub mod hdr;
pub mod bloom;
pub mod tonemap;
pub mod forward_pass;
pub mod auto_exposure;
pub mod instancing;
pub mod bilateral_blur;
pub mod temporal_accum;
pub mod taa;
pub mod ssr;
pub mod motion_blur;
pub mod dof;
pub use motion_blur::MotionBlur;
pub use dof::DepthOfField;
pub use gpu::{GpuContext, DEPTH_FORMAT};
pub use light::{CameraUniform, LightUniform, LightData, LightsUniform, MAX_LIGHTS, LIGHT_DIRECTIONAL, LIGHT_POINT, LIGHT_SPOT};
pub use mesh::Mesh;
pub use vertex::{Vertex, MeshVertex};
pub use camera::{Camera, FpsController};
pub use texture::{GpuTexture, pbr_texture_bind_group_layout, create_pbr_texture_bind_group};
pub use texture::{GpuTexture, pbr_texture_bind_group_layout, create_pbr_texture_bind_group, pbr_full_texture_bind_group_layout, create_pbr_full_texture_bind_group};
pub use material::MaterialUniform;
pub use sphere::generate_sphere;
pub use pbr_pipeline::create_pbr_pipeline;
pub use shadow::{ShadowMap, ShadowUniform, ShadowPassUniform, SHADOW_MAP_SIZE, SHADOW_FORMAT};
pub use shadow_pipeline::{create_shadow_pipeline, shadow_pass_bind_group_layout};
pub use csm::{CascadedShadowMap, CsmUniform, compute_cascade_matrices, CSM_CASCADE_COUNT, CSM_MAP_SIZE, CSM_FORMAT};
pub use point_shadow::{PointShadowMap, point_shadow_view_matrices, point_shadow_projection, POINT_SHADOW_SIZE, POINT_SHADOW_FORMAT};
pub use spot_shadow::{SpotShadowMap, spot_shadow_matrix};
pub use frustum::{Plane, Frustum, extract_frustum, sphere_vs_frustum, cull_lights};
pub use ibl::IblResources;
pub use sh::{compute_sh_coefficients, pack_sh_coefficients, evaluate_sh_cpu};
pub use gbuffer::GBuffer;
pub use fullscreen_quad::{create_fullscreen_vertex_buffer, FullscreenVertex};
pub use deferred_pipeline::{
@@ -50,7 +80,32 @@ pub use deferred_pipeline::{
pub use ssgi::{SsgiResources, SsgiUniform, SSGI_OUTPUT_FORMAT};
pub use rt_accel::{RtAccel, RtInstance, BlasMeshData, mat4_to_tlas_transform};
pub use rt_shadow::{RtShadowResources, RtShadowUniform, RT_SHADOW_FORMAT};
pub use rt_reflections::RtReflections;
pub use rt_ao::RtAo;
pub use rt_point_shadow::RtPointShadow;
pub use hdr::{HdrTarget, HDR_FORMAT};
pub use bloom::{BloomResources, BloomUniform, mip_sizes, BLOOM_MIP_COUNT};
pub use tonemap::{TonemapUniform, aces_tonemap};
pub use forward_pass::{ForwardPass, sort_transparent_back_to_front};
pub use auto_exposure::AutoExposure;
pub use instancing::{InstanceData, InstanceBuffer, create_instanced_pipeline};
pub use bilateral_blur::BilateralBlur;
pub use temporal_accum::TemporalAccumulation;
pub use taa::Taa;
pub use ssr::Ssr;
pub mod stencil_opt;
pub mod half_res_ssgi;
pub mod bilateral_bloom;
pub use png::parse_png;
pub use jpg::parse_jpg;
pub use gltf::{parse_gltf, GltfData, GltfMesh, GltfMaterial};
pub mod soft_rt_shadow;
pub mod blas_update;
pub use blas_update::BlasTracker;
pub mod rt_fallback;
pub use rt_fallback::{RtCapabilities, RenderMode};
pub mod light_probes;
pub use light_probes::{LightProbe, LightProbeGrid};
pub mod light_volumes;
pub use light_volumes::LightVolume;

View File

@@ -0,0 +1,63 @@
pub struct LightProbe {
pub position: [f32; 3],
pub sh_coefficients: [[f32; 3]; 9], // L2 SH, 9 RGB coefficients
}
impl LightProbe {
pub fn new(position: [f32; 3]) -> Self {
LightProbe { position, sh_coefficients: [[0.0; 3]; 9] }
}
pub fn evaluate_irradiance(&self, normal: [f32; 3]) -> [f32; 3] {
// L0
let mut result = [self.sh_coefficients[0][0], self.sh_coefficients[0][1], self.sh_coefficients[0][2]];
// L1
let (nx, ny, nz) = (normal[0], normal[1], normal[2]);
for c in 0..3 {
result[c] += self.sh_coefficients[1][c] * ny;
result[c] += self.sh_coefficients[2][c] * nz;
result[c] += self.sh_coefficients[3][c] * nx;
}
result
}
}
pub struct LightProbeGrid {
probes: Vec<LightProbe>,
}
impl LightProbeGrid {
pub fn new() -> Self { LightProbeGrid { probes: Vec::new() } }
pub fn add(&mut self, probe: LightProbe) { self.probes.push(probe); }
pub fn nearest(&self, pos: [f32; 3]) -> Option<&LightProbe> {
self.probes.iter().min_by(|a, b| {
let da = dist_sq(a.position, pos);
let db = dist_sq(b.position, pos);
da.partial_cmp(&db).unwrap()
})
}
pub fn len(&self) -> usize { self.probes.len() }
}
fn dist_sq(a: [f32; 3], b: [f32; 3]) -> f32 {
let dx = a[0]-b[0]; let dy = a[1]-b[1]; let dz = a[2]-b[2]; dx*dx+dy*dy+dz*dz
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_probe_evaluate() {
let mut p = LightProbe::new([0.0; 3]);
p.sh_coefficients[0] = [0.5, 0.5, 0.5]; // ambient
let irr = p.evaluate_irradiance([0.0, 1.0, 0.0]);
assert!(irr[0] > 0.0);
}
#[test]
fn test_grid_nearest() {
let mut grid = LightProbeGrid::new();
grid.add(LightProbe::new([0.0, 0.0, 0.0]));
grid.add(LightProbe::new([10.0, 0.0, 0.0]));
let nearest = grid.nearest([1.0, 0.0, 0.0]).unwrap();
assert!((nearest.position[0] - 0.0).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,54 @@
/// Light volume shapes for deferred lighting optimization.
#[derive(Debug, Clone)]
pub enum LightVolume {
Sphere { center: [f32; 3], radius: f32 },
Cone { apex: [f32; 3], direction: [f32; 3], angle: f32, range: f32 },
Fullscreen,
}
impl LightVolume {
pub fn point_light(center: [f32; 3], radius: f32) -> Self { LightVolume::Sphere { center, radius } }
pub fn spot_light(apex: [f32; 3], dir: [f32; 3], angle: f32, range: f32) -> Self {
LightVolume::Cone { apex, direction: dir, angle, range }
}
pub fn directional() -> Self { LightVolume::Fullscreen }
pub fn contains_point(&self, point: [f32; 3]) -> bool {
match self {
LightVolume::Sphere { center, radius } => {
let dx = point[0]-center[0]; let dy = point[1]-center[1]; let dz = point[2]-center[2];
dx*dx + dy*dy + dz*dz <= radius * radius
}
LightVolume::Fullscreen => true,
LightVolume::Cone { apex, direction, angle, range } => {
let dx = point[0]-apex[0]; let dy = point[1]-apex[1]; let dz = point[2]-apex[2];
let dist = (dx*dx+dy*dy+dz*dz).sqrt();
if dist > *range { return false; }
if dist < 1e-6 { return true; }
let dot = (dx*direction[0]+dy*direction[1]+dz*direction[2]) / dist;
dot >= angle.cos()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sphere_contains() {
let v = LightVolume::point_light([0.0; 3], 5.0);
assert!(v.contains_point([3.0, 0.0, 0.0]));
assert!(!v.contains_point([6.0, 0.0, 0.0]));
}
#[test]
fn test_fullscreen() {
assert!(LightVolume::directional().contains_point([999.0, 999.0, 999.0]));
}
#[test]
fn test_cone_contains() {
let v = LightVolume::spot_light([0.0;3], [0.0,0.0,-1.0], 0.5, 10.0);
assert!(v.contains_point([0.0, 0.0, -5.0])); // on axis
assert!(!v.contains_point([0.0, 0.0, 5.0])); // behind
}
}

View File

@@ -0,0 +1,177 @@
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
pub struct MotionBlurParams {
pub inv_view_proj: [[f32; 4]; 4],
pub prev_view_proj: [[f32; 4]; 4],
pub num_samples: u32,
pub strength: f32,
pub _pad: [f32; 2],
}
impl MotionBlurParams {
pub fn new() -> Self {
MotionBlurParams {
inv_view_proj: [[0.0; 4]; 4],
prev_view_proj: [[0.0; 4]; 4],
num_samples: 8,
strength: 1.0,
_pad: [0.0; 2],
}
}
}
pub struct MotionBlur {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
params_buffer: wgpu::Buffer,
pub enabled: bool,
}
impl MotionBlur {
pub fn new(device: &wgpu::Device) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Motion Blur Compute"),
source: wgpu::ShaderSource::Wgsl(include_str!("motion_blur.wgsl").into()),
});
let bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Motion Blur BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Float { filterable: false },
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Depth,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture {
access: wgpu::StorageTextureAccess::WriteOnly,
format: wgpu::TextureFormat::Rgba16Float,
view_dimension: wgpu::TextureViewDimension::D2,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Motion Blur PL"),
bind_group_layouts: &[&bind_group_layout],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Motion Blur Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Motion Blur Params"),
size: std::mem::size_of::<MotionBlurParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
MotionBlur {
pipeline,
bind_group_layout,
params_buffer,
enabled: true,
}
}
pub fn dispatch(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
encoder: &mut wgpu::CommandEncoder,
color_view: &wgpu::TextureView,
depth_view: &wgpu::TextureView,
output_view: &wgpu::TextureView,
params: &MotionBlurParams,
width: u32,
height: u32,
) {
queue.write_buffer(&self.params_buffer, 0, bytemuck::cast_slice(&[*params]));
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Motion Blur BG"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(color_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::TextureView(depth_view),
},
wgpu::BindGroupEntry {
binding: 2,
resource: wgpu::BindingResource::TextureView(output_view),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.params_buffer.as_entire_binding(),
},
],
});
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Motion Blur Pass"),
timestamp_writes: None,
});
cpass.set_pipeline(&self.pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups((width + 15) / 16, (height + 15) / 16, 1);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_params_default() {
let p = MotionBlurParams::new();
assert_eq!(p.num_samples, 8);
assert!((p.strength - 1.0).abs() < 1e-6);
}
#[test]
fn test_params_size_aligned() {
let size = std::mem::size_of::<MotionBlurParams>();
assert_eq!(size % 16, 0);
}
}

View File

@@ -0,0 +1,49 @@
struct MotionBlurParams {
inv_view_proj: mat4x4<f32>,
prev_view_proj: mat4x4<f32>,
num_samples: u32,
strength: f32,
_pad: vec2<f32>,
};
@group(0) @binding(0) var color_tex: texture_2d<f32>;
@group(0) @binding(1) var depth_tex: texture_depth_2d;
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba16float, write>;
@group(0) @binding(3) var<uniform> params: MotionBlurParams;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dims = textureDimensions(color_tex);
if (gid.x >= dims.x || gid.y >= dims.y) { return; }
let pos = vec2<i32>(gid.xy);
let uv = (vec2<f32>(gid.xy) + 0.5) / vec2<f32>(dims);
// Reconstruct world position from depth
let depth = textureLoad(depth_tex, pos, 0);
let ndc = vec4<f32>(uv.x * 2.0 - 1.0, 1.0 - uv.y * 2.0, depth, 1.0);
let world = params.inv_view_proj * ndc;
let world_pos = world.xyz / world.w;
// Project to previous frame
let prev_clip = params.prev_view_proj * vec4<f32>(world_pos, 1.0);
let prev_ndc = prev_clip.xyz / prev_clip.w;
let prev_uv = vec2<f32>(prev_ndc.x * 0.5 + 0.5, 1.0 - (prev_ndc.y * 0.5 + 0.5));
// Velocity = current_uv - prev_uv
let velocity = (uv - prev_uv) * params.strength;
// Sample along velocity direction
var color = vec4<f32>(0.0);
let n = params.num_samples;
for (var i = 0u; i < n; i++) {
let t = f32(i) / f32(n - 1u) - 0.5; // -0.5 to 0.5
let sample_uv = uv + velocity * t;
let sample_pos = vec2<i32>(sample_uv * vec2<f32>(dims));
let clamped = clamp(sample_pos, vec2<i32>(0), vec2<i32>(dims) - 1);
color += textureLoad(color_tex, clamped, 0);
}
color /= f32(n);
textureStore(output_tex, pos, color);
}

View File

@@ -36,6 +36,10 @@ struct MaterialUniform {
@group(1) @binding(1) var s_diffuse: sampler;
@group(1) @binding(2) var t_normal: texture_2d<f32>;
@group(1) @binding(3) var s_normal: sampler;
@group(1) @binding(4) var t_orm: texture_2d<f32>;
@group(1) @binding(5) var s_orm: sampler;
@group(1) @binding(6) var t_emissive: texture_2d<f32>;
@group(1) @binding(7) var s_emissive: sampler;
@group(2) @binding(0) var<uniform> material: MaterialUniform;
@@ -43,6 +47,10 @@ struct ShadowUniform {
light_view_proj: mat4x4<f32>,
shadow_map_size: f32,
shadow_bias: f32,
_padding: vec2<f32>,
sun_direction: vec3<f32>,
turbidity: f32,
sh_coefficients: array<vec4<f32>, 7>,
};
@group(3) @binding(0) var t_shadow: texture_depth_2d;
@@ -229,30 +237,93 @@ fn calculate_shadow(light_space_pos: vec4<f32>) -> f32 {
return shadow_val / 9.0;
}
// Procedural environment sampling for IBL
// Hosek-Wilkie inspired procedural sky model
fn sample_environment(direction: vec3<f32>, roughness: f32) -> vec3<f32> {
let t = direction.y * 0.5 + 0.5;
let sun_dir = normalize(shadow.sun_direction);
let turb = clamp(shadow.turbidity, 1.5, 10.0);
var env: vec3<f32>;
if direction.y > 0.0 {
let horizon = vec3<f32>(0.6, 0.6, 0.5);
let sky = vec3<f32>(0.3, 0.5, 0.9);
env = mix(horizon, sky, pow(direction.y, 0.4));
// Rayleigh scattering: blue zenith, warm horizon
let zenith_color = vec3<f32>(0.15, 0.3, 0.8) * (1.0 / (turb * 0.15 + 0.5));
let horizon_color = vec3<f32>(0.7, 0.6, 0.5) * (1.0 + turb * 0.04);
let elevation = direction.y;
let sky_gradient = mix(horizon_color, zenith_color, pow(elevation, 0.4));
// Mie scattering: haze near sun direction
let cos_sun = max(dot(direction, sun_dir), 0.0);
let mie_strength = turb * 0.02;
let mie = mie_strength * pow(cos_sun, 8.0) * vec3<f32>(1.0, 0.9, 0.7);
// Sun disk: bright spot with falloff
let sun_disk = pow(max(cos_sun, 0.0), 2048.0) * vec3<f32>(10.0, 9.0, 7.0);
// Combine
env = sky_gradient + mie + sun_disk;
} else {
let horizon = vec3<f32>(0.6, 0.6, 0.5);
let ground = vec3<f32>(0.1, 0.08, 0.06);
env = mix(horizon, ground, pow(-direction.y, 0.4));
// Ground: dark, warm
let horizon_color = vec3<f32>(0.6, 0.55, 0.45);
let ground_color = vec3<f32>(0.1, 0.08, 0.06);
env = mix(horizon_color, ground_color, pow(-direction.y, 0.4));
}
// Roughness blur: blend toward average for rough surfaces
let avg = vec3<f32>(0.3, 0.35, 0.4);
return mix(env, avg, roughness * roughness);
}
// Evaluate L2 Spherical Harmonics at given normal direction
// 9 SH coefficients (RGB) packed into 7 vec4s
fn evaluate_sh(normal: vec3<f32>, coeffs: array<vec4<f32>, 7>) -> vec3<f32> {
let x = normal.x;
let y = normal.y;
let z = normal.z;
// SH basis functions (real, L2 order)
let Y00 = 0.282095; // L=0, M=0
let Y1n1 = 0.488603 * y; // L=1, M=-1
let Y10 = 0.488603 * z; // L=1, M=0
let Y1p1 = 0.488603 * x; // L=1, M=1
let Y2n2 = 1.092548 * x * y; // L=2, M=-2
let Y2n1 = 1.092548 * y * z; // L=2, M=-1
let Y20 = 0.315392 * (3.0 * z * z - 1.0); // L=2, M=0
let Y2p1 = 1.092548 * x * z; // L=2, M=1
let Y2p2 = 0.546274 * (x * x - y * y); // L=2, M=2
// Unpack: coeffs[0].xyz = c0_rgb, coeffs[0].w = c1_r,
// coeffs[1].xyz = c1_gb + c2_r, coeffs[1].w = c2_g, etc.
// Packing: 9 coeffs * 3 channels = 27 floats -> 7 vec4s (28 floats, last padded)
let c0 = vec3<f32>(coeffs[0].x, coeffs[0].y, coeffs[0].z);
let c1 = vec3<f32>(coeffs[0].w, coeffs[1].x, coeffs[1].y);
let c2 = vec3<f32>(coeffs[1].z, coeffs[1].w, coeffs[2].x);
let c3 = vec3<f32>(coeffs[2].y, coeffs[2].z, coeffs[2].w);
let c4 = vec3<f32>(coeffs[3].x, coeffs[3].y, coeffs[3].z);
let c5 = vec3<f32>(coeffs[3].w, coeffs[4].x, coeffs[4].y);
let c6 = vec3<f32>(coeffs[4].z, coeffs[4].w, coeffs[5].x);
let c7 = vec3<f32>(coeffs[5].y, coeffs[5].z, coeffs[5].w);
let c8 = vec3<f32>(coeffs[6].x, coeffs[6].y, coeffs[6].z);
return max(
c0 * Y00 + c1 * Y1n1 + c2 * Y10 + c3 * Y1p1 +
c4 * Y2n2 + c5 * Y2n1 + c6 * Y20 + c7 * Y2p1 + c8 * Y2p2,
vec3<f32>(0.0)
);
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let tex_color = textureSample(t_diffuse, s_diffuse, in.uv);
let albedo = material.base_color.rgb * tex_color.rgb;
let metallic = material.metallic;
let roughness = material.roughness;
let ao = material.ao;
// Sample ORM texture: R=AO, G=Roughness, B=Metallic; multiply with material params
let orm_sample = textureSample(t_orm, s_orm, in.uv);
let ao = orm_sample.r * material.ao;
let roughness = orm_sample.g * material.roughness;
let metallic = orm_sample.b * material.metallic;
// Sample emissive texture
let emissive = textureSample(t_emissive, s_emissive, in.uv).rgb;
// Normal mapping via TBN matrix
let T = normalize(in.world_tangent);
@@ -291,8 +362,14 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let NdotV_ibl = max(dot(N, V), 0.0);
let R = reflect(-V, N);
// Diffuse IBL
let irradiance = sample_environment(N, 1.0);
// Diffuse IBL: use SH irradiance if SH coefficients are set, else fallback to procedural
var irradiance: vec3<f32>;
let sh_test = shadow.sh_coefficients[0].x + shadow.sh_coefficients[0].y + shadow.sh_coefficients[0].z;
if abs(sh_test) > 0.0001 {
irradiance = evaluate_sh(N, shadow.sh_coefficients);
} else {
irradiance = sample_environment(N, 1.0);
}
let F_env = fresnel_schlick(NdotV_ibl, F0);
let kd_ibl = (vec3<f32>(1.0) - F_env) * (1.0 - metallic);
let diffuse_ibl = kd_ibl * albedo * irradiance;
@@ -304,7 +381,7 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let ambient = (diffuse_ibl + specular_ibl) * ao;
var color = ambient + Lo;
var color = ambient + Lo + emissive;
// Reinhard tone mapping
color = color / (color + vec3<f32>(1.0));

View File

@@ -0,0 +1,178 @@
use voltex_math::{Mat4, Vec3};
pub const POINT_SHADOW_SIZE: u32 = 512;
pub const POINT_SHADOW_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Depth32Float;
/// Depth cube map for omnidirectional point light shadows.
///
/// Uses a single cube texture with 6 faces (512x512 each).
pub struct PointShadowMap {
pub texture: wgpu::Texture,
pub view: wgpu::TextureView,
/// Per-face views for rendering into each cube face.
pub face_views: [wgpu::TextureView; 6],
pub sampler: wgpu::Sampler,
}
impl PointShadowMap {
pub fn new(device: &wgpu::Device) -> Self {
let texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("Point Shadow Cube Texture"),
size: wgpu::Extent3d {
width: POINT_SHADOW_SIZE,
height: POINT_SHADOW_SIZE,
depth_or_array_layers: 6,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: POINT_SHADOW_FORMAT,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
// Full cube view for sampling in shader
let view = texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("Point Shadow Cube View"),
dimension: Some(wgpu::TextureViewDimension::Cube),
..Default::default()
});
// Per-face views for rendering
let face_labels = ["+X", "-X", "+Y", "-Y", "+Z", "-Z"];
let face_views = std::array::from_fn(|i| {
texture.create_view(&wgpu::TextureViewDescriptor {
label: Some(&format!("Point Shadow Face {}", face_labels[i])),
dimension: Some(wgpu::TextureViewDimension::D2),
base_array_layer: i as u32,
array_layer_count: Some(1),
..Default::default()
})
});
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("Point Shadow Sampler"),
address_mode_u: wgpu::AddressMode::ClampToEdge,
address_mode_v: wgpu::AddressMode::ClampToEdge,
address_mode_w: wgpu::AddressMode::ClampToEdge,
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::MipmapFilterMode::Nearest,
compare: Some(wgpu::CompareFunction::LessEqual),
..Default::default()
});
Self {
texture,
view,
face_views,
sampler,
}
}
}
/// Compute the 6 view matrices for rendering into a point light shadow cube map.
///
/// Order: +X, -X, +Y, -Y, +Z, -Z (matching wgpu cube face order).
///
/// Each matrix is a look-at view matrix from `light_pos` toward the
/// corresponding axis direction, with an appropriate up vector.
pub fn point_shadow_view_matrices(light_pos: Vec3) -> [Mat4; 6] {
[
// +X: look right
Mat4::look_at(light_pos, light_pos + Vec3::X, -Vec3::Y),
// -X: look left
Mat4::look_at(light_pos, light_pos - Vec3::X, -Vec3::Y),
// +Y: look up
Mat4::look_at(light_pos, light_pos + Vec3::Y, Vec3::Z),
// -Y: look down
Mat4::look_at(light_pos, light_pos - Vec3::Y, -Vec3::Z),
// +Z: look forward
Mat4::look_at(light_pos, light_pos + Vec3::Z, -Vec3::Y),
// -Z: look backward
Mat4::look_at(light_pos, light_pos - Vec3::Z, -Vec3::Y),
]
}
/// Compute the perspective projection for point light shadow rendering.
///
/// 90 degree FOV, 1:1 aspect ratio.
/// `near` and `far` control the shadow range (typically 0.1 and light.range).
pub fn point_shadow_projection(near: f32, far: f32) -> Mat4 {
Mat4::perspective(std::f32::consts::FRAC_PI_2, 1.0, near, far)
}
#[cfg(test)]
mod tests {
use super::*;
use voltex_math::Vec4;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < 1e-4
}
#[test]
fn test_point_shadow_view_matrices_count() {
let views = point_shadow_view_matrices(Vec3::ZERO);
assert_eq!(views.len(), 6);
}
#[test]
fn test_point_shadow_view_directions() {
let pos = Vec3::new(0.0, 0.0, 0.0);
let views = point_shadow_view_matrices(pos);
// For the +X face, a point at (1, 0, 0) should map to the center of the view
// (i.e., to (0, 0, -1) in view space, roughly).
let test_point = Vec4::new(1.0, 0.0, 0.0, 1.0);
let in_view = views[0].mul_vec4(test_point);
// z should be negative (in front of camera)
assert!(in_view.z < 0.0, "+X face: point at +X should be in front, got z={}", in_view.z);
// x and y should be near 0 (centered)
assert!(approx_eq(in_view.x, 0.0), "+X face: expected x~0, got {}", in_view.x);
assert!(approx_eq(in_view.y, 0.0), "+X face: expected y~0, got {}", in_view.y);
// For the -X face, a point at (-1, 0, 0) should be in front
let test_neg_x = Vec4::new(-1.0, 0.0, 0.0, 1.0);
let in_view_neg = views[1].mul_vec4(test_neg_x);
assert!(in_view_neg.z < 0.0, "-X face: point at -X should be in front, got z={}", in_view_neg.z);
}
#[test]
fn test_point_shadow_view_offset_position() {
let pos = Vec3::new(5.0, 10.0, -3.0);
let views = point_shadow_view_matrices(pos);
// Origin of the light should map to (0,0,0) in view space
let origin = Vec4::from_vec3(pos, 1.0);
for (i, view) in views.iter().enumerate() {
let v = view.mul_vec4(origin);
assert!(approx_eq(v.x, 0.0), "Face {}: origin x should be 0, got {}", i, v.x);
assert!(approx_eq(v.y, 0.0), "Face {}: origin y should be 0, got {}", i, v.y);
assert!(approx_eq(v.z, 0.0), "Face {}: origin z should be 0, got {}", i, v.z);
}
}
#[test]
fn test_point_shadow_projection_90fov() {
let proj = point_shadow_projection(0.1, 100.0);
// At 90 degree FOV with aspect 1:1, a point at (-near, 0, -near) in view space
// should project to the edge of the viewport.
let near = 0.1f32;
let edge = Vec4::new(near, 0.0, -near, 1.0);
let clip = proj.mul_vec4(edge);
let ndc_x = clip.x / clip.w;
// Should be at x=1.0 (right edge)
assert!(approx_eq(ndc_x, 1.0), "Expected NDC x=1.0, got {}", ndc_x);
}
#[test]
fn test_point_shadow_projection_near_plane() {
let proj = point_shadow_projection(0.1, 100.0);
let p = Vec4::new(0.0, 0.0, -0.1, 1.0);
let clip = proj.mul_vec4(p);
let ndc_z = clip.z / clip.w;
assert!(approx_eq(ndc_z, 0.0), "Near plane should map to NDC z=0, got {}", ndc_z);
}
}

Some files were not shown because too many files have changed in this diff Show More