feat: SFTP service + credential service with encrypted key/password storage

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Vantz Stockwell 2026-03-17 06:55:18 -04:00
parent a75e21138e
commit 6e25a646d3
30 changed files with 3603 additions and 0 deletions

58
go.mod Normal file
View File

@ -0,0 +1,58 @@
module github.com/vstockwell/wraith
go 1.26.1
require (
github.com/google/uuid v1.6.0
github.com/pkg/sftp v1.13.10
github.com/wailsapp/wails/v3 v3.0.0-alpha.74
golang.org/x/crypto v0.49.0
modernc.org/sqlite v1.46.2
)
require (
dario.cat/mergo v1.0.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/adrg/xdg v0.5.3 // indirect
github.com/bep/debounce v1.2.1 // indirect
github.com/cloudflare/circl v1.6.3 // indirect
github.com/coder/websocket v1.8.14 // indirect
github.com/cyphar/filepath-securejoin v0.6.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.9.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
github.com/go-git/go-billy/v5 v5.7.0 // indirect
github.com/go-git/go-git/v5 v5.16.4 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/godbus/dbus/v5 v5.2.2 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
github.com/leaanthony/u v1.1.1 // indirect
github.com/lmittmann/tint v1.1.2 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/samber/lo v1.52.0 // indirect
github.com/sergi/go-diff v1.4.0 // indirect
github.com/skeema/knownhosts v1.3.2 // indirect
github.com/wailsapp/go-webview2 v1.0.23 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
modernc.org/libc v1.70.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
)

197
go.sum Normal file
View File

@ -0,0 +1,197 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78=
github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY=
github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0=
github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/cyphar/filepath-securejoin v0.6.1 h1:5CeZ1jPXEiYt3+Z6zqprSAgSWiggmpVyciv8syjIpVE=
github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A=
github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI=
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic=
github.com/go-git/go-billy/v5 v5.7.0 h1:83lBUJhGWhYp0ngzCMSgllhUSuoHP1iEWYjsPl9nwqM=
github.com/go-git/go-billy/v5 v5.7.0/go.mod h1:/1IUejTKH8xipsAcdfcSAlUlo2J7lkYV8GTKxAT/L3E=
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4=
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
github.com/go-git/go-git/v5 v5.16.4 h1:7ajIEZHZJULcyJebDLo99bGgS0jRrOxzZG4uCk2Yb2Y=
github.com/go-git/go-git/v5 v5.16.4/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8=
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 h1:njuLRcjAuMKr7kI3D85AXWkw6/+v9PwtV6M6o11sWHQ=
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1/go.mod h1:alcuEEnZsY1WQsagKhZDsoPCRoOijYqhZvPwLG0kzVs=
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leaanthony/go-ansi-parser v1.6.1 h1:xd8bzARK3dErqkPFtoF9F3/HgN8UQk0ed1YDKpEz01A=
github.com/leaanthony/go-ansi-parser v1.6.1/go.mod h1:+vva/2y4alzVmmIEpk9QDhA7vLC5zKDTRwfZGOp3IWU=
github.com/leaanthony/u v1.1.1 h1:TUFjwDGlNX+WuwVEzDqQwC2lOv0P4uhTQw7CMFdiK7M=
github.com/leaanthony/u v1.1.1/go.mod h1:9+o6hejoRljvZ3BzdYlVL0JYCwtnAsVuN9pVTQcaRfI=
github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w=
github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE=
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ=
github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw=
github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/skeema/knownhosts v1.3.2 h1:EDL9mgf4NzwMXCTfaxSD/o/a5fxDw/xL9nkU28JjdBg=
github.com/skeema/knownhosts v1.3.2/go.mod h1:bEg3iQAuw+jyiw+484wwFJoKSLwcfd7fqRy+N0QTiow=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/wailsapp/go-webview2 v1.0.23 h1:jmv8qhz1lHibCc79bMM/a/FqOnnzOGEisLav+a0b9P0=
github.com/wailsapp/go-webview2 v1.0.23/go.mod h1:qJmWAmAmaniuKGZPWwne+uor3AHMB5PFhqiK0Bbj8kc=
github.com/wailsapp/wails/v3 v3.0.0-alpha.74 h1:wRm1EiDQtxDisXk46NtpiBH90STwfKp36NrTDwOEdxw=
github.com/wailsapp/wails/v3 v3.0.0-alpha.74/go.mod h1:4saK4A4K9970X+X7RkMwP2lyGbLogcUz54wVeq4C/V8=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU=
golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200810151505-1b9f1253b3ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.46.2 h1:gkXQ6R0+AjxFC/fTDaeIVLbNLNrRoOK7YYVz5BKhTcE=
modernc.org/sqlite v1.46.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

167
internal/app/app.go Normal file
View File

@ -0,0 +1,167 @@
package app
import (
"database/sql"
"encoding/hex"
"fmt"
"log/slog"
"os"
"path/filepath"
"github.com/vstockwell/wraith/internal/connections"
"github.com/vstockwell/wraith/internal/db"
"github.com/vstockwell/wraith/internal/plugin"
"github.com/vstockwell/wraith/internal/session"
"github.com/vstockwell/wraith/internal/settings"
"github.com/vstockwell/wraith/internal/theme"
"github.com/vstockwell/wraith/internal/vault"
)
// WraithApp is the main application struct that wires together all services
// and exposes vault management methods to the frontend via Wails bindings.
type WraithApp struct {
db *sql.DB
Vault *vault.VaultService
Settings *settings.SettingsService
Connections *connections.ConnectionService
Themes *theme.ThemeService
Sessions *session.Manager
Plugins *plugin.Registry
unlocked bool
}
// New creates and initializes the WraithApp, opening the database, running
// migrations, creating all services, and seeding built-in themes.
func New() (*WraithApp, error) {
dataDir := dataDirectory()
dbPath := filepath.Join(dataDir, "wraith.db")
slog.Info("opening database", "path", dbPath)
database, err := db.Open(dbPath)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
if err := db.Migrate(database); err != nil {
return nil, fmt.Errorf("run migrations: %w", err)
}
settingsSvc := settings.NewSettingsService(database)
connSvc := connections.NewConnectionService(database)
themeSvc := theme.NewThemeService(database)
sessionMgr := session.NewManager()
pluginReg := plugin.NewRegistry()
// Seed built-in themes on every startup (INSERT OR IGNORE keeps it idempotent)
if err := themeSvc.SeedBuiltins(); err != nil {
slog.Warn("failed to seed themes", "error", err)
}
return &WraithApp{
db: database,
Settings: settingsSvc,
Connections: connSvc,
Themes: themeSvc,
Sessions: sessionMgr,
Plugins: pluginReg,
}, nil
}
// dataDirectory returns the path where Wraith stores its data.
// On Windows with APPDATA set, it uses %APPDATA%\Wraith.
// On macOS/Linux with XDG_DATA_HOME or HOME, it uses the appropriate path.
// Falls back to the current working directory for development.
func dataDirectory() string {
// Windows
if appData := os.Getenv("APPDATA"); appData != "" {
return filepath.Join(appData, "Wraith")
}
// macOS / Linux: use XDG_DATA_HOME or fallback to ~/.local/share
if home, err := os.UserHomeDir(); err == nil {
if xdg := os.Getenv("XDG_DATA_HOME"); xdg != "" {
return filepath.Join(xdg, "wraith")
}
return filepath.Join(home, ".local", "share", "wraith")
}
// Dev fallback
return "."
}
// IsFirstRun checks whether the vault has been set up by looking for vault_salt in settings.
func (a *WraithApp) IsFirstRun() bool {
salt, _ := a.Settings.Get("vault_salt")
return salt == ""
}
// CreateVault sets up the vault with a master password. It generates a salt,
// derives an encryption key, and stores the salt and a check value in settings.
func (a *WraithApp) CreateVault(password string) error {
salt, err := vault.GenerateSalt()
if err != nil {
return fmt.Errorf("generate salt: %w", err)
}
key := vault.DeriveKey(password, salt)
a.Vault = vault.NewVaultService(key)
// Store salt as hex in settings
if err := a.Settings.Set("vault_salt", hex.EncodeToString(salt)); err != nil {
return fmt.Errorf("store salt: %w", err)
}
// Encrypt a known check value — used to verify the password on unlock
check, err := a.Vault.Encrypt("wraith-vault-check")
if err != nil {
return fmt.Errorf("encrypt check value: %w", err)
}
if err := a.Settings.Set("vault_check", check); err != nil {
return fmt.Errorf("store check value: %w", err)
}
a.unlocked = true
slog.Info("vault created successfully")
return nil
}
// Unlock verifies the master password against the stored check value and
// initializes the vault service for decryption.
func (a *WraithApp) Unlock(password string) error {
saltHex, err := a.Settings.Get("vault_salt")
if err != nil || saltHex == "" {
return fmt.Errorf("vault not set up — call CreateVault first")
}
salt, err := hex.DecodeString(saltHex)
if err != nil {
return fmt.Errorf("decode salt: %w", err)
}
key := vault.DeriveKey(password, salt)
vs := vault.NewVaultService(key)
// Verify by decrypting the stored check value
checkEncrypted, err := a.Settings.Get("vault_check")
if err != nil || checkEncrypted == "" {
return fmt.Errorf("vault check value missing")
}
checkPlain, err := vs.Decrypt(checkEncrypted)
if err != nil {
return fmt.Errorf("incorrect master password")
}
if checkPlain != "wraith-vault-check" {
return fmt.Errorf("incorrect master password")
}
a.Vault = vs
a.unlocked = true
slog.Info("vault unlocked successfully")
return nil
}
// IsUnlocked returns whether the vault is currently unlocked.
func (a *WraithApp) IsUnlocked() bool {
return a.unlocked
}

View File

