diff --git a/src-tauri/src/bin/wraith_mcp_bridge.rs b/src-tauri/src/bin/wraith_mcp_bridge.rs index 908b77b..68e9423 100644 --- a/src-tauri/src/bin/wraith_mcp_bridge.rs +++ b/src-tauri/src/bin/wraith_mcp_bridge.rs @@ -279,6 +279,17 @@ fn handle_tools_list(id: Value) -> JsonRpcResponse { "description": "Set the clipboard content on a remote RDP session", "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "text": { "type": "string" } }, "required": ["session_id", "text"] } }, + { + "name": "ssh_connect", + "description": "Open a new SSH connection through Wraith. Returns the session ID for use with other tools.", + "inputSchema": { "type": "object", "properties": { + "hostname": { "type": "string" }, + "port": { "type": "number", "description": "Default: 22" }, + "username": { "type": "string" }, + "password": { "type": "string", "description": "Password (for password auth)" }, + "private_key_path": { "type": "string", "description": "Path to SSH private key file on the local machine" } + }, "required": ["hostname", "username"] } + }, { "name": "list_sessions", "description": "List all active Wraith sessions (SSH, RDP, PTY) with connection details", @@ -347,6 +358,7 @@ fn handle_tool_call(id: Value, port: u16, tool_name: &str, args: &Value) -> Json "rdp_click" => call_wraith(port, "/mcp/rdp/click", args.clone()), "rdp_type" => call_wraith(port, "/mcp/rdp/type", args.clone()), "rdp_clipboard" => call_wraith(port, "/mcp/rdp/clipboard", args.clone()), + "ssh_connect" => call_wraith(port, "/mcp/ssh/connect", args.clone()), "terminal_screenshot" => { let result = call_wraith(port, "/mcp/screenshot", args.clone()); // Screenshot returns base64 PNG — wrap as image content for multimodal AI diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 206a4e7..21a9a59 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -155,7 +155,9 @@ pub fn run() { let _ = write_log(&log_file, "Setup: cloned services OK"); // Error watcher — std::thread, no tokio needed + let watcher_for_mcp = watcher.clone(); let app_handle = app.handle().clone(); + let app_handle_for_mcp = app.handle().clone(); let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { mcp::error_watcher::start_error_watcher(watcher, scrollback.clone(), app_handle); })); @@ -165,7 +167,7 @@ pub fn run() { let log_file2 = log_file.clone(); let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { tauri::async_runtime::spawn(async move { - match mcp::server::start_mcp_server(ssh, rdp, sftp, scrollback).await { + match mcp::server::start_mcp_server(ssh, rdp, sftp, scrollback, app_handle_for_mcp, watcher_for_mcp).await { Ok(port) => { let _ = write_log(&log_file2, &format!("MCP server started on localhost:{}", port)); } Err(e) => { let _ = write_log(&log_file2, &format!("MCP server FAILED: {}", e)); } } diff --git a/src-tauri/src/mcp/server.rs b/src-tauri/src/mcp/server.rs index 5b37cc2..4ece5f7 100644 --- a/src-tauri/src/mcp/server.rs +++ b/src-tauri/src/mcp/server.rs @@ -20,6 +20,8 @@ pub struct McpServerState { pub rdp: RdpService, pub sftp: SftpService, pub scrollback: ScrollbackRegistry, + pub app_handle: tauri::AppHandle, + pub error_watcher: std::sync::Arc, } #[derive(Deserialize)] @@ -432,6 +434,48 @@ async fn handle_git_log(AxumState(state): AxumState>, Json(r match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } } +// ── Session creation handlers ──────────────────────────────────────────────── + +#[derive(Deserialize)] +struct SshConnectRequest { + hostname: String, + port: Option, + username: String, + password: Option, + private_key_path: Option, +} + +async fn handle_ssh_connect(AxumState(state): AxumState>, Json(req): Json) -> Json> { + use crate::ssh::session::AuthMethod; + + let port = req.port.unwrap_or(22); + let auth = if let Some(key_path) = req.private_key_path { + // Read key file + let pem = match std::fs::read_to_string(&key_path) { + Ok(p) => p, + Err(e) => return err_response(format!("Failed to read key file {}: {}", key_path, e)), + }; + AuthMethod::Key { private_key_pem: pem, passphrase: req.password } + } else { + AuthMethod::Password(req.password.unwrap_or_default()) + }; + + match state.ssh.connect( + state.app_handle.clone(), + &req.hostname, + port, + &req.username, + auth, + 120, 40, + &state.sftp, + &state.scrollback, + &state.error_watcher, + ).await { + Ok(session_id) => ok_response(session_id), + Err(e) => err_response(e), + } +} + // ── RDP interaction handlers ───────────────────────────────────────────────── #[derive(Deserialize)] @@ -476,8 +520,10 @@ pub async fn start_mcp_server( rdp: RdpService, sftp: SftpService, scrollback: ScrollbackRegistry, + app_handle: tauri::AppHandle, + error_watcher: std::sync::Arc, ) -> Result { - let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback }); + let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher }); let app = Router::new() .route("/mcp/sessions", post(handle_list_sessions)) @@ -510,6 +556,7 @@ pub async fn start_mcp_server( .route("/mcp/rdp/click", post(handle_rdp_click)) .route("/mcp/rdp/type", post(handle_rdp_type)) .route("/mcp/rdp/clipboard", post(handle_rdp_clipboard)) + .route("/mcp/ssh/connect", post(handle_ssh_connect)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await