diff --git a/backend/src/main.ts b/backend/src/main.ts index 426104c..69401c2 100644 --- a/backend/src/main.ts +++ b/backend/src/main.ts @@ -1,13 +1,14 @@ import { NestFactory } from '@nestjs/core'; import { ValidationPipe } from '@nestjs/common'; -import { WsAdapter } from '@nestjs/platform-ws'; import { AppModule } from './app.module'; +import { WebSocketServer } from 'ws'; +import { TerminalGateway } from './terminal/terminal.gateway'; +import { SftpGateway } from './terminal/sftp.gateway'; async function bootstrap() { const app = await NestFactory.create(AppModule); app.setGlobalPrefix('api'); app.useGlobalPipes(new ValidationPipe({ whitelist: true, transform: true })); - app.useWebSocketAdapter(new WsAdapter(app)); app.enableCors({ origin: process.env.NODE_ENV === 'production' ? false : 'http://localhost:3001', credentials: true, @@ -15,10 +16,40 @@ async function bootstrap() { await app.listen(3000); console.log('Wraith backend running on port 3000'); - // Debug: monitor WebSocket upgrades at the HTTP server level + // Manual WebSocket handling — bypasses NestJS WsAdapter entirely const server = app.getHttpServer(); - server.on('upgrade', (req: any, socket: any) => { - console.log(`[HTTP-UPGRADE] ${req.method} ${req.url} from ${req.headers.origin || 'unknown'}`); + const terminalGateway = app.get(TerminalGateway); + const sftpGateway = app.get(SftpGateway); + + const terminalWss = new WebSocketServer({ noServer: true }); + const sftpWss = new WebSocketServer({ noServer: true }); + + terminalWss.on('connection', (ws, req) => { + console.log(`[WS] Terminal connection from ${req.url}`); + terminalGateway.handleConnection(ws, req); + }); + + sftpWss.on('connection', (ws, req) => { + console.log(`[WS] SFTP connection from ${req.url}`); + sftpGateway.handleConnection(ws, req); + }); + + server.on('upgrade', (req: any, socket: any, head: any) => { + const pathname = req.url?.split('?')[0]; + console.log(`[HTTP-UPGRADE] path=${pathname}`); + + if (pathname === '/api/ws/terminal') { + terminalWss.handleUpgrade(req, socket, head, (ws) => { + terminalWss.emit('connection', ws, req); + }); + } else if (pathname === '/api/ws/sftp') { + sftpWss.handleUpgrade(req, socket, head, (ws) => { + sftpWss.emit('connection', ws, req); + }); + } else { + console.log(`[HTTP-UPGRADE] Unknown WS path: ${pathname}, destroying socket`); + socket.destroy(); + } }); } bootstrap(); diff --git a/backend/src/terminal/sftp.gateway.ts b/backend/src/terminal/sftp.gateway.ts index c1b8e60..c63957c 100644 --- a/backend/src/terminal/sftp.gateway.ts +++ b/backend/src/terminal/sftp.gateway.ts @@ -1,14 +1,11 @@ -import { WebSocketGateway, WebSocketServer, OnGatewayConnection, OnGatewayDisconnect } from '@nestjs/websockets'; -import { Logger } from '@nestjs/common'; -import { Server } from 'ws'; +import { Injectable, Logger } from '@nestjs/common'; import { WsAuthGuard } from '../auth/ws-auth.guard'; import { SshConnectionService } from './ssh-connection.service'; const MAX_EDIT_SIZE = 5 * 1024 * 1024; // 5MB -@WebSocketGateway({ path: '/api/ws/sftp' }) -export class SftpGateway implements OnGatewayConnection, OnGatewayDisconnect { - @WebSocketServer() server: Server; +@Injectable() +export class SftpGateway { private readonly logger = new Logger(SftpGateway.name); constructor( @@ -34,8 +31,6 @@ export class SftpGateway implements OnGatewayConnection, OnGatewayDisconnect { }); } - handleDisconnect() {} - private async handleMessage(client: any, msg: any) { const { sessionId } = msg; if (!sessionId) { @@ -103,7 +98,6 @@ export class SftpGateway implements OnGatewayConnection, OnGatewayDisconnect { break; } case 'delete': { - // Try unlink (file), fallback to rmdir (directory) sftp.unlink(msg.path, (err: any) => { if (err) { sftp.rmdir(msg.path, (err2: any) => { @@ -140,7 +134,6 @@ export class SftpGateway implements OnGatewayConnection, OnGatewayDisconnect { break; } case 'download': { - // Stream file data to client in chunks const readStream = sftp.createReadStream(msg.path); sftp.stat(msg.path, (err: any, stats: any) => { if (err) return this.send(client, { type: 'error', message: err.message }); diff --git a/backend/src/terminal/terminal.gateway.ts b/backend/src/terminal/terminal.gateway.ts index 4642ff2..35b7612 100644 --- a/backend/src/terminal/terminal.gateway.ts +++ b/backend/src/terminal/terminal.gateway.ts @@ -1,12 +1,9 @@ -import { WebSocketGateway, WebSocketServer, OnGatewayConnection, OnGatewayDisconnect } from '@nestjs/websockets'; -import { Logger } from '@nestjs/common'; -import { Server } from 'ws'; +import { Injectable, Logger } from '@nestjs/common'; import { WsAuthGuard } from '../auth/ws-auth.guard'; import { SshConnectionService } from './ssh-connection.service'; -@WebSocketGateway({ path: '/api/ws/terminal' }) -export class TerminalGateway implements OnGatewayConnection, OnGatewayDisconnect { - @WebSocketServer() server: Server; +@Injectable() +export class TerminalGateway { private readonly logger = new Logger(TerminalGateway.name); private clientSessions = new Map(); // ws client → sessionIds @@ -36,12 +33,13 @@ export class TerminalGateway implements OnGatewayConnection, OnGatewayDisconnect this.send(client, { type: 'error', message: err.message }); } }); - } - handleDisconnect(client: any) { - const sessions = this.clientSessions.get(client) || []; - sessions.forEach((sid) => this.ssh.disconnect(sid)); - this.clientSessions.delete(client); + client.on('close', () => { + this.logger.log(`[WS] Client disconnected`); + const sessions = this.clientSessions.get(client) || []; + sessions.forEach((sid) => this.ssh.disconnect(sid)); + this.clientSessions.delete(client); + }); } private async handleMessage(client: any, msg: any) { diff --git a/backend/src/terminal/terminal.module.ts b/backend/src/terminal/terminal.module.ts index 77a7f52..99e9df8 100644 --- a/backend/src/terminal/terminal.module.ts +++ b/backend/src/terminal/terminal.module.ts @@ -9,6 +9,6 @@ import { AuthModule } from '../auth/auth.module'; @Module({ imports: [VaultModule, ConnectionsModule, AuthModule], providers: [SshConnectionService, TerminalGateway, SftpGateway], - exports: [SshConnectionService], + exports: [SshConnectionService, TerminalGateway, SftpGateway], }) export class TerminalModule {}