Compare commits

..

No commits in common. "3842d483901a0d23feb3e3d8efda458ed677b02e" and "ff9fc798c3fff68b9c3ebf63485ca04dae9907be" have entirely different histories.

25 changed files with 229 additions and 493 deletions

1
.gitignore vendored
View File

@ -4,4 +4,3 @@ src-tauri/target/
src-tauri/binaries/ src-tauri/binaries/
*.log *.log
.DS_Store .DS_Store
.claude/worktrees/

2
src-tauri/Cargo.lock generated
View File

@ -8913,11 +8913,9 @@ dependencies = [
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-util",
"ureq", "ureq",
"uuid", "uuid",
"x509-cert", "x509-cert",
"zeroize",
] ]
[[package]] [[package]]

View File

@ -12,15 +12,11 @@ crate-type = ["lib", "cdylib", "staticlib"]
name = "wraith-mcp-bridge" name = "wraith-mcp-bridge"
path = "src/bin/wraith_mcp_bridge.rs" path = "src/bin/wraith_mcp_bridge.rs"
[features]
default = []
devtools = ["tauri/devtools"]
[build-dependencies] [build-dependencies]
tauri-build = { version = "2", features = [] } tauri-build = { version = "2", features = [] }
[dependencies] [dependencies]
tauri = { version = "2", features = [] } tauri = { version = "2", features = ["devtools"] }
tauri-plugin-shell = "2" tauri-plugin-shell = "2"
tauri-plugin-updater = "2" tauri-plugin-updater = "2"
anyhow = "1" anyhow = "1"
@ -37,8 +33,6 @@ uuid = { version = "1", features = ["v4"] }
base64 = "0.22" base64 = "0.22"
dashmap = "6" dashmap = "6"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-util = "0.7"
zeroize = { version = "1", features = ["derive"] }
async-trait = "0.1" async-trait = "0.1"
log = "0.4" log = "0.4"
env_logger = "0.11" env_logger = "0.11"

View File

