diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..16d8ffc --- /dev/null +++ b/go.mod @@ -0,0 +1,58 @@ +module github.com/vstockwell/wraith + +go 1.26.1 + +require ( + github.com/google/uuid v1.6.0 + github.com/pkg/sftp v1.13.10 + github.com/wailsapp/wails/v3 v3.0.0-alpha.74 + golang.org/x/crypto v0.49.0 + modernc.org/sqlite v1.46.2 +) + +require ( + dario.cat/mergo v1.0.2 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/ProtonMail/go-crypto v1.3.0 // indirect + github.com/adrg/xdg v0.5.3 // indirect + github.com/bep/debounce v1.2.1 // indirect + github.com/cloudflare/circl v1.6.3 // indirect + github.com/coder/websocket v1.8.14 // indirect + github.com/cyphar/filepath-securejoin v0.6.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/ebitengine/purego v0.9.1 // indirect + github.com/emirpasic/gods v1.18.1 // indirect + github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect + github.com/go-git/go-billy/v5 v5.7.0 // indirect + github.com/go-git/go-git/v5 v5.16.4 // indirect + github.com/go-ole/go-ole v1.3.0 // indirect + github.com/godbus/dbus/v5 v5.2.2 // indirect + github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect + github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect + github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect + github.com/kevinburke/ssh_config v1.4.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/kr/fs v0.1.0 // indirect + github.com/leaanthony/go-ansi-parser v1.6.1 // indirect + github.com/leaanthony/u v1.1.1 // indirect + github.com/lmittmann/tint v1.1.2 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/pjbgf/sha1cd v0.5.0 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/samber/lo v1.52.0 // indirect + github.com/sergi/go-diff v1.4.0 // indirect + github.com/skeema/knownhosts v1.3.2 // indirect + github.com/wailsapp/go-webview2 v1.0.23 // indirect + github.com/xanzy/ssh-agent v0.3.3 // indirect + golang.org/x/net v0.51.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + gopkg.in/warnings.v0 v0.1.2 // indirect + modernc.org/libc v1.70.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0868101 --- /dev/null +++ b/go.sum @@ -0,0 +1,197 @@ +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw= +github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= +github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= +github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY= +github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0= +github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= +github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/cyphar/filepath-securejoin v0.6.1 h1:5CeZ1jPXEiYt3+Z6zqprSAgSWiggmpVyciv8syjIpVE= +github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A= +github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= +github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= +github.com/go-git/go-billy/v5 v5.7.0 h1:83lBUJhGWhYp0ngzCMSgllhUSuoHP1iEWYjsPl9nwqM= +github.com/go-git/go-billy/v5 v5.7.0/go.mod h1:/1IUejTKH8xipsAcdfcSAlUlo2J7lkYV8GTKxAT/L3E= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= +github.com/go-git/go-git/v5 v5.16.4 h1:7ajIEZHZJULcyJebDLo99bGgS0jRrOxzZG4uCk2Yb2Y= +github.com/go-git/go-git/v5 v5.16.4/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8= +github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU= +github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= +github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= +github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 h1:njuLRcjAuMKr7kI3D85AXWkw6/+v9PwtV6M6o11sWHQ= +github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1/go.mod h1:alcuEEnZsY1WQsagKhZDsoPCRoOijYqhZvPwLG0kzVs= +github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ= +github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leaanthony/go-ansi-parser v1.6.1 h1:xd8bzARK3dErqkPFtoF9F3/HgN8UQk0ed1YDKpEz01A= +github.com/leaanthony/go-ansi-parser v1.6.1/go.mod h1:+vva/2y4alzVmmIEpk9QDhA7vLC5zKDTRwfZGOp3IWU= +github.com/leaanthony/u v1.1.1 h1:TUFjwDGlNX+WuwVEzDqQwC2lOv0P4uhTQw7CMFdiK7M= +github.com/leaanthony/u v1.1.1/go.mod h1:9+o6hejoRljvZ3BzdYlVL0JYCwtnAsVuN9pVTQcaRfI= +github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w= +github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= +github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= +github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ= +github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= +github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= +github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= +github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= +github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= +github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/skeema/knownhosts v1.3.2 h1:EDL9mgf4NzwMXCTfaxSD/o/a5fxDw/xL9nkU28JjdBg= +github.com/skeema/knownhosts v1.3.2/go.mod h1:bEg3iQAuw+jyiw+484wwFJoKSLwcfd7fqRy+N0QTiow= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/wailsapp/go-webview2 v1.0.23 h1:jmv8qhz1lHibCc79bMM/a/FqOnnzOGEisLav+a0b9P0= +github.com/wailsapp/go-webview2 v1.0.23/go.mod h1:qJmWAmAmaniuKGZPWwne+uor3AHMB5PFhqiK0Bbj8kc= +github.com/wailsapp/wails/v3 v3.0.0-alpha.74 h1:wRm1EiDQtxDisXk46NtpiBH90STwfKp36NrTDwOEdxw= +github.com/wailsapp/wails/v3 v3.0.0-alpha.74/go.mod h1:4saK4A4K9970X+X7RkMwP2lyGbLogcUz54wVeq4C/V8= +github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= +github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200810151505-1b9f1253b3ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= +modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= +modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.2 h1:gkXQ6R0+AjxFC/fTDaeIVLbNLNrRoOK7YYVz5BKhTcE= +modernc.org/sqlite v1.46.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 0000000..aac1439 --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,167 @@ +package app + +import ( + "database/sql" + "encoding/hex" + "fmt" + "log/slog" + "os" + "path/filepath" + + "github.com/vstockwell/wraith/internal/connections" + "github.com/vstockwell/wraith/internal/db" + "github.com/vstockwell/wraith/internal/plugin" + "github.com/vstockwell/wraith/internal/session" + "github.com/vstockwell/wraith/internal/settings" + "github.com/vstockwell/wraith/internal/theme" + "github.com/vstockwell/wraith/internal/vault" +) + +// WraithApp is the main application struct that wires together all services +// and exposes vault management methods to the frontend via Wails bindings. +type WraithApp struct { + db *sql.DB + Vault *vault.VaultService + Settings *settings.SettingsService + Connections *connections.ConnectionService + Themes *theme.ThemeService + Sessions *session.Manager + Plugins *plugin.Registry + unlocked bool +} + +// New creates and initializes the WraithApp, opening the database, running +// migrations, creating all services, and seeding built-in themes. +func New() (*WraithApp, error) { + dataDir := dataDirectory() + dbPath := filepath.Join(dataDir, "wraith.db") + + slog.Info("opening database", "path", dbPath) + database, err := db.Open(dbPath) + if err != nil { + return nil, fmt.Errorf("open database: %w", err) + } + + if err := db.Migrate(database); err != nil { + return nil, fmt.Errorf("run migrations: %w", err) + } + + settingsSvc := settings.NewSettingsService(database) + connSvc := connections.NewConnectionService(database) + themeSvc := theme.NewThemeService(database) + sessionMgr := session.NewManager() + pluginReg := plugin.NewRegistry() + + // Seed built-in themes on every startup (INSERT OR IGNORE keeps it idempotent) + if err := themeSvc.SeedBuiltins(); err != nil { + slog.Warn("failed to seed themes", "error", err) + } + + return &WraithApp{ + db: database, + Settings: settingsSvc, + Connections: connSvc, + Themes: themeSvc, + Sessions: sessionMgr, + Plugins: pluginReg, + }, nil +} + +// dataDirectory returns the path where Wraith stores its data. +// On Windows with APPDATA set, it uses %APPDATA%\Wraith. +// On macOS/Linux with XDG_DATA_HOME or HOME, it uses the appropriate path. +// Falls back to the current working directory for development. +func dataDirectory() string { + // Windows + if appData := os.Getenv("APPDATA"); appData != "" { + return filepath.Join(appData, "Wraith") + } + + // macOS / Linux: use XDG_DATA_HOME or fallback to ~/.local/share + if home, err := os.UserHomeDir(); err == nil { + if xdg := os.Getenv("XDG_DATA_HOME"); xdg != "" { + return filepath.Join(xdg, "wraith") + } + return filepath.Join(home, ".local", "share", "wraith") + } + + // Dev fallback + return "." +} + +// IsFirstRun checks whether the vault has been set up by looking for vault_salt in settings. +func (a *WraithApp) IsFirstRun() bool { + salt, _ := a.Settings.Get("vault_salt") + return salt == "" +} + +// CreateVault sets up the vault with a master password. It generates a salt, +// derives an encryption key, and stores the salt and a check value in settings. +func (a *WraithApp) CreateVault(password string) error { + salt, err := vault.GenerateSalt() + if err != nil { + return fmt.Errorf("generate salt: %w", err) + } + + key := vault.DeriveKey(password, salt) + a.Vault = vault.NewVaultService(key) + + // Store salt as hex in settings + if err := a.Settings.Set("vault_salt", hex.EncodeToString(salt)); err != nil { + return fmt.Errorf("store salt: %w", err) + } + + // Encrypt a known check value — used to verify the password on unlock + check, err := a.Vault.Encrypt("wraith-vault-check") + if err != nil { + return fmt.Errorf("encrypt check value: %w", err) + } + if err := a.Settings.Set("vault_check", check); err != nil { + return fmt.Errorf("store check value: %w", err) + } + + a.unlocked = true + slog.Info("vault created successfully") + return nil +} + +// Unlock verifies the master password against the stored check value and +// initializes the vault service for decryption. +func (a *WraithApp) Unlock(password string) error { + saltHex, err := a.Settings.Get("vault_salt") + if err != nil || saltHex == "" { + return fmt.Errorf("vault not set up — call CreateVault first") + } + + salt, err := hex.DecodeString(saltHex) + if err != nil { + return fmt.Errorf("decode salt: %w", err) + } + + key := vault.DeriveKey(password, salt) + vs := vault.NewVaultService(key) + + // Verify by decrypting the stored check value + checkEncrypted, err := a.Settings.Get("vault_check") + if err != nil || checkEncrypted == "" { + return fmt.Errorf("vault check value missing") + } + + checkPlain, err := vs.Decrypt(checkEncrypted) + if err != nil { + return fmt.Errorf("incorrect master password") + } + if checkPlain != "wraith-vault-check" { + return fmt.Errorf("incorrect master password") + } + + a.Vault = vs + a.unlocked = true + slog.Info("vault unlocked successfully") + return nil +} + +// IsUnlocked returns whether the vault is currently unlocked. +func (a *WraithApp) IsUnlocked() bool { + return a.unlocked +} diff --git a/internal/connections/search.go b/internal/connections/search.go new file mode 100644 index 0000000..74fd03c --- /dev/null +++ b/internal/connections/search.go @@ -0,0 +1,40 @@ +package connections + +import "fmt" + +func (s *ConnectionService) Search(query string) ([]Connection, error) { + like := "%" + query + "%" + rows, err := s.db.Query( + `SELECT id, name, hostname, port, protocol, group_id, credential_id, + COALESCE(color,''), tags, COALESCE(notes,''), COALESCE(options,'{}'), + sort_order, last_connected, created_at, updated_at + FROM connections + WHERE name LIKE ? COLLATE NOCASE + OR hostname LIKE ? COLLATE NOCASE + OR tags LIKE ? COLLATE NOCASE + OR notes LIKE ? COLLATE NOCASE + ORDER BY last_connected DESC NULLS LAST, name`, + like, like, like, like, + ) + if err != nil { + return nil, fmt.Errorf("search connections: %w", err) + } + defer rows.Close() + return scanConnections(rows) +} + +func (s *ConnectionService) FilterByTag(tag string) ([]Connection, error) { + rows, err := s.db.Query( + `SELECT c.id, c.name, c.hostname, c.port, c.protocol, c.group_id, c.credential_id, + COALESCE(c.color,''), c.tags, COALESCE(c.notes,''), COALESCE(c.options,'{}'), + c.sort_order, c.last_connected, c.created_at, c.updated_at + FROM connections c, json_each(c.tags) AS t + WHERE t.value = ? + ORDER BY c.name`, tag, + ) + if err != nil { + return nil, fmt.Errorf("filter by tag: %w", err) + } + defer rows.Close() + return scanConnections(rows) +} diff --git a/internal/connections/search_test.go b/internal/connections/search_test.go new file mode 100644 index 0000000..b7acf1f --- /dev/null +++ b/internal/connections/search_test.go @@ -0,0 +1,53 @@ +package connections + +import "testing" + +func TestSearchByName(t *testing.T) { + svc := setupTestService(t) + svc.CreateConnection(CreateConnectionInput{Name: "Asgard", Hostname: "192.168.1.4", Port: 22, Protocol: "ssh"}) + svc.CreateConnection(CreateConnectionInput{Name: "Docker", Hostname: "155.254.29.221", Port: 22, Protocol: "ssh"}) + + results, err := svc.Search("asg") + if err != nil { + t.Fatalf("Search() error: %v", err) + } + if len(results) != 1 { + t.Fatalf("len(results) = %d, want 1", len(results)) + } + if results[0].Name != "Asgard" { + t.Errorf("Name = %q, want %q", results[0].Name, "Asgard") + } +} + +func TestSearchByHostname(t *testing.T) { + svc := setupTestService(t) + svc.CreateConnection(CreateConnectionInput{Name: "Asgard", Hostname: "192.168.1.4", Port: 22, Protocol: "ssh"}) + + results, _ := svc.Search("192.168") + if len(results) != 1 { + t.Errorf("len(results) = %d, want 1", len(results)) + } +} + +func TestSearchByTag(t *testing.T) { + svc := setupTestService(t) + svc.CreateConnection(CreateConnectionInput{Name: "ProdServer", Hostname: "10.0.0.1", Port: 22, Protocol: "ssh", Tags: []string{"Prod", "Linux"}}) + svc.CreateConnection(CreateConnectionInput{Name: "DevServer", Hostname: "10.0.0.2", Port: 22, Protocol: "ssh", Tags: []string{"Dev", "Linux"}}) + + results, _ := svc.Search("Prod") + if len(results) != 1 { + t.Errorf("len(results) = %d, want 1", len(results)) + } +} + +func TestFilterByTag(t *testing.T) { + svc := setupTestService(t) + svc.CreateConnection(CreateConnectionInput{Name: "A", Hostname: "10.0.0.1", Port: 22, Protocol: "ssh", Tags: []string{"Prod"}}) + svc.CreateConnection(CreateConnectionInput{Name: "B", Hostname: "10.0.0.2", Port: 22, Protocol: "ssh", Tags: []string{"Dev"}}) + svc.CreateConnection(CreateConnectionInput{Name: "C", Hostname: "10.0.0.3", Port: 22, Protocol: "ssh", Tags: []string{"Prod", "Linux"}}) + + results, _ := svc.FilterByTag("Prod") + if len(results) != 2 { + t.Errorf("len(results) = %d, want 2", len(results)) + } +} diff --git a/internal/connections/service.go b/internal/connections/service.go new file mode 100644 index 0000000..5ca5f23 --- /dev/null +++ b/internal/connections/service.go @@ -0,0 +1,361 @@ +package connections + +import ( + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" +) + +type Group struct { + ID int64 `json:"id"` + Name string `json:"name"` + ParentID *int64 `json:"parentId"` + SortOrder int `json:"sortOrder"` + Icon string `json:"icon"` + CreatedAt time.Time `json:"createdAt"` + Children []Group `json:"children,omitempty"` +} + +type Connection struct { + ID int64 `json:"id"` + Name string `json:"name"` + Hostname string `json:"hostname"` + Port int `json:"port"` + Protocol string `json:"protocol"` + GroupID *int64 `json:"groupId"` + CredentialID *int64 `json:"credentialId"` + Color string `json:"color"` + Tags []string `json:"tags"` + Notes string `json:"notes"` + Options string `json:"options"` + SortOrder int `json:"sortOrder"` + LastConnected *time.Time `json:"lastConnected"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +type CreateConnectionInput struct { + Name string `json:"name"` + Hostname string `json:"hostname"` + Port int `json:"port"` + Protocol string `json:"protocol"` + GroupID *int64 `json:"groupId"` + CredentialID *int64 `json:"credentialId"` + Color string `json:"color"` + Tags []string `json:"tags"` + Notes string `json:"notes"` + Options string `json:"options"` +} + +type UpdateConnectionInput struct { + Name *string `json:"name"` + Hostname *string `json:"hostname"` + Port *int `json:"port"` + GroupID *int64 `json:"groupId"` + CredentialID *int64 `json:"credentialId"` + Color *string `json:"color"` + Tags []string `json:"tags"` + Notes *string `json:"notes"` + Options *string `json:"options"` +} + +type ConnectionService struct { + db *sql.DB +} + +func NewConnectionService(db *sql.DB) *ConnectionService { + return &ConnectionService{db: db} +} + +// ---------- Group CRUD ---------- + +func (s *ConnectionService) CreateGroup(name string, parentID *int64) (*Group, error) { + result, err := s.db.Exec( + "INSERT INTO groups (name, parent_id) VALUES (?, ?)", + name, parentID, + ) + if err != nil { + return nil, fmt.Errorf("create group: %w", err) + } + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get group id: %w", err) + } + + var g Group + var icon sql.NullString + err = s.db.QueryRow( + "SELECT id, name, parent_id, sort_order, icon, created_at FROM groups WHERE id = ?", id, + ).Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt) + if err != nil { + return nil, fmt.Errorf("get created group: %w", err) + } + if icon.Valid { + g.Icon = icon.String + } + return &g, nil +} + +func (s *ConnectionService) ListGroups() ([]Group, error) { + rows, err := s.db.Query( + "SELECT id, name, parent_id, sort_order, icon, created_at FROM groups ORDER BY sort_order, name", + ) + if err != nil { + return nil, fmt.Errorf("list groups: %w", err) + } + defer rows.Close() + + groupMap := make(map[int64]*Group) + var allGroups []*Group + + for rows.Next() { + var g Group + var icon sql.NullString + if err := rows.Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt); err != nil { + return nil, fmt.Errorf("scan group: %w", err) + } + if icon.Valid { + g.Icon = icon.String + } + groupMap[g.ID] = &g + allGroups = append(allGroups, &g) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate groups: %w", err) + } + + // Build tree: attach children to parents, collect roots + var roots []Group + for _, g := range allGroups { + if g.ParentID != nil { + if parent, ok := groupMap[*g.ParentID]; ok { + parent.Children = append(parent.Children, *g) + } + } else { + roots = append(roots, *g) + } + } + + // Re-attach children to root copies (since we copied into roots) + for i := range roots { + if orig, ok := groupMap[roots[i].ID]; ok { + roots[i].Children = orig.Children + } + } + + if roots == nil { + roots = []Group{} + } + return roots, nil +} + +func (s *ConnectionService) DeleteGroup(id int64) error { + _, err := s.db.Exec("DELETE FROM groups WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete group: %w", err) + } + return nil +} + +// ---------- Connection CRUD ---------- + +func (s *ConnectionService) CreateConnection(input CreateConnectionInput) (*Connection, error) { + tags, err := json.Marshal(input.Tags) + if err != nil { + return nil, fmt.Errorf("marshal tags: %w", err) + } + if input.Tags == nil { + tags = []byte("[]") + } + + options := input.Options + if options == "" { + options = "{}" + } + + result, err := s.db.Exec( + `INSERT INTO connections (name, hostname, port, protocol, group_id, credential_id, color, tags, notes, options) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + input.Name, input.Hostname, input.Port, input.Protocol, + input.GroupID, input.CredentialID, input.Color, + string(tags), input.Notes, options, + ) + if err != nil { + return nil, fmt.Errorf("create connection: %w", err) + } + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get connection id: %w", err) + } + return s.GetConnection(id) +} + +func (s *ConnectionService) GetConnection(id int64) (*Connection, error) { + row := s.db.QueryRow( + `SELECT id, name, hostname, port, protocol, group_id, credential_id, + color, tags, notes, options, sort_order, last_connected, created_at, updated_at + FROM connections WHERE id = ?`, id, + ) + + var c Connection + var tagsJSON string + var color, notes, options sql.NullString + var lastConnected sql.NullTime + + err := row.Scan( + &c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol, + &c.GroupID, &c.CredentialID, + &color, &tagsJSON, ¬es, &options, + &c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("get connection: %w", err) + } + + if color.Valid { + c.Color = color.String + } + if notes.Valid { + c.Notes = notes.String + } + if options.Valid { + c.Options = options.String + } + if lastConnected.Valid { + c.LastConnected = &lastConnected.Time + } + + if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil { + c.Tags = []string{} + } + if c.Tags == nil { + c.Tags = []string{} + } + + return &c, nil +} + +func (s *ConnectionService) ListConnections() ([]Connection, error) { + rows, err := s.db.Query( + `SELECT id, name, hostname, port, protocol, group_id, credential_id, + color, tags, notes, options, sort_order, last_connected, created_at, updated_at + FROM connections ORDER BY sort_order, name`, + ) + if err != nil { + return nil, fmt.Errorf("list connections: %w", err) + } + defer rows.Close() + + return scanConnections(rows) +} + +// scanConnections is a shared helper used by ListConnections and (later) Search. +func scanConnections(rows *sql.Rows) ([]Connection, error) { + var conns []Connection + + for rows.Next() { + var c Connection + var tagsJSON string + var color, notes, options sql.NullString + var lastConnected sql.NullTime + + if err := rows.Scan( + &c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol, + &c.GroupID, &c.CredentialID, + &color, &tagsJSON, ¬es, &options, + &c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan connection: %w", err) + } + + if color.Valid { + c.Color = color.String + } + if notes.Valid { + c.Notes = notes.String + } + if options.Valid { + c.Options = options.String + } + if lastConnected.Valid { + c.LastConnected = &lastConnected.Time + } + + if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil { + c.Tags = []string{} + } + if c.Tags == nil { + c.Tags = []string{} + } + + conns = append(conns, c) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate connections: %w", err) + } + + if conns == nil { + conns = []Connection{} + } + return conns, nil +} + +func (s *ConnectionService) UpdateConnection(id int64, input UpdateConnectionInput) (*Connection, error) { + setClauses := []string{"updated_at = CURRENT_TIMESTAMP"} + args := []interface{}{} + + if input.Name != nil { + setClauses = append(setClauses, "name = ?") + args = append(args, *input.Name) + } + if input.Hostname != nil { + setClauses = append(setClauses, "hostname = ?") + args = append(args, *input.Hostname) + } + if input.Port != nil { + setClauses = append(setClauses, "port = ?") + args = append(args, *input.Port) + } + if input.GroupID != nil { + setClauses = append(setClauses, "group_id = ?") + args = append(args, *input.GroupID) + } + if input.CredentialID != nil { + setClauses = append(setClauses, "credential_id = ?") + args = append(args, *input.CredentialID) + } + if input.Tags != nil { + tags, _ := json.Marshal(input.Tags) + setClauses = append(setClauses, "tags = ?") + args = append(args, string(tags)) + } + if input.Notes != nil { + setClauses = append(setClauses, "notes = ?") + args = append(args, *input.Notes) + } + if input.Color != nil { + setClauses = append(setClauses, "color = ?") + args = append(args, *input.Color) + } + if input.Options != nil { + setClauses = append(setClauses, "options = ?") + args = append(args, *input.Options) + } + + args = append(args, id) + query := fmt.Sprintf("UPDATE connections SET %s WHERE id = ?", strings.Join(setClauses, ", ")) + if _, err := s.db.Exec(query, args...); err != nil { + return nil, fmt.Errorf("update connection: %w", err) + } + return s.GetConnection(id) +} + +func (s *ConnectionService) DeleteConnection(id int64) error { + _, err := s.db.Exec("DELETE FROM connections WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete connection: %w", err) + } + return nil +} diff --git a/internal/connections/service_test.go b/internal/connections/service_test.go new file mode 100644 index 0000000..7be0d96 --- /dev/null +++ b/internal/connections/service_test.go @@ -0,0 +1,234 @@ +package connections + +import ( + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" +) + +func strPtr(s string) *string { return &s } + +func setupTestService(t *testing.T) *ConnectionService { + t.Helper() + dir := t.TempDir() + database, err := db.Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("db.Open() error: %v", err) + } + if err := db.Migrate(database); err != nil { + t.Fatalf("db.Migrate() error: %v", err) + } + t.Cleanup(func() { database.Close() }) + return NewConnectionService(database) +} + +func TestCreateGroup(t *testing.T) { + svc := setupTestService(t) + + g, err := svc.CreateGroup("Servers", nil) + if err != nil { + t.Fatalf("CreateGroup() error: %v", err) + } + if g.ID == 0 { + t.Error("expected non-zero ID") + } + if g.Name != "Servers" { + t.Errorf("Name = %q, want %q", g.Name, "Servers") + } + if g.ParentID != nil { + t.Errorf("ParentID = %v, want nil", g.ParentID) + } +} + +func TestCreateSubGroup(t *testing.T) { + svc := setupTestService(t) + + parent, err := svc.CreateGroup("Servers", nil) + if err != nil { + t.Fatalf("CreateGroup(parent) error: %v", err) + } + + child, err := svc.CreateGroup("Production", &parent.ID) + if err != nil { + t.Fatalf("CreateGroup(child) error: %v", err) + } + if child.ParentID == nil { + t.Fatal("expected non-nil ParentID") + } + if *child.ParentID != parent.ID { + t.Errorf("ParentID = %d, want %d", *child.ParentID, parent.ID) + } +} + +func TestListGroups(t *testing.T) { + svc := setupTestService(t) + + parent, err := svc.CreateGroup("Servers", nil) + if err != nil { + t.Fatalf("CreateGroup(parent) error: %v", err) + } + if _, err := svc.CreateGroup("Production", &parent.ID); err != nil { + t.Fatalf("CreateGroup(child) error: %v", err) + } + + groups, err := svc.ListGroups() + if err != nil { + t.Fatalf("ListGroups() error: %v", err) + } + if len(groups) != 1 { + t.Fatalf("len(groups) = %d, want 1 (only root groups)", len(groups)) + } + if groups[0].Name != "Servers" { + t.Errorf("groups[0].Name = %q, want %q", groups[0].Name, "Servers") + } + if len(groups[0].Children) != 1 { + t.Fatalf("len(children) = %d, want 1", len(groups[0].Children)) + } + if groups[0].Children[0].Name != "Production" { + t.Errorf("child name = %q, want %q", groups[0].Children[0].Name, "Production") + } +} + +func TestDeleteGroup(t *testing.T) { + svc := setupTestService(t) + + g, err := svc.CreateGroup("ToDelete", nil) + if err != nil { + t.Fatalf("CreateGroup() error: %v", err) + } + + if err := svc.DeleteGroup(g.ID); err != nil { + t.Fatalf("DeleteGroup() error: %v", err) + } + + groups, err := svc.ListGroups() + if err != nil { + t.Fatalf("ListGroups() error: %v", err) + } + if len(groups) != 0 { + t.Errorf("len(groups) = %d, want 0 after delete", len(groups)) + } +} + +func TestCreateConnection(t *testing.T) { + svc := setupTestService(t) + + conn, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Web Server", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + Tags: []string{"Prod", "Linux"}, + Options: `{"keepAliveInterval": 60}`, + }) + if err != nil { + t.Fatalf("CreateConnection() error: %v", err) + } + if conn.ID == 0 { + t.Error("expected non-zero ID") + } + if conn.Name != "Web Server" { + t.Errorf("Name = %q, want %q", conn.Name, "Web Server") + } + if len(conn.Tags) != 2 { + t.Fatalf("len(Tags) = %d, want 2", len(conn.Tags)) + } + if conn.Tags[0] != "Prod" || conn.Tags[1] != "Linux" { + t.Errorf("Tags = %v, want [Prod Linux]", conn.Tags) + } + if conn.Options != `{"keepAliveInterval": 60}` { + t.Errorf("Options = %q, want JSON blob", conn.Options) + } +} + +func TestListConnections(t *testing.T) { + svc := setupTestService(t) + + if _, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Server A", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + }); err != nil { + t.Fatalf("CreateConnection(A) error: %v", err) + } + if _, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Server B", + Hostname: "10.0.0.2", + Port: 3389, + Protocol: "rdp", + }); err != nil { + t.Fatalf("CreateConnection(B) error: %v", err) + } + + conns, err := svc.ListConnections() + if err != nil { + t.Fatalf("ListConnections() error: %v", err) + } + if len(conns) != 2 { + t.Fatalf("len(conns) = %d, want 2", len(conns)) + } +} + +func TestUpdateConnection(t *testing.T) { + svc := setupTestService(t) + + conn, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Old Name", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + Tags: []string{"Dev"}, + }) + if err != nil { + t.Fatalf("CreateConnection() error: %v", err) + } + + updated, err := svc.UpdateConnection(conn.ID, UpdateConnectionInput{ + Name: strPtr("New Name"), + Tags: []string{"Prod", "Linux"}, + }) + if err != nil { + t.Fatalf("UpdateConnection() error: %v", err) + } + if updated.Name != "New Name" { + t.Errorf("Name = %q, want %q", updated.Name, "New Name") + } + if len(updated.Tags) != 2 { + t.Fatalf("len(Tags) = %d, want 2", len(updated.Tags)) + } + if updated.Tags[0] != "Prod" { + t.Errorf("Tags[0] = %q, want %q", updated.Tags[0], "Prod") + } + // Hostname should remain unchanged + if updated.Hostname != "10.0.0.1" { + t.Errorf("Hostname = %q, want %q (unchanged)", updated.Hostname, "10.0.0.1") + } +} + +func TestDeleteConnection(t *testing.T) { + svc := setupTestService(t) + + conn, err := svc.CreateConnection(CreateConnectionInput{ + Name: "ToDelete", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + }) + if err != nil { + t.Fatalf("CreateConnection() error: %v", err) + } + + if err := svc.DeleteConnection(conn.ID); err != nil { + t.Fatalf("DeleteConnection() error: %v", err) + } + + conns, err := svc.ListConnections() + if err != nil { + t.Fatalf("ListConnections() error: %v", err) + } + if len(conns) != 0 { + t.Errorf("len(conns) = %d, want 0 after delete", len(conns)) + } +} diff --git a/internal/credentials/service.go b/internal/credentials/service.go new file mode 100644 index 0000000..617de89 --- /dev/null +++ b/internal/credentials/service.go @@ -0,0 +1,408 @@ +package credentials + +import ( + "crypto/ed25519" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "database/sql" + "encoding/base64" + "encoding/pem" + "fmt" + + "github.com/vstockwell/wraith/internal/vault" + "golang.org/x/crypto/ssh" +) + +// Credential represents a stored credential (password or SSH key reference). +type Credential struct { + ID int64 `json:"id"` + Name string `json:"name"` + Username string `json:"username"` + Domain string `json:"domain"` + Type string `json:"type"` // "password" or "ssh_key" + SSHKeyID *int64 `json:"sshKeyId"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` +} + +// SSHKey represents a stored SSH key. +type SSHKey struct { + ID int64 `json:"id"` + Name string `json:"name"` + KeyType string `json:"keyType"` + Fingerprint string `json:"fingerprint"` + PublicKey string `json:"publicKey"` + CreatedAt string `json:"createdAt"` +} + +// CredentialService provides CRUD for credentials with vault encryption. +type CredentialService struct { + db *sql.DB + vault *vault.VaultService +} + +// NewCredentialService creates a new CredentialService. +func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService { + return &CredentialService{db: db, vault: vault} +} + +// CreatePassword creates a password credential (password encrypted via vault). +func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) { + encrypted, err := s.vault.Encrypt(password) + if err != nil { + return nil, fmt.Errorf("encrypt password: %w", err) + } + + result, err := s.db.Exec( + `INSERT INTO credentials (name, username, domain, type, encrypted_value) + VALUES (?, ?, ?, 'password', ?)`, + name, username, domain, encrypted, + ) + if err != nil { + return nil, fmt.Errorf("insert credential: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get credential id: %w", err) + } + + return s.getCredential(id) +} + +// CreateSSHKey imports an SSH key (private key encrypted via vault). +func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) { + // Parse the private key to detect type and extract public key + keyType := DetectKeyType(privateKeyPEM) + + // Parse the key to get the public key for fingerprinting. + // Try without passphrase first (handles unencrypted keys even when a + // passphrase is provided for storage), then fall back to using the + // passphrase for encrypted PEM keys. + var signer ssh.Signer + var err error + signer, err = ssh.ParsePrivateKey(privateKeyPEM) + if err != nil && passphrase != "" { + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase)) + } + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + pubKey := signer.PublicKey() + fingerprint := ssh.FingerprintSHA256(pubKey) + publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey)) + + // Encrypt private key via vault + encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM)) + if err != nil { + return nil, fmt.Errorf("encrypt private key: %w", err) + } + + // Encrypt passphrase via vault (if provided) + var encryptedPassphrase sql.NullString + if passphrase != "" { + ep, err := s.vault.Encrypt(passphrase) + if err != nil { + return nil, fmt.Errorf("encrypt passphrase: %w", err) + } + encryptedPassphrase = sql.NullString{String: ep, Valid: true} + } + + result, err := s.db.Exec( + `INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted) + VALUES (?, ?, ?, ?, ?, ?)`, + name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase, + ) + if err != nil { + return nil, fmt.Errorf("insert ssh key: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get ssh key id: %w", err) + } + + return s.getSSHKey(id) +} + +// ListCredentials returns all credentials WITHOUT encrypted values. +func (s *CredentialService) ListCredentials() ([]Credential, error) { + rows, err := s.db.Query( + `SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at + FROM credentials ORDER BY name`, + ) + if err != nil { + return nil, fmt.Errorf("list credentials: %w", err) + } + defer rows.Close() + + var creds []Credential + for rows.Next() { + var c Credential + var username, domain sql.NullString + if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan credential: %w", err) + } + if username.Valid { + c.Username = username.String + } + if domain.Valid { + c.Domain = domain.String + } + creds = append(creds, c) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate credentials: %w", err) + } + + if creds == nil { + creds = []Credential{} + } + return creds, nil +} + +// ListSSHKeys returns all SSH keys WITHOUT private key data. +func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) { + rows, err := s.db.Query( + `SELECT id, name, key_type, fingerprint, public_key, created_at + FROM ssh_keys ORDER BY name`, + ) + if err != nil { + return nil, fmt.Errorf("list ssh keys: %w", err) + } + defer rows.Close() + + var keys []SSHKey + for rows.Next() { + var k SSHKey + var keyType, fingerprint, publicKey sql.NullString + if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil { + return nil, fmt.Errorf("scan ssh key: %w", err) + } + if keyType.Valid { + k.KeyType = keyType.String + } + if fingerprint.Valid { + k.Fingerprint = fingerprint.String + } + if publicKey.Valid { + k.PublicKey = publicKey.String + } + keys = append(keys, k) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate ssh keys: %w", err) + } + + if keys == nil { + keys = []SSHKey{} + } + return keys, nil +} + +// DecryptPassword returns the decrypted password for a credential. +func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) { + var encrypted sql.NullString + err := s.db.QueryRow( + "SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'", + credentialID, + ).Scan(&encrypted) + if err != nil { + return "", fmt.Errorf("get encrypted password: %w", err) + } + if !encrypted.Valid { + return "", fmt.Errorf("no encrypted value for credential %d", credentialID) + } + + password, err := s.vault.Decrypt(encrypted.String) + if err != nil { + return "", fmt.Errorf("decrypt password: %w", err) + } + return password, nil +} + +// DecryptSSHKey returns the decrypted private key + passphrase. +func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) { + var encryptedKey string + var encryptedPassphrase sql.NullString + err = s.db.QueryRow( + "SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?", + sshKeyID, + ).Scan(&encryptedKey, &encryptedPassphrase) + if err != nil { + return nil, "", fmt.Errorf("get encrypted ssh key: %w", err) + } + + decryptedKey, err := s.vault.Decrypt(encryptedKey) + if err != nil { + return nil, "", fmt.Errorf("decrypt private key: %w", err) + } + + if encryptedPassphrase.Valid { + passphrase, err = s.vault.Decrypt(encryptedPassphrase.String) + if err != nil { + return nil, "", fmt.Errorf("decrypt passphrase: %w", err) + } + } + + return []byte(decryptedKey), passphrase, nil +} + +// DeleteCredential removes a credential. +func (s *CredentialService) DeleteCredential(id int64) error { + _, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete credential: %w", err) + } + return nil +} + +// DeleteSSHKey removes an SSH key. +func (s *CredentialService) DeleteSSHKey(id int64) error { + _, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete ssh key: %w", err) + } + return nil +} + +// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa). +func DetectKeyType(pemData []byte) string { + block, _ := pem.Decode(pemData) + if block == nil { + return "unknown" + } + + // Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks) + if block.Type == "OPENSSH PRIVATE KEY" { + return detectOpenSSHKeyType(block.Bytes) + } + + // Try PKCS8 format + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + switch key.(type) { + case *rsa.PrivateKey: + return "rsa" + case ed25519.PrivateKey: + return "ed25519" + case *ecdsa.PrivateKey: + return "ecdsa" + } + } + + // Try RSA PKCS1 + if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return "rsa" + } + + // Try EC + if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return "ecdsa" + } + + return "unknown" +} + +// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type. +func detectOpenSSHKeyType(data []byte) string { + // OpenSSH private key format: "openssh-key-v1\0" magic, then fields. + // We parse the key using ssh package to determine the type. + // Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer. + pemBlock := &pem.Block{ + Type: "OPENSSH PRIVATE KEY", + Bytes: data, + } + pemBytes := pem.EncodeToMemory(pemBlock) + + signer, err := ssh.ParsePrivateKey(pemBytes) + if err != nil { + return "unknown" + } + + return classifyPublicKey(signer.PublicKey()) +} + +// classifyPublicKey determines the key type from an ssh.PublicKey. +func classifyPublicKey(pub ssh.PublicKey) string { + keyType := pub.Type() + switch keyType { + case "ssh-rsa": + return "rsa" + case "ssh-ed25519": + return "ed25519" + case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521": + return "ecdsa" + default: + return keyType + } +} + +// getCredential retrieves a single credential by ID. +func (s *CredentialService) getCredential(id int64) (*Credential, error) { + var c Credential + var username, domain sql.NullString + err := s.db.QueryRow( + `SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at + FROM credentials WHERE id = ?`, id, + ).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt) + if err != nil { + return nil, fmt.Errorf("get credential: %w", err) + } + if username.Valid { + c.Username = username.String + } + if domain.Valid { + c.Domain = domain.String + } + return &c, nil +} + +// getSSHKey retrieves a single SSH key by ID (without private key data). +func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) { + var k SSHKey + var keyType, fingerprint, publicKey sql.NullString + err := s.db.QueryRow( + `SELECT id, name, key_type, fingerprint, public_key, created_at + FROM ssh_keys WHERE id = ?`, id, + ).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt) + if err != nil { + return nil, fmt.Errorf("get ssh key: %w", err) + } + if keyType.Valid { + k.KeyType = keyType.String + } + if fingerprint.Valid { + k.Fingerprint = fingerprint.String + } + if publicKey.Valid { + k.PublicKey = publicKey.String + } + return &k, nil +} + +// generateFingerprint generates an SSH fingerprint string from a public key. +func generateFingerprint(pubKey ssh.PublicKey) string { + return ssh.FingerprintSHA256(pubKey) +} + +// marshalPublicKey returns the authorized_keys format of an SSH public key. +func marshalPublicKey(pubKey ssh.PublicKey) string { + return base64.StdEncoding.EncodeToString(pubKey.Marshal()) +} + +// ecdsaCurveName returns the name for an ECDSA curve. +func ecdsaCurveName(curve elliptic.Curve) string { + switch curve { + case elliptic.P256(): + return "nistp256" + case elliptic.P384(): + return "nistp384" + case elliptic.P521(): + return "nistp521" + default: + return "unknown" + } +} diff --git a/internal/credentials/service_test.go b/internal/credentials/service_test.go new file mode 100644 index 0000000..8c7211e --- /dev/null +++ b/internal/credentials/service_test.go @@ -0,0 +1,176 @@ +package credentials + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" + "github.com/vstockwell/wraith/internal/vault" + "golang.org/x/crypto/ssh" +) + +func setupCredentialService(t *testing.T) *CredentialService { + t.Helper() + d, err := db.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + if err := db.Migrate(d); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { d.Close() }) + + salt := []byte("test-salt-exactly-32-bytes-long!") + key := vault.DeriveKey("testpassword", salt) + vs := vault.NewVaultService(key) + + return NewCredentialService(d, vs) +} + +func TestCreatePasswordCredential(t *testing.T) { + svc := setupCredentialService(t) + cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "") + if err != nil { + t.Fatal(err) + } + if cred.Name != "Test Cred" { + t.Error("wrong name") + } + if cred.Type != "password" { + t.Error("wrong type") + } +} + +func TestDecryptPassword(t *testing.T) { + svc := setupCredentialService(t) + cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "") + password, err := svc.DecryptPassword(cred.ID) + if err != nil { + t.Fatal(err) + } + if password != "mypassword" { + t.Errorf("got %q, want mypassword", password) + } +} + +func TestListCredentialsExcludesSecrets(t *testing.T) { + svc := setupCredentialService(t) + svc.CreatePassword("Cred1", "user1", "pass1", "") + svc.CreatePassword("Cred2", "user2", "pass2", "") + creds, err := svc.ListCredentials() + if err != nil { + t.Fatal(err) + } + if len(creds) != 2 { + t.Errorf("got %d, want 2", len(creds)) + } +} + +func TestCreateSSHKey(t *testing.T) { + svc := setupCredentialService(t) + // Generate a test key + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + key, err := svc.CreateSSHKey("My Key", keyPEM, "") + if err != nil { + t.Fatal(err) + } + if key.KeyType != "ed25519" { + t.Errorf("KeyType = %q, want ed25519", key.KeyType) + } + if key.Fingerprint == "" { + t.Error("fingerprint should not be empty") + } +} + +func TestDecryptSSHKey(t *testing.T) { + svc := setupCredentialService(t) + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass") + decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID) + if err != nil { + t.Fatal(err) + } + if len(decryptedKey) == 0 { + t.Error("decrypted key should not be empty") + } + if passphrase != "testpass" { + t.Errorf("passphrase = %q, want testpass", passphrase) + } +} + +func TestDetectKeyType(t *testing.T) { + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + if got := DetectKeyType(keyPEM); got != "ed25519" { + t.Errorf("got %q", got) + } +} + +func TestDeleteCredential(t *testing.T) { + svc := setupCredentialService(t) + cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "") + err := svc.DeleteCredential(cred.ID) + if err != nil { + t.Fatal(err) + } + creds, _ := svc.ListCredentials() + if len(creds) != 0 { + t.Errorf("got %d credentials, want 0", len(creds)) + } +} + +func TestDeleteSSHKey(t *testing.T) { + svc := setupCredentialService(t) + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "") + err := svc.DeleteSSHKey(key.ID) + if err != nil { + t.Fatal(err) + } + keys, _ := svc.ListSSHKeys() + if len(keys) != 0 { + t.Errorf("got %d keys, want 0", len(keys)) + } +} + +func TestListSSHKeys(t *testing.T) { + svc := setupCredentialService(t) + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + svc.CreateSSHKey("Key1", keyPEM, "") + svc.CreateSSHKey("Key2", keyPEM, "") + + keys, err := svc.ListSSHKeys() + if err != nil { + t.Fatal(err) + } + if len(keys) != 2 { + t.Errorf("got %d keys, want 2", len(keys)) + } +} + +func TestCreatePasswordWithDomain(t *testing.T) { + svc := setupCredentialService(t) + cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com") + if err != nil { + t.Fatal(err) + } + if cred.Domain != "example.com" { + t.Errorf("domain = %q, want example.com", cred.Domain) + } +} diff --git a/internal/db/migrations.go b/internal/db/migrations.go new file mode 100644 index 0000000..4e9e43b --- /dev/null +++ b/internal/db/migrations.go @@ -0,0 +1,34 @@ +package db + +import ( + "database/sql" + "embed" + "fmt" + "sort" +) + +//go:embed migrations/*.sql +var migrationFiles embed.FS + +func Migrate(db *sql.DB) error { + entries, err := migrationFiles.ReadDir("migrations") + if err != nil { + return fmt.Errorf("read migrations: %w", err) + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name() < entries[j].Name() + }) + + for _, entry := range entries { + content, err := migrationFiles.ReadFile("migrations/" + entry.Name()) + if err != nil { + return fmt.Errorf("read migration %s: %w", entry.Name(), err) + } + if _, err := db.Exec(string(content)); err != nil { + return fmt.Errorf("execute migration %s: %w", entry.Name(), err) + } + } + + return nil +} diff --git a/internal/db/migrations/001_initial.sql b/internal/db/migrations/001_initial.sql new file mode 100644 index 0000000..6ce3e4e --- /dev/null +++ b/internal/db/migrations/001_initial.sql @@ -0,0 +1,101 @@ +-- 001_initial.sql +CREATE TABLE IF NOT EXISTS groups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + parent_id INTEGER REFERENCES groups(id) ON DELETE SET NULL, + sort_order INTEGER DEFAULT 0, + icon TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS ssh_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + key_type TEXT, + fingerprint TEXT, + public_key TEXT, + encrypted_private_key TEXT NOT NULL, + passphrase_encrypted TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS credentials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + username TEXT, + domain TEXT, + type TEXT NOT NULL CHECK(type IN ('password','ssh_key')), + encrypted_value TEXT, + ssh_key_id INTEGER REFERENCES ssh_keys(id) ON DELETE SET NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS connections ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + hostname TEXT NOT NULL, + port INTEGER NOT NULL DEFAULT 22, + protocol TEXT NOT NULL CHECK(protocol IN ('ssh','rdp')), + group_id INTEGER REFERENCES groups(id) ON DELETE SET NULL, + credential_id INTEGER REFERENCES credentials(id) ON DELETE SET NULL, + color TEXT, + tags TEXT DEFAULT '[]', + notes TEXT, + options TEXT DEFAULT '{}', + sort_order INTEGER DEFAULT 0, + last_connected DATETIME, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS themes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + foreground TEXT NOT NULL, + background TEXT NOT NULL, + cursor TEXT NOT NULL, + black TEXT NOT NULL, + red TEXT NOT NULL, + green TEXT NOT NULL, + yellow TEXT NOT NULL, + blue TEXT NOT NULL, + magenta TEXT NOT NULL, + cyan TEXT NOT NULL, + white TEXT NOT NULL, + bright_black TEXT NOT NULL, + bright_red TEXT NOT NULL, + bright_green TEXT NOT NULL, + bright_yellow TEXT NOT NULL, + bright_blue TEXT NOT NULL, + bright_magenta TEXT NOT NULL, + bright_cyan TEXT NOT NULL, + bright_white TEXT NOT NULL, + selection_bg TEXT, + selection_fg TEXT, + is_builtin BOOLEAN DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS connection_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + connection_id INTEGER NOT NULL REFERENCES connections(id) ON DELETE CASCADE, + protocol TEXT NOT NULL, + connected_at DATETIME DEFAULT CURRENT_TIMESTAMP, + disconnected_at DATETIME, + duration_secs INTEGER +); + +CREATE TABLE IF NOT EXISTS host_keys ( + hostname TEXT NOT NULL, + port INTEGER NOT NULL, + key_type TEXT NOT NULL, + fingerprint TEXT NOT NULL, + raw_key TEXT, + first_seen DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (hostname, port, key_type) +); + +CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); diff --git a/internal/db/sqlite.go b/internal/db/sqlite.go new file mode 100644 index 0000000..b69c75b --- /dev/null +++ b/internal/db/sqlite.go @@ -0,0 +1,39 @@ +package db + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + + _ "modernc.org/sqlite" +) + +func Open(dbPath string) (*sql.DB, error) { + dir := filepath.Dir(dbPath) + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, fmt.Errorf("create db directory: %w", err) + } + + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("open database: %w", err) + } + + if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { + db.Close() + return nil, fmt.Errorf("set WAL mode: %w", err) + } + + if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil { + db.Close() + return nil, fmt.Errorf("set busy_timeout: %w", err) + } + + if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil { + db.Close() + return nil, fmt.Errorf("enable foreign keys: %w", err) + } + + return db, nil +} diff --git a/internal/db/sqlite_test.go b/internal/db/sqlite_test.go new file mode 100644 index 0000000..af21fe8 --- /dev/null +++ b/internal/db/sqlite_test.go @@ -0,0 +1,85 @@ +package db + +import ( + "os" + "path/filepath" + "testing" +) + +func TestOpenCreatesDatabase(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + + db, err := Open(dbPath) + if err != nil { + t.Fatalf("Open() error: %v", err) + } + defer db.Close() + + if _, err := os.Stat(dbPath); os.IsNotExist(err) { + t.Fatal("database file was not created") + } +} + +func TestOpenSetsWALMode(t *testing.T) { + dir := t.TempDir() + db, err := Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("Open() error: %v", err) + } + defer db.Close() + + var mode string + err = db.QueryRow("PRAGMA journal_mode").Scan(&mode) + if err != nil { + t.Fatalf("PRAGMA query error: %v", err) + } + if mode != "wal" { + t.Errorf("journal_mode = %q, want %q", mode, "wal") + } +} + +func TestOpenSetsBusyTimeout(t *testing.T) { + dir := t.TempDir() + db, err := Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("Open() error: %v", err) + } + defer db.Close() + + var timeout int + err = db.QueryRow("PRAGMA busy_timeout").Scan(&timeout) + if err != nil { + t.Fatalf("PRAGMA query error: %v", err) + } + if timeout != 5000 { + t.Errorf("busy_timeout = %d, want %d", timeout, 5000) + } +} + +func TestMigrateCreatesAllTables(t *testing.T) { + dir := t.TempDir() + db, err := Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("Open() error: %v", err) + } + defer db.Close() + + if err := Migrate(db); err != nil { + t.Fatalf("Migrate() error: %v", err) + } + + expectedTables := []string{ + "groups", "connections", "credentials", "ssh_keys", + "themes", "connection_history", "host_keys", "settings", + } + for _, table := range expectedTables { + var name string + err := db.QueryRow( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", table, + ).Scan(&name) + if err != nil { + t.Errorf("table %q not found: %v", table, err) + } + } +} diff --git a/internal/plugin/interfaces.go b/internal/plugin/interfaces.go new file mode 100644 index 0000000..2235e1c --- /dev/null +++ b/internal/plugin/interfaces.go @@ -0,0 +1,57 @@ +package plugin + +type ProtocolHandler interface { + Name() string + Connect(config map[string]interface{}) (Session, error) + Disconnect(sessionID string) error +} + +type Session interface { + ID() string + Protocol() string + Write(data []byte) error + Close() error +} + +type Importer interface { + Name() string + FileExtensions() []string + Parse(data []byte) (*ImportResult, error) +} + +type ImportResult struct { + Groups []ImportGroup `json:"groups"` + Connections []ImportConnection `json:"connections"` + HostKeys []ImportHostKey `json:"hostKeys"` + Theme *ImportTheme `json:"theme,omitempty"` +} + +type ImportGroup struct { + Name string `json:"name"` + ParentName string `json:"parentName,omitempty"` +} + +type ImportConnection struct { + Name string `json:"name"` + Hostname string `json:"hostname"` + Port int `json:"port"` + Protocol string `json:"protocol"` + Username string `json:"username"` + GroupName string `json:"groupName"` + Notes string `json:"notes"` +} + +type ImportHostKey struct { + Hostname string `json:"hostname"` + Port int `json:"port"` + KeyType string `json:"keyType"` + Fingerprint string `json:"fingerprint"` +} + +type ImportTheme struct { + Name string `json:"name"` + Foreground string `json:"foreground"` + Background string `json:"background"` + Cursor string `json:"cursor"` + Colors [16]string `json:"colors"` +} diff --git a/internal/plugin/registry.go b/internal/plugin/registry.go new file mode 100644 index 0000000..d42a99f --- /dev/null +++ b/internal/plugin/registry.go @@ -0,0 +1,47 @@ +package plugin + +import "fmt" + +type Registry struct { + protocols map[string]ProtocolHandler + importers map[string]Importer +} + +func NewRegistry() *Registry { + return &Registry{ + protocols: make(map[string]ProtocolHandler), + importers: make(map[string]Importer), + } +} + +func (r *Registry) RegisterProtocol(handler ProtocolHandler) { + r.protocols[handler.Name()] = handler +} + +func (r *Registry) RegisterImporter(imp Importer) { + r.importers[imp.Name()] = imp +} + +func (r *Registry) GetProtocol(name string) (ProtocolHandler, error) { + h, ok := r.protocols[name] + if !ok { + return nil, fmt.Errorf("protocol handler %q not registered", name) + } + return h, nil +} + +func (r *Registry) GetImporter(name string) (Importer, error) { + imp, ok := r.importers[name] + if !ok { + return nil, fmt.Errorf("importer %q not registered", name) + } + return imp, nil +} + +func (r *Registry) ListProtocols() []string { + names := make([]string, 0, len(r.protocols)) + for name := range r.protocols { + names = append(names, name) + } + return names +} diff --git a/internal/session/manager.go b/internal/session/manager.go new file mode 100644 index 0000000..74e760c --- /dev/null +++ b/internal/session/manager.go @@ -0,0 +1,96 @@ +package session + +import ( + "fmt" + "sync" + + "github.com/google/uuid" +) + +const MaxSessions = 32 + +type Manager struct { + mu sync.RWMutex + sessions map[string]*SessionInfo +} + +func NewManager() *Manager { + return &Manager{ + sessions: make(map[string]*SessionInfo), + } +} + +func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.sessions) >= MaxSessions { + return nil, fmt.Errorf("maximum sessions (%d) reached", MaxSessions) + } + + s := &SessionInfo{ + ID: uuid.NewString(), + ConnectionID: connectionID, + Protocol: protocol, + State: StateConnecting, + TabPosition: len(m.sessions), + } + m.sessions[s.ID] = s + return s, nil +} + +func (m *Manager) Get(id string) (*SessionInfo, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + s, ok := m.sessions[id] + return s, ok +} + +func (m *Manager) List() []*SessionInfo { + m.mu.RLock() + defer m.mu.RUnlock() + list := make([]*SessionInfo, 0, len(m.sessions)) + for _, s := range m.sessions { + list = append(list, s) + } + return list +} + +func (m *Manager) SetState(id string, state SessionState) error { + m.mu.Lock() + defer m.mu.Unlock() + s, ok := m.sessions[id] + if !ok { + return fmt.Errorf("session %s not found", id) + } + s.State = state + return nil +} + +func (m *Manager) Detach(id string) error { + return m.SetState(id, StateDetached) +} + +func (m *Manager) Reattach(id, windowID string) error { + m.mu.Lock() + defer m.mu.Unlock() + s, ok := m.sessions[id] + if !ok { + return fmt.Errorf("session %s not found", id) + } + s.State = StateConnected + s.WindowID = windowID + return nil +} + +func (m *Manager) Remove(id string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.sessions, id) +} + +func (m *Manager) Count() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.sessions) +} diff --git a/internal/session/manager_test.go b/internal/session/manager_test.go new file mode 100644 index 0000000..86b4370 --- /dev/null +++ b/internal/session/manager_test.go @@ -0,0 +1,64 @@ +package session + +import "testing" + +func TestCreateSession(t *testing.T) { + m := NewManager() + s, err := m.Create(1, "ssh") + if err != nil { + t.Fatalf("Create() error: %v", err) + } + if s.ID == "" { + t.Error("session ID should not be empty") + } + if s.State != StateConnecting { + t.Errorf("State = %q, want %q", s.State, StateConnecting) + } +} + +func TestMaxSessions(t *testing.T) { + m := NewManager() + for i := 0; i < MaxSessions; i++ { + _, err := m.Create(int64(i), "ssh") + if err != nil { + t.Fatalf("Create() error at %d: %v", i, err) + } + } + _, err := m.Create(999, "ssh") + if err == nil { + t.Error("Create() should fail at max sessions") + } +} + +func TestDetachReattach(t *testing.T) { + m := NewManager() + s, _ := m.Create(1, "ssh") + m.SetState(s.ID, StateConnected) + + if err := m.Detach(s.ID); err != nil { + t.Fatalf("Detach() error: %v", err) + } + + got, _ := m.Get(s.ID) + if got.State != StateDetached { + t.Errorf("State = %q, want %q", got.State, StateDetached) + } + + if err := m.Reattach(s.ID, "window-1"); err != nil { + t.Fatalf("Reattach() error: %v", err) + } + + got, _ = m.Get(s.ID) + if got.State != StateConnected { + t.Errorf("State = %q, want %q", got.State, StateConnected) + } +} + +func TestRemoveSession(t *testing.T) { + m := NewManager() + s, _ := m.Create(1, "ssh") + m.Remove(s.ID) + if m.Count() != 0 { + t.Error("session should have been removed") + } +} diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..3850231 --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,22 @@ +package session + +import "time" + +type SessionState string + +const ( + StateConnecting SessionState = "connecting" + StateConnected SessionState = "connected" + StateDisconnected SessionState = "disconnected" + StateDetached SessionState = "detached" +) + +type SessionInfo struct { + ID string `json:"id"` + ConnectionID int64 `json:"connectionId"` + Protocol string `json:"protocol"` + State SessionState `json:"state"` + WindowID string `json:"windowId"` + TabPosition int `json:"tabPosition"` + ConnectedAt time.Time `json:"connectedAt"` +} diff --git a/internal/settings/service.go b/internal/settings/service.go new file mode 100644 index 0000000..9669207 --- /dev/null +++ b/internal/settings/service.go @@ -0,0 +1,41 @@ +package settings + +import "database/sql" + +type SettingsService struct { + db *sql.DB +} + +func NewSettingsService(db *sql.DB) *SettingsService { + return &SettingsService{db: db} +} + +func (s *SettingsService) Get(key string) (string, error) { + var value string + err := s.db.QueryRow("SELECT value FROM settings WHERE key = ?", key).Scan(&value) + if err == sql.ErrNoRows { + return "", nil + } + return value, err +} + +func (s *SettingsService) GetDefault(key, defaultValue string) string { + val, err := s.Get(key) + if err != nil || val == "" { + return defaultValue + } + return val +} + +func (s *SettingsService) Set(key, value string) error { + _, err := s.db.Exec( + "INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = ?", + key, value, value, + ) + return err +} + +func (s *SettingsService) Delete(key string) error { + _, err := s.db.Exec("DELETE FROM settings WHERE key = ?", key) + return err +} diff --git a/internal/settings/service_test.go b/internal/settings/service_test.go new file mode 100644 index 0000000..9d6e27f --- /dev/null +++ b/internal/settings/service_test.go @@ -0,0 +1,70 @@ +package settings + +import ( + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" +) + +func setupTestDB(t *testing.T) *SettingsService { + t.Helper() + d, err := db.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + if err := db.Migrate(d); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { d.Close() }) + return NewSettingsService(d) +} + +func TestSetAndGet(t *testing.T) { + s := setupTestDB(t) + + if err := s.Set("theme", "dracula"); err != nil { + t.Fatalf("Set() error: %v", err) + } + + val, err := s.Get("theme") + if err != nil { + t.Fatalf("Get() error: %v", err) + } + if val != "dracula" { + t.Errorf("Get() = %q, want %q", val, "dracula") + } +} + +func TestGetMissing(t *testing.T) { + s := setupTestDB(t) + + val, err := s.Get("nonexistent") + if err != nil { + t.Fatalf("Get() error: %v", err) + } + if val != "" { + t.Errorf("Get() = %q, want empty string", val) + } +} + +func TestSetOverwrites(t *testing.T) { + s := setupTestDB(t) + + s.Set("key", "value1") + s.Set("key", "value2") + + val, _ := s.Get("key") + if val != "value2" { + t.Errorf("Get() = %q, want %q", val, "value2") + } +} + +func TestGetWithDefault(t *testing.T) { + s := setupTestDB(t) + + val := s.GetDefault("missing", "fallback") + if val != "fallback" { + t.Errorf("GetDefault() = %q, want %q", val, "fallback") + } +} diff --git a/internal/sftp/service.go b/internal/sftp/service.go new file mode 100644 index 0000000..a027fdb --- /dev/null +++ b/internal/sftp/service.go @@ -0,0 +1,238 @@ +package sftp + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + "sync" + + "github.com/pkg/sftp" +) + +const MaxEditFileSize = 5 * 1024 * 1024 // 5MB + +// FileEntry represents a file or directory in a remote filesystem. +type FileEntry struct { + Name string `json:"name"` + Path string `json:"path"` + Size int64 `json:"size"` + IsDir bool `json:"isDir"` + Permissions string `json:"permissions"` + ModTime string `json:"modTime"` +} + +// SFTPService manages SFTP clients keyed by session ID. +type SFTPService struct { + clients map[string]*sftp.Client + mu sync.RWMutex +} + +// NewSFTPService creates a new SFTPService. +func NewSFTPService() *SFTPService { + return &SFTPService{ + clients: make(map[string]*sftp.Client), + } +} + +// RegisterClient stores an SFTP client for a session. +func (s *SFTPService) RegisterClient(sessionID string, client *sftp.Client) { + s.mu.Lock() + defer s.mu.Unlock() + s.clients[sessionID] = client +} + +// RemoveClient removes and closes an SFTP client. +func (s *SFTPService) RemoveClient(sessionID string) { + s.mu.Lock() + client, ok := s.clients[sessionID] + if ok { + delete(s.clients, sessionID) + } + s.mu.Unlock() + + if ok && client != nil { + client.Close() + } +} + +// getClient returns the SFTP client for a session or an error if not found. +func (s *SFTPService) getClient(sessionID string) (*sftp.Client, error) { + s.mu.RLock() + defer s.mu.RUnlock() + client, ok := s.clients[sessionID] + if !ok { + return nil, fmt.Errorf("no SFTP client for session %s", sessionID) + } + return client, nil +} + +// SortEntries sorts file entries with directories first, then alphabetically by name. +func SortEntries(entries []FileEntry) { + sort.Slice(entries, func(i, j int) bool { + if entries[i].IsDir != entries[j].IsDir { + return entries[i].IsDir + } + return strings.ToLower(entries[i].Name) < strings.ToLower(entries[j].Name) + }) +} + +// fileInfoToEntry converts an os.FileInfo and its path into a FileEntry. +func fileInfoToEntry(info os.FileInfo, path string) FileEntry { + return FileEntry{ + Name: info.Name(), + Path: path, + Size: info.Size(), + IsDir: info.IsDir(), + Permissions: info.Mode().Perm().String(), + ModTime: info.ModTime().UTC().Format("2006-01-02T15:04:05Z"), + } +} + +// List returns directory contents sorted (dirs first, then files alphabetically). +func (s *SFTPService) List(sessionID string, path string) ([]FileEntry, error) { + client, err := s.getClient(sessionID) + if err != nil { + return nil, err + } + + infos, err := client.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("read directory %s: %w", path, err) + } + + entries := make([]FileEntry, 0, len(infos)) + for _, info := range infos { + entryPath := path + if !strings.HasSuffix(entryPath, "/") { + entryPath += "/" + } + entryPath += info.Name() + entries = append(entries, fileInfoToEntry(info, entryPath)) + } + + SortEntries(entries) + return entries, nil +} + +// ReadFile reads a file (max 5MB). Returns content as string. +func (s *SFTPService) ReadFile(sessionID string, path string) (string, error) { + client, err := s.getClient(sessionID) + if err != nil { + return "", err + } + + info, err := client.Stat(path) + if err != nil { + return "", fmt.Errorf("stat %s: %w", path, err) + } + + if info.IsDir() { + return "", fmt.Errorf("%s is a directory", path) + } + + if info.Size() > MaxEditFileSize { + return "", fmt.Errorf("file %s is %d bytes, exceeds max edit size of %d bytes", path, info.Size(), MaxEditFileSize) + } + + f, err := client.Open(path) + if err != nil { + return "", fmt.Errorf("open %s: %w", path, err) + } + defer f.Close() + + data, err := io.ReadAll(f) + if err != nil { + return "", fmt.Errorf("read %s: %w", path, err) + } + + return string(data), nil +} + +// WriteFile writes content to a file. +func (s *SFTPService) WriteFile(sessionID string, path string, content string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + f, err := client.Create(path) + if err != nil { + return fmt.Errorf("create %s: %w", path, err) + } + defer f.Close() + + if _, err := f.Write([]byte(content)); err != nil { + return fmt.Errorf("write %s: %w", path, err) + } + + return nil +} + +// Mkdir creates a directory. +func (s *SFTPService) Mkdir(sessionID string, path string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + if err := client.Mkdir(path); err != nil { + return fmt.Errorf("mkdir %s: %w", path, err) + } + return nil +} + +// Delete removes a file or empty directory. +func (s *SFTPService) Delete(sessionID string, path string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + info, err := client.Stat(path) + if err != nil { + return fmt.Errorf("stat %s: %w", path, err) + } + + if info.IsDir() { + if err := client.RemoveDirectory(path); err != nil { + return fmt.Errorf("remove directory %s: %w", path, err) + } + } else { + if err := client.Remove(path); err != nil { + return fmt.Errorf("remove %s: %w", path, err) + } + } + + return nil +} + +// Rename renames/moves a file. +func (s *SFTPService) Rename(sessionID string, oldPath, newPath string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + if err := client.Rename(oldPath, newPath); err != nil { + return fmt.Errorf("rename %s to %s: %w", oldPath, newPath, err) + } + return nil +} + +// Stat returns info about a file/directory. +func (s *SFTPService) Stat(sessionID string, path string) (*FileEntry, error) { + client, err := s.getClient(sessionID) + if err != nil { + return nil, err + } + + info, err := client.Stat(path) + if err != nil { + return nil, fmt.Errorf("stat %s: %w", path, err) + } + + entry := fileInfoToEntry(info, path) + return &entry, nil +} diff --git a/internal/sftp/service_test.go b/internal/sftp/service_test.go new file mode 100644 index 0000000..6ce5975 --- /dev/null +++ b/internal/sftp/service_test.go @@ -0,0 +1,119 @@ +package sftp + +import ( + "testing" +) + +func TestNewSFTPService(t *testing.T) { + svc := NewSFTPService() + if svc == nil { + t.Fatal("nil") + } +} + +func TestListWithoutClient(t *testing.T) { + svc := NewSFTPService() + _, err := svc.List("nonexistent", "/") + if err == nil { + t.Error("should error without client") + } +} + +func TestReadFileWithoutClient(t *testing.T) { + svc := NewSFTPService() + _, err := svc.ReadFile("nonexistent", "/etc/hosts") + if err == nil { + t.Error("should error without client") + } +} + +func TestWriteFileWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.WriteFile("nonexistent", "/tmp/test", "data") + if err == nil { + t.Error("should error without client") + } +} + +func TestMkdirWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.Mkdir("nonexistent", "/tmp/newdir") + if err == nil { + t.Error("should error without client") + } +} + +func TestDeleteWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.Delete("nonexistent", "/tmp/file") + if err == nil { + t.Error("should error without client") + } +} + +func TestRenameWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.Rename("nonexistent", "/old", "/new") + if err == nil { + t.Error("should error without client") + } +} + +func TestStatWithoutClient(t *testing.T) { + svc := NewSFTPService() + _, err := svc.Stat("nonexistent", "/tmp") + if err == nil { + t.Error("should error without client") + } +} + +func TestFileEntrySorting(t *testing.T) { + // Test that SortEntries puts dirs first, then alpha + entries := []FileEntry{ + {Name: "zebra.txt", IsDir: false}, + {Name: "alpha", IsDir: true}, + {Name: "beta.conf", IsDir: false}, + {Name: "omega", IsDir: true}, + } + SortEntries(entries) + if entries[0].Name != "alpha" { + t.Errorf("[0] = %s, want alpha", entries[0].Name) + } + if entries[1].Name != "omega" { + t.Errorf("[1] = %s, want omega", entries[1].Name) + } + if entries[2].Name != "beta.conf" { + t.Errorf("[2] = %s, want beta.conf", entries[2].Name) + } + if entries[3].Name != "zebra.txt" { + t.Errorf("[3] = %s, want zebra.txt", entries[3].Name) + } +} + +func TestSortEntriesEmpty(t *testing.T) { + entries := []FileEntry{} + SortEntries(entries) + if len(entries) != 0 { + t.Errorf("expected empty slice, got %d entries", len(entries)) + } +} + +func TestSortEntriesCaseInsensitive(t *testing.T) { + entries := []FileEntry{ + {Name: "Zebra", IsDir: false}, + {Name: "alpha", IsDir: false}, + } + SortEntries(entries) + if entries[0].Name != "alpha" { + t.Errorf("[0] = %s, want alpha", entries[0].Name) + } + if entries[1].Name != "Zebra" { + t.Errorf("[1] = %s, want Zebra", entries[1].Name) + } +} + +func TestMaxEditFileSize(t *testing.T) { + if MaxEditFileSize != 5*1024*1024 { + t.Errorf("MaxEditFileSize = %d, want %d", MaxEditFileSize, 5*1024*1024) + } +} diff --git a/internal/ssh/service.go b/internal/ssh/service.go new file mode 100644 index 0000000..be88557 --- /dev/null +++ b/internal/ssh/service.go @@ -0,0 +1,259 @@ +package ssh + +import ( + "database/sql" + "fmt" + "io" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/crypto/ssh" +) + +// OutputHandler is called when data is read from an SSH session's stdout. +// In production this will emit Wails events; for testing, a simple callback. +type OutputHandler func(sessionID string, data []byte) + +// SSHSession represents an active SSH connection with its PTY shell session. +type SSHSession struct { + ID string + Client *ssh.Client + Session *ssh.Session + Stdin io.WriteCloser + ConnID int64 + Hostname string + Port int + Username string + Connected time.Time + mu sync.Mutex +} + +// SSHService manages SSH connections and their associated sessions. +type SSHService struct { + sessions map[string]*SSHSession + mu sync.RWMutex + db *sql.DB + outputHandler OutputHandler +} + +// NewSSHService creates a new SSHService. The outputHandler is called when data +// arrives from a session's stdout. Pass nil if output handling is not needed. +func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService { + return &SSHService{ + sessions: make(map[string]*SSHSession), + db: db, + outputHandler: outputHandler, + } +} + +// Connect dials an SSH server, opens a session with a PTY and shell, and +// launches a goroutine to read stdout. Returns the session ID. +func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) { + addr := fmt.Sprintf("%s:%d", hostname, port) + + config := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 15 * time.Second, + } + + client, err := ssh.Dial("tcp", addr, config) + if err != nil { + return "", fmt.Errorf("ssh dial %s: %w", addr, err) + } + + session, err := client.NewSession() + if err != nil { + client.Close() + return "", fmt.Errorf("new session: %w", err) + } + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } + + if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("request pty: %w", err) + } + + stdin, err := session.StdinPipe() + if err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("stdin pipe: %w", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("stdout pipe: %w", err) + } + + if err := session.Shell(); err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("start shell: %w", err) + } + + sessionID := uuid.NewString() + sshSession := &SSHSession{ + ID: sessionID, + Client: client, + Session: session, + Stdin: stdin, + Hostname: hostname, + Port: port, + Username: username, + Connected: time.Now(), + } + + s.mu.Lock() + s.sessions[sessionID] = sshSession + s.mu.Unlock() + + // Launch goroutine to read stdout and forward data via the output handler + go s.readLoop(sessionID, stdout) + + return sessionID, nil +} + +// readLoop continuously reads from the session stdout and calls the output +// handler with data. It stops when the reader returns an error (typically EOF +// when the session closes). +func (s *SSHService) readLoop(sessionID string, reader io.Reader) { + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if n > 0 && s.outputHandler != nil { + data := make([]byte, n) + copy(data, buf[:n]) + s.outputHandler(sessionID, data) + } + if err != nil { + break + } + } +} + +// Write sends data to the session's stdin. +func (s *SSHService) Write(sessionID string, data string) error { + s.mu.RLock() + sess, ok := s.sessions[sessionID] + s.mu.RUnlock() + + if !ok { + return fmt.Errorf("session %s not found", sessionID) + } + + sess.mu.Lock() + defer sess.mu.Unlock() + + if sess.Stdin == nil { + return fmt.Errorf("session %s stdin is closed", sessionID) + } + + _, err := sess.Stdin.Write([]byte(data)) + if err != nil { + return fmt.Errorf("write to session %s: %w", sessionID, err) + } + return nil +} + +// Resize sends a window-change request to the remote PTY. +func (s *SSHService) Resize(sessionID string, cols, rows int) error { + s.mu.RLock() + sess, ok := s.sessions[sessionID] + s.mu.RUnlock() + + if !ok { + return fmt.Errorf("session %s not found", sessionID) + } + + sess.mu.Lock() + defer sess.mu.Unlock() + + if sess.Session == nil { + return fmt.Errorf("session %s is closed", sessionID) + } + + if err := sess.Session.WindowChange(rows, cols); err != nil { + return fmt.Errorf("resize session %s: %w", sessionID, err) + } + return nil +} + +// Disconnect closes the SSH session and client, and removes it from tracking. +func (s *SSHService) Disconnect(sessionID string) error { + s.mu.Lock() + sess, ok := s.sessions[sessionID] + if !ok { + s.mu.Unlock() + return fmt.Errorf("session %s not found", sessionID) + } + delete(s.sessions, sessionID) + s.mu.Unlock() + + sess.mu.Lock() + defer sess.mu.Unlock() + + if sess.Stdin != nil { + sess.Stdin.Close() + } + if sess.Session != nil { + sess.Session.Close() + } + if sess.Client != nil { + sess.Client.Close() + } + + return nil +} + +// GetSession returns the SSHSession for the given ID, or false if not found. +func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[sessionID] + return sess, ok +} + +// ListSessions returns all active SSH sessions. +func (s *SSHService) ListSessions() []*SSHSession { + s.mu.RLock() + defer s.mu.RUnlock() + list := make([]*SSHSession, 0, len(s.sessions)) + for _, sess := range s.sessions { + list = append(list, sess) + } + return list +} + +// BuildPasswordAuth creates an ssh.AuthMethod for password authentication. +func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod { + return ssh.Password(password) +} + +// BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key. +// If the key is encrypted, pass the passphrase; otherwise pass an empty string. +func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) { + var signer ssh.Signer + var err error + + if passphrase != "" { + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase)) + } else { + signer, err = ssh.ParsePrivateKey(privateKey) + } + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + return ssh.PublicKeys(signer), nil +} diff --git a/internal/ssh/service_test.go b/internal/ssh/service_test.go new file mode 100644 index 0000000..0a880b4 --- /dev/null +++ b/internal/ssh/service_test.go @@ -0,0 +1,148 @@ +package ssh + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func TestNewSSHService(t *testing.T) { + svc := NewSSHService(nil, nil) + if svc == nil { + t.Fatal("NewSSHService returned nil") + } + if len(svc.ListSessions()) != 0 { + t.Error("new service should have no sessions") + } +} + +func TestBuildPasswordAuth(t *testing.T) { + svc := NewSSHService(nil, nil) + auth := svc.BuildPasswordAuth("mypassword") + if auth == nil { + t.Error("BuildPasswordAuth returned nil") + } +} + +func TestBuildKeyAuth(t *testing.T) { + svc := NewSSHService(nil, nil) + + // Generate a test Ed25519 key + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey error: %v", err) + } + pemBlock, err := ssh.MarshalPrivateKey(priv, "") + if err != nil { + t.Fatalf("MarshalPrivateKey error: %v", err) + } + keyBytes := pem.EncodeToMemory(pemBlock) + + auth, err := svc.BuildKeyAuth(keyBytes, "") + if err != nil { + t.Fatalf("BuildKeyAuth error: %v", err) + } + if auth == nil { + t.Error("BuildKeyAuth returned nil") + } +} + +func TestBuildKeyAuthInvalidKey(t *testing.T) { + svc := NewSSHService(nil, nil) + _, err := svc.BuildKeyAuth([]byte("not a key"), "") + if err == nil { + t.Error("BuildKeyAuth should fail with invalid key") + } +} + +func TestSessionTracking(t *testing.T) { + svc := NewSSHService(nil, nil) + + // Manually add a session to test tracking + svc.mu.Lock() + svc.sessions["test-123"] = &SSHSession{ + ID: "test-123", + Hostname: "192.168.1.4", + Port: 22, + Username: "vstockwell", + Connected: time.Now(), + } + svc.mu.Unlock() + + s, ok := svc.GetSession("test-123") + if !ok { + t.Fatal("session not found") + } + if s.Hostname != "192.168.1.4" { + t.Errorf("Hostname = %q, want %q", s.Hostname, "192.168.1.4") + } + + sessions := svc.ListSessions() + if len(sessions) != 1 { + t.Errorf("ListSessions() = %d, want 1", len(sessions)) + } +} + +func TestGetSessionNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + _, ok := svc.GetSession("nonexistent") + if ok { + t.Error("GetSession should return false for nonexistent session") + } +} + +func TestWriteNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + err := svc.Write("nonexistent", "data") + if err == nil { + t.Error("Write should fail for nonexistent session") + } +} + +func TestResizeNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + err := svc.Resize("nonexistent", 80, 24) + if err == nil { + t.Error("Resize should fail for nonexistent session") + } +} + +func TestDisconnectNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + err := svc.Disconnect("nonexistent") + if err == nil { + t.Error("Disconnect should fail for nonexistent session") + } +} + +func TestDisconnectRemovesSession(t *testing.T) { + svc := NewSSHService(nil, nil) + + // Manually add a session with nil Client/Session/Stdin (no real connection) + svc.mu.Lock() + svc.sessions["test-dc"] = &SSHSession{ + ID: "test-dc", + Hostname: "10.0.0.1", + Port: 22, + Username: "admin", + Connected: time.Now(), + } + svc.mu.Unlock() + + if err := svc.Disconnect("test-dc"); err != nil { + t.Fatalf("Disconnect error: %v", err) + } + + _, ok := svc.GetSession("test-dc") + if ok { + t.Error("session should be removed after Disconnect") + } + + if len(svc.ListSessions()) != 0 { + t.Error("ListSessions should be empty after Disconnect") + } +} diff --git a/internal/theme/builtins.go b/internal/theme/builtins.go new file mode 100644 index 0000000..c216685 --- /dev/null +++ b/internal/theme/builtins.go @@ -0,0 +1,94 @@ +package theme + +type Theme struct { + ID int64 `json:"id"` + Name string `json:"name"` + Foreground string `json:"foreground"` + Background string `json:"background"` + Cursor string `json:"cursor"` + Black string `json:"black"` + Red string `json:"red"` + Green string `json:"green"` + Yellow string `json:"yellow"` + Blue string `json:"blue"` + Magenta string `json:"magenta"` + Cyan string `json:"cyan"` + White string `json:"white"` + BrightBlack string `json:"brightBlack"` + BrightRed string `json:"brightRed"` + BrightGreen string `json:"brightGreen"` + BrightYellow string `json:"brightYellow"` + BrightBlue string `json:"brightBlue"` + BrightMagenta string `json:"brightMagenta"` + BrightCyan string `json:"brightCyan"` + BrightWhite string `json:"brightWhite"` + SelectionBg string `json:"selectionBg,omitempty"` + SelectionFg string `json:"selectionFg,omitempty"` + IsBuiltin bool `json:"isBuiltin"` +} + +var BuiltinThemes = []Theme{ + { + Name: "Dracula", IsBuiltin: true, + Foreground: "#f8f8f2", Background: "#282a36", Cursor: "#f8f8f2", + Black: "#21222c", Red: "#ff5555", Green: "#50fa7b", Yellow: "#f1fa8c", + Blue: "#bd93f9", Magenta: "#ff79c6", Cyan: "#8be9fd", White: "#f8f8f2", + BrightBlack: "#6272a4", BrightRed: "#ff6e6e", BrightGreen: "#69ff94", + BrightYellow: "#ffffa5", BrightBlue: "#d6acff", BrightMagenta: "#ff92df", + BrightCyan: "#a4ffff", BrightWhite: "#ffffff", + }, + { + Name: "Nord", IsBuiltin: true, + Foreground: "#d8dee9", Background: "#2e3440", Cursor: "#d8dee9", + Black: "#3b4252", Red: "#bf616a", Green: "#a3be8c", Yellow: "#ebcb8b", + Blue: "#81a1c1", Magenta: "#b48ead", Cyan: "#88c0d0", White: "#e5e9f0", + BrightBlack: "#4c566a", BrightRed: "#bf616a", BrightGreen: "#a3be8c", + BrightYellow: "#ebcb8b", BrightBlue: "#81a1c1", BrightMagenta: "#b48ead", + BrightCyan: "#8fbcbb", BrightWhite: "#eceff4", + }, + { + Name: "Monokai", IsBuiltin: true, + Foreground: "#f8f8f2", Background: "#272822", Cursor: "#f8f8f0", + Black: "#272822", Red: "#f92672", Green: "#a6e22e", Yellow: "#f4bf75", + Blue: "#66d9ef", Magenta: "#ae81ff", Cyan: "#a1efe4", White: "#f8f8f2", + BrightBlack: "#75715e", BrightRed: "#f92672", BrightGreen: "#a6e22e", + BrightYellow: "#f4bf75", BrightBlue: "#66d9ef", BrightMagenta: "#ae81ff", + BrightCyan: "#a1efe4", BrightWhite: "#f9f8f5", + }, + { + Name: "One Dark", IsBuiltin: true, + Foreground: "#abb2bf", Background: "#282c34", Cursor: "#528bff", + Black: "#282c34", Red: "#e06c75", Green: "#98c379", Yellow: "#e5c07b", + Blue: "#61afef", Magenta: "#c678dd", Cyan: "#56b6c2", White: "#abb2bf", + BrightBlack: "#545862", BrightRed: "#e06c75", BrightGreen: "#98c379", + BrightYellow: "#e5c07b", BrightBlue: "#61afef", BrightMagenta: "#c678dd", + BrightCyan: "#56b6c2", BrightWhite: "#c8ccd4", + }, + { + Name: "Solarized Dark", IsBuiltin: true, + Foreground: "#839496", Background: "#002b36", Cursor: "#839496", + Black: "#073642", Red: "#dc322f", Green: "#859900", Yellow: "#b58900", + Blue: "#268bd2", Magenta: "#d33682", Cyan: "#2aa198", White: "#eee8d5", + BrightBlack: "#002b36", BrightRed: "#cb4b16", BrightGreen: "#586e75", + BrightYellow: "#657b83", BrightBlue: "#839496", BrightMagenta: "#6c71c4", + BrightCyan: "#93a1a1", BrightWhite: "#fdf6e3", + }, + { + Name: "Gruvbox Dark", IsBuiltin: true, + Foreground: "#ebdbb2", Background: "#282828", Cursor: "#ebdbb2", + Black: "#282828", Red: "#cc241d", Green: "#98971a", Yellow: "#d79921", + Blue: "#458588", Magenta: "#b16286", Cyan: "#689d6a", White: "#a89984", + BrightBlack: "#928374", BrightRed: "#fb4934", BrightGreen: "#b8bb26", + BrightYellow: "#fabd2f", BrightBlue: "#83a598", BrightMagenta: "#d3869b", + BrightCyan: "#8ec07c", BrightWhite: "#ebdbb2", + }, + { + Name: "MobaXTerm Classic", IsBuiltin: true, + Foreground: "#ececec", Background: "#242424", Cursor: "#b4b4c0", + Black: "#000000", Red: "#aa4244", Green: "#7e8d53", Yellow: "#e4b46d", + Blue: "#6e9aba", Magenta: "#9e5085", Cyan: "#80d5cf", White: "#cccccc", + BrightBlack: "#808080", BrightRed: "#cc7b7d", BrightGreen: "#a5b17c", + BrightYellow: "#ecc995", BrightBlue: "#96b6cd", BrightMagenta: "#c083ac", + BrightCyan: "#a9e2de", BrightWhite: "#cccccc", + }, +} diff --git a/internal/theme/service.go b/internal/theme/service.go new file mode 100644 index 0000000..3a2bfdf --- /dev/null +++ b/internal/theme/service.go @@ -0,0 +1,85 @@ +package theme + +import ( + "database/sql" + "fmt" +) + +type ThemeService struct { + db *sql.DB +} + +func NewThemeService(db *sql.DB) *ThemeService { + return &ThemeService{db: db} +} + +func (s *ThemeService) SeedBuiltins() error { + for _, t := range BuiltinThemes { + _, err := s.db.Exec( + `INSERT OR IGNORE INTO themes (name, foreground, background, cursor, + black, red, green, yellow, blue, magenta, cyan, white, + bright_black, bright_red, bright_green, bright_yellow, bright_blue, + bright_magenta, bright_cyan, bright_white, selection_bg, selection_fg, is_builtin) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,1)`, + t.Name, t.Foreground, t.Background, t.Cursor, + t.Black, t.Red, t.Green, t.Yellow, t.Blue, t.Magenta, t.Cyan, t.White, + t.BrightBlack, t.BrightRed, t.BrightGreen, t.BrightYellow, t.BrightBlue, + t.BrightMagenta, t.BrightCyan, t.BrightWhite, t.SelectionBg, t.SelectionFg, + ) + if err != nil { + return fmt.Errorf("seed theme %s: %w", t.Name, err) + } + } + return nil +} + +func (s *ThemeService) List() ([]Theme, error) { + rows, err := s.db.Query( + `SELECT id, name, foreground, background, cursor, + black, red, green, yellow, blue, magenta, cyan, white, + bright_black, bright_red, bright_green, bright_yellow, bright_blue, + bright_magenta, bright_cyan, bright_white, + COALESCE(selection_bg,''), COALESCE(selection_fg,''), is_builtin + FROM themes ORDER BY is_builtin DESC, name`) + if err != nil { + return nil, err + } + defer rows.Close() + + var themes []Theme + for rows.Next() { + var t Theme + if err := rows.Scan(&t.ID, &t.Name, &t.Foreground, &t.Background, &t.Cursor, + &t.Black, &t.Red, &t.Green, &t.Yellow, &t.Blue, &t.Magenta, &t.Cyan, &t.White, + &t.BrightBlack, &t.BrightRed, &t.BrightGreen, &t.BrightYellow, &t.BrightBlue, + &t.BrightMagenta, &t.BrightCyan, &t.BrightWhite, + &t.SelectionBg, &t.SelectionFg, &t.IsBuiltin); err != nil { + return nil, err + } + themes = append(themes, t) + } + if err := rows.Err(); err != nil { + return nil, err + } + return themes, nil +} + +func (s *ThemeService) GetByName(name string) (*Theme, error) { + var t Theme + err := s.db.QueryRow( + `SELECT id, name, foreground, background, cursor, + black, red, green, yellow, blue, magenta, cyan, white, + bright_black, bright_red, bright_green, bright_yellow, bright_blue, + bright_magenta, bright_cyan, bright_white, + COALESCE(selection_bg,''), COALESCE(selection_fg,''), is_builtin + FROM themes WHERE name = ?`, name, + ).Scan(&t.ID, &t.Name, &t.Foreground, &t.Background, &t.Cursor, + &t.Black, &t.Red, &t.Green, &t.Yellow, &t.Blue, &t.Magenta, &t.Cyan, &t.White, + &t.BrightBlack, &t.BrightRed, &t.BrightGreen, &t.BrightYellow, &t.BrightBlue, + &t.BrightMagenta, &t.BrightCyan, &t.BrightWhite, + &t.SelectionBg, &t.SelectionFg, &t.IsBuiltin) + if err != nil { + return nil, fmt.Errorf("get theme %s: %w", name, err) + } + return &t, nil +} diff --git a/internal/theme/service_test.go b/internal/theme/service_test.go new file mode 100644 index 0000000..2355885 --- /dev/null +++ b/internal/theme/service_test.go @@ -0,0 +1,60 @@ +package theme + +import ( + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" +) + +func setupTestDB(t *testing.T) *ThemeService { + t.Helper() + d, err := db.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + if err := db.Migrate(d); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { d.Close() }) + return NewThemeService(d) +} + +func TestSeedBuiltins(t *testing.T) { + svc := setupTestDB(t) + if err := svc.SeedBuiltins(); err != nil { + t.Fatalf("SeedBuiltins() error: %v", err) + } + + themes, err := svc.List() + if err != nil { + t.Fatalf("List() error: %v", err) + } + if len(themes) != len(BuiltinThemes) { + t.Errorf("len(themes) = %d, want %d", len(themes), len(BuiltinThemes)) + } +} + +func TestSeedBuiltinsIdempotent(t *testing.T) { + svc := setupTestDB(t) + svc.SeedBuiltins() + svc.SeedBuiltins() + + themes, _ := svc.List() + if len(themes) != len(BuiltinThemes) { + t.Errorf("len(themes) = %d after double seed, want %d", len(themes), len(BuiltinThemes)) + } +} + +func TestGetByName(t *testing.T) { + svc := setupTestDB(t) + svc.SeedBuiltins() + + theme, err := svc.GetByName("Dracula") + if err != nil { + t.Fatalf("GetByName() error: %v", err) + } + if theme.Background != "#282a36" { + t.Errorf("Background = %q, want %q", theme.Background, "#282a36") + } +} diff --git a/internal/vault/service.go b/internal/vault/service.go new file mode 100644 index 0000000..b82f89d --- /dev/null +++ b/internal/vault/service.go @@ -0,0 +1,95 @@ +package vault + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +const ( + argonTime = 3 + argonMemory = 64 * 1024 + argonThreads = 4 + argonKeyLen = 32 +) + +func DeriveKey(password string, salt []byte) []byte { + return argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen) +} + +func GenerateSalt() ([]byte, error) { + salt := make([]byte, 32) + if _, err := rand.Read(salt); err != nil { + return nil, fmt.Errorf("generate salt: %w", err) + } + return salt, nil +} + +type VaultService struct { + key []byte +} + +func NewVaultService(key []byte) *VaultService { + return &VaultService{key: key} +} + +func (v *VaultService) Encrypt(plaintext string) (string, error) { + block, err := aes.NewCipher(v.key) + if err != nil { + return "", fmt.Errorf("create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("create GCM: %w", err) + } + + iv := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(iv); err != nil { + return "", fmt.Errorf("generate IV: %w", err) + } + + sealed := gcm.Seal(nil, iv, []byte(plaintext), nil) + + return fmt.Sprintf("v1:%s:%s", hex.EncodeToString(iv), hex.EncodeToString(sealed)), nil +} + +func (v *VaultService) Decrypt(encrypted string) (string, error) { + parts := strings.SplitN(encrypted, ":", 3) + if len(parts) != 3 || parts[0] != "v1" { + return "", errors.New("invalid encrypted format: expected v1:{iv}:{sealed}") + } + + iv, err := hex.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode IV: %w", err) + } + + sealed, err := hex.DecodeString(parts[2]) + if err != nil { + return "", fmt.Errorf("decode sealed data: %w", err) + } + + block, err := aes.NewCipher(v.key) + if err != nil { + return "", fmt.Errorf("create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("create GCM: %w", err) + } + + plaintext, err := gcm.Open(nil, iv, sealed, nil) + if err != nil { + return "", fmt.Errorf("decrypt: %w", err) + } + + return string(plaintext), nil +} diff --git a/internal/vault/service_test.go b/internal/vault/service_test.go new file mode 100644 index 0000000..5558048 --- /dev/null +++ b/internal/vault/service_test.go @@ -0,0 +1,104 @@ +package vault + +import ( + "strings" + "testing" +) + +func TestDeriveKeyConsistent(t *testing.T) { + salt := []byte("test-salt-exactly-32-bytes-long!") + key1 := DeriveKey("mypassword", salt) + key2 := DeriveKey("mypassword", salt) + + if len(key1) != 32 { + t.Errorf("key length = %d, want 32", len(key1)) + } + if string(key1) != string(key2) { + t.Error("same password+salt produced different keys") + } +} + +func TestDeriveKeyDifferentPasswords(t *testing.T) { + salt := []byte("test-salt-exactly-32-bytes-long!") + key1 := DeriveKey("password1", salt) + key2 := DeriveKey("password2", salt) + + if string(key1) == string(key2) { + t.Error("different passwords produced same key") + } +} + +func TestEncryptDecryptRoundTrip(t *testing.T) { + key := DeriveKey("testpassword", []byte("test-salt-exactly-32-bytes-long!")) + vs := NewVaultService(key) + + plaintext := "super-secret-ssh-key-data" + encrypted, err := vs.Encrypt(plaintext) + if err != nil { + t.Fatalf("Encrypt() error: %v", err) + } + + if !strings.HasPrefix(encrypted, "v1:") { + t.Errorf("encrypted does not start with v1: prefix: %q", encrypted[:10]) + } + + decrypted, err := vs.Decrypt(encrypted) + if err != nil { + t.Fatalf("Decrypt() error: %v", err) + } + + if decrypted != plaintext { + t.Errorf("Decrypt() = %q, want %q", decrypted, plaintext) + } +} + +func TestEncryptProducesDifferentCiphertexts(t *testing.T) { + key := DeriveKey("testpassword", []byte("test-salt-exactly-32-bytes-long!")) + vs := NewVaultService(key) + + enc1, _ := vs.Encrypt("same-data") + enc2, _ := vs.Encrypt("same-data") + + if enc1 == enc2 { + t.Error("two encryptions of same data produced identical ciphertext (IV reuse)") + } +} + +func TestDecryptWrongKey(t *testing.T) { + key1 := DeriveKey("password1", []byte("test-salt-exactly-32-bytes-long!")) + key2 := DeriveKey("password2", []byte("test-salt-exactly-32-bytes-long!")) + + vs1 := NewVaultService(key1) + vs2 := NewVaultService(key2) + + encrypted, _ := vs1.Encrypt("secret") + _, err := vs2.Decrypt(encrypted) + if err == nil { + t.Error("Decrypt() with wrong key should return error") + } +} + +func TestDecryptInvalidFormat(t *testing.T) { + key := DeriveKey("test", []byte("test-salt-exactly-32-bytes-long!")) + vs := NewVaultService(key) + + _, err := vs.Decrypt("not-valid-format") + if err == nil { + t.Error("Decrypt() with invalid format should return error") + } +} + +func TestGenerateSalt(t *testing.T) { + salt1, err := GenerateSalt() + if err != nil { + t.Fatalf("GenerateSalt() error: %v", err) + } + if len(salt1) != 32 { + t.Errorf("salt length = %d, want 32", len(salt1)) + } + + salt2, _ := GenerateSalt() + if string(salt1) == string(salt2) { + t.Error("two calls to GenerateSalt produced identical salt") + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..9e46232 --- /dev/null +++ b/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "embed" + "log" + "log/slog" + + wraithapp "github.com/vstockwell/wraith/internal/app" + "github.com/wailsapp/wails/v3/pkg/application" +) + +// version is set at build time via -ldflags "-X main.version=..." +var version = "dev" + +//go:embed all:frontend/dist +var assets embed.FS + +func main() { + slog.Info("starting Wraith") + + wraith, err := wraithapp.New() + if err != nil { + log.Fatalf("failed to initialize: %v", err) + } + + app := application.New(application.Options{ + Name: "Wraith", + Description: "SSH + RDP + SFTP Desktop Client", + Services: []application.Service{ + application.NewService(wraith), + application.NewService(wraith.Connections), + application.NewService(wraith.Themes), + application.NewService(wraith.Settings), + }, + Assets: application.AssetOptions{ + Handler: application.BundledAssetFileServer(assets), + }, + }) + + app.Window.NewWithOptions(application.WebviewWindowOptions{ + Title: "Wraith", + Width: 1400, + Height: 900, + URL: "/", + BackgroundColour: application.NewRGBA(13, 17, 23, 255), + }) + + if err := app.Run(); err != nil { + log.Fatal(err) + } +}