diff --git a/src-tauri/src/bin/wraith_mcp_bridge.rs b/src-tauri/src/bin/wraith_mcp_bridge.rs index 68e9423..261490b 100644 --- a/src-tauri/src/bin/wraith_mcp_bridge.rs +++ b/src-tauri/src/bin/wraith_mcp_bridge.rs @@ -38,19 +38,22 @@ struct JsonRpcError { message: String, } -fn get_mcp_port() -> Result { - // Check standard locations for the port file - let port_file = if let Ok(appdata) = std::env::var("APPDATA") { - std::path::PathBuf::from(appdata).join("Wraith").join("mcp-port") +fn get_data_dir() -> Result { + if let Ok(appdata) = std::env::var("APPDATA") { + Ok(std::path::PathBuf::from(appdata).join("Wraith")) } else if let Ok(home) = std::env::var("HOME") { if cfg!(target_os = "macos") { - std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith").join("mcp-port") + Ok(std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith")) } else { - std::path::PathBuf::from(home).join(".local").join("share").join("wraith").join("mcp-port") + Ok(std::path::PathBuf::from(home).join(".local").join("share").join("wraith")) } } else { - return Err("Cannot determine data directory".to_string()); - }; + Err("Cannot determine data directory".to_string()) + } +} + +fn get_mcp_port() -> Result { + let port_file = get_data_dir()?.join("mcp-port"); let port_str = std::fs::read_to_string(&port_file) .map_err(|e| format!("Cannot read MCP port file at {}: {} — is Wraith running?", port_file.display(), e))?; @@ -59,6 +62,15 @@ fn get_mcp_port() -> Result { .map_err(|e| format!("Invalid port in MCP port file: {}", e)) } +fn get_mcp_token() -> Result { + let token_file = get_data_dir()?.join("mcp-token"); + + let token = std::fs::read_to_string(&token_file) + .map_err(|e| format!("Cannot read MCP token file at {}: {} — is Wraith running?", token_file.display(), e))?; + + Ok(token.trim().to_string()) +} + fn handle_initialize(id: Value) -> JsonRpcResponse { JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -304,12 +316,13 @@ fn handle_tools_list(id: Value) -> JsonRpcResponse { } } -fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result { +fn call_wraith(port: u16, token: &str, endpoint: &str, body: Value) -> Result { let url = format!("http://127.0.0.1:{}{}", port, endpoint); let body_str = serde_json::to_string(&body).unwrap_or_default(); let mut resp = ureq::post(url) .header("Content-Type", "application/json") + .header("Authorization", &format!("Bearer {}", token)) .send(body_str.as_bytes()) .map_err(|e| format!("HTTP request to Wraith failed: {}", e))?; @@ -327,40 +340,40 @@ fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result } } -fn handle_tool_call(id: Value, port: u16, tool_name: &str, args: &Value) -> JsonRpcResponse { +fn handle_tool_call(id: Value, port: u16, token: &str, tool_name: &str, args: &Value) -> JsonRpcResponse { let result = match tool_name { - "list_sessions" => call_wraith(port, "/mcp/sessions", serde_json::json!({})), - "terminal_type" => call_wraith(port, "/mcp/terminal/type", args.clone()), - "terminal_read" => call_wraith(port, "/mcp/terminal/read", args.clone()), - "terminal_execute" => call_wraith(port, "/mcp/terminal/execute", args.clone()), - "sftp_list" => call_wraith(port, "/mcp/sftp/list", args.clone()), - "sftp_read" => call_wraith(port, "/mcp/sftp/read", args.clone()), - "sftp_write" => call_wraith(port, "/mcp/sftp/write", args.clone()), - "network_scan" => call_wraith(port, "/mcp/tool/scan-network", args.clone()), - "port_scan" => call_wraith(port, "/mcp/tool/scan-ports", args.clone()), - "ping" => call_wraith(port, "/mcp/tool/ping", args.clone()), - "traceroute" => call_wraith(port, "/mcp/tool/traceroute", args.clone()), - "dns_lookup" => call_wraith(port, "/mcp/tool/dns", args.clone()), - "whois" => call_wraith(port, "/mcp/tool/whois", args.clone()), - "wake_on_lan" => call_wraith(port, "/mcp/tool/wol", args.clone()), - "bandwidth_test" => call_wraith(port, "/mcp/tool/bandwidth", args.clone()), - "subnet_calc" => call_wraith(port, "/mcp/tool/subnet", args.clone()), - "generate_ssh_key" => call_wraith(port, "/mcp/tool/keygen", args.clone()), - "generate_password" => call_wraith(port, "/mcp/tool/passgen", args.clone()), - "docker_ps" => call_wraith(port, "/mcp/docker/ps", args.clone()), - "docker_action" => call_wraith(port, "/mcp/docker/action", args.clone()), - "docker_exec" => call_wraith(port, "/mcp/docker/exec", args.clone()), - "service_status" => call_wraith(port, "/mcp/service/status", args.clone()), - "process_list" => call_wraith(port, "/mcp/process/list", args.clone()), - "git_status" => call_wraith(port, "/mcp/git/status", args.clone()), - "git_pull" => call_wraith(port, "/mcp/git/pull", args.clone()), - "git_log" => call_wraith(port, "/mcp/git/log", args.clone()), - "rdp_click" => call_wraith(port, "/mcp/rdp/click", args.clone()), - "rdp_type" => call_wraith(port, "/mcp/rdp/type", args.clone()), - "rdp_clipboard" => call_wraith(port, "/mcp/rdp/clipboard", args.clone()), - "ssh_connect" => call_wraith(port, "/mcp/ssh/connect", args.clone()), + "list_sessions" => call_wraith(port, token, "/mcp/sessions", serde_json::json!({})), + "terminal_type" => call_wraith(port, token, "/mcp/terminal/type", args.clone()), + "terminal_read" => call_wraith(port, token, "/mcp/terminal/read", args.clone()), + "terminal_execute" => call_wraith(port, token, "/mcp/terminal/execute", args.clone()), + "sftp_list" => call_wraith(port, token, "/mcp/sftp/list", args.clone()), + "sftp_read" => call_wraith(port, token, "/mcp/sftp/read", args.clone()), + "sftp_write" => call_wraith(port, token, "/mcp/sftp/write", args.clone()), + "network_scan" => call_wraith(port, token, "/mcp/tool/scan-network", args.clone()), + "port_scan" => call_wraith(port, token, "/mcp/tool/scan-ports", args.clone()), + "ping" => call_wraith(port, token, "/mcp/tool/ping", args.clone()), + "traceroute" => call_wraith(port, token, "/mcp/tool/traceroute", args.clone()), + "dns_lookup" => call_wraith(port, token, "/mcp/tool/dns", args.clone()), + "whois" => call_wraith(port, token, "/mcp/tool/whois", args.clone()), + "wake_on_lan" => call_wraith(port, token, "/mcp/tool/wol", args.clone()), + "bandwidth_test" => call_wraith(port, token, "/mcp/tool/bandwidth", args.clone()), + "subnet_calc" => call_wraith(port, token, "/mcp/tool/subnet", args.clone()), + "generate_ssh_key" => call_wraith(port, token, "/mcp/tool/keygen", args.clone()), + "generate_password" => call_wraith(port, token, "/mcp/tool/passgen", args.clone()), + "docker_ps" => call_wraith(port, token, "/mcp/docker/ps", args.clone()), + "docker_action" => call_wraith(port, token, "/mcp/docker/action", args.clone()), + "docker_exec" => call_wraith(port, token, "/mcp/docker/exec", args.clone()), + "service_status" => call_wraith(port, token, "/mcp/service/status", args.clone()), + "process_list" => call_wraith(port, token, "/mcp/process/list", args.clone()), + "git_status" => call_wraith(port, token, "/mcp/git/status", args.clone()), + "git_pull" => call_wraith(port, token, "/mcp/git/pull", args.clone()), + "git_log" => call_wraith(port, token, "/mcp/git/log", args.clone()), + "rdp_click" => call_wraith(port, token, "/mcp/rdp/click", args.clone()), + "rdp_type" => call_wraith(port, token, "/mcp/rdp/type", args.clone()), + "rdp_clipboard" => call_wraith(port, token, "/mcp/rdp/clipboard", args.clone()), + "ssh_connect" => call_wraith(port, token, "/mcp/ssh/connect", args.clone()), "terminal_screenshot" => { - let result = call_wraith(port, "/mcp/screenshot", args.clone()); + let result = call_wraith(port, token, "/mcp/screenshot", args.clone()); // Screenshot returns base64 PNG — wrap as image content for multimodal AI return match result { Ok(b64) => JsonRpcResponse { @@ -420,6 +433,14 @@ fn main() { } }; + let token = match get_mcp_token() { + Ok(t) => t, + Err(e) => { + eprintln!("wraith-mcp-bridge: {}", e); + std::process::exit(1); + } + }; + let stdin = io::stdin(); let mut stdout = io::stdout(); @@ -458,7 +479,7 @@ fn main() { let args = request.params.get("arguments") .cloned() .unwrap_or(Value::Object(serde_json::Map::new())); - handle_tool_call(request.id, port, tool_name, &args) + handle_tool_call(request.id, port, &token, tool_name, &args) } "notifications/initialized" | "notifications/cancelled" => { // Notifications don't get responses diff --git a/src-tauri/src/commands/docker_commands.rs b/src-tauri/src/commands/docker_commands.rs index 36c6ce9..b860765 100644 --- a/src-tauri/src/commands/docker_commands.rs +++ b/src-tauri/src/commands/docker_commands.rs @@ -3,6 +3,7 @@ use tauri::State; use serde::Serialize; use crate::AppState; +use crate::utils::shell_escape; #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] @@ -84,14 +85,15 @@ pub async fn docker_list_volumes(session_id: String, state: State<'_, AppState>) #[tauri::command] pub async fn docker_action(session_id: String, action: String, target: String, state: State<'_, AppState>) -> Result { let session = state.ssh.get_session(&session_id).ok_or("Session not found")?; + let t = shell_escape(&target); let cmd = match action.as_str() { - "start" => format!("docker start {} 2>&1", target), - "stop" => format!("docker stop {} 2>&1", target), - "restart" => format!("docker restart {} 2>&1", target), - "remove" => format!("docker rm -f {} 2>&1", target), - "logs" => format!("docker logs --tail 100 {} 2>&1", target), - "remove-image" => format!("docker rmi {} 2>&1", target), - "remove-volume" => format!("docker volume rm {} 2>&1", target), + "start" => format!("docker start {} 2>&1", t), + "stop" => format!("docker stop {} 2>&1", t), + "restart" => format!("docker restart {} 2>&1", t), + "remove" => format!("docker rm -f {} 2>&1", t), + "logs" => format!("docker logs --tail 100 {} 2>&1", t), + "remove-image" => format!("docker rmi {} 2>&1", t), + "remove-volume" => format!("docker volume rm {} 2>&1", t), "builder-prune" => "docker builder prune -f 2>&1".to_string(), "system-prune" => "docker system prune -f 2>&1".to_string(), "system-prune-all" => "docker system prune -a -f 2>&1".to_string(), diff --git a/src-tauri/src/commands/tools_commands.rs b/src-tauri/src/commands/tools_commands.rs index 531beca..7a0c1cb 100644 --- a/src-tauri/src/commands/tools_commands.rs +++ b/src-tauri/src/commands/tools_commands.rs @@ -4,6 +4,7 @@ use tauri::State; use serde::Serialize; use crate::AppState; +use crate::utils::shell_escape; // ── Ping ───────────────────────────────────────────────────────────────────── @@ -25,7 +26,7 @@ pub async fn tool_ping( let session = state.ssh.get_session(&session_id) .ok_or_else(|| format!("SSH session {} not found", session_id))?; let n = count.unwrap_or(4); - let cmd = format!("ping -c {} {} 2>&1", n, target); + let cmd = format!("ping -c {} {} 2>&1", n, shell_escape(&target)); let output = exec_on_session(&session.handle, &cmd).await?; Ok(PingResult { target, output }) } @@ -39,7 +40,8 @@ pub async fn tool_traceroute( ) -> Result { let session = state.ssh.get_session(&session_id) .ok_or_else(|| format!("SSH session {} not found", session_id))?; - let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", target, target); + let t = shell_escape(&target); + let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t); exec_on_session(&session.handle, &cmd).await } @@ -65,14 +67,16 @@ pub async fn tool_wake_on_lan( let cmd = format!( r#"python3 -c " import socket, struct -mac = bytes.fromhex('{mac_clean}') +mac = bytes.fromhex({mac_clean_escaped}) pkt = b'\xff'*6 + mac*16 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) s.sendto(pkt, ('255.255.255.255', 9)) s.close() -print('WoL packet sent to {mac_address}') -" 2>&1 || echo "python3 not available — install python3 on remote host for WoL""# +print('WoL packet sent to {mac_display_escaped}') +" 2>&1 || echo "python3 not available — install python3 on remote host for WoL""#, + mac_clean_escaped = shell_escape(&mac_clean), + mac_display_escaped = shell_escape(&mac_address), ); exec_on_session(&session.handle, &cmd).await diff --git a/src-tauri/src/commands/tools_commands_r2.rs b/src-tauri/src/commands/tools_commands_r2.rs index d1307f1..e487f9a 100644 --- a/src-tauri/src/commands/tools_commands_r2.rs +++ b/src-tauri/src/commands/tools_commands_r2.rs @@ -4,6 +4,7 @@ use tauri::State; use serde::Serialize; use crate::AppState; +use crate::utils::shell_escape; // ── DNS Lookup ─────────────────────────────────────────────────────────────── @@ -16,10 +17,11 @@ pub async fn tool_dns_lookup( ) -> Result { let session = state.ssh.get_session(&session_id) .ok_or_else(|| format!("SSH session {} not found", session_id))?; - let rtype = record_type.unwrap_or_else(|| "A".to_string()); + let d = shell_escape(&domain); + let rt = shell_escape(&record_type.unwrap_or_else(|| "A".to_string())); let cmd = format!( r#"dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null"#, - domain, rtype, rtype, domain, rtype, domain + d, rt, rt, d, rt, d ); exec_on_session(&session.handle, &cmd).await } @@ -34,7 +36,7 @@ pub async fn tool_whois( ) -> Result { let session = state.ssh.get_session(&session_id) .ok_or_else(|| format!("SSH session {} not found", session_id))?; - let cmd = format!("whois {} 2>&1 | head -80", target); + let cmd = format!("whois {} 2>&1 | head -80", shell_escape(&target)); exec_on_session(&session.handle, &cmd).await } @@ -50,9 +52,10 @@ pub async fn tool_bandwidth_iperf( let session = state.ssh.get_session(&session_id) .ok_or_else(|| format!("SSH session {} not found", session_id))?; let dur = duration.unwrap_or(5); + let s = shell_escape(&server); let cmd = format!( "iperf3 -c {} -t {} --json 2>/dev/null || iperf3 -c {} -t {} 2>&1 || echo 'iperf3 not installed — run: apt install iperf3 / brew install iperf3'", - server, dur, server, dur + s, dur, s, dur ); exec_on_session(&session.handle, &cmd).await } diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 21a9a59..621b1d9 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -21,6 +21,7 @@ pub mod pty; pub mod mcp; pub mod scanner; pub mod commands; +pub mod utils; use std::path::PathBuf; use std::sync::Mutex; diff --git a/src-tauri/src/mcp/server.rs b/src-tauri/src/mcp/server.rs index 8f30231..112a5f8 100644 --- a/src-tauri/src/mcp/server.rs +++ b/src-tauri/src/mcp/server.rs @@ -5,7 +5,14 @@ use std::sync::Arc; -use axum::{extract::State as AxumState, routing::post, Json, Router}; +use axum::{ + extract::State as AxumState, + http::{Request, StatusCode}, + middleware::{self, Next}, + response::Response, + routing::post, + Json, Router, +}; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; @@ -13,6 +20,7 @@ use crate::mcp::ScrollbackRegistry; use crate::rdp::RdpService; use crate::sftp::SftpService; use crate::ssh::session::SshService; +use crate::utils::shell_escape; /// Shared state passed to axum handlers. pub struct McpServerState { @@ -22,6 +30,27 @@ pub struct McpServerState { pub scrollback: ScrollbackRegistry, pub app_handle: tauri::AppHandle, pub error_watcher: std::sync::Arc, + pub bearer_token: String, +} + +/// Middleware that validates the `Authorization: Bearer ` header. +async fn auth_middleware( + AxumState(state): AxumState>, + req: Request, + next: Next, +) -> Result { + let auth_header = req + .headers() + .get("authorization") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + let expected = format!("Bearer {}", state.bearer_token); + if auth_header != expected { + return Err(StatusCode::UNAUTHORIZED); + } + + Ok(next.run(req).await) } #[derive(Deserialize)] @@ -279,29 +308,31 @@ struct ToolPassgenRequest { length: Option, uppercase: Option, lowe async fn handle_tool_ping(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_traceroute(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + let t = shell_escape(&req.target); + match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_dns(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - let rt = req.record_type.unwrap_or_else(|| "A".to_string()); - match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", req.domain, rt, rt, req.domain, rt, req.domain)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + let rt = shell_escape(&req.record_type.unwrap_or_else(|| "A".to_string())); + let d = shell_escape(&req.domain); + match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", d, rt, rt, d, rt, d)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_whois(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_wol(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let mac_clean = req.mac_address.replace([':', '-'], ""); - let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex('{}');pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, mac_clean, req.mac_address); + let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex({});pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, shell_escape(&mac_clean), shell_escape(&req.mac_address)); match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } @@ -382,12 +413,13 @@ async fn handle_docker_ps(AxumState(state): AxumState>, Json async fn handle_docker_action(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + let t = shell_escape(&req.target); let cmd = match req.action.as_str() { - "start" => format!("docker start {} 2>&1", req.target), - "stop" => format!("docker stop {} 2>&1", req.target), - "restart" => format!("docker restart {} 2>&1", req.target), - "remove" => format!("docker rm -f {} 2>&1", req.target), - "logs" => format!("docker logs --tail 100 {} 2>&1", req.target), + "start" => format!("docker start {} 2>&1", t), + "stop" => format!("docker stop {} 2>&1", t), + "restart" => format!("docker restart {} 2>&1", t), + "remove" => format!("docker rm -f {} 2>&1", t), + "logs" => format!("docker logs --tail 100 {} 2>&1", t), "builder-prune" => "docker builder prune -f 2>&1".to_string(), "system-prune" => "docker system prune -f 2>&1".to_string(), _ => return err_response(format!("Unknown action: {}", req.action)), @@ -397,7 +429,7 @@ async fn handle_docker_action(AxumState(state): AxumState>, async fn handle_docker_exec(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - let cmd = format!("docker exec {} {} 2>&1", req.container, req.command); + let cmd = format!("docker exec {} {} 2>&1", shell_escape(&req.container), shell_escape(&req.command)); match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } @@ -405,12 +437,13 @@ async fn handle_docker_exec(AxumState(state): AxumState>, Js async fn handle_service_status(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + let t = shell_escape(&req.target); + match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_process_list(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", req.target) }; + let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", shell_escape(&req.target)) }; match tool_exec(&session.handle, &format!("ps {}", filter)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } @@ -421,17 +454,17 @@ struct GitRequest { session_id: String, path: String } async fn handle_git_status(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_git_pull(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_git_log(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; - match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } + match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } // ── Session creation handlers ──────────────────────────────────────────────── @@ -533,7 +566,15 @@ pub async fn start_mcp_server( app_handle: tauri::AppHandle, error_watcher: std::sync::Arc, ) -> Result { - let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher }); + // Generate a cryptographically random bearer token for authentication + use rand::Rng; + let bearer_token: String = rand::rng() + .sample_iter(&rand::distr::Alphanumeric) + .take(64) + .map(char::from) + .collect(); + + let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher, bearer_token: bearer_token.clone() }); let app = Router::new() .route("/mcp/sessions", post(handle_list_sessions)) @@ -567,6 +608,7 @@ pub async fn start_mcp_server( .route("/mcp/rdp/type", post(handle_rdp_type)) .route("/mcp/rdp/clipboard", post(handle_rdp_clipboard)) .route("/mcp/ssh/connect", post(handle_ssh_connect)) + .layer(middleware::from_fn_with_state(state.clone(), auth_middleware)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await @@ -577,10 +619,23 @@ pub async fn start_mcp_server( .port(); // Write port to well-known location - let port_file = crate::data_directory().join("mcp-port"); + let data_dir = crate::data_directory(); + let port_file = data_dir.join("mcp-port"); std::fs::write(&port_file, port.to_string()) .map_err(|e| format!("Failed to write MCP port file: {}", e))?; + // Write bearer token to a separate file with restrictive permissions + let token_file = data_dir.join("mcp-token"); + std::fs::write(&token_file, &bearer_token) + .map_err(|e| format!("Failed to write MCP token file: {}", e))?; + + // Set owner-only read/write permissions (Unix) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = std::fs::set_permissions(&token_file, std::fs::Permissions::from_mode(0o600)); + } + tokio::spawn(async move { axum::serve(listener, app).await.ok(); }); diff --git a/src-tauri/src/scanner/mod.rs b/src-tauri/src/scanner/mod.rs index 1698ec5..4ce44d7 100644 --- a/src-tauri/src/scanner/mod.rs +++ b/src-tauri/src/scanner/mod.rs @@ -12,6 +12,7 @@ use serde::Serialize; use tokio::sync::Mutex as TokioMutex; use crate::ssh::session::SshClient; +use crate::utils::shell_escape; #[derive(Debug, Serialize, Clone)] #[serde(rename_all = "camelCase")] @@ -72,9 +73,10 @@ pub async fn scan_network( // 1. Ping sweep the subnet to populate ARP cache // 2. Read ARP table for IP/MAC pairs // 3. Try reverse DNS for hostnames + let escaped_subnet = shell_escape(subnet); let script = format!(r#" OS=$(uname -s 2>/dev/null) -SUBNET="{subnet}" +SUBNET={escaped_subnet} # Ping sweep (background, fast) if [ "$OS" = "Linux" ]; then @@ -151,6 +153,12 @@ pub async fn scan_ports( target: &str, ports: &[u16], ) -> Result, String> { + // Validate target — /dev/tcp requires a bare hostname/IP, not a shell-quoted value. + // Only allow alphanumeric, dots, hyphens, and colons (for IPv6). + if !target.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == ':') { + return Err(format!("Invalid target for port scan: {}", target)); + } + // Use bash /dev/tcp for port scanning — no nmap required let port_checks: Vec = ports.iter() .map(|p| format!( diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs new file mode 100644 index 0000000..8629b17 --- /dev/null +++ b/src-tauri/src/utils.rs @@ -0,0 +1,19 @@ +//! Shared utility functions. + +/// Escape a string for safe interpolation into a POSIX shell command. +/// +/// Wraps the input in single quotes and escapes any embedded single quotes +/// using the `'\''` technique. This prevents command injection when building +/// shell commands from user-supplied values. +/// +/// # Examples +/// +/// ``` +/// # use wraith_lib::utils::shell_escape; +/// assert_eq!(shell_escape("hello"), "'hello'"); +/// assert_eq!(shell_escape("it's"), "'it'\\''s'"); +/// assert_eq!(shell_escape(";rm -rf /"), "';rm -rf /'"); +/// ``` +pub fn shell_escape(input: &str) -> String { + format!("'{}'", input.replace('\'', "'\\''")) +}