diff --git a/src-tauri/src/bin/wraith_mcp_bridge.rs b/src-tauri/src/bin/wraith_mcp_bridge.rs index 6ed4bc5..48afee8 100644 --- a/src-tauri/src/bin/wraith_mcp_bridge.rs +++ b/src-tauri/src/bin/wraith_mcp_bridge.rs @@ -156,6 +156,61 @@ fn handle_tools_list(id: Value) -> JsonRpcResponse { "required": ["session_id", "path", "content"] } }, + { + "name": "network_scan", + "description": "Discover all devices on a remote network subnet via ARP + ping sweep", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "subnet": { "type": "string", "description": "First 3 octets, e.g. 192.168.1" } }, "required": ["session_id", "subnet"] } + }, + { + "name": "port_scan", + "description": "Scan TCP ports on a target host through an SSH session", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "target": { "type": "string" }, "ports": { "type": "array", "items": { "type": "number" }, "description": "Specific ports. Omit for quick scan of 24 common ports." } }, "required": ["session_id", "target"] } + }, + { + "name": "ping", + "description": "Ping a host through an SSH session", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "target": { "type": "string" } }, "required": ["session_id", "target"] } + }, + { + "name": "traceroute", + "description": "Traceroute to a host through an SSH session", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "target": { "type": "string" } }, "required": ["session_id", "target"] } + }, + { + "name": "dns_lookup", + "description": "DNS lookup for a domain through an SSH session", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "domain": { "type": "string" }, "record_type": { "type": "string", "description": "A, AAAA, MX, NS, TXT, CNAME, SOA, SRV, PTR" } }, "required": ["session_id", "domain"] } + }, + { + "name": "whois", + "description": "Whois lookup for a domain or IP through an SSH session", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "target": { "type": "string" } }, "required": ["session_id", "target"] } + }, + { + "name": "wake_on_lan", + "description": "Send Wake-on-LAN magic packet through an SSH session to wake a device", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" }, "mac_address": { "type": "string", "description": "MAC address (AA:BB:CC:DD:EE:FF)" } }, "required": ["session_id", "mac_address"] } + }, + { + "name": "bandwidth_test", + "description": "Run an internet speed test on a remote host through SSH", + "inputSchema": { "type": "object", "properties": { "session_id": { "type": "string" } }, "required": ["session_id"] } + }, + { + "name": "subnet_calc", + "description": "Calculate subnet details from CIDR notation (no SSH needed)", + "inputSchema": { "type": "object", "properties": { "cidr": { "type": "string", "description": "e.g. 192.168.1.0/24" } }, "required": ["cidr"] } + }, + { + "name": "generate_ssh_key", + "description": "Generate an SSH key pair (ed25519 or RSA)", + "inputSchema": { "type": "object", "properties": { "key_type": { "type": "string", "description": "ed25519 or rsa" }, "comment": { "type": "string" } }, "required": ["key_type"] } + }, + { + "name": "generate_password", + "description": "Generate a cryptographically secure random password", + "inputSchema": { "type": "object", "properties": { "length": { "type": "number" }, "uppercase": { "type": "boolean" }, "lowercase": { "type": "boolean" }, "digits": { "type": "boolean" }, "symbols": { "type": "boolean" } } } + }, { "name": "list_sessions", "description": "List all active Wraith sessions (SSH, RDP, PTY) with connection details", @@ -201,6 +256,17 @@ fn handle_tool_call(id: Value, port: u16, tool_name: &str, args: &Value) -> Json "sftp_list" => call_wraith(port, "/mcp/sftp/list", args.clone()), "sftp_read" => call_wraith(port, "/mcp/sftp/read", args.clone()), "sftp_write" => call_wraith(port, "/mcp/sftp/write", args.clone()), + "network_scan" => call_wraith(port, "/mcp/tool/scan-network", args.clone()), + "port_scan" => call_wraith(port, "/mcp/tool/scan-ports", args.clone()), + "ping" => call_wraith(port, "/mcp/tool/ping", args.clone()), + "traceroute" => call_wraith(port, "/mcp/tool/traceroute", args.clone()), + "dns_lookup" => call_wraith(port, "/mcp/tool/dns", args.clone()), + "whois" => call_wraith(port, "/mcp/tool/whois", args.clone()), + "wake_on_lan" => call_wraith(port, "/mcp/tool/wol", args.clone()), + "bandwidth_test" => call_wraith(port, "/mcp/tool/bandwidth", args.clone()), + "subnet_calc" => call_wraith(port, "/mcp/tool/subnet", args.clone()), + "generate_ssh_key" => call_wraith(port, "/mcp/tool/keygen", args.clone()), + "generate_password" => call_wraith(port, "/mcp/tool/passgen", args.clone()), "terminal_screenshot" => { let result = call_wraith(port, "/mcp/screenshot", args.clone()); // Screenshot returns base64 PNG — wrap as image content for multimodal AI diff --git a/src-tauri/src/commands/tools_commands.rs b/src-tauri/src/commands/tools_commands.rs index c1bebda..531beca 100644 --- a/src-tauri/src/commands/tools_commands.rs +++ b/src-tauri/src/commands/tools_commands.rs @@ -94,6 +94,13 @@ pub struct GeneratedKey { pub fn tool_generate_ssh_key( key_type: String, comment: Option, +) -> Result { + tool_generate_ssh_key_inner(&key_type, comment) +} + +pub fn tool_generate_ssh_key_inner( + key_type: &str, + comment: Option, ) -> Result { use ssh_key::{Algorithm, HashAlg, LineEnding}; @@ -136,6 +143,16 @@ pub fn tool_generate_password( lowercase: Option, digits: Option, symbols: Option, +) -> Result { + tool_generate_password_inner(length, uppercase, lowercase, digits, symbols) +} + +pub fn tool_generate_password_inner( + length: Option, + uppercase: Option, + lowercase: Option, + digits: Option, + symbols: Option, ) -> Result { use rand::Rng; diff --git a/src-tauri/src/commands/tools_commands_r2.rs b/src-tauri/src/commands/tools_commands_r2.rs index 5438525..d1307f1 100644 --- a/src-tauri/src/commands/tools_commands_r2.rs +++ b/src-tauri/src/commands/tools_commands_r2.rs @@ -103,9 +103,12 @@ pub struct SubnetInfo { /// Pure Rust subnet calculator — no SSH session needed. #[tauri::command] -pub fn tool_subnet_calc( - cidr: String, -) -> Result { +pub fn tool_subnet_calc(cidr: String) -> Result { + tool_subnet_calc_inner(&cidr) +} + +pub fn tool_subnet_calc_inner(cidr: &str) -> Result { + let cidr = cidr.to_string(); let parts: Vec<&str> = cidr.split('/').collect(); if parts.len() != 2 { return Err("Expected CIDR notation: e.g. 192.168.1.0/24".to_string()); diff --git a/src-tauri/src/mcp/server.rs b/src-tauri/src/mcp/server.rs index a0dce4e..334304d 100644 --- a/src-tauri/src/mcp/server.rs +++ b/src-tauri/src/mcp/server.rs @@ -223,6 +223,122 @@ async fn handle_terminal_execute( } } +// ── Tool handlers (all tools exposed to AI via MCP) ────────────────────────── + +#[derive(Deserialize)] +struct ToolSessionTarget { session_id: String, target: String } + +#[derive(Deserialize)] +struct ToolSessionOnly { session_id: String } + +#[derive(Deserialize)] +struct ToolDnsRequest { session_id: String, domain: String, record_type: Option } + +#[derive(Deserialize)] +struct ToolWolRequest { session_id: String, mac_address: String } + +#[derive(Deserialize)] +struct ToolScanNetworkRequest { session_id: String, subnet: String } + +#[derive(Deserialize)] +struct ToolScanPortsRequest { session_id: String, target: String, ports: Option> } + +#[derive(Deserialize)] +struct ToolSubnetRequest { cidr: String } + +#[derive(Deserialize)] +struct ToolKeygenRequest { key_type: String, comment: Option } + +#[derive(Deserialize)] +struct ToolPassgenRequest { length: Option, uppercase: Option, lowercase: Option, digits: Option, symbols: Option } + +async fn handle_tool_ping(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } +} + +async fn handle_tool_traceroute(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } +} + +async fn handle_tool_dns(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + let rt = req.record_type.unwrap_or_else(|| "A".to_string()); + match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", req.domain, rt, rt, req.domain, rt, req.domain)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } +} + +async fn handle_tool_whois(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } +} + +async fn handle_tool_wol(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + let mac_clean = req.mac_address.replace([':', '-'], ""); + let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex('{}');pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, mac_clean, req.mac_address); + match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } +} + +async fn handle_tool_scan_network(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + match crate::scanner::scan_network(&session.handle, &req.subnet).await { + Ok(hosts) => ok_response(serde_json::to_value(hosts).unwrap_or_default()), + Err(e) => err_response(e), + } +} + +async fn handle_tool_scan_ports(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + let result = if let Some(ports) = req.ports { + crate::scanner::scan_ports(&session.handle, &req.target, &ports).await + } else { + crate::scanner::quick_port_scan(&session.handle, &req.target).await + }; + match result { Ok(r) => ok_response(serde_json::to_value(r).unwrap_or_default()), Err(e) => err_response(e) } +} + +async fn handle_tool_subnet(_state: AxumState>, Json(req): Json) -> Json> { + match crate::commands::tools_commands_r2::tool_subnet_calc_inner(&req.cidr) { + Ok(info) => ok_response(serde_json::to_value(info).unwrap_or_default()), + Err(e) => err_response(e), + } +} + +async fn handle_tool_bandwidth(AxumState(state): AxumState>, Json(req): Json) -> Json> { + let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; + let cmd = r#"if command -v speedtest-cli >/dev/null 2>&1; then speedtest-cli --simple 2>&1; elif command -v curl >/dev/null 2>&1; then curl -o /dev/null -w "Download: %{speed_download} bytes/sec\n" https://speed.cloudflare.com/__down?bytes=25000000 2>/dev/null; else echo "No speedtest tool found"; fi"#; + match tool_exec(&session.handle, cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } +} + +async fn handle_tool_keygen(_state: AxumState>, Json(req): Json) -> Json> { + match crate::commands::tools_commands::tool_generate_ssh_key_inner(&req.key_type, req.comment) { + Ok(key) => ok_response(serde_json::to_value(key).unwrap_or_default()), + Err(e) => err_response(e), + } +} + +async fn handle_tool_passgen(_state: AxumState>, Json(req): Json) -> Json> { + match crate::commands::tools_commands::tool_generate_password_inner(req.length, req.uppercase, req.lowercase, req.digits, req.symbols) { + Ok(pw) => ok_response(pw), + Err(e) => err_response(e), + } +} + +async fn tool_exec(handle: &std::sync::Arc>>, cmd: &str) -> Result { + let mut channel = { let h = handle.lock().await; h.channel_open_session().await.map_err(|e| format!("Exec failed: {}", e))? }; + channel.exec(true, cmd).await.map_err(|e| format!("Exec failed: {}", e))?; + let mut output = String::new(); + loop { + match channel.wait().await { + Some(russh::ChannelMsg::Data { ref data }) => { if let Ok(t) = std::str::from_utf8(data.as_ref()) { output.push_str(t); } } + Some(russh::ChannelMsg::Eof) | Some(russh::ChannelMsg::Close) | None => break, + _ => {} + } + } + Ok(output) +} + /// Start the MCP HTTP server and write the port to disk. pub async fn start_mcp_server( ssh: SshService, @@ -240,6 +356,17 @@ pub async fn start_mcp_server( .route("/mcp/sftp/list", post(handle_sftp_list)) .route("/mcp/sftp/read", post(handle_sftp_read)) .route("/mcp/sftp/write", post(handle_sftp_write)) + .route("/mcp/tool/ping", post(handle_tool_ping)) + .route("/mcp/tool/traceroute", post(handle_tool_traceroute)) + .route("/mcp/tool/dns", post(handle_tool_dns)) + .route("/mcp/tool/whois", post(handle_tool_whois)) + .route("/mcp/tool/wol", post(handle_tool_wol)) + .route("/mcp/tool/scan-network", post(handle_tool_scan_network)) + .route("/mcp/tool/scan-ports", post(handle_tool_scan_ports)) + .route("/mcp/tool/subnet", post(handle_tool_subnet)) + .route("/mcp/tool/bandwidth", post(handle_tool_bandwidth)) + .route("/mcp/tool/keygen", post(handle_tool_keygen)) + .route("/mcp/tool/passgen", post(handle_tool_passgen)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await