@ -38,22 +38,19 @@ struct JsonRpcError {
message: String, message: String,
} }
fn get_data_dir() -> Result<std::path::PathBuf, String> { fn get_mcp_port() -> Result<u16, String> {
if let Ok(appdata) = std::env::var("APPDATA") { // Check standard locations for the port file
Ok(std::path::PathBuf::from(appdata).join("Wraith")) let port_file = if let Ok(appdata) = std::env::var("APPDATA") {
std::path::PathBuf::from(appdata).join("Wraith").join("mcp-port")
} else if let Ok(home) = std::env::var("HOME") { } else if let Ok(home) = std::env::var("HOME") {
if cfg!(target_os = "macos") { if cfg!(target_os = "macos") {
Ok(std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith")) std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith").join("mcp-port")
} else { } else {
Ok(std::path::PathBuf::from(home).join(".local").join("share").join("wraith")) std::path::PathBuf::from(home).join(".local").join("share").join("wraith").join("mcp-port")
} }
} else { } else {
Err("Cannot determine data directory".to_string()) return Err("Cannot determine data directory".to_string());
} };
}
fn get_mcp_port() -> Result<u16, String> {
let port_file = get_data_dir()?.join("mcp-port");
let port_str = std::fs::read_to_string(&port_file) let port_str = std::fs::read_to_string(&port_file)
.map_err(|e| format!("Cannot read MCP port file at {}: {} — is Wraith running?", port_file.display(), e))?; .map_err(|e| format!("Cannot read MCP port file at {}: {} — is Wraith running?", port_file.display(), e))?;
@ -62,15 +59,6 @@ fn get_mcp_port() -> Result<u16, String> {
.map_err(|e| format!("Invalid port in MCP port file: {}", e)) .map_err(|e| format!("Invalid port in MCP port file: {}", e))
} }
fn get_mcp_token() -> Result<String, String> {
let token_file = get_data_dir()?.join("mcp-token");
let token = std::fs::read_to_string(&token_file)
.map_err(|e| format!("Cannot read MCP token file at {}: {} — is Wraith running?", token_file.display(), e))?;
Ok(token.trim().to_string())
}
fn handle_initialize(id: Value) -> JsonRpcResponse { fn handle_initialize(id: Value) -> JsonRpcResponse {
JsonRpcResponse { JsonRpcResponse {
jsonrpc: "2.0".to_string(), jsonrpc: "2.0".to_string(),
@ -316,13 +304,12 @@ fn handle_tools_list(id: Value) -> JsonRpcResponse {
} }
} }
fn call_wraith(port: u16, token: &str, endpoint: &str, body: Value) -> Result<Value, String> { fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result<Value, String> {
let url = format!("http://127.0.0.1:{}{}", port, endpoint); let url = format!("http://127.0.0.1:{}{}", port, endpoint);
let body_str = serde_json::to_string(&body).unwrap_or_default(); let body_str = serde_json::to_string(&body).unwrap_or_default();
let mut resp = ureq::post(url) let mut resp = ureq::post(url)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", &format!("Bearer {}", token))
.send(body_str.as_bytes()) .send(body_str.as_bytes())
.map_err(|e| format!("HTTP request to Wraith failed: {}", e))?; .map_err(|e| format!("HTTP request to Wraith failed: {}", e))?;
@ -340,40 +327,40 @@ fn call_wraith(port: u16, token: &str, endpoint: &str, body: Value) -> Result<Va
} }
} }
fn handle_tool_call(id: Value, port: u16, token: &str, tool_name: &str, args: &Value) -> JsonRpcResponse { fn handle_tool_call(id: Value, port: u16, tool_name: &str, args: &Value) -> JsonRpcResponse {
let result = match tool_name { let result = match tool_name {
"list_sessions" => call_wraith(port, token, "/mcp/sessions", serde_json::json!({})), "list_sessions" => call_wraith(port, "/mcp/sessions", serde_json::json!({})),
"terminal_type" => call_wraith(port, token, "/mcp/terminal/type", args.clone()), "terminal_type" => call_wraith(port, "/mcp/terminal/type", args.clone()),
"terminal_read" => call_wraith(port, token, "/mcp/terminal/read", args.clone()), "terminal_read" => call_wraith(port, "/mcp/terminal/read", args.clone()),
"terminal_execute" => call_wraith(port, token, "/mcp/terminal/execute", args.clone()), "terminal_execute" => call_wraith(port, "/mcp/terminal/execute", args.clone()),
"sftp_list" => call_wraith(port, token, "/mcp/sftp/list", args.clone()), "sftp_list" => call_wraith(port, "/mcp/sftp/list", args.clone()),
"sftp_read" => call_wraith(port, token, "/mcp/sftp/read", args.clone()), "sftp_read" => call_wraith(port, "/mcp/sftp/read", args.clone()),
"sftp_write" => call_wraith(port, token, "/mcp/sftp/write", args.clone()), "sftp_write" => call_wraith(port, "/mcp/sftp/write", args.clone()),
"network_scan" => call_wraith(port, token, "/mcp/tool/scan-network", args.clone()), "network_scan" => call_wraith(port, "/mcp/tool/scan-network", args.clone()),
"port_scan" => call_wraith(port, token, "/mcp/tool/scan-ports", args.clone()), "port_scan" => call_wraith(port, "/mcp/tool/scan-ports", args.clone()),
"ping" => call_wraith(port, token, "/mcp/tool/ping", args.clone()), "ping" => call_wraith(port, "/mcp/tool/ping", args.clone()),
"traceroute" => call_wraith(port, token, "/mcp/tool/traceroute", args.clone()), "traceroute" => call_wraith(port, "/mcp/tool/traceroute", args.clone()),
"dns_lookup" => call_wraith(port, token, "/mcp/tool/dns", args.clone()), "dns_lookup" => call_wraith(port, "/mcp/tool/dns", args.clone()),
"whois" => call_wraith(port, token, "/mcp/tool/whois", args.clone()), "whois" => call_wraith(port, "/mcp/tool/whois", args.clone()),
"wake_on_lan" => call_wraith(port, token, "/mcp/tool/wol", args.clone()), "wake_on_lan" => call_wraith(port, "/mcp/tool/wol", args.clone()),
"bandwidth_test" => call_wraith(port, token, "/mcp/tool/bandwidth", args.clone()), "bandwidth_test" => call_wraith(port, "/mcp/tool/bandwidth", args.clone()),
"subnet_calc" => call_wraith(port, token, "/mcp/tool/subnet", args.clone()), "subnet_calc" => call_wraith(port, "/mcp/tool/subnet", args.clone()),
"generate_ssh_key" => call_wraith(port, token, "/mcp/tool/keygen", args.clone()), "generate_ssh_key" => call_wraith(port, "/mcp/tool/keygen", args.clone()),
"generate_password" => call_wraith(port, token, "/mcp/tool/passgen", args.clone()), "generate_password" => call_wraith(port, "/mcp/tool/passgen", args.clone()),
"docker_ps" => call_wraith(port, token, "/mcp/docker/ps", args.clone()), "docker_ps" => call_wraith(port, "/mcp/docker/ps", args.clone()),
"docker_action" => call_wraith(port, token, "/mcp/docker/action", args.clone()), "docker_action" => call_wraith(port, "/mcp/docker/action", args.clone()),
"docker_exec" => call_wraith(port, token, "/mcp/docker/exec", args.clone()), "docker_exec" => call_wraith(port, "/mcp/docker/exec", args.clone()),
"service_status" => call_wraith(port, token, "/mcp/service/status", args.clone()), "service_status" => call_wraith(port, "/mcp/service/status", args.clone()),
"process_list" => call_wraith(port, token, "/mcp/process/list", args.clone()), "process_list" => call_wraith(port, "/mcp/process/list", args.clone()),
"git_status" => call_wraith(port, token, "/mcp/git/status", args.clone()), "git_status" => call_wraith(port, "/mcp/git/status", args.clone()),
"git_pull" => call_wraith(port, token, "/mcp/git/pull", args.clone()), "git_pull" => call_wraith(port, "/mcp/git/pull", args.clone()),
"git_log" => call_wraith(port, token, "/mcp/git/log", args.clone()), "git_log" => call_wraith(port, "/mcp/git/log", args.clone()),
"rdp_click" => call_wraith(port, token, "/mcp/rdp/click", args.clone()), "rdp_click" => call_wraith(port, "/mcp/rdp/click", args.clone()),
"rdp_type" => call_wraith(port, token, "/mcp/rdp/type", args.clone()), "rdp_type" => call_wraith(port, "/mcp/rdp/type", args.clone()),
"rdp_clipboard" => call_wraith(port, token, "/mcp/rdp/clipboard", args.clone()), "rdp_clipboard" => call_wraith(port, "/mcp/rdp/clipboard", args.clone()),
"ssh_connect" => call_wraith(port, token, "/mcp/ssh/connect", args.clone()), "ssh_connect" => call_wraith(port, "/mcp/ssh/connect", args.clone()),
"terminal_screenshot" => { "terminal_screenshot" => {
let result = call_wraith(port, token, "/mcp/screenshot", args.clone()); let result = call_wraith(port, "/mcp/screenshot", args.clone());
// Screenshot returns base64 PNG — wrap as image content for multimodal AI // Screenshot returns base64 PNG — wrap as image content for multimodal AI
return match result { return match result {
Ok(b64) => JsonRpcResponse { Ok(b64) => JsonRpcResponse {
@ -433,14 +420,6 @@ fn main() {
} }
}; };
let token = match get_mcp_token() {
Ok(t) => t,
Err(e) => {
eprintln!("wraith-mcp-bridge: {}", e);
std::process::exit(1);
}
};
let stdin = io::stdin(); let stdin = io::stdin();
let mut stdout = io::stdout(); let mut stdout = io::stdout();
@ -479,7 +458,7 @@ fn main() {
let args = request.params.get("arguments") let args = request.params.get("arguments")
.cloned() .cloned()
.unwrap_or(Value::Object(serde_json::Map::new())); .unwrap_or(Value::Object(serde_json::Map::new()));
handle_tool_call(request.id, port, &token, tool_name, &args) handle_tool_call(request.id, port, tool_name, &args)
} }
"notifications/initialized" | "notifications/cancelled" => { "notifications/initialized" | "notifications/cancelled" => {
// Notifications don't get responses // Notifications don't get responses

View File

@ -3,16 +3,34 @@ use tauri::State;
use crate::credentials::Credential; use crate::credentials::Credential;
use crate::AppState; use crate::AppState;
/// Guard helper: lock the credentials mutex and return a ref to the inner
/// `CredentialService`, or a "Vault is locked" error if the vault has not
/// been unlocked for this session.
///
/// This is a macro rather than a function because returning a `MutexGuard`
/// from a helper function would require lifetime annotations that complicate
/// the tauri command signatures unnecessarily.
macro_rules! require_unlocked {
($state:expr) => {{
let guard = $state
.credentials
.lock()
.map_err(|_| "Credentials mutex was poisoned".to_string())?;
if guard.is_none() {
return Err("Vault is locked — call unlock before accessing credentials".into());
}
// SAFETY: we just checked `is_none` above, so `unwrap` cannot panic.
guard
}};
}
/// Return all credentials ordered by name. /// Return all credentials ordered by name.
/// ///
/// Secret values (passwords, private keys) are never included — only metadata. /// Secret values (passwords, private keys) are never included — only metadata.
#[tauri::command] #[tauri::command]
pub async fn list_credentials(state: State<'_, AppState>) -> Result<Vec<Credential>, String> { pub fn list_credentials(state: State<'_, AppState>) -> Result<Vec<Credential>, String> {
let guard = state.credentials.lock().await; let guard = require_unlocked!(state);
let svc = guard guard.as_ref().unwrap().list()
.as_ref()
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
svc.list()
} }
/// Store a new username/password credential. /// Store a new username/password credential.
@ -21,18 +39,18 @@ pub async fn list_credentials(state: State<'_, AppState>) -> Result<Vec<Credenti
/// Returns the created credential record (without the plaintext password). /// Returns the created credential record (without the plaintext password).
/// `domain` is `None` for non-domain credentials; `Some("")` is treated as NULL. /// `domain` is `None` for non-domain credentials; `Some("")` is treated as NULL.
#[tauri::command] #[tauri::command]
pub async fn create_password( pub fn create_password(
name: String, name: String,
username: String, username: String,
password: String, password: String,
domain: Option<String>, domain: Option<String>,
state: State<'_, AppState>, state: State<'_, AppState>,
) -> Result<Credential, String> { ) -> Result<Credential, String> {
let guard = state.credentials.lock().await; let guard = require_unlocked!(state);
let svc = guard guard
.as_ref() .as_ref()
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?; .unwrap()
svc.create_password(name, username, password, domain) .create_password(name, username, password, domain)
} }
/// Store a new SSH private key credential. /// Store a new SSH private key credential.
@ -41,18 +59,18 @@ pub async fn create_password(
/// Pass `None` for `passphrase` when the key has no passphrase. /// Pass `None` for `passphrase` when the key has no passphrase.
/// Returns the created credential record without any secret material. /// Returns the created credential record without any secret material.
#[tauri::command] #[tauri::command]
pub async fn create_ssh_key( pub fn create_ssh_key(
name: String, name: String,
username: String, username: String,
private_key_pem: String, private_key_pem: String,
passphrase: Option<String>, passphrase: Option<String>,
state: State<'_, AppState>, state: State<'_, AppState>,
) -> Result<Credential, String> { ) -> Result<Credential, String> {
let guard = state.credentials.lock().await; let guard = require_unlocked!(state);
let svc = guard guard
.as_ref() .as_ref()
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?; .unwrap()
svc.create_ssh_key(name, username, private_key_pem, passphrase) .create_ssh_key(name, username, private_key_pem, passphrase)
} }
/// Delete a credential by id. /// Delete a credential by id.
@ -60,30 +78,21 @@ pub async fn create_ssh_key(
/// For SSH key credentials, the associated `ssh_keys` row is also deleted. /// For SSH key credentials, the associated `ssh_keys` row is also deleted.
/// Returns `Err` if the vault is locked or the id does not exist. /// Returns `Err` if the vault is locked or the id does not exist.
#[tauri::command] #[tauri::command]
pub async fn delete_credential(id: i64, state: State<'_, AppState>) -> Result<(), String> { pub fn delete_credential(id: i64, state: State<'_, AppState>) -> Result<(), String> {
let guard = state.credentials.lock().await; let guard = require_unlocked!(state);
let svc = guard guard.as_ref().unwrap().delete(id)
.as_ref()
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
svc.delete(id)
} }
/// Decrypt and return the password for a credential. /// Decrypt and return the password for a credential.
#[tauri::command] #[tauri::command]
pub async fn decrypt_password(credential_id: i64, state: State<'_, AppState>) -> Result<String, String> { pub fn decrypt_password(credential_id: i64, state: State<'_, AppState>) -> Result<String, String> {
let guard = state.credentials.lock().await; let guard = require_unlocked!(state);
let svc = guard guard.as_ref().unwrap().decrypt_password(credential_id)
.as_ref()
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
svc.decrypt_password(credential_id)
} }
/// Decrypt and return the SSH private key and passphrase. /// Decrypt and return the SSH private key and passphrase.
#[tauri::command] #[tauri::command]
pub async fn decrypt_ssh_key(ssh_key_id: i64, state: State<'_, AppState>) -> Result<(String, String), String> { pub fn decrypt_ssh_key(ssh_key_id: i64, state: State<'_, AppState>) -> Result<(String, String), String> {
let guard = state.credentials.lock().await; let guard = require_unlocked!(state);
let svc = guard guard.as_ref().unwrap().decrypt_ssh_key(ssh_key_id)
.as_ref()
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
svc.decrypt_ssh_key(ssh_key_id)
} }

View File

@ -3,7 +3,6 @@
use tauri::State; use tauri::State;
use serde::Serialize; use serde::Serialize;
use crate::AppState; use crate::AppState;
use crate::utils::shell_escape;
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -85,15 +84,14 @@ pub async fn docker_list_volumes(session_id: String, state: State<'_, AppState>)
#[tauri::command] #[tauri::command]
pub async fn docker_action(session_id: String, action: String, target: String, state: State<'_, AppState>) -> Result<String, String> { pub async fn docker_action(session_id: String, action: String, target: String, state: State<'_, AppState>) -> Result<String, String> {
let session = state.ssh.get_session(&session_id).ok_or("Session not found")?; let session = state.ssh.get_session(&session_id).ok_or("Session not found")?;
let t = shell_escape(&target);
let cmd = match action.as_str() { let cmd = match action.as_str() {
"start" => format!("docker start {} 2>&1", t), "start" => format!("docker start {} 2>&1", target),
"stop" => format!("docker stop {} 2>&1", t), "stop" => format!("docker stop {} 2>&1", target),
"restart" => format!("docker restart {} 2>&1", t), "restart" => format!("docker restart {} 2>&1", target),
"remove" => format!("docker rm -f {} 2>&1", t), "remove" => format!("docker rm -f {} 2>&1", target),
"logs" => format!("docker logs --tail 100 {} 2>&1", t), "logs" => format!("docker logs --tail 100 {} 2>&1", target),
"remove-image" => format!("docker rmi {} 2>&1", t), "remove-image" => format!("docker rmi {} 2>&1", target),
"remove-volume" => format!("docker volume rm {} 2>&1", t), "remove-volume" => format!("docker volume rm {} 2>&1", target),
"builder-prune" => "docker builder prune -f 2>&1".to_string(), "builder-prune" => "docker builder prune -f 2>&1".to_string(),
"system-prune" => "docker system prune -f 2>&1".to_string(), "system-prune" => "docker system prune -f 2>&1".to_string(),
"system-prune-all" => "docker system prune -a -f 2>&1".to_string(), "system-prune-all" => "docker system prune -a -f 2>&1".to_string(),

View File

@ -4,7 +4,6 @@
//! delegate to the `RdpService` via `State<AppState>`. //! delegate to the `RdpService` via `State<AppState>`.
use tauri::{AppHandle, State}; use tauri::{AppHandle, State};
use tauri::ipc::Response;
use crate::rdp::{RdpConfig, RdpSessionInfo}; use crate::rdp::{RdpConfig, RdpSessionInfo};
use crate::AppState; use crate::AppState;
@ -19,18 +18,16 @@ pub fn connect_rdp(
state.rdp.connect(config, app_handle) state.rdp.connect(config, app_handle)
} }
/// Get the current frame buffer as raw RGBA bytes via binary IPC. /// Get the current frame buffer as raw RGBA bytes (binary IPC — no base64).
/// ///
/// Uses `tauri::ipc::Response` to return raw bytes without JSON serialization.
/// Pixel format: RGBA, 4 bytes per pixel, row-major, top-left origin. /// Pixel format: RGBA, 4 bytes per pixel, row-major, top-left origin.
/// Returns empty payload if frame hasn't changed since last call. /// Returns empty Vec if frame hasn't changed since last call.
#[tauri::command] #[tauri::command]
pub async fn rdp_get_frame( pub async fn rdp_get_frame(
session_id: String, session_id: String,
state: State<'_, AppState>, state: State<'_, AppState>,
) -> Result<Response, String> { ) -> Result<Vec<u8>, String> {
let frame = state.rdp.get_frame(&session_id).await?; state.rdp.get_frame(&session_id).await
Ok(Response::new(frame))
} }
/// Send a mouse event to an RDP session. /// Send a mouse event to an RDP session.

View File

@ -4,7 +4,6 @@ use tauri::State;
use serde::Serialize; use serde::Serialize;
use crate::AppState; use crate::AppState;
use crate::utils::shell_escape;
// ── Ping ───────────────────────────────────────────────────────────────────── // ── Ping ─────────────────────────────────────────────────────────────────────
@ -26,7 +25,7 @@ pub async fn tool_ping(
let session = state.ssh.get_session(&session_id) let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?; .ok_or_else(|| format!("SSH session {} not found", session_id))?;
let n = count.unwrap_or(4); let n = count.unwrap_or(4);
let cmd = format!("ping -c {} {} 2>&1", n, shell_escape(&target)); let cmd = format!("ping -c {} {} 2>&1", n, target);
let output = exec_on_session(&session.handle, &cmd).await?; let output = exec_on_session(&session.handle, &cmd).await?;
Ok(PingResult { target, output }) Ok(PingResult { target, output })
} }
@ -40,8 +39,7 @@ pub async fn tool_traceroute(
) -> Result<String, String> { ) -> Result<String, String> {
let session = state.ssh.get_session(&session_id) let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?; .ok_or_else(|| format!("SSH session {} not found", session_id))?;
let t = shell_escape(&target); let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", target, target);
let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t);
exec_on_session(&session.handle, &cmd).await exec_on_session(&session.handle, &cmd).await
} }
@ -67,16 +65,14 @@ pub async fn tool_wake_on_lan(
let cmd = format!( let cmd = format!(
r#"python3 -c " r#"python3 -c "
import socket, struct import socket, struct
mac = bytes.fromhex({mac_clean_escaped}) mac = bytes.fromhex('{mac_clean}')
pkt = b'\xff'*6 + mac*16 pkt = b'\xff'*6 + mac*16
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
s.sendto(pkt, ('255.255.255.255', 9)) s.sendto(pkt, ('255.255.255.255', 9))
s.close() s.close()
print('WoL packet sent to {mac_display_escaped}') print('WoL packet sent to {mac_address}')
" 2>&1 || echo "python3 not available install python3 on remote host for WoL""#, " 2>&1 || echo "python3 not available install python3 on remote host for WoL""#
mac_clean_escaped = shell_escape(&mac_clean),
mac_display_escaped = shell_escape(&mac_address),
); );
exec_on_session(&session.handle, &cmd).await exec_on_session(&session.handle, &cmd).await

View File

@ -4,7 +4,6 @@ use tauri::State;
use serde::Serialize; use serde::Serialize;
use crate::AppState; use crate::AppState;
use crate::utils::shell_escape;
// ── DNS Lookup ─────────────────────────────────────────────────────────────── // ── DNS Lookup ───────────────────────────────────────────────────────────────
@ -17,11 +16,10 @@ pub async fn tool_dns_lookup(
) -> Result<String, String> { ) -> Result<String, String> {
let session = state.ssh.get_session(&session_id) let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?; .ok_or_else(|| format!("SSH session {} not found", session_id))?;
let d = shell_escape(&domain); let rtype = record_type.unwrap_or_else(|| "A".to_string());
let rt = shell_escape(&record_type.unwrap_or_else(|| "A".to_string()));
let cmd = format!( let cmd = format!(
r#"dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null"#, r#"dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null"#,
d, rt, rt, d, rt, d domain, rtype, rtype, domain, rtype, domain
); );
exec_on_session(&session.handle, &cmd).await exec_on_session(&session.handle, &cmd).await
} }
@ -36,7 +34,7 @@ pub async fn tool_whois(
) -> Result<String, String> { ) -> Result<String, String> {
let session = state.ssh.get_session(&session_id) let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?; .ok_or_else(|| format!("SSH session {} not found", session_id))?;
let cmd = format!("whois {} 2>&1 | head -80", shell_escape(&target)); let cmd = format!("whois {} 2>&1 | head -80", target);
exec_on_session(&session.handle, &cmd).await exec_on_session(&session.handle, &cmd).await
} }
@ -52,10 +50,9 @@ pub async fn tool_bandwidth_iperf(
let session = state.ssh.get_session(&session_id) let session = state.ssh.get_session(&session_id)
.ok_or_else(|| format!("SSH session {} not found", session_id))?; .ok_or_else(|| format!("SSH session {} not found", session_id))?;
let dur = duration.unwrap_or(5); let dur = duration.unwrap_or(5);
let s = shell_escape(&server);
let cmd = format!( let cmd = format!(
"iperf3 -c {} -t {} --json 2>/dev/null || iperf3 -c {} -t {} 2>&1 || echo 'iperf3 not installed — run: apt install iperf3 / brew install iperf3'", "iperf3 -c {} -t {} --json 2>/dev/null || iperf3 -c {} -t {} 2>&1 || echo 'iperf3 not installed — run: apt install iperf3 / brew install iperf3'",
s, dur, s, dur server, dur, server, dur
); );
exec_on_session(&session.handle, &cmd).await exec_on_session(&session.handle, &cmd).await
} }

View File

@ -1,5 +1,4 @@
use tauri::State; use tauri::State;
use zeroize::Zeroize;
use crate::vault::{self, VaultService}; use crate::vault::{self, VaultService};
use crate::credentials::CredentialService; use crate::credentials::CredentialService;
@ -22,15 +21,14 @@ pub fn is_first_run(state: State<'_, AppState>) -> bool {
/// Returns `Err` if the vault has already been set up or if any storage /// Returns `Err` if the vault has already been set up or if any storage
/// operation fails. /// operation fails.
#[tauri::command] #[tauri::command]
pub async fn create_vault(mut password: String, state: State<'_, AppState>) -> Result<(), String> { pub fn create_vault(password: String, state: State<'_, AppState>) -> Result<(), String> {
let result = async {
if !state.is_first_run() { if !state.is_first_run() {
return Err("Vault already exists — use unlock instead of create".into()); return Err("Vault already exists — use unlock instead of create".into());
} }
let salt = vault::generate_salt(); let salt = vault::generate_salt();
let key = vault::derive_key(&password, &salt); let key = vault::derive_key(&password, &salt);
let vs = VaultService::new(key.clone()); let vs = VaultService::new(key);
// Persist the salt so we can re-derive the key on future unlocks. // Persist the salt so we can re-derive the key on future unlocks.
state.settings.set("vault_salt", &hex::encode(salt))?; state.settings.set("vault_salt", &hex::encode(salt))?;
@ -41,14 +39,10 @@ pub async fn create_vault(mut password: String, state: State<'_, AppState>) -> R
// Activate the vault and credentials service for this session. // Activate the vault and credentials service for this session.
let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key)); let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key));
*state.credentials.lock().await = Some(cred_svc); *state.credentials.lock().unwrap() = Some(cred_svc);
*state.vault.lock().await = Some(vs); *state.vault.lock().unwrap() = Some(vs);
Ok(()) Ok(())
}.await;
password.zeroize();
result
} }
/// Unlock an existing vault using the master password. /// Unlock an existing vault using the master password.
@ -58,8 +52,7 @@ pub async fn create_vault(mut password: String, state: State<'_, AppState>) -> R
/// ///
/// Returns `Err("Incorrect master password")` if the password is wrong. /// Returns `Err("Incorrect master password")` if the password is wrong.
#[tauri::command] #[tauri::command]
pub async fn unlock(mut password: String, state: State<'_, AppState>) -> Result<(), String> { pub fn unlock(password: String, state: State<'_, AppState>) -> Result<(), String> {
let result = async {
let salt_hex = state let salt_hex = state
.settings .settings
.get("vault_salt") .get("vault_salt")
@ -69,7 +62,7 @@ pub async fn unlock(mut password: String, state: State<'_, AppState>) -> Result<
.map_err(|e| format!("Stored vault salt is corrupt: {e}"))?; .map_err(|e| format!("Stored vault salt is corrupt: {e}"))?;
let key = vault::derive_key(&password, &salt); let key = vault::derive_key(&password, &salt);
let vs = VaultService::new(key.clone()); let vs = VaultService::new(key);
// Verify the password by decrypting the check value. // Verify the password by decrypting the check value.
let check_blob = state let check_blob = state
@ -87,18 +80,14 @@ pub async fn unlock(mut password: String, state: State<'_, AppState>) -> Result<
// Activate the vault and credentials service for this session. // Activate the vault and credentials service for this session.
let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key)); let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key));
*state.credentials.lock().await = Some(cred_svc); *state.credentials.lock().unwrap() = Some(cred_svc);
*state.vault.lock().await = Some(vs); *state.vault.lock().unwrap() = Some(vs);
Ok(()) Ok(())
}.await;
password.zeroize();
result
} }
/// Returns `true` if the vault is currently unlocked for this session. /// Returns `true` if the vault is currently unlocked for this session.
#[tauri::command] #[tauri::command]
pub async fn is_unlocked(state: State<'_, AppState>) -> Result<bool, String> { pub fn is_unlocked(state: State<'_, AppState>) -> bool {
Ok(state.is_unlocked().await) state.is_unlocked()
} }

