diff --git a/src-tauri/src/commands/rdp_commands.rs b/src-tauri/src/commands/rdp_commands.rs index b95bdf3..3834166 100644 --- a/src-tauri/src/commands/rdp_commands.rs +++ b/src-tauri/src/commands/rdp_commands.rs @@ -3,7 +3,6 @@ //! Mirrors the pattern used by `ssh_commands.rs` — thin command wrappers that //! delegate to the `RdpService` via `State`. -use serde::Deserialize; use tauri::State; use crate::rdp::{RdpConfig, RdpSessionInfo}; diff --git a/src-tauri/src/rdp/mod.rs b/src-tauri/src/rdp/mod.rs index 99a6ea0..6c9bb12 100644 --- a/src-tauri/src/rdp/mod.rs +++ b/src-tauri/src/rdp/mod.rs @@ -7,7 +7,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use base64::Engine; use dashmap::DashMap; -use log::{debug, error, info, warn}; +use log::{error, info, warn}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; diff --git a/src-tauri/src/ssh/session.rs b/src-tauri/src/ssh/session.rs index 34f2081..36dbc1f 100644 --- a/src-tauri/src/ssh/session.rs +++ b/src-tauri/src/ssh/session.rs @@ -4,12 +4,12 @@ use std::sync::Arc; use async_trait::async_trait; use base64::Engine; use dashmap::DashMap; -use log::{debug, error, info, warn}; -use russh::client::{self, Handle, Msg}; -use russh::{Channel, ChannelMsg, Disconnect}; +use russh::client::{self, Handle}; +use russh::{ChannelId, ChannelMsg, CryptoVec, Disconnect}; use serde::Serialize; use tauri::{AppHandle, Emitter}; use tokio::sync::Mutex as TokioMutex; +use tokio::sync::mpsc; use crate::db::Database; use crate::sftp::SftpService; @@ -21,6 +21,12 @@ pub enum AuthMethod { Key { private_key_pem: String, passphrase: Option }, } +/// Commands sent to the output loop that owns the Channel. +pub enum ChannelCommand { + Resize { cols: u32, rows: u32 }, + Shutdown, +} + #[derive(Debug, Serialize, Clone)] #[serde(rename_all = "camelCase")] pub struct SessionInfo { @@ -35,8 +41,9 @@ pub struct SshSession { pub hostname: String, pub port: u16, pub username: String, - pub channel: Arc>>, + pub channel_id: ChannelId, pub handle: Arc>>, + pub command_tx: mpsc::UnboundedSender, pub cwd_tracker: Option, } @@ -103,16 +110,17 @@ impl SshService { if !auth_success { return Err("Authentication failed: server rejected credentials".to_string()); } - let channel = handle.channel_open_session().await.map_err(|e| format!("Failed to open session channel: {}", e))?; + let mut channel = handle.channel_open_session().await.map_err(|e| format!("Failed to open session channel: {}", e))?; channel.request_pty(true, "xterm-256color", cols, rows, 0, 0, &[]).await.map_err(|e| format!("Failed to request PTY: {}", e))?; channel.request_shell(true).await.map_err(|e| format!("Failed to start shell: {}", e))?; + let channel_id = channel.id(); let handle = Arc::new(TokioMutex::new(handle)); - let channel = Arc::new(TokioMutex::new(channel)); + let (command_tx, mut command_rx) = mpsc::unbounded_channel::(); let cwd_tracker = CwdTracker::new(); cwd_tracker.start(handle.clone(), app_handle.clone(), session_id.clone()); - let session = Arc::new(SshSession { id: session_id.clone(), hostname: hostname.to_string(), port, username: username.to_string(), channel: channel.clone(), handle: handle.clone(), cwd_tracker: Some(cwd_tracker) }); + let session = Arc::new(SshSession { id: session_id.clone(), hostname: hostname.to_string(), port, username: username.to_string(), channel_id, handle: handle.clone(), command_tx: command_tx.clone(), cwd_tracker: Some(cwd_tracker) }); self.sessions.insert(session_id.clone(), session); { let h = handle.lock().await; @@ -125,30 +133,46 @@ impl SshService { } } + // Output reader loop — owns the Channel exclusively. + // Writes go through Handle::data() so no shared mutex is needed. let sid = session_id.clone(); - let chan = channel.clone(); let app = app_handle.clone(); tokio::spawn(async move { loop { - let msg = { let mut ch = chan.lock().await; ch.wait().await }; - match msg { - Some(ChannelMsg::Data { ref data }) => { - let encoded = base64::engine::general_purpose::STANDARD.encode(data.as_ref()); - let _ = app.emit(&format!("ssh:data:{}", sid), encoded); + tokio::select! { + msg = channel.wait() => { + match msg { + Some(ChannelMsg::Data { ref data }) => { + let encoded = base64::engine::general_purpose::STANDARD.encode(data.as_ref()); + let _ = app.emit(&format!("ssh:data:{}", sid), encoded); + } + Some(ChannelMsg::ExtendedData { ref data, .. }) => { + let encoded = base64::engine::general_purpose::STANDARD.encode(data.as_ref()); + let _ = app.emit(&format!("ssh:data:{}", sid), encoded); + } + Some(ChannelMsg::ExitStatus { exit_status }) => { + let _ = app.emit(&format!("ssh:exit:{}", sid), exit_status); + break; + } + Some(ChannelMsg::Close) | None => { + let _ = app.emit(&format!("ssh:close:{}", sid), ()); + break; + } + _ => {} + } } - Some(ChannelMsg::ExtendedData { ref data, .. }) => { - let encoded = base64::engine::general_purpose::STANDARD.encode(data.as_ref()); - let _ = app.emit(&format!("ssh:data:{}", sid), encoded); + cmd = command_rx.recv() => { + match cmd { + Some(ChannelCommand::Resize { cols, rows }) => { + let _ = channel.window_change(cols, rows, 0, 0).await; + } + Some(ChannelCommand::Shutdown) | None => { + let _ = channel.eof().await; + let _ = channel.close().await; + break; + } + } } - Some(ChannelMsg::ExitStatus { exit_status }) => { - let _ = app.emit(&format!("ssh:exit:{}", sid), exit_status); - break; - } - Some(ChannelMsg::Close) | None => { - let _ = app.emit(&format!("ssh:close:{}", sid), ()); - break; - } - _ => {} } } }); @@ -158,19 +182,21 @@ impl SshService { pub async fn write(&self, session_id: &str, data: &[u8]) -> Result<(), String> { let session = self.sessions.get(session_id).ok_or_else(|| format!("Session {} not found", session_id))?; - let channel = session.channel.lock().await; - channel.data(&data[..]).await.map_err(|e| format!("Failed to write to session {}: {}", session_id, e)) + let handle = session.handle.lock().await; + handle.data(session.channel_id, CryptoVec::from_slice(data)) + .await + .map_err(|_| format!("Failed to write to session {}", session_id)) } pub async fn resize(&self, session_id: &str, cols: u32, rows: u32) -> Result<(), String> { let session = self.sessions.get(session_id).ok_or_else(|| format!("Session {} not found", session_id))?; - let channel = session.channel.lock().await; - channel.window_change(cols, rows, 0, 0).await.map_err(|e| format!("Failed to resize session {}: {}", session_id, e)) + session.command_tx.send(ChannelCommand::Resize { cols, rows }) + .map_err(|_| format!("Failed to resize session {}: channel closed", session_id)) } pub async fn disconnect(&self, session_id: &str, sftp_service: &SftpService) -> Result<(), String> { let (_, session) = self.sessions.remove(session_id).ok_or_else(|| format!("Session {} not found", session_id))?; - { let channel = session.channel.lock().await; let _ = channel.eof().await; let _ = channel.close().await; } + let _ = session.command_tx.send(ChannelCommand::Shutdown); { let handle = session.handle.lock().await; let _ = handle.disconnect(Disconnect::ByApplication, "", "en").await; } sftp_service.remove_client(session_id); Ok(())