wraith/src-tauri/src/mcp/server.rs
Vantz Stockwell bc608b0683
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Failing after 15s
feat: copilot QoL — resizable panel, SFTP tools, context, error watcher
Resizable panel:
- Drag handle on left border of copilot panel
- Pointer events for smooth resize (320px–1200px range)

SFTP MCP tools:
- sftp_list: list remote directories
- sftp_read: read remote files
- sftp_write: write remote files
- Full HTTP endpoints + bridge tool definitions

Active session context:
- mcp_get_session_context command returns last 20 lines of scrollback
- Frontend can call on tab switch to keep AI informed

Error watcher:
- Background scanner runs every 2 seconds across all sessions
- 20+ error patterns (permission denied, OOM, segfault, disk full, etc.)
- Emits mcp:error events to frontend with session ID and matched line
- Sessions auto-registered with watcher on connect

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-24 23:30:12 -04:00

263 lines
7.7 KiB
Rust

//! Tiny HTTP server for MCP bridge communication.
//!
//! Runs on localhost:0 (random port) at Tauri startup. The port is written
//! to ~/.wraith/mcp-port so the bridge binary can find it.
use std::sync::Arc;
use axum::{extract::State as AxumState, routing::post, Json, Router};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use crate::mcp::ScrollbackRegistry;
use crate::rdp::RdpService;
use crate::sftp::SftpService;
use crate::ssh::session::SshService;
/// Shared state passed to axum handlers.
pub struct McpServerState {
pub ssh: SshService,
pub rdp: RdpService,
pub sftp: SftpService,
pub scrollback: ScrollbackRegistry,
}
#[derive(Deserialize)]
struct TerminalReadRequest {
session_id: String,
lines: Option<usize>,
}
#[derive(Deserialize)]
struct ScreenshotRequest {
session_id: String,
}
#[derive(Deserialize)]
struct SftpListRequest {
session_id: String,
path: String,
}
#[derive(Deserialize)]
struct SftpReadRequest {
session_id: String,
path: String,
}
#[derive(Deserialize)]
struct SftpWriteRequest {
session_id: String,
path: String,
content: String,
}
#[derive(Deserialize)]
struct TerminalExecuteRequest {
session_id: String,
command: String,
timeout_ms: Option<u64>,
}
#[derive(Serialize)]
struct McpResponse<T: Serialize> {
ok: bool,
data: Option<T>,
error: Option<String>,
}
fn ok_response<T: Serialize>(data: T) -> Json<McpResponse<T>> {
Json(McpResponse { ok: true, data: Some(data), error: None })
}
fn err_response<T: Serialize>(msg: String) -> Json<McpResponse<T>> {
Json(McpResponse { ok: false, data: None, error: Some(msg) })
}
async fn handle_list_sessions(
AxumState(state): AxumState<Arc<McpServerState>>,
) -> Json<McpResponse<Vec<serde_json::Value>>> {
let mut sessions: Vec<serde_json::Value> = state.ssh.list_sessions()
.into_iter()
.map(|s| serde_json::json!({
"id": s.id,
"type": "ssh",
"name": format!("{}@{}:{}", s.username, s.hostname, s.port),
"host": s.hostname,
"username": s.username,
}))
.collect();
// Include RDP sessions
for s in state.rdp.list_sessions() {
sessions.push(serde_json::json!({
"id": s.id,
"type": "rdp",
"name": s.hostname.clone(),
"host": s.hostname,
"width": s.width,
"height": s.height,
}));
}
ok_response(sessions)
}
async fn handle_sftp_list(
AxumState(state): AxumState<Arc<McpServerState>>,
Json(req): Json<SftpListRequest>,
) -> Json<McpResponse<Vec<serde_json::Value>>> {
match state.sftp.list(&req.session_id, &req.path).await {
Ok(entries) => {
let items: Vec<serde_json::Value> = entries.into_iter().map(|e| {
serde_json::json!({
"name": e.name,
"path": e.path,
"size": e.size,
"is_dir": e.is_dir,
"modified": e.mod_time,
})
}).collect();
ok_response(items)
}
Err(e) => err_response(e),
}
}
async fn handle_sftp_read(
AxumState(state): AxumState<Arc<McpServerState>>,
Json(req): Json<SftpReadRequest>,
) -> Json<McpResponse<String>> {
match state.sftp.read_file(&req.session_id, &req.path).await {
Ok(content) => ok_response(content),
Err(e) => err_response(e),
}
}
async fn handle_sftp_write(
AxumState(state): AxumState<Arc<McpServerState>>,
Json(req): Json<SftpWriteRequest>,
) -> Json<McpResponse<String>> {
match state.sftp.write_file(&req.session_id, &req.path, &req.content).await {
Ok(()) => ok_response("OK".to_string()),
Err(e) => err_response(e),
}
}
async fn handle_screenshot(
AxumState(state): AxumState<Arc<McpServerState>>,
Json(req): Json<ScreenshotRequest>,
) -> Json<McpResponse<String>> {
match state.rdp.screenshot_png_base64(&req.session_id).await {
Ok(b64) => ok_response(b64),
Err(e) => err_response(e),
}
}
async fn handle_terminal_read(
AxumState(state): AxumState<Arc<McpServerState>>,
Json(req): Json<TerminalReadRequest>,
) -> Json<McpResponse<String>> {
let n = req.lines.unwrap_or(50);
match state.scrollback.get(&req.session_id) {
Some(buf) => ok_response(buf.read_lines(n)),
None => err_response(format!("No scrollback buffer for session {}", req.session_id)),
}
}
async fn handle_terminal_execute(
AxumState(state): AxumState<Arc<McpServerState>>,
Json(req): Json<TerminalExecuteRequest>,
) -> Json<McpResponse<String>> {
let timeout = req.timeout_ms.unwrap_or(5000);
let marker = "__WRAITH_MCP_DONE__";
let buf = match state.scrollback.get(&req.session_id) {
Some(b) => b,
None => return err_response(format!("No scrollback buffer for session {}", req.session_id)),
};
let before = buf.total_written();
let full_cmd = format!("{}\necho {}\n", req.command, marker);
if let Err(e) = state.ssh.write(&req.session_id, full_cmd.as_bytes()).await {
return err_response(e);
}
let start = std::time::Instant::now();
let timeout_dur = std::time::Duration::from_millis(timeout);
loop {
if start.elapsed() > timeout_dur {
let raw = buf.read_raw();
let total = buf.total_written();
let new_bytes = total.saturating_sub(before);
let output = if new_bytes > 0 && raw.len() >= new_bytes {
&raw[raw.len() - new_bytes.min(raw.len())..]
} else {
""
};
return ok_response(format!("[timeout after {}ms]\n{}", timeout, output));
}
let raw = buf.read_raw();
if raw.contains(marker) {
let total = buf.total_written();
let new_bytes = total.saturating_sub(before);
let output = if new_bytes > 0 && raw.len() >= new_bytes {
raw[raw.len() - new_bytes.min(raw.len())..].to_string()
} else {
String::new()
};
let clean = output
.lines()
.filter(|line| !line.contains(marker))
.collect::<Vec<_>>()
.join("\n");
return ok_response(clean.trim().to_string());
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
}
/// Start the MCP HTTP server and write the port to disk.
pub async fn start_mcp_server(
ssh: SshService,
rdp: RdpService,
sftp: SftpService,
scrollback: ScrollbackRegistry,
) -> Result<u16, String> {
let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback });
let app = Router::new()
.route("/mcp/sessions", post(handle_list_sessions))
.route("/mcp/terminal/read", post(handle_terminal_read))
.route("/mcp/terminal/execute", post(handle_terminal_execute))
.route("/mcp/screenshot", post(handle_screenshot))
.route("/mcp/sftp/list", post(handle_sftp_list))
.route("/mcp/sftp/read", post(handle_sftp_read))
.route("/mcp/sftp/write", post(handle_sftp_write))
.with_state(state);
let listener = TcpListener::bind("127.0.0.1:0").await
.map_err(|e| format!("Failed to bind MCP server: {}", e))?;
let port = listener.local_addr()
.map_err(|e| format!("Failed to get MCP server port: {}", e))?
.port();
// Write port to well-known location
let port_file = crate::data_directory().join("mcp-port");
std::fs::write(&port_file, port.to_string())
.map_err(|e| format!("Failed to write MCP port file: {}", e))?;
tokio::spawn(async move {
axum::serve(listener, app).await.ok();
});
Ok(port)
}