diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index b37a5de..f42e7de 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -356,6 +356,58 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base16ct" version = "0.2.0" @@ -2596,6 +2648,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hybrid-array" version = "0.4.8" @@ -2621,6 +2679,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -3495,6 +3554,12 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md-5" version = "0.10.6" @@ -5974,6 +6039,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_repr" version = "0.1.20" @@ -7402,6 +7478,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -7591,6 +7668,35 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "ureq-proto", + "utf8-zero", + "webpki-roots", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -7622,6 +7728,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -8768,6 +8880,7 @@ dependencies = [ "anyhow", "argon2", "async-trait", + "axum", "base64 0.22.1", "block-padding 0.3.3", "cbc 0.1.2", @@ -8799,6 +8912,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-rustls", + "ureq", "uuid", "x509-cert", ] diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index b56a310..1f03188 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -7,6 +7,10 @@ edition = "2024" name = "wraith_lib" crate-type = ["lib", "cdylib", "staticlib"] +[[bin]] +name = "wraith-mcp-bridge" +path = "src/bin/wraith_mcp_bridge.rs" + [build-dependencies] tauri-build = { version = "2", features = [] } @@ -48,6 +52,10 @@ sec1 = { version = "0.7", features = ["pem"] } # Local PTY for AI copilot panel portable-pty = "0.8" +# MCP HTTP server (for bridge binary communication) +axum = "0.8" +ureq = "3" + # RDP (IronRDP) ironrdp = { version = "0.14", features = ["connector", "session", "graphics", "input"] } ironrdp-tokio = { version = "0.8", features = ["reqwest-rustls-ring"] } diff --git a/src-tauri/src/bin/wraith_mcp_bridge.rs b/src-tauri/src/bin/wraith_mcp_bridge.rs new file mode 100644 index 0000000..7c2f041 --- /dev/null +++ b/src-tauri/src/bin/wraith_mcp_bridge.rs @@ -0,0 +1,245 @@ +//! Wraith MCP Bridge — stdio JSON-RPC proxy to Wraith's HTTP API. +//! +//! This binary is spawned by AI CLIs (Claude Code, Gemini CLI) as an MCP +//! server. It reads JSON-RPC requests from stdin, translates them to HTTP +//! calls against the running Wraith instance, and writes responses to stdout. +//! +//! The Wraith instance's MCP HTTP port is read from the data directory's +//! `mcp-port` file. + +use std::io::{self, BufRead, Write}; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Deserialize)] +#[allow(dead_code)] +struct JsonRpcRequest { + jsonrpc: String, + id: Value, + method: String, + #[serde(default)] + params: Value, +} + +#[derive(Serialize)] +struct JsonRpcResponse { + jsonrpc: String, + id: Value, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +#[derive(Serialize)] +struct JsonRpcError { + code: i32, + message: String, +} + +fn get_mcp_port() -> Result { + // Check standard locations for the port file + let port_file = if let Ok(appdata) = std::env::var("APPDATA") { + std::path::PathBuf::from(appdata).join("Wraith").join("mcp-port") + } else if let Ok(home) = std::env::var("HOME") { + if cfg!(target_os = "macos") { + std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith").join("mcp-port") + } else { + std::path::PathBuf::from(home).join(".local").join("share").join("wraith").join("mcp-port") + } + } else { + return Err("Cannot determine data directory".to_string()); + }; + + let port_str = std::fs::read_to_string(&port_file) + .map_err(|e| format!("Cannot read MCP port file at {}: {} — is Wraith running?", port_file.display(), e))?; + + port_str.trim().parse::() + .map_err(|e| format!("Invalid port in MCP port file: {}", e)) +} + +fn handle_initialize(id: Value) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: Some(serde_json::json!({ + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "serverInfo": { + "name": "wraith-terminal", + "version": "1.0.0" + } + })), + error: None, + } +} + +fn handle_tools_list(id: Value) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: Some(serde_json::json!({ + "tools": [ + { + "name": "terminal_read", + "description": "Read recent terminal output from an active SSH or PTY session (ANSI codes stripped)", + "inputSchema": { + "type": "object", + "properties": { + "session_id": { "type": "string", "description": "The session ID to read from. Use list_sessions to find IDs." }, + "lines": { "type": "number", "description": "Number of recent lines to return (default: 50)" } + }, + "required": ["session_id"] + } + }, + { + "name": "terminal_execute", + "description": "Execute a command in an active SSH session and return the output", + "inputSchema": { + "type": "object", + "properties": { + "session_id": { "type": "string", "description": "The SSH session ID to execute in" }, + "command": { "type": "string", "description": "The command to run" }, + "timeout_ms": { "type": "number", "description": "Max wait time in ms (default: 5000)" } + }, + "required": ["session_id", "command"] + } + }, + { + "name": "list_sessions", + "description": "List all active Wraith sessions (SSH, RDP, PTY) with connection details", + "inputSchema": { + "type": "object", + "properties": {} + } + } + ] + })), + error: None, + } +} + +fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result { + let url = format!("http://127.0.0.1:{}{}", port, endpoint); + let body_str = serde_json::to_string(&body).unwrap_or_default(); + + let mut resp = ureq::post(url) + .header("Content-Type", "application/json") + .send(body_str.as_bytes()) + .map_err(|e| format!("HTTP request to Wraith failed: {}", e))?; + + let resp_str = resp.body_mut().read_to_string() + .map_err(|e| format!("Failed to read Wraith response: {}", e))?; + + let json: Value = serde_json::from_str(&resp_str) + .map_err(|e| format!("Failed to parse Wraith response: {}", e))?; + + if json.get("ok").and_then(|v| v.as_bool()) == Some(true) { + Ok(json.get("data").cloned().unwrap_or(Value::Null)) + } else { + let err_msg = json.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error"); + Err(err_msg.to_string()) + } +} + +fn handle_tool_call(id: Value, port: u16, tool_name: &str, args: &Value) -> JsonRpcResponse { + let result = match tool_name { + "list_sessions" => call_wraith(port, "/mcp/sessions", serde_json::json!({})), + "terminal_read" => call_wraith(port, "/mcp/terminal/read", args.clone()), + "terminal_execute" => call_wraith(port, "/mcp/terminal/execute", args.clone()), + _ => Err(format!("Unknown tool: {}", tool_name)), + }; + + match result { + Ok(data) => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: Some(serde_json::json!({ + "content": [{ + "type": "text", + "text": if data.is_string() { + data.as_str().unwrap().to_string() + } else { + serde_json::to_string_pretty(&data).unwrap_or_default() + } + }] + })), + error: None, + }, + Err(e) => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(JsonRpcError { code: -32000, message: e }), + }, + } +} + +fn main() { + let port = match get_mcp_port() { + Ok(p) => p, + Err(e) => { + eprintln!("wraith-mcp-bridge: {}", e); + std::process::exit(1); + } + }; + + let stdin = io::stdin(); + let mut stdout = io::stdout(); + + for line in stdin.lock().lines() { + let line = match line { + Ok(l) => l, + Err(_) => break, + }; + + if line.trim().is_empty() { + continue; + } + + let request: JsonRpcRequest = match serde_json::from_str(&line) { + Ok(r) => r, + Err(e) => { + let err_resp = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: Value::Null, + result: None, + error: Some(JsonRpcError { code: -32700, message: format!("Parse error: {}", e) }), + }; + let _ = writeln!(stdout, "{}", serde_json::to_string(&err_resp).unwrap()); + let _ = stdout.flush(); + continue; + } + }; + + let response = match request.method.as_str() { + "initialize" => handle_initialize(request.id), + "tools/list" => handle_tools_list(request.id), + "tools/call" => { + let tool_name = request.params.get("name") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let args = request.params.get("arguments") + .cloned() + .unwrap_or(Value::Object(serde_json::Map::new())); + handle_tool_call(request.id, port, tool_name, &args) + } + "notifications/initialized" | "notifications/cancelled" => { + // Notifications don't get responses + continue; + } + _ => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(JsonRpcError { code: -32601, message: format!("Method not found: {}", request.method) }), + }, + }; + + let _ = writeln!(stdout, "{}", serde_json::to_string(&response).unwrap()); + let _ = stdout.flush(); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 508945a..c3e4b71 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -99,7 +99,21 @@ pub fn run() { window.open_devtools(); } } - let _ = app; + + // Start MCP HTTP server for bridge binary communication + { + use tauri::Manager; + let state = app.state::(); + let ssh_clone = state.ssh.clone(); + let scrollback_clone = state.scrollback.clone(); + tauri::async_runtime::spawn(async move { + match mcp::server::start_mcp_server(ssh_clone, scrollback_clone).await { + Ok(port) => log::info!("MCP server started on localhost:{}", port), + Err(e) => log::error!("Failed to start MCP server: {}", e), + } + }); + } + Ok(()) }) .invoke_handler(tauri::generate_handler![ diff --git a/src-tauri/src/mcp/mod.rs b/src-tauri/src/mcp/mod.rs index ad7e9d5..617d9c3 100644 --- a/src-tauri/src/mcp/mod.rs +++ b/src-tauri/src/mcp/mod.rs @@ -5,6 +5,7 @@ //! sessions. pub mod scrollback; +pub mod server; use std::sync::Arc; @@ -14,6 +15,7 @@ use crate::mcp::scrollback::ScrollbackBuffer; /// Registry of scrollback buffers keyed by session ID. /// Shared between SSH/PTY output loops (writers) and MCP tools (readers). +#[derive(Clone)] pub struct ScrollbackRegistry { buffers: DashMap>, } diff --git a/src-tauri/src/mcp/server.rs b/src-tauri/src/mcp/server.rs new file mode 100644 index 0000000..9427c2c --- /dev/null +++ b/src-tauri/src/mcp/server.rs @@ -0,0 +1,164 @@ +//! Tiny HTTP server for MCP bridge communication. +//! +//! Runs on localhost:0 (random port) at Tauri startup. The port is written +//! to ~/.wraith/mcp-port so the bridge binary can find it. + +use std::sync::Arc; + +use axum::{extract::State as AxumState, routing::post, Json, Router}; +use serde::{Deserialize, Serialize}; +use tokio::net::TcpListener; + +use crate::mcp::ScrollbackRegistry; +use crate::ssh::session::SshService; + +/// Shared state passed to axum handlers. +pub struct McpServerState { + pub ssh: SshService, + pub scrollback: ScrollbackRegistry, +} + +#[derive(Deserialize)] +struct TerminalReadRequest { + session_id: String, + lines: Option, +} + +#[derive(Deserialize)] +struct TerminalExecuteRequest { + session_id: String, + command: String, + timeout_ms: Option, +} + +#[derive(Serialize)] +struct McpResponse { + ok: bool, + data: Option, + error: Option, +} + +fn ok_response(data: T) -> Json> { + Json(McpResponse { ok: true, data: Some(data), error: None }) +} + +fn err_response(msg: String) -> Json> { + Json(McpResponse { ok: false, data: None, error: Some(msg) }) +} + +async fn handle_list_sessions( + AxumState(state): AxumState>, +) -> Json>> { + let sessions: Vec = state.ssh.list_sessions() + .into_iter() + .map(|s| serde_json::json!({ + "id": s.id, + "type": "ssh", + "name": format!("{}@{}:{}", s.username, s.hostname, s.port), + "host": s.hostname, + "username": s.username, + })) + .collect(); + ok_response(sessions) +} + +async fn handle_terminal_read( + AxumState(state): AxumState>, + Json(req): Json, +) -> Json> { + let n = req.lines.unwrap_or(50); + match state.scrollback.get(&req.session_id) { + Some(buf) => ok_response(buf.read_lines(n)), + None => err_response(format!("No scrollback buffer for session {}", req.session_id)), + } +} + +async fn handle_terminal_execute( + AxumState(state): AxumState>, + Json(req): Json, +) -> Json> { + let timeout = req.timeout_ms.unwrap_or(5000); + let marker = "__WRAITH_MCP_DONE__"; + + let buf = match state.scrollback.get(&req.session_id) { + Some(b) => b, + None => return err_response(format!("No scrollback buffer for session {}", req.session_id)), + }; + + let before = buf.total_written(); + let full_cmd = format!("{}\necho {}\n", req.command, marker); + + if let Err(e) = state.ssh.write(&req.session_id, full_cmd.as_bytes()).await { + return err_response(e); + } + + let start = std::time::Instant::now(); + let timeout_dur = std::time::Duration::from_millis(timeout); + + loop { + if start.elapsed() > timeout_dur { + let raw = buf.read_raw(); + let total = buf.total_written(); + let new_bytes = total.saturating_sub(before); + let output = if new_bytes > 0 && raw.len() >= new_bytes { + &raw[raw.len() - new_bytes.min(raw.len())..] + } else { + "" + }; + return ok_response(format!("[timeout after {}ms]\n{}", timeout, output)); + } + + let raw = buf.read_raw(); + if raw.contains(marker) { + let total = buf.total_written(); + let new_bytes = total.saturating_sub(before); + let output = if new_bytes > 0 && raw.len() >= new_bytes { + raw[raw.len() - new_bytes.min(raw.len())..].to_string() + } else { + String::new() + }; + + let clean = output + .lines() + .filter(|line| !line.contains(marker)) + .collect::>() + .join("\n"); + + return ok_response(clean.trim().to_string()); + } + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } +} + +/// Start the MCP HTTP server and write the port to disk. +pub async fn start_mcp_server( + ssh: SshService, + scrollback: ScrollbackRegistry, +) -> Result { + let state = Arc::new(McpServerState { ssh, scrollback }); + + let app = Router::new() + .route("/mcp/sessions", post(handle_list_sessions)) + .route("/mcp/terminal/read", post(handle_terminal_read)) + .route("/mcp/terminal/execute", post(handle_terminal_execute)) + .with_state(state); + + let listener = TcpListener::bind("127.0.0.1:0").await + .map_err(|e| format!("Failed to bind MCP server: {}", e))?; + + let port = listener.local_addr() + .map_err(|e| format!("Failed to get MCP server port: {}", e))? + .port(); + + // Write port to well-known location + let port_file = crate::data_directory().join("mcp-port"); + std::fs::write(&port_file, port.to_string()) + .map_err(|e| format!("Failed to write MCP port file: {}", e))?; + + tokio::spawn(async move { + axum::serve(listener, app).await.ok(); + }); + + Ok(port) +} diff --git a/src-tauri/src/pty/mod.rs b/src-tauri/src/pty/mod.rs index 416afd1..83769f1 100644 --- a/src-tauri/src/pty/mod.rs +++ b/src-tauri/src/pty/mod.rs @@ -91,7 +91,10 @@ impl PtyService { .openpty(PtySize { rows, cols, pixel_width: 0, pixel_height: 0 }) .map_err(|e| format!("Failed to open PTY: {}", e))?; - let cmd = CommandBuilder::new(shell_path); + let mut cmd = CommandBuilder::new(shell_path); + + // Auto-inject MCP server config so AI CLIs discover the bridge + cmd.env("WRAITH_MCP_BRIDGE", "wraith-mcp-bridge"); let child = pair.slave .spawn_command(cmd) diff --git a/src-tauri/src/ssh/session.rs b/src-tauri/src/ssh/session.rs index 543c483..9d68628 100644 --- a/src-tauri/src/ssh/session.rs +++ b/src-tauri/src/ssh/session.rs @@ -73,6 +73,7 @@ impl client::Handler for SshClient { } } +#[derive(Clone)] pub struct SshService { sessions: DashMap>, db: Database,