wraith/src-tauri/src/ssh/session.rs
Vantz Stockwell 44c79decf3
All checks were successful
Build & Sign Wraith / Build Windows + Sign (push) Successful in 2m58s
fix: SFTP preserves position on tab switch + CWD following on macOS
SFTP tab switch fix:
- Removed :key on FileTree that destroyed component on every switch
- useSftp now accepts a reactive Ref<string> sessionId
- Watches sessionId changes and reinitializes without destroying state
- Per-session path memory via sessionPaths map — switching back to a
  tab restores exactly where you were browsing

CWD following fix (macOS + all platforms):
- Injects OSC 7 prompt hook into the shell after SSH connect
- zsh: precmd() emits \e]7;file://host/path\e\\
- bash: PROMPT_COMMAND emits the same sequence
- Sent via the PTY channel so it configures the interactive shell
- The passive OSC 7 parser in the output loop picks it up
- SFTP sidebar auto-navigates to the current working directory

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 00:41:50 -04:00

449 lines
19 KiB
Rust

//! SSH session manager — connects, authenticates, manages PTY channels.
use std::sync::Arc;
use async_trait::async_trait;
use base64::Engine;
use dashmap::DashMap;
use russh::client::{self, Handle};
use russh::{ChannelId, ChannelMsg, CryptoVec, Disconnect};
use serde::Serialize;
use tauri::{AppHandle, Emitter};
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::mpsc;
use crate::db::Database;
use crate::mcp::ScrollbackRegistry;
use crate::mcp::error_watcher::ErrorWatcher;
use crate::sftp::SftpService;
use crate::ssh::cwd::CwdTracker;
use crate::ssh::host_key::{HostKeyResult, HostKeyStore};
pub enum AuthMethod {
Password(String),
Key { private_key_pem: String, passphrase: Option<String> },
}
/// Commands sent to the output loop that owns the Channel.
pub enum ChannelCommand {
Resize { cols: u32, rows: u32 },
Shutdown,
}
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct SessionInfo {
pub id: String,
pub hostname: String,
pub port: u16,
pub username: String,
}
pub struct SshSession {
pub id: String,
pub hostname: String,
pub port: u16,
pub username: String,
pub channel_id: ChannelId,
pub handle: Arc<TokioMutex<Handle<SshClient>>>,
pub command_tx: mpsc::UnboundedSender<ChannelCommand>,
pub cwd_tracker: Option<CwdTracker>,
}
pub struct SshClient {
host_key_store: HostKeyStore,
hostname: String,
port: u16,
}
#[async_trait]
impl client::Handler for SshClient {
type Error = russh::Error;
async fn check_server_key(&mut self, server_public_key: &ssh_key::PublicKey) -> Result<bool, Self::Error> {
let key_type = server_public_key.algorithm().to_string();
let fingerprint = server_public_key.fingerprint(ssh_key::HashAlg::Sha256).to_string();
let raw_key = server_public_key.to_openssh().unwrap_or_default();
match self.host_key_store.verify(&self.hostname, self.port, &key_type, &fingerprint) {
Ok(HostKeyResult::New) => {
let _ = self.host_key_store.store(&self.hostname, self.port, &key_type, &fingerprint, &raw_key);
Ok(true)
}
Ok(HostKeyResult::Match) => Ok(true),
Ok(HostKeyResult::Changed) => Ok(false),
Err(_) => Ok(false),
}
}
}
#[derive(Clone)]
pub struct SshService {
sessions: DashMap<String, Arc<SshSession>>,
db: Database,
}
impl SshService {
pub fn new(db: Database) -> Self {
Self { sessions: DashMap::new(), db }
}
pub async fn connect(&self, app_handle: AppHandle, hostname: &str, port: u16, username: &str, auth: AuthMethod, cols: u32, rows: u32, sftp_service: &SftpService, scrollback: &ScrollbackRegistry, error_watcher: &ErrorWatcher) -> Result<String, String> {
let session_id = uuid::Uuid::new_v4().to_string();
let config = Arc::new(russh::client::Config::default());
let handler = SshClient { host_key_store: HostKeyStore::new(self.db.clone()), hostname: hostname.to_string(), port };
let mut handle = tokio::time::timeout(std::time::Duration::from_secs(10), client::connect(config, (hostname, port), handler))
.await
.map_err(|_| format!("SSH connection to {}:{} timed out after 10s", hostname, port))?
.map_err(|e| format!("SSH connection to {}:{} failed: {}", hostname, port, e))?;
let auth_success = match auth {
AuthMethod::Password(ref password) => {
tokio::time::timeout(std::time::Duration::from_secs(10), handle.authenticate_password(username, password))
.await
.map_err(|_| "SSH password authentication timed out after 10s".to_string())?
.map_err(|e| format!("SSH authentication error: {}", e))?
}
AuthMethod::Key { ref private_key_pem, ref passphrase } => {
let pem = resolve_private_key(private_key_pem)?;
let key = match russh::keys::decode_secret_key(&pem, passphrase.as_deref()) {
Ok(k) => k,
Err(_) if pem.contains("BEGIN EC PRIVATE KEY") => {
// EC keys in SEC1 format — decrypt and convert to PKCS#8
let converted = convert_ec_key_to_pkcs8(&pem, passphrase.as_deref())?;
russh::keys::decode_secret_key(&converted, None).map_err(|e| {
format!("Failed to decode converted EC key: {}", e)
})?
}
Err(e) => {
let first_line = pem.lines().next().unwrap_or("<empty>");
return Err(format!("Failed to decode private key (header: '{}'): {}", first_line, e));
}
};
tokio::time::timeout(std::time::Duration::from_secs(10), handle.authenticate_publickey(username, Arc::new(key)))
.await
.map_err(|_| "SSH key authentication timed out after 10s".to_string())?
.map_err(|e| format!("SSH authentication error: {}", e))?
}
};
if !auth_success { return Err("Authentication failed: server rejected credentials".to_string()); }
let mut channel = handle.channel_open_session().await.map_err(|e| format!("Failed to open session channel: {}", e))?;
channel.request_pty(true, "xterm-256color", cols, rows, 0, 0, &[]).await.map_err(|e| format!("Failed to request PTY: {}", e))?;
channel.request_shell(true).await.map_err(|e| format!("Failed to start shell: {}", e))?;
let channel_id = channel.id();
let handle = Arc::new(TokioMutex::new(handle));
let (command_tx, mut command_rx) = mpsc::unbounded_channel::<ChannelCommand>();
let cwd_tracker = CwdTracker::new();
cwd_tracker.start(handle.clone(), app_handle.clone(), session_id.clone());
let session = Arc::new(SshSession { id: session_id.clone(), hostname: hostname.to_string(), port, username: username.to_string(), channel_id, handle: handle.clone(), command_tx: command_tx.clone(), cwd_tracker: Some(cwd_tracker) });
self.sessions.insert(session_id.clone(), session);
{ let h = handle.lock().await;
if let Ok(sftp_channel) = h.channel_open_session().await {
if sftp_channel.request_subsystem(true, "sftp").await.is_ok() {
if let Ok(sftp_client) = russh_sftp::client::SftpSession::new(sftp_channel.into_stream()).await {
sftp_service.register_client(&session_id, sftp_client);
}
}
}
}
// Create scrollback buffer for MCP terminal_read
let scrollback_buf = scrollback.create(&session_id);
error_watcher.watch(&session_id);
// Start remote monitoring if enabled (runs on a separate exec channel)
crate::ssh::monitor::start_monitor(handle.clone(), app_handle.clone(), session_id.clone());
// Inject OSC 7 CWD reporting hook into the user's shell.
// This enables SFTP CWD following on all platforms (Linux, macOS, FreeBSD).
// Sent via the PTY channel so it configures the interactive shell.
{
let osc7_hook = concat!(
// Detect shell and inject the appropriate hook silently
r#"if [ -n "$ZSH_VERSION" ]; then "#,
r#"precmd() { printf '\033]7;file://%s%s\033\\' "$HOST" "$PWD"; }; "#,
r#"elif [ -n "$BASH_VERSION" ]; then "#,
r#"PROMPT_COMMAND='printf "\033]7;file://%s%s\033\\\\" "$HOSTNAME" "$PWD"'; "#,
r#"fi"#,
"\n"
);
let h = handle.lock().await;
let _ = h.data(channel_id, CryptoVec::from_slice(osc7_hook.as_bytes())).await;
}
// Output reader loop — owns the Channel exclusively.
// Writes go through Handle::data() so no shared mutex is needed.
let sid = session_id.clone();
let app = app_handle.clone();
tokio::spawn(async move {
loop {
tokio::select! {
msg = channel.wait() => {
match msg {
Some(ChannelMsg::Data { ref data }) => {
scrollback_buf.push(data.as_ref());
// Passive OSC 7 CWD detection — scan without modifying stream
if let Some(cwd) = extract_osc7_cwd(data.as_ref()) {
let _ = app.emit(&format!("ssh:cwd:{}", sid), &cwd);
}
let encoded = base64::engine::general_purpose::STANDARD.encode(data.as_ref());
let _ = app.emit(&format!("ssh:data:{}", sid), encoded);
}
Some(ChannelMsg::ExtendedData { ref data, .. }) => {
scrollback_buf.push(data.as_ref());
let encoded = base64::engine::general_purpose::STANDARD.encode(data.as_ref());
let _ = app.emit(&format!("ssh:data:{}", sid), encoded);
}
Some(ChannelMsg::ExitStatus { exit_status }) => {
let _ = app.emit(&format!("ssh:exit:{}", sid), exit_status);
break;
}
Some(ChannelMsg::Close) | None => {
let _ = app.emit(&format!("ssh:close:{}", sid), ());
break;
}
_ => {}
}
}
cmd = command_rx.recv() => {
match cmd {
Some(ChannelCommand::Resize { cols, rows }) => {
let _ = channel.window_change(cols, rows, 0, 0).await;
}
Some(ChannelCommand::Shutdown) | None => {
let _ = channel.eof().await;
let _ = channel.close().await;
break;
}
}
}
}
}
});
Ok(session_id)
}
pub async fn write(&self, session_id: &str, data: &[u8]) -> Result<(), String> {
let session = self.sessions.get(session_id).ok_or_else(|| format!("Session {} not found", session_id))?;
let handle = session.handle.lock().await;
handle.data(session.channel_id, CryptoVec::from_slice(data))
.await
.map_err(|_| format!("Failed to write to session {}", session_id))
}
pub async fn resize(&self, session_id: &str, cols: u32, rows: u32) -> Result<(), String> {
let session = self.sessions.get(session_id).ok_or_else(|| format!("Session {} not found", session_id))?;
session.command_tx.send(ChannelCommand::Resize { cols, rows })
.map_err(|_| format!("Failed to resize session {}: channel closed", session_id))
}
pub async fn disconnect(&self, session_id: &str, sftp_service: &SftpService) -> Result<(), String> {
let (_, session) = self.sessions.remove(session_id).ok_or_else(|| format!("Session {} not found", session_id))?;
let _ = session.command_tx.send(ChannelCommand::Shutdown);
{ let handle = session.handle.lock().await; let _ = handle.disconnect(Disconnect::ByApplication, "", "en").await; }
sftp_service.remove_client(session_id);
Ok(())
}
pub fn get_session(&self, session_id: &str) -> Option<Arc<SshSession>> {
self.sessions.get(session_id).map(|entry| entry.clone())
}
pub fn list_sessions(&self) -> Vec<SessionInfo> {
self.sessions.iter().map(|entry| {
let s = entry.value();
SessionInfo { id: s.id.clone(), hostname: s.hostname.clone(), port: s.port, username: s.username.clone() }
}).collect()
}
}
/// Decrypt a legacy PEM-encrypted EC key and re-encode as unencrypted PKCS#8.
/// Handles -----BEGIN EC PRIVATE KEY----- with Proc-Type/DEK-Info headers.
/// Uses the same MD5-based EVP_BytesToKey KDF that OpenSSL/russh use for RSA.
fn convert_ec_key_to_pkcs8(pem_text: &str, passphrase: Option<&str>) -> Result<String, String> {
use aes::cipher::{BlockDecryptMut, KeyIvInit};
// Parse PEM to extract headers and base64 body
let parsed = pem::parse(pem_text)
.map_err(|e| format!("Failed to parse PEM: {}", e))?;
if parsed.tag() != "EC PRIVATE KEY" {
return Err(format!("Expected EC PRIVATE KEY, got {}", parsed.tag()));
}
let der_bytes = parsed.contents();
// Check if the PEM has encryption headers (Proc-Type: 4,ENCRYPTED)
let is_encrypted = pem_text.contains("Proc-Type: 4,ENCRYPTED");
let decrypted = if is_encrypted {
let pass = passphrase
.ok_or_else(|| "EC key is encrypted but no passphrase provided".to_string())?;
// Extract IV from DEK-Info header
let iv = extract_dek_iv(pem_text)?;
// EVP_BytesToKey: key = MD5(password + iv[:8])
let mut ctx = md5::Context::new();
ctx.consume(pass.as_bytes());
ctx.consume(&iv[..8]);
let key_bytes = ctx.compute();
// Decrypt AES-128-CBC
let decryptor = cbc::Decryptor::<aes::Aes128>::new_from_slices(&key_bytes.0, &iv)
.map_err(|e| format!("AES init failed: {}", e))?;
let mut buf = der_bytes.to_vec();
let decrypted = decryptor
.decrypt_padded_mut::<block_padding::Pkcs7>(&mut buf)
.map_err(|_| "Decryption failed — wrong passphrase?".to_string())?;
decrypted.to_vec()
} else {
der_bytes.to_vec()
};
// Parse SEC1 DER → re-encode as PKCS#8 PEM
use sec1::der::Decode;
let ec_key = sec1::EcPrivateKey::from_der(&decrypted)
.map_err(|e| format!("Failed to parse EC key DER: {}", e))?;
// Build PKCS#8 wrapper around the SEC1 key
// The OID for the curve is embedded in the SEC1 parameters field
let oid = ec_key.parameters
.map(|p| { let sec1::EcParameters::NamedCurve(oid) = p; oid })
.ok_or_else(|| "EC key missing curve OID in parameters".to_string())?;
// Re-encode as PKCS#8 OneAsymmetricKey
use pkcs8::der::Encode;
let inner_der = ec_key.to_der()
.map_err(|e| format!("Failed to re-encode EC key: {}", e))?;
let algorithm = pkcs8::AlgorithmIdentifierRef {
oid: pkcs8::ObjectIdentifier::new("1.2.840.10045.2.1")
.map_err(|e| format!("Bad EC OID: {}", e))?,
parameters: Some(
pkcs8::der::asn1::AnyRef::new(pkcs8::der::Tag::ObjectIdentifier, oid.as_bytes())
.map_err(|e| format!("Bad curve param: {}", e))?
),
};
let pkcs8_info = pkcs8::PrivateKeyInfo {
algorithm,
private_key: &inner_der,
public_key: None,
};
let pkcs8_der = pkcs8_info.to_der()
.map_err(|e| format!("Failed to encode PKCS#8: {}", e))?;
// Wrap in PEM
let pkcs8_pem = pem::encode(&pem::Pem::new("PRIVATE KEY", pkcs8_der));
Ok(pkcs8_pem)
}
/// Extract the 16-byte IV from a DEK-Info: AES-128-CBC,<hex> header.
fn extract_dek_iv(pem_text: &str) -> Result<[u8; 16], String> {
for line in pem_text.lines() {
if let Some(rest) = line.strip_prefix("DEK-Info: AES-128-CBC,") {
let iv_hex = rest.trim();
let iv_bytes = hex::decode(iv_hex)
.map_err(|e| format!("Invalid DEK-Info IV hex: {}", e))?;
if iv_bytes.len() != 16 {
return Err(format!("IV must be 16 bytes, got {}", iv_bytes.len()));
}
let mut iv = [0u8; 16];
iv.copy_from_slice(&iv_bytes);
return Ok(iv);
}
}
Err("No DEK-Info: AES-128-CBC header found in encrypted PEM".to_string())
}
/// Passively extract CWD from OSC 7 escape sequences in terminal output.
/// Format: \e]7;file://hostname/path\a or \e]7;file://hostname/path\e\\
/// Returns the path portion without modifying the data stream.
fn extract_osc7_cwd(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
// Look for OSC 7 pattern: \x1b]7;file://
let marker = "\x1b]7;file://";
let start = text.find(marker)?;
let after_marker = &text[start + marker.len()..];
// Skip hostname (everything up to the next /)
let path_start = after_marker.find('/')?;
let path_part = &after_marker[path_start..];
// Find the terminator: BEL (\x07) or ST (\x1b\\)
let end = path_part.find('\x07')
.or_else(|| path_part.find("\x1b\\").map(|i| i));
let path = match end {
Some(e) => &path_part[..e],
None => path_part, // Might be split across chunks — take what we have
};
if path.is_empty() {
None
} else {
// URL-decode the path (spaces encoded as %20, etc.)
Some(percent_decode(path))
}
}
fn percent_decode(input: &str) -> String {
let mut output = String::with_capacity(input.len());
let mut chars = input.chars();
while let Some(ch) = chars.next() {
if ch == '%' {
let hex: String = chars.by_ref().take(2).collect();
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
output.push(byte as char);
} else {
output.push('%');
output.push_str(&hex);
}
} else {
output.push(ch);
}
}
output
}
/// Resolve a private key string — if it looks like PEM content, return as-is.
/// If it looks like a file path, read the file. Strip BOM and normalize.
fn resolve_private_key(input: &str) -> Result<String, String> {
let input = input.trim();
// Strip UTF-8 BOM if present
let input = input.strip_prefix('\u{feff}').unwrap_or(input);
if input.starts_with("-----BEGIN ") {
return Ok(input.to_string());
}
// Doesn't look like PEM — try as file path
let path = if input.starts_with('~') {
if let Ok(home) = std::env::var("HOME") {
input.replacen('~', &home, 1)
} else {
input.to_string()
}
} else {
input.to_string()
};
let path = std::path::Path::new(&path);
if path.exists() && path.is_file() {
std::fs::read_to_string(path)
.map(|s| s.trim().to_string())
.map_err(|e| format!("Failed to read private key file '{}': {}", path.display(), e))
} else if input.contains('/') || input.contains('\\') {
Err(format!("Private key file not found: {}", input))
} else {
// Neither PEM nor a path — pass through and let russh give its error
Ok(input.to_string())
}
}