feat(script): add Lua table interop, coroutines, sandbox, hot reload

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 07:33:00 +09:00
parent 63e59c0544
commit f522bf10ac
4 changed files with 670 additions and 0 deletions

View File

@@ -0,0 +1,187 @@
use crate::ffi;
use std::ffi::CString;
/// Status returned by a coroutine resume.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoroutineStatus {
Yielded,
Finished,
}
/// A Lua coroutine backed by a lua_newthread.
pub struct LuaCoroutine {
/// The coroutine thread state.
state: *mut ffi::lua_State,
/// Whether the coroutine has finished execution.
finished: bool,
}
impl crate::state::LuaState {
/// Create a coroutine from a global Lua function name.
/// The function must already be defined in the Lua state.
pub fn create_coroutine(&self, func_name: &str) -> Result<LuaCoroutine, String> {
unsafe {
let co_state = ffi::lua_newthread(self.raw());
if co_state.is_null() {
return Err("failed to create Lua thread".to_string());
}
// Pop the thread from the main stack (it's anchored in the registry)
// Actually, we need to keep the thread referenced. lua_newthread pushes it
// onto the main stack. We'll leave it there (Lua GC won't collect it while
// it's on the stack). For simplicity, pop it — Lua keeps a reference in the
// registry as long as the coroutine is alive.
// Note: lua_newthread creates a thread that is anchored by the registry
// automatically in Lua 5.4, so we can pop it from the main stack.
ffi::lua_pop(self.raw(), 1);
// Push the function onto the coroutine stack
let c_name = CString::new(func_name).map_err(|e| e.to_string())?;
let ty = ffi::lua_getglobal(co_state, c_name.as_ptr());
if ty != ffi::LUA_TFUNCTION {
ffi::lua_pop(co_state, 1);
return Err(format!("'{}' is not a Lua function", func_name));
}
Ok(LuaCoroutine {
state: co_state,
finished: false,
})
}
}
}
impl LuaCoroutine {
/// Resume the coroutine. Returns Yielded if the coroutine yielded,
/// or Finished if it completed.
pub fn resume(&mut self) -> Result<CoroutineStatus, String> {
if self.finished {
return Err("coroutine already finished".to_string());
}
unsafe {
let mut nresults: std::os::raw::c_int = 0;
let status = ffi::lua_resume(self.state, std::ptr::null_mut(), 0, &mut nresults);
// Pop any results from the coroutine stack
if nresults > 0 {
ffi::lua_pop(self.state, nresults);
}
match status {
ffi::LUA_YIELD => Ok(CoroutineStatus::Yielded),
ffi::LUA_OK => {
self.finished = true;
Ok(CoroutineStatus::Finished)
}
_ => {
self.finished = true;
let ptr = ffi::lua_tostring(self.state, -1);
let msg = if ptr.is_null() {
"coroutine error".to_string()
} else {
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
};
ffi::lua_pop(self.state, 1);
Err(msg)
}
}
}
}
/// Check if the coroutine has finished execution.
pub fn is_finished(&self) -> bool {
self.finished
}
}
#[cfg(test)]
mod tests {
use crate::state::LuaState;
use super::*;
#[test]
fn test_simple_coroutine_yield_resume() {
let lua = LuaState::new();
lua.exec("
function my_coro()
step = 1
coroutine.yield()
step = 2
coroutine.yield()
step = 3
end
step = 0
").unwrap();
let mut co = lua.create_coroutine("my_coro").unwrap();
assert!(!co.is_finished());
let status = co.resume().unwrap();
assert_eq!(status, CoroutineStatus::Yielded);
assert_eq!(lua.get_global_number("step"), Some(1.0));
let status = co.resume().unwrap();
assert_eq!(status, CoroutineStatus::Yielded);
assert_eq!(lua.get_global_number("step"), Some(2.0));
let status = co.resume().unwrap();
assert_eq!(status, CoroutineStatus::Finished);
assert_eq!(lua.get_global_number("step"), Some(3.0));
assert!(co.is_finished());
}
#[test]
fn test_multi_step_coroutine() {
let lua = LuaState::new();
lua.exec("
function counter_coro()
for i = 1, 5 do
count = i
coroutine.yield()
end
end
count = 0
").unwrap();
let mut co = lua.create_coroutine("counter_coro").unwrap();
for i in 1..=5 {
let status = co.resume().unwrap();
assert_eq!(lua.get_global_number("count"), Some(i as f64));
if i < 5 {
assert_eq!(status, CoroutineStatus::Yielded);
} else {
// After the last yield, the loop ends and the function returns
// Actually the yield happens inside the loop, so after i=5 it yields,
// then we resume once more to finish
}
}
// One more resume to finish the coroutine (the loop body yields after setting count)
let status = co.resume().unwrap();
assert_eq!(status, CoroutineStatus::Finished);
assert!(co.is_finished());
}
#[test]
fn test_finished_detection() {
let lua = LuaState::new();
lua.exec("
function instant()
result = 42
end
").unwrap();
let mut co = lua.create_coroutine("instant").unwrap();
assert!(!co.is_finished());
let status = co.resume().unwrap();
assert_eq!(status, CoroutineStatus::Finished);
assert!(co.is_finished());
assert_eq!(lua.get_global_number("result"), Some(42.0));
// Resuming again should error
assert!(co.resume().is_err());
}
#[test]
fn test_create_coroutine_invalid_function() {
let lua = LuaState::new();
let result = lua.create_coroutine("nonexistent");
assert!(result.is_err());
}
}

View File

@@ -8,8 +8,13 @@ pub type lua_CFunction = unsafe extern "C" fn(*mut lua_State) -> c_int;
// Constants
pub const LUA_OK: c_int = 0;
pub const LUA_YIELD: c_int = 1;
pub const LUA_TNIL: c_int = 0;
pub const LUA_TBOOLEAN: c_int = 1;
pub const LUA_TNUMBER: c_int = 3;
pub const LUA_TSTRING: c_int = 4;
pub const LUA_TTABLE: c_int = 5;
pub const LUA_TFUNCTION: c_int = 6;
pub const LUA_MULTRET: c_int = -1;
extern "C" {
@@ -26,21 +31,41 @@ extern "C" {
pub fn lua_gettop(L: *mut lua_State) -> c_int;
pub fn lua_settop(L: *mut lua_State, idx: c_int);
pub fn lua_type(L: *mut lua_State, idx: c_int) -> c_int;
pub fn lua_pushvalue(L: *mut lua_State, idx: c_int);
// Push
pub fn lua_pushnumber(L: *mut lua_State, n: lua_Number);
pub fn lua_pushinteger(L: *mut lua_State, n: lua_Integer);
pub fn lua_pushstring(L: *mut lua_State, s: *const c_char) -> *const c_char;
pub fn lua_pushcclosure(L: *mut lua_State, f: lua_CFunction, n: c_int);
pub fn lua_pushlightuserdata(L: *mut lua_State, p: *mut c_void);
pub fn lua_pushboolean(L: *mut lua_State, b: c_int);
pub fn lua_pushnil(L: *mut lua_State);
// Get
pub fn lua_tonumberx(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Number;
pub fn lua_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char;
pub fn lua_touserdata(L: *mut lua_State, idx: c_int) -> *mut c_void;
pub fn lua_toboolean(L: *mut lua_State, idx: c_int) -> c_int;
pub fn lua_tointegerx(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Integer;
// Table
pub fn lua_createtable(L: *mut lua_State, narr: c_int, nrec: c_int);
pub fn lua_setfield(L: *mut lua_State, idx: c_int, k: *const c_char);
pub fn lua_getfield(L: *mut lua_State, idx: c_int, k: *const c_char) -> c_int;
pub fn lua_next(L: *mut lua_State, idx: c_int) -> c_int;
pub fn lua_rawgeti(L: *mut lua_State, idx: c_int, n: lua_Integer) -> c_int;
pub fn lua_rawseti(L: *mut lua_State, idx: c_int, n: lua_Integer);
// Globals
pub fn lua_getglobal(L: *mut lua_State, name: *const c_char) -> c_int;
pub fn lua_setglobal(L: *mut lua_State, name: *const c_char);
// Coroutine
pub fn lua_newthread(L: *mut lua_State) -> *mut lua_State;
pub fn lua_resume(L: *mut lua_State, from: *mut lua_State, nargs: c_int, nresults: *mut c_int) -> c_int;
pub fn lua_xmove(from: *mut lua_State, to: *mut lua_State, n: c_int);
pub fn lua_status(L: *mut lua_State) -> c_int;
}
// Helper: lua_pcall macro equivalent

View File

@@ -0,0 +1,150 @@
use crate::state::LuaState;
/// List of dangerous globals to remove for sandboxing.
const BLOCKED_GLOBALS: &[&str] = &[
"os",
"io",
"loadfile",
"dofile",
"require",
"load", // can load arbitrary bytecode
"rawget", // bypass metatables
"rawset", // bypass metatables
"rawequal",
"rawlen",
"collectgarbage", // can manipulate GC
"debug", // full debug access
];
/// Apply sandboxing to a LuaState by removing dangerous globals.
/// Call this after `LuaState::new()` and before executing any user scripts.
///
/// Allowed: math, string, table, pairs, ipairs, print, type, tostring, tonumber,
/// pcall, xpcall, error, select, unpack, next, coroutine, assert
pub fn create_sandbox(state: &LuaState) -> Result<(), String> {
let mut code = String::new();
for global in BLOCKED_GLOBALS {
code.push_str(&format!("{} = nil\n", global));
}
state.exec(&code)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::LuaState;
#[test]
fn test_os_execute_blocked() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
let result = lua.exec("os.execute('echo hello')");
assert!(result.is_err(), "os.execute should be blocked");
}
#[test]
fn test_io_blocked() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
let result = lua.exec("io.open('/etc/passwd', 'r')");
assert!(result.is_err(), "io.open should be blocked");
}
#[test]
fn test_loadfile_blocked() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
let result = lua.exec("loadfile('something.lua')");
assert!(result.is_err(), "loadfile should be blocked");
}
#[test]
fn test_dofile_blocked() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
let result = lua.exec("dofile('something.lua')");
assert!(result.is_err(), "dofile should be blocked");
}
#[test]
fn test_require_blocked() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
let result = lua.exec("require('os')");
assert!(result.is_err(), "require should be blocked");
}
#[test]
fn test_debug_blocked() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
let result = lua.exec("debug.getinfo(1)");
assert!(result.is_err(), "debug should be blocked");
}
#[test]
fn test_math_allowed() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
lua.exec("result = math.sin(0)").unwrap();
assert_eq!(lua.get_global_number("result"), Some(0.0));
}
#[test]
fn test_string_allowed() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
lua.exec("result = string.len('hello')").unwrap();
assert_eq!(lua.get_global_number("result"), Some(5.0));
}
#[test]
fn test_table_functions_allowed() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
lua.exec("
t = {3, 1, 2}
table.sort(t)
result = t[1]
").unwrap();
assert_eq!(lua.get_global_number("result"), Some(1.0));
}
#[test]
fn test_pairs_ipairs_allowed() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
lua.exec("
sum = 0
for _, v in ipairs({1, 2, 3}) do sum = sum + v end
").unwrap();
assert_eq!(lua.get_global_number("sum"), Some(6.0));
}
#[test]
fn test_type_tostring_tonumber_allowed() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
lua.exec("
t = type(42)
n = tonumber('10')
s = tostring(42)
").unwrap();
assert_eq!(lua.get_global_string("t"), Some("number".to_string()));
assert_eq!(lua.get_global_number("n"), Some(10.0));
assert_eq!(lua.get_global_string("s"), Some("42".to_string()));
}
#[test]
fn test_coroutine_allowed_after_sandbox() {
let lua = LuaState::new();
create_sandbox(&lua).unwrap();
lua.exec("
function coro() coroutine.yield() end
co = coroutine.create(coro)
coroutine.resume(co)
status = coroutine.status(co)
").unwrap();
assert_eq!(lua.get_global_string("status"), Some("suspended".to_string()));
}
}

