diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 6e4853f..6ca1874 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -24,6 +24,7 @@ jobs: os: - ubuntu-latest - macos-latest + - windows-latest steps: - name: Check out repository diff --git a/README.md b/README.md index d24cc24..b5f444d 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,11 @@ including: * [`openai_oauth`](./examples/openai_oauth.rs): OpenAI OAuth-backed provider setup. +The builtin runtime shell uses `/bin/sh` on Unix hosts and `cmd.exe` on +Windows hosts. The OpenAI OAuth example keeps `PersistentTokenStoreKind::Auto` +platform-native as well: macOS uses Keychain, while Windows and Linux use the +file-backed store. + ## Getting Started If you want to explore the workspace after cloning the repository, the quickest diff --git a/mentra/Cargo.toml b/mentra/Cargo.toml index 1cc03ca..f25f025 100644 --- a/mentra/Cargo.toml +++ b/mentra/Cargo.toml @@ -41,3 +41,6 @@ libc = "0.2" regex = "1.12.2" rand = { version = "0.9.2", optional = true } ring = { version = "0.17.14", optional = true } + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.61.2", features = ["Win32_Foundation", "Win32_System_Threading"] } diff --git a/mentra/README.md b/mentra/README.md index e4396ac..7f398a5 100644 --- a/mentra/README.md +++ b/mentra/README.md @@ -143,11 +143,14 @@ Mentra's builtin runtime tools are available by default, but command execution i - foreground shell execution is disabled by default - background command execution is disabled by default - `RuntimePolicy::permissive()` enables both shell and background command execution +- builtin shell commands run through `/bin/sh -c` on Unix and `cmd.exe /C` on Windows - runtime policy still enforces hard limits such as working-directory roots, file read/write roots, allowed environment variables, timeouts, output caps, and background task limits - semantic review is opt-in through `RuntimeBuilder::with_tool_authorizer(...)` Use the default policy when you want a safer runtime surface, and opt into `RuntimePolicy::permissive()` only when you are intentionally building a coding-agent or automation workflow that should be able to act on the local workspace. +If you need different command semantics, such as PowerShell on Windows or a sandboxed executor, replace the default local executor with `RuntimeBuilder::with_executor(...)`. + ## Tool Authorization Mentra can run a caller-provided authorization pass before any tool executes. This is the recommended integration point for LLM-based security review, human approval, or custom policy engines. diff --git a/mentra/src/agent/config.rs b/mentra/src/agent/config.rs index feb8fd1..990d3fa 100644 --- a/mentra/src/agent/config.rs +++ b/mentra/src/agent/config.rs @@ -262,11 +262,17 @@ mod tests { use crate::provider::{ReasoningEffort, ReasoningOptions}; + fn test_path(label: &str) -> PathBuf { + std::env::temp_dir() + .join("mentra-agent-config-tests") + .join(label) + } + #[test] fn explicit_paths_override_defaults() { - let tasks_dir = PathBuf::from("/tmp/custom-tasks"); - let team_dir = PathBuf::from("/tmp/custom-team"); - let transcript_dir = PathBuf::from("/tmp/custom-transcripts"); + let tasks_dir = test_path("custom-tasks"); + let team_dir = test_path("custom-team"); + let transcript_dir = test_path("custom-transcripts"); let config = AgentConfig { task: TaskConfig { diff --git a/mentra/src/agent/tests/runtime_snapshot.rs b/mentra/src/agent/tests/runtime_snapshot.rs index 8c2bc68..62d23e7 100644 --- a/mentra/src/agent/tests/runtime_snapshot.rs +++ b/mentra/src/agent/tests/runtime_snapshot.rs @@ -1,13 +1,24 @@ -use tokio::sync::watch; +use std::{ + sync::atomic::{AtomicU64, Ordering}, + time::{SystemTime, UNIX_EPOCH}, +}; + +use tokio::{ + sync::watch, + time::{Duration, timeout}, +}; use crate::{ BackgroundTaskStatus, BuiltinProvider, ContentBlock, Role, agent::{AgentSnapshot, AgentStatus}, provider::{ContentBlockDelta, ContentBlockStart, ProviderEvent}, - runtime::{Runtime, RuntimePolicy}, + runtime::{Runtime, RuntimePolicy, SqliteRuntimeStore}, }; -use super::support::{ScriptedProvider, controlled_stream, model_info, ok_stream}; +use super::support::{ + ScriptedProvider, background_success_command, command_input_json, controlled_stream, + model_info, ok_stream, +}; #[tokio::test] async fn snapshot_progresses_during_streaming() { @@ -85,6 +96,7 @@ async fn snapshot_progresses_during_streaming() { #[tokio::test] async fn snapshot_updates_when_background_task_finishes() { + let command = background_success_command("bg-done", 50); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, @@ -105,9 +117,7 @@ async fn snapshot_updates_when_background_task_finishes() { }, ProviderEvent::ContentBlockDelta { index: 0, - delta: ContentBlockDelta::ToolUseInputJson( - r#"{"command":"sleep 0.05; printf bg-done"}"#.to_string(), - ), + delta: ContentBlockDelta::ToolUseInputJson(command_input_json(&command)), }, ProviderEvent::ContentBlockStopped { index: 0 }, ProviderEvent::MessageStopped, @@ -133,6 +143,7 @@ async fn snapshot_updates_when_background_task_finishes() { ); let runtime = Runtime::builder() + .with_store(temp_store("snapshot-background-finish")) .with_policy(RuntimePolicy::permissive()) .with_provider_instance(provider) .build() @@ -147,39 +158,59 @@ async fn snapshot_updates_when_background_task_finishes() { .await .unwrap(); - wait_for_background_status(&mut snapshot, BackgroundTaskStatus::Running).await; wait_for_background_status(&mut snapshot, BackgroundTaskStatus::Finished).await; assert_eq!(snapshot.borrow().background_tasks.len(), 1); - assert_eq!( + assert!( snapshot.borrow().background_tasks[0] .output_preview - .as_deref(), - Some("bg-done") + .as_deref() + .is_some_and(|preview| preview.contains("bg-done")) ); } +static NEXT_TEMP_ID: AtomicU64 = AtomicU64::new(1); + +fn temp_store(label: &str) -> SqliteRuntimeStore { + let unique = NEXT_TEMP_ID.fetch_add(1, Ordering::Relaxed); + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time") + .as_nanos(); + SqliteRuntimeStore::new(std::env::temp_dir().join(format!( + "mentra-runtime-store-{label}-{timestamp}-{unique}.sqlite" + ))) +} + async fn wait_for_status(receiver: &mut watch::Receiver, status: AgentStatus) { - loop { - if receiver.borrow().status == status { - return; + timeout(Duration::from_secs(90), async { + loop { + if receiver.borrow().status == status { + return; + } + receiver.changed().await.unwrap(); } - receiver.changed().await.unwrap(); - } + }) + .await + .unwrap_or_else(|_| panic!("timed out waiting for agent status {status:?}")); } async fn wait_for_background_status( receiver: &mut watch::Receiver, status: BackgroundTaskStatus, ) { - loop { - if receiver - .borrow() - .background_tasks - .iter() - .any(|task| task.status == status) - { - return; + timeout(Duration::from_secs(90), async { + loop { + if receiver + .borrow() + .background_tasks + .iter() + .any(|task| task.status == status) + { + return; + } + receiver.changed().await.unwrap(); } - receiver.changed().await.unwrap(); - } + }) + .await + .unwrap_or_else(|_| panic!("timed out waiting for background status {status:?}")); } diff --git a/mentra/src/agent/tests/runtime_tools.rs b/mentra/src/agent/tests/runtime_tools.rs index 0cbf762..51024d6 100644 --- a/mentra/src/agent/tests/runtime_tools.rs +++ b/mentra/src/agent/tests/runtime_tools.rs @@ -32,8 +32,9 @@ use crate::{ }; use super::support::{ - ProbeTool, ScriptedProvider, StaticTool, StreamScript, controlled_stream, erroring_stream, - model_info, ok_stream, + ProbeTool, ScriptedProvider, StaticTool, StreamScript, background_failure_command, + background_success_command, command_input_json, command_input_with_working_directory_json, + controlled_stream, erroring_stream, model_info, ok_stream, shell_pwd_command, }; #[tokio::test] @@ -296,17 +297,14 @@ async fn malformed_tool_json_is_reported_back_to_model_instead_of_aborting() { #[tokio::test] async fn background_run_tool_starts_task_and_continues_the_turn() { + let command = background_success_command("bg-done", 200); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream( - &model.id, - "tool-bg", - "background_run", - r#"{"command":"sleep 0.2; printf bg-done"}"#, - ), + tool_use_stream(&model.id, "tool-bg", "background_run", &input), text_stream(&model.id, "continued"), ], ); @@ -331,10 +329,7 @@ async fn background_run_tool_starts_task_and_continues_the_turn() { agent.history()[2], Message::user(ContentBlock::ToolResult { tool_use_id: "tool-bg".to_string(), - content: format!( - "Started background task bg-1 in {cwd} for `sleep 0.2; printf bg-done`" - ) - .into(), + content: format!("Started background task bg-1 in {cwd} for `{command}`").into(), is_error: false, }) ); @@ -351,23 +346,20 @@ async fn background_run_tool_starts_task_and_continues_the_turn() { assert!(events.iter().any(|event| matches!( event, AgentEvent::BackgroundTaskStarted { task } - if task.id == "bg-1" && task.command == "sleep 0.2; printf bg-done" + if task.id == "bg-1" && task.command == command ))); } #[tokio::test] async fn completed_background_results_are_injected_on_next_send() { + let command = background_success_command("bg-done", 50); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream( - &model.id, - "tool-bg", - "background_run", - r#"{"command":"sleep 0.05; printf bg-done"}"#, - ), + tool_use_stream(&model.id, "tool-bg", "background_run", &input), text_stream(&model.id, "continued"), text_stream(&model.id, "next turn"), ], @@ -375,6 +367,7 @@ async fn completed_background_results_are_injected_on_next_send() { let provider_handle = provider.clone(); let runtime = Runtime::builder() + .with_store(temp_store("bg-results-next-send")) .with_policy(RuntimePolicy::permissive()) .with_provider_instance(provider) .build() @@ -400,23 +393,20 @@ async fn completed_background_results_are_injected_on_next_send() { let injected = latest_background_results_text(&requests[2]).expect("background results"); assert!(injected.contains("")); assert!(injected.contains("[bg:bg-1] status=finished")); - assert!(injected.contains("command=\"sleep 0.05; printf bg-done\"")); - assert!(injected.contains("output=\"bg-done\"")); + assert!(injected.contains(&format!("command=\"{command}\""))); + assert!(injected.contains("output=\"bg-done")); } #[tokio::test] async fn teammate_auto_wakes_after_background_task_finishes() { + let command = background_success_command("bg-done", 50); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream( - &model.id, - "tool-bg", - "background_run", - r#"{"command":"sleep 0.05; printf bg-done"}"#, - ), + tool_use_stream(&model.id, "tool-bg", "background_run", &input), text_stream(&model.id, "started"), text_stream(&model.id, "processed background result"), ], @@ -472,17 +462,14 @@ async fn teammate_auto_wakes_after_background_task_finishes() { #[tokio::test] async fn check_background_reports_single_task_and_lists_all_tasks() { + let command = background_success_command("bg-done", 50); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream( - &model.id, - "tool-bg", - "background_run", - r#"{"command":"sleep 0.05; printf bg-done"}"#, - ), + tool_use_stream(&model.id, "tool-bg", "background_run", &input), text_stream(&model.id, "started"), multi_tool_use_stream( &model.id, @@ -496,6 +483,7 @@ async fn check_background_reports_single_task_and_lists_all_tasks() { ); let runtime = Runtime::builder() + .with_store(temp_store("bg-check-reports")) .with_policy(RuntimePolicy::permissive()) .with_provider_instance(provider) .build() @@ -518,36 +506,43 @@ async fn check_background_reports_single_task_and_lists_all_tasks() { .unwrap(); let cwd = agent.config().workspace.base_dir.display().to_string(); - assert_eq!( - agent.history()[7], - Message { - role: Role::User, - content: vec![ - ContentBlock::ToolResult { - tool_use_id: "check-one".to_string(), - content: format!("[finished] cwd={cwd}\nsleep 0.05; printf bg-done\nbg-done") - .into(), - is_error: false, - }, - ContentBlock::ToolResult { - tool_use_id: "check-all".to_string(), - content: format!("bg-1: [finished] cwd={cwd} sleep 0.05; printf bg-done") - .into(), - is_error: false, - }, - ], + let message = &agent.history()[7]; + assert_eq!(message.role, Role::User); + assert_eq!(message.content.len(), 2); + match (&message.content[0], &message.content[1]) { + ( + ContentBlock::ToolResult { + tool_use_id: check_one_id, + content: check_one_content, + is_error: false, + }, + ContentBlock::ToolResult { + tool_use_id: check_all_id, + content: check_all_content, + is_error: false, + }, + ) => { + assert_eq!(check_one_id, "check-one"); + assert_eq!(check_all_id, "check-all"); + assert!( + check_one_content.contains(&format!("[finished] cwd={cwd}\n{command}\nbg-done")) + ); + assert!(check_all_content.contains(&format!("bg-1: [finished] cwd={cwd} {command}"))); } - ); + other => panic!("unexpected tool result payloads: {other:?}"), + } } #[tokio::test] async fn task_working_directory_routes_shell_for_teammate() { + let command = shell_pwd_command(); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream(&model.id, "pwd", "shell", r#"{"command":"pwd"}"#), + tool_use_stream(&model.id, "pwd", "shell", &input), text_stream(&model.id, "done"), ], ); @@ -613,12 +608,14 @@ async fn task_working_directory_routes_shell_for_teammate() { #[tokio::test] async fn teammate_shell_without_working_directory_uses_base_dir() { + let command = shell_pwd_command(); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream(&model.id, "pwd", "shell", r#"{"command":"pwd"}"#), + tool_use_stream(&model.id, "pwd", "shell", &input), text_stream(&model.id, "handled"), ], ); @@ -667,17 +664,14 @@ async fn teammate_shell_without_working_directory_uses_base_dir() { #[tokio::test] async fn shell_working_directory_overrides_default_routing() { + let command = shell_pwd_command(); + let input = command_input_with_working_directory_json(&command, "custom"); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream( - &model.id, - "pwd", - "shell", - r#"{"command":"pwd","workingDirectory":"custom"}"#, - ), + tool_use_stream(&model.id, "pwd", "shell", &input), text_stream(&model.id, "done"), ], ); @@ -1542,6 +1536,10 @@ async fn run_options_cancelled_run_stops_before_provider_request() { #[tokio::test] async fn completed_background_results_are_batched_in_completion_order() { + let first_command = background_success_command("first", 20); + let second_command = background_success_command("second", 50); + let first_input = command_input_json(&first_command); + let second_input = command_input_json(&second_command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, @@ -1550,16 +1548,8 @@ async fn completed_background_results_are_batched_in_completion_order() { multi_tool_use_stream( &model.id, &[ - ( - "tool-bg-1", - "background_run", - r#"{"command":"sleep 0.02; printf first"}"#, - ), - ( - "tool-bg-2", - "background_run", - r#"{"command":"sleep 0.05; printf second"}"#, - ), + ("tool-bg-1", "background_run", first_input.as_str()), + ("tool-bg-2", "background_run", second_input.as_str()), ], ), text_stream(&model.id, "continued"), @@ -1569,6 +1559,7 @@ async fn completed_background_results_are_batched_in_completion_order() { let provider_handle = provider.clone(); let runtime = Runtime::builder() + .with_store(temp_store("bg-results-batched-order")) .with_policy(RuntimePolicy::permissive()) .with_provider_instance(provider) .build() @@ -1596,23 +1587,20 @@ async fn completed_background_results_are_batched_in_completion_order() { let first = injected.find("[bg:bg-1]").expect("first task line"); let second = injected.find("[bg:bg-2]").expect("second task line"); assert!(first < second); - assert!(injected.contains("output=\"first\"")); - assert!(injected.contains("output=\"second\"")); + assert!(injected.contains("output=\"first")); + assert!(injected.contains("output=\"second")); } #[tokio::test] async fn failed_background_results_surface_in_snapshot_events_and_notifications() { + let command = background_failure_command("boom", 7, 50); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream( - &model.id, - "tool-bg", - "background_run", - r#"{"command":"sleep 0.05; echo boom >&2; exit 7"}"#, - ), + tool_use_stream(&model.id, "tool-bg", "background_run", &input), text_stream(&model.id, "continued"), text_stream(&model.id, "next turn"), ], @@ -1620,6 +1608,7 @@ async fn failed_background_results_surface_in_snapshot_events_and_notifications( let provider_handle = provider.clone(); let runtime = Runtime::builder() + .with_store(temp_store("bg-failure-results")) .with_policy(RuntimePolicy::permissive()) .with_provider_instance(provider) .build() @@ -1664,22 +1653,19 @@ async fn failed_background_results_surface_in_snapshot_events_and_notifications( let requests = provider_handle.recorded_requests().await; let injected = latest_background_results_text(&requests[2]).expect("background results"); assert!(injected.contains("status=failed")); - assert!(injected.contains("output=\"boom\"")); + assert!(injected.contains("output=\"boom")); } #[tokio::test] async fn drained_background_notifications_are_requeued_after_failed_run() { + let command = background_success_command("bg-done", 50); + let input = command_input_json(&command); let model = model_info("model", BuiltinProvider::Anthropic); let provider = ScriptedProvider::new( BuiltinProvider::Anthropic, vec![model.clone()], vec![ - tool_use_stream( - &model.id, - "tool-bg", - "background_run", - r#"{"command":"sleep 0.05; printf bg-done"}"#, - ), + tool_use_stream(&model.id, "tool-bg", "background_run", &input), text_stream(&model.id, "continued"), erroring_stream( vec![ProviderEvent::MessageStarted { @@ -1695,6 +1681,7 @@ async fn drained_background_notifications_are_requeued_after_failed_run() { let provider_handle = provider.clone(); let runtime = Runtime::builder() + .with_store(temp_store("bg-requeue-failed-run")) .with_policy(RuntimePolicy::permissive()) .with_provider_instance(provider) .build() @@ -4244,23 +4231,27 @@ fn collect_events(receiver: &mut tokio::sync::broadcast::Receiver) - events } +const SHORT_WAIT_ATTEMPTS: usize = 200; +const BACKGROUND_WAIT_ATTEMPTS: usize = 6000; +const POLL_INTERVAL_MS: u64 = 10; + async fn wait_for_pending_team_messages(agent: &Agent, expected_count: usize) { - for _ in 0..200 { + for _ in 0..SHORT_WAIT_ATTEMPTS { if agent.watch_snapshot().borrow().pending_team_messages == expected_count { return; } - sleep(Duration::from_millis(10)).await; + sleep(Duration::from_millis(POLL_INTERVAL_MS)).await; } panic!("timed out waiting for {expected_count} pending team messages"); } async fn wait_for_background_task_count(agent: &Agent, expected_count: usize) { - for _ in 0..200 { + for _ in 0..BACKGROUND_WAIT_ATTEMPTS { if agent.watch_snapshot().borrow().background_tasks.len() == expected_count { return; } - sleep(Duration::from_millis(10)).await; + sleep(Duration::from_millis(POLL_INTERVAL_MS)).await; } panic!("timed out waiting for {expected_count} background tasks"); @@ -4271,14 +4262,14 @@ async fn wait_for_background_tasks( expected_count: usize, status: BackgroundTaskStatus, ) { - for _ in 0..200 { + for _ in 0..BACKGROUND_WAIT_ATTEMPTS { let background_tasks = agent.watch_snapshot().borrow().background_tasks.clone(); if background_tasks.len() == expected_count && background_tasks.iter().all(|task| task.status == status) { return; } - sleep(Duration::from_millis(10)).await; + sleep(Duration::from_millis(POLL_INTERVAL_MS)).await; } panic!("timed out waiting for {expected_count} background tasks to reach {status:?}"); @@ -4552,11 +4543,11 @@ fn write_skill(root: &Path, name: &str, content: &str) { } async fn wait_for_recorded_requests(provider: &ScriptedProvider, expected: usize) { - for _ in 0..500 { + for _ in 0..BACKGROUND_WAIT_ATTEMPTS { if provider.recorded_requests().await.len() >= expected { return; } - sleep(Duration::from_millis(10)).await; + sleep(Duration::from_millis(POLL_INTERVAL_MS)).await; } panic!("timed out waiting for {expected} recorded requests"); @@ -4568,7 +4559,7 @@ async fn wait_for_background_task_status( task_id: &str, expected_status: BackgroundTaskStatus, ) { - for _ in 0..500 { + for _ in 0..BACKGROUND_WAIT_ATTEMPTS { let tasks = ::load_background_tasks( store, agent_id, @@ -4580,7 +4571,7 @@ async fn wait_for_background_task_status( { return; } - sleep(Duration::from_millis(10)).await; + sleep(Duration::from_millis(POLL_INTERVAL_MS)).await; } panic!("timed out waiting for background task {task_id} to reach {expected_status:?}"); @@ -4591,7 +4582,7 @@ async fn wait_for_background_task_record( agent_id: &str, expected_count: usize, ) { - for _ in 0..500 { + for _ in 0..BACKGROUND_WAIT_ATTEMPTS { let tasks = ::load_background_tasks( store, agent_id, @@ -4600,7 +4591,7 @@ async fn wait_for_background_task_record( if tasks.len() == expected_count { return; } - sleep(Duration::from_millis(10)).await; + sleep(Duration::from_millis(POLL_INTERVAL_MS)).await; } panic!("timed out waiting for {expected_count} background task records"); diff --git a/mentra/src/agent/tests/support.rs b/mentra/src/agent/tests/support.rs index 8fc63f0..0afbef6 100644 --- a/mentra/src/agent/tests/support.rs +++ b/mentra/src/agent/tests/support.rs @@ -91,6 +91,93 @@ pub(super) fn ok_stream(events: Vec) -> StreamScript { StreamScript::Buffered(events.into_iter().map(Ok).collect()) } +pub(super) fn command_input_json(command: &str) -> String { + json!({ "command": command }).to_string() +} + +pub(super) fn command_input_with_working_directory_json( + command: &str, + working_directory: &str, +) -> String { + json!({ + "command": command, + "workingDirectory": working_directory, + }) + .to_string() +} + +pub(super) fn shell_pwd_command() -> String { + #[cfg(unix)] + { + "pwd".to_string() + } + + #[cfg(windows)] + { + "cd".to_string() + } +} + +pub(super) fn background_success_command(output: &str, delay_ms: u64) -> String { + #[cfg(unix)] + { + format!( + "sleep {}; printf {}", + delay_seconds(delay_ms), + shell_single_quoted(output) + ) + } + + #[cfg(windows)] + { + let delay_seconds = (delay_ms / 1000).saturating_add(1); + format!( + "ping -n {delay_seconds} 127.0.0.1 >NUL & echo {output}", + output = cmd_echo_literal(output) + ) + } +} + +pub(super) fn background_failure_command(stderr: &str, exit_code: i32, delay_ms: u64) -> String { + #[cfg(unix)] + { + format!( + "sleep {}; printf {} >&2; exit {exit_code}", + delay_seconds(delay_ms), + shell_single_quoted(stderr) + ) + } + + #[cfg(windows)] + { + let delay_seconds = (delay_ms / 1000).saturating_add(1); + format!( + "ping -n {delay_seconds} 127.0.0.1 >NUL & echo {stderr} 1>&2 & exit /b {exit_code}", + stderr = cmd_echo_literal(stderr) + ) + } +} + +#[cfg(unix)] +fn delay_seconds(delay_ms: u64) -> String { + format!("{:.3}", delay_ms as f64 / 1000.0) +} + +#[cfg(unix)] +fn shell_single_quoted(value: &str) -> String { + format!("'{}'", value.replace('\'', r"'\''")) +} + +#[cfg(windows)] +fn cmd_echo_literal(value: &str) -> String { + value + .replace('^', "^^") + .replace('&', "^&") + .replace('|', "^|") + .replace('<', "^<") + .replace('>', "^>") +} + pub(super) fn erroring_stream(events: Vec, error: ProviderError) -> StreamScript { let mut items = events.into_iter().map(Ok).collect::>(); items.push(Err(error)); diff --git a/mentra/src/default_paths.rs b/mentra/src/default_paths.rs index 0fc8a01..eba50e5 100644 --- a/mentra/src/default_paths.rs +++ b/mentra/src/default_paths.rs @@ -78,10 +78,16 @@ fn workspace_hash(path: &Path) -> String { mod tests { use super::*; + fn test_path(label: &str) -> PathBuf { + std::env::temp_dir() + .join("mentra-default-paths-tests") + .join(label) + } + #[test] fn uses_platform_data_directory_when_available() { - let workspace = PathBuf::from("/tmp/workspaces/release-check"); - let data_dir = PathBuf::from("/Users/example/Library/Application Support"); + let workspace = test_path("release-check-workspace"); + let data_dir = test_path("release-check-data"); let paths = workspace_default_paths_for(workspace.clone(), Some(data_dir.clone())); @@ -105,7 +111,7 @@ mod tests { #[test] fn falls_back_to_workspace_dot_directory_without_platform_data_dir() { - let workspace = PathBuf::from("/tmp/workspaces/fallback-check"); + let workspace = test_path("fallback-check-workspace"); let paths = workspace_default_paths_for(workspace.clone(), None); @@ -120,8 +126,8 @@ mod tests { #[test] fn same_workspace_produces_shared_root_for_all_default_paths() { - let workspace = PathBuf::from("/tmp/workspaces/shared-root"); - let data_dir = PathBuf::from("/var/data"); + let workspace = test_path("shared-root-workspace"); + let data_dir = test_path("shared-root-data"); let paths = workspace_default_paths_for(workspace, Some(data_dir)); diff --git a/mentra/src/runtime/control/command.rs b/mentra/src/runtime/control/command.rs index 9a796d7..7b5d4fb 100644 --- a/mentra/src/runtime/control/command.rs +++ b/mentra/src/runtime/control/command.rs @@ -4,6 +4,9 @@ use std::{ time::Duration, }; +#[cfg(windows)] +use std::process::Command as StdCommand; + use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tokio::{ @@ -93,10 +96,9 @@ impl RuntimeExecutor for LocalRuntimeExecutor { CommandSpec::Shell { command } => command, }; - let mut process = Command::new("bash"); + let mut process = Command::new(platform_shell_program()); process - .arg("-c") - .arg(&command) + .args(platform_shell_args(&command)) .current_dir(&cwd) .env_clear() .envs(env) @@ -224,9 +226,45 @@ fn kill_entire_process_tree(child: &mut Child) -> io::Result<()> { } } + #[cfg(windows)] + { + if let Some(pid) = child.id() { + let status = StdCommand::new("taskkill") + .args(["/PID", &pid.to_string(), "/T", "/F"]) + .status()?; + if status.success() { + return Ok(()); + } + + if child.try_wait()?.is_some() { + return Ok(()); + } + } + } + child.start_kill() } +#[cfg(unix)] +fn platform_shell_program() -> &'static str { + "/bin/sh" +} + +#[cfg(windows)] +fn platform_shell_program() -> &'static str { + "cmd.exe" +} + +#[cfg(unix)] +fn platform_shell_args(command: &str) -> [&str; 2] { + ["-c", command] +} + +#[cfg(windows)] +fn platform_shell_args(command: &str) -> [&str; 2] { + ["/C", command] +} + pub async fn read_limited_file(path: &Path, max_lines: Option) -> Result { let file = tokio::fs::File::open(path) .await @@ -255,20 +293,83 @@ pub async fn read_limited_file(path: &Path, max_lines: Option) -> Result< mod tests { use super::*; + #[cfg(unix)] + fn stdout_and_stderr_command() -> String { + "printf 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'; printf 'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb' >&2" + .to_string() + } + + #[cfg(windows)] + fn stdout_and_stderr_command() -> String { + powershell_encoded_command( + "$stdout='a' * 32; $stderr='b' * 32; [Console]::Out.Write($stdout); [Console]::Error.Write($stderr)", + ) + } + + #[cfg(unix)] + fn missing_secret_command() -> String { + "printf '%s' \"${SECRET:-missing}\"".to_string() + } + + #[cfg(windows)] + fn missing_secret_command() -> String { + powershell_encoded_command( + "if ($null -ne $env:SECRET -and $env:SECRET.Length -gt 0) { [Console]::Out.Write($env:SECRET) } else { [Console]::Out.Write('missing') }", + ) + } + + #[cfg(unix)] + fn timeout_command() -> String { + "sleep 1".to_string() + } + + #[cfg(windows)] + fn timeout_command() -> String { + powershell_encoded_command("Start-Sleep -Seconds 1") + } + + #[cfg(unix)] + fn minimal_shell_env() -> Vec<(String, String)> { + vec![( + "PATH".to_string(), + std::env::var("PATH").expect("path available"), + )] + } + + #[cfg(windows)] + fn minimal_shell_env() -> Vec<(String, String)> { + ["PATH", "PATHEXT", "SystemRoot", "COMSPEC", "TEMP", "TMP"] + .into_iter() + .filter_map(|name| { + std::env::var(name) + .ok() + .map(|value| (name.to_string(), value)) + }) + .collect() + } + + #[cfg(windows)] + fn powershell_encoded_command(script: &str) -> String { + use base64::Engine as _; + + let utf16 = script + .encode_utf16() + .flat_map(|unit| unit.to_le_bytes()) + .collect::>(); + let encoded = base64::engine::general_purpose::STANDARD.encode(utf16); + format!("powershell.exe -NoProfile -EncodedCommand {encoded}") + } + #[tokio::test] async fn caps_stdout_and_stderr_independently() { let output = LocalRuntimeExecutor .run(CommandRequest { spec: CommandSpec::Shell { - command: "printf 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'; printf 'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb' >&2" - .to_string(), + command: stdout_and_stderr_command(), }, cwd: std::env::temp_dir(), timeout: Duration::from_secs(5), - env: vec![( - "PATH".to_string(), - std::env::var("PATH").expect("path available"), - )], + env: minimal_shell_env(), max_output_bytes_per_stream: 8, }) .await @@ -285,14 +386,11 @@ mod tests { let output = LocalRuntimeExecutor .run(CommandRequest { spec: CommandSpec::Shell { - command: "printf '%s' \"${SECRET:-missing}\"".to_string(), + command: missing_secret_command(), }, cwd: std::env::temp_dir(), timeout: Duration::from_secs(5), - env: vec![( - "PATH".to_string(), - std::env::var("PATH").expect("path available"), - )], + env: minimal_shell_env(), max_output_bytes_per_stream: 1024, }) .await @@ -306,14 +404,11 @@ mod tests { let output = LocalRuntimeExecutor .run(CommandRequest { spec: CommandSpec::Shell { - command: "sleep 1".to_string(), + command: timeout_command(), }, cwd: std::env::temp_dir(), timeout: Duration::from_millis(50), - env: vec![( - "PATH".to_string(), - std::env::var("PATH").expect("path available"), - )], + env: minimal_shell_env(), max_output_bytes_per_stream: 1024, }) .await diff --git a/mentra/src/runtime/control/policy.rs b/mentra/src/runtime/control/policy.rs index 1f46de1..de47982 100644 --- a/mentra/src/runtime/control/policy.rs +++ b/mentra/src/runtime/control/policy.rs @@ -27,7 +27,7 @@ impl Default for RuntimePolicy { allowed_working_roots: Vec::new(), allowed_read_roots: Vec::new(), allowed_write_roots: Vec::new(), - allowed_env_vars: vec!["PATH".to_string()], + allowed_env_vars: default_allowed_env_vars(), background_task_limit: Some(8), default_command_timeout: Duration::from_secs(30), max_command_timeout: Duration::from_secs(30), @@ -36,6 +36,26 @@ impl Default for RuntimePolicy { } } +fn default_allowed_env_vars() -> Vec { + #[cfg(windows)] + { + let mut vars = vec!["PATH".to_string()]; + vars.extend([ + "PATHEXT".to_string(), + "SystemRoot".to_string(), + "COMSPEC".to_string(), + "TEMP".to_string(), + "TMP".to_string(), + ]); + vars + } + + #[cfg(not(windows))] + { + vec!["PATH".to_string()] + } +} + impl RuntimePolicy { /// Returns a permissive policy that enables shell and background execution. pub fn permissive() -> Self { @@ -207,12 +227,13 @@ impl RuntimePolicy { } fn path_is_allowed(path: &Path, default_root: &Path, extra_roots: &[PathBuf]) -> bool { + let candidate_path = canonicalize_policy_root(path); let default_root = canonicalize_policy_root(default_root); - path.starts_with(&default_root) + candidate_path.starts_with(&default_root) || extra_roots .iter() .map(|root| canonicalize_policy_root(root)) - .any(|root| path.starts_with(root)) + .any(|root| candidate_path.starts_with(root)) } fn canonicalize_policy_root(path: &Path) -> PathBuf { @@ -316,9 +337,15 @@ mod tests { time::{SystemTime, UNIX_EPOCH}, }; + fn test_path(label: &str) -> PathBuf { + std::env::temp_dir() + .join("mentra-runtime-policy-tests") + .join(label) + } + #[test] fn shell_roots_and_background_switches_short_circuit() { - let cwd = PathBuf::from("/tmp/repo"); + let cwd = test_path("repo"); let policy = RuntimePolicy::default() .allow_shell_commands(true) .allow_background_commands(false); @@ -330,8 +357,8 @@ mod tests { #[test] fn authorize_command_execution_rejects_working_directory_outside_roots() { - let base_dir = PathBuf::from("/tmp/repo"); - let cwd = PathBuf::from("/tmp/other"); + let base_dir = test_path("repo"); + let cwd = test_path("other"); let policy = RuntimePolicy::default().allow_shell_commands(true); let error = policy diff --git a/mentra/src/runtime/store.rs b/mentra/src/runtime/store.rs index 64ca30c..7eefda2 100644 --- a/mentra/src/runtime/store.rs +++ b/mentra/src/runtime/store.rs @@ -385,15 +385,14 @@ impl BackgroundStore for SqliteRuntimeStore { let conn = self.open()?; conn.execute( r#" - INSERT INTO background_jobs (id, agent_id, payload_json, notification_state, created_at, updated_at) + INSERT INTO background_jobs (agent_id, id, payload_json, notification_state, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, ?5, ?5) - ON CONFLICT(id) DO UPDATE SET - agent_id = excluded.agent_id, + ON CONFLICT(agent_id, id) DO UPDATE SET payload_json = excluded.payload_json, notification_state = excluded.notification_state, updated_at = excluded.updated_at "#, - params![task.id, agent_id, to_json(task)?, notification_state, now_secs()], + params![agent_id, task.id, to_json(task)?, notification_state, now_secs()], ) .map_err(sqlite_error)?; Ok(()) @@ -422,8 +421,8 @@ impl BackgroundStore for SqliteRuntimeStore { }; for (id, _) in &jobs { tx.execute( - "UPDATE background_jobs SET notification_state = ?2 WHERE id = ?1", - params![id, DELIVERY_INFLIGHT], + "UPDATE background_jobs SET notification_state = ?3 WHERE agent_id = ?1 AND id = ?2", + params![agent_id, id, DELIVERY_INFLIGHT], ) .map_err(sqlite_error)?; } @@ -642,12 +641,13 @@ impl SqliteRuntimeStore { created_at INTEGER NOT NULL ); CREATE TABLE IF NOT EXISTS background_jobs ( - id TEXT PRIMARY KEY, agent_id TEXT NOT NULL, + id TEXT NOT NULL, payload_json TEXT NOT NULL, notification_state INTEGER NOT NULL, created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL + updated_at INTEGER NOT NULL, + PRIMARY KEY (agent_id, id) ); CREATE TABLE IF NOT EXISTS audit_events ( id TEXT PRIMARY KEY, @@ -682,7 +682,50 @@ impl SqliteRuntimeStore { ); "#, ) - .map_err(sqlite_error) + .map_err(sqlite_error)?; + self.migrate_background_jobs_schema(conn) + } + + fn migrate_background_jobs_schema(&self, conn: &Connection) -> Result<(), RuntimeError> { + let Some(schema_sql) = conn + .query_row( + "SELECT sql FROM sqlite_master WHERE type = 'table' AND name = 'background_jobs'", + [], + |row| row.get::<_, String>(0), + ) + .optional() + .map_err(sqlite_error)? + else { + return Ok(()); + }; + + if schema_sql.contains("PRIMARY KEY (agent_id, id)") + || schema_sql.contains("PRIMARY KEY(agent_id, id)") + { + return Ok(()); + } + + conn.execute_batch( + r#" + ALTER TABLE background_jobs RENAME TO background_jobs_legacy; + CREATE TABLE background_jobs ( + agent_id TEXT NOT NULL, + id TEXT NOT NULL, + payload_json TEXT NOT NULL, + notification_state INTEGER NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + PRIMARY KEY (agent_id, id) + ); + INSERT INTO background_jobs (agent_id, id, payload_json, notification_state, created_at, updated_at) + SELECT agent_id, id, payload_json, notification_state, created_at, updated_at + FROM background_jobs_legacy; + DROP TABLE background_jobs_legacy; + "#, + ) + .map_err(sqlite_error)?; + + Ok(()) } fn write_agent( @@ -797,21 +840,25 @@ impl AgentStore for SqliteRuntimeStore { { let mut stmt = tx - .prepare("SELECT id, payload_json FROM background_jobs") + .prepare("SELECT agent_id, id, payload_json FROM background_jobs") .map_err(sqlite_error)?; let rows = stmt .query_map([], |row| { - Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + )) }) .map_err(sqlite_error)?; for row in rows { - let (id, payload) = row.map_err(sqlite_error)?; + let (agent_id, id, payload) = row.map_err(sqlite_error)?; let mut task: BackgroundTaskSummary = from_json(&payload)?; if task.status == BackgroundTaskStatus::Running { task.status = BackgroundTaskStatus::Interrupted; tx.execute( - "UPDATE background_jobs SET payload_json = ?2, notification_state = ?3, updated_at = ?4 WHERE id = ?1", - params![id, to_json(&task)?, DELIVERY_PENDING, now_secs()], + "UPDATE background_jobs SET payload_json = ?3, notification_state = ?4, updated_at = ?5 WHERE agent_id = ?1 AND id = ?2", + params![agent_id, id, to_json(&task)?, DELIVERY_PENDING, now_secs()], ) .map_err(sqlite_error)?; } @@ -1420,13 +1467,14 @@ fn prune_stale_runtime_leases(tx: &rusqlite::Transaction<'_>) -> Result<(), Runt fn runtime_owner_is_stale(owner: &str) -> bool { let Some(pid) = owner .strip_prefix("runtime-") - .and_then(|value| value.parse::().ok()) + .and_then(|value| value.parse::().ok()) else { return false; }; #[cfg(unix)] { + let pid = pid as i32; let result = unsafe { libc::kill(pid, 0) }; if result == 0 { return false; @@ -1439,9 +1487,43 @@ fn runtime_owner_is_stale(owner: &str) -> bool { } } - #[cfg(not(unix))] + #[cfg(windows)] + { + use windows_sys::Win32::{ + Foundation::{CloseHandle, STILL_ACTIVE}, + System::Threading::{ + GetExitCodeProcess, OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION, + }, + }; + + const ERROR_ACCESS_DENIED: i32 = 5; + const ERROR_INVALID_PARAMETER: i32 = 87; + + unsafe { + let handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid); + if handle.is_null() { + return match std::io::Error::last_os_error().raw_os_error() { + Some(ERROR_INVALID_PARAMETER) => true, + Some(ERROR_ACCESS_DENIED) => false, + _ => false, + }; + } + + let mut exit_code = 0u32; + let result = GetExitCodeProcess(handle, &mut exit_code); + let close_result = CloseHandle(handle); + debug_assert_ne!(close_result, 0, "process handle should close"); + + if result == 0 { + return false; + } + + exit_code != STILL_ACTIVE as u32 + } + } + + #[cfg(not(any(unix, windows)))] { - let _ = pid; false } } @@ -1536,6 +1618,54 @@ mod tests { assert!(acquired); } + #[test] + fn background_tasks_are_scoped_per_agent() { + let store = SqliteRuntimeStore::new( + std::env::temp_dir().join(format!("mentra-store-background-{}.sqlite", now_nanos())), + ); + + store + .upsert_background_task( + "agent-a", + &BackgroundTaskSummary { + id: "bg-1".to_string(), + command: "echo a".to_string(), + cwd: std::env::temp_dir().join("a"), + status: BackgroundTaskStatus::Running, + output_preview: None, + }, + DELIVERY_ACKED, + ) + .expect("seed agent a background task"); + store + .upsert_background_task( + "agent-b", + &BackgroundTaskSummary { + id: "bg-1".to_string(), + command: "echo b".to_string(), + cwd: std::env::temp_dir().join("b"), + status: BackgroundTaskStatus::Finished, + output_preview: Some("done".to_string()), + }, + DELIVERY_PENDING, + ) + .expect("seed agent b background task"); + + let agent_a_tasks = store + .load_background_tasks("agent-a") + .expect("load agent a background tasks"); + let agent_b_tasks = store + .load_background_tasks("agent-b") + .expect("load agent b background tasks"); + + assert_eq!(agent_a_tasks.len(), 1); + assert_eq!(agent_a_tasks[0].command, "echo a"); + assert_eq!(agent_a_tasks[0].status, BackgroundTaskStatus::Running); + assert_eq!(agent_b_tasks.len(), 1); + assert_eq!(agent_b_tasks[0].command, "echo b"); + assert_eq!(agent_b_tasks[0].status, BackgroundTaskStatus::Finished); + } + #[test] fn fts_query_returns_none_when_input_has_no_searchable_terms() { assert_eq!(fts_query("... --- \"\""), None); diff --git a/mentra/src/tool/files/workspace.rs b/mentra/src/tool/files/workspace.rs index b62affc..76f44be 100644 --- a/mentra/src/tool/files/workspace.rs +++ b/mentra/src/tool/files/workspace.rs @@ -753,10 +753,10 @@ impl WorkspaceEditor { if rendered.is_empty() { ".".to_string() } else { - rendered + normalize_display_path(rendered) } } else { - path.display().to_string() + normalize_display_path(path.display().to_string()) } } @@ -767,12 +767,16 @@ impl WorkspaceEditor { .unwrap_or_else(|| ".".to_string()) } else { path.strip_prefix(root) - .map(|relative| relative.display().to_string()) + .map(|relative| normalize_display_path(relative.display().to_string())) .unwrap_or_else(|_| self.display_path(path)) } } } +fn normalize_display_path(path: String) -> String { + path.replace('\\', "/") +} + fn normalize_path(path: PathBuf) -> Result { let mut normalized = if path.is_absolute() { PathBuf::new()