View File

@ -19,7 +19,6 @@ use crate::db::Database;
// ── domain types ────────────────────────────────────────────────────────────── // ── domain types ──────────────────────────────────────────────────────────────
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionGroup { pub struct ConnectionGroup {
pub id: i64, pub id: i64,
pub name: String, pub name: String,

View File

@ -21,9 +21,9 @@ pub mod pty;
pub mod mcp; pub mod mcp;
pub mod scanner; pub mod scanner;
pub mod commands; pub mod commands;
pub mod utils;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Mutex;
use db::Database; use db::Database;
use vault::VaultService; use vault::VaultService;
@ -41,10 +41,10 @@ use mcp::error_watcher::ErrorWatcher;
pub struct AppState { pub struct AppState {
pub db: Database, pub db: Database,
pub vault: tokio::sync::Mutex<Option<VaultService>>, pub vault: Mutex<Option<VaultService>>,
pub settings: SettingsService, pub settings: SettingsService,
pub connections: ConnectionService, pub connections: ConnectionService,
pub credentials: tokio::sync::Mutex<Option<CredentialService>>, pub credentials: Mutex<Option<CredentialService>>,
pub ssh: SshService, pub ssh: SshService,
pub sftp: SftpService, pub sftp: SftpService,
pub rdp: RdpService, pub rdp: RdpService,
@ -60,18 +60,17 @@ impl AppState {
std::fs::create_dir_all(&data_dir)?; std::fs::create_dir_all(&data_dir)?;
let database = Database::open(&data_dir.join("wraith.db"))?; let database = Database::open(&data_dir.join("wraith.db"))?;
database.migrate()?; database.migrate()?;
let settings = SettingsService::new(database.clone());
Ok(Self { Ok(Self {
db: database.clone(), db: database.clone(),
vault: tokio::sync::Mutex::new(None), vault: Mutex::new(None),
settings: SettingsService::new(database.clone()),
connections: ConnectionService::new(database.clone()), connections: ConnectionService::new(database.clone()),
credentials: tokio::sync::Mutex::new(None), credentials: Mutex::new(None),
ssh: SshService::new(database.clone()), ssh: SshService::new(database.clone()),
sftp: SftpService::new(), sftp: SftpService::new(),
rdp: RdpService::new(), rdp: RdpService::new(),
theme: ThemeService::new(database), theme: ThemeService::new(database.clone()),
workspace: WorkspaceService::new(settings.clone()), workspace: WorkspaceService::new(SettingsService::new(database.clone())),
settings,
pty: PtyService::new(), pty: PtyService::new(),
scrollback: ScrollbackRegistry::new(), scrollback: ScrollbackRegistry::new(),
error_watcher: std::sync::Arc::new(ErrorWatcher::new()), error_watcher: std::sync::Arc::new(ErrorWatcher::new()),
@ -86,8 +85,8 @@ impl AppState {
self.settings.get("vault_salt").unwrap_or_default().is_empty() self.settings.get("vault_salt").unwrap_or_default().is_empty()
} }
pub async fn is_unlocked(&self) -> bool { pub fn is_unlocked(&self) -> bool {
self.vault.lock().await.is_some() self.vault.lock().unwrap().is_some()
} }
} }

View File

@ -40,25 +40,13 @@ impl ScrollbackBuffer {
/// Append bytes to the buffer. Old data is overwritten when full. /// Append bytes to the buffer. Old data is overwritten when full.
pub fn push(&self, bytes: &[u8]) { pub fn push(&self, bytes: &[u8]) {
if bytes.is_empty() {
return;
}
let mut buf = self.inner.lock().unwrap(); let mut buf = self.inner.lock().unwrap();
let cap = buf.capacity; for &b in bytes {
// If input exceeds capacity, only keep the last `cap` bytes let pos = buf.write_pos;
let data = if bytes.len() > cap { buf.data[pos] = b;
&bytes[bytes.len() - cap..] buf.write_pos = (pos + 1) % buf.capacity;
} else { buf.total_written += 1;
bytes
};
let write_pos = buf.write_pos;
let first_len = (cap - write_pos).min(data.len());
buf.data[write_pos..write_pos + first_len].copy_from_slice(&data[..first_len]);
if first_len < data.len() {
buf.data[..data.len() - first_len].copy_from_slice(&data[first_len..]);
} }
buf.write_pos = (write_pos + data.len()) % cap;
buf.total_written += bytes.len();
} }
/// Read the last `n` lines from the buffer, with ANSI escape codes stripped. /// Read the last `n` lines from the buffer, with ANSI escape codes stripped.
@ -204,42 +192,4 @@ mod tests {
buf.push(b"ABCD"); // 4 more, wraps buf.push(b"ABCD"); // 4 more, wraps
assert_eq!(buf.total_written(), 12); assert_eq!(buf.total_written(), 12);
} }
#[test]
fn push_empty_is_noop() {
let buf = ScrollbackBuffer::with_capacity(8);
buf.push(b"hello");
buf.push(b"");
assert_eq!(buf.total_written(), 5);
assert!(buf.read_raw().contains("hello"));
}
#[test]
fn push_larger_than_capacity() {
let buf = ScrollbackBuffer::with_capacity(4);
buf.push(b"ABCDEFGH"); // 8 bytes into 4-byte buffer
let raw = buf.read_raw();
assert_eq!(raw, "EFGH"); // only last 4 bytes kept
assert_eq!(buf.total_written(), 8);
}
#[test]
fn push_exact_capacity() {
let buf = ScrollbackBuffer::with_capacity(8);
buf.push(b"12345678");
let raw = buf.read_raw();
assert_eq!(raw, "12345678");
assert_eq!(buf.total_written(), 8);
}
#[test]
fn push_wrap_around_boundary() {
let buf = ScrollbackBuffer::with_capacity(8);
buf.push(b"123456"); // write_pos = 6
buf.push(b"ABCD"); // wraps: 2 at end, 2 at start
let raw = buf.read_raw();
// Buffer: [C, D, 3, 4, 5, 6, A, B], write_pos=2
// Read from pos 2: "3456AB" + wrap: no, read from write_pos to end then start
assert_eq!(raw, "3456ABCD");
}
} }

View File

@ -5,14 +5,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{ use axum::{extract::State as AxumState, routing::post, Json, Router};
extract::State as AxumState,
http::{Request, StatusCode},
middleware::{self, Next},
response::Response,
routing::post,
Json, Router,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::TcpListener; use tokio::net::TcpListener;
@ -20,7 +13,6 @@ use crate::mcp::ScrollbackRegistry;
use crate::rdp::RdpService; use crate::rdp::RdpService;
use crate::sftp::SftpService; use crate::sftp::SftpService;
use crate::ssh::session::SshService; use crate::ssh::session::SshService;
use crate::utils::shell_escape;
/// Shared state passed to axum handlers. /// Shared state passed to axum handlers.
pub struct McpServerState { pub struct McpServerState {
@ -30,27 +22,6 @@ pub struct McpServerState {
pub scrollback: ScrollbackRegistry, pub scrollback: ScrollbackRegistry,
pub app_handle: tauri::AppHandle, pub app_handle: tauri::AppHandle,
pub error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>, pub error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
pub bearer_token: String,
}
/// Middleware that validates the `Authorization: Bearer <token>` header.
async fn auth_middleware(
AxumState(state): AxumState<Arc<McpServerState>>,
req: Request<axum::body::Body>,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let expected = format!("Bearer {}", state.bearer_token);
if auth_header != expected {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(req).await)
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -308,31 +279,29 @@ struct ToolPassgenRequest { length: Option<usize>, uppercase: Option<bool>, lowe
async fn handle_tool_ping(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> { async fn handle_tool_ping(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
async fn handle_tool_traceroute(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> { async fn handle_tool_traceroute(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let t = shell_escape(&req.target); match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
async fn handle_tool_dns(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolDnsRequest>) -> Json<McpResponse<String>> { async fn handle_tool_dns(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolDnsRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let rt = shell_escape(&req.record_type.unwrap_or_else(|| "A".to_string())); let rt = req.record_type.unwrap_or_else(|| "A".to_string());
let d = shell_escape(&req.domain); match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", req.domain, rt, rt, req.domain, rt, req.domain)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", d, rt, rt, d, rt, d)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
async fn handle_tool_whois(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> { async fn handle_tool_whois(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
async fn handle_tool_wol(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolWolRequest>) -> Json<McpResponse<String>> { async fn handle_tool_wol(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolWolRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let mac_clean = req.mac_address.replace([':', '-'], ""); let mac_clean = req.mac_address.replace([':', '-'], "");
let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex({});pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, shell_escape(&mac_clean), shell_escape(&req.mac_address)); let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex('{}');pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, mac_clean, req.mac_address);
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
@ -413,13 +382,12 @@ async fn handle_docker_ps(AxumState(state): AxumState<Arc<McpServerState>>, Json
async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerActionRequest>) -> Json<McpResponse<String>> { async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerActionRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let t = shell_escape(&req.target);
let cmd = match req.action.as_str() { let cmd = match req.action.as_str() {
"start" => format!("docker start {} 2>&1", t), "start" => format!("docker start {} 2>&1", req.target),
"stop" => format!("docker stop {} 2>&1", t), "stop" => format!("docker stop {} 2>&1", req.target),
"restart" => format!("docker restart {} 2>&1", t), "restart" => format!("docker restart {} 2>&1", req.target),
"remove" => format!("docker rm -f {} 2>&1", t), "remove" => format!("docker rm -f {} 2>&1", req.target),
"logs" => format!("docker logs --tail 100 {} 2>&1", t), "logs" => format!("docker logs --tail 100 {} 2>&1", req.target),
"builder-prune" => "docker builder prune -f 2>&1".to_string(), "builder-prune" => "docker builder prune -f 2>&1".to_string(),
"system-prune" => "docker system prune -f 2>&1".to_string(), "system-prune" => "docker system prune -f 2>&1".to_string(),
_ => return err_response(format!("Unknown action: {}", req.action)), _ => return err_response(format!("Unknown action: {}", req.action)),
@ -429,7 +397,7 @@ async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>,
async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerExecRequest>) -> Json<McpResponse<String>> { async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerExecRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let cmd = format!("docker exec {} {} 2>&1", shell_escape(&req.container), shell_escape(&req.command)); let cmd = format!("docker exec {} {} 2>&1", req.container, req.command);
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
@ -437,13 +405,12 @@ async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Js
async fn handle_service_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> { async fn handle_service_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let t = shell_escape(&req.target); match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
async fn handle_process_list(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> { async fn handle_process_list(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", shell_escape(&req.target)) }; let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", req.target) };
match tool_exec(&session.handle, &format!("ps {}", filter)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &format!("ps {}", filter)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
@ -454,17 +421,17 @@ struct GitRequest { session_id: String, path: String }
async fn handle_git_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> { async fn handle_git_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
async fn handle_git_pull(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> { async fn handle_git_pull(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
async fn handle_git_log(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> { async fn handle_git_log(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) }; let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) } match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
} }
// ── Session creation handlers ──────────────────────────────────────────────── // ── Session creation handlers ────────────────────────────────────────────────
@ -566,15 +533,7 @@ pub async fn start_mcp_server(
app_handle: tauri::AppHandle, app_handle: tauri::AppHandle,
error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>, error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
) -> Result<u16, String> { ) -> Result<u16, String> {
// Generate a cryptographically random bearer token for authentication let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher });
use rand::Rng;
let bearer_token: String = rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
.take(64)
.map(char::from)
.collect();
let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher, bearer_token: bearer_token.clone() });
let app = Router::new() let app = Router::new()
.route("/mcp/sessions", post(handle_list_sessions)) .route("/mcp/sessions", post(handle_list_sessions))
@ -608,7 +567,6 @@ pub async fn start_mcp_server(
.route("/mcp/rdp/type", post(handle_rdp_type)) .route("/mcp/rdp/type", post(handle_rdp_type))
.route("/mcp/rdp/clipboard", post(handle_rdp_clipboard)) .route("/mcp/rdp/clipboard", post(handle_rdp_clipboard))
.route("/mcp/ssh/connect", post(handle_ssh_connect)) .route("/mcp/ssh/connect", post(handle_ssh_connect))
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware))
.with_state(state); .with_state(state);
let listener = TcpListener::bind("127.0.0.1:0").await let listener = TcpListener::bind("127.0.0.1:0").await
@ -619,23 +577,10 @@ pub async fn start_mcp_server(
.port(); .port();
// Write port to well-known location // Write port to well-known location
let data_dir = crate::data_directory(); let port_file = crate::data_directory().join("mcp-port");
let port_file = data_dir.join("mcp-port");
std::fs::write(&port_file, port.to_string()) std::fs::write(&port_file, port.to_string())
.map_err(|e| format!("Failed to write MCP port file: {}", e))?; .map_err(|e| format!("Failed to write MCP port file: {}", e))?;
// Write bearer token to a separate file with restrictive permissions
let token_file = data_dir.join("mcp-token");
std::fs::write(&token_file, &bearer_token)
.map_err(|e| format!("Failed to write MCP token file: {}", e))?;
// Set owner-only read/write permissions (Unix)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&token_file, std::fs::Permissions::from_mode(0o600));
}
tokio::spawn(async move { tokio::spawn(async move {
axum::serve(listener, app).await.ok(); axum::serve(listener, app).await.ok();
}); });

View File

@ -12,7 +12,6 @@ use serde::Serialize;
use tokio::sync::Mutex as TokioMutex; use tokio::sync::Mutex as TokioMutex;
use crate::ssh::session::SshClient; use crate::ssh::session::SshClient;
use crate::utils::shell_escape;
#[derive(Debug, Serialize, Clone)] #[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -73,10 +72,9 @@ pub async fn scan_network(
// 1. Ping sweep the subnet to populate ARP cache // 1. Ping sweep the subnet to populate ARP cache
// 2. Read ARP table for IP/MAC pairs // 2. Read ARP table for IP/MAC pairs
// 3. Try reverse DNS for hostnames // 3. Try reverse DNS for hostnames
let escaped_subnet = shell_escape(subnet);
let script = format!(r#" let script = format!(r#"
OS=$(uname -s 2>/dev/null) OS=$(uname -s 2>/dev/null)
SUBNET={escaped_subnet} SUBNET="{subnet}"
# Ping sweep (background, fast) # Ping sweep (background, fast)
if [ "$OS" = "Linux" ]; then if [ "$OS" = "Linux" ]; then
@ -153,12 +151,6 @@ pub async fn scan_ports(
target: &str, target: &str,
ports: &[u16], ports: &[u16],
) -> Result<Vec<PortResult>, String> { ) -> Result<Vec<PortResult>, String> {
// Validate target — /dev/tcp requires a bare hostname/IP, not a shell-quoted value.
// Only allow alphanumeric, dots, hyphens, and colons (for IPv6).
if !target.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == ':') {
return Err(format!("Invalid target for port scan: {}", target));
}
// Use bash /dev/tcp for port scanning — no nmap required // Use bash /dev/tcp for port scanning — no nmap required
let port_checks: Vec<String> = ports.iter() let port_checks: Vec<String> = ports.iter()
.map(|p| format!( .map(|p| format!(

View File

@ -8,7 +8,6 @@ use crate::db::Database;
/// ///
/// All operations acquire the shared DB mutex for their duration and /// All operations acquire the shared DB mutex for their duration and
/// return immediately — no async needed for a local SQLite store. /// return immediately — no async needed for a local SQLite store.
#[derive(Clone)]
pub struct SettingsService { pub struct SettingsService {
db: Database, db: Database,
} }

View File

@ -16,7 +16,6 @@ use russh::ChannelMsg;
use tauri::{AppHandle, Emitter}; use tauri::{AppHandle, Emitter};
use tokio::sync::watch; use tokio::sync::watch;
use tokio::sync::Mutex as TokioMutex; use tokio::sync::Mutex as TokioMutex;
use tokio_util::sync::CancellationToken;
use crate::ssh::session::SshClient; use crate::ssh::session::SshClient;
@ -40,15 +39,13 @@ impl CwdTracker {
/// Spawn a background tokio task that polls `pwd` every 2 seconds on a /// Spawn a background tokio task that polls `pwd` every 2 seconds on a
/// separate exec channel. /// separate exec channel.
/// ///
/// The task runs until cancelled via the `CancellationToken`, or until the /// The task runs until the SSH connection is closed or the channel cannot
/// SSH connection is closed or the channel cannot be opened. /// be opened. CWD changes are emitted as `ssh:cwd:{session_id}` events.
/// CWD changes are emitted as `ssh:cwd:{session_id}` events.
pub fn start( pub fn start(
&self, &self,
handle: Arc<TokioMutex<Handle<SshClient>>>, handle: Arc<TokioMutex<Handle<SshClient>>>,
app_handle: AppHandle, app_handle: AppHandle,
session_id: String, session_id: String,
cancel: CancellationToken,
) { ) {
let sender = self._sender.clone(); let sender = self._sender.clone();
@ -59,10 +56,6 @@ impl CwdTracker {
let mut previous_cwd = String::new(); let mut previous_cwd = String::new();
loop { loop {
if cancel.is_cancelled() {
break;
}
// Open a fresh exec channel for each `pwd` invocation. // Open a fresh exec channel for each `pwd` invocation.
// Some SSH servers do not allow multiple exec requests on a // Some SSH servers do not allow multiple exec requests on a
// single channel, so we open a new one each time. // single channel, so we open a new one each time.
@ -126,11 +119,8 @@ impl CwdTracker {
} }
} }
// Wait 2 seconds before the next poll, or cancel. // Wait 2 seconds before the next poll.
tokio::select! { tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
_ = tokio::time::sleep(tokio::time::Duration::from_secs(2)) => {}
_ = cancel.cancelled() => { break; }
}
} }
debug!("CWD tracker for session {} stopped", session_id); debug!("CWD tracker for session {} stopped", session_id);

View File

@ -6,13 +6,11 @@
use std::sync::Arc; use std::sync::Arc;
use log::warn;
use russh::client::Handle; use russh::client::Handle;
use russh::ChannelMsg; use russh::ChannelMsg;
use serde::Serialize; use serde::Serialize;
use tauri::{AppHandle, Emitter}; use tauri::{AppHandle, Emitter};
use tokio::sync::Mutex as TokioMutex; use tokio::sync::Mutex as TokioMutex;
use tokio_util::sync::CancellationToken;
use crate::ssh::session::SshClient; use crate::ssh::session::SshClient;
@ -32,53 +30,26 @@ pub struct SystemStats {
} }
/// Spawn a background task that polls system stats every 5 seconds. /// Spawn a background task that polls system stats every 5 seconds.
///
/// The task runs until cancelled via the `CancellationToken`, or until the
/// SSH connection is closed.
pub fn start_monitor( pub fn start_monitor(
handle: Arc<TokioMutex<Handle<SshClient>>>, handle: Arc<TokioMutex<Handle<SshClient>>>,
app_handle: AppHandle, app_handle: AppHandle,
session_id: String, session_id: String,
cancel: CancellationToken,
) { ) {
tokio::spawn(async move { tokio::spawn(async move {
// Brief delay to let the shell start up // Brief delay to let the shell start up
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
let mut consecutive_timeouts: u32 = 0;
loop { loop {
if cancel.is_cancelled() {
break;
}
let stats = collect_stats(&handle).await; let stats = collect_stats(&handle).await;
match stats { if let Some(stats) = stats {
Some(stats) => {
consecutive_timeouts = 0;
let _ = app_handle.emit( let _ = app_handle.emit(
&format!("ssh:monitor:{}", session_id), &format!("ssh:monitor:{}", session_id),
&stats, &stats,
); );
} }
None => {
consecutive_timeouts += 1;
if consecutive_timeouts >= 3 {
warn!(
"SSH monitor for session {}: 3 consecutive failures, stopping",
session_id
);
break;
}
}
}
// Wait 5 seconds before the next poll, or cancel. tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
tokio::select! {
_ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {}
_ = cancel.cancelled() => { break; }
}
} }
}); });
} }
@ -154,24 +125,7 @@ fn parse_stats(raw: &str) -> Option<SystemStats> {
}) })
} }
/// Execute a command on a separate exec channel with a 10-second timeout.
async fn exec_command(handle: &Arc<TokioMutex<Handle<SshClient>>>, cmd: &str) -> Option<String> { async fn exec_command(handle: &Arc<TokioMutex<Handle<SshClient>>>, cmd: &str) -> Option<String> {
let result = tokio::time::timeout(
std::time::Duration::from_secs(10),
exec_command_inner(handle, cmd),
)
.await;
match result {
Ok(output) => output,
Err(_) => {
warn!("SSH monitor exec_command timed out after 10s");
None
}
}
}
async fn exec_command_inner(handle: &Arc<TokioMutex<Handle<SshClient>>>, cmd: &str) -> Option<String> {
let mut channel = { let mut channel = {
let h = handle.lock().await; let h = handle.lock().await;
h.channel_open_session().await.ok()? h.channel_open_session().await.ok()?

View File

@ -17,7 +17,6 @@ use crate::mcp::error_watcher::ErrorWatcher;
use crate::sftp::SftpService; use crate::sftp::SftpService;
use crate::ssh::cwd::CwdTracker; use crate::ssh::cwd::CwdTracker;
use crate::ssh::host_key::{HostKeyResult, HostKeyStore}; use crate::ssh::host_key::{HostKeyResult, HostKeyStore};
use tokio_util::sync::CancellationToken;
pub enum AuthMethod { pub enum AuthMethod {
Password(String), Password(String),
@ -48,7 +47,6 @@ pub struct SshSession {
pub handle: Arc<TokioMutex<Handle<SshClient>>>, pub handle: Arc<TokioMutex<Handle<SshClient>>>,
pub command_tx: mpsc::UnboundedSender<ChannelCommand>, pub command_tx: mpsc::UnboundedSender<ChannelCommand>,
pub cwd_tracker: Option<CwdTracker>, pub cwd_tracker: Option<CwdTracker>,
pub cancel_token: CancellationToken,
} }
pub struct SshClient { pub struct SshClient {
@ -137,11 +135,10 @@ impl SshService {
let channel_id = channel.id(); let channel_id = channel.id();
let handle = Arc::new(TokioMutex::new(handle)); let handle = Arc::new(TokioMutex::new(handle));
let (command_tx, mut command_rx) = mpsc::unbounded_channel::<ChannelCommand>(); let (command_tx, mut command_rx) = mpsc::unbounded_channel::<ChannelCommand>();
let cancel_token = CancellationToken::new();
let cwd_tracker = CwdTracker::new(); let cwd_tracker = CwdTracker::new();
cwd_tracker.start(handle.clone(), app_handle.clone(), session_id.clone(), cancel_token.clone()); cwd_tracker.start(handle.clone(), app_handle.clone(), session_id.clone());
let session = Arc::new(SshSession { id: session_id.clone(), hostname: hostname.to_string(), port, username: username.to_string(), channel_id, handle: handle.clone(), command_tx: command_tx.clone(), cwd_tracker: Some(cwd_tracker), cancel_token: cancel_token.clone() }); let session = Arc::new(SshSession { id: session_id.clone(), hostname: hostname.to_string(), port, username: username.to_string(), channel_id, handle: handle.clone(), command_tx: command_tx.clone(), cwd_tracker: Some(cwd_tracker) });
self.sessions.insert(session_id.clone(), session); self.sessions.insert(session_id.clone(), session);
{ let h = handle.lock().await; { let h = handle.lock().await;
@ -161,7 +158,7 @@ impl SshService {
error_watcher.watch(&session_id); error_watcher.watch(&session_id);
// Start remote monitoring if enabled (runs on a separate exec channel) // Start remote monitoring if enabled (runs on a separate exec channel)
crate::ssh::monitor::start_monitor(handle.clone(), app_handle.clone(), session_id.clone(), cancel_token.clone()); crate::ssh::monitor::start_monitor(handle.clone(), app_handle.clone(), session_id.clone());
// Inject OSC 7 CWD reporting hook into the user's shell. // Inject OSC 7 CWD reporting hook into the user's shell.
// This enables SFTP CWD following on all platforms (Linux, macOS, FreeBSD). // This enables SFTP CWD following on all platforms (Linux, macOS, FreeBSD).
@ -249,8 +246,6 @@ impl SshService {
pub async fn disconnect(&self, session_id: &str, sftp_service: &SftpService) -> Result<(), String> { pub async fn disconnect(&self, session_id: &str, sftp_service: &SftpService) -> Result<(), String> {
let (_, session) = self.sessions.remove(session_id).ok_or_else(|| format!("Session {} not found", session_id))?; let (_, session) = self.sessions.remove(session_id).ok_or_else(|| format!("Session {} not found", session_id))?;
// Cancel background tasks (CWD tracker, monitor) before tearing down the connection.
session.cancel_token.cancel();
let _ = session.command_tx.send(ChannelCommand::Shutdown); let _ = session.command_tx.send(ChannelCommand::Shutdown);
{ let handle = session.handle.lock().await; let _ = handle.disconnect(Disconnect::ByApplication, "", "en").await; } { let handle = session.handle.lock().await; let _ = handle.disconnect(Disconnect::ByApplication, "", "en").await; }
sftp_service.remove_client(session_id); sftp_service.remove_client(session_id);

View File

@ -1,19 +0,0 @@
//! Shared utility functions.
/// Escape a string for safe interpolation into a POSIX shell command.
///
/// Wraps the input in single quotes and escapes any embedded single quotes
/// using the `'\''` technique. This prevents command injection when building
/// shell commands from user-supplied values.
///
/// # Examples
///
/// ```
/// # use wraith_lib::utils::shell_escape;
/// assert_eq!(shell_escape("hello"), "'hello'");
/// assert_eq!(shell_escape("it's"), "'it'\\''s'");
/// assert_eq!(shell_escape(";rm -rf /"), "';rm -rf /'");
/// ```
pub fn shell_escape(input: &str) -> String {
format!("'{}'", input.replace('\'', "'\\''"))
}

View File

@ -4,7 +4,6 @@ use aes_gcm::{
Aes256Gcm, Key, Nonce, Aes256Gcm, Key, Nonce,
}; };
use argon2::{Algorithm, Argon2, Params, Version}; use argon2::{Algorithm, Argon2, Params, Version};
use zeroize::Zeroizing;
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// VaultService // VaultService
@ -22,18 +21,18 @@ use zeroize::Zeroizing;
/// The version prefix allows a future migration to a different algorithm /// The version prefix allows a future migration to a different algorithm
/// without breaking existing stored blobs. /// without breaking existing stored blobs.
pub struct VaultService { pub struct VaultService {
key: Zeroizing<[u8; 32]>, key: [u8; 32],
} }
impl VaultService { impl VaultService {
pub fn new(key: Zeroizing<[u8; 32]>) -> Self { pub fn new(key: [u8; 32]) -> Self {
Self { key } Self { key }
} }
/// Encrypt `plaintext` and return a `v1:{iv_hex}:{sealed_hex}` blob. /// Encrypt `plaintext` and return a `v1:{iv_hex}:{sealed_hex}` blob.
pub fn encrypt(&self, plaintext: &str) -> Result<String, String> { pub fn encrypt(&self, plaintext: &str) -> Result<String, String> {
// Build the AES-256-GCM cipher from our key. // Build the AES-256-GCM cipher from our key.
let key = Key::<Aes256Gcm>::from_slice(&*self.key); let key = Key::<Aes256Gcm>::from_slice(&self.key);
let cipher = Aes256Gcm::new(key); let cipher = Aes256Gcm::new(key);
// Generate a random 12-byte nonce (96-bit is the GCM standard). // Generate a random 12-byte nonce (96-bit is the GCM standard).
@ -72,7 +71,7 @@ impl VaultService {
)); ));
} }
let key = Key::<Aes256Gcm>::from_slice(&*self.key); let key = Key::<Aes256Gcm>::from_slice(&self.key);
let cipher = Aes256Gcm::new(key); let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(&iv_bytes); let nonce = Nonce::from_slice(&iv_bytes);
@ -96,7 +95,7 @@ impl VaultService {
/// t = 3 iterations /// t = 3 iterations
/// m = 65536 KiB (64 MiB) memory /// m = 65536 KiB (64 MiB) memory
/// p = 4 parallelism lanes /// p = 4 parallelism lanes
pub fn derive_key(password: &str, salt: &[u8]) -> Zeroizing<[u8; 32]> { pub fn derive_key(password: &str, salt: &[u8]) -> [u8; 32] {
let params = Params::new( let params = Params::new(
65536, // m_cost: 64 MiB 65536, // m_cost: 64 MiB
3, // t_cost: iterations 3, // t_cost: iterations
@ -107,9 +106,9 @@ pub fn derive_key(password: &str, salt: &[u8]) -> Zeroizing<[u8; 32]> {
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params); let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
let mut output_key = Zeroizing::new([0u8; 32]); let mut output_key = [0u8; 32];
argon2 argon2
.hash_password_into(password.as_bytes(), salt, &mut *output_key) .hash_password_into(password.as_bytes(), salt, &mut output_key)
.expect("Argon2id key derivation failed"); .expect("Argon2id key derivation failed");
output_key output_key

View File

@ -22,9 +22,9 @@
} }
], ],
"security": { "security": {
"csp": "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' asset: https://asset.localhost data:; connect-src 'self' ipc: http://ipc.localhost" "csp": null
}, },
"withGlobalTauri": false "withGlobalTauri": true
}, },
"bundle": { "bundle": {
"active": true, "active": true,

View File

@ -88,7 +88,7 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted, onBeforeUnmount } from "vue"; import { ref, onMounted } from "vue";
import { invoke } from "@tauri-apps/api/core"; import { invoke } from "@tauri-apps/api/core";
import { useSessionStore, type Session } from "@/stores/session.store"; import { useSessionStore, type Session } from "@/stores/session.store";
import { useConnectionStore } from "@/stores/connection.store"; import { useConnectionStore } from "@/stores/connection.store";
@ -151,19 +151,9 @@ function closeMenuTab(): void {
if (session) sessionStore.closeSession(session.id); if (session) sessionStore.closeSession(session.id);
} }
// Listen for reattach events from detached windows
import { listen } from "@tauri-apps/api/event"; import { listen } from "@tauri-apps/api/event";
import type { UnlistenFn } from "@tauri-apps/api/event"; listen<{ sessionId: string; name: string; protocol: string }>("session:reattach", (event) => {
let unlistenReattach: UnlistenFn | null = null;
onMounted(async () => {
try {
availableShells.value = await invoke<ShellInfo[]>("list_available_shells");
} catch {
availableShells.value = [];
}
unlistenReattach = await listen<{ sessionId: string; name: string; protocol: string }>("session:reattach", (event) => {
const { sessionId } = event.payload; const { sessionId } = event.payload;
const session = sessionStore.sessions.find(s => s.id === sessionId); const session = sessionStore.sessions.find(s => s.id === sessionId);
if (session) { if (session) {
@ -171,10 +161,13 @@ onMounted(async () => {
sessionStore.activateSession(sessionId); sessionStore.activateSession(sessionId);
} }
}); });
});
onBeforeUnmount(() => { onMounted(async () => {
unlistenReattach?.(); try {
availableShells.value = await invoke<ShellInfo[]>("list_available_shells");
} catch {
availableShells.value = [];
}
}); });
// Drag-and-drop tab reordering // Drag-and-drop tab reordering

View File

@ -184,7 +184,7 @@ export interface UseRdpReturn {
* Composable that manages an RDP session's rendering and input. * Composable that manages an RDP session's rendering and input.
* *
* Uses Tauri's invoke() to call Rust commands: * Uses Tauri's invoke() to call Rust commands:
* rdp_get_frame raw RGBA ArrayBuffer (binary IPC) * rdp_get_frame base64 RGBA string
* rdp_send_mouse fire-and-forget * rdp_send_mouse fire-and-forget
* rdp_send_key fire-and-forget * rdp_send_key fire-and-forget
* rdp_send_clipboard fire-and-forget * rdp_send_clipboard fire-and-forget
@ -195,7 +195,6 @@ export function useRdp(): UseRdpReturn {
const clipboardSync = ref(false); const clipboardSync = ref(false);
let animFrameId: number | null = null; let animFrameId: number | null = null;
let unlistenFrame: (() => void) | null = null;
/** /**
* Fetch the current frame from the Rust RDP backend. * Fetch the current frame from the Rust RDP backend.
@ -209,16 +208,16 @@ export function useRdp(): UseRdpReturn {
width: number, width: number,
height: number, height: number,
): Promise<ImageData | null> { ): Promise<ImageData | null> {
let raw: ArrayBuffer; let raw: number[];
try { try {
raw = await invoke<ArrayBuffer>("rdp_get_frame", { sessionId }); raw = await invoke<number[]>("rdp_get_frame", { sessionId });
} catch { } catch {
return null; return null;
} }
if (!raw || raw.byteLength === 0) return null; if (!raw || raw.length === 0) return null;
// Binary IPC — tauri::ipc::Response delivers raw bytes as ArrayBuffer // Binary IPC — Tauri returns Vec<u8> as number array
const bytes = new Uint8ClampedArray(raw); const bytes = new Uint8ClampedArray(raw);
const expected = width * height * 4; const expected = width * height * 4;
@ -316,7 +315,8 @@ export function useRdp(): UseRdpReturn {
listen(`rdp:frame:${sessionId}`, () => { listen(`rdp:frame:${sessionId}`, () => {
onFrameReady(); onFrameReady();
}).then((unlisten) => { }).then((unlisten) => {
unlistenFrame = unlisten; // Store unlisten so we can clean up
(canvas as any).__wraith_unlisten = unlisten;
}); });
}); });
@ -332,10 +332,6 @@ export function useRdp(): UseRdpReturn {
cancelAnimationFrame(animFrameId); cancelAnimationFrame(animFrameId);
animFrameId = null; animFrameId = null;
} }
if (unlistenFrame !== null) {
unlistenFrame();
unlistenFrame = null;
}
connected.value = false; connected.value = false;
} }

View File

@ -2,7 +2,6 @@ import { defineStore } from "pinia";
import { ref, computed } from "vue"; import { ref, computed } from "vue";
import { invoke } from "@tauri-apps/api/core"; import { invoke } from "@tauri-apps/api/core";
import { listen } from "@tauri-apps/api/event"; import { listen } from "@tauri-apps/api/event";
import type { UnlistenFn } from "@tauri-apps/api/event";
import { useConnectionStore } from "@/stores/connection.store"; import { useConnectionStore } from "@/stores/connection.store";
import type { ThemeDefinition } from "@/components/common/ThemePicker.vue"; import type { ThemeDefinition } from "@/components/common/ThemePicker.vue";
@ -40,14 +39,10 @@ export const useSessionStore = defineStore("session", () => {
const sessionCount = computed(() => sessions.value.length); const sessionCount = computed(() => sessions.value.length);
const sessionUnlisteners = new Map<string, Array<UnlistenFn>>();
// Listen for backend close/exit events to update session status // Listen for backend close/exit events to update session status
async function setupStatusListeners(sessionId: string): Promise<void> { function setupStatusListeners(sessionId: string): void {
const unlisteners: UnlistenFn[] = []; listen(`ssh:close:${sessionId}`, () => markDisconnected(sessionId));
unlisteners.push(await listen(`ssh:close:${sessionId}`, () => markDisconnected(sessionId))); listen(`ssh:exit:${sessionId}`, () => markDisconnected(sessionId));
unlisteners.push(await listen(`ssh:exit:${sessionId}`, () => markDisconnected(sessionId)));
sessionUnlisteners.set(sessionId, unlisteners);
} }
function markDisconnected(sessionId: string): void { function markDisconnected(sessionId: string): void {
@ -97,12 +92,6 @@ export const useSessionStore = defineStore("session", () => {
console.error("Failed to disconnect session:", err); console.error("Failed to disconnect session:", err);
} }
const unlisteners = sessionUnlisteners.get(id);
if (unlisteners) {
unlisteners.forEach((fn) => fn());
sessionUnlisteners.delete(id);
}
sessions.value.splice(idx, 1); sessions.value.splice(idx, 1);
if (activeSessionId.value === id) { if (activeSessionId.value === id) {
@ -336,8 +325,7 @@ export const useSessionStore = defineStore("session", () => {
}); });
// Listen for PTY close // Listen for PTY close
const unlistenPty = await listen(`pty:close:${sessionId}`, () => markDisconnected(sessionId)); listen(`pty:close:${sessionId}`, () => markDisconnected(sessionId));
sessionUnlisteners.set(sessionId, [unlistenPty]);
activeSessionId.value = sessionId; activeSessionId.value = sessionId;
} catch (err: unknown) { } catch (err: unknown) {