wraith/internal/ssh/service.go
Vantz Stockwell 8a096d7f7b
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Has been cancelled
Wraith v0.1.0 — Desktop SSH + RDP + SFTP Client
Go + Wails v3 + Vue 3 + SQLite + FreeRDP3 (purego)
183 tests, 76 source files, 9,910 lines of code

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 08:19:29 -04:00

260 lines
6.3 KiB
Go

package ssh
import (
"database/sql"
"fmt"
"io"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/ssh"
)
// OutputHandler is called when data is read from an SSH session's stdout.
// In production this will emit Wails events; for testing, a simple callback.
type OutputHandler func(sessionID string, data []byte)
// SSHSession represents an active SSH connection with its PTY shell session.
type SSHSession struct {
ID string
Client *ssh.Client
Session *ssh.Session
Stdin io.WriteCloser
ConnID int64
Hostname string
Port int
Username string
Connected time.Time
mu sync.Mutex
}
// SSHService manages SSH connections and their associated sessions.
type SSHService struct {
sessions map[string]*SSHSession
mu sync.RWMutex
db *sql.DB
outputHandler OutputHandler
}
// NewSSHService creates a new SSHService. The outputHandler is called when data
// arrives from a session's stdout. Pass nil if output handling is not needed.
func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService {
return &SSHService{
sessions: make(map[string]*SSHSession),
db: db,
outputHandler: outputHandler,
}
}
// Connect dials an SSH server, opens a session with a PTY and shell, and
// launches a goroutine to read stdout. Returns the session ID.
func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) {
addr := fmt.Sprintf("%s:%d", hostname, port)
config := &ssh.ClientConfig{
User: username,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 15 * time.Second,
}
client, err := ssh.Dial("tcp", addr, config)
if err != nil {
return "", fmt.Errorf("ssh dial %s: %w", addr, err)
}
session, err := client.NewSession()
if err != nil {
client.Close()
return "", fmt.Errorf("new session: %w", err)
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("request pty: %w", err)
}
stdin, err := session.StdinPipe()
if err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("stdin pipe: %w", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("stdout pipe: %w", err)
}
if err := session.Shell(); err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("start shell: %w", err)
}
sessionID := uuid.NewString()
sshSession := &SSHSession{
ID: sessionID,
Client: client,
Session: session,
Stdin: stdin,
Hostname: hostname,
Port: port,
Username: username,
Connected: time.Now(),
}
s.mu.Lock()
s.sessions[sessionID] = sshSession
s.mu.Unlock()
// Launch goroutine to read stdout and forward data via the output handler
go s.readLoop(sessionID, stdout)
return sessionID, nil
}
// readLoop continuously reads from the session stdout and calls the output
// handler with data. It stops when the reader returns an error (typically EOF
// when the session closes).
func (s *SSHService) readLoop(sessionID string, reader io.Reader) {
buf := make([]byte, 32*1024)
for {
n, err := reader.Read(buf)
if n > 0 && s.outputHandler != nil {
data := make([]byte, n)
copy(data, buf[:n])
s.outputHandler(sessionID, data)
}
if err != nil {
break
}
}
}
// Write sends data to the session's stdin.
func (s *SSHService) Write(sessionID string, data string) error {
s.mu.RLock()
sess, ok := s.sessions[sessionID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("session %s not found", sessionID)
}
sess.mu.Lock()
defer sess.mu.Unlock()
if sess.Stdin == nil {
return fmt.Errorf("session %s stdin is closed", sessionID)
}
_, err := sess.Stdin.Write([]byte(data))
if err != nil {
return fmt.Errorf("write to session %s: %w", sessionID, err)
}
return nil
}
// Resize sends a window-change request to the remote PTY.
func (s *SSHService) Resize(sessionID string, cols, rows int) error {
s.mu.RLock()
sess, ok := s.sessions[sessionID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("session %s not found", sessionID)
}
sess.mu.Lock()
defer sess.mu.Unlock()
if sess.Session == nil {
return fmt.Errorf("session %s is closed", sessionID)
}
if err := sess.Session.WindowChange(rows, cols); err != nil {
return fmt.Errorf("resize session %s: %w", sessionID, err)
}
return nil
}
// Disconnect closes the SSH session and client, and removes it from tracking.
func (s *SSHService) Disconnect(sessionID string) error {
s.mu.Lock()
sess, ok := s.sessions[sessionID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("session %s not found", sessionID)
}
delete(s.sessions, sessionID)
s.mu.Unlock()
sess.mu.Lock()
defer sess.mu.Unlock()
if sess.Stdin != nil {
sess.Stdin.Close()
}
if sess.Session != nil {
sess.Session.Close()
}
if sess.Client != nil {
sess.Client.Close()
}
return nil
}
// GetSession returns the SSHSession for the given ID, or false if not found.
func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
sess, ok := s.sessions[sessionID]
return sess, ok
}
// ListSessions returns all active SSH sessions.
func (s *SSHService) ListSessions() []*SSHSession {
s.mu.RLock()
defer s.mu.RUnlock()
list := make([]*SSHSession, 0, len(s.sessions))
for _, sess := range s.sessions {
list = append(list, sess)
}
return list
}
// BuildPasswordAuth creates an ssh.AuthMethod for password authentication.
func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod {
return ssh.Password(password)
}
// BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key.
// If the key is encrypted, pass the passphrase; otherwise pass an empty string.
func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) {
var signer ssh.Signer
var err error
if passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
} else {
signer, err = ssh.ParsePrivateKey(privateKey)
}
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return ssh.PublicKeys(signer), nil
}