feat: SFTP service + credential service with encrypted key/password storage
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a75e21138e
commit
6e25a646d3
58
go.mod
Normal file
58
go.mod
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
module github.com/vstockwell/wraith
|
||||||
|
|
||||||
|
go 1.26.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/pkg/sftp v1.13.10
|
||||||
|
github.com/wailsapp/wails/v3 v3.0.0-alpha.74
|
||||||
|
golang.org/x/crypto v0.49.0
|
||||||
|
modernc.org/sqlite v1.46.2
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
dario.cat/mergo v1.0.2 // indirect
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
|
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||||
|
github.com/adrg/xdg v0.5.3 // indirect
|
||||||
|
github.com/bep/debounce v1.2.1 // indirect
|
||||||
|
github.com/cloudflare/circl v1.6.3 // indirect
|
||||||
|
github.com/coder/websocket v1.8.14 // indirect
|
||||||
|
github.com/cyphar/filepath-securejoin v0.6.1 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/ebitengine/purego v0.9.1 // indirect
|
||||||
|
github.com/emirpasic/gods v1.18.1 // indirect
|
||||||
|
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
||||||
|
github.com/go-git/go-billy/v5 v5.7.0 // indirect
|
||||||
|
github.com/go-git/go-git/v5 v5.16.4 // indirect
|
||||||
|
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||||
|
github.com/godbus/dbus/v5 v5.2.2 // indirect
|
||||||
|
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||||
|
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
|
||||||
|
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
|
||||||
|
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
github.com/kr/fs v0.1.0 // indirect
|
||||||
|
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
|
||||||
|
github.com/leaanthony/u v1.1.1 // indirect
|
||||||
|
github.com/lmittmann/tint v1.1.2 // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
|
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
|
github.com/samber/lo v1.52.0 // indirect
|
||||||
|
github.com/sergi/go-diff v1.4.0 // indirect
|
||||||
|
github.com/skeema/knownhosts v1.3.2 // indirect
|
||||||
|
github.com/wailsapp/go-webview2 v1.0.23 // indirect
|
||||||
|
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||||
|
golang.org/x/net v0.51.0 // indirect
|
||||||
|
golang.org/x/sys v0.42.0 // indirect
|
||||||
|
golang.org/x/text v0.35.0 // indirect
|
||||||
|
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||||
|
modernc.org/libc v1.70.0 // indirect
|
||||||
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
modernc.org/memory v1.11.0 // indirect
|
||||||
|
)
|
||||||
197
go.sum
Normal file
197
go.sum
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||||
|
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||||
|
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
|
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
|
||||||
|
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
||||||
|
github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78=
|
||||||
|
github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ=
|
||||||
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||||
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||||
|
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||||
|
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||||
|
github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY=
|
||||||
|
github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0=
|
||||||
|
github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
|
||||||
|
github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
|
||||||
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
|
github.com/cyphar/filepath-securejoin v0.6.1 h1:5CeZ1jPXEiYt3+Z6zqprSAgSWiggmpVyciv8syjIpVE=
|
||||||
|
github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A=
|
||||||
|
github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
|
github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
|
||||||
|
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||||
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
|
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
|
||||||
|
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
|
||||||
|
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI=
|
||||||
|
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic=
|
||||||
|
github.com/go-git/go-billy/v5 v5.7.0 h1:83lBUJhGWhYp0ngzCMSgllhUSuoHP1iEWYjsPl9nwqM=
|
||||||
|
github.com/go-git/go-billy/v5 v5.7.0/go.mod h1:/1IUejTKH8xipsAcdfcSAlUlo2J7lkYV8GTKxAT/L3E=
|
||||||
|
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4=
|
||||||
|
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
|
||||||
|
github.com/go-git/go-git/v5 v5.16.4 h1:7ajIEZHZJULcyJebDLo99bGgS0jRrOxzZG4uCk2Yb2Y=
|
||||||
|
github.com/go-git/go-git/v5 v5.16.4/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8=
|
||||||
|
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
|
||||||
|
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
|
||||||
|
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||||
|
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
||||||
|
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
||||||
|
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||||
|
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
|
||||||
|
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
|
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
|
||||||
|
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
|
||||||
|
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 h1:njuLRcjAuMKr7kI3D85AXWkw6/+v9PwtV6M6o11sWHQ=
|
||||||
|
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1/go.mod h1:alcuEEnZsY1WQsagKhZDsoPCRoOijYqhZvPwLG0kzVs=
|
||||||
|
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
|
||||||
|
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
|
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||||
|
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||||
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/leaanthony/go-ansi-parser v1.6.1 h1:xd8bzARK3dErqkPFtoF9F3/HgN8UQk0ed1YDKpEz01A=
|
||||||
|
github.com/leaanthony/go-ansi-parser v1.6.1/go.mod h1:+vva/2y4alzVmmIEpk9QDhA7vLC5zKDTRwfZGOp3IWU=
|
||||||
|
github.com/leaanthony/u v1.1.1 h1:TUFjwDGlNX+WuwVEzDqQwC2lOv0P4uhTQw7CMFdiK7M=
|
||||||
|
github.com/leaanthony/u v1.1.1/go.mod h1:9+o6hejoRljvZ3BzdYlVL0JYCwtnAsVuN9pVTQcaRfI=
|
||||||
|
github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w=
|
||||||
|
github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE=
|
||||||
|
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||||
|
github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ=
|
||||||
|
github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||||
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||||
|
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||||
|
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||||
|
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
|
||||||
|
github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
|
github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw=
|
||||||
|
github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
|
||||||
|
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||||
|
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||||
|
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||||
|
github.com/skeema/knownhosts v1.3.2 h1:EDL9mgf4NzwMXCTfaxSD/o/a5fxDw/xL9nkU28JjdBg=
|
||||||
|
github.com/skeema/knownhosts v1.3.2/go.mod h1:bEg3iQAuw+jyiw+484wwFJoKSLwcfd7fqRy+N0QTiow=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/wailsapp/go-webview2 v1.0.23 h1:jmv8qhz1lHibCc79bMM/a/FqOnnzOGEisLav+a0b9P0=
|
||||||
|
github.com/wailsapp/go-webview2 v1.0.23/go.mod h1:qJmWAmAmaniuKGZPWwne+uor3AHMB5PFhqiK0Bbj8kc=
|
||||||
|
github.com/wailsapp/wails/v3 v3.0.0-alpha.74 h1:wRm1EiDQtxDisXk46NtpiBH90STwfKp36NrTDwOEdxw=
|
||||||
|
github.com/wailsapp/wails/v3 v3.0.0-alpha.74/go.mod h1:4saK4A4K9970X+X7RkMwP2lyGbLogcUz54wVeq4C/V8=
|
||||||
|
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||||
|
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
|
||||||
|
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
|
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||||
|
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||||
|
golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU=
|
||||||
|
golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU=
|
||||||
|
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||||
|
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||||
|
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
|
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||||
|
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||||
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200810151505-1b9f1253b3ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||||
|
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||||
|
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||||
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||||
|
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||||
|
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
|
||||||
|
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
|
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
|
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
|
||||||
|
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
|
||||||
|
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||||
|
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||||
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
|
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||||
|
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
|
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
|
||||||
|
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
|
||||||
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
|
modernc.org/sqlite v1.46.2 h1:gkXQ6R0+AjxFC/fTDaeIVLbNLNrRoOK7YYVz5BKhTcE=
|
||||||
|
modernc.org/sqlite v1.46.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||||
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
167
internal/app/app.go
Normal file
167
internal/app/app.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/connections"
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
"github.com/vstockwell/wraith/internal/plugin"
|
||||||
|
"github.com/vstockwell/wraith/internal/session"
|
||||||
|
"github.com/vstockwell/wraith/internal/settings"
|
||||||
|
"github.com/vstockwell/wraith/internal/theme"
|
||||||
|
"github.com/vstockwell/wraith/internal/vault"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WraithApp is the main application struct that wires together all services
|
||||||
|
// and exposes vault management methods to the frontend via Wails bindings.
|
||||||
|
type WraithApp struct {
|
||||||
|
db *sql.DB
|
||||||
|
Vault *vault.VaultService
|
||||||
|
Settings *settings.SettingsService
|
||||||
|
Connections *connections.ConnectionService
|
||||||
|
Themes *theme.ThemeService
|
||||||
|
Sessions *session.Manager
|
||||||
|
Plugins *plugin.Registry
|
||||||
|
unlocked bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates and initializes the WraithApp, opening the database, running
|
||||||
|
// migrations, creating all services, and seeding built-in themes.
|
||||||
|
func New() (*WraithApp, error) {
|
||||||
|
dataDir := dataDirectory()
|
||||||
|
dbPath := filepath.Join(dataDir, "wraith.db")
|
||||||
|
|
||||||
|
slog.Info("opening database", "path", dbPath)
|
||||||
|
database, err := db.Open(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Migrate(database); err != nil {
|
||||||
|
return nil, fmt.Errorf("run migrations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
settingsSvc := settings.NewSettingsService(database)
|
||||||
|
connSvc := connections.NewConnectionService(database)
|
||||||
|
themeSvc := theme.NewThemeService(database)
|
||||||
|
sessionMgr := session.NewManager()
|
||||||
|
pluginReg := plugin.NewRegistry()
|
||||||
|
|
||||||
|
// Seed built-in themes on every startup (INSERT OR IGNORE keeps it idempotent)
|
||||||
|
if err := themeSvc.SeedBuiltins(); err != nil {
|
||||||
|
slog.Warn("failed to seed themes", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &WraithApp{
|
||||||
|
db: database,
|
||||||
|
Settings: settingsSvc,
|
||||||
|
Connections: connSvc,
|
||||||
|
Themes: themeSvc,
|
||||||
|
Sessions: sessionMgr,
|
||||||
|
Plugins: pluginReg,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dataDirectory returns the path where Wraith stores its data.
|
||||||
|
// On Windows with APPDATA set, it uses %APPDATA%\Wraith.
|
||||||
|
// On macOS/Linux with XDG_DATA_HOME or HOME, it uses the appropriate path.
|
||||||
|
// Falls back to the current working directory for development.
|
||||||
|
func dataDirectory() string {
|
||||||
|
// Windows
|
||||||
|
if appData := os.Getenv("APPDATA"); appData != "" {
|
||||||
|
return filepath.Join(appData, "Wraith")
|
||||||
|
}
|
||||||
|
|
||||||
|
// macOS / Linux: use XDG_DATA_HOME or fallback to ~/.local/share
|
||||||
|
if home, err := os.UserHomeDir(); err == nil {
|
||||||
|
if xdg := os.Getenv("XDG_DATA_HOME"); xdg != "" {
|
||||||
|
return filepath.Join(xdg, "wraith")
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".local", "share", "wraith")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dev fallback
|
||||||
|
return "."
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFirstRun checks whether the vault has been set up by looking for vault_salt in settings.
|
||||||
|
func (a *WraithApp) IsFirstRun() bool {
|
||||||
|
salt, _ := a.Settings.Get("vault_salt")
|
||||||
|
return salt == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateVault sets up the vault with a master password. It generates a salt,
|
||||||
|
// derives an encryption key, and stores the salt and a check value in settings.
|
||||||
|
func (a *WraithApp) CreateVault(password string) error {
|
||||||
|
salt, err := vault.GenerateSalt()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := vault.DeriveKey(password, salt)
|
||||||
|
a.Vault = vault.NewVaultService(key)
|
||||||
|
|
||||||
|
// Store salt as hex in settings
|
||||||
|
if err := a.Settings.Set("vault_salt", hex.EncodeToString(salt)); err != nil {
|
||||||
|
return fmt.Errorf("store salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt a known check value — used to verify the password on unlock
|
||||||
|
check, err := a.Vault.Encrypt("wraith-vault-check")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encrypt check value: %w", err)
|
||||||
|
}
|
||||||
|
if err := a.Settings.Set("vault_check", check); err != nil {
|
||||||
|
return fmt.Errorf("store check value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.unlocked = true
|
||||||
|
slog.Info("vault created successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlock verifies the master password against the stored check value and
|
||||||
|
// initializes the vault service for decryption.
|
||||||
|
func (a *WraithApp) Unlock(password string) error {
|
||||||
|
saltHex, err := a.Settings.Get("vault_salt")
|
||||||
|
if err != nil || saltHex == "" {
|
||||||
|
return fmt.Errorf("vault not set up — call CreateVault first")
|
||||||
|
}
|
||||||
|
|
||||||
|
salt, err := hex.DecodeString(saltHex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decode salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := vault.DeriveKey(password, salt)
|
||||||
|
vs := vault.NewVaultService(key)
|
||||||
|
|
||||||
|
// Verify by decrypting the stored check value
|
||||||
|
checkEncrypted, err := a.Settings.Get("vault_check")
|
||||||
|
if err != nil || checkEncrypted == "" {
|
||||||
|
return fmt.Errorf("vault check value missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
checkPlain, err := vs.Decrypt(checkEncrypted)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("incorrect master password")
|
||||||
|
}
|
||||||
|
if checkPlain != "wraith-vault-check" {
|
||||||
|
return fmt.Errorf("incorrect master password")
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Vault = vs
|
||||||
|
a.unlocked = true
|
||||||
|
slog.Info("vault unlocked successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnlocked returns whether the vault is currently unlocked.
|
||||||
|
func (a *WraithApp) IsUnlocked() bool {
|
||||||
|
return a.unlocked
|
||||||
|
}
|
||||||
40
internal/connections/search.go
Normal file
40
internal/connections/search.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package connections
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func (s *ConnectionService) Search(query string) ([]Connection, error) {
|
||||||
|
like := "%" + query + "%"
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
|
||||||
|
COALESCE(color,''), tags, COALESCE(notes,''), COALESCE(options,'{}'),
|
||||||
|
sort_order, last_connected, created_at, updated_at
|
||||||
|
FROM connections
|
||||||
|
WHERE name LIKE ? COLLATE NOCASE
|
||||||
|
OR hostname LIKE ? COLLATE NOCASE
|
||||||
|
OR tags LIKE ? COLLATE NOCASE
|
||||||
|
OR notes LIKE ? COLLATE NOCASE
|
||||||
|
ORDER BY last_connected DESC NULLS LAST, name`,
|
||||||
|
like, like, like, like,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("search connections: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanConnections(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionService) FilterByTag(tag string) ([]Connection, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT c.id, c.name, c.hostname, c.port, c.protocol, c.group_id, c.credential_id,
|
||||||
|
COALESCE(c.color,''), c.tags, COALESCE(c.notes,''), COALESCE(c.options,'{}'),
|
||||||
|
c.sort_order, c.last_connected, c.created_at, c.updated_at
|
||||||
|
FROM connections c, json_each(c.tags) AS t
|
||||||
|
WHERE t.value = ?
|
||||||
|
ORDER BY c.name`, tag,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("filter by tag: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanConnections(rows)
|
||||||
|
}
|
||||||
53
internal/connections/search_test.go
Normal file
53
internal/connections/search_test.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package connections
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSearchByName(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "Asgard", Hostname: "192.168.1.4", Port: 22, Protocol: "ssh"})
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "Docker", Hostname: "155.254.29.221", Port: 22, Protocol: "ssh"})
|
||||||
|
|
||||||
|
results, err := svc.Search("asg")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Search() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Fatalf("len(results) = %d, want 1", len(results))
|
||||||
|
}
|
||||||
|
if results[0].Name != "Asgard" {
|
||||||
|
t.Errorf("Name = %q, want %q", results[0].Name, "Asgard")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchByHostname(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "Asgard", Hostname: "192.168.1.4", Port: 22, Protocol: "ssh"})
|
||||||
|
|
||||||
|
results, _ := svc.Search("192.168")
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Errorf("len(results) = %d, want 1", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchByTag(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "ProdServer", Hostname: "10.0.0.1", Port: 22, Protocol: "ssh", Tags: []string{"Prod", "Linux"}})
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "DevServer", Hostname: "10.0.0.2", Port: 22, Protocol: "ssh", Tags: []string{"Dev", "Linux"}})
|
||||||
|
|
||||||
|
results, _ := svc.Search("Prod")
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Errorf("len(results) = %d, want 1", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterByTag(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "A", Hostname: "10.0.0.1", Port: 22, Protocol: "ssh", Tags: []string{"Prod"}})
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "B", Hostname: "10.0.0.2", Port: 22, Protocol: "ssh", Tags: []string{"Dev"}})
|
||||||
|
svc.CreateConnection(CreateConnectionInput{Name: "C", Hostname: "10.0.0.3", Port: 22, Protocol: "ssh", Tags: []string{"Prod", "Linux"}})
|
||||||
|
|
||||||
|
results, _ := svc.FilterByTag("Prod")
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("len(results) = %d, want 2", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
361
internal/connections/service.go
Normal file
361
internal/connections/service.go
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
package connections
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Group struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ParentID *int64 `json:"parentId"`
|
||||||
|
SortOrder int `json:"sortOrder"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
Children []Group `json:"children,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Connection struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
GroupID *int64 `json:"groupId"`
|
||||||
|
CredentialID *int64 `json:"credentialId"`
|
||||||
|
Color string `json:"color"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Notes string `json:"notes"`
|
||||||
|
Options string `json:"options"`
|
||||||
|
SortOrder int `json:"sortOrder"`
|
||||||
|
LastConnected *time.Time `json:"lastConnected"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateConnectionInput struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
GroupID *int64 `json:"groupId"`
|
||||||
|
CredentialID *int64 `json:"credentialId"`
|
||||||
|
Color string `json:"color"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Notes string `json:"notes"`
|
||||||
|
Options string `json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateConnectionInput struct {
|
||||||
|
Name *string `json:"name"`
|
||||||
|
Hostname *string `json:"hostname"`
|
||||||
|
Port *int `json:"port"`
|
||||||
|
GroupID *int64 `json:"groupId"`
|
||||||
|
CredentialID *int64 `json:"credentialId"`
|
||||||
|
Color *string `json:"color"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Notes *string `json:"notes"`
|
||||||
|
Options *string `json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectionService struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionService(db *sql.DB) *ConnectionService {
|
||||||
|
return &ConnectionService{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Group CRUD ----------
|
||||||
|
|
||||||
|
func (s *ConnectionService) CreateGroup(name string, parentID *int64) (*Group, error) {
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
"INSERT INTO groups (name, parent_id) VALUES (?, ?)",
|
||||||
|
name, parentID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create group: %w", err)
|
||||||
|
}
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get group id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var g Group
|
||||||
|
var icon sql.NullString
|
||||||
|
err = s.db.QueryRow(
|
||||||
|
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups WHERE id = ?", id,
|
||||||
|
).Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get created group: %w", err)
|
||||||
|
}
|
||||||
|
if icon.Valid {
|
||||||
|
g.Icon = icon.String
|
||||||
|
}
|
||||||
|
return &g, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionService) ListGroups() ([]Group, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups ORDER BY sort_order, name",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list groups: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
groupMap := make(map[int64]*Group)
|
||||||
|
var allGroups []*Group
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var g Group
|
||||||
|
var icon sql.NullString
|
||||||
|
if err := rows.Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan group: %w", err)
|
||||||
|
}
|
||||||
|
if icon.Valid {
|
||||||
|
g.Icon = icon.String
|
||||||
|
}
|
||||||
|
groupMap[g.ID] = &g
|
||||||
|
allGroups = append(allGroups, &g)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build tree: attach children to parents, collect roots
|
||||||
|
var roots []Group
|
||||||
|
for _, g := range allGroups {
|
||||||
|
if g.ParentID != nil {
|
||||||
|
if parent, ok := groupMap[*g.ParentID]; ok {
|
||||||
|
parent.Children = append(parent.Children, *g)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
roots = append(roots, *g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-attach children to root copies (since we copied into roots)
|
||||||
|
for i := range roots {
|
||||||
|
if orig, ok := groupMap[roots[i].ID]; ok {
|
||||||
|
roots[i].Children = orig.Children
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if roots == nil {
|
||||||
|
roots = []Group{}
|
||||||
|
}
|
||||||
|
return roots, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionService) DeleteGroup(id int64) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM groups WHERE id = ?", id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete group: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Connection CRUD ----------
|
||||||
|
|
||||||
|
func (s *ConnectionService) CreateConnection(input CreateConnectionInput) (*Connection, error) {
|
||||||
|
tags, err := json.Marshal(input.Tags)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal tags: %w", err)
|
||||||
|
}
|
||||||
|
if input.Tags == nil {
|
||||||
|
tags = []byte("[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
options := input.Options
|
||||||
|
if options == "" {
|
||||||
|
options = "{}"
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
`INSERT INTO connections (name, hostname, port, protocol, group_id, credential_id, color, tags, notes, options)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
input.Name, input.Hostname, input.Port, input.Protocol,
|
||||||
|
input.GroupID, input.CredentialID, input.Color,
|
||||||
|
string(tags), input.Notes, options,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create connection: %w", err)
|
||||||
|
}
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get connection id: %w", err)
|
||||||
|
}
|
||||||
|
return s.GetConnection(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionService) GetConnection(id int64) (*Connection, error) {
|
||||||
|
row := s.db.QueryRow(
|
||||||
|
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
|
||||||
|
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
|
||||||
|
FROM connections WHERE id = ?`, id,
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Connection
|
||||||
|
var tagsJSON string
|
||||||
|
var color, notes, options sql.NullString
|
||||||
|
var lastConnected sql.NullTime
|
||||||
|
|
||||||
|
err := row.Scan(
|
||||||
|
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
|
||||||
|
&c.GroupID, &c.CredentialID,
|
||||||
|
&color, &tagsJSON, ¬es, &options,
|
||||||
|
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if color.Valid {
|
||||||
|
c.Color = color.String
|
||||||
|
}
|
||||||
|
if notes.Valid {
|
||||||
|
c.Notes = notes.String
|
||||||
|
}
|
||||||
|
if options.Valid {
|
||||||
|
c.Options = options.String
|
||||||
|
}
|
||||||
|
if lastConnected.Valid {
|
||||||
|
c.LastConnected = &lastConnected.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
|
||||||
|
c.Tags = []string{}
|
||||||
|
}
|
||||||
|
if c.Tags == nil {
|
||||||
|
c.Tags = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionService) ListConnections() ([]Connection, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
|
||||||
|
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
|
||||||
|
FROM connections ORDER BY sort_order, name`,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list connections: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanConnections(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanConnections is a shared helper used by ListConnections and (later) Search.
|
||||||
|
func scanConnections(rows *sql.Rows) ([]Connection, error) {
|
||||||
|
var conns []Connection
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var c Connection
|
||||||
|
var tagsJSON string
|
||||||
|
var color, notes, options sql.NullString
|
||||||
|
var lastConnected sql.NullTime
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
|
||||||
|
&c.GroupID, &c.CredentialID,
|
||||||
|
&color, &tagsJSON, ¬es, &options,
|
||||||
|
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if color.Valid {
|
||||||
|
c.Color = color.String
|
||||||
|
}
|
||||||
|
if notes.Valid {
|
||||||
|
c.Notes = notes.String
|
||||||
|
}
|
||||||
|
if options.Valid {
|
||||||
|
c.Options = options.String
|
||||||
|
}
|
||||||
|
if lastConnected.Valid {
|
||||||
|
c.LastConnected = &lastConnected.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
|
||||||
|
c.Tags = []string{}
|
||||||
|
}
|
||||||
|
if c.Tags == nil {
|
||||||
|
c.Tags = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
conns = append(conns, c)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate connections: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conns == nil {
|
||||||
|
conns = []Connection{}
|
||||||
|
}
|
||||||
|
return conns, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionService) UpdateConnection(id int64, input UpdateConnectionInput) (*Connection, error) {
|
||||||
|
setClauses := []string{"updated_at = CURRENT_TIMESTAMP"}
|
||||||
|
args := []interface{}{}
|
||||||
|
|
||||||
|
if input.Name != nil {
|
||||||
|
setClauses = append(setClauses, "name = ?")
|
||||||
|
args = append(args, *input.Name)
|
||||||
|
}
|
||||||
|
if input.Hostname != nil {
|
||||||
|
setClauses = append(setClauses, "hostname = ?")
|
||||||
|
args = append(args, *input.Hostname)
|
||||||
|
}
|
||||||
|
if input.Port != nil {
|
||||||
|
setClauses = append(setClauses, "port = ?")
|
||||||
|
args = append(args, *input.Port)
|
||||||
|
}
|
||||||
|
if input.GroupID != nil {
|
||||||
|
setClauses = append(setClauses, "group_id = ?")
|
||||||
|
args = append(args, *input.GroupID)
|
||||||
|
}
|
||||||
|
if input.CredentialID != nil {
|
||||||
|
setClauses = append(setClauses, "credential_id = ?")
|
||||||
|
args = append(args, *input.CredentialID)
|
||||||
|
}
|
||||||
|
if input.Tags != nil {
|
||||||
|
tags, _ := json.Marshal(input.Tags)
|
||||||
|
setClauses = append(setClauses, "tags = ?")
|
||||||
|
args = append(args, string(tags))
|
||||||
|
}
|
||||||
|
if input.Notes != nil {
|
||||||
|
setClauses = append(setClauses, "notes = ?")
|
||||||
|
args = append(args, *input.Notes)
|
||||||
|
}
|
||||||
|
if input.Color != nil {
|
||||||
|
setClauses = append(setClauses, "color = ?")
|
||||||
|
args = append(args, *input.Color)
|
||||||
|
}
|
||||||
|
if input.Options != nil {
|
||||||
|
setClauses = append(setClauses, "options = ?")
|
||||||
|
args = append(args, *input.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, id)
|
||||||
|
query := fmt.Sprintf("UPDATE connections SET %s WHERE id = ?", strings.Join(setClauses, ", "))
|
||||||
|
if _, err := s.db.Exec(query, args...); err != nil {
|
||||||
|
return nil, fmt.Errorf("update connection: %w", err)
|
||||||
|
}
|
||||||
|
return s.GetConnection(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConnectionService) DeleteConnection(id int64) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM connections WHERE id = ?", id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete connection: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
234
internal/connections/service_test.go
Normal file
234
internal/connections/service_test.go
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
package connections
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func strPtr(s string) *string { return &s }
|
||||||
|
|
||||||
|
func setupTestService(t *testing.T) *ConnectionService {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Open() error: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(database); err != nil {
|
||||||
|
t.Fatalf("db.Migrate() error: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { database.Close() })
|
||||||
|
return NewConnectionService(database)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGroup(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
g, err := svc.CreateGroup("Servers", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateGroup() error: %v", err)
|
||||||
|
}
|
||||||
|
if g.ID == 0 {
|
||||||
|
t.Error("expected non-zero ID")
|
||||||
|
}
|
||||||
|
if g.Name != "Servers" {
|
||||||
|
t.Errorf("Name = %q, want %q", g.Name, "Servers")
|
||||||
|
}
|
||||||
|
if g.ParentID != nil {
|
||||||
|
t.Errorf("ParentID = %v, want nil", g.ParentID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSubGroup(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
parent, err := svc.CreateGroup("Servers", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateGroup(parent) error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
child, err := svc.CreateGroup("Production", &parent.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateGroup(child) error: %v", err)
|
||||||
|
}
|
||||||
|
if child.ParentID == nil {
|
||||||
|
t.Fatal("expected non-nil ParentID")
|
||||||
|
}
|
||||||
|
if *child.ParentID != parent.ID {
|
||||||
|
t.Errorf("ParentID = %d, want %d", *child.ParentID, parent.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListGroups(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
parent, err := svc.CreateGroup("Servers", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateGroup(parent) error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := svc.CreateGroup("Production", &parent.ID); err != nil {
|
||||||
|
t.Fatalf("CreateGroup(child) error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := svc.ListGroups()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListGroups() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(groups) != 1 {
|
||||||
|
t.Fatalf("len(groups) = %d, want 1 (only root groups)", len(groups))
|
||||||
|
}
|
||||||
|
if groups[0].Name != "Servers" {
|
||||||
|
t.Errorf("groups[0].Name = %q, want %q", groups[0].Name, "Servers")
|
||||||
|
}
|
||||||
|
if len(groups[0].Children) != 1 {
|
||||||
|
t.Fatalf("len(children) = %d, want 1", len(groups[0].Children))
|
||||||
|
}
|
||||||
|
if groups[0].Children[0].Name != "Production" {
|
||||||
|
t.Errorf("child name = %q, want %q", groups[0].Children[0].Name, "Production")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteGroup(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
g, err := svc.CreateGroup("ToDelete", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateGroup() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := svc.DeleteGroup(g.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteGroup() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := svc.ListGroups()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListGroups() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(groups) != 0 {
|
||||||
|
t.Errorf("len(groups) = %d, want 0 after delete", len(groups))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateConnection(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
conn, err := svc.CreateConnection(CreateConnectionInput{
|
||||||
|
Name: "Web Server",
|
||||||
|
Hostname: "10.0.0.1",
|
||||||
|
Port: 22,
|
||||||
|
Protocol: "ssh",
|
||||||
|
Tags: []string{"Prod", "Linux"},
|
||||||
|
Options: `{"keepAliveInterval": 60}`,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConnection() error: %v", err)
|
||||||
|
}
|
||||||
|
if conn.ID == 0 {
|
||||||
|
t.Error("expected non-zero ID")
|
||||||
|
}
|
||||||
|
if conn.Name != "Web Server" {
|
||||||
|
t.Errorf("Name = %q, want %q", conn.Name, "Web Server")
|
||||||
|
}
|
||||||
|
if len(conn.Tags) != 2 {
|
||||||
|
t.Fatalf("len(Tags) = %d, want 2", len(conn.Tags))
|
||||||
|
}
|
||||||
|
if conn.Tags[0] != "Prod" || conn.Tags[1] != "Linux" {
|
||||||
|
t.Errorf("Tags = %v, want [Prod Linux]", conn.Tags)
|
||||||
|
}
|
||||||
|
if conn.Options != `{"keepAliveInterval": 60}` {
|
||||||
|
t.Errorf("Options = %q, want JSON blob", conn.Options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListConnections(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
if _, err := svc.CreateConnection(CreateConnectionInput{
|
||||||
|
Name: "Server A",
|
||||||
|
Hostname: "10.0.0.1",
|
||||||
|
Port: 22,
|
||||||
|
Protocol: "ssh",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("CreateConnection(A) error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := svc.CreateConnection(CreateConnectionInput{
|
||||||
|
Name: "Server B",
|
||||||
|
Hostname: "10.0.0.2",
|
||||||
|
Port: 3389,
|
||||||
|
Protocol: "rdp",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("CreateConnection(B) error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conns, err := svc.ListConnections()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListConnections() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(conns) != 2 {
|
||||||
|
t.Fatalf("len(conns) = %d, want 2", len(conns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnection(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
conn, err := svc.CreateConnection(CreateConnectionInput{
|
||||||
|
Name: "Old Name",
|
||||||
|
Hostname: "10.0.0.1",
|
||||||
|
Port: 22,
|
||||||
|
Protocol: "ssh",
|
||||||
|
Tags: []string{"Dev"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConnection() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := svc.UpdateConnection(conn.ID, UpdateConnectionInput{
|
||||||
|
Name: strPtr("New Name"),
|
||||||
|
Tags: []string{"Prod", "Linux"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateConnection() error: %v", err)
|
||||||
|
}
|
||||||
|
if updated.Name != "New Name" {
|
||||||
|
t.Errorf("Name = %q, want %q", updated.Name, "New Name")
|
||||||
|
}
|
||||||
|
if len(updated.Tags) != 2 {
|
||||||
|
t.Fatalf("len(Tags) = %d, want 2", len(updated.Tags))
|
||||||
|
}
|
||||||
|
if updated.Tags[0] != "Prod" {
|
||||||
|
t.Errorf("Tags[0] = %q, want %q", updated.Tags[0], "Prod")
|
||||||
|
}
|
||||||
|
// Hostname should remain unchanged
|
||||||
|
if updated.Hostname != "10.0.0.1" {
|
||||||
|
t.Errorf("Hostname = %q, want %q (unchanged)", updated.Hostname, "10.0.0.1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteConnection(t *testing.T) {
|
||||||
|
svc := setupTestService(t)
|
||||||
|
|
||||||
|
conn, err := svc.CreateConnection(CreateConnectionInput{
|
||||||
|
Name: "ToDelete",
|
||||||
|
Hostname: "10.0.0.1",
|
||||||
|
Port: 22,
|
||||||
|
Protocol: "ssh",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConnection() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := svc.DeleteConnection(conn.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteConnection() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conns, err := svc.ListConnections()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListConnections() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(conns) != 0 {
|
||||||
|
t.Errorf("len(conns) = %d, want 0 after delete", len(conns))
|
||||||
|
}
|
||||||
|
}
|
||||||
408
internal/credentials/service.go
Normal file
408
internal/credentials/service.go
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/vault"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Credential represents a stored credential (password or SSH key reference).
|
||||||
|
type Credential struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
Type string `json:"type"` // "password" or "ssh_key"
|
||||||
|
SSHKeyID *int64 `json:"sshKeyId"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
UpdatedAt string `json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHKey represents a stored SSH key.
|
||||||
|
type SSHKey struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
KeyType string `json:"keyType"`
|
||||||
|
Fingerprint string `json:"fingerprint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CredentialService provides CRUD for credentials with vault encryption.
|
||||||
|
type CredentialService struct {
|
||||||
|
db *sql.DB
|
||||||
|
vault *vault.VaultService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCredentialService creates a new CredentialService.
|
||||||
|
func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService {
|
||||||
|
return &CredentialService{db: db, vault: vault}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePassword creates a password credential (password encrypted via vault).
|
||||||
|
func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) {
|
||||||
|
encrypted, err := s.vault.Encrypt(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encrypt password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
`INSERT INTO credentials (name, username, domain, type, encrypted_value)
|
||||||
|
VALUES (?, ?, ?, 'password', ?)`,
|
||||||
|
name, username, domain, encrypted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("insert credential: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get credential id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.getCredential(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSSHKey imports an SSH key (private key encrypted via vault).
|
||||||
|
func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) {
|
||||||
|
// Parse the private key to detect type and extract public key
|
||||||
|
keyType := DetectKeyType(privateKeyPEM)
|
||||||
|
|
||||||
|
// Parse the key to get the public key for fingerprinting.
|
||||||
|
// Try without passphrase first (handles unencrypted keys even when a
|
||||||
|
// passphrase is provided for storage), then fall back to using the
|
||||||
|
// passphrase for encrypted PEM keys.
|
||||||
|
var signer ssh.Signer
|
||||||
|
var err error
|
||||||
|
signer, err = ssh.ParsePrivateKey(privateKeyPEM)
|
||||||
|
if err != nil && passphrase != "" {
|
||||||
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := signer.PublicKey()
|
||||||
|
fingerprint := ssh.FingerprintSHA256(pubKey)
|
||||||
|
publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey))
|
||||||
|
|
||||||
|
// Encrypt private key via vault
|
||||||
|
encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encrypt private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt passphrase via vault (if provided)
|
||||||
|
var encryptedPassphrase sql.NullString
|
||||||
|
if passphrase != "" {
|
||||||
|
ep, err := s.vault.Encrypt(passphrase)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encrypt passphrase: %w", err)
|
||||||
|
}
|
||||||
|
encryptedPassphrase = sql.NullString{String: ep, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
`INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||||
|
name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("insert ssh key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get ssh key id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.getSSHKey(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListCredentials returns all credentials WITHOUT encrypted values.
|
||||||
|
func (s *CredentialService) ListCredentials() ([]Credential, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
|
||||||
|
FROM credentials ORDER BY name`,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list credentials: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var creds []Credential
|
||||||
|
for rows.Next() {
|
||||||
|
var c Credential
|
||||||
|
var username, domain sql.NullString
|
||||||
|
if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan credential: %w", err)
|
||||||
|
}
|
||||||
|
if username.Valid {
|
||||||
|
c.Username = username.String
|
||||||
|
}
|
||||||
|
if domain.Valid {
|
||||||
|
c.Domain = domain.String
|
||||||
|
}
|
||||||
|
creds = append(creds, c)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if creds == nil {
|
||||||
|
creds = []Credential{}
|
||||||
|
}
|
||||||
|
return creds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSSHKeys returns all SSH keys WITHOUT private key data.
|
||||||
|
func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT id, name, key_type, fingerprint, public_key, created_at
|
||||||
|
FROM ssh_keys ORDER BY name`,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list ssh keys: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var keys []SSHKey
|
||||||
|
for rows.Next() {
|
||||||
|
var k SSHKey
|
||||||
|
var keyType, fingerprint, publicKey sql.NullString
|
||||||
|
if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan ssh key: %w", err)
|
||||||
|
}
|
||||||
|
if keyType.Valid {
|
||||||
|
k.KeyType = keyType.String
|
||||||
|
}
|
||||||
|
if fingerprint.Valid {
|
||||||
|
k.Fingerprint = fingerprint.String
|
||||||
|
}
|
||||||
|
if publicKey.Valid {
|
||||||
|
k.PublicKey = publicKey.String
|
||||||
|
}
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate ssh keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keys == nil {
|
||||||
|
keys = []SSHKey{}
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptPassword returns the decrypted password for a credential.
|
||||||
|
func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) {
|
||||||
|
var encrypted sql.NullString
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
"SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'",
|
||||||
|
credentialID,
|
||||||
|
).Scan(&encrypted)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get encrypted password: %w", err)
|
||||||
|
}
|
||||||
|
if !encrypted.Valid {
|
||||||
|
return "", fmt.Errorf("no encrypted value for credential %d", credentialID)
|
||||||
|
}
|
||||||
|
|
||||||
|
password, err := s.vault.Decrypt(encrypted.String)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decrypt password: %w", err)
|
||||||
|
}
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptSSHKey returns the decrypted private key + passphrase.
|
||||||
|
func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) {
|
||||||
|
var encryptedKey string
|
||||||
|
var encryptedPassphrase sql.NullString
|
||||||
|
err = s.db.QueryRow(
|
||||||
|
"SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?",
|
||||||
|
sshKeyID,
|
||||||
|
).Scan(&encryptedKey, &encryptedPassphrase)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("get encrypted ssh key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptedKey, err := s.vault.Decrypt(encryptedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("decrypt private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encryptedPassphrase.Valid {
|
||||||
|
passphrase, err = s.vault.Decrypt(encryptedPassphrase.String)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("decrypt passphrase: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(decryptedKey), passphrase, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteCredential removes a credential.
|
||||||
|
func (s *CredentialService) DeleteCredential(id int64) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete credential: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSSHKey removes an SSH key.
|
||||||
|
func (s *CredentialService) DeleteSSHKey(id int64) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete ssh key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa).
|
||||||
|
func DetectKeyType(pemData []byte) string {
|
||||||
|
block, _ := pem.Decode(pemData)
|
||||||
|
if block == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks)
|
||||||
|
if block.Type == "OPENSSH PRIVATE KEY" {
|
||||||
|
return detectOpenSSHKeyType(block.Bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try PKCS8 format
|
||||||
|
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||||
|
switch key.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return "rsa"
|
||||||
|
case ed25519.PrivateKey:
|
||||||
|
return "ed25519"
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return "ecdsa"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try RSA PKCS1
|
||||||
|
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||||
|
return "rsa"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try EC
|
||||||
|
if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||||
|
return "ecdsa"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type.
|
||||||
|
func detectOpenSSHKeyType(data []byte) string {
|
||||||
|
// OpenSSH private key format: "openssh-key-v1\0" magic, then fields.
|
||||||
|
// We parse the key using ssh package to determine the type.
|
||||||
|
// Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer.
|
||||||
|
pemBlock := &pem.Block{
|
||||||
|
Type: "OPENSSH PRIVATE KEY",
|
||||||
|
Bytes: data,
|
||||||
|
}
|
||||||
|
pemBytes := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
return classifyPublicKey(signer.PublicKey())
|
||||||
|
}
|
||||||
|
|
||||||
|
// classifyPublicKey determines the key type from an ssh.PublicKey.
|
||||||
|
func classifyPublicKey(pub ssh.PublicKey) string {
|
||||||
|
keyType := pub.Type()
|
||||||
|
switch keyType {
|
||||||
|
case "ssh-rsa":
|
||||||
|
return "rsa"
|
||||||
|
case "ssh-ed25519":
|
||||||
|
return "ed25519"
|
||||||
|
case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521":
|
||||||
|
return "ecdsa"
|
||||||
|
default:
|
||||||
|
return keyType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCredential retrieves a single credential by ID.
|
||||||
|
func (s *CredentialService) getCredential(id int64) (*Credential, error) {
|
||||||
|
var c Credential
|
||||||
|
var username, domain sql.NullString
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
|
||||||
|
FROM credentials WHERE id = ?`, id,
|
||||||
|
).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get credential: %w", err)
|
||||||
|
}
|
||||||
|
if username.Valid {
|
||||||
|
c.Username = username.String
|
||||||
|
}
|
||||||
|
if domain.Valid {
|
||||||
|
c.Domain = domain.String
|
||||||
|
}
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSSHKey retrieves a single SSH key by ID (without private key data).
|
||||||
|
func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) {
|
||||||
|
var k SSHKey
|
||||||
|
var keyType, fingerprint, publicKey sql.NullString
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
`SELECT id, name, key_type, fingerprint, public_key, created_at
|
||||||
|
FROM ssh_keys WHERE id = ?`, id,
|
||||||
|
).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get ssh key: %w", err)
|
||||||
|
}
|
||||||
|
if keyType.Valid {
|
||||||
|
k.KeyType = keyType.String
|
||||||
|
}
|
||||||
|
if fingerprint.Valid {
|
||||||
|
k.Fingerprint = fingerprint.String
|
||||||
|
}
|
||||||
|
if publicKey.Valid {
|
||||||
|
k.PublicKey = publicKey.String
|
||||||
|
}
|
||||||
|
return &k, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateFingerprint generates an SSH fingerprint string from a public key.
|
||||||
|
func generateFingerprint(pubKey ssh.PublicKey) string {
|
||||||
|
return ssh.FingerprintSHA256(pubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalPublicKey returns the authorized_keys format of an SSH public key.
|
||||||
|
func marshalPublicKey(pubKey ssh.PublicKey) string {
|
||||||
|
return base64.StdEncoding.EncodeToString(pubKey.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ecdsaCurveName returns the name for an ECDSA curve.
|
||||||
|
func ecdsaCurveName(curve elliptic.Curve) string {
|
||||||
|
switch curve {
|
||||||
|
case elliptic.P256():
|
||||||
|
return "nistp256"
|
||||||
|
case elliptic.P384():
|
||||||
|
return "nistp384"
|
||||||
|
case elliptic.P521():
|
||||||
|
return "nistp521"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
176
internal/credentials/service_test.go
Normal file
176
internal/credentials/service_test.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/pem"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
"github.com/vstockwell/wraith/internal/vault"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupCredentialService(t *testing.T) *CredentialService {
|
||||||
|
t.Helper()
|
||||||
|
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(d); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { d.Close() })
|
||||||
|
|
||||||
|
salt := []byte("test-salt-exactly-32-bytes-long!")
|
||||||
|
key := vault.DeriveKey("testpassword", salt)
|
||||||
|
vs := vault.NewVaultService(key)
|
||||||
|
|
||||||
|
return NewCredentialService(d, vs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreatePasswordCredential(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cred.Name != "Test Cred" {
|
||||||
|
t.Error("wrong name")
|
||||||
|
}
|
||||||
|
if cred.Type != "password" {
|
||||||
|
t.Error("wrong type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptPassword(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "")
|
||||||
|
password, err := svc.DecryptPassword(cred.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if password != "mypassword" {
|
||||||
|
t.Errorf("got %q, want mypassword", password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListCredentialsExcludesSecrets(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
svc.CreatePassword("Cred1", "user1", "pass1", "")
|
||||||
|
svc.CreatePassword("Cred2", "user2", "pass2", "")
|
||||||
|
creds, err := svc.ListCredentials()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(creds) != 2 {
|
||||||
|
t.Errorf("got %d, want 2", len(creds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSSHKey(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
// Generate a test key
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
key, err := svc.CreateSSHKey("My Key", keyPEM, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if key.KeyType != "ed25519" {
|
||||||
|
t.Errorf("KeyType = %q, want ed25519", key.KeyType)
|
||||||
|
}
|
||||||
|
if key.Fingerprint == "" {
|
||||||
|
t.Error("fingerprint should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptSSHKey(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass")
|
||||||
|
decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(decryptedKey) == 0 {
|
||||||
|
t.Error("decrypted key should not be empty")
|
||||||
|
}
|
||||||
|
if passphrase != "testpass" {
|
||||||
|
t.Errorf("passphrase = %q, want testpass", passphrase)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectKeyType(t *testing.T) {
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
if got := DetectKeyType(keyPEM); got != "ed25519" {
|
||||||
|
t.Errorf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCredential(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "")
|
||||||
|
err := svc.DeleteCredential(cred.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
creds, _ := svc.ListCredentials()
|
||||||
|
if len(creds) != 0 {
|
||||||
|
t.Errorf("got %d credentials, want 0", len(creds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteSSHKey(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "")
|
||||||
|
err := svc.DeleteSSHKey(key.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
keys, _ := svc.ListSSHKeys()
|
||||||
|
if len(keys) != 0 {
|
||||||
|
t.Errorf("got %d keys, want 0", len(keys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListSSHKeys(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
svc.CreateSSHKey("Key1", keyPEM, "")
|
||||||
|
svc.CreateSSHKey("Key2", keyPEM, "")
|
||||||
|
|
||||||
|
keys, err := svc.ListSSHKeys()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(keys) != 2 {
|
||||||
|
t.Errorf("got %d keys, want 2", len(keys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreatePasswordWithDomain(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cred.Domain != "example.com" {
|
||||||
|
t.Errorf("domain = %q, want example.com", cred.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
34
internal/db/migrations.go
Normal file
34
internal/db/migrations.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed migrations/*.sql
|
||||||
|
var migrationFiles embed.FS
|
||||||
|
|
||||||
|
func Migrate(db *sql.DB) error {
|
||||||
|
entries, err := migrationFiles.ReadDir("migrations")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read migrations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Name() < entries[j].Name()
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
content, err := migrationFiles.ReadFile("migrations/" + entry.Name())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read migration %s: %w", entry.Name(), err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(string(content)); err != nil {
|
||||||
|
return fmt.Errorf("execute migration %s: %w", entry.Name(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
101
internal/db/migrations/001_initial.sql
Normal file
101
internal/db/migrations/001_initial.sql
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
-- 001_initial.sql
|
||||||
|
CREATE TABLE IF NOT EXISTS groups (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
parent_id INTEGER REFERENCES groups(id) ON DELETE SET NULL,
|
||||||
|
sort_order INTEGER DEFAULT 0,
|
||||||
|
icon TEXT,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS ssh_keys (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
key_type TEXT,
|
||||||
|
fingerprint TEXT,
|
||||||
|
public_key TEXT,
|
||||||
|
encrypted_private_key TEXT NOT NULL,
|
||||||
|
passphrase_encrypted TEXT,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS credentials (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
username TEXT,
|
||||||
|
domain TEXT,
|
||||||
|
type TEXT NOT NULL CHECK(type IN ('password','ssh_key')),
|
||||||
|
encrypted_value TEXT,
|
||||||
|
ssh_key_id INTEGER REFERENCES ssh_keys(id) ON DELETE SET NULL,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS connections (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
hostname TEXT NOT NULL,
|
||||||
|
port INTEGER NOT NULL DEFAULT 22,
|
||||||
|
protocol TEXT NOT NULL CHECK(protocol IN ('ssh','rdp')),
|
||||||
|
group_id INTEGER REFERENCES groups(id) ON DELETE SET NULL,
|
||||||
|
credential_id INTEGER REFERENCES credentials(id) ON DELETE SET NULL,
|
||||||
|
color TEXT,
|
||||||
|
tags TEXT DEFAULT '[]',
|
||||||
|
notes TEXT,
|
||||||
|
options TEXT DEFAULT '{}',
|
||||||
|
sort_order INTEGER DEFAULT 0,
|
||||||
|
last_connected DATETIME,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS themes (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
foreground TEXT NOT NULL,
|
||||||
|
background TEXT NOT NULL,
|
||||||
|
cursor TEXT NOT NULL,
|
||||||
|
black TEXT NOT NULL,
|
||||||
|
red TEXT NOT NULL,
|
||||||
|
green TEXT NOT NULL,
|
||||||
|
yellow TEXT NOT NULL,
|
||||||
|
blue TEXT NOT NULL,
|
||||||
|
magenta TEXT NOT NULL,
|
||||||
|
cyan TEXT NOT NULL,
|
||||||
|
white TEXT NOT NULL,
|
||||||
|
bright_black TEXT NOT NULL,
|
||||||
|
bright_red TEXT NOT NULL,
|
||||||
|
bright_green TEXT NOT NULL,
|
||||||
|
bright_yellow TEXT NOT NULL,
|
||||||
|
bright_blue TEXT NOT NULL,
|
||||||
|
bright_magenta TEXT NOT NULL,
|
||||||
|
bright_cyan TEXT NOT NULL,
|
||||||
|
bright_white TEXT NOT NULL,
|
||||||
|
selection_bg TEXT,
|
||||||
|
selection_fg TEXT,
|
||||||
|
is_builtin BOOLEAN DEFAULT 0
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS connection_history (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
connection_id INTEGER NOT NULL REFERENCES connections(id) ON DELETE CASCADE,
|
||||||
|
protocol TEXT NOT NULL,
|
||||||
|
connected_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
disconnected_at DATETIME,
|
||||||
|
duration_secs INTEGER
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS host_keys (
|
||||||
|
hostname TEXT NOT NULL,
|
||||||
|
port INTEGER NOT NULL,
|
||||||
|
key_type TEXT NOT NULL,
|
||||||
|
fingerprint TEXT NOT NULL,
|
||||||
|
raw_key TEXT,
|
||||||
|
first_seen DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY (hostname, port, key_type)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS settings (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL
|
||||||
|
);
|
||||||
39
internal/db/sqlite.go
Normal file
39
internal/db/sqlite.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Open(dbPath string) (*sql.DB, error) {
|
||||||
|
dir := filepath.Dir(dbPath)
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return nil, fmt.Errorf("create db directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("set WAL mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("set busy_timeout: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("enable foreign keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
85
internal/db/sqlite_test.go
Normal file
85
internal/db/sqlite_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenCreatesDatabase(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
dbPath := filepath.Join(dir, "test.db")
|
||||||
|
|
||||||
|
db, err := Open(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open() error: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||||
|
t.Fatal("database file was not created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenSetsWALMode(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
db, err := Open(filepath.Join(dir, "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open() error: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var mode string
|
||||||
|
err = db.QueryRow("PRAGMA journal_mode").Scan(&mode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PRAGMA query error: %v", err)
|
||||||
|
}
|
||||||
|
if mode != "wal" {
|
||||||
|
t.Errorf("journal_mode = %q, want %q", mode, "wal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenSetsBusyTimeout(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
db, err := Open(filepath.Join(dir, "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open() error: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var timeout int
|
||||||
|
err = db.QueryRow("PRAGMA busy_timeout").Scan(&timeout)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PRAGMA query error: %v", err)
|
||||||
|
}
|
||||||
|
if timeout != 5000 {
|
||||||
|
t.Errorf("busy_timeout = %d, want %d", timeout, 5000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateCreatesAllTables(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
db, err := Open(filepath.Join(dir, "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open() error: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err := Migrate(db); err != nil {
|
||||||
|
t.Fatalf("Migrate() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedTables := []string{
|
||||||
|
"groups", "connections", "credentials", "ssh_keys",
|
||||||
|
"themes", "connection_history", "host_keys", "settings",
|
||||||
|
}
|
||||||
|
for _, table := range expectedTables {
|
||||||
|
var name string
|
||||||
|
err := db.QueryRow(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", table,
|
||||||
|
).Scan(&name)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("table %q not found: %v", table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
57
internal/plugin/interfaces.go
Normal file
57
internal/plugin/interfaces.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package plugin
|
||||||
|
|
||||||
|
type ProtocolHandler interface {
|
||||||
|
Name() string
|
||||||
|
Connect(config map[string]interface{}) (Session, error)
|
||||||
|
Disconnect(sessionID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session interface {
|
||||||
|
ID() string
|
||||||
|
Protocol() string
|
||||||
|
Write(data []byte) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Importer interface {
|
||||||
|
Name() string
|
||||||
|
FileExtensions() []string
|
||||||
|
Parse(data []byte) (*ImportResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImportResult struct {
|
||||||
|
Groups []ImportGroup `json:"groups"`
|
||||||
|
Connections []ImportConnection `json:"connections"`
|
||||||
|
HostKeys []ImportHostKey `json:"hostKeys"`
|
||||||
|
Theme *ImportTheme `json:"theme,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImportGroup struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
ParentName string `json:"parentName,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImportConnection struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
GroupName string `json:"groupName"`
|
||||||
|
Notes string `json:"notes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImportHostKey struct {
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
KeyType string `json:"keyType"`
|
||||||
|
Fingerprint string `json:"fingerprint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImportTheme struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Foreground string `json:"foreground"`
|
||||||
|
Background string `json:"background"`
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Colors [16]string `json:"colors"`
|
||||||
|
}
|
||||||
47
internal/plugin/registry.go
Normal file
47
internal/plugin/registry.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package plugin
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type Registry struct {
|
||||||
|
protocols map[string]ProtocolHandler
|
||||||
|
importers map[string]Importer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
return &Registry{
|
||||||
|
protocols: make(map[string]ProtocolHandler),
|
||||||
|
importers: make(map[string]Importer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) RegisterProtocol(handler ProtocolHandler) {
|
||||||
|
r.protocols[handler.Name()] = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) RegisterImporter(imp Importer) {
|
||||||
|
r.importers[imp.Name()] = imp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) GetProtocol(name string) (ProtocolHandler, error) {
|
||||||
|
h, ok := r.protocols[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("protocol handler %q not registered", name)
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) GetImporter(name string) (Importer, error) {
|
||||||
|
imp, ok := r.importers[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("importer %q not registered", name)
|
||||||
|
}
|
||||||
|
return imp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) ListProtocols() []string {
|
||||||
|
names := make([]string, 0, len(r.protocols))
|
||||||
|
for name := range r.protocols {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
96
internal/session/manager.go
Normal file
96
internal/session/manager.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxSessions = 32
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]*SessionInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager() *Manager {
|
||||||
|
return &Manager{
|
||||||
|
sessions: make(map[string]*SessionInfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if len(m.sessions) >= MaxSessions {
|
||||||
|
return nil, fmt.Errorf("maximum sessions (%d) reached", MaxSessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &SessionInfo{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
ConnectionID: connectionID,
|
||||||
|
Protocol: protocol,
|
||||||
|
State: StateConnecting,
|
||||||
|
TabPosition: len(m.sessions),
|
||||||
|
}
|
||||||
|
m.sessions[s.ID] = s
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Get(id string) (*SessionInfo, bool) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
s, ok := m.sessions[id]
|
||||||
|
return s, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) List() []*SessionInfo {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
list := make([]*SessionInfo, 0, len(m.sessions))
|
||||||
|
for _, s := range m.sessions {
|
||||||
|
list = append(list, s)
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) SetState(id string, state SessionState) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
s, ok := m.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("session %s not found", id)
|
||||||
|
}
|
||||||
|
s.State = state
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Detach(id string) error {
|
||||||
|
return m.SetState(id, StateDetached)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Reattach(id, windowID string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
s, ok := m.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("session %s not found", id)
|
||||||
|
}
|
||||||
|
s.State = StateConnected
|
||||||
|
s.WindowID = windowID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Remove(id string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.sessions, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Count() int {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return len(m.sessions)
|
||||||
|
}
|
||||||
64
internal/session/manager_test.go
Normal file
64
internal/session/manager_test.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestCreateSession(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
s, err := m.Create(1, "ssh")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create() error: %v", err)
|
||||||
|
}
|
||||||
|
if s.ID == "" {
|
||||||
|
t.Error("session ID should not be empty")
|
||||||
|
}
|
||||||
|
if s.State != StateConnecting {
|
||||||
|
t.Errorf("State = %q, want %q", s.State, StateConnecting)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxSessions(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
for i := 0; i < MaxSessions; i++ {
|
||||||
|
_, err := m.Create(int64(i), "ssh")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create() error at %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err := m.Create(999, "ssh")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Create() should fail at max sessions")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetachReattach(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
s, _ := m.Create(1, "ssh")
|
||||||
|
m.SetState(s.ID, StateConnected)
|
||||||
|
|
||||||
|
if err := m.Detach(s.ID); err != nil {
|
||||||
|
t.Fatalf("Detach() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := m.Get(s.ID)
|
||||||
|
if got.State != StateDetached {
|
||||||
|
t.Errorf("State = %q, want %q", got.State, StateDetached)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Reattach(s.ID, "window-1"); err != nil {
|
||||||
|
t.Fatalf("Reattach() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ = m.Get(s.ID)
|
||||||
|
if got.State != StateConnected {
|
||||||
|
t.Errorf("State = %q, want %q", got.State, StateConnected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveSession(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
s, _ := m.Create(1, "ssh")
|
||||||
|
m.Remove(s.ID)
|
||||||
|
if m.Count() != 0 {
|
||||||
|
t.Error("session should have been removed")
|
||||||
|
}
|
||||||
|
}
|
||||||
22
internal/session/session.go
Normal file
22
internal/session/session.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type SessionState string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateConnecting SessionState = "connecting"
|
||||||
|
StateConnected SessionState = "connected"
|
||||||
|
StateDisconnected SessionState = "disconnected"
|
||||||
|
StateDetached SessionState = "detached"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ConnectionID int64 `json:"connectionId"`
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
State SessionState `json:"state"`
|
||||||
|
WindowID string `json:"windowId"`
|
||||||
|
TabPosition int `json:"tabPosition"`
|
||||||
|
ConnectedAt time.Time `json:"connectedAt"`
|
||||||
|
}
|
||||||
41
internal/settings/service.go
Normal file
41
internal/settings/service.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import "database/sql"
|
||||||
|
|
||||||
|
type SettingsService struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSettingsService(db *sql.DB) *SettingsService {
|
||||||
|
return &SettingsService{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingsService) Get(key string) (string, error) {
|
||||||
|
var value string
|
||||||
|
err := s.db.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return value, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingsService) GetDefault(key, defaultValue string) string {
|
||||||
|
val, err := s.Get(key)
|
||||||
|
if err != nil || val == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingsService) Set(key, value string) error {
|
||||||
|
_, err := s.db.Exec(
|
||||||
|
"INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = ?",
|
||||||
|
key, value, value,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingsService) Delete(key string) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM settings WHERE key = ?", key)
|
||||||
|
return err
|
||||||
|
}
|
||||||
70
internal/settings/service_test.go
Normal file
70
internal/settings/service_test.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestDB(t *testing.T) *SettingsService {
|
||||||
|
t.Helper()
|
||||||
|
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(d); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { d.Close() })
|
||||||
|
return NewSettingsService(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAndGet(t *testing.T) {
|
||||||
|
s := setupTestDB(t)
|
||||||
|
|
||||||
|
if err := s.Set("theme", "dracula"); err != nil {
|
||||||
|
t.Fatalf("Set() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := s.Get("theme")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
if val != "dracula" {
|
||||||
|
t.Errorf("Get() = %q, want %q", val, "dracula")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMissing(t *testing.T) {
|
||||||
|
s := setupTestDB(t)
|
||||||
|
|
||||||
|
val, err := s.Get("nonexistent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
if val != "" {
|
||||||
|
t.Errorf("Get() = %q, want empty string", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetOverwrites(t *testing.T) {
|
||||||
|
s := setupTestDB(t)
|
||||||
|
|
||||||
|
s.Set("key", "value1")
|
||||||
|
s.Set("key", "value2")
|
||||||
|
|
||||||
|
val, _ := s.Get("key")
|
||||||
|
if val != "value2" {
|
||||||
|
t.Errorf("Get() = %q, want %q", val, "value2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWithDefault(t *testing.T) {
|
||||||
|
s := setupTestDB(t)
|
||||||
|
|
||||||
|
val := s.GetDefault("missing", "fallback")
|
||||||
|
if val != "fallback" {
|
||||||
|
t.Errorf("GetDefault() = %q, want %q", val, "fallback")
|
||||||
|
}
|
||||||
|
}
|
||||||
238
internal/sftp/service.go
Normal file
238
internal/sftp/service.go
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
package sftp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkg/sftp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxEditFileSize = 5 * 1024 * 1024 // 5MB
|
||||||
|
|
||||||
|
// FileEntry represents a file or directory in a remote filesystem.
|
||||||
|
type FileEntry struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
IsDir bool `json:"isDir"`
|
||||||
|
Permissions string `json:"permissions"`
|
||||||
|
ModTime string `json:"modTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFTPService manages SFTP clients keyed by session ID.
|
||||||
|
type SFTPService struct {
|
||||||
|
clients map[string]*sftp.Client
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSFTPService creates a new SFTPService.
|
||||||
|
func NewSFTPService() *SFTPService {
|
||||||
|
return &SFTPService{
|
||||||
|
clients: make(map[string]*sftp.Client),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClient stores an SFTP client for a session.
|
||||||
|
func (s *SFTPService) RegisterClient(sessionID string, client *sftp.Client) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.clients[sessionID] = client
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveClient removes and closes an SFTP client.
|
||||||
|
func (s *SFTPService) RemoveClient(sessionID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
client, ok := s.clients[sessionID]
|
||||||
|
if ok {
|
||||||
|
delete(s.clients, sessionID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if ok && client != nil {
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClient returns the SFTP client for a session or an error if not found.
|
||||||
|
func (s *SFTPService) getClient(sessionID string) (*sftp.Client, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
client, ok := s.clients[sessionID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no SFTP client for session %s", sessionID)
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SortEntries sorts file entries with directories first, then alphabetically by name.
|
||||||
|
func SortEntries(entries []FileEntry) {
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
if entries[i].IsDir != entries[j].IsDir {
|
||||||
|
return entries[i].IsDir
|
||||||
|
}
|
||||||
|
return strings.ToLower(entries[i].Name) < strings.ToLower(entries[j].Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileInfoToEntry converts an os.FileInfo and its path into a FileEntry.
|
||||||
|
func fileInfoToEntry(info os.FileInfo, path string) FileEntry {
|
||||||
|
return FileEntry{
|
||||||
|
Name: info.Name(),
|
||||||
|
Path: path,
|
||||||
|
Size: info.Size(),
|
||||||
|
IsDir: info.IsDir(),
|
||||||
|
Permissions: info.Mode().Perm().String(),
|
||||||
|
ModTime: info.ModTime().UTC().Format("2006-01-02T15:04:05Z"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns directory contents sorted (dirs first, then files alphabetically).
|
||||||
|
func (s *SFTPService) List(sessionID string, path string) ([]FileEntry, error) {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
infos, err := client.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read directory %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]FileEntry, 0, len(infos))
|
||||||
|
for _, info := range infos {
|
||||||
|
entryPath := path
|
||||||
|
if !strings.HasSuffix(entryPath, "/") {
|
||||||
|
entryPath += "/"
|
||||||
|
}
|
||||||
|
entryPath += info.Name()
|
||||||
|
entries = append(entries, fileInfoToEntry(info, entryPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
SortEntries(entries)
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFile reads a file (max 5MB). Returns content as string.
|
||||||
|
func (s *SFTPService) ReadFile(sessionID string, path string) (string, error) {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := client.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("stat %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
return "", fmt.Errorf("%s is a directory", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Size() > MaxEditFileSize {
|
||||||
|
return "", fmt.Errorf("file %s is %d bytes, exceeds max edit size of %d bytes", path, info.Size(), MaxEditFileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := client.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("open %s: %w", path, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFile writes content to a file.
|
||||||
|
func (s *SFTPService) WriteFile(sessionID string, path string, content string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := client.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create %s: %w", path, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := f.Write([]byte(content)); err != nil {
|
||||||
|
return fmt.Errorf("write %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mkdir creates a directory.
|
||||||
|
func (s *SFTPService) Mkdir(sessionID string, path string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Mkdir(path); err != nil {
|
||||||
|
return fmt.Errorf("mkdir %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a file or empty directory.
|
||||||
|
func (s *SFTPService) Delete(sessionID string, path string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := client.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stat %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
if err := client.RemoveDirectory(path); err != nil {
|
||||||
|
return fmt.Errorf("remove directory %s: %w", path, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := client.Remove(path); err != nil {
|
||||||
|
return fmt.Errorf("remove %s: %w", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rename renames/moves a file.
|
||||||
|
func (s *SFTPService) Rename(sessionID string, oldPath, newPath string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Rename(oldPath, newPath); err != nil {
|
||||||
|
return fmt.Errorf("rename %s to %s: %w", oldPath, newPath, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stat returns info about a file/directory.
|
||||||
|
func (s *SFTPService) Stat(sessionID string, path string) (*FileEntry, error) {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := client.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("stat %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := fileInfoToEntry(info, path)
|
||||||
|
return &entry, nil
|
||||||
|
}
|
||||||
119
internal/sftp/service_test.go
Normal file
119
internal/sftp/service_test.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package sftp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSFTPService(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
if svc == nil {
|
||||||
|
t.Fatal("nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
_, err := svc.List("nonexistent", "/")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFileWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
_, err := svc.ReadFile("nonexistent", "/etc/hosts")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteFileWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.WriteFile("nonexistent", "/tmp/test", "data")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMkdirWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.Mkdir("nonexistent", "/tmp/newdir")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.Delete("nonexistent", "/tmp/file")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenameWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.Rename("nonexistent", "/old", "/new")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
_, err := svc.Stat("nonexistent", "/tmp")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileEntrySorting(t *testing.T) {
|
||||||
|
// Test that SortEntries puts dirs first, then alpha
|
||||||
|
entries := []FileEntry{
|
||||||
|
{Name: "zebra.txt", IsDir: false},
|
||||||
|
{Name: "alpha", IsDir: true},
|
||||||
|
{Name: "beta.conf", IsDir: false},
|
||||||
|
{Name: "omega", IsDir: true},
|
||||||
|
}
|
||||||
|
SortEntries(entries)
|
||||||
|
if entries[0].Name != "alpha" {
|
||||||
|
t.Errorf("[0] = %s, want alpha", entries[0].Name)
|
||||||
|
}
|
||||||
|
if entries[1].Name != "omega" {
|
||||||
|
t.Errorf("[1] = %s, want omega", entries[1].Name)
|
||||||
|
}
|
||||||
|
if entries[2].Name != "beta.conf" {
|
||||||
|
t.Errorf("[2] = %s, want beta.conf", entries[2].Name)
|
||||||
|
}
|
||||||
|
if entries[3].Name != "zebra.txt" {
|
||||||
|
t.Errorf("[3] = %s, want zebra.txt", entries[3].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortEntriesEmpty(t *testing.T) {
|
||||||
|
entries := []FileEntry{}
|
||||||
|
SortEntries(entries)
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Errorf("expected empty slice, got %d entries", len(entries))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortEntriesCaseInsensitive(t *testing.T) {
|
||||||
|
entries := []FileEntry{
|
||||||
|
{Name: "Zebra", IsDir: false},
|
||||||
|
{Name: "alpha", IsDir: false},
|
||||||
|
}
|
||||||
|
SortEntries(entries)
|
||||||
|
if entries[0].Name != "alpha" {
|
||||||
|
t.Errorf("[0] = %s, want alpha", entries[0].Name)
|
||||||
|
}
|
||||||
|
if entries[1].Name != "Zebra" {
|
||||||
|
t.Errorf("[1] = %s, want Zebra", entries[1].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxEditFileSize(t *testing.T) {
|
||||||
|
if MaxEditFileSize != 5*1024*1024 {
|
||||||
|
t.Errorf("MaxEditFileSize = %d, want %d", MaxEditFileSize, 5*1024*1024)
|
||||||
|
}
|
||||||
|
}
|
||||||
259
internal/ssh/service.go
Normal file
259
internal/ssh/service.go
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OutputHandler is called when data is read from an SSH session's stdout.
|
||||||
|
// In production this will emit Wails events; for testing, a simple callback.
|
||||||
|
type OutputHandler func(sessionID string, data []byte)
|
||||||
|
|
||||||
|
// SSHSession represents an active SSH connection with its PTY shell session.
|
||||||
|
type SSHSession struct {
|
||||||
|
ID string
|
||||||
|
Client *ssh.Client
|
||||||
|
Session *ssh.Session
|
||||||
|
Stdin io.WriteCloser
|
||||||
|
ConnID int64
|
||||||
|
Hostname string
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Connected time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHService manages SSH connections and their associated sessions.
|
||||||
|
type SSHService struct {
|
||||||
|
sessions map[string]*SSHSession
|
||||||
|
mu sync.RWMutex
|
||||||
|
db *sql.DB
|
||||||
|
outputHandler OutputHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSHService creates a new SSHService. The outputHandler is called when data
|
||||||
|
// arrives from a session's stdout. Pass nil if output handling is not needed.
|
||||||
|
func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService {
|
||||||
|
return &SSHService{
|
||||||
|
sessions: make(map[string]*SSHSession),
|
||||||
|
db: db,
|
||||||
|
outputHandler: outputHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect dials an SSH server, opens a session with a PTY and shell, and
|
||||||
|
// launches a goroutine to read stdout. Returns the session ID.
|
||||||
|
func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) {
|
||||||
|
addr := fmt.Sprintf("%s:%d", hostname, port)
|
||||||
|
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: username,
|
||||||
|
Auth: authMethods,
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := ssh.Dial("tcp", addr, config)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("ssh dial %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("new session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modes := ssh.TerminalModes{
|
||||||
|
ssh.ECHO: 1,
|
||||||
|
ssh.TTY_OP_ISPEED: 14400,
|
||||||
|
ssh.TTY_OP_OSPEED: 14400,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("request pty: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdin, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("stdin pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout, err := session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
session.Close()
|
||||||
|
client.Close()
|
||||||
|
return "", fmt.Errorf("start shell: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := uuid.NewString()
|
||||||
|
sshSession := &SSHSession{
|
||||||
|
ID: sessionID,
|
||||||
|
Client: client,
|
||||||
|
Session: session,
|
||||||
|
Stdin: stdin,
|
||||||
|
Hostname: hostname,
|
||||||
|
Port: port,
|
||||||
|
Username: username,
|
||||||
|
Connected: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.sessions[sessionID] = sshSession
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Launch goroutine to read stdout and forward data via the output handler
|
||||||
|
go s.readLoop(sessionID, stdout)
|
||||||
|
|
||||||
|
return sessionID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLoop continuously reads from the session stdout and calls the output
|
||||||
|
// handler with data. It stops when the reader returns an error (typically EOF
|
||||||
|
// when the session closes).
|
||||||
|
func (s *SSHService) readLoop(sessionID string, reader io.Reader) {
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
for {
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if n > 0 && s.outputHandler != nil {
|
||||||
|
data := make([]byte, n)
|
||||||
|
copy(data, buf[:n])
|
||||||
|
s.outputHandler(sessionID, data)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write sends data to the session's stdin.
|
||||||
|
func (s *SSHService) Write(sessionID string, data string) error {
|
||||||
|
s.mu.RLock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("session %s not found", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.mu.Lock()
|
||||||
|
defer sess.mu.Unlock()
|
||||||
|
|
||||||
|
if sess.Stdin == nil {
|
||||||
|
return fmt.Errorf("session %s stdin is closed", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := sess.Stdin.Write([]byte(data))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write to session %s: %w", sessionID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize sends a window-change request to the remote PTY.
|
||||||
|
func (s *SSHService) Resize(sessionID string, cols, rows int) error {
|
||||||
|
s.mu.RLock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("session %s not found", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.mu.Lock()
|
||||||
|
defer sess.mu.Unlock()
|
||||||
|
|
||||||
|
if sess.Session == nil {
|
||||||
|
return fmt.Errorf("session %s is closed", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sess.Session.WindowChange(rows, cols); err != nil {
|
||||||
|
return fmt.Errorf("resize session %s: %w", sessionID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect closes the SSH session and client, and removes it from tracking.
|
||||||
|
func (s *SSHService) Disconnect(sessionID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("session %s not found", sessionID)
|
||||||
|
}
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
sess.mu.Lock()
|
||||||
|
defer sess.mu.Unlock()
|
||||||
|
|
||||||
|
if sess.Stdin != nil {
|
||||||
|
sess.Stdin.Close()
|
||||||
|
}
|
||||||
|
if sess.Session != nil {
|
||||||
|
sess.Session.Close()
|
||||||
|
}
|
||||||
|
if sess.Client != nil {
|
||||||
|
sess.Client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession returns the SSHSession for the given ID, or false if not found.
|
||||||
|
func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
sess, ok := s.sessions[sessionID]
|
||||||
|
return sess, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSessions returns all active SSH sessions.
|
||||||
|
func (s *SSHService) ListSessions() []*SSHSession {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
list := make([]*SSHSession, 0, len(s.sessions))
|
||||||
|
for _, sess := range s.sessions {
|
||||||
|
list = append(list, sess)
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildPasswordAuth creates an ssh.AuthMethod for password authentication.
|
||||||
|
func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod {
|
||||||
|
return ssh.Password(password)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key.
|
||||||
|
// If the key is encrypted, pass the passphrase; otherwise pass an empty string.
|
||||||
|
func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) {
|
||||||
|
var signer ssh.Signer
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if passphrase != "" {
|
||||||
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
|
||||||
|
} else {
|
||||||
|
signer, err = ssh.ParsePrivateKey(privateKey)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.PublicKeys(signer), nil
|
||||||
|
}
|
||||||
148
internal/ssh/service_test.go
Normal file
148
internal/ssh/service_test.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSSHService(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
if svc == nil {
|
||||||
|
t.Fatal("NewSSHService returned nil")
|
||||||
|
}
|
||||||
|
if len(svc.ListSessions()) != 0 {
|
||||||
|
t.Error("new service should have no sessions")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPasswordAuth(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
auth := svc.BuildPasswordAuth("mypassword")
|
||||||
|
if auth == nil {
|
||||||
|
t.Error("BuildPasswordAuth returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKeyAuth(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
|
||||||
|
// Generate a test Ed25519 key
|
||||||
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateKey error: %v", err)
|
||||||
|
}
|
||||||
|
pemBlock, err := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalPrivateKey error: %v", err)
|
||||||
|
}
|
||||||
|
keyBytes := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
auth, err := svc.BuildKeyAuth(keyBytes, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BuildKeyAuth error: %v", err)
|
||||||
|
}
|
||||||
|
if auth == nil {
|
||||||
|
t.Error("BuildKeyAuth returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKeyAuthInvalidKey(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
_, err := svc.BuildKeyAuth([]byte("not a key"), "")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("BuildKeyAuth should fail with invalid key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionTracking(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
|
||||||
|
// Manually add a session to test tracking
|
||||||
|
svc.mu.Lock()
|
||||||
|
svc.sessions["test-123"] = &SSHSession{
|
||||||
|
ID: "test-123",
|
||||||
|
Hostname: "192.168.1.4",
|
||||||
|
Port: 22,
|
||||||
|
Username: "vstockwell",
|
||||||
|
Connected: time.Now(),
|
||||||
|
}
|
||||||
|
svc.mu.Unlock()
|
||||||
|
|
||||||
|
s, ok := svc.GetSession("test-123")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("session not found")
|
||||||
|
}
|
||||||
|
if s.Hostname != "192.168.1.4" {
|
||||||
|
t.Errorf("Hostname = %q, want %q", s.Hostname, "192.168.1.4")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions := svc.ListSessions()
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Errorf("ListSessions() = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSessionNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
_, ok := svc.GetSession("nonexistent")
|
||||||
|
if ok {
|
||||||
|
t.Error("GetSession should return false for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
err := svc.Write("nonexistent", "data")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Write should fail for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResizeNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
err := svc.Resize("nonexistent", 80, 24)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Resize should fail for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisconnectNotFound(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
err := svc.Disconnect("nonexistent")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Disconnect should fail for nonexistent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisconnectRemovesSession(t *testing.T) {
|
||||||
|
svc := NewSSHService(nil, nil)
|
||||||
|
|
||||||
|
// Manually add a session with nil Client/Session/Stdin (no real connection)
|
||||||
|
svc.mu.Lock()
|
||||||
|
svc.sessions["test-dc"] = &SSHSession{
|
||||||
|
ID: "test-dc",
|
||||||
|
Hostname: "10.0.0.1",
|
||||||
|
Port: 22,
|
||||||
|
Username: "admin",
|
||||||
|
Connected: time.Now(),
|
||||||
|
}
|
||||||
|
svc.mu.Unlock()
|
||||||
|
|
||||||
|
if err := svc.Disconnect("test-dc"); err != nil {
|
||||||
|
t.Fatalf("Disconnect error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := svc.GetSession("test-dc")
|
||||||
|
if ok {
|
||||||
|
t.Error("session should be removed after Disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(svc.ListSessions()) != 0 {
|
||||||
|
t.Error("ListSessions should be empty after Disconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
94
internal/theme/builtins.go
Normal file
94
internal/theme/builtins.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package theme
|
||||||
|
|
||||||
|
type Theme struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Foreground string `json:"foreground"`
|
||||||
|
Background string `json:"background"`
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Black string `json:"black"`
|
||||||
|
Red string `json:"red"`
|
||||||
|
Green string `json:"green"`
|
||||||
|
Yellow string `json:"yellow"`
|
||||||
|
Blue string `json:"blue"`
|
||||||
|
Magenta string `json:"magenta"`
|
||||||
|
Cyan string `json:"cyan"`
|
||||||
|
White string `json:"white"`
|
||||||
|
BrightBlack string `json:"brightBlack"`
|
||||||
|
BrightRed string `json:"brightRed"`
|
||||||
|
BrightGreen string `json:"brightGreen"`
|
||||||
|
BrightYellow string `json:"brightYellow"`
|
||||||
|
BrightBlue string `json:"brightBlue"`
|
||||||
|
BrightMagenta string `json:"brightMagenta"`
|
||||||
|
BrightCyan string `json:"brightCyan"`
|
||||||
|
BrightWhite string `json:"brightWhite"`
|
||||||
|
SelectionBg string `json:"selectionBg,omitempty"`
|
||||||
|
SelectionFg string `json:"selectionFg,omitempty"`
|
||||||
|
IsBuiltin bool `json:"isBuiltin"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var BuiltinThemes = []Theme{
|
||||||
|
{
|
||||||
|
Name: "Dracula", IsBuiltin: true,
|
||||||
|
Foreground: "#f8f8f2", Background: "#282a36", Cursor: "#f8f8f2",
|
||||||
|
Black: "#21222c", Red: "#ff5555", Green: "#50fa7b", Yellow: "#f1fa8c",
|
||||||
|
Blue: "#bd93f9", Magenta: "#ff79c6", Cyan: "#8be9fd", White: "#f8f8f2",
|
||||||
|
BrightBlack: "#6272a4", BrightRed: "#ff6e6e", BrightGreen: "#69ff94",
|
||||||
|
BrightYellow: "#ffffa5", BrightBlue: "#d6acff", BrightMagenta: "#ff92df",
|
||||||
|
BrightCyan: "#a4ffff", BrightWhite: "#ffffff",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Nord", IsBuiltin: true,
|
||||||
|
Foreground: "#d8dee9", Background: "#2e3440", Cursor: "#d8dee9",
|
||||||
|
Black: "#3b4252", Red: "#bf616a", Green: "#a3be8c", Yellow: "#ebcb8b",
|
||||||
|
Blue: "#81a1c1", Magenta: "#b48ead", Cyan: "#88c0d0", White: "#e5e9f0",
|
||||||
|
BrightBlack: "#4c566a", BrightRed: "#bf616a", BrightGreen: "#a3be8c",
|
||||||
|
BrightYellow: "#ebcb8b", BrightBlue: "#81a1c1", BrightMagenta: "#b48ead",
|
||||||
|
BrightCyan: "#8fbcbb", BrightWhite: "#eceff4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Monokai", IsBuiltin: true,
|
||||||
|
Foreground: "#f8f8f2", Background: "#272822", Cursor: "#f8f8f0",
|
||||||
|
Black: "#272822", Red: "#f92672", Green: "#a6e22e", Yellow: "#f4bf75",
|
||||||
|
Blue: "#66d9ef", Magenta: "#ae81ff", Cyan: "#a1efe4", White: "#f8f8f2",
|
||||||
|
BrightBlack: "#75715e", BrightRed: "#f92672", BrightGreen: "#a6e22e",
|
||||||
|
BrightYellow: "#f4bf75", BrightBlue: "#66d9ef", BrightMagenta: "#ae81ff",
|
||||||
|
BrightCyan: "#a1efe4", BrightWhite: "#f9f8f5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "One Dark", IsBuiltin: true,
|
||||||
|
Foreground: "#abb2bf", Background: "#282c34", Cursor: "#528bff",
|
||||||
|
Black: "#282c34", Red: "#e06c75", Green: "#98c379", Yellow: "#e5c07b",
|
||||||
|
Blue: "#61afef", Magenta: "#c678dd", Cyan: "#56b6c2", White: "#abb2bf",
|
||||||
|
BrightBlack: "#545862", BrightRed: "#e06c75", BrightGreen: "#98c379",
|
||||||
|
BrightYellow: "#e5c07b", BrightBlue: "#61afef", BrightMagenta: "#c678dd",
|
||||||
|
BrightCyan: "#56b6c2", BrightWhite: "#c8ccd4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Solarized Dark", IsBuiltin: true,
|
||||||
|
Foreground: "#839496", Background: "#002b36", Cursor: "#839496",
|
||||||
|
Black: "#073642", Red: "#dc322f", Green: "#859900", Yellow: "#b58900",
|
||||||
|
Blue: "#268bd2", Magenta: "#d33682", Cyan: "#2aa198", White: "#eee8d5",
|
||||||
|
BrightBlack: "#002b36", BrightRed: "#cb4b16", BrightGreen: "#586e75",
|
||||||
|
BrightYellow: "#657b83", BrightBlue: "#839496", BrightMagenta: "#6c71c4",
|
||||||
|
BrightCyan: "#93a1a1", BrightWhite: "#fdf6e3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Gruvbox Dark", IsBuiltin: true,
|
||||||
|
Foreground: "#ebdbb2", Background: "#282828", Cursor: "#ebdbb2",
|
||||||
|
Black: "#282828", Red: "#cc241d", Green: "#98971a", Yellow: "#d79921",
|
||||||
|
Blue: "#458588", Magenta: "#b16286", Cyan: "#689d6a", White: "#a89984",
|
||||||
|
BrightBlack: "#928374", BrightRed: "#fb4934", BrightGreen: "#b8bb26",
|
||||||
|
BrightYellow: "#fabd2f", BrightBlue: "#83a598", BrightMagenta: "#d3869b",
|
||||||
|
BrightCyan: "#8ec07c", BrightWhite: "#ebdbb2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "MobaXTerm Classic", IsBuiltin: true,
|
||||||
|
Foreground: "#ececec", Background: "#242424", Cursor: "#b4b4c0",
|
||||||
|
Black: "#000000", Red: "#aa4244", Green: "#7e8d53", Yellow: "#e4b46d",
|
||||||
|
Blue: "#6e9aba", Magenta: "#9e5085", Cyan: "#80d5cf", White: "#cccccc",
|
||||||
|
BrightBlack: "#808080", BrightRed: "#cc7b7d", BrightGreen: "#a5b17c",
|
||||||
|
BrightYellow: "#ecc995", BrightBlue: "#96b6cd", BrightMagenta: "#c083ac",
|
||||||
|
BrightCyan: "#a9e2de", BrightWhite: "#cccccc",
|
||||||
|
},
|
||||||
|
}
|
||||||
85
internal/theme/service.go
Normal file
85
internal/theme/service.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package theme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ThemeService struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewThemeService(db *sql.DB) *ThemeService {
|
||||||
|
return &ThemeService{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ThemeService) SeedBuiltins() error {
|
||||||
|
for _, t := range BuiltinThemes {
|
||||||
|
_, err := s.db.Exec(
|
||||||
|
`INSERT OR IGNORE INTO themes (name, foreground, background, cursor,
|
||||||
|
black, red, green, yellow, blue, magenta, cyan, white,
|
||||||
|
bright_black, bright_red, bright_green, bright_yellow, bright_blue,
|
||||||
|
bright_magenta, bright_cyan, bright_white, selection_bg, selection_fg, is_builtin)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,1)`,
|
||||||
|
t.Name, t.Foreground, t.Background, t.Cursor,
|
||||||
|
t.Black, t.Red, t.Green, t.Yellow, t.Blue, t.Magenta, t.Cyan, t.White,
|
||||||
|
t.BrightBlack, t.BrightRed, t.BrightGreen, t.BrightYellow, t.BrightBlue,
|
||||||
|
t.BrightMagenta, t.BrightCyan, t.BrightWhite, t.SelectionBg, t.SelectionFg,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("seed theme %s: %w", t.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ThemeService) List() ([]Theme, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT id, name, foreground, background, cursor,
|
||||||
|
black, red, green, yellow, blue, magenta, cyan, white,
|
||||||
|
bright_black, bright_red, bright_green, bright_yellow, bright_blue,
|
||||||
|
bright_magenta, bright_cyan, bright_white,
|
||||||
|
COALESCE(selection_bg,''), COALESCE(selection_fg,''), is_builtin
|
||||||
|
FROM themes ORDER BY is_builtin DESC, name`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var themes []Theme
|
||||||
|
for rows.Next() {
|
||||||
|
var t Theme
|
||||||
|
if err := rows.Scan(&t.ID, &t.Name, &t.Foreground, &t.Background, &t.Cursor,
|
||||||
|
&t.Black, &t.Red, &t.Green, &t.Yellow, &t.Blue, &t.Magenta, &t.Cyan, &t.White,
|
||||||
|
&t.BrightBlack, &t.BrightRed, &t.BrightGreen, &t.BrightYellow, &t.BrightBlue,
|
||||||
|
&t.BrightMagenta, &t.BrightCyan, &t.BrightWhite,
|
||||||
|
&t.SelectionBg, &t.SelectionFg, &t.IsBuiltin); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
themes = append(themes, t)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return themes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ThemeService) GetByName(name string) (*Theme, error) {
|
||||||
|
var t Theme
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
`SELECT id, name, foreground, background, cursor,
|
||||||
|
black, red, green, yellow, blue, magenta, cyan, white,
|
||||||
|
bright_black, bright_red, bright_green, bright_yellow, bright_blue,
|
||||||
|
bright_magenta, bright_cyan, bright_white,
|
||||||
|
COALESCE(selection_bg,''), COALESCE(selection_fg,''), is_builtin
|
||||||
|
FROM themes WHERE name = ?`, name,
|
||||||
|
).Scan(&t.ID, &t.Name, &t.Foreground, &t.Background, &t.Cursor,
|
||||||
|
&t.Black, &t.Red, &t.Green, &t.Yellow, &t.Blue, &t.Magenta, &t.Cyan, &t.White,
|
||||||
|
&t.BrightBlack, &t.BrightRed, &t.BrightGreen, &t.BrightYellow, &t.BrightBlue,
|
||||||
|
&t.BrightMagenta, &t.BrightCyan, &t.BrightWhite,
|
||||||
|
&t.SelectionBg, &t.SelectionFg, &t.IsBuiltin)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get theme %s: %w", name, err)
|
||||||
|
}
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
60
internal/theme/service_test.go
Normal file
60
internal/theme/service_test.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package theme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestDB(t *testing.T) *ThemeService {
|
||||||
|
t.Helper()
|
||||||
|
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(d); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { d.Close() })
|
||||||
|
return NewThemeService(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSeedBuiltins(t *testing.T) {
|
||||||
|
svc := setupTestDB(t)
|
||||||
|
if err := svc.SeedBuiltins(); err != nil {
|
||||||
|
t.Fatalf("SeedBuiltins() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
themes, err := svc.List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(themes) != len(BuiltinThemes) {
|
||||||
|
t.Errorf("len(themes) = %d, want %d", len(themes), len(BuiltinThemes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSeedBuiltinsIdempotent(t *testing.T) {
|
||||||
|
svc := setupTestDB(t)
|
||||||
|
svc.SeedBuiltins()
|
||||||
|
svc.SeedBuiltins()
|
||||||
|
|
||||||
|
themes, _ := svc.List()
|
||||||
|
if len(themes) != len(BuiltinThemes) {
|
||||||
|
t.Errorf("len(themes) = %d after double seed, want %d", len(themes), len(BuiltinThemes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetByName(t *testing.T) {
|
||||||
|
svc := setupTestDB(t)
|
||||||
|
svc.SeedBuiltins()
|
||||||
|
|
||||||
|
theme, err := svc.GetByName("Dracula")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByName() error: %v", err)
|
||||||
|
}
|
||||||
|
if theme.Background != "#282a36" {
|
||||||
|
t.Errorf("Background = %q, want %q", theme.Background, "#282a36")
|
||||||
|
}
|
||||||
|
}
|
||||||
95
internal/vault/service.go
Normal file
95
internal/vault/service.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package vault
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
argonTime = 3
|
||||||
|
argonMemory = 64 * 1024
|
||||||
|
argonThreads = 4
|
||||||
|
argonKeyLen = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
func DeriveKey(password string, salt []byte) []byte {
|
||||||
|
return argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateSalt() ([]byte, error) {
|
||||||
|
salt := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(salt); err != nil {
|
||||||
|
return nil, fmt.Errorf("generate salt: %w", err)
|
||||||
|
}
|
||||||
|
return salt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type VaultService struct {
|
||||||
|
key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVaultService(key []byte) *VaultService {
|
||||||
|
return &VaultService{key: key}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VaultService) Encrypt(plaintext string) (string, error) {
|
||||||
|
block, err := aes.NewCipher(v.key)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iv := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err := rand.Read(iv); err != nil {
|
||||||
|
return "", fmt.Errorf("generate IV: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sealed := gcm.Seal(nil, iv, []byte(plaintext), nil)
|
||||||
|
|
||||||
|
return fmt.Sprintf("v1:%s:%s", hex.EncodeToString(iv), hex.EncodeToString(sealed)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VaultService) Decrypt(encrypted string) (string, error) {
|
||||||
|
parts := strings.SplitN(encrypted, ":", 3)
|
||||||
|
if len(parts) != 3 || parts[0] != "v1" {
|
||||||
|
return "", errors.New("invalid encrypted format: expected v1:{iv}:{sealed}")
|
||||||
|
}
|
||||||
|
|
||||||
|
iv, err := hex.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode IV: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sealed, err := hex.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode sealed data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(v.key)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := gcm.Open(nil, iv, sealed, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decrypt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(plaintext), nil
|
||||||
|
}
|
||||||
104
internal/vault/service_test.go
Normal file
104
internal/vault/service_test.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package vault
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeriveKeyConsistent(t *testing.T) {
|
||||||
|
salt := []byte("test-salt-exactly-32-bytes-long!")
|
||||||
|
key1 := DeriveKey("mypassword", salt)
|
||||||
|
key2 := DeriveKey("mypassword", salt)
|
||||||
|
|
||||||
|
if len(key1) != 32 {
|
||||||
|
t.Errorf("key length = %d, want 32", len(key1))
|
||||||
|
}
|
||||||
|
if string(key1) != string(key2) {
|
||||||
|
t.Error("same password+salt produced different keys")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeriveKeyDifferentPasswords(t *testing.T) {
|
||||||
|
salt := []byte("test-salt-exactly-32-bytes-long!")
|
||||||
|
key1 := DeriveKey("password1", salt)
|
||||||
|
key2 := DeriveKey("password2", salt)
|
||||||
|
|
||||||
|
if string(key1) == string(key2) {
|
||||||
|
t.Error("different passwords produced same key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecryptRoundTrip(t *testing.T) {
|
||||||
|
key := DeriveKey("testpassword", []byte("test-salt-exactly-32-bytes-long!"))
|
||||||
|
vs := NewVaultService(key)
|
||||||
|
|
||||||
|
plaintext := "super-secret-ssh-key-data"
|
||||||
|
encrypted, err := vs.Encrypt(plaintext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encrypt() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(encrypted, "v1:") {
|
||||||
|
t.Errorf("encrypted does not start with v1: prefix: %q", encrypted[:10])
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := vs.Decrypt(encrypted)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decrypt() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decrypted != plaintext {
|
||||||
|
t.Errorf("Decrypt() = %q, want %q", decrypted, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
|
||||||
|
key := DeriveKey("testpassword", []byte("test-salt-exactly-32-bytes-long!"))
|
||||||
|
vs := NewVaultService(key)
|
||||||
|
|
||||||
|
enc1, _ := vs.Encrypt("same-data")
|
||||||
|
enc2, _ := vs.Encrypt("same-data")
|
||||||
|
|
||||||
|
if enc1 == enc2 {
|
||||||
|
t.Error("two encryptions of same data produced identical ciphertext (IV reuse)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptWrongKey(t *testing.T) {
|
||||||
|
key1 := DeriveKey("password1", []byte("test-salt-exactly-32-bytes-long!"))
|
||||||
|
key2 := DeriveKey("password2", []byte("test-salt-exactly-32-bytes-long!"))
|
||||||
|
|
||||||
|
vs1 := NewVaultService(key1)
|
||||||
|
vs2 := NewVaultService(key2)
|
||||||
|
|
||||||
|
encrypted, _ := vs1.Encrypt("secret")
|
||||||
|
_, err := vs2.Decrypt(encrypted)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Decrypt() with wrong key should return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptInvalidFormat(t *testing.T) {
|
||||||
|
key := DeriveKey("test", []byte("test-salt-exactly-32-bytes-long!"))
|
||||||
|
vs := NewVaultService(key)
|
||||||
|
|
||||||
|
_, err := vs.Decrypt("not-valid-format")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Decrypt() with invalid format should return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSalt(t *testing.T) {
|
||||||
|
salt1, err := GenerateSalt()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateSalt() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(salt1) != 32 {
|
||||||
|
t.Errorf("salt length = %d, want 32", len(salt1))
|
||||||
|
}
|
||||||
|
|
||||||
|
salt2, _ := GenerateSalt()
|
||||||
|
if string(salt1) == string(salt2) {
|
||||||
|
t.Error("two calls to GenerateSalt produced identical salt")
|
||||||
|
}
|
||||||
|
}
|
||||||
51
main.go
Normal file
51
main.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"log"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
wraithapp "github.com/vstockwell/wraith/internal/app"
|
||||||
|
"github.com/wailsapp/wails/v3/pkg/application"
|
||||||
|
)
|
||||||
|
|
||||||
|
// version is set at build time via -ldflags "-X main.version=..."
|
||||||
|
var version = "dev"
|
||||||
|
|
||||||
|
//go:embed all:frontend/dist
|
||||||
|
var assets embed.FS
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
slog.Info("starting Wraith")
|
||||||
|
|
||||||
|
wraith, err := wraithapp.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to initialize: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app := application.New(application.Options{
|
||||||
|
Name: "Wraith",
|
||||||
|
Description: "SSH + RDP + SFTP Desktop Client",
|
||||||
|
Services: []application.Service{
|
||||||
|
application.NewService(wraith),
|
||||||
|
application.NewService(wraith.Connections),
|
||||||
|
application.NewService(wraith.Themes),
|
||||||
|
application.NewService(wraith.Settings),
|
||||||
|
},
|
||||||
|
Assets: application.AssetOptions{
|
||||||
|
Handler: application.BundledAssetFileServer(assets),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
app.Window.NewWithOptions(application.WebviewWindowOptions{
|
||||||
|
Title: "Wraith",
|
||||||
|
Width: 1400,
|
||||||
|
Height: 900,
|
||||||
|
URL: "/",
|
||||||
|
BackgroundColour: application.NewRGBA(13, 17, 23, 255),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := app.Run(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user