feat(net): add reliability layer, state sync, and client interpolation

- ReliableChannel: sequence numbers, ACK, retransmission, RTT estimation
- OrderedChannel: in-order delivery with out-of-order buffering
- Snapshot serialization with delta compression (per-field bitmask)
- InterpolationBuffer: linear interpolation between server snapshots
- New packet types: Reliable, Ack, Snapshot, SnapshotDelta

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 21:03:52 +09:00
parent dccea21bfe
commit 0ef750de69
5 changed files with 1030 additions and 0 deletions

View File

@@ -0,0 +1,234 @@
use std::collections::VecDeque;
use crate::snapshot::{EntityState, Snapshot};
/// Buffers recent snapshots and interpolates between them for smooth rendering.
pub struct InterpolationBuffer {
snapshots: VecDeque<(f64, Snapshot)>,
/// Render delay behind the latest server time (seconds).
interp_delay: f64,
/// Maximum number of snapshots to keep in the buffer.
max_snapshots: usize,
}
impl InterpolationBuffer {
/// Create a new interpolation buffer with the given delay in seconds.
pub fn new(interp_delay: f64) -> Self {
InterpolationBuffer {
snapshots: VecDeque::new(),
interp_delay,
max_snapshots: 32,
}
}
/// Push a new snapshot with its server timestamp.
pub fn push(&mut self, server_time: f64, snapshot: Snapshot) {
self.snapshots.push_back((server_time, snapshot));
// Evict old snapshots beyond the buffer limit
while self.snapshots.len() > self.max_snapshots {
self.snapshots.pop_front();
}
}
/// Interpolate to produce a snapshot for the given render_time.
///
/// The render_time should be `current_server_time - interp_delay`.
/// Returns None if there are fewer than 2 snapshots or render_time
/// is before all buffered snapshots.
pub fn interpolate(&self, render_time: f64) -> Option<Snapshot> {
if self.snapshots.len() < 2 {
return None;
}
// Find two bracketing snapshots: the last one <= render_time and the first one > render_time
let mut before = None;
let mut after = None;
for (i, (time, _)) in self.snapshots.iter().enumerate() {
if *time <= render_time {
before = Some(i);
} else {
after = Some(i);
break;
}
}
match (before, after) {
(Some(b), Some(a)) => {
let (t0, snap0) = &self.snapshots[b];
let (t1, snap1) = &self.snapshots[a];
let dt = t1 - t0;
if dt <= 0.0 {
return Some(snap0.clone());
}
let alpha = ((render_time - t0) / dt).clamp(0.0, 1.0) as f32;
Some(lerp_snapshots(snap0, snap1, alpha))
}
(Some(b), None) => {
// render_time is beyond all snapshots — return the latest
Some(self.snapshots[b].1.clone())
}
_ => None,
}
}
/// Get the interpolation delay.
pub fn delay(&self) -> f64 {
self.interp_delay
}
}
fn lerp(a: f32, b: f32, t: f32) -> f32 {
a + (b - a) * t
}
fn lerp_f32x3(a: &[f32; 3], b: &[f32; 3], t: f32) -> [f32; 3] {
[lerp(a[0], b[0], t), lerp(a[1], b[1], t), lerp(a[2], b[2], t)]
}
fn lerp_entity(a: &EntityState, b: &EntityState, t: f32) -> EntityState {
EntityState {
id: a.id,
position: lerp_f32x3(&a.position, &b.position, t),
rotation: lerp_f32x3(&a.rotation, &b.rotation, t),
velocity: lerp_f32x3(&a.velocity, &b.velocity, t),
}
}
/// Linearly interpolate between two snapshots.
/// Entities are matched by id. Entities only in one snapshot are included as-is.
fn lerp_snapshots(a: &Snapshot, b: &Snapshot, t: f32) -> Snapshot {
use std::collections::HashMap;
let a_map: HashMap<u32, &EntityState> = a.entities.iter().map(|e| (e.id, e)).collect();
let b_map: HashMap<u32, &EntityState> = b.entities.iter().map(|e| (e.id, e)).collect();
let mut entities = Vec::new();
// Interpolate matched entities, include a-only entities
for ea in &a.entities {
if let Some(eb) = b_map.get(&ea.id) {
entities.push(lerp_entity(ea, eb, t));
} else {
entities.push(ea.clone());
}
}
// Include b-only entities
for eb in &b.entities {
if !a_map.contains_key(&eb.id) {
entities.push(eb.clone());
}
}
// Interpolate tick
let tick = (a.tick as f64 + (b.tick as f64 - a.tick as f64) * t as f64) as u32;
Snapshot { tick, entities }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::snapshot::EntityState;
fn make_snapshot(tick: u32, x: f32) -> Snapshot {
Snapshot {
tick,
entities: vec![EntityState {
id: 1,
position: [x, 0.0, 0.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
}
}
#[test]
fn test_exact_match_at_snapshot_time() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(0.1, make_snapshot(1, 10.0));
let result = buf.interpolate(0.0).expect("should interpolate");
assert_eq!(result.entities[0].position[0], 0.0);
}
#[test]
fn test_midpoint_interpolation() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(10, 10.0));
let result = buf.interpolate(0.5).expect("should interpolate");
let x = result.entities[0].position[0];
assert!(
(x - 5.0).abs() < 0.001,
"Expected ~5.0 at midpoint, got {}",
x
);
}
#[test]
fn test_interpolation_at_quarter() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(10, 100.0));
let result = buf.interpolate(0.25).unwrap();
let x = result.entities[0].position[0];
assert!(
(x - 25.0).abs() < 0.01,
"Expected ~25.0 at 0.25, got {}",
x
);
}
#[test]
fn test_extrapolation_returns_latest() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(10, 10.0));
// render_time beyond all snapshots
let result = buf.interpolate(2.0).expect("should return latest");
assert_eq!(result.entities[0].position[0], 10.0);
}
#[test]
fn test_too_few_snapshots_returns_none() {
let mut buf = InterpolationBuffer::new(0.1);
assert!(buf.interpolate(0.0).is_none());
buf.push(0.0, make_snapshot(0, 0.0));
assert!(buf.interpolate(0.0).is_none());
}
#[test]
fn test_render_time_before_all_snapshots() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(1.0, make_snapshot(10, 10.0));
buf.push(2.0, make_snapshot(20, 20.0));
// render_time before the first snapshot
let result = buf.interpolate(0.0);
assert!(result.is_none());
}
#[test]
fn test_multiple_snapshots_picks_correct_bracket() {
let mut buf = InterpolationBuffer::new(0.1);
buf.push(0.0, make_snapshot(0, 0.0));
buf.push(1.0, make_snapshot(1, 10.0));
buf.push(2.0, make_snapshot(2, 20.0));
// Should interpolate between snapshot at t=1 and t=2
let result = buf.interpolate(1.5).unwrap();
let x = result.entities[0].position[0];
assert!(
(x - 15.0).abs() < 0.01,
"Expected ~15.0, got {}",
x
);
}
}

