pub mod input; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use base64::Engine; use dashmap::DashMap; use log::{error, info, warn}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::sync::mpsc; use tokio::sync::Mutex as TokioMutex; use ironrdp::connector::{self, ClientConnector, ConnectionResult, Credentials, DesktopSize}; use ironrdp::graphics::image_processing::PixelFormat; use ironrdp::input::{self as rdp_input, MouseButton, MousePosition, Operation, Scancode, WheelRotations}; use ironrdp::pdu::gcc::KeyboardType; use ironrdp::pdu::rdp::capability_sets::MajorPlatformType; use ironrdp::pdu::rdp::client_info::{PerformanceFlags, TimezoneInfo}; use ironrdp::session::image::DecodedImage; use ironrdp::session::{ActiveStage, ActiveStageOutput}; use ironrdp_tokio::reqwest::ReqwestNetworkClient; use ironrdp_tokio::{split_tokio_framed, FramedWrite, TokioFramed}; use self::input::mouse_flags; // ── Public types ────────────────────────────────────────────────────────────── #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct RdpConfig { pub hostname: String, pub port: u16, pub username: String, pub password: String, pub domain: Option, pub width: u16, pub height: u16, } #[derive(Debug, Serialize, Clone)] #[serde(rename_all = "camelCase")] pub struct RdpSessionInfo { pub id: String, pub hostname: String, pub width: u16, pub height: u16, pub connected: bool, } enum InputEvent { Mouse { x: u16, y: u16, flags: u32, }, Key { scancode: u16, pressed: bool, }, Clipboard(String), Disconnect, } struct RdpSessionHandle { id: String, hostname: String, width: u16, height: u16, frame_buffer: Arc>>, frame_dirty: Arc, input_tx: mpsc::UnboundedSender, } pub struct RdpService { sessions: DashMap>, } impl RdpService { pub fn new() -> Self { Self { sessions: DashMap::new(), } } pub fn connect(&self, config: RdpConfig) -> Result { let session_id = uuid::Uuid::new_v4().to_string(); wraith_log!("[RDP] Connecting to {}:{} as {} (session {})", config.hostname, config.port, config.username, session_id); let width = config.width; let height = config.height; let hostname = config.hostname.clone(); let buf_size = (width as usize) * (height as usize) * 4; let mut initial_buf = vec![0u8; buf_size]; for pixel in initial_buf.chunks_exact_mut(4) { pixel[3] = 255; } let frame_buffer = Arc::new(TokioMutex::new(initial_buf)); let frame_dirty = Arc::new(AtomicBool::new(false)); let (input_tx, input_rx) = mpsc::unbounded_channel(); let handle = Arc::new(RdpSessionHandle { id: session_id.clone(), hostname: hostname.clone(), width, height, frame_buffer: frame_buffer.clone(), frame_dirty: frame_dirty.clone(), input_tx, }); self.sessions.insert(session_id.clone(), handle); let sid = session_id.clone(); let sessions_ref = self.sessions.clone(); let (ready_tx, ready_rx) = std::sync::mpsc::channel::>(); std::thread::spawn(move || { let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); rt.block_on(async move { let connector_config = match build_connector_config(&config) { Ok(c) => c, Err(e) => { let _ = ready_tx.send(Err(format!("Failed to build RDP config: {}", e))); sessions_ref.remove(&sid); return; } }; let (connection_result, framed) = match tokio::time::timeout(std::time::Duration::from_secs(15), establish_connection(connector_config, &config.hostname, config.port)).await { Ok(Ok(r)) => r, Ok(Err(e)) => { let _ = ready_tx.send(Err(format!("RDP connection failed: {}", e))); sessions_ref.remove(&sid); return; } Err(_) => { let _ = ready_tx.send(Err("RDP connection timed out after 15s".to_string())); sessions_ref.remove(&sid); return; } }; info!("RDP connection established to {}:{} (session {})", config.hostname, config.port, sid); let _ = ready_tx.send(Ok(())); if let Err(e) = run_active_session( connection_result, framed, frame_buffer, frame_dirty, input_rx, width as u16, height as u16, ) .await { error!("RDP session {} error: {}", sid, e); } info!("RDP session {} ended", sid); sessions_ref.remove(&sid); }); })); if let Err(panic) = result { let msg = if let Some(s) = panic.downcast_ref::() { s.clone() } else if let Some(s) = panic.downcast_ref::<&str>() { s.to_string() } else { "unknown panic".to_string() }; let _ = crate::write_log(&crate::data_directory().join("wraith.log"), &format!("RDP thread PANIC: {}", msg)); // ready_tx is dropped here, which triggers the "died unexpectedly" error } }); match ready_rx.recv() { Ok(Ok(())) => {} Ok(Err(e)) => { self.sessions.remove(&session_id); return Err(e); } Err(_) => { self.sessions.remove(&session_id); return Err("RDP connection thread panicked — check wraith.log for details".into()); } } Ok(session_id) } pub async fn get_frame(&self, session_id: &str) -> Result { let handle = self.sessions.get(session_id).ok_or_else(|| format!("RDP session {} not found", session_id))?; if !handle.frame_dirty.swap(false, Ordering::Relaxed) { return Ok(String::new()); } let buf = handle.frame_buffer.lock().await; let encoded = base64::engine::general_purpose::STANDARD.encode(&*buf); Ok(encoded) } pub async fn get_frame_raw(&self, session_id: &str) -> Result, String> { let handle = self.sessions.get(session_id).ok_or_else(|| format!("RDP session {} not found", session_id))?; let buf = handle.frame_buffer.lock().await; Ok(buf.clone()) } /// Capture the current RDP frame as a base64-encoded PNG. pub async fn screenshot_png_base64(&self, session_id: &str) -> Result { let handle = self.sessions.get(session_id).ok_or_else(|| format!("RDP session {} not found", session_id))?; let width = handle.width as u32; let height = handle.height as u32; let buf = handle.frame_buffer.lock().await; // Encode RGBA raw bytes to PNG let mut png_data = Vec::new(); { let mut encoder = png::Encoder::new(&mut png_data, width, height); encoder.set_color(png::ColorType::Rgba); encoder.set_depth(png::BitDepth::Eight); let mut writer = encoder.write_header() .map_err(|e| format!("PNG header error: {}", e))?; writer.write_image_data(&buf) .map_err(|e| format!("PNG encode error: {}", e))?; } Ok(base64::engine::general_purpose::STANDARD.encode(&png_data)) } pub fn send_clipboard(&self, session_id: &str, text: &str) -> Result<(), String> { let handle = self.sessions.get(session_id).ok_or_else(|| format!("RDP session {} not found", session_id))?; handle.input_tx.send(InputEvent::Clipboard(text.to_string())).map_err(|_| format!("RDP session {} input channel closed", session_id)) } pub fn send_mouse(&self, session_id: &str, x: u16, y: u16, flags: u32) -> Result<(), String> { let handle = self.sessions.get(session_id).ok_or_else(|| format!("RDP session {} not found", session_id))?; handle.input_tx.send(InputEvent::Mouse { x, y, flags }).map_err(|_| format!("RDP session {} input channel closed", session_id)) } pub fn send_key(&self, session_id: &str, scancode: u16, pressed: bool) -> Result<(), String> { let handle = self.sessions.get(session_id).ok_or_else(|| format!("RDP session {} not found", session_id))?; handle.input_tx.send(InputEvent::Key { scancode, pressed }).map_err(|_| format!("RDP session {} input channel closed", session_id)) } pub fn disconnect(&self, session_id: &str) -> Result<(), String> { let handle = self.sessions.get(session_id).ok_or_else(|| format!("RDP session {} not found", session_id))?; let _ = handle.input_tx.send(InputEvent::Disconnect); drop(handle); self.sessions.remove(session_id); info!("RDP session {} disconnect requested", session_id); Ok(()) } pub fn list_sessions(&self) -> Vec { self.sessions.iter().map(|entry| { let h = entry.value(); RdpSessionInfo { id: h.id.clone(), hostname: h.hostname.clone(), width: h.width, height: h.height, connected: !h.input_tx.is_closed() } }).collect() } } impl Clone for RdpService { fn clone(&self) -> Self { Self { sessions: self.sessions.clone() } } } fn build_connector_config(config: &RdpConfig) -> Result { Ok(connector::Config { credentials: Credentials::UsernamePassword { username: config.username.clone(), password: config.password.clone() }, domain: config.domain.clone(), enable_tls: false, enable_credssp: true, keyboard_type: KeyboardType::IbmEnhanced, keyboard_subtype: 0, keyboard_layout: 0, keyboard_functional_keys_count: 12, ime_file_name: String::new(), dig_product_id: String::new(), desktop_size: DesktopSize { width: config.width, height: config.height }, bitmap: None, client_build: 0, client_name: "Wraith Desktop".to_owned(), client_dir: r"C:\Windows\System32\mstscax.dll".to_owned(), #[cfg(windows)] platform: MajorPlatformType::WINDOWS, #[cfg(target_os = "macos")] platform: MajorPlatformType::MACINTOSH, #[cfg(target_os = "linux")] platform: MajorPlatformType::UNIX, #[cfg(not(any(windows, target_os = "macos", target_os = "linux")))] platform: MajorPlatformType::UNIX, enable_server_pointer: true, pointer_software_rendering: true, request_data: None, autologon: false, enable_audio_playback: false, performance_flags: PerformanceFlags::default(), desktop_scale_factor: 0, hardware_id: None, license_cache: None, timezone_info: TimezoneInfo::default(), }) } trait AsyncReadWrite: AsyncRead + AsyncWrite + 'static {} impl AsyncReadWrite for T {} type UpgradedFramed = TokioFramed>; async fn establish_connection(config: connector::Config, hostname: &str, port: u16) -> Result<(ConnectionResult, UpgradedFramed), String> { let addr = format!("{}:{}", hostname, port); let stream = TcpStream::connect(&addr).await.map_err(|e| format!("TCP connect to {} failed: {}", addr, e))?; let client_addr = stream.local_addr().map_err(|e| format!("Failed to get local address: {}", e))?; let mut framed = TokioFramed::new(stream); let mut connector = ClientConnector::new(config, client_addr); let should_upgrade = ironrdp_tokio::connect_begin(&mut framed, &mut connector).await.map_err(|e| format!("RDP connect_begin failed: {}", e))?; let (initial_stream, leftover_bytes) = framed.into_inner(); let (tls_stream, tls_cert) = ironrdp_tls::upgrade(initial_stream, hostname).await.map_err(|e| format!("TLS upgrade failed: {}", e))?; let upgraded = ironrdp_tokio::mark_as_upgraded(should_upgrade, &mut connector); let erased_stream: Box = Box::new(tls_stream); let mut upgraded_framed = TokioFramed::new_with_leftover(erased_stream, leftover_bytes); let server_public_key = ironrdp_tls::extract_tls_server_public_key(&tls_cert).ok_or_else(|| "Failed to extract TLS server public key".to_string())?.to_owned(); let connection_result = ironrdp_tokio::connect_finalize(upgraded, connector, &mut upgraded_framed, &mut ReqwestNetworkClient::new(), hostname.into(), server_public_key, None).await.map_err(|e| format!("RDP connect_finalize failed: {}", e))?; Ok((connection_result, upgraded_framed)) } async fn run_active_session(connection_result: ConnectionResult, framed: UpgradedFramed, frame_buffer: Arc>>, frame_dirty: Arc, mut input_rx: mpsc::UnboundedReceiver, width: u16, height: u16) -> Result<(), String> { let (mut reader, mut writer) = split_tokio_framed(framed); let mut image = DecodedImage::new(PixelFormat::RgbA32, width, height); let mut active_stage = ActiveStage::new(connection_result); let mut input_db = rdp_input::Database::new(); loop { let outputs = tokio::select! { frame = reader.read_pdu() => { let (action, payload) = frame.map_err(|e| format!("Failed to read RDP frame: {}", e))?; active_stage.process(&mut image, action, &payload).map_err(|e| format!("Failed to process RDP frame: {}", e))? } input_event = input_rx.recv() => { match input_event { Some(InputEvent::Disconnect) | None => { if let Ok(outputs) = active_stage.graceful_shutdown() { for out in outputs { if let ActiveStageOutput::ResponseFrame(frame) = out { let _ = writer.write_all(&frame).await; } } } return Ok(()); } Some(InputEvent::Mouse { x, y, flags }) => { let ops = translate_mouse_flags(x, y, flags); let events = input_db.apply(ops); active_stage.process_fastpath_input(&mut image, &events).map_err(|e| format!("Failed to process mouse input: {}", e))? } Some(InputEvent::Key { scancode, pressed }) => { let sc = Scancode::from_u16(scancode); let op = if pressed { Operation::KeyPressed(sc) } else { Operation::KeyReleased(sc) }; let events = input_db.apply([op]); active_stage.process_fastpath_input(&mut image, &events).map_err(|e| format!("Failed to process keyboard input: {}", e))? } Some(InputEvent::Clipboard(text)) => { let shift_sc = Scancode::from_u16(0x002A); let mut all_outputs = Vec::new(); for ch in text.chars() { if let Some((sc_val, shift)) = char_to_scancode(ch) { let sc = Scancode::from_u16(sc_val); if shift { let evts = input_db.apply([Operation::KeyPressed(shift_sc)]); all_outputs.extend(active_stage.process_fastpath_input(&mut image, &evts).map_err(|e| format!("clipboard input error: {}", e))?); } let evts = input_db.apply([Operation::KeyPressed(sc)]); all_outputs.extend(active_stage.process_fastpath_input(&mut image, &evts).map_err(|e| format!("clipboard input error: {}", e))?); let evts = input_db.apply([Operation::KeyReleased(sc)]); all_outputs.extend(active_stage.process_fastpath_input(&mut image, &evts).map_err(|e| format!("clipboard input error: {}", e))?); if shift { let evts = input_db.apply([Operation::KeyReleased(shift_sc)]); all_outputs.extend(active_stage.process_fastpath_input(&mut image, &evts).map_err(|e| format!("clipboard input error: {}", e))?); } } } all_outputs } } } }; for out in outputs { match out { ActiveStageOutput::ResponseFrame(frame) => { writer.write_all(&frame).await.map_err(|e| format!("Failed to write RDP response frame: {}", e))?; } ActiveStageOutput::GraphicsUpdate(_region) => { let mut buf = frame_buffer.lock().await; let src = image.data(); if src.len() == buf.len() { buf.copy_from_slice(src); } else { *buf = src.to_vec(); } frame_dirty.store(true, Ordering::Relaxed); } ActiveStageOutput::Terminate(reason) => { info!("RDP session terminated: {:?}", reason); return Ok(()); } ActiveStageOutput::DeactivateAll(_) => { warn!("RDP server sent DeactivateAll — reconnection not yet implemented"); return Ok(()); } _ => {} } } } } /// Map an ASCII character to (scancode, needs_shift) for RDP keystroke injection. fn char_to_scancode(ch: char) -> Option<(u16, bool)> { match ch { 'a'..='z' => { let offsets: &[u16] = &[ 0x1E, 0x30, 0x2E, 0x20, 0x12, 0x21, 0x22, 0x23, 0x17, 0x24, 0x25, 0x26, 0x32, 0x31, 0x18, 0x19, 0x10, 0x13, 0x1F, 0x14, 0x16, 0x2F, 0x11, 0x2D, 0x15, 0x2C, ]; Some((offsets[(ch as u8 - b'a') as usize], false)) } 'A'..='Z' => { char_to_scancode(ch.to_ascii_lowercase()).map(|(sc, _)| (sc, true)) } '0' => Some((0x0B, false)), '1'..='9' => Some(((ch as u16 - '0' as u16) + 1, false)), ')' => Some((0x0B, true)), '!' => Some((0x02, true)), '@' => Some((0x03, true)), '#' => Some((0x04, true)), '$' => Some((0x05, true)), '%' => Some((0x06, true)), '^' => Some((0x07, true)), '&' => Some((0x08, true)), '*' => Some((0x09, true)), '(' => Some((0x0A, true)), '-' => Some((0x0C, false)), '_' => Some((0x0C, true)), '=' => Some((0x0D, false)), '+' => Some((0x0D, true)), '[' => Some((0x1A, false)), '{' => Some((0x1A, true)), ']' => Some((0x1B, false)), '}' => Some((0x1B, true)), '\\' => Some((0x2B, false)), '|' => Some((0x2B, true)), ';' => Some((0x27, false)), ':' => Some((0x27, true)), '\'' => Some((0x28, false)), '"' => Some((0x28, true)), ',' => Some((0x33, false)), '<' => Some((0x33, true)), '.' => Some((0x34, false)), '>' => Some((0x34, true)), '/' => Some((0x35, false)), '?' => Some((0x35, true)), '`' => Some((0x29, false)), '~' => Some((0x29, true)), ' ' => Some((0x39, false)), '\n' | '\r' => Some((0x1C, false)), '\t' => Some((0x0F, false)), _ => None, } } fn translate_mouse_flags(x: u16, y: u16, flags: u32) -> Vec { let mut ops = Vec::new(); let pos = MousePosition { x, y }; if flags & mouse_flags::MOVE != 0 { ops.push(Operation::MouseMove(pos)); } let is_down = flags & mouse_flags::DOWN != 0; if flags & mouse_flags::BUTTON1 != 0 { if is_down { ops.push(Operation::MouseButtonPressed(MouseButton::Left)); } else { ops.push(Operation::MouseButtonReleased(MouseButton::Left)); } } if flags & mouse_flags::BUTTON2 != 0 { if is_down { ops.push(Operation::MouseButtonPressed(MouseButton::Right)); } else { ops.push(Operation::MouseButtonReleased(MouseButton::Right)); } } if flags & mouse_flags::BUTTON3 != 0 { if is_down { ops.push(Operation::MouseButtonPressed(MouseButton::Middle)); } else { ops.push(Operation::MouseButtonReleased(MouseButton::Middle)); } } if flags & mouse_flags::WHEEL != 0 { let units: i16 = if flags & mouse_flags::WHEEL_NEG != 0 { -120 } else { 120 }; ops.push(Operation::WheelRotations(WheelRotations { is_vertical: true, rotation_units: units })); } if flags & mouse_flags::HWHEEL != 0 { let units: i16 = if flags & mouse_flags::WHEEL_NEG != 0 { -120 } else { 120 }; ops.push(Operation::WheelRotations(WheelRotations { is_vertical: false, rotation_units: units })); } if ops.is_empty() { ops.push(Operation::MouseMove(pos)); } ops }