feat: SSH service — connect, PTY, shell I/O with goroutine pipes
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
fad5692c00
commit
c48c0de042
259
internal/ssh/service.go
Normal file
259
internal/ssh/service.go
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OutputHandler is called when data is read from an SSH session's stdout.
|
||||||
|
// In production this will emit Wails events; for testing, a simple callback.
|
||||||
|
type OutputHandler func(sessionID string, data []byte)
|
||||||
|
|
||||||
|
// SSHSession represents an active SSH connection with its PTY shell session.
|
||||||
|
type SSHSession struct {
|
||||||
|
ID string
|
||||||
|
Client *ssh.Client
|
||||||
|
Session *ssh.Session
|
||||||
|
Stdin io.WriteCloser
|
||||||
|
ConnID int64
|
||||||
|
Hostname string
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Connected time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHService manages SSH connections and their associated sessions.
|
||||||
|
type SSHService struct {
|
||||||
|
sessions map[string]*SSHSession
|
||||||
|
mu sync.RWMutex
|
||||||
|
db *sql.DB
|
||||||
|
outputHandler OutputHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSHService creates a new SSHService. The outputHandler is called when data
|
||||||
|
// arrives from a session's stdout. Pass nil if output handling is not needed.
|
||||||
|
func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService {
|
||||||
|
return &SSHService{
|
||||||
|
sessions: make(map[string]*SSHSession),
|
||||||
|
db: db,
|
||||||
|
outputHandler: outputHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect dials an SSH server, opens a session with a PTY and shell, and
|
||||||
|
// launches a goroutine to read stdout. Returns the session ID.
|
||||||
|
func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) {
|
||||||
|
addr := fmt.Sprintf("%s:%d", hostname, port)
|
||||||
|
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: username,
|
||||||
|
Auth: authMethods,
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := ssh.Dial("tcp", addr, config)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("ssh dial %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("new session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modes := ssh.TerminalModes{
|
||||||
|
ssh.ECHO: 1,
|
||||||
|
ssh.TTY_OP_ISPEED: 14400,
|
||||||
|
ssh.TTY_OP_OSPEED: 14400,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("request pty: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdin, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("stdin pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout, err := session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("start shell: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := uuid.NewString()
|
||||||
|
sshSession := &SSHSession{
|
||||||
|
ID: sessionID,
|
||||||
|
Client: client,
|
||||||
|
Session: session,
|
||||||
|
Stdin: stdin,
|
||||||
|
Hostname: hostname,
|
||||||
|
Port: port,
|
||||||
|
Username: username,
|
||||||
|
Connected: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.sessions[sessionID] = sshSession
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Launch goroutine to read stdout and forward data via the output handler
|
||||||
|
go s.readLoop(sessionID, stdout)
|
||||||
|
|
||||||
|
return sessionID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLoop continuously reads from the session stdout and calls the output
|
||||||
|
// handler with data. It stops when the reader returns an error (typically EOF
|
||||||
|
// when the session closes).
|
||||||
|
func (s *SSHService) readLoop(sessionID string, reader io.Reader) {
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
for {
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if n > 0 && s.outputHandler != nil {
|
||||||
|
data := make([]byte, n)
|
||||||
|
copy(data, buf[:n])
|
||||||
|
s.outputHandler(sessionID, data)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write sends data to the session's stdin.
|
||||||
|
func (s *SSHService) Write(sessionID string, data string) error {
|
||||||
|
s.mu.RLock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("session %s not found", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.mu.Lock()
|
||||||
|
defer sess.mu.Unlock()
|
||||||
|
|
||||||
|
if sess.Stdin == nil {
|
||||||
|
return fmt.Errorf("session %s stdin is closed", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := sess.Stdin.Write([]byte(data))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write to session %s: %w", sessionID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize sends a window-change request to the remote PTY.
|
||||||
|
func (s *SSHService) Resize(sessionID string, cols, rows int) error {
|
||||||
|
s.mu.RLock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("session %s not found", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.mu.Lock()
|
||||||
|
defer sess.mu.Unlock()
|
||||||
|
|
||||||
|
if sess.Session == nil {
|
||||||
|
return fmt.Errorf("session %s is closed", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sess.Session.WindowChange(rows, cols); err != nil {
|
||||||
|
return fmt.Errorf("resize session %s: %w", sessionID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect closes the SSH session and client, and removes it from tracking.
|
||||||
|
func (s *SSHService) Disconnect(sessionID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("session %s not found", sessionID)
|
||||||
|
}
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
sess.mu.Lock()
|
||||||
|
defer sess.mu.Unlock()
|
||||||
|
|
||||||
|
if sess.Stdin != nil {
|
||||||
|
sess.Stdin.Close()
|
||||||
|
}
|
||||||
|
if sess.Session != nil {
|
||||||
|
sess.Session.Close()
|
||||||
|
}
|
||||||
|
if sess.Client != nil {
|
||||||
|
sess.Client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession returns the SSHSession for the given ID, or false if not found.
|
||||||
|
func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
return sess, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSessions returns all active SSH sessions.
|
||||||
|
func (s *SSHService) ListSessions() []*SSHSession {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
list := make([]*SSHSession, 0, len(s.sessions))
|
||||||
|
for _, sess := range s.sessions {
|
||||||
|
list = append(list, sess)
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildPasswordAuth creates an ssh.AuthMethod for password authentication.
|
||||||
|
func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod {
|
||||||
|
return ssh.Password(password)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key.
|
||||||
|
// If the key is encrypted, pass the passphrase; otherwise pass an empty string.
|
||||||
|
func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) {
|
||||||
|
var signer ssh.Signer
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if passphrase != "" {
|
||||||
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
|
||||||
|
} else {
|
||||||
|
signer, err = ssh.ParsePrivateKey(privateKey)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.PublicKeys(signer), nil
|
||||||
|
}
|
||||||
148
internal/ssh/service_test.go
Normal file
148
internal/ssh/service_test.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSSHService(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
if svc == nil {
|
||||||
|
t.Fatal("NewSSHService returned nil")
|
||||||
|
}
|
||||||
|
if len(svc.ListSessions()) != 0 {
|
||||||
|
t.Error("new service should have no sessions")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPasswordAuth(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
auth := svc.BuildPasswordAuth("mypassword")
|
||||||
|
if auth == nil {
|
||||||
|
t.Error("BuildPasswordAuth returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKeyAuth(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
|
||||||
|
// Generate a test Ed25519 key
|
||||||
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateKey error: %v", err)
|
||||||
|
}
|
||||||
|
pemBlock, err := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalPrivateKey error: %v", err)
|
||||||
|
}
|
||||||
|
keyBytes := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
auth, err := svc.BuildKeyAuth(keyBytes, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BuildKeyAuth error: %v", err)
|
||||||
|
}
|
||||||
|
if auth == nil {
|
||||||
|
t.Error("BuildKeyAuth returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKeyAuthInvalidKey(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
_, err := svc.BuildKeyAuth([]byte("not a key"), "")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("BuildKeyAuth should fail with invalid key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionTracking(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
|
||||||
|
// Manually add a session to test tracking
|
||||||
|
svc.mu.Lock()
|
||||||
|
svc.sessions["test-123"] = &SSHSession{
|
||||||
|
ID: "test-123",
|
||||||
|
Hostname: "192.168.1.4",
|
||||||
|
Port: 22,
|
||||||
|
Username: "vstockwell",
|
||||||
|
Connected: time.Now(),
|
||||||
|
}
|
||||||
|
svc.mu.Unlock()
|
||||||
|
|
||||||
|
s, ok := svc.GetSession("test-123")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("session not found")
|
||||||
|
}
|
||||||
|
if s.Hostname != "192.168.1.4" {
|
||||||
|
t.Errorf("Hostname = %q, want %q", s.Hostname, "192.168.1.4")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions := svc.ListSessions()
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Errorf("ListSessions() = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSessionNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
_, ok := svc.GetSession("nonexistent")
|
||||||
|
if ok {
|
||||||
|
t.Error("GetSession should return false for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
err := svc.Write("nonexistent", "data")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Write should fail for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResizeNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
err := svc.Resize("nonexistent", 80, 24)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Resize should fail for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisconnectNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
err := svc.Disconnect("nonexistent")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Disconnect should fail for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisconnectRemovesSession(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
|
||||||
|
// Manually add a session with nil Client/Session/Stdin (no real connection)
|
||||||
|
svc.mu.Lock()
|
||||||
|
svc.sessions["test-dc"] = &SSHSession{
|
||||||
|
ID: "test-dc",
|
||||||
|
Hostname: "10.0.0.1",
|
||||||
|
Port: 22,
|
||||||
|
Username: "admin",
|
||||||
|
Connected: time.Now(),
|
||||||
|
}
|
||||||
|
svc.mu.Unlock()
|
||||||
|
|
||||||
|
if err := svc.Disconnect("test-dc"); err != nil {
|
||||||
|
t.Fatalf("Disconnect error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := svc.GetSession("test-dc")
|
||||||
|
if ok {
|
||||||
|
t.Error("session should be removed after Disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(svc.ListSessions()) != 0 {
|
||||||
|
t.Error("ListSessions should be empty after Disconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user