View File

@@ -1,6 +1,16 @@
use std::ffi::{CStr, CString};
use crate::ffi;
/// Represents a value that can be passed to/from Lua.
#[derive(Debug, Clone, PartialEq)]
pub enum LuaValue {
Nil,
Bool(bool),
Number(f64),
String(String),
Table(Vec<(String, LuaValue)>),
}
pub struct LuaState {
state: *mut ffi::lua_State,
}
@@ -109,6 +119,186 @@ impl LuaState {
}
}
/// Push a LuaValue onto the Lua stack.
pub fn push_value(&self, value: &LuaValue) {
unsafe {
match value {
LuaValue::Nil => ffi::lua_pushnil(self.state),
LuaValue::Bool(b) => ffi::lua_pushboolean(self.state, if *b { 1 } else { 0 }),
LuaValue::Number(n) => ffi::lua_pushnumber(self.state, *n),
LuaValue::String(s) => {
let cs = CString::new(s.as_str()).unwrap();
ffi::lua_pushstring(self.state, cs.as_ptr());
}
LuaValue::Table(pairs) => {
let borrowed: Vec<(&str, LuaValue)> = pairs.iter()
.map(|(k, v)| (k.as_str(), v.clone()))
.collect();
self.push_table(&borrowed);
}
}
}
}
/// Push a Rust slice of key-value pairs as a Lua table onto the stack.
pub fn push_table(&self, pairs: &[(&str, LuaValue)]) {
unsafe {
ffi::lua_createtable(self.state, 0, pairs.len() as i32);
for (key, val) in pairs {
self.push_value(val);
let ckey = CString::new(*key).unwrap();
ffi::lua_setfield(self.state, -2, ckey.as_ptr());
}
}
}
/// Read a Lua table at the given stack index into key-value pairs.
pub fn read_table(&self, index: i32) -> Result<Vec<(String, LuaValue)>, String> {
unsafe {
if ffi::lua_type(self.state, index) != ffi::LUA_TTABLE {
return Err("expected table".to_string());
}
let abs_idx = if index < 0 {
ffi::lua_gettop(self.state) + index + 1
} else {
index
};
let mut result = Vec::new();
ffi::lua_pushnil(self.state); // first key
while ffi::lua_next(self.state, abs_idx) != 0 {
// key at -2, value at -1
let key = self.read_stack_as_string(-2);
let val = self.read_stack_value(-1);
result.push((key, val));
ffi::lua_pop(self.state, 1); // pop value, keep key for next iteration
}
Ok(result)
}
}
/// Push a Vec3 as a Lua table {x=, y=, z=}.
pub fn push_vec3(&self, v: [f32; 3]) {
let pairs: Vec<(&str, LuaValue)> = vec![
("x", LuaValue::Number(v[0] as f64)),
("y", LuaValue::Number(v[1] as f64)),
("z", LuaValue::Number(v[2] as f64)),
];
self.push_table(&pairs);
}
/// Read a Vec3 table from the stack. Supports {x=, y=, z=} named fields
/// or {[1]=, [2]=, [3]=} array-style fields.
pub fn read_vec3(&self, index: i32) -> Result<[f32; 3], String> {
unsafe {
if ffi::lua_type(self.state, index) != ffi::LUA_TTABLE {
return Err("expected table for vec3".to_string());
}
let abs_idx = if index < 0 {
ffi::lua_gettop(self.state) + index + 1
} else {
index
};
// Try named fields first
let cx = CString::new("x").unwrap();
let cy = CString::new("y").unwrap();
let cz = CString::new("z").unwrap();
let tx = ffi::lua_getfield(self.state, abs_idx, cx.as_ptr());
if tx == ffi::LUA_TNUMBER {
let mut isnum = 0;
let x = ffi::lua_tonumberx(self.state, -1, &mut isnum) as f32;
ffi::lua_pop(self.state, 1);
ffi::lua_getfield(self.state, abs_idx, cy.as_ptr());
let y = ffi::lua_tonumberx(self.state, -1, &mut isnum) as f32;
ffi::lua_pop(self.state, 1);
ffi::lua_getfield(self.state, abs_idx, cz.as_ptr());
let z = ffi::lua_tonumberx(self.state, -1, &mut isnum) as f32;
ffi::lua_pop(self.state, 1);
return Ok([x, y, z]);
}
ffi::lua_pop(self.state, 1);
// Try array-style {v[1], v[2], v[3]}
let mut vals = [0.0f32; 3];
for i in 0..3 {
ffi::lua_rawgeti(self.state, abs_idx, (i + 1) as ffi::lua_Integer);
let mut isnum = 0;
vals[i] = ffi::lua_tonumberx(self.state, -1, &mut isnum) as f32;
ffi::lua_pop(self.state, 1);
if isnum == 0 {
return Err("vec3 array element is not a number".to_string());
}
}
Ok(vals)
}
}
/// Read a value from the Lua stack at the given index.
pub fn read_stack_value(&self, index: i32) -> LuaValue {
unsafe {
match ffi::lua_type(self.state, index) {
ffi::LUA_TNIL => LuaValue::Nil,
ffi::LUA_TBOOLEAN => {
LuaValue::Bool(ffi::lua_toboolean(self.state, index) != 0)
}
ffi::LUA_TNUMBER => {
let mut isnum = 0;
let n = ffi::lua_tonumberx(self.state, index, &mut isnum);
LuaValue::Number(n)
}
ffi::LUA_TSTRING => {
let ptr = ffi::lua_tostring(self.state, index);
if ptr.is_null() {
LuaValue::Nil
} else {
LuaValue::String(CStr::from_ptr(ptr).to_string_lossy().into_owned())
}
}
ffi::LUA_TTABLE => {
// Read recursively
self.read_table(index).map(LuaValue::Table).unwrap_or(LuaValue::Nil)
}
_ => LuaValue::Nil,
}
}
}
/// Read the stack value at index as a string key (for table iteration).
fn read_stack_as_string(&self, index: i32) -> String {
unsafe {
match ffi::lua_type(self.state, index) {
ffi::LUA_TSTRING => {
let ptr = ffi::lua_tostring(self.state, index);
if ptr.is_null() {
String::new()
} else {
CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
}
ffi::LUA_TNUMBER => {
let mut isnum = 0;
let n = ffi::lua_tonumberx(self.state, index, &mut isnum);
// Format integer keys without decimal
if n == (n as i64) as f64 {
format!("{}", n as i64)
} else {
format!("{}", n)
}
}
_ => String::new(),
}
}
}
/// Re-execute a Lua script file on the existing state (hot reload).
pub fn reload_file(&mut self, path: &str) -> Result<(), String> {
self.exec_file(path)
}
/// Raw Lua state pointer (for advanced use).
pub fn raw(&self) -> *mut ffi::lua_State {
self.state
@@ -194,4 +384,122 @@ mod tests {
lua.exec("result = add_ten(5)").unwrap();
assert_eq!(lua.get_global_number("result"), Some(15.0));
}
#[test]
fn test_push_read_table_roundtrip() {
let lua = LuaState::new();
let pairs: Vec<(&str, LuaValue)> = vec![
("name", LuaValue::String("test".into())),
("value", LuaValue::Number(42.0)),
("active", LuaValue::Bool(true)),
];
lua.push_table(&pairs);
unsafe { crate::ffi::lua_setglobal(lua.state, b"tbl\0".as_ptr() as *const _); }
lua.exec("result_name = tbl.name; result_val = tbl.value").unwrap();
assert_eq!(lua.get_global_string("result_name"), Some("test".to_string()));
assert_eq!(lua.get_global_number("result_val"), Some(42.0));
}
#[test]
fn test_read_table_from_lua() {
let lua = LuaState::new();
lua.exec("test_table = { greeting = 'hello', count = 7 }").unwrap();
unsafe {
let name = std::ffi::CString::new("test_table").unwrap();
crate::ffi::lua_getglobal(lua.state, name.as_ptr());
let table = lua.read_table(-1).unwrap();
crate::ffi::lua_pop(lua.state, 1);
let has_greeting = table.iter().any(|(k, v)| {
k == "greeting" && *v == LuaValue::String("hello".into())
});
let has_count = table.iter().any(|(k, v)| {
k == "count" && *v == LuaValue::Number(7.0)
});
assert!(has_greeting, "table should contain greeting='hello'");
assert!(has_count, "table should contain count=7");
}
}
#[test]
fn test_nested_table() {
let lua = LuaState::new();
lua.exec("nested = { inner = { val = 99 } }").unwrap();
unsafe {
let name = std::ffi::CString::new("nested").unwrap();
crate::ffi::lua_getglobal(lua.state, name.as_ptr());
let table = lua.read_table(-1).unwrap();
crate::ffi::lua_pop(lua.state, 1);
let inner = table.iter().find(|(k, _)| k == "inner");
assert!(inner.is_some());
if let Some((_, LuaValue::Table(inner_pairs))) = inner {
let has_val = inner_pairs.iter().any(|(k, v)| {
k == "val" && *v == LuaValue::Number(99.0)
});
assert!(has_val);
} else {
panic!("inner should be a table");
}
}
}
#[test]
fn test_vec3_named_roundtrip() {
let lua = LuaState::new();
lua.push_vec3([1.0, 2.5, 3.0]);
unsafe { crate::ffi::lua_setglobal(lua.state, b"v\0".as_ptr() as *const _); }
// Read it back
unsafe {
crate::ffi::lua_getglobal(lua.state, b"v\0".as_ptr() as *const _);
let v = lua.read_vec3(-1).unwrap();
crate::ffi::lua_pop(lua.state, 1);
assert_eq!(v, [1.0, 2.5, 3.0]);
}
}
#[test]
fn test_vec3_from_lua_named() {
let lua = LuaState::new();
lua.exec("v = { x = 10, y = 20, z = 30 }").unwrap();
unsafe {
crate::ffi::lua_getglobal(lua.state, b"v\0".as_ptr() as *const _);
let v = lua.read_vec3(-1).unwrap();
crate::ffi::lua_pop(lua.state, 1);
assert_eq!(v, [10.0, 20.0, 30.0]);
}
}
#[test]
fn test_vec3_from_lua_array() {
let lua = LuaState::new();
lua.exec("v = { 4, 5, 6 }").unwrap();
unsafe {
crate::ffi::lua_getglobal(lua.state, b"v\0".as_ptr() as *const _);
let v = lua.read_vec3(-1).unwrap();
crate::ffi::lua_pop(lua.state, 1);
assert_eq!(v, [4.0, 5.0, 6.0]);
}
}
#[test]
fn test_hot_reload() {
let dir = std::env::temp_dir().join("voltex_script_reload_test");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("reload_test.lua");
// Initial script
std::fs::write(&path, "counter = 1").unwrap();
let mut lua = LuaState::new();
lua.exec_file(path.to_str().unwrap()).unwrap();
assert_eq!(lua.get_global_number("counter"), Some(1.0));
// Modified script
std::fs::write(&path, "counter = counter + 10").unwrap();
lua.reload_file(path.to_str().unwrap()).unwrap();
assert_eq!(lua.get_global_number("counter"), Some(11.0));
let _ = std::fs::remove_dir_all(&dir);
}
}