//! Tiny HTTP server for MCP bridge communication. //! //! Runs on localhost:0 (random port) at Tauri startup. The port is written //! to ~/.wraith/mcp-port so the bridge binary can find it. use std::sync::Arc; use axum::{extract::State as AxumState, routing::post, Json, Router}; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use crate::mcp::ScrollbackRegistry; use crate::rdp::RdpService; use crate::sftp::SftpService; use crate::ssh::session::SshService; /// Shared state passed to axum handlers. pub struct McpServerState { pub ssh: SshService, pub rdp: RdpService, pub sftp: SftpService, pub scrollback: ScrollbackRegistry, pub app_handle: tauri::AppHandle, pub error_watcher: std::sync::Arc, } #[derive(Deserialize)] struct TerminalReadRequest { session_id: String, lines: Option, } #[derive(Deserialize)] struct ScreenshotRequest { session_id: String, } #[derive(Deserialize)] struct SftpListRequest { session_id: String, path: String, } #[derive(Deserialize)] struct SftpReadRequest { session_id: String, path: String, } #[derive(Deserialize)] struct SftpWriteRequest { session_id: String, path: String, content: String, } #[derive(Deserialize)] struct TerminalTypeRequest { session_id: String, text: String, press_enter: Option, } #[derive(Deserialize)] struct TerminalExecuteRequest { session_id: String, command: String, timeout_ms: Option, } #[derive(Serialize)] struct McpResponse { ok: bool, data: Option, error: Option, } fn ok_response(data: T) -> Json> { Json(McpResponse { ok: true, data: Some(data), error: None }) } fn err_response(msg: String) -> Json> { Json(McpResponse { ok: false, data: None, error: Some(msg) }) } async fn handle_list_sessions( AxumState(state): AxumState>, ) -> Json>> { let mut sessions: Vec = state.ssh.list_sessions() .into_iter() .map(|s| serde_json::json!({ "id": s.id, "type": "ssh", "name": format!("{}@{}:{}", s.username, s.hostname, s.port), "host": s.hostname, "username": s.username, })) .collect(); // Include RDP sessions for s in state.rdp.list_sessions() { sessions.push(serde_json::json!({ "id": s.id, "type": "rdp", "name": s.hostname.clone(), "host": s.hostname, "width": s.width, "height": s.height, })); } ok_response(sessions) } async fn handle_sftp_list( AxumState(state): AxumState>, Json(req): Json, ) -> Json>> { match state.sftp.list(&req.session_id, &req.path).await { Ok(entries) => { let items: Vec = entries.into_iter().map(|e| { serde_json::json!({ "name": e.name, "path": e.path, "size": e.size, "is_dir": e.is_dir, "modified": e.mod_time, }) }).collect(); ok_response(items) } Err(e) => err_response(e), } } async fn handle_sftp_read( AxumState(state): AxumState>, Json(req): Json, ) -> Json> { match state.sftp.read_file(&req.session_id, &req.path).await { Ok(content) => ok_response(content), Err(e) => err_response(e), } } async fn handle_sftp_write( AxumState(state): AxumState>, Json(req): Json, ) -> Json> { match state.sftp.write_file(&req.session_id, &req.path, &req.content).await { Ok(()) => ok_response("OK".to_string()), Err(e) => err_response(e), } } async fn handle_screenshot( AxumState(state): AxumState>, Json(req): Json, ) -> Json> { match state.rdp.screenshot_png_base64(&req.session_id).await { Ok(b64) => ok_response(b64), Err(e) => err_response(e), } } async fn handle_terminal_type( AxumState(state): AxumState>, Json(req): Json, ) -> Json> { let text = if req.press_enter.unwrap_or(true) { format!("{}\r", req.text) } else { req.text.clone() }; match state.ssh.write(&req.session_id, text.as_bytes()).await { Ok(()) => ok_response("sent".to_string()), Err(e) => err_response(e), } } async fn handle_terminal_read( AxumState(state): AxumState>, Json(req): Json, ) -> Json> { let n = req.lines.unwrap_or(50); match state.scrollback.get(&req.session_id) { Some(buf) => ok_response(buf.read_lines(n)), None => err_response(format!("No scrollback buffer for session {}", req.session_id)), } } async fn handle_terminal_execute( AxumState(state): AxumState>, Json(req): Json, ) -> Json> { let timeout = req.timeout_ms.unwrap_or(5000); let marker = "__WRAITH_MCP_DONE__"; let buf = match state.scrollback.get(&req.session_id) { Some(b) => b, None => return err_response(format!("No scrollback buffer for session {}", req.session_id)), }; let before = buf.total_written(); let full_cmd = format!("{}\recho {}\r", req.command, marker); if let Err(e) = state.ssh.write(&req.session_id, full_cmd.as_bytes()).await { return err_response(e); } let start = std::time::Instant::now(); let timeout_dur = std::time::Duration::from_millis(timeout); loop { if start.elapsed() > timeout_dur { let raw = buf.read_raw(); let total = buf.total_written(); let new_bytes = total.saturating_sub(before); let output = if new_bytes > 0 && raw.len() >= new_bytes { &raw[raw.len() - new_bytes.min(raw.len())..] } else { "" }; return ok_response(format!("[timeout after {}ms]\n{}", timeout, output)); } let raw = buf.read_raw(); if raw.contains(marker) { let total = buf.total_written(); let new_bytes = total.saturating_sub(before); let output = if new_bytes > 0 && raw.len() >= new_bytes { raw[raw.len() - new_bytes.min(raw.len())..].to_string() } else { String::new() }; let clean = output .lines() .filter(|line| !line.contains(marker)) .collect::>() .join("\n"); return ok_response(clean.trim().to_string()); } tokio::time::sleep(std::time::Duration::from_millis(50)).await; } } // ── Tool handlers (all tools exposed to AI via MCP) ────────────────────────── #[derive(Deserialize)] struct ToolSessionTarget { session_id: String, target: String } #[derive(Deserialize)] struct ToolSessionOnly { session_id: String } #[derive(Deserialize)] struct ToolDnsRequest { session_id: String, domain: String, record_type: Option } #[derive(Deserialize)] struct ToolWolRequest { session_id: String, mac_address: String } #[derive(Deserialize)] struct ToolScanNetworkRequest { session_id: String, subnet: String } #[derive(Deserialize)] struct ToolScanPortsRequest { session_id: String, target: String, ports: Option> } #[derive(Deserialize)] struct ToolSubnetRequest { cidr: String } #[derive(Deserialize)] struct ToolKeygenRequest { key_type: String, comment: Option } #[derive(Deserialize)] struct ToolPassgenRequest { length: Option, uppercase: Option, lowercase: Option, digits: Option, symbols: Option } async fn handle_tool_ping(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_traceroute(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_dns(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let rt = req.record_type.unwrap_or_else(|| "A".to_string()); match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", req.domain, rt, rt, req.domain, rt, req.domain)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_whois(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_wol(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let mac_clean = req.mac_address.replace([':', '-'], ""); let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex('{}');pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, mac_clean, req.mac_address); match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_scan_network(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match crate::scanner::scan_network(&session.handle, &req.subnet).await { Ok(hosts) => ok_response(serde_json::to_value(hosts).unwrap_or_default()), Err(e) => err_response(e), } } async fn handle_tool_scan_ports(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let result = if let Some(ports) = req.ports { crate::scanner::scan_ports(&session.handle, &req.target, &ports).await } else { crate::scanner::quick_port_scan(&session.handle, &req.target).await }; match result { Ok(r) => ok_response(serde_json::to_value(r).unwrap_or_default()), Err(e) => err_response(e) } } async fn handle_tool_subnet(_state: AxumState>, Json(req): Json) -> Json> { match crate::commands::tools_commands_r2::tool_subnet_calc_inner(&req.cidr) { Ok(info) => ok_response(serde_json::to_value(info).unwrap_or_default()), Err(e) => err_response(e), } } async fn handle_tool_bandwidth(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let cmd = r#"if command -v speedtest-cli >/dev/null 2>&1; then speedtest-cli --simple 2>&1; elif command -v curl >/dev/null 2>&1; then curl -o /dev/null -w "Download: %{speed_download} bytes/sec\n" https://speed.cloudflare.com/__down?bytes=25000000 2>/dev/null; else echo "No speedtest tool found"; fi"#; match tool_exec(&session.handle, cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_tool_keygen(_state: AxumState>, Json(req): Json) -> Json> { match crate::commands::tools_commands::tool_generate_ssh_key_inner(&req.key_type, req.comment) { Ok(key) => ok_response(serde_json::to_value(key).unwrap_or_default()), Err(e) => err_response(e), } } async fn handle_tool_passgen(_state: AxumState>, Json(req): Json) -> Json> { match crate::commands::tools_commands::tool_generate_password_inner(req.length, req.uppercase, req.lowercase, req.digits, req.symbols) { Ok(pw) => ok_response(pw), Err(e) => err_response(e), } } async fn tool_exec(handle: &std::sync::Arc>>, cmd: &str) -> Result { let mut channel = { let h = handle.lock().await; h.channel_open_session().await.map_err(|e| format!("Exec failed: {}", e))? }; channel.exec(true, cmd).await.map_err(|e| format!("Exec failed: {}", e))?; let mut output = String::new(); loop { match channel.wait().await { Some(russh::ChannelMsg::Data { ref data }) => { if let Ok(t) = std::str::from_utf8(data.as_ref()) { output.push_str(t); } } Some(russh::ChannelMsg::Eof) | Some(russh::ChannelMsg::Close) | None => break, _ => {} } } Ok(output) } // ── Docker handlers ────────────────────────────────────────────────────────── #[derive(Deserialize)] struct DockerActionRequest { session_id: String, action: String, target: String } #[derive(Deserialize)] struct DockerListRequest { session_id: String } #[derive(Deserialize)] struct DockerExecRequest { session_id: String, container: String, command: String } async fn handle_docker_ps(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, "docker ps -a --format '{{.Names}}|{{.Image}}|{{.Status}}|{{.Ports}}' 2>&1").await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_docker_action(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let cmd = match req.action.as_str() { "start" => format!("docker start {} 2>&1", req.target), "stop" => format!("docker stop {} 2>&1", req.target), "restart" => format!("docker restart {} 2>&1", req.target), "remove" => format!("docker rm -f {} 2>&1", req.target), "logs" => format!("docker logs --tail 100 {} 2>&1", req.target), "builder-prune" => "docker builder prune -f 2>&1".to_string(), "system-prune" => "docker system prune -f 2>&1".to_string(), _ => return err_response(format!("Unknown action: {}", req.action)), }; match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_docker_exec(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let cmd = format!("docker exec {} {} 2>&1", req.container, req.command); match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } // ── Service/process handlers ───────────────────────────────────────────────── async fn handle_service_status(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_process_list(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", req.target) }; match tool_exec(&session.handle, &format!("ps {}", filter)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } // ── Git handlers ───────────────────────────────────────────────────────────── #[derive(Deserialize)] struct GitRequest { session_id: String, path: String } async fn handle_git_status(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_git_pull(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } async fn handle_git_log(AxumState(state): AxumState>, Json(req): Json) -> Json> { let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } // ── Session creation handlers ──────────────────────────────────────────────── #[derive(Deserialize)] struct SshConnectRequest { hostname: String, port: Option, username: String, password: Option, private_key_path: Option, } async fn handle_ssh_connect(AxumState(state): AxumState>, Json(req): Json) -> Json> { use crate::ssh::session::AuthMethod; let port = req.port.unwrap_or(22); let auth = if let Some(key_path) = req.private_key_path { // Read key file let pem = match std::fs::read_to_string(&key_path) { Ok(p) => p, Err(e) => return err_response(format!("Failed to read key file {}: {}", key_path, e)), }; AuthMethod::Key { private_key_pem: pem, passphrase: req.password } } else { AuthMethod::Password(req.password.unwrap_or_default()) }; match state.ssh.connect( state.app_handle.clone(), &req.hostname, port, &req.username, auth, 120, 40, &state.sftp, &state.scrollback, &state.error_watcher, ).await { Ok(session_id) => ok_response(session_id), Err(e) => err_response(e), } } // ── RDP interaction handlers ───────────────────────────────────────────────── #[derive(Deserialize)] struct RdpClickRequest { session_id: String, x: u16, y: u16, button: Option } #[derive(Deserialize)] struct RdpTypeRequest { session_id: String, text: String } #[derive(Deserialize)] struct RdpClipboardRequest { session_id: String, text: String } async fn handle_rdp_click(AxumState(state): AxumState>, Json(req): Json) -> Json> { use crate::rdp::input::mouse_flags; let button_flag = match req.button.as_deref().unwrap_or("left") { "right" => mouse_flags::BUTTON2, "middle" => mouse_flags::BUTTON3, _ => mouse_flags::BUTTON1, }; // Move to position if let Err(e) = state.rdp.send_mouse(&req.session_id, req.x, req.y, mouse_flags::MOVE) { return err_response(e); } // Click down if let Err(e) = state.rdp.send_mouse(&req.session_id, req.x, req.y, button_flag | mouse_flags::DOWN) { return err_response(e); } // Click up if let Err(e) = state.rdp.send_mouse(&req.session_id, req.x, req.y, button_flag) { return err_response(e); } ok_response(format!("clicked ({}, {})", req.x, req.y)) } async fn handle_rdp_type(AxumState(state): AxumState>, Json(req): Json) -> Json> { // Set clipboard then simulate Ctrl+V to paste (most reliable for arbitrary text) if let Err(e) = state.rdp.send_clipboard(&req.session_id, &req.text) { return err_response(e); } // Small delay for clipboard to propagate, then Ctrl+V tokio::time::sleep(std::time::Duration::from_millis(50)).await; // Ctrl down let _ = state.rdp.send_key(&req.session_id, 0x001D, true); // V down let _ = state.rdp.send_key(&req.session_id, 0x002F, true); // V up let _ = state.rdp.send_key(&req.session_id, 0x002F, false); // Ctrl up let _ = state.rdp.send_key(&req.session_id, 0x001D, false); ok_response(format!("typed {} chars via clipboard paste", req.text.len())) } async fn handle_rdp_clipboard(AxumState(state): AxumState>, Json(req): Json) -> Json> { if let Err(e) = state.rdp.send_clipboard(&req.session_id, &req.text) { return err_response(e); } ok_response("clipboard set".to_string()) } /// Start the MCP HTTP server and write the port to disk. pub async fn start_mcp_server( ssh: SshService, rdp: RdpService, sftp: SftpService, scrollback: ScrollbackRegistry, app_handle: tauri::AppHandle, error_watcher: std::sync::Arc, ) -> Result { let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher }); let app = Router::new() .route("/mcp/sessions", post(handle_list_sessions)) .route("/mcp/terminal/type", post(handle_terminal_type)) .route("/mcp/terminal/read", post(handle_terminal_read)) .route("/mcp/terminal/execute", post(handle_terminal_execute)) .route("/mcp/screenshot", post(handle_screenshot)) .route("/mcp/sftp/list", post(handle_sftp_list)) .route("/mcp/sftp/read", post(handle_sftp_read)) .route("/mcp/sftp/write", post(handle_sftp_write)) .route("/mcp/tool/ping", post(handle_tool_ping)) .route("/mcp/tool/traceroute", post(handle_tool_traceroute)) .route("/mcp/tool/dns", post(handle_tool_dns)) .route("/mcp/tool/whois", post(handle_tool_whois)) .route("/mcp/tool/wol", post(handle_tool_wol)) .route("/mcp/tool/scan-network", post(handle_tool_scan_network)) .route("/mcp/tool/scan-ports", post(handle_tool_scan_ports)) .route("/mcp/tool/subnet", post(handle_tool_subnet)) .route("/mcp/tool/bandwidth", post(handle_tool_bandwidth)) .route("/mcp/tool/keygen", post(handle_tool_keygen)) .route("/mcp/tool/passgen", post(handle_tool_passgen)) .route("/mcp/docker/ps", post(handle_docker_ps)) .route("/mcp/docker/action", post(handle_docker_action)) .route("/mcp/docker/exec", post(handle_docker_exec)) .route("/mcp/service/status", post(handle_service_status)) .route("/mcp/process/list", post(handle_process_list)) .route("/mcp/git/status", post(handle_git_status)) .route("/mcp/git/pull", post(handle_git_pull)) .route("/mcp/git/log", post(handle_git_log)) .route("/mcp/rdp/click", post(handle_rdp_click)) .route("/mcp/rdp/type", post(handle_rdp_type)) .route("/mcp/rdp/clipboard", post(handle_rdp_clipboard)) .route("/mcp/ssh/connect", post(handle_ssh_connect)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await .map_err(|e| format!("Failed to bind MCP server: {}", e))?; let port = listener.local_addr() .map_err(|e| format!("Failed to get MCP server port: {}", e))? .port(); // Write port to well-known location let port_file = crate::data_directory().join("mcp-port"); std::fs::write(&port_file, port.to_string()) .map_err(|e| format!("Failed to write MCP port file: {}", e))?; tokio::spawn(async move { axum::serve(listener, app).await.ok(); }); Ok(port) }