fix: SEC-1/SEC-2 shell escape utility + MCP bearer token auth

- New shell_escape() utility for safe command interpolation
- Applied across all MCP tools, docker, scanner, network commands
- MCP server generates random bearer token at startup
- Token written to mcp-token file with 0600 permissions
- All MCP HTTP requests require Authorization header
- Bridge binary reads token and sends on every request

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Vantz Stockwell 2026-03-29 16:40:13 -04:00
parent 1b7b1a0051
commit 17973fc3dc
8 changed files with 192 additions and 79 deletions

View File

@ -38,19 +38,22 @@ struct JsonRpcError {
message: String,
}
fn get_mcp_port() -> Result<u16, String> {
// Check standard locations for the port file
let port_file = if let Ok(appdata) = std::env::var("APPDATA") {
std::path::PathBuf::from(appdata).join("Wraith").join("mcp-port")
fn get_data_dir() -> Result<std::path::PathBuf, String> {
if let Ok(appdata) = std::env::var("APPDATA") {
Ok(std::path::PathBuf::from(appdata).join("Wraith"))
} else if let Ok(home) = std::env::var("HOME") {
if cfg!(target_os = "macos") {
std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith").join("mcp-port")
Ok(std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith"))
} else {
std::path::PathBuf::from(home).join(".local").join("share").join("wraith").join("mcp-port")
Ok(std::path::PathBuf::from(home).join(".local").join("share").join("wraith"))
}
} else {
return Err("Cannot determine data directory".to_string());
};
Err("Cannot determine data directory".to_string())
}
}
fn get_mcp_port() -> Result<u16, String> {
let port_file = get_data_dir()?.join("mcp-port");
let port_str = std::fs::read_to_string(&port_file)
.map_err(|e| format!("Cannot read MCP port file at {}: {} — is Wraith running?", port_file.display(), e))?;
@ -59,6 +62,15 @@ fn get_mcp_port() -> Result<u16, String> {
.map_err(|e| format!("Invalid port in MCP port file: {}", e))
}
fn get_mcp_token() -> Result<String, String> {
let token_file = get_data_dir()?.join("mcp-token");
let token = std::fs::read_to_string(&token_file)
.map_err(|e| format!("Cannot read MCP token file at {}: {} — is Wraith running?", token_file.display(), e))?;
Ok(token.trim().to_string())
}
fn handle_initialize(id: Value) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
@ -304,12 +316,13 @@ fn handle_tools_list(id: Value) -> JsonRpcResponse {
}
}
fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result<Value, String> {
fn call_wraith(port: u16, token: &str, endpoint: &str, body: Value) -> Result<Value, String> {
let url = format!("http://127.0.0.1:{}{}", port, endpoint);
let body_str = serde_json::to_string(&body).unwrap_or_default();
let mut resp = ureq::post(url)
.header("Content-Type", "application/json")
.header("Authorization", &format!("Bearer {}", token))
.send(body_str.as_bytes())
.map_err(|e| format!("HTTP request to Wraith failed: {}", e))?;
@ -327,40 +340,40 @@ fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result<Value, String>
}
}
fn handle_tool_call(id: Value, port: u16, tool_name: &str, args: &Value) -> JsonRpcResponse {
fn handle_tool_call(id: Value, port: u16, token: &str, tool_name: &str, args: &Value) -> JsonRpcResponse {
let result = match tool_name {
"list_sessions" => call_wraith(port, "/mcp/sessions", serde_json::json!({})),
"terminal_type" => call_wraith(port, "/mcp/terminal/type", args.clone()),
"terminal_read" => call_wraith(port, "/mcp/terminal/read", args.clone()),
"terminal_execute" => call_wraith(port, "/mcp/terminal/execute", args.clone()),
"sftp_list" => call_wraith(port, "/mcp/sftp/list", args.clone()),
"sftp_read" => call_wraith(port, "/mcp/sftp/read", args.clone()),
"sftp_write" => call_wraith(port, "/mcp/sftp/write", args.clone()),
"network_scan" => call_wraith(port, "/mcp/tool/scan-network", args.clone()),
"port_scan" => call_wraith(port, "/mcp/tool/scan-ports", args.clone()),
"ping" => call_wraith(port, "/mcp/tool/ping", args.clone()),
"traceroute" => call_wraith(port, "/mcp/tool/traceroute", args.clone()),
"dns_lookup" => call_wraith(port, "/mcp/tool/dns", args.clone()),
"whois" => call_wraith(port, "/mcp/tool/whois", args.clone()),
"wake_on_lan" => call_wraith(port, "/mcp/tool/wol", args.clone()),
"bandwidth_test" => call_wraith(port, "/mcp/tool/bandwidth", args.clone()),
"subnet_calc" => call_wraith(port, "/mcp/tool/subnet", args.clone()),
"generate_ssh_key" => call_wraith(port, "/mcp/tool/keygen", args.clone()),
"generate_password" => call_wraith(port, "/mcp/tool/passgen", args.clone()),
"docker_ps" => call_wraith(port, "/mcp/docker/ps", args.clone()),
"docker_action" => call_wraith(port, "/mcp/docker/action", args.clone()),
"docker_exec" => call_wraith(port, "/mcp/docker/exec", args.clone()),
"service_status" => call_wraith(port, "/mcp/service/status", args.clone()),
"process_list" => call_wraith(port, "/mcp/process/list", args.clone()),
"git_status" => call_wraith(port, "/mcp/git/status", args.clone()),
"git_pull" => call_wraith(port, "/mcp/git/pull", args.clone()),
"git_log" => call_wraith(port, "/mcp/git/log", args.clone()),
"rdp_click" => call_wraith(port, "/mcp/rdp/click", args.clone()),
"rdp_type" => call_wraith(port, "/mcp/rdp/type", args.clone()),
"rdp_clipboard" => call_wraith(port, "/mcp/rdp/clipboard", args.clone()),
"ssh_connect" => call_wraith(port, "/mcp/ssh/connect", args.clone()),
"list_sessions" => call_wraith(port, token, "/mcp/sessions", serde_json::json!({})),
"terminal_type" => call_wraith(port, token, "/mcp/terminal/type", args.clone()),
"terminal_read" => call_wraith(port, token, "/mcp/terminal/read", args.clone()),
"terminal_execute" => call_wraith(port, token, "/mcp/terminal/execute", args.clone()),
"sftp_list" => call_wraith(port, token, "/mcp/sftp/list", args.clone()),
"sftp_read" => call_wraith(port, token, "/mcp/sftp/read", args.clone()),
"sftp_write" => call_wraith(port, token, "/mcp/sftp/write", args.clone()),
"network_scan" => call_wraith(port, token, "/mcp/tool/scan-network", args.clone()),
"port_scan" => call_wraith(port, token, "/mcp/tool/scan-ports", args.clone()),
"ping" => call_wraith(port, token, "/mcp/tool/ping", args.clone()),
"traceroute" => call_wraith(port, token, "/mcp/tool/traceroute", args.clone()),
"dns_lookup" => call_wraith(port, token, "/mcp/tool/dns", args.clone()),
"whois" => call_wraith(port, token, "/mcp/tool/whois", args.clone()),
"wake_on_lan" => call_wraith(port, token, "/mcp/tool/wol", args.clone()),
"bandwidth_test" => call_wraith(port, token, "/mcp/tool/bandwidth", args.clone()),
"subnet_calc" => call_wraith(port, token, "/mcp/tool/subnet", args.clone()),
"generate_ssh_key" => call_wraith(port, token, "/mcp/tool/keygen", args.clone()),
"generate_password" => call_wraith(port, token, "/mcp/tool/passgen", args.clone()),
"docker_ps" => call_wraith(port, token, "/mcp/docker/ps", args.clone()),
"docker_action" => call_wraith(port, token, "/mcp/docker/action", args.clone()),
"docker_exec" => call_wraith(port, token, "/mcp/docker/exec", args.clone()),
"service_status" => call_wraith(port, token, "/mcp/service/status", args.clone()),
"process_list" => call_wraith(port, token, "/mcp/process/list", args.clone()),
"git_status" => call_wraith(port, token, "/mcp/git/status", args.clone()),
"git_pull" => call_wraith(port, token, "/mcp/git/pull", args.clone()),
"git_log" => call_wraith(port, token, "/mcp/git/log", args.clone()),
"rdp_click" => call_wraith(port, token, "/mcp/rdp/click", args.clone()),
"rdp_type" => call_wraith(port, token, "/mcp/rdp/type", args.clone()),
"rdp_clipboard" => call_wraith(port, token, "/mcp/rdp/clipboard", args.clone()),
"ssh_connect" => call_wraith(port, token, "/mcp/ssh/connect", args.clone()),
"terminal_screenshot" => {
let result = call_wraith(port, "/mcp/screenshot", args.clone());
let result = call_wraith(port, token, "/mcp/screenshot", args.clone());
// Screenshot returns base64 PNG — wrap as image content for multimodal AI
return match result {
Ok(b64) => JsonRpcResponse {
@ -420,6 +433,14 @@ fn main() {
}
};
let token = match get_mcp_token() {
Ok(t) => t,
Err(e) => {
eprintln!("wraith-mcp-bridge: {}", e);
std::process::exit(1);
}
};
let stdin = io::stdin();
let mut stdout = io::stdout();
@ -458,7 +479,7 @@ fn main() {
let args = request.params.get("arguments")
.cloned()
.unwrap_or(Value::Object(serde_json::Map::new()));
handle_tool_call(request.id, port, tool_name, &args)
handle_tool_call(request.id, port, &token, tool_name, &args)
}
"notifications/initialized" | "notifications/cancelled" => {
// Notifications don't get responses

View File

@ -3,6 +3,7 @@
use tauri::State;
use serde::Serialize;
use crate::AppState;
use crate::utils::shell_escape;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
@ -84,14 +85,15 @@ pub async fn docker_list_volumes(session_id: String, state: State<'_, AppState>)
#[tauri::command]
pub async fn docker_action(session_id: String, action: String, target: String, state: State<'_, AppState>) -> Result<String, String> {
let session = state.ssh.get_session(&session_id).ok_or("Session not found")?;
let t = shell_escape(&target);
let cmd = match action.as_str() {
"start" => format!("docker start {} 2>&1", target),
"stop" => format!("docker stop {} 2>&1", target),
"restart" => format!("docker restart {} 2>&1", target),
"remove" => format!("docker rm -f {} 2>&1", target),
"logs" => format!("docker logs --tail 100 {} 2>&1", target),
"remove-image" => format!("docker rmi {} 2>&1", target),
"remove-volume" => format!("docker volume rm {} 2>&1", target),
"start" => format!("docker start {} 2>&1", t),
"stop" => format!("docker stop {} 2>&1", t),
"restart" => format!("docker restart {} 2>&1", t),
"remove" => format!("docker rm -f {} 2>&1", t),
"logs" => format!("docker logs --tail 100 {} 2>&1", t),
"remove-image" => format!("docker rmi {} 2>&1", t),
"remove-volume" => format!("docker volume rm {} 2>&1", t),
"builder-prune" => "docker builder prune -f 2>&1".to_string(),
"system-prune" => "docker system prune -f 2>&1".to_string(),
"system-prune-all" => "docker system prune -a -f 2>&1".to_string(),

View File

@ -4,6 +4,7 @@ use tauri::State;
use serde::Serialize;
use crate::AppState;
use crate::utils::shell_escape;
// ── Ping ─────────────────────────────────────────────────────────────────────
@ -25,7 +26,7 @@ pub async fn tool_ping(
let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
let n = count.unwrap_or(4);
let cmd = format!("ping -c {} {} 2>&1", n, target);
let cmd = format!("ping -c {} {} 2>&1", n, shell_escape(&target));
let output = exec_on_session(&session.handle, &cmd).await?;
Ok(PingResult { target, output })
}
@ -39,7 +40,8 @@ pub async fn tool_traceroute(
) -> Result<String, String> {
let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", target, target);
let t = shell_escape(&target);
let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t);
exec_on_session(&session.handle, &cmd).await
}
@ -65,14 +67,16 @@ pub async fn tool_wake_on_lan(
let cmd = format!(
r#"python3 -c "
import socket, struct
mac = bytes.fromhex('{mac_clean}')
mac = bytes.fromhex({mac_clean_escaped})
pkt = b'\xff'*6 + mac*16
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
s.sendto(pkt, ('255.255.255.255', 9))
s.close()
print('WoL packet sent to {mac_address}')
" 2>&1 || echo "python3 not available install python3 on remote host for WoL""#
print('WoL packet sent to {mac_display_escaped}')
" 2>&1 || echo "python3 not available install python3 on remote host for WoL""#,
mac_clean_escaped = shell_escape(&mac_clean),
mac_display_escaped = shell_escape(&mac_address),
);
exec_on_session(&session.handle, &cmd).await

View File

@ -4,6 +4,7 @@ use tauri::State;
use serde::Serialize;
use crate::AppState;
use crate::utils::shell_escape;
// ── DNS Lookup ───────────────────────────────────────────────────────────────
@ -16,10 +17,11 @@ pub async fn tool_dns_lookup(
) -> Result<String, String> {
let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
let rtype = record_type.unwrap_or_else(|| "A".to_string());
let d = shell_escape(&domain);
let rt = shell_escape(&record_type.unwrap_or_else(|| "A".to_string()));
let cmd = format!(
r#"dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null"#,
domain, rtype, rtype, domain, rtype, domain
d, rt, rt, d, rt, d
);
exec_on_session(&session.handle, &cmd).await
}
@ -34,7 +36,7 @@ pub async fn tool_whois(
) -> Result<String, String> {
let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
let cmd = format!("whois {} 2>&1 | head -80", target);
let cmd = format!("whois {} 2>&1 | head -80", shell_escape(&target));
exec_on_session(&session.handle, &cmd).await
}
@ -50,9 +52,10 @@ pub async fn tool_bandwidth_iperf(
let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
let dur = duration.unwrap_or(5);
let s = shell_escape(&server);
let cmd = format!(
"iperf3 -c {} -t {} --json 2>/dev/null || iperf3 -c {} -t {} 2>&1 || echo 'iperf3 not installed — run: apt install iperf3 / brew install iperf3'",
server, dur, server, dur
s, dur, s, dur
);
exec_on_session(&session.handle, &cmd).await
}

View File

@ -21,6 +21,7 @@ pub mod pty;
pub mod mcp;
pub mod scanner;
pub mod commands;
pub mod utils;
use std::path::PathBuf;
use std::sync::Mutex;

View File

@ -5,7 +5,14 @@
use std::sync::Arc;
use axum::{extract::State as AxumState, routing::post, Json, Router};
use axum::{
extract::State as AxumState,
http::{Request, StatusCode},
middleware::{self, Next},
response::Response,
routing::post,
Json, Router,
};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
@ -13,6 +20,7 @@ use crate::mcp::ScrollbackRegistry;
use crate::rdp::RdpService;
use crate::sftp::SftpService;
use crate::ssh::session::SshService;
use crate::utils::shell_escape;
/// Shared state passed to axum handlers.
pub struct McpServerState {
@ -22,6 +30,27 @@ pub struct McpServerState {
pub scrollback: ScrollbackRegistry,
pub app_handle: tauri::AppHandle,
pub error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
pub bearer_token: String,
}
/// Middleware that validates the `Authorization: Bearer <token>` header.
async fn auth_middleware(
AxumState(state): AxumState<Arc<McpServerState>>,
req: Request<axum::body::Body>,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let expected = format!("Bearer {}", state.bearer_token);
if auth_header != expected {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(req).await)
}
#[derive(Deserialize)]
@ -279,29 +308,31 @@ struct ToolPassgenRequest { length: Option<usize>, uppercase: Option<bool>, lowe
async fn handle_tool_ping(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
async fn handle_tool_traceroute(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
let t = shell_escape(&req.target);
match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
async fn handle_tool_dns(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolDnsRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let rt = req.record_type.unwrap_or_else(|| "A".to_string());
match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", req.domain, rt, rt, req.domain, rt, req.domain)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
let rt = shell_escape(&req.record_type.unwrap_or_else(|| "A".to_string()));
let d = shell_escape(&req.domain);
match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", d, rt, rt, d, rt, d)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
async fn handle_tool_whois(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
async fn handle_tool_wol(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolWolRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let mac_clean = req.mac_address.replace([':', '-'], "");
let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex('{}');pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, mac_clean, req.mac_address);
let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex({});pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, shell_escape(&mac_clean), shell_escape(&req.mac_address));
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
@ -382,12 +413,13 @@ async fn handle_docker_ps(AxumState(state): AxumState<Arc<McpServerState>>, Json
async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerActionRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let t = shell_escape(&req.target);
let cmd = match req.action.as_str() {
"start" => format!("docker start {} 2>&1", req.target),
"stop" => format!("docker stop {} 2>&1", req.target),
"restart" => format!("docker restart {} 2>&1", req.target),
"remove" => format!("docker rm -f {} 2>&1", req.target),
"logs" => format!("docker logs --tail 100 {} 2>&1", req.target),
"start" => format!("docker start {} 2>&1", t),
"stop" => format!("docker stop {} 2>&1", t),
"restart" => format!("docker restart {} 2>&1", t),
"remove" => format!("docker rm -f {} 2>&1", t),
"logs" => format!("docker logs --tail 100 {} 2>&1", t),
"builder-prune" => "docker builder prune -f 2>&1".to_string(),
"system-prune" => "docker system prune -f 2>&1".to_string(),
_ => return err_response(format!("Unknown action: {}", req.action)),
@ -397,7 +429,7 @@ async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>,
async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerExecRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let cmd = format!("docker exec {} {} 2>&1", req.container, req.command);
let cmd = format!("docker exec {} {} 2>&1", shell_escape(&req.container), shell_escape(&req.command));
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
@ -405,12 +437,13 @@ async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Js
async fn handle_service_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
let t = shell_escape(&req.target);
match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
async fn handle_process_list(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", req.target) };
let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", shell_escape(&req.target)) };
match tool_exec(&session.handle, &format!("ps {}", filter)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
@ -421,17 +454,17 @@ struct GitRequest { session_id: String, path: String }
async fn handle_git_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
async fn handle_git_pull(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
async fn handle_git_log(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
}
// ── Session creation handlers ────────────────────────────────────────────────
@ -533,7 +566,15 @@ pub async fn start_mcp_server(
app_handle: tauri::AppHandle,
error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
) -> Result<u16, String> {
let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher });
// Generate a cryptographically random bearer token for authentication
use rand::Rng;
let bearer_token: String = rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
.take(64)
.map(char::from)
.collect();
let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher, bearer_token: bearer_token.clone() });
let app = Router::new()
.route("/mcp/sessions", post(handle_list_sessions))
@ -567,6 +608,7 @@ pub async fn start_mcp_server(
.route("/mcp/rdp/type", post(handle_rdp_type))
.route("/mcp/rdp/clipboard", post(handle_rdp_clipboard))
.route("/mcp/ssh/connect", post(handle_ssh_connect))
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware))
.with_state(state);
let listener = TcpListener::bind("127.0.0.1:0").await
@ -577,10 +619,23 @@ pub async fn start_mcp_server(
.port();
// Write port to well-known location
let port_file = crate::data_directory().join("mcp-port");
let data_dir = crate::data_directory();
let port_file = data_dir.join("mcp-port");
std::fs::write(&port_file, port.to_string())
.map_err(|e| format!("Failed to write MCP port file: {}", e))?;
// Write bearer token to a separate file with restrictive permissions
let token_file = data_dir.join("mcp-token");
std::fs::write(&token_file, &bearer_token)
.map_err(|e| format!("Failed to write MCP token file: {}", e))?;
// Set owner-only read/write permissions (Unix)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&token_file, std::fs::Permissions::from_mode(0o600));
}
tokio::spawn(async move {
axum::serve(listener, app).await.ok();
});

View File

@ -12,6 +12,7 @@ use serde::Serialize;
use tokio::sync::Mutex as TokioMutex;
use crate::ssh::session::SshClient;
use crate::utils::shell_escape;
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
@ -72,9 +73,10 @@ pub async fn scan_network(
// 1. Ping sweep the subnet to populate ARP cache
// 2. Read ARP table for IP/MAC pairs
// 3. Try reverse DNS for hostnames
let escaped_subnet = shell_escape(subnet);
let script = format!(r#"
OS=$(uname -s 2>/dev/null)
SUBNET="{subnet}"
SUBNET={escaped_subnet}
# Ping sweep (background, fast)
if [ "$OS" = "Linux" ]; then
@ -151,6 +153,12 @@ pub async fn scan_ports(
target: &str,
ports: &[u16],
) -> Result<Vec<PortResult>, String> {
// Validate target — /dev/tcp requires a bare hostname/IP, not a shell-quoted value.
// Only allow alphanumeric, dots, hyphens, and colons (for IPv6).
if !target.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == ':') {
return Err(format!("Invalid target for port scan: {}", target));
}
// Use bash /dev/tcp for port scanning — no nmap required
let port_checks: Vec<String> = ports.iter()
.map(|p| format!(

19
src-tauri/src/utils.rs Normal file
View File

@ -0,0 +1,19 @@
//! Shared utility functions.
/// Escape a string for safe interpolation into a POSIX shell command.
///
/// Wraps the input in single quotes and escapes any embedded single quotes
/// using the `'\''` technique. This prevents command injection when building
/// shell commands from user-supplied values.
///
/// # Examples
///
/// ```
/// # use wraith_lib::utils::shell_escape;
/// assert_eq!(shell_escape("hello"), "'hello'");
/// assert_eq!(shell_escape("it's"), "'it'\\''s'");
/// assert_eq!(shell_escape(";rm -rf /"), "';rm -rf /'");
/// ```
pub fn shell_escape(input: &str) -> String {
format!("'{}'", input.replace('\'', "'\\''"))
}