Compare commits
11 Commits
ff9fc798c3
...
3842d48390
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3842d48390 | ||
|
|
687ccfb982 | ||
|
|
8a66103d3d | ||
|
|
15c95841be | ||
|
|
625a4500bc | ||
|
|
3843f18b31 | ||
|
|
17973fc3dc | ||
|
|
da2dd5bbfc | ||
|
|
fca6ed023e | ||
|
|
24e8b1e359 | ||
|
|
a907213d57 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ src-tauri/target/
|
|||||||
src-tauri/binaries/
|
src-tauri/binaries/
|
||||||
*.log
|
*.log
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
.claude/worktrees/
|
||||||
|
|||||||
2
src-tauri/Cargo.lock
generated
2
src-tauri/Cargo.lock
generated
@ -8913,9 +8913,11 @@ dependencies = [
|
|||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
|
"tokio-util",
|
||||||
"ureq",
|
"ureq",
|
||||||
"uuid",
|
"uuid",
|
||||||
"x509-cert",
|
"x509-cert",
|
||||||
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@ -12,11 +12,15 @@ crate-type = ["lib", "cdylib", "staticlib"]
|
|||||||
name = "wraith-mcp-bridge"
|
name = "wraith-mcp-bridge"
|
||||||
path = "src/bin/wraith_mcp_bridge.rs"
|
path = "src/bin/wraith_mcp_bridge.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
devtools = ["tauri/devtools"]
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tauri-build = { version = "2", features = [] }
|
tauri-build = { version = "2", features = [] }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tauri = { version = "2", features = ["devtools"] }
|
tauri = { version = "2", features = [] }
|
||||||
tauri-plugin-shell = "2"
|
tauri-plugin-shell = "2"
|
||||||
tauri-plugin-updater = "2"
|
tauri-plugin-updater = "2"
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
@ -33,6 +37,8 @@ uuid = { version = "1", features = ["v4"] }
|
|||||||
base64 = "0.22"
|
base64 = "0.22"
|
||||||
dashmap = "6"
|
dashmap = "6"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
tokio-util = "0.7"
|
||||||
|
zeroize = { version = "1", features = ["derive"] }
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
env_logger = "0.11"
|
env_logger = "0.11"
|
||||||
|
|||||||
@ -38,19 +38,22 @@ struct JsonRpcError {
|
|||||||
message: String,
|
message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mcp_port() -> Result<u16, String> {
|
fn get_data_dir() -> Result<std::path::PathBuf, String> {
|
||||||
// Check standard locations for the port file
|
if let Ok(appdata) = std::env::var("APPDATA") {
|
||||||
let port_file = if let Ok(appdata) = std::env::var("APPDATA") {
|
Ok(std::path::PathBuf::from(appdata).join("Wraith"))
|
||||||
std::path::PathBuf::from(appdata).join("Wraith").join("mcp-port")
|
|
||||||
} else if let Ok(home) = std::env::var("HOME") {
|
} else if let Ok(home) = std::env::var("HOME") {
|
||||||
if cfg!(target_os = "macos") {
|
if cfg!(target_os = "macos") {
|
||||||
std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith").join("mcp-port")
|
Ok(std::path::PathBuf::from(home).join("Library").join("Application Support").join("Wraith"))
|
||||||
} else {
|
} else {
|
||||||
std::path::PathBuf::from(home).join(".local").join("share").join("wraith").join("mcp-port")
|
Ok(std::path::PathBuf::from(home).join(".local").join("share").join("wraith"))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err("Cannot determine data directory".to_string());
|
Err("Cannot determine data directory".to_string())
|
||||||
};
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mcp_port() -> Result<u16, String> {
|
||||||
|
let port_file = get_data_dir()?.join("mcp-port");
|
||||||
|
|
||||||
let port_str = std::fs::read_to_string(&port_file)
|
let port_str = std::fs::read_to_string(&port_file)
|
||||||
.map_err(|e| format!("Cannot read MCP port file at {}: {} — is Wraith running?", port_file.display(), e))?;
|
.map_err(|e| format!("Cannot read MCP port file at {}: {} — is Wraith running?", port_file.display(), e))?;
|
||||||
@ -59,6 +62,15 @@ fn get_mcp_port() -> Result<u16, String> {
|
|||||||
.map_err(|e| format!("Invalid port in MCP port file: {}", e))
|
.map_err(|e| format!("Invalid port in MCP port file: {}", e))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_mcp_token() -> Result<String, String> {
|
||||||
|
let token_file = get_data_dir()?.join("mcp-token");
|
||||||
|
|
||||||
|
let token = std::fs::read_to_string(&token_file)
|
||||||
|
.map_err(|e| format!("Cannot read MCP token file at {}: {} — is Wraith running?", token_file.display(), e))?;
|
||||||
|
|
||||||
|
Ok(token.trim().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
fn handle_initialize(id: Value) -> JsonRpcResponse {
|
fn handle_initialize(id: Value) -> JsonRpcResponse {
|
||||||
JsonRpcResponse {
|
JsonRpcResponse {
|
||||||
jsonrpc: "2.0".to_string(),
|
jsonrpc: "2.0".to_string(),
|
||||||
@ -304,12 +316,13 @@ fn handle_tools_list(id: Value) -> JsonRpcResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result<Value, String> {
|
fn call_wraith(port: u16, token: &str, endpoint: &str, body: Value) -> Result<Value, String> {
|
||||||
let url = format!("http://127.0.0.1:{}{}", port, endpoint);
|
let url = format!("http://127.0.0.1:{}{}", port, endpoint);
|
||||||
let body_str = serde_json::to_string(&body).unwrap_or_default();
|
let body_str = serde_json::to_string(&body).unwrap_or_default();
|
||||||
|
|
||||||
let mut resp = ureq::post(url)
|
let mut resp = ureq::post(url)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", &format!("Bearer {}", token))
|
||||||
.send(body_str.as_bytes())
|
.send(body_str.as_bytes())
|
||||||
.map_err(|e| format!("HTTP request to Wraith failed: {}", e))?;
|
.map_err(|e| format!("HTTP request to Wraith failed: {}", e))?;
|
||||||
|
|
||||||
@ -327,40 +340,40 @@ fn call_wraith(port: u16, endpoint: &str, body: Value) -> Result<Value, String>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_tool_call(id: Value, port: u16, tool_name: &str, args: &Value) -> JsonRpcResponse {
|
fn handle_tool_call(id: Value, port: u16, token: &str, tool_name: &str, args: &Value) -> JsonRpcResponse {
|
||||||
let result = match tool_name {
|
let result = match tool_name {
|
||||||
"list_sessions" => call_wraith(port, "/mcp/sessions", serde_json::json!({})),
|
"list_sessions" => call_wraith(port, token, "/mcp/sessions", serde_json::json!({})),
|
||||||
"terminal_type" => call_wraith(port, "/mcp/terminal/type", args.clone()),
|
"terminal_type" => call_wraith(port, token, "/mcp/terminal/type", args.clone()),
|
||||||
"terminal_read" => call_wraith(port, "/mcp/terminal/read", args.clone()),
|
"terminal_read" => call_wraith(port, token, "/mcp/terminal/read", args.clone()),
|
||||||
"terminal_execute" => call_wraith(port, "/mcp/terminal/execute", args.clone()),
|
"terminal_execute" => call_wraith(port, token, "/mcp/terminal/execute", args.clone()),
|
||||||
"sftp_list" => call_wraith(port, "/mcp/sftp/list", args.clone()),
|
"sftp_list" => call_wraith(port, token, "/mcp/sftp/list", args.clone()),
|
||||||
"sftp_read" => call_wraith(port, "/mcp/sftp/read", args.clone()),
|
"sftp_read" => call_wraith(port, token, "/mcp/sftp/read", args.clone()),
|
||||||
"sftp_write" => call_wraith(port, "/mcp/sftp/write", args.clone()),
|
"sftp_write" => call_wraith(port, token, "/mcp/sftp/write", args.clone()),
|
||||||
"network_scan" => call_wraith(port, "/mcp/tool/scan-network", args.clone()),
|
"network_scan" => call_wraith(port, token, "/mcp/tool/scan-network", args.clone()),
|
||||||
"port_scan" => call_wraith(port, "/mcp/tool/scan-ports", args.clone()),
|
"port_scan" => call_wraith(port, token, "/mcp/tool/scan-ports", args.clone()),
|
||||||
"ping" => call_wraith(port, "/mcp/tool/ping", args.clone()),
|
"ping" => call_wraith(port, token, "/mcp/tool/ping", args.clone()),
|
||||||
"traceroute" => call_wraith(port, "/mcp/tool/traceroute", args.clone()),
|
"traceroute" => call_wraith(port, token, "/mcp/tool/traceroute", args.clone()),
|
||||||
"dns_lookup" => call_wraith(port, "/mcp/tool/dns", args.clone()),
|
"dns_lookup" => call_wraith(port, token, "/mcp/tool/dns", args.clone()),
|
||||||
"whois" => call_wraith(port, "/mcp/tool/whois", args.clone()),
|
"whois" => call_wraith(port, token, "/mcp/tool/whois", args.clone()),
|
||||||
"wake_on_lan" => call_wraith(port, "/mcp/tool/wol", args.clone()),
|
"wake_on_lan" => call_wraith(port, token, "/mcp/tool/wol", args.clone()),
|
||||||
"bandwidth_test" => call_wraith(port, "/mcp/tool/bandwidth", args.clone()),
|
"bandwidth_test" => call_wraith(port, token, "/mcp/tool/bandwidth", args.clone()),
|
||||||
"subnet_calc" => call_wraith(port, "/mcp/tool/subnet", args.clone()),
|
"subnet_calc" => call_wraith(port, token, "/mcp/tool/subnet", args.clone()),
|
||||||
"generate_ssh_key" => call_wraith(port, "/mcp/tool/keygen", args.clone()),
|
"generate_ssh_key" => call_wraith(port, token, "/mcp/tool/keygen", args.clone()),
|
||||||
"generate_password" => call_wraith(port, "/mcp/tool/passgen", args.clone()),
|
"generate_password" => call_wraith(port, token, "/mcp/tool/passgen", args.clone()),
|
||||||
"docker_ps" => call_wraith(port, "/mcp/docker/ps", args.clone()),
|
"docker_ps" => call_wraith(port, token, "/mcp/docker/ps", args.clone()),
|
||||||
"docker_action" => call_wraith(port, "/mcp/docker/action", args.clone()),
|
"docker_action" => call_wraith(port, token, "/mcp/docker/action", args.clone()),
|
||||||
"docker_exec" => call_wraith(port, "/mcp/docker/exec", args.clone()),
|
"docker_exec" => call_wraith(port, token, "/mcp/docker/exec", args.clone()),
|
||||||
"service_status" => call_wraith(port, "/mcp/service/status", args.clone()),
|
"service_status" => call_wraith(port, token, "/mcp/service/status", args.clone()),
|
||||||
"process_list" => call_wraith(port, "/mcp/process/list", args.clone()),
|
"process_list" => call_wraith(port, token, "/mcp/process/list", args.clone()),
|
||||||
"git_status" => call_wraith(port, "/mcp/git/status", args.clone()),
|
"git_status" => call_wraith(port, token, "/mcp/git/status", args.clone()),
|
||||||
"git_pull" => call_wraith(port, "/mcp/git/pull", args.clone()),
|
"git_pull" => call_wraith(port, token, "/mcp/git/pull", args.clone()),
|
||||||
"git_log" => call_wraith(port, "/mcp/git/log", args.clone()),
|
"git_log" => call_wraith(port, token, "/mcp/git/log", args.clone()),
|
||||||
"rdp_click" => call_wraith(port, "/mcp/rdp/click", args.clone()),
|
"rdp_click" => call_wraith(port, token, "/mcp/rdp/click", args.clone()),
|
||||||
"rdp_type" => call_wraith(port, "/mcp/rdp/type", args.clone()),
|
"rdp_type" => call_wraith(port, token, "/mcp/rdp/type", args.clone()),
|
||||||
"rdp_clipboard" => call_wraith(port, "/mcp/rdp/clipboard", args.clone()),
|
"rdp_clipboard" => call_wraith(port, token, "/mcp/rdp/clipboard", args.clone()),
|
||||||
"ssh_connect" => call_wraith(port, "/mcp/ssh/connect", args.clone()),
|
"ssh_connect" => call_wraith(port, token, "/mcp/ssh/connect", args.clone()),
|
||||||
"terminal_screenshot" => {
|
"terminal_screenshot" => {
|
||||||
let result = call_wraith(port, "/mcp/screenshot", args.clone());
|
let result = call_wraith(port, token, "/mcp/screenshot", args.clone());
|
||||||
// Screenshot returns base64 PNG — wrap as image content for multimodal AI
|
// Screenshot returns base64 PNG — wrap as image content for multimodal AI
|
||||||
return match result {
|
return match result {
|
||||||
Ok(b64) => JsonRpcResponse {
|
Ok(b64) => JsonRpcResponse {
|
||||||
@ -420,6 +433,14 @@ fn main() {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let token = match get_mcp_token() {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("wraith-mcp-bridge: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let stdin = io::stdin();
|
let stdin = io::stdin();
|
||||||
let mut stdout = io::stdout();
|
let mut stdout = io::stdout();
|
||||||
|
|
||||||
@ -458,7 +479,7 @@ fn main() {
|
|||||||
let args = request.params.get("arguments")
|
let args = request.params.get("arguments")
|
||||||
.cloned()
|
.cloned()
|
||||||
.unwrap_or(Value::Object(serde_json::Map::new()));
|
.unwrap_or(Value::Object(serde_json::Map::new()));
|
||||||
handle_tool_call(request.id, port, tool_name, &args)
|
handle_tool_call(request.id, port, &token, tool_name, &args)
|
||||||
}
|
}
|
||||||
"notifications/initialized" | "notifications/cancelled" => {
|
"notifications/initialized" | "notifications/cancelled" => {
|
||||||
// Notifications don't get responses
|
// Notifications don't get responses
|
||||||
|
|||||||
@ -3,34 +3,16 @@ use tauri::State;
|
|||||||
use crate::credentials::Credential;
|
use crate::credentials::Credential;
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
/// Guard helper: lock the credentials mutex and return a ref to the inner
|
|
||||||
/// `CredentialService`, or a "Vault is locked" error if the vault has not
|
|
||||||
/// been unlocked for this session.
|
|
||||||
///
|
|
||||||
/// This is a macro rather than a function because returning a `MutexGuard`
|
|
||||||
/// from a helper function would require lifetime annotations that complicate
|
|
||||||
/// the tauri command signatures unnecessarily.
|
|
||||||
macro_rules! require_unlocked {
|
|
||||||
($state:expr) => {{
|
|
||||||
let guard = $state
|
|
||||||
.credentials
|
|
||||||
.lock()
|
|
||||||
.map_err(|_| "Credentials mutex was poisoned".to_string())?;
|
|
||||||
if guard.is_none() {
|
|
||||||
return Err("Vault is locked — call unlock before accessing credentials".into());
|
|
||||||
}
|
|
||||||
// SAFETY: we just checked `is_none` above, so `unwrap` cannot panic.
|
|
||||||
guard
|
|
||||||
}};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return all credentials ordered by name.
|
/// Return all credentials ordered by name.
|
||||||
///
|
///
|
||||||
/// Secret values (passwords, private keys) are never included — only metadata.
|
/// Secret values (passwords, private keys) are never included — only metadata.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn list_credentials(state: State<'_, AppState>) -> Result<Vec<Credential>, String> {
|
pub async fn list_credentials(state: State<'_, AppState>) -> Result<Vec<Credential>, String> {
|
||||||
let guard = require_unlocked!(state);
|
let guard = state.credentials.lock().await;
|
||||||
guard.as_ref().unwrap().list()
|
let svc = guard
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
|
||||||
|
svc.list()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store a new username/password credential.
|
/// Store a new username/password credential.
|
||||||
@ -39,18 +21,18 @@ pub fn list_credentials(state: State<'_, AppState>) -> Result<Vec<Credential>, S
|
|||||||
/// Returns the created credential record (without the plaintext password).
|
/// Returns the created credential record (without the plaintext password).
|
||||||
/// `domain` is `None` for non-domain credentials; `Some("")` is treated as NULL.
|
/// `domain` is `None` for non-domain credentials; `Some("")` is treated as NULL.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn create_password(
|
pub async fn create_password(
|
||||||
name: String,
|
name: String,
|
||||||
username: String,
|
username: String,
|
||||||
password: String,
|
password: String,
|
||||||
domain: Option<String>,
|
domain: Option<String>,
|
||||||
state: State<'_, AppState>,
|
state: State<'_, AppState>,
|
||||||
) -> Result<Credential, String> {
|
) -> Result<Credential, String> {
|
||||||
let guard = require_unlocked!(state);
|
let guard = state.credentials.lock().await;
|
||||||
guard
|
let svc = guard
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
|
||||||
.create_password(name, username, password, domain)
|
svc.create_password(name, username, password, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store a new SSH private key credential.
|
/// Store a new SSH private key credential.
|
||||||
@ -59,18 +41,18 @@ pub fn create_password(
|
|||||||
/// Pass `None` for `passphrase` when the key has no passphrase.
|
/// Pass `None` for `passphrase` when the key has no passphrase.
|
||||||
/// Returns the created credential record without any secret material.
|
/// Returns the created credential record without any secret material.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn create_ssh_key(
|
pub async fn create_ssh_key(
|
||||||
name: String,
|
name: String,
|
||||||
username: String,
|
username: String,
|
||||||
private_key_pem: String,
|
private_key_pem: String,
|
||||||
passphrase: Option<String>,
|
passphrase: Option<String>,
|
||||||
state: State<'_, AppState>,
|
state: State<'_, AppState>,
|
||||||
) -> Result<Credential, String> {
|
) -> Result<Credential, String> {
|
||||||
let guard = require_unlocked!(state);
|
let guard = state.credentials.lock().await;
|
||||||
guard
|
let svc = guard
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
|
||||||
.create_ssh_key(name, username, private_key_pem, passphrase)
|
svc.create_ssh_key(name, username, private_key_pem, passphrase)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete a credential by id.
|
/// Delete a credential by id.
|
||||||
@ -78,21 +60,30 @@ pub fn create_ssh_key(
|
|||||||
/// For SSH key credentials, the associated `ssh_keys` row is also deleted.
|
/// For SSH key credentials, the associated `ssh_keys` row is also deleted.
|
||||||
/// Returns `Err` if the vault is locked or the id does not exist.
|
/// Returns `Err` if the vault is locked or the id does not exist.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn delete_credential(id: i64, state: State<'_, AppState>) -> Result<(), String> {
|
pub async fn delete_credential(id: i64, state: State<'_, AppState>) -> Result<(), String> {
|
||||||
let guard = require_unlocked!(state);
|
let guard = state.credentials.lock().await;
|
||||||
guard.as_ref().unwrap().delete(id)
|
let svc = guard
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
|
||||||
|
svc.delete(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decrypt and return the password for a credential.
|
/// Decrypt and return the password for a credential.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn decrypt_password(credential_id: i64, state: State<'_, AppState>) -> Result<String, String> {
|
pub async fn decrypt_password(credential_id: i64, state: State<'_, AppState>) -> Result<String, String> {
|
||||||
let guard = require_unlocked!(state);
|
let guard = state.credentials.lock().await;
|
||||||
guard.as_ref().unwrap().decrypt_password(credential_id)
|
let svc = guard
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
|
||||||
|
svc.decrypt_password(credential_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decrypt and return the SSH private key and passphrase.
|
/// Decrypt and return the SSH private key and passphrase.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn decrypt_ssh_key(ssh_key_id: i64, state: State<'_, AppState>) -> Result<(String, String), String> {
|
pub async fn decrypt_ssh_key(ssh_key_id: i64, state: State<'_, AppState>) -> Result<(String, String), String> {
|
||||||
let guard = require_unlocked!(state);
|
let guard = state.credentials.lock().await;
|
||||||
guard.as_ref().unwrap().decrypt_ssh_key(ssh_key_id)
|
let svc = guard
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| "Vault is locked — call unlock before accessing credentials".to_string())?;
|
||||||
|
svc.decrypt_ssh_key(ssh_key_id)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
use tauri::State;
|
use tauri::State;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
use crate::utils::shell_escape;
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
@ -84,14 +85,15 @@ pub async fn docker_list_volumes(session_id: String, state: State<'_, AppState>)
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn docker_action(session_id: String, action: String, target: String, state: State<'_, AppState>) -> Result<String, String> {
|
pub async fn docker_action(session_id: String, action: String, target: String, state: State<'_, AppState>) -> Result<String, String> {
|
||||||
let session = state.ssh.get_session(&session_id).ok_or("Session not found")?;
|
let session = state.ssh.get_session(&session_id).ok_or("Session not found")?;
|
||||||
|
let t = shell_escape(&target);
|
||||||
let cmd = match action.as_str() {
|
let cmd = match action.as_str() {
|
||||||
"start" => format!("docker start {} 2>&1", target),
|
"start" => format!("docker start {} 2>&1", t),
|
||||||
"stop" => format!("docker stop {} 2>&1", target),
|
"stop" => format!("docker stop {} 2>&1", t),
|
||||||
"restart" => format!("docker restart {} 2>&1", target),
|
"restart" => format!("docker restart {} 2>&1", t),
|
||||||
"remove" => format!("docker rm -f {} 2>&1", target),
|
"remove" => format!("docker rm -f {} 2>&1", t),
|
||||||
"logs" => format!("docker logs --tail 100 {} 2>&1", target),
|
"logs" => format!("docker logs --tail 100 {} 2>&1", t),
|
||||||
"remove-image" => format!("docker rmi {} 2>&1", target),
|
"remove-image" => format!("docker rmi {} 2>&1", t),
|
||||||
"remove-volume" => format!("docker volume rm {} 2>&1", target),
|
"remove-volume" => format!("docker volume rm {} 2>&1", t),
|
||||||
"builder-prune" => "docker builder prune -f 2>&1".to_string(),
|
"builder-prune" => "docker builder prune -f 2>&1".to_string(),
|
||||||
"system-prune" => "docker system prune -f 2>&1".to_string(),
|
"system-prune" => "docker system prune -f 2>&1".to_string(),
|
||||||
"system-prune-all" => "docker system prune -a -f 2>&1".to_string(),
|
"system-prune-all" => "docker system prune -a -f 2>&1".to_string(),
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
//! delegate to the `RdpService` via `State<AppState>`.
|
//! delegate to the `RdpService` via `State<AppState>`.
|
||||||
|
|
||||||
use tauri::{AppHandle, State};
|
use tauri::{AppHandle, State};
|
||||||
|
use tauri::ipc::Response;
|
||||||
|
|
||||||
use crate::rdp::{RdpConfig, RdpSessionInfo};
|
use crate::rdp::{RdpConfig, RdpSessionInfo};
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
@ -18,16 +19,18 @@ pub fn connect_rdp(
|
|||||||
state.rdp.connect(config, app_handle)
|
state.rdp.connect(config, app_handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the current frame buffer as raw RGBA bytes (binary IPC — no base64).
|
/// Get the current frame buffer as raw RGBA bytes via binary IPC.
|
||||||
///
|
///
|
||||||
|
/// Uses `tauri::ipc::Response` to return raw bytes without JSON serialization.
|
||||||
/// Pixel format: RGBA, 4 bytes per pixel, row-major, top-left origin.
|
/// Pixel format: RGBA, 4 bytes per pixel, row-major, top-left origin.
|
||||||
/// Returns empty Vec if frame hasn't changed since last call.
|
/// Returns empty payload if frame hasn't changed since last call.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn rdp_get_frame(
|
pub async fn rdp_get_frame(
|
||||||
session_id: String,
|
session_id: String,
|
||||||
state: State<'_, AppState>,
|
state: State<'_, AppState>,
|
||||||
) -> Result<Vec<u8>, String> {
|
) -> Result<Response, String> {
|
||||||
state.rdp.get_frame(&session_id).await
|
let frame = state.rdp.get_frame(&session_id).await?;
|
||||||
|
Ok(Response::new(frame))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send a mouse event to an RDP session.
|
/// Send a mouse event to an RDP session.
|
||||||
|
|||||||
@ -4,6 +4,7 @@ use tauri::State;
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
use crate::utils::shell_escape;
|
||||||
|
|
||||||
// ── Ping ─────────────────────────────────────────────────────────────────────
|
// ── Ping ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ pub async fn tool_ping(
|
|||||||
let session = state.ssh.get_session(&session_id)
|
let session = state.ssh.get_session(&session_id)
|
||||||
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
||||||
let n = count.unwrap_or(4);
|
let n = count.unwrap_or(4);
|
||||||
let cmd = format!("ping -c {} {} 2>&1", n, target);
|
let cmd = format!("ping -c {} {} 2>&1", n, shell_escape(&target));
|
||||||
let output = exec_on_session(&session.handle, &cmd).await?;
|
let output = exec_on_session(&session.handle, &cmd).await?;
|
||||||
Ok(PingResult { target, output })
|
Ok(PingResult { target, output })
|
||||||
}
|
}
|
||||||
@ -39,7 +40,8 @@ pub async fn tool_traceroute(
|
|||||||
) -> Result<String, String> {
|
) -> Result<String, String> {
|
||||||
let session = state.ssh.get_session(&session_id)
|
let session = state.ssh.get_session(&session_id)
|
||||||
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
||||||
let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", target, target);
|
let t = shell_escape(&target);
|
||||||
|
let cmd = format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t);
|
||||||
exec_on_session(&session.handle, &cmd).await
|
exec_on_session(&session.handle, &cmd).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,14 +67,16 @@ pub async fn tool_wake_on_lan(
|
|||||||
let cmd = format!(
|
let cmd = format!(
|
||||||
r#"python3 -c "
|
r#"python3 -c "
|
||||||
import socket, struct
|
import socket, struct
|
||||||
mac = bytes.fromhex('{mac_clean}')
|
mac = bytes.fromhex({mac_clean_escaped})
|
||||||
pkt = b'\xff'*6 + mac*16
|
pkt = b'\xff'*6 + mac*16
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||||
s.sendto(pkt, ('255.255.255.255', 9))
|
s.sendto(pkt, ('255.255.255.255', 9))
|
||||||
s.close()
|
s.close()
|
||||||
print('WoL packet sent to {mac_address}')
|
print('WoL packet sent to {mac_display_escaped}')
|
||||||
" 2>&1 || echo "python3 not available — install python3 on remote host for WoL""#
|
" 2>&1 || echo "python3 not available — install python3 on remote host for WoL""#,
|
||||||
|
mac_clean_escaped = shell_escape(&mac_clean),
|
||||||
|
mac_display_escaped = shell_escape(&mac_address),
|
||||||
);
|
);
|
||||||
|
|
||||||
exec_on_session(&session.handle, &cmd).await
|
exec_on_session(&session.handle, &cmd).await
|
||||||
|
|||||||
@ -4,6 +4,7 @@ use tauri::State;
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
use crate::utils::shell_escape;
|
||||||
|
|
||||||
// ── DNS Lookup ───────────────────────────────────────────────────────────────
|
// ── DNS Lookup ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@ -16,10 +17,11 @@ pub async fn tool_dns_lookup(
|
|||||||
) -> Result<String, String> {
|
) -> Result<String, String> {
|
||||||
let session = state.ssh.get_session(&session_id)
|
let session = state.ssh.get_session(&session_id)
|
||||||
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
||||||
let rtype = record_type.unwrap_or_else(|| "A".to_string());
|
let d = shell_escape(&domain);
|
||||||
|
let rt = shell_escape(&record_type.unwrap_or_else(|| "A".to_string()));
|
||||||
let cmd = format!(
|
let cmd = format!(
|
||||||
r#"dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null"#,
|
r#"dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null"#,
|
||||||
domain, rtype, rtype, domain, rtype, domain
|
d, rt, rt, d, rt, d
|
||||||
);
|
);
|
||||||
exec_on_session(&session.handle, &cmd).await
|
exec_on_session(&session.handle, &cmd).await
|
||||||
}
|
}
|
||||||
@ -34,7 +36,7 @@ pub async fn tool_whois(
|
|||||||
) -> Result<String, String> {
|
) -> Result<String, String> {
|
||||||
let session = state.ssh.get_session(&session_id)
|
let session = state.ssh.get_session(&session_id)
|
||||||
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
||||||
let cmd = format!("whois {} 2>&1 | head -80", target);
|
let cmd = format!("whois {} 2>&1 | head -80", shell_escape(&target));
|
||||||
exec_on_session(&session.handle, &cmd).await
|
exec_on_session(&session.handle, &cmd).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,9 +52,10 @@ pub async fn tool_bandwidth_iperf(
|
|||||||
let session = state.ssh.get_session(&session_id)
|
let session = state.ssh.get_session(&session_id)
|
||||||
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
.ok_or_else(|| format!("SSH session {} not found", session_id))?;
|
||||||
let dur = duration.unwrap_or(5);
|
let dur = duration.unwrap_or(5);
|
||||||
|
let s = shell_escape(&server);
|
||||||
let cmd = format!(
|
let cmd = format!(
|
||||||
"iperf3 -c {} -t {} --json 2>/dev/null || iperf3 -c {} -t {} 2>&1 || echo 'iperf3 not installed — run: apt install iperf3 / brew install iperf3'",
|
"iperf3 -c {} -t {} --json 2>/dev/null || iperf3 -c {} -t {} 2>&1 || echo 'iperf3 not installed — run: apt install iperf3 / brew install iperf3'",
|
||||||
server, dur, server, dur
|
s, dur, s, dur
|
||||||
);
|
);
|
||||||
exec_on_session(&session.handle, &cmd).await
|
exec_on_session(&session.handle, &cmd).await
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
use tauri::State;
|
use tauri::State;
|
||||||
|
use zeroize::Zeroize;
|
||||||
|
|
||||||
use crate::vault::{self, VaultService};
|
use crate::vault::{self, VaultService};
|
||||||
use crate::credentials::CredentialService;
|
use crate::credentials::CredentialService;
|
||||||
@ -21,14 +22,15 @@ pub fn is_first_run(state: State<'_, AppState>) -> bool {
|
|||||||
/// Returns `Err` if the vault has already been set up or if any storage
|
/// Returns `Err` if the vault has already been set up or if any storage
|
||||||
/// operation fails.
|
/// operation fails.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn create_vault(password: String, state: State<'_, AppState>) -> Result<(), String> {
|
pub async fn create_vault(mut password: String, state: State<'_, AppState>) -> Result<(), String> {
|
||||||
|
let result = async {
|
||||||
if !state.is_first_run() {
|
if !state.is_first_run() {
|
||||||
return Err("Vault already exists — use unlock instead of create".into());
|
return Err("Vault already exists — use unlock instead of create".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let salt = vault::generate_salt();
|
let salt = vault::generate_salt();
|
||||||
let key = vault::derive_key(&password, &salt);
|
let key = vault::derive_key(&password, &salt);
|
||||||
let vs = VaultService::new(key);
|
let vs = VaultService::new(key.clone());
|
||||||
|
|
||||||
// Persist the salt so we can re-derive the key on future unlocks.
|
// Persist the salt so we can re-derive the key on future unlocks.
|
||||||
state.settings.set("vault_salt", &hex::encode(salt))?;
|
state.settings.set("vault_salt", &hex::encode(salt))?;
|
||||||
@ -39,10 +41,14 @@ pub fn create_vault(password: String, state: State<'_, AppState>) -> Result<(),
|
|||||||
|
|
||||||
// Activate the vault and credentials service for this session.
|
// Activate the vault and credentials service for this session.
|
||||||
let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key));
|
let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key));
|
||||||
*state.credentials.lock().unwrap() = Some(cred_svc);
|
*state.credentials.lock().await = Some(cred_svc);
|
||||||
*state.vault.lock().unwrap() = Some(vs);
|
*state.vault.lock().await = Some(vs);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
}.await;
|
||||||
|
|
||||||
|
password.zeroize();
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unlock an existing vault using the master password.
|
/// Unlock an existing vault using the master password.
|
||||||
@ -52,7 +58,8 @@ pub fn create_vault(password: String, state: State<'_, AppState>) -> Result<(),
|
|||||||
///
|
///
|
||||||
/// Returns `Err("Incorrect master password")` if the password is wrong.
|
/// Returns `Err("Incorrect master password")` if the password is wrong.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn unlock(password: String, state: State<'_, AppState>) -> Result<(), String> {
|
pub async fn unlock(mut password: String, state: State<'_, AppState>) -> Result<(), String> {
|
||||||
|
let result = async {
|
||||||
let salt_hex = state
|
let salt_hex = state
|
||||||
.settings
|
.settings
|
||||||
.get("vault_salt")
|
.get("vault_salt")
|
||||||
@ -62,7 +69,7 @@ pub fn unlock(password: String, state: State<'_, AppState>) -> Result<(), String
|
|||||||
.map_err(|e| format!("Stored vault salt is corrupt: {e}"))?;
|
.map_err(|e| format!("Stored vault salt is corrupt: {e}"))?;
|
||||||
|
|
||||||
let key = vault::derive_key(&password, &salt);
|
let key = vault::derive_key(&password, &salt);
|
||||||
let vs = VaultService::new(key);
|
let vs = VaultService::new(key.clone());
|
||||||
|
|
||||||
// Verify the password by decrypting the check value.
|
// Verify the password by decrypting the check value.
|
||||||
let check_blob = state
|
let check_blob = state
|
||||||
@ -80,14 +87,18 @@ pub fn unlock(password: String, state: State<'_, AppState>) -> Result<(), String
|
|||||||
|
|
||||||
// Activate the vault and credentials service for this session.
|
// Activate the vault and credentials service for this session.
|
||||||
let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key));
|
let cred_svc = CredentialService::new(state.db.clone(), VaultService::new(key));
|
||||||
*state.credentials.lock().unwrap() = Some(cred_svc);
|
*state.credentials.lock().await = Some(cred_svc);
|
||||||
*state.vault.lock().unwrap() = Some(vs);
|
*state.vault.lock().await = Some(vs);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
}.await;
|
||||||
|
|
||||||
|
password.zeroize();
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns `true` if the vault is currently unlocked for this session.
|
/// Returns `true` if the vault is currently unlocked for this session.
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn is_unlocked(state: State<'_, AppState>) -> bool {
|
pub async fn is_unlocked(state: State<'_, AppState>) -> Result<bool, String> {
|
||||||
state.is_unlocked()
|
Ok(state.is_unlocked().await)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,7 @@ use crate::db::Database;
|
|||||||
// ── domain types ──────────────────────────────────────────────────────────────
|
// ── domain types ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ConnectionGroup {
|
pub struct ConnectionGroup {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
|||||||
@ -21,9 +21,9 @@ pub mod pty;
|
|||||||
pub mod mcp;
|
pub mod mcp;
|
||||||
pub mod scanner;
|
pub mod scanner;
|
||||||
pub mod commands;
|
pub mod commands;
|
||||||
|
pub mod utils;
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Mutex;
|
|
||||||
|
|
||||||
use db::Database;
|
use db::Database;
|
||||||
use vault::VaultService;
|
use vault::VaultService;
|
||||||
@ -41,10 +41,10 @@ use mcp::error_watcher::ErrorWatcher;
|
|||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub db: Database,
|
pub db: Database,
|
||||||
pub vault: Mutex<Option<VaultService>>,
|
pub vault: tokio::sync::Mutex<Option<VaultService>>,
|
||||||
pub settings: SettingsService,
|
pub settings: SettingsService,
|
||||||
pub connections: ConnectionService,
|
pub connections: ConnectionService,
|
||||||
pub credentials: Mutex<Option<CredentialService>>,
|
pub credentials: tokio::sync::Mutex<Option<CredentialService>>,
|
||||||
pub ssh: SshService,
|
pub ssh: SshService,
|
||||||
pub sftp: SftpService,
|
pub sftp: SftpService,
|
||||||
pub rdp: RdpService,
|
pub rdp: RdpService,
|
||||||
@ -60,17 +60,18 @@ impl AppState {
|
|||||||
std::fs::create_dir_all(&data_dir)?;
|
std::fs::create_dir_all(&data_dir)?;
|
||||||
let database = Database::open(&data_dir.join("wraith.db"))?;
|
let database = Database::open(&data_dir.join("wraith.db"))?;
|
||||||
database.migrate()?;
|
database.migrate()?;
|
||||||
|
let settings = SettingsService::new(database.clone());
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
db: database.clone(),
|
db: database.clone(),
|
||||||
vault: Mutex::new(None),
|
vault: tokio::sync::Mutex::new(None),
|
||||||
settings: SettingsService::new(database.clone()),
|
|
||||||
connections: ConnectionService::new(database.clone()),
|
connections: ConnectionService::new(database.clone()),
|
||||||
credentials: Mutex::new(None),
|
credentials: tokio::sync::Mutex::new(None),
|
||||||
ssh: SshService::new(database.clone()),
|
ssh: SshService::new(database.clone()),
|
||||||
sftp: SftpService::new(),
|
sftp: SftpService::new(),
|
||||||
rdp: RdpService::new(),
|
rdp: RdpService::new(),
|
||||||
theme: ThemeService::new(database.clone()),
|
theme: ThemeService::new(database),
|
||||||
workspace: WorkspaceService::new(SettingsService::new(database.clone())),
|
workspace: WorkspaceService::new(settings.clone()),
|
||||||
|
settings,
|
||||||
pty: PtyService::new(),
|
pty: PtyService::new(),
|
||||||
scrollback: ScrollbackRegistry::new(),
|
scrollback: ScrollbackRegistry::new(),
|
||||||
error_watcher: std::sync::Arc::new(ErrorWatcher::new()),
|
error_watcher: std::sync::Arc::new(ErrorWatcher::new()),
|
||||||
@ -85,8 +86,8 @@ impl AppState {
|
|||||||
self.settings.get("vault_salt").unwrap_or_default().is_empty()
|
self.settings.get("vault_salt").unwrap_or_default().is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_unlocked(&self) -> bool {
|
pub async fn is_unlocked(&self) -> bool {
|
||||||
self.vault.lock().unwrap().is_some()
|
self.vault.lock().await.is_some()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -40,13 +40,25 @@ impl ScrollbackBuffer {
|
|||||||
|
|
||||||
/// Append bytes to the buffer. Old data is overwritten when full.
|
/// Append bytes to the buffer. Old data is overwritten when full.
|
||||||
pub fn push(&self, bytes: &[u8]) {
|
pub fn push(&self, bytes: &[u8]) {
|
||||||
let mut buf = self.inner.lock().unwrap();
|
if bytes.is_empty() {
|
||||||
for &b in bytes {
|
return;
|
||||||
let pos = buf.write_pos;
|
|
||||||
buf.data[pos] = b;
|
|
||||||
buf.write_pos = (pos + 1) % buf.capacity;
|
|
||||||
buf.total_written += 1;
|
|
||||||
}
|
}
|
||||||
|
let mut buf = self.inner.lock().unwrap();
|
||||||
|
let cap = buf.capacity;
|
||||||
|
// If input exceeds capacity, only keep the last `cap` bytes
|
||||||
|
let data = if bytes.len() > cap {
|
||||||
|
&bytes[bytes.len() - cap..]
|
||||||
|
} else {
|
||||||
|
bytes
|
||||||
|
};
|
||||||
|
let write_pos = buf.write_pos;
|
||||||
|
let first_len = (cap - write_pos).min(data.len());
|
||||||
|
buf.data[write_pos..write_pos + first_len].copy_from_slice(&data[..first_len]);
|
||||||
|
if first_len < data.len() {
|
||||||
|
buf.data[..data.len() - first_len].copy_from_slice(&data[first_len..]);
|
||||||
|
}
|
||||||
|
buf.write_pos = (write_pos + data.len()) % cap;
|
||||||
|
buf.total_written += bytes.len();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read the last `n` lines from the buffer, with ANSI escape codes stripped.
|
/// Read the last `n` lines from the buffer, with ANSI escape codes stripped.
|
||||||
@ -192,4 +204,42 @@ mod tests {
|
|||||||
buf.push(b"ABCD"); // 4 more, wraps
|
buf.push(b"ABCD"); // 4 more, wraps
|
||||||
assert_eq!(buf.total_written(), 12);
|
assert_eq!(buf.total_written(), 12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_empty_is_noop() {
|
||||||
|
let buf = ScrollbackBuffer::with_capacity(8);
|
||||||
|
buf.push(b"hello");
|
||||||
|
buf.push(b"");
|
||||||
|
assert_eq!(buf.total_written(), 5);
|
||||||
|
assert!(buf.read_raw().contains("hello"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_larger_than_capacity() {
|
||||||
|
let buf = ScrollbackBuffer::with_capacity(4);
|
||||||
|
buf.push(b"ABCDEFGH"); // 8 bytes into 4-byte buffer
|
||||||
|
let raw = buf.read_raw();
|
||||||
|
assert_eq!(raw, "EFGH"); // only last 4 bytes kept
|
||||||
|
assert_eq!(buf.total_written(), 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_exact_capacity() {
|
||||||
|
let buf = ScrollbackBuffer::with_capacity(8);
|
||||||
|
buf.push(b"12345678");
|
||||||
|
let raw = buf.read_raw();
|
||||||
|
assert_eq!(raw, "12345678");
|
||||||
|
assert_eq!(buf.total_written(), 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_wrap_around_boundary() {
|
||||||
|
let buf = ScrollbackBuffer::with_capacity(8);
|
||||||
|
buf.push(b"123456"); // write_pos = 6
|
||||||
|
buf.push(b"ABCD"); // wraps: 2 at end, 2 at start
|
||||||
|
let raw = buf.read_raw();
|
||||||
|
// Buffer: [C, D, 3, 4, 5, 6, A, B], write_pos=2
|
||||||
|
// Read from pos 2: "3456AB" + wrap: no, read from write_pos to end then start
|
||||||
|
assert_eq!(raw, "3456ABCD");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,14 @@
|
|||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{extract::State as AxumState, routing::post, Json, Router};
|
use axum::{
|
||||||
|
extract::State as AxumState,
|
||||||
|
http::{Request, StatusCode},
|
||||||
|
middleware::{self, Next},
|
||||||
|
response::Response,
|
||||||
|
routing::post,
|
||||||
|
Json, Router,
|
||||||
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
@ -13,6 +20,7 @@ use crate::mcp::ScrollbackRegistry;
|
|||||||
use crate::rdp::RdpService;
|
use crate::rdp::RdpService;
|
||||||
use crate::sftp::SftpService;
|
use crate::sftp::SftpService;
|
||||||
use crate::ssh::session::SshService;
|
use crate::ssh::session::SshService;
|
||||||
|
use crate::utils::shell_escape;
|
||||||
|
|
||||||
/// Shared state passed to axum handlers.
|
/// Shared state passed to axum handlers.
|
||||||
pub struct McpServerState {
|
pub struct McpServerState {
|
||||||
@ -22,6 +30,27 @@ pub struct McpServerState {
|
|||||||
pub scrollback: ScrollbackRegistry,
|
pub scrollback: ScrollbackRegistry,
|
||||||
pub app_handle: tauri::AppHandle,
|
pub app_handle: tauri::AppHandle,
|
||||||
pub error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
|
pub error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
|
||||||
|
pub bearer_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware that validates the `Authorization: Bearer <token>` header.
|
||||||
|
async fn auth_middleware(
|
||||||
|
AxumState(state): AxumState<Arc<McpServerState>>,
|
||||||
|
req: Request<axum::body::Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
let auth_header = req
|
||||||
|
.headers()
|
||||||
|
.get("authorization")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
|
let expected = format!("Bearer {}", state.bearer_token);
|
||||||
|
if auth_header != expected {
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(next.run(req).await)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -279,29 +308,31 @@ struct ToolPassgenRequest { length: Option<usize>, uppercase: Option<bool>, lowe
|
|||||||
|
|
||||||
async fn handle_tool_ping(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
async fn handle_tool_ping(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &format!("ping -c 4 {} 2>&1", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tool_traceroute(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
async fn handle_tool_traceroute(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
let t = shell_escape(&req.target);
|
||||||
|
match tool_exec(&session.handle, &format!("traceroute {} 2>&1 || tracert {} 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tool_dns(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolDnsRequest>) -> Json<McpResponse<String>> {
|
async fn handle_tool_dns(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolDnsRequest>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
let rt = req.record_type.unwrap_or_else(|| "A".to_string());
|
let rt = shell_escape(&req.record_type.unwrap_or_else(|| "A".to_string()));
|
||||||
match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", req.domain, rt, rt, req.domain, rt, req.domain)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
let d = shell_escape(&req.domain);
|
||||||
|
match tool_exec(&session.handle, &format!("dig {} {} +short 2>/dev/null || nslookup -type={} {} 2>/dev/null || host -t {} {} 2>/dev/null", d, rt, rt, d, rt, d)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tool_whois(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
async fn handle_tool_whois(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &format!("whois {} 2>&1 | head -80", shell_escape(&req.target))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tool_wol(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolWolRequest>) -> Json<McpResponse<String>> {
|
async fn handle_tool_wol(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolWolRequest>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
let mac_clean = req.mac_address.replace([':', '-'], "");
|
let mac_clean = req.mac_address.replace([':', '-'], "");
|
||||||
let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex('{}');pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, mac_clean, req.mac_address);
|
let cmd = format!(r#"python3 -c "import socket;mac=bytes.fromhex({});pkt=b'\xff'*6+mac*16;s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM);s.setsockopt(socket.SOL_SOCKET,socket.SO_BROADCAST,1);s.sendto(pkt,('255.255.255.255',9));s.close();print('WoL sent to {}')" 2>&1"#, shell_escape(&mac_clean), shell_escape(&req.mac_address));
|
||||||
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -382,12 +413,13 @@ async fn handle_docker_ps(AxumState(state): AxumState<Arc<McpServerState>>, Json
|
|||||||
|
|
||||||
async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerActionRequest>) -> Json<McpResponse<String>> {
|
async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerActionRequest>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
|
let t = shell_escape(&req.target);
|
||||||
let cmd = match req.action.as_str() {
|
let cmd = match req.action.as_str() {
|
||||||
"start" => format!("docker start {} 2>&1", req.target),
|
"start" => format!("docker start {} 2>&1", t),
|
||||||
"stop" => format!("docker stop {} 2>&1", req.target),
|
"stop" => format!("docker stop {} 2>&1", t),
|
||||||
"restart" => format!("docker restart {} 2>&1", req.target),
|
"restart" => format!("docker restart {} 2>&1", t),
|
||||||
"remove" => format!("docker rm -f {} 2>&1", req.target),
|
"remove" => format!("docker rm -f {} 2>&1", t),
|
||||||
"logs" => format!("docker logs --tail 100 {} 2>&1", req.target),
|
"logs" => format!("docker logs --tail 100 {} 2>&1", t),
|
||||||
"builder-prune" => "docker builder prune -f 2>&1".to_string(),
|
"builder-prune" => "docker builder prune -f 2>&1".to_string(),
|
||||||
"system-prune" => "docker system prune -f 2>&1".to_string(),
|
"system-prune" => "docker system prune -f 2>&1".to_string(),
|
||||||
_ => return err_response(format!("Unknown action: {}", req.action)),
|
_ => return err_response(format!("Unknown action: {}", req.action)),
|
||||||
@ -397,7 +429,7 @@ async fn handle_docker_action(AxumState(state): AxumState<Arc<McpServerState>>,
|
|||||||
|
|
||||||
async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerExecRequest>) -> Json<McpResponse<String>> {
|
async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<DockerExecRequest>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
let cmd = format!("docker exec {} {} 2>&1", req.container, req.command);
|
let cmd = format!("docker exec {} {} 2>&1", shell_escape(&req.container), shell_escape(&req.command));
|
||||||
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &cmd).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,12 +437,13 @@ async fn handle_docker_exec(AxumState(state): AxumState<Arc<McpServerState>>, Js
|
|||||||
|
|
||||||
async fn handle_service_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
async fn handle_service_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", req.target, req.target)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
let t = shell_escape(&req.target);
|
||||||
|
match tool_exec(&session.handle, &format!("systemctl status {} --no-pager 2>&1 || service {} status 2>&1", t, t)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_process_list(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
async fn handle_process_list(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<ToolSessionTarget>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", req.target) };
|
let filter = if req.target.is_empty() { "aux --sort=-%cpu | head -30".to_string() } else { format!("aux | grep -i {} | grep -v grep", shell_escape(&req.target)) };
|
||||||
match tool_exec(&session.handle, &format!("ps {}", filter)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &format!("ps {}", filter)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -421,17 +454,17 @@ struct GitRequest { session_id: String, path: String }
|
|||||||
|
|
||||||
async fn handle_git_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
|
async fn handle_git_status(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &format!("cd {} && git status --short --branch 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_git_pull(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
|
async fn handle_git_pull(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &format!("cd {} && git pull 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_git_log(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
|
async fn handle_git_log(AxumState(state): AxumState<Arc<McpServerState>>, Json(req): Json<GitRequest>) -> Json<McpResponse<String>> {
|
||||||
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
let session = match state.ssh.get_session(&req.session_id) { Some(s) => s, None => return err_response(format!("Session {} not found", req.session_id)) };
|
||||||
match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", req.path)).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
match tool_exec(&session.handle, &format!("cd {} && git log --oneline -20 2>&1", shell_escape(&req.path))).await { Ok(o) => ok_response(o), Err(e) => err_response(e) }
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Session creation handlers ────────────────────────────────────────────────
|
// ── Session creation handlers ────────────────────────────────────────────────
|
||||||
@ -533,7 +566,15 @@ pub async fn start_mcp_server(
|
|||||||
app_handle: tauri::AppHandle,
|
app_handle: tauri::AppHandle,
|
||||||
error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
|
error_watcher: std::sync::Arc<crate::mcp::error_watcher::ErrorWatcher>,
|
||||||
) -> Result<u16, String> {
|
) -> Result<u16, String> {
|
||||||
let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher });
|
// Generate a cryptographically random bearer token for authentication
|
||||||
|
use rand::Rng;
|
||||||
|
let bearer_token: String = rand::rng()
|
||||||
|
.sample_iter(&rand::distr::Alphanumeric)
|
||||||
|
.take(64)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let state = Arc::new(McpServerState { ssh, rdp, sftp, scrollback, app_handle, error_watcher, bearer_token: bearer_token.clone() });
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/mcp/sessions", post(handle_list_sessions))
|
.route("/mcp/sessions", post(handle_list_sessions))
|
||||||
@ -567,6 +608,7 @@ pub async fn start_mcp_server(
|
|||||||
.route("/mcp/rdp/type", post(handle_rdp_type))
|
.route("/mcp/rdp/type", post(handle_rdp_type))
|
||||||
.route("/mcp/rdp/clipboard", post(handle_rdp_clipboard))
|
.route("/mcp/rdp/clipboard", post(handle_rdp_clipboard))
|
||||||
.route("/mcp/ssh/connect", post(handle_ssh_connect))
|
.route("/mcp/ssh/connect", post(handle_ssh_connect))
|
||||||
|
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware))
|
||||||
.with_state(state);
|
.with_state(state);
|
||||||
|
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await
|
let listener = TcpListener::bind("127.0.0.1:0").await
|
||||||
@ -577,10 +619,23 @@ pub async fn start_mcp_server(
|
|||||||
.port();
|
.port();
|
||||||
|
|
||||||
// Write port to well-known location
|
// Write port to well-known location
|
||||||
let port_file = crate::data_directory().join("mcp-port");
|
let data_dir = crate::data_directory();
|
||||||
|
let port_file = data_dir.join("mcp-port");
|
||||||
std::fs::write(&port_file, port.to_string())
|
std::fs::write(&port_file, port.to_string())
|
||||||
.map_err(|e| format!("Failed to write MCP port file: {}", e))?;
|
.map_err(|e| format!("Failed to write MCP port file: {}", e))?;
|
||||||
|
|
||||||
|
// Write bearer token to a separate file with restrictive permissions
|
||||||
|
let token_file = data_dir.join("mcp-token");
|
||||||
|
std::fs::write(&token_file, &bearer_token)
|
||||||
|
.map_err(|e| format!("Failed to write MCP token file: {}", e))?;
|
||||||
|
|
||||||
|
// Set owner-only read/write permissions (Unix)
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
let _ = std::fs::set_permissions(&token_file, std::fs::Permissions::from_mode(0o600));
|
||||||
|
}
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.ok();
|
axum::serve(listener, app).await.ok();
|
||||||
});
|
});
|
||||||
|
|||||||
@ -12,6 +12,7 @@ use serde::Serialize;
|
|||||||
use tokio::sync::Mutex as TokioMutex;
|
use tokio::sync::Mutex as TokioMutex;
|
||||||
|
|
||||||
use crate::ssh::session::SshClient;
|
use crate::ssh::session::SshClient;
|
||||||
|
use crate::utils::shell_escape;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Clone)]
|
#[derive(Debug, Serialize, Clone)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
@ -72,9 +73,10 @@ pub async fn scan_network(
|
|||||||
// 1. Ping sweep the subnet to populate ARP cache
|
// 1. Ping sweep the subnet to populate ARP cache
|
||||||
// 2. Read ARP table for IP/MAC pairs
|
// 2. Read ARP table for IP/MAC pairs
|
||||||
// 3. Try reverse DNS for hostnames
|
// 3. Try reverse DNS for hostnames
|
||||||
|
let escaped_subnet = shell_escape(subnet);
|
||||||
let script = format!(r#"
|
let script = format!(r#"
|
||||||
OS=$(uname -s 2>/dev/null)
|
OS=$(uname -s 2>/dev/null)
|
||||||
SUBNET="{subnet}"
|
SUBNET={escaped_subnet}
|
||||||
|
|
||||||
# Ping sweep (background, fast)
|
# Ping sweep (background, fast)
|
||||||
if [ "$OS" = "Linux" ]; then
|
if [ "$OS" = "Linux" ]; then
|
||||||
@ -151,6 +153,12 @@ pub async fn scan_ports(
|
|||||||
target: &str,
|
target: &str,
|
||||||
ports: &[u16],
|
ports: &[u16],
|
||||||
) -> Result<Vec<PortResult>, String> {
|
) -> Result<Vec<PortResult>, String> {
|
||||||
|
// Validate target — /dev/tcp requires a bare hostname/IP, not a shell-quoted value.
|
||||||
|
// Only allow alphanumeric, dots, hyphens, and colons (for IPv6).
|
||||||
|
if !target.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == ':') {
|
||||||
|
return Err(format!("Invalid target for port scan: {}", target));
|
||||||
|
}
|
||||||
|
|
||||||
// Use bash /dev/tcp for port scanning — no nmap required
|
// Use bash /dev/tcp for port scanning — no nmap required
|
||||||
let port_checks: Vec<String> = ports.iter()
|
let port_checks: Vec<String> = ports.iter()
|
||||||
.map(|p| format!(
|
.map(|p| format!(
|
||||||
|
|||||||
@ -8,6 +8,7 @@ use crate::db::Database;
|
|||||||
///
|
///
|
||||||
/// All operations acquire the shared DB mutex for their duration and
|
/// All operations acquire the shared DB mutex for their duration and
|
||||||
/// return immediately — no async needed for a local SQLite store.
|
/// return immediately — no async needed for a local SQLite store.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct SettingsService {
|
pub struct SettingsService {
|
||||||
db: Database,
|
db: Database,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,6 +16,7 @@ use russh::ChannelMsg;
|
|||||||
use tauri::{AppHandle, Emitter};
|
use tauri::{AppHandle, Emitter};
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
use tokio::sync::Mutex as TokioMutex;
|
use tokio::sync::Mutex as TokioMutex;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use crate::ssh::session::SshClient;
|
use crate::ssh::session::SshClient;
|
||||||
|
|
||||||
@ -39,13 +40,15 @@ impl CwdTracker {
|
|||||||
/// Spawn a background tokio task that polls `pwd` every 2 seconds on a
|
/// Spawn a background tokio task that polls `pwd` every 2 seconds on a
|
||||||
/// separate exec channel.
|
/// separate exec channel.
|
||||||
///
|
///
|
||||||
/// The task runs until the SSH connection is closed or the channel cannot
|
/// The task runs until cancelled via the `CancellationToken`, or until the
|
||||||
/// be opened. CWD changes are emitted as `ssh:cwd:{session_id}` events.
|
/// SSH connection is closed or the channel cannot be opened.
|
||||||
|
/// CWD changes are emitted as `ssh:cwd:{session_id}` events.
|
||||||
pub fn start(
|
pub fn start(
|
||||||
&self,
|
&self,
|
||||||
handle: Arc<TokioMutex<Handle<SshClient>>>,
|
handle: Arc<TokioMutex<Handle<SshClient>>>,
|
||||||
app_handle: AppHandle,
|
app_handle: AppHandle,
|
||||||
session_id: String,
|
session_id: String,
|
||||||
|
cancel: CancellationToken,
|
||||||
) {
|
) {
|
||||||
let sender = self._sender.clone();
|
let sender = self._sender.clone();
|
||||||
|
|
||||||
@ -56,6 +59,10 @@ impl CwdTracker {
|
|||||||
let mut previous_cwd = String::new();
|
let mut previous_cwd = String::new();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
if cancel.is_cancelled() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Open a fresh exec channel for each `pwd` invocation.
|
// Open a fresh exec channel for each `pwd` invocation.
|
||||||
// Some SSH servers do not allow multiple exec requests on a
|
// Some SSH servers do not allow multiple exec requests on a
|
||||||
// single channel, so we open a new one each time.
|
// single channel, so we open a new one each time.
|
||||||
@ -119,8 +126,11 @@ impl CwdTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait 2 seconds before the next poll.
|
// Wait 2 seconds before the next poll, or cancel.
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
tokio::select! {
|
||||||
|
_ = tokio::time::sleep(tokio::time::Duration::from_secs(2)) => {}
|
||||||
|
_ = cancel.cancelled() => { break; }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("CWD tracker for session {} stopped", session_id);
|
debug!("CWD tracker for session {} stopped", session_id);
|
||||||
|
|||||||
@ -6,11 +6,13 @@
|
|||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use log::warn;
|
||||||
use russh::client::Handle;
|
use russh::client::Handle;
|
||||||
use russh::ChannelMsg;
|
use russh::ChannelMsg;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use tauri::{AppHandle, Emitter};
|
use tauri::{AppHandle, Emitter};
|
||||||
use tokio::sync::Mutex as TokioMutex;
|
use tokio::sync::Mutex as TokioMutex;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use crate::ssh::session::SshClient;
|
use crate::ssh::session::SshClient;
|
||||||
|
|
||||||
@ -30,26 +32,53 @@ pub struct SystemStats {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Spawn a background task that polls system stats every 5 seconds.
|
/// Spawn a background task that polls system stats every 5 seconds.
|
||||||
|
///
|
||||||
|
/// The task runs until cancelled via the `CancellationToken`, or until the
|
||||||
|
/// SSH connection is closed.
|
||||||
pub fn start_monitor(
|
pub fn start_monitor(
|
||||||
handle: Arc<TokioMutex<Handle<SshClient>>>,
|
handle: Arc<TokioMutex<Handle<SshClient>>>,
|
||||||
app_handle: AppHandle,
|
app_handle: AppHandle,
|
||||||
session_id: String,
|
session_id: String,
|
||||||
|
cancel: CancellationToken,
|
||||||
) {
|
) {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// Brief delay to let the shell start up
|
// Brief delay to let the shell start up
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||||
|
|
||||||
|
let mut consecutive_timeouts: u32 = 0;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
if cancel.is_cancelled() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
let stats = collect_stats(&handle).await;
|
let stats = collect_stats(&handle).await;
|
||||||
|
|
||||||
if let Some(stats) = stats {
|
match stats {
|
||||||
|
Some(stats) => {
|
||||||
|
consecutive_timeouts = 0;
|
||||||
let _ = app_handle.emit(
|
let _ = app_handle.emit(
|
||||||
&format!("ssh:monitor:{}", session_id),
|
&format!("ssh:monitor:{}", session_id),
|
||||||
&stats,
|
&stats,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
None => {
|
||||||
|
consecutive_timeouts += 1;
|
||||||
|
if consecutive_timeouts >= 3 {
|
||||||
|
warn!(
|
||||||
|
"SSH monitor for session {}: 3 consecutive failures, stopping",
|
||||||
|
session_id
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
// Wait 5 seconds before the next poll, or cancel.
|
||||||
|
tokio::select! {
|
||||||
|
_ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {}
|
||||||
|
_ = cancel.cancelled() => { break; }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -125,7 +154,24 @@ fn parse_stats(raw: &str) -> Option<SystemStats> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Execute a command on a separate exec channel with a 10-second timeout.
|
||||||
async fn exec_command(handle: &Arc<TokioMutex<Handle<SshClient>>>, cmd: &str) -> Option<String> {
|
async fn exec_command(handle: &Arc<TokioMutex<Handle<SshClient>>>, cmd: &str) -> Option<String> {
|
||||||
|
let result = tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(10),
|
||||||
|
exec_command_inner(handle, cmd),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(output) => output,
|
||||||
|
Err(_) => {
|
||||||
|
warn!("SSH monitor exec_command timed out after 10s");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn exec_command_inner(handle: &Arc<TokioMutex<Handle<SshClient>>>, cmd: &str) -> Option<String> {
|
||||||
let mut channel = {
|
let mut channel = {
|
||||||
let h = handle.lock().await;
|
let h = handle.lock().await;
|
||||||
h.channel_open_session().await.ok()?
|
h.channel_open_session().await.ok()?
|
||||||
|
|||||||
@ -17,6 +17,7 @@ use crate::mcp::error_watcher::ErrorWatcher;
|
|||||||
use crate::sftp::SftpService;
|
use crate::sftp::SftpService;
|
||||||
use crate::ssh::cwd::CwdTracker;
|
use crate::ssh::cwd::CwdTracker;
|
||||||
use crate::ssh::host_key::{HostKeyResult, HostKeyStore};
|
use crate::ssh::host_key::{HostKeyResult, HostKeyStore};
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
pub enum AuthMethod {
|
pub enum AuthMethod {
|
||||||
Password(String),
|
Password(String),
|
||||||
@ -47,6 +48,7 @@ pub struct SshSession {
|
|||||||
pub handle: Arc<TokioMutex<Handle<SshClient>>>,
|
pub handle: Arc<TokioMutex<Handle<SshClient>>>,
|
||||||
pub command_tx: mpsc::UnboundedSender<ChannelCommand>,
|
pub command_tx: mpsc::UnboundedSender<ChannelCommand>,
|
||||||
pub cwd_tracker: Option<CwdTracker>,
|
pub cwd_tracker: Option<CwdTracker>,
|
||||||
|
pub cancel_token: CancellationToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SshClient {
|
pub struct SshClient {
|
||||||
@ -135,10 +137,11 @@ impl SshService {
|
|||||||
let channel_id = channel.id();
|
let channel_id = channel.id();
|
||||||
let handle = Arc::new(TokioMutex::new(handle));
|
let handle = Arc::new(TokioMutex::new(handle));
|
||||||
let (command_tx, mut command_rx) = mpsc::unbounded_channel::<ChannelCommand>();
|
let (command_tx, mut command_rx) = mpsc::unbounded_channel::<ChannelCommand>();
|
||||||
|
let cancel_token = CancellationToken::new();
|
||||||
let cwd_tracker = CwdTracker::new();
|
let cwd_tracker = CwdTracker::new();
|
||||||
cwd_tracker.start(handle.clone(), app_handle.clone(), session_id.clone());
|
cwd_tracker.start(handle.clone(), app_handle.clone(), session_id.clone(), cancel_token.clone());
|
||||||
|
|
||||||
let session = Arc::new(SshSession { id: session_id.clone(), hostname: hostname.to_string(), port, username: username.to_string(), channel_id, handle: handle.clone(), command_tx: command_tx.clone(), cwd_tracker: Some(cwd_tracker) });
|
let session = Arc::new(SshSession { id: session_id.clone(), hostname: hostname.to_string(), port, username: username.to_string(), channel_id, handle: handle.clone(), command_tx: command_tx.clone(), cwd_tracker: Some(cwd_tracker), cancel_token: cancel_token.clone() });
|
||||||
self.sessions.insert(session_id.clone(), session);
|
self.sessions.insert(session_id.clone(), session);
|
||||||
|
|
||||||
{ let h = handle.lock().await;
|
{ let h = handle.lock().await;
|
||||||
@ -158,7 +161,7 @@ impl SshService {
|
|||||||
error_watcher.watch(&session_id);
|
error_watcher.watch(&session_id);
|
||||||
|
|
||||||
// Start remote monitoring if enabled (runs on a separate exec channel)
|
// Start remote monitoring if enabled (runs on a separate exec channel)
|
||||||
crate::ssh::monitor::start_monitor(handle.clone(), app_handle.clone(), session_id.clone());
|
crate::ssh::monitor::start_monitor(handle.clone(), app_handle.clone(), session_id.clone(), cancel_token.clone());
|
||||||
|
|
||||||
// Inject OSC 7 CWD reporting hook into the user's shell.
|
// Inject OSC 7 CWD reporting hook into the user's shell.
|
||||||
// This enables SFTP CWD following on all platforms (Linux, macOS, FreeBSD).
|
// This enables SFTP CWD following on all platforms (Linux, macOS, FreeBSD).
|
||||||
@ -246,6 +249,8 @@ impl SshService {
|
|||||||
|
|
||||||
pub async fn disconnect(&self, session_id: &str, sftp_service: &SftpService) -> Result<(), String> {
|
pub async fn disconnect(&self, session_id: &str, sftp_service: &SftpService) -> Result<(), String> {
|
||||||
let (_, session) = self.sessions.remove(session_id).ok_or_else(|| format!("Session {} not found", session_id))?;
|
let (_, session) = self.sessions.remove(session_id).ok_or_else(|| format!("Session {} not found", session_id))?;
|
||||||
|
// Cancel background tasks (CWD tracker, monitor) before tearing down the connection.
|
||||||
|
session.cancel_token.cancel();
|
||||||
let _ = session.command_tx.send(ChannelCommand::Shutdown);
|
let _ = session.command_tx.send(ChannelCommand::Shutdown);
|
||||||
{ let handle = session.handle.lock().await; let _ = handle.disconnect(Disconnect::ByApplication, "", "en").await; }
|
{ let handle = session.handle.lock().await; let _ = handle.disconnect(Disconnect::ByApplication, "", "en").await; }
|
||||||
sftp_service.remove_client(session_id);
|
sftp_service.remove_client(session_id);
|
||||||
|
|||||||
19
src-tauri/src/utils.rs
Normal file
19
src-tauri/src/utils.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
//! Shared utility functions.
|
||||||
|
|
||||||
|
/// Escape a string for safe interpolation into a POSIX shell command.
|
||||||
|
///
|
||||||
|
/// Wraps the input in single quotes and escapes any embedded single quotes
|
||||||
|
/// using the `'\''` technique. This prevents command injection when building
|
||||||
|
/// shell commands from user-supplied values.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use wraith_lib::utils::shell_escape;
|
||||||
|
/// assert_eq!(shell_escape("hello"), "'hello'");
|
||||||
|
/// assert_eq!(shell_escape("it's"), "'it'\\''s'");
|
||||||
|
/// assert_eq!(shell_escape(";rm -rf /"), "';rm -rf /'");
|
||||||
|
/// ```
|
||||||
|
pub fn shell_escape(input: &str) -> String {
|
||||||
|
format!("'{}'", input.replace('\'', "'\\''"))
|
||||||
|
}
|
||||||
@ -4,6 +4,7 @@ use aes_gcm::{
|
|||||||
Aes256Gcm, Key, Nonce,
|
Aes256Gcm, Key, Nonce,
|
||||||
};
|
};
|
||||||
use argon2::{Algorithm, Argon2, Params, Version};
|
use argon2::{Algorithm, Argon2, Params, Version};
|
||||||
|
use zeroize::Zeroizing;
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// VaultService
|
// VaultService
|
||||||
@ -21,18 +22,18 @@ use argon2::{Algorithm, Argon2, Params, Version};
|
|||||||
/// The version prefix allows a future migration to a different algorithm
|
/// The version prefix allows a future migration to a different algorithm
|
||||||
/// without breaking existing stored blobs.
|
/// without breaking existing stored blobs.
|
||||||
pub struct VaultService {
|
pub struct VaultService {
|
||||||
key: [u8; 32],
|
key: Zeroizing<[u8; 32]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VaultService {
|
impl VaultService {
|
||||||
pub fn new(key: [u8; 32]) -> Self {
|
pub fn new(key: Zeroizing<[u8; 32]>) -> Self {
|
||||||
Self { key }
|
Self { key }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt `plaintext` and return a `v1:{iv_hex}:{sealed_hex}` blob.
|
/// Encrypt `plaintext` and return a `v1:{iv_hex}:{sealed_hex}` blob.
|
||||||
pub fn encrypt(&self, plaintext: &str) -> Result<String, String> {
|
pub fn encrypt(&self, plaintext: &str) -> Result<String, String> {
|
||||||
// Build the AES-256-GCM cipher from our key.
|
// Build the AES-256-GCM cipher from our key.
|
||||||
let key = Key::<Aes256Gcm>::from_slice(&self.key);
|
let key = Key::<Aes256Gcm>::from_slice(&*self.key);
|
||||||
let cipher = Aes256Gcm::new(key);
|
let cipher = Aes256Gcm::new(key);
|
||||||
|
|
||||||
// Generate a random 12-byte nonce (96-bit is the GCM standard).
|
// Generate a random 12-byte nonce (96-bit is the GCM standard).
|
||||||
@ -71,7 +72,7 @@ impl VaultService {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let key = Key::<Aes256Gcm>::from_slice(&self.key);
|
let key = Key::<Aes256Gcm>::from_slice(&*self.key);
|
||||||
let cipher = Aes256Gcm::new(key);
|
let cipher = Aes256Gcm::new(key);
|
||||||
let nonce = Nonce::from_slice(&iv_bytes);
|
let nonce = Nonce::from_slice(&iv_bytes);
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ impl VaultService {
|
|||||||
/// t = 3 iterations
|
/// t = 3 iterations
|
||||||
/// m = 65536 KiB (64 MiB) memory
|
/// m = 65536 KiB (64 MiB) memory
|
||||||
/// p = 4 parallelism lanes
|
/// p = 4 parallelism lanes
|
||||||
pub fn derive_key(password: &str, salt: &[u8]) -> [u8; 32] {
|
pub fn derive_key(password: &str, salt: &[u8]) -> Zeroizing<[u8; 32]> {
|
||||||
let params = Params::new(
|
let params = Params::new(
|
||||||
65536, // m_cost: 64 MiB
|
65536, // m_cost: 64 MiB
|
||||||
3, // t_cost: iterations
|
3, // t_cost: iterations
|
||||||
@ -106,9 +107,9 @@ pub fn derive_key(password: &str, salt: &[u8]) -> [u8; 32] {
|
|||||||
|
|
||||||
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
|
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
|
||||||
|
|
||||||
let mut output_key = [0u8; 32];
|
let mut output_key = Zeroizing::new([0u8; 32]);
|
||||||
argon2
|
argon2
|
||||||
.hash_password_into(password.as_bytes(), salt, &mut output_key)
|
.hash_password_into(password.as_bytes(), salt, &mut *output_key)
|
||||||
.expect("Argon2id key derivation failed");
|
.expect("Argon2id key derivation failed");
|
||||||
|
|
||||||
output_key
|
output_key
|
||||||
|
|||||||
@ -22,9 +22,9 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"security": {
|
"security": {
|
||||||
"csp": null
|
"csp": "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' asset: https://asset.localhost data:; connect-src 'self' ipc: http://ipc.localhost"
|
||||||
},
|
},
|
||||||
"withGlobalTauri": true
|
"withGlobalTauri": false
|
||||||
},
|
},
|
||||||
"bundle": {
|
"bundle": {
|
||||||
"active": true,
|
"active": true,
|
||||||
|
|||||||
@ -88,7 +88,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted } from "vue";
|
import { ref, onMounted, onBeforeUnmount } from "vue";
|
||||||
import { invoke } from "@tauri-apps/api/core";
|
import { invoke } from "@tauri-apps/api/core";
|
||||||
import { useSessionStore, type Session } from "@/stores/session.store";
|
import { useSessionStore, type Session } from "@/stores/session.store";
|
||||||
import { useConnectionStore } from "@/stores/connection.store";
|
import { useConnectionStore } from "@/stores/connection.store";
|
||||||
@ -151,16 +151,10 @@ function closeMenuTab(): void {
|
|||||||
if (session) sessionStore.closeSession(session.id);
|
if (session) sessionStore.closeSession(session.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen for reattach events from detached windows
|
|
||||||
import { listen } from "@tauri-apps/api/event";
|
import { listen } from "@tauri-apps/api/event";
|
||||||
listen<{ sessionId: string; name: string; protocol: string }>("session:reattach", (event) => {
|
import type { UnlistenFn } from "@tauri-apps/api/event";
|
||||||
const { sessionId } = event.payload;
|
|
||||||
const session = sessionStore.sessions.find(s => s.id === sessionId);
|
let unlistenReattach: UnlistenFn | null = null;
|
||||||
if (session) {
|
|
||||||
session.active = true;
|
|
||||||
sessionStore.activateSession(sessionId);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
try {
|
try {
|
||||||
@ -168,6 +162,19 @@ onMounted(async () => {
|
|||||||
} catch {
|
} catch {
|
||||||
availableShells.value = [];
|
availableShells.value = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unlistenReattach = await listen<{ sessionId: string; name: string; protocol: string }>("session:reattach", (event) => {
|
||||||
|
const { sessionId } = event.payload;
|
||||||
|
const session = sessionStore.sessions.find(s => s.id === sessionId);
|
||||||
|
if (session) {
|
||||||
|
session.active = true;
|
||||||
|
sessionStore.activateSession(sessionId);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
onBeforeUnmount(() => {
|
||||||
|
unlistenReattach?.();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Drag-and-drop tab reordering
|
// Drag-and-drop tab reordering
|
||||||
|
|||||||
@ -184,7 +184,7 @@ export interface UseRdpReturn {
|
|||||||
* Composable that manages an RDP session's rendering and input.
|
* Composable that manages an RDP session's rendering and input.
|
||||||
*
|
*
|
||||||
* Uses Tauri's invoke() to call Rust commands:
|
* Uses Tauri's invoke() to call Rust commands:
|
||||||
* rdp_get_frame → base64 RGBA string
|
* rdp_get_frame → raw RGBA ArrayBuffer (binary IPC)
|
||||||
* rdp_send_mouse → fire-and-forget
|
* rdp_send_mouse → fire-and-forget
|
||||||
* rdp_send_key → fire-and-forget
|
* rdp_send_key → fire-and-forget
|
||||||
* rdp_send_clipboard → fire-and-forget
|
* rdp_send_clipboard → fire-and-forget
|
||||||
@ -195,6 +195,7 @@ export function useRdp(): UseRdpReturn {
|
|||||||
const clipboardSync = ref(false);
|
const clipboardSync = ref(false);
|
||||||
|
|
||||||
let animFrameId: number | null = null;
|
let animFrameId: number | null = null;
|
||||||
|
let unlistenFrame: (() => void) | null = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fetch the current frame from the Rust RDP backend.
|
* Fetch the current frame from the Rust RDP backend.
|
||||||
@ -208,16 +209,16 @@ export function useRdp(): UseRdpReturn {
|
|||||||
width: number,
|
width: number,
|
||||||
height: number,
|
height: number,
|
||||||
): Promise<ImageData | null> {
|
): Promise<ImageData | null> {
|
||||||
let raw: number[];
|
let raw: ArrayBuffer;
|
||||||
try {
|
try {
|
||||||
raw = await invoke<number[]>("rdp_get_frame", { sessionId });
|
raw = await invoke<ArrayBuffer>("rdp_get_frame", { sessionId });
|
||||||
} catch {
|
} catch {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!raw || raw.length === 0) return null;
|
if (!raw || raw.byteLength === 0) return null;
|
||||||
|
|
||||||
// Binary IPC — Tauri returns Vec<u8> as number array
|
// Binary IPC — tauri::ipc::Response delivers raw bytes as ArrayBuffer
|
||||||
const bytes = new Uint8ClampedArray(raw);
|
const bytes = new Uint8ClampedArray(raw);
|
||||||
|
|
||||||
const expected = width * height * 4;
|
const expected = width * height * 4;
|
||||||
@ -315,8 +316,7 @@ export function useRdp(): UseRdpReturn {
|
|||||||
listen(`rdp:frame:${sessionId}`, () => {
|
listen(`rdp:frame:${sessionId}`, () => {
|
||||||
onFrameReady();
|
onFrameReady();
|
||||||
}).then((unlisten) => {
|
}).then((unlisten) => {
|
||||||
// Store unlisten so we can clean up
|
unlistenFrame = unlisten;
|
||||||
(canvas as any).__wraith_unlisten = unlisten;
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -332,6 +332,10 @@ export function useRdp(): UseRdpReturn {
|
|||||||
cancelAnimationFrame(animFrameId);
|
cancelAnimationFrame(animFrameId);
|
||||||
animFrameId = null;
|
animFrameId = null;
|
||||||
}
|
}
|
||||||
|
if (unlistenFrame !== null) {
|
||||||
|
unlistenFrame();
|
||||||
|
unlistenFrame = null;
|
||||||
|
}
|
||||||
connected.value = false;
|
connected.value = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import { defineStore } from "pinia";
|
|||||||
import { ref, computed } from "vue";
|
import { ref, computed } from "vue";
|
||||||
import { invoke } from "@tauri-apps/api/core";
|
import { invoke } from "@tauri-apps/api/core";
|
||||||
import { listen } from "@tauri-apps/api/event";
|
import { listen } from "@tauri-apps/api/event";
|
||||||
|
import type { UnlistenFn } from "@tauri-apps/api/event";
|
||||||
import { useConnectionStore } from "@/stores/connection.store";
|
import { useConnectionStore } from "@/stores/connection.store";
|
||||||
import type { ThemeDefinition } from "@/components/common/ThemePicker.vue";
|
import type { ThemeDefinition } from "@/components/common/ThemePicker.vue";
|
||||||
|
|
||||||
@ -39,10 +40,14 @@ export const useSessionStore = defineStore("session", () => {
|
|||||||
|
|
||||||
const sessionCount = computed(() => sessions.value.length);
|
const sessionCount = computed(() => sessions.value.length);
|
||||||
|
|
||||||
|
const sessionUnlisteners = new Map<string, Array<UnlistenFn>>();
|
||||||
|
|
||||||
// Listen for backend close/exit events to update session status
|
// Listen for backend close/exit events to update session status
|
||||||
function setupStatusListeners(sessionId: string): void {
|
async function setupStatusListeners(sessionId: string): Promise<void> {
|
||||||
listen(`ssh:close:${sessionId}`, () => markDisconnected(sessionId));
|
const unlisteners: UnlistenFn[] = [];
|
||||||
listen(`ssh:exit:${sessionId}`, () => markDisconnected(sessionId));
|
unlisteners.push(await listen(`ssh:close:${sessionId}`, () => markDisconnected(sessionId)));
|
||||||
|
unlisteners.push(await listen(`ssh:exit:${sessionId}`, () => markDisconnected(sessionId)));
|
||||||
|
sessionUnlisteners.set(sessionId, unlisteners);
|
||||||
}
|
}
|
||||||
|
|
||||||
function markDisconnected(sessionId: string): void {
|
function markDisconnected(sessionId: string): void {
|
||||||
@ -92,6 +97,12 @@ export const useSessionStore = defineStore("session", () => {
|
|||||||
console.error("Failed to disconnect session:", err);
|
console.error("Failed to disconnect session:", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const unlisteners = sessionUnlisteners.get(id);
|
||||||
|
if (unlisteners) {
|
||||||
|
unlisteners.forEach((fn) => fn());
|
||||||
|
sessionUnlisteners.delete(id);
|
||||||
|
}
|
||||||
|
|
||||||
sessions.value.splice(idx, 1);
|
sessions.value.splice(idx, 1);
|
||||||
|
|
||||||
if (activeSessionId.value === id) {
|
if (activeSessionId.value === id) {
|
||||||
@ -325,7 +336,8 @@ export const useSessionStore = defineStore("session", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Listen for PTY close
|
// Listen for PTY close
|
||||||
listen(`pty:close:${sessionId}`, () => markDisconnected(sessionId));
|
const unlistenPty = await listen(`pty:close:${sessionId}`, () => markDisconnected(sessionId));
|
||||||
|
sessionUnlisteners.set(sessionId, [unlistenPty]);
|
||||||
|
|
||||||
activeSessionId.value = sessionId;
|
activeSessionId.value = sessionId;
|
||||||
} catch (err: unknown) {
|
} catch (err: unknown) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user