From 0ef750de6937af6117fc8dec88d23a485129fe63 Mon Sep 17 00:00:00 2001 From: tolelom <98kimsungmin@naver.com> Date: Wed, 25 Mar 2026 21:03:52 +0900 Subject: [PATCH] feat(net): add reliability layer, state sync, and client interpolation - ReliableChannel: sequence numbers, ACK, retransmission, RTT estimation - OrderedChannel: in-order delivery with out-of-order buffering - Snapshot serialization with delta compression (per-field bitmask) - InterpolationBuffer: linear interpolation between server snapshots - New packet types: Reliable, Ack, Snapshot, SnapshotDelta Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/voltex_net/src/interpolation.rs | 234 +++++++++++++++ crates/voltex_net/src/lib.rs | 6 + crates/voltex_net/src/packet.rs | 94 ++++++ crates/voltex_net/src/reliable.rs | 318 +++++++++++++++++++++ crates/voltex_net/src/snapshot.rs | 378 +++++++++++++++++++++++++ 5 files changed, 1030 insertions(+) create mode 100644 crates/voltex_net/src/interpolation.rs create mode 100644 crates/voltex_net/src/reliable.rs create mode 100644 crates/voltex_net/src/snapshot.rs diff --git a/crates/voltex_net/src/interpolation.rs b/crates/voltex_net/src/interpolation.rs new file mode 100644 index 0000000..30f27d9 --- /dev/null +++ b/crates/voltex_net/src/interpolation.rs @@ -0,0 +1,234 @@ +use std::collections::VecDeque; + +use crate::snapshot::{EntityState, Snapshot}; + +/// Buffers recent snapshots and interpolates between them for smooth rendering. +pub struct InterpolationBuffer { + snapshots: VecDeque<(f64, Snapshot)>, + /// Render delay behind the latest server time (seconds). + interp_delay: f64, + /// Maximum number of snapshots to keep in the buffer. + max_snapshots: usize, +} + +impl InterpolationBuffer { + /// Create a new interpolation buffer with the given delay in seconds. + pub fn new(interp_delay: f64) -> Self { + InterpolationBuffer { + snapshots: VecDeque::new(), + interp_delay, + max_snapshots: 32, + } + } + + /// Push a new snapshot with its server timestamp. + pub fn push(&mut self, server_time: f64, snapshot: Snapshot) { + self.snapshots.push_back((server_time, snapshot)); + // Evict old snapshots beyond the buffer limit + while self.snapshots.len() > self.max_snapshots { + self.snapshots.pop_front(); + } + } + + /// Interpolate to produce a snapshot for the given render_time. + /// + /// The render_time should be `current_server_time - interp_delay`. + /// Returns None if there are fewer than 2 snapshots or render_time + /// is before all buffered snapshots. + pub fn interpolate(&self, render_time: f64) -> Option { + if self.snapshots.len() < 2 { + return None; + } + + // Find two bracketing snapshots: the last one <= render_time and the first one > render_time + let mut before = None; + let mut after = None; + + for (i, (time, _)) in self.snapshots.iter().enumerate() { + if *time <= render_time { + before = Some(i); + } else { + after = Some(i); + break; + } + } + + match (before, after) { + (Some(b), Some(a)) => { + let (t0, snap0) = &self.snapshots[b]; + let (t1, snap1) = &self.snapshots[a]; + let dt = t1 - t0; + if dt <= 0.0 { + return Some(snap0.clone()); + } + let alpha = ((render_time - t0) / dt).clamp(0.0, 1.0) as f32; + Some(lerp_snapshots(snap0, snap1, alpha)) + } + (Some(b), None) => { + // render_time is beyond all snapshots — return the latest + Some(self.snapshots[b].1.clone()) + } + _ => None, + } + } + + /// Get the interpolation delay. + pub fn delay(&self) -> f64 { + self.interp_delay + } +} + +fn lerp(a: f32, b: f32, t: f32) -> f32 { + a + (b - a) * t +} + +fn lerp_f32x3(a: &[f32; 3], b: &[f32; 3], t: f32) -> [f32; 3] { + [lerp(a[0], b[0], t), lerp(a[1], b[1], t), lerp(a[2], b[2], t)] +} + +fn lerp_entity(a: &EntityState, b: &EntityState, t: f32) -> EntityState { + EntityState { + id: a.id, + position: lerp_f32x3(&a.position, &b.position, t), + rotation: lerp_f32x3(&a.rotation, &b.rotation, t), + velocity: lerp_f32x3(&a.velocity, &b.velocity, t), + } +} + +/// Linearly interpolate between two snapshots. +/// Entities are matched by id. Entities only in one snapshot are included as-is. +fn lerp_snapshots(a: &Snapshot, b: &Snapshot, t: f32) -> Snapshot { + use std::collections::HashMap; + + let a_map: HashMap = a.entities.iter().map(|e| (e.id, e)).collect(); + let b_map: HashMap = b.entities.iter().map(|e| (e.id, e)).collect(); + + let mut entities = Vec::new(); + + // Interpolate matched entities, include a-only entities + for ea in &a.entities { + if let Some(eb) = b_map.get(&ea.id) { + entities.push(lerp_entity(ea, eb, t)); + } else { + entities.push(ea.clone()); + } + } + + // Include b-only entities + for eb in &b.entities { + if !a_map.contains_key(&eb.id) { + entities.push(eb.clone()); + } + } + + // Interpolate tick + let tick = (a.tick as f64 + (b.tick as f64 - a.tick as f64) * t as f64) as u32; + + Snapshot { tick, entities } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::snapshot::EntityState; + + fn make_snapshot(tick: u32, x: f32) -> Snapshot { + Snapshot { + tick, + entities: vec![EntityState { + id: 1, + position: [x, 0.0, 0.0], + rotation: [0.0, 0.0, 0.0], + velocity: [0.0, 0.0, 0.0], + }], + } + } + + #[test] + fn test_exact_match_at_snapshot_time() { + let mut buf = InterpolationBuffer::new(0.1); + buf.push(0.0, make_snapshot(0, 0.0)); + buf.push(0.1, make_snapshot(1, 10.0)); + + let result = buf.interpolate(0.0).expect("should interpolate"); + assert_eq!(result.entities[0].position[0], 0.0); + } + + #[test] + fn test_midpoint_interpolation() { + let mut buf = InterpolationBuffer::new(0.1); + buf.push(0.0, make_snapshot(0, 0.0)); + buf.push(1.0, make_snapshot(10, 10.0)); + + let result = buf.interpolate(0.5).expect("should interpolate"); + let x = result.entities[0].position[0]; + assert!( + (x - 5.0).abs() < 0.001, + "Expected ~5.0 at midpoint, got {}", + x + ); + } + + #[test] + fn test_interpolation_at_quarter() { + let mut buf = InterpolationBuffer::new(0.1); + buf.push(0.0, make_snapshot(0, 0.0)); + buf.push(1.0, make_snapshot(10, 100.0)); + + let result = buf.interpolate(0.25).unwrap(); + let x = result.entities[0].position[0]; + assert!( + (x - 25.0).abs() < 0.01, + "Expected ~25.0 at 0.25, got {}", + x + ); + } + + #[test] + fn test_extrapolation_returns_latest() { + let mut buf = InterpolationBuffer::new(0.1); + buf.push(0.0, make_snapshot(0, 0.0)); + buf.push(1.0, make_snapshot(10, 10.0)); + + // render_time beyond all snapshots + let result = buf.interpolate(2.0).expect("should return latest"); + assert_eq!(result.entities[0].position[0], 10.0); + } + + #[test] + fn test_too_few_snapshots_returns_none() { + let mut buf = InterpolationBuffer::new(0.1); + assert!(buf.interpolate(0.0).is_none()); + + buf.push(0.0, make_snapshot(0, 0.0)); + assert!(buf.interpolate(0.0).is_none()); + } + + #[test] + fn test_render_time_before_all_snapshots() { + let mut buf = InterpolationBuffer::new(0.1); + buf.push(1.0, make_snapshot(10, 10.0)); + buf.push(2.0, make_snapshot(20, 20.0)); + + // render_time before the first snapshot + let result = buf.interpolate(0.0); + assert!(result.is_none()); + } + + #[test] + fn test_multiple_snapshots_picks_correct_bracket() { + let mut buf = InterpolationBuffer::new(0.1); + buf.push(0.0, make_snapshot(0, 0.0)); + buf.push(1.0, make_snapshot(1, 10.0)); + buf.push(2.0, make_snapshot(2, 20.0)); + + // Should interpolate between snapshot at t=1 and t=2 + let result = buf.interpolate(1.5).unwrap(); + let x = result.entities[0].position[0]; + assert!( + (x - 15.0).abs() < 0.01, + "Expected ~15.0, got {}", + x + ); + } +} diff --git a/crates/voltex_net/src/lib.rs b/crates/voltex_net/src/lib.rs index dc640de..472d057 100644 --- a/crates/voltex_net/src/lib.rs +++ b/crates/voltex_net/src/lib.rs @@ -2,8 +2,14 @@ pub mod packet; pub mod socket; pub mod server; pub mod client; +pub mod reliable; +pub mod snapshot; +pub mod interpolation; pub use packet::Packet; pub use socket::NetSocket; pub use server::{NetServer, ServerEvent, ClientInfo}; pub use client::{NetClient, ClientEvent}; +pub use reliable::{ReliableChannel, OrderedChannel}; +pub use snapshot::{Snapshot, EntityState, serialize_snapshot, deserialize_snapshot, diff_snapshots, apply_diff}; +pub use interpolation::InterpolationBuffer; diff --git a/crates/voltex_net/src/packet.rs b/crates/voltex_net/src/packet.rs index 2d081b9..8ae5423 100644 --- a/crates/voltex_net/src/packet.rs +++ b/crates/voltex_net/src/packet.rs @@ -5,6 +5,10 @@ const TYPE_DISCONNECT: u8 = 3; const TYPE_PING: u8 = 4; const TYPE_PONG: u8 = 5; const TYPE_USER_DATA: u8 = 6; +const TYPE_RELIABLE: u8 = 7; +const TYPE_ACK: u8 = 8; +const TYPE_SNAPSHOT: u8 = 9; +const TYPE_SNAPSHOT_DELTA: u8 = 10; /// Header size: type_id(1) + payload_len(2 LE) + reserved(1) = 4 bytes const HEADER_SIZE: usize = 4; @@ -18,6 +22,10 @@ pub enum Packet { Ping { timestamp: u64 }, Pong { timestamp: u64 }, UserData { client_id: u32, data: Vec }, + Reliable { sequence: u16, data: Vec }, + Ack { sequence: u16 }, + Snapshot { tick: u32, data: Vec }, + SnapshotDelta { base_tick: u32, tick: u32, data: Vec }, } impl Packet { @@ -112,6 +120,38 @@ impl Packet { let data = payload[4..].to_vec(); Ok(Packet::UserData { client_id, data }) } + TYPE_RELIABLE => { + if payload.len() < 2 { + return Err("Reliable payload too short".to_string()); + } + let sequence = u16::from_le_bytes([payload[0], payload[1]]); + let data = payload[2..].to_vec(); + Ok(Packet::Reliable { sequence, data }) + } + TYPE_ACK => { + if payload.len() < 2 { + return Err("Ack payload too short".to_string()); + } + let sequence = u16::from_le_bytes([payload[0], payload[1]]); + Ok(Packet::Ack { sequence }) + } + TYPE_SNAPSHOT => { + if payload.len() < 4 { + return Err("Snapshot payload too short".to_string()); + } + let tick = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]); + let data = payload[4..].to_vec(); + Ok(Packet::Snapshot { tick, data }) + } + TYPE_SNAPSHOT_DELTA => { + if payload.len() < 8 { + return Err("SnapshotDelta payload too short".to_string()); + } + let base_tick = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]); + let tick = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]); + let data = payload[8..].to_vec(); + Ok(Packet::SnapshotDelta { base_tick, tick, data }) + } _ => Err(format!("Unknown packet type_id: {}", type_id)), } } @@ -124,6 +164,10 @@ impl Packet { Packet::Ping { .. } => TYPE_PING, Packet::Pong { .. } => TYPE_PONG, Packet::UserData { .. } => TYPE_USER_DATA, + Packet::Reliable { .. } => TYPE_RELIABLE, + Packet::Ack { .. } => TYPE_ACK, + Packet::Snapshot { .. } => TYPE_SNAPSHOT, + Packet::SnapshotDelta { .. } => TYPE_SNAPSHOT_DELTA, } } @@ -147,6 +191,26 @@ impl Packet { buf.extend_from_slice(data); buf } + Packet::Reliable { sequence, data } => { + let mut buf = Vec::with_capacity(2 + data.len()); + buf.extend_from_slice(&sequence.to_le_bytes()); + buf.extend_from_slice(data); + buf + } + Packet::Ack { sequence } => sequence.to_le_bytes().to_vec(), + Packet::Snapshot { tick, data } => { + let mut buf = Vec::with_capacity(4 + data.len()); + buf.extend_from_slice(&tick.to_le_bytes()); + buf.extend_from_slice(data); + buf + } + Packet::SnapshotDelta { base_tick, tick, data } => { + let mut buf = Vec::with_capacity(8 + data.len()); + buf.extend_from_slice(&base_tick.to_le_bytes()); + buf.extend_from_slice(&tick.to_le_bytes()); + buf.extend_from_slice(data); + buf + } } } } @@ -200,6 +264,36 @@ mod tests { }); } + #[test] + fn test_reliable_roundtrip() { + roundtrip(Packet::Reliable { + sequence: 42, + data: vec![0xCA, 0xFE], + }); + } + + #[test] + fn test_ack_roundtrip() { + roundtrip(Packet::Ack { sequence: 100 }); + } + + #[test] + fn test_snapshot_roundtrip() { + roundtrip(Packet::Snapshot { + tick: 999, + data: vec![1, 2, 3], + }); + } + + #[test] + fn test_snapshot_delta_roundtrip() { + roundtrip(Packet::SnapshotDelta { + base_tick: 10, + tick: 15, + data: vec![4, 5, 6], + }); + } + #[test] fn test_invalid_type_returns_error() { // Build a packet with type_id = 99 (unknown) diff --git a/crates/voltex_net/src/reliable.rs b/crates/voltex_net/src/reliable.rs new file mode 100644 index 0000000..69f62c3 --- /dev/null +++ b/crates/voltex_net/src/reliable.rs @@ -0,0 +1,318 @@ +use std::collections::{HashMap, HashSet}; +use std::time::{Duration, Instant}; + +/// A channel that provides reliable delivery over unreliable transport. +/// +/// Assigns sequence numbers, tracks ACKs, estimates RTT, +/// and retransmits unacknowledged packets after 2x RTT. +pub struct ReliableChannel { + next_sequence: u16, + pending_acks: HashMap)>, + received_seqs: HashSet, + rtt: Duration, + /// Outgoing ACK packets that need to be sent by the caller. + outgoing_acks: Vec, +} + +impl ReliableChannel { + pub fn new() -> Self { + ReliableChannel { + next_sequence: 0, + pending_acks: HashMap::new(), + received_seqs: HashSet::new(), + rtt: Duration::from_millis(100), // initial estimate + outgoing_acks: Vec::new(), + } + } + + /// Returns the current RTT estimate. + pub fn rtt(&self) -> Duration { + self.rtt + } + + /// Returns the number of packets awaiting acknowledgement. + pub fn pending_count(&self) -> usize { + self.pending_acks.len() + } + + /// Prepare a reliable send. Returns (sequence_number, wrapped_data). + /// The caller is responsible for actually transmitting the wrapped data. + pub fn send_reliable(&mut self, data: &[u8]) -> (u16, Vec) { + let seq = self.next_sequence; + self.next_sequence = self.next_sequence.wrapping_add(1); + + // Build the reliable packet payload: [seq(2 LE), data...] + let mut buf = Vec::with_capacity(2 + data.len()); + buf.extend_from_slice(&seq.to_le_bytes()); + buf.extend_from_slice(data); + + self.pending_acks.insert(seq, (Instant::now(), buf.clone())); + + (seq, buf) + } + + /// Process a received reliable packet. Returns the payload data if this is + /// not a duplicate, or None if already received. Queues an ACK to send. + pub fn receive_and_ack(&mut self, sequence: u16, data: &[u8]) -> Option> { + // Always queue an ACK, even for duplicates + self.outgoing_acks.push(sequence); + + if self.received_seqs.contains(&sequence) { + return None; // duplicate + } + + self.received_seqs.insert(sequence); + Some(data.to_vec()) + } + + /// Process an incoming ACK for a sequence we sent. + pub fn process_ack(&mut self, sequence: u16) { + if let Some((send_time, _)) = self.pending_acks.remove(&sequence) { + let sample = send_time.elapsed(); + // Exponential moving average: rtt = 0.875 * rtt + 0.125 * sample + self.rtt = Duration::from_secs_f64( + 0.875 * self.rtt.as_secs_f64() + 0.125 * sample.as_secs_f64(), + ); + } + } + + /// Drain any pending outgoing ACK sequence numbers. + pub fn drain_acks(&mut self) -> Vec { + std::mem::take(&mut self.outgoing_acks) + } + + /// Check for timed-out packets and return their data for retransmission. + /// Resets the send_time for retransmitted packets. + pub fn update(&mut self) -> Vec> { + let timeout = self.rtt * 2; + let now = Instant::now(); + let mut retransmits = Vec::new(); + + for (_, (send_time, data)) in self.pending_acks.iter_mut() { + if now.duration_since(*send_time) >= timeout { + retransmits.push(data.clone()); + *send_time = now; + } + } + + retransmits + } +} + +/// A channel that delivers packets in order, built on top of ReliableChannel. +pub struct OrderedChannel { + reliable: ReliableChannel, + next_deliver: u16, + buffer: HashMap>, +} + +impl OrderedChannel { + pub fn new() -> Self { + OrderedChannel { + reliable: ReliableChannel::new(), + next_deliver: 0, + buffer: HashMap::new(), + } + } + + /// Access the underlying reliable channel (e.g., for send_reliable, process_ack, update). + pub fn reliable(&self) -> &ReliableChannel { + &self.reliable + } + + /// Access the underlying reliable channel mutably. + pub fn reliable_mut(&mut self) -> &mut ReliableChannel { + &mut self.reliable + } + + /// Prepare a reliable, ordered send. + pub fn send(&mut self, data: &[u8]) -> (u16, Vec) { + self.reliable.send_reliable(data) + } + + /// Receive a packet. Buffers out-of-order packets and returns all + /// packets that can now be delivered in sequence order. + pub fn receive(&mut self, sequence: u16, data: &[u8]) -> Vec> { + let payload = self.reliable.receive_and_ack(sequence, data); + + if let Some(payload) = payload { + self.buffer.insert(sequence, payload); + } + + // Deliver as many consecutive packets as possible + let mut delivered = Vec::new(); + while let Some(data) = self.buffer.remove(&self.next_deliver) { + delivered.push(data); + self.next_deliver = self.next_deliver.wrapping_add(1); + } + + delivered + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ---- ReliableChannel tests ---- + + #[test] + fn test_send_receive_ack_roundtrip() { + let mut sender = ReliableChannel::new(); + let mut receiver = ReliableChannel::new(); + + let original = b"hello world"; + let (seq, _buf) = sender.send_reliable(original); + assert_eq!(seq, 0); + assert_eq!(sender.pending_count(), 1); + + // Receiver gets the packet + let result = receiver.receive_and_ack(seq, original); + assert_eq!(result, Some(original.to_vec())); + + // Receiver queued an ack + let acks = receiver.drain_acks(); + assert_eq!(acks, vec![0]); + + // Sender processes the ack + sender.process_ack(seq); + assert_eq!(sender.pending_count(), 0); + } + + #[test] + fn test_duplicate_rejection() { + let mut receiver = ReliableChannel::new(); + + let data = b"payload"; + let result1 = receiver.receive_and_ack(0, data); + assert!(result1.is_some()); + + let result2 = receiver.receive_and_ack(0, data); + assert!(result2.is_none(), "Duplicate should be rejected"); + + // But ACK is still queued for both + let acks = receiver.drain_acks(); + assert_eq!(acks.len(), 2); + } + + #[test] + fn test_sequence_numbers_increment() { + let mut channel = ReliableChannel::new(); + + let (s0, _) = channel.send_reliable(b"a"); + let (s1, _) = channel.send_reliable(b"b"); + let (s2, _) = channel.send_reliable(b"c"); + + assert_eq!(s0, 0); + assert_eq!(s1, 1); + assert_eq!(s2, 2); + assert_eq!(channel.pending_count(), 3); + } + + #[test] + fn test_retransmission_on_timeout() { + let mut channel = ReliableChannel::new(); + // Set a very short RTT so timeout (2*RTT) triggers quickly + channel.rtt = Duration::from_millis(1); + + let (_seq, _buf) = channel.send_reliable(b"data"); + assert_eq!(channel.pending_count(), 1); + + // Wait for timeout + std::thread::sleep(Duration::from_millis(10)); + + let retransmits = channel.update(); + assert_eq!(retransmits.len(), 1, "Should retransmit 1 packet"); + // Packet is still pending (not acked) + assert_eq!(channel.pending_count(), 1); + } + + #[test] + fn test_no_retransmission_before_timeout() { + let mut channel = ReliableChannel::new(); + // Default RTT = 100ms, so timeout = 200ms + let (_seq, _buf) = channel.send_reliable(b"data"); + + // Immediately check — should not retransmit + let retransmits = channel.update(); + assert!(retransmits.is_empty()); + } + + #[test] + fn test_rtt_estimation() { + let mut channel = ReliableChannel::new(); + let initial_rtt = channel.rtt(); + + let (seq, _) = channel.send_reliable(b"x"); + std::thread::sleep(Duration::from_millis(5)); + channel.process_ack(seq); + + // RTT should have changed from initial value + let new_rtt = channel.rtt(); + assert_ne!(initial_rtt, new_rtt, "RTT should be updated after ACK"); + } + + #[test] + fn test_wrapping_sequence() { + let mut channel = ReliableChannel::new(); + channel.next_sequence = u16::MAX; + + let (s1, _) = channel.send_reliable(b"a"); + assert_eq!(s1, u16::MAX); + + let (s2, _) = channel.send_reliable(b"b"); + assert_eq!(s2, 0); // wrapped + } + + // ---- OrderedChannel tests ---- + + #[test] + fn test_ordered_in_order_delivery() { + let mut channel = OrderedChannel::new(); + + let delivered0 = channel.receive(0, b"first"); + assert_eq!(delivered0, vec![b"first".to_vec()]); + + let delivered1 = channel.receive(1, b"second"); + assert_eq!(delivered1, vec![b"second".to_vec()]); + } + + #[test] + fn test_ordered_out_of_order_delivery() { + let mut channel = OrderedChannel::new(); + + // Receive seq 1 first (out of order) + let delivered = channel.receive(1, b"second"); + assert!(delivered.is_empty(), "Seq 1 should be buffered, waiting for 0"); + + // Receive seq 2 (still missing 0) + let delivered = channel.receive(2, b"third"); + assert!(delivered.is_empty()); + + // Receive seq 0 — should deliver 0, 1, 2 in order + let delivered = channel.receive(0, b"first"); + assert_eq!(delivered.len(), 3); + assert_eq!(delivered[0], b"first"); + assert_eq!(delivered[1], b"second"); + assert_eq!(delivered[2], b"third"); + } + + #[test] + fn test_ordered_gap_handling() { + let mut channel = OrderedChannel::new(); + + // Deliver 0 + let d = channel.receive(0, b"a"); + assert_eq!(d.len(), 1); + + // Skip 1, deliver 2 + let d = channel.receive(2, b"c"); + assert!(d.is_empty(), "Can't deliver 2 without 1"); + + // Now deliver 1 — should flush both 1 and 2 + let d = channel.receive(1, b"b"); + assert_eq!(d.len(), 2); + assert_eq!(d[0], b"b"); + assert_eq!(d[1], b"c"); + } +} diff --git a/crates/voltex_net/src/snapshot.rs b/crates/voltex_net/src/snapshot.rs new file mode 100644 index 0000000..8a1ee8e --- /dev/null +++ b/crates/voltex_net/src/snapshot.rs @@ -0,0 +1,378 @@ +/// State of a single entity at a point in time. +#[derive(Debug, Clone, PartialEq)] +pub struct EntityState { + pub id: u32, + pub position: [f32; 3], + pub rotation: [f32; 3], + pub velocity: [f32; 3], +} + +/// A snapshot of the world at a given tick. +#[derive(Debug, Clone, PartialEq)] +pub struct Snapshot { + pub tick: u32, + pub entities: Vec, +} + +/// Binary size of one entity: id(4) + pos(12) + rot(12) + vel(12) = 40 bytes +const ENTITY_SIZE: usize = 4 + 12 + 12 + 12; + +fn write_f32_le(buf: &mut Vec, v: f32) { + buf.extend_from_slice(&v.to_le_bytes()); +} + +fn read_f32_le(data: &[u8], offset: usize) -> f32 { + f32::from_le_bytes([data[offset], data[offset + 1], data[offset + 2], data[offset + 3]]) +} + +fn write_f32x3(buf: &mut Vec, v: &[f32; 3]) { + write_f32_le(buf, v[0]); + write_f32_le(buf, v[1]); + write_f32_le(buf, v[2]); +} + +fn read_f32x3(data: &[u8], offset: usize) -> [f32; 3] { + [ + read_f32_le(data, offset), + read_f32_le(data, offset + 4), + read_f32_le(data, offset + 8), + ] +} + +fn serialize_entity(buf: &mut Vec, e: &EntityState) { + buf.extend_from_slice(&e.id.to_le_bytes()); + write_f32x3(buf, &e.position); + write_f32x3(buf, &e.rotation); + write_f32x3(buf, &e.velocity); +} + +fn deserialize_entity(data: &[u8], offset: usize) -> EntityState { + let id = u32::from_le_bytes([ + data[offset], data[offset + 1], data[offset + 2], data[offset + 3], + ]); + let position = read_f32x3(data, offset + 4); + let rotation = read_f32x3(data, offset + 16); + let velocity = read_f32x3(data, offset + 28); + EntityState { id, position, rotation, velocity } +} + +/// Serialize a snapshot into compact binary format. +/// Layout: tick(4 LE) + entity_count(4 LE) + entities... +pub fn serialize_snapshot(snapshot: &Snapshot) -> Vec { + let count = snapshot.entities.len() as u32; + let mut buf = Vec::with_capacity(8 + ENTITY_SIZE * snapshot.entities.len()); + buf.extend_from_slice(&snapshot.tick.to_le_bytes()); + buf.extend_from_slice(&count.to_le_bytes()); + for e in &snapshot.entities { + serialize_entity(&mut buf, e); + } + buf +} + +/// Deserialize a snapshot from binary data. +pub fn deserialize_snapshot(data: &[u8]) -> Result { + if data.len() < 8 { + return Err("Snapshot data too short for header".to_string()); + } + let tick = u32::from_le_bytes([data[0], data[1], data[2], data[3]]); + let count = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize; + + let expected = 8 + count * ENTITY_SIZE; + if data.len() < expected { + return Err(format!( + "Snapshot data too short: expected {} bytes, got {}", + expected, + data.len() + )); + } + + let mut entities = Vec::with_capacity(count); + for i in 0..count { + entities.push(deserialize_entity(data, 8 + i * ENTITY_SIZE)); + } + + Ok(Snapshot { tick, entities }) +} + +/// Compute a delta between two snapshots. +/// Format: new_tick(4) + count(4) + [id(4) + flags(1) + changed_fields...] +/// Flags bitmask: 0x01 = position, 0x02 = rotation, 0x04 = velocity, 0x80 = new entity (full) +pub fn diff_snapshots(old: &Snapshot, new: &Snapshot) -> Vec { + use std::collections::HashMap; + + let old_map: HashMap = old.entities.iter().map(|e| (e.id, e)).collect(); + + let mut entries: Vec = Vec::new(); + let mut count: u32 = 0; + + for new_ent in &new.entities { + if let Some(old_ent) = old_map.get(&new_ent.id) { + let mut flags: u8 = 0; + let mut fields = Vec::new(); + + if new_ent.position != old_ent.position { + flags |= 0x01; + write_f32x3(&mut fields, &new_ent.position); + } + if new_ent.rotation != old_ent.rotation { + flags |= 0x02; + write_f32x3(&mut fields, &new_ent.rotation); + } + if new_ent.velocity != old_ent.velocity { + flags |= 0x04; + write_f32x3(&mut fields, &new_ent.velocity); + } + + if flags != 0 { + entries.extend_from_slice(&new_ent.id.to_le_bytes()); + entries.push(flags); + entries.extend_from_slice(&fields); + count += 1; + } + } else { + // New entity — send full state + entries.extend_from_slice(&new_ent.id.to_le_bytes()); + entries.push(0x80); // "new entity" flag + write_f32x3(&mut entries, &new_ent.position); + write_f32x3(&mut entries, &new_ent.rotation); + write_f32x3(&mut entries, &new_ent.velocity); + count += 1; + } + } + + let mut buf = Vec::with_capacity(8 + entries.len()); + buf.extend_from_slice(&new.tick.to_le_bytes()); + buf.extend_from_slice(&count.to_le_bytes()); + buf.extend_from_slice(&entries); + buf +} + +/// Apply a delta to a base snapshot to produce an updated snapshot. +pub fn apply_diff(base: &Snapshot, diff: &[u8]) -> Result { + if diff.len() < 8 { + return Err("Diff data too short for header".to_string()); + } + + let tick = u32::from_le_bytes([diff[0], diff[1], diff[2], diff[3]]); + let count = u32::from_le_bytes([diff[4], diff[5], diff[6], diff[7]]) as usize; + + // Start from a clone of the base + let mut entities: Vec = base.entities.clone(); + let mut offset = 8; + + for _ in 0..count { + if offset + 5 > diff.len() { + return Err("Diff truncated at entry header".to_string()); + } + let id = u32::from_le_bytes([diff[offset], diff[offset + 1], diff[offset + 2], diff[offset + 3]]); + let flags = diff[offset + 4]; + offset += 5; + + if flags & 0x80 != 0 { + // New entity — full state + if offset + 36 > diff.len() { + return Err("Diff truncated at new entity data".to_string()); + } + let position = read_f32x3(diff, offset); + let rotation = read_f32x3(diff, offset + 12); + let velocity = read_f32x3(diff, offset + 24); + offset += 36; + + // Add or replace + if let Some(ent) = entities.iter_mut().find(|e| e.id == id) { + ent.position = position; + ent.rotation = rotation; + ent.velocity = velocity; + } else { + entities.push(EntityState { id, position, rotation, velocity }); + } + } else { + // Delta update — find existing entity + let ent = entities.iter_mut().find(|e| e.id == id) + .ok_or_else(|| format!("Diff references unknown entity {}", id))?; + + if flags & 0x01 != 0 { + if offset + 12 > diff.len() { + return Err("Diff truncated at position".to_string()); + } + ent.position = read_f32x3(diff, offset); + offset += 12; + } + if flags & 0x02 != 0 { + if offset + 12 > diff.len() { + return Err("Diff truncated at rotation".to_string()); + } + ent.rotation = read_f32x3(diff, offset); + offset += 12; + } + if flags & 0x04 != 0 { + if offset + 12 > diff.len() { + return Err("Diff truncated at velocity".to_string()); + } + ent.velocity = read_f32x3(diff, offset); + offset += 12; + } + } + } + + Ok(Snapshot { tick, entities }) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_entity(id: u32, px: f32, py: f32, pz: f32) -> EntityState { + EntityState { + id, + position: [px, py, pz], + rotation: [0.0, 0.0, 0.0], + velocity: [0.0, 0.0, 0.0], + } + } + + #[test] + fn test_snapshot_roundtrip() { + let snap = Snapshot { + tick: 42, + entities: vec![ + make_entity(1, 1.0, 2.0, 3.0), + make_entity(2, 4.0, 5.0, 6.0), + ], + }; + + let bytes = serialize_snapshot(&snap); + let decoded = deserialize_snapshot(&bytes).expect("deserialize failed"); + assert_eq!(snap, decoded); + } + + #[test] + fn test_snapshot_empty() { + let snap = Snapshot { tick: 0, entities: vec![] }; + let bytes = serialize_snapshot(&snap); + assert_eq!(bytes.len(), 8); // just header + let decoded = deserialize_snapshot(&bytes).unwrap(); + assert_eq!(snap, decoded); + } + + #[test] + fn test_diff_no_changes() { + let snap = Snapshot { + tick: 10, + entities: vec![make_entity(1, 1.0, 2.0, 3.0)], + }; + let snap2 = Snapshot { + tick: 11, + entities: vec![make_entity(1, 1.0, 2.0, 3.0)], + }; + + let diff = diff_snapshots(&snap, &snap2); + // Header only: tick(4) + count(4) = 8, count = 0 + assert_eq!(diff.len(), 8); + let count = u32::from_le_bytes([diff[4], diff[5], diff[6], diff[7]]); + assert_eq!(count, 0); + } + + #[test] + fn test_diff_position_changed() { + let old = Snapshot { + tick: 10, + entities: vec![make_entity(1, 0.0, 0.0, 0.0)], + }; + let new = Snapshot { + tick: 11, + entities: vec![EntityState { + id: 1, + position: [1.0, 2.0, 3.0], + rotation: [0.0, 0.0, 0.0], + velocity: [0.0, 0.0, 0.0], + }], + }; + + let diff = diff_snapshots(&old, &new); + let result = apply_diff(&old, &diff).expect("apply_diff failed"); + + assert_eq!(result.tick, 11); + assert_eq!(result.entities.len(), 1); + assert_eq!(result.entities[0].position, [1.0, 2.0, 3.0]); + assert_eq!(result.entities[0].rotation, [0.0, 0.0, 0.0]); // unchanged + } + + #[test] + fn test_diff_new_entity() { + let old = Snapshot { + tick: 10, + entities: vec![make_entity(1, 0.0, 0.0, 0.0)], + }; + let new = Snapshot { + tick: 11, + entities: vec![ + make_entity(1, 0.0, 0.0, 0.0), + make_entity(2, 5.0, 6.0, 7.0), + ], + }; + + let diff = diff_snapshots(&old, &new); + let result = apply_diff(&old, &diff).expect("apply_diff failed"); + + assert_eq!(result.entities.len(), 2); + assert_eq!(result.entities[1].id, 2); + assert_eq!(result.entities[1].position, [5.0, 6.0, 7.0]); + } + + #[test] + fn test_diff_multiple_fields_changed() { + let old = Snapshot { + tick: 10, + entities: vec![EntityState { + id: 1, + position: [0.0, 0.0, 0.0], + rotation: [0.0, 0.0, 0.0], + velocity: [0.0, 0.0, 0.0], + }], + }; + let new = Snapshot { + tick: 11, + entities: vec![EntityState { + id: 1, + position: [1.0, 1.0, 1.0], + rotation: [2.0, 2.0, 2.0], + velocity: [3.0, 3.0, 3.0], + }], + }; + + let diff = diff_snapshots(&old, &new); + let result = apply_diff(&old, &diff).unwrap(); + + assert_eq!(result.entities[0].position, [1.0, 1.0, 1.0]); + assert_eq!(result.entities[0].rotation, [2.0, 2.0, 2.0]); + assert_eq!(result.entities[0].velocity, [3.0, 3.0, 3.0]); + } + + #[test] + fn test_diff_is_compact() { + // Only position changes — diff should be smaller than full snapshot + let old = Snapshot { + tick: 10, + entities: vec![make_entity(1, 0.0, 0.0, 0.0)], + }; + let new = Snapshot { + tick: 11, + entities: vec![EntityState { + id: 1, + position: [1.0, 2.0, 3.0], + rotation: [0.0, 0.0, 0.0], + velocity: [0.0, 0.0, 0.0], + }], + }; + + let full_bytes = serialize_snapshot(&new); + let diff_bytes = diff_snapshots(&old, &new); + assert!( + diff_bytes.len() < full_bytes.len(), + "Diff ({} bytes) should be smaller than full snapshot ({} bytes)", + diff_bytes.len(), + full_bytes.len() + ); + } +}