@ -0,0 +1,40 @@
package connections
import "fmt"
func (s *ConnectionService) Search(query string) ([]Connection, error) {
like := "%" + query + "%"
rows, err := s.db.Query(
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
COALESCE(color,''), tags, COALESCE(notes,''), COALESCE(options,'{}'),
sort_order, last_connected, created_at, updated_at
FROM connections
WHERE name LIKE ? COLLATE NOCASE
OR hostname LIKE ? COLLATE NOCASE
OR tags LIKE ? COLLATE NOCASE
OR notes LIKE ? COLLATE NOCASE
ORDER BY last_connected DESC NULLS LAST, name`,
like, like, like, like,
)
if err != nil {
return nil, fmt.Errorf("search connections: %w", err)
}
defer rows.Close()
return scanConnections(rows)
}
func (s *ConnectionService) FilterByTag(tag string) ([]Connection, error) {
rows, err := s.db.Query(
`SELECT c.id, c.name, c.hostname, c.port, c.protocol, c.group_id, c.credential_id,
COALESCE(c.color,''), c.tags, COALESCE(c.notes,''), COALESCE(c.options,'{}'),
c.sort_order, c.last_connected, c.created_at, c.updated_at
FROM connections c, json_each(c.tags) AS t
WHERE t.value = ?
ORDER BY c.name`, tag,
)
if err != nil {
return nil, fmt.Errorf("filter by tag: %w", err)
}
defer rows.Close()
return scanConnections(rows)
}

View File

@ -0,0 +1,53 @@
package connections
import "testing"
func TestSearchByName(t *testing.T) {
svc := setupTestService(t)
svc.CreateConnection(CreateConnectionInput{Name: "Asgard", Hostname: "192.168.1.4", Port: 22, Protocol: "ssh"})
svc.CreateConnection(CreateConnectionInput{Name: "Docker", Hostname: "155.254.29.221", Port: 22, Protocol: "ssh"})
results, err := svc.Search("asg")
if err != nil {
t.Fatalf("Search() error: %v", err)
}
if len(results) != 1 {
t.Fatalf("len(results) = %d, want 1", len(results))
}
if results[0].Name != "Asgard" {
t.Errorf("Name = %q, want %q", results[0].Name, "Asgard")
}
}
func TestSearchByHostname(t *testing.T) {
svc := setupTestService(t)
svc.CreateConnection(CreateConnectionInput{Name: "Asgard", Hostname: "192.168.1.4", Port: 22, Protocol: "ssh"})
results, _ := svc.Search("192.168")
if len(results) != 1 {
t.Errorf("len(results) = %d, want 1", len(results))
}
}
func TestSearchByTag(t *testing.T) {
svc := setupTestService(t)
svc.CreateConnection(CreateConnectionInput{Name: "ProdServer", Hostname: "10.0.0.1", Port: 22, Protocol: "ssh", Tags: []string{"Prod", "Linux"}})
svc.CreateConnection(CreateConnectionInput{Name: "DevServer", Hostname: "10.0.0.2", Port: 22, Protocol: "ssh", Tags: []string{"Dev", "Linux"}})
results, _ := svc.Search("Prod")
if len(results) != 1 {
t.Errorf("len(results) = %d, want 1", len(results))
}
}
func TestFilterByTag(t *testing.T) {
svc := setupTestService(t)
svc.CreateConnection(CreateConnectionInput{Name: "A", Hostname: "10.0.0.1", Port: 22, Protocol: "ssh", Tags: []string{"Prod"}})
svc.CreateConnection(CreateConnectionInput{Name: "B", Hostname: "10.0.0.2", Port: 22, Protocol: "ssh", Tags: []string{"Dev"}})
svc.CreateConnection(CreateConnectionInput{Name: "C", Hostname: "10.0.0.3", Port: 22, Protocol: "ssh", Tags: []string{"Prod", "Linux"}})
results, _ := svc.FilterByTag("Prod")
if len(results) != 2 {
t.Errorf("len(results) = %d, want 2", len(results))
}
}

View File

@ -0,0 +1,361 @@
package connections
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parentId"`
SortOrder int `json:"sortOrder"`
Icon string `json:"icon"`
CreatedAt time.Time `json:"createdAt"`
Children []Group `json:"children,omitempty"`
}
type Connection struct {
ID int64 `json:"id"`
Name string `json:"name"`
Hostname string `json:"hostname"`
Port int `json:"port"`
Protocol string `json:"protocol"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color string `json:"color"`
Tags []string `json:"tags"`
Notes string `json:"notes"`
Options string `json:"options"`
SortOrder int `json:"sortOrder"`
LastConnected *time.Time `json:"lastConnected"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type CreateConnectionInput struct {
Name string `json:"name"`
Hostname string `json:"hostname"`
Port int `json:"port"`
Protocol string `json:"protocol"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color string `json:"color"`
Tags []string `json:"tags"`
Notes string `json:"notes"`
Options string `json:"options"`
}
type UpdateConnectionInput struct {
Name *string `json:"name"`
Hostname *string `json:"hostname"`
Port *int `json:"port"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color *string `json:"color"`
Tags []string `json:"tags"`
Notes *string `json:"notes"`
Options *string `json:"options"`
}
type ConnectionService struct {
db *sql.DB
}
func NewConnectionService(db *sql.DB) *ConnectionService {
return &ConnectionService{db: db}
}
// ---------- Group CRUD ----------
func (s *ConnectionService) CreateGroup(name string, parentID *int64) (*Group, error) {
result, err := s.db.Exec(
"INSERT INTO groups (name, parent_id) VALUES (?, ?)",
name, parentID,
)
if err != nil {
return nil, fmt.Errorf("create group: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get group id: %w", err)
}
var g Group
var icon sql.NullString
err = s.db.QueryRow(
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups WHERE id = ?", id,
).Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt)
if err != nil {
return nil, fmt.Errorf("get created group: %w", err)
}
if icon.Valid {
g.Icon = icon.String
}
return &g, nil
}
func (s *ConnectionService) ListGroups() ([]Group, error) {
rows, err := s.db.Query(
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups ORDER BY sort_order, name",
)
if err != nil {
return nil, fmt.Errorf("list groups: %w", err)
}
defer rows.Close()
groupMap := make(map[int64]*Group)
var allGroups []*Group
for rows.Next() {
var g Group
var icon sql.NullString
if err := rows.Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt); err != nil {
return nil, fmt.Errorf("scan group: %w", err)
}
if icon.Valid {
g.Icon = icon.String
}
groupMap[g.ID] = &g
allGroups = append(allGroups, &g)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate groups: %w", err)
}
// Build tree: attach children to parents, collect roots
var roots []Group
for _, g := range allGroups {
if g.ParentID != nil {
if parent, ok := groupMap[*g.ParentID]; ok {
parent.Children = append(parent.Children, *g)
}
} else {
roots = append(roots, *g)
}
}
// Re-attach children to root copies (since we copied into roots)
for i := range roots {
if orig, ok := groupMap[roots[i].ID]; ok {
roots[i].Children = orig.Children
}
}
if roots == nil {
roots = []Group{}
}
return roots, nil
}
func (s *ConnectionService) DeleteGroup(id int64) error {
_, err := s.db.Exec("DELETE FROM groups WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
}
// ---------- Connection CRUD ----------
func (s *ConnectionService) CreateConnection(input CreateConnectionInput) (*Connection, error) {
tags, err := json.Marshal(input.Tags)
if err != nil {
return nil, fmt.Errorf("marshal tags: %w", err)
}
if input.Tags == nil {
tags = []byte("[]")
}
options := input.Options
if options == "" {
options = "{}"
}
result, err := s.db.Exec(
`INSERT INTO connections (name, hostname, port, protocol, group_id, credential_id, color, tags, notes, options)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
input.Name, input.Hostname, input.Port, input.Protocol,
input.GroupID, input.CredentialID, input.Color,
string(tags), input.Notes, options,
)
if err != nil {
return nil, fmt.Errorf("create connection: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get connection id: %w", err)
}
return s.GetConnection(id)
}
func (s *ConnectionService) GetConnection(id int64) (*Connection, error) {
row := s.db.QueryRow(
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
FROM connections WHERE id = ?`, id,
)
var c Connection
var tagsJSON string
var color, notes, options sql.NullString
var lastConnected sql.NullTime
err := row.Scan(
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
&c.GroupID, &c.CredentialID,
&color, &tagsJSON, &notes, &options,
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("get connection: %w", err)
}
if color.Valid {
c.Color = color.String
}
if notes.Valid {
c.Notes = notes.String
}
if options.Valid {
c.Options = options.String
}
if lastConnected.Valid {
c.LastConnected = &lastConnected.Time
}
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
c.Tags = []string{}
}
if c.Tags == nil {
c.Tags = []string{}
}
return &c, nil
}
func (s *ConnectionService) ListConnections() ([]Connection, error) {
rows, err := s.db.Query(
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
FROM connections ORDER BY sort_order, name`,
)
if err != nil {
return nil, fmt.Errorf("list connections: %w", err)
}
defer rows.Close()
return scanConnections(rows)
}
// scanConnections is a shared helper used by ListConnections and (later) Search.
func scanConnections(rows *sql.Rows) ([]Connection, error) {
var conns []Connection
for rows.Next() {
var c Connection
var tagsJSON string
var color, notes, options sql.NullString
var lastConnected sql.NullTime
if err := rows.Scan(
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
&c.GroupID, &c.CredentialID,
&color, &tagsJSON, &notes, &options,
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan connection: %w", err)
}
if color.Valid {
c.Color = color.String
}
if notes.Valid {
c.Notes = notes.String
}
if options.Valid {
c.Options = options.String
}
if lastConnected.Valid {
c.LastConnected = &lastConnected.Time
}
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
c.Tags = []string{}
}
if c.Tags == nil {
c.Tags = []string{}
}
conns = append(conns, c)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate connections: %w", err)
}
if conns == nil {
conns = []Connection{}
}
return conns, nil
}
func (s *ConnectionService) UpdateConnection(id int64, input UpdateConnectionInput) (*Connection, error) {
setClauses := []string{"updated_at = CURRENT_TIMESTAMP"}
args := []interface{}{}
if input.Name != nil {
setClauses = append(setClauses, "name = ?")
args = append(args, *input.Name)
}
if input.Hostname != nil {
setClauses = append(setClauses, "hostname = ?")
args = append(args, *input.Hostname)
}
if input.Port != nil {
setClauses = append(setClauses, "port = ?")
args = append(args, *input.Port)
}
if input.GroupID != nil {
setClauses = append(setClauses, "group_id = ?")
args = append(args, *input.GroupID)
}
if input.CredentialID != nil {
setClauses = append(setClauses, "credential_id = ?")
args = append(args, *input.CredentialID)
}
if input.Tags != nil {
tags, _ := json.Marshal(input.Tags)
setClauses = append(setClauses, "tags = ?")
args = append(args, string(tags))
}
if input.Notes != nil {
setClauses = append(setClauses, "notes = ?")
args = append(args, *input.Notes)
}
if input.Color != nil {
setClauses = append(setClauses, "color = ?")
args = append(args, *input.Color)
}
if input.Options != nil {
setClauses = append(setClauses, "options = ?")
args = append(args, *input.Options)
}
args = append(args, id)
query := fmt.Sprintf("UPDATE connections SET %s WHERE id = ?", strings.Join(setClauses, ", "))
if _, err := s.db.Exec(query, args...); err != nil {
return nil, fmt.Errorf("update connection: %w", err)
}
return s.GetConnection(id)
}
func (s *ConnectionService) DeleteConnection(id int64) error {
_, err := s.db.Exec("DELETE FROM connections WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete connection: %w", err)
}
return nil
}

View File

@ -0,0 +1,234 @@
package connections
import (
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
)
func strPtr(s string) *string { return &s }
func setupTestService(t *testing.T) *ConnectionService {
t.Helper()
dir := t.TempDir()
database, err := db.Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("db.Open() error: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("db.Migrate() error: %v", err)
}
t.Cleanup(func() { database.Close() })
return NewConnectionService(database)
}
func TestCreateGroup(t *testing.T) {
svc := setupTestService(t)
g, err := svc.CreateGroup("Servers", nil)
if err != nil {
t.Fatalf("CreateGroup() error: %v", err)
}
if g.ID == 0 {
t.Error("expected non-zero ID")
}
if g.Name != "Servers" {
t.Errorf("Name = %q, want %q", g.Name, "Servers")
}
if g.ParentID != nil {
t.Errorf("ParentID = %v, want nil", g.ParentID)
}
}
func TestCreateSubGroup(t *testing.T) {
svc := setupTestService(t)
parent, err := svc.CreateGroup("Servers", nil)
if err != nil {
t.Fatalf("CreateGroup(parent) error: %v", err)
}
child, err := svc.CreateGroup("Production", &parent.ID)
if err != nil {
t.Fatalf("CreateGroup(child) error: %v", err)
}
if child.ParentID == nil {
t.Fatal("expected non-nil ParentID")
}
if *child.ParentID != parent.ID {
t.Errorf("ParentID = %d, want %d", *child.ParentID, parent.ID)
}
}
func TestListGroups(t *testing.T) {
svc := setupTestService(t)
parent, err := svc.CreateGroup("Servers", nil)
if err != nil {
t.Fatalf("CreateGroup(parent) error: %v", err)
}
if _, err := svc.CreateGroup("Production", &parent.ID); err != nil {
t.Fatalf("CreateGroup(child) error: %v", err)
}
groups, err := svc.ListGroups()
if err != nil {
t.Fatalf("ListGroups() error: %v", err)
}
if len(groups) != 1 {
t.Fatalf("len(groups) = %d, want 1 (only root groups)", len(groups))
}
if groups[0].Name != "Servers" {
t.Errorf("groups[0].Name = %q, want %q", groups[0].Name, "Servers")
}
if len(groups[0].Children) != 1 {
t.Fatalf("len(children) = %d, want 1", len(groups[0].Children))
}
if groups[0].Children[0].Name != "Production" {
t.Errorf("child name = %q, want %q", groups[0].Children[0].Name, "Production")
}
}
func TestDeleteGroup(t *testing.T) {
svc := setupTestService(t)
g, err := svc.CreateGroup("ToDelete", nil)
if err != nil {
t.Fatalf("CreateGroup() error: %v", err)
}
if err := svc.DeleteGroup(g.ID); err != nil {
t.Fatalf("DeleteGroup() error: %v", err)
}
groups, err := svc.ListGroups()
if err != nil {
t.Fatalf("ListGroups() error: %v", err)
}
if len(groups) != 0 {
t.Errorf("len(groups) = %d, want 0 after delete", len(groups))
}
}
func TestCreateConnection(t *testing.T) {
svc := setupTestService(t)
conn, err := svc.CreateConnection(CreateConnectionInput{
Name: "Web Server",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
Tags: []string{"Prod", "Linux"},
Options: `{"keepAliveInterval": 60}`,
})
if err != nil {
t.Fatalf("CreateConnection() error: %v", err)
}
if conn.ID == 0 {
t.Error("expected non-zero ID")
}
if conn.Name != "Web Server" {
t.Errorf("Name = %q, want %q", conn.Name, "Web Server")
}
if len(conn.Tags) != 2 {
t.Fatalf("len(Tags) = %d, want 2", len(conn.Tags))
}
if conn.Tags[0] != "Prod" || conn.Tags[1] != "Linux" {
t.Errorf("Tags = %v, want [Prod Linux]", conn.Tags)
}
if conn.Options != `{"keepAliveInterval": 60}` {
t.Errorf("Options = %q, want JSON blob", conn.Options)
}
}
func TestListConnections(t *testing.T) {
svc := setupTestService(t)
if _, err := svc.CreateConnection(CreateConnectionInput{
Name: "Server A",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
}); err != nil {
t.Fatalf("CreateConnection(A) error: %v", err)
}
if _, err := svc.CreateConnection(CreateConnectionInput{
Name: "Server B",
Hostname: "10.0.0.2",
Port: 3389,
Protocol: "rdp",
}); err != nil {
t.Fatalf("CreateConnection(B) error: %v", err)
}
conns, err := svc.ListConnections()
if err != nil {
t.Fatalf("ListConnections() error: %v", err)
}
if len(conns) != 2 {
t.Fatalf("len(conns) = %d, want 2", len(conns))
}
}
func TestUpdateConnection(t *testing.T) {
svc := setupTestService(t)
conn, err := svc.CreateConnection(CreateConnectionInput{
Name: "Old Name",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
Tags: []string{"Dev"},
})
if err != nil {
t.Fatalf("CreateConnection() error: %v", err)
}
updated, err := svc.UpdateConnection(conn.ID, UpdateConnectionInput{
Name: strPtr("New Name"),
Tags: []string{"Prod", "Linux"},
})
if err != nil {
t.Fatalf("UpdateConnection() error: %v", err)
}
if updated.Name != "New Name" {
t.Errorf("Name = %q, want %q", updated.Name, "New Name")
}
if len(updated.Tags) != 2 {
t.Fatalf("len(Tags) = %d, want 2", len(updated.Tags))
}
if updated.Tags[0] != "Prod" {
t.Errorf("Tags[0] = %q, want %q", updated.Tags[0], "Prod")
}
// Hostname should remain unchanged
if updated.Hostname != "10.0.0.1" {
t.Errorf("Hostname = %q, want %q (unchanged)", updated.Hostname, "10.0.0.1")
}
}
func TestDeleteConnection(t *testing.T) {
svc := setupTestService(t)
conn, err := svc.CreateConnection(CreateConnectionInput{
Name: "ToDelete",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
})
if err != nil {
t.Fatalf("CreateConnection() error: %v", err)
}
if err := svc.DeleteConnection(conn.ID); err != nil {
t.Fatalf("DeleteConnection() error: %v", err)
}
conns, err := svc.ListConnections()
if err != nil {
t.Fatalf("ListConnections() error: %v", err)
}
if len(conns) != 0 {
t.Errorf("len(conns) = %d, want 0 after delete", len(conns))
}
}

View File

@ -0,0 +1,408 @@
package credentials
import (
"crypto/ed25519"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/pem"
"fmt"
"github.com/vstockwell/wraith/internal/vault"
"golang.org/x/crypto/ssh"
)
// Credential represents a stored credential (password or SSH key reference).
type Credential struct {
ID int64 `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
Domain string `json:"domain"`
Type string `json:"type"` // "password" or "ssh_key"
SSHKeyID *int64 `json:"sshKeyId"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
// SSHKey represents a stored SSH key.
type SSHKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
KeyType string `json:"keyType"`
Fingerprint string `json:"fingerprint"`
PublicKey string `json:"publicKey"`
CreatedAt string `json:"createdAt"`
}
// CredentialService provides CRUD for credentials with vault encryption.
type CredentialService struct {
db *sql.DB
vault *vault.VaultService
}
// NewCredentialService creates a new CredentialService.
func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService {
return &CredentialService{db: db, vault: vault}
}
// CreatePassword creates a password credential (password encrypted via vault).
func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) {
encrypted, err := s.vault.Encrypt(password)
if err != nil {
return nil, fmt.Errorf("encrypt password: %w", err)
}
result, err := s.db.Exec(
`INSERT INTO credentials (name, username, domain, type, encrypted_value)
VALUES (?, ?, ?, 'password', ?)`,
name, username, domain, encrypted,
)
if err != nil {
return nil, fmt.Errorf("insert credential: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get credential id: %w", err)
}
return s.getCredential(id)
}
// CreateSSHKey imports an SSH key (private key encrypted via vault).
func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) {
// Parse the private key to detect type and extract public key
keyType := DetectKeyType(privateKeyPEM)
// Parse the key to get the public key for fingerprinting.
// Try without passphrase first (handles unencrypted keys even when a
// passphrase is provided for storage), then fall back to using the
// passphrase for encrypted PEM keys.
var signer ssh.Signer
var err error
signer, err = ssh.ParsePrivateKey(privateKeyPEM)
if err != nil && passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase))
}
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
pubKey := signer.PublicKey()
fingerprint := ssh.FingerprintSHA256(pubKey)
publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey))
// Encrypt private key via vault
encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM))
if err != nil {
return nil, fmt.Errorf("encrypt private key: %w", err)
}
// Encrypt passphrase via vault (if provided)
var encryptedPassphrase sql.NullString
if passphrase != "" {
ep, err := s.vault.Encrypt(passphrase)
if err != nil {
return nil, fmt.Errorf("encrypt passphrase: %w", err)
}
encryptedPassphrase = sql.NullString{String: ep, Valid: true}
}
result, err := s.db.Exec(
`INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted)
VALUES (?, ?, ?, ?, ?, ?)`,
name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase,
)
if err != nil {
return nil, fmt.Errorf("insert ssh key: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get ssh key id: %w", err)
}
return s.getSSHKey(id)
}
// ListCredentials returns all credentials WITHOUT encrypted values.
func (s *CredentialService) ListCredentials() ([]Credential, error) {
rows, err := s.db.Query(
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
FROM credentials ORDER BY name`,
)
if err != nil {
return nil, fmt.Errorf("list credentials: %w", err)
}
defer rows.Close()
var creds []Credential
for rows.Next() {
var c Credential
var username, domain sql.NullString
if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan credential: %w", err)
}
if username.Valid {
c.Username = username.String
}
if domain.Valid {
c.Domain = domain.String
}
creds = append(creds, c)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate credentials: %w", err)
}
if creds == nil {
creds = []Credential{}
}
return creds, nil
}
// ListSSHKeys returns all SSH keys WITHOUT private key data.
func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) {
rows, err := s.db.Query(
`SELECT id, name, key_type, fingerprint, public_key, created_at
FROM ssh_keys ORDER BY name`,
)
if err != nil {
return nil, fmt.Errorf("list ssh keys: %w", err)
}
defer rows.Close()
var keys []SSHKey
for rows.Next() {
var k SSHKey
var keyType, fingerprint, publicKey sql.NullString
if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil {
return nil, fmt.Errorf("scan ssh key: %w", err)
}
if keyType.Valid {
k.KeyType = keyType.String
}
if fingerprint.Valid {
k.Fingerprint = fingerprint.String
}
if publicKey.Valid {
k.PublicKey = publicKey.String
}
keys = append(keys, k)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate ssh keys: %w", err)
}
if keys == nil {
keys = []SSHKey{}
}
return keys, nil
}
// DecryptPassword returns the decrypted password for a credential.
func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) {
var encrypted sql.NullString
err := s.db.QueryRow(
"SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'",
credentialID,
).Scan(&encrypted)
if err != nil {
return "", fmt.Errorf("get encrypted password: %w", err)
}
if !encrypted.Valid {
return "", fmt.Errorf("no encrypted value for credential %d", credentialID)
}
password, err := s.vault.Decrypt(encrypted.String)
if err != nil {
return "", fmt.Errorf("decrypt password: %w", err)
}
return password, nil
}
// DecryptSSHKey returns the decrypted private key + passphrase.
func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) {
var encryptedKey string
var encryptedPassphrase sql.NullString
err = s.db.QueryRow(
"SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?",
sshKeyID,
).Scan(&encryptedKey, &encryptedPassphrase)
if err != nil {
return nil, "", fmt.Errorf("get encrypted ssh key: %w", err)
}
decryptedKey, err := s.vault.Decrypt(encryptedKey)
if err != nil {
return nil, "", fmt.Errorf("decrypt private key: %w", err)
}
if encryptedPassphrase.Valid {
passphrase, err = s.vault.Decrypt(encryptedPassphrase.String)
if err != nil {
return nil, "", fmt.Errorf("decrypt passphrase: %w", err)
}
}
return []byte(decryptedKey), passphrase, nil
}
// DeleteCredential removes a credential.
func (s *CredentialService) DeleteCredential(id int64) error {
_, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete credential: %w", err)
}
return nil
}
// DeleteSSHKey removes an SSH key.
func (s *CredentialService) DeleteSSHKey(id int64) error {
_, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete ssh key: %w", err)
}
return nil
}
// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa).
func DetectKeyType(pemData []byte) string {
block, _ := pem.Decode(pemData)
if block == nil {
return "unknown"
}
// Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks)
if block.Type == "OPENSSH PRIVATE KEY" {
return detectOpenSSHKeyType(block.Bytes)
}
// Try PKCS8 format
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
switch key.(type) {
case *rsa.PrivateKey:
return "rsa"
case ed25519.PrivateKey:
return "ed25519"
case *ecdsa.PrivateKey:
return "ecdsa"
}
}
// Try RSA PKCS1
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return "rsa"
}
// Try EC
if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return "ecdsa"
}
return "unknown"
}
// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type.
func detectOpenSSHKeyType(data []byte) string {
// OpenSSH private key format: "openssh-key-v1\0" magic, then fields.
// We parse the key using ssh package to determine the type.
// Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer.
pemBlock := &pem.Block{
Type: "OPENSSH PRIVATE KEY",
Bytes: data,
}
pemBytes := pem.EncodeToMemory(pemBlock)
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return "unknown"
}
return classifyPublicKey(signer.PublicKey())
}
// classifyPublicKey determines the key type from an ssh.PublicKey.
func classifyPublicKey(pub ssh.PublicKey) string {
keyType := pub.Type()
switch keyType {
case "ssh-rsa":
return "rsa"
case "ssh-ed25519":
return "ed25519"
case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521":
return "ecdsa"
default:
return keyType
}
}
// getCredential retrieves a single credential by ID.
func (s *CredentialService) getCredential(id int64) (*Credential, error) {
var c Credential
var username, domain sql.NullString
err := s.db.QueryRow(
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
FROM credentials WHERE id = ?`, id,
).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("get credential: %w", err)
}
if username.Valid {
c.Username = username.String
}
if domain.Valid {
c.Domain = domain.String
}
return &c, nil
}
// getSSHKey retrieves a single SSH key by ID (without private key data).
func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) {
var k SSHKey
var keyType, fingerprint, publicKey sql.NullString
err := s.db.QueryRow(
`SELECT id, name, key_type, fingerprint, public_key, created_at
FROM ssh_keys WHERE id = ?`, id,
).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt)
if err != nil {
return nil, fmt.Errorf("get ssh key: %w", err)
}
if keyType.Valid {
k.KeyType = keyType.String
}
if fingerprint.Valid {
k.Fingerprint = fingerprint.String
}
if publicKey.Valid {
k.PublicKey = publicKey.String
}
return &k, nil
}
// generateFingerprint generates an SSH fingerprint string from a public key.
func generateFingerprint(pubKey ssh.PublicKey) string {
return ssh.FingerprintSHA256(pubKey)
}
// marshalPublicKey returns the authorized_keys format of an SSH public key.
func marshalPublicKey(pubKey ssh.PublicKey) string {
return base64.StdEncoding.EncodeToString(pubKey.Marshal())
}
// ecdsaCurveName returns the name for an ECDSA curve.
func ecdsaCurveName(curve elliptic.Curve) string {
switch curve {
case elliptic.P256():
return "nistp256"
case elliptic.P384():
return "nistp384"
case elliptic.P521():
return "nistp521"
default:
return "unknown"
}
}

View File

@ -0,0 +1,176 @@
package credentials
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
"github.com/vstockwell/wraith/internal/vault"
"golang.org/x/crypto/ssh"
)
func setupCredentialService(t *testing.T) *CredentialService {
t.Helper()
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
if err := db.Migrate(d); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { d.Close() })
salt := []byte("test-salt-exactly-32-bytes-long!")
key := vault.DeriveKey("testpassword", salt)
vs := vault.NewVaultService(key)
return NewCredentialService(d, vs)
}
func TestCreatePasswordCredential(t *testing.T) {
svc := setupCredentialService(t)
cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "")
if err != nil {
t.Fatal(err)
}
if cred.Name != "Test Cred" {
t.Error("wrong name")
}
if cred.Type != "password" {
t.Error("wrong type")
}
}
func TestDecryptPassword(t *testing.T) {
svc := setupCredentialService(t)
cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "")
password, err := svc.DecryptPassword(cred.ID)
if err != nil {
t.Fatal(err)
}
if password != "mypassword" {
t.Errorf("got %q, want mypassword", password)
}
}
func TestListCredentialsExcludesSecrets(t *testing.T) {
svc := setupCredentialService(t)
svc.CreatePassword("Cred1", "user1", "pass1", "")
svc.CreatePassword("Cred2", "user2", "pass2", "")
creds, err := svc.ListCredentials()
if err != nil {
t.Fatal(err)
}
if len(creds) != 2 {
t.Errorf("got %d, want 2", len(creds))
}
}
func TestCreateSSHKey(t *testing.T) {
svc := setupCredentialService(t)
// Generate a test key
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
key, err := svc.CreateSSHKey("My Key", keyPEM, "")
if err != nil {
t.Fatal(err)
}
if key.KeyType != "ed25519" {
t.Errorf("KeyType = %q, want ed25519", key.KeyType)
}
if key.Fingerprint == "" {
t.Error("fingerprint should not be empty")
}
}
func TestDecryptSSHKey(t *testing.T) {
svc := setupCredentialService(t)
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass")
decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID)
if err != nil {
t.Fatal(err)
}
if len(decryptedKey) == 0 {
t.Error("decrypted key should not be empty")
}
if passphrase != "testpass" {
t.Errorf("passphrase = %q, want testpass", passphrase)
}
}
func TestDetectKeyType(t *testing.T) {
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
if got := DetectKeyType(keyPEM); got != "ed25519" {
t.Errorf("got %q", got)
}
}
func TestDeleteCredential(t *testing.T) {
svc := setupCredentialService(t)
cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "")
err := svc.DeleteCredential(cred.ID)
if err != nil {
t.Fatal(err)
}
creds, _ := svc.ListCredentials()
if len(creds) != 0 {
t.Errorf("got %d credentials, want 0", len(creds))
}
}
func TestDeleteSSHKey(t *testing.T) {
svc := setupCredentialService(t)
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "")
err := svc.DeleteSSHKey(key.ID)
if err != nil {
t.Fatal(err)
}
keys, _ := svc.ListSSHKeys()
if len(keys) != 0 {
t.Errorf("got %d keys, want 0", len(keys))
}
}
func TestListSSHKeys(t *testing.T) {
svc := setupCredentialService(t)
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
svc.CreateSSHKey("Key1", keyPEM, "")
svc.CreateSSHKey("Key2", keyPEM, "")
keys, err := svc.ListSSHKeys()
if err != nil {
t.Fatal(err)
}
if len(keys) != 2 {
t.Errorf("got %d keys, want 2", len(keys))
}
}
func TestCreatePasswordWithDomain(t *testing.T) {
svc := setupCredentialService(t)
cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com")
if err != nil {
t.Fatal(err)
}
if cred.Domain != "example.com" {
t.Errorf("domain = %q, want example.com", cred.Domain)
}
}

34
internal/db/migrations.go Normal file
View File

@ -0,0 +1,34 @@
package db
import (
"database/sql"
"embed"
"fmt"
"sort"
)
//go:embed migrations/*.sql
var migrationFiles embed.FS
func Migrate(db *sql.DB) error {
entries, err := migrationFiles.ReadDir("migrations")
if err != nil {
return fmt.Errorf("read migrations: %w", err)
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Name() < entries[j].Name()
})
for _, entry := range entries {
content, err := migrationFiles.ReadFile("migrations/" + entry.Name())
if err != nil {
return fmt.Errorf("read migration %s: %w", entry.Name(), err)
}
if _, err := db.Exec(string(content)); err != nil {
return fmt.Errorf("execute migration %s: %w", entry.Name(), err)
}
}
return nil
}

View File

@ -0,0 +1,101 @@
-- 001_initial.sql
CREATE TABLE IF NOT EXISTS groups (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
parent_id INTEGER REFERENCES groups(id) ON DELETE SET NULL,
sort_order INTEGER DEFAULT 0,
icon TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS ssh_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
key_type TEXT,
fingerprint TEXT,
public_key TEXT,
encrypted_private_key TEXT NOT NULL,
passphrase_encrypted TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS credentials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
username TEXT,
domain TEXT,
type TEXT NOT NULL CHECK(type IN ('password','ssh_key')),
encrypted_value TEXT,
ssh_key_id INTEGER REFERENCES ssh_keys(id) ON DELETE SET NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS connections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
hostname TEXT NOT NULL,
port INTEGER NOT NULL DEFAULT 22,
protocol TEXT NOT NULL CHECK(protocol IN ('ssh','rdp')),
group_id INTEGER REFERENCES groups(id) ON DELETE SET NULL,
credential_id INTEGER REFERENCES credentials(id) ON DELETE SET NULL,
color TEXT,
tags TEXT DEFAULT '[]',
notes TEXT,
options TEXT DEFAULT '{}',
sort_order INTEGER DEFAULT 0,
last_connected DATETIME,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS themes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
foreground TEXT NOT NULL,
background TEXT NOT NULL,
cursor TEXT NOT NULL,
black TEXT NOT NULL,
red TEXT NOT NULL,
green TEXT NOT NULL,
yellow TEXT NOT NULL,
blue TEXT NOT NULL,
magenta TEXT NOT NULL,
cyan TEXT NOT NULL,
white TEXT NOT NULL,
bright_black TEXT NOT NULL,
bright_red TEXT NOT NULL,
bright_green TEXT NOT NULL,
bright_yellow TEXT NOT NULL,
bright_blue TEXT NOT NULL,
bright_magenta TEXT NOT NULL,
bright_cyan TEXT NOT NULL,
bright_white TEXT NOT NULL,
selection_bg TEXT,
selection_fg TEXT,
is_builtin BOOLEAN DEFAULT 0
);
CREATE TABLE IF NOT EXISTS connection_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
connection_id INTEGER NOT NULL REFERENCES connections(id) ON DELETE CASCADE,
protocol TEXT NOT NULL,
connected_at DATETIME DEFAULT CURRENT_TIMESTAMP,
disconnected_at DATETIME,
duration_secs INTEGER
);
CREATE TABLE IF NOT EXISTS host_keys (
hostname TEXT NOT NULL,
port INTEGER NOT NULL,
key_type TEXT NOT NULL,
fingerprint TEXT NOT NULL,
raw_key TEXT,
first_seen DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (hostname, port, key_type)
);
CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);

39
internal/db/sqlite.go Normal file
View File

@ -0,0 +1,39 @@
package db
import (
"database/sql"
"fmt"
"os"
"path/filepath"
_ "modernc.org/sqlite"
)
func Open(dbPath string) (*sql.DB, error) {
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, fmt.Errorf("create db directory: %w", err)
}
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
db.Close()
return nil, fmt.Errorf("set WAL mode: %w", err)
}
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
db.Close()
return nil, fmt.Errorf("set busy_timeout: %w", err)
}
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
db.Close()
return nil, fmt.Errorf("enable foreign keys: %w", err)
}
return db, nil
}

View File

@ -0,0 +1,85 @@
package db
import (
"os"
"path/filepath"
"testing"
)
func TestOpenCreatesDatabase(t *testing.T) {
dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
db, err := Open(dbPath)
if err != nil {
t.Fatalf("Open() error: %v", err)
}
defer db.Close()
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
t.Fatal("database file was not created")
}
}
func TestOpenSetsWALMode(t *testing.T) {
dir := t.TempDir()
db, err := Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("Open() error: %v", err)
}
defer db.Close()
var mode string
err = db.QueryRow("PRAGMA journal_mode").Scan(&mode)
if err != nil {
t.Fatalf("PRAGMA query error: %v", err)
}
if mode != "wal" {
t.Errorf("journal_mode = %q, want %q", mode, "wal")
}
}
func TestOpenSetsBusyTimeout(t *testing.T) {
dir := t.TempDir()
db, err := Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("Open() error: %v", err)
}
defer db.Close()
var timeout int
err = db.QueryRow("PRAGMA busy_timeout").Scan(&timeout)
if err != nil {
t.Fatalf("PRAGMA query error: %v", err)
}
if timeout != 5000 {
t.Errorf("busy_timeout = %d, want %d", timeout, 5000)
}
}
func TestMigrateCreatesAllTables(t *testing.T) {
dir := t.TempDir()
db, err := Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("Open() error: %v", err)
}
defer db.Close()
if err := Migrate(db); err != nil {
t.Fatalf("Migrate() error: %v", err)
}
expectedTables := []string{
"groups", "connections", "credentials", "ssh_keys",
"themes", "connection_history", "host_keys", "settings",
}
for _, table := range expectedTables {
var name string
err := db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", table,
).Scan(&name)
if err != nil {
t.Errorf("table %q not found: %v", table, err)
}
}
}

View File

@ -0,0 +1,57 @@
package plugin
type ProtocolHandler interface {
Name() string
Connect(config map[string]interface{}) (Session, error)
Disconnect(sessionID string) error
}
type Session interface {
ID() string
Protocol() string
Write(data []byte) error
Close() error
}
type Importer interface {
Name() string
FileExtensions() []string
Parse(data []byte) (*ImportResult, error)
}
type ImportResult struct {
Groups []ImportGroup `json:"groups"`
Connections []ImportConnection `json:"connections"`
HostKeys []ImportHostKey `json:"hostKeys"`
Theme *ImportTheme `json:"theme,omitempty"`
}
type ImportGroup struct {
Name string `json:"name"`
ParentName string `json:"parentName,omitempty"`
}
type ImportConnection struct {
Name string `json:"name"`
Hostname string `json:"hostname"`
Port int `json:"port"`
Protocol string `json:"protocol"`
Username string `json:"username"`
GroupName string `json:"groupName"`
Notes string `json:"notes"`
}
type ImportHostKey struct {
Hostname string `json:"hostname"`
Port int `json:"port"`
KeyType string `json:"keyType"`
Fingerprint string `json:"fingerprint"`
}
type ImportTheme struct {
Name string `json:"name"`
Foreground string `json:"foreground"`
Background string `json:"background"`
Cursor string `json:"cursor"`
Colors [16]string `json:"colors"`
}

View File

@ -0,0 +1,47 @@
package plugin
import "fmt"
type Registry struct {
protocols map[string]ProtocolHandler
importers map[string]Importer
}
func NewRegistry() *Registry {
return &Registry{
protocols: make(map[string]ProtocolHandler),
importers: make(map[string]Importer),
}
}
func (r *Registry) RegisterProtocol(handler ProtocolHandler) {
r.protocols[handler.Name()] = handler
}
func (r *Registry) RegisterImporter(imp Importer) {
r.importers[imp.Name()] = imp
}
func (r *Registry) GetProtocol(name string) (ProtocolHandler, error) {
h, ok := r.protocols[name]
if !ok {
return nil, fmt.Errorf("protocol handler %q not registered", name)
}
return h, nil
}
func (r *Registry) GetImporter(name string) (Importer, error) {
imp, ok := r.importers[name]
if !ok {
return nil, fmt.Errorf("importer %q not registered", name)
}
return imp, nil
}
func (r *Registry) ListProtocols() []string {
names := make([]string, 0, len(r.protocols))
for name := range r.protocols {
names = append(names, name)
}
return names
}

View File

@ -0,0 +1,96 @@
package session
import (
"fmt"
"sync"
"github.com/google/uuid"
)
const MaxSessions = 32
type Manager struct {
mu sync.RWMutex
sessions map[string]*SessionInfo
}
func NewManager() *Manager {
return &Manager{
sessions: make(map[string]*SessionInfo),
}
}
func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.sessions) >= MaxSessions {
return nil, fmt.Errorf("maximum sessions (%d) reached", MaxSessions)
}
s := &SessionInfo{
ID: uuid.NewString(),
ConnectionID: connectionID,
Protocol: protocol,
State: StateConnecting,
TabPosition: len(m.sessions),
}
m.sessions[s.ID] = s
return s, nil
}
func (m *Manager) Get(id string) (*SessionInfo, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
s, ok := m.sessions[id]
return s, ok
}
func (m *Manager) List() []*SessionInfo {
m.mu.RLock()
defer m.mu.RUnlock()
list := make([]*SessionInfo, 0, len(m.sessions))
for _, s := range m.sessions {
list = append(list, s)
}
return list
}
func (m *Manager) SetState(id string, state SessionState) error {
m.mu.Lock()
defer m.mu.Unlock()
s, ok := m.sessions[id]
if !ok {
return fmt.Errorf("session %s not found", id)
}
s.State = state
return nil
}
func (m *Manager) Detach(id string) error {
return m.SetState(id, StateDetached)
}
func (m *Manager) Reattach(id, windowID string) error {
m.mu.Lock()
defer m.mu.Unlock()
s, ok := m.sessions[id]
if !ok {
return fmt.Errorf("session %s not found", id)
}
s.State = StateConnected
s.WindowID = windowID
return nil
}
func (m *Manager) Remove(id string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.sessions, id)
}
func (m *Manager) Count() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.sessions)
}

View File

@ -0,0 +1,64 @@
package session
import "testing"
func TestCreateSession(t *testing.T) {
m := NewManager()
s, err := m.Create(1, "ssh")
if err != nil {
t.Fatalf("Create() error: %v", err)
}
if s.ID == "" {
t.Error("session ID should not be empty")
}
if s.State != StateConnecting {
t.Errorf("State = %q, want %q", s.State, StateConnecting)
}
}
func TestMaxSessions(t *testing.T) {
m := NewManager()
for i := 0; i < MaxSessions; i++ {
_, err := m.Create(int64(i), "ssh")
if err != nil {
t.Fatalf("Create() error at %d: %v", i, err)
}
}
_, err := m.Create(999, "ssh")
if err == nil {
t.Error("Create() should fail at max sessions")
}
}
func TestDetachReattach(t *testing.T) {
m := NewManager()
s, _ := m.Create(1, "ssh")
m.SetState(s.ID, StateConnected)
if err := m.Detach(s.ID); err != nil {
t.Fatalf("Detach() error: %v", err)
}
got, _ := m.Get(s.ID)
if got.State != StateDetached {
t.Errorf("State = %q, want %q", got.State, StateDetached)
}
if err := m.Reattach(s.ID, "window-1"); err != nil {
t.Fatalf("Reattach() error: %v", err)
}
got, _ = m.Get(s.ID)
if got.State != StateConnected {
t.Errorf("State = %q, want %q", got.State, StateConnected)
}
}
func TestRemoveSession(t *testing.T) {
m := NewManager()
s, _ := m.Create(1, "ssh")
m.Remove(s.ID)
if m.Count() != 0 {
t.Error("session should have been removed")
}
}

View File

@ -0,0 +1,22 @@
package session
import "time"
type SessionState string
const (
StateConnecting SessionState = "connecting"
StateConnected SessionState = "connected"
StateDisconnected SessionState = "disconnected"
StateDetached SessionState = "detached"
)
type SessionInfo struct {
ID string `json:"id"`
ConnectionID int64 `json:"connectionId"`
Protocol string `json:"protocol"`
State SessionState `json:"state"`
WindowID string `json:"windowId"`
TabPosition int `json:"tabPosition"`
ConnectedAt time.Time `json:"connectedAt"`
}

View File

@ -0,0 +1,41 @@
package settings
import "database/sql"
type SettingsService struct {
db *sql.DB
}
func NewSettingsService(db *sql.DB) *SettingsService {
return &SettingsService{db: db}
}
func (s *SettingsService) Get(key string) (string, error) {
var value string
err := s.db.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value)
if err == sql.ErrNoRows {
return "", nil
}
return value, err
}
func (s *SettingsService) GetDefault(key, defaultValue string) string {
val, err := s.Get(key)
if err != nil || val == "" {
return defaultValue
}
return val
}
func (s *SettingsService) Set(key, value string) error {
_, err := s.db.Exec(
"INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = ?",
key, value, value,
)
return err
}
func (s *SettingsService) Delete(key string) error {
_, err := s.db.Exec("DELETE FROM settings WHERE key = ?", key)
return err
}

View File

@ -0,0 +1,70 @@
package settings
import (
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
)
func setupTestDB(t *testing.T) *SettingsService {
t.Helper()
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
if err := db.Migrate(d); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { d.Close() })
return NewSettingsService(d)
}
func TestSetAndGet(t *testing.T) {
s := setupTestDB(t)
if err := s.Set("theme", "dracula"); err != nil {
t.Fatalf("Set() error: %v", err)
}
val, err := s.Get("theme")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
if val != "dracula" {
t.Errorf("Get() = %q, want %q", val, "dracula")
}
}
func TestGetMissing(t *testing.T) {
s := setupTestDB(t)
val, err := s.Get("nonexistent")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
if val != "" {
t.Errorf("Get() = %q, want empty string", val)
}
}
func TestSetOverwrites(t *testing.T) {
s := setupTestDB(t)
s.Set("key", "value1")
s.Set("key", "value2")
val, _ := s.Get("key")
if val != "value2" {
t.Errorf("Get() = %q, want %q", val, "value2")
}
}
func TestGetWithDefault(t *testing.T) {
s := setupTestDB(t)
val := s.GetDefault("missing", "fallback")
if val != "fallback" {
t.Errorf("GetDefault() = %q, want %q", val, "fallback")
}
}

238
internal/sftp/service.go Normal file
View File

@ -0,0 +1,238 @@
package sftp
import (
"fmt"
"io"
"os"
"sort"
"strings"
"sync"
"github.com/pkg/sftp"
)
const MaxEditFileSize = 5 * 1024 * 1024 // 5MB
// FileEntry represents a file or directory in a remote filesystem.
type FileEntry struct {
Name string `json:"name"`
Path string `json:"path"`
Size int64 `json:"size"`
IsDir bool `json:"isDir"`
Permissions string `json:"permissions"`
ModTime string `json:"modTime"`
}
// SFTPService manages SFTP clients keyed by session ID.
type SFTPService struct {
clients map[string]*sftp.Client
mu sync.RWMutex
}
// NewSFTPService creates a new SFTPService.
func NewSFTPService() *SFTPService {
return &SFTPService{
clients: make(map[string]*sftp.Client),
}
}
// RegisterClient stores an SFTP client for a session.
func (s *SFTPService) RegisterClient(sessionID string, client *sftp.Client) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients[sessionID] = client
}
// RemoveClient removes and closes an SFTP client.
func (s *SFTPService) RemoveClient(sessionID string) {
s.mu.Lock()
client, ok := s.clients[sessionID]
if ok {
delete(s.clients, sessionID)
}
s.mu.Unlock()
if ok && client != nil {
client.Close()
}
}
// getClient returns the SFTP client for a session or an error if not found.
func (s *SFTPService) getClient(sessionID string) (*sftp.Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client, ok := s.clients[sessionID]
if !ok {
return nil, fmt.Errorf("no SFTP client for session %s", sessionID)
}
return client, nil
}
// SortEntries sorts file entries with directories first, then alphabetically by name.
func SortEntries(entries []FileEntry) {
sort.Slice(entries, func(i, j int) bool {
if entries[i].IsDir != entries[j].IsDir {
return entries[i].IsDir
}
return strings.ToLower(entries[i].Name) < strings.ToLower(entries[j].Name)
})
}
// fileInfoToEntry converts an os.FileInfo and its path into a FileEntry.
func fileInfoToEntry(info os.FileInfo, path string) FileEntry {
return FileEntry{
Name: info.Name(),
Path: path,
Size: info.Size(),
IsDir: info.IsDir(),
Permissions: info.Mode().Perm().String(),
ModTime: info.ModTime().UTC().Format("2006-01-02T15:04:05Z"),
}
}
// List returns directory contents sorted (dirs first, then files alphabetically).
func (s *SFTPService) List(sessionID string, path string) ([]FileEntry, error) {
client, err := s.getClient(sessionID)
if err != nil {
return nil, err
}
infos, err := client.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("read directory %s: %w", path, err)
}
entries := make([]FileEntry, 0, len(infos))
for _, info := range infos {
entryPath := path
if !strings.HasSuffix(entryPath, "/") {
entryPath += "/"
}
entryPath += info.Name()
entries = append(entries, fileInfoToEntry(info, entryPath))
}
SortEntries(entries)
return entries, nil
}
// ReadFile reads a file (max 5MB). Returns content as string.
func (s *SFTPService) ReadFile(sessionID string, path string) (string, error) {
client, err := s.getClient(sessionID)
if err != nil {
return "", err
}
info, err := client.Stat(path)
if err != nil {
return "", fmt.Errorf("stat %s: %w", path, err)
}
if info.IsDir() {
return "", fmt.Errorf("%s is a directory", path)
}
if info.Size() > MaxEditFileSize {
return "", fmt.Errorf("file %s is %d bytes, exceeds max edit size of %d bytes", path, info.Size(), MaxEditFileSize)
}
f, err := client.Open(path)
if err != nil {
return "", fmt.Errorf("open %s: %w", path, err)
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return "", fmt.Errorf("read %s: %w", path, err)
}
return string(data), nil
}
// WriteFile writes content to a file.
func (s *SFTPService) WriteFile(sessionID string, path string, content string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
f, err := client.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
if _, err := f.Write([]byte(content)); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
return nil
}
// Mkdir creates a directory.
func (s *SFTPService) Mkdir(sessionID string, path string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
if err := client.Mkdir(path); err != nil {
return fmt.Errorf("mkdir %s: %w", path, err)
}
return nil
}
// Delete removes a file or empty directory.
func (s *SFTPService) Delete(sessionID string, path string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
info, err := client.Stat(path)
if err != nil {
return fmt.Errorf("stat %s: %w", path, err)
}
if info.IsDir() {
if err := client.RemoveDirectory(path); err != nil {
return fmt.Errorf("remove directory %s: %w", path, err)
}
} else {
if err := client.Remove(path); err != nil {
return fmt.Errorf("remove %s: %w", path, err)
}
}
return nil
}
// Rename renames/moves a file.
func (s *SFTPService) Rename(sessionID string, oldPath, newPath string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
if err := client.Rename(oldPath, newPath); err != nil {
return fmt.Errorf("rename %s to %s: %w", oldPath, newPath, err)
}
return nil
}
// Stat returns info about a file/directory.
func (s *SFTPService) Stat(sessionID string, path string) (*FileEntry, error) {
client, err := s.getClient(sessionID)
if err != nil {
return nil, err
}
info, err := client.Stat(path)
if err != nil {
return nil, fmt.Errorf("stat %s: %w", path, err)
}
entry := fileInfoToEntry(info, path)
return &entry, nil
}

View File

@ -0,0 +1,119 @@
package sftp
import (
"testing"
)
func TestNewSFTPService(t *testing.T) {
svc := NewSFTPService()
if svc == nil {
t.Fatal("nil")
}
}
func TestListWithoutClient(t *testing.T) {
svc := NewSFTPService()
_, err := svc.List("nonexistent", "/")
if err == nil {
t.Error("should error without client")
}
}
func TestReadFileWithoutClient(t *testing.T) {
svc := NewSFTPService()
_, err := svc.ReadFile("nonexistent", "/etc/hosts")
if err == nil {
t.Error("should error without client")
}
}
func TestWriteFileWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.WriteFile("nonexistent", "/tmp/test", "data")
if err == nil {
t.Error("should error without client")
}
}
func TestMkdirWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.Mkdir("nonexistent", "/tmp/newdir")
if err == nil {
t.Error("should error without client")
}
}
func TestDeleteWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.Delete("nonexistent", "/tmp/file")
if err == nil {
t.Error("should error without client")
}
}
func TestRenameWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.Rename("nonexistent", "/old", "/new")
if err == nil {
t.Error("should error without client")
}
}
func TestStatWithoutClient(t *testing.T) {
svc := NewSFTPService()
_, err := svc.Stat("nonexistent", "/tmp")
if err == nil {
t.Error("should error without client")
}
}
func TestFileEntrySorting(t *testing.T) {
// Test that SortEntries puts dirs first, then alpha
entries := []FileEntry{
{Name: "zebra.txt", IsDir: false},
{Name: "alpha", IsDir: true},
{Name: "beta.conf", IsDir: false},
{Name: "omega", IsDir: true},
}
SortEntries(entries)
if entries[0].Name != "alpha" {
t.Errorf("[0] = %s, want alpha", entries[0].Name)
}
if entries[1].Name != "omega" {
t.Errorf("[1] = %s, want omega", entries[1].Name)
}
if entries[2].Name != "beta.conf" {
t.Errorf("[2] = %s, want beta.conf", entries[2].Name)
}
if entries[3].Name != "zebra.txt" {
t.Errorf("[3] = %s, want zebra.txt", entries[3].Name)
}
}
func TestSortEntriesEmpty(t *testing.T) {
entries := []FileEntry{}
SortEntries(entries)
if len(entries) != 0 {
t.Errorf("expected empty slice, got %d entries", len(entries))
}
}
func TestSortEntriesCaseInsensitive(t *testing.T) {
entries := []FileEntry{
{Name: "Zebra", IsDir: false},
{Name: "alpha", IsDir: false},
}
SortEntries(entries)
if entries[0].Name != "alpha" {
t.Errorf("[0] = %s, want alpha", entries[0].Name)
}
if entries[1].Name != "Zebra" {
t.Errorf("[1] = %s, want Zebra", entries[1].Name)
}
}
func TestMaxEditFileSize(t *testing.T) {
if MaxEditFileSize != 5*1024*1024 {
t.Errorf("MaxEditFileSize = %d, want %d", MaxEditFileSize, 5*1024*1024)
}
}

259
internal/ssh/service.go Normal file
View File

@ -0,0 +1,259 @@
package ssh
import (
"database/sql"
"fmt"
"io"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/ssh"
)
// OutputHandler is called when data is read from an SSH session's stdout.
// In production this will emit Wails events; for testing, a simple callback.
type OutputHandler func(sessionID string, data []byte)
// SSHSession represents an active SSH connection with its PTY shell session.
type SSHSession struct {
ID string
Client *ssh.Client
Session *ssh.Session
Stdin io.WriteCloser
ConnID int64
Hostname string
Port int
Username string
Connected time.Time
mu sync.Mutex
}
// SSHService manages SSH connections and their associated sessions.
type SSHService struct {
sessions map[string]*SSHSession
mu sync.RWMutex
db *sql.DB
outputHandler OutputHandler
}
// NewSSHService creates a new SSHService. The outputHandler is called when data
// arrives from a session's stdout. Pass nil if output handling is not needed.
func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService {
return &SSHService{
sessions: make(map[string]*SSHSession),
db: db,
outputHandler: outputHandler,
}
}
// Connect dials an SSH server, opens a session with a PTY and shell, and
// launches a goroutine to read stdout. Returns the session ID.
func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) {
addr := fmt.Sprintf("%s:%d", hostname, port)
config := &ssh.ClientConfig{
User: username,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 15 * time.Second,
}
client, err := ssh.Dial("tcp", addr, config)
if err != nil {
return "", fmt.Errorf("ssh dial %s: %w", addr, err)
}
session, err := client.NewSession()
if err != nil {
client.Close()
return "", fmt.Errorf("new session: %w", err)
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("request pty: %w", err)
}
stdin, err := session.StdinPipe()
if err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("stdin pipe: %w", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("stdout pipe: %w", err)
}
if err := session.Shell(); err != nil {
session.Close()
client.Close()
return "", fmt.Errorf("start shell: %w", err)
}
sessionID := uuid.NewString()
sshSession := &SSHSession{
ID: sessionID,
Client: client,
Session: session,
Stdin: stdin,
Hostname: hostname,
Port: port,
Username: username,
Connected: time.Now(),
}
s.mu.Lock()
s.sessions[sessionID] = sshSession
s.mu.Unlock()
// Launch goroutine to read stdout and forward data via the output handler
go s.readLoop(sessionID, stdout)
return sessionID, nil
}
// readLoop continuously reads from the session stdout and calls the output
// handler with data. It stops when the reader returns an error (typically EOF
// when the session closes).
func (s *SSHService) readLoop(sessionID string, reader io.Reader) {
buf := make([]byte, 32*1024)
for {
n, err := reader.Read(buf)
if n > 0 && s.outputHandler != nil {
data := make([]byte, n)
copy(data, buf[:n])
s.outputHandler(sessionID, data)
}
if err != nil {
break
}
}
}
// Write sends data to the session's stdin.
func (s *SSHService) Write(sessionID string, data string) error {
s.mu.RLock()
sess, ok := s.sessions[sessionID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("session %s not found", sessionID)
}
sess.mu.Lock()
defer sess.mu.Unlock()
if sess.Stdin == nil {
return fmt.Errorf("session %s stdin is closed", sessionID)
}
_, err := sess.Stdin.Write([]byte(data))
if err != nil {
return fmt.Errorf("write to session %s: %w", sessionID, err)
}
return nil
}
// Resize sends a window-change request to the remote PTY.
func (s *SSHService) Resize(sessionID string, cols, rows int) error {
s.mu.RLock()
sess, ok := s.sessions[sessionID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("session %s not found", sessionID)
}
sess.mu.Lock()
defer sess.mu.Unlock()
if sess.Session == nil {
return fmt.Errorf("session %s is closed", sessionID)
}
if err := sess.Session.WindowChange(rows, cols); err != nil {
return fmt.Errorf("resize session %s: %w", sessionID, err)
}
return nil
}
// Disconnect closes the SSH session and client, and removes it from tracking.
func (s *SSHService) Disconnect(sessionID string) error {
s.mu.Lock()
sess, ok := s.sessions[sessionID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("session %s not found", sessionID)
}
delete(s.sessions, sessionID)
s.mu.Unlock()
sess.mu.Lock()
defer sess.mu.Unlock()
if sess.Stdin != nil {
sess.Stdin.Close()
}
if sess.Session != nil {
sess.Session.Close()
}
if sess.Client != nil {
sess.Client.Close()
}
return nil
}
// GetSession returns the SSHSession for the given ID, or false if not found.
func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
sess, ok := s.sessions[sessionID]
return sess, ok
}
// ListSessions returns all active SSH sessions.
func (s *SSHService) ListSessions() []*SSHSession {
s.mu.RLock()
defer s.mu.RUnlock()
list := make([]*SSHSession, 0, len(s.sessions))
for _, sess := range s.sessions {
list = append(list, sess)
}
return list
}
// BuildPasswordAuth creates an ssh.AuthMethod for password authentication.
func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod {
return ssh.Password(password)
}
// BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key.
// If the key is encrypted, pass the passphrase; otherwise pass an empty string.
func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) {
var signer ssh.Signer
var err error
if passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
} else {
signer, err = ssh.ParsePrivateKey(privateKey)
}
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return ssh.PublicKeys(signer), nil
}

View File

@ -0,0 +1,148 @@
package ssh
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestNewSSHService(t *testing.T) {
svc := NewSSHService(nil, nil)
if svc == nil {
t.Fatal("NewSSHService returned nil")
}
if len(svc.ListSessions()) != 0 {
t.Error("new service should have no sessions")
}
}
func TestBuildPasswordAuth(t *testing.T) {
svc := NewSSHService(nil, nil)
auth := svc.BuildPasswordAuth("mypassword")
if auth == nil {
t.Error("BuildPasswordAuth returned nil")
}
}
func TestBuildKeyAuth(t *testing.T) {
svc := NewSSHService(nil, nil)
// Generate a test Ed25519 key
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("GenerateKey error: %v", err)
}
pemBlock, err := ssh.MarshalPrivateKey(priv, "")
if err != nil {
t.Fatalf("MarshalPrivateKey error: %v", err)
}
keyBytes := pem.EncodeToMemory(pemBlock)
auth, err := svc.BuildKeyAuth(keyBytes, "")
if err != nil {
t.Fatalf("BuildKeyAuth error: %v", err)
}
if auth == nil {
t.Error("BuildKeyAuth returned nil")
}
}
func TestBuildKeyAuthInvalidKey(t *testing.T) {
svc := NewSSHService(nil, nil)
_, err := svc.BuildKeyAuth([]byte("not a key"), "")
if err == nil {
t.Error("BuildKeyAuth should fail with invalid key")
}
}
func TestSessionTracking(t *testing.T) {
svc := NewSSHService(nil, nil)
// Manually add a session to test tracking
svc.mu.Lock()
svc.sessions["test-123"] = &SSHSession{
ID: "test-123",
Hostname: "192.168.1.4",
Port: 22,
Username: "vstockwell",
Connected: time.Now(),
}
svc.mu.Unlock()
s, ok := svc.GetSession("test-123")
if !ok {
t.Fatal("session not found")
}
if s.Hostname != "192.168.1.4" {
t.Errorf("Hostname = %q, want %q", s.Hostname, "192.168.1.4")
}
sessions := svc.ListSessions()
if len(sessions) != 1 {
t.Errorf("ListSessions() = %d, want 1", len(sessions))
}
}
func TestGetSessionNotFound(t *testing.T) {
svc := NewSSHService(nil, nil)
_, ok := svc.GetSession("nonexistent")
if ok {
t.Error("GetSession should return false for nonexistent session")
}
}
func TestWriteNotFound(t *testing.T) {
svc := NewSSHService(nil, nil)
err := svc.Write("nonexistent", "data")
if err == nil {
t.Error("Write should fail for nonexistent session")
}
}
func TestResizeNotFound(t *testing.T) {
svc := NewSSHService(nil, nil)
err := svc.Resize("nonexistent", 80, 24)
if err == nil {
t.Error("Resize should fail for nonexistent session")
}
}
func TestDisconnectNotFound(t *testing.T) {
svc := NewSSHService(nil, nil)
err := svc.Disconnect("nonexistent")
if err == nil {
t.Error("Disconnect should fail for nonexistent session")
}
}
func TestDisconnectRemovesSession(t *testing.T) {
svc := NewSSHService(nil, nil)
// Manually add a session with nil Client/Session/Stdin (no real connection)
svc.mu.Lock()
svc.sessions["test-dc"] = &SSHSession{
ID: "test-dc",
Hostname: "10.0.0.1",
Port: 22,
Username: "admin",
Connected: time.Now(),
}
svc.mu.Unlock()
if err := svc.Disconnect("test-dc"); err != nil {
t.Fatalf("Disconnect error: %v", err)
}
_, ok := svc.GetSession("test-dc")
if ok {
t.Error("session should be removed after Disconnect")
}
if len(svc.ListSessions()) != 0 {
t.Error("ListSessions should be empty after Disconnect")
}
}

View File

@ -0,0 +1,94 @@
package theme
type Theme struct {
ID int64 `json:"id"`
Name string `json:"name"`
Foreground string `json:"foreground"`
Background string `json:"background"`
Cursor string `json:"cursor"`
Black string `json:"black"`
Red string `json:"red"`
Green string `json:"green"`
Yellow string `json:"yellow"`
Blue string `json:"blue"`
Magenta string `json:"magenta"`
Cyan string `json:"cyan"`
White string `json:"white"`
BrightBlack string `json:"brightBlack"`
BrightRed string `json:"brightRed"`
BrightGreen string `json:"brightGreen"`
BrightYellow string `json:"brightYellow"`
BrightBlue string `json:"brightBlue"`
BrightMagenta string `json:"brightMagenta"`
BrightCyan string `json:"brightCyan"`
BrightWhite string `json:"brightWhite"`
SelectionBg string `json:"selectionBg,omitempty"`
SelectionFg string `json:"selectionFg,omitempty"`
IsBuiltin bool `json:"isBuiltin"`
}
var BuiltinThemes = []Theme{
{
Name: "Dracula", IsBuiltin: true,
Foreground: "#f8f8f2", Background: "#282a36", Cursor: "#f8f8f2",
Black: "#21222c", Red: "#ff5555", Green: "#50fa7b", Yellow: "#f1fa8c",
Blue: "#bd93f9", Magenta: "#ff79c6", Cyan: "#8be9fd", White: "#f8f8f2",
BrightBlack: "#6272a4", BrightRed: "#ff6e6e", BrightGreen: "#69ff94",
BrightYellow: "#ffffa5", BrightBlue: "#d6acff", BrightMagenta: "#ff92df",
BrightCyan: "#a4ffff", BrightWhite: "#ffffff",
},
{
Name: "Nord", IsBuiltin: true,
Foreground: "#d8dee9", Background: "#2e3440", Cursor: "#d8dee9",
Black: "#3b4252", Red: "#bf616a", Green: "#a3be8c", Yellow: "#ebcb8b",
Blue: "#81a1c1", Magenta: "#b48ead", Cyan: "#88c0d0", White: "#e5e9f0",
BrightBlack: "#4c566a", BrightRed: "#bf616a", BrightGreen: "#a3be8c",
BrightYellow: "#ebcb8b", BrightBlue: "#81a1c1", BrightMagenta: "#b48ead",
BrightCyan: "#8fbcbb", BrightWhite: "#eceff4",
},
{
Name: "Monokai", IsBuiltin: true,
Foreground: "#f8f8f2", Background: "#272822", Cursor: "#f8f8f0",
Black: "#272822", Red: "#f92672", Green: "#a6e22e", Yellow: "#f4bf75",
Blue: "#66d9ef", Magenta: "#ae81ff", Cyan: "#a1efe4", White: "#f8f8f2",
BrightBlack: "#75715e", BrightRed: "#f92672", BrightGreen: "#a6e22e",
BrightYellow: "#f4bf75", BrightBlue: "#66d9ef", BrightMagenta: "#ae81ff",
BrightCyan: "#a1efe4", BrightWhite: "#f9f8f5",
},
{
Name: "One Dark", IsBuiltin: true,
Foreground: "#abb2bf", Background: "#282c34", Cursor: "#528bff",
Black: "#282c34", Red: "#e06c75", Green: "#98c379", Yellow: "#e5c07b",
Blue: "#61afef", Magenta: "#c678dd", Cyan: "#56b6c2", White: "#abb2bf",
BrightBlack: "#545862", BrightRed: "#e06c75", BrightGreen: "#98c379",
BrightYellow: "#e5c07b", BrightBlue: "#61afef", BrightMagenta: "#c678dd",
BrightCyan: "#56b6c2", BrightWhite: "#c8ccd4",
},
{
Name: "Solarized Dark", IsBuiltin: true,
Foreground: "#839496", Background: "#002b36", Cursor: "#839496",
Black: "#073642", Red: "#dc322f", Green: "#859900", Yellow: "#b58900",
Blue: "#268bd2", Magenta: "#d33682", Cyan: "#2aa198", White: "#eee8d5",
BrightBlack: "#002b36", BrightRed: "#cb4b16", BrightGreen: "#586e75",
BrightYellow: "#657b83", BrightBlue: "#839496", BrightMagenta: "#6c71c4",
BrightCyan: "#93a1a1", BrightWhite: "#fdf6e3",
},
{
Name: "Gruvbox Dark", IsBuiltin: true,
Foreground: "#ebdbb2", Background: "#282828", Cursor: "#ebdbb2",
Black: "#282828", Red: "#cc241d", Green: "#98971a", Yellow: "#d79921",
Blue: "#458588", Magenta: "#b16286", Cyan: "#689d6a", White: "#a89984",
BrightBlack: "#928374", BrightRed: "#fb4934", BrightGreen: "#b8bb26",
BrightYellow: "#fabd2f", BrightBlue: "#83a598", BrightMagenta: "#d3869b",
BrightCyan: "#8ec07c", BrightWhite: "#ebdbb2",
},
{
Name: "MobaXTerm Classic", IsBuiltin: true,
Foreground: "#ececec", Background: "#242424", Cursor: "#b4b4c0",
Black: "#000000", Red: "#aa4244", Green: "#7e8d53", Yellow: "#e4b46d",
Blue: "#6e9aba", Magenta: "#9e5085", Cyan: "#80d5cf", White: "#cccccc",
BrightBlack: "#808080", BrightRed: "#cc7b7d", BrightGreen: "#a5b17c",
BrightYellow: "#ecc995", BrightBlue: "#96b6cd", BrightMagenta: "#c083ac",
BrightCyan: "#a9e2de", BrightWhite: "#cccccc",
},
}

85
internal/theme/service.go Normal file
View File

@ -0,0 +1,85 @@
package theme
import (
"database/sql"
"fmt"
)
type ThemeService struct {
db *sql.DB
}
func NewThemeService(db *sql.DB) *ThemeService {
return &ThemeService{db: db}
}
func (s *ThemeService) SeedBuiltins() error {
for _, t := range BuiltinThemes {
_, err := s.db.Exec(
`INSERT OR IGNORE INTO themes (name, foreground, background, cursor,
black, red, green, yellow, blue, magenta, cyan, white,
bright_black, bright_red, bright_green, bright_yellow, bright_blue,
bright_magenta, bright_cyan, bright_white, selection_bg, selection_fg, is_builtin)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,1)`,
t.Name, t.Foreground, t.Background, t.Cursor,
t.Black, t.Red, t.Green, t.Yellow, t.Blue, t.Magenta, t.Cyan, t.White,
t.BrightBlack, t.BrightRed, t.BrightGreen, t.BrightYellow, t.BrightBlue,
t.BrightMagenta, t.BrightCyan, t.BrightWhite, t.SelectionBg, t.SelectionFg,
)
if err != nil {
return fmt.Errorf("seed theme %s: %w", t.Name, err)
}
}
return nil
}
func (s *ThemeService) List() ([]Theme, error) {
rows, err := s.db.Query(
`SELECT id, name, foreground, background, cursor,
black, red, green, yellow, blue, magenta, cyan, white,
bright_black, bright_red, bright_green, bright_yellow, bright_blue,
bright_magenta, bright_cyan, bright_white,
COALESCE(selection_bg,''), COALESCE(selection_fg,''), is_builtin
FROM themes ORDER BY is_builtin DESC, name`)
if err != nil {
return nil, err
}
defer rows.Close()
var themes []Theme
for rows.Next() {
var t Theme
if err := rows.Scan(&t.ID, &t.Name, &t.Foreground, &t.Background, &t.Cursor,
&t.Black, &t.Red, &t.Green, &t.Yellow, &t.Blue, &t.Magenta, &t.Cyan, &t.White,
&t.BrightBlack, &t.BrightRed, &t.BrightGreen, &t.BrightYellow, &t.BrightBlue,
&t.BrightMagenta, &t.BrightCyan, &t.BrightWhite,
&t.SelectionBg, &t.SelectionFg, &t.IsBuiltin); err != nil {
return nil, err
}
themes = append(themes, t)
}
if err := rows.Err(); err != nil {
return nil, err
}
return themes, nil
}
func (s *ThemeService) GetByName(name string) (*Theme, error) {
var t Theme
err := s.db.QueryRow(
`SELECT id, name, foreground, background, cursor,
black, red, green, yellow, blue, magenta, cyan, white,
bright_black, bright_red, bright_green, bright_yellow, bright_blue,
bright_magenta, bright_cyan, bright_white,
COALESCE(selection_bg,''), COALESCE(selection_fg,''), is_builtin
FROM themes WHERE name = ?`, name,
).Scan(&t.ID, &t.Name, &t.Foreground, &t.Background, &t.Cursor,
&t.Black, &t.Red, &t.Green, &t.Yellow, &t.Blue, &t.Magenta, &t.Cyan, &t.White,
&t.BrightBlack, &t.BrightRed, &t.BrightGreen, &t.BrightYellow, &t.BrightBlue,
&t.BrightMagenta, &t.BrightCyan, &t.BrightWhite,
&t.SelectionBg, &t.SelectionFg, &t.IsBuiltin)
if err != nil {
return nil, fmt.Errorf("get theme %s: %w", name, err)
}
return &t, nil
}

View File

@ -0,0 +1,60 @@
package theme
import (
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
)
func setupTestDB(t *testing.T) *ThemeService {
t.Helper()
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
if err := db.Migrate(d); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { d.Close() })
return NewThemeService(d)
}
func TestSeedBuiltins(t *testing.T) {
svc := setupTestDB(t)
if err := svc.SeedBuiltins(); err != nil {
t.Fatalf("SeedBuiltins() error: %v", err)
}
themes, err := svc.List()
if err != nil {
t.Fatalf("List() error: %v", err)
}
if len(themes) != len(BuiltinThemes) {
t.Errorf("len(themes) = %d, want %d", len(themes), len(BuiltinThemes))
}
}
func TestSeedBuiltinsIdempotent(t *testing.T) {
svc := setupTestDB(t)
svc.SeedBuiltins()
svc.SeedBuiltins()
themes, _ := svc.List()
if len(themes) != len(BuiltinThemes) {
t.Errorf("len(themes) = %d after double seed, want %d", len(themes), len(BuiltinThemes))
}
}
func TestGetByName(t *testing.T) {
svc := setupTestDB(t)
svc.SeedBuiltins()
theme, err := svc.GetByName("Dracula")
if err != nil {
t.Fatalf("GetByName() error: %v", err)
}
if theme.Background != "#282a36" {
t.Errorf("Background = %q, want %q", theme.Background, "#282a36")
}
}

95
internal/vault/service.go Normal file
View File

@ -0,0 +1,95 @@
package vault
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"golang.org/x/crypto/argon2"
)
const (
argonTime = 3
argonMemory = 64 * 1024
argonThreads = 4
argonKeyLen = 32
)
func DeriveKey(password string, salt []byte) []byte {
return argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
}
func GenerateSalt() ([]byte, error) {
salt := make([]byte, 32)
if _, err := rand.Read(salt); err != nil {
return nil, fmt.Errorf("generate salt: %w", err)
}
return salt, nil
}
type VaultService struct {
key []byte
}
func NewVaultService(key []byte) *VaultService {
return &VaultService{key: key}
}
func (v *VaultService) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(v.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create GCM: %w", err)
}
iv := make([]byte, gcm.NonceSize())
if _, err := rand.Read(iv); err != nil {
return "", fmt.Errorf("generate IV: %w", err)
}
sealed := gcm.Seal(nil, iv, []byte(plaintext), nil)
return fmt.Sprintf("v1:%s:%s", hex.EncodeToString(iv), hex.EncodeToString(sealed)), nil
}
func (v *VaultService) Decrypt(encrypted string) (string, error) {
parts := strings.SplitN(encrypted, ":", 3)
if len(parts) != 3 || parts[0] != "v1" {
return "", errors.New("invalid encrypted format: expected v1:{iv}:{sealed}")
}
iv, err := hex.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("decode IV: %w", err)
}
sealed, err := hex.DecodeString(parts[2])
if err != nil {
return "", fmt.Errorf("decode sealed data: %w", err)
}
block, err := aes.NewCipher(v.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create GCM: %w", err)
}
plaintext, err := gcm.Open(nil, iv, sealed, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@ -0,0 +1,104 @@
package vault
import (
"strings"
"testing"
)
func TestDeriveKeyConsistent(t *testing.T) {
salt := []byte("test-salt-exactly-32-bytes-long!")
key1 := DeriveKey("mypassword", salt)
key2 := DeriveKey("mypassword", salt)
if len(key1) != 32 {
t.Errorf("key length = %d, want 32", len(key1))
}
if string(key1) != string(key2) {
t.Error("same password+salt produced different keys")
}
}
func TestDeriveKeyDifferentPasswords(t *testing.T) {
salt := []byte("test-salt-exactly-32-bytes-long!")
key1 := DeriveKey("password1", salt)
key2 := DeriveKey("password2", salt)
if string(key1) == string(key2) {
t.Error("different passwords produced same key")
}
}
func TestEncryptDecryptRoundTrip(t *testing.T) {
key := DeriveKey("testpassword", []byte("test-salt-exactly-32-bytes-long!"))
vs := NewVaultService(key)
plaintext := "super-secret-ssh-key-data"
encrypted, err := vs.Encrypt(plaintext)
if err != nil {
t.Fatalf("Encrypt() error: %v", err)
}
if !strings.HasPrefix(encrypted, "v1:") {
t.Errorf("encrypted does not start with v1: prefix: %q", encrypted[:10])
}
decrypted, err := vs.Decrypt(encrypted)
if err != nil {
t.Fatalf("Decrypt() error: %v", err)
}
if decrypted != plaintext {
t.Errorf("Decrypt() = %q, want %q", decrypted, plaintext)
}
}
func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
key := DeriveKey("testpassword", []byte("test-salt-exactly-32-bytes-long!"))
vs := NewVaultService(key)
enc1, _ := vs.Encrypt("same-data")
enc2, _ := vs.Encrypt("same-data")
if enc1 == enc2 {
t.Error("two encryptions of same data produced identical ciphertext (IV reuse)")
}
}
func TestDecryptWrongKey(t *testing.T) {
key1 := DeriveKey("password1", []byte("test-salt-exactly-32-bytes-long!"))
key2 := DeriveKey("password2", []byte("test-salt-exactly-32-bytes-long!"))
vs1 := NewVaultService(key1)
vs2 := NewVaultService(key2)
encrypted, _ := vs1.Encrypt("secret")
_, err := vs2.Decrypt(encrypted)
if err == nil {
t.Error("Decrypt() with wrong key should return error")
}
}
func TestDecryptInvalidFormat(t *testing.T) {
key := DeriveKey("test", []byte("test-salt-exactly-32-bytes-long!"))
vs := NewVaultService(key)
_, err := vs.Decrypt("not-valid-format")
if err == nil {
t.Error("Decrypt() with invalid format should return error")
}
}
func TestGenerateSalt(t *testing.T) {
salt1, err := GenerateSalt()
if err != nil {
t.Fatalf("GenerateSalt() error: %v", err)
}
if len(salt1) != 32 {
t.Errorf("salt length = %d, want 32", len(salt1))
}
salt2, _ := GenerateSalt()
if string(salt1) == string(salt2) {
t.Error("two calls to GenerateSalt produced identical salt")
}
}

51
main.go Normal file
View File

@ -0,0 +1,51 @@
package main
import (
"embed"
"log"
"log/slog"
wraithapp "github.com/vstockwell/wraith/internal/app"
"github.com/wailsapp/wails/v3/pkg/application"
)
// version is set at build time via -ldflags "-X main.version=..."
var version = "dev"
//go:embed all:frontend/dist
var assets embed.FS
func main() {
slog.Info("starting Wraith")
wraith, err := wraithapp.New()
if err != nil {
log.Fatalf("failed to initialize: %v", err)
}
app := application.New(application.Options{
Name: "Wraith",
Description: "SSH + RDP + SFTP Desktop Client",
Services: []application.Service{
application.NewService(wraith),
application.NewService(wraith.Connections),
application.NewService(wraith.Themes),
application.NewService(wraith.Settings),
},
Assets: application.AssetOptions{
Handler: application.BundledAssetFileServer(assets),
},
})
app.Window.NewWithOptions(application.WebviewWindowOptions{
Title: "Wraith",
Width: 1400,
Height: 900,
URL: "/",
BackgroundColour: application.NewRGBA(13, 17, 23, 255),
})
if err := app.Run(); err != nil {
log.Fatal(err)
}
}