View File

@@ -2,8 +2,14 @@ pub mod packet;
pub mod socket;
pub mod server;
pub mod client;
pub mod reliable;
pub mod snapshot;
pub mod interpolation;
pub use packet::Packet;
pub use socket::NetSocket;
pub use server::{NetServer, ServerEvent, ClientInfo};
pub use client::{NetClient, ClientEvent};
pub use reliable::{ReliableChannel, OrderedChannel};
pub use snapshot::{Snapshot, EntityState, serialize_snapshot, deserialize_snapshot, diff_snapshots, apply_diff};
pub use interpolation::InterpolationBuffer;

View File

@@ -5,6 +5,10 @@ const TYPE_DISCONNECT: u8 = 3;
const TYPE_PING: u8 = 4;
const TYPE_PONG: u8 = 5;
const TYPE_USER_DATA: u8 = 6;
const TYPE_RELIABLE: u8 = 7;
const TYPE_ACK: u8 = 8;
const TYPE_SNAPSHOT: u8 = 9;
const TYPE_SNAPSHOT_DELTA: u8 = 10;
/// Header size: type_id(1) + payload_len(2 LE) + reserved(1) = 4 bytes
const HEADER_SIZE: usize = 4;
@@ -18,6 +22,10 @@ pub enum Packet {
Ping { timestamp: u64 },
Pong { timestamp: u64 },
UserData { client_id: u32, data: Vec<u8> },
Reliable { sequence: u16, data: Vec<u8> },
Ack { sequence: u16 },
Snapshot { tick: u32, data: Vec<u8> },
SnapshotDelta { base_tick: u32, tick: u32, data: Vec<u8> },
}
impl Packet {
@@ -112,6 +120,38 @@ impl Packet {
let data = payload[4..].to_vec();
Ok(Packet::UserData { client_id, data })
}
TYPE_RELIABLE => {
if payload.len() < 2 {
return Err("Reliable payload too short".to_string());
}
let sequence = u16::from_le_bytes([payload[0], payload[1]]);
let data = payload[2..].to_vec();
Ok(Packet::Reliable { sequence, data })
}
TYPE_ACK => {
if payload.len() < 2 {
return Err("Ack payload too short".to_string());
}
let sequence = u16::from_le_bytes([payload[0], payload[1]]);
Ok(Packet::Ack { sequence })
}
TYPE_SNAPSHOT => {
if payload.len() < 4 {
return Err("Snapshot payload too short".to_string());
}
let tick = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let data = payload[4..].to_vec();
Ok(Packet::Snapshot { tick, data })
}
TYPE_SNAPSHOT_DELTA => {
if payload.len() < 8 {
return Err("SnapshotDelta payload too short".to_string());
}
let base_tick = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let tick = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]);
let data = payload[8..].to_vec();
Ok(Packet::SnapshotDelta { base_tick, tick, data })
}
_ => Err(format!("Unknown packet type_id: {}", type_id)),
}
}
@@ -124,6 +164,10 @@ impl Packet {
Packet::Ping { .. } => TYPE_PING,
Packet::Pong { .. } => TYPE_PONG,
Packet::UserData { .. } => TYPE_USER_DATA,
Packet::Reliable { .. } => TYPE_RELIABLE,
Packet::Ack { .. } => TYPE_ACK,
Packet::Snapshot { .. } => TYPE_SNAPSHOT,
Packet::SnapshotDelta { .. } => TYPE_SNAPSHOT_DELTA,
}
}
@@ -147,6 +191,26 @@ impl Packet {
buf.extend_from_slice(data);
buf
}
Packet::Reliable { sequence, data } => {
let mut buf = Vec::with_capacity(2 + data.len());
buf.extend_from_slice(&sequence.to_le_bytes());
buf.extend_from_slice(data);
buf
}
Packet::Ack { sequence } => sequence.to_le_bytes().to_vec(),
Packet::Snapshot { tick, data } => {
let mut buf = Vec::with_capacity(4 + data.len());
buf.extend_from_slice(&tick.to_le_bytes());
buf.extend_from_slice(data);
buf
}
Packet::SnapshotDelta { base_tick, tick, data } => {
let mut buf = Vec::with_capacity(8 + data.len());
buf.extend_from_slice(&base_tick.to_le_bytes());
buf.extend_from_slice(&tick.to_le_bytes());
buf.extend_from_slice(data);
buf
}
}
}
}
@@ -200,6 +264,36 @@ mod tests {
});
}
#[test]
fn test_reliable_roundtrip() {
roundtrip(Packet::Reliable {
sequence: 42,
data: vec![0xCA, 0xFE],
});
}
#[test]
fn test_ack_roundtrip() {
roundtrip(Packet::Ack { sequence: 100 });
}
#[test]
fn test_snapshot_roundtrip() {
roundtrip(Packet::Snapshot {
tick: 999,
data: vec![1, 2, 3],
});
}
#[test]
fn test_snapshot_delta_roundtrip() {
roundtrip(Packet::SnapshotDelta {
base_tick: 10,
tick: 15,
data: vec![4, 5, 6],
});
}
#[test]
fn test_invalid_type_returns_error() {
// Build a packet with type_id = 99 (unknown)

View File

@@ -0,0 +1,318 @@
use std::collections::{HashMap, HashSet};
use std::time::{Duration, Instant};
/// A channel that provides reliable delivery over unreliable transport.
///
/// Assigns sequence numbers, tracks ACKs, estimates RTT,
/// and retransmits unacknowledged packets after 2x RTT.
pub struct ReliableChannel {
next_sequence: u16,
pending_acks: HashMap<u16, (Instant, Vec<u8>)>,
received_seqs: HashSet<u16>,
rtt: Duration,
/// Outgoing ACK packets that need to be sent by the caller.
outgoing_acks: Vec<u16>,
}
impl ReliableChannel {
pub fn new() -> Self {
ReliableChannel {
next_sequence: 0,
pending_acks: HashMap::new(),
received_seqs: HashSet::new(),
rtt: Duration::from_millis(100), // initial estimate
outgoing_acks: Vec::new(),
}
}
/// Returns the current RTT estimate.
pub fn rtt(&self) -> Duration {
self.rtt
}
/// Returns the number of packets awaiting acknowledgement.
pub fn pending_count(&self) -> usize {
self.pending_acks.len()
}
/// Prepare a reliable send. Returns (sequence_number, wrapped_data).
/// The caller is responsible for actually transmitting the wrapped data.
pub fn send_reliable(&mut self, data: &[u8]) -> (u16, Vec<u8>) {
let seq = self.next_sequence;
self.next_sequence = self.next_sequence.wrapping_add(1);
// Build the reliable packet payload: [seq(2 LE), data...]
let mut buf = Vec::with_capacity(2 + data.len());
buf.extend_from_slice(&seq.to_le_bytes());
buf.extend_from_slice(data);
self.pending_acks.insert(seq, (Instant::now(), buf.clone()));
(seq, buf)
}
/// Process a received reliable packet. Returns the payload data if this is
/// not a duplicate, or None if already received. Queues an ACK to send.
pub fn receive_and_ack(&mut self, sequence: u16, data: &[u8]) -> Option<Vec<u8>> {
// Always queue an ACK, even for duplicates
self.outgoing_acks.push(sequence);
if self.received_seqs.contains(&sequence) {
return None; // duplicate
}
self.received_seqs.insert(sequence);
Some(data.to_vec())
}
/// Process an incoming ACK for a sequence we sent.
pub fn process_ack(&mut self, sequence: u16) {
if let Some((send_time, _)) = self.pending_acks.remove(&sequence) {
let sample = send_time.elapsed();
// Exponential moving average: rtt = 0.875 * rtt + 0.125 * sample
self.rtt = Duration::from_secs_f64(
0.875 * self.rtt.as_secs_f64() + 0.125 * sample.as_secs_f64(),
);
}
}
/// Drain any pending outgoing ACK sequence numbers.
pub fn drain_acks(&mut self) -> Vec<u16> {
std::mem::take(&mut self.outgoing_acks)
}
/// Check for timed-out packets and return their data for retransmission.
/// Resets the send_time for retransmitted packets.
pub fn update(&mut self) -> Vec<Vec<u8>> {
let timeout = self.rtt * 2;
let now = Instant::now();
let mut retransmits = Vec::new();
for (_, (send_time, data)) in self.pending_acks.iter_mut() {
if now.duration_since(*send_time) >= timeout {
retransmits.push(data.clone());
*send_time = now;
}
}
retransmits
}
}
/// A channel that delivers packets in order, built on top of ReliableChannel.
pub struct OrderedChannel {
reliable: ReliableChannel,
next_deliver: u16,
buffer: HashMap<u16, Vec<u8>>,
}
impl OrderedChannel {
pub fn new() -> Self {
OrderedChannel {
reliable: ReliableChannel::new(),
next_deliver: 0,
buffer: HashMap::new(),
}
}
/// Access the underlying reliable channel (e.g., for send_reliable, process_ack, update).
pub fn reliable(&self) -> &ReliableChannel {
&self.reliable
}
/// Access the underlying reliable channel mutably.
pub fn reliable_mut(&mut self) -> &mut ReliableChannel {
&mut self.reliable
}
/// Prepare a reliable, ordered send.
pub fn send(&mut self, data: &[u8]) -> (u16, Vec<u8>) {
self.reliable.send_reliable(data)
}
/// Receive a packet. Buffers out-of-order packets and returns all
/// packets that can now be delivered in sequence order.
pub fn receive(&mut self, sequence: u16, data: &[u8]) -> Vec<Vec<u8>> {
let payload = self.reliable.receive_and_ack(sequence, data);
if let Some(payload) = payload {
self.buffer.insert(sequence, payload);
}
// Deliver as many consecutive packets as possible
let mut delivered = Vec::new();
while let Some(data) = self.buffer.remove(&self.next_deliver) {
delivered.push(data);
self.next_deliver = self.next_deliver.wrapping_add(1);
}
delivered
}
}
#[cfg(test)]
mod tests {
use super::*;
// ---- ReliableChannel tests ----
#[test]
fn test_send_receive_ack_roundtrip() {
let mut sender = ReliableChannel::new();
let mut receiver = ReliableChannel::new();
let original = b"hello world";
let (seq, _buf) = sender.send_reliable(original);
assert_eq!(seq, 0);
assert_eq!(sender.pending_count(), 1);
// Receiver gets the packet
let result = receiver.receive_and_ack(seq, original);
assert_eq!(result, Some(original.to_vec()));
// Receiver queued an ack
let acks = receiver.drain_acks();
assert_eq!(acks, vec![0]);
// Sender processes the ack
sender.process_ack(seq);
assert_eq!(sender.pending_count(), 0);
}
#[test]
fn test_duplicate_rejection() {
let mut receiver = ReliableChannel::new();
let data = b"payload";
let result1 = receiver.receive_and_ack(0, data);
assert!(result1.is_some());
let result2 = receiver.receive_and_ack(0, data);
assert!(result2.is_none(), "Duplicate should be rejected");
// But ACK is still queued for both
let acks = receiver.drain_acks();
assert_eq!(acks.len(), 2);
}
#[test]
fn test_sequence_numbers_increment() {
let mut channel = ReliableChannel::new();
let (s0, _) = channel.send_reliable(b"a");
let (s1, _) = channel.send_reliable(b"b");
let (s2, _) = channel.send_reliable(b"c");
assert_eq!(s0, 0);
assert_eq!(s1, 1);
assert_eq!(s2, 2);
assert_eq!(channel.pending_count(), 3);
}
#[test]
fn test_retransmission_on_timeout() {
let mut channel = ReliableChannel::new();
// Set a very short RTT so timeout (2*RTT) triggers quickly
channel.rtt = Duration::from_millis(1);
let (_seq, _buf) = channel.send_reliable(b"data");
assert_eq!(channel.pending_count(), 1);
// Wait for timeout
std::thread::sleep(Duration::from_millis(10));
let retransmits = channel.update();
assert_eq!(retransmits.len(), 1, "Should retransmit 1 packet");
// Packet is still pending (not acked)
assert_eq!(channel.pending_count(), 1);
}
#[test]
fn test_no_retransmission_before_timeout() {
let mut channel = ReliableChannel::new();
// Default RTT = 100ms, so timeout = 200ms
let (_seq, _buf) = channel.send_reliable(b"data");
// Immediately check — should not retransmit
let retransmits = channel.update();
assert!(retransmits.is_empty());
}
#[test]
fn test_rtt_estimation() {
let mut channel = ReliableChannel::new();
let initial_rtt = channel.rtt();
let (seq, _) = channel.send_reliable(b"x");
std::thread::sleep(Duration::from_millis(5));
channel.process_ack(seq);
// RTT should have changed from initial value
let new_rtt = channel.rtt();
assert_ne!(initial_rtt, new_rtt, "RTT should be updated after ACK");
}
#[test]
fn test_wrapping_sequence() {
let mut channel = ReliableChannel::new();
channel.next_sequence = u16::MAX;
let (s1, _) = channel.send_reliable(b"a");
assert_eq!(s1, u16::MAX);
let (s2, _) = channel.send_reliable(b"b");
assert_eq!(s2, 0); // wrapped
}
// ---- OrderedChannel tests ----
#[test]
fn test_ordered_in_order_delivery() {
let mut channel = OrderedChannel::new();
let delivered0 = channel.receive(0, b"first");
assert_eq!(delivered0, vec![b"first".to_vec()]);
let delivered1 = channel.receive(1, b"second");
assert_eq!(delivered1, vec![b"second".to_vec()]);
}
#[test]
fn test_ordered_out_of_order_delivery() {
let mut channel = OrderedChannel::new();
// Receive seq 1 first (out of order)
let delivered = channel.receive(1, b"second");
assert!(delivered.is_empty(), "Seq 1 should be buffered, waiting for 0");
// Receive seq 2 (still missing 0)
let delivered = channel.receive(2, b"third");
assert!(delivered.is_empty());
// Receive seq 0 — should deliver 0, 1, 2 in order
let delivered = channel.receive(0, b"first");
assert_eq!(delivered.len(), 3);
assert_eq!(delivered[0], b"first");
assert_eq!(delivered[1], b"second");
assert_eq!(delivered[2], b"third");
}
#[test]
fn test_ordered_gap_handling() {
let mut channel = OrderedChannel::new();
// Deliver 0
let d = channel.receive(0, b"a");
assert_eq!(d.len(), 1);
// Skip 1, deliver 2
let d = channel.receive(2, b"c");
assert!(d.is_empty(), "Can't deliver 2 without 1");
// Now deliver 1 — should flush both 1 and 2
let d = channel.receive(1, b"b");
assert_eq!(d.len(), 2);
assert_eq!(d[0], b"b");
assert_eq!(d[1], b"c");
}
}

View File

@@ -0,0 +1,378 @@
/// State of a single entity at a point in time.
#[derive(Debug, Clone, PartialEq)]
pub struct EntityState {
pub id: u32,
pub position: [f32; 3],
pub rotation: [f32; 3],
pub velocity: [f32; 3],
}
/// A snapshot of the world at a given tick.
#[derive(Debug, Clone, PartialEq)]
pub struct Snapshot {
pub tick: u32,
pub entities: Vec<EntityState>,
}
/// Binary size of one entity: id(4) + pos(12) + rot(12) + vel(12) = 40 bytes
const ENTITY_SIZE: usize = 4 + 12 + 12 + 12;
fn write_f32_le(buf: &mut Vec<u8>, v: f32) {
buf.extend_from_slice(&v.to_le_bytes());
}
fn read_f32_le(data: &[u8], offset: usize) -> f32 {
f32::from_le_bytes([data[offset], data[offset + 1], data[offset + 2], data[offset + 3]])
}
fn write_f32x3(buf: &mut Vec<u8>, v: &[f32; 3]) {
write_f32_le(buf, v[0]);
write_f32_le(buf, v[1]);
write_f32_le(buf, v[2]);
}
fn read_f32x3(data: &[u8], offset: usize) -> [f32; 3] {
[
read_f32_le(data, offset),
read_f32_le(data, offset + 4),
read_f32_le(data, offset + 8),
]
}
fn serialize_entity(buf: &mut Vec<u8>, e: &EntityState) {
buf.extend_from_slice(&e.id.to_le_bytes());
write_f32x3(buf, &e.position);
write_f32x3(buf, &e.rotation);
write_f32x3(buf, &e.velocity);
}
fn deserialize_entity(data: &[u8], offset: usize) -> EntityState {
let id = u32::from_le_bytes([
data[offset], data[offset + 1], data[offset + 2], data[offset + 3],
]);
let position = read_f32x3(data, offset + 4);
let rotation = read_f32x3(data, offset + 16);
let velocity = read_f32x3(data, offset + 28);
EntityState { id, position, rotation, velocity }
}
/// Serialize a snapshot into compact binary format.
/// Layout: tick(4 LE) + entity_count(4 LE) + entities...
pub fn serialize_snapshot(snapshot: &Snapshot) -> Vec<u8> {
let count = snapshot.entities.len() as u32;
let mut buf = Vec::with_capacity(8 + ENTITY_SIZE * snapshot.entities.len());
buf.extend_from_slice(&snapshot.tick.to_le_bytes());
buf.extend_from_slice(&count.to_le_bytes());
for e in &snapshot.entities {
serialize_entity(&mut buf, e);
}
buf
}
/// Deserialize a snapshot from binary data.
pub fn deserialize_snapshot(data: &[u8]) -> Result<Snapshot, String> {
if data.len() < 8 {
return Err("Snapshot data too short for header".to_string());
}
let tick = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
let count = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
let expected = 8 + count * ENTITY_SIZE;
if data.len() < expected {
return Err(format!(
"Snapshot data too short: expected {} bytes, got {}",
expected,
data.len()
));
}
let mut entities = Vec::with_capacity(count);
for i in 0..count {
entities.push(deserialize_entity(data, 8 + i * ENTITY_SIZE));
}
Ok(Snapshot { tick, entities })
}
/// Compute a delta between two snapshots.
/// Format: new_tick(4) + count(4) + [id(4) + flags(1) + changed_fields...]
/// Flags bitmask: 0x01 = position, 0x02 = rotation, 0x04 = velocity, 0x80 = new entity (full)
pub fn diff_snapshots(old: &Snapshot, new: &Snapshot) -> Vec<u8> {
use std::collections::HashMap;
let old_map: HashMap<u32, &EntityState> = old.entities.iter().map(|e| (e.id, e)).collect();
let mut entries: Vec<u8> = Vec::new();
let mut count: u32 = 0;
for new_ent in &new.entities {
if let Some(old_ent) = old_map.get(&new_ent.id) {
let mut flags: u8 = 0;
let mut fields = Vec::new();
if new_ent.position != old_ent.position {
flags |= 0x01;
write_f32x3(&mut fields, &new_ent.position);
}
if new_ent.rotation != old_ent.rotation {
flags |= 0x02;
write_f32x3(&mut fields, &new_ent.rotation);
}
if new_ent.velocity != old_ent.velocity {
flags |= 0x04;
write_f32x3(&mut fields, &new_ent.velocity);
}
if flags != 0 {
entries.extend_from_slice(&new_ent.id.to_le_bytes());
entries.push(flags);
entries.extend_from_slice(&fields);
count += 1;
}
} else {
// New entity — send full state
entries.extend_from_slice(&new_ent.id.to_le_bytes());
entries.push(0x80); // "new entity" flag
write_f32x3(&mut entries, &new_ent.position);
write_f32x3(&mut entries, &new_ent.rotation);
write_f32x3(&mut entries, &new_ent.velocity);
count += 1;
}
}
let mut buf = Vec::with_capacity(8 + entries.len());
buf.extend_from_slice(&new.tick.to_le_bytes());
buf.extend_from_slice(&count.to_le_bytes());
buf.extend_from_slice(&entries);
buf
}
/// Apply a delta to a base snapshot to produce an updated snapshot.
pub fn apply_diff(base: &Snapshot, diff: &[u8]) -> Result<Snapshot, String> {
if diff.len() < 8 {
return Err("Diff data too short for header".to_string());
}
let tick = u32::from_le_bytes([diff[0], diff[1], diff[2], diff[3]]);
let count = u32::from_le_bytes([diff[4], diff[5], diff[6], diff[7]]) as usize;
// Start from a clone of the base
let mut entities: Vec<EntityState> = base.entities.clone();
let mut offset = 8;
for _ in 0..count {
if offset + 5 > diff.len() {
return Err("Diff truncated at entry header".to_string());
}
let id = u32::from_le_bytes([diff[offset], diff[offset + 1], diff[offset + 2], diff[offset + 3]]);
let flags = diff[offset + 4];
offset += 5;
if flags & 0x80 != 0 {
// New entity — full state
if offset + 36 > diff.len() {
return Err("Diff truncated at new entity data".to_string());
}
let position = read_f32x3(diff, offset);
let rotation = read_f32x3(diff, offset + 12);
let velocity = read_f32x3(diff, offset + 24);
offset += 36;
// Add or replace
if let Some(ent) = entities.iter_mut().find(|e| e.id == id) {
ent.position = position;
ent.rotation = rotation;
ent.velocity = velocity;
} else {
entities.push(EntityState { id, position, rotation, velocity });
}
} else {
// Delta update — find existing entity
let ent = entities.iter_mut().find(|e| e.id == id)
.ok_or_else(|| format!("Diff references unknown entity {}", id))?;
if flags & 0x01 != 0 {
if offset + 12 > diff.len() {
return Err("Diff truncated at position".to_string());
}
ent.position = read_f32x3(diff, offset);
offset += 12;
}
if flags & 0x02 != 0 {
if offset + 12 > diff.len() {
return Err("Diff truncated at rotation".to_string());
}
ent.rotation = read_f32x3(diff, offset);
offset += 12;
}
if flags & 0x04 != 0 {
if offset + 12 > diff.len() {
return Err("Diff truncated at velocity".to_string());
}
ent.velocity = read_f32x3(diff, offset);
offset += 12;
}
}
}
Ok(Snapshot { tick, entities })
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entity(id: u32, px: f32, py: f32, pz: f32) -> EntityState {
EntityState {
id,
position: [px, py, pz],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}
}
#[test]
fn test_snapshot_roundtrip() {
let snap = Snapshot {
tick: 42,
entities: vec![
make_entity(1, 1.0, 2.0, 3.0),
make_entity(2, 4.0, 5.0, 6.0),
],
};
let bytes = serialize_snapshot(&snap);
let decoded = deserialize_snapshot(&bytes).expect("deserialize failed");
assert_eq!(snap, decoded);
}
#[test]
fn test_snapshot_empty() {
let snap = Snapshot { tick: 0, entities: vec![] };
let bytes = serialize_snapshot(&snap);
assert_eq!(bytes.len(), 8); // just header
let decoded = deserialize_snapshot(&bytes).unwrap();
assert_eq!(snap, decoded);
}
#[test]
fn test_diff_no_changes() {
let snap = Snapshot {
tick: 10,
entities: vec![make_entity(1, 1.0, 2.0, 3.0)],
};
let snap2 = Snapshot {
tick: 11,
entities: vec![make_entity(1, 1.0, 2.0, 3.0)],
};
let diff = diff_snapshots(&snap, &snap2);
// Header only: tick(4) + count(4) = 8, count = 0
assert_eq!(diff.len(), 8);
let count = u32::from_le_bytes([diff[4], diff[5], diff[6], diff[7]]);
assert_eq!(count, 0);
}
#[test]
fn test_diff_position_changed() {
let old = Snapshot {
tick: 10,
entities: vec![make_entity(1, 0.0, 0.0, 0.0)],
};
let new = Snapshot {
tick: 11,
entities: vec![EntityState {
id: 1,
position: [1.0, 2.0, 3.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
};
let diff = diff_snapshots(&old, &new);
let result = apply_diff(&old, &diff).expect("apply_diff failed");
assert_eq!(result.tick, 11);
assert_eq!(result.entities.len(), 1);
assert_eq!(result.entities[0].position, [1.0, 2.0, 3.0]);
assert_eq!(result.entities[0].rotation, [0.0, 0.0, 0.0]); // unchanged
}
#[test]
fn test_diff_new_entity() {
let old = Snapshot {
tick: 10,
entities: vec![make_entity(1, 0.0, 0.0, 0.0)],
};
let new = Snapshot {
tick: 11,
entities: vec![
make_entity(1, 0.0, 0.0, 0.0),
make_entity(2, 5.0, 6.0, 7.0),
],
};
let diff = diff_snapshots(&old, &new);
let result = apply_diff(&old, &diff).expect("apply_diff failed");
assert_eq!(result.entities.len(), 2);
assert_eq!(result.entities[1].id, 2);
assert_eq!(result.entities[1].position, [5.0, 6.0, 7.0]);
}
#[test]
fn test_diff_multiple_fields_changed() {
let old = Snapshot {
tick: 10,
entities: vec![EntityState {
id: 1,
position: [0.0, 0.0, 0.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
};
let new = Snapshot {
tick: 11,
entities: vec![EntityState {
id: 1,
position: [1.0, 1.0, 1.0],
rotation: [2.0, 2.0, 2.0],
velocity: [3.0, 3.0, 3.0],
}],
};
let diff = diff_snapshots(&old, &new);
let result = apply_diff(&old, &diff).unwrap();
assert_eq!(result.entities[0].position, [1.0, 1.0, 1.0]);
assert_eq!(result.entities[0].rotation, [2.0, 2.0, 2.0]);
assert_eq!(result.entities[0].velocity, [3.0, 3.0, 3.0]);
}
#[test]
fn test_diff_is_compact() {
// Only position changes — diff should be smaller than full snapshot
let old = Snapshot {
tick: 10,
entities: vec![make_entity(1, 0.0, 0.0, 0.0)],
};
let new = Snapshot {
tick: 11,
entities: vec![EntityState {
id: 1,
position: [1.0, 2.0, 3.0],
rotation: [0.0, 0.0, 0.0],
velocity: [0.0, 0.0, 0.0],
}],
};
let full_bytes = serialize_snapshot(&new);
let diff_bytes = diff_snapshots(&old, &new);
assert!(
diff_bytes.len() < full_bytes.len(),
"Diff ({} bytes) should be smaller than full snapshot ({} bytes)",
diff_bytes.len(),
full_bytes.len()
);
}
}