diff --git a/.gitignore b/.gitignore index cf34e00..d5920a3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ # Go artifacts +backend/api/api backend/sshbackend/sshbackend backend/dummybackend/dummybackend +backend/sshappbackend/local-code/remote-bin +backend/sshappbackend/local-code/sshappbackend backend/externalbackendlauncher/externalbackendlauncher -backend/api/api frontend/frontend # Backup artifacts diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 57562cf..0000000 --- a/.prettierrc +++ /dev/null @@ -1,16 +0,0 @@ -{ - "arrowParens": "avoid", - "bracketSpacing": true, - "htmlWhitespaceSensitivity": "css", - "insertPragma": false, - "jsxSingleQuote": false, - "printWidth": 80, - "proseWrap": "always", - "quoteProps": "as-needed", - "requirePragma": false, - "semi": true, - "singleQuote": false, - "tabWidth": 2, - "trailingComma": "all", - "useTabs": false -} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index c1fa49f..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,49 +0,0 @@ -# Changelog - -## [v1.1.2](https://github.com/imterah/nextnet/tree/v1.1.2) (2024-09-29) - -## [v1.1.1](https://github.com/imterah/nextnet/tree/v1.1.1) (2024-09-29) - -## [v1.1.0](https://github.com/imterah/nextnet/tree/v1.1.0) (2024-09-22) - -**Fixed bugs:** - -- Desktop app fails to build on macOS w/ `nix-shell` [\#1](https://github.com/imterah/nextnet/issues/1) - -**Merged pull requests:** - -- chore\(deps\): bump find-my-way from 8.1.0 to 8.2.2 in /api [\#17](https://github.com/imterah/nextnet/pull/17) -- chore\(deps\): bump axios from 1.6.8 to 1.7.4 in /lom [\#16](https://github.com/imterah/nextnet/pull/16) -- chore\(deps\): bump micromatch from 4.0.5 to 4.0.8 in /lom [\#15](https://github.com/imterah/nextnet/pull/15) -- chore\(deps\): bump braces from 3.0.2 to 3.0.3 in /lom [\#13](https://github.com/imterah/nextnet/pull/13) -- chore\(deps-dev\): bump braces from 3.0.2 to 3.0.3 in /api [\#11](https://github.com/imterah/nextnet/pull/11) -- chore\(deps\): bump ws from 8.17.0 to 8.17.1 in /api [\#10](https://github.com/imterah/nextnet/pull/10) - -## [v1.0.1](https://github.com/imterah/nextnet/tree/v1.0.1) (2024-05-18) - -**Merged pull requests:** - -- Adds public key authentication [\#6](https://github.com/imterah/nextnet/pull/6) -- Add support for eslint [\#5](https://github.com/imterah/nextnet/pull/5) - -## [v1.0.0](https://github.com/imterah/nextnet/tree/v1.0.0) (2024-05-10) - -## [v0.1.1](https://github.com/imterah/nextnet/tree/v0.1.1) (2024-05-05) - -## [v0.1.0](https://github.com/imterah/nextnet/tree/v0.1.0) (2024-05-05) - -**Implemented enhancements:** - -- \(potentially\) Migrate nix shell to nix flake [\#2](https://github.com/imterah/nextnet/issues/2) - -**Closed issues:** - -- add precommit hooks [\#3](https://github.com/imterah/nextnet/issues/3) - -**Merged pull requests:** - -- Reimplements PassyFire as a possible backend [\#4](https://github.com/imterah/nextnet/pull/4) - - - -\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* diff --git a/Dockerfile b/Dockerfile index 81710a7..ae9c525 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,4 +7,5 @@ WORKDIR /app COPY --from=build /build/backend/backends.prod.json /app/backends.json COPY --from=build /build/backend/api/api /app/hermes COPY --from=build /build/backend/sshbackend/sshbackend /app/sshbackend +COPY --from=build /build/backend/sshappbackend/local-code/sshappbackend /app/sshappbackend ENTRYPOINT ["/app/hermes", "--backends-path", "/app/backends.json"] diff --git a/LICENSE b/LICENSE index 8914588..a085e23 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2024, Greyson +Copyright (c) 2024, Tera Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.md b/README.md index 01d5885..323855d 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ 1. Copy and change the default password (or username & db name too) from the template file `prod-docker.env`: ```bash - sed "s/POSTGRES_PASSWORD=hermes/POSTGRES_PASSWORD=$(head -c 500 /dev/random | sha512sum | cut -d " " -f 1)/g" prod-docker.env > .env + sed -e "s/POSTGRES_PASSWORD=hermes/POSTGRES_PASSWORD=$(head -c 500 /dev/random | sha512sum | cut -d " " -f 1)/g" -e "s/JWT_SECRET=hermes/JWT_SECRET=$(head -c 500 /dev/random | sha512sum | cut -d " " -f 1)/g" prod-docker.env > .env ``` 2. Build the docker stack: `docker compose --env-file .env up -d` diff --git a/backend/api/backendruntime/core.go b/backend/api/backendruntime/core.go index e06a2a6..eac5934 100644 --- a/backend/api/backendruntime/core.go +++ b/backend/api/backendruntime/core.go @@ -6,10 +6,10 @@ var ( AvailableBackends []*Backend RunningBackends map[uint]*Runtime TempDir string - isDevelopmentMode bool + shouldLog bool ) func init() { RunningBackends = make(map[uint]*Runtime) - isDevelopmentMode = os.Getenv("HERMES_DEVELOPMENT_MODE") != "" + shouldLog = os.Getenv("HERMES_DEVELOPMENT_MODE") != "" || os.Getenv("HERMES_BACKEND_LOGGING_ENABLED") != "" || os.Getenv("HERMES_LOG_LEVEL") == "debug" } diff --git a/backend/api/backendruntime/runtime.go b/backend/api/backendruntime/runtime.go index a65c077..8d5ca7a 100644 --- a/backend/api/backendruntime/runtime.go +++ b/backend/api/backendruntime/runtime.go @@ -15,8 +15,11 @@ import ( "github.com/charmbracelet/log" ) -func handleCommand(commandType string, command interface{}, sock net.Conn, rtcChan chan interface{}) error { - bytes, err := commonbackend.Marshal(commandType, command) +// TODO TODO TODO(imterah): +// This code is a mess. This NEEDS to be rearchitected and refactored to work better. Or at the very least, this code needs to be documented heavily. + +func handleCommand(command interface{}, sock net.Conn, rtcChan chan interface{}) error { + bytes, err := commonbackend.Marshal(command) if err != nil { log.Warnf("Failed to marshal message: %s", err.Error()) @@ -32,7 +35,7 @@ func handleCommand(commandType string, command interface{}, sock net.Conn, rtcCh return fmt.Errorf("failed to write message: %s", err.Error()) } - _, data, err := commonbackend.Unmarshal(sock) + data, err := commonbackend.Unmarshal(sock) if err != nil { log.Warnf("Failed to unmarshal message: %s", err.Error()) @@ -117,9 +120,7 @@ func (runtime *Runtime) goRoutineHandler() error { // To be safe here, we have to use the proper (yet annoying) facilities to prevent cross-talk, since we're in // a goroutine, and can't talk directly. This actually has benefits, as the OuterLoop should exit on its own, if we // encounter a critical error. - statusResponse, err := runtime.ProcessCommand(&commonbackend.BackendStatusRequest{ - Type: "backendStatusRequest", - }) + statusResponse, err := runtime.ProcessCommand(&commonbackend.BackendStatusRequest{}) if err != nil { log.Warnf("Failed to get response for backend (in backend runtime keep alive): %s", err.Error()) @@ -162,119 +163,28 @@ func (runtime *Runtime) goRoutineHandler() error { OuterLoop: for { + _ = <-runtime.startProcessingNotification + runtime.isRuntimeCurrentlyProcessing = true + for chanIndex, messageData := range runtime.messageBuffer { if messageData == nil { continue } - switch command := messageData.Message.(type) { - case *commonbackend.AddProxy: - err := handleCommand("addProxy", command, sock, messageData.Channel) + err := handleCommand(messageData.Message, sock, messageData.Channel) - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) + if err != nil { + log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } + if strings.HasPrefix(err.Error(), "failed to write message") { + break OuterLoop } - case *commonbackend.BackendStatusRequest: - err := handleCommand("backendStatusRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.CheckClientParameters: - err := handleCommand("checkClientParameters", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.CheckServerParameters: - err := handleCommand("checkServerParameters", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.ProxyConnectionsRequest: - err := handleCommand("proxyConnectionsRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.ProxyInstanceRequest: - err := handleCommand("proxyInstanceRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.ProxyStatusRequest: - err := handleCommand("proxyStatusRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.RemoveProxy: - err := handleCommand("removeProxy", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.Start: - err := handleCommand("start", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.Stop: - err := handleCommand("stop", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - default: - log.Warnf("Recieved unknown command type from channel: %T", command) - messageData.Channel <- fmt.Errorf("unknown command recieved") } runtime.messageBuffer[chanIndex] = nil } + + runtime.isRuntimeCurrentlyProcessing = false } sock.Close() @@ -333,6 +243,7 @@ func (runtime *Runtime) Start() error { runtime.messageBuffer = make([]*messageForBuf, 10) runtime.messageBufferLock = sync.Mutex{} + runtime.startProcessingNotification = make(chan bool) runtime.processRestartNotification = make(chan bool, 1) runtime.logger = &writeLogger{ @@ -420,6 +331,10 @@ SchedulingLoop: schedulingAttempts++ } + if !runtime.isRuntimeCurrentlyProcessing { + runtime.startProcessingNotification <- true + } + // Fetch response and close Channel response, ok := <-commandChannel diff --git a/backend/api/backendruntime/struct.go b/backend/api/backendruntime/struct.go index 9d057ad..cd4b3b8 100644 --- a/backend/api/backendruntime/struct.go +++ b/backend/api/backendruntime/struct.go @@ -16,15 +16,18 @@ type Backend struct { type messageForBuf struct { Channel chan interface{} + // TODO(imterah): could this be refactored to just be a []byte instead? Look into this Message interface{} } type Runtime struct { - isRuntimeRunning bool - logger *writeLogger - currentProcess *exec.Cmd - currentListener net.Listener - processRestartNotification chan bool + isRuntimeRunning bool + isRuntimeCurrentlyProcessing bool + startProcessingNotification chan bool + logger *writeLogger + currentProcess *exec.Cmd + currentListener net.Listener + processRestartNotification chan bool messageBufferLock sync.Mutex messageBuffer []*messageForBuf @@ -42,7 +45,7 @@ type writeLogger struct { func (writer writeLogger) Write(p []byte) (n int, err error) { logSplit := strings.Split(string(p), "\n") - if isDevelopmentMode { + if shouldLog { for _, logLine := range logSplit { if logLine == "" { continue diff --git a/backend/api/backup.go b/backend/api/backup.go deleted file mode 100644 index d5a9b86..0000000 --- a/backend/api/backup.go +++ /dev/null @@ -1,298 +0,0 @@ -package main - -import ( - "compress/gzip" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "os" - "strings" - - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "github.com/charmbracelet/log" - "github.com/go-playground/validator/v10" - "github.com/urfave/cli/v2" - "gorm.io/gorm" -) - -// Data structures -type BackupBackend struct { - ID uint `json:"id" validate:"required"` - - Name string `json:"name" validate:"required"` - Description *string `json:"description"` - Backend string `json:"backend" validate:"required"` - BackendParameters string `json:"connectionDetails" validate:"required"` -} - -type BackupProxy struct { - ID uint `json:"id" validate:"required"` - BackendID uint `json:"destProviderID" validate:"required"` - - Name string `json:"name" validate:"required"` - Description *string `json:"description"` - Protocol string `json:"protocol" validate:"required"` - SourceIP string `json:"sourceIP" validate:"required"` - SourcePort uint16 `json:"sourcePort" validate:"required"` - DestinationPort uint16 `json:"destPort" validate:"required"` - AutoStart bool `json:"enabled" validate:"required"` -} - -type BackupPermission struct { - ID uint `json:"id" validate:"required"` - - PermissionNode string `json:"permission" validate:"required"` - HasPermission bool `json:"has" validate:"required"` - UserID uint `json:"userID" validate:"required"` -} - -type BackupUser struct { - ID uint `json:"id" validate:"required"` - - Email string `json:"email" validate:"required"` - Username *string `json:"username"` - Name string `json:"name" validate:"required"` - Password string `json:"password" validate:"required"` - IsBot *bool `json:"isRootServiceAccount"` - - Token *string `json:"rootToken" validate:"required"` -} - -type BackupData struct { - Backends []*BackupBackend `json:"destinationProviders" validate:"required"` - Proxies []*BackupProxy `json:"forwardRules" validate:"required"` - Permissions []*BackupPermission `json:"allPermissions" validate:"required"` - Users []*BackupUser `json:"users" validate:"required"` -} - -// From https://stackoverflow.com/questions/54461423/efficient-way-to-remove-all-non-alphanumeric-characters-from-large-text -// Strips all alphanumeric characters from a string -func stripAllAlphanumeric(s string) string { - var result strings.Builder - - for i := 0; i < len(s); i++ { - b := s[i] - if ('a' <= b && b <= 'z') || - ('A' <= b && b <= 'Z') || - ('0' <= b && b <= '9') { - result.WriteByte(b) - } else { - result.WriteByte('_') - } - } - - return result.String() -} - -func backupRestoreEntrypoint(cCtx *cli.Context) error { - log.Info("Decompressing backup...") - - backupFile, err := os.Open(cCtx.String("backup-path")) - - if err != nil { - return fmt.Errorf("failed to open backup: %s", err.Error()) - } - - reader, err := gzip.NewReader(backupFile) - - if err != nil { - return fmt.Errorf("failed to initialize Gzip (compression) reader: %s", err.Error()) - } - - backupDataBytes, err := io.ReadAll(reader) - - if err != nil { - return fmt.Errorf("failed to read backup contents: %s", err.Error()) - } - - log.Info("Decompressed backup. Cleaning up...") - - err = reader.Close() - - if err != nil { - return fmt.Errorf("failed to close Gzip reader: %s", err.Error()) - } - - err = backupFile.Close() - - if err != nil { - return fmt.Errorf("failed to close backup: %s", err.Error()) - } - - log.Info("Parsing backup into internal structures...") - - backupData := &BackupData{} - - err = json.Unmarshal(backupDataBytes, backupData) - - if err != nil { - return fmt.Errorf("failed to parse backup: %s", err.Error()) - } - - if err := validator.New().Struct(backupData); err != nil { - return fmt.Errorf("failed to validate backup: %s", err.Error()) - } - - log.Info("Initializing database and opening it...") - - err = dbcore.InitializeDatabase(&gorm.Config{}) - - if err != nil { - log.Fatalf("Failed to initialize database: %s", err) - } - - log.Info("Running database migrations...") - - if err := dbcore.DoDatabaseMigrations(dbcore.DB); err != nil { - return fmt.Errorf("Failed to run database migrations: %s", err) - } - - log.Info("Restoring database...") - bestEffortOwnerUIDFromBackup := -1 - - log.Info("Attempting to find user to use as owner of resources...") - - for _, user := range backupData.Users { - foundUser := false - failedAdministrationCheck := false - - for _, permission := range backupData.Permissions { - if permission.UserID != user.ID { - continue - } - - foundUser = true - - if !strings.HasPrefix(permission.PermissionNode, "routes.") && permission.PermissionNode != "permissions.see" && !permission.HasPermission { - log.Infof("User with email '%s' and ID of '%d' failed administration check (lacks all permissions required). Attempting to find better user", user.Email, user.ID) - failedAdministrationCheck = true - - break - } - } - - if !foundUser { - log.Warnf("User with email '%s' and ID of '%d' lacks any permissions!", user.Email, user.ID) - continue - } - - if failedAdministrationCheck { - continue - } - - log.Infof("Using user with email '%s', and ID of '%d'", user.Email, user.ID) - bestEffortOwnerUIDFromBackup = int(user.ID) - - break - } - - if bestEffortOwnerUIDFromBackup == -1 { - log.Warnf("Could not find Administrative level user to use as the owner of resources. Using user with email '%s', and ID of '%d'", backupData.Users[0].Email, backupData.Users[0].ID) - bestEffortOwnerUIDFromBackup = int(backupData.Users[0].ID) - } - - var bestEffortOwnerUID uint - - for _, user := range backupData.Users { - log.Infof("Migrating user with email '%s' and ID of '%d'", user.Email, user.ID) - tokens := make([]dbcore.Token, 0) - permissions := make([]dbcore.Permission, 0) - - if user.Token != nil { - tokens = append(tokens, dbcore.Token{ - Token: *user.Token, - DisableExpiry: true, - CreationIPAddr: "127.0.0.1", // We don't know the creation IP address... - }) - } - - for _, permission := range backupData.Permissions { - if permission.UserID != user.ID { - continue - } - - permissions = append(permissions, dbcore.Permission{ - PermissionNode: permission.PermissionNode, - HasPermission: permission.HasPermission, - }) - } - - username := "" - - if user.Username == nil { - username = strings.ToLower(stripAllAlphanumeric(user.Email)) - log.Warnf("User with ID of '%d' doesn't have a username. Derived username from email is '%s' (email is '%s')", user.ID, username, user.Email) - } else { - username = *user.Username - } - - userDatabase := &dbcore.User{ - Email: user.Email, - Username: username, - Name: user.Name, - Password: base64.StdEncoding.EncodeToString([]byte(user.Password)), - IsBot: user.IsBot, - - Tokens: tokens, - Permissions: permissions, - } - - if err := dbcore.DB.Create(userDatabase).Error; err != nil { - log.Errorf("Failed to create user: %s", err.Error()) - continue - } - - if uint(bestEffortOwnerUIDFromBackup) == user.ID { - bestEffortOwnerUID = userDatabase.ID - } - } - - for _, backend := range backupData.Backends { - log.Infof("Migrating backend ID '%d' with name '%s'", backend.ID, backend.Name) - - backendDatabase := &dbcore.Backend{ - UserID: bestEffortOwnerUID, - Name: backend.Name, - Description: backend.Description, - Backend: backend.Backend, - BackendParameters: base64.StdEncoding.EncodeToString([]byte(backend.BackendParameters)), - } - - if err := dbcore.DB.Create(backendDatabase).Error; err != nil { - log.Errorf("Failed to create backend: %s", err.Error()) - continue - } - - log.Infof("Migrating proxies for backend ID '%d'", backend.ID) - - for _, proxy := range backupData.Proxies { - if proxy.BackendID != backend.ID { - continue - } - - log.Infof("Migrating proxy ID '%d' with name '%s'", proxy.ID, proxy.Name) - - proxyDatabase := &dbcore.Proxy{ - BackendID: backendDatabase.ID, - UserID: bestEffortOwnerUID, - - Name: proxy.Name, - Description: proxy.Description, - Protocol: proxy.Protocol, - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestinationPort: proxy.DestinationPort, - AutoStart: proxy.AutoStart, - } - - if err := dbcore.DB.Create(proxyDatabase).Error; err != nil { - log.Errorf("Failed to create proxy: %s", err.Error()) - } - } - } - - log.Info("Successfully upgraded to Hermes from NextNet.") - - return nil -} diff --git a/backend/api/controllers/v1/backends/create.go b/backend/api/controllers/v1/backends/create.go index 0d8c614..314dc3e 100644 --- a/backend/api/controllers/v1/backends/create.go +++ b/backend/api/controllers/v1/backends/create.go @@ -7,13 +7,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type BackendCreationRequest struct { @@ -24,132 +23,114 @@ type BackendCreationRequest struct { BackendParameters interface{} `json:"connectionDetails" validate:"required"` } -func CreateBackend(c *gin.Context) { - var req BackendCreationRequest +func SetupCreateBackend(state *state.State) { + state.Engine.POST("/api/v1/backends/create", func(c *gin.Context) { + var req BackendCreationRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) + user, err := state.JWT.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "backends.add") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Missing permissions", }) return } - } - if !permissions.UserHasPermission(user, "backends.add") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + var backendParameters []byte - return - } + switch parameters := req.BackendParameters.(type) { + case string: + backendParameters = []byte(parameters) + case map[string]interface{}: + backendParameters, err = json.Marshal(parameters) - var backendParameters []byte + if err != nil { + log.Warnf("Failed to marshal JSON recieved as BackendParameters: %s", err.Error()) - switch parameters := req.BackendParameters.(type) { - case string: - backendParameters = []byte(parameters) - case map[string]interface{}: - backendParameters, err = json.Marshal(parameters) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to prepare parameters", + }) + + return + } + default: + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Invalid type for connectionDetails (recieved %T)", parameters), + }) + + return + } + + var backendRuntimeFilePath string + + for _, runtime := range backendruntime.AvailableBackends { + if runtime.Name == req.Backend { + backendRuntimeFilePath = runtime.Path + } + } + + if backendRuntimeFilePath == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Unsupported backend recieved", + }) + + return + } + + backend := backendruntime.NewBackend(backendRuntimeFilePath) + err = backend.Start() if err != nil { - log.Warnf("Failed to marshal JSON recieved as BackendParameters: %s", err.Error()) + log.Warnf("Failed to start backend: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to prepare parameters", + "error": "Failed to start backend", }) return } - default: - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Invalid type for connectionDetails (recieved %T)", parameters), + + backendParamCheckResponse, err := backend.ProcessCommand(&commonbackend.CheckServerParameters{ + Arguments: backendParameters, }) - return - } - - var backendRuntimeFilePath string - - for _, runtime := range backendruntime.AvailableBackends { - if runtime.Name == req.Backend { - backendRuntimeFilePath = runtime.Path - } - } - - if backendRuntimeFilePath == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Unsupported backend recieved", - }) - - return - } - - backend := backendruntime.NewBackend(backendRuntimeFilePath) - err = backend.Start() - - if err != nil { - log.Warnf("Failed to start backend: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to start backend", - }) - - return - } - - backendParamCheckResponse, err := backend.ProcessCommand(&commonbackend.CheckServerParameters{ - Type: "checkServerParameters", - Arguments: backendParameters, - }) - - if err != nil { - log.Warnf("Failed to get response for backend: %s", err.Error()) - - err = backend.Stop() - if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) - } - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get status response from backend", - }) - - return - } - - switch responseMessage := backendParamCheckResponse.(type) { - case *commonbackend.CheckParametersResponse: - if responseMessage.InResponseTo != "checkServerParameters" { - log.Errorf("Got illegal response to CheckServerParameters: %s", responseMessage.InResponseTo) + log.Warnf("Failed to get response for backend: %s", err.Error()) err = backend.Stop() @@ -164,108 +145,126 @@ func CreateBackend(c *gin.Context) { return } - if !responseMessage.IsValid { + switch responseMessage := backendParamCheckResponse.(type) { + case *commonbackend.CheckParametersResponse: + if responseMessage.InResponseTo != "checkServerParameters" { + log.Errorf("Got illegal response to CheckServerParameters: %s", responseMessage.InResponseTo) + + err = backend.Stop() + + if err != nil { + log.Warnf("Failed to stop backend: %s", err.Error()) + } + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get status response from backend", + }) + + return + } + + if !responseMessage.IsValid { + err = backend.Stop() + + if err != nil { + log.Warnf("Failed to stop backend: %s", err.Error()) + } + + var errorMessage string + + if responseMessage.Message == "" { + errorMessage = "Unkown error while trying to parse connectionDetails" + } else { + errorMessage = fmt.Sprintf("Invalid backend parameters: %s", responseMessage.Message) + } + + c.JSON(http.StatusBadRequest, gin.H{ + "error": errorMessage, + }) + + return + } + default: + log.Warnf("Got illegal response type for backend: %T", responseMessage) + } + + log.Info("Passed backend checks successfully") + + backendInDatabase := &db.Backend{ + UserID: user.ID, + Name: req.Name, + Description: req.Description, + Backend: req.Backend, + BackendParameters: base64.StdEncoding.EncodeToString(backendParameters), + } + + if result := state.DB.DB.Create(&backendInDatabase); result.Error != nil { + log.Warnf("Failed to create backend: %s", result.Error.Error()) + err = backend.Stop() if err != nil { log.Warnf("Failed to stop backend: %s", err.Error()) } - var errorMessage string - - if responseMessage.Message == "" { - errorMessage = "Unkown error while trying to parse connectionDetails" - } else { - errorMessage = fmt.Sprintf("Invalid backend parameters: %s", responseMessage.Message) - } - - c.JSON(http.StatusBadRequest, gin.H{ - "error": errorMessage, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add backend into database", }) return } - default: - log.Warnf("Got illegal response type for backend: %T", responseMessage) - } - log.Info("Passed backend checks successfully") - - backendInDatabase := &dbcore.Backend{ - UserID: user.ID, - Name: req.Name, - Description: req.Description, - Backend: req.Backend, - BackendParameters: base64.StdEncoding.EncodeToString(backendParameters), - } - - if result := dbcore.DB.Create(&backendInDatabase); result.Error != nil { - log.Warnf("Failed to create backend: %s", result.Error.Error()) - - err = backend.Stop() - - if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) - } - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add backend into database", + backendStartResponse, err := backend.ProcessCommand(&commonbackend.Start{ + Arguments: backendParameters, }) - return - } - - backendStartResponse, err := backend.ProcessCommand(&commonbackend.Start{ - Type: "start", - Arguments: backendParameters, - }) - - if err != nil { - log.Warnf("Failed to get response for backend: %s", err.Error()) - - err = backend.Stop() - if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) - } + log.Warnf("Failed to get response for backend: %s", err.Error()) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get status response from backend", - }) - - return - } - - switch responseMessage := backendStartResponse.(type) { - case *commonbackend.BackendStatusResponse: - if !responseMessage.IsRunning { err = backend.Stop() if err != nil { - log.Warnf("Failed to start backend: %s", err.Error()) + log.Warnf("Failed to stop backend: %s", err.Error()) } - var errorMessage string - - if responseMessage.Message == "" { - errorMessage = "Unkown error while trying to start the backend" - } else { - errorMessage = fmt.Sprintf("Failed to start backend: %s", responseMessage.Message) - } - - c.JSON(http.StatusBadRequest, gin.H{ - "error": errorMessage, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get status response from backend", }) return } - default: - log.Warnf("Got illegal response type for backend: %T", responseMessage) - } - backendruntime.RunningBackends[backendInDatabase.ID] = backend + switch responseMessage := backendStartResponse.(type) { + case *commonbackend.BackendStatusResponse: + if !responseMessage.IsRunning { + err = backend.Stop() - c.JSON(http.StatusOK, gin.H{ - "success": true, + if err != nil { + log.Warnf("Failed to start backend: %s", err.Error()) + } + + var errorMessage string + + if responseMessage.Message == "" { + errorMessage = "Unkown error while trying to start the backend" + } else { + errorMessage = fmt.Sprintf("Failed to start backend: %s", responseMessage.Message) + } + + c.JSON(http.StatusBadRequest, gin.H{ + "error": errorMessage, + }) + + return + } + default: + log.Warnf("Got illegal response type for backend: %T", responseMessage) + } + + backendruntime.RunningBackends[backendInDatabase.ID] = backend + + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) }) } diff --git a/backend/api/controllers/v1/backends/lookup.go b/backend/api/controllers/v1/backends/lookup.go index 9819df7..6cbb386 100644 --- a/backend/api/controllers/v1/backends/lookup.go +++ b/backend/api/controllers/v1/backends/lookup.go @@ -7,12 +7,11 @@ import ( "strings" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type BackendLookupRequest struct { @@ -38,95 +37,80 @@ type LookupResponse struct { Data []*SanitizedBackend `json:"data"` } -func LookupBackend(c *gin.Context) { - var req BackendLookupRequest +func SetupLookupBackend(state *state.State) { + state.Engine.POST("/api/v1/backends/lookup", func(c *gin.Context) { + var req BackendLookupRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - if !permissions.UserHasPermission(user, "backends.visible") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - backends := []dbcore.Backend{} - queryString := []string{} - queryParameters := []interface{}{} + user, err := state.JWT.GetUserFromJWT(req.Token) - if req.BackendID != nil { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, req.BackendID) - } + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) - if req.Name != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - if req.Description != nil { - queryString = append(queryString, "description = ?") - queryParameters = append(queryParameters, req.Description) - } + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) - if req.Backend != nil { - queryString = append(queryString, "is_bot = ?") - queryParameters = append(queryParameters, req.Backend) - } + return + } + } - if err := dbcore.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&backends).Error; err != nil { - log.Warnf("Failed to get backends: %s", err.Error()) + if !permissions.UserHasPermission(user, "backends.visible") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get backends", - }) + return + } - return - } + backends := []db.Backend{} + queryString := []string{} + queryParameters := []interface{}{} - sanitizedBackends := make([]*SanitizedBackend, len(backends)) - hasSecretVisibility := permissions.UserHasPermission(user, "backends.secretVis") + if req.BackendID != nil { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, req.BackendID) + } - for backendIndex, backend := range backends { - foundBackend, ok := backendruntime.RunningBackends[backend.ID] + if req.Name != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } - if !ok { - log.Warnf("Failed to get backend #%d controller", backend.ID) + if req.Description != nil { + queryString = append(queryString, "description = ?") + queryParameters = append(queryParameters, req.Description) + } + + if req.Backend != nil { + queryString = append(queryString, "is_bot = ?") + queryParameters = append(queryParameters, req.Backend) + } + + if err := state.DB.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&backends).Error; err != nil { + log.Warnf("Failed to get backends: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ "error": "Failed to get backends", @@ -135,29 +119,46 @@ func LookupBackend(c *gin.Context) { return } - sanitizedBackends[backendIndex] = &SanitizedBackend{ - BackendID: backend.ID, - OwnerID: backend.UserID, - Name: backend.Name, - Description: backend.Description, - Backend: backend.Backend, - Logs: foundBackend.Logs, - } + sanitizedBackends := make([]*SanitizedBackend, len(backends)) + hasSecretVisibility := permissions.UserHasPermission(user, "backends.secretVis") - if backend.UserID == user.ID || hasSecretVisibility { - backendParametersBytes, err := base64.StdEncoding.DecodeString(backend.BackendParameters) + for backendIndex, backend := range backends { + foundBackend, ok := backendruntime.RunningBackends[backend.ID] - if err != nil { - log.Warnf("Failed to decode base64 backend parameters: %s", err.Error()) + if !ok { + log.Warnf("Failed to get backend #%d controller", backend.ID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get backends", + }) + + return } - backendParameters := string(backendParametersBytes) - sanitizedBackends[backendIndex].BackendParameters = &backendParameters - } - } + sanitizedBackends[backendIndex] = &SanitizedBackend{ + BackendID: backend.ID, + OwnerID: backend.UserID, + Name: backend.Name, + Description: backend.Description, + Backend: backend.Backend, + Logs: foundBackend.Logs, + } - c.JSON(http.StatusOK, &LookupResponse{ - Success: true, - Data: sanitizedBackends, + if backend.UserID == user.ID || hasSecretVisibility { + backendParametersBytes, err := base64.StdEncoding.DecodeString(backend.BackendParameters) + + if err != nil { + log.Warnf("Failed to decode base64 backend parameters: %s", err.Error()) + } + + backendParameters := string(backendParametersBytes) + sanitizedBackends[backendIndex].BackendParameters = &backendParameters + } + } + + c.JSON(http.StatusOK, &LookupResponse{ + Success: true, + Data: sanitizedBackends, + }) }) } diff --git a/backend/api/controllers/v1/backends/remove.go b/backend/api/controllers/v1/backends/remove.go index f9dbd8c..338ccbd 100644 --- a/backend/api/controllers/v1/backends/remove.go +++ b/backend/api/controllers/v1/backends/remove.go @@ -5,12 +5,11 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type BackendRemovalRequest struct { @@ -18,106 +17,108 @@ type BackendRemovalRequest struct { BackendID uint `json:"id" validate:"required"` } -func RemoveBackend(c *gin.Context) { - var req BackendRemovalRequest +func SetupRemoveBackend(state *state.State) { + state.Engine.POST("/api/v1/backends/remove", func(c *gin.Context) { + var req BackendRemovalRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - if !permissions.UserHasPermission(user, "backends.remove") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - var backend *dbcore.Backend - backendRequest := dbcore.DB.Where("id = ?", req.BackendID).Find(&backend) - - if backendRequest.Error != nil { - log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if backend exists", - }) - - return - } - - backendExists := backendRequest.RowsAffected > 0 - - if !backendExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Backend doesn't exist", - }) - - return - } - - if err := dbcore.DB.Delete(backend).Error; err != nil { - log.Warnf("failed to delete backend: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to delete backend", - }) - - return - } - - backendInstance, ok := backendruntime.RunningBackends[req.BackendID] - - if ok { - err = backendInstance.Stop() + user, err := state.JWT.GetUserFromJWT(req.Token) if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Backend deleted, but failed to stop", + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "backends.remove") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", }) - delete(backendruntime.RunningBackends, req.BackendID) return } - delete(backendruntime.RunningBackends, req.BackendID) - } + var backend *db.Backend + backendRequest := state.DB.DB.Where("id = ?", req.BackendID).Find(&backend) - c.JSON(http.StatusOK, gin.H{ - "success": true, + if backendRequest.Error != nil { + log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if backend exists", + }) + + return + } + + backendExists := backendRequest.RowsAffected > 0 + + if !backendExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Backend doesn't exist", + }) + + return + } + + if err := state.DB.DB.Delete(backend).Error; err != nil { + log.Warnf("failed to delete backend: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to delete backend", + }) + + return + } + + backendInstance, ok := backendruntime.RunningBackends[req.BackendID] + + if ok { + err = backendInstance.Stop() + + if err != nil { + log.Warnf("Failed to stop backend: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Backend deleted, but failed to stop", + }) + + delete(backendruntime.RunningBackends, req.BackendID) + return + } + + delete(backendruntime.RunningBackends, req.BackendID) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) }) } diff --git a/backend/api/controllers/v1/proxies/connections.go b/backend/api/controllers/v1/proxies/connections.go index f46c284..7ea42fb 100644 --- a/backend/api/controllers/v1/proxies/connections.go +++ b/backend/api/controllers/v1/proxies/connections.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ConnectionsRequest struct { @@ -37,129 +36,130 @@ type ConnectionsResponse struct { Data []*SanitizedConnection `json:"data"` } -func GetConnections(c *gin.Context) { - var req ConnectionsRequest +func SetupGetConnections(state *state.State) { + state.Engine.POST("/api/v1/forward/connections", func(c *gin.Context) { + var req ConnectionsRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - if !permissions.UserHasPermission(user, "routes.visibleConn") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - var proxy dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.Id).First(&proxy) + user, err := state.JWT.GetUserFromJWT(req.Token) - if proxyRequest.Error != nil { - log.Warnf("failed to find proxy: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find forward entry", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "No forward entry found", - }) - - return - } - - backendRuntime, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backendRuntime.ProcessCommand(&commonbackend.ProxyConnectionsRequest{ - Type: "proxyConnectionsRequest", - }) - - if err != nil { - log.Warnf("Failed to get response for backend: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get status response from backend", - }) - - return - } - - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyConnectionsResponse: - sanitizedConnections := []*SanitizedConnection{} - - for _, connection := range responseMessage.Connections { - if connection.SourceIP == proxy.SourceIP && connection.SourcePort == proxy.SourcePort && proxy.DestinationPort == proxy.DestinationPort { - sanitizedConnections = append(sanitizedConnections, &SanitizedConnection{ - ClientIP: connection.ClientIP, - Port: connection.ClientPort, - - ConnectionDetails: &ConnectionDetailsForConnection{ - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - }, + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return } } - c.JSON(http.StatusOK, &ConnectionsResponse{ - Success: true, - Data: sanitizedConnections, - }) - default: - log.Warnf("Got illegal response type for backend: %T", responseMessage) + if !permissions.UserHasPermission(user, "routes.visibleConn") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got illegal response type", - }) - } + return + } + + var proxy db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.Id).First(&proxy) + + if proxyRequest.Error != nil { + log.Warnf("failed to find proxy: %s", proxyRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find forward entry", + }) + + return + } + + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "No forward entry found", + }) + + return + } + + backendRuntime, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backendRuntime.ProcessCommand(&commonbackend.ProxyConnectionsRequest{}) + + if err != nil { + log.Warnf("Failed to get response for backend: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get status response from backend", + }) + + return + } + + switch responseMessage := backendResponse.(type) { + case *commonbackend.ProxyConnectionsResponse: + sanitizedConnections := []*SanitizedConnection{} + + for _, connection := range responseMessage.Connections { + if connection.SourceIP == proxy.SourceIP && connection.SourcePort == proxy.SourcePort && proxy.DestinationPort == proxy.DestinationPort { + sanitizedConnections = append(sanitizedConnections, &SanitizedConnection{ + ClientIP: connection.ClientIP, + Port: connection.ClientPort, + + ConnectionDetails: &ConnectionDetailsForConnection{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + }, + }) + } + } + + c.JSON(http.StatusOK, &ConnectionsResponse{ + Success: true, + Data: sanitizedConnections, + }) + default: + log.Warnf("Got illegal response type for backend: %T", responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got illegal response type", + }) + } + }) } diff --git a/backend/api/controllers/v1/proxies/create.go b/backend/api/controllers/v1/proxies/create.go index 20ad144..d790c49 100644 --- a/backend/api/controllers/v1/proxies/create.go +++ b/backend/api/controllers/v1/proxies/create.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyCreationRequest struct { @@ -26,151 +25,153 @@ type ProxyCreationRequest struct { AutoStart *bool `json:"autoStart"` } -func CreateProxy(c *gin.Context) { - var req ProxyCreationRequest +func SetupCreateProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/create", func(c *gin.Context) { + var req ProxyCreationRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", - }) - - return - } - } - - if !permissions.UserHasPermission(user, "routes.add") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) - - return - } - - if req.Protocol != "tcp" && req.Protocol != "udp" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Protocol must be either 'tcp' or 'udp'", - }) - - return - } - - var backend dbcore.Backend - backendRequest := dbcore.DB.Where("id = ?", req.ProviderID).First(&backend) - - if backendRequest.Error != nil { - log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if backend exists", - }) - } - - backendExists := backendRequest.RowsAffected > 0 - - if !backendExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Could not find backend", - }) - } - - autoStart := false - - if req.AutoStart != nil { - autoStart = *req.AutoStart - } - - proxy := &dbcore.Proxy{ - UserID: user.ID, - BackendID: req.ProviderID, - Name: req.Name, - Description: req.Description, - Protocol: req.Protocol, - SourceIP: req.SourceIP, - SourcePort: req.SourcePort, - DestinationPort: req.DestinationPort, - AutoStart: autoStart, - } - - if result := dbcore.DB.Create(proxy); result.Error != nil { - log.Warnf("failed to create proxy: %s", result.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add forward rule to database", - }) - - return - } - - if autoStart { - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "id": proxy.ID, + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ - Type: "addProxy", - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) if err != nil { - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get response from backend", + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.add") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", }) return } - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyStatusResponse: - if !responseMessage.IsActive { - log.Warnf("Failed to start proxy for backend #%d", proxy.BackendID) - } - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) - } - } + if req.Protocol != "tcp" && req.Protocol != "udp" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Protocol must be either 'tcp' or 'udp'", + }) - c.JSON(http.StatusOK, gin.H{ - "success": true, - "id": proxy.ID, + return + } + + var backend db.Backend + backendRequest := state.DB.DB.Where("id = ?", req.ProviderID).First(&backend) + + if backendRequest.Error != nil { + log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if backend exists", + }) + } + + backendExists := backendRequest.RowsAffected > 0 + + if !backendExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Could not find backend", + }) + } + + autoStart := false + + if req.AutoStart != nil { + autoStart = *req.AutoStart + } + + proxy := &db.Proxy{ + UserID: user.ID, + BackendID: req.ProviderID, + Name: req.Name, + Description: req.Description, + Protocol: req.Protocol, + SourceIP: req.SourceIP, + SourcePort: req.SourcePort, + DestinationPort: req.DestinationPort, + AutoStart: autoStart, + } + + if result := state.DB.DB.Create(proxy); result.Error != nil { + log.Warnf("failed to create proxy: %s", result.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add forward rule to database", + }) + + return + } + + if autoStart { + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "id": proxy.ID, + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, + }) + + if err != nil { + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get response from backend", + }) + + return + } + + switch responseMessage := backendResponse.(type) { + case *commonbackend.ProxyStatusResponse: + if !responseMessage.IsActive { + log.Warnf("Failed to start proxy for backend #%d", proxy.BackendID) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "id": proxy.ID, + }) }) } diff --git a/backend/api/controllers/v1/proxies/lookup.go b/backend/api/controllers/v1/proxies/lookup.go index 3b0d4c3..bf2c3ea 100644 --- a/backend/api/controllers/v1/proxies/lookup.go +++ b/backend/api/controllers/v1/proxies/lookup.go @@ -5,12 +5,11 @@ import ( "net/http" "strings" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyLookupRequest struct { @@ -43,141 +42,143 @@ type ProxyLookupResponse struct { Data []*SanitizedProxy `json:"data"` } -func LookupProxy(c *gin.Context) { - var req ProxyLookupRequest +func SetupLookupProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/lookup", func(c *gin.Context) { + var req ProxyLookupRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + } + + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.visible") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) + + return + } + + if req.Protocol != nil { + if *req.Protocol != "tcp" && *req.Protocol != "udp" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Protocol specified in body must either be 'tcp' or 'udp'", + }) + + return + } + } + + proxies := []db.Proxy{} + + queryString := []string{} + queryParameters := []interface{}{} + + if req.Id != nil { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, req.Id) + } + + if req.Name != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } + + if req.Description != nil { + queryString = append(queryString, "description = ?") + queryParameters = append(queryParameters, req.Description) + } + + if req.SourceIP != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } + + if req.SourcePort != nil { + queryString = append(queryString, "source_port = ?") + queryParameters = append(queryParameters, req.SourcePort) + } + + if req.DestinationPort != nil { + queryString = append(queryString, "destination_port = ?") + queryParameters = append(queryParameters, req.DestinationPort) + } + + if req.ProviderID != nil { + queryString = append(queryString, "backend_id = ?") + queryParameters = append(queryParameters, req.ProviderID) + } + + if req.AutoStart != nil { + queryString = append(queryString, "auto_start = ?") + queryParameters = append(queryParameters, req.AutoStart) + } + + if req.Protocol != nil { + queryString = append(queryString, "protocol = ?") + queryParameters = append(queryParameters, req.Protocol) + } + + if err := state.DB.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&proxies).Error; err != nil { + log.Warnf("failed to get proxies: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Failed to get proxies", }) return } - } - if !permissions.UserHasPermission(user, "routes.visible") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + sanitizedProxies := make([]*SanitizedProxy, len(proxies)) - return - } - - if req.Protocol != nil { - if *req.Protocol != "tcp" && *req.Protocol != "udp" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Protocol specified in body must either be 'tcp' or 'udp'", - }) - - return + for proxyIndex, proxy := range proxies { + sanitizedProxies[proxyIndex] = &SanitizedProxy{ + Id: proxy.ID, + Name: proxy.Name, + Description: proxy.Description, + Protcol: proxy.Protocol, + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestinationPort: proxy.DestinationPort, + ProviderID: proxy.BackendID, + AutoStart: proxy.AutoStart, + } } - } - proxies := []dbcore.Proxy{} - - queryString := []string{} - queryParameters := []interface{}{} - - if req.Id != nil { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, req.Id) - } - - if req.Name != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } - - if req.Description != nil { - queryString = append(queryString, "description = ?") - queryParameters = append(queryParameters, req.Description) - } - - if req.SourceIP != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } - - if req.SourcePort != nil { - queryString = append(queryString, "source_port = ?") - queryParameters = append(queryParameters, req.SourcePort) - } - - if req.DestinationPort != nil { - queryString = append(queryString, "destination_port = ?") - queryParameters = append(queryParameters, req.DestinationPort) - } - - if req.ProviderID != nil { - queryString = append(queryString, "backend_id = ?") - queryParameters = append(queryParameters, req.ProviderID) - } - - if req.AutoStart != nil { - queryString = append(queryString, "auto_start = ?") - queryParameters = append(queryParameters, req.AutoStart) - } - - if req.Protocol != nil { - queryString = append(queryString, "protocol = ?") - queryParameters = append(queryParameters, req.Protocol) - } - - if err := dbcore.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&proxies).Error; err != nil { - log.Warnf("failed to get proxies: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get proxies", + c.JSON(http.StatusOK, &ProxyLookupResponse{ + Success: true, + Data: sanitizedProxies, }) - - return - } - - sanitizedProxies := make([]*SanitizedProxy, len(proxies)) - - for proxyIndex, proxy := range proxies { - sanitizedProxies[proxyIndex] = &SanitizedProxy{ - Id: proxy.ID, - Name: proxy.Name, - Description: proxy.Description, - Protcol: proxy.Protocol, - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestinationPort: proxy.DestinationPort, - ProviderID: proxy.BackendID, - AutoStart: proxy.AutoStart, - } - } - - c.JSON(http.StatusOK, &ProxyLookupResponse{ - Success: true, - Data: sanitizedProxies, }) } diff --git a/backend/api/controllers/v1/proxies/remove.go b/backend/api/controllers/v1/proxies/remove.go index 8087e29..304c5c7 100644 --- a/backend/api/controllers/v1/proxies/remove.go +++ b/backend/api/controllers/v1/proxies/remove.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyRemovalRequest struct { @@ -19,135 +18,133 @@ type ProxyRemovalRequest struct { ID uint `validate:"required" json:"id"` } -func RemoveProxy(c *gin.Context) { - var req ProxyRemovalRequest +func SetupRemoveProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/remove", func(c *gin.Context) { + var req ProxyRemovalRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.remove") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Missing permissions", }) return } - } - if !permissions.UserHasPermission(user, "routes.remove") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + var proxy *db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.ID).Find(&proxy) - return - } + if proxyRequest.Error != nil { + log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - var proxy *dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.ID).Find(&proxy) - - if proxyRequest.Error != nil { - log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if forward rule exists", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Forward rule doesn't exist", - }) - - return - } - - if err := dbcore.DB.Delete(proxy).Error; err != nil { - log.Warnf("failed to delete proxy: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to delete forward rule", - }) - - return - } - - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ - Type: "removeProxy", - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) - - if err != nil { - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get response from backend. Proxy was still successfully deleted", - }) - - return - } - - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyStatusResponse: - if responseMessage.IsActive { c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to stop proxy. Proxy was still successfully deleted", + "error": "Failed to find if forward rule exists", }) return } - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got invalid response from backend. Proxy was still successfully deleted", + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Forward rule doesn't exist", + }) + + return + } + + if err := state.DB.DB.Delete(proxy).Error; err != nil { + log.Warnf("failed to delete proxy: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to delete forward rule", + }) + + return + } + + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, }) - return - } + if err != nil { + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) - c.JSON(http.StatusOK, gin.H{ - "success": true, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get response from backend. Proxy was still successfully deleted", + }) + + return + } + + switch responseMessage := backendResponse.(type) { + case *commonbackend.ProxyStatusResponse: + if responseMessage.IsActive { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to stop proxy. Proxy was still successfully deleted", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got invalid response from backend. Proxy was still successfully deleted", + }) + } }) } diff --git a/backend/api/controllers/v1/proxies/start.go b/backend/api/controllers/v1/proxies/start.go index d0cd5e0..1680ddf 100644 --- a/backend/api/controllers/v1/proxies/start.go +++ b/backend/api/controllers/v1/proxies/start.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyStartRequest struct { @@ -19,125 +18,119 @@ type ProxyStartRequest struct { ID uint `validate:"required" json:"id"` } -func StartProxy(c *gin.Context) { - var req ProxyStartRequest +func SetupStartProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/start", func(c *gin.Context) { + var req ProxyStartRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.start") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", - }) - - return - } - } - - if !permissions.UserHasPermission(user, "routes.start") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) - - return - } - - var proxy *dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.ID).Find(&proxy) - - if proxyRequest.Error != nil { - log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if forward rule exists", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Forward rule doesn't exist", - }) - - return - } - - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ - Type: "addProxy", - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) - - switch responseMessage := backendResponse.(type) { - case error: - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, responseMessage.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get response from backend", - }) - - return - case *commonbackend.ProxyStatusResponse: - if !responseMessage.IsActive { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to start proxy", + "error": "Missing permissions", }) return } - break - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + var proxy *db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.ID).Find(&proxy) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got invalid response from backend. Proxy was still successfully deleted", + if proxyRequest.Error != nil { + log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if forward rule exists", + }) + + return + } + + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Forward rule doesn't exist", + }) + + return + } + + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, }) - return - } + switch responseMessage := backendResponse.(type) { + case error: + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, responseMessage.Error()) - c.JSON(http.StatusOK, gin.H{ - "success": true, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get response from backend", + }) + case *commonbackend.ProxyStatusResponse: + if !responseMessage.IsActive { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to start proxy", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got invalid response from backend. Proxy was likely still successfully started", + }) + } }) } diff --git a/backend/api/controllers/v1/proxies/stop.go b/backend/api/controllers/v1/proxies/stop.go index 1f5f525..27d63ce 100644 --- a/backend/api/controllers/v1/proxies/stop.go +++ b/backend/api/controllers/v1/proxies/stop.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyStopRequest struct { @@ -19,125 +18,119 @@ type ProxyStopRequest struct { ID uint `validate:"required" json:"id"` } -func StopProxy(c *gin.Context) { - var req ProxyStopRequest +func SetupStopProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/stop", func(c *gin.Context) { + var req ProxyStartRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.stop") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Missing permissions", }) return } - } - if !permissions.UserHasPermission(user, "routes.stop") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + var proxy *db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.ID).Find(&proxy) - return - } + if proxyRequest.Error != nil { + log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - var proxy *dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.ID).Find(&proxy) - - if proxyRequest.Error != nil { - log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if forward rule exists", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Forward rule doesn't exist", - }) - - return - } - - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ - Type: "removeProxy", - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) - - if err != nil { - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get response from backend", - }) - - return - } - - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyStatusResponse: - if responseMessage.IsActive { c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to stop proxy", + "error": "Failed to find if forward rule exists", }) return } - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got invalid response from backend. Proxy was still successfully deleted", + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Forward rule doesn't exist", + }) + + return + } + + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, }) - return - } + switch responseMessage := backendResponse.(type) { + case error: + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, responseMessage.Error()) - c.JSON(http.StatusOK, gin.H{ - "success": true, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get response from backend", + }) + case *commonbackend.ProxyStatusResponse: + if responseMessage.IsActive { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to stop proxy", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got invalid response from backend. Proxy was likely still successfully stopped", + }) + } }) } diff --git a/backend/api/controllers/v1/users/create.go b/backend/api/controllers/v1/users/create.go index e92858a..f39aa95 100644 --- a/backend/api/controllers/v1/users/create.go +++ b/backend/api/controllers/v1/users/create.go @@ -7,11 +7,9 @@ import ( "net/http" "strings" - "github.com/go-playground/validator/v10" - - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" permissionHelper "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" "golang.org/x/crypto/bcrypt" @@ -22,142 +20,141 @@ type UserCreationRequest struct { Email string `validate:"required"` Password string `validate:"required"` Username string `validate:"required"` - - // TODO: implement support - ExistingUserToken string `json:"token"` - IsBot bool + IsBot bool } -func CreateUser(c *gin.Context) { - if !signupEnabled && !unsafeSignup { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Signing up is not enabled at this time.", - }) +func SetupCreateUser(state *state.State) { + state.Engine.POST("/api/v1/users/create", func(c *gin.Context) { + if !signupEnabled && !unsafeSignup { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Signing up is not enabled at this time.", + }) - return - } - - var req UserCreationRequest - - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - var user *dbcore.User - userRequest := dbcore.DB.Where("email = ? OR username = ?", req.Email, req.Username).Find(&user) - - if userRequest.Error != nil { - log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if user exists", - }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if userExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "User already exists", - }) - - return - } - - passwordHashed, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) - - if err != nil { - log.Warnf("Failed to generate password for client upon signup: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate password hash", - }) - - return - } - - permissions := []dbcore.Permission{} - - for _, permission := range permissionHelper.DefaultPermissionNodes { - permissionEnabledState := false - - if unsafeSignup || strings.HasPrefix(permission, "routes.") || permission == "permissions.see" { - permissionEnabledState = true + return } - permissions = append(permissions, dbcore.Permission{ - PermissionNode: permission, - HasPermission: permissionEnabledState, - }) - } + var req UserCreationRequest - tokenRandomData := make([]byte, 80) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - if _, err := rand.Read(tokenRandomData); err != nil { - log.Warnf("Failed to read random data to use as token: %s", err.Error()) + return + } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user = &dbcore.User{ - Email: req.Email, - Username: req.Username, - Name: req.Name, - IsBot: &req.IsBot, - Password: base64.StdEncoding.EncodeToString(passwordHashed), - Permissions: permissions, - Tokens: []dbcore.Token{ - { - Token: base64.StdEncoding.EncodeToString(tokenRandomData), - DisableExpiry: forceNoExpiryTokens, - CreationIPAddr: c.ClientIP(), + var user *db.User + userRequest := state.DB.DB.Where("email = ? OR username = ?", req.Email, req.Username).Find(&user) + + if userRequest.Error != nil { + log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if user exists", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if userExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "User already exists", + }) + + return + } + + passwordHashed, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) + + if err != nil { + log.Warnf("Failed to generate password for client upon signup: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate password hash", + }) + + return + } + + permissions := []db.Permission{} + + for _, permission := range permissionHelper.DefaultPermissionNodes { + permissionEnabledState := false + + if unsafeSignup || strings.HasPrefix(permission, "routes.") || permission == "permissions.see" { + permissionEnabledState = true + } + + permissions = append(permissions, db.Permission{ + PermissionNode: permission, + HasPermission: permissionEnabledState, + }) + } + + tokenRandomData := make([]byte, 80) + + if _, err := rand.Read(tokenRandomData); err != nil { + log.Warnf("Failed to read random data to use as token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + user = &db.User{ + Email: req.Email, + Username: req.Username, + Name: req.Name, + IsBot: &req.IsBot, + Password: base64.StdEncoding.EncodeToString(passwordHashed), + Permissions: permissions, + Tokens: []db.Token{ + { + Token: base64.StdEncoding.EncodeToString(tokenRandomData), + DisableExpiry: forceNoExpiryTokens, + CreationIPAddr: c.ClientIP(), + }, }, - }, - } + } - if result := dbcore.DB.Create(&user); result.Error != nil { - log.Warnf("Failed to create user: %s", result.Error.Error()) + if result := state.DB.DB.Create(&user); result.Error != nil { + log.Warnf("Failed to create user: %s", result.Error.Error()) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add user into database", + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add user into database", + }) + + return + } + + jwt, err := state.JWT.Generate(user.ID) + + if err != nil { + log.Warnf("Failed to generate JWT: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "token": jwt, + "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) - - return - } - - jwt, err := jwtcore.Generate(user.ID) - - if err != nil { - log.Warnf("Failed to generate JWT: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "token": jwt, - "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) } diff --git a/backend/api/controllers/v1/users/login.go b/backend/api/controllers/v1/users/login.go index b6fd586..ea4f2f3 100644 --- a/backend/api/controllers/v1/users/login.go +++ b/backend/api/controllers/v1/users/login.go @@ -6,11 +6,10 @@ import ( "fmt" "net/http" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" "golang.org/x/crypto/bcrypt" ) @@ -21,137 +20,139 @@ type UserLoginRequest struct { Password string `validate:"required"` } -func LoginUser(c *gin.Context) { - var req UserLoginRequest +func SetupLoginUser(state *state.State) { + state.Engine.POST("/api/v1/users/login", func(c *gin.Context) { + var req UserLoginRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) + + return + } + + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + if req.Email == nil && req.Username == nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Missing both email and username in body", + }) + + return + } + + userFindRequestArguments := make([]interface{}, 1) + userFindRequest := "" + + if req.Email != nil { + userFindRequestArguments[0] = &req.Email + userFindRequest += "email = ?" + } + + if req.Username != nil { + userFindRequestArguments[0] = &req.Username + userFindRequest += "username = ?" + } + + var user *db.User + userRequest := state.DB.DB.Where(userFindRequest, userFindRequestArguments...).Find(&user) + + if userRequest.Error != nil { + log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if user exists", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "User not found", + }) + + return + } + + decodedPassword := make([]byte, base64.StdEncoding.DecodedLen(len(user.Password))) + _, err := base64.StdEncoding.Decode(decodedPassword, []byte(user.Password)) + + if err != nil { + log.Warnf("failed to decode password in database: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse database result for password", + }) + + return + } + + err = bcrypt.CompareHashAndPassword(decodedPassword, []byte(req.Password)) + + if err != nil { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Invalid password", + }) + + return + } + + tokenRandomData := make([]byte, 80) + + if _, err := rand.Read(tokenRandomData); err != nil { + log.Warnf("Failed to read random data to use as token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + token := &db.Token{ + UserID: user.ID, + + Token: base64.StdEncoding.EncodeToString(tokenRandomData), + DisableExpiry: forceNoExpiryTokens, + CreationIPAddr: c.ClientIP(), + } + + if result := state.DB.DB.Create(&token); result.Error != nil { + log.Warnf("Failed to create user: %s", result.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add refresh token into database", + }) + + return + } + + jwt, err := state.JWT.Generate(user.ID) + + if err != nil { + log.Warnf("Failed to generate JWT: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "token": jwt, + "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - if req.Email == nil && req.Username == nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Missing both email and username in body", - }) - - return - } - - userFindRequestArguments := make([]interface{}, 1) - userFindRequest := "" - - if req.Email != nil { - userFindRequestArguments[0] = &req.Email - userFindRequest += "email = ?" - } - - if req.Username != nil { - userFindRequestArguments[0] = &req.Username - userFindRequest += "username = ?" - } - - var user *dbcore.User - userRequest := dbcore.DB.Where(userFindRequest, userFindRequestArguments...).Find(&user) - - if userRequest.Error != nil { - log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if user exists", - }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "User not found", - }) - - return - } - - decodedPassword := make([]byte, base64.StdEncoding.DecodedLen(len(user.Password))) - _, err := base64.StdEncoding.Decode(decodedPassword, []byte(user.Password)) - - if err != nil { - log.Warnf("failed to decode password in database: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse database result for password", - }) - - return - } - - err = bcrypt.CompareHashAndPassword(decodedPassword, []byte(req.Password)) - - if err != nil { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Invalid password", - }) - - return - } - - tokenRandomData := make([]byte, 80) - - if _, err := rand.Read(tokenRandomData); err != nil { - log.Warnf("Failed to read random data to use as token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - token := &dbcore.Token{ - UserID: user.ID, - - Token: base64.StdEncoding.EncodeToString(tokenRandomData), - DisableExpiry: forceNoExpiryTokens, - CreationIPAddr: c.ClientIP(), - } - - if result := dbcore.DB.Create(&token); result.Error != nil { - log.Warnf("Failed to create user: %s", result.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add refresh token into database", - }) - - return - } - - jwt, err := jwtcore.Generate(user.ID) - - if err != nil { - log.Warnf("Failed to generate JWT: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "token": jwt, - "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) } diff --git a/backend/api/controllers/v1/users/lookup.go b/backend/api/controllers/v1/users/lookup.go index 04782e5..f5c14fc 100644 --- a/backend/api/controllers/v1/users/lookup.go +++ b/backend/api/controllers/v1/users/lookup.go @@ -5,12 +5,11 @@ import ( "net/http" "strings" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type UserLookupRequest struct { @@ -35,102 +34,104 @@ type LookupResponse struct { Data []*SanitizedUsers `json:"data"` } -func LookupUser(c *gin.Context) { - var req UserLookupRequest +func SetupLookupUser(state *state.State) { + state.Engine.POST("/api/v1/users/lookup", func(c *gin.Context) { + var req UserLookupRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + } + + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + users := []db.User{} + queryString := []string{} + queryParameters := []interface{}{} + + if !permissions.UserHasPermission(user, "users.lookup") { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, user.ID) + } else if permissions.UserHasPermission(user, "users.lookup") && req.UID != nil { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, req.UID) + } + + if req.Name != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } + + if req.Email != nil { + queryString = append(queryString, "email = ?") + queryParameters = append(queryParameters, req.Email) + } + + if req.IsBot != nil { + queryString = append(queryString, "is_bot = ?") + queryParameters = append(queryParameters, req.IsBot) + } + + if err := state.DB.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&users).Error; err != nil { + log.Warnf("Failed to get users: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Failed to get users", }) return } - } - users := []dbcore.User{} - queryString := []string{} - queryParameters := []interface{}{} + sanitizedUsers := make([]*SanitizedUsers, len(users)) - if !permissions.UserHasPermission(user, "users.lookup") { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, user.ID) - } else if permissions.UserHasPermission(user, "users.lookup") && req.UID != nil { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, req.UID) - } + for userIndex, user := range users { + isBot := false - if req.Name != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } + if user.IsBot != nil { + isBot = *user.IsBot + } - if req.Email != nil { - queryString = append(queryString, "email = ?") - queryParameters = append(queryParameters, req.Email) - } + sanitizedUsers[userIndex] = &SanitizedUsers{ + UID: user.ID, + Name: user.Name, + Email: user.Email, + Username: user.Username, + IsBot: isBot, + } + } - if req.IsBot != nil { - queryString = append(queryString, "is_bot = ?") - queryParameters = append(queryParameters, req.IsBot) - } - - if err := dbcore.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&users).Error; err != nil { - log.Warnf("Failed to get users: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get users", + c.JSON(http.StatusOK, &LookupResponse{ + Success: true, + Data: sanitizedUsers, }) - - return - } - - sanitizedUsers := make([]*SanitizedUsers, len(users)) - - for userIndex, user := range users { - isBot := false - - if user.IsBot != nil { - isBot = *user.IsBot - } - - sanitizedUsers[userIndex] = &SanitizedUsers{ - UID: user.ID, - Name: user.Name, - Email: user.Email, - Username: user.Username, - IsBot: isBot, - } - } - - c.JSON(http.StatusOK, &LookupResponse{ - Success: true, - Data: sanitizedUsers, }) } diff --git a/backend/api/controllers/v1/users/refresh.go b/backend/api/controllers/v1/users/refresh.go index 59fcf80..ce9fbaf 100644 --- a/backend/api/controllers/v1/users/refresh.go +++ b/backend/api/controllers/v1/users/refresh.go @@ -5,113 +5,114 @@ import ( "net/http" "time" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type UserRefreshRequest struct { Token string `validate:"required"` } -func RefreshUserToken(c *gin.Context) { - var req UserRefreshRequest +func SetupRefreshUserToken(state *state.State) { + state.Engine.POST("/api/v1/users/refresh", func(c *gin.Context) { + var req UserRefreshRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - var tokenInDatabase *dbcore.Token - tokenRequest := dbcore.DB.Where("token = ?", req.Token).Find(&tokenInDatabase) - - if tokenRequest.Error != nil { - log.Warnf("failed to find if token exists or not: %s", tokenRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if token exists", - }) - - return - } - - tokenExists := tokenRequest.RowsAffected > 0 - - if !tokenExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Token not found", - }) - - return - } - - // First, we check to make sure that the key expiry is disabled before checking if the key is expired. - // Then, we check if the IP addresses differ, or if it has been 7 days since the token has been created. - if !tokenInDatabase.DisableExpiry && (c.ClientIP() != tokenInDatabase.CreationIPAddr || time.Now().Before(tokenInDatabase.CreatedAt.Add((24*7)*time.Hour))) { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Token has expired", - }) - - tx := dbcore.DB.Delete(tokenInDatabase) - - if tx.Error != nil { - log.Warnf("Failed to delete expired token from database: %s", tx.Error.Error()) + return } - return - } + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - // Get the user to check if the user exists before doing anything - var user *dbcore.User - userRequest := dbcore.DB.Where("id = ?", tokenInDatabase.UserID).Find(&user) + return + } - if tokenRequest.Error != nil { - log.Warnf("failed to find if token user or not: %s", userRequest.Error.Error()) + var tokenInDatabase *db.Token + tokenRequest := state.DB.DB.Where("token = ?", req.Token).Find(&tokenInDatabase) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find user", + if tokenRequest.Error != nil { + log.Warnf("failed to find if token exists or not: %s", tokenRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if token exists", + }) + + return + } + + tokenExists := tokenRequest.RowsAffected > 0 + + if !tokenExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Token not found", + }) + + return + } + + // First, we check to make sure that the key expiry is disabled before checking if the key is expired. + // Then, we check if the IP addresses differ, or if it has been 7 days since the token has been created. + if !tokenInDatabase.DisableExpiry && (c.ClientIP() != tokenInDatabase.CreationIPAddr || time.Now().Before(tokenInDatabase.CreatedAt.Add((24*7)*time.Hour))) { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Token has expired", + }) + + tx := state.DB.DB.Delete(tokenInDatabase) + + if tx.Error != nil { + log.Warnf("Failed to delete expired token from database: %s", tx.Error.Error()) + } + + return + } + + // Get the user to check if the user exists before doing anything + var user *db.User + userRequest := state.DB.DB.Where("id = ?", tokenInDatabase.UserID).Find(&user) + + if tokenRequest.Error != nil { + log.Warnf("failed to find if token user or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find user", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "User not found", + }) + + return + } + + jwt, err := state.JWT.Generate(user.ID) + + if err != nil { + log.Warnf("Failed to generate JWT: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "token": jwt, }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "User not found", - }) - - return - } - - jwt, err := jwtcore.Generate(user.ID) - - if err != nil { - log.Warnf("Failed to generate JWT: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "token": jwt, }) } diff --git a/backend/api/controllers/v1/users/remove.go b/backend/api/controllers/v1/users/remove.go index 5446755..59e460c 100644 --- a/backend/api/controllers/v1/users/remove.go +++ b/backend/api/controllers/v1/users/remove.go @@ -4,12 +4,11 @@ import ( "fmt" "net/http" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type UserRemovalRequest struct { @@ -17,89 +16,91 @@ type UserRemovalRequest struct { UID *uint `json:"uid"` } -func RemoveUser(c *gin.Context) { - var req UserRemovalRequest +func SetupRemoveUser(state *state.State) { + state.Engine.POST("/api/v1/users/remove", func(c *gin.Context) { + var req UserRemovalRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", - }) - - return - } - } - - uid := user.ID - - if req.UID != nil { - uid = *req.UID - - if uid != user.ID && !permissions.UserHasPermission(user, "users.remove") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) - - return - } - } - - // Make sure the user exists first if we have a custom UserID - - if uid != user.ID { - var customUser *dbcore.User - userRequest := dbcore.DB.Where("id = ?", uid).Find(customUser) - - if userRequest.Error != nil { - log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if user exists", - }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { + if err := c.BindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ - "error": "User doesn't exist", + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - dbcore.DB.Select("Tokens", "Permissions", "Proxys", "Backends").Where("id = ?", uid).Delete(user) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - c.JSON(http.StatusOK, gin.H{ - "success": true, + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + uid := user.ID + + if req.UID != nil { + uid = *req.UID + + if uid != user.ID && !permissions.UserHasPermission(user, "users.remove") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) + + return + } + } + + // Make sure the user exists first if we have a custom UserID + + if uid != user.ID { + var customUser *db.User + userRequest := state.DB.DB.Where("id = ?", uid).Find(customUser) + + if userRequest.Error != nil { + log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if user exists", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "User doesn't exist", + }) + + return + } + } + + state.DB.DB.Select("Tokens", "Permissions", "Proxys", "Backends").Where("id = ?", uid).Delete(user) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) }) } diff --git a/backend/api/db/db.go b/backend/api/db/db.go new file mode 100644 index 0000000..295bff8 --- /dev/null +++ b/backend/api/db/db.go @@ -0,0 +1,77 @@ +package db + +import ( + "fmt" + "os" + + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type DB struct { + DB *gorm.DB +} + +func New(backend, params string) (*DB, error) { + var err error + + dialector, err := initDialector(backend, params) + + if err != nil { + return nil, fmt.Errorf("failed to initialize physical database: %s", err) + } + + database, err := gorm.Open(dialector) + + if err != nil { + return nil, fmt.Errorf("failed to open database: %s", err) + } + + return &DB{DB: database}, nil +} + +func (db *DB) DoMigrations() error { + if err := db.DB.AutoMigrate(&Proxy{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&Backend{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&Permission{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&Token{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&User{}); err != nil { + return err + } + + return nil +} + +func initDialector(backend, params string) (gorm.Dialector, error) { + switch backend { + case "sqlite": + if params == "" { + return nil, fmt.Errorf("sqlite database file not specified") + } + + return sqlite.Open(params), nil + case "postgresql": + if params == "" { + return nil, fmt.Errorf("postgres DSN not specified") + } + + return postgres.Open(params), nil + case "": + return nil, fmt.Errorf("no database backend specified in environment variables") + default: + return nil, fmt.Errorf("unknown database backend specified: %s", os.Getenv(backend)) + } +} diff --git a/backend/api/db/models.go b/backend/api/db/models.go new file mode 100644 index 0000000..290cd6e --- /dev/null +++ b/backend/api/db/models.go @@ -0,0 +1,66 @@ +package db + +import ( + "gorm.io/gorm" +) + +type Backend struct { + gorm.Model + + UserID uint + + Name string + Description *string + Backend string + BackendParameters string + + Proxies []Proxy +} + +type Proxy struct { + gorm.Model + + BackendID uint + UserID uint + + Name string + Description *string + Protocol string + SourceIP string + SourcePort uint16 + DestinationPort uint16 + AutoStart bool +} + +type Permission struct { + gorm.Model + + PermissionNode string + HasPermission bool + UserID uint +} + +type Token struct { + gorm.Model + + UserID uint + + Token string + DisableExpiry bool + CreationIPAddr string +} + +type User struct { + gorm.Model + + Email string `gorm:"unique"` + Username string `gorm:"unique"` + Name string + Password string + IsBot *bool + + Permissions []Permission + OwnedProxies []Proxy + OwnedBackends []Backend + Tokens []Token +} diff --git a/backend/api/dbcore/db.go b/backend/api/dbcore/db.go deleted file mode 100644 index b5e0676..0000000 --- a/backend/api/dbcore/db.go +++ /dev/null @@ -1,142 +0,0 @@ -package dbcore - -import ( - "fmt" - "os" - - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/gorm" -) - -type Backend struct { - gorm.Model - - UserID uint - - Name string - Description *string - Backend string - BackendParameters string - - Proxies []Proxy -} - -type Proxy struct { - gorm.Model - - BackendID uint - UserID uint - - Name string - Description *string - Protocol string - SourceIP string - SourcePort uint16 - DestinationPort uint16 - AutoStart bool -} - -type Permission struct { - gorm.Model - - PermissionNode string - HasPermission bool - UserID uint -} - -type Token struct { - gorm.Model - - UserID uint - - Token string - DisableExpiry bool - CreationIPAddr string -} - -type User struct { - gorm.Model - - Email string `gorm:"unique"` - Username string `gorm:"unique"` - Name string - Password string - IsBot *bool - - Permissions []Permission - OwnedProxies []Proxy - OwnedBackends []Backend - Tokens []Token -} - -var DB *gorm.DB - -func InitializeDatabaseDialector() (gorm.Dialector, error) { - databaseBackend := os.Getenv("HERMES_DATABASE_BACKEND") - - switch databaseBackend { - case "sqlite": - filePath := os.Getenv("HERMES_SQLITE_FILEPATH") - - if filePath == "" { - return nil, fmt.Errorf("sqlite database file not specified (missing HERMES_SQLITE_FILEPATH)") - } - - return sqlite.Open(filePath), nil - case "postgresql": - postgresDSN := os.Getenv("HERMES_POSTGRES_DSN") - - if postgresDSN == "" { - return nil, fmt.Errorf("postgres DSN not specified (missing HERMES_POSTGRES_DSN)") - } - - return postgres.Open(postgresDSN), nil - case "": - return nil, fmt.Errorf("no database backend specified in environment variables (missing HERMES_DATABASE_BACKEND)") - default: - return nil, fmt.Errorf("unknown database backend specified: %s", os.Getenv(databaseBackend)) - } -} - -func InitializeDatabase(config *gorm.Config) error { - var err error - - dialector, err := InitializeDatabaseDialector() - - if err != nil { - return fmt.Errorf("failed to initialize physical database: %s", err) - } - - DB, err = gorm.Open(dialector, config) - - if err != nil { - return fmt.Errorf("failed to open database: %s", err) - } - - return nil -} - -func DoDatabaseMigrations(db *gorm.DB) error { - if err := db.AutoMigrate(&Proxy{}); err != nil { - return err - } - - if err := db.AutoMigrate(&Backend{}); err != nil { - return err - } - - if err := db.AutoMigrate(&Permission{}); err != nil { - return err - } - - if err := db.AutoMigrate(&Token{}); err != nil { - return err - } - - if err := db.AutoMigrate(&User{}); err != nil { - return err - } - - return nil -} diff --git a/backend/api/jwt/jwt.go b/backend/api/jwt/jwt.go new file mode 100644 index 0000000..40e011b --- /dev/null +++ b/backend/api/jwt/jwt.go @@ -0,0 +1,107 @@ +package jwt + +import ( + "errors" + "fmt" + "strconv" + "time" + + "git.terah.dev/imterah/hermes/backend/api/db" + "github.com/golang-jwt/jwt/v5" +) + +var ( + DevelopmentModeTimings = time.Duration(60*24) * time.Minute + NormalModeTimings = time.Duration(3) * time.Minute +) + +type JWTCore struct { + Key []byte + Database *db.DB + TimeMultiplier time.Duration +} + +func New(key []byte, database *db.DB, timeMultiplier time.Duration) *JWTCore { + jwtCore := &JWTCore{ + Key: key, + Database: database, + TimeMultiplier: timeMultiplier, + } + + return jwtCore +} + +func (jwtCore *JWTCore) Parse(tokenString string, options ...jwt.ParserOption) (*jwt.Token, error) { + return jwt.Parse(tokenString, jwtCore.jwtKeyCallback, options...) +} + +func (jwtCore *JWTCore) GetUserFromJWT(token string) (*db.User, error) { + if jwtCore.Database == nil { + return nil, fmt.Errorf("database is not initialized") + } + + parsedJWT, err := jwtCore.Parse(token) + + if err != nil { + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, fmt.Errorf("token is expired") + } else { + return nil, err + } + } + + audience, err := parsedJWT.Claims.GetAudience() + + if err != nil { + return nil, err + } + + if len(audience) < 1 { + return nil, fmt.Errorf("audience is too small") + } + + uid, err := strconv.Atoi(audience[0]) + + if err != nil { + return nil, err + } + + user := &db.User{} + userRequest := jwtCore.Database.DB.Preload("Permissions").Where("id = ?", uint(uid)).Find(&user) + + if userRequest.Error != nil { + return user, fmt.Errorf("failed to find if user exists or not: %s", userRequest.Error.Error()) + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + return user, fmt.Errorf("user does not exist") + } + + return user, nil +} + +func (jwtCore *JWTCore) Generate(uid uint) (string, error) { + currentJWTTime := jwt.NewNumericDate(time.Now()) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(jwtCore.TimeMultiplier)), + IssuedAt: currentJWTTime, + NotBefore: currentJWTTime, + // Convert the user ID to a string, and then set it as the audience parameters only value (there's only 1 user per key) + Audience: []string{strconv.Itoa(int(uid))}, + }) + + signedToken, err := token.SignedString(jwtCore.Key) + + if err != nil { + return "", err + } + + return signedToken, nil +} + +func (jwtCore *JWTCore) jwtKeyCallback(*jwt.Token) (any, error) { + return jwtCore.Key, nil +} diff --git a/backend/api/jwtcore/jwt.go b/backend/api/jwtcore/jwt.go deleted file mode 100644 index ed2ec56..0000000 --- a/backend/api/jwtcore/jwt.go +++ /dev/null @@ -1,117 +0,0 @@ -package jwtcore - -import ( - "encoding/base64" - "errors" - "fmt" - "os" - "strconv" - "time" - - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "github.com/golang-jwt/jwt/v5" -) - -var ( - JWTKey []byte - developmentMode bool -) - -func SetupJWT() error { - var err error - jwtDataString := os.Getenv("HERMES_JWT_SECRET") - - if jwtDataString == "" { - return fmt.Errorf("JWT secret isn't set (missing HERMES_JWT_SECRET)") - } - - if os.Getenv("HERMES_JWT_BASE64_ENCODED") != "" { - JWTKey, err = base64.StdEncoding.DecodeString(jwtDataString) - - if err != nil { - return fmt.Errorf("failed to decode base64 JWT: %s", err.Error()) - } - } else { - JWTKey = []byte(jwtDataString) - } - - if os.Getenv("HERMES_DEVELOPMENT_MODE") != "" { - developmentMode = true - } - - return nil -} - -func Parse(tokenString string, options ...jwt.ParserOption) (*jwt.Token, error) { - return jwt.Parse(tokenString, JWTKeyCallback, options...) -} - -func GetUserFromJWT(token string) (*dbcore.User, error) { - parsedJWT, err := Parse(token) - - if err != nil { - if errors.Is(err, jwt.ErrTokenExpired) { - return nil, fmt.Errorf("token is expired") - } else { - return nil, err - } - } - - audience, err := parsedJWT.Claims.GetAudience() - - if err != nil { - return nil, err - } - - if len(audience) < 1 { - return nil, fmt.Errorf("audience is too small") - } - - uid, err := strconv.Atoi(audience[0]) - - if err != nil { - return nil, err - } - - user := &dbcore.User{} - userRequest := dbcore.DB.Preload("Permissions").Where("id = ?", uint(uid)).Find(&user) - - if userRequest.Error != nil { - return user, fmt.Errorf("failed to find if user exists or not: %s", userRequest.Error.Error()) - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { - return user, fmt.Errorf("user does not exist") - } - - return user, nil -} - -func Generate(uid uint) (string, error) { - timeMultiplier := 3 - - if developmentMode { - timeMultiplier = 60 * 24 - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(timeMultiplier) * time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Audience: []string{strconv.Itoa(int(uid))}, - }) - - signedToken, err := token.SignedString(JWTKey) - - if err != nil { - return "", err - } - - return signedToken, nil -} - -func JWTKeyCallback(*jwt.Token) (interface{}, error) { - return JWTKey, nil -} diff --git a/backend/api/main.go b/backend/api/main.go index 2969e9b..1438af8 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -9,18 +9,19 @@ import ( "path" "path/filepath" "strings" + "time" "git.terah.dev/imterah/hermes/backend/api/backendruntime" "git.terah.dev/imterah/hermes/backend/api/controllers/v1/backends" "git.terah.dev/imterah/hermes/backend/api/controllers/v1/proxies" "git.terah.dev/imterah/hermes/backend/api/controllers/v1/users" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/jwt" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" "github.com/urfave/cli/v2" - "gorm.io/gorm" ) func apiEntrypoint(cCtx *cli.Context) error { @@ -34,7 +35,26 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Info("Hermes is initializing...") log.Debug("Initializing database and opening it...") - err := dbcore.InitializeDatabase(&gorm.Config{}) + databaseBackendName := os.Getenv("HERMES_DATABASE_BACKEND") + var databaseBackendParams string + + if databaseBackendName == "sqlite" { + databaseBackendParams = os.Getenv("HERMES_SQLITE_FILEPATH") + + if databaseBackendParams == "" { + log.Fatal("HERMES_SQLITE_FILEPATH is not set") + } + } else if databaseBackendName == "postgresql" { + databaseBackendParams = os.Getenv("HERMES_POSTGRES_DSN") + + if databaseBackendParams == "" { + log.Fatal("HERMES_POSTGRES_DSN is not set") + } + } else { + log.Fatalf("Unsupported database backend: %s", databaseBackendName) + } + + dbInstance, err := db.New(databaseBackendName, databaseBackendParams) if err != nil { log.Fatalf("Failed to initialize database: %s", err) @@ -42,16 +62,38 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Debug("Running database migrations...") - if err := dbcore.DoDatabaseMigrations(dbcore.DB); err != nil { + if err := dbInstance.DoMigrations(); err != nil { return fmt.Errorf("Failed to run database migrations: %s", err) } log.Debug("Initializing the JWT subsystem...") - if err := jwtcore.SetupJWT(); err != nil { - return fmt.Errorf("Failed to initialize the JWT subsystem: %s", err.Error()) + jwtDataString := os.Getenv("HERMES_JWT_SECRET") + var jwtKey []byte + var jwtValidityTimeDuration time.Duration + + if jwtDataString == "" { + log.Fatalf("HERMES_JWT_SECRET is not set") } + if os.Getenv("HERMES_JWT_BASE64_ENCODED") != "" { + jwtKey, err = base64.StdEncoding.DecodeString(jwtDataString) + + if err != nil { + log.Fatalf("Failed to decode base64 JWT: %s", err.Error()) + } + } else { + jwtKey = []byte(jwtDataString) + } + + if developmentMode { + jwtValidityTimeDuration = jwt.DevelopmentModeTimings + } else { + jwtValidityTimeDuration = jwt.NormalModeTimings + } + + jwtInstance := jwt.New(jwtKey, dbInstance, jwtValidityTimeDuration) + log.Debug("Initializing the backend subsystem...") backendMetadataPath := cCtx.String("backends-path") @@ -76,9 +118,9 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Debug("Enumerating backends...") - backendList := []dbcore.Backend{} + backendList := []db.Backend{} - if err := dbcore.DB.Find(&backendList).Error; err != nil { + if err := dbInstance.DB.Find(&backendList).Error; err != nil { return fmt.Errorf("Failed to enumerate backends: %s", err.Error()) } @@ -108,8 +150,7 @@ func apiEntrypoint(cCtx *cli.Context) error { return } - marshalledStartCommand, err := commonbackend.Marshal("start", &commonbackend.Start{ - Type: "start", + marshalledStartCommand, err := commonbackend.Marshal(&commonbackend.Start{ Arguments: backendParameters, }) @@ -123,7 +164,7 @@ func apiEntrypoint(cCtx *cli.Context) error { return } - _, backendResponse, err := commonbackend.Unmarshal(conn) + backendResponse, err := commonbackend.Unmarshal(conn) if err != nil { log.Errorf("Failed to get start command response for backend #%d: %s", backend.ID, err.Error()) @@ -142,9 +183,9 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Warnf("Backend #%d has reinitialized! Starting up auto-starting proxies...", backend.ID) - autoStartProxies := []dbcore.Proxy{} + autoStartProxies := []db.Proxy{} - if err := dbcore.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { + if err := dbInstance.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { log.Errorf("Failed to query proxies to autostart: %s", err.Error()) return } @@ -152,8 +193,7 @@ func apiEntrypoint(cCtx *cli.Context) error { for _, proxy := range autoStartProxies { log.Infof("Starting up route #%d for backend #%d: %s", proxy.ID, backend.ID, proxy.Name) - marhalledCommand, err := commonbackend.Marshal("addProxy", &commonbackend.AddProxy{ - Type: "addProxy", + marhalledCommand, err := commonbackend.Marshal(&commonbackend.AddProxy{ SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, @@ -170,7 +210,7 @@ func apiEntrypoint(cCtx *cli.Context) error { continue } - _, backendResponse, err := commonbackend.Unmarshal(conn) + backendResponse, err := commonbackend.Unmarshal(conn) if err != nil { log.Errorf("Failed to get response for backend #%d and route #%d: %s", proxy.BackendID, proxy.ID, err.Error()) @@ -204,7 +244,6 @@ func apiEntrypoint(cCtx *cli.Context) error { } backendStartResponse, err := backendInstance.ProcessCommand(&commonbackend.Start{ - Type: "start", Arguments: backendParameters, }) @@ -246,9 +285,9 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Infof("Successfully initialized backend #%d", backend.ID) - autoStartProxies := []dbcore.Proxy{} + autoStartProxies := []db.Proxy{} - if err := dbcore.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { + if err := dbInstance.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { log.Errorf("Failed to query proxies to autostart: %s", err.Error()) continue } @@ -257,7 +296,6 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Infof("Starting up route #%d for backend #%d: %s", proxy.ID, backend.ID, proxy.Name) backendResponse, err := backendInstance.ProcessCommand(&commonbackend.AddProxy{ - Type: "addProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, @@ -313,23 +351,25 @@ func apiEntrypoint(cCtx *cli.Context) error { engine.SetTrustedProxies(nil) } + state := state.New(dbInstance, jwtInstance, engine) + // Initialize routes - engine.POST("/api/v1/users/create", users.CreateUser) - engine.POST("/api/v1/users/login", users.LoginUser) - engine.POST("/api/v1/users/refresh", users.RefreshUserToken) - engine.POST("/api/v1/users/remove", users.RemoveUser) - engine.POST("/api/v1/users/lookup", users.LookupUser) + users.SetupCreateUser(state) + users.SetupLoginUser(state) + users.SetupRefreshUserToken(state) + users.SetupRemoveUser(state) + users.SetupLookupUser(state) - engine.POST("/api/v1/backends/create", backends.CreateBackend) - engine.POST("/api/v1/backends/remove", backends.RemoveBackend) - engine.POST("/api/v1/backends/lookup", backends.LookupBackend) + backends.SetupCreateBackend(state) + backends.SetupRemoveBackend(state) + backends.SetupLookupBackend(state) - engine.POST("/api/v1/forward/create", proxies.CreateProxy) - engine.POST("/api/v1/forward/lookup", proxies.LookupProxy) - engine.POST("/api/v1/forward/remove", proxies.RemoveProxy) - engine.POST("/api/v1/forward/start", proxies.StartProxy) - engine.POST("/api/v1/forward/stop", proxies.StopProxy) - engine.POST("/api/v1/forward/connections", proxies.GetConnections) + proxies.SetupCreateProxy(state) + proxies.SetupRemoveProxy(state) + proxies.SetupLookupProxy(state) + proxies.SetupStartProxy(state) + proxies.SetupStopProxy(state) + proxies.SetupGetConnections(state) log.Infof("Listening on '%s'", listeningAddress) err = engine.Run(listeningAddress) @@ -366,22 +406,6 @@ func main() { app := &cli.App{ Name: "hermes", Usage: "port forwarding across boundaries", - Commands: []*cli.Command{ - { - Name: "import", - Usage: "imports from legacy NextNet/Hermes source", - Aliases: []string{"i"}, - Flags: []cli.Flag{ - &cli.StringFlag{ - Name: "backup-path", - Aliases: []string{"bp"}, - Usage: "path to the backup file", - Required: true, - }, - }, - Action: backupRestoreEntrypoint, - }, - }, Flags: []cli.Flag{ &cli.StringFlag{ Name: "backends-path", diff --git a/backend/api/permissions/permission_nodes.go b/backend/api/permissions/permission_nodes.go index d682d69..f13e80a 100644 --- a/backend/api/permissions/permission_nodes.go +++ b/backend/api/permissions/permission_nodes.go @@ -1,6 +1,6 @@ package permissions -import "git.terah.dev/imterah/hermes/backend/api/dbcore" +import "git.terah.dev/imterah/hermes/backend/api/db" var DefaultPermissionNodes []string = []string{ "routes.add", @@ -27,7 +27,7 @@ var DefaultPermissionNodes []string = []string{ "users.edit", } -func UserHasPermission(user *dbcore.User, node string) bool { +func UserHasPermission(user *db.User, node string) bool { for _, permission := range user.Permissions { if permission.PermissionNode == node && permission.HasPermission { return true diff --git a/backend/api/state/state.go b/backend/api/state/state.go new file mode 100644 index 0000000..754ff56 --- /dev/null +++ b/backend/api/state/state.go @@ -0,0 +1,24 @@ +package state + +import ( + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/jwt" + "github.com/gin-gonic/gin" + "github.com/go-playground/validator/v10" +) + +type State struct { + DB *db.DB + JWT *jwt.JWTCore + Engine *gin.Engine + Validator *validator.Validate +} + +func New(db *db.DB, jwt *jwt.JWTCore, engine *gin.Engine) *State { + return &State{ + DB: db, + JWT: jwt, + Engine: engine, + Validator: validator.New(), + } +} diff --git a/backend/backends.dev.json b/backend/backends.dev.json index 9ff52ca..f314c69 100644 --- a/backend/backends.dev.json +++ b/backend/backends.dev.json @@ -3,6 +3,10 @@ "name": "ssh", "path": "./sshbackend/sshbackend" }, + { + "name": "sshapp", + "path": "./sshappbackend/local-code/sshappbackend" + }, { "name": "dummy", "path": "./dummybackend/dummybackend" diff --git a/backend/backends.prod.json b/backend/backends.prod.json index 9a9a09e..0ccfedc 100644 --- a/backend/backends.prod.json +++ b/backend/backends.prod.json @@ -2,5 +2,9 @@ { "name": "ssh", "path": "./sshbackend" + }, + { + "name": "sshapp", + "path": "./sshappbackend" } ] diff --git a/backend/backendutil/application.go b/backend/backendutil/application.go index 802e8c7..afa3147 100644 --- a/backend/backendutil/application.go +++ b/backend/backendutil/application.go @@ -1,7 +1,6 @@ package backendutil import ( - "fmt" "net" "os" @@ -18,7 +17,7 @@ type BackendApplicationHelper struct { func (helper *BackendApplicationHelper) Start() error { log.Debug("BackendApplicationHelper is starting") - err := configureAndLaunchBackgroundProfilingTasks() + err := ConfigureProfiling() if err != nil { return err @@ -35,21 +34,15 @@ func (helper *BackendApplicationHelper) Start() error { log.Debug("Sucessfully connected") for { - commandType, commandRaw, err := commonbackend.Unmarshal(helper.socket) + commandRaw, err := commonbackend.Unmarshal(helper.socket) if err != nil { return err } - switch commandType { - case "start": - command, ok := commandRaw.(*commonbackend.Start) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StartBackend(command.Arguments) + switch command := commandRaw.(type) { + case *commonbackend.Start: + ok, err := helper.Backend.StartBackend(command.Arguments) var ( message string @@ -64,13 +57,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -78,13 +70,7 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "backendStatusRequest": - _, ok := commandRaw.(*commonbackend.BackendStatusRequest) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.BackendStatusRequest: ok, err := helper.Backend.GetBackendStatus() var ( @@ -100,13 +86,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -114,14 +99,8 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "stop": - _, ok := commandRaw.(*commonbackend.Stop) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StopBackend() + case *commonbackend.Stop: + ok, err := helper.Backend.StopBackend() var ( message string @@ -136,13 +115,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: !ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -150,26 +128,19 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "addProxy": - command, ok := commandRaw.(*commonbackend.AddProxy) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StartProxy(command) + case *commonbackend.AddProxy: + ok, err := helper.Backend.StartProxy(command) var hasAnyFailed bool - if !ok { - log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) - hasAnyFailed = true - } else if err != nil { + if err != nil { log.Warnf("failed to add proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) hasAnyFailed = true + } else if !ok { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true } response := &commonbackend.ProxyStatusResponse{ - Type: "proxyStatusResponse", SourceIP: command.SourceIP, SourcePort: command.SourcePort, DestPort: command.DestPort, @@ -177,7 +148,7 @@ func (helper *BackendApplicationHelper) Start() error { IsActive: !hasAnyFailed, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -185,26 +156,19 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "removeProxy": - command, ok := commandRaw.(*commonbackend.RemoveProxy) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StopProxy(command) + case *commonbackend.RemoveProxy: + ok, err := helper.Backend.StopProxy(command) var hasAnyFailed bool - if !ok { - log.Warnf("failed to remove proxy (%s:%d -> remote:%d): RemoveProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) - hasAnyFailed = true - } else if err != nil { + if err != nil { log.Warnf("failed to remove proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) hasAnyFailed = true + } else if !ok { + log.Warnf("failed to remove proxy (%s:%d -> remote:%d): RemoveProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true } response := &commonbackend.ProxyStatusResponse{ - Type: "proxyStatusResponse", SourceIP: command.SourceIP, SourcePort: command.SourcePort, DestPort: command.DestPort, @@ -212,7 +176,7 @@ func (helper *BackendApplicationHelper) Start() error { IsActive: hasAnyFailed, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -220,21 +184,14 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "proxyConnectionsRequest": - _, ok := commandRaw.(*commonbackend.ProxyConnectionsRequest) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.ProxyConnectionsRequest: connections := helper.Backend.GetAllClientConnections() serverParams := &commonbackend.ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", Connections: connections, } - byteData, err := commonbackend.Marshal(serverParams.Type, serverParams) + byteData, err := commonbackend.Marshal(serverParams) if err != nil { return err @@ -243,18 +200,11 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } - case "checkClientParameters": - command, ok := commandRaw.(*commonbackend.CheckClientParameters) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.CheckClientParameters: resp := helper.Backend.CheckParametersForConnections(command) - resp.Type = "checkParametersResponse" resp.InResponseTo = "checkClientParameters" - byteData, err := commonbackend.Marshal(resp.Type, resp) + byteData, err := commonbackend.Marshal(resp) if err != nil { return err @@ -263,18 +213,11 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } - case "checkServerParameters": - command, ok := commandRaw.(*commonbackend.CheckServerParameters) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.CheckServerParameters: resp := helper.Backend.CheckParametersForBackend(command.Arguments) - resp.Type = "checkParametersResponse" resp.InResponseTo = "checkServerParameters" - byteData, err := commonbackend.Marshal(resp.Type, resp) + byteData, err := commonbackend.Marshal(resp) if err != nil { return err @@ -283,6 +226,8 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } + default: + log.Warnf("Unsupported command recieved: %T", command) } } } diff --git a/backend/backendutil/profiling_disabled.go b/backend/backendutil/profiling_disabled.go index d93cbdd..8538407 100644 --- a/backend/backendutil/profiling_disabled.go +++ b/backend/backendutil/profiling_disabled.go @@ -4,6 +4,6 @@ package backendutil var endProfileFunc func() -func configureAndLaunchBackgroundProfilingTasks() error { +func ConfigureProfiling() error { return nil } diff --git a/backend/backendutil/profiling_enabled.go b/backend/backendutil/profiling_enabled.go index a3be527..6fcb189 100644 --- a/backend/backendutil/profiling_enabled.go +++ b/backend/backendutil/profiling_enabled.go @@ -15,7 +15,7 @@ import ( "golang.org/x/exp/rand" ) -func configureAndLaunchBackgroundProfilingTasks() error { +func ConfigureProfiling() error { profilingMode, err := os.ReadFile("/tmp/hermes.backendlauncher.profilebackends") if err != nil && errors.Is(err, os.ErrNotExist) { diff --git a/backend/build.sh b/backend/build.sh index ecb7537..cee4440 100755 --- a/backend/build.sh +++ b/backend/build.sh @@ -1,20 +1,43 @@ #!/usr/bin/env bash -pushd sshbackend -GOOS=linux go build . -strip sshbackend -popd +pushd sshbackend > /dev/null +echo "building sshbackend" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null -pushd dummybackend -GOOS=linux go build . -strip dummybackend -popd +pushd dummybackend > /dev/null +echo "building dummybackend" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null -pushd externalbackendlauncher -go build . -strip externalbackendlauncher -popd +pushd externalbackendlauncher > /dev/null +echo "building externalbackendlauncher" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null -pushd api -GOOS=linux go build . -strip api -popd +if [ ! -d "sshappbackend/local-code/remote-bin" ]; then + mkdir "sshappbackend/local-code/remote-bin" +fi + +pushd sshappbackend/remote-code > /dev/null +echo "building sshappbackend/remote-code" +# Disable dynamic linking by disabling CGo. +# We need to make the remote code as generic as possible, so we do this +echo " - building for arm64" +CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm64 . +echo " - building for arm" +CGO_ENABLED=0 GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm . +echo " - building for amd64" +CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-amd64 . +echo " - building for i386" +CGO_ENABLED=0 GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-386 . +popd > /dev/null + +pushd sshappbackend/local-code > /dev/null +echo "building sshappbackend/local-code" +go build -ldflags="-s -w" -trimpath -o sshappbackend . +popd > /dev/null + +pushd api > /dev/null +echo "building api" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null diff --git a/backend/commonbackend/constants.go b/backend/commonbackend/constants.go index cdb68f2..6d5362b 100644 --- a/backend/commonbackend/constants.go +++ b/backend/commonbackend/constants.go @@ -1,16 +1,13 @@ package commonbackend type Start struct { - Type string // Will be 'start' always Arguments []byte } type Stop struct { - Type string // Will be 'stop' always } type AddProxy struct { - Type string // Will be 'addProxy' always SourceIP string SourcePort uint16 DestPort uint16 @@ -18,7 +15,6 @@ type AddProxy struct { } type RemoveProxy struct { - Type string // Will be 'removeProxy' always SourceIP string SourcePort uint16 DestPort uint16 @@ -26,7 +22,6 @@ type RemoveProxy struct { } type ProxyStatusRequest struct { - Type string // Will be 'proxyStatusRequest' always SourceIP string SourcePort uint16 DestPort uint16 @@ -34,7 +29,6 @@ type ProxyStatusRequest struct { } type ProxyStatusResponse struct { - Type string // Will be 'proxyStatusResponse' always SourceIP string SourcePort uint16 DestPort uint16 @@ -50,27 +44,22 @@ type ProxyInstance struct { } type ProxyInstanceResponse struct { - Type string // Will be 'proxyConnectionResponse' always Proxies []*ProxyInstance // List of connections } type ProxyInstanceRequest struct { - Type string // Will be 'proxyConnectionRequest' always } type BackendStatusResponse struct { - Type string // Will be 'backendStatusResponse' always IsRunning bool // True if running, false if not running StatusCode int // Either the 'Success' or 'Failure' constant Message string // String message from the client (ex. failed to dial TCP) } type BackendStatusRequest struct { - Type string // Will be 'backendStatusRequest' always } type ProxyConnectionsRequest struct { - Type string // Will be 'proxyConnectionsRequest' always } // Client's connection to a specific proxy @@ -83,12 +72,10 @@ type ProxyClientConnection struct { } type ProxyConnectionsResponse struct { - Type string // Will be 'proxyConnectionsResponse' always Connections []*ProxyClientConnection // List of connections } type CheckClientParameters struct { - Type string // Will be 'checkClientParameters' always SourceIP string SourcePort uint16 DestPort uint16 @@ -96,13 +83,11 @@ type CheckClientParameters struct { } type CheckServerParameters struct { - Type string // Will be 'checkServerParameters' always Arguments []byte } // Sent as a response to either CheckClientParameters or CheckBackendParameters type CheckParametersResponse struct { - Type string // Will be 'checkParametersResponse' always InResponseTo string // Will be either 'checkClientParameters' or 'checkServerParameters' IsValid bool // If true, valid, and if false, invalid Message string // String message from the client (ex. failed to unmarshal JSON: x is not defined) diff --git a/backend/commonbackend/marshal.go b/backend/commonbackend/marshal.go index 6baf02e..7203ee3 100644 --- a/backend/commonbackend/marshal.go +++ b/backend/commonbackend/marshal.go @@ -84,37 +84,19 @@ func marshalIndividualProxyStruct(conn *ProxyInstance) ([]byte, error) { return proxyBlock, nil } -func Marshal(commandType string, command interface{}) ([]byte, error) { - switch commandType { - case "start": - startCommand, ok := command.(*Start) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - startCommandBytes := make([]byte, 1+2+len(startCommand.Arguments)) +func Marshal(command interface{}) ([]byte, error) { + switch command := command.(type) { + case *Start: + startCommandBytes := make([]byte, 1+2+len(command.Arguments)) startCommandBytes[0] = StartID - binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(startCommand.Arguments))) - copy(startCommandBytes[3:], startCommand.Arguments) + binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(command.Arguments))) + copy(startCommandBytes[3:], command.Arguments) return startCommandBytes, nil - case "stop": - _, ok := command.(*Stop) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *Stop: return []byte{StopID}, nil - case "addProxy": - addConnectionCommand, ok := command.(*AddProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(addConnectionCommand.SourceIP) + case *AddProxy: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -134,14 +116,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { copy(addConnectionBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], addConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], addConnectionCommand.DestPort) + binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if addConnectionCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if addConnectionCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -150,14 +132,8 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { addConnectionBytes[6+len(ipBytes)] = protocol return addConnectionBytes, nil - case "removeProxy": - removeConnectionCommand, ok := command.(*RemoveProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(removeConnectionCommand.SourceIP) + case *RemoveProxy: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -175,14 +151,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { removeConnectionBytes[0] = RemoveProxyID removeConnectionBytes[1] = ipVer copy(removeConnectionBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], removeConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], removeConnectionCommand.DestPort) + binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if removeConnectionCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if removeConnectionCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -191,17 +167,11 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { removeConnectionBytes[6+len(ipBytes)] = protocol return removeConnectionBytes, nil - case "proxyConnectionsResponse": - allConnectionsCommand, ok := command.(*ProxyConnectionsResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - connectionsArray := make([][]byte, len(allConnectionsCommand.Connections)) + case *ProxyConnectionsResponse: + connectionsArray := make([][]byte, len(command.Connections)) totalSize := 0 - for connIndex, conn := range allConnectionsCommand.Connections { + for connIndex, conn := range command.Connections { connectionsArray[connIndex] = marshalIndividualConnectionStruct(conn) totalSize += len(connectionsArray[connIndex]) + 1 } @@ -223,14 +193,8 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { connectionCommandArray[totalSize] = '\n' return connectionCommandArray, nil - case "checkClientParameters": - checkClientCommand, ok := command.(*CheckClientParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(checkClientCommand.SourceIP) + case *CheckClientParameters: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -248,14 +212,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { checkClientBytes[0] = CheckClientParametersID checkClientBytes[1] = ipVer copy(checkClientBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], checkClientCommand.SourcePort) - binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], checkClientCommand.DestPort) + binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if checkClientCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if checkClientCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -264,31 +228,19 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { checkClientBytes[6+len(ipBytes)] = protocol return checkClientBytes, nil - case "checkServerParameters": - checkServerCommand, ok := command.(*CheckServerParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - serverCommandBytes := make([]byte, 1+2+len(checkServerCommand.Arguments)) + case *CheckServerParameters: + serverCommandBytes := make([]byte, 1+2+len(command.Arguments)) serverCommandBytes[0] = CheckServerParametersID - binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(checkServerCommand.Arguments))) - copy(serverCommandBytes[3:], checkServerCommand.Arguments) + binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(command.Arguments))) + copy(serverCommandBytes[3:], command.Arguments) return serverCommandBytes, nil - case "checkParametersResponse": - checkParametersCommand, ok := command.(*CheckParametersResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *CheckParametersResponse: var checkMethod uint8 - if checkParametersCommand.InResponseTo == "checkClientParameters" { + if command.InResponseTo == "checkClientParameters" { checkMethod = CheckClientParametersID - } else if checkParametersCommand.InResponseTo == "checkServerParameters" { + } else if command.InResponseTo == "checkServerParameters" { checkMethod = CheckServerParametersID } else { return nil, fmt.Errorf("invalid mode recieved (must be either checkClientParameters or checkServerParameters)") @@ -296,68 +248,50 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { var isValid uint8 - if checkParametersCommand.IsValid { + if command.IsValid { isValid = 1 } - checkResponseBytes := make([]byte, 3+2+len(checkParametersCommand.Message)) + checkResponseBytes := make([]byte, 3+2+len(command.Message)) checkResponseBytes[0] = CheckParametersResponseID checkResponseBytes[1] = checkMethod checkResponseBytes[2] = isValid - binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(checkParametersCommand.Message))) + binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(command.Message))) - if len(checkParametersCommand.Message) != 0 { - copy(checkResponseBytes[5:], []byte(checkParametersCommand.Message)) + if len(command.Message) != 0 { + copy(checkResponseBytes[5:], []byte(command.Message)) } return checkResponseBytes, nil - case "backendStatusResponse": - backendStatusResponse, ok := command.(*BackendStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *BackendStatusResponse: var isRunning uint8 - if backendStatusResponse.IsRunning { + if command.IsRunning { isRunning = 1 } else { isRunning = 0 } - statusResponseBytes := make([]byte, 3+2+len(backendStatusResponse.Message)) + statusResponseBytes := make([]byte, 3+2+len(command.Message)) statusResponseBytes[0] = BackendStatusResponseID statusResponseBytes[1] = isRunning - statusResponseBytes[2] = byte(backendStatusResponse.StatusCode) + statusResponseBytes[2] = byte(command.StatusCode) - binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(backendStatusResponse.Message))) + binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(command.Message))) - if len(backendStatusResponse.Message) != 0 { - copy(statusResponseBytes[5:], []byte(backendStatusResponse.Message)) + if len(command.Message) != 0 { + copy(statusResponseBytes[5:], []byte(command.Message)) } return statusResponseBytes, nil - case "backendStatusRequest": - _, ok := command.(*BackendStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *BackendStatusRequest: statusRequestBytes := make([]byte, 1) statusRequestBytes[0] = BackendStatusRequestID return statusRequestBytes, nil - case "proxyStatusRequest": - proxyStatusRequest, ok := command.(*ProxyStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusRequest.SourceIP) + case *ProxyStatusRequest: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -370,37 +304,31 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { ipVer = IPv4 } - proxyStatusRequestBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) + commandBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) - proxyStatusRequestBytes[0] = ProxyStatusRequestID - proxyStatusRequestBytes[1] = ipVer + commandBytes[0] = ProxyStatusRequestID + commandBytes[1] = ipVer - copy(proxyStatusRequestBytes[2:2+len(ipBytes)], ipBytes) + copy(commandBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(proxyStatusRequestBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusRequest.SourcePort) - binary.BigEndian.PutUint16(proxyStatusRequestBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusRequest.DestPort) + binary.BigEndian.PutUint16(commandBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(commandBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if proxyStatusRequest.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if proxyStatusRequest.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") } - proxyStatusRequestBytes[6+len(ipBytes)] = protocol + commandBytes[6+len(ipBytes)] = protocol - return proxyStatusRequestBytes, nil - case "proxyStatusResponse": - proxyStatusResponse, ok := command.(*ProxyStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusResponse.SourceIP) + return commandBytes, nil + case *ProxyStatusResponse: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -413,50 +341,44 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { ipVer = IPv4 } - proxyStatusResponseBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) + commandBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) - proxyStatusResponseBytes[0] = ProxyStatusResponseID - proxyStatusResponseBytes[1] = ipVer + commandBytes[0] = ProxyStatusResponseID + commandBytes[1] = ipVer - copy(proxyStatusResponseBytes[2:2+len(ipBytes)], ipBytes) + copy(commandBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(proxyStatusResponseBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusResponse.SourcePort) - binary.BigEndian.PutUint16(proxyStatusResponseBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusResponse.DestPort) + binary.BigEndian.PutUint16(commandBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(commandBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if proxyStatusResponse.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if proxyStatusResponse.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") } - proxyStatusResponseBytes[6+len(ipBytes)] = protocol + commandBytes[6+len(ipBytes)] = protocol var isActive uint8 - if proxyStatusResponse.IsActive { + if command.IsActive { isActive = 1 } else { isActive = 0 } - proxyStatusResponseBytes[7+len(ipBytes)] = isActive + commandBytes[7+len(ipBytes)] = isActive - return proxyStatusResponseBytes, nil - case "proxyInstanceResponse": - proxyConectionResponse, ok := command.(*ProxyInstanceResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - proxyArray := make([][]byte, len(proxyConectionResponse.Proxies)) + return commandBytes, nil + case *ProxyInstanceResponse: + proxyArray := make([][]byte, len(command.Proxies)) totalSize := 0 - for proxyIndex, proxy := range proxyConectionResponse.Proxies { + for proxyIndex, proxy := range command.Proxies { var err error proxyArray[proxyIndex], err = marshalIndividualProxyStruct(proxy) @@ -485,23 +407,11 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { connectionCommandArray[totalSize] = '\n' return connectionCommandArray, nil - case "proxyInstanceRequest": - _, ok := command.(*ProxyInstanceRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *ProxyInstanceRequest: return []byte{ProxyInstanceRequestID}, nil - case "proxyConnectionsRequest": - _, ok := command.(*ProxyConnectionsRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *ProxyConnectionsRequest: return []byte{ProxyConnectionsRequestID}, nil } - return nil, fmt.Errorf("couldn't match command name") + return nil, fmt.Errorf("couldn't match command type") } diff --git a/backend/commonbackend/marshalling_test.go b/backend/commonbackend/marshalling_test.go index 1c93f94..c2b6375 100644 --- a/backend/commonbackend/marshalling_test.go +++ b/backend/commonbackend/marshalling_test.go @@ -9,13 +9,12 @@ import ( var logLevel = os.Getenv("HERMES_LOG_LEVEL") -func TestStartCommandMarshalSupport(t *testing.T) { +func TestStart(t *testing.T) { commandInput := &Start{ - Type: "start", Arguments: []byte("Hello from automated testing"), } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -26,39 +25,27 @@ func TestStartCommandMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*Start) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) } } -func TestStopCommandMarshalSupport(t *testing.T) { - commandInput := &Stop{ - Type: "stop", - } +func TestStop(t *testing.T) { + commandInput := &Stop{} - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -69,39 +56,28 @@ func TestStopCommandMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*Stop) + _, ok := commandUnmarshalledRaw.(*Stop) if !ok { t.Fatal("failed typecast") } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } } -func TestAddConnectionCommandMarshalSupport(t *testing.T) { +func TestAddConnection(t *testing.T) { commandInput := &AddProxy{ - Type: "addProxy", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -112,28 +88,18 @@ func TestAddConnectionCommandMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*AddProxy) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -155,16 +121,15 @@ func TestAddConnectionCommandMarshalSupport(t *testing.T) { } } -func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { +func TestRemoveConnection(t *testing.T) { commandInput := &RemoveProxy{ - Type: "removeProxy", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -175,28 +140,18 @@ func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -218,9 +173,8 @@ func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { } } -func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { +func TestGetAllConnections(t *testing.T) { commandInput := &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", Connections: []*ProxyClientConnection{ { SourceIP: "127.0.0.1", @@ -246,7 +200,7 @@ func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { }, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -257,28 +211,18 @@ func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - for commandIndex, originalConnection := range commandInput.Connections { remoteConnection := commandUnmarshalled.Connections[commandIndex] @@ -309,16 +253,15 @@ func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { } } -func TestCheckClientParametersMarshalSupport(t *testing.T) { +func TestCheckClientParameters(t *testing.T) { commandInput := &CheckClientParameters{ - Type: "checkClientParameters", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -329,28 +272,18 @@ func TestCheckClientParametersMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckClientParameters) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -372,13 +305,12 @@ func TestCheckClientParametersMarshalSupport(t *testing.T) { } } -func TestCheckServerParametersMarshalSupport(t *testing.T) { +func TestCheckServerParameters(t *testing.T) { commandInput := &CheckServerParameters{ - Type: "checkServerParameters", Arguments: []byte("Hello from automated testing"), } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -389,42 +321,31 @@ func TestCheckServerParametersMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckServerParameters) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) } } -func TestCheckParametersResponseMarshalSupport(t *testing.T) { +func TestCheckParametersResponse(t *testing.T) { commandInput := &CheckParametersResponse{ - Type: "checkParametersResponse", InResponseTo: "checkClientParameters", IsValid: true, Message: "Hello from automated testing", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -435,28 +356,18 @@ func TestCheckParametersResponseMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckParametersResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.InResponseTo != commandUnmarshalled.InResponseTo { t.Fail() log.Printf("InResponseTo's are not equal (orig: %s, unmsh: %s)", commandInput.InResponseTo, commandUnmarshalled.InResponseTo) @@ -473,12 +384,9 @@ func TestCheckParametersResponseMarshalSupport(t *testing.T) { } } -func TestBackendStatusRequestMarshalSupport(t *testing.T) { - commandInput := &BackendStatusRequest{ - Type: "backendStatusRequest", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) +func TestBackendStatusRequest(t *testing.T) { + commandInput := &BackendStatusRequest{} + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -489,38 +397,27 @@ func TestBackendStatusRequestMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusRequest) + _, ok := commandUnmarshalledRaw.(*BackendStatusRequest) if !ok { t.Fatal("failed typecast") } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } } -func TestBackendStatusResponseMarshalSupport(t *testing.T) { +func TestBackendStatusResponse(t *testing.T) { commandInput := &BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: true, StatusCode: StatusFailure, Message: "Hello from automated testing", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -531,28 +428,18 @@ func TestBackendStatusResponseMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.IsRunning != commandUnmarshalled.IsRunning { t.Fail() log.Printf("IsRunning's are not equal (orig: %t, unmsh: %t)", commandInput.IsRunning, commandUnmarshalled.IsRunning) @@ -569,16 +456,15 @@ func TestBackendStatusResponseMarshalSupport(t *testing.T) { } } -func TestProxyStatusRequestMarshalSupport(t *testing.T) { +func TestProxyStatusRequest(t *testing.T) { commandInput := &ProxyStatusRequest{ - Type: "proxyStatusRequest", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -589,28 +475,18 @@ func TestProxyStatusRequestMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -632,9 +508,8 @@ func TestProxyStatusRequestMarshalSupport(t *testing.T) { } } -func TestProxyStatusResponseMarshalSupport(t *testing.T) { +func TestProxyStatusResponse(t *testing.T) { commandInput := &ProxyStatusResponse{ - Type: "proxyStatusResponse", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, @@ -642,7 +517,7 @@ func TestProxyStatusResponseMarshalSupport(t *testing.T) { IsActive: true, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -653,28 +528,18 @@ func TestProxyStatusResponseMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -701,12 +566,10 @@ func TestProxyStatusResponseMarshalSupport(t *testing.T) { } } -func TestProxyConnectionRequestMarshalSupport(t *testing.T) { - commandInput := &ProxyInstanceRequest{ - Type: "proxyInstanceRequest", - } +func TestProxyConnectionRequest(t *testing.T) { + commandInput := &ProxyInstanceRequest{} - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -717,32 +580,21 @@ func TestProxyConnectionRequestMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceRequest) + _, ok := commandUnmarshalledRaw.(*ProxyInstanceRequest) if !ok { t.Fatal("failed typecast") } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } } -func TestProxyConnectionResponseMarshalSupport(t *testing.T) { +func TestProxyConnectionResponse(t *testing.T) { commandInput := &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", Proxies: []*ProxyInstance{ { SourceIP: "192.168.0.168", @@ -765,7 +617,7 @@ func TestProxyConnectionResponseMarshalSupport(t *testing.T) { }, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -776,28 +628,18 @@ func TestProxyConnectionResponseMarshalSupport(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - for proxyIndex, originalProxy := range commandInput.Proxies { remoteProxy := commandUnmarshalled.Proxies[proxyIndex] diff --git a/backend/commonbackend/unmarshal.go b/backend/commonbackend/unmarshal.go index b8500dd..6bb5af4 100644 --- a/backend/commonbackend/unmarshal.go +++ b/backend/commonbackend/unmarshal.go @@ -142,11 +142,11 @@ func unmarshalIndividualProxyStruct(conn io.Reader) (*ProxyInstance, error) { }, nil } -func Unmarshal(conn io.Reader) (string, interface{}, error) { +func Unmarshal(conn io.Reader) (interface{}, error) { commandType := make([]byte, 1) if _, err := conn.Read(commandType); err != nil { - return "", nil, fmt.Errorf("couldn't read command") + return nil, fmt.Errorf("couldn't read command") } switch commandType[0] { @@ -154,28 +154,25 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { argumentsLength := make([]byte, 2) if _, err := conn.Read(argumentsLength); err != nil { - return "", nil, fmt.Errorf("couldn't read argument length") + return nil, fmt.Errorf("couldn't read argument length") } arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) if _, err := conn.Read(arguments); err != nil { - return "", nil, fmt.Errorf("couldn't read arguments") + return nil, fmt.Errorf("couldn't read arguments") } - return "start", &Start{ - Type: "start", + return &Start{ Arguments: arguments, }, nil case StopID: - return "stop", &Stop{ - Type: "stop", - }, nil + return &Stop{}, nil case AddProxyID: ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -185,45 +182,44 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "addProxy", &AddProxy{ - Type: "addProxy", + return &AddProxy{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -233,7 +229,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -243,45 +239,44 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "removeProxy", &RemoveProxy{ - Type: "removeProxy", + return &RemoveProxy{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -301,13 +296,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { break } - return "", nil, err + return nil, err } connections = append(connections, connection) if _, err := conn.Read(delimiter); err != nil { - return "", nil, fmt.Errorf("couldn't read delimiter") + return nil, fmt.Errorf("couldn't read delimiter") } if delimiter[0] == '\r' { @@ -321,15 +316,14 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } } - return "proxyConnectionsResponse", &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", + return &ProxyConnectionsResponse{ Connections: connections, }, errorReturn case CheckClientParametersID: ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -339,45 +333,44 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "checkClientParameters", &CheckClientParameters{ - Type: "checkClientParameters", + return &CheckClientParameters{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -387,24 +380,23 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { argumentsLength := make([]byte, 2) if _, err := conn.Read(argumentsLength); err != nil { - return "", nil, fmt.Errorf("couldn't read argument length") + return nil, fmt.Errorf("couldn't read argument length") } arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) if _, err := conn.Read(arguments); err != nil { - return "", nil, fmt.Errorf("couldn't read arguments") + return nil, fmt.Errorf("couldn't read arguments") } - return "checkServerParameters", &CheckServerParameters{ - Type: "checkServerParameters", + return &CheckServerParameters{ Arguments: arguments, }, nil case CheckParametersResponseID: checkMethodByte := make([]byte, 1) if _, err := conn.Read(checkMethodByte); err != nil { - return "", nil, fmt.Errorf("couldn't read check method byte") + return nil, fmt.Errorf("couldn't read check method byte") } var checkMethod string @@ -414,19 +406,19 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if checkMethodByte[0] == CheckServerParametersID { checkMethod = "checkServerParameters" } else { - return "", nil, fmt.Errorf("invalid check method recieved") + return nil, fmt.Errorf("invalid check method recieved") } isValid := make([]byte, 1) if _, err := conn.Read(isValid); err != nil { - return "", nil, fmt.Errorf("couldn't read isValid byte") + return nil, fmt.Errorf("couldn't read isValid byte") } messageLengthBytes := make([]byte, 2) if _, err := conn.Read(messageLengthBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message length") + return nil, fmt.Errorf("couldn't read message length") } messageLength := binary.BigEndian.Uint16(messageLengthBytes) @@ -436,14 +428,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { messageBytes := make([]byte, messageLength) if _, err := conn.Read(messageBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message") + return nil, fmt.Errorf("couldn't read message") } message = string(messageBytes) } - return "checkParametersResponse", &CheckParametersResponse{ - Type: "checkParametersResponse", + return &CheckParametersResponse{ InResponseTo: checkMethod, IsValid: isValid[0] == 1, Message: message, @@ -452,19 +443,19 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { isRunning := make([]byte, 1) if _, err := conn.Read(isRunning); err != nil { - return "", nil, fmt.Errorf("couldn't read isRunning field") + return nil, fmt.Errorf("couldn't read isRunning field") } statusCode := make([]byte, 1) if _, err := conn.Read(statusCode); err != nil { - return "", nil, fmt.Errorf("couldn't read status code field") + return nil, fmt.Errorf("couldn't read status code field") } messageLengthBytes := make([]byte, 2) if _, err := conn.Read(messageLengthBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message length") + return nil, fmt.Errorf("couldn't read message length") } messageLength := binary.BigEndian.Uint16(messageLengthBytes) @@ -474,27 +465,24 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { messageBytes := make([]byte, messageLength) if _, err := conn.Read(messageBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message") + return nil, fmt.Errorf("couldn't read message") } message = string(messageBytes) } - return "backendStatusResponse", &BackendStatusResponse{ - Type: "backendStatusResponse", + return &BackendStatusResponse{ IsRunning: isRunning[0] == 1, StatusCode: int(statusCode[0]), Message: message, }, nil case BackendStatusRequestID: - return "backendStatusRequest", &BackendStatusRequest{ - Type: "backendStatusRequest", - }, nil + return &BackendStatusRequest{}, nil case ProxyStatusRequestID: ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -504,45 +492,44 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "proxyStatusRequest", &ProxyStatusRequest{ - Type: "proxyStatusRequest", + return &ProxyStatusRequest{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -552,7 +539,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -562,51 +549,50 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } isActive := make([]byte, 1) if _, err := conn.Read(isActive); err != nil { - return "", nil, fmt.Errorf("couldn't read isActive field") + return nil, fmt.Errorf("couldn't read isActive field") } - return "proxyStatusResponse", &ProxyStatusResponse{ - Type: "proxyStatusResponse", + return &ProxyStatusResponse{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -614,9 +600,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { IsActive: isActive[0] == 1, }, nil case ProxyInstanceRequestID: - return "proxyInstanceRequest", &ProxyInstanceRequest{ - Type: "proxyInstanceRequest", - }, nil + return &ProxyInstanceRequest{}, nil case ProxyInstanceResponseID: proxies := []*ProxyInstance{} delimiter := make([]byte, 1) @@ -631,13 +615,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { break } - return "", nil, err + return nil, err } proxies = append(proxies, proxy) if _, err := conn.Read(delimiter); err != nil { - return "", nil, fmt.Errorf("couldn't read delimiter") + return nil, fmt.Errorf("couldn't read delimiter") } if delimiter[0] == '\r' { @@ -651,15 +635,12 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } } - return "proxyInstanceResponse", &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", + return &ProxyInstanceResponse{ Proxies: proxies, }, errorReturn case ProxyConnectionsRequestID: - return "proxyConnectionsRequest", &ProxyConnectionsRequest{ - Type: "proxyConnectionsRequest", - }, nil + return &ProxyConnectionsRequest{}, nil } - return "", nil, fmt.Errorf("couldn't match command ID") + return nil, fmt.Errorf("couldn't match command ID") } diff --git a/backend/externalbackendlauncher/main.go b/backend/externalbackendlauncher/main.go index 4c66c6a..c196866 100644 --- a/backend/externalbackendlauncher/main.go +++ b/backend/externalbackendlauncher/main.go @@ -21,11 +21,8 @@ type ProxyInstance struct { Protocol string `json:"protocol"` } -type WriteLogger struct { - UseError bool -} +type WriteLogger struct{} -// TODO: deprecate UseError switching func (writer WriteLogger) Write(p []byte) (n int, err error) { logSplit := strings.Split(string(p), "\n") @@ -34,11 +31,7 @@ func (writer WriteLogger) Write(p []byte) (n int, err error) { continue } - if writer.UseError { - log.Errorf("application: %s", line) - } else { - log.Infof("application: %s", line) - } + log.Infof("application: %s", line) } return len(p), err @@ -119,11 +112,10 @@ func entrypoint(cCtx *cli.Context) error { defer sock.Close() startCommand := &commonbackend.Start{ - Type: "start", Arguments: backendParameters, } - startMarshalledCommand, err := commonbackend.Marshal("start", startCommand) + startMarshalledCommand, err := commonbackend.Marshal(startCommand) if err != nil { log.Errorf("failed to generate start command: %s", err.Error()) @@ -135,18 +127,13 @@ func entrypoint(cCtx *cli.Context) error { continue } - commandType, commandRaw, err := commonbackend.Unmarshal(sock) + commandRaw, err := commonbackend.Unmarshal(sock) if err != nil { log.Errorf("failed to read from/unmarshal from socket: %s", err.Error()) continue } - if commandType != "backendStatusResponse" { - log.Errorf("recieved commandType '%s', expecting 'backendStatusResponse'", commandType) - continue - } - command, ok := commandRaw.(*commonbackend.BackendStatusResponse) if !ok { @@ -175,14 +162,13 @@ func entrypoint(cCtx *cli.Context) error { log.Infof("initializing proxy %s:%d -> remote:%d", proxy.SourceIP, proxy.SourcePort, proxy.DestPort) proxyAddCommand := &commonbackend.AddProxy{ - Type: "addProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestPort, Protocol: proxy.Protocol, } - marshalledProxyCommand, err := commonbackend.Marshal("addProxy", proxyAddCommand) + marshalledProxyCommand, err := commonbackend.Marshal(proxyAddCommand) if err != nil { log.Errorf("failed to generate start command: %s", err.Error()) @@ -196,7 +182,7 @@ func entrypoint(cCtx *cli.Context) error { continue } - commandType, commandRaw, err := commonbackend.Unmarshal(sock) + commandRaw, err := commonbackend.Unmarshal(sock) if err != nil { log.Errorf("failed to read from/unmarshal from socket: %s", err.Error()) @@ -204,12 +190,6 @@ func entrypoint(cCtx *cli.Context) error { continue } - if commandType != "proxyStatusResponse" { - log.Errorf("recieved commandType '%s', expecting 'proxyStatusResponse'", commandType) - hasAnyFailed = true - continue - } - command, ok := commandRaw.(*commonbackend.ProxyStatusResponse) if !ok { @@ -242,13 +222,8 @@ func entrypoint(cCtx *cli.Context) error { log.Debug("entering execution loop (in main goroutine)...") - stdout := WriteLogger{ - UseError: false, - } - - stderr := WriteLogger{ - UseError: true, - } + stdout := WriteLogger{} + stderr := WriteLogger{} for { log.Info("starting process...") diff --git a/backend/sshappbackend/datacommands/constants.go b/backend/sshappbackend/datacommands/constants.go new file mode 100644 index 0000000..6385e98 --- /dev/null +++ b/backend/sshappbackend/datacommands/constants.go @@ -0,0 +1,90 @@ +package datacommands + +// DO NOT USE +type ProxyStatusRequest struct { + ProxyID uint16 +} + +type ProxyStatusResponse struct { + ProxyID uint16 + IsActive bool +} + +type RemoveProxy struct { + ProxyID uint16 +} + +type ProxyInstanceResponse struct { + Proxies []uint16 +} + +type ProxyConnectionsRequest struct { + ProxyID uint16 +} + +type ProxyConnectionsResponse struct { + Connections []uint16 +} + +type TCPConnectionOpened struct { + ProxyID uint16 + ConnectionID uint16 +} + +type TCPConnectionClosed struct { + ProxyID uint16 + ConnectionID uint16 +} + +type TCPProxyData struct { + ProxyID uint16 + ConnectionID uint16 + DataLength uint16 +} + +type UDPProxyData struct { + ProxyID uint16 + ClientIP string + ClientPort uint16 + DataLength uint16 +} + +type ProxyInformationRequest struct { + ProxyID uint16 +} + +type ProxyInformationResponse struct { + Exists bool + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type ProxyConnectionInformationRequest struct { + ProxyID uint16 + ConnectionID uint16 +} + +type ProxyConnectionInformationResponse struct { + Exists bool + ClientIP string + ClientPort uint16 +} + +const ( + ProxyStatusRequestID = iota + 100 + ProxyStatusResponseID + RemoveProxyID + ProxyInstanceResponseID + ProxyConnectionsRequestID + ProxyConnectionsResponseID + TCPConnectionOpenedID + TCPConnectionClosedID + TCPProxyDataID + UDPProxyDataID + ProxyInformationRequestID + ProxyInformationResponseID + ProxyConnectionInformationRequestID + ProxyConnectionInformationResponseID +) diff --git a/backend/sshappbackend/datacommands/marshal.go b/backend/sshappbackend/datacommands/marshal.go new file mode 100644 index 0000000..a2c13bf --- /dev/null +++ b/backend/sshappbackend/datacommands/marshal.go @@ -0,0 +1,323 @@ +package datacommands + +import ( + "encoding/binary" + "fmt" + "net" +) + +// Example size and protocol constants — adjust as needed. +const ( + IPv4Size = 4 + IPv6Size = 16 + + TCP = 1 + UDP = 2 +) + +// Marshal takes a command (pointer to one of our structs) and converts it to a byte slice. +func Marshal(command interface{}) ([]byte, error) { + switch cmd := command.(type) { + // ProxyStatusRequest: 1 byte for the command ID + 2 bytes for the ProxyID. + case *ProxyStatusRequest: + buf := make([]byte, 1+2) + + buf[0] = ProxyStatusRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyStatusResponse: 1 byte for the command ID, 2 bytes for ProxyID, and 1 byte for IsActive. + case *ProxyStatusResponse: + buf := make([]byte, 1+2+1) + + buf[0] = ProxyStatusResponseID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + if cmd.IsActive { + buf[3] = 1 + } else { + buf[3] = 0 + } + + return buf, nil + + // RemoveProxy: 1 byte for the command ID + 2 bytes for the ProxyID. + case *RemoveProxy: + buf := make([]byte, 1+2) + + buf[0] = RemoveProxyID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyConnectionsRequest: 1 byte for the command ID + 2 bytes for the ProxyID. + case *ProxyConnectionsRequest: + buf := make([]byte, 1+2) + + buf[0] = ProxyConnectionsRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyConnectionsResponse: 1 byte for the command ID + 2 bytes length of the Connections + 2 bytes for each + // number in the Connection array. + case *ProxyConnectionsResponse: + buf := make([]byte, 1+((len(cmd.Connections)+1)*2)) + + buf[0] = ProxyConnectionsResponseID + binary.BigEndian.PutUint16(buf[1:], uint16(len(cmd.Connections))) + + for connectionIndex, connection := range cmd.Connections { + binary.BigEndian.PutUint16(buf[3+(connectionIndex*2):], connection) + } + + return buf, nil + + // ProxyConnectionsResponse: 1 byte for the command ID + 2 bytes length of the Proxies + 2 bytes for each + // number in the Proxies array. + case *ProxyInstanceResponse: + buf := make([]byte, 1+((len(cmd.Proxies)+1)*2)) + + buf[0] = ProxyInstanceResponseID + binary.BigEndian.PutUint16(buf[1:], uint16(len(cmd.Proxies))) + + for connectionIndex, connection := range cmd.Proxies { + binary.BigEndian.PutUint16(buf[3+(connectionIndex*2):], connection) + } + + return buf, nil + + // TCPConnectionOpened: 1 byte for the command ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *TCPConnectionOpened: + buf := make([]byte, 1+2+2) + + buf[0] = TCPConnectionOpenedID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // TCPConnectionClosed: 1 byte for the command ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *TCPConnectionClosed: + buf := make([]byte, 1+2+2) + + buf[0] = TCPConnectionClosedID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // TCPProxyData: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + 2 bytes DataLength. + case *TCPProxyData: + buf := make([]byte, 1+2+2+2) + + buf[0] = TCPProxyDataID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + binary.BigEndian.PutUint16(buf[5:], cmd.DataLength) + + return buf, nil + + // UDPProxyData: + // Format: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + + // 1 byte IP version + IP bytes + 2 bytes ClientPort + 2 bytes DataLength. + case *UDPProxyData: + ip := net.ParseIP(cmd.ClientIP) + if ip == nil { + return nil, fmt.Errorf("invalid client IP: %v", cmd.ClientIP) + } + + var ipVer uint8 + var ipBytes []byte + + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.ClientIP) + } + + totalSize := 1 + // id + 2 + // ProxyID + 1 + // IP version + len(ipBytes) + // client IP bytes + 2 + // ClientPort + 2 // DataLength + + buf := make([]byte, totalSize) + offset := 0 + buf[offset] = UDPProxyDataID + offset++ + + binary.BigEndian.PutUint16(buf[offset:], cmd.ProxyID) + offset += 2 + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.ClientPort) + offset += 2 + + binary.BigEndian.PutUint16(buf[offset:], cmd.DataLength) + + return buf, nil + + // ProxyInformationRequest: 1 byte ID + 2 bytes ProxyID. + case *ProxyInformationRequest: + buf := make([]byte, 1+2) + buf[0] = ProxyInformationRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + return buf, nil + + // ProxyInformationResponse: + // Format: 1 byte ID + 1 byte Exists + (if exists:) + // 1 byte IP version + IP bytes + 2 bytes SourcePort + 2 bytes DestPort + 1 byte Protocol. + // (For simplicity, this marshaller always writes the IP and port info even if !Exists.) + case *ProxyInformationResponse: + if !cmd.Exists { + buf := make([]byte, 1+1) + buf[0] = ProxyInformationResponseID + buf[1] = 0 /* false */ + + return buf, nil + } + + ip := net.ParseIP(cmd.SourceIP) + + if ip == nil { + return nil, fmt.Errorf("invalid source IP: %v", cmd.SourceIP) + } + + var ipVer uint8 + var ipBytes []byte + + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.SourceIP) + } + + totalSize := 1 + // id + 1 + // Exists flag + 1 + // IP version + len(ipBytes) + + 2 + // SourcePort + 2 + // DestPort + 1 // Protocol + + buf := make([]byte, totalSize) + + offset := 0 + buf[offset] = ProxyInformationResponseID + offset++ + + // We already handle this above + buf[offset] = 1 /* true */ + offset++ + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.SourcePort) + offset += 2 + + binary.BigEndian.PutUint16(buf[offset:], cmd.DestPort) + offset += 2 + + // Encode protocol as 1 byte. + switch cmd.Protocol { + case "tcp": + buf[offset] = TCP + case "udp": + buf[offset] = UDP + default: + return nil, fmt.Errorf("invalid protocol: %v", cmd.Protocol) + } + + // offset++ (not needed since we are at the end) + return buf, nil + + // ProxyConnectionInformationRequest: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *ProxyConnectionInformationRequest: + buf := make([]byte, 1+2+2) + + buf[0] = ProxyConnectionInformationRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // ProxyConnectionInformationResponse: + // Format: 1 byte ID + 1 byte Exists + (if exists:) + // 1 byte IP version + IP bytes + 2 bytes ClientPort. + // This marshaller only writes the rest of the data if Exists. + case *ProxyConnectionInformationResponse: + if !cmd.Exists { + buf := make([]byte, 1+1) + buf[0] = ProxyConnectionInformationResponseID + buf[1] = 0 /* false */ + + return buf, nil + } + + ip := net.ParseIP(cmd.ClientIP) + + if ip == nil { + return nil, fmt.Errorf("invalid client IP: %v", cmd.ClientIP) + } + + var ipVer uint8 + var ipBytes []byte + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.ClientIP) + } + + totalSize := 1 + // id + 1 + // Exists flag + 1 + // IP version + len(ipBytes) + + 2 // ClientPort + + buf := make([]byte, totalSize) + offset := 0 + buf[offset] = ProxyConnectionInformationResponseID + offset++ + + // We already handle this above + buf[offset] = 1 /* true */ + offset++ + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.ClientPort) + + return buf, nil + + default: + return nil, fmt.Errorf("unsupported command type") + } +} diff --git a/backend/sshappbackend/datacommands/marshalling_test.go b/backend/sshappbackend/datacommands/marshalling_test.go new file mode 100644 index 0000000..5b2e5ab --- /dev/null +++ b/backend/sshappbackend/datacommands/marshalling_test.go @@ -0,0 +1,652 @@ +package datacommands + +import ( + "bytes" + "log" + "os" + "testing" +) + +var logLevel = os.Getenv("HERMES_LOG_LEVEL") + +func TestProxyStatusRequest(t *testing.T) { + commandInput := &ProxyStatusRequest{ + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyStatusResponse(t *testing.T) { + commandInput := &ProxyStatusResponse{ + ProxyID: 19132, + IsActive: true, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.IsActive != commandUnmarshalled.IsActive { + t.Fail() + log.Printf("IsActive's are not equal (orig: '%t', unmsh: '%t')", commandInput.IsActive, commandUnmarshalled.IsActive) + } +} + +func TestRemoveProxy(t *testing.T) { + commandInput := &RemoveProxy{ + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyConnectionsRequest(t *testing.T) { + commandInput := &ProxyConnectionsRequest{ + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyConnectionsResponse(t *testing.T) { + commandInput := &ProxyConnectionsResponse{ + Connections: []uint16{12831, 9455, 64219, 12, 32}, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) + + if !ok { + t.Fatal("failed typecast") + } + + for connectionIndex, originalConnection := range commandInput.Connections { + remoteConnection := commandUnmarshalled.Connections[connectionIndex] + + if originalConnection != remoteConnection { + t.Fail() + log.Printf("(in #%d) SourceIP's are not equal (orig: %d, unmsh: %d)", connectionIndex, originalConnection, connectionIndex) + } + } +} + +func TestProxyInstanceResponse(t *testing.T) { + commandInput := &ProxyInstanceResponse{ + Proxies: []uint16{12831, 9455, 64219, 12, 32}, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) + + if !ok { + t.Fatal("failed typecast") + } + + for proxyIndex, originalProxy := range commandInput.Proxies { + remoteProxy := commandUnmarshalled.Proxies[proxyIndex] + + if originalProxy != remoteProxy { + t.Fail() + log.Printf("(in #%d) Proxy IDs are not equal (orig: %d, unmsh: %d)", proxyIndex, originalProxy, remoteProxy) + } + } +} + +func TestTCPConnectionOpened(t *testing.T) { + commandInput := &TCPConnectionOpened{ + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionOpened) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestTCPConnectionClosed(t *testing.T) { + commandInput := &TCPConnectionClosed{ + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionClosed) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestTCPProxyData(t *testing.T) { + commandInput := &TCPProxyData{ + ProxyID: 19132, + ConnectionID: 25565, + DataLength: 1234, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPProxyData) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } + + if commandInput.DataLength != commandUnmarshalled.DataLength { + t.Fail() + log.Printf("DataLength's are not equal (orig: '%d', unmsh: '%d')", commandInput.DataLength, commandUnmarshalled.DataLength) + } +} + +func TestUDPProxyData(t *testing.T) { + commandInput := &UDPProxyData{ + ProxyID: 19132, + ClientIP: "68.51.23.54", + ClientPort: 28173, + DataLength: 1234, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*UDPProxyData) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ClientIP != commandUnmarshalled.ClientIP { + t.Fail() + log.Printf("ClientIP's are not equal (orig: '%s', unmsh: '%s')", commandInput.ClientIP, commandUnmarshalled.ClientIP) + } + + if commandInput.ClientPort != commandUnmarshalled.ClientPort { + t.Fail() + log.Printf("ClientPort's are not equal (orig: '%d', unmsh: '%d')", commandInput.ClientPort, commandUnmarshalled.ClientPort) + } + + if commandInput.DataLength != commandUnmarshalled.DataLength { + t.Fail() + log.Printf("DataLength's are not equal (orig: '%d', unmsh: '%d')", commandInput.DataLength, commandUnmarshalled.DataLength) + } +} + +func TestProxyInformationRequest(t *testing.T) { + commandInput := &ProxyInformationRequest{ + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyInformationResponseExists(t *testing.T) { + commandInput := &ProxyInformationResponse{ + Exists: true, + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + } + + commandMarshalled, err := Marshal(commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } +} + +func TestProxyInformationResponseNoExist(t *testing.T) { + commandInput := &ProxyInformationResponse{ + Exists: false, + } + + commandMarshalled, err := Marshal(commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } +} + +func TestProxyConnectionInformationRequest(t *testing.T) { + commandInput := &ProxyConnectionInformationRequest{ + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestProxyConnectionInformationResponseExists(t *testing.T) { + commandInput := &ProxyConnectionInformationResponse{ + Exists: true, + ClientIP: "192.168.0.139", + ClientPort: 19132, + } + + commandMarshalled, err := Marshal(commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } + + if commandInput.ClientIP != commandUnmarshalled.ClientIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.ClientIP, commandUnmarshalled.ClientIP) + } + + if commandInput.ClientPort != commandUnmarshalled.ClientPort { + t.Fail() + log.Printf("ClientPort's are not equal (orig: %d, unmsh: %d)", commandInput.ClientPort, commandUnmarshalled.ClientPort) + } +} + +func TestProxyConnectionInformationResponseNoExists(t *testing.T) { + commandInput := &ProxyConnectionInformationResponse{ + Exists: false, + } + + commandMarshalled, err := Marshal(commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } +} diff --git a/backend/sshappbackend/datacommands/unmarshal.go b/backend/sshappbackend/datacommands/unmarshal.go new file mode 100644 index 0000000..d9d0523 --- /dev/null +++ b/backend/sshappbackend/datacommands/unmarshal.go @@ -0,0 +1,422 @@ +package datacommands + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +// Unmarshal reads from the provided connection and returns +// the message type (as a string), the unmarshalled struct, or an error. +func Unmarshal(conn io.Reader) (interface{}, error) { + // Every command starts with a 1-byte command ID. + header := make([]byte, 1) + + if _, err := io.ReadFull(conn, header); err != nil { + return nil, fmt.Errorf("couldn't read command ID: %w", err) + } + + cmdID := header[0] + switch cmdID { + // ProxyStatusRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyStatusRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyStatusRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return &ProxyStatusRequest{ + ProxyID: proxyID, + }, nil + + // ProxyStatusResponse: 1 byte ID + 2 bytes ProxyID + 1 byte IsActive. + case ProxyStatusResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyStatusResponse ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + boolBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyStatusResponse IsActive: %w", err) + } + + isActive := boolBuf[0] != 0 + + return &ProxyStatusResponse{ + ProxyID: proxyID, + IsActive: isActive, + }, nil + + // RemoveProxy: 1 byte ID + 2 bytes ProxyID. + case RemoveProxyID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read RemoveProxy ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return &RemoveProxy{ + ProxyID: proxyID, + }, nil + + // ProxyConnectionsRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyConnectionsRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionsRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return &ProxyConnectionsRequest{ + ProxyID: proxyID, + }, nil + + // ProxyConnectionsResponse: 1 byte ID + 2 bytes Connections length + 2 bytes for each Connection in Connections. + case ProxyConnectionsResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + } + + length := binary.BigEndian.Uint16(buf) + connections := make([]uint16, length) + + var failedDuringReading error + + for connectionIndex := range connections { + if _, err := io.ReadFull(conn, buf); err != nil { + failedDuringReading = fmt.Errorf("couldn't read ProxyConnectionsResponse with position of %d: %w", connectionIndex, err) + break + } + + connections[connectionIndex] = binary.BigEndian.Uint16(buf) + } + + return &ProxyConnectionsResponse{ + Connections: connections, + }, failedDuringReading + + // ProxyInstanceResponse: 1 byte ID + 2 bytes Proxies length + 2 bytes for each Proxy in Proxies. + case ProxyInstanceResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + } + + length := binary.BigEndian.Uint16(buf) + proxies := make([]uint16, length) + + var failedDuringReading error + + for connectionIndex := range proxies { + if _, err := io.ReadFull(conn, buf); err != nil { + failedDuringReading = fmt.Errorf("couldn't read ProxyConnectionsResponse with position of %d: %w", connectionIndex, err) + break + } + + proxies[connectionIndex] = binary.BigEndian.Uint16(buf) + } + + return &ProxyInstanceResponse{ + Proxies: proxies, + }, failedDuringReading + + // TCPConnectionOpened: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case TCPConnectionOpenedID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read TCPConnectionOpened fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return &TCPConnectionOpened{ + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // TCPConnectionClosed: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case TCPConnectionClosedID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read TCPConnectionClosed fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return &TCPConnectionClosed{ + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // TCPProxyData: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + 2 bytes DataLength. + case TCPProxyDataID: + buf := make([]byte, 2+2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read TCPProxyData fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + dataLength := binary.BigEndian.Uint16(buf[4:6]) + + return &TCPProxyData{ + ProxyID: proxyID, + ConnectionID: connectionID, + DataLength: dataLength, + }, nil + + // UDPProxyData: + // Format: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + + // 1 byte IP version + IP bytes + 2 bytes ClientPort + 2 bytes DataLength. + case UDPProxyDataID: + // Read 2 bytes ProxyID + 2 bytes ConnectionID. + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read UDPProxyData ProxyID/ConnectionID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return nil, fmt.Errorf("couldn't read UDPProxyData IP version: %w", err) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else if ipVerBuf[0] == 6 { + ipSize = IPv6Size + } else { + return nil, fmt.Errorf("invalid IP version received: %v", ipVerBuf[0]) + } + + // Read the IP bytes. + ipBytes := make([]byte, ipSize) + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return nil, fmt.Errorf("couldn't read UDPProxyData IP bytes: %w", err) + } + clientIP := net.IP(ipBytes).String() + + // Read ClientPort. + portBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, portBuf); err != nil { + return nil, fmt.Errorf("couldn't read UDPProxyData ClientPort: %w", err) + } + + clientPort := binary.BigEndian.Uint16(portBuf) + + // Read DataLength. + dataLengthBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, dataLengthBuf); err != nil { + return nil, fmt.Errorf("couldn't read UDPProxyData DataLength: %w", err) + } + + dataLength := binary.BigEndian.Uint16(dataLengthBuf) + + return &UDPProxyData{ + ProxyID: proxyID, + ClientIP: clientIP, + ClientPort: clientPort, + DataLength: dataLength, + }, nil + + // ProxyInformationRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyInformationRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyInformationRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return &ProxyInformationRequest{ + ProxyID: proxyID, + }, nil + + // ProxyInformationResponse: + // Format: 1 byte ID + 1 byte Exists + + // 1 byte IP version + IP bytes + 2 bytes SourcePort + 2 bytes DestPort + 1 byte Protocol. + case ProxyInformationResponseID: + // Read Exists flag. + boolBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyInformationResponse Exists flag: %w", err) + } + + exists := boolBuf[0] != 0 + + if !exists { + return &ProxyInformationResponse{ + Exists: exists, + }, nil + } + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyInformationResponse IP version: %w", err) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else if ipVerBuf[0] == 6 { + ipSize = IPv6Size + } else { + return nil, fmt.Errorf("invalid IP version in ProxyInformationResponse: %v", ipVerBuf[0]) + } + + // Read the source IP bytes. + ipBytes := make([]byte, ipSize) + + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return nil, fmt.Errorf("couldn't read ProxyInformationResponse IP bytes: %w", err) + } + + sourceIP := net.IP(ipBytes).String() + + // Read SourcePort and DestPort. + portsBuf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, portsBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyInformationResponse ports: %w", err) + } + + sourcePort := binary.BigEndian.Uint16(portsBuf[0:2]) + destPort := binary.BigEndian.Uint16(portsBuf[2:4]) + + // Read protocol. + protoBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, protoBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyInformationResponse protocol: %w", err) + } + + var protocol string + + if protoBuf[0] == TCP { + protocol = "tcp" + } else if protoBuf[0] == UDP { + protocol = "udp" + } else { + return nil, fmt.Errorf("invalid protocol value in ProxyInformationResponse: %d", protoBuf[0]) + } + + return &ProxyInformationResponse{ + Exists: exists, + SourceIP: sourceIP, + SourcePort: sourcePort, + DestPort: destPort, + Protocol: protocol, + }, nil + + // ProxyConnectionInformationRequest: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case ProxyConnectionInformationRequestID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationRequest fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return &ProxyConnectionInformationRequest{ + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // ProxyConnectionInformationResponse: + // Format: 1 byte ID + 1 byte Exists + 1 byte IP version + IP bytes + 2 bytes ClientPort. + case ProxyConnectionInformationResponseID: + // Read Exists flag. + boolBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse Exists flag: %w", err) + } + + exists := boolBuf[0] != 0 + + if !exists { + return &ProxyConnectionInformationResponse{ + Exists: exists, + }, nil + } + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP version: %w", err) + } + + if ipVerBuf[0] != 4 && ipVerBuf[0] != 6 { + return nil, fmt.Errorf("invalid IP version in ProxyConnectionInformationResponse: %v", ipVerBuf[0]) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else { + ipSize = IPv6Size + } + + // Read IP bytes. + ipBytes := make([]byte, ipSize) + + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP bytes: %w", err) + } + + clientIP := net.IP(ipBytes).String() + + // Read ClientPort. + portBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, portBuf); err != nil { + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse ClientPort: %w", err) + } + + clientPort := binary.BigEndian.Uint16(portBuf) + + return &ProxyConnectionInformationResponse{ + Exists: exists, + ClientIP: clientIP, + ClientPort: clientPort, + }, nil + default: + return nil, fmt.Errorf("unknown command id: %v", cmdID) + } +} diff --git a/backend/sshappbackend/gaslighter/gaslighter.go b/backend/sshappbackend/gaslighter/gaslighter.go new file mode 100644 index 0000000..ecccec7 --- /dev/null +++ b/backend/sshappbackend/gaslighter/gaslighter.go @@ -0,0 +1,30 @@ +package gaslighter + +import "io" + +type Gaslighter struct { + Byte byte + HasGaslit bool + ProxiedReader io.Reader +} + +func (gaslighter *Gaslighter) Read(p []byte) (n int, err error) { + if gaslighter.HasGaslit { + return gaslighter.ProxiedReader.Read(p) + } + + if len(p) == 0 { + return 0, nil + } + + p[0] = gaslighter.Byte + gaslighter.HasGaslit = true + + if len(p) > 1 { + n, err := gaslighter.ProxiedReader.Read(p[1:]) + + return n + 1, err + } else { + return 1, nil + } +} diff --git a/backend/sshappbackend/local-code/fs.go b/backend/sshappbackend/local-code/fs.go new file mode 100644 index 0000000..7fe43e4 --- /dev/null +++ b/backend/sshappbackend/local-code/fs.go @@ -0,0 +1,8 @@ +package main + +import ( + "embed" +) + +//go:embed remote-bin +var binFiles embed.FS diff --git a/backend/sshappbackend/local-code/logger.go b/backend/sshappbackend/local-code/logger.go new file mode 100644 index 0000000..d8ed3f9 --- /dev/null +++ b/backend/sshappbackend/local-code/logger.go @@ -0,0 +1,23 @@ +package main + +import ( + "strings" + + "github.com/charmbracelet/log" +) + +type WriteLogger struct{} + +func (writer WriteLogger) Write(p []byte) (n int, err error) { + logSplit := strings.Split(string(p), "\n") + + for _, line := range logSplit { + if line == "" { + continue + } + + log.Infof("Process: %s", line) + } + + return len(p), err +} diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go new file mode 100644 index 0000000..34a85d0 --- /dev/null +++ b/backend/sshappbackend/local-code/main.go @@ -0,0 +1,868 @@ +package main + +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand/v2" + "net" + "os" + "strings" + "sync" + "time" + + "git.terah.dev/imterah/hermes/backend/backendutil" + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "git.terah.dev/imterah/hermes/backend/sshappbackend/gaslighter" + "git.terah.dev/imterah/hermes/backend/sshappbackend/local-code/porttranslation" + "github.com/charmbracelet/log" + "github.com/go-playground/validator/v10" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +var validatorInstance *validator.Validate + +type TCPProxy struct { + proxyInformation *commonbackend.AddProxy + connections map[uint16]net.Conn +} + +type UDPProxy struct { + proxyInformation *commonbackend.AddProxy + portTranslation *porttranslation.PortTranslation +} + +type SSHAppBackendData struct { + IP string `json:"ip" validate:"required"` + Port uint16 `json:"port" validate:"required"` + Username string `json:"username" validate:"required"` + PrivateKey string `json:"privateKey" validate:"required"` + ListenOnIPs []string `json:"listenOnIPs"` +} + +type SSHAppBackend struct { + config *SSHAppBackendData + conn *ssh.Client + listener net.Listener + currentSock net.Conn + + tcpProxies map[uint16]*TCPProxy + udpProxies map[uint16]*UDPProxy + + // globalNonCriticalMessageLock: Locks all messages that don't need low-latency transmissions & high + // speed behind a lock. This ensures safety when it comes to handling messages correctly. + globalNonCriticalMessageLock sync.Mutex + // globalNonCriticalMessageChan: Channel for handling messages that need a reply / aren't critical. + globalNonCriticalMessageChan chan interface{} +} + +func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { + log.Info("SSHAppBackend is initializing...") + + if validatorInstance == nil { + validatorInstance = validator.New() + } + + backend.globalNonCriticalMessageChan = make(chan interface{}) + backend.tcpProxies = map[uint16]*TCPProxy{} + backend.udpProxies = map[uint16]*UDPProxy{} + + var backendData SSHAppBackendData + + if err := json.Unmarshal(configBytes, &backendData); err != nil { + return false, err + } + + if err := validatorInstance.Struct(&backendData); err != nil { + return false, err + } + + backend.config = &backendData + + if len(backend.config.ListenOnIPs) == 0 { + backend.config.ListenOnIPs = []string{"0.0.0.0"} + } + + signer, err := ssh.ParsePrivateKey([]byte(backendData.PrivateKey)) + + if err != nil { + log.Warnf("Failed to initialize: %s", err.Error()) + return false, err + } + + auth := ssh.PublicKeys(signer) + + config := &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + User: backendData.Username, + Auth: []ssh.AuthMethod{ + auth, + }, + } + + conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backendData.IP, backendData.Port), config) + + if err != nil { + log.Warnf("Failed to initialize: %s", err.Error()) + return false, err + } + + backend.conn = conn + + log.Debug("SSHAppBackend has connected successfully.") + log.Debug("Getting CPU architecture...") + + session, err := backend.conn.NewSession() + + if err != nil { + log.Warnf("Failed to create session: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + var stdoutBuf bytes.Buffer + session.Stdout = &stdoutBuf + + err = session.Run("uname -m") + + if err != nil { + log.Warnf("Failed to run uname command: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + cpuArchBytes := make([]byte, stdoutBuf.Len()) + stdoutBuf.Read(cpuArchBytes) + + cpuArch := string(cpuArchBytes) + cpuArch = cpuArch[:len(cpuArch)-1] + + var backendBinary string + + // Ordered in (subjective) popularity + if cpuArch == "x86_64" { + backendBinary = "remote-bin/rt-amd64" + } else if cpuArch == "aarch64" { + backendBinary = "remote-bin/rt-arm64" + } else if cpuArch == "arm" { + backendBinary = "remote-bin/rt-arm" + } else if len(cpuArch) == 4 && string(cpuArch[0]) == "i" && strings.HasSuffix(cpuArch, "86") { + backendBinary = "remote-bin/rt-386" + } else { + log.Warn("Failed to determine executable to use: CPU architecture not compiled/supported currently") + conn.Close() + backend.conn = nil + return false, fmt.Errorf("CPU architecture not compiled/supported currently") + } + + log.Debug("Checking if we need to copy the application...") + + var binary []byte + needsToCopyBinary := true + + session, err = backend.conn.NewSession() + + if err != nil { + log.Warnf("Failed to create session: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + session.Stdout = &stdoutBuf + + err = session.Start("[ -f /tmp/sshappbackend.runtime ] && md5sum /tmp/sshappbackend.runtime | cut -d \" \" -f 1") + + if err != nil { + log.Warnf("Failed to calculate hash of possibly existing backend: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + fileExists := stdoutBuf.Len() != 0 + + if fileExists { + remoteMD5HashStringBuf := make([]byte, stdoutBuf.Len()) + stdoutBuf.Read(remoteMD5HashStringBuf) + + remoteMD5HashString := string(remoteMD5HashStringBuf) + remoteMD5HashString = remoteMD5HashString[:len(remoteMD5HashString)-1] + + remoteMD5Hash, err := hex.DecodeString(remoteMD5HashString) + + if err != nil { + log.Warnf("Failed to decode hex: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + binary, err = binFiles.ReadFile(backendBinary) + + if err != nil { + log.Warnf("Failed to read file in the embedded FS: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, fmt.Errorf("(embedded FS): %s", err.Error()) + } + + localMD5Hash := md5.Sum(binary) + + log.Infof("remote: %s, local: %s", remoteMD5HashString, hex.EncodeToString(localMD5Hash[:])) + + if bytes.Compare(localMD5Hash[:], remoteMD5Hash) == 0 { + needsToCopyBinary = false + } + } + + if needsToCopyBinary { + log.Debug("Copying binary...") + + sftpInstance, err := sftp.NewClient(conn) + + if err != nil { + log.Warnf("Failed to initialize SFTP: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + defer sftpInstance.Close() + + if len(binary) == 0 { + binary, err = binFiles.ReadFile(backendBinary) + + if err != nil { + log.Warnf("Failed to read file in the embedded FS: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, fmt.Errorf("(embedded FS): %s", err.Error()) + } + } + + var file *sftp.File + + if fileExists { + file, err = sftpInstance.OpenFile("/tmp/sshappbackend.runtime", os.O_WRONLY) + } else { + file, err = sftpInstance.Create("/tmp/sshappbackend.runtime") + } + + if err != nil { + log.Warnf("Failed to create (or open) file: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + _, err = file.Write(binary) + + if err != nil { + log.Warnf("Failed to write file: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + err = file.Chmod(0755) + + if err != nil { + log.Warnf("Failed to change permissions on file: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + log.Debug("Done copying file.") + sftpInstance.Close() + } else { + log.Debug("Skipping copying as there's a copy on disk already.") + } + + log.Debug("Initializing Unix socket...") + + socketPath := fmt.Sprintf("/tmp/sock-%d.sock", rand.Uint()) + listener, err := conn.ListenUnix(socketPath) + + if err != nil { + log.Warnf("Failed to listen on socket: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + log.Debug("Starting process...") + + session, err = backend.conn.NewSession() + + if err != nil { + log.Warnf("Failed to create session: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + backend.listener = listener + + session.Stdout = WriteLogger{} + session.Stderr = WriteLogger{} + + go func() { + for { + err := session.Run(fmt.Sprintf("HERMES_LOG_LEVEL=\"%s\" HERMES_API_SOCK=\"%s\" /tmp/sshappbackend.runtime", os.Getenv("HERMES_LOG_LEVEL"), socketPath)) + + if err != nil && !errors.Is(err, &ssh.ExitError{}) && !errors.Is(err, &ssh.ExitMissingError{}) { + log.Errorf("Critically failed during execution of remote code: %s", err.Error()) + return + } else { + log.Warn("Remote code failed for an unknown reason. Restarting...") + } + } + }() + + go backend.sockServerHandler() + + log.Debug("Started process. Waiting for Unix socket connection...") + + for backend.currentSock == nil { + time.Sleep(10 * time.Millisecond) + } + + log.Debug("Detected connection. Sending initialization command...") + + proxyStatusRaw, err := backend.SendNonCriticalMessage(&commonbackend.Start{ + Arguments: []byte{}, + }) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*commonbackend.BackendStatusResponse) + + if !ok { + return false, fmt.Errorf("recieved invalid response type: %T", proxyStatusRaw) + } + + if proxyStatus.StatusCode == commonbackend.StatusFailure { + if proxyStatus.Message == "" { + return false, fmt.Errorf("failed to initialize backend in remote code") + } else { + return false, fmt.Errorf("failed to initialize backend in remote code: %s", proxyStatus.Message) + } + } + + log.Info("SSHAppBackend has initialized successfully.") + + return true, nil +} + +func (backend *SSHAppBackend) StopBackend() (bool, error) { + err := backend.conn.Close() + + if err != nil { + return false, err + } + + return true, nil +} + +func (backend *SSHAppBackend) GetBackendStatus() (bool, error) { + return backend.conn != nil, nil +} + +func (backend *SSHAppBackend) StartProxy(command *commonbackend.AddProxy) (bool, error) { + proxyStatusRaw, err := backend.SendNonCriticalMessage(command) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*datacommands.ProxyStatusResponse) + + if !ok { + return false, fmt.Errorf("recieved invalid response type: %T", proxyStatusRaw) + } + + if !proxyStatus.IsActive { + return false, fmt.Errorf("failed to initialize proxy in remote code") + } + + if command.Protocol == "tcp" { + backend.tcpProxies[proxyStatus.ProxyID] = &TCPProxy{ + proxyInformation: command, + } + + backend.tcpProxies[proxyStatus.ProxyID].connections = map[uint16]net.Conn{} + } else if command.Protocol == "udp" { + backend.udpProxies[proxyStatus.ProxyID] = &UDPProxy{ + proxyInformation: command, + portTranslation: &porttranslation.PortTranslation{}, + } + + backend.udpProxies[proxyStatus.ProxyID].portTranslation.UDPAddr = &net.UDPAddr{ + IP: net.ParseIP(command.SourceIP), + Port: int(command.SourcePort), + } + + udpMessageCommand := &datacommands.UDPProxyData{} + udpMessageCommand.ProxyID = proxyStatus.ProxyID + + backend.udpProxies[proxyStatus.ProxyID].portTranslation.WriteFrom = func(ip string, port uint16, data []byte) { + udpMessageCommand.ClientIP = ip + udpMessageCommand.ClientPort = port + udpMessageCommand.DataLength = uint16(len(data)) + + marshalledCommand, err := datacommands.Marshal(udpMessageCommand) + + if err != nil { + log.Warnf("Failed to marshal UDP message header") + return + } + + if _, err := backend.currentSock.Write(marshalledCommand); err != nil { + log.Warnf("Failed to write UDP message header") + return + } + + if _, err := backend.currentSock.Write(data); err != nil { + log.Warnf("Failed to write UDP message") + return + } + } + + go func() { + for { + time.Sleep(3 * time.Minute) + + // Checks if the proxy still exists before continuing + _, ok := backend.udpProxies[proxyStatus.ProxyID] + + if !ok { + return + } + + // Then attempt to run cleanup tasks + log.Debug("Running UDP proxy cleanup tasks (invoking CleanupPorts() on portTranslation)") + backend.udpProxies[proxyStatus.ProxyID].portTranslation.CleanupPorts() + } + }() + } + + return true, nil +} + +func (backend *SSHAppBackend) StopProxy(command *commonbackend.RemoveProxy) (bool, error) { + if command.Protocol == "tcp" { + for proxyIndex, proxy := range backend.tcpProxies { + if proxy.proxyInformation.DestPort != command.DestPort { + continue + } + + onDisconnect := &datacommands.TCPConnectionClosed{ + ProxyID: proxyIndex, + } + + for connectionIndex, connection := range proxy.connections { + connection.Close() + delete(proxy.connections, connectionIndex) + + onDisconnect.ConnectionID = connectionIndex + disconnectionCommandMarshalled, err := datacommands.Marshal(onDisconnect) + + if err != nil { + log.Errorf("failed to marshal disconnection message: %s", err.Error()) + } + + backend.currentSock.Write(disconnectionCommandMarshalled) + } + + proxyStatusRaw, err := backend.SendNonCriticalMessage(&datacommands.RemoveProxy{ + ProxyID: proxyIndex, + }) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*datacommands.ProxyStatusResponse) + + if !ok { + log.Warn("Failed to stop proxy: typecast failed") + return true, fmt.Errorf("failed to stop proxy: typecast failed") + } + + if proxyStatus.IsActive { + log.Warn("Failed to stop proxy: still running") + return true, fmt.Errorf("failed to stop proxy: still running") + } + } + } else if command.Protocol == "udp" { + for proxyIndex, proxy := range backend.udpProxies { + if proxy.proxyInformation.DestPort != command.DestPort { + continue + } + + proxyStatusRaw, err := backend.SendNonCriticalMessage(&datacommands.RemoveProxy{ + ProxyID: proxyIndex, + }) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*datacommands.ProxyStatusResponse) + + if !ok { + log.Warn("Failed to stop proxy: typecast failed") + return true, fmt.Errorf("failed to stop proxy: typecast failed") + } + + if proxyStatus.IsActive { + log.Warn("Failed to stop proxy: still running") + return true, fmt.Errorf("failed to stop proxy: still running") + } + + proxy.portTranslation.StopAllPorts() + delete(backend.udpProxies, proxyIndex) + } + } + + return false, fmt.Errorf("could not find the proxy") +} + +func (backend *SSHAppBackend) GetAllClientConnections() []*commonbackend.ProxyClientConnection { + connections := []*commonbackend.ProxyClientConnection{} + informationRequest := &datacommands.ProxyConnectionInformationRequest{} + + for proxyID, tcpProxy := range backend.tcpProxies { + informationRequest.ProxyID = proxyID + + for connectionID := range tcpProxy.connections { + informationRequest.ConnectionID = connectionID + + proxyStatusRaw, err := backend.SendNonCriticalMessage(informationRequest) + + if err != nil { + log.Warnf("Failed to get connection information for Proxy ID: %d, Connection ID: %d: %s", proxyID, connectionID, err.Error()) + return connections + } + + connectionStatus, ok := proxyStatusRaw.(*datacommands.ProxyConnectionInformationResponse) + + if !ok { + log.Warn("Failed to get connection response: typecast failed") + return connections + } + + if !connectionStatus.Exists { + log.Warnf("Connection with proxy ID: %d, Connection ID: %d is reported to not exist!", proxyID, connectionID) + tcpProxy.connections[connectionID].Close() + } + + connections = append(connections, &commonbackend.ProxyClientConnection{ + SourceIP: tcpProxy.proxyInformation.SourceIP, + SourcePort: tcpProxy.proxyInformation.SourcePort, + DestPort: tcpProxy.proxyInformation.DestPort, + ClientIP: connectionStatus.ClientIP, + ClientPort: connectionStatus.ClientPort, + }) + } + } + + return connections +} + +// We don't have any parameter limitations, so we should be good. +func (backend *SSHAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + } +} + +func (backend *SSHAppBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { + var backendData SSHAppBackendData + + if validatorInstance == nil { + validatorInstance = validator.New() + } + + if err := json.Unmarshal(arguments, &backendData); err != nil { + return &commonbackend.CheckParametersResponse{ + IsValid: false, + Message: fmt.Sprintf("could not read json: %s", err.Error()), + } + } + + if err := validatorInstance.Struct(&backendData); err != nil { + return &commonbackend.CheckParametersResponse{ + IsValid: false, + Message: fmt.Sprintf("failed validation of parameters: %s", err.Error()), + } + } + + return &commonbackend.CheckParametersResponse{ + IsValid: true, + } +} + +func (backend *SSHAppBackend) OnTCPConnectionOpened(proxyID, connectionID uint16) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", backend.tcpProxies[proxyID].proxyInformation.SourceIP, backend.tcpProxies[proxyID].proxyInformation.SourcePort)) + + if err != nil { + log.Warnf("failed to dial sock: %s", err.Error()) + } + + go func() { + dataBuf := make([]byte, 65535) + + tcpData := &datacommands.TCPProxyData{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + for { + len, err := conn.Read(dataBuf) + + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } else if err.Error() != "EOF" { + log.Warnf("failed to read from sock: %s", err.Error()) + } + + conn.Close() + break + } + + tcpData.DataLength = uint16(len) + marshalledMessageCommand, err := datacommands.Marshal(tcpData) + + if err != nil { + log.Warnf("failed to marshal message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.currentSock.Write(marshalledMessageCommand); err != nil { + log.Warnf("failed to send marshalled message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.currentSock.Write(dataBuf[:len]); err != nil { + log.Warnf("failed to send raw message data: %s", err.Error()) + + conn.Close() + break + } + } + + onDisconnect := &datacommands.TCPConnectionClosed{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + disconnectionCommandMarshalled, err := datacommands.Marshal(onDisconnect) + + if err != nil { + log.Errorf("failed to marshal disconnection message: %s", err.Error()) + } + + backend.currentSock.Write(disconnectionCommandMarshalled) + }() + + backend.tcpProxies[proxyID].connections[connectionID] = conn +} + +func (backend *SSHAppBackend) OnTCPConnectionClosed(proxyID, connectionID uint16) { + proxy, ok := backend.tcpProxies[proxyID] + + if !ok { + log.Warn("Could not find TCP proxy") + } + + connection, ok := proxy.connections[connectionID] + + if !ok { + log.Warn("Could not find connection in TCP proxy") + } + + connection.Close() + delete(proxy.connections, connectionID) +} + +func (backend *SSHAppBackend) HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) { + proxy, ok := backend.tcpProxies[message.ProxyID] + + if !ok { + log.Warn("Could not find TCP proxy") + } + + connection, ok := proxy.connections[message.ConnectionID] + + if !ok { + log.Warn("Could not find connection in TCP proxy") + } + + connection.Write(data) +} + +func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) { + proxy, ok := backend.udpProxies[message.ProxyID] + + if !ok { + log.Warn("Could not find UDP proxy") + } + + if _, err := proxy.portTranslation.WriteTo(message.ClientIP, message.ClientPort, data); err != nil { + log.Warnf("Failed to write to UDP: %s", err.Error()) + } +} + +func (backend *SSHAppBackend) SendNonCriticalMessage(iface interface{}) (interface{}, error) { + if backend.currentSock == nil { + return nil, fmt.Errorf("socket connection not initialized yet") + } + + bytes, err := datacommands.Marshal(iface) + + if err != nil && err.Error() == "unsupported command type" { + bytes, err = commonbackend.Marshal(iface) + + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + backend.globalNonCriticalMessageLock.Lock() + + if _, err := backend.currentSock.Write(bytes); err != nil { + backend.globalNonCriticalMessageLock.Unlock() + return nil, fmt.Errorf("failed to write message: %s", err.Error()) + } + + reply, ok := <-backend.globalNonCriticalMessageChan + + if !ok { + backend.globalNonCriticalMessageLock.Unlock() + return nil, fmt.Errorf("failed to get reply back: chan not OK") + } + + backend.globalNonCriticalMessageLock.Unlock() + return reply, nil +} + +func (backend *SSHAppBackend) sockServerHandler() { + for { + conn, err := backend.listener.Accept() + + if err != nil { + log.Warnf("Failed to accept remote connection: %s", err.Error()) + } + + log.Debug("Successfully connected.") + + backend.currentSock = conn + + commandID := make([]byte, 1) + + gaslighter := &gaslighter.Gaslighter{} + gaslighter.ProxiedReader = conn + + dataBuffer := make([]byte, 65535) + + var commandRaw interface{} + + for { + if _, err := conn.Read(commandID); err != nil { + log.Warnf("Failed to read command ID: %s", err.Error()) + return + } + + gaslighter.Byte = commandID[0] + gaslighter.HasGaslit = false + + if gaslighter.Byte > 100 { + commandRaw, err = datacommands.Unmarshal(gaslighter) + } else { + commandRaw, err = commonbackend.Unmarshal(gaslighter) + } + + if err != nil { + log.Warnf("Failed to parse command: %s", err.Error()) + } + + switch command := commandRaw.(type) { + case *datacommands.TCPConnectionOpened: + backend.OnTCPConnectionOpened(command.ProxyID, command.ConnectionID) + case *datacommands.TCPConnectionClosed: + backend.OnTCPConnectionClosed(command.ProxyID, command.ConnectionID) + case *datacommands.TCPProxyData: + if _, err := io.ReadFull(conn, dataBuffer[:command.DataLength]); err != nil { + log.Warnf("Failed to read entire data buffer: %s", err.Error()) + break + } + + backend.HandleTCPMessage(command, dataBuffer[:command.DataLength]) + case *datacommands.UDPProxyData: + if _, err := io.ReadFull(conn, dataBuffer[:command.DataLength]); err != nil { + log.Warnf("Failed to read entire data buffer: %s", err.Error()) + break + } + + backend.HandleUDPMessage(command, dataBuffer[:command.DataLength]) + default: + select { + case backend.globalNonCriticalMessageChan <- command: + default: + } + } + } + } +} + +func main() { + logLevel := os.Getenv("HERMES_LOG_LEVEL") + + if logLevel != "" { + switch logLevel { + case "debug": + log.SetLevel(log.DebugLevel) + + case "info": + log.SetLevel(log.InfoLevel) + + case "warn": + log.SetLevel(log.WarnLevel) + + case "error": + log.SetLevel(log.ErrorLevel) + + case "fatal": + log.SetLevel(log.FatalLevel) + } + } + + backend := &SSHAppBackend{} + + application := backendutil.NewHelper(backend) + err := application.Start() + + if err != nil { + log.Fatalf("failed execution in application: %s", err.Error()) + } +} diff --git a/backend/sshappbackend/local-code/porttranslation/translation.go b/backend/sshappbackend/local-code/porttranslation/translation.go new file mode 100644 index 0000000..b8c0454 --- /dev/null +++ b/backend/sshappbackend/local-code/porttranslation/translation.go @@ -0,0 +1,112 @@ +package porttranslation + +import ( + "fmt" + "net" + "sync" + "time" +) + +type connectionData struct { + udpConn *net.UDPConn + buf []byte + hasBeenAliveFor time.Time +} + +type PortTranslation struct { + UDPAddr *net.UDPAddr + WriteFrom func(ip string, port uint16, data []byte) + + newConnectionLock sync.Mutex + connections map[string]map[uint16]*connectionData +} + +func (translation *PortTranslation) CleanupPorts() { + if translation.connections == nil { + translation.connections = map[string]map[uint16]*connectionData{} + return + } + + for connectionIPIndex, connectionPorts := range translation.connections { + anyAreAlive := false + + for connectionPortIndex, connectionData := range connectionPorts { + if time.Now().Before(connectionData.hasBeenAliveFor.Add(3 * time.Minute)) { + anyAreAlive = true + continue + } + + connectionData.udpConn.Close() + delete(connectionPorts, connectionPortIndex) + } + + if !anyAreAlive { + delete(translation.connections, connectionIPIndex) + } + } +} + +func (translation *PortTranslation) StopAllPorts() { + if translation.connections == nil { + return + } + + for connectionIPIndex, connectionPorts := range translation.connections { + for connectionPortIndex, connectionData := range connectionPorts { + connectionData.udpConn.Close() + delete(connectionPorts, connectionPortIndex) + } + + delete(translation.connections, connectionIPIndex) + } + + translation.connections = nil +} + +func (translation *PortTranslation) WriteTo(ip string, port uint16, data []byte) (int, error) { + if translation.connections == nil { + translation.connections = map[string]map[uint16]*connectionData{} + } + + connectionPortData, ok := translation.connections[ip] + + if !ok { + translation.connections[ip] = map[uint16]*connectionData{} + connectionPortData = translation.connections[ip] + } + + connectionStruct, ok := connectionPortData[port] + + if !ok { + connectionPortData[port] = &connectionData{} + connectionStruct = connectionPortData[port] + + udpConn, err := net.DialUDP("udp", nil, translation.UDPAddr) + + if err != nil { + return 0, fmt.Errorf("failed to initialize UDP socket: %s", err.Error()) + } + + connectionStruct.udpConn = udpConn + connectionStruct.buf = make([]byte, 65535) + + go func() { + for { + n, err := udpConn.Read(connectionStruct.buf) + + if err != nil { + udpConn.Close() + delete(connectionPortData, port) + + return + } + + connectionStruct.hasBeenAliveFor = time.Now() + translation.WriteFrom(ip, port, connectionStruct.buf[:n]) + } + }() + } + + connectionStruct.hasBeenAliveFor = time.Now() + return connectionStruct.udpConn.Write(data) +} diff --git a/backend/sshappbackend/remote-code/backendutil_custom/application.go b/backend/sshappbackend/remote-code/backendutil_custom/application.go new file mode 100644 index 0000000..2747f28 --- /dev/null +++ b/backend/sshappbackend/remote-code/backendutil_custom/application.go @@ -0,0 +1,306 @@ +package backendutil_custom + +import ( + "io" + "net" + "os" + + "git.terah.dev/imterah/hermes/backend/backendutil" + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "git.terah.dev/imterah/hermes/backend/sshappbackend/gaslighter" + "github.com/charmbracelet/log" +) + +type BackendApplicationHelper struct { + Backend BackendInterface + SocketPath string + + socket net.Conn +} + +func (helper *BackendApplicationHelper) Start() error { + log.Debug("BackendApplicationHelper is starting") + err := backendutil.ConfigureProfiling() + + if err != nil { + return err + } + + log.Debug("Currently waiting for Unix socket connection...") + + helper.socket, err = net.Dial("unix", helper.SocketPath) + + if err != nil { + return err + } + + helper.Backend.OnSocketConnection(helper.socket) + + log.Debug("Sucessfully connected") + + gaslighter := &gaslighter.Gaslighter{} + gaslighter.ProxiedReader = helper.socket + + commandID := make([]byte, 1) + + for { + if _, err := helper.socket.Read(commandID); err != nil { + return err + } + + gaslighter.Byte = commandID[0] + gaslighter.HasGaslit = false + + var commandRaw interface{} + + if gaslighter.Byte > 100 { + commandRaw, err = datacommands.Unmarshal(gaslighter) + } else { + commandRaw, err = commonbackend.Unmarshal(gaslighter) + } + + if err != nil { + return err + } + + switch command := commandRaw.(type) { + case *datacommands.ProxyConnectionsRequest: + connections := helper.Backend.GetAllClientConnections(command.ProxyID) + + serverParams := &datacommands.ProxyConnectionsResponse{ + Connections: connections, + } + + byteData, err := datacommands.Marshal(serverParams) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + case *datacommands.RemoveProxy: + ok, err := helper.Backend.StopProxy(command) + var hasAnyFailed bool + + if !ok { + log.Warnf("failed to remove proxy (ID %d): RemoveProxy returned into failure state", command.ProxyID) + hasAnyFailed = true + } else if err != nil { + log.Warnf("failed to remove proxy (ID %d): %s", command.ProxyID, err.Error()) + hasAnyFailed = true + } + + response := &datacommands.ProxyStatusResponse{ + ProxyID: command.ProxyID, + IsActive: hasAnyFailed, + } + + responseMarshalled, err := datacommands.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *datacommands.ProxyInformationRequest: + response := helper.Backend.ResolveProxy(command.ProxyID) + responseMarshalled, err := datacommands.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *datacommands.ProxyConnectionInformationRequest: + response := helper.Backend.ResolveConnection(command.ProxyID, command.ConnectionID) + responseMarshalled, err := datacommands.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *datacommands.TCPConnectionClosed: + helper.Backend.OnTCPConnectionClosed(command.ProxyID, command.ConnectionID) + case *datacommands.TCPProxyData: + bytes := make([]byte, command.DataLength) + _, err := io.ReadFull(helper.socket, bytes) + + if err != nil { + log.Warn("failed to read TCP data") + } + + helper.Backend.HandleTCPMessage(command, bytes) + case *datacommands.UDPProxyData: + bytes := make([]byte, command.DataLength) + _, err := io.ReadFull(helper.socket, bytes) + + if err != nil { + log.Warn("failed to read TCP data") + } + + helper.Backend.HandleUDPMessage(command, bytes) + case *commonbackend.Start: + ok, err := helper.Backend.StartBackend(command.Arguments) + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.Stop: + ok, err := helper.Backend.StopBackend() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + IsRunning: !ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.BackendStatusRequest: + ok, err := helper.Backend.GetBackendStatus() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.AddProxy: + id, ok, err := helper.Backend.StartProxy(command) + var hasAnyFailed bool + + if !ok { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true + } else if err != nil { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) + hasAnyFailed = true + } + + response := &datacommands.ProxyStatusResponse{ + ProxyID: id, + IsActive: !hasAnyFailed, + } + + responseMarshalled, err := datacommands.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.CheckClientParameters: + resp := helper.Backend.CheckParametersForConnections(command) + resp.InResponseTo = "checkClientParameters" + + byteData, err := commonbackend.Marshal(resp) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + case *commonbackend.CheckServerParameters: + resp := helper.Backend.CheckParametersForBackend(command.Arguments) + resp.InResponseTo = "checkServerParameters" + + byteData, err := commonbackend.Marshal(resp) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + default: + log.Warnf("Unsupported command recieved: %T", command) + } + } +} + +func NewHelper(backend BackendInterface) *BackendApplicationHelper { + socketPath, ok := os.LookupEnv("HERMES_API_SOCK") + + if !ok { + log.Warn("HERMES_API_SOCK is not defined! This will cause an issue unless the backend manually overwrites it") + } + + helper := &BackendApplicationHelper{ + Backend: backend, + SocketPath: socketPath, + } + + return helper +} diff --git a/backend/sshappbackend/remote-code/backendutil_custom/structure.go b/backend/sshappbackend/remote-code/backendutil_custom/structure.go new file mode 100644 index 0000000..65c5a23 --- /dev/null +++ b/backend/sshappbackend/remote-code/backendutil_custom/structure.go @@ -0,0 +1,26 @@ +package backendutil_custom + +import ( + "net" + + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" +) + +type BackendInterface interface { + StartBackend(arguments []byte) (bool, error) + StopBackend() (bool, error) + GetBackendStatus() (bool, error) + StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) + StopProxy(command *datacommands.RemoveProxy) (bool, error) + GetAllProxies() []uint16 + ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse + GetAllClientConnections(proxyID uint16) []uint16 + ResolveConnection(proxyID, connectionID uint16) *datacommands.ProxyConnectionInformationResponse + CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse + CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse + OnTCPConnectionClosed(proxyID, connectionID uint16) + HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) + HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) + OnSocketConnection(sock net.Conn) +} diff --git a/backend/sshappbackend/remote-code/main.go b/backend/sshappbackend/remote-code/main.go new file mode 100644 index 0000000..d56a7a3 --- /dev/null +++ b/backend/sshappbackend/remote-code/main.go @@ -0,0 +1,460 @@ +package main + +import ( + "errors" + "fmt" + "net" + "os" + "strconv" + "strings" + "sync" + + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "git.terah.dev/imterah/hermes/backend/sshappbackend/remote-code/backendutil_custom" + "github.com/charmbracelet/log" +) + +type TCPProxy struct { + connectionIDIndex uint16 + connectionIDLock sync.Mutex + + proxyInformation *commonbackend.AddProxy + connections map[uint16]net.Conn + server net.Listener +} + +type UDPProxy struct { + server *net.UDPConn + proxyInformation *commonbackend.AddProxy +} + +type SSHRemoteAppBackend struct { + proxyIDIndex uint16 + proxyIDLock sync.Mutex + + tcpProxies map[uint16]*TCPProxy + udpProxies map[uint16]*UDPProxy + + isRunning bool + + sock net.Conn +} + +func (backend *SSHRemoteAppBackend) StartBackend(byte []byte) (bool, error) { + backend.tcpProxies = map[uint16]*TCPProxy{} + backend.udpProxies = map[uint16]*UDPProxy{} + + backend.isRunning = true + + return true, nil +} + +func (backend *SSHRemoteAppBackend) StopBackend() (bool, error) { + for tcpProxyIndex, tcpProxy := range backend.tcpProxies { + for _, tcpConnection := range tcpProxy.connections { + tcpConnection.Close() + } + + tcpProxy.server.Close() + delete(backend.tcpProxies, tcpProxyIndex) + } + + for udpProxyIndex, udpProxy := range backend.udpProxies { + udpProxy.server.Close() + delete(backend.udpProxies, udpProxyIndex) + } + + backend.isRunning = false + return true, nil +} + +func (backend *SSHRemoteAppBackend) GetBackendStatus() (bool, error) { + return backend.isRunning, nil +} + +func (backend *SSHRemoteAppBackend) StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) { + // Allocate a new proxy ID + backend.proxyIDLock.Lock() + proxyID := backend.proxyIDIndex + backend.proxyIDIndex++ + backend.proxyIDLock.Unlock() + + if command.Protocol == "tcp" { + backend.tcpProxies[proxyID] = &TCPProxy{ + connections: map[uint16]net.Conn{}, + proxyInformation: command, + } + + server, err := net.Listen("tcp", fmt.Sprintf(":%d", command.DestPort)) + + if err != nil { + return 0, false, fmt.Errorf("failed to open server: %s", err.Error()) + } + + backend.tcpProxies[proxyID].server = server + + go func() { + for { + conn, err := server.Accept() + + if err != nil { + log.Warnf("failed to accept connection: %s", err.Error()) + return + } + + go func() { + backend.tcpProxies[proxyID].connectionIDLock.Lock() + connectionID := backend.tcpProxies[proxyID].connectionIDIndex + backend.tcpProxies[proxyID].connectionIDIndex++ + backend.tcpProxies[proxyID].connectionIDLock.Unlock() + + backend.tcpProxies[proxyID].connections[connectionID] = conn + + dataBuf := make([]byte, 65535) + + onConnection := &datacommands.TCPConnectionOpened{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + connectionCommandMarshalled, err := datacommands.Marshal(onConnection) + + if err != nil { + log.Errorf("failed to marshal connection message: %s", err.Error()) + } + + backend.sock.Write(connectionCommandMarshalled) + + tcpData := &datacommands.TCPProxyData{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + for { + len, err := conn.Read(dataBuf) + + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } else if err.Error() != "EOF" { + log.Warnf("failed to read from sock: %s", err.Error()) + } + + conn.Close() + break + } + + tcpData.DataLength = uint16(len) + marshalledMessageCommand, err := datacommands.Marshal(tcpData) + + if err != nil { + log.Warnf("failed to marshal message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.sock.Write(marshalledMessageCommand); err != nil { + log.Warnf("failed to send marshalled message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.sock.Write(dataBuf[:len]); err != nil { + log.Warnf("failed to send raw message data: %s", err.Error()) + + conn.Close() + break + } + } + + onDisconnect := &datacommands.TCPConnectionClosed{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + disconnectionCommandMarshalled, err := datacommands.Marshal(onDisconnect) + + if err != nil { + log.Errorf("failed to marshal disconnection message: %s", err.Error()) + } + + backend.sock.Write(disconnectionCommandMarshalled) + }() + } + }() + } else if command.Protocol == "udp" { + backend.udpProxies[proxyID] = &UDPProxy{ + proxyInformation: command, + } + + server, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv4(0, 0, 0, 0), + Port: int(command.DestPort), + }) + + if err != nil { + return 0, false, fmt.Errorf("failed to open server: %s", err.Error()) + } + + backend.udpProxies[proxyID].server = server + dataBuf := make([]byte, 65535) + + udpProxyData := &datacommands.UDPProxyData{ + ProxyID: proxyID, + } + + go func() { + for { + len, addr, err := server.ReadFromUDP(dataBuf) + + if err != nil { + log.Warnf("failed to read from UDP socket: %s", err.Error()) + continue + } + + udpProxyData.ClientIP = addr.IP.String() + udpProxyData.ClientPort = uint16(addr.Port) + udpProxyData.DataLength = uint16(len) + + marshalledMessageCommand, err := datacommands.Marshal(udpProxyData) + + if err != nil { + log.Warnf("failed to marshal message data: %s", err.Error()) + continue + } + + if _, err := backend.sock.Write(marshalledMessageCommand); err != nil { + log.Warnf("failed to send marshalled message data: %s", err.Error()) + continue + } + + if _, err := backend.sock.Write(dataBuf[:len]); err != nil { + log.Warnf("failed to send raw message data: %s", err.Error()) + continue + } + } + }() + } + + return proxyID, true, nil +} + +func (backend *SSHRemoteAppBackend) StopProxy(command *datacommands.RemoveProxy) (bool, error) { + tcpProxy, ok := backend.tcpProxies[command.ProxyID] + + if !ok { + udpProxy, ok := backend.udpProxies[command.ProxyID] + + if !ok { + return ok, fmt.Errorf("could not find proxy") + } + + udpProxy.server.Close() + delete(backend.udpProxies, command.ProxyID) + } else { + for _, tcpConnection := range tcpProxy.connections { + tcpConnection.Close() + } + + tcpProxy.server.Close() + delete(backend.tcpProxies, command.ProxyID) + } + + return true, nil +} + +func (backend *SSHRemoteAppBackend) GetAllProxies() []uint16 { + proxyList := make([]uint16, len(backend.tcpProxies)+len(backend.udpProxies)) + + currentPos := 0 + + for tcpProxy := range backend.tcpProxies { + proxyList[currentPos] = tcpProxy + currentPos += 1 + } + + for udpProxy := range backend.udpProxies { + proxyList[currentPos] = udpProxy + currentPos += 1 + } + + return proxyList +} + +func (backend *SSHRemoteAppBackend) ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse { + var proxyInformation *commonbackend.AddProxy + response := &datacommands.ProxyInformationResponse{} + + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + udpProxy, ok := backend.udpProxies[proxyID] + + if !ok { + response.Exists = false + return response + } + + proxyInformation = udpProxy.proxyInformation + } else { + proxyInformation = tcpProxy.proxyInformation + } + + response.Exists = true + response.SourceIP = proxyInformation.SourceIP + response.SourcePort = proxyInformation.SourcePort + response.DestPort = proxyInformation.DestPort + response.Protocol = proxyInformation.Protocol + + return response +} + +func (backend *SSHRemoteAppBackend) GetAllClientConnections(proxyID uint16) []uint16 { + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + return []uint16{} + } + + connectionsArray := make([]uint16, len(tcpProxy.connections)) + currentPos := 0 + + for connectionIndex := range tcpProxy.connections { + connectionsArray[currentPos] = connectionIndex + currentPos++ + } + + return connectionsArray +} + +func (backend *SSHRemoteAppBackend) ResolveConnection(proxyID, connectionID uint16) *datacommands.ProxyConnectionInformationResponse { + response := &datacommands.ProxyConnectionInformationResponse{} + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + response.Exists = false + return response + } + + connection, ok := tcpProxy.connections[connectionID] + + if !ok { + response.Exists = false + return response + } + + addr := connection.RemoteAddr().String() + ip := addr[:strings.LastIndex(addr, ":")] + port, err := strconv.Atoi(addr[strings.LastIndex(addr, ":")+1:]) + + if err != nil { + log.Warnf("failed to parse client port: %s", err.Error()) + response.Exists = false + + return response + } + + response.ClientIP = ip + response.ClientPort = uint16(port) + + return response +} + +func (backend *SSHRemoteAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + } +} + +func (backend *SSHRemoteAppBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + } +} + +func (backend *SSHRemoteAppBackend) HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) { + tcpProxy, ok := backend.tcpProxies[message.ProxyID] + + if !ok { + log.Warnf("could not find tcp proxy (ID %d)", message.ProxyID) + return + } + + connection, ok := tcpProxy.connections[message.ConnectionID] + + if !ok { + log.Warnf("could not find tcp proxy (ID %d) with connection ID (%d)", message.ProxyID, message.ConnectionID) + return + } + + connection.Write(data) +} + +func (backend *SSHRemoteAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) { + udpProxy, ok := backend.udpProxies[message.ProxyID] + + if !ok { + return + } + + udpProxy.server.WriteToUDP(data, &net.UDPAddr{ + IP: net.ParseIP(message.ClientIP), + Port: int(message.ClientPort), + }) +} + +func (backend *SSHRemoteAppBackend) OnTCPConnectionClosed(proxyID, connectionID uint16) { + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + return + } + + connection, ok := tcpProxy.connections[connectionID] + + if !ok { + return + } + + connection.Close() + delete(tcpProxy.connections, connectionID) +} + +func (backend *SSHRemoteAppBackend) OnSocketConnection(sock net.Conn) { + backend.sock = sock +} + +func main() { + logLevel := os.Getenv("HERMES_LOG_LEVEL") + + if logLevel != "" { + switch logLevel { + case "debug": + log.SetLevel(log.DebugLevel) + + case "info": + log.SetLevel(log.InfoLevel) + + case "warn": + log.SetLevel(log.WarnLevel) + + case "error": + log.SetLevel(log.ErrorLevel) + + case "fatal": + log.SetLevel(log.FatalLevel) + } + } + + backend := &SSHRemoteAppBackend{} + + application := backendutil_custom.NewHelper(backend) + err := application.Start() + + if err != nil { + log.Fatalf("failed execution in application: %s", err.Error()) + } +} diff --git a/backend/sshbackend/main.go b/backend/sshbackend/main.go index baf1994..d46b330 100644 --- a/backend/sshbackend/main.go +++ b/backend/sshbackend/main.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "os" + "slices" "strconv" "strings" "sync" @@ -18,6 +19,34 @@ import ( "golang.org/x/crypto/ssh" ) +var validatorInstance *validator.Validate + +type ConnWithTimeout struct { + net.Conn + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func (c *ConnWithTimeout) Read(b []byte) (int, error) { + err := c.Conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + + if err != nil { + return 0, err + } + + return c.Conn.Read(b) +} + +func (c *ConnWithTimeout) Write(b []byte) (int, error) { + err := c.Conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + + if err != nil { + return 0, err + } + + return c.Conn.Write(b) +} + type SSHListener struct { SourceIP string SourcePort uint16 @@ -26,31 +55,46 @@ type SSHListener struct { Listeners []net.Listener } +type SSHBackendData struct { + IP string `json:"ip" validate:"required"` + Port uint16 `json:"port" validate:"required"` + Username string `json:"username" validate:"required"` + PrivateKey string `json:"privateKey" validate:"required"` + DisablePIDCheck bool `json:"disablePIDCheck"` + ListenOnIPs []string `json:"listenOnIPs"` +} + type SSHBackend struct { config *SSHBackendData conn *ssh.Client clients []*commonbackend.ProxyClientConnection proxies []*SSHListener arrayPropMutex sync.Mutex -} - -type SSHBackendData struct { - IP string `json:"ip" validate:"required"` - Port uint16 `json:"port" validate:"required"` - Username string `json:"username" validate:"required"` - PrivateKey string `json:"privateKey" validate:"required"` - ListenOnIPs []string `json:"listenOnIPs"` + pid int + isReady bool + inReinitLoop bool } func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { log.Info("SSHBackend is initializing...") + + if validatorInstance == nil { + validatorInstance = validator.New() + } + + if backend.inReinitLoop { + for !backend.isReady { + time.Sleep(100 * time.Millisecond) + } + } + var backendData SSHBackendData if err := json.Unmarshal(bytes, &backendData); err != nil { return false, err } - if err := validator.New().Struct(&backendData); err != nil { + if err := validatorInstance.Struct(&backendData); err != nil { return false, err } @@ -76,16 +120,70 @@ func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { }, } - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backendData.IP, backendData.Port), config) + addr := fmt.Sprintf("%s:%d", backendData.IP, backendData.Port) + timeout := time.Duration(10 * time.Second) + + rawTCPConn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { return false, err } - backend.conn = conn + connWithTimeout := &ConnWithTimeout{ + Conn: rawTCPConn, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + c, chans, reqs, err := ssh.NewClientConn(connWithTimeout, addr, config) + + if err != nil { + return false, err + } + + client := ssh.NewClient(c, chans, reqs) + backend.conn = client + + if !backendData.DisablePIDCheck { + if backend.pid != 0 { + session, err := client.NewSession() + + if err != nil { + return false, err + } + + err = session.Run(fmt.Sprintf("kill -9 %d", backend.pid)) + + if err != nil { + log.Warnf("Failed to kill process: %s", err.Error()) + } + } + + session, err := client.NewSession() + + if err != nil { + return false, err + } + + // Get the parent PID of the shell so we can kill it if we disconnect + output, err := session.Output("ps --no-headers -fp $$ | awk '{print $3}'") + + if err != nil { + return false, err + } + + // Strip the new line and convert to int + backend.pid, err = strconv.Atoi(string(output)[:len(output)-1]) + + if err != nil { + return false, err + } + } + + go backend.backendDisconnectHandler() + go backend.backendKeepaliveHandler() log.Info("SSHBackend has initialized successfully.") - go backend.backendDisconnectHandler() return true, nil } @@ -203,8 +301,7 @@ func (backend *SSHBackend) StartProxy(command *commonbackend.AddProxy) (bool, er // Splice out the clientInstance by clientIndex // TODO: change approach. It works but it's a bit wonky imho - // I asked AI to do this as it's a relatively simple task and I forgot how to do this effectively - backend.clients = append(backend.clients[:clientIndex], backend.clients[clientIndex+1:]...) + backend.clients = slices.Delete(backend.clients, clientIndex, clientIndex+1) return } } @@ -288,10 +385,8 @@ func (backend *SSHBackend) StopProxy(command *commonbackend.RemoveProxy) (bool, } // Splice out the proxy instance by proxyIndex - // TODO: change approach. It works but it's a bit wonky imho - // I asked AI to do this as it's a relatively simple task and I forgot how to do this effectively - backend.proxies = append(backend.proxies[:proxyIndex], backend.proxies[proxyIndex+1:]...) + backend.proxies = slices.Delete(backend.proxies, proxyIndex, proxyIndex+1) return true, nil } } @@ -322,6 +417,10 @@ func (backend *SSHBackend) CheckParametersForConnections(clientParameters *commo func (backend *SSHBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { var backendData SSHBackendData + if validatorInstance == nil { + validatorInstance = validator.New() + } + if err := json.Unmarshal(arguments, &backendData); err != nil { return &commonbackend.CheckParametersResponse{ IsValid: false, @@ -329,7 +428,7 @@ func (backend *SSHBackend) CheckParametersForBackend(arguments []byte) *commonba } } - if err := validator.New().Struct(&backendData); err != nil { + if err := validatorInstance.Struct(&backendData); err != nil { return &commonbackend.CheckParametersResponse{ IsValid: false, Message: fmt.Sprintf("failed validation of parameters: %s", err.Error()), @@ -341,17 +440,34 @@ func (backend *SSHBackend) CheckParametersForBackend(arguments []byte) *commonba } } -func (backend *SSHBackend) backendDisconnectHandler() { +func (backend *SSHBackend) backendKeepaliveHandler() { for { if backend.conn != nil { - err := backend.conn.Wait() + _, _, err := backend.conn.SendRequest("keepalive@openssh.com", true, nil) - if err == nil || err.Error() != "EOF" { - continue + if err != nil { + log.Warn("Keepalive message failed!") + return } } - log.Info("Disconnected from the remote SSH server. Attempting to reconnect in 5 seconds...") + time.Sleep(5 * time.Second) + } +} + +func (backend *SSHBackend) backendDisconnectHandler() { + for { + if backend.conn != nil { + backend.conn.Wait() + backend.conn.Close() + + backend.isReady = false + backend.inReinitLoop = true + + log.Info("Disconnected from the remote SSH server. Attempting to reconnect in 5 seconds...") + } else { + log.Info("Retrying reconnection in 5 seconds...") + } time.Sleep(5 * time.Second) @@ -376,14 +492,74 @@ func (backend *SSHBackend) backendDisconnectHandler() { }, } - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backend.config.IP, backend.config.Port), config) + addr := fmt.Sprintf("%s:%d", backend.config.IP, backend.config.Port) + timeout := time.Duration(10 * time.Second) + + rawTCPConn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { - log.Errorf("Failed to connect to the server: %s", err.Error()) - return + log.Errorf("Failed to establish connection to the server: %s", err.Error()) + continue } - backend.conn = conn + connWithTimeout := &ConnWithTimeout{ + Conn: rawTCPConn, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + c, chans, reqs, err := ssh.NewClientConn(connWithTimeout, addr, config) + + if err != nil { + log.Errorf("Failed to create SSH client connection: %s", err.Error()) + rawTCPConn.Close() + continue + } + + client := ssh.NewClient(c, chans, reqs) + backend.conn = client + + if !backend.config.DisablePIDCheck { + if backend.pid != 0 { + session, err := client.NewSession() + + if err != nil { + log.Warnf("Failed to create SSH command session: %s", err.Error()) + return + } + + err = session.Run(fmt.Sprintf("kill -9 %d", backend.pid)) + + if err != nil { + log.Warnf("Failed to kill process: %s", err.Error()) + } + } + + session, err := client.NewSession() + + if err != nil { + log.Warnf("Failed to create SSH command session: %s", err.Error()) + return + } + + // Get the parent PID of the shell so we can kill it if we disconnect + output, err := session.Output("ps --no-headers -fp $$ | awk '{print $3}'") + + if err != nil { + log.Warnf("Failed to execute command to fetch PID: %s", err.Error()) + return + } + + // Strip the new line and convert to int + backend.pid, err = strconv.Atoi(string(output)[:len(output)-1]) + + if err != nil { + log.Warnf("Failed to parse PID: %s", err.Error()) + return + } + } + + go backend.backendKeepaliveHandler() log.Info("SSHBackend has reconnected successfully. Attempting to set up proxies again...") diff --git a/docker-compose.yml b/docker-compose.yml index 7ac3334..a549035 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,27 +6,12 @@ services: environment: DATABASE_URL: postgresql://${POSTGRES_USERNAME}:${POSTGRES_PASSWORD}@nextnet-postgres:5432/${POSTGRES_DB}?schema=nextnet HERMES_POSTGRES_DSN: postgres://${POSTGRES_USERNAME}:${POSTGRES_PASSWORD}@nextnet-postgres:5432/${POSTGRES_DB} + HERMES_JWT_SECRET: ${JWT_SECRET} HERMES_DATABASE_BACKEND: postgresql depends_on: - db ports: - 3000:3000 - - # WARN: The LOM is deprecated and likely broken currently. - # - # NOTE: For this to work correctly, the nextnet-api must be version > 0.1.1 - # or have a version with backported username support, incl. logins - lom: - image: ghcr.io/imterah/hermes-lom:latest - container_name: hermes-lom - restart: always - ports: - - 2222:2222 - depends_on: - - api - volumes: - - ssh_key_data:/app/keys - db: image: postgres:17.2 container_name: nextnet-postgres @@ -37,7 +22,6 @@ services: POSTGRES_USER: ${POSTGRES_USERNAME} volumes: - postgres_data:/var/lib/postgresql/data - volumes: postgres_data: ssh_key_data: diff --git a/go.mod b/go.mod index 0942d36..b390542 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/go-playground/validator/v10 v10.23.0 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/pkg/sftp v1.13.7 github.com/urfave/cli/v2 v2.27.5 golang.org/x/crypto v0.31.0 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/term v0.28.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.11 @@ -38,6 +40,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect @@ -56,7 +59,6 @@ require ( github.com/ugorji/go/codec v1.2.12 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect golang.org/x/arch v0.12.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.29.0 // indirect diff --git a/go.sum b/go.sum index da8d74e..dd30942 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02 github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -86,6 +88,8 @@ github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= +github.com/pkg/sftp v1.13.7 h1:uv+I3nNJvlKZIQGSr8JVQLNHFU9YhhNpvC14Y6KgmSM= +github.com/pkg/sftp v1.13.7/go.mod h1:KMKI0t3T6hfA+lTR/ssZdunHo+uwq7ghoN09/FSu3DY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -114,23 +118,61 @@ github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w= github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg= golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.0 h1:mjIs9gYtt56AzC4ZaffQuh88TZurBGhIJMBZGSxNerQ= google.golang.org/protobuf v1.36.0/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/prod-docker.env b/prod-docker.env index fe5da79..954c20f 100644 --- a/prod-docker.env +++ b/prod-docker.env @@ -1,5 +1,4 @@ -# These are default values, please change these! - POSTGRES_USERNAME=hermes POSTGRES_PASSWORD=hermes POSTGRES_DB=hermes +JWT_SECRET=hermes diff --git a/shell.nix b/shell.nix index 17c5989..ea1aa8e 100644 --- a/shell.nix +++ b/shell.nix @@ -3,7 +3,6 @@ }: pkgs.mkShell { buildInputs = with pkgs; [ # api/ - nodejs go gopls ];