From 0efda4b2833cecdd39bb1cae7db75930216ad677 Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 10 Jan 2025 16:23:26 -0500 Subject: [PATCH 01/24] feature: Add profiling documentation for backends based on BackendUtil. --- backend/backendutil/application.go | 7 +- backend/backendutil/profiling_disabled.go | 9 +++ backend/backendutil/profiling_enabled.go | 91 +++++++++++++++++++++++ docs/profiling.md | 6 ++ 4 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 backend/backendutil/profiling_disabled.go create mode 100644 backend/backendutil/profiling_enabled.go create mode 100644 docs/profiling.md diff --git a/backend/backendutil/application.go b/backend/backendutil/application.go index ccb21f2..802e8c7 100644 --- a/backend/backendutil/application.go +++ b/backend/backendutil/application.go @@ -18,9 +18,14 @@ type BackendApplicationHelper struct { func (helper *BackendApplicationHelper) Start() error { log.Debug("BackendApplicationHelper is starting") + err := configureAndLaunchBackgroundProfilingTasks() + + if err != nil { + return err + } + log.Debug("Currently waiting for Unix socket connection...") - var err error helper.socket, err = net.Dial("unix", helper.SocketPath) if err != nil { diff --git a/backend/backendutil/profiling_disabled.go b/backend/backendutil/profiling_disabled.go new file mode 100644 index 0000000..d93cbdd --- /dev/null +++ b/backend/backendutil/profiling_disabled.go @@ -0,0 +1,9 @@ +//go:build !debug + +package backendutil + +var endProfileFunc func() + +func configureAndLaunchBackgroundProfilingTasks() error { + return nil +} diff --git a/backend/backendutil/profiling_enabled.go b/backend/backendutil/profiling_enabled.go new file mode 100644 index 0000000..a3be527 --- /dev/null +++ b/backend/backendutil/profiling_enabled.go @@ -0,0 +1,91 @@ +//go:build debug + +package backendutil + +import ( + "errors" + "fmt" + "os" + "os/signal" + "runtime/pprof" + "syscall" + "time" + + "github.com/charmbracelet/log" + "golang.org/x/exp/rand" +) + +func configureAndLaunchBackgroundProfilingTasks() error { + profilingMode, err := os.ReadFile("/tmp/hermes.backendlauncher.profilebackends") + + if err != nil && errors.Is(err, os.ErrNotExist) { + return nil + } + + switch string(profilingMode) { + case "cpu": + log.Debug("Starting CPU profiling as a background task") + go doCPUProfiling() + case "mem": + log.Debug("Starting memory profiling as a background task") + go doMemoryProfiling() + default: + log.Warnf("Unknown profiling mode: %s", string(profilingMode)) + return nil + } + + return nil +} + +func doCPUProfiling() { + // (imterah) WTF? why isn't this being seeded on its own? according to Go docs, this should be seeded automatically... + rand.Seed(uint64(time.Now().UnixNano())) + + profileFileName := fmt.Sprintf("/tmp/hermes.backendlauncher.cpu.prof.%d", rand.Int()) + profileFile, err := os.Create(profileFileName) + + if err != nil { + log.Fatalf("Failed to create CPU profiling file: %s", err.Error()) + } + + log.Debugf("Writing CPU usage profile to '%s'. Will capture when Ctrl+C/SIGTERM is recieved.", profileFileName) + pprof.StartCPUProfile(profileFile) + + exitNotification := make(chan os.Signal, 1) + signal.Notify(exitNotification, os.Interrupt, syscall.SIGTERM) + <-exitNotification + + log.Debug("Recieved SIGTERM. Cleaning up and exiting...") + + pprof.StopCPUProfile() + profileFile.Close() + + log.Debug("Exiting...") + os.Exit(0) +} + +func doMemoryProfiling() { + // (imterah) WTF? why isn't this being seeded on its own? according to Go docs, this should be seeded automatically... + rand.Seed(uint64(time.Now().UnixNano())) + + profileFileName := fmt.Sprintf("/tmp/hermes.backendlauncher.mem.prof.%d", rand.Int()) + profileFile, err := os.Create(profileFileName) + + if err != nil { + log.Fatalf("Failed to create memory profiling file: %s", err.Error()) + } + + log.Debugf("Writing memory profile to '%s'. Will capture when Ctrl+C/SIGTERM is recieved.", profileFileName) + + exitNotification := make(chan os.Signal, 1) + signal.Notify(exitNotification, os.Interrupt, syscall.SIGTERM) + <-exitNotification + + log.Debug("Recieved SIGTERM. Cleaning up and exiting...") + + pprof.WriteHeapProfile(profileFile) + profileFile.Close() + + log.Debug("Exiting...") + os.Exit(0) +} diff --git a/docs/profiling.md b/docs/profiling.md new file mode 100644 index 0000000..06ceb7d --- /dev/null +++ b/docs/profiling.md @@ -0,0 +1,6 @@ +# Profiling +To profile any backend code based on `backendutil`, follow these steps: +1. Rebuild the backend with the `debug` flag: `cd $BACKEND_HERE; GOOS=linux go build -tags debug .; cd ..` +2. Copy the binary to the target machine (if applicable), and stop the API server. +3. If you want to profile the CPU utilization, write `cpu` to the file `/tmp/hermes.backendlauncher.profilebackends`: `echo -n "cpu" > /tmp/hermes.backendlauncher.profilebackends`. Else, replace `cpu` with `mem`. +4. Start the API server, with development mode and debug logging enabled. From 48adfc88db93a0eb6e3827d4fb86d90b179c047f Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 10 Jan 2025 16:37:38 -0500 Subject: [PATCH 02/24] fix: Fixes performance regression introduced in 4cb648cd66 / v2.1.0. (closes #7) --- backend/sshbackend/main.go | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/backend/sshbackend/main.go b/backend/sshbackend/main.go index f63b8c3..baf1994 100644 --- a/backend/sshbackend/main.go +++ b/backend/sshbackend/main.go @@ -221,13 +221,19 @@ func (backend *SSHBackend) StartProxy(command *commonbackend.AddProxy) (bool, er for { len, err := forwardedConn.Read(forwardedBuffer) - if err != nil && err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { - log.Errorf("failed to read from forwarded connection: %s", err.Error()) + if err != nil { + if err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { + log.Errorf("failed to read from forwarded connection: %s", err.Error()) + } + return } - if _, err = sourceConn.Write(forwardedBuffer[:len]); err != nil && err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { - log.Errorf("failed to write to source connection: %s", err.Error()) + if _, err = sourceConn.Write(forwardedBuffer[:len]); err != nil { + if err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { + log.Errorf("failed to write to source connection: %s", err.Error()) + } + return } } @@ -239,13 +245,19 @@ func (backend *SSHBackend) StartProxy(command *commonbackend.AddProxy) (bool, er for { len, err := sourceConn.Read(sourceBuffer) - if err != nil && err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { - log.Errorf("failed to read from source connection: %s", err.Error()) + if err != nil { + if err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { + log.Errorf("failed to read from source connection: %s", err.Error()) + } + return } - if _, err = forwardedConn.Write(sourceBuffer[:len]); err != nil && err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { - log.Errorf("failed to write to forwarded connection: %s", err.Error()) + if _, err = forwardedConn.Write(sourceBuffer[:len]); err != nil { + if err.Error() != "EOF" && !errors.Is(err, net.ErrClosed) { + log.Errorf("failed to write to forwarded connection: %s", err.Error()) + } + return } } From 737ba2887f1a9f58eaee98f67d7c75e9f22fc7c8 Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 10 Jan 2025 20:34:50 -0500 Subject: [PATCH 03/24] chore: Delete old commonbackend code. --- backend/commonbackend/constants.go | 147 ---- backend/commonbackend/marshal.go | 507 ------------- backend/commonbackend/marshalling_test.go | 824 ---------------------- backend/commonbackend/unmarshal.go | 665 ----------------- 4 files changed, 2143 deletions(-) delete mode 100644 backend/commonbackend/constants.go delete mode 100644 backend/commonbackend/marshal.go delete mode 100644 backend/commonbackend/marshalling_test.go delete mode 100644 backend/commonbackend/unmarshal.go diff --git a/backend/commonbackend/constants.go b/backend/commonbackend/constants.go deleted file mode 100644 index cdb68f2..0000000 --- a/backend/commonbackend/constants.go +++ /dev/null @@ -1,147 +0,0 @@ -package commonbackend - -type Start struct { - Type string // Will be 'start' always - Arguments []byte -} - -type Stop struct { - Type string // Will be 'stop' always -} - -type AddProxy struct { - Type string // Will be 'addProxy' always - SourceIP string - SourcePort uint16 - DestPort uint16 - Protocol string // Will be either 'tcp' or 'udp' -} - -type RemoveProxy struct { - Type string // Will be 'removeProxy' always - SourceIP string - SourcePort uint16 - DestPort uint16 - Protocol string // Will be either 'tcp' or 'udp' -} - -type ProxyStatusRequest struct { - Type string // Will be 'proxyStatusRequest' always - SourceIP string - SourcePort uint16 - DestPort uint16 - Protocol string // Will be either 'tcp' or 'udp' -} - -type ProxyStatusResponse struct { - Type string // Will be 'proxyStatusResponse' always - SourceIP string - SourcePort uint16 - DestPort uint16 - Protocol string // Will be either 'tcp' or 'udp' - IsActive bool -} - -type ProxyInstance struct { - SourceIP string - SourcePort uint16 - DestPort uint16 - Protocol string // Will be either 'tcp' or 'udp' -} - -type ProxyInstanceResponse struct { - Type string // Will be 'proxyConnectionResponse' always - Proxies []*ProxyInstance // List of connections -} - -type ProxyInstanceRequest struct { - Type string // Will be 'proxyConnectionRequest' always -} - -type BackendStatusResponse struct { - Type string // Will be 'backendStatusResponse' always - IsRunning bool // True if running, false if not running - StatusCode int // Either the 'Success' or 'Failure' constant - Message string // String message from the client (ex. failed to dial TCP) -} - -type BackendStatusRequest struct { - Type string // Will be 'backendStatusRequest' always -} - -type ProxyConnectionsRequest struct { - Type string // Will be 'proxyConnectionsRequest' always -} - -// Client's connection to a specific proxy -type ProxyClientConnection struct { - SourceIP string - SourcePort uint16 - DestPort uint16 - ClientIP string - ClientPort uint16 -} - -type ProxyConnectionsResponse struct { - Type string // Will be 'proxyConnectionsResponse' always - Connections []*ProxyClientConnection // List of connections -} - -type CheckClientParameters struct { - Type string // Will be 'checkClientParameters' always - SourceIP string - SourcePort uint16 - DestPort uint16 - Protocol string // Will be either 'tcp' or 'udp' -} - -type CheckServerParameters struct { - Type string // Will be 'checkServerParameters' always - Arguments []byte -} - -// Sent as a response to either CheckClientParameters or CheckBackendParameters -type CheckParametersResponse struct { - Type string // Will be 'checkParametersResponse' always - InResponseTo string // Will be either 'checkClientParameters' or 'checkServerParameters' - IsValid bool // If true, valid, and if false, invalid - Message string // String message from the client (ex. failed to unmarshal JSON: x is not defined) -} - -const ( - StartID = iota - StopID - AddProxyID - RemoveProxyID - ProxyConnectionsResponseID - CheckClientParametersID - CheckServerParametersID - CheckParametersResponseID - ProxyConnectionsRequestID - BackendStatusResponseID - BackendStatusRequestID - ProxyStatusRequestID - ProxyStatusResponseID - ProxyInstanceResponseID - ProxyInstanceRequestID -) - -const ( - TCP = iota - UDP -) - -const ( - StatusSuccess = iota - StatusFailure -) - -const ( - // IP versions - IPv4 = 4 - IPv6 = 6 - - // TODO: net has these constants defined already. We should switch to these - IPv4Size = 4 - IPv6Size = 16 -) diff --git a/backend/commonbackend/marshal.go b/backend/commonbackend/marshal.go deleted file mode 100644 index 6baf02e..0000000 --- a/backend/commonbackend/marshal.go +++ /dev/null @@ -1,507 +0,0 @@ -package commonbackend - -import ( - "encoding/binary" - "fmt" - "net" -) - -func marshalIndividualConnectionStruct(conn *ProxyClientConnection) []byte { - sourceIPOriginal := net.ParseIP(conn.SourceIP) - clientIPOriginal := net.ParseIP(conn.ClientIP) - - var serverIPVer uint8 - var sourceIP []byte - - if sourceIPOriginal.To4() == nil { - serverIPVer = IPv6 - sourceIP = sourceIPOriginal.To16() - } else { - serverIPVer = IPv4 - sourceIP = sourceIPOriginal.To4() - } - - var clientIPVer uint8 - var clientIP []byte - - if clientIPOriginal.To4() == nil { - clientIPVer = IPv6 - clientIP = clientIPOriginal.To16() - } else { - clientIPVer = IPv4 - clientIP = clientIPOriginal.To4() - } - - connectionBlock := make([]byte, 8+len(sourceIP)+len(clientIP)) - - connectionBlock[0] = serverIPVer - copy(connectionBlock[1:len(sourceIP)+1], sourceIP) - - binary.BigEndian.PutUint16(connectionBlock[1+len(sourceIP):3+len(sourceIP)], conn.SourcePort) - binary.BigEndian.PutUint16(connectionBlock[3+len(sourceIP):5+len(sourceIP)], conn.DestPort) - - connectionBlock[5+len(sourceIP)] = clientIPVer - copy(connectionBlock[6+len(sourceIP):6+len(sourceIP)+len(clientIP)], clientIP) - binary.BigEndian.PutUint16(connectionBlock[6+len(sourceIP)+len(clientIP):8+len(sourceIP)+len(clientIP)], conn.ClientPort) - - return connectionBlock -} - -func marshalIndividualProxyStruct(conn *ProxyInstance) ([]byte, error) { - sourceIPOriginal := net.ParseIP(conn.SourceIP) - - var sourceIPVer uint8 - var sourceIP []byte - - if sourceIPOriginal.To4() == nil { - sourceIPVer = IPv6 - sourceIP = sourceIPOriginal.To16() - } else { - sourceIPVer = IPv4 - sourceIP = sourceIPOriginal.To4() - } - - proxyBlock := make([]byte, 6+len(sourceIP)) - - proxyBlock[0] = sourceIPVer - copy(proxyBlock[1:len(sourceIP)+1], sourceIP) - - binary.BigEndian.PutUint16(proxyBlock[1+len(sourceIP):3+len(sourceIP)], conn.SourcePort) - binary.BigEndian.PutUint16(proxyBlock[3+len(sourceIP):5+len(sourceIP)], conn.DestPort) - - var protocolVersion uint8 - - if conn.Protocol == "tcp" { - protocolVersion = TCP - } else if conn.Protocol == "udp" { - protocolVersion = UDP - } else { - return proxyBlock, fmt.Errorf("invalid protocol recieved") - } - - proxyBlock[5+len(sourceIP)] = protocolVersion - - return proxyBlock, nil -} - -func Marshal(commandType string, command interface{}) ([]byte, error) { - switch commandType { - case "start": - startCommand, ok := command.(*Start) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - startCommandBytes := make([]byte, 1+2+len(startCommand.Arguments)) - startCommandBytes[0] = StartID - binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(startCommand.Arguments))) - copy(startCommandBytes[3:], startCommand.Arguments) - - return startCommandBytes, nil - case "stop": - _, ok := command.(*Stop) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - return []byte{StopID}, nil - case "addProxy": - addConnectionCommand, ok := command.(*AddProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(addConnectionCommand.SourceIP) - - var ipVer uint8 - var ipBytes []byte - - if sourceIP.To4() == nil { - ipBytes = sourceIP.To16() - ipVer = IPv6 - } else { - ipBytes = sourceIP.To4() - ipVer = IPv4 - } - - addConnectionBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) - - addConnectionBytes[0] = AddProxyID - addConnectionBytes[1] = ipVer - - copy(addConnectionBytes[2:2+len(ipBytes)], ipBytes) - - binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], addConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], addConnectionCommand.DestPort) - - var protocol uint8 - - if addConnectionCommand.Protocol == "tcp" { - protocol = TCP - } else if addConnectionCommand.Protocol == "udp" { - protocol = UDP - } else { - return nil, fmt.Errorf("invalid protocol") - } - - addConnectionBytes[6+len(ipBytes)] = protocol - - return addConnectionBytes, nil - case "removeProxy": - removeConnectionCommand, ok := command.(*RemoveProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(removeConnectionCommand.SourceIP) - - var ipVer uint8 - var ipBytes []byte - - if sourceIP.To4() == nil { - ipBytes = sourceIP.To16() - ipVer = IPv6 - } else { - ipBytes = sourceIP.To4() - ipVer = IPv4 - } - - removeConnectionBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) - - removeConnectionBytes[0] = RemoveProxyID - removeConnectionBytes[1] = ipVer - copy(removeConnectionBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], removeConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], removeConnectionCommand.DestPort) - - var protocol uint8 - - if removeConnectionCommand.Protocol == "tcp" { - protocol = TCP - } else if removeConnectionCommand.Protocol == "udp" { - protocol = UDP - } else { - return nil, fmt.Errorf("invalid protocol") - } - - removeConnectionBytes[6+len(ipBytes)] = protocol - - return removeConnectionBytes, nil - case "proxyConnectionsResponse": - allConnectionsCommand, ok := command.(*ProxyConnectionsResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - connectionsArray := make([][]byte, len(allConnectionsCommand.Connections)) - totalSize := 0 - - for connIndex, conn := range allConnectionsCommand.Connections { - connectionsArray[connIndex] = marshalIndividualConnectionStruct(conn) - totalSize += len(connectionsArray[connIndex]) + 1 - } - - if totalSize == 0 { - totalSize = 1 - } - - connectionCommandArray := make([]byte, totalSize+1) - connectionCommandArray[0] = ProxyConnectionsResponseID - - currentPosition := 1 - - for _, connection := range connectionsArray { - copy(connectionCommandArray[currentPosition:currentPosition+len(connection)], connection) - connectionCommandArray[currentPosition+len(connection)] = '\r' - currentPosition += len(connection) + 1 - } - - connectionCommandArray[totalSize] = '\n' - return connectionCommandArray, nil - case "checkClientParameters": - checkClientCommand, ok := command.(*CheckClientParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(checkClientCommand.SourceIP) - - var ipVer uint8 - var ipBytes []byte - - if sourceIP.To4() == nil { - ipBytes = sourceIP.To16() - ipVer = IPv6 - } else { - ipBytes = sourceIP.To4() - ipVer = IPv4 - } - - checkClientBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) - - checkClientBytes[0] = CheckClientParametersID - checkClientBytes[1] = ipVer - copy(checkClientBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], checkClientCommand.SourcePort) - binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], checkClientCommand.DestPort) - - var protocol uint8 - - if checkClientCommand.Protocol == "tcp" { - protocol = TCP - } else if checkClientCommand.Protocol == "udp" { - protocol = UDP - } else { - return nil, fmt.Errorf("invalid protocol") - } - - checkClientBytes[6+len(ipBytes)] = protocol - - return checkClientBytes, nil - case "checkServerParameters": - checkServerCommand, ok := command.(*CheckServerParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - serverCommandBytes := make([]byte, 1+2+len(checkServerCommand.Arguments)) - serverCommandBytes[0] = CheckServerParametersID - binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(checkServerCommand.Arguments))) - copy(serverCommandBytes[3:], checkServerCommand.Arguments) - - return serverCommandBytes, nil - case "checkParametersResponse": - checkParametersCommand, ok := command.(*CheckParametersResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - var checkMethod uint8 - - if checkParametersCommand.InResponseTo == "checkClientParameters" { - checkMethod = CheckClientParametersID - } else if checkParametersCommand.InResponseTo == "checkServerParameters" { - checkMethod = CheckServerParametersID - } else { - return nil, fmt.Errorf("invalid mode recieved (must be either checkClientParameters or checkServerParameters)") - } - - var isValid uint8 - - if checkParametersCommand.IsValid { - isValid = 1 - } - - checkResponseBytes := make([]byte, 3+2+len(checkParametersCommand.Message)) - checkResponseBytes[0] = CheckParametersResponseID - checkResponseBytes[1] = checkMethod - checkResponseBytes[2] = isValid - - binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(checkParametersCommand.Message))) - - if len(checkParametersCommand.Message) != 0 { - copy(checkResponseBytes[5:], []byte(checkParametersCommand.Message)) - } - - return checkResponseBytes, nil - case "backendStatusResponse": - backendStatusResponse, ok := command.(*BackendStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - var isRunning uint8 - - if backendStatusResponse.IsRunning { - isRunning = 1 - } else { - isRunning = 0 - } - - statusResponseBytes := make([]byte, 3+2+len(backendStatusResponse.Message)) - statusResponseBytes[0] = BackendStatusResponseID - statusResponseBytes[1] = isRunning - statusResponseBytes[2] = byte(backendStatusResponse.StatusCode) - - binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(backendStatusResponse.Message))) - - if len(backendStatusResponse.Message) != 0 { - copy(statusResponseBytes[5:], []byte(backendStatusResponse.Message)) - } - - return statusResponseBytes, nil - case "backendStatusRequest": - _, ok := command.(*BackendStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - statusRequestBytes := make([]byte, 1) - statusRequestBytes[0] = BackendStatusRequestID - - return statusRequestBytes, nil - case "proxyStatusRequest": - proxyStatusRequest, ok := command.(*ProxyStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusRequest.SourceIP) - - var ipVer uint8 - var ipBytes []byte - - if sourceIP.To4() == nil { - ipBytes = sourceIP.To16() - ipVer = IPv6 - } else { - ipBytes = sourceIP.To4() - ipVer = IPv4 - } - - proxyStatusRequestBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) - - proxyStatusRequestBytes[0] = ProxyStatusRequestID - proxyStatusRequestBytes[1] = ipVer - - copy(proxyStatusRequestBytes[2:2+len(ipBytes)], ipBytes) - - binary.BigEndian.PutUint16(proxyStatusRequestBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusRequest.SourcePort) - binary.BigEndian.PutUint16(proxyStatusRequestBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusRequest.DestPort) - - var protocol uint8 - - if proxyStatusRequest.Protocol == "tcp" { - protocol = TCP - } else if proxyStatusRequest.Protocol == "udp" { - protocol = UDP - } else { - return nil, fmt.Errorf("invalid protocol") - } - - proxyStatusRequestBytes[6+len(ipBytes)] = protocol - - return proxyStatusRequestBytes, nil - case "proxyStatusResponse": - proxyStatusResponse, ok := command.(*ProxyStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusResponse.SourceIP) - - var ipVer uint8 - var ipBytes []byte - - if sourceIP.To4() == nil { - ipBytes = sourceIP.To16() - ipVer = IPv6 - } else { - ipBytes = sourceIP.To4() - ipVer = IPv4 - } - - proxyStatusResponseBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) - - proxyStatusResponseBytes[0] = ProxyStatusResponseID - proxyStatusResponseBytes[1] = ipVer - - copy(proxyStatusResponseBytes[2:2+len(ipBytes)], ipBytes) - - binary.BigEndian.PutUint16(proxyStatusResponseBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusResponse.SourcePort) - binary.BigEndian.PutUint16(proxyStatusResponseBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusResponse.DestPort) - - var protocol uint8 - - if proxyStatusResponse.Protocol == "tcp" { - protocol = TCP - } else if proxyStatusResponse.Protocol == "udp" { - protocol = UDP - } else { - return nil, fmt.Errorf("invalid protocol") - } - - proxyStatusResponseBytes[6+len(ipBytes)] = protocol - - var isActive uint8 - - if proxyStatusResponse.IsActive { - isActive = 1 - } else { - isActive = 0 - } - - proxyStatusResponseBytes[7+len(ipBytes)] = isActive - - return proxyStatusResponseBytes, nil - case "proxyInstanceResponse": - proxyConectionResponse, ok := command.(*ProxyInstanceResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - proxyArray := make([][]byte, len(proxyConectionResponse.Proxies)) - totalSize := 0 - - for proxyIndex, proxy := range proxyConectionResponse.Proxies { - var err error - proxyArray[proxyIndex], err = marshalIndividualProxyStruct(proxy) - - if err != nil { - return nil, err - } - - totalSize += len(proxyArray[proxyIndex]) + 1 - } - - if totalSize == 0 { - totalSize = 1 - } - - connectionCommandArray := make([]byte, totalSize+1) - connectionCommandArray[0] = ProxyInstanceResponseID - - currentPosition := 1 - - for _, connection := range proxyArray { - copy(connectionCommandArray[currentPosition:currentPosition+len(connection)], connection) - connectionCommandArray[currentPosition+len(connection)] = '\r' - currentPosition += len(connection) + 1 - } - - connectionCommandArray[totalSize] = '\n' - - return connectionCommandArray, nil - case "proxyInstanceRequest": - _, ok := command.(*ProxyInstanceRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - return []byte{ProxyInstanceRequestID}, nil - case "proxyConnectionsRequest": - _, ok := command.(*ProxyConnectionsRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - return []byte{ProxyConnectionsRequestID}, nil - } - - return nil, fmt.Errorf("couldn't match command name") -} diff --git a/backend/commonbackend/marshalling_test.go b/backend/commonbackend/marshalling_test.go deleted file mode 100644 index 1c93f94..0000000 --- a/backend/commonbackend/marshalling_test.go +++ /dev/null @@ -1,824 +0,0 @@ -package commonbackend - -import ( - "bytes" - "log" - "os" - "testing" -) - -var logLevel = os.Getenv("HERMES_LOG_LEVEL") - -func TestStartCommandMarshalSupport(t *testing.T) { - commandInput := &Start{ - Type: "start", - Arguments: []byte("Hello from automated testing"), - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - if err != nil { - t.Fatal(err.Error()) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*Start) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { - log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) - } -} - -func TestStopCommandMarshalSupport(t *testing.T) { - commandInput := &Stop{ - Type: "stop", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - if err != nil { - t.Fatal(err.Error()) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*Stop) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } -} - -func TestAddConnectionCommandMarshalSupport(t *testing.T) { - commandInput := &AddProxy{ - Type: "addProxy", - SourceIP: "192.168.0.139", - SourcePort: 19132, - DestPort: 19132, - Protocol: "tcp", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - if err != nil { - t.Fatal(err.Error()) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*AddProxy) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if commandInput.SourceIP != commandUnmarshalled.SourceIP { - t.Fail() - log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) - } - - if commandInput.SourcePort != commandUnmarshalled.SourcePort { - t.Fail() - log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) - } - - if commandInput.DestPort != commandUnmarshalled.DestPort { - t.Fail() - log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) - } - - if commandInput.Protocol != commandUnmarshalled.Protocol { - t.Fail() - log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) - } -} - -func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { - commandInput := &RemoveProxy{ - Type: "removeProxy", - SourceIP: "192.168.0.139", - SourcePort: 19132, - DestPort: 19132, - Protocol: "tcp", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if err != nil { - t.Fatal(err.Error()) - } - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if commandInput.SourceIP != commandUnmarshalled.SourceIP { - t.Fail() - log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) - } - - if commandInput.SourcePort != commandUnmarshalled.SourcePort { - t.Fail() - log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) - } - - if commandInput.DestPort != commandUnmarshalled.DestPort { - t.Fail() - log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) - } - - if commandInput.Protocol != commandUnmarshalled.Protocol { - t.Fail() - log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) - } -} - -func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { - commandInput := &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", - Connections: []*ProxyClientConnection{ - { - SourceIP: "127.0.0.1", - SourcePort: 19132, - DestPort: 19132, - ClientIP: "127.0.0.1", - ClientPort: 12321, - }, - { - SourceIP: "127.0.0.1", - SourcePort: 19132, - DestPort: 19132, - ClientIP: "192.168.0.168", - ClientPort: 23457, - }, - { - SourceIP: "127.0.0.1", - SourcePort: 19132, - DestPort: 19132, - ClientIP: "68.42.203.47", - ClientPort: 38721, - }, - }, - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if err != nil { - t.Fatal(err.Error()) - } - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - for commandIndex, originalConnection := range commandInput.Connections { - remoteConnection := commandUnmarshalled.Connections[commandIndex] - - if originalConnection.SourceIP != remoteConnection.SourceIP { - t.Fail() - log.Printf("(in #%d) SourceIP's are not equal (orig: %s, unmsh: %s)", commandIndex, originalConnection.SourceIP, remoteConnection.SourceIP) - } - - if originalConnection.SourcePort != remoteConnection.SourcePort { - t.Fail() - log.Printf("(in #%d) SourcePort's are not equal (orig: %d, unmsh: %d)", commandIndex, originalConnection.SourcePort, remoteConnection.SourcePort) - } - - if originalConnection.DestPort != remoteConnection.DestPort { - t.Fail() - log.Printf("(in #%d) DestPort's are not equal (orig: %d, unmsh: %d)", commandIndex, originalConnection.DestPort, remoteConnection.DestPort) - } - - if originalConnection.ClientIP != remoteConnection.ClientIP { - t.Fail() - log.Printf("(in #%d) ClientIP's are not equal (orig: %s, unmsh: %s)", commandIndex, originalConnection.ClientIP, remoteConnection.ClientIP) - } - - if originalConnection.ClientPort != remoteConnection.ClientPort { - t.Fail() - log.Printf("(in #%d) ClientPort's are not equal (orig: %d, unmsh: %d)", commandIndex, originalConnection.ClientPort, remoteConnection.ClientPort) - } - } -} - -func TestCheckClientParametersMarshalSupport(t *testing.T) { - commandInput := &CheckClientParameters{ - Type: "checkClientParameters", - SourceIP: "192.168.0.139", - SourcePort: 19132, - DestPort: 19132, - Protocol: "tcp", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if err != nil { - t.Fatal(err.Error()) - } - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckClientParameters) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if commandInput.SourceIP != commandUnmarshalled.SourceIP { - t.Fail() - log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) - } - - if commandInput.SourcePort != commandUnmarshalled.SourcePort { - t.Fail() - log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) - } - - if commandInput.DestPort != commandUnmarshalled.DestPort { - t.Fail() - log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) - } - - if commandInput.Protocol != commandUnmarshalled.Protocol { - t.Fail() - log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) - } -} - -func TestCheckServerParametersMarshalSupport(t *testing.T) { - commandInput := &CheckServerParameters{ - Type: "checkServerParameters", - Arguments: []byte("Hello from automated testing"), - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - if err != nil { - t.Fatal(err.Error()) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckServerParameters) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { - log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) - } -} - -func TestCheckParametersResponseMarshalSupport(t *testing.T) { - commandInput := &CheckParametersResponse{ - Type: "checkParametersResponse", - InResponseTo: "checkClientParameters", - IsValid: true, - Message: "Hello from automated testing", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if err != nil { - t.Fatal(err.Error()) - } - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckParametersResponse) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if commandInput.InResponseTo != commandUnmarshalled.InResponseTo { - t.Fail() - log.Printf("InResponseTo's are not equal (orig: %s, unmsh: %s)", commandInput.InResponseTo, commandUnmarshalled.InResponseTo) - } - - if commandInput.IsValid != commandUnmarshalled.IsValid { - t.Fail() - log.Printf("IsValid's are not equal (orig: %t, unmsh: %t)", commandInput.IsValid, commandUnmarshalled.IsValid) - } - - if commandInput.Message != commandUnmarshalled.Message { - t.Fail() - log.Printf("Messages are not equal (orig: %s, unmsh: %s)", commandInput.Message, commandUnmarshalled.Message) - } -} - -func TestBackendStatusRequestMarshalSupport(t *testing.T) { - commandInput := &BackendStatusRequest{ - Type: "backendStatusRequest", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - if err != nil { - t.Fatal(err.Error()) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusRequest) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } -} - -func TestBackendStatusResponseMarshalSupport(t *testing.T) { - commandInput := &BackendStatusResponse{ - Type: "backendStatusResponse", - IsRunning: true, - StatusCode: StatusFailure, - Message: "Hello from automated testing", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - if err != nil { - t.Fatal(err.Error()) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusResponse) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if commandInput.IsRunning != commandUnmarshalled.IsRunning { - t.Fail() - log.Printf("IsRunning's are not equal (orig: %t, unmsh: %t)", commandInput.IsRunning, commandUnmarshalled.IsRunning) - } - - if commandInput.StatusCode != commandUnmarshalled.StatusCode { - t.Fail() - log.Printf("StatusCodes are not equal (orig: %d, unmsh: %d)", commandInput.StatusCode, commandUnmarshalled.StatusCode) - } - - if commandInput.Message != commandUnmarshalled.Message { - t.Fail() - log.Printf("Messages are not equal (orig: %s, unmsh: %s)", commandInput.Message, commandUnmarshalled.Message) - } -} - -func TestProxyStatusRequestMarshalSupport(t *testing.T) { - commandInput := &ProxyStatusRequest{ - Type: "proxyStatusRequest", - SourceIP: "192.168.0.139", - SourcePort: 19132, - DestPort: 19132, - Protocol: "tcp", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if err != nil { - t.Fatal(err.Error()) - } - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if commandInput.SourceIP != commandUnmarshalled.SourceIP { - t.Fail() - log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) - } - - if commandInput.SourcePort != commandUnmarshalled.SourcePort { - t.Fail() - log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) - } - - if commandInput.DestPort != commandUnmarshalled.DestPort { - t.Fail() - log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) - } - - if commandInput.Protocol != commandUnmarshalled.Protocol { - t.Fail() - log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) - } -} - -func TestProxyStatusResponseMarshalSupport(t *testing.T) { - commandInput := &ProxyStatusResponse{ - Type: "proxyStatusResponse", - SourceIP: "192.168.0.139", - SourcePort: 19132, - DestPort: 19132, - Protocol: "tcp", - IsActive: true, - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if err != nil { - t.Fatal(err.Error()) - } - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - if commandInput.SourceIP != commandUnmarshalled.SourceIP { - t.Fail() - log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) - } - - if commandInput.SourcePort != commandUnmarshalled.SourcePort { - t.Fail() - log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) - } - - if commandInput.DestPort != commandUnmarshalled.DestPort { - t.Fail() - log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) - } - - if commandInput.Protocol != commandUnmarshalled.Protocol { - t.Fail() - log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) - } - - if commandInput.IsActive != commandUnmarshalled.IsActive { - t.Fail() - log.Printf("IsActive's are not equal (orig: %t, unmsh: %t)", commandInput.IsActive, commandUnmarshalled.IsActive) - } -} - -func TestProxyConnectionRequestMarshalSupport(t *testing.T) { - commandInput := &ProxyInstanceRequest{ - Type: "proxyInstanceRequest", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - if err != nil { - t.Fatal(err.Error()) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceRequest) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } -} - -func TestProxyConnectionResponseMarshalSupport(t *testing.T) { - commandInput := &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", - Proxies: []*ProxyInstance{ - { - SourceIP: "192.168.0.168", - SourcePort: 25565, - DestPort: 25565, - Protocol: "tcp", - }, - { - SourceIP: "127.0.0.1", - SourcePort: 19132, - DestPort: 19132, - Protocol: "udp", - }, - { - SourceIP: "68.42.203.47", - SourcePort: 22, - DestPort: 2222, - Protocol: "tcp", - }, - }, - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) - - if err != nil { - t.Fatal(err.Error()) - } - - if logLevel == "debug" { - log.Printf("Generated array contents: %v", commandMarshalled) - } - - buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) - - if err != nil { - t.Fatal(err.Error()) - } - - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) - - if !ok { - t.Fatal("failed typecast") - } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - - for proxyIndex, originalProxy := range commandInput.Proxies { - remoteProxy := commandUnmarshalled.Proxies[proxyIndex] - - if originalProxy.SourceIP != remoteProxy.SourceIP { - t.Fail() - log.Printf("(in #%d) SourceIP's are not equal (orig: %s, unmsh: %s)", proxyIndex, originalProxy.SourceIP, remoteProxy.SourceIP) - } - - if originalProxy.SourcePort != remoteProxy.SourcePort { - t.Fail() - log.Printf("(in #%d) SourcePort's are not equal (orig: %d, unmsh: %d)", proxyIndex, originalProxy.SourcePort, remoteProxy.SourcePort) - } - - if originalProxy.DestPort != remoteProxy.DestPort { - t.Fail() - log.Printf("(in #%d) DestPort's are not equal (orig: %d, unmsh: %d)", proxyIndex, originalProxy.DestPort, remoteProxy.DestPort) - } - - if originalProxy.Protocol != remoteProxy.Protocol { - t.Fail() - log.Printf("(in #%d) ClientIP's are not equal (orig: %s, unmsh: %s)", proxyIndex, originalProxy.Protocol, remoteProxy.Protocol) - } - } -} diff --git a/backend/commonbackend/unmarshal.go b/backend/commonbackend/unmarshal.go deleted file mode 100644 index b8500dd..0000000 --- a/backend/commonbackend/unmarshal.go +++ /dev/null @@ -1,665 +0,0 @@ -package commonbackend - -import ( - "encoding/binary" - "fmt" - "io" - "net" -) - -func unmarshalIndividualConnectionStruct(conn io.Reader) (*ProxyClientConnection, error) { - serverIPVersion := make([]byte, 1) - - if _, err := conn.Read(serverIPVersion); err != nil { - return nil, fmt.Errorf("couldn't read server IP version") - } - - var serverIPSize uint8 - - if serverIPVersion[0] == 4 { - serverIPSize = IPv4Size - } else if serverIPVersion[0] == 6 { - serverIPSize = IPv6Size - } else if serverIPVersion[0] == '\n' { - return nil, fmt.Errorf("no data found") - } else { - return nil, fmt.Errorf("invalid server IP version recieved") - } - - serverIP := make(net.IP, serverIPSize) - - if _, err := conn.Read(serverIP); err != nil { - return nil, fmt.Errorf("couldn't read server IP") - } - - sourcePort := make([]byte, 2) - - if _, err := conn.Read(sourcePort); err != nil { - return nil, fmt.Errorf("couldn't read source port") - } - - destinationPort := make([]byte, 2) - - if _, err := conn.Read(destinationPort); err != nil { - return nil, fmt.Errorf("couldn't read source port") - } - - clientIPVersion := make([]byte, 1) - - if _, err := conn.Read(clientIPVersion); err != nil { - return nil, fmt.Errorf("couldn't read server IP version") - } - - var clientIPSize uint8 - - if clientIPVersion[0] == 4 { - clientIPSize = IPv4Size - } else if clientIPVersion[0] == 6 { - clientIPSize = IPv6Size - } else { - return nil, fmt.Errorf("invalid server IP version recieved") - } - - clientIP := make(net.IP, clientIPSize) - - if _, err := conn.Read(clientIP); err != nil { - return nil, fmt.Errorf("couldn't read server IP") - } - - clientPort := make([]byte, 2) - - if _, err := conn.Read(clientPort); err != nil { - return nil, fmt.Errorf("couldn't read source port") - } - - return &ProxyClientConnection{ - SourceIP: serverIP.String(), - SourcePort: binary.BigEndian.Uint16(sourcePort), - DestPort: binary.BigEndian.Uint16(destinationPort), - ClientIP: clientIP.String(), - ClientPort: binary.BigEndian.Uint16(clientPort), - }, nil -} - -func unmarshalIndividualProxyStruct(conn io.Reader) (*ProxyInstance, error) { - ipVersion := make([]byte, 1) - - if _, err := conn.Read(ipVersion); err != nil { - return nil, fmt.Errorf("couldn't read ip version") - } - - var ipSize uint8 - - if ipVersion[0] == 4 { - ipSize = IPv4Size - } else if ipVersion[0] == 6 { - ipSize = IPv6Size - } else if ipVersion[0] == '\n' { - return nil, fmt.Errorf("no data found") - } else { - return nil, fmt.Errorf("invalid IP version recieved") - } - - ip := make(net.IP, ipSize) - - if _, err := conn.Read(ip); err != nil { - return nil, fmt.Errorf("couldn't read source IP") - } - - sourcePort := make([]byte, 2) - - if _, err := conn.Read(sourcePort); err != nil { - return nil, fmt.Errorf("couldn't read source port") - } - - destPort := make([]byte, 2) - - if _, err := conn.Read(destPort); err != nil { - return nil, fmt.Errorf("couldn't read destination port") - } - - protocolBytes := make([]byte, 1) - - if _, err := conn.Read(protocolBytes); err != nil { - return nil, fmt.Errorf("couldn't read protocol") - } - - var protocol string - - if protocolBytes[0] == TCP { - protocol = "tcp" - } else if protocolBytes[0] == UDP { - protocol = "udp" - } else { - return nil, fmt.Errorf("invalid protocol") - } - - return &ProxyInstance{ - SourceIP: ip.String(), - SourcePort: binary.BigEndian.Uint16(sourcePort), - DestPort: binary.BigEndian.Uint16(destPort), - Protocol: protocol, - }, nil -} - -func Unmarshal(conn io.Reader) (string, interface{}, error) { - commandType := make([]byte, 1) - - if _, err := conn.Read(commandType); err != nil { - return "", nil, fmt.Errorf("couldn't read command") - } - - switch commandType[0] { - case StartID: - argumentsLength := make([]byte, 2) - - if _, err := conn.Read(argumentsLength); err != nil { - return "", nil, fmt.Errorf("couldn't read argument length") - } - - arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) - - if _, err := conn.Read(arguments); err != nil { - return "", nil, fmt.Errorf("couldn't read arguments") - } - - return "start", &Start{ - Type: "start", - Arguments: arguments, - }, nil - case StopID: - return "stop", &Stop{ - Type: "stop", - }, nil - case AddProxyID: - ipVersion := make([]byte, 1) - - if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") - } - - var ipSize uint8 - - if ipVersion[0] == 4 { - ipSize = IPv4Size - } else if ipVersion[0] == 6 { - ipSize = IPv6Size - } else { - return "", nil, fmt.Errorf("invalid IP version recieved") - } - - ip := make(net.IP, ipSize) - - if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") - } - - sourcePort := make([]byte, 2) - - if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") - } - - destPort := make([]byte, 2) - - if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") - } - - protocolBytes := make([]byte, 1) - - if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") - } - - var protocol string - - if protocolBytes[0] == TCP { - protocol = "tcp" - } else if protocolBytes[1] == UDP { - protocol = "udp" - } else { - return "", nil, fmt.Errorf("invalid protocol") - } - - return "addProxy", &AddProxy{ - Type: "addProxy", - SourceIP: ip.String(), - SourcePort: binary.BigEndian.Uint16(sourcePort), - DestPort: binary.BigEndian.Uint16(destPort), - Protocol: protocol, - }, nil - case RemoveProxyID: - ipVersion := make([]byte, 1) - - if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") - } - - var ipSize uint8 - - if ipVersion[0] == 4 { - ipSize = IPv4Size - } else if ipVersion[0] == 6 { - ipSize = IPv6Size - } else { - return "", nil, fmt.Errorf("invalid IP version recieved") - } - - ip := make(net.IP, ipSize) - - if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") - } - - sourcePort := make([]byte, 2) - - if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") - } - - destPort := make([]byte, 2) - - if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") - } - - protocolBytes := make([]byte, 1) - - if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") - } - - var protocol string - - if protocolBytes[0] == TCP { - protocol = "tcp" - } else if protocolBytes[1] == UDP { - protocol = "udp" - } else { - return "", nil, fmt.Errorf("invalid protocol") - } - - return "removeProxy", &RemoveProxy{ - Type: "removeProxy", - SourceIP: ip.String(), - SourcePort: binary.BigEndian.Uint16(sourcePort), - DestPort: binary.BigEndian.Uint16(destPort), - Protocol: protocol, - }, nil - case ProxyConnectionsResponseID: - connections := []*ProxyClientConnection{} - delimiter := make([]byte, 1) - var errorReturn error - - // Infinite loop because we don't know the length - for { - connection, err := unmarshalIndividualConnectionStruct(conn) - - if err != nil { - if err.Error() == "no data found" { - break - } - - return "", nil, err - } - - connections = append(connections, connection) - - if _, err := conn.Read(delimiter); err != nil { - return "", nil, fmt.Errorf("couldn't read delimiter") - } - - if delimiter[0] == '\r' { - continue - } else if delimiter[0] == '\n' { - break - } else { - // WTF? This shouldn't happen. Break out and return, but give an error - errorReturn = fmt.Errorf("invalid delimiter recieved while processing stream") - break - } - } - - return "proxyConnectionsResponse", &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", - Connections: connections, - }, errorReturn - case CheckClientParametersID: - ipVersion := make([]byte, 1) - - if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") - } - - var ipSize uint8 - - if ipVersion[0] == 4 { - ipSize = IPv4Size - } else if ipVersion[0] == 6 { - ipSize = IPv6Size - } else { - return "", nil, fmt.Errorf("invalid IP version recieved") - } - - ip := make(net.IP, ipSize) - - if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") - } - - sourcePort := make([]byte, 2) - - if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") - } - - destPort := make([]byte, 2) - - if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") - } - - protocolBytes := make([]byte, 1) - - if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") - } - - var protocol string - - if protocolBytes[0] == TCP { - protocol = "tcp" - } else if protocolBytes[1] == UDP { - protocol = "udp" - } else { - return "", nil, fmt.Errorf("invalid protocol") - } - - return "checkClientParameters", &CheckClientParameters{ - Type: "checkClientParameters", - SourceIP: ip.String(), - SourcePort: binary.BigEndian.Uint16(sourcePort), - DestPort: binary.BigEndian.Uint16(destPort), - Protocol: protocol, - }, nil - case CheckServerParametersID: - argumentsLength := make([]byte, 2) - - if _, err := conn.Read(argumentsLength); err != nil { - return "", nil, fmt.Errorf("couldn't read argument length") - } - - arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) - - if _, err := conn.Read(arguments); err != nil { - return "", nil, fmt.Errorf("couldn't read arguments") - } - - return "checkServerParameters", &CheckServerParameters{ - Type: "checkServerParameters", - Arguments: arguments, - }, nil - case CheckParametersResponseID: - checkMethodByte := make([]byte, 1) - - if _, err := conn.Read(checkMethodByte); err != nil { - return "", nil, fmt.Errorf("couldn't read check method byte") - } - - var checkMethod string - - if checkMethodByte[0] == CheckClientParametersID { - checkMethod = "checkClientParameters" - } else if checkMethodByte[0] == CheckServerParametersID { - checkMethod = "checkServerParameters" - } else { - return "", nil, fmt.Errorf("invalid check method recieved") - } - - isValid := make([]byte, 1) - - if _, err := conn.Read(isValid); err != nil { - return "", nil, fmt.Errorf("couldn't read isValid byte") - } - - messageLengthBytes := make([]byte, 2) - - if _, err := conn.Read(messageLengthBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message length") - } - - messageLength := binary.BigEndian.Uint16(messageLengthBytes) - var message string - - if messageLength != 0 { - messageBytes := make([]byte, messageLength) - - if _, err := conn.Read(messageBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message") - } - - message = string(messageBytes) - } - - return "checkParametersResponse", &CheckParametersResponse{ - Type: "checkParametersResponse", - InResponseTo: checkMethod, - IsValid: isValid[0] == 1, - Message: message, - }, nil - case BackendStatusResponseID: - isRunning := make([]byte, 1) - - if _, err := conn.Read(isRunning); err != nil { - return "", nil, fmt.Errorf("couldn't read isRunning field") - } - - statusCode := make([]byte, 1) - - if _, err := conn.Read(statusCode); err != nil { - return "", nil, fmt.Errorf("couldn't read status code field") - } - - messageLengthBytes := make([]byte, 2) - - if _, err := conn.Read(messageLengthBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message length") - } - - messageLength := binary.BigEndian.Uint16(messageLengthBytes) - var message string - - if messageLength != 0 { - messageBytes := make([]byte, messageLength) - - if _, err := conn.Read(messageBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message") - } - - message = string(messageBytes) - } - - return "backendStatusResponse", &BackendStatusResponse{ - Type: "backendStatusResponse", - IsRunning: isRunning[0] == 1, - StatusCode: int(statusCode[0]), - Message: message, - }, nil - case BackendStatusRequestID: - return "backendStatusRequest", &BackendStatusRequest{ - Type: "backendStatusRequest", - }, nil - case ProxyStatusRequestID: - ipVersion := make([]byte, 1) - - if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") - } - - var ipSize uint8 - - if ipVersion[0] == 4 { - ipSize = IPv4Size - } else if ipVersion[0] == 6 { - ipSize = IPv6Size - } else { - return "", nil, fmt.Errorf("invalid IP version recieved") - } - - ip := make(net.IP, ipSize) - - if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") - } - - sourcePort := make([]byte, 2) - - if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") - } - - destPort := make([]byte, 2) - - if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") - } - - protocolBytes := make([]byte, 1) - - if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") - } - - var protocol string - - if protocolBytes[0] == TCP { - protocol = "tcp" - } else if protocolBytes[1] == UDP { - protocol = "udp" - } else { - return "", nil, fmt.Errorf("invalid protocol") - } - - return "proxyStatusRequest", &ProxyStatusRequest{ - Type: "proxyStatusRequest", - SourceIP: ip.String(), - SourcePort: binary.BigEndian.Uint16(sourcePort), - DestPort: binary.BigEndian.Uint16(destPort), - Protocol: protocol, - }, nil - case ProxyStatusResponseID: - ipVersion := make([]byte, 1) - - if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") - } - - var ipSize uint8 - - if ipVersion[0] == 4 { - ipSize = IPv4Size - } else if ipVersion[0] == 6 { - ipSize = IPv6Size - } else { - return "", nil, fmt.Errorf("invalid IP version recieved") - } - - ip := make(net.IP, ipSize) - - if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") - } - - sourcePort := make([]byte, 2) - - if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") - } - - destPort := make([]byte, 2) - - if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") - } - - protocolBytes := make([]byte, 1) - - if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") - } - - var protocol string - - if protocolBytes[0] == TCP { - protocol = "tcp" - } else if protocolBytes[1] == UDP { - protocol = "udp" - } else { - return "", nil, fmt.Errorf("invalid protocol") - } - - isActive := make([]byte, 1) - - if _, err := conn.Read(isActive); err != nil { - return "", nil, fmt.Errorf("couldn't read isActive field") - } - - return "proxyStatusResponse", &ProxyStatusResponse{ - Type: "proxyStatusResponse", - SourceIP: ip.String(), - SourcePort: binary.BigEndian.Uint16(sourcePort), - DestPort: binary.BigEndian.Uint16(destPort), - Protocol: protocol, - IsActive: isActive[0] == 1, - }, nil - case ProxyInstanceRequestID: - return "proxyInstanceRequest", &ProxyInstanceRequest{ - Type: "proxyInstanceRequest", - }, nil - case ProxyInstanceResponseID: - proxies := []*ProxyInstance{} - delimiter := make([]byte, 1) - var errorReturn error - - // Infinite loop because we don't know the length - for { - proxy, err := unmarshalIndividualProxyStruct(conn) - - if err != nil { - if err.Error() == "no data found" { - break - } - - return "", nil, err - } - - proxies = append(proxies, proxy) - - if _, err := conn.Read(delimiter); err != nil { - return "", nil, fmt.Errorf("couldn't read delimiter") - } - - if delimiter[0] == '\r' { - continue - } else if delimiter[0] == '\n' { - break - } else { - // WTF? This shouldn't happen. Break out and return, but give an error - errorReturn = fmt.Errorf("invalid delimiter recieved while processing stream") - break - } - } - - return "proxyInstanceResponse", &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", - Proxies: proxies, - }, errorReturn - case ProxyConnectionsRequestID: - return "proxyConnectionsRequest", &ProxyConnectionsRequest{ - Type: "proxyConnectionsRequest", - }, nil - } - - return "", nil, fmt.Errorf("couldn't match command ID") -} From 4101ce70071386ac69de439d784afe49629d6fa6 Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 10 Jan 2025 20:44:02 -0500 Subject: [PATCH 04/24] Revert "chore: Delete old commonbackend code." Wrong branch oops This reverts commit 737ba2887f1a9f58eaee98f67d7c75e9f22fc7c8. --- backend/commonbackend/constants.go | 147 ++++ backend/commonbackend/marshal.go | 507 +++++++++++++ backend/commonbackend/marshalling_test.go | 824 ++++++++++++++++++++++ backend/commonbackend/unmarshal.go | 665 +++++++++++++++++ 4 files changed, 2143 insertions(+) create mode 100644 backend/commonbackend/constants.go create mode 100644 backend/commonbackend/marshal.go create mode 100644 backend/commonbackend/marshalling_test.go create mode 100644 backend/commonbackend/unmarshal.go diff --git a/backend/commonbackend/constants.go b/backend/commonbackend/constants.go new file mode 100644 index 0000000..cdb68f2 --- /dev/null +++ b/backend/commonbackend/constants.go @@ -0,0 +1,147 @@ +package commonbackend + +type Start struct { + Type string // Will be 'start' always + Arguments []byte +} + +type Stop struct { + Type string // Will be 'stop' always +} + +type AddProxy struct { + Type string // Will be 'addProxy' always + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type RemoveProxy struct { + Type string // Will be 'removeProxy' always + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type ProxyStatusRequest struct { + Type string // Will be 'proxyStatusRequest' always + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type ProxyStatusResponse struct { + Type string // Will be 'proxyStatusResponse' always + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' + IsActive bool +} + +type ProxyInstance struct { + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type ProxyInstanceResponse struct { + Type string // Will be 'proxyConnectionResponse' always + Proxies []*ProxyInstance // List of connections +} + +type ProxyInstanceRequest struct { + Type string // Will be 'proxyConnectionRequest' always +} + +type BackendStatusResponse struct { + Type string // Will be 'backendStatusResponse' always + IsRunning bool // True if running, false if not running + StatusCode int // Either the 'Success' or 'Failure' constant + Message string // String message from the client (ex. failed to dial TCP) +} + +type BackendStatusRequest struct { + Type string // Will be 'backendStatusRequest' always +} + +type ProxyConnectionsRequest struct { + Type string // Will be 'proxyConnectionsRequest' always +} + +// Client's connection to a specific proxy +type ProxyClientConnection struct { + SourceIP string + SourcePort uint16 + DestPort uint16 + ClientIP string + ClientPort uint16 +} + +type ProxyConnectionsResponse struct { + Type string // Will be 'proxyConnectionsResponse' always + Connections []*ProxyClientConnection // List of connections +} + +type CheckClientParameters struct { + Type string // Will be 'checkClientParameters' always + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type CheckServerParameters struct { + Type string // Will be 'checkServerParameters' always + Arguments []byte +} + +// Sent as a response to either CheckClientParameters or CheckBackendParameters +type CheckParametersResponse struct { + Type string // Will be 'checkParametersResponse' always + InResponseTo string // Will be either 'checkClientParameters' or 'checkServerParameters' + IsValid bool // If true, valid, and if false, invalid + Message string // String message from the client (ex. failed to unmarshal JSON: x is not defined) +} + +const ( + StartID = iota + StopID + AddProxyID + RemoveProxyID + ProxyConnectionsResponseID + CheckClientParametersID + CheckServerParametersID + CheckParametersResponseID + ProxyConnectionsRequestID + BackendStatusResponseID + BackendStatusRequestID + ProxyStatusRequestID + ProxyStatusResponseID + ProxyInstanceResponseID + ProxyInstanceRequestID +) + +const ( + TCP = iota + UDP +) + +const ( + StatusSuccess = iota + StatusFailure +) + +const ( + // IP versions + IPv4 = 4 + IPv6 = 6 + + // TODO: net has these constants defined already. We should switch to these + IPv4Size = 4 + IPv6Size = 16 +) diff --git a/backend/commonbackend/marshal.go b/backend/commonbackend/marshal.go new file mode 100644 index 0000000..6baf02e --- /dev/null +++ b/backend/commonbackend/marshal.go @@ -0,0 +1,507 @@ +package commonbackend + +import ( + "encoding/binary" + "fmt" + "net" +) + +func marshalIndividualConnectionStruct(conn *ProxyClientConnection) []byte { + sourceIPOriginal := net.ParseIP(conn.SourceIP) + clientIPOriginal := net.ParseIP(conn.ClientIP) + + var serverIPVer uint8 + var sourceIP []byte + + if sourceIPOriginal.To4() == nil { + serverIPVer = IPv6 + sourceIP = sourceIPOriginal.To16() + } else { + serverIPVer = IPv4 + sourceIP = sourceIPOriginal.To4() + } + + var clientIPVer uint8 + var clientIP []byte + + if clientIPOriginal.To4() == nil { + clientIPVer = IPv6 + clientIP = clientIPOriginal.To16() + } else { + clientIPVer = IPv4 + clientIP = clientIPOriginal.To4() + } + + connectionBlock := make([]byte, 8+len(sourceIP)+len(clientIP)) + + connectionBlock[0] = serverIPVer + copy(connectionBlock[1:len(sourceIP)+1], sourceIP) + + binary.BigEndian.PutUint16(connectionBlock[1+len(sourceIP):3+len(sourceIP)], conn.SourcePort) + binary.BigEndian.PutUint16(connectionBlock[3+len(sourceIP):5+len(sourceIP)], conn.DestPort) + + connectionBlock[5+len(sourceIP)] = clientIPVer + copy(connectionBlock[6+len(sourceIP):6+len(sourceIP)+len(clientIP)], clientIP) + binary.BigEndian.PutUint16(connectionBlock[6+len(sourceIP)+len(clientIP):8+len(sourceIP)+len(clientIP)], conn.ClientPort) + + return connectionBlock +} + +func marshalIndividualProxyStruct(conn *ProxyInstance) ([]byte, error) { + sourceIPOriginal := net.ParseIP(conn.SourceIP) + + var sourceIPVer uint8 + var sourceIP []byte + + if sourceIPOriginal.To4() == nil { + sourceIPVer = IPv6 + sourceIP = sourceIPOriginal.To16() + } else { + sourceIPVer = IPv4 + sourceIP = sourceIPOriginal.To4() + } + + proxyBlock := make([]byte, 6+len(sourceIP)) + + proxyBlock[0] = sourceIPVer + copy(proxyBlock[1:len(sourceIP)+1], sourceIP) + + binary.BigEndian.PutUint16(proxyBlock[1+len(sourceIP):3+len(sourceIP)], conn.SourcePort) + binary.BigEndian.PutUint16(proxyBlock[3+len(sourceIP):5+len(sourceIP)], conn.DestPort) + + var protocolVersion uint8 + + if conn.Protocol == "tcp" { + protocolVersion = TCP + } else if conn.Protocol == "udp" { + protocolVersion = UDP + } else { + return proxyBlock, fmt.Errorf("invalid protocol recieved") + } + + proxyBlock[5+len(sourceIP)] = protocolVersion + + return proxyBlock, nil +} + +func Marshal(commandType string, command interface{}) ([]byte, error) { + switch commandType { + case "start": + startCommand, ok := command.(*Start) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + startCommandBytes := make([]byte, 1+2+len(startCommand.Arguments)) + startCommandBytes[0] = StartID + binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(startCommand.Arguments))) + copy(startCommandBytes[3:], startCommand.Arguments) + + return startCommandBytes, nil + case "stop": + _, ok := command.(*Stop) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + return []byte{StopID}, nil + case "addProxy": + addConnectionCommand, ok := command.(*AddProxy) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + sourceIP := net.ParseIP(addConnectionCommand.SourceIP) + + var ipVer uint8 + var ipBytes []byte + + if sourceIP.To4() == nil { + ipBytes = sourceIP.To16() + ipVer = IPv6 + } else { + ipBytes = sourceIP.To4() + ipVer = IPv4 + } + + addConnectionBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) + + addConnectionBytes[0] = AddProxyID + addConnectionBytes[1] = ipVer + + copy(addConnectionBytes[2:2+len(ipBytes)], ipBytes) + + binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], addConnectionCommand.SourcePort) + binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], addConnectionCommand.DestPort) + + var protocol uint8 + + if addConnectionCommand.Protocol == "tcp" { + protocol = TCP + } else if addConnectionCommand.Protocol == "udp" { + protocol = UDP + } else { + return nil, fmt.Errorf("invalid protocol") + } + + addConnectionBytes[6+len(ipBytes)] = protocol + + return addConnectionBytes, nil + case "removeProxy": + removeConnectionCommand, ok := command.(*RemoveProxy) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + sourceIP := net.ParseIP(removeConnectionCommand.SourceIP) + + var ipVer uint8 + var ipBytes []byte + + if sourceIP.To4() == nil { + ipBytes = sourceIP.To16() + ipVer = IPv6 + } else { + ipBytes = sourceIP.To4() + ipVer = IPv4 + } + + removeConnectionBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) + + removeConnectionBytes[0] = RemoveProxyID + removeConnectionBytes[1] = ipVer + copy(removeConnectionBytes[2:2+len(ipBytes)], ipBytes) + binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], removeConnectionCommand.SourcePort) + binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], removeConnectionCommand.DestPort) + + var protocol uint8 + + if removeConnectionCommand.Protocol == "tcp" { + protocol = TCP + } else if removeConnectionCommand.Protocol == "udp" { + protocol = UDP + } else { + return nil, fmt.Errorf("invalid protocol") + } + + removeConnectionBytes[6+len(ipBytes)] = protocol + + return removeConnectionBytes, nil + case "proxyConnectionsResponse": + allConnectionsCommand, ok := command.(*ProxyConnectionsResponse) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + connectionsArray := make([][]byte, len(allConnectionsCommand.Connections)) + totalSize := 0 + + for connIndex, conn := range allConnectionsCommand.Connections { + connectionsArray[connIndex] = marshalIndividualConnectionStruct(conn) + totalSize += len(connectionsArray[connIndex]) + 1 + } + + if totalSize == 0 { + totalSize = 1 + } + + connectionCommandArray := make([]byte, totalSize+1) + connectionCommandArray[0] = ProxyConnectionsResponseID + + currentPosition := 1 + + for _, connection := range connectionsArray { + copy(connectionCommandArray[currentPosition:currentPosition+len(connection)], connection) + connectionCommandArray[currentPosition+len(connection)] = '\r' + currentPosition += len(connection) + 1 + } + + connectionCommandArray[totalSize] = '\n' + return connectionCommandArray, nil + case "checkClientParameters": + checkClientCommand, ok := command.(*CheckClientParameters) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + sourceIP := net.ParseIP(checkClientCommand.SourceIP) + + var ipVer uint8 + var ipBytes []byte + + if sourceIP.To4() == nil { + ipBytes = sourceIP.To16() + ipVer = IPv6 + } else { + ipBytes = sourceIP.To4() + ipVer = IPv4 + } + + checkClientBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) + + checkClientBytes[0] = CheckClientParametersID + checkClientBytes[1] = ipVer + copy(checkClientBytes[2:2+len(ipBytes)], ipBytes) + binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], checkClientCommand.SourcePort) + binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], checkClientCommand.DestPort) + + var protocol uint8 + + if checkClientCommand.Protocol == "tcp" { + protocol = TCP + } else if checkClientCommand.Protocol == "udp" { + protocol = UDP + } else { + return nil, fmt.Errorf("invalid protocol") + } + + checkClientBytes[6+len(ipBytes)] = protocol + + return checkClientBytes, nil + case "checkServerParameters": + checkServerCommand, ok := command.(*CheckServerParameters) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + serverCommandBytes := make([]byte, 1+2+len(checkServerCommand.Arguments)) + serverCommandBytes[0] = CheckServerParametersID + binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(checkServerCommand.Arguments))) + copy(serverCommandBytes[3:], checkServerCommand.Arguments) + + return serverCommandBytes, nil + case "checkParametersResponse": + checkParametersCommand, ok := command.(*CheckParametersResponse) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + var checkMethod uint8 + + if checkParametersCommand.InResponseTo == "checkClientParameters" { + checkMethod = CheckClientParametersID + } else if checkParametersCommand.InResponseTo == "checkServerParameters" { + checkMethod = CheckServerParametersID + } else { + return nil, fmt.Errorf("invalid mode recieved (must be either checkClientParameters or checkServerParameters)") + } + + var isValid uint8 + + if checkParametersCommand.IsValid { + isValid = 1 + } + + checkResponseBytes := make([]byte, 3+2+len(checkParametersCommand.Message)) + checkResponseBytes[0] = CheckParametersResponseID + checkResponseBytes[1] = checkMethod + checkResponseBytes[2] = isValid + + binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(checkParametersCommand.Message))) + + if len(checkParametersCommand.Message) != 0 { + copy(checkResponseBytes[5:], []byte(checkParametersCommand.Message)) + } + + return checkResponseBytes, nil + case "backendStatusResponse": + backendStatusResponse, ok := command.(*BackendStatusResponse) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + var isRunning uint8 + + if backendStatusResponse.IsRunning { + isRunning = 1 + } else { + isRunning = 0 + } + + statusResponseBytes := make([]byte, 3+2+len(backendStatusResponse.Message)) + statusResponseBytes[0] = BackendStatusResponseID + statusResponseBytes[1] = isRunning + statusResponseBytes[2] = byte(backendStatusResponse.StatusCode) + + binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(backendStatusResponse.Message))) + + if len(backendStatusResponse.Message) != 0 { + copy(statusResponseBytes[5:], []byte(backendStatusResponse.Message)) + } + + return statusResponseBytes, nil + case "backendStatusRequest": + _, ok := command.(*BackendStatusRequest) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + statusRequestBytes := make([]byte, 1) + statusRequestBytes[0] = BackendStatusRequestID + + return statusRequestBytes, nil + case "proxyStatusRequest": + proxyStatusRequest, ok := command.(*ProxyStatusRequest) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + sourceIP := net.ParseIP(proxyStatusRequest.SourceIP) + + var ipVer uint8 + var ipBytes []byte + + if sourceIP.To4() == nil { + ipBytes = sourceIP.To16() + ipVer = IPv6 + } else { + ipBytes = sourceIP.To4() + ipVer = IPv4 + } + + proxyStatusRequestBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) + + proxyStatusRequestBytes[0] = ProxyStatusRequestID + proxyStatusRequestBytes[1] = ipVer + + copy(proxyStatusRequestBytes[2:2+len(ipBytes)], ipBytes) + + binary.BigEndian.PutUint16(proxyStatusRequestBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusRequest.SourcePort) + binary.BigEndian.PutUint16(proxyStatusRequestBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusRequest.DestPort) + + var protocol uint8 + + if proxyStatusRequest.Protocol == "tcp" { + protocol = TCP + } else if proxyStatusRequest.Protocol == "udp" { + protocol = UDP + } else { + return nil, fmt.Errorf("invalid protocol") + } + + proxyStatusRequestBytes[6+len(ipBytes)] = protocol + + return proxyStatusRequestBytes, nil + case "proxyStatusResponse": + proxyStatusResponse, ok := command.(*ProxyStatusResponse) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + sourceIP := net.ParseIP(proxyStatusResponse.SourceIP) + + var ipVer uint8 + var ipBytes []byte + + if sourceIP.To4() == nil { + ipBytes = sourceIP.To16() + ipVer = IPv6 + } else { + ipBytes = sourceIP.To4() + ipVer = IPv4 + } + + proxyStatusResponseBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) + + proxyStatusResponseBytes[0] = ProxyStatusResponseID + proxyStatusResponseBytes[1] = ipVer + + copy(proxyStatusResponseBytes[2:2+len(ipBytes)], ipBytes) + + binary.BigEndian.PutUint16(proxyStatusResponseBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusResponse.SourcePort) + binary.BigEndian.PutUint16(proxyStatusResponseBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusResponse.DestPort) + + var protocol uint8 + + if proxyStatusResponse.Protocol == "tcp" { + protocol = TCP + } else if proxyStatusResponse.Protocol == "udp" { + protocol = UDP + } else { + return nil, fmt.Errorf("invalid protocol") + } + + proxyStatusResponseBytes[6+len(ipBytes)] = protocol + + var isActive uint8 + + if proxyStatusResponse.IsActive { + isActive = 1 + } else { + isActive = 0 + } + + proxyStatusResponseBytes[7+len(ipBytes)] = isActive + + return proxyStatusResponseBytes, nil + case "proxyInstanceResponse": + proxyConectionResponse, ok := command.(*ProxyInstanceResponse) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + proxyArray := make([][]byte, len(proxyConectionResponse.Proxies)) + totalSize := 0 + + for proxyIndex, proxy := range proxyConectionResponse.Proxies { + var err error + proxyArray[proxyIndex], err = marshalIndividualProxyStruct(proxy) + + if err != nil { + return nil, err + } + + totalSize += len(proxyArray[proxyIndex]) + 1 + } + + if totalSize == 0 { + totalSize = 1 + } + + connectionCommandArray := make([]byte, totalSize+1) + connectionCommandArray[0] = ProxyInstanceResponseID + + currentPosition := 1 + + for _, connection := range proxyArray { + copy(connectionCommandArray[currentPosition:currentPosition+len(connection)], connection) + connectionCommandArray[currentPosition+len(connection)] = '\r' + currentPosition += len(connection) + 1 + } + + connectionCommandArray[totalSize] = '\n' + + return connectionCommandArray, nil + case "proxyInstanceRequest": + _, ok := command.(*ProxyInstanceRequest) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + return []byte{ProxyInstanceRequestID}, nil + case "proxyConnectionsRequest": + _, ok := command.(*ProxyConnectionsRequest) + + if !ok { + return nil, fmt.Errorf("failed to typecast") + } + + return []byte{ProxyConnectionsRequestID}, nil + } + + return nil, fmt.Errorf("couldn't match command name") +} diff --git a/backend/commonbackend/marshalling_test.go b/backend/commonbackend/marshalling_test.go new file mode 100644 index 0000000..1c93f94 --- /dev/null +++ b/backend/commonbackend/marshalling_test.go @@ -0,0 +1,824 @@ +package commonbackend + +import ( + "bytes" + "log" + "os" + "testing" +) + +var logLevel = os.Getenv("HERMES_LOG_LEVEL") + +func TestStartCommandMarshalSupport(t *testing.T) { + commandInput := &Start{ + Type: "start", + Arguments: []byte("Hello from automated testing"), + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*Start) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { + log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) + } +} + +func TestStopCommandMarshalSupport(t *testing.T) { + commandInput := &Stop{ + Type: "stop", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*Stop) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } +} + +func TestAddConnectionCommandMarshalSupport(t *testing.T) { + commandInput := &AddProxy{ + Type: "addProxy", + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*AddProxy) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } +} + +func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { + commandInput := &RemoveProxy{ + Type: "removeProxy", + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } +} + +func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { + commandInput := &ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: []*ProxyClientConnection{ + { + SourceIP: "127.0.0.1", + SourcePort: 19132, + DestPort: 19132, + ClientIP: "127.0.0.1", + ClientPort: 12321, + }, + { + SourceIP: "127.0.0.1", + SourcePort: 19132, + DestPort: 19132, + ClientIP: "192.168.0.168", + ClientPort: 23457, + }, + { + SourceIP: "127.0.0.1", + SourcePort: 19132, + DestPort: 19132, + ClientIP: "68.42.203.47", + ClientPort: 38721, + }, + }, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + for commandIndex, originalConnection := range commandInput.Connections { + remoteConnection := commandUnmarshalled.Connections[commandIndex] + + if originalConnection.SourceIP != remoteConnection.SourceIP { + t.Fail() + log.Printf("(in #%d) SourceIP's are not equal (orig: %s, unmsh: %s)", commandIndex, originalConnection.SourceIP, remoteConnection.SourceIP) + } + + if originalConnection.SourcePort != remoteConnection.SourcePort { + t.Fail() + log.Printf("(in #%d) SourcePort's are not equal (orig: %d, unmsh: %d)", commandIndex, originalConnection.SourcePort, remoteConnection.SourcePort) + } + + if originalConnection.DestPort != remoteConnection.DestPort { + t.Fail() + log.Printf("(in #%d) DestPort's are not equal (orig: %d, unmsh: %d)", commandIndex, originalConnection.DestPort, remoteConnection.DestPort) + } + + if originalConnection.ClientIP != remoteConnection.ClientIP { + t.Fail() + log.Printf("(in #%d) ClientIP's are not equal (orig: %s, unmsh: %s)", commandIndex, originalConnection.ClientIP, remoteConnection.ClientIP) + } + + if originalConnection.ClientPort != remoteConnection.ClientPort { + t.Fail() + log.Printf("(in #%d) ClientPort's are not equal (orig: %d, unmsh: %d)", commandIndex, originalConnection.ClientPort, remoteConnection.ClientPort) + } + } +} + +func TestCheckClientParametersMarshalSupport(t *testing.T) { + commandInput := &CheckClientParameters{ + Type: "checkClientParameters", + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckClientParameters) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } +} + +func TestCheckServerParametersMarshalSupport(t *testing.T) { + commandInput := &CheckServerParameters{ + Type: "checkServerParameters", + Arguments: []byte("Hello from automated testing"), + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckServerParameters) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { + log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) + } +} + +func TestCheckParametersResponseMarshalSupport(t *testing.T) { + commandInput := &CheckParametersResponse{ + Type: "checkParametersResponse", + InResponseTo: "checkClientParameters", + IsValid: true, + Message: "Hello from automated testing", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckParametersResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.InResponseTo != commandUnmarshalled.InResponseTo { + t.Fail() + log.Printf("InResponseTo's are not equal (orig: %s, unmsh: %s)", commandInput.InResponseTo, commandUnmarshalled.InResponseTo) + } + + if commandInput.IsValid != commandUnmarshalled.IsValid { + t.Fail() + log.Printf("IsValid's are not equal (orig: %t, unmsh: %t)", commandInput.IsValid, commandUnmarshalled.IsValid) + } + + if commandInput.Message != commandUnmarshalled.Message { + t.Fail() + log.Printf("Messages are not equal (orig: %s, unmsh: %s)", commandInput.Message, commandUnmarshalled.Message) + } +} + +func TestBackendStatusRequestMarshalSupport(t *testing.T) { + commandInput := &BackendStatusRequest{ + Type: "backendStatusRequest", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } +} + +func TestBackendStatusResponseMarshalSupport(t *testing.T) { + commandInput := &BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: true, + StatusCode: StatusFailure, + Message: "Hello from automated testing", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.IsRunning != commandUnmarshalled.IsRunning { + t.Fail() + log.Printf("IsRunning's are not equal (orig: %t, unmsh: %t)", commandInput.IsRunning, commandUnmarshalled.IsRunning) + } + + if commandInput.StatusCode != commandUnmarshalled.StatusCode { + t.Fail() + log.Printf("StatusCodes are not equal (orig: %d, unmsh: %d)", commandInput.StatusCode, commandUnmarshalled.StatusCode) + } + + if commandInput.Message != commandUnmarshalled.Message { + t.Fail() + log.Printf("Messages are not equal (orig: %s, unmsh: %s)", commandInput.Message, commandUnmarshalled.Message) + } +} + +func TestProxyStatusRequestMarshalSupport(t *testing.T) { + commandInput := &ProxyStatusRequest{ + Type: "proxyStatusRequest", + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } +} + +func TestProxyStatusResponseMarshalSupport(t *testing.T) { + commandInput := &ProxyStatusResponse{ + Type: "proxyStatusResponse", + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + IsActive: true, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } + + if commandInput.IsActive != commandUnmarshalled.IsActive { + t.Fail() + log.Printf("IsActive's are not equal (orig: %t, unmsh: %t)", commandInput.IsActive, commandUnmarshalled.IsActive) + } +} + +func TestProxyConnectionRequestMarshalSupport(t *testing.T) { + commandInput := &ProxyInstanceRequest{ + Type: "proxyInstanceRequest", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } +} + +func TestProxyConnectionResponseMarshalSupport(t *testing.T) { + commandInput := &ProxyInstanceResponse{ + Type: "proxyInstanceResponse", + Proxies: []*ProxyInstance{ + { + SourceIP: "192.168.0.168", + SourcePort: 25565, + DestPort: 25565, + Protocol: "tcp", + }, + { + SourceIP: "127.0.0.1", + SourcePort: 19132, + DestPort: 19132, + Protocol: "udp", + }, + { + SourceIP: "68.42.203.47", + SourcePort: 22, + DestPort: 2222, + Protocol: "tcp", + }, + }, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + for proxyIndex, originalProxy := range commandInput.Proxies { + remoteProxy := commandUnmarshalled.Proxies[proxyIndex] + + if originalProxy.SourceIP != remoteProxy.SourceIP { + t.Fail() + log.Printf("(in #%d) SourceIP's are not equal (orig: %s, unmsh: %s)", proxyIndex, originalProxy.SourceIP, remoteProxy.SourceIP) + } + + if originalProxy.SourcePort != remoteProxy.SourcePort { + t.Fail() + log.Printf("(in #%d) SourcePort's are not equal (orig: %d, unmsh: %d)", proxyIndex, originalProxy.SourcePort, remoteProxy.SourcePort) + } + + if originalProxy.DestPort != remoteProxy.DestPort { + t.Fail() + log.Printf("(in #%d) DestPort's are not equal (orig: %d, unmsh: %d)", proxyIndex, originalProxy.DestPort, remoteProxy.DestPort) + } + + if originalProxy.Protocol != remoteProxy.Protocol { + t.Fail() + log.Printf("(in #%d) ClientIP's are not equal (orig: %s, unmsh: %s)", proxyIndex, originalProxy.Protocol, remoteProxy.Protocol) + } + } +} diff --git a/backend/commonbackend/unmarshal.go b/backend/commonbackend/unmarshal.go new file mode 100644 index 0000000..b8500dd --- /dev/null +++ b/backend/commonbackend/unmarshal.go @@ -0,0 +1,665 @@ +package commonbackend + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +func unmarshalIndividualConnectionStruct(conn io.Reader) (*ProxyClientConnection, error) { + serverIPVersion := make([]byte, 1) + + if _, err := conn.Read(serverIPVersion); err != nil { + return nil, fmt.Errorf("couldn't read server IP version") + } + + var serverIPSize uint8 + + if serverIPVersion[0] == 4 { + serverIPSize = IPv4Size + } else if serverIPVersion[0] == 6 { + serverIPSize = IPv6Size + } else if serverIPVersion[0] == '\n' { + return nil, fmt.Errorf("no data found") + } else { + return nil, fmt.Errorf("invalid server IP version recieved") + } + + serverIP := make(net.IP, serverIPSize) + + if _, err := conn.Read(serverIP); err != nil { + return nil, fmt.Errorf("couldn't read server IP") + } + + sourcePort := make([]byte, 2) + + if _, err := conn.Read(sourcePort); err != nil { + return nil, fmt.Errorf("couldn't read source port") + } + + destinationPort := make([]byte, 2) + + if _, err := conn.Read(destinationPort); err != nil { + return nil, fmt.Errorf("couldn't read source port") + } + + clientIPVersion := make([]byte, 1) + + if _, err := conn.Read(clientIPVersion); err != nil { + return nil, fmt.Errorf("couldn't read server IP version") + } + + var clientIPSize uint8 + + if clientIPVersion[0] == 4 { + clientIPSize = IPv4Size + } else if clientIPVersion[0] == 6 { + clientIPSize = IPv6Size + } else { + return nil, fmt.Errorf("invalid server IP version recieved") + } + + clientIP := make(net.IP, clientIPSize) + + if _, err := conn.Read(clientIP); err != nil { + return nil, fmt.Errorf("couldn't read server IP") + } + + clientPort := make([]byte, 2) + + if _, err := conn.Read(clientPort); err != nil { + return nil, fmt.Errorf("couldn't read source port") + } + + return &ProxyClientConnection{ + SourceIP: serverIP.String(), + SourcePort: binary.BigEndian.Uint16(sourcePort), + DestPort: binary.BigEndian.Uint16(destinationPort), + ClientIP: clientIP.String(), + ClientPort: binary.BigEndian.Uint16(clientPort), + }, nil +} + +func unmarshalIndividualProxyStruct(conn io.Reader) (*ProxyInstance, error) { + ipVersion := make([]byte, 1) + + if _, err := conn.Read(ipVersion); err != nil { + return nil, fmt.Errorf("couldn't read ip version") + } + + var ipSize uint8 + + if ipVersion[0] == 4 { + ipSize = IPv4Size + } else if ipVersion[0] == 6 { + ipSize = IPv6Size + } else if ipVersion[0] == '\n' { + return nil, fmt.Errorf("no data found") + } else { + return nil, fmt.Errorf("invalid IP version recieved") + } + + ip := make(net.IP, ipSize) + + if _, err := conn.Read(ip); err != nil { + return nil, fmt.Errorf("couldn't read source IP") + } + + sourcePort := make([]byte, 2) + + if _, err := conn.Read(sourcePort); err != nil { + return nil, fmt.Errorf("couldn't read source port") + } + + destPort := make([]byte, 2) + + if _, err := conn.Read(destPort); err != nil { + return nil, fmt.Errorf("couldn't read destination port") + } + + protocolBytes := make([]byte, 1) + + if _, err := conn.Read(protocolBytes); err != nil { + return nil, fmt.Errorf("couldn't read protocol") + } + + var protocol string + + if protocolBytes[0] == TCP { + protocol = "tcp" + } else if protocolBytes[0] == UDP { + protocol = "udp" + } else { + return nil, fmt.Errorf("invalid protocol") + } + + return &ProxyInstance{ + SourceIP: ip.String(), + SourcePort: binary.BigEndian.Uint16(sourcePort), + DestPort: binary.BigEndian.Uint16(destPort), + Protocol: protocol, + }, nil +} + +func Unmarshal(conn io.Reader) (string, interface{}, error) { + commandType := make([]byte, 1) + + if _, err := conn.Read(commandType); err != nil { + return "", nil, fmt.Errorf("couldn't read command") + } + + switch commandType[0] { + case StartID: + argumentsLength := make([]byte, 2) + + if _, err := conn.Read(argumentsLength); err != nil { + return "", nil, fmt.Errorf("couldn't read argument length") + } + + arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) + + if _, err := conn.Read(arguments); err != nil { + return "", nil, fmt.Errorf("couldn't read arguments") + } + + return "start", &Start{ + Type: "start", + Arguments: arguments, + }, nil + case StopID: + return "stop", &Stop{ + Type: "stop", + }, nil + case AddProxyID: + ipVersion := make([]byte, 1) + + if _, err := conn.Read(ipVersion); err != nil { + return "", nil, fmt.Errorf("couldn't read ip version") + } + + var ipSize uint8 + + if ipVersion[0] == 4 { + ipSize = IPv4Size + } else if ipVersion[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version recieved") + } + + ip := make(net.IP, ipSize) + + if _, err := conn.Read(ip); err != nil { + return "", nil, fmt.Errorf("couldn't read source IP") + } + + sourcePort := make([]byte, 2) + + if _, err := conn.Read(sourcePort); err != nil { + return "", nil, fmt.Errorf("couldn't read source port") + } + + destPort := make([]byte, 2) + + if _, err := conn.Read(destPort); err != nil { + return "", nil, fmt.Errorf("couldn't read destination port") + } + + protocolBytes := make([]byte, 1) + + if _, err := conn.Read(protocolBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read protocol") + } + + var protocol string + + if protocolBytes[0] == TCP { + protocol = "tcp" + } else if protocolBytes[1] == UDP { + protocol = "udp" + } else { + return "", nil, fmt.Errorf("invalid protocol") + } + + return "addProxy", &AddProxy{ + Type: "addProxy", + SourceIP: ip.String(), + SourcePort: binary.BigEndian.Uint16(sourcePort), + DestPort: binary.BigEndian.Uint16(destPort), + Protocol: protocol, + }, nil + case RemoveProxyID: + ipVersion := make([]byte, 1) + + if _, err := conn.Read(ipVersion); err != nil { + return "", nil, fmt.Errorf("couldn't read ip version") + } + + var ipSize uint8 + + if ipVersion[0] == 4 { + ipSize = IPv4Size + } else if ipVersion[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version recieved") + } + + ip := make(net.IP, ipSize) + + if _, err := conn.Read(ip); err != nil { + return "", nil, fmt.Errorf("couldn't read source IP") + } + + sourcePort := make([]byte, 2) + + if _, err := conn.Read(sourcePort); err != nil { + return "", nil, fmt.Errorf("couldn't read source port") + } + + destPort := make([]byte, 2) + + if _, err := conn.Read(destPort); err != nil { + return "", nil, fmt.Errorf("couldn't read destination port") + } + + protocolBytes := make([]byte, 1) + + if _, err := conn.Read(protocolBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read protocol") + } + + var protocol string + + if protocolBytes[0] == TCP { + protocol = "tcp" + } else if protocolBytes[1] == UDP { + protocol = "udp" + } else { + return "", nil, fmt.Errorf("invalid protocol") + } + + return "removeProxy", &RemoveProxy{ + Type: "removeProxy", + SourceIP: ip.String(), + SourcePort: binary.BigEndian.Uint16(sourcePort), + DestPort: binary.BigEndian.Uint16(destPort), + Protocol: protocol, + }, nil + case ProxyConnectionsResponseID: + connections := []*ProxyClientConnection{} + delimiter := make([]byte, 1) + var errorReturn error + + // Infinite loop because we don't know the length + for { + connection, err := unmarshalIndividualConnectionStruct(conn) + + if err != nil { + if err.Error() == "no data found" { + break + } + + return "", nil, err + } + + connections = append(connections, connection) + + if _, err := conn.Read(delimiter); err != nil { + return "", nil, fmt.Errorf("couldn't read delimiter") + } + + if delimiter[0] == '\r' { + continue + } else if delimiter[0] == '\n' { + break + } else { + // WTF? This shouldn't happen. Break out and return, but give an error + errorReturn = fmt.Errorf("invalid delimiter recieved while processing stream") + break + } + } + + return "proxyConnectionsResponse", &ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: connections, + }, errorReturn + case CheckClientParametersID: + ipVersion := make([]byte, 1) + + if _, err := conn.Read(ipVersion); err != nil { + return "", nil, fmt.Errorf("couldn't read ip version") + } + + var ipSize uint8 + + if ipVersion[0] == 4 { + ipSize = IPv4Size + } else if ipVersion[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version recieved") + } + + ip := make(net.IP, ipSize) + + if _, err := conn.Read(ip); err != nil { + return "", nil, fmt.Errorf("couldn't read source IP") + } + + sourcePort := make([]byte, 2) + + if _, err := conn.Read(sourcePort); err != nil { + return "", nil, fmt.Errorf("couldn't read source port") + } + + destPort := make([]byte, 2) + + if _, err := conn.Read(destPort); err != nil { + return "", nil, fmt.Errorf("couldn't read destination port") + } + + protocolBytes := make([]byte, 1) + + if _, err := conn.Read(protocolBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read protocol") + } + + var protocol string + + if protocolBytes[0] == TCP { + protocol = "tcp" + } else if protocolBytes[1] == UDP { + protocol = "udp" + } else { + return "", nil, fmt.Errorf("invalid protocol") + } + + return "checkClientParameters", &CheckClientParameters{ + Type: "checkClientParameters", + SourceIP: ip.String(), + SourcePort: binary.BigEndian.Uint16(sourcePort), + DestPort: binary.BigEndian.Uint16(destPort), + Protocol: protocol, + }, nil + case CheckServerParametersID: + argumentsLength := make([]byte, 2) + + if _, err := conn.Read(argumentsLength); err != nil { + return "", nil, fmt.Errorf("couldn't read argument length") + } + + arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) + + if _, err := conn.Read(arguments); err != nil { + return "", nil, fmt.Errorf("couldn't read arguments") + } + + return "checkServerParameters", &CheckServerParameters{ + Type: "checkServerParameters", + Arguments: arguments, + }, nil + case CheckParametersResponseID: + checkMethodByte := make([]byte, 1) + + if _, err := conn.Read(checkMethodByte); err != nil { + return "", nil, fmt.Errorf("couldn't read check method byte") + } + + var checkMethod string + + if checkMethodByte[0] == CheckClientParametersID { + checkMethod = "checkClientParameters" + } else if checkMethodByte[0] == CheckServerParametersID { + checkMethod = "checkServerParameters" + } else { + return "", nil, fmt.Errorf("invalid check method recieved") + } + + isValid := make([]byte, 1) + + if _, err := conn.Read(isValid); err != nil { + return "", nil, fmt.Errorf("couldn't read isValid byte") + } + + messageLengthBytes := make([]byte, 2) + + if _, err := conn.Read(messageLengthBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read message length") + } + + messageLength := binary.BigEndian.Uint16(messageLengthBytes) + var message string + + if messageLength != 0 { + messageBytes := make([]byte, messageLength) + + if _, err := conn.Read(messageBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read message") + } + + message = string(messageBytes) + } + + return "checkParametersResponse", &CheckParametersResponse{ + Type: "checkParametersResponse", + InResponseTo: checkMethod, + IsValid: isValid[0] == 1, + Message: message, + }, nil + case BackendStatusResponseID: + isRunning := make([]byte, 1) + + if _, err := conn.Read(isRunning); err != nil { + return "", nil, fmt.Errorf("couldn't read isRunning field") + } + + statusCode := make([]byte, 1) + + if _, err := conn.Read(statusCode); err != nil { + return "", nil, fmt.Errorf("couldn't read status code field") + } + + messageLengthBytes := make([]byte, 2) + + if _, err := conn.Read(messageLengthBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read message length") + } + + messageLength := binary.BigEndian.Uint16(messageLengthBytes) + var message string + + if messageLength != 0 { + messageBytes := make([]byte, messageLength) + + if _, err := conn.Read(messageBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read message") + } + + message = string(messageBytes) + } + + return "backendStatusResponse", &BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: isRunning[0] == 1, + StatusCode: int(statusCode[0]), + Message: message, + }, nil + case BackendStatusRequestID: + return "backendStatusRequest", &BackendStatusRequest{ + Type: "backendStatusRequest", + }, nil + case ProxyStatusRequestID: + ipVersion := make([]byte, 1) + + if _, err := conn.Read(ipVersion); err != nil { + return "", nil, fmt.Errorf("couldn't read ip version") + } + + var ipSize uint8 + + if ipVersion[0] == 4 { + ipSize = IPv4Size + } else if ipVersion[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version recieved") + } + + ip := make(net.IP, ipSize) + + if _, err := conn.Read(ip); err != nil { + return "", nil, fmt.Errorf("couldn't read source IP") + } + + sourcePort := make([]byte, 2) + + if _, err := conn.Read(sourcePort); err != nil { + return "", nil, fmt.Errorf("couldn't read source port") + } + + destPort := make([]byte, 2) + + if _, err := conn.Read(destPort); err != nil { + return "", nil, fmt.Errorf("couldn't read destination port") + } + + protocolBytes := make([]byte, 1) + + if _, err := conn.Read(protocolBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read protocol") + } + + var protocol string + + if protocolBytes[0] == TCP { + protocol = "tcp" + } else if protocolBytes[1] == UDP { + protocol = "udp" + } else { + return "", nil, fmt.Errorf("invalid protocol") + } + + return "proxyStatusRequest", &ProxyStatusRequest{ + Type: "proxyStatusRequest", + SourceIP: ip.String(), + SourcePort: binary.BigEndian.Uint16(sourcePort), + DestPort: binary.BigEndian.Uint16(destPort), + Protocol: protocol, + }, nil + case ProxyStatusResponseID: + ipVersion := make([]byte, 1) + + if _, err := conn.Read(ipVersion); err != nil { + return "", nil, fmt.Errorf("couldn't read ip version") + } + + var ipSize uint8 + + if ipVersion[0] == 4 { + ipSize = IPv4Size + } else if ipVersion[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version recieved") + } + + ip := make(net.IP, ipSize) + + if _, err := conn.Read(ip); err != nil { + return "", nil, fmt.Errorf("couldn't read source IP") + } + + sourcePort := make([]byte, 2) + + if _, err := conn.Read(sourcePort); err != nil { + return "", nil, fmt.Errorf("couldn't read source port") + } + + destPort := make([]byte, 2) + + if _, err := conn.Read(destPort); err != nil { + return "", nil, fmt.Errorf("couldn't read destination port") + } + + protocolBytes := make([]byte, 1) + + if _, err := conn.Read(protocolBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read protocol") + } + + var protocol string + + if protocolBytes[0] == TCP { + protocol = "tcp" + } else if protocolBytes[1] == UDP { + protocol = "udp" + } else { + return "", nil, fmt.Errorf("invalid protocol") + } + + isActive := make([]byte, 1) + + if _, err := conn.Read(isActive); err != nil { + return "", nil, fmt.Errorf("couldn't read isActive field") + } + + return "proxyStatusResponse", &ProxyStatusResponse{ + Type: "proxyStatusResponse", + SourceIP: ip.String(), + SourcePort: binary.BigEndian.Uint16(sourcePort), + DestPort: binary.BigEndian.Uint16(destPort), + Protocol: protocol, + IsActive: isActive[0] == 1, + }, nil + case ProxyInstanceRequestID: + return "proxyInstanceRequest", &ProxyInstanceRequest{ + Type: "proxyInstanceRequest", + }, nil + case ProxyInstanceResponseID: + proxies := []*ProxyInstance{} + delimiter := make([]byte, 1) + var errorReturn error + + // Infinite loop because we don't know the length + for { + proxy, err := unmarshalIndividualProxyStruct(conn) + + if err != nil { + if err.Error() == "no data found" { + break + } + + return "", nil, err + } + + proxies = append(proxies, proxy) + + if _, err := conn.Read(delimiter); err != nil { + return "", nil, fmt.Errorf("couldn't read delimiter") + } + + if delimiter[0] == '\r' { + continue + } else if delimiter[0] == '\n' { + break + } else { + // WTF? This shouldn't happen. Break out and return, but give an error + errorReturn = fmt.Errorf("invalid delimiter recieved while processing stream") + break + } + } + + return "proxyInstanceResponse", &ProxyInstanceResponse{ + Type: "proxyInstanceResponse", + Proxies: proxies, + }, errorReturn + case ProxyConnectionsRequestID: + return "proxyConnectionsRequest", &ProxyConnectionsRequest{ + Type: "proxyConnectionsRequest", + }, nil + } + + return "", nil, fmt.Errorf("couldn't match command ID") +} From a35602a6f24ad6e2c03fcc2e1dc9f9a3446a0f06 Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 24 Jan 2025 13:26:25 -0500 Subject: [PATCH 05/24] chore: Initialize sshappbackend. --- .gitignore | 4 +- backend/build.sh | 53 ++++++++++++++++------- backend/sshappbackend/local-code/main.go | 7 +++ backend/sshappbackend/remote-code/main.go | 7 +++ 4 files changed, 54 insertions(+), 17 deletions(-) create mode 100644 backend/sshappbackend/local-code/main.go create mode 100644 backend/sshappbackend/remote-code/main.go diff --git a/.gitignore b/.gitignore index cf34e00..7e60755 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ # Go artifacts +backend/api/api backend/sshbackend/sshbackend backend/dummybackend/dummybackend +backend/sshappbackend/remote-code/bin +backend/sshappbackend/local-code/sshappbackend backend/externalbackendlauncher/externalbackendlauncher -backend/api/api frontend/frontend # Backup artifacts diff --git a/backend/build.sh b/backend/build.sh index ecb7537..6e52fef 100755 --- a/backend/build.sh +++ b/backend/build.sh @@ -1,20 +1,41 @@ #!/usr/bin/env bash -pushd sshbackend -GOOS=linux go build . -strip sshbackend -popd +pushd sshbackend > /dev/null +echo "building sshbackend" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null -pushd dummybackend -GOOS=linux go build . -strip dummybackend -popd +pushd dummybackend > /dev/null +echo "building dummybackend" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null -pushd externalbackendlauncher -go build . -strip externalbackendlauncher -popd +pushd externalbackendlauncher > /dev/null +echo "building externalbackendlauncher" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null -pushd api -GOOS=linux go build . -strip api -popd +pushd sshappbackend/remote-code > /dev/null +echo "building sshappbackend/remote-code" +if [ ! -d bin ]; then + mkdir bin +fi + +echo " - building for arm64" +GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o bin/rt-arm64 . +echo " - building for arm" +GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o bin/rt-arm . +echo " - building for amd64" +GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o bin/rt-amd64 . +echo " - building for i386" +GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o bin/rt-386 . +popd > /dev/null + +pushd sshappbackend/local-code > /dev/null +echo "building sshappbackend/local-code" +go build -ldflags="-s -w" -trimpath -o sshappbackend . +popd > /dev/null + +pushd api > /dev/null +echo "building api" +go build -ldflags="-s -w" -trimpath . +popd > /dev/null diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go new file mode 100644 index 0000000..2f8a01e --- /dev/null +++ b/backend/sshappbackend/local-code/main.go @@ -0,0 +1,7 @@ +package main + +import "fmt" + +func main() { + fmt.Println("lokkuh code") +} diff --git a/backend/sshappbackend/remote-code/main.go b/backend/sshappbackend/remote-code/main.go new file mode 100644 index 0000000..8a3f1cc --- /dev/null +++ b/backend/sshappbackend/remote-code/main.go @@ -0,0 +1,7 @@ +package main + +import "fmt" + +func main() { + fmt.Println("remottuh code") +} From ede4d528aab16c414ab7fc2546d2dd589188bb90 Mon Sep 17 00:00:00 2001 From: imterah Date: Mon, 27 Jan 2025 07:36:29 -0500 Subject: [PATCH 06/24] feature: Adds basic backend starting for sshappbackend. --- .gitignore | 2 +- backend/build.sh | 8 +- backend/externalbackendlauncher/main.go | 20 +- backend/sshappbackend/local-code/fs.go | 8 + backend/sshappbackend/local-code/logger.go | 23 ++ backend/sshappbackend/local-code/main.go | 352 ++++++++++++++++++++- go.mod | 2 + go.sum | 42 +++ 8 files changed, 434 insertions(+), 23 deletions(-) create mode 100644 backend/sshappbackend/local-code/fs.go create mode 100644 backend/sshappbackend/local-code/logger.go diff --git a/.gitignore b/.gitignore index 7e60755..d5920a3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ backend/api/api backend/sshbackend/sshbackend backend/dummybackend/dummybackend -backend/sshappbackend/remote-code/bin +backend/sshappbackend/local-code/remote-bin backend/sshappbackend/local-code/sshappbackend backend/externalbackendlauncher/externalbackendlauncher frontend/frontend diff --git a/backend/build.sh b/backend/build.sh index 6e52fef..aa5d2ca 100755 --- a/backend/build.sh +++ b/backend/build.sh @@ -21,13 +21,13 @@ if [ ! -d bin ]; then fi echo " - building for arm64" -GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o bin/rt-arm64 . +GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm64 . echo " - building for arm" -GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o bin/rt-arm . +GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm . echo " - building for amd64" -GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o bin/rt-amd64 . +GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-amd64 . echo " - building for i386" -GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o bin/rt-386 . +GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-386 . popd > /dev/null pushd sshappbackend/local-code > /dev/null diff --git a/backend/externalbackendlauncher/main.go b/backend/externalbackendlauncher/main.go index 4c66c6a..d33d6cd 100644 --- a/backend/externalbackendlauncher/main.go +++ b/backend/externalbackendlauncher/main.go @@ -21,11 +21,8 @@ type ProxyInstance struct { Protocol string `json:"protocol"` } -type WriteLogger struct { - UseError bool -} +type WriteLogger struct{} -// TODO: deprecate UseError switching func (writer WriteLogger) Write(p []byte) (n int, err error) { logSplit := strings.Split(string(p), "\n") @@ -34,11 +31,7 @@ func (writer WriteLogger) Write(p []byte) (n int, err error) { continue } - if writer.UseError { - log.Errorf("application: %s", line) - } else { - log.Infof("application: %s", line) - } + log.Infof("application: %s", line) } return len(p), err @@ -242,13 +235,8 @@ func entrypoint(cCtx *cli.Context) error { log.Debug("entering execution loop (in main goroutine)...") - stdout := WriteLogger{ - UseError: false, - } - - stderr := WriteLogger{ - UseError: true, - } + stdout := WriteLogger{} + stderr := WriteLogger{} for { log.Info("starting process...") diff --git a/backend/sshappbackend/local-code/fs.go b/backend/sshappbackend/local-code/fs.go new file mode 100644 index 0000000..7fe43e4 --- /dev/null +++ b/backend/sshappbackend/local-code/fs.go @@ -0,0 +1,8 @@ +package main + +import ( + "embed" +) + +//go:embed remote-bin +var binFiles embed.FS diff --git a/backend/sshappbackend/local-code/logger.go b/backend/sshappbackend/local-code/logger.go new file mode 100644 index 0000000..d8ed3f9 --- /dev/null +++ b/backend/sshappbackend/local-code/logger.go @@ -0,0 +1,23 @@ +package main + +import ( + "strings" + + "github.com/charmbracelet/log" +) + +type WriteLogger struct{} + +func (writer WriteLogger) Write(p []byte) (n int, err error) { + logSplit := strings.Split(string(p), "\n") + + for _, line := range logSplit { + if line == "" { + continue + } + + log.Infof("Process: %s", line) + } + + return len(p), err +} diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go index 2f8a01e..98a2370 100644 --- a/backend/sshappbackend/local-code/main.go +++ b/backend/sshappbackend/local-code/main.go @@ -1,7 +1,355 @@ package main -import "fmt" +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "strings" + "sync" + + "git.terah.dev/imterah/hermes/backend/backendutil" + "git.terah.dev/imterah/hermes/backend/commonbackend" + "github.com/charmbracelet/log" + "github.com/go-playground/validator/v10" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +type SSHAppBackendData struct { + IP string `json:"ip" validate:"required"` + Port uint16 `json:"port" validate:"required"` + Username string `json:"username" validate:"required"` + PrivateKey string `json:"privateKey" validate:"required"` + ListenOnIPs []string `json:"listenOnIPs"` +} + +type SSHAppBackend struct { + config *SSHAppBackendData + conn *ssh.Client + clients []*commonbackend.ProxyClientConnection + arrayPropMutex sync.Mutex +} + +func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { + log.Info("SSHAppBackend is initializing...") + var backendData SSHAppBackendData + + if err := json.Unmarshal(configBytes, &backendData); err != nil { + return false, err + } + + if err := validator.New().Struct(&backendData); err != nil { + return false, err + } + + backend.config = &backendData + + if len(backend.config.ListenOnIPs) == 0 { + backend.config.ListenOnIPs = []string{"0.0.0.0"} + } + + signer, err := ssh.ParsePrivateKey([]byte(backendData.PrivateKey)) + + if err != nil { + log.Warnf("Failed to initialize: %s", err.Error()) + return false, err + } + + auth := ssh.PublicKeys(signer) + + config := &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + User: backendData.Username, + Auth: []ssh.AuthMethod{ + auth, + }, + } + + conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backendData.IP, backendData.Port), config) + + if err != nil { + log.Warnf("Failed to initialize: %s", err.Error()) + return false, err + } + + backend.conn = conn + + log.Info("SSHAppBackend has connected successfully.") + log.Info("Getting CPU architecture...") + + session, err := backend.conn.NewSession() + + if err != nil { + log.Warnf("Failed to create session: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + var stdoutBuf bytes.Buffer + session.Stdout = &stdoutBuf + + err = session.Run("uname -m") + + if err != nil { + log.Warnf("Failed to run uname command: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + cpuArchBytes := make([]byte, stdoutBuf.Len()) + stdoutBuf.Read(cpuArchBytes) + + cpuArch := string(cpuArchBytes) + cpuArch = cpuArch[:len(cpuArch)-1] + + var backendBinary string + + // Ordered in (subjective) popularity + if cpuArch == "x86_64" { + backendBinary = "remote-bin/rt-amd64" + } else if cpuArch == "aarch64" { + backendBinary = "remote-bin/rt-arm64" + } else if cpuArch == "arm" { + backendBinary = "remote-bin/rt-arm" + } else if len(cpuArch) == 4 && string(cpuArch[0]) == "i" && strings.HasSuffix(cpuArch, "86") { + backendBinary = "remote-bin/rt-386" + } else { + log.Warn("Failed to determine executable to use: CPU architecture not compiled/supported currently") + conn.Close() + backend.conn = nil + return false, fmt.Errorf("CPU architecture not compiled/supported currently") + } + + log.Info("Checking if we need to copy the application...") + + var binary []byte + needsToCopyBinary := true + + session, err = backend.conn.NewSession() + + if err != nil { + log.Warnf("Failed to create session: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + session.Stdout = &stdoutBuf + + err = session.Start("[ -f /tmp/sshappbackend.runtime ] && md5sum /tmp/sshappbackend.runtime | cut -d \" \" -f 1") + + if err != nil { + log.Warnf("Failed to calculate hash of possibly existing backend: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + fileExists := stdoutBuf.Len() != 0 + + if fileExists { + remoteMD5HashStringBuf := make([]byte, stdoutBuf.Len()) + stdoutBuf.Read(remoteMD5HashStringBuf) + + remoteMD5HashString := string(remoteMD5HashStringBuf) + remoteMD5HashString = remoteMD5HashString[:len(remoteMD5HashString)-1] + + remoteMD5Hash, err := hex.DecodeString(remoteMD5HashString) + + if err != nil { + log.Warnf("Failed to decode hex: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + binary, err = binFiles.ReadFile(backendBinary) + + if err != nil { + log.Warnf("Failed to read file in the embedded FS: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, fmt.Errorf("(embedded FS): %s", err.Error()) + } + + localMD5Hash := md5.Sum(binary) + + log.Infof("remote: %s, local: %s", remoteMD5HashString, hex.EncodeToString(localMD5Hash[:])) + + if bytes.Compare(localMD5Hash[:], remoteMD5Hash) == 0 { + needsToCopyBinary = false + } + } + + if needsToCopyBinary { + log.Info("Copying binary...") + sftpInstance, err := sftp.NewClient(conn) + + if err != nil { + log.Warnf("Failed to initialize SFTP: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + defer sftpInstance.Close() + + if len(binary) == 0 { + binary, err = binFiles.ReadFile(backendBinary) + + if err != nil { + log.Warnf("Failed to read file in the embedded FS: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, fmt.Errorf("(embedded FS): %s", err.Error()) + } + } + + var file *sftp.File + + if fileExists { + file, err = sftpInstance.Create("/tmp/sshappbackend.runtime") + } else { + file, err = sftpInstance.OpenFile("/tmp/sshappbackend.runtime", os.O_WRONLY) + } + + if err != nil { + log.Warnf("Failed to create file: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + _, err = file.Write(binary) + + if err != nil { + log.Warnf("Failed to write file: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + err = file.Chmod(775) + + if err != nil { + log.Warnf("Failed to change permissions on file: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + log.Info("Done copying file.") + } else { + log.Info("Skipping copying as there's a copy on disk already.") + } + + log.Info("Starting process...") + + session, err = backend.conn.NewSession() + + if err != nil { + log.Warnf("Failed to create session: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + session.Stdout = WriteLogger{} + session.Stderr = WriteLogger{} + + go session.Run("/tmp/sshappbackend.runtime") + log.Info("SSHAppBackend has initialized successfully.") + + return true, nil +} + +func (backend *SSHAppBackend) StopBackend() (bool, error) { + err := backend.conn.Close() + + if err != nil { + return false, err + } + + return true, nil +} + +func (backend *SSHAppBackend) GetBackendStatus() (bool, error) { + return backend.conn != nil, nil +} + +func (backend *SSHAppBackend) StartProxy(command *commonbackend.AddProxy) (bool, error) { + return true, nil +} + +func (backend *SSHAppBackend) StopProxy(command *commonbackend.RemoveProxy) (bool, error) { + return false, fmt.Errorf("could not find the proxy") +} + +func (backend *SSHAppBackend) GetAllClientConnections() []*commonbackend.ProxyClientConnection { + return backend.clients +} + +func (backend *SSHAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + } +} + +func (backend *SSHAppBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { + var backendData SSHAppBackendData + + if err := json.Unmarshal(arguments, &backendData); err != nil { + return &commonbackend.CheckParametersResponse{ + IsValid: false, + Message: fmt.Sprintf("could not read json: %s", err.Error()), + } + } + + if err := validator.New().Struct(&backendData); err != nil { + return &commonbackend.CheckParametersResponse{ + IsValid: false, + Message: fmt.Sprintf("failed validation of parameters: %s", err.Error()), + } + } + + return &commonbackend.CheckParametersResponse{ + IsValid: true, + } +} func main() { - fmt.Println("lokkuh code") + logLevel := os.Getenv("HERMES_LOG_LEVEL") + + if logLevel != "" { + switch logLevel { + case "debug": + log.SetLevel(log.DebugLevel) + + case "info": + log.SetLevel(log.InfoLevel) + + case "warn": + log.SetLevel(log.WarnLevel) + + case "error": + log.SetLevel(log.ErrorLevel) + + case "fatal": + log.SetLevel(log.FatalLevel) + } + } + + backend := &SSHAppBackend{} + + application := backendutil.NewHelper(backend) + err := application.Start() + + if err != nil { + log.Fatalf("failed execution in application: %s", err.Error()) + } } diff --git a/go.mod b/go.mod index 0942d36..98f0e12 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect @@ -49,6 +50,7 @@ require ( github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/pkg/sftp v1.13.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index da8d74e..dd30942 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02 github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -86,6 +88,8 @@ github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= +github.com/pkg/sftp v1.13.7 h1:uv+I3nNJvlKZIQGSr8JVQLNHFU9YhhNpvC14Y6KgmSM= +github.com/pkg/sftp v1.13.7/go.mod h1:KMKI0t3T6hfA+lTR/ssZdunHo+uwq7ghoN09/FSu3DY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -114,23 +118,61 @@ github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w= github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg= golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.0 h1:mjIs9gYtt56AzC4ZaffQuh88TZurBGhIJMBZGSxNerQ= google.golang.org/protobuf v1.36.0/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 17e1491f96e614197c03c782a02e2cedcaadf2d3 Mon Sep 17 00:00:00 2001 From: imterah Date: Sun, 16 Feb 2025 15:02:50 -0500 Subject: [PATCH 07/24] feature: Adds basic data command support. --- backend/backendutil/application.go | 2 +- backend/backendutil/profiling_disabled.go | 2 +- backend/backendutil/profiling_enabled.go | 2 +- backend/build.sh | 10 +- backend/commonbackend/marshal.go | 250 ++---- backend/commonbackend/marshalling_test.go | 28 +- .../sshappbackend/datacommands/constants.go | 103 +++ backend/sshappbackend/datacommands/marshal.go | 323 +++++++ .../datacommands/marshalling_test.go | 828 ++++++++++++++++++ .../sshappbackend/datacommands/unmarshal.go | 435 +++++++++ .../backendutil_custom/application.go | 310 +++++++ .../backendutil_custom/structure.go | 22 + backend/sshappbackend/remote-code/main.go | 117 ++- go.mod | 4 +- 14 files changed, 2241 insertions(+), 195 deletions(-) create mode 100644 backend/sshappbackend/datacommands/constants.go create mode 100644 backend/sshappbackend/datacommands/marshal.go create mode 100644 backend/sshappbackend/datacommands/marshalling_test.go create mode 100644 backend/sshappbackend/datacommands/unmarshal.go create mode 100644 backend/sshappbackend/remote-code/backendutil_custom/application.go create mode 100644 backend/sshappbackend/remote-code/backendutil_custom/structure.go diff --git a/backend/backendutil/application.go b/backend/backendutil/application.go index 802e8c7..c6797ce 100644 --- a/backend/backendutil/application.go +++ b/backend/backendutil/application.go @@ -18,7 +18,7 @@ type BackendApplicationHelper struct { func (helper *BackendApplicationHelper) Start() error { log.Debug("BackendApplicationHelper is starting") - err := configureAndLaunchBackgroundProfilingTasks() + err := ConfigureProfiling() if err != nil { return err diff --git a/backend/backendutil/profiling_disabled.go b/backend/backendutil/profiling_disabled.go index d93cbdd..8538407 100644 --- a/backend/backendutil/profiling_disabled.go +++ b/backend/backendutil/profiling_disabled.go @@ -4,6 +4,6 @@ package backendutil var endProfileFunc func() -func configureAndLaunchBackgroundProfilingTasks() error { +func ConfigureProfiling() error { return nil } diff --git a/backend/backendutil/profiling_enabled.go b/backend/backendutil/profiling_enabled.go index a3be527..6fcb189 100644 --- a/backend/backendutil/profiling_enabled.go +++ b/backend/backendutil/profiling_enabled.go @@ -15,7 +15,7 @@ import ( "golang.org/x/exp/rand" ) -func configureAndLaunchBackgroundProfilingTasks() error { +func ConfigureProfiling() error { profilingMode, err := os.ReadFile("/tmp/hermes.backendlauncher.profilebackends") if err != nil && errors.Is(err, os.ErrNotExist) { diff --git a/backend/build.sh b/backend/build.sh index aa5d2ca..413455a 100755 --- a/backend/build.sh +++ b/backend/build.sh @@ -20,14 +20,16 @@ if [ ! -d bin ]; then mkdir bin fi +# Disable dynamic linking by disabling CGo. +# We need to make the remote code as generic as possible, so we do this echo " - building for arm64" -GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm64 . +CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm64 . echo " - building for arm" -GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm . +CGO_ENABLED=0 GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm . echo " - building for amd64" -GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-amd64 . +CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-amd64 . echo " - building for i386" -GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-386 . +CGO_ENABLED=0 GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-386 . popd > /dev/null pushd sshappbackend/local-code > /dev/null diff --git a/backend/commonbackend/marshal.go b/backend/commonbackend/marshal.go index 6baf02e..4496494 100644 --- a/backend/commonbackend/marshal.go +++ b/backend/commonbackend/marshal.go @@ -84,37 +84,19 @@ func marshalIndividualProxyStruct(conn *ProxyInstance) ([]byte, error) { return proxyBlock, nil } -func Marshal(commandType string, command interface{}) ([]byte, error) { - switch commandType { - case "start": - startCommand, ok := command.(*Start) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - startCommandBytes := make([]byte, 1+2+len(startCommand.Arguments)) +func Marshal(_ string, command interface{}) ([]byte, error) { + switch command := command.(type) { + case *Start: + startCommandBytes := make([]byte, 1+2+len(command.Arguments)) startCommandBytes[0] = StartID - binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(startCommand.Arguments))) - copy(startCommandBytes[3:], startCommand.Arguments) + binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(command.Arguments))) + copy(startCommandBytes[3:], command.Arguments) return startCommandBytes, nil - case "stop": - _, ok := command.(*Stop) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *Stop: return []byte{StopID}, nil - case "addProxy": - addConnectionCommand, ok := command.(*AddProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(addConnectionCommand.SourceIP) + case *AddProxy: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -134,14 +116,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { copy(addConnectionBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], addConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], addConnectionCommand.DestPort) + binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if addConnectionCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if addConnectionCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -150,14 +132,8 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { addConnectionBytes[6+len(ipBytes)] = protocol return addConnectionBytes, nil - case "removeProxy": - removeConnectionCommand, ok := command.(*RemoveProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(removeConnectionCommand.SourceIP) + case *RemoveProxy: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -175,14 +151,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { removeConnectionBytes[0] = RemoveProxyID removeConnectionBytes[1] = ipVer copy(removeConnectionBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], removeConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], removeConnectionCommand.DestPort) + binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if removeConnectionCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if removeConnectionCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -191,17 +167,11 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { removeConnectionBytes[6+len(ipBytes)] = protocol return removeConnectionBytes, nil - case "proxyConnectionsResponse": - allConnectionsCommand, ok := command.(*ProxyConnectionsResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - connectionsArray := make([][]byte, len(allConnectionsCommand.Connections)) + case *ProxyConnectionsResponse: + connectionsArray := make([][]byte, len(command.Connections)) totalSize := 0 - for connIndex, conn := range allConnectionsCommand.Connections { + for connIndex, conn := range command.Connections { connectionsArray[connIndex] = marshalIndividualConnectionStruct(conn) totalSize += len(connectionsArray[connIndex]) + 1 } @@ -223,14 +193,8 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { connectionCommandArray[totalSize] = '\n' return connectionCommandArray, nil - case "checkClientParameters": - checkClientCommand, ok := command.(*CheckClientParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(checkClientCommand.SourceIP) + case *CheckClientParameters: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -248,14 +212,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { checkClientBytes[0] = CheckClientParametersID checkClientBytes[1] = ipVer copy(checkClientBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], checkClientCommand.SourcePort) - binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], checkClientCommand.DestPort) + binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if checkClientCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if checkClientCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -264,31 +228,19 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { checkClientBytes[6+len(ipBytes)] = protocol return checkClientBytes, nil - case "checkServerParameters": - checkServerCommand, ok := command.(*CheckServerParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - serverCommandBytes := make([]byte, 1+2+len(checkServerCommand.Arguments)) + case *CheckServerParameters: + serverCommandBytes := make([]byte, 1+2+len(command.Arguments)) serverCommandBytes[0] = CheckServerParametersID - binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(checkServerCommand.Arguments))) - copy(serverCommandBytes[3:], checkServerCommand.Arguments) + binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(command.Arguments))) + copy(serverCommandBytes[3:], command.Arguments) return serverCommandBytes, nil - case "checkParametersResponse": - checkParametersCommand, ok := command.(*CheckParametersResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *CheckParametersResponse: var checkMethod uint8 - if checkParametersCommand.InResponseTo == "checkClientParameters" { + if command.InResponseTo == "checkClientParameters" { checkMethod = CheckClientParametersID - } else if checkParametersCommand.InResponseTo == "checkServerParameters" { + } else if command.InResponseTo == "checkServerParameters" { checkMethod = CheckServerParametersID } else { return nil, fmt.Errorf("invalid mode recieved (must be either checkClientParameters or checkServerParameters)") @@ -296,68 +248,50 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { var isValid uint8 - if checkParametersCommand.IsValid { + if command.IsValid { isValid = 1 } - checkResponseBytes := make([]byte, 3+2+len(checkParametersCommand.Message)) + checkResponseBytes := make([]byte, 3+2+len(command.Message)) checkResponseBytes[0] = CheckParametersResponseID checkResponseBytes[1] = checkMethod checkResponseBytes[2] = isValid - binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(checkParametersCommand.Message))) + binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(command.Message))) - if len(checkParametersCommand.Message) != 0 { - copy(checkResponseBytes[5:], []byte(checkParametersCommand.Message)) + if len(command.Message) != 0 { + copy(checkResponseBytes[5:], []byte(command.Message)) } return checkResponseBytes, nil - case "backendStatusResponse": - backendStatusResponse, ok := command.(*BackendStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *BackendStatusResponse: var isRunning uint8 - if backendStatusResponse.IsRunning { + if command.IsRunning { isRunning = 1 } else { isRunning = 0 } - statusResponseBytes := make([]byte, 3+2+len(backendStatusResponse.Message)) + statusResponseBytes := make([]byte, 3+2+len(command.Message)) statusResponseBytes[0] = BackendStatusResponseID statusResponseBytes[1] = isRunning - statusResponseBytes[2] = byte(backendStatusResponse.StatusCode) + statusResponseBytes[2] = byte(command.StatusCode) - binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(backendStatusResponse.Message))) + binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(command.Message))) - if len(backendStatusResponse.Message) != 0 { - copy(statusResponseBytes[5:], []byte(backendStatusResponse.Message)) + if len(command.Message) != 0 { + copy(statusResponseBytes[5:], []byte(command.Message)) } return statusResponseBytes, nil - case "backendStatusRequest": - _, ok := command.(*BackendStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *BackendStatusRequest: statusRequestBytes := make([]byte, 1) statusRequestBytes[0] = BackendStatusRequestID return statusRequestBytes, nil - case "proxyStatusRequest": - proxyStatusRequest, ok := command.(*ProxyStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusRequest.SourceIP) + case *ProxyStatusRequest: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -370,37 +304,31 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { ipVer = IPv4 } - proxyStatusRequestBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) + commandBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) - proxyStatusRequestBytes[0] = ProxyStatusRequestID - proxyStatusRequestBytes[1] = ipVer + commandBytes[0] = ProxyStatusRequestID + commandBytes[1] = ipVer - copy(proxyStatusRequestBytes[2:2+len(ipBytes)], ipBytes) + copy(commandBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(proxyStatusRequestBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusRequest.SourcePort) - binary.BigEndian.PutUint16(proxyStatusRequestBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusRequest.DestPort) + binary.BigEndian.PutUint16(commandBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(commandBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if proxyStatusRequest.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if proxyStatusRequest.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") } - proxyStatusRequestBytes[6+len(ipBytes)] = protocol + commandBytes[6+len(ipBytes)] = protocol - return proxyStatusRequestBytes, nil - case "proxyStatusResponse": - proxyStatusResponse, ok := command.(*ProxyStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusResponse.SourceIP) + return commandBytes, nil + case *ProxyStatusResponse: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -413,50 +341,44 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { ipVer = IPv4 } - proxyStatusResponseBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) + commandBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) - proxyStatusResponseBytes[0] = ProxyStatusResponseID - proxyStatusResponseBytes[1] = ipVer + commandBytes[0] = ProxyStatusResponseID + commandBytes[1] = ipVer - copy(proxyStatusResponseBytes[2:2+len(ipBytes)], ipBytes) + copy(commandBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(proxyStatusResponseBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusResponse.SourcePort) - binary.BigEndian.PutUint16(proxyStatusResponseBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusResponse.DestPort) + binary.BigEndian.PutUint16(commandBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(commandBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if proxyStatusResponse.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if proxyStatusResponse.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") } - proxyStatusResponseBytes[6+len(ipBytes)] = protocol + commandBytes[6+len(ipBytes)] = protocol var isActive uint8 - if proxyStatusResponse.IsActive { + if command.IsActive { isActive = 1 } else { isActive = 0 } - proxyStatusResponseBytes[7+len(ipBytes)] = isActive + commandBytes[7+len(ipBytes)] = isActive - return proxyStatusResponseBytes, nil - case "proxyInstanceResponse": - proxyConectionResponse, ok := command.(*ProxyInstanceResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - proxyArray := make([][]byte, len(proxyConectionResponse.Proxies)) + return commandBytes, nil + case *ProxyInstanceResponse: + proxyArray := make([][]byte, len(command.Proxies)) totalSize := 0 - for proxyIndex, proxy := range proxyConectionResponse.Proxies { + for proxyIndex, proxy := range command.Proxies { var err error proxyArray[proxyIndex], err = marshalIndividualProxyStruct(proxy) @@ -485,23 +407,11 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { connectionCommandArray[totalSize] = '\n' return connectionCommandArray, nil - case "proxyInstanceRequest": - _, ok := command.(*ProxyInstanceRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *ProxyInstanceRequest: return []byte{ProxyInstanceRequestID}, nil - case "proxyConnectionsRequest": - _, ok := command.(*ProxyConnectionsRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *ProxyConnectionsRequest: return []byte{ProxyConnectionsRequestID}, nil } - return nil, fmt.Errorf("couldn't match command name") + return nil, fmt.Errorf("couldn't match command type") } diff --git a/backend/commonbackend/marshalling_test.go b/backend/commonbackend/marshalling_test.go index 1c93f94..e834041 100644 --- a/backend/commonbackend/marshalling_test.go +++ b/backend/commonbackend/marshalling_test.go @@ -9,7 +9,7 @@ import ( var logLevel = os.Getenv("HERMES_LOG_LEVEL") -func TestStartCommandMarshalSupport(t *testing.T) { +func TestStart(t *testing.T) { commandInput := &Start{ Type: "start", Arguments: []byte("Hello from automated testing"), @@ -53,7 +53,7 @@ func TestStartCommandMarshalSupport(t *testing.T) { } } -func TestStopCommandMarshalSupport(t *testing.T) { +func TestStop(t *testing.T) { commandInput := &Stop{ Type: "stop", } @@ -92,7 +92,7 @@ func TestStopCommandMarshalSupport(t *testing.T) { } } -func TestAddConnectionCommandMarshalSupport(t *testing.T) { +func TestAddConnection(t *testing.T) { commandInput := &AddProxy{ Type: "addProxy", SourceIP: "192.168.0.139", @@ -155,7 +155,7 @@ func TestAddConnectionCommandMarshalSupport(t *testing.T) { } } -func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { +func TestRemoveConnection(t *testing.T) { commandInput := &RemoveProxy{ Type: "removeProxy", SourceIP: "192.168.0.139", @@ -218,7 +218,7 @@ func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { } } -func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { +func TestGetAllConnections(t *testing.T) { commandInput := &ProxyConnectionsResponse{ Type: "proxyConnectionsResponse", Connections: []*ProxyClientConnection{ @@ -309,7 +309,7 @@ func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { } } -func TestCheckClientParametersMarshalSupport(t *testing.T) { +func TestCheckClientParameters(t *testing.T) { commandInput := &CheckClientParameters{ Type: "checkClientParameters", SourceIP: "192.168.0.139", @@ -372,7 +372,7 @@ func TestCheckClientParametersMarshalSupport(t *testing.T) { } } -func TestCheckServerParametersMarshalSupport(t *testing.T) { +func TestCheckServerParameters(t *testing.T) { commandInput := &CheckServerParameters{ Type: "checkServerParameters", Arguments: []byte("Hello from automated testing"), @@ -416,7 +416,7 @@ func TestCheckServerParametersMarshalSupport(t *testing.T) { } } -func TestCheckParametersResponseMarshalSupport(t *testing.T) { +func TestCheckParametersResponse(t *testing.T) { commandInput := &CheckParametersResponse{ Type: "checkParametersResponse", InResponseTo: "checkClientParameters", @@ -473,7 +473,7 @@ func TestCheckParametersResponseMarshalSupport(t *testing.T) { } } -func TestBackendStatusRequestMarshalSupport(t *testing.T) { +func TestBackendStatusRequest(t *testing.T) { commandInput := &BackendStatusRequest{ Type: "backendStatusRequest", } @@ -512,7 +512,7 @@ func TestBackendStatusRequestMarshalSupport(t *testing.T) { } } -func TestBackendStatusResponseMarshalSupport(t *testing.T) { +func TestBackendStatusResponse(t *testing.T) { commandInput := &BackendStatusResponse{ Type: "backendStatusResponse", IsRunning: true, @@ -569,7 +569,7 @@ func TestBackendStatusResponseMarshalSupport(t *testing.T) { } } -func TestProxyStatusRequestMarshalSupport(t *testing.T) { +func TestProxyStatusRequest(t *testing.T) { commandInput := &ProxyStatusRequest{ Type: "proxyStatusRequest", SourceIP: "192.168.0.139", @@ -632,7 +632,7 @@ func TestProxyStatusRequestMarshalSupport(t *testing.T) { } } -func TestProxyStatusResponseMarshalSupport(t *testing.T) { +func TestProxyStatusResponse(t *testing.T) { commandInput := &ProxyStatusResponse{ Type: "proxyStatusResponse", SourceIP: "192.168.0.139", @@ -701,7 +701,7 @@ func TestProxyStatusResponseMarshalSupport(t *testing.T) { } } -func TestProxyConnectionRequestMarshalSupport(t *testing.T) { +func TestProxyConnectionRequest(t *testing.T) { commandInput := &ProxyInstanceRequest{ Type: "proxyInstanceRequest", } @@ -740,7 +740,7 @@ func TestProxyConnectionRequestMarshalSupport(t *testing.T) { } } -func TestProxyConnectionResponseMarshalSupport(t *testing.T) { +func TestProxyConnectionResponse(t *testing.T) { commandInput := &ProxyInstanceResponse{ Type: "proxyInstanceResponse", Proxies: []*ProxyInstance{ diff --git a/backend/sshappbackend/datacommands/constants.go b/backend/sshappbackend/datacommands/constants.go new file mode 100644 index 0000000..7b7620b --- /dev/null +++ b/backend/sshappbackend/datacommands/constants.go @@ -0,0 +1,103 @@ +package datacommands + +type ProxyStatusRequest struct { + Type string + ProxyID uint16 +} + +type ProxyStatusResponse struct { + Type string + ProxyID uint16 + IsActive bool +} + +type RemoveProxy struct { + Type string + ProxyID uint16 +} + +type ProxyInstanceResponse struct { + Type string + Proxies []uint16 +} + +type ProxyConnectionsRequest struct { + Type string + ProxyID uint16 +} + +type ProxyConnectionsResponse struct { + Type string + Connections []uint16 +} + +type TCPConnectionOpened struct { + Type string + ProxyID uint16 + ConnectionID uint16 +} + +type TCPConnectionClosed struct { + Type string + ProxyID uint16 + ConnectionID uint16 +} + +type TCPProxyData struct { + Type string + ProxyID uint16 + ConnectionID uint16 + DataLength uint16 +} + +type UDPProxyData struct { + Type string + ProxyID uint16 + ClientIP string + ClientPort uint16 + DataLength uint16 +} + +type ProxyInformationRequest struct { + Type string + ProxyID uint16 +} + +type ProxyInformationResponse struct { + Type string + Exists bool + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type ProxyConnectionInformationRequest struct { + Type string + ProxyID uint16 + ConnectionID uint16 +} + +type ProxyConnectionInformationResponse struct { + Type string + Exists bool + ClientIP string + ClientPort uint16 +} + +const ( + ProxyStatusRequestID = iota + 100 + ProxyStatusResponseID + RemoveProxyID + ProxyInstanceResponseID + ProxyConnectionsRequestID + ProxyConnectionsResponseID + TCPConnectionOpenedID + TCPConnectionClosedID + TCPProxyDataID + UDPProxyDataID + ProxyInformationRequestID + ProxyInformationResponseID + ProxyConnectionInformationRequestID + ProxyConnectionInformationResponseID +) diff --git a/backend/sshappbackend/datacommands/marshal.go b/backend/sshappbackend/datacommands/marshal.go new file mode 100644 index 0000000..bacbda1 --- /dev/null +++ b/backend/sshappbackend/datacommands/marshal.go @@ -0,0 +1,323 @@ +package datacommands + +import ( + "encoding/binary" + "fmt" + "net" +) + +// Example size and protocol constants — adjust as needed. +const ( + IPv4Size = 4 + IPv6Size = 16 + + TCP = 1 + UDP = 2 +) + +// Marshal takes a command (pointer to one of our structs) and converts it to a byte slice. +func Marshal(_ string, command interface{}) ([]byte, error) { + switch cmd := command.(type) { + // ProxyStatusRequest: 1 byte for the command ID + 2 bytes for the ProxyID. + case *ProxyStatusRequest: + buf := make([]byte, 1+2) + + buf[0] = ProxyStatusRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyStatusResponse: 1 byte for the command ID, 2 bytes for ProxyID, and 1 byte for IsActive. + case *ProxyStatusResponse: + buf := make([]byte, 1+2+1) + + buf[0] = ProxyStatusResponseID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + if cmd.IsActive { + buf[3] = 1 + } else { + buf[3] = 0 + } + + return buf, nil + + // RemoveProxy: 1 byte for the command ID + 2 bytes for the ProxyID. + case *RemoveProxy: + buf := make([]byte, 1+2) + + buf[0] = RemoveProxyID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyConnectionsRequest: 1 byte for the command ID + 2 bytes for the ProxyID. + case *ProxyConnectionsRequest: + buf := make([]byte, 1+2) + + buf[0] = ProxyConnectionsRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyConnectionsResponse: 1 byte for the command ID + 2 bytes length of the Connections + 2 bytes for each + // number in the Connection array. + case *ProxyConnectionsResponse: + buf := make([]byte, 1+((len(cmd.Connections)+1)*2)) + + buf[0] = ProxyConnectionsResponseID + binary.BigEndian.PutUint16(buf[1:], uint16(len(cmd.Connections))) + + for connectionIndex, connection := range cmd.Connections { + binary.BigEndian.PutUint16(buf[3+(connectionIndex*2):], connection) + } + + return buf, nil + + // ProxyConnectionsResponse: 1 byte for the command ID + 2 bytes length of the Proxies + 2 bytes for each + // number in the Proxies array. + case *ProxyInstanceResponse: + buf := make([]byte, 1+((len(cmd.Proxies)+1)*2)) + + buf[0] = ProxyInstanceResponseID + binary.BigEndian.PutUint16(buf[1:], uint16(len(cmd.Proxies))) + + for connectionIndex, connection := range cmd.Proxies { + binary.BigEndian.PutUint16(buf[3+(connectionIndex*2):], connection) + } + + return buf, nil + + // TCPConnectionOpened: 1 byte for the command ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *TCPConnectionOpened: + buf := make([]byte, 1+2+2) + + buf[0] = TCPConnectionOpenedID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // TCPConnectionClosed: 1 byte for the command ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *TCPConnectionClosed: + buf := make([]byte, 1+2+2) + + buf[0] = TCPConnectionClosedID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // TCPProxyData: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + 2 bytes DataLength. + case *TCPProxyData: + buf := make([]byte, 1+2+2+2) + + buf[0] = TCPProxyDataID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + binary.BigEndian.PutUint16(buf[5:], cmd.DataLength) + + return buf, nil + + // UDPProxyData: + // Format: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + + // 1 byte IP version + IP bytes + 2 bytes ClientPort + 2 bytes DataLength. + case *UDPProxyData: + ip := net.ParseIP(cmd.ClientIP) + if ip == nil { + return nil, fmt.Errorf("invalid client IP: %v", cmd.ClientIP) + } + + var ipVer uint8 + var ipBytes []byte + + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.ClientIP) + } + + totalSize := 1 + // id + 2 + // ProxyID + 1 + // IP version + len(ipBytes) + // client IP bytes + 2 + // ClientPort + 2 // DataLength + + buf := make([]byte, totalSize) + offset := 0 + buf[offset] = UDPProxyDataID + offset++ + + binary.BigEndian.PutUint16(buf[offset:], cmd.ProxyID) + offset += 2 + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.ClientPort) + offset += 2 + + binary.BigEndian.PutUint16(buf[offset:], cmd.DataLength) + + return buf, nil + + // ProxyInformationRequest: 1 byte ID + 2 bytes ProxyID. + case *ProxyInformationRequest: + buf := make([]byte, 1+2) + buf[0] = ProxyInformationRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + return buf, nil + + // ProxyInformationResponse: + // Format: 1 byte ID + 1 byte Exists + (if exists:) + // 1 byte IP version + IP bytes + 2 bytes SourcePort + 2 bytes DestPort + 1 byte Protocol. + // (For simplicity, this marshaller always writes the IP and port info even if !Exists.) + case *ProxyInformationResponse: + if !cmd.Exists { + buf := make([]byte, 1+1) + buf[0] = ProxyInformationResponseID + buf[1] = 0 /* false */ + + return buf, nil + } + + ip := net.ParseIP(cmd.SourceIP) + + if ip == nil { + return nil, fmt.Errorf("invalid source IP: %v", cmd.SourceIP) + } + + var ipVer uint8 + var ipBytes []byte + + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.SourceIP) + } + + totalSize := 1 + // id + 1 + // Exists flag + 1 + // IP version + len(ipBytes) + + 2 + // SourcePort + 2 + // DestPort + 1 // Protocol + + buf := make([]byte, totalSize) + + offset := 0 + buf[offset] = ProxyInformationResponseID + offset++ + + // We already handle this above + buf[offset] = 1 /* true */ + offset++ + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.SourcePort) + offset += 2 + + binary.BigEndian.PutUint16(buf[offset:], cmd.DestPort) + offset += 2 + + // Encode protocol as 1 byte. + switch cmd.Protocol { + case "tcp": + buf[offset] = TCP + case "udp": + buf[offset] = UDP + default: + return nil, fmt.Errorf("invalid protocol: %v", cmd.Protocol) + } + + // offset++ (not needed since we are at the end) + return buf, nil + + // ProxyConnectionInformationRequest: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *ProxyConnectionInformationRequest: + buf := make([]byte, 1+2+2) + + buf[0] = ProxyConnectionInformationRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // ProxyConnectionInformationResponse: + // Format: 1 byte ID + 1 byte Exists + (if exists:) + // 1 byte IP version + IP bytes + 2 bytes ClientPort. + // This marshaller only writes the rest of the data if Exists. + case *ProxyConnectionInformationResponse: + if !cmd.Exists { + buf := make([]byte, 1+1) + buf[0] = ProxyConnectionInformationResponseID + buf[1] = 0 /* false */ + + return buf, nil + } + + ip := net.ParseIP(cmd.ClientIP) + + if ip == nil { + return nil, fmt.Errorf("invalid client IP: %v", cmd.ClientIP) + } + + var ipVer uint8 + var ipBytes []byte + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.ClientIP) + } + + totalSize := 1 + // id + 1 + // Exists flag + 1 + // IP version + len(ipBytes) + + 2 // ClientPort + + buf := make([]byte, totalSize) + offset := 0 + buf[offset] = ProxyConnectionInformationResponseID + offset++ + + // We already handle this above + buf[offset] = 1 /* true */ + offset++ + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.ClientPort) + + return buf, nil + + default: + return nil, fmt.Errorf("unsupported command type") + } +} diff --git a/backend/sshappbackend/datacommands/marshalling_test.go b/backend/sshappbackend/datacommands/marshalling_test.go new file mode 100644 index 0000000..81e7101 --- /dev/null +++ b/backend/sshappbackend/datacommands/marshalling_test.go @@ -0,0 +1,828 @@ +package datacommands + +import ( + "bytes" + "log" + "os" + "testing" +) + +var logLevel = os.Getenv("HERMES_LOG_LEVEL") + +func TestProxyStatusRequest(t *testing.T) { + commandInput := &ProxyStatusRequest{ + Type: "proxyStatusRequest", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyStatusResponse(t *testing.T) { + commandInput := &ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: 19132, + IsActive: true, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.IsActive != commandUnmarshalled.IsActive { + t.Fail() + log.Printf("IsActive's are not equal (orig: '%t', unmsh: '%t')", commandInput.IsActive, commandUnmarshalled.IsActive) + } +} + +func TestRemoveProxy(t *testing.T) { + commandInput := &RemoveProxy{ + Type: "removeProxy", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyConnectionsRequest(t *testing.T) { + commandInput := &ProxyConnectionsRequest{ + Type: "proxyConnectionsRequest", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyConnectionsResponse(t *testing.T) { + commandInput := &ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: []uint16{12831, 9455, 64219, 12, 32}, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + for connectionIndex, originalConnection := range commandInput.Connections { + remoteConnection := commandUnmarshalled.Connections[connectionIndex] + + if originalConnection != remoteConnection { + t.Fail() + log.Printf("(in #%d) SourceIP's are not equal (orig: %d, unmsh: %d)", connectionIndex, originalConnection, connectionIndex) + } + } +} + +func TestProxyInstanceResponse(t *testing.T) { + commandInput := &ProxyInstanceResponse{ + Type: "proxyInstanceResponse", + Proxies: []uint16{12831, 9455, 64219, 12, 32}, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + for proxyIndex, originalProxy := range commandInput.Proxies { + remoteProxy := commandUnmarshalled.Proxies[proxyIndex] + + if originalProxy != remoteProxy { + t.Fail() + log.Printf("(in #%d) Proxy IDs are not equal (orig: %d, unmsh: %d)", proxyIndex, originalProxy, remoteProxy) + } + } +} + +func TestTCPConnectionOpened(t *testing.T) { + commandInput := &TCPConnectionOpened{ + Type: "tcpConnectionOpened", + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionOpened) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestTCPConnectionClosed(t *testing.T) { + commandInput := &TCPConnectionClosed{ + Type: "tcpConnectionClosed", + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionClosed) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestTCPProxyData(t *testing.T) { + commandInput := &TCPProxyData{ + Type: "tcpProxyData", + ProxyID: 19132, + ConnectionID: 25565, + DataLength: 1234, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPProxyData) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } + + if commandInput.DataLength != commandUnmarshalled.DataLength { + t.Fail() + log.Printf("DataLength's are not equal (orig: '%d', unmsh: '%d')", commandInput.DataLength, commandUnmarshalled.DataLength) + } +} + +func TestUDPProxyData(t *testing.T) { + commandInput := &UDPProxyData{ + Type: "udpProxyData", + ProxyID: 19132, + ClientIP: "68.51.23.54", + ClientPort: 28173, + DataLength: 1234, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*UDPProxyData) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ClientIP != commandUnmarshalled.ClientIP { + t.Fail() + log.Printf("ClientIP's are not equal (orig: '%s', unmsh: '%s')", commandInput.ClientIP, commandUnmarshalled.ClientIP) + } + + if commandInput.ClientPort != commandUnmarshalled.ClientPort { + t.Fail() + log.Printf("ClientPort's are not equal (orig: '%d', unmsh: '%d')", commandInput.ClientPort, commandUnmarshalled.ClientPort) + } + + if commandInput.DataLength != commandUnmarshalled.DataLength { + t.Fail() + log.Printf("DataLength's are not equal (orig: '%d', unmsh: '%d')", commandInput.DataLength, commandUnmarshalled.DataLength) + } +} + +func TestProxyInformationRequest(t *testing.T) { + commandInput := &ProxyInformationRequest{ + Type: "proxyInformationRequest", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyInformationResponseExists(t *testing.T) { + commandInput := &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: true, + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } +} + +func TestProxyInformationResponseNoExist(t *testing.T) { + commandInput := &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: false, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } +} + +func TestProxyConnectionInformationRequest(t *testing.T) { + commandInput := &ProxyConnectionInformationRequest{ + Type: "proxyConnectionInformationRequest", + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestProxyConnectionInformationResponseExists(t *testing.T) { + commandInput := &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: true, + ClientIP: "192.168.0.139", + ClientPort: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } + + if commandInput.ClientIP != commandUnmarshalled.ClientIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.ClientIP, commandUnmarshalled.ClientIP) + } + + if commandInput.ClientPort != commandUnmarshalled.ClientPort { + t.Fail() + log.Printf("ClientPort's are not equal (orig: %d, unmsh: %d)", commandInput.ClientPort, commandUnmarshalled.ClientPort) + } +} + +func TestProxyConnectionInformationResponseNoExists(t *testing.T) { + commandInput := &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: false, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } +} diff --git a/backend/sshappbackend/datacommands/unmarshal.go b/backend/sshappbackend/datacommands/unmarshal.go new file mode 100644 index 0000000..7e97c3f --- /dev/null +++ b/backend/sshappbackend/datacommands/unmarshal.go @@ -0,0 +1,435 @@ +package datacommands + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +// Unmarshal reads from the provided connection and returns +// the message type (as a string), the unmarshalled struct, or an error. +func Unmarshal(conn io.Reader) (string, interface{}, error) { + // Every command starts with a 1-byte command ID. + header := make([]byte, 1) + if _, err := io.ReadFull(conn, header); err != nil { + return "", nil, fmt.Errorf("couldn't read command ID: %w", err) + } + + cmdID := header[0] + switch cmdID { + // ProxyStatusRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyStatusRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyStatusRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "proxyStatusRequest", &ProxyStatusRequest{ + Type: "proxyStatusRequest", + ProxyID: proxyID, + }, nil + + // ProxyStatusResponse: 1 byte ID + 2 bytes ProxyID + 1 byte IsActive. + case ProxyStatusResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyStatusResponse ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + boolBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyStatusResponse IsActive: %w", err) + } + + isActive := boolBuf[0] != 0 + + return "proxyStatusResponse", &ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: proxyID, + IsActive: isActive, + }, nil + + // RemoveProxy: 1 byte ID + 2 bytes ProxyID. + case RemoveProxyID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read RemoveProxy ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "removeProxy", &RemoveProxy{ + Type: "removeProxy", + ProxyID: proxyID, + }, nil + + // ProxyConnectionsRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyConnectionsRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionsRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "proxyConnectionsRequest", &ProxyConnectionsRequest{ + Type: "proxyConnectionsRequest", + ProxyID: proxyID, + }, nil + + // ProxyConnectionsResponse: 1 byte ID + 2 bytes Connections length + 2 bytes for each Connection in Connections. + case ProxyConnectionsResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + } + + length := binary.BigEndian.Uint16(buf) + connections := make([]uint16, length) + + var failedDuringReading error + + for connectionIndex := range connections { + if _, err := io.ReadFull(conn, buf); err != nil { + failedDuringReading = fmt.Errorf("couldn't read ProxyConnectionsResponse with position of %d: %w", connectionIndex, err) + break + } + + connections[connectionIndex] = binary.BigEndian.Uint16(buf) + } + + return "proxyConnectionsResponse", &ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: connections, + }, failedDuringReading + + // ProxyInstanceResponse: 1 byte ID + 2 bytes Proxies length + 2 bytes for each Proxy in Proxies. + case ProxyInstanceResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + } + + length := binary.BigEndian.Uint16(buf) + proxies := make([]uint16, length) + + var failedDuringReading error + + for connectionIndex := range proxies { + if _, err := io.ReadFull(conn, buf); err != nil { + failedDuringReading = fmt.Errorf("couldn't read ProxyConnectionsResponse with position of %d: %w", connectionIndex, err) + break + } + + proxies[connectionIndex] = binary.BigEndian.Uint16(buf) + } + + return "proxyInstanceResponse", &ProxyInstanceResponse{ + Type: "proxyInstanceResponse", + Proxies: proxies, + }, failedDuringReading + + // TCPConnectionOpened: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case TCPConnectionOpenedID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read TCPConnectionOpened fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return "tcpConnectionOpened", &TCPConnectionOpened{ + Type: "tcpConnectionOpened", + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // TCPConnectionClosed: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case TCPConnectionClosedID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read TCPConnectionClosed fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return "tcpConnectionClosed", &TCPConnectionClosed{ + Type: "tcpConnectionClosed", + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // TCPProxyData: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + 2 bytes DataLength. + case TCPProxyDataID: + buf := make([]byte, 2+2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read TCPProxyData fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + dataLength := binary.BigEndian.Uint16(buf[4:6]) + + return "tcpProxyData", &TCPProxyData{ + Type: "tcpProxyData", + ProxyID: proxyID, + ConnectionID: connectionID, + DataLength: dataLength, + }, nil + + // UDPProxyData: + // Format: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + + // 1 byte IP version + IP bytes + 2 bytes ClientPort + 2 bytes DataLength. + case UDPProxyDataID: + // Read 2 bytes ProxyID + 2 bytes ConnectionID. + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData ProxyID/ConnectionID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData IP version: %w", err) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else if ipVerBuf[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version received: %v", ipVerBuf[0]) + } + + // Read the IP bytes. + ipBytes := make([]byte, ipSize) + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData IP bytes: %w", err) + } + clientIP := net.IP(ipBytes).String() + + // Read ClientPort. + portBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, portBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData ClientPort: %w", err) + } + + clientPort := binary.BigEndian.Uint16(portBuf) + + // Read DataLength. + dataLengthBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, dataLengthBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData DataLength: %w", err) + } + + dataLength := binary.BigEndian.Uint16(dataLengthBuf) + + return "udpProxyData", &UDPProxyData{ + Type: "udpProxyData", + ProxyID: proxyID, + ClientIP: clientIP, + ClientPort: clientPort, + DataLength: dataLength, + }, nil + + // ProxyInformationRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyInformationRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "proxyInformationRequest", &ProxyInformationRequest{ + Type: "proxyInformationRequest", + ProxyID: proxyID, + }, nil + + // ProxyInformationResponse: + // Format: 1 byte ID + 1 byte Exists + + // 1 byte IP version + IP bytes + 2 bytes SourcePort + 2 bytes DestPort + 1 byte Protocol. + case ProxyInformationResponseID: + // Read Exists flag. + boolBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse Exists flag: %w", err) + } + + exists := boolBuf[0] != 0 + + if !exists { + return "proxyInformationResponse", &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: exists, + }, nil + } + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse IP version: %w", err) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else if ipVerBuf[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version in ProxyInformationResponse: %v", ipVerBuf[0]) + } + + // Read the source IP bytes. + ipBytes := make([]byte, ipSize) + + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse IP bytes: %w", err) + } + + sourceIP := net.IP(ipBytes).String() + + // Read SourcePort and DestPort. + portsBuf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, portsBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse ports: %w", err) + } + + sourcePort := binary.BigEndian.Uint16(portsBuf[0:2]) + destPort := binary.BigEndian.Uint16(portsBuf[2:4]) + + // Read protocol. + protoBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, protoBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse protocol: %w", err) + } + var protocol string + if protoBuf[0] == TCP { + protocol = "tcp" + } else if protoBuf[0] == UDP { + protocol = "udp" + } else { + return "", nil, fmt.Errorf("invalid protocol value in ProxyInformationResponse: %d", protoBuf[0]) + } + + return "proxyInformationResponse", &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: exists, + SourceIP: sourceIP, + SourcePort: sourcePort, + DestPort: destPort, + Protocol: protocol, + }, nil + + // ProxyConnectionInformationRequest: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case ProxyConnectionInformationRequestID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationRequest fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return "proxyConnectionInformationRequest", &ProxyConnectionInformationRequest{ + Type: "proxyConnectionInformationRequest", + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // ProxyConnectionInformationResponse: + // Format: 1 byte ID + 1 byte Exists + 1 byte IP version + IP bytes + 2 bytes ClientPort. + case ProxyConnectionInformationResponseID: + // Read Exists flag. + boolBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse Exists flag: %w", err) + } + + exists := boolBuf[0] != 0 + + if !exists { + return "proxyConnectionInformationResponse", &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: exists, + }, nil + } + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP version: %w", err) + } + + if ipVerBuf[0] != 4 && ipVerBuf[0] != 6 { + return "", nil, fmt.Errorf("invalid IP version in ProxyConnectionInformationResponse: %v", ipVerBuf[0]) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else { + ipSize = IPv6Size + } + + // Read IP bytes. + ipBytes := make([]byte, ipSize) + + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP bytes: %w", err) + } + + clientIP := net.IP(ipBytes).String() + + // Read ClientPort. + portBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, portBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse ClientPort: %w", err) + } + + clientPort := binary.BigEndian.Uint16(portBuf) + + return "proxyConnectionInformationResponse", &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: exists, + ClientIP: clientIP, + ClientPort: clientPort, + }, nil + default: + return "", nil, fmt.Errorf("unknown command id: %v", cmdID) + } +} diff --git a/backend/sshappbackend/remote-code/backendutil_custom/application.go b/backend/sshappbackend/remote-code/backendutil_custom/application.go new file mode 100644 index 0000000..6d46a7b --- /dev/null +++ b/backend/sshappbackend/remote-code/backendutil_custom/application.go @@ -0,0 +1,310 @@ +package backendutil_custom + +import ( + "fmt" + "net" + "os" + + "git.terah.dev/imterah/hermes/backend/backendutil" + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "github.com/charmbracelet/log" +) + +type BackendApplicationHelper struct { + Backend BackendInterface + SocketPath string + + socket net.Conn +} + +func (helper *BackendApplicationHelper) Start() error { + log.Debug("BackendApplicationHelper is starting") + err := backendutil.ConfigureProfiling() + + if err != nil { + return err + } + + log.Debug("Currently waiting for Unix socket connection...") + + helper.socket, err = net.Dial("unix", helper.SocketPath) + + if err != nil { + return err + } + + log.Debug("Sucessfully connected") + + for { + commandType, commandRaw, err := datacommands.Unmarshal(helper.socket) + + if err != nil && err.Error() != "couldn't match command ID" { + return err + } + + switch commandType { + case "proxyConnectionsRequest": + proxyConnectionRequest, ok := commandRaw.(*datacommands.ProxyConnectionsRequest) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + connections := helper.Backend.GetAllClientConnections(proxyConnectionRequest.ProxyID) + + serverParams := &datacommands.ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: connections, + } + + byteData, err := datacommands.Marshal(serverParams.Type, serverParams) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + case "removeProxy": + command, ok := commandRaw.(*datacommands.RemoveProxy) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err = helper.Backend.StopProxy(command) + var hasAnyFailed bool + + if !ok { + log.Warnf("failed to remove proxy (ID %d): RemoveProxy returned into failure state", command.ProxyID) + hasAnyFailed = true + } else if err != nil { + log.Warnf("failed to remove proxy (ID %d): %s", command.ProxyID, err.Error()) + hasAnyFailed = true + } + + response := &datacommands.ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: command.ProxyID, + IsActive: hasAnyFailed, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + default: + commandType, commandRaw, err := commonbackend.Unmarshal(helper.socket) + + if err != nil { + return err + } + + switch commandType { + case "start": + command, ok := commandRaw.(*commonbackend.Start) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err = helper.Backend.StartBackend(command.Arguments) + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "stop": + _, ok := commandRaw.(*commonbackend.Stop) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err = helper.Backend.StopBackend() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: !ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "backendStatusRequest": + _, ok := commandRaw.(*commonbackend.BackendStatusRequest) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err := helper.Backend.GetBackendStatus() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "addProxy": + command, ok := commandRaw.(*commonbackend.AddProxy) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + id, ok, err := helper.Backend.StartProxy(command) + var hasAnyFailed bool + + if !ok { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true + } else if err != nil { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) + hasAnyFailed = true + } + + response := &datacommands.ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: id, + IsActive: !hasAnyFailed, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "checkClientParameters": + command, ok := commandRaw.(*commonbackend.CheckClientParameters) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + resp := helper.Backend.CheckParametersForConnections(command) + resp.Type = "checkParametersResponse" + resp.InResponseTo = "checkClientParameters" + + byteData, err := commonbackend.Marshal(resp.Type, resp) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + case "checkServerParameters": + command, ok := commandRaw.(*commonbackend.CheckServerParameters) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + resp := helper.Backend.CheckParametersForBackend(command.Arguments) + resp.Type = "checkParametersResponse" + resp.InResponseTo = "checkServerParameters" + + byteData, err := commonbackend.Marshal(resp.Type, resp) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + default: + log.Warn("Unsupported command recieved: %s", commandType) + } + } + } +} + +func NewHelper(backend BackendInterface) *BackendApplicationHelper { + socketPath, ok := os.LookupEnv("HERMES_API_SOCK") + + if !ok { + log.Warn("HERMES_API_SOCK is not defined! This will cause an issue unless the backend manually overwrites it") + } + + helper := &BackendApplicationHelper{ + Backend: backend, + SocketPath: socketPath, + } + + return helper +} diff --git a/backend/sshappbackend/remote-code/backendutil_custom/structure.go b/backend/sshappbackend/remote-code/backendutil_custom/structure.go new file mode 100644 index 0000000..96bcbf8 --- /dev/null +++ b/backend/sshappbackend/remote-code/backendutil_custom/structure.go @@ -0,0 +1,22 @@ +package backendutil_custom + +import ( + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" +) + +type BackendInterface interface { + StartBackend(arguments []byte) (bool, error) + StopBackend() (bool, error) + GetBackendStatus() (bool, error) + StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) + StopProxy(command *datacommands.RemoveProxy) (bool, error) + GetAllProxies() []uint16 + ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse + GetAllClientConnections(proxyID uint16) []uint16 + ResolveConnection(connectionID uint16) *datacommands.ProxyConnectionsResponse + CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse + CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse + HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) + HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) +} diff --git a/backend/sshappbackend/remote-code/main.go b/backend/sshappbackend/remote-code/main.go index 8a3f1cc..0fd11ee 100644 --- a/backend/sshappbackend/remote-code/main.go +++ b/backend/sshappbackend/remote-code/main.go @@ -1,7 +1,120 @@ package main -import "fmt" +import ( + "os" + "sync" + + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "git.terah.dev/imterah/hermes/backend/sshappbackend/remote-code/backendutil_custom" + "github.com/charmbracelet/log" +) + +type TCPProxy struct { + proxyIDIndex uint16 + proxyIDLock sync.Mutex +} + +type UDPProxy struct { +} + +type SSHRemoteAppBackend struct { + connectionIDIndex uint16 + connectionIDLock sync.Mutex + + tcpProxies map[uint16]*TCPProxy + udpProxies map[uint16]*UDPProxy +} + +func (backend *SSHRemoteAppBackend) StartBackend(byte []byte) (bool, error) { + backend.tcpProxies = map[uint16]*TCPProxy{} + backend.udpProxies = map[uint16]*UDPProxy{} + + return true, nil +} + +func (backend *SSHRemoteAppBackend) StopBackend() (bool, error) { + return true, nil +} + +func (backend *SSHRemoteAppBackend) GetBackendStatus() (bool, error) { + return true, nil +} + +func (backend *SSHRemoteAppBackend) StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) { + return 0, true, nil +} + +func (backend *SSHRemoteAppBackend) StopProxy(command *datacommands.RemoveProxy) (bool, error) { + return true, nil +} + +func (backend *SSHRemoteAppBackend) GetAllProxies() []uint16 { + return []uint16{} +} + +func (backend *SSHRemoteAppBackend) ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse { + return &datacommands.ProxyInformationResponse{} +} + +func (backend *SSHRemoteAppBackend) GetAllClientConnections(proxyID uint16) []uint16 { + return []uint16{} +} + +func (backend *SSHRemoteAppBackend) ResolveConnection(proxyID uint16) *datacommands.ProxyConnectionsResponse { + return &datacommands.ProxyConnectionsResponse{} +} + +func (backend *SSHRemoteAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + Message: "Valid!", + } +} + +func (backend *SSHRemoteAppBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + Message: "Valid!", + } +} + +func (backend *SSHRemoteAppBackend) HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) { + +} + +func (backend *SSHRemoteAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) { + +} func main() { - fmt.Println("remottuh code") + logLevel := os.Getenv("HERMES_LOG_LEVEL") + + if logLevel != "" { + switch logLevel { + case "debug": + log.SetLevel(log.DebugLevel) + + case "info": + log.SetLevel(log.InfoLevel) + + case "warn": + log.SetLevel(log.WarnLevel) + + case "error": + log.SetLevel(log.ErrorLevel) + + case "fatal": + log.SetLevel(log.FatalLevel) + } + } + + backend := &SSHRemoteAppBackend{} + + application := backendutil_custom.NewHelper(backend) + err := application.Start() + + if err != nil { + log.Fatalf("failed execution in application: %s", err.Error()) + } } diff --git a/go.mod b/go.mod index 98f0e12..b390542 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/go-playground/validator/v10 v10.23.0 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/pkg/sftp v1.13.7 github.com/urfave/cli/v2 v2.27.5 golang.org/x/crypto v0.31.0 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/term v0.28.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.11 @@ -50,7 +52,6 @@ require ( github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/pkg/sftp v1.13.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -58,7 +59,6 @@ require ( github.com/ugorji/go/codec v1.2.12 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect golang.org/x/arch v0.12.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.29.0 // indirect From 62cc8b39ad67b507383e553af51a17d69a91a4aa Mon Sep 17 00:00:00 2001 From: imterah Date: Sun, 16 Feb 2025 18:11:01 -0500 Subject: [PATCH 08/24] chore: Cleanup code by switching to type switching instead of string switching. --- backend/backendutil/application.go | 79 ++++-------------- backend/build.sh | 8 +- .../backendutil_custom/application.go | 83 ++++--------------- 3 files changed, 37 insertions(+), 133 deletions(-) diff --git a/backend/backendutil/application.go b/backend/backendutil/application.go index c6797ce..f9fead0 100644 --- a/backend/backendutil/application.go +++ b/backend/backendutil/application.go @@ -1,7 +1,6 @@ package backendutil import ( - "fmt" "net" "os" @@ -35,21 +34,15 @@ func (helper *BackendApplicationHelper) Start() error { log.Debug("Sucessfully connected") for { - commandType, commandRaw, err := commonbackend.Unmarshal(helper.socket) + _, commandRaw, err := commonbackend.Unmarshal(helper.socket) if err != nil { return err } - switch commandType { - case "start": - command, ok := commandRaw.(*commonbackend.Start) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StartBackend(command.Arguments) + switch command := commandRaw.(type) { + case *commonbackend.Start: + ok, err := helper.Backend.StartBackend(command.Arguments) var ( message string @@ -78,13 +71,7 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "backendStatusRequest": - _, ok := commandRaw.(*commonbackend.BackendStatusRequest) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.BackendStatusRequest: ok, err := helper.Backend.GetBackendStatus() var ( @@ -114,14 +101,8 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "stop": - _, ok := commandRaw.(*commonbackend.Stop) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StopBackend() + case *commonbackend.Stop: + ok, err := helper.Backend.StopBackend() var ( message string @@ -150,14 +131,8 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "addProxy": - command, ok := commandRaw.(*commonbackend.AddProxy) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StartProxy(command) + case *commonbackend.AddProxy: + ok, err := helper.Backend.StartProxy(command) var hasAnyFailed bool if !ok { @@ -185,14 +160,8 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "removeProxy": - command, ok := commandRaw.(*commonbackend.RemoveProxy) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StopProxy(command) + case *commonbackend.RemoveProxy: + ok, err := helper.Backend.StopProxy(command) var hasAnyFailed bool if !ok { @@ -220,13 +189,7 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "proxyConnectionsRequest": - _, ok := commandRaw.(*commonbackend.ProxyConnectionsRequest) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.ProxyConnectionsRequest: connections := helper.Backend.GetAllClientConnections() serverParams := &commonbackend.ProxyConnectionsResponse{ @@ -243,13 +206,7 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } - case "checkClientParameters": - command, ok := commandRaw.(*commonbackend.CheckClientParameters) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.CheckClientParameters: resp := helper.Backend.CheckParametersForConnections(command) resp.Type = "checkParametersResponse" resp.InResponseTo = "checkClientParameters" @@ -263,13 +220,7 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } - case "checkServerParameters": - command, ok := commandRaw.(*commonbackend.CheckServerParameters) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.CheckServerParameters: resp := helper.Backend.CheckParametersForBackend(command.Arguments) resp.Type = "checkParametersResponse" resp.InResponseTo = "checkServerParameters" @@ -283,6 +234,8 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } + default: + log.Warnf("Unsupported command recieved: %T", command) } } } diff --git a/backend/build.sh b/backend/build.sh index 413455a..cee4440 100755 --- a/backend/build.sh +++ b/backend/build.sh @@ -14,12 +14,12 @@ echo "building externalbackendlauncher" go build -ldflags="-s -w" -trimpath . popd > /dev/null -pushd sshappbackend/remote-code > /dev/null -echo "building sshappbackend/remote-code" -if [ ! -d bin ]; then - mkdir bin +if [ ! -d "sshappbackend/local-code/remote-bin" ]; then + mkdir "sshappbackend/local-code/remote-bin" fi +pushd sshappbackend/remote-code > /dev/null +echo "building sshappbackend/remote-code" # Disable dynamic linking by disabling CGo. # We need to make the remote code as generic as possible, so we do this echo " - building for arm64" diff --git a/backend/sshappbackend/remote-code/backendutil_custom/application.go b/backend/sshappbackend/remote-code/backendutil_custom/application.go index 6d46a7b..e058648 100644 --- a/backend/sshappbackend/remote-code/backendutil_custom/application.go +++ b/backend/sshappbackend/remote-code/backendutil_custom/application.go @@ -1,7 +1,6 @@ package backendutil_custom import ( - "fmt" "net" "os" @@ -37,21 +36,15 @@ func (helper *BackendApplicationHelper) Start() error { log.Debug("Sucessfully connected") for { - commandType, commandRaw, err := datacommands.Unmarshal(helper.socket) + _, commandRaw, err := datacommands.Unmarshal(helper.socket) if err != nil && err.Error() != "couldn't match command ID" { return err } - switch commandType { - case "proxyConnectionsRequest": - proxyConnectionRequest, ok := commandRaw.(*datacommands.ProxyConnectionsRequest) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - connections := helper.Backend.GetAllClientConnections(proxyConnectionRequest.ProxyID) + switch command := commandRaw.(type) { + case *datacommands.ProxyConnectionsRequest: + connections := helper.Backend.GetAllClientConnections(command.ProxyID) serverParams := &datacommands.ProxyConnectionsResponse{ Type: "proxyConnectionsResponse", @@ -67,14 +60,8 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } - case "removeProxy": - command, ok := commandRaw.(*datacommands.RemoveProxy) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StopProxy(command) + case *datacommands.RemoveProxy: + ok, err := helper.Backend.StopProxy(command) var hasAnyFailed bool if !ok { @@ -100,21 +87,15 @@ func (helper *BackendApplicationHelper) Start() error { helper.socket.Write(responseMarshalled) default: - commandType, commandRaw, err := commonbackend.Unmarshal(helper.socket) + _, commandRaw, err := commonbackend.Unmarshal(helper.socket) if err != nil { return err } - switch commandType { - case "start": - command, ok := commandRaw.(*commonbackend.Start) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StartBackend(command.Arguments) + switch command := commandRaw.(type) { + case *commonbackend.Start: + ok, err := helper.Backend.StartBackend(command.Arguments) var ( message string @@ -143,14 +124,8 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "stop": - _, ok := commandRaw.(*commonbackend.Stop) - - if !ok { - return fmt.Errorf("failed to typecast") - } - - ok, err = helper.Backend.StopBackend() + case *commonbackend.Stop: + ok, err := helper.Backend.StopBackend() var ( message string @@ -179,13 +154,7 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "backendStatusRequest": - _, ok := commandRaw.(*commonbackend.BackendStatusRequest) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.BackendStatusRequest: ok, err := helper.Backend.GetBackendStatus() var ( @@ -215,13 +184,7 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "addProxy": - command, ok := commandRaw.(*commonbackend.AddProxy) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.AddProxy: id, ok, err := helper.Backend.StartProxy(command) var hasAnyFailed bool @@ -247,13 +210,7 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) - case "checkClientParameters": - command, ok := commandRaw.(*commonbackend.CheckClientParameters) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.CheckClientParameters: resp := helper.Backend.CheckParametersForConnections(command) resp.Type = "checkParametersResponse" resp.InResponseTo = "checkClientParameters" @@ -267,13 +224,7 @@ func (helper *BackendApplicationHelper) Start() error { if _, err = helper.socket.Write(byteData); err != nil { return err } - case "checkServerParameters": - command, ok := commandRaw.(*commonbackend.CheckServerParameters) - - if !ok { - return fmt.Errorf("failed to typecast") - } - + case *commonbackend.CheckServerParameters: resp := helper.Backend.CheckParametersForBackend(command.Arguments) resp.Type = "checkParametersResponse" resp.InResponseTo = "checkServerParameters" @@ -288,7 +239,7 @@ func (helper *BackendApplicationHelper) Start() error { return err } default: - log.Warn("Unsupported command recieved: %s", commandType) + log.Warnf("Unsupported command recieved: %T", command) } } } From cf90ddb104a273566f64a813cbb82663c88d3339 Mon Sep 17 00:00:00 2001 From: imterah Date: Sun, 16 Feb 2025 19:12:17 -0500 Subject: [PATCH 09/24] chore: Strip unneeded components from code. --- backend/api/backendruntime/runtime.go | 116 +-------- backend/api/controllers/v1/backends/create.go | 2 - .../api/controllers/v1/proxies/connections.go | 4 +- backend/api/controllers/v1/proxies/create.go | 1 - backend/api/controllers/v1/proxies/remove.go | 1 - backend/api/controllers/v1/proxies/start.go | 1 - backend/api/controllers/v1/proxies/stop.go | 1 - backend/api/main.go | 12 +- backend/backendutil/application.go | 26 +- backend/commonbackend/constants.go | 15 -- backend/commonbackend/marshal.go | 2 +- backend/commonbackend/marshalling_test.go | 226 +++-------------- backend/commonbackend/unmarshal.go | 161 ++++++------ backend/externalbackendlauncher/main.go | 21 +- .../sshappbackend/datacommands/constants.go | 14 - backend/sshappbackend/datacommands/marshal.go | 2 +- .../datacommands/marshalling_test.go | 240 +++--------------- .../sshappbackend/datacommands/unmarshal.go | 117 ++++----- .../backendutil_custom/application.go | 28 +- 19 files changed, 228 insertions(+), 762 deletions(-) diff --git a/backend/api/backendruntime/runtime.go b/backend/api/backendruntime/runtime.go index a65c077..52e280d 100644 --- a/backend/api/backendruntime/runtime.go +++ b/backend/api/backendruntime/runtime.go @@ -15,8 +15,8 @@ import ( "github.com/charmbracelet/log" ) -func handleCommand(commandType string, command interface{}, sock net.Conn, rtcChan chan interface{}) error { - bytes, err := commonbackend.Marshal(commandType, command) +func handleCommand(command interface{}, sock net.Conn, rtcChan chan interface{}) error { + bytes, err := commonbackend.Marshal(command) if err != nil { log.Warnf("Failed to marshal message: %s", err.Error()) @@ -32,7 +32,7 @@ func handleCommand(commandType string, command interface{}, sock net.Conn, rtcCh return fmt.Errorf("failed to write message: %s", err.Error()) } - _, data, err := commonbackend.Unmarshal(sock) + data, err := commonbackend.Unmarshal(sock) if err != nil { log.Warnf("Failed to unmarshal message: %s", err.Error()) @@ -117,9 +117,7 @@ func (runtime *Runtime) goRoutineHandler() error { // To be safe here, we have to use the proper (yet annoying) facilities to prevent cross-talk, since we're in // a goroutine, and can't talk directly. This actually has benefits, as the OuterLoop should exit on its own, if we // encounter a critical error. - statusResponse, err := runtime.ProcessCommand(&commonbackend.BackendStatusRequest{ - Type: "backendStatusRequest", - }) + statusResponse, err := runtime.ProcessCommand(&commonbackend.BackendStatusRequest{}) if err != nil { log.Warnf("Failed to get response for backend (in backend runtime keep alive): %s", err.Error()) @@ -167,110 +165,14 @@ func (runtime *Runtime) goRoutineHandler() error { continue } - switch command := messageData.Message.(type) { - case *commonbackend.AddProxy: - err := handleCommand("addProxy", command, sock, messageData.Channel) + err := handleCommand(messageData.Message, sock, messageData.Channel) - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) + if err != nil { + log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } + if strings.HasPrefix(err.Error(), "failed to write message") { + break OuterLoop } - case *commonbackend.BackendStatusRequest: - err := handleCommand("backendStatusRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.CheckClientParameters: - err := handleCommand("checkClientParameters", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.CheckServerParameters: - err := handleCommand("checkServerParameters", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.ProxyConnectionsRequest: - err := handleCommand("proxyConnectionsRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.ProxyInstanceRequest: - err := handleCommand("proxyInstanceRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.ProxyStatusRequest: - err := handleCommand("proxyStatusRequest", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.RemoveProxy: - err := handleCommand("removeProxy", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.Start: - err := handleCommand("start", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - case *commonbackend.Stop: - err := handleCommand("stop", command, sock, messageData.Channel) - - if err != nil { - log.Warnf("failed to handle command in backend runtime instance: %s", err.Error()) - - if strings.HasPrefix(err.Error(), "failed to write message") { - break OuterLoop - } - } - default: - log.Warnf("Recieved unknown command type from channel: %T", command) - messageData.Channel <- fmt.Errorf("unknown command recieved") } runtime.messageBuffer[chanIndex] = nil diff --git a/backend/api/controllers/v1/backends/create.go b/backend/api/controllers/v1/backends/create.go index 0d8c614..287f314 100644 --- a/backend/api/controllers/v1/backends/create.go +++ b/backend/api/controllers/v1/backends/create.go @@ -126,7 +126,6 @@ func CreateBackend(c *gin.Context) { } backendParamCheckResponse, err := backend.ProcessCommand(&commonbackend.CheckServerParameters{ - Type: "checkServerParameters", Arguments: backendParameters, }) @@ -216,7 +215,6 @@ func CreateBackend(c *gin.Context) { } backendStartResponse, err := backend.ProcessCommand(&commonbackend.Start{ - Type: "start", Arguments: backendParameters, }) diff --git a/backend/api/controllers/v1/proxies/connections.go b/backend/api/controllers/v1/proxies/connections.go index f46c284..5eb3db9 100644 --- a/backend/api/controllers/v1/proxies/connections.go +++ b/backend/api/controllers/v1/proxies/connections.go @@ -118,9 +118,7 @@ func GetConnections(c *gin.Context) { return } - backendResponse, err := backendRuntime.ProcessCommand(&commonbackend.ProxyConnectionsRequest{ - Type: "proxyConnectionsRequest", - }) + backendResponse, err := backendRuntime.ProcessCommand(&commonbackend.ProxyConnectionsRequest{}) if err != nil { log.Warnf("Failed to get response for backend: %s", err.Error()) diff --git a/backend/api/controllers/v1/proxies/create.go b/backend/api/controllers/v1/proxies/create.go index 20ad144..28e8c95 100644 --- a/backend/api/controllers/v1/proxies/create.go +++ b/backend/api/controllers/v1/proxies/create.go @@ -142,7 +142,6 @@ func CreateProxy(c *gin.Context) { } backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ - Type: "addProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, diff --git a/backend/api/controllers/v1/proxies/remove.go b/backend/api/controllers/v1/proxies/remove.go index 8087e29..e41fa84 100644 --- a/backend/api/controllers/v1/proxies/remove.go +++ b/backend/api/controllers/v1/proxies/remove.go @@ -111,7 +111,6 @@ func RemoveProxy(c *gin.Context) { } backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ - Type: "removeProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, diff --git a/backend/api/controllers/v1/proxies/start.go b/backend/api/controllers/v1/proxies/start.go index d0cd5e0..1573382 100644 --- a/backend/api/controllers/v1/proxies/start.go +++ b/backend/api/controllers/v1/proxies/start.go @@ -101,7 +101,6 @@ func StartProxy(c *gin.Context) { } backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ - Type: "addProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, diff --git a/backend/api/controllers/v1/proxies/stop.go b/backend/api/controllers/v1/proxies/stop.go index 1f5f525..820cec4 100644 --- a/backend/api/controllers/v1/proxies/stop.go +++ b/backend/api/controllers/v1/proxies/stop.go @@ -101,7 +101,6 @@ func StopProxy(c *gin.Context) { } backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ - Type: "removeProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, diff --git a/backend/api/main.go b/backend/api/main.go index 2969e9b..c88f1ae 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -108,8 +108,7 @@ func apiEntrypoint(cCtx *cli.Context) error { return } - marshalledStartCommand, err := commonbackend.Marshal("start", &commonbackend.Start{ - Type: "start", + marshalledStartCommand, err := commonbackend.Marshal(&commonbackend.Start{ Arguments: backendParameters, }) @@ -123,7 +122,7 @@ func apiEntrypoint(cCtx *cli.Context) error { return } - _, backendResponse, err := commonbackend.Unmarshal(conn) + backendResponse, err := commonbackend.Unmarshal(conn) if err != nil { log.Errorf("Failed to get start command response for backend #%d: %s", backend.ID, err.Error()) @@ -152,8 +151,7 @@ func apiEntrypoint(cCtx *cli.Context) error { for _, proxy := range autoStartProxies { log.Infof("Starting up route #%d for backend #%d: %s", proxy.ID, backend.ID, proxy.Name) - marhalledCommand, err := commonbackend.Marshal("addProxy", &commonbackend.AddProxy{ - Type: "addProxy", + marhalledCommand, err := commonbackend.Marshal(&commonbackend.AddProxy{ SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, @@ -170,7 +168,7 @@ func apiEntrypoint(cCtx *cli.Context) error { continue } - _, backendResponse, err := commonbackend.Unmarshal(conn) + backendResponse, err := commonbackend.Unmarshal(conn) if err != nil { log.Errorf("Failed to get response for backend #%d and route #%d: %s", proxy.BackendID, proxy.ID, err.Error()) @@ -204,7 +202,6 @@ func apiEntrypoint(cCtx *cli.Context) error { } backendStartResponse, err := backendInstance.ProcessCommand(&commonbackend.Start{ - Type: "start", Arguments: backendParameters, }) @@ -257,7 +254,6 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Infof("Starting up route #%d for backend #%d: %s", proxy.ID, backend.ID, proxy.Name) backendResponse, err := backendInstance.ProcessCommand(&commonbackend.AddProxy{ - Type: "addProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestinationPort, diff --git a/backend/backendutil/application.go b/backend/backendutil/application.go index f9fead0..7f134a2 100644 --- a/backend/backendutil/application.go +++ b/backend/backendutil/application.go @@ -34,7 +34,7 @@ func (helper *BackendApplicationHelper) Start() error { log.Debug("Sucessfully connected") for { - _, commandRaw, err := commonbackend.Unmarshal(helper.socket) + commandRaw, err := commonbackend.Unmarshal(helper.socket) if err != nil { return err @@ -57,13 +57,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -87,13 +86,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -117,13 +115,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: !ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -144,7 +141,6 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.ProxyStatusResponse{ - Type: "proxyStatusResponse", SourceIP: command.SourceIP, SourcePort: command.SourcePort, DestPort: command.DestPort, @@ -152,7 +148,7 @@ func (helper *BackendApplicationHelper) Start() error { IsActive: !hasAnyFailed, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -173,7 +169,6 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.ProxyStatusResponse{ - Type: "proxyStatusResponse", SourceIP: command.SourceIP, SourcePort: command.SourcePort, DestPort: command.DestPort, @@ -181,7 +176,7 @@ func (helper *BackendApplicationHelper) Start() error { IsActive: hasAnyFailed, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -193,11 +188,10 @@ func (helper *BackendApplicationHelper) Start() error { connections := helper.Backend.GetAllClientConnections() serverParams := &commonbackend.ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", Connections: connections, } - byteData, err := commonbackend.Marshal(serverParams.Type, serverParams) + byteData, err := commonbackend.Marshal(serverParams) if err != nil { return err @@ -208,10 +202,9 @@ func (helper *BackendApplicationHelper) Start() error { } case *commonbackend.CheckClientParameters: resp := helper.Backend.CheckParametersForConnections(command) - resp.Type = "checkParametersResponse" resp.InResponseTo = "checkClientParameters" - byteData, err := commonbackend.Marshal(resp.Type, resp) + byteData, err := commonbackend.Marshal(resp) if err != nil { return err @@ -222,10 +215,9 @@ func (helper *BackendApplicationHelper) Start() error { } case *commonbackend.CheckServerParameters: resp := helper.Backend.CheckParametersForBackend(command.Arguments) - resp.Type = "checkParametersResponse" resp.InResponseTo = "checkServerParameters" - byteData, err := commonbackend.Marshal(resp.Type, resp) + byteData, err := commonbackend.Marshal(resp) if err != nil { return err diff --git a/backend/commonbackend/constants.go b/backend/commonbackend/constants.go index cdb68f2..6d5362b 100644 --- a/backend/commonbackend/constants.go +++ b/backend/commonbackend/constants.go @@ -1,16 +1,13 @@ package commonbackend type Start struct { - Type string // Will be 'start' always Arguments []byte } type Stop struct { - Type string // Will be 'stop' always } type AddProxy struct { - Type string // Will be 'addProxy' always SourceIP string SourcePort uint16 DestPort uint16 @@ -18,7 +15,6 @@ type AddProxy struct { } type RemoveProxy struct { - Type string // Will be 'removeProxy' always SourceIP string SourcePort uint16 DestPort uint16 @@ -26,7 +22,6 @@ type RemoveProxy struct { } type ProxyStatusRequest struct { - Type string // Will be 'proxyStatusRequest' always SourceIP string SourcePort uint16 DestPort uint16 @@ -34,7 +29,6 @@ type ProxyStatusRequest struct { } type ProxyStatusResponse struct { - Type string // Will be 'proxyStatusResponse' always SourceIP string SourcePort uint16 DestPort uint16 @@ -50,27 +44,22 @@ type ProxyInstance struct { } type ProxyInstanceResponse struct { - Type string // Will be 'proxyConnectionResponse' always Proxies []*ProxyInstance // List of connections } type ProxyInstanceRequest struct { - Type string // Will be 'proxyConnectionRequest' always } type BackendStatusResponse struct { - Type string // Will be 'backendStatusResponse' always IsRunning bool // True if running, false if not running StatusCode int // Either the 'Success' or 'Failure' constant Message string // String message from the client (ex. failed to dial TCP) } type BackendStatusRequest struct { - Type string // Will be 'backendStatusRequest' always } type ProxyConnectionsRequest struct { - Type string // Will be 'proxyConnectionsRequest' always } // Client's connection to a specific proxy @@ -83,12 +72,10 @@ type ProxyClientConnection struct { } type ProxyConnectionsResponse struct { - Type string // Will be 'proxyConnectionsResponse' always Connections []*ProxyClientConnection // List of connections } type CheckClientParameters struct { - Type string // Will be 'checkClientParameters' always SourceIP string SourcePort uint16 DestPort uint16 @@ -96,13 +83,11 @@ type CheckClientParameters struct { } type CheckServerParameters struct { - Type string // Will be 'checkServerParameters' always Arguments []byte } // Sent as a response to either CheckClientParameters or CheckBackendParameters type CheckParametersResponse struct { - Type string // Will be 'checkParametersResponse' always InResponseTo string // Will be either 'checkClientParameters' or 'checkServerParameters' IsValid bool // If true, valid, and if false, invalid Message string // String message from the client (ex. failed to unmarshal JSON: x is not defined) diff --git a/backend/commonbackend/marshal.go b/backend/commonbackend/marshal.go index 4496494..7203ee3 100644 --- a/backend/commonbackend/marshal.go +++ b/backend/commonbackend/marshal.go @@ -84,7 +84,7 @@ func marshalIndividualProxyStruct(conn *ProxyInstance) ([]byte, error) { return proxyBlock, nil } -func Marshal(_ string, command interface{}) ([]byte, error) { +func Marshal(command interface{}) ([]byte, error) { switch command := command.(type) { case *Start: startCommandBytes := make([]byte, 1+2+len(command.Arguments)) diff --git a/backend/commonbackend/marshalling_test.go b/backend/commonbackend/marshalling_test.go index e834041..c2b6375 100644 --- a/backend/commonbackend/marshalling_test.go +++ b/backend/commonbackend/marshalling_test.go @@ -11,11 +11,10 @@ var logLevel = os.Getenv("HERMES_LOG_LEVEL") func TestStart(t *testing.T) { commandInput := &Start{ - Type: "start", Arguments: []byte("Hello from automated testing"), } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -26,39 +25,27 @@ func TestStart(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*Start) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) } } func TestStop(t *testing.T) { - commandInput := &Stop{ - Type: "stop", - } + commandInput := &Stop{} - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -69,39 +56,28 @@ func TestStop(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*Stop) + _, ok := commandUnmarshalledRaw.(*Stop) if !ok { t.Fatal("failed typecast") } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } } func TestAddConnection(t *testing.T) { commandInput := &AddProxy{ - Type: "addProxy", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -112,28 +88,18 @@ func TestAddConnection(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*AddProxy) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -157,14 +123,13 @@ func TestAddConnection(t *testing.T) { func TestRemoveConnection(t *testing.T) { commandInput := &RemoveProxy{ - Type: "removeProxy", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -175,28 +140,18 @@ func TestRemoveConnection(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -220,7 +175,6 @@ func TestRemoveConnection(t *testing.T) { func TestGetAllConnections(t *testing.T) { commandInput := &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", Connections: []*ProxyClientConnection{ { SourceIP: "127.0.0.1", @@ -246,7 +200,7 @@ func TestGetAllConnections(t *testing.T) { }, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -257,28 +211,18 @@ func TestGetAllConnections(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - for commandIndex, originalConnection := range commandInput.Connections { remoteConnection := commandUnmarshalled.Connections[commandIndex] @@ -311,14 +255,13 @@ func TestGetAllConnections(t *testing.T) { func TestCheckClientParameters(t *testing.T) { commandInput := &CheckClientParameters{ - Type: "checkClientParameters", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -329,28 +272,18 @@ func TestCheckClientParameters(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckClientParameters) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -374,11 +307,10 @@ func TestCheckClientParameters(t *testing.T) { func TestCheckServerParameters(t *testing.T) { commandInput := &CheckServerParameters{ - Type: "checkServerParameters", Arguments: []byte("Hello from automated testing"), } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -389,28 +321,18 @@ func TestCheckServerParameters(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckServerParameters) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if !bytes.Equal(commandInput.Arguments, commandUnmarshalled.Arguments) { log.Fatalf("Arguments are not equal (orig: '%s', unmsh: '%s')", string(commandInput.Arguments), string(commandUnmarshalled.Arguments)) } @@ -418,13 +340,12 @@ func TestCheckServerParameters(t *testing.T) { func TestCheckParametersResponse(t *testing.T) { commandInput := &CheckParametersResponse{ - Type: "checkParametersResponse", InResponseTo: "checkClientParameters", IsValid: true, Message: "Hello from automated testing", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -435,28 +356,18 @@ func TestCheckParametersResponse(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Printf("command type does not match up! (orig: %s, unmsh: %s)", commandType, commandInput.Type) - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*CheckParametersResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.InResponseTo != commandUnmarshalled.InResponseTo { t.Fail() log.Printf("InResponseTo's are not equal (orig: %s, unmsh: %s)", commandInput.InResponseTo, commandUnmarshalled.InResponseTo) @@ -474,11 +385,8 @@ func TestCheckParametersResponse(t *testing.T) { } func TestBackendStatusRequest(t *testing.T) { - commandInput := &BackendStatusRequest{ - Type: "backendStatusRequest", - } - - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandInput := &BackendStatusRequest{} + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -489,38 +397,27 @@ func TestBackendStatusRequest(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusRequest) + _, ok := commandUnmarshalledRaw.(*BackendStatusRequest) if !ok { t.Fatal("failed typecast") } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } } func TestBackendStatusResponse(t *testing.T) { commandInput := &BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: true, StatusCode: StatusFailure, Message: "Hello from automated testing", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -531,28 +428,18 @@ func TestBackendStatusResponse(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*BackendStatusResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.IsRunning != commandUnmarshalled.IsRunning { t.Fail() log.Printf("IsRunning's are not equal (orig: %t, unmsh: %t)", commandInput.IsRunning, commandUnmarshalled.IsRunning) @@ -571,14 +458,13 @@ func TestBackendStatusResponse(t *testing.T) { func TestProxyStatusRequest(t *testing.T) { commandInput := &ProxyStatusRequest{ - Type: "proxyStatusRequest", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -589,28 +475,18 @@ func TestProxyStatusRequest(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -634,7 +510,6 @@ func TestProxyStatusRequest(t *testing.T) { func TestProxyStatusResponse(t *testing.T) { commandInput := &ProxyStatusResponse{ - Type: "proxyStatusResponse", SourceIP: "192.168.0.139", SourcePort: 19132, DestPort: 19132, @@ -642,7 +517,7 @@ func TestProxyStatusResponse(t *testing.T) { IsActive: true, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -653,28 +528,18 @@ func TestProxyStatusResponse(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.SourceIP != commandUnmarshalled.SourceIP { t.Fail() log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) @@ -702,11 +567,9 @@ func TestProxyStatusResponse(t *testing.T) { } func TestProxyConnectionRequest(t *testing.T) { - commandInput := &ProxyInstanceRequest{ - Type: "proxyInstanceRequest", - } + commandInput := &ProxyInstanceRequest{} - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -717,32 +580,21 @@ func TestProxyConnectionRequest(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceRequest) + _, ok := commandUnmarshalledRaw.(*ProxyInstanceRequest) if !ok { t.Fatal("failed typecast") } - - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } } func TestProxyConnectionResponse(t *testing.T) { commandInput := &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", Proxies: []*ProxyInstance{ { SourceIP: "192.168.0.168", @@ -765,7 +617,7 @@ func TestProxyConnectionResponse(t *testing.T) { }, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -776,28 +628,18 @@ func TestProxyConnectionResponse(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - for proxyIndex, originalProxy := range commandInput.Proxies { remoteProxy := commandUnmarshalled.Proxies[proxyIndex] diff --git a/backend/commonbackend/unmarshal.go b/backend/commonbackend/unmarshal.go index b8500dd..8e338c6 100644 --- a/backend/commonbackend/unmarshal.go +++ b/backend/commonbackend/unmarshal.go @@ -142,11 +142,11 @@ func unmarshalIndividualProxyStruct(conn io.Reader) (*ProxyInstance, error) { }, nil } -func Unmarshal(conn io.Reader) (string, interface{}, error) { +func Unmarshal(conn io.Reader) (interface{}, error) { commandType := make([]byte, 1) if _, err := conn.Read(commandType); err != nil { - return "", nil, fmt.Errorf("couldn't read command") + return nil, fmt.Errorf("couldn't read command") } switch commandType[0] { @@ -154,28 +154,25 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { argumentsLength := make([]byte, 2) if _, err := conn.Read(argumentsLength); err != nil { - return "", nil, fmt.Errorf("couldn't read argument length") + return nil, fmt.Errorf("couldn't read argument length") } arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) if _, err := conn.Read(arguments); err != nil { - return "", nil, fmt.Errorf("couldn't read arguments") + return nil, fmt.Errorf("couldn't read arguments") } - return "start", &Start{ - Type: "start", + return &Start{ Arguments: arguments, }, nil case StopID: - return "stop", &Stop{ - Type: "stop", - }, nil + return &Stop{}, nil case AddProxyID: ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -185,31 +182,31 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string @@ -219,11 +216,10 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if protocolBytes[1] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "addProxy", &AddProxy{ - Type: "addProxy", + return &AddProxy{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -233,7 +229,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -243,31 +239,31 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string @@ -277,11 +273,10 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if protocolBytes[1] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "removeProxy", &RemoveProxy{ - Type: "removeProxy", + return &RemoveProxy{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -301,13 +296,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { break } - return "", nil, err + return nil, err } connections = append(connections, connection) if _, err := conn.Read(delimiter); err != nil { - return "", nil, fmt.Errorf("couldn't read delimiter") + return nil, fmt.Errorf("couldn't read delimiter") } if delimiter[0] == '\r' { @@ -321,15 +316,14 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } } - return "proxyConnectionsResponse", &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", + return &ProxyConnectionsResponse{ Connections: connections, }, errorReturn case CheckClientParametersID: ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -339,31 +333,31 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string @@ -373,11 +367,10 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if protocolBytes[1] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "checkClientParameters", &CheckClientParameters{ - Type: "checkClientParameters", + return &CheckClientParameters{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -387,24 +380,23 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { argumentsLength := make([]byte, 2) if _, err := conn.Read(argumentsLength); err != nil { - return "", nil, fmt.Errorf("couldn't read argument length") + return nil, fmt.Errorf("couldn't read argument length") } arguments := make([]byte, binary.BigEndian.Uint16(argumentsLength)) if _, err := conn.Read(arguments); err != nil { - return "", nil, fmt.Errorf("couldn't read arguments") + return nil, fmt.Errorf("couldn't read arguments") } - return "checkServerParameters", &CheckServerParameters{ - Type: "checkServerParameters", + return &CheckServerParameters{ Arguments: arguments, }, nil case CheckParametersResponseID: checkMethodByte := make([]byte, 1) if _, err := conn.Read(checkMethodByte); err != nil { - return "", nil, fmt.Errorf("couldn't read check method byte") + return nil, fmt.Errorf("couldn't read check method byte") } var checkMethod string @@ -414,19 +406,19 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if checkMethodByte[0] == CheckServerParametersID { checkMethod = "checkServerParameters" } else { - return "", nil, fmt.Errorf("invalid check method recieved") + return nil, fmt.Errorf("invalid check method recieved") } isValid := make([]byte, 1) if _, err := conn.Read(isValid); err != nil { - return "", nil, fmt.Errorf("couldn't read isValid byte") + return nil, fmt.Errorf("couldn't read isValid byte") } messageLengthBytes := make([]byte, 2) if _, err := conn.Read(messageLengthBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message length") + return nil, fmt.Errorf("couldn't read message length") } messageLength := binary.BigEndian.Uint16(messageLengthBytes) @@ -436,14 +428,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { messageBytes := make([]byte, messageLength) if _, err := conn.Read(messageBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message") + return nil, fmt.Errorf("couldn't read message") } message = string(messageBytes) } - return "checkParametersResponse", &CheckParametersResponse{ - Type: "checkParametersResponse", + return &CheckParametersResponse{ InResponseTo: checkMethod, IsValid: isValid[0] == 1, Message: message, @@ -452,19 +443,19 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { isRunning := make([]byte, 1) if _, err := conn.Read(isRunning); err != nil { - return "", nil, fmt.Errorf("couldn't read isRunning field") + return nil, fmt.Errorf("couldn't read isRunning field") } statusCode := make([]byte, 1) if _, err := conn.Read(statusCode); err != nil { - return "", nil, fmt.Errorf("couldn't read status code field") + return nil, fmt.Errorf("couldn't read status code field") } messageLengthBytes := make([]byte, 2) if _, err := conn.Read(messageLengthBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message length") + return nil, fmt.Errorf("couldn't read message length") } messageLength := binary.BigEndian.Uint16(messageLengthBytes) @@ -474,27 +465,24 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { messageBytes := make([]byte, messageLength) if _, err := conn.Read(messageBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read message") + return nil, fmt.Errorf("couldn't read message") } message = string(messageBytes) } - return "backendStatusResponse", &BackendStatusResponse{ - Type: "backendStatusResponse", + return &BackendStatusResponse{ IsRunning: isRunning[0] == 1, StatusCode: int(statusCode[0]), Message: message, }, nil case BackendStatusRequestID: - return "backendStatusRequest", &BackendStatusRequest{ - Type: "backendStatusRequest", - }, nil + return &BackendStatusRequest{}, nil case ProxyStatusRequestID: ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -504,31 +492,31 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string @@ -538,11 +526,10 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if protocolBytes[1] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } - return "proxyStatusRequest", &ProxyStatusRequest{ - Type: "proxyStatusRequest", + return &ProxyStatusRequest{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -552,7 +539,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipVersion := make([]byte, 1) if _, err := conn.Read(ipVersion); err != nil { - return "", nil, fmt.Errorf("couldn't read ip version") + return nil, fmt.Errorf("couldn't read ip version") } var ipSize uint8 @@ -562,31 +549,31 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVersion[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version recieved") + return nil, fmt.Errorf("invalid IP version recieved") } ip := make(net.IP, ipSize) if _, err := conn.Read(ip); err != nil { - return "", nil, fmt.Errorf("couldn't read source IP") + return nil, fmt.Errorf("couldn't read source IP") } sourcePort := make([]byte, 2) if _, err := conn.Read(sourcePort); err != nil { - return "", nil, fmt.Errorf("couldn't read source port") + return nil, fmt.Errorf("couldn't read source port") } destPort := make([]byte, 2) if _, err := conn.Read(destPort); err != nil { - return "", nil, fmt.Errorf("couldn't read destination port") + return nil, fmt.Errorf("couldn't read destination port") } protocolBytes := make([]byte, 1) if _, err := conn.Read(protocolBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read protocol") + return nil, fmt.Errorf("couldn't read protocol") } var protocol string @@ -596,17 +583,16 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if protocolBytes[1] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol") + return nil, fmt.Errorf("invalid protocol") } isActive := make([]byte, 1) if _, err := conn.Read(isActive); err != nil { - return "", nil, fmt.Errorf("couldn't read isActive field") + return nil, fmt.Errorf("couldn't read isActive field") } - return "proxyStatusResponse", &ProxyStatusResponse{ - Type: "proxyStatusResponse", + return &ProxyStatusResponse{ SourceIP: ip.String(), SourcePort: binary.BigEndian.Uint16(sourcePort), DestPort: binary.BigEndian.Uint16(destPort), @@ -614,9 +600,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { IsActive: isActive[0] == 1, }, nil case ProxyInstanceRequestID: - return "proxyInstanceRequest", &ProxyInstanceRequest{ - Type: "proxyInstanceRequest", - }, nil + return &ProxyInstanceRequest{}, nil case ProxyInstanceResponseID: proxies := []*ProxyInstance{} delimiter := make([]byte, 1) @@ -631,13 +615,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { break } - return "", nil, err + return nil, err } proxies = append(proxies, proxy) if _, err := conn.Read(delimiter); err != nil { - return "", nil, fmt.Errorf("couldn't read delimiter") + return nil, fmt.Errorf("couldn't read delimiter") } if delimiter[0] == '\r' { @@ -651,15 +635,12 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } } - return "proxyInstanceResponse", &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", + return &ProxyInstanceResponse{ Proxies: proxies, }, errorReturn case ProxyConnectionsRequestID: - return "proxyConnectionsRequest", &ProxyConnectionsRequest{ - Type: "proxyConnectionsRequest", - }, nil + return &ProxyConnectionsRequest{}, nil } - return "", nil, fmt.Errorf("couldn't match command ID") + return nil, fmt.Errorf("couldn't match command ID") } diff --git a/backend/externalbackendlauncher/main.go b/backend/externalbackendlauncher/main.go index d33d6cd..c196866 100644 --- a/backend/externalbackendlauncher/main.go +++ b/backend/externalbackendlauncher/main.go @@ -112,11 +112,10 @@ func entrypoint(cCtx *cli.Context) error { defer sock.Close() startCommand := &commonbackend.Start{ - Type: "start", Arguments: backendParameters, } - startMarshalledCommand, err := commonbackend.Marshal("start", startCommand) + startMarshalledCommand, err := commonbackend.Marshal(startCommand) if err != nil { log.Errorf("failed to generate start command: %s", err.Error()) @@ -128,18 +127,13 @@ func entrypoint(cCtx *cli.Context) error { continue } - commandType, commandRaw, err := commonbackend.Unmarshal(sock) + commandRaw, err := commonbackend.Unmarshal(sock) if err != nil { log.Errorf("failed to read from/unmarshal from socket: %s", err.Error()) continue } - if commandType != "backendStatusResponse" { - log.Errorf("recieved commandType '%s', expecting 'backendStatusResponse'", commandType) - continue - } - command, ok := commandRaw.(*commonbackend.BackendStatusResponse) if !ok { @@ -168,14 +162,13 @@ func entrypoint(cCtx *cli.Context) error { log.Infof("initializing proxy %s:%d -> remote:%d", proxy.SourceIP, proxy.SourcePort, proxy.DestPort) proxyAddCommand := &commonbackend.AddProxy{ - Type: "addProxy", SourceIP: proxy.SourceIP, SourcePort: proxy.SourcePort, DestPort: proxy.DestPort, Protocol: proxy.Protocol, } - marshalledProxyCommand, err := commonbackend.Marshal("addProxy", proxyAddCommand) + marshalledProxyCommand, err := commonbackend.Marshal(proxyAddCommand) if err != nil { log.Errorf("failed to generate start command: %s", err.Error()) @@ -189,7 +182,7 @@ func entrypoint(cCtx *cli.Context) error { continue } - commandType, commandRaw, err := commonbackend.Unmarshal(sock) + commandRaw, err := commonbackend.Unmarshal(sock) if err != nil { log.Errorf("failed to read from/unmarshal from socket: %s", err.Error()) @@ -197,12 +190,6 @@ func entrypoint(cCtx *cli.Context) error { continue } - if commandType != "proxyStatusResponse" { - log.Errorf("recieved commandType '%s', expecting 'proxyStatusResponse'", commandType) - hasAnyFailed = true - continue - } - command, ok := commandRaw.(*commonbackend.ProxyStatusResponse) if !ok { diff --git a/backend/sshappbackend/datacommands/constants.go b/backend/sshappbackend/datacommands/constants.go index 7b7620b..9630d43 100644 --- a/backend/sshappbackend/datacommands/constants.go +++ b/backend/sshappbackend/datacommands/constants.go @@ -1,57 +1,47 @@ package datacommands type ProxyStatusRequest struct { - Type string ProxyID uint16 } type ProxyStatusResponse struct { - Type string ProxyID uint16 IsActive bool } type RemoveProxy struct { - Type string ProxyID uint16 } type ProxyInstanceResponse struct { - Type string Proxies []uint16 } type ProxyConnectionsRequest struct { - Type string ProxyID uint16 } type ProxyConnectionsResponse struct { - Type string Connections []uint16 } type TCPConnectionOpened struct { - Type string ProxyID uint16 ConnectionID uint16 } type TCPConnectionClosed struct { - Type string ProxyID uint16 ConnectionID uint16 } type TCPProxyData struct { - Type string ProxyID uint16 ConnectionID uint16 DataLength uint16 } type UDPProxyData struct { - Type string ProxyID uint16 ClientIP string ClientPort uint16 @@ -59,12 +49,10 @@ type UDPProxyData struct { } type ProxyInformationRequest struct { - Type string ProxyID uint16 } type ProxyInformationResponse struct { - Type string Exists bool SourceIP string SourcePort uint16 @@ -73,13 +61,11 @@ type ProxyInformationResponse struct { } type ProxyConnectionInformationRequest struct { - Type string ProxyID uint16 ConnectionID uint16 } type ProxyConnectionInformationResponse struct { - Type string Exists bool ClientIP string ClientPort uint16 diff --git a/backend/sshappbackend/datacommands/marshal.go b/backend/sshappbackend/datacommands/marshal.go index bacbda1..a2c13bf 100644 --- a/backend/sshappbackend/datacommands/marshal.go +++ b/backend/sshappbackend/datacommands/marshal.go @@ -16,7 +16,7 @@ const ( ) // Marshal takes a command (pointer to one of our structs) and converts it to a byte slice. -func Marshal(_ string, command interface{}) ([]byte, error) { +func Marshal(command interface{}) ([]byte, error) { switch cmd := command.(type) { // ProxyStatusRequest: 1 byte for the command ID + 2 bytes for the ProxyID. case *ProxyStatusRequest: diff --git a/backend/sshappbackend/datacommands/marshalling_test.go b/backend/sshappbackend/datacommands/marshalling_test.go index 81e7101..5b2e5ab 100644 --- a/backend/sshappbackend/datacommands/marshalling_test.go +++ b/backend/sshappbackend/datacommands/marshalling_test.go @@ -11,11 +11,10 @@ var logLevel = os.Getenv("HERMES_LOG_LEVEL") func TestProxyStatusRequest(t *testing.T) { commandInput := &ProxyStatusRequest{ - Type: "proxyStatusRequest", ProxyID: 19132, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -26,28 +25,18 @@ func TestProxyStatusRequest(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -56,12 +45,11 @@ func TestProxyStatusRequest(t *testing.T) { func TestProxyStatusResponse(t *testing.T) { commandInput := &ProxyStatusResponse{ - Type: "proxyStatusResponse", ProxyID: 19132, IsActive: true, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -72,28 +60,18 @@ func TestProxyStatusResponse(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -107,11 +85,10 @@ func TestProxyStatusResponse(t *testing.T) { func TestRemoveProxy(t *testing.T) { commandInput := &RemoveProxy{ - Type: "removeProxy", ProxyID: 19132, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -122,28 +99,18 @@ func TestRemoveProxy(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -152,11 +119,10 @@ func TestRemoveProxy(t *testing.T) { func TestProxyConnectionsRequest(t *testing.T) { commandInput := &ProxyConnectionsRequest{ - Type: "proxyConnectionsRequest", ProxyID: 19132, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -167,28 +133,18 @@ func TestProxyConnectionsRequest(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsRequest) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -197,11 +153,10 @@ func TestProxyConnectionsRequest(t *testing.T) { func TestProxyConnectionsResponse(t *testing.T) { commandInput := &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", Connections: []uint16{12831, 9455, 64219, 12, 32}, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -212,28 +167,18 @@ func TestProxyConnectionsResponse(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - for connectionIndex, originalConnection := range commandInput.Connections { remoteConnection := commandUnmarshalled.Connections[connectionIndex] @@ -246,11 +191,10 @@ func TestProxyConnectionsResponse(t *testing.T) { func TestProxyInstanceResponse(t *testing.T) { commandInput := &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", Proxies: []uint16{12831, 9455, 64219, 12, 32}, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -261,28 +205,18 @@ func TestProxyInstanceResponse(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - for proxyIndex, originalProxy := range commandInput.Proxies { remoteProxy := commandUnmarshalled.Proxies[proxyIndex] @@ -295,12 +229,11 @@ func TestProxyInstanceResponse(t *testing.T) { func TestTCPConnectionOpened(t *testing.T) { commandInput := &TCPConnectionOpened{ - Type: "tcpConnectionOpened", ProxyID: 19132, ConnectionID: 25565, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -311,28 +244,18 @@ func TestTCPConnectionOpened(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionOpened) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -346,12 +269,11 @@ func TestTCPConnectionOpened(t *testing.T) { func TestTCPConnectionClosed(t *testing.T) { commandInput := &TCPConnectionClosed{ - Type: "tcpConnectionClosed", ProxyID: 19132, ConnectionID: 25565, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -362,28 +284,18 @@ func TestTCPConnectionClosed(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionClosed) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -397,13 +309,12 @@ func TestTCPConnectionClosed(t *testing.T) { func TestTCPProxyData(t *testing.T) { commandInput := &TCPProxyData{ - Type: "tcpProxyData", ProxyID: 19132, ConnectionID: 25565, DataLength: 1234, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -414,28 +325,18 @@ func TestTCPProxyData(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPProxyData) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -454,14 +355,13 @@ func TestTCPProxyData(t *testing.T) { func TestUDPProxyData(t *testing.T) { commandInput := &UDPProxyData{ - Type: "udpProxyData", ProxyID: 19132, ClientIP: "68.51.23.54", ClientPort: 28173, DataLength: 1234, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -472,28 +372,18 @@ func TestUDPProxyData(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*UDPProxyData) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -517,11 +407,10 @@ func TestUDPProxyData(t *testing.T) { func TestProxyInformationRequest(t *testing.T) { commandInput := &ProxyInformationRequest{ - Type: "proxyInformationRequest", ProxyID: 19132, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -532,28 +421,18 @@ func TestProxyInformationRequest(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationRequest) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -562,7 +441,6 @@ func TestProxyInformationRequest(t *testing.T) { func TestProxyInformationResponseExists(t *testing.T) { commandInput := &ProxyInformationResponse{ - Type: "proxyInformationResponse", Exists: true, SourceIP: "192.168.0.139", SourcePort: 19132, @@ -570,7 +448,7 @@ func TestProxyInformationResponseExists(t *testing.T) { Protocol: "tcp", } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -581,28 +459,18 @@ func TestProxyInformationResponseExists(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.Exists != commandUnmarshalled.Exists { t.Fail() log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) @@ -631,11 +499,10 @@ func TestProxyInformationResponseExists(t *testing.T) { func TestProxyInformationResponseNoExist(t *testing.T) { commandInput := &ProxyInformationResponse{ - Type: "proxyInformationResponse", Exists: false, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -646,28 +513,18 @@ func TestProxyInformationResponseNoExist(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.Exists != commandUnmarshalled.Exists { t.Fail() log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) @@ -676,12 +533,11 @@ func TestProxyInformationResponseNoExist(t *testing.T) { func TestProxyConnectionInformationRequest(t *testing.T) { commandInput := &ProxyConnectionInformationRequest{ - Type: "proxyConnectionInformationRequest", ProxyID: 19132, ConnectionID: 25565, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if logLevel == "debug" { log.Printf("Generated array contents: %v", commandMarshalled) @@ -692,28 +548,18 @@ func TestProxyConnectionInformationRequest(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationRequest) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.ProxyID != commandUnmarshalled.ProxyID { t.Fail() log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) @@ -727,13 +573,12 @@ func TestProxyConnectionInformationRequest(t *testing.T) { func TestProxyConnectionInformationResponseExists(t *testing.T) { commandInput := &ProxyConnectionInformationResponse{ - Type: "proxyConnectionInformationResponse", Exists: true, ClientIP: "192.168.0.139", ClientPort: 19132, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -744,28 +589,18 @@ func TestProxyConnectionInformationResponseExists(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.Exists != commandUnmarshalled.Exists { t.Fail() log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) @@ -784,11 +619,10 @@ func TestProxyConnectionInformationResponseExists(t *testing.T) { func TestProxyConnectionInformationResponseNoExists(t *testing.T) { commandInput := &ProxyConnectionInformationResponse{ - Type: "proxyConnectionInformationResponse", Exists: false, } - commandMarshalled, err := Marshal(commandInput.Type, commandInput) + commandMarshalled, err := Marshal(commandInput) if err != nil { t.Fatal(err.Error()) @@ -799,28 +633,18 @@ func TestProxyConnectionInformationResponseNoExists(t *testing.T) { } buf := bytes.NewBuffer(commandMarshalled) - commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + commandUnmarshalledRaw, err := Unmarshal(buf) if err != nil { t.Fatal(err.Error()) } - if commandType != commandInput.Type { - t.Fail() - log.Print("command type does not match up!") - } - commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) if !ok { t.Fatal("failed typecast") } - if commandInput.Type != commandUnmarshalled.Type { - t.Fail() - log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) - } - if commandInput.Exists != commandUnmarshalled.Exists { t.Fail() log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) diff --git a/backend/sshappbackend/datacommands/unmarshal.go b/backend/sshappbackend/datacommands/unmarshal.go index 7e97c3f..d9d0523 100644 --- a/backend/sshappbackend/datacommands/unmarshal.go +++ b/backend/sshappbackend/datacommands/unmarshal.go @@ -9,11 +9,12 @@ import ( // Unmarshal reads from the provided connection and returns // the message type (as a string), the unmarshalled struct, or an error. -func Unmarshal(conn io.Reader) (string, interface{}, error) { +func Unmarshal(conn io.Reader) (interface{}, error) { // Every command starts with a 1-byte command ID. header := make([]byte, 1) + if _, err := io.ReadFull(conn, header); err != nil { - return "", nil, fmt.Errorf("couldn't read command ID: %w", err) + return nil, fmt.Errorf("couldn't read command ID: %w", err) } cmdID := header[0] @@ -23,13 +24,12 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyStatusRequest ProxyID: %w", err) + return nil, fmt.Errorf("couldn't read ProxyStatusRequest ProxyID: %w", err) } proxyID := binary.BigEndian.Uint16(buf) - return "proxyStatusRequest", &ProxyStatusRequest{ - Type: "proxyStatusRequest", + return &ProxyStatusRequest{ ProxyID: proxyID, }, nil @@ -38,20 +38,19 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyStatusResponse ProxyID: %w", err) + return nil, fmt.Errorf("couldn't read ProxyStatusResponse ProxyID: %w", err) } proxyID := binary.BigEndian.Uint16(buf) boolBuf := make([]byte, 1) if _, err := io.ReadFull(conn, boolBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyStatusResponse IsActive: %w", err) + return nil, fmt.Errorf("couldn't read ProxyStatusResponse IsActive: %w", err) } isActive := boolBuf[0] != 0 - return "proxyStatusResponse", &ProxyStatusResponse{ - Type: "proxyStatusResponse", + return &ProxyStatusResponse{ ProxyID: proxyID, IsActive: isActive, }, nil @@ -61,13 +60,12 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read RemoveProxy ProxyID: %w", err) + return nil, fmt.Errorf("couldn't read RemoveProxy ProxyID: %w", err) } proxyID := binary.BigEndian.Uint16(buf) - return "removeProxy", &RemoveProxy{ - Type: "removeProxy", + return &RemoveProxy{ ProxyID: proxyID, }, nil @@ -76,13 +74,12 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionsRequest ProxyID: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionsRequest ProxyID: %w", err) } proxyID := binary.BigEndian.Uint16(buf) - return "proxyConnectionsRequest", &ProxyConnectionsRequest{ - Type: "proxyConnectionsRequest", + return &ProxyConnectionsRequest{ ProxyID: proxyID, }, nil @@ -91,7 +88,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) } length := binary.BigEndian.Uint16(buf) @@ -108,8 +105,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { connections[connectionIndex] = binary.BigEndian.Uint16(buf) } - return "proxyConnectionsResponse", &ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", + return &ProxyConnectionsResponse{ Connections: connections, }, failedDuringReading @@ -118,7 +114,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) } length := binary.BigEndian.Uint16(buf) @@ -135,8 +131,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { proxies[connectionIndex] = binary.BigEndian.Uint16(buf) } - return "proxyInstanceResponse", &ProxyInstanceResponse{ - Type: "proxyInstanceResponse", + return &ProxyInstanceResponse{ Proxies: proxies, }, failedDuringReading @@ -145,14 +140,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2+2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read TCPConnectionOpened fields: %w", err) + return nil, fmt.Errorf("couldn't read TCPConnectionOpened fields: %w", err) } proxyID := binary.BigEndian.Uint16(buf[0:2]) connectionID := binary.BigEndian.Uint16(buf[2:4]) - return "tcpConnectionOpened", &TCPConnectionOpened{ - Type: "tcpConnectionOpened", + return &TCPConnectionOpened{ ProxyID: proxyID, ConnectionID: connectionID, }, nil @@ -162,14 +156,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2+2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read TCPConnectionClosed fields: %w", err) + return nil, fmt.Errorf("couldn't read TCPConnectionClosed fields: %w", err) } proxyID := binary.BigEndian.Uint16(buf[0:2]) connectionID := binary.BigEndian.Uint16(buf[2:4]) - return "tcpConnectionClosed", &TCPConnectionClosed{ - Type: "tcpConnectionClosed", + return &TCPConnectionClosed{ ProxyID: proxyID, ConnectionID: connectionID, }, nil @@ -179,15 +172,14 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2+2+2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read TCPProxyData fields: %w", err) + return nil, fmt.Errorf("couldn't read TCPProxyData fields: %w", err) } proxyID := binary.BigEndian.Uint16(buf[0:2]) connectionID := binary.BigEndian.Uint16(buf[2:4]) dataLength := binary.BigEndian.Uint16(buf[4:6]) - return "tcpProxyData", &TCPProxyData{ - Type: "tcpProxyData", + return &TCPProxyData{ ProxyID: proxyID, ConnectionID: connectionID, DataLength: dataLength, @@ -201,7 +193,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read UDPProxyData ProxyID/ConnectionID: %w", err) + return nil, fmt.Errorf("couldn't read UDPProxyData ProxyID/ConnectionID: %w", err) } proxyID := binary.BigEndian.Uint16(buf) @@ -210,7 +202,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipVerBuf := make([]byte, 1) if _, err := io.ReadFull(conn, ipVerBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read UDPProxyData IP version: %w", err) + return nil, fmt.Errorf("couldn't read UDPProxyData IP version: %w", err) } var ipSize int @@ -220,13 +212,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVerBuf[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version received: %v", ipVerBuf[0]) + return nil, fmt.Errorf("invalid IP version received: %v", ipVerBuf[0]) } // Read the IP bytes. ipBytes := make([]byte, ipSize) if _, err := io.ReadFull(conn, ipBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read UDPProxyData IP bytes: %w", err) + return nil, fmt.Errorf("couldn't read UDPProxyData IP bytes: %w", err) } clientIP := net.IP(ipBytes).String() @@ -234,7 +226,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { portBuf := make([]byte, 2) if _, err := io.ReadFull(conn, portBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read UDPProxyData ClientPort: %w", err) + return nil, fmt.Errorf("couldn't read UDPProxyData ClientPort: %w", err) } clientPort := binary.BigEndian.Uint16(portBuf) @@ -243,13 +235,12 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { dataLengthBuf := make([]byte, 2) if _, err := io.ReadFull(conn, dataLengthBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read UDPProxyData DataLength: %w", err) + return nil, fmt.Errorf("couldn't read UDPProxyData DataLength: %w", err) } dataLength := binary.BigEndian.Uint16(dataLengthBuf) - return "udpProxyData", &UDPProxyData{ - Type: "udpProxyData", + return &UDPProxyData{ ProxyID: proxyID, ClientIP: clientIP, ClientPort: clientPort, @@ -261,13 +252,12 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyInformationRequest ProxyID: %w", err) + return nil, fmt.Errorf("couldn't read ProxyInformationRequest ProxyID: %w", err) } proxyID := binary.BigEndian.Uint16(buf) - return "proxyInformationRequest", &ProxyInformationRequest{ - Type: "proxyInformationRequest", + return &ProxyInformationRequest{ ProxyID: proxyID, }, nil @@ -279,14 +269,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { boolBuf := make([]byte, 1) if _, err := io.ReadFull(conn, boolBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse Exists flag: %w", err) + return nil, fmt.Errorf("couldn't read ProxyInformationResponse Exists flag: %w", err) } exists := boolBuf[0] != 0 if !exists { - return "proxyInformationResponse", &ProxyInformationResponse{ - Type: "proxyInformationResponse", + return &ProxyInformationResponse{ Exists: exists, }, nil } @@ -295,7 +284,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipVerBuf := make([]byte, 1) if _, err := io.ReadFull(conn, ipVerBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse IP version: %w", err) + return nil, fmt.Errorf("couldn't read ProxyInformationResponse IP version: %w", err) } var ipSize int @@ -305,14 +294,14 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { } else if ipVerBuf[0] == 6 { ipSize = IPv6Size } else { - return "", nil, fmt.Errorf("invalid IP version in ProxyInformationResponse: %v", ipVerBuf[0]) + return nil, fmt.Errorf("invalid IP version in ProxyInformationResponse: %v", ipVerBuf[0]) } // Read the source IP bytes. ipBytes := make([]byte, ipSize) if _, err := io.ReadFull(conn, ipBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse IP bytes: %w", err) + return nil, fmt.Errorf("couldn't read ProxyInformationResponse IP bytes: %w", err) } sourceIP := net.IP(ipBytes).String() @@ -321,7 +310,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { portsBuf := make([]byte, 2+2) if _, err := io.ReadFull(conn, portsBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse ports: %w", err) + return nil, fmt.Errorf("couldn't read ProxyInformationResponse ports: %w", err) } sourcePort := binary.BigEndian.Uint16(portsBuf[0:2]) @@ -331,19 +320,20 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { protoBuf := make([]byte, 1) if _, err := io.ReadFull(conn, protoBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse protocol: %w", err) + return nil, fmt.Errorf("couldn't read ProxyInformationResponse protocol: %w", err) } + var protocol string + if protoBuf[0] == TCP { protocol = "tcp" } else if protoBuf[0] == UDP { protocol = "udp" } else { - return "", nil, fmt.Errorf("invalid protocol value in ProxyInformationResponse: %d", protoBuf[0]) + return nil, fmt.Errorf("invalid protocol value in ProxyInformationResponse: %d", protoBuf[0]) } - return "proxyInformationResponse", &ProxyInformationResponse{ - Type: "proxyInformationResponse", + return &ProxyInformationResponse{ Exists: exists, SourceIP: sourceIP, SourcePort: sourcePort, @@ -356,14 +346,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { buf := make([]byte, 2+2) if _, err := io.ReadFull(conn, buf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationRequest fields: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationRequest fields: %w", err) } proxyID := binary.BigEndian.Uint16(buf[0:2]) connectionID := binary.BigEndian.Uint16(buf[2:4]) - return "proxyConnectionInformationRequest", &ProxyConnectionInformationRequest{ - Type: "proxyConnectionInformationRequest", + return &ProxyConnectionInformationRequest{ ProxyID: proxyID, ConnectionID: connectionID, }, nil @@ -374,14 +363,13 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { // Read Exists flag. boolBuf := make([]byte, 1) if _, err := io.ReadFull(conn, boolBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse Exists flag: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse Exists flag: %w", err) } exists := boolBuf[0] != 0 if !exists { - return "proxyConnectionInformationResponse", &ProxyConnectionInformationResponse{ - Type: "proxyConnectionInformationResponse", + return &ProxyConnectionInformationResponse{ Exists: exists, }, nil } @@ -390,11 +378,11 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipVerBuf := make([]byte, 1) if _, err := io.ReadFull(conn, ipVerBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP version: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP version: %w", err) } if ipVerBuf[0] != 4 && ipVerBuf[0] != 6 { - return "", nil, fmt.Errorf("invalid IP version in ProxyConnectionInformationResponse: %v", ipVerBuf[0]) + return nil, fmt.Errorf("invalid IP version in ProxyConnectionInformationResponse: %v", ipVerBuf[0]) } var ipSize int @@ -409,7 +397,7 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { ipBytes := make([]byte, ipSize) if _, err := io.ReadFull(conn, ipBytes); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP bytes: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP bytes: %w", err) } clientIP := net.IP(ipBytes).String() @@ -418,18 +406,17 @@ func Unmarshal(conn io.Reader) (string, interface{}, error) { portBuf := make([]byte, 2) if _, err := io.ReadFull(conn, portBuf); err != nil { - return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse ClientPort: %w", err) + return nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse ClientPort: %w", err) } clientPort := binary.BigEndian.Uint16(portBuf) - return "proxyConnectionInformationResponse", &ProxyConnectionInformationResponse{ - Type: "proxyConnectionInformationResponse", + return &ProxyConnectionInformationResponse{ Exists: exists, ClientIP: clientIP, ClientPort: clientPort, }, nil default: - return "", nil, fmt.Errorf("unknown command id: %v", cmdID) + return nil, fmt.Errorf("unknown command id: %v", cmdID) } } diff --git a/backend/sshappbackend/remote-code/backendutil_custom/application.go b/backend/sshappbackend/remote-code/backendutil_custom/application.go index e058648..5ad2246 100644 --- a/backend/sshappbackend/remote-code/backendutil_custom/application.go +++ b/backend/sshappbackend/remote-code/backendutil_custom/application.go @@ -36,7 +36,7 @@ func (helper *BackendApplicationHelper) Start() error { log.Debug("Sucessfully connected") for { - _, commandRaw, err := datacommands.Unmarshal(helper.socket) + commandRaw, err := datacommands.Unmarshal(helper.socket) if err != nil && err.Error() != "couldn't match command ID" { return err @@ -47,11 +47,10 @@ func (helper *BackendApplicationHelper) Start() error { connections := helper.Backend.GetAllClientConnections(command.ProxyID) serverParams := &datacommands.ProxyConnectionsResponse{ - Type: "proxyConnectionsResponse", Connections: connections, } - byteData, err := datacommands.Marshal(serverParams.Type, serverParams) + byteData, err := datacommands.Marshal(serverParams) if err != nil { return err @@ -73,12 +72,11 @@ func (helper *BackendApplicationHelper) Start() error { } response := &datacommands.ProxyStatusResponse{ - Type: "proxyStatusResponse", ProxyID: command.ProxyID, IsActive: hasAnyFailed, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := datacommands.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -87,7 +85,7 @@ func (helper *BackendApplicationHelper) Start() error { helper.socket.Write(responseMarshalled) default: - _, commandRaw, err := commonbackend.Unmarshal(helper.socket) + commandRaw, err := commonbackend.Unmarshal(helper.socket) if err != nil { return err @@ -110,13 +108,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -140,13 +137,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: !ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -170,13 +166,12 @@ func (helper *BackendApplicationHelper) Start() error { } response := &commonbackend.BackendStatusResponse{ - Type: "backendStatusResponse", IsRunning: ok, StatusCode: statusCode, Message: message, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := commonbackend.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -197,12 +192,11 @@ func (helper *BackendApplicationHelper) Start() error { } response := &datacommands.ProxyStatusResponse{ - Type: "proxyStatusResponse", ProxyID: id, IsActive: !hasAnyFailed, } - responseMarshalled, err := commonbackend.Marshal(response.Type, response) + responseMarshalled, err := datacommands.Marshal(response) if err != nil { log.Error("failed to marshal response: %s", err.Error()) @@ -212,10 +206,9 @@ func (helper *BackendApplicationHelper) Start() error { helper.socket.Write(responseMarshalled) case *commonbackend.CheckClientParameters: resp := helper.Backend.CheckParametersForConnections(command) - resp.Type = "checkParametersResponse" resp.InResponseTo = "checkClientParameters" - byteData, err := commonbackend.Marshal(resp.Type, resp) + byteData, err := commonbackend.Marshal(resp) if err != nil { return err @@ -226,10 +219,9 @@ func (helper *BackendApplicationHelper) Start() error { } case *commonbackend.CheckServerParameters: resp := helper.Backend.CheckParametersForBackend(command.Arguments) - resp.Type = "checkParametersResponse" resp.InResponseTo = "checkServerParameters" - byteData, err := commonbackend.Marshal(resp.Type, resp) + byteData, err := commonbackend.Marshal(resp) if err != nil { return err From 432d457ad7746fc3fe9247930cd173939fa4f6db Mon Sep 17 00:00:00 2001 From: imterah Date: Sun, 16 Feb 2025 21:51:33 -0500 Subject: [PATCH 10/24] feature: Adds remote implementation of code. --- .../sshappbackend/datacommands/constants.go | 1 + .../backendutil_custom/application.go | 43 +++ .../backendutil_custom/structure.go | 6 +- backend/sshappbackend/remote-code/main.go | 363 +++++++++++++++++- 4 files changed, 396 insertions(+), 17 deletions(-) diff --git a/backend/sshappbackend/datacommands/constants.go b/backend/sshappbackend/datacommands/constants.go index 9630d43..6385e98 100644 --- a/backend/sshappbackend/datacommands/constants.go +++ b/backend/sshappbackend/datacommands/constants.go @@ -1,5 +1,6 @@ package datacommands +// DO NOT USE type ProxyStatusRequest struct { ProxyID uint16 } diff --git a/backend/sshappbackend/remote-code/backendutil_custom/application.go b/backend/sshappbackend/remote-code/backendutil_custom/application.go index 5ad2246..0e00da7 100644 --- a/backend/sshappbackend/remote-code/backendutil_custom/application.go +++ b/backend/sshappbackend/remote-code/backendutil_custom/application.go @@ -1,6 +1,7 @@ package backendutil_custom import ( + "io" "net" "os" @@ -33,6 +34,8 @@ func (helper *BackendApplicationHelper) Start() error { return err } + helper.Backend.OnSocketConnection(helper.socket) + log.Debug("Sucessfully connected") for { @@ -84,6 +87,46 @@ func (helper *BackendApplicationHelper) Start() error { } helper.socket.Write(responseMarshalled) + case *datacommands.ProxyInformationRequest: + response := helper.Backend.ResolveProxy(command.ProxyID) + responseMarshalled, err := datacommands.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *datacommands.ProxyConnectionInformationRequest: + response := helper.Backend.ResolveConnection(command.ProxyID, command.ConnectionID) + responseMarshalled, err := datacommands.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *datacommands.TCPConnectionClosed: + helper.Backend.OnTCPConnectionClosed(command.ProxyID, command.ConnectionID) + case *datacommands.TCPProxyData: + bytes := make([]byte, command.DataLength) + _, err := io.ReadFull(helper.socket, bytes) + + if err != nil { + log.Warn("failed to read TCP data") + } + + helper.Backend.HandleTCPMessage(command, bytes) + case *datacommands.UDPProxyData: + bytes := make([]byte, command.DataLength) + _, err := io.ReadFull(helper.socket, bytes) + + if err != nil { + log.Warn("failed to read TCP data") + } + + helper.Backend.HandleUDPMessage(command, bytes) default: commandRaw, err := commonbackend.Unmarshal(helper.socket) diff --git a/backend/sshappbackend/remote-code/backendutil_custom/structure.go b/backend/sshappbackend/remote-code/backendutil_custom/structure.go index 96bcbf8..65c5a23 100644 --- a/backend/sshappbackend/remote-code/backendutil_custom/structure.go +++ b/backend/sshappbackend/remote-code/backendutil_custom/structure.go @@ -1,6 +1,8 @@ package backendutil_custom import ( + "net" + "git.terah.dev/imterah/hermes/backend/commonbackend" "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" ) @@ -14,9 +16,11 @@ type BackendInterface interface { GetAllProxies() []uint16 ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse GetAllClientConnections(proxyID uint16) []uint16 - ResolveConnection(connectionID uint16) *datacommands.ProxyConnectionsResponse + ResolveConnection(proxyID, connectionID uint16) *datacommands.ProxyConnectionInformationResponse CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse + OnTCPConnectionClosed(proxyID, connectionID uint16) HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) + OnSocketConnection(sock net.Conn) } diff --git a/backend/sshappbackend/remote-code/main.go b/backend/sshappbackend/remote-code/main.go index 0fd11ee..a1d9bb7 100644 --- a/backend/sshappbackend/remote-code/main.go +++ b/backend/sshappbackend/remote-code/main.go @@ -1,7 +1,12 @@ package main import ( + "errors" + "fmt" + "net" "os" + "strconv" + "strings" "sync" "git.terah.dev/imterah/hermes/backend/commonbackend" @@ -11,19 +16,27 @@ import ( ) type TCPProxy struct { - proxyIDIndex uint16 - proxyIDLock sync.Mutex -} - -type UDPProxy struct { -} - -type SSHRemoteAppBackend struct { connectionIDIndex uint16 connectionIDLock sync.Mutex + proxyInformation *commonbackend.AddProxy + connections map[uint16]net.Conn + server net.Listener +} + +type UDPProxy struct { + server *net.UDPConn + proxyInformation *commonbackend.AddProxy +} + +type SSHRemoteAppBackend struct { + proxyIDIndex uint16 + proxyIDLock sync.Mutex + tcpProxies map[uint16]*TCPProxy udpProxies map[uint16]*UDPProxy + + sock net.Conn } func (backend *SSHRemoteAppBackend) StartBackend(byte []byte) (bool, error) { @@ -34,6 +47,20 @@ func (backend *SSHRemoteAppBackend) StartBackend(byte []byte) (bool, error) { } func (backend *SSHRemoteAppBackend) StopBackend() (bool, error) { + for tcpProxyIndex, tcpProxy := range backend.tcpProxies { + for _, tcpConnection := range tcpProxy.connections { + tcpConnection.Close() + } + + tcpProxy.server.Close() + delete(backend.tcpProxies, tcpProxyIndex) + } + + for udpProxyIndex, udpProxy := range backend.udpProxies { + udpProxy.server.Close() + delete(backend.udpProxies, udpProxyIndex) + } + return true, nil } @@ -42,49 +69,353 @@ func (backend *SSHRemoteAppBackend) GetBackendStatus() (bool, error) { } func (backend *SSHRemoteAppBackend) StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) { - return 0, true, nil + // Allocate a new proxy ID + backend.proxyIDLock.Lock() + proxyID := backend.proxyIDIndex + backend.proxyIDIndex++ + backend.proxyIDLock.Unlock() + + if command.Protocol == "tcp" { + backend.tcpProxies[proxyID] = &TCPProxy{ + connections: map[uint16]net.Conn{}, + proxyInformation: command, + } + + server, err := net.Listen("tcp", fmt.Sprintf(":%d", command.DestPort)) + + if err != nil { + return 0, false, fmt.Errorf("failed to open server: %s", err.Error()) + } + + backend.tcpProxies[proxyID].server = server + + go func() { + for { + conn, err := server.Accept() + + if err != nil { + log.Warnf("failed to accept connection: %s", err.Error()) + return + } + + go func() { + backend.tcpProxies[proxyID].connectionIDLock.Lock() + connectionID := backend.tcpProxies[proxyID].connectionIDIndex + backend.tcpProxies[proxyID].connectionIDIndex++ + backend.tcpProxies[proxyID].connectionIDLock.Unlock() + + dataBuf := make([]byte, 65535) + + onConnection := &datacommands.TCPConnectionOpened{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + connectionCommandMarshalled, err := datacommands.Marshal(onConnection) + + if err != nil { + log.Errorf("failed to marshal connection message: %s", err.Error()) + } + + backend.sock.Write(connectionCommandMarshalled) + + tcpData := &datacommands.TCPProxyData{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + for { + len, err := conn.Read(dataBuf) + + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } else if err.Error() != "EOF" { + log.Warnf("failed to read from sock: %s", err.Error()) + } + + conn.Close() + break + } + + tcpData.DataLength = uint16(len) + marshalledMessageCommand, err := datacommands.Marshal(tcpData) + + if err != nil { + log.Warnf("failed to marshal message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.sock.Write(marshalledMessageCommand); err != nil { + log.Warnf("failed to send marshalled message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.sock.Write(dataBuf[:len]); err != nil { + log.Warnf("failed to send raw message data: %s", err.Error()) + + conn.Close() + break + } + } + + onDisconnect := &datacommands.TCPConnectionClosed{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + disconnectionCommandMarshalled, err := datacommands.Marshal(onDisconnect) + + if err != nil { + log.Errorf("failed to marshal disconnection message: %s", err.Error()) + } + + backend.sock.Write(disconnectionCommandMarshalled) + }() + } + }() + } else if command.Protocol == "udp" { + backend.udpProxies[proxyID] = &UDPProxy{ + proxyInformation: command, + } + + server, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv4(0, 0, 0, 0), + Port: int(command.DestPort), + }) + + if err != nil { + return 0, false, fmt.Errorf("failed to open server: %s", err.Error()) + } + + backend.udpProxies[proxyID].server = server + dataBuf := make([]byte, 65535) + + udpProxyData := &datacommands.UDPProxyData{ + ProxyID: proxyID, + } + + go func() { + for { + len, addr, err := server.ReadFromUDP(dataBuf) + + if err != nil { + log.Warnf("failed to read from UDP socket: %s", err.Error()) + continue + } + + udpProxyData.ClientIP = addr.IP.String() + udpProxyData.ClientPort = uint16(addr.Port) + udpProxyData.DataLength = uint16(len) + + marshalledMessageCommand, err := datacommands.Marshal(udpProxyData) + + if err != nil { + log.Warnf("failed to marshal message data: %s", err.Error()) + continue + } + + if _, err := backend.sock.Write(marshalledMessageCommand); err != nil { + log.Warnf("failed to send marshalled message data: %s", err.Error()) + continue + } + + if _, err := backend.sock.Write(dataBuf[:len]); err != nil { + log.Warnf("failed to send raw message data: %s", err.Error()) + continue + } + } + }() + } + + return proxyID, true, nil } func (backend *SSHRemoteAppBackend) StopProxy(command *datacommands.RemoveProxy) (bool, error) { + tcpProxy, ok := backend.tcpProxies[command.ProxyID] + + if !ok { + udpProxy, ok := backend.udpProxies[command.ProxyID] + + if !ok { + return ok, fmt.Errorf("could not find proxy") + } + + udpProxy.server.Close() + delete(backend.udpProxies, command.ProxyID) + } else { + for _, tcpConnection := range tcpProxy.connections { + tcpConnection.Close() + } + + tcpProxy.server.Close() + delete(backend.tcpProxies, command.ProxyID) + } + return true, nil } func (backend *SSHRemoteAppBackend) GetAllProxies() []uint16 { - return []uint16{} + proxyList := make([]uint16, len(backend.tcpProxies)+len(backend.udpProxies)) + + currentPos := 0 + + for tcpProxy := range backend.tcpProxies { + proxyList[currentPos] = tcpProxy + currentPos += 1 + } + + for udpProxy := range backend.udpProxies { + proxyList[currentPos] = udpProxy + currentPos += 1 + } + + return proxyList } func (backend *SSHRemoteAppBackend) ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse { - return &datacommands.ProxyInformationResponse{} + var proxyInformation *commonbackend.AddProxy + response := &datacommands.ProxyInformationResponse{} + + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + udpProxy, ok := backend.udpProxies[proxyID] + + if !ok { + response.Exists = false + return response + } + + proxyInformation = udpProxy.proxyInformation + } else { + proxyInformation = tcpProxy.proxyInformation + } + + response.Exists = true + response.SourceIP = proxyInformation.SourceIP + response.SourcePort = proxyInformation.SourcePort + response.DestPort = proxyInformation.DestPort + response.Protocol = proxyInformation.Protocol + + return response } func (backend *SSHRemoteAppBackend) GetAllClientConnections(proxyID uint16) []uint16 { - return []uint16{} + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + return []uint16{} + } + + connectionsArray := make([]uint16, len(tcpProxy.connections)) + currentPos := 0 + + for connectionIndex := range tcpProxy.connections { + connectionsArray[currentPos] = connectionIndex + currentPos++ + } + + return connectionsArray } -func (backend *SSHRemoteAppBackend) ResolveConnection(proxyID uint16) *datacommands.ProxyConnectionsResponse { - return &datacommands.ProxyConnectionsResponse{} +func (backend *SSHRemoteAppBackend) ResolveConnection(proxyID, connectionID uint16) *datacommands.ProxyConnectionInformationResponse { + response := &datacommands.ProxyConnectionInformationResponse{} + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + response.Exists = false + return response + } + + connection, ok := tcpProxy.connections[connectionID] + + if !ok { + response.Exists = false + return response + } + + addr := connection.RemoteAddr().String() + ip := addr[:strings.LastIndex(addr, ":")] + port, err := strconv.Atoi(addr[strings.LastIndex(addr, ":")+1:]) + + if err != nil { + log.Warnf("failed to parse client port: %s", err.Error()) + response.Exists = false + + return response + } + + response.ClientIP = ip + response.ClientPort = uint16(port) + + return response } func (backend *SSHRemoteAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { return &commonbackend.CheckParametersResponse{ IsValid: true, - Message: "Valid!", } } func (backend *SSHRemoteAppBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { return &commonbackend.CheckParametersResponse{ IsValid: true, - Message: "Valid!", } } func (backend *SSHRemoteAppBackend) HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) { + tcpProxy, ok := backend.tcpProxies[message.ProxyID] + if !ok { + return + } + + connection, ok := tcpProxy.connections[message.ConnectionID] + + if !ok { + return + } + + connection.Write(data) } func (backend *SSHRemoteAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) { + udpProxy, ok := backend.udpProxies[message.ProxyID] + if !ok { + return + } + + udpProxy.server.WriteToUDP(data, &net.UDPAddr{ + IP: net.ParseIP(message.ClientIP), + Port: int(message.ClientPort), + }) +} + +func (backend *SSHRemoteAppBackend) OnTCPConnectionClosed(proxyID, connectionID uint16) { + tcpProxy, ok := backend.tcpProxies[proxyID] + + if !ok { + return + } + + connection, ok := tcpProxy.connections[connectionID] + + if !ok { + return + } + + connection.Close() + delete(tcpProxy.connections, connectionID) +} + +func (backend *SSHRemoteAppBackend) OnSocketConnection(sock net.Conn) { + backend.sock = sock } func main() { From 15176831e6588da81e3aed667e07a697023f4e7b Mon Sep 17 00:00:00 2001 From: imterah Date: Tue, 18 Feb 2025 13:15:09 -0500 Subject: [PATCH 11/24] feature: Adds basic TCP support for SSHAppBackend. --- backend/backendutil/application.go | 16 +- .../sshappbackend/gaslighter/gaslighter.go | 30 ++ backend/sshappbackend/local-code/main.go | 435 +++++++++++++++++- .../backendutil_custom/application.go | 302 ++++++------ backend/sshappbackend/remote-code/main.go | 11 +- 5 files changed, 622 insertions(+), 172 deletions(-) create mode 100644 backend/sshappbackend/gaslighter/gaslighter.go diff --git a/backend/backendutil/application.go b/backend/backendutil/application.go index 7f134a2..afa3147 100644 --- a/backend/backendutil/application.go +++ b/backend/backendutil/application.go @@ -132,12 +132,12 @@ func (helper *BackendApplicationHelper) Start() error { ok, err := helper.Backend.StartProxy(command) var hasAnyFailed bool - if !ok { - log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) - hasAnyFailed = true - } else if err != nil { + if err != nil { log.Warnf("failed to add proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) hasAnyFailed = true + } else if !ok { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true } response := &commonbackend.ProxyStatusResponse{ @@ -160,12 +160,12 @@ func (helper *BackendApplicationHelper) Start() error { ok, err := helper.Backend.StopProxy(command) var hasAnyFailed bool - if !ok { - log.Warnf("failed to remove proxy (%s:%d -> remote:%d): RemoveProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) - hasAnyFailed = true - } else if err != nil { + if err != nil { log.Warnf("failed to remove proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) hasAnyFailed = true + } else if !ok { + log.Warnf("failed to remove proxy (%s:%d -> remote:%d): RemoveProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true } response := &commonbackend.ProxyStatusResponse{ diff --git a/backend/sshappbackend/gaslighter/gaslighter.go b/backend/sshappbackend/gaslighter/gaslighter.go new file mode 100644 index 0000000..ecccec7 --- /dev/null +++ b/backend/sshappbackend/gaslighter/gaslighter.go @@ -0,0 +1,30 @@ +package gaslighter + +import "io" + +type Gaslighter struct { + Byte byte + HasGaslit bool + ProxiedReader io.Reader +} + +func (gaslighter *Gaslighter) Read(p []byte) (n int, err error) { + if gaslighter.HasGaslit { + return gaslighter.ProxiedReader.Read(p) + } + + if len(p) == 0 { + return 0, nil + } + + p[0] = gaslighter.Byte + gaslighter.HasGaslit = true + + if len(p) > 1 { + n, err := gaslighter.ProxiedReader.Read(p[1:]) + + return n + 1, err + } else { + return 1, nil + } +} diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go index 98a2370..75044b2 100644 --- a/backend/sshappbackend/local-code/main.go +++ b/backend/sshappbackend/local-code/main.go @@ -5,19 +5,35 @@ import ( "crypto/md5" "encoding/hex" "encoding/json" + "errors" "fmt" + "io" + "math/rand/v2" + "net" "os" "strings" "sync" + "time" "git.terah.dev/imterah/hermes/backend/backendutil" "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "git.terah.dev/imterah/hermes/backend/sshappbackend/gaslighter" "github.com/charmbracelet/log" "github.com/go-playground/validator/v10" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) +type TCPProxy struct { + proxyInformation *commonbackend.AddProxy + connections map[uint16]net.Conn +} + +type UDPProxy struct { + proxyInformation *commonbackend.AddProxy +} + type SSHAppBackendData struct { IP string `json:"ip" validate:"required"` Port uint16 `json:"port" validate:"required"` @@ -27,14 +43,27 @@ type SSHAppBackendData struct { } type SSHAppBackend struct { - config *SSHAppBackendData - conn *ssh.Client - clients []*commonbackend.ProxyClientConnection - arrayPropMutex sync.Mutex + config *SSHAppBackendData + conn *ssh.Client + listener net.Listener + currentSock net.Conn + + tcpProxies map[uint16]*TCPProxy + udpProxies map[uint16]*UDPProxy + + // globalNonCriticalMessageLock: Locks all messages that don't need low-latency transmissions & high + // speed behind a lock. This ensures safety when it comes to handling messages correctly. + globalNonCriticalMessageLock sync.Mutex + // globalNonCriticalMessageChan: Channel for handling messages that need a reply / aren't critical. + globalNonCriticalMessageChan chan interface{} } func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { log.Info("SSHAppBackend is initializing...") + backend.globalNonCriticalMessageChan = make(chan interface{}) + backend.tcpProxies = map[uint16]*TCPProxy{} + backend.udpProxies = map[uint16]*UDPProxy{} + var backendData SSHAppBackendData if err := json.Unmarshal(configBytes, &backendData); err != nil { @@ -77,8 +106,8 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { backend.conn = conn - log.Info("SSHAppBackend has connected successfully.") - log.Info("Getting CPU architecture...") + log.Debug("SSHAppBackend has connected successfully.") + log.Debug("Getting CPU architecture...") session, err := backend.conn.NewSession() @@ -125,7 +154,7 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { return false, fmt.Errorf("CPU architecture not compiled/supported currently") } - log.Info("Checking if we need to copy the application...") + log.Debug("Checking if we need to copy the application...") var binary []byte needsToCopyBinary := true @@ -187,7 +216,8 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { } if needsToCopyBinary { - log.Info("Copying binary...") + log.Debug("Copying binary...") + sftpInstance, err := sftp.NewClient(conn) if err != nil { @@ -213,13 +243,13 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { var file *sftp.File if fileExists { - file, err = sftpInstance.Create("/tmp/sshappbackend.runtime") - } else { file, err = sftpInstance.OpenFile("/tmp/sshappbackend.runtime", os.O_WRONLY) + } else { + file, err = sftpInstance.Create("/tmp/sshappbackend.runtime") } if err != nil { - log.Warnf("Failed to create file: %s", err.Error()) + log.Warnf("Failed to create (or open) file: %s", err.Error()) conn.Close() backend.conn = nil return false, err @@ -234,7 +264,7 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { return false, err } - err = file.Chmod(775) + err = file.Chmod(0755) if err != nil { log.Warnf("Failed to change permissions on file: %s", err.Error()) @@ -243,12 +273,25 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { return false, err } - log.Info("Done copying file.") + log.Debug("Done copying file.") + sftpInstance.Close() } else { - log.Info("Skipping copying as there's a copy on disk already.") + log.Debug("Skipping copying as there's a copy on disk already.") } - log.Info("Starting process...") + log.Debug("Initializing Unix socket...") + + socketPath := fmt.Sprintf("/tmp/sock-%d.sock", rand.Uint()) + listener, err := conn.ListenUnix(socketPath) + + if err != nil { + log.Warnf("Failed to listen on socket: %s", err.Error()) + conn.Close() + backend.conn = nil + return false, err + } + + log.Debug("Starting process...") session, err = backend.conn.NewSession() @@ -259,10 +302,56 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { return false, err } + backend.listener = listener + session.Stdout = WriteLogger{} session.Stderr = WriteLogger{} - go session.Run("/tmp/sshappbackend.runtime") + go func() { + for { + err := session.Run(fmt.Sprintf("HERMES_LOG_LEVEL=\"%s\" HERMES_API_SOCK=\"%s\" /tmp/sshappbackend.runtime", os.Getenv("HERMES_LOG_LEVEL"), socketPath)) + + if err != nil && !errors.Is(err, &ssh.ExitError{}) && !errors.Is(err, &ssh.ExitMissingError{}) { + log.Errorf("Critically failed during execution of remote code: %s", err.Error()) + return + } else { + log.Warn("Remote code failed for an unknown reason. Restarting...") + } + } + }() + + go backend.sockServerHandler() + + log.Debug("Started process. Waiting for Unix socket connection...") + + for backend.currentSock == nil { + time.Sleep(10 * time.Millisecond) + } + + log.Debug("Detected connection. Sending initialization command...") + + proxyStatusRaw, err := backend.SendNonCriticalMessage(&commonbackend.Start{ + Arguments: []byte{}, + }) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*commonbackend.BackendStatusResponse) + + if !ok { + return false, fmt.Errorf("recieved invalid response type: %T", proxyStatusRaw) + } + + if proxyStatus.StatusCode == commonbackend.StatusFailure { + if proxyStatus.Message == "" { + return false, fmt.Errorf("failed to initialize backend in remote code") + } else { + return false, fmt.Errorf("failed to initialize backend in remote code: %s", proxyStatus.Message) + } + } + log.Info("SSHAppBackend has initialized successfully.") return true, nil @@ -283,15 +372,117 @@ func (backend *SSHAppBackend) GetBackendStatus() (bool, error) { } func (backend *SSHAppBackend) StartProxy(command *commonbackend.AddProxy) (bool, error) { + proxyStatusRaw, err := backend.SendNonCriticalMessage(command) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*datacommands.ProxyStatusResponse) + + if !ok { + return false, fmt.Errorf("recieved invalid response type: %T", proxyStatusRaw) + } + + if !proxyStatus.IsActive { + return false, fmt.Errorf("failed to initialize proxy in remote code") + } + + if command.Protocol == "tcp" { + backend.tcpProxies[proxyStatus.ProxyID] = &TCPProxy{ + proxyInformation: command, + } + + backend.tcpProxies[proxyStatus.ProxyID].connections = map[uint16]net.Conn{} + } else if command.Protocol == "udp" { + backend.udpProxies[proxyStatus.ProxyID] = &UDPProxy{ + proxyInformation: command, + } + } + return true, nil } func (backend *SSHAppBackend) StopProxy(command *commonbackend.RemoveProxy) (bool, error) { + if command.Protocol == "tcp" { + for proxyIndex, proxy := range backend.tcpProxies { + if proxy.proxyInformation.DestPort != command.DestPort { + continue + } + + onDisconnect := &datacommands.TCPConnectionClosed{ + ProxyID: proxyIndex, + } + + for connectionIndex, connection := range proxy.connections { + connection.Close() + delete(proxy.connections, connectionIndex) + + onDisconnect.ConnectionID = connectionIndex + disconnectionCommandMarshalled, err := datacommands.Marshal(onDisconnect) + + if err != nil { + log.Errorf("failed to marshal disconnection message: %s", err.Error()) + } + + backend.currentSock.Write(disconnectionCommandMarshalled) + } + + proxyStatusRaw, err := backend.SendNonCriticalMessage(&datacommands.RemoveProxy{ + ProxyID: proxyIndex, + }) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*datacommands.ProxyStatusResponse) + + if !ok { + log.Warn("Failed to stop proxy: typecast failed") + return true, fmt.Errorf("failed to stop proxy: typecast failed") + } + + if proxyStatus.IsActive { + log.Warn("Failed to stop proxy: still running") + return true, fmt.Errorf("failed to stop proxy: still running") + } + } + } else if command.Protocol == "udp" { + for proxyIndex, proxy := range backend.udpProxies { + if proxy.proxyInformation.DestPort != command.DestPort { + continue + } + + proxyStatusRaw, err := backend.SendNonCriticalMessage(&datacommands.RemoveProxy{ + ProxyID: proxyIndex, + }) + + if err != nil { + return false, err + } + + proxyStatus, ok := proxyStatusRaw.(*datacommands.ProxyStatusResponse) + + if !ok { + log.Warn("Failed to stop proxy: typecast failed") + return true, fmt.Errorf("failed to stop proxy: typecast failed") + } + + if proxyStatus.IsActive { + log.Warn("Failed to stop proxy: still running") + return true, fmt.Errorf("failed to stop proxy: still running") + } + + // TODO: finish code for UDP + } + } + return false, fmt.Errorf("could not find the proxy") } func (backend *SSHAppBackend) GetAllClientConnections() []*commonbackend.ProxyClientConnection { - return backend.clients + return []*commonbackend.ProxyClientConnection{} } func (backend *SSHAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { @@ -322,6 +513,216 @@ func (backend *SSHAppBackend) CheckParametersForBackend(arguments []byte) *commo } } +func (backend *SSHAppBackend) OnTCPConnectionOpened(proxyID, connectionID uint16) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", backend.tcpProxies[proxyID].proxyInformation.SourceIP, backend.tcpProxies[proxyID].proxyInformation.SourcePort)) + + if err != nil { + log.Warnf("failed to dial sock: %s", err.Error()) + } + + go func() { + dataBuf := make([]byte, 65535) + + tcpData := &datacommands.TCPProxyData{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + for { + len, err := conn.Read(dataBuf) + + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } else if err.Error() != "EOF" { + log.Warnf("failed to read from sock: %s", err.Error()) + } + + conn.Close() + break + } + + tcpData.DataLength = uint16(len) + marshalledMessageCommand, err := datacommands.Marshal(tcpData) + + if err != nil { + log.Warnf("failed to marshal message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.currentSock.Write(marshalledMessageCommand); err != nil { + log.Warnf("failed to send marshalled message data: %s", err.Error()) + + conn.Close() + break + } + + if _, err := backend.currentSock.Write(dataBuf[:len]); err != nil { + log.Warnf("failed to send raw message data: %s", err.Error()) + + conn.Close() + break + } + } + + onDisconnect := &datacommands.TCPConnectionClosed{ + ProxyID: proxyID, + ConnectionID: connectionID, + } + + disconnectionCommandMarshalled, err := datacommands.Marshal(onDisconnect) + + if err != nil { + log.Errorf("failed to marshal disconnection message: %s", err.Error()) + } + + backend.currentSock.Write(disconnectionCommandMarshalled) + }() + + backend.tcpProxies[proxyID].connections[connectionID] = conn +} + +func (backend *SSHAppBackend) OnTCPConnectionClosed(proxyID, connectionID uint16) { + proxy, ok := backend.tcpProxies[proxyID] + + if !ok { + log.Warn("Could not find TCP proxy") + } + + connection, ok := proxy.connections[connectionID] + + if !ok { + log.Warn("Could not find connection in TCP proxy") + } + + connection.Close() + delete(proxy.connections, connectionID) +} + +func (backend *SSHAppBackend) HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) { + proxy, ok := backend.tcpProxies[message.ProxyID] + + if !ok { + log.Warn("Could not find TCP proxy") + } + + connection, ok := proxy.connections[message.ConnectionID] + + if !ok { + log.Warn("Could not find connection in TCP proxy") + } + + connection.Write(data) +} + +func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) {} + +func (backend *SSHAppBackend) SendNonCriticalMessage(iface interface{}) (interface{}, error) { + if backend.currentSock == nil { + return nil, fmt.Errorf("socket connection not initialized yet") + } + + bytes, err := datacommands.Marshal(iface) + + if err != nil && err.Error() == "unsupported command type" { + bytes, err = commonbackend.Marshal(iface) + + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + backend.globalNonCriticalMessageLock.Lock() + + if _, err := backend.currentSock.Write(bytes); err != nil { + backend.globalNonCriticalMessageLock.Unlock() + return nil, fmt.Errorf("failed to write message: %s", err.Error()) + } + + reply, ok := <-backend.globalNonCriticalMessageChan + + if !ok { + backend.globalNonCriticalMessageLock.Unlock() + return nil, fmt.Errorf("failed to get reply back: chan not OK") + } + + backend.globalNonCriticalMessageLock.Unlock() + return reply, nil +} + +func (backend *SSHAppBackend) sockServerHandler() { + for { + conn, err := backend.listener.Accept() + + if err != nil { + log.Warnf("Failed to accept remote connection: %s", err.Error()) + } + + log.Debug("Successfully connected.") + + backend.currentSock = conn + + commandID := make([]byte, 1) + + gaslighter := &gaslighter.Gaslighter{} + gaslighter.ProxiedReader = conn + + dataBuffer := make([]byte, 65535) + + var commandRaw interface{} + + for { + if _, err := conn.Read(commandID); err != nil { + log.Warnf("Failed to read command ID: %s", err.Error()) + return + } + + gaslighter.Byte = commandID[0] + gaslighter.HasGaslit = false + + if gaslighter.Byte > 100 { + commandRaw, err = datacommands.Unmarshal(gaslighter) + } else { + commandRaw, err = commonbackend.Unmarshal(gaslighter) + } + + if err != nil { + log.Warnf("Failed to parse command: %s", err.Error()) + } + + switch command := commandRaw.(type) { + case *datacommands.TCPConnectionOpened: + backend.OnTCPConnectionOpened(command.ProxyID, command.ConnectionID) + case *datacommands.TCPConnectionClosed: + backend.OnTCPConnectionClosed(command.ProxyID, command.ConnectionID) + case *datacommands.TCPProxyData: + if _, err := io.ReadFull(conn, dataBuffer[:command.DataLength]); err != nil { + log.Warnf("Failed to read entire data buffer: %s", err.Error()) + break + } + + backend.HandleTCPMessage(command, dataBuffer[:command.DataLength]) + case *datacommands.UDPProxyData: + if _, err := io.ReadFull(conn, dataBuffer[:command.DataLength]); err != nil { + log.Warnf("Failed to read entire data buffer: %s", err.Error()) + break + } + + backend.HandleUDPMessage(command, dataBuffer[:command.DataLength]) + default: + select { + case backend.globalNonCriticalMessageChan <- command: + default: + } + } + } + } +} + func main() { logLevel := os.Getenv("HERMES_LOG_LEVEL") diff --git a/backend/sshappbackend/remote-code/backendutil_custom/application.go b/backend/sshappbackend/remote-code/backendutil_custom/application.go index 0e00da7..2747f28 100644 --- a/backend/sshappbackend/remote-code/backendutil_custom/application.go +++ b/backend/sshappbackend/remote-code/backendutil_custom/application.go @@ -8,6 +8,7 @@ import ( "git.terah.dev/imterah/hermes/backend/backendutil" "git.terah.dev/imterah/hermes/backend/commonbackend" "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "git.terah.dev/imterah/hermes/backend/sshappbackend/gaslighter" "github.com/charmbracelet/log" ) @@ -38,10 +39,28 @@ func (helper *BackendApplicationHelper) Start() error { log.Debug("Sucessfully connected") - for { - commandRaw, err := datacommands.Unmarshal(helper.socket) + gaslighter := &gaslighter.Gaslighter{} + gaslighter.ProxiedReader = helper.socket - if err != nil && err.Error() != "couldn't match command ID" { + commandID := make([]byte, 1) + + for { + if _, err := helper.socket.Read(commandID); err != nil { + return err + } + + gaslighter.Byte = commandID[0] + gaslighter.HasGaslit = false + + var commandRaw interface{} + + if gaslighter.Byte > 100 { + commandRaw, err = datacommands.Unmarshal(gaslighter) + } else { + commandRaw, err = commonbackend.Unmarshal(gaslighter) + } + + if err != nil { return err } @@ -127,155 +146,146 @@ func (helper *BackendApplicationHelper) Start() error { } helper.Backend.HandleUDPMessage(command, bytes) - default: - commandRaw, err := commonbackend.Unmarshal(helper.socket) + case *commonbackend.Start: + ok, err := helper.Backend.StartBackend(command.Arguments) + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.Stop: + ok, err := helper.Backend.StopBackend() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + IsRunning: !ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.BackendStatusRequest: + ok, err := helper.Backend.GetBackendStatus() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.AddProxy: + id, ok, err := helper.Backend.StartProxy(command) + var hasAnyFailed bool + + if !ok { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true + } else if err != nil { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) + hasAnyFailed = true + } + + response := &datacommands.ProxyStatusResponse{ + ProxyID: id, + IsActive: !hasAnyFailed, + } + + responseMarshalled, err := datacommands.Marshal(response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case *commonbackend.CheckClientParameters: + resp := helper.Backend.CheckParametersForConnections(command) + resp.InResponseTo = "checkClientParameters" + + byteData, err := commonbackend.Marshal(resp) if err != nil { return err } - switch command := commandRaw.(type) { - case *commonbackend.Start: - ok, err := helper.Backend.StartBackend(command.Arguments) - - var ( - message string - statusCode int - ) - - if err != nil { - message = err.Error() - statusCode = commonbackend.StatusFailure - } else { - statusCode = commonbackend.StatusSuccess - } - - response := &commonbackend.BackendStatusResponse{ - IsRunning: ok, - StatusCode: statusCode, - Message: message, - } - - responseMarshalled, err := commonbackend.Marshal(response) - - if err != nil { - log.Error("failed to marshal response: %s", err.Error()) - continue - } - - helper.socket.Write(responseMarshalled) - case *commonbackend.Stop: - ok, err := helper.Backend.StopBackend() - - var ( - message string - statusCode int - ) - - if err != nil { - message = err.Error() - statusCode = commonbackend.StatusFailure - } else { - statusCode = commonbackend.StatusSuccess - } - - response := &commonbackend.BackendStatusResponse{ - IsRunning: !ok, - StatusCode: statusCode, - Message: message, - } - - responseMarshalled, err := commonbackend.Marshal(response) - - if err != nil { - log.Error("failed to marshal response: %s", err.Error()) - continue - } - - helper.socket.Write(responseMarshalled) - case *commonbackend.BackendStatusRequest: - ok, err := helper.Backend.GetBackendStatus() - - var ( - message string - statusCode int - ) - - if err != nil { - message = err.Error() - statusCode = commonbackend.StatusFailure - } else { - statusCode = commonbackend.StatusSuccess - } - - response := &commonbackend.BackendStatusResponse{ - IsRunning: ok, - StatusCode: statusCode, - Message: message, - } - - responseMarshalled, err := commonbackend.Marshal(response) - - if err != nil { - log.Error("failed to marshal response: %s", err.Error()) - continue - } - - helper.socket.Write(responseMarshalled) - case *commonbackend.AddProxy: - id, ok, err := helper.Backend.StartProxy(command) - var hasAnyFailed bool - - if !ok { - log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) - hasAnyFailed = true - } else if err != nil { - log.Warnf("failed to add proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) - hasAnyFailed = true - } - - response := &datacommands.ProxyStatusResponse{ - ProxyID: id, - IsActive: !hasAnyFailed, - } - - responseMarshalled, err := datacommands.Marshal(response) - - if err != nil { - log.Error("failed to marshal response: %s", err.Error()) - continue - } - - helper.socket.Write(responseMarshalled) - case *commonbackend.CheckClientParameters: - resp := helper.Backend.CheckParametersForConnections(command) - resp.InResponseTo = "checkClientParameters" - - byteData, err := commonbackend.Marshal(resp) - - if err != nil { - return err - } - - if _, err = helper.socket.Write(byteData); err != nil { - return err - } - case *commonbackend.CheckServerParameters: - resp := helper.Backend.CheckParametersForBackend(command.Arguments) - resp.InResponseTo = "checkServerParameters" - - byteData, err := commonbackend.Marshal(resp) - - if err != nil { - return err - } - - if _, err = helper.socket.Write(byteData); err != nil { - return err - } - default: - log.Warnf("Unsupported command recieved: %T", command) + if _, err = helper.socket.Write(byteData); err != nil { + return err } + case *commonbackend.CheckServerParameters: + resp := helper.Backend.CheckParametersForBackend(command.Arguments) + resp.InResponseTo = "checkServerParameters" + + byteData, err := commonbackend.Marshal(resp) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + default: + log.Warnf("Unsupported command recieved: %T", command) } } } diff --git a/backend/sshappbackend/remote-code/main.go b/backend/sshappbackend/remote-code/main.go index a1d9bb7..d56a7a3 100644 --- a/backend/sshappbackend/remote-code/main.go +++ b/backend/sshappbackend/remote-code/main.go @@ -36,6 +36,8 @@ type SSHRemoteAppBackend struct { tcpProxies map[uint16]*TCPProxy udpProxies map[uint16]*UDPProxy + isRunning bool + sock net.Conn } @@ -43,6 +45,8 @@ func (backend *SSHRemoteAppBackend) StartBackend(byte []byte) (bool, error) { backend.tcpProxies = map[uint16]*TCPProxy{} backend.udpProxies = map[uint16]*UDPProxy{} + backend.isRunning = true + return true, nil } @@ -61,11 +65,12 @@ func (backend *SSHRemoteAppBackend) StopBackend() (bool, error) { delete(backend.udpProxies, udpProxyIndex) } + backend.isRunning = false return true, nil } func (backend *SSHRemoteAppBackend) GetBackendStatus() (bool, error) { - return true, nil + return backend.isRunning, nil } func (backend *SSHRemoteAppBackend) StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) { @@ -104,6 +109,8 @@ func (backend *SSHRemoteAppBackend) StartProxy(command *commonbackend.AddProxy) backend.tcpProxies[proxyID].connectionIDIndex++ backend.tcpProxies[proxyID].connectionIDLock.Unlock() + backend.tcpProxies[proxyID].connections[connectionID] = conn + dataBuf := make([]byte, 65535) onConnection := &datacommands.TCPConnectionOpened{ @@ -372,12 +379,14 @@ func (backend *SSHRemoteAppBackend) HandleTCPMessage(message *datacommands.TCPPr tcpProxy, ok := backend.tcpProxies[message.ProxyID] if !ok { + log.Warnf("could not find tcp proxy (ID %d)", message.ProxyID) return } connection, ok := tcpProxy.connections[message.ConnectionID] if !ok { + log.Warnf("could not find tcp proxy (ID %d) with connection ID (%d)", message.ProxyID, message.ConnectionID) return } From f8a4fe00a0126d75bb07490a1667d5e9a6a8e65d Mon Sep 17 00:00:00 2001 From: imterah Date: Wed, 19 Feb 2025 07:58:42 -0500 Subject: [PATCH 12/24] feature: Adds basic UDP support. --- backend/commonbackend/unmarshal.go | 10 +- backend/sshappbackend/local-code/main.go | 68 ++++++++++- .../local-code/porttranslation/translation.go | 112 ++++++++++++++++++ 3 files changed, 183 insertions(+), 7 deletions(-) create mode 100644 backend/sshappbackend/local-code/porttranslation/translation.go diff --git a/backend/commonbackend/unmarshal.go b/backend/commonbackend/unmarshal.go index 8e338c6..6bb5af4 100644 --- a/backend/commonbackend/unmarshal.go +++ b/backend/commonbackend/unmarshal.go @@ -213,7 +213,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -270,7 +270,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -364,7 +364,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -523,7 +523,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -580,7 +580,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go index 75044b2..abb0081 100644 --- a/backend/sshappbackend/local-code/main.go +++ b/backend/sshappbackend/local-code/main.go @@ -19,6 +19,7 @@ import ( "git.terah.dev/imterah/hermes/backend/commonbackend" "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" "git.terah.dev/imterah/hermes/backend/sshappbackend/gaslighter" + "git.terah.dev/imterah/hermes/backend/sshappbackend/local-code/porttranslation" "github.com/charmbracelet/log" "github.com/go-playground/validator/v10" "github.com/pkg/sftp" @@ -32,6 +33,7 @@ type TCPProxy struct { type UDPProxy struct { proxyInformation *commonbackend.AddProxy + portTranslation *porttranslation.PortTranslation } type SSHAppBackendData struct { @@ -397,7 +399,56 @@ func (backend *SSHAppBackend) StartProxy(command *commonbackend.AddProxy) (bool, } else if command.Protocol == "udp" { backend.udpProxies[proxyStatus.ProxyID] = &UDPProxy{ proxyInformation: command, + portTranslation: &porttranslation.PortTranslation{}, } + + backend.udpProxies[proxyStatus.ProxyID].portTranslation.UDPAddr = &net.UDPAddr{ + IP: net.ParseIP(command.SourceIP), + Port: int(command.SourcePort), + } + + udpMessageCommand := &datacommands.UDPProxyData{} + udpMessageCommand.ProxyID = proxyStatus.ProxyID + + backend.udpProxies[proxyStatus.ProxyID].portTranslation.WriteFrom = func(ip string, port uint16, data []byte) { + udpMessageCommand.ClientIP = ip + udpMessageCommand.ClientPort = port + udpMessageCommand.DataLength = uint16(len(data)) + + marshalledCommand, err := datacommands.Marshal(udpMessageCommand) + + if err != nil { + log.Warnf("Failed to marshal UDP message header") + return + } + + if _, err := backend.currentSock.Write(marshalledCommand); err != nil { + log.Warnf("Failed to write UDP message header") + return + } + + if _, err := backend.currentSock.Write(data); err != nil { + log.Warnf("Failed to write UDP message") + return + } + } + + go func() { + for { + time.Sleep(3 * time.Minute) + + // Checks if the proxy still exists before continuing + _, ok := backend.udpProxies[proxyStatus.ProxyID] + + if !ok { + return + } + + // Then attempt to run cleanup tasks + log.Debug("Running UDP proxy cleanup tasks (invoking CleanupPorts() on portTranslation)") + backend.udpProxies[proxyStatus.ProxyID].portTranslation.CleanupPorts() + } + }() } return true, nil @@ -474,17 +525,20 @@ func (backend *SSHAppBackend) StopProxy(command *commonbackend.RemoveProxy) (boo return true, fmt.Errorf("failed to stop proxy: still running") } - // TODO: finish code for UDP + proxy.portTranslation.StopAllPorts() + delete(backend.udpProxies, proxyIndex) } } return false, fmt.Errorf("could not find the proxy") } +// TODO: implement! func (backend *SSHAppBackend) GetAllClientConnections() []*commonbackend.ProxyClientConnection { return []*commonbackend.ProxyClientConnection{} } +// We don't have any parameter limitations, so we should be good. func (backend *SSHAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { return &commonbackend.CheckParametersResponse{ IsValid: true, @@ -617,7 +671,17 @@ func (backend *SSHAppBackend) HandleTCPMessage(message *datacommands.TCPProxyDat connection.Write(data) } -func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) {} +func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) { + proxy, ok := backend.udpProxies[message.ProxyID] + + if !ok { + log.Warn("Could not find UDP proxy") + } + + if _, err := proxy.portTranslation.WriteTo(message.ClientIP, message.ClientPort, data); err != nil { + log.Warnf("Failed to write to UDP: %s", err.Error()) + } +} func (backend *SSHAppBackend) SendNonCriticalMessage(iface interface{}) (interface{}, error) { if backend.currentSock == nil { diff --git a/backend/sshappbackend/local-code/porttranslation/translation.go b/backend/sshappbackend/local-code/porttranslation/translation.go new file mode 100644 index 0000000..b8c0454 --- /dev/null +++ b/backend/sshappbackend/local-code/porttranslation/translation.go @@ -0,0 +1,112 @@ +package porttranslation + +import ( + "fmt" + "net" + "sync" + "time" +) + +type connectionData struct { + udpConn *net.UDPConn + buf []byte + hasBeenAliveFor time.Time +} + +type PortTranslation struct { + UDPAddr *net.UDPAddr + WriteFrom func(ip string, port uint16, data []byte) + + newConnectionLock sync.Mutex + connections map[string]map[uint16]*connectionData +} + +func (translation *PortTranslation) CleanupPorts() { + if translation.connections == nil { + translation.connections = map[string]map[uint16]*connectionData{} + return + } + + for connectionIPIndex, connectionPorts := range translation.connections { + anyAreAlive := false + + for connectionPortIndex, connectionData := range connectionPorts { + if time.Now().Before(connectionData.hasBeenAliveFor.Add(3 * time.Minute)) { + anyAreAlive = true + continue + } + + connectionData.udpConn.Close() + delete(connectionPorts, connectionPortIndex) + } + + if !anyAreAlive { + delete(translation.connections, connectionIPIndex) + } + } +} + +func (translation *PortTranslation) StopAllPorts() { + if translation.connections == nil { + return + } + + for connectionIPIndex, connectionPorts := range translation.connections { + for connectionPortIndex, connectionData := range connectionPorts { + connectionData.udpConn.Close() + delete(connectionPorts, connectionPortIndex) + } + + delete(translation.connections, connectionIPIndex) + } + + translation.connections = nil +} + +func (translation *PortTranslation) WriteTo(ip string, port uint16, data []byte) (int, error) { + if translation.connections == nil { + translation.connections = map[string]map[uint16]*connectionData{} + } + + connectionPortData, ok := translation.connections[ip] + + if !ok { + translation.connections[ip] = map[uint16]*connectionData{} + connectionPortData = translation.connections[ip] + } + + connectionStruct, ok := connectionPortData[port] + + if !ok { + connectionPortData[port] = &connectionData{} + connectionStruct = connectionPortData[port] + + udpConn, err := net.DialUDP("udp", nil, translation.UDPAddr) + + if err != nil { + return 0, fmt.Errorf("failed to initialize UDP socket: %s", err.Error()) + } + + connectionStruct.udpConn = udpConn + connectionStruct.buf = make([]byte, 65535) + + go func() { + for { + n, err := udpConn.Read(connectionStruct.buf) + + if err != nil { + udpConn.Close() + delete(connectionPortData, port) + + return + } + + connectionStruct.hasBeenAliveFor = time.Now() + translation.WriteFrom(ip, port, connectionStruct.buf[:n]) + } + }() + } + + connectionStruct.hasBeenAliveFor = time.Now() + return connectionStruct.udpConn.Write(data) +} From 34b605c1b14b9fa1223760924391a0c094e3bc78 Mon Sep 17 00:00:00 2001 From: imterah Date: Thu, 20 Feb 2025 09:11:28 -0500 Subject: [PATCH 13/24] feature: Adds API manifest definitions, and implement GetAllClientConnections() --- Dockerfile | 1 + backend/backends.dev.json | 4 +++ backend/backends.prod.json | 4 +++ backend/sshappbackend/local-code/main.go | 41 ++++++++++++++++++++++-- 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 81710a7..ae9c525 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,4 +7,5 @@ WORKDIR /app COPY --from=build /build/backend/backends.prod.json /app/backends.json COPY --from=build /build/backend/api/api /app/hermes COPY --from=build /build/backend/sshbackend/sshbackend /app/sshbackend +COPY --from=build /build/backend/sshappbackend/local-code/sshappbackend /app/sshappbackend ENTRYPOINT ["/app/hermes", "--backends-path", "/app/backends.json"] diff --git a/backend/backends.dev.json b/backend/backends.dev.json index 9ff52ca..f314c69 100644 --- a/backend/backends.dev.json +++ b/backend/backends.dev.json @@ -3,6 +3,10 @@ "name": "ssh", "path": "./sshbackend/sshbackend" }, + { + "name": "sshapp", + "path": "./sshappbackend/local-code/sshappbackend" + }, { "name": "dummy", "path": "./dummybackend/dummybackend" diff --git a/backend/backends.prod.json b/backend/backends.prod.json index 9a9a09e..0ccfedc 100644 --- a/backend/backends.prod.json +++ b/backend/backends.prod.json @@ -2,5 +2,9 @@ { "name": "ssh", "path": "./sshbackend" + }, + { + "name": "sshapp", + "path": "./sshappbackend" } ] diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go index abb0081..cb0c13d 100644 --- a/backend/sshappbackend/local-code/main.go +++ b/backend/sshappbackend/local-code/main.go @@ -533,9 +533,46 @@ func (backend *SSHAppBackend) StopProxy(command *commonbackend.RemoveProxy) (boo return false, fmt.Errorf("could not find the proxy") } -// TODO: implement! func (backend *SSHAppBackend) GetAllClientConnections() []*commonbackend.ProxyClientConnection { - return []*commonbackend.ProxyClientConnection{} + connections := []*commonbackend.ProxyClientConnection{} + informationRequest := &datacommands.ProxyConnectionInformationRequest{} + + for proxyID, tcpProxy := range backend.tcpProxies { + informationRequest.ProxyID = proxyID + + for connectionID := range tcpProxy.connections { + informationRequest.ConnectionID = connectionID + + proxyStatusRaw, err := backend.SendNonCriticalMessage(informationRequest) + + if err != nil { + log.Warnf("Failed to get connection information for Proxy ID: %d, Connection ID: %d: %s", proxyID, connectionID, err.Error()) + return connections + } + + connectionStatus, ok := proxyStatusRaw.(*datacommands.ProxyConnectionInformationResponse) + + if !ok { + log.Warn("Failed to get connection response: typecast failed") + return connections + } + + if !connectionStatus.Exists { + log.Warnf("Connection with proxy ID: %d, Connection ID: %d is reported to not exist!", proxyID, connectionID) + tcpProxy.connections[connectionID].Close() + } + + connections = append(connections, &commonbackend.ProxyClientConnection{ + SourceIP: tcpProxy.proxyInformation.SourceIP, + SourcePort: tcpProxy.proxyInformation.SourcePort, + DestPort: tcpProxy.proxyInformation.DestPort, + ClientIP: connectionStatus.ClientIP, + ClientPort: connectionStatus.ClientPort, + }) + } + } + + return connections } // We don't have any parameter limitations, so we should be good. From 959718163ebf8af98ec9ddffdef79a41c760dcfa Mon Sep 17 00:00:00 2001 From: imterah Date: Sun, 16 Mar 2025 21:34:20 -0400 Subject: [PATCH 14/24] fix: Fix disconnect handler not working in production --- backend/sshbackend/main.go | 110 +++++++++++++++++++++++++++++++------ 1 file changed, 93 insertions(+), 17 deletions(-) diff --git a/backend/sshbackend/main.go b/backend/sshbackend/main.go index baf1994..ffdf546 100644 --- a/backend/sshbackend/main.go +++ b/backend/sshbackend/main.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "os" + "slices" "strconv" "strings" "sync" @@ -18,6 +19,32 @@ import ( "golang.org/x/crypto/ssh" ) +type ConnWithTimeout struct { + net.Conn + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func (c *ConnWithTimeout) Read(b []byte) (int, error) { + err := c.Conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + + if err != nil { + return 0, err + } + + return c.Conn.Read(b) +} + +func (c *ConnWithTimeout) Write(b []byte) (int, error) { + err := c.Conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + + if err != nil { + return 0, err + } + + return c.Conn.Write(b) +} + type SSHListener struct { SourceIP string SourcePort uint16 @@ -76,16 +103,34 @@ func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { }, } - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backendData.IP, backendData.Port), config) + addr := fmt.Sprintf("%s:%d", backendData.IP, backendData.Port) + timeout := time.Duration(10 * time.Second) + + rawTCPConn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { return false, err } - backend.conn = conn + connWithTimeout := &ConnWithTimeout{ + Conn: rawTCPConn, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + c, chans, reqs, err := ssh.NewClientConn(connWithTimeout, addr, config) + + if err != nil { + return false, err + } + + client := ssh.NewClient(c, chans, reqs) + backend.conn = client + + go backend.backendDisconnectHandler() + go backend.backendKeepaliveHandler() log.Info("SSHBackend has initialized successfully.") - go backend.backendDisconnectHandler() return true, nil } @@ -203,8 +248,7 @@ func (backend *SSHBackend) StartProxy(command *commonbackend.AddProxy) (bool, er // Splice out the clientInstance by clientIndex // TODO: change approach. It works but it's a bit wonky imho - // I asked AI to do this as it's a relatively simple task and I forgot how to do this effectively - backend.clients = append(backend.clients[:clientIndex], backend.clients[clientIndex+1:]...) + backend.clients = slices.Delete(backend.clients, clientIndex, clientIndex+1) return } } @@ -288,10 +332,8 @@ func (backend *SSHBackend) StopProxy(command *commonbackend.RemoveProxy) (bool, } // Splice out the proxy instance by proxyIndex - // TODO: change approach. It works but it's a bit wonky imho - // I asked AI to do this as it's a relatively simple task and I forgot how to do this effectively - backend.proxies = append(backend.proxies[:proxyIndex], backend.proxies[proxyIndex+1:]...) + backend.proxies = slices.Delete(backend.proxies, proxyIndex, proxyIndex+1) return true, nil } } @@ -341,17 +383,31 @@ func (backend *SSHBackend) CheckParametersForBackend(arguments []byte) *commonba } } -func (backend *SSHBackend) backendDisconnectHandler() { +func (backend *SSHBackend) backendKeepaliveHandler() { for { if backend.conn != nil { - err := backend.conn.Wait() + _, _, err := backend.conn.SendRequest("keepalive@openssh.com", true, nil) - if err == nil || err.Error() != "EOF" { - continue + if err != nil { + log.Warn("Keepalive message failed!") + return } } - log.Info("Disconnected from the remote SSH server. Attempting to reconnect in 5 seconds...") + time.Sleep(5 * time.Second) + } +} + +func (backend *SSHBackend) backendDisconnectHandler() { + for { + if backend.conn != nil { + backend.conn.Wait() + backend.conn.Close() + + log.Info("Disconnected from the remote SSH server. Attempting to reconnect in 5 seconds...") + } else { + log.Info("Retrying reconnection in 5 seconds...") + } time.Sleep(5 * time.Second) @@ -376,14 +432,34 @@ func (backend *SSHBackend) backendDisconnectHandler() { }, } - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backend.config.IP, backend.config.Port), config) + addr := fmt.Sprintf("%s:%d", backend.config.IP, backend.config.Port) + timeout := time.Duration(10 * time.Second) + + rawTCPConn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { - log.Errorf("Failed to connect to the server: %s", err.Error()) - return + log.Errorf("Failed to establish connection to the server: %s", err.Error()) + continue } - backend.conn = conn + connWithTimeout := &ConnWithTimeout{ + Conn: rawTCPConn, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + c, chans, reqs, err := ssh.NewClientConn(connWithTimeout, addr, config) + + if err != nil { + log.Errorf("Failed to create SSH client connection: %s", err.Error()) + rawTCPConn.Close() + continue + } + + client := ssh.NewClient(c, chans, reqs) + backend.conn = client + + go backend.backendKeepaliveHandler() log.Info("SSHBackend has reconnected successfully. Attempting to set up proxies again...") From 5c503f04218c6ca4be6beb778946aee57228d056 Mon Sep 17 00:00:00 2001 From: imterah Date: Tue, 18 Mar 2025 20:27:40 -0400 Subject: [PATCH 15/24] fix: Make logging options more clear for the backend runtime's backend logs --- .prettierrc | 16 ---------------- backend/api/backendruntime/core.go | 4 ++-- backend/api/backendruntime/struct.go | 2 +- 3 files changed, 3 insertions(+), 19 deletions(-) delete mode 100644 .prettierrc diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 57562cf..0000000 --- a/.prettierrc +++ /dev/null @@ -1,16 +0,0 @@ -{ - "arrowParens": "avoid", - "bracketSpacing": true, - "htmlWhitespaceSensitivity": "css", - "insertPragma": false, - "jsxSingleQuote": false, - "printWidth": 80, - "proseWrap": "always", - "quoteProps": "as-needed", - "requirePragma": false, - "semi": true, - "singleQuote": false, - "tabWidth": 2, - "trailingComma": "all", - "useTabs": false -} \ No newline at end of file diff --git a/backend/api/backendruntime/core.go b/backend/api/backendruntime/core.go index e06a2a6..eac5934 100644 --- a/backend/api/backendruntime/core.go +++ b/backend/api/backendruntime/core.go @@ -6,10 +6,10 @@ var ( AvailableBackends []*Backend RunningBackends map[uint]*Runtime TempDir string - isDevelopmentMode bool + shouldLog bool ) func init() { RunningBackends = make(map[uint]*Runtime) - isDevelopmentMode = os.Getenv("HERMES_DEVELOPMENT_MODE") != "" + shouldLog = os.Getenv("HERMES_DEVELOPMENT_MODE") != "" || os.Getenv("HERMES_BACKEND_LOGGING_ENABLED") != "" || os.Getenv("HERMES_LOG_LEVEL") == "debug" } diff --git a/backend/api/backendruntime/struct.go b/backend/api/backendruntime/struct.go index 9d057ad..cd30182 100644 --- a/backend/api/backendruntime/struct.go +++ b/backend/api/backendruntime/struct.go @@ -42,7 +42,7 @@ type writeLogger struct { func (writer writeLogger) Write(p []byte) (n int, err error) { logSplit := strings.Split(string(p), "\n") - if isDevelopmentMode { + if shouldLog { for _, logLine := range logSplit { if logLine == "" { continue From 17b10c9b19b0c39de7b5249b819b8ff3fd65a232 Mon Sep 17 00:00:00 2001 From: imterah Date: Tue, 18 Mar 2025 20:31:28 -0400 Subject: [PATCH 16/24] fix: Add system to detect duplicate running remote processes and kill them accordingly --- backend/sshbackend/main.go | 106 ++++++++++++++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 8 deletions(-) diff --git a/backend/sshbackend/main.go b/backend/sshbackend/main.go index ffdf546..a554fa3 100644 --- a/backend/sshbackend/main.go +++ b/backend/sshbackend/main.go @@ -53,24 +53,35 @@ type SSHListener struct { Listeners []net.Listener } +type SSHBackendData struct { + IP string `json:"ip" validate:"required"` + Port uint16 `json:"port" validate:"required"` + Username string `json:"username" validate:"required"` + PrivateKey string `json:"privateKey" validate:"required"` + DisablePIDCheck bool `json:"disablePIDCheck"` + ListenOnIPs []string `json:"listenOnIPs"` +} + type SSHBackend struct { config *SSHBackendData conn *ssh.Client clients []*commonbackend.ProxyClientConnection proxies []*SSHListener arrayPropMutex sync.Mutex -} - -type SSHBackendData struct { - IP string `json:"ip" validate:"required"` - Port uint16 `json:"port" validate:"required"` - Username string `json:"username" validate:"required"` - PrivateKey string `json:"privateKey" validate:"required"` - ListenOnIPs []string `json:"listenOnIPs"` + pid int + isReady bool + inReinitLoop bool } func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { log.Info("SSHBackend is initializing...") + + if backend.inReinitLoop { + for !backend.isReady { + time.Sleep(100 * time.Millisecond) + } + } + var backendData SSHBackendData if err := json.Unmarshal(bytes, &backendData); err != nil { @@ -127,6 +138,42 @@ func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { client := ssh.NewClient(c, chans, reqs) backend.conn = client + if !backendData.DisablePIDCheck { + if backend.pid != 0 { + session, err := client.NewSession() + + if err != nil { + return false, err + } + + err = session.Run(fmt.Sprintf("kill -9 %d", backend.pid)) + + if err != nil { + log.Warnf("Failed to kill process: %s", err.Error()) + } + } + + session, err := client.NewSession() + + if err != nil { + return false, err + } + + // Get the parent PID of the shell so we can kill it if we disconnect + output, err := session.Output("ps --no-headers -fp $$ | awk '{print $3}'") + + if err != nil { + return false, err + } + + // Strip the new line and convert to int + backend.pid, err = strconv.Atoi(string(output)[:len(output)-1]) + + if err != nil { + return false, err + } + } + go backend.backendDisconnectHandler() go backend.backendKeepaliveHandler() @@ -404,6 +451,9 @@ func (backend *SSHBackend) backendDisconnectHandler() { backend.conn.Wait() backend.conn.Close() + backend.isReady = false + backend.inReinitLoop = true + log.Info("Disconnected from the remote SSH server. Attempting to reconnect in 5 seconds...") } else { log.Info("Retrying reconnection in 5 seconds...") @@ -459,6 +509,46 @@ func (backend *SSHBackend) backendDisconnectHandler() { client := ssh.NewClient(c, chans, reqs) backend.conn = client + if !backend.config.DisablePIDCheck { + if backend.pid != 0 { + session, err := client.NewSession() + + if err != nil { + log.Warnf("Failed to create SSH command session: %s", err.Error()) + return + } + + err = session.Run(fmt.Sprintf("kill -9 %d", backend.pid)) + + if err != nil { + log.Warnf("Failed to kill process: %s", err.Error()) + } + } + + session, err := client.NewSession() + + if err != nil { + log.Warnf("Failed to create SSH command session: %s", err.Error()) + return + } + + // Get the parent PID of the shell so we can kill it if we disconnect + output, err := session.Output("ps --no-headers -fp $$ | awk '{print $3}'") + + if err != nil { + log.Warnf("Failed to execute command to fetch PID: %s", err.Error()) + return + } + + // Strip the new line and convert to int + backend.pid, err = strconv.Atoi(string(output)[:len(output)-1]) + + if err != nil { + log.Warnf("Failed to parse PID: %s", err.Error()) + return + } + } + go backend.backendKeepaliveHandler() log.Info("SSHBackend has reconnected successfully. Attempting to set up proxies again...") From 7dee159d5f46f0e4baf2b915b45f3167882ecbab Mon Sep 17 00:00:00 2001 From: imterah Date: Tue, 18 Mar 2025 20:33:59 -0400 Subject: [PATCH 17/24] chore: Change license name --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 8914588..a085e23 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2024, Greyson +Copyright (c) 2024, Tera Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: From 83f80af4055d242c033b1b8c87db1d6751873676 Mon Sep 17 00:00:00 2001 From: imterah Date: Wed, 19 Mar 2025 00:34:36 +0000 Subject: [PATCH 18/24] chore: Delete unmaintained CHANGELOG --- CHANGELOG.md | 49 ------------------------------------------------- 1 file changed, 49 deletions(-) delete mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index c1fa49f..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,49 +0,0 @@ -# Changelog - -## [v1.1.2](https://github.com/imterah/nextnet/tree/v1.1.2) (2024-09-29) - -## [v1.1.1](https://github.com/imterah/nextnet/tree/v1.1.1) (2024-09-29) - -## [v1.1.0](https://github.com/imterah/nextnet/tree/v1.1.0) (2024-09-22) - -**Fixed bugs:** - -- Desktop app fails to build on macOS w/ `nix-shell` [\#1](https://github.com/imterah/nextnet/issues/1) - -**Merged pull requests:** - -- chore\(deps\): bump find-my-way from 8.1.0 to 8.2.2 in /api [\#17](https://github.com/imterah/nextnet/pull/17) -- chore\(deps\): bump axios from 1.6.8 to 1.7.4 in /lom [\#16](https://github.com/imterah/nextnet/pull/16) -- chore\(deps\): bump micromatch from 4.0.5 to 4.0.8 in /lom [\#15](https://github.com/imterah/nextnet/pull/15) -- chore\(deps\): bump braces from 3.0.2 to 3.0.3 in /lom [\#13](https://github.com/imterah/nextnet/pull/13) -- chore\(deps-dev\): bump braces from 3.0.2 to 3.0.3 in /api [\#11](https://github.com/imterah/nextnet/pull/11) -- chore\(deps\): bump ws from 8.17.0 to 8.17.1 in /api [\#10](https://github.com/imterah/nextnet/pull/10) - -## [v1.0.1](https://github.com/imterah/nextnet/tree/v1.0.1) (2024-05-18) - -**Merged pull requests:** - -- Adds public key authentication [\#6](https://github.com/imterah/nextnet/pull/6) -- Add support for eslint [\#5](https://github.com/imterah/nextnet/pull/5) - -## [v1.0.0](https://github.com/imterah/nextnet/tree/v1.0.0) (2024-05-10) - -## [v0.1.1](https://github.com/imterah/nextnet/tree/v0.1.1) (2024-05-05) - -## [v0.1.0](https://github.com/imterah/nextnet/tree/v0.1.0) (2024-05-05) - -**Implemented enhancements:** - -- \(potentially\) Migrate nix shell to nix flake [\#2](https://github.com/imterah/nextnet/issues/2) - -**Closed issues:** - -- add precommit hooks [\#3](https://github.com/imterah/nextnet/issues/3) - -**Merged pull requests:** - -- Reimplements PassyFire as a possible backend [\#4](https://github.com/imterah/nextnet/pull/4) - - - -\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* From 1cefe64f8837509f891c8960279a5dfd6a4848f4 Mon Sep 17 00:00:00 2001 From: imterah Date: Tue, 18 Mar 2025 20:36:00 -0400 Subject: [PATCH 19/24] chore: Remove Node.JS from the nix shell --- shell.nix | 1 - 1 file changed, 1 deletion(-) diff --git a/shell.nix b/shell.nix index 17c5989..ea1aa8e 100644 --- a/shell.nix +++ b/shell.nix @@ -3,7 +3,6 @@ }: pkgs.mkShell { buildInputs = with pkgs; [ # api/ - nodejs go gopls ]; From 71d53990de34fb59eb18356e884e06556bb02d82 Mon Sep 17 00:00:00 2001 From: imterah Date: Tue, 18 Mar 2025 20:51:01 -0400 Subject: [PATCH 20/24] chore: Fix sample code to remove the deprecated LOM, and add JWT secrets --- README.md | 2 +- docker-compose.yml | 18 +----------------- prod-docker.env | 3 +-- 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 01d5885..323855d 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ 1. Copy and change the default password (or username & db name too) from the template file `prod-docker.env`: ```bash - sed "s/POSTGRES_PASSWORD=hermes/POSTGRES_PASSWORD=$(head -c 500 /dev/random | sha512sum | cut -d " " -f 1)/g" prod-docker.env > .env + sed -e "s/POSTGRES_PASSWORD=hermes/POSTGRES_PASSWORD=$(head -c 500 /dev/random | sha512sum | cut -d " " -f 1)/g" -e "s/JWT_SECRET=hermes/JWT_SECRET=$(head -c 500 /dev/random | sha512sum | cut -d " " -f 1)/g" prod-docker.env > .env ``` 2. Build the docker stack: `docker compose --env-file .env up -d` diff --git a/docker-compose.yml b/docker-compose.yml index 7ac3334..a549035 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,27 +6,12 @@ services: environment: DATABASE_URL: postgresql://${POSTGRES_USERNAME}:${POSTGRES_PASSWORD}@nextnet-postgres:5432/${POSTGRES_DB}?schema=nextnet HERMES_POSTGRES_DSN: postgres://${POSTGRES_USERNAME}:${POSTGRES_PASSWORD}@nextnet-postgres:5432/${POSTGRES_DB} + HERMES_JWT_SECRET: ${JWT_SECRET} HERMES_DATABASE_BACKEND: postgresql depends_on: - db ports: - 3000:3000 - - # WARN: The LOM is deprecated and likely broken currently. - # - # NOTE: For this to work correctly, the nextnet-api must be version > 0.1.1 - # or have a version with backported username support, incl. logins - lom: - image: ghcr.io/imterah/hermes-lom:latest - container_name: hermes-lom - restart: always - ports: - - 2222:2222 - depends_on: - - api - volumes: - - ssh_key_data:/app/keys - db: image: postgres:17.2 container_name: nextnet-postgres @@ -37,7 +22,6 @@ services: POSTGRES_USER: ${POSTGRES_USERNAME} volumes: - postgres_data:/var/lib/postgresql/data - volumes: postgres_data: ssh_key_data: diff --git a/prod-docker.env b/prod-docker.env index fe5da79..954c20f 100644 --- a/prod-docker.env +++ b/prod-docker.env @@ -1,5 +1,4 @@ -# These are default values, please change these! - POSTGRES_USERNAME=hermes POSTGRES_PASSWORD=hermes POSTGRES_DB=hermes +JWT_SECRET=hermes From d56a8eb7bf69af685a00113d030239e496d155bb Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 21 Mar 2025 12:59:51 -0400 Subject: [PATCH 21/24] feature: Change state management from global variables to object passing This restructures dbcore (now the db package) and jwtcore (now the jwt package) to use a single struct. There is now a state package, which contains a struct with the full application state. After this, instead of initializing the API routes directly in the main function, the state object gets passed, and the API routes get initialized with their accompanying code. One fix done to reduce memory usage and increase speed is that the validator object is now persistent across requests, instead of recreating it each time. This should speed things up slightly, and improve memory usage. One additional chore done is that the database models have been moved to be a seperate file from the DB initialization itself. --- backend/api/backup.go | 298 --------------- backend/api/controllers/v1/backends/create.go | 357 +++++++++--------- backend/api/controllers/v1/backends/lookup.go | 187 ++++----- backend/api/controllers/v1/backends/remove.go | 175 ++++----- .../api/controllers/v1/proxies/connections.go | 226 +++++------ backend/api/controllers/v1/proxies/create.go | 268 ++++++------- backend/api/controllers/v1/proxies/lookup.go | 249 ++++++------ backend/api/controllers/v1/proxies/remove.go | 218 ++++++----- backend/api/controllers/v1/proxies/start.go | 208 +++++----- backend/api/controllers/v1/proxies/stop.go | 194 +++++----- backend/api/controllers/v1/users/create.go | 251 ++++++------ backend/api/controllers/v1/users/login.go | 267 ++++++------- backend/api/controllers/v1/users/lookup.go | 169 ++++----- backend/api/controllers/v1/users/refresh.go | 185 ++++----- backend/api/controllers/v1/users/remove.go | 159 ++++---- backend/api/db/db.go | 77 ++++ backend/api/db/models.go | 66 ++++ backend/api/dbcore/db.go | 142 ------- backend/api/jwt/jwt.go | 107 ++++++ backend/api/jwtcore/jwt.go | 117 ------ backend/api/main.go | 114 +++--- backend/api/permissions/permission_nodes.go | 4 +- backend/api/state/state.go | 24 ++ 23 files changed, 1901 insertions(+), 2161 deletions(-) delete mode 100644 backend/api/backup.go create mode 100644 backend/api/db/db.go create mode 100644 backend/api/db/models.go delete mode 100644 backend/api/dbcore/db.go create mode 100644 backend/api/jwt/jwt.go delete mode 100644 backend/api/jwtcore/jwt.go create mode 100644 backend/api/state/state.go diff --git a/backend/api/backup.go b/backend/api/backup.go deleted file mode 100644 index d5a9b86..0000000 --- a/backend/api/backup.go +++ /dev/null @@ -1,298 +0,0 @@ -package main - -import ( - "compress/gzip" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "os" - "strings" - - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "github.com/charmbracelet/log" - "github.com/go-playground/validator/v10" - "github.com/urfave/cli/v2" - "gorm.io/gorm" -) - -// Data structures -type BackupBackend struct { - ID uint `json:"id" validate:"required"` - - Name string `json:"name" validate:"required"` - Description *string `json:"description"` - Backend string `json:"backend" validate:"required"` - BackendParameters string `json:"connectionDetails" validate:"required"` -} - -type BackupProxy struct { - ID uint `json:"id" validate:"required"` - BackendID uint `json:"destProviderID" validate:"required"` - - Name string `json:"name" validate:"required"` - Description *string `json:"description"` - Protocol string `json:"protocol" validate:"required"` - SourceIP string `json:"sourceIP" validate:"required"` - SourcePort uint16 `json:"sourcePort" validate:"required"` - DestinationPort uint16 `json:"destPort" validate:"required"` - AutoStart bool `json:"enabled" validate:"required"` -} - -type BackupPermission struct { - ID uint `json:"id" validate:"required"` - - PermissionNode string `json:"permission" validate:"required"` - HasPermission bool `json:"has" validate:"required"` - UserID uint `json:"userID" validate:"required"` -} - -type BackupUser struct { - ID uint `json:"id" validate:"required"` - - Email string `json:"email" validate:"required"` - Username *string `json:"username"` - Name string `json:"name" validate:"required"` - Password string `json:"password" validate:"required"` - IsBot *bool `json:"isRootServiceAccount"` - - Token *string `json:"rootToken" validate:"required"` -} - -type BackupData struct { - Backends []*BackupBackend `json:"destinationProviders" validate:"required"` - Proxies []*BackupProxy `json:"forwardRules" validate:"required"` - Permissions []*BackupPermission `json:"allPermissions" validate:"required"` - Users []*BackupUser `json:"users" validate:"required"` -} - -// From https://stackoverflow.com/questions/54461423/efficient-way-to-remove-all-non-alphanumeric-characters-from-large-text -// Strips all alphanumeric characters from a string -func stripAllAlphanumeric(s string) string { - var result strings.Builder - - for i := 0; i < len(s); i++ { - b := s[i] - if ('a' <= b && b <= 'z') || - ('A' <= b && b <= 'Z') || - ('0' <= b && b <= '9') { - result.WriteByte(b) - } else { - result.WriteByte('_') - } - } - - return result.String() -} - -func backupRestoreEntrypoint(cCtx *cli.Context) error { - log.Info("Decompressing backup...") - - backupFile, err := os.Open(cCtx.String("backup-path")) - - if err != nil { - return fmt.Errorf("failed to open backup: %s", err.Error()) - } - - reader, err := gzip.NewReader(backupFile) - - if err != nil { - return fmt.Errorf("failed to initialize Gzip (compression) reader: %s", err.Error()) - } - - backupDataBytes, err := io.ReadAll(reader) - - if err != nil { - return fmt.Errorf("failed to read backup contents: %s", err.Error()) - } - - log.Info("Decompressed backup. Cleaning up...") - - err = reader.Close() - - if err != nil { - return fmt.Errorf("failed to close Gzip reader: %s", err.Error()) - } - - err = backupFile.Close() - - if err != nil { - return fmt.Errorf("failed to close backup: %s", err.Error()) - } - - log.Info("Parsing backup into internal structures...") - - backupData := &BackupData{} - - err = json.Unmarshal(backupDataBytes, backupData) - - if err != nil { - return fmt.Errorf("failed to parse backup: %s", err.Error()) - } - - if err := validator.New().Struct(backupData); err != nil { - return fmt.Errorf("failed to validate backup: %s", err.Error()) - } - - log.Info("Initializing database and opening it...") - - err = dbcore.InitializeDatabase(&gorm.Config{}) - - if err != nil { - log.Fatalf("Failed to initialize database: %s", err) - } - - log.Info("Running database migrations...") - - if err := dbcore.DoDatabaseMigrations(dbcore.DB); err != nil { - return fmt.Errorf("Failed to run database migrations: %s", err) - } - - log.Info("Restoring database...") - bestEffortOwnerUIDFromBackup := -1 - - log.Info("Attempting to find user to use as owner of resources...") - - for _, user := range backupData.Users { - foundUser := false - failedAdministrationCheck := false - - for _, permission := range backupData.Permissions { - if permission.UserID != user.ID { - continue - } - - foundUser = true - - if !strings.HasPrefix(permission.PermissionNode, "routes.") && permission.PermissionNode != "permissions.see" && !permission.HasPermission { - log.Infof("User with email '%s' and ID of '%d' failed administration check (lacks all permissions required). Attempting to find better user", user.Email, user.ID) - failedAdministrationCheck = true - - break - } - } - - if !foundUser { - log.Warnf("User with email '%s' and ID of '%d' lacks any permissions!", user.Email, user.ID) - continue - } - - if failedAdministrationCheck { - continue - } - - log.Infof("Using user with email '%s', and ID of '%d'", user.Email, user.ID) - bestEffortOwnerUIDFromBackup = int(user.ID) - - break - } - - if bestEffortOwnerUIDFromBackup == -1 { - log.Warnf("Could not find Administrative level user to use as the owner of resources. Using user with email '%s', and ID of '%d'", backupData.Users[0].Email, backupData.Users[0].ID) - bestEffortOwnerUIDFromBackup = int(backupData.Users[0].ID) - } - - var bestEffortOwnerUID uint - - for _, user := range backupData.Users { - log.Infof("Migrating user with email '%s' and ID of '%d'", user.Email, user.ID) - tokens := make([]dbcore.Token, 0) - permissions := make([]dbcore.Permission, 0) - - if user.Token != nil { - tokens = append(tokens, dbcore.Token{ - Token: *user.Token, - DisableExpiry: true, - CreationIPAddr: "127.0.0.1", // We don't know the creation IP address... - }) - } - - for _, permission := range backupData.Permissions { - if permission.UserID != user.ID { - continue - } - - permissions = append(permissions, dbcore.Permission{ - PermissionNode: permission.PermissionNode, - HasPermission: permission.HasPermission, - }) - } - - username := "" - - if user.Username == nil { - username = strings.ToLower(stripAllAlphanumeric(user.Email)) - log.Warnf("User with ID of '%d' doesn't have a username. Derived username from email is '%s' (email is '%s')", user.ID, username, user.Email) - } else { - username = *user.Username - } - - userDatabase := &dbcore.User{ - Email: user.Email, - Username: username, - Name: user.Name, - Password: base64.StdEncoding.EncodeToString([]byte(user.Password)), - IsBot: user.IsBot, - - Tokens: tokens, - Permissions: permissions, - } - - if err := dbcore.DB.Create(userDatabase).Error; err != nil { - log.Errorf("Failed to create user: %s", err.Error()) - continue - } - - if uint(bestEffortOwnerUIDFromBackup) == user.ID { - bestEffortOwnerUID = userDatabase.ID - } - } - - for _, backend := range backupData.Backends { - log.Infof("Migrating backend ID '%d' with name '%s'", backend.ID, backend.Name) - - backendDatabase := &dbcore.Backend{ - UserID: bestEffortOwnerUID, - Name: backend.Name, - Description: backend.Description, - Backend: backend.Backend, - BackendParameters: base64.StdEncoding.EncodeToString([]byte(backend.BackendParameters)), - } - - if err := dbcore.DB.Create(backendDatabase).Error; err != nil { - log.Errorf("Failed to create backend: %s", err.Error()) - continue - } - - log.Infof("Migrating proxies for backend ID '%d'", backend.ID) - - for _, proxy := range backupData.Proxies { - if proxy.BackendID != backend.ID { - continue - } - - log.Infof("Migrating proxy ID '%d' with name '%s'", proxy.ID, proxy.Name) - - proxyDatabase := &dbcore.Proxy{ - BackendID: backendDatabase.ID, - UserID: bestEffortOwnerUID, - - Name: proxy.Name, - Description: proxy.Description, - Protocol: proxy.Protocol, - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestinationPort: proxy.DestinationPort, - AutoStart: proxy.AutoStart, - } - - if err := dbcore.DB.Create(proxyDatabase).Error; err != nil { - log.Errorf("Failed to create proxy: %s", err.Error()) - } - } - } - - log.Info("Successfully upgraded to Hermes from NextNet.") - - return nil -} diff --git a/backend/api/controllers/v1/backends/create.go b/backend/api/controllers/v1/backends/create.go index 287f314..314dc3e 100644 --- a/backend/api/controllers/v1/backends/create.go +++ b/backend/api/controllers/v1/backends/create.go @@ -7,13 +7,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type BackendCreationRequest struct { @@ -24,131 +23,114 @@ type BackendCreationRequest struct { BackendParameters interface{} `json:"connectionDetails" validate:"required"` } -func CreateBackend(c *gin.Context) { - var req BackendCreationRequest +func SetupCreateBackend(state *state.State) { + state.Engine.POST("/api/v1/backends/create", func(c *gin.Context) { + var req BackendCreationRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) + user, err := state.JWT.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "backends.add") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Missing permissions", }) return } - } - if !permissions.UserHasPermission(user, "backends.add") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + var backendParameters []byte - return - } + switch parameters := req.BackendParameters.(type) { + case string: + backendParameters = []byte(parameters) + case map[string]interface{}: + backendParameters, err = json.Marshal(parameters) - var backendParameters []byte + if err != nil { + log.Warnf("Failed to marshal JSON recieved as BackendParameters: %s", err.Error()) - switch parameters := req.BackendParameters.(type) { - case string: - backendParameters = []byte(parameters) - case map[string]interface{}: - backendParameters, err = json.Marshal(parameters) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to prepare parameters", + }) + + return + } + default: + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Invalid type for connectionDetails (recieved %T)", parameters), + }) + + return + } + + var backendRuntimeFilePath string + + for _, runtime := range backendruntime.AvailableBackends { + if runtime.Name == req.Backend { + backendRuntimeFilePath = runtime.Path + } + } + + if backendRuntimeFilePath == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Unsupported backend recieved", + }) + + return + } + + backend := backendruntime.NewBackend(backendRuntimeFilePath) + err = backend.Start() if err != nil { - log.Warnf("Failed to marshal JSON recieved as BackendParameters: %s", err.Error()) + log.Warnf("Failed to start backend: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to prepare parameters", + "error": "Failed to start backend", }) return } - default: - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Invalid type for connectionDetails (recieved %T)", parameters), + + backendParamCheckResponse, err := backend.ProcessCommand(&commonbackend.CheckServerParameters{ + Arguments: backendParameters, }) - return - } - - var backendRuntimeFilePath string - - for _, runtime := range backendruntime.AvailableBackends { - if runtime.Name == req.Backend { - backendRuntimeFilePath = runtime.Path - } - } - - if backendRuntimeFilePath == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Unsupported backend recieved", - }) - - return - } - - backend := backendruntime.NewBackend(backendRuntimeFilePath) - err = backend.Start() - - if err != nil { - log.Warnf("Failed to start backend: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to start backend", - }) - - return - } - - backendParamCheckResponse, err := backend.ProcessCommand(&commonbackend.CheckServerParameters{ - Arguments: backendParameters, - }) - - if err != nil { - log.Warnf("Failed to get response for backend: %s", err.Error()) - - err = backend.Stop() - if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) - } - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get status response from backend", - }) - - return - } - - switch responseMessage := backendParamCheckResponse.(type) { - case *commonbackend.CheckParametersResponse: - if responseMessage.InResponseTo != "checkServerParameters" { - log.Errorf("Got illegal response to CheckServerParameters: %s", responseMessage.InResponseTo) + log.Warnf("Failed to get response for backend: %s", err.Error()) err = backend.Stop() @@ -163,107 +145,126 @@ func CreateBackend(c *gin.Context) { return } - if !responseMessage.IsValid { + switch responseMessage := backendParamCheckResponse.(type) { + case *commonbackend.CheckParametersResponse: + if responseMessage.InResponseTo != "checkServerParameters" { + log.Errorf("Got illegal response to CheckServerParameters: %s", responseMessage.InResponseTo) + + err = backend.Stop() + + if err != nil { + log.Warnf("Failed to stop backend: %s", err.Error()) + } + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get status response from backend", + }) + + return + } + + if !responseMessage.IsValid { + err = backend.Stop() + + if err != nil { + log.Warnf("Failed to stop backend: %s", err.Error()) + } + + var errorMessage string + + if responseMessage.Message == "" { + errorMessage = "Unkown error while trying to parse connectionDetails" + } else { + errorMessage = fmt.Sprintf("Invalid backend parameters: %s", responseMessage.Message) + } + + c.JSON(http.StatusBadRequest, gin.H{ + "error": errorMessage, + }) + + return + } + default: + log.Warnf("Got illegal response type for backend: %T", responseMessage) + } + + log.Info("Passed backend checks successfully") + + backendInDatabase := &db.Backend{ + UserID: user.ID, + Name: req.Name, + Description: req.Description, + Backend: req.Backend, + BackendParameters: base64.StdEncoding.EncodeToString(backendParameters), + } + + if result := state.DB.DB.Create(&backendInDatabase); result.Error != nil { + log.Warnf("Failed to create backend: %s", result.Error.Error()) + err = backend.Stop() if err != nil { log.Warnf("Failed to stop backend: %s", err.Error()) } - var errorMessage string - - if responseMessage.Message == "" { - errorMessage = "Unkown error while trying to parse connectionDetails" - } else { - errorMessage = fmt.Sprintf("Invalid backend parameters: %s", responseMessage.Message) - } - - c.JSON(http.StatusBadRequest, gin.H{ - "error": errorMessage, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add backend into database", }) return } - default: - log.Warnf("Got illegal response type for backend: %T", responseMessage) - } - log.Info("Passed backend checks successfully") - - backendInDatabase := &dbcore.Backend{ - UserID: user.ID, - Name: req.Name, - Description: req.Description, - Backend: req.Backend, - BackendParameters: base64.StdEncoding.EncodeToString(backendParameters), - } - - if result := dbcore.DB.Create(&backendInDatabase); result.Error != nil { - log.Warnf("Failed to create backend: %s", result.Error.Error()) - - err = backend.Stop() - - if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) - } - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add backend into database", + backendStartResponse, err := backend.ProcessCommand(&commonbackend.Start{ + Arguments: backendParameters, }) - return - } - - backendStartResponse, err := backend.ProcessCommand(&commonbackend.Start{ - Arguments: backendParameters, - }) - - if err != nil { - log.Warnf("Failed to get response for backend: %s", err.Error()) - - err = backend.Stop() - if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) - } + log.Warnf("Failed to get response for backend: %s", err.Error()) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get status response from backend", - }) - - return - } - - switch responseMessage := backendStartResponse.(type) { - case *commonbackend.BackendStatusResponse: - if !responseMessage.IsRunning { err = backend.Stop() if err != nil { - log.Warnf("Failed to start backend: %s", err.Error()) + log.Warnf("Failed to stop backend: %s", err.Error()) } - var errorMessage string - - if responseMessage.Message == "" { - errorMessage = "Unkown error while trying to start the backend" - } else { - errorMessage = fmt.Sprintf("Failed to start backend: %s", responseMessage.Message) - } - - c.JSON(http.StatusBadRequest, gin.H{ - "error": errorMessage, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get status response from backend", }) return } - default: - log.Warnf("Got illegal response type for backend: %T", responseMessage) - } - backendruntime.RunningBackends[backendInDatabase.ID] = backend + switch responseMessage := backendStartResponse.(type) { + case *commonbackend.BackendStatusResponse: + if !responseMessage.IsRunning { + err = backend.Stop() - c.JSON(http.StatusOK, gin.H{ - "success": true, + if err != nil { + log.Warnf("Failed to start backend: %s", err.Error()) + } + + var errorMessage string + + if responseMessage.Message == "" { + errorMessage = "Unkown error while trying to start the backend" + } else { + errorMessage = fmt.Sprintf("Failed to start backend: %s", responseMessage.Message) + } + + c.JSON(http.StatusBadRequest, gin.H{ + "error": errorMessage, + }) + + return + } + default: + log.Warnf("Got illegal response type for backend: %T", responseMessage) + } + + backendruntime.RunningBackends[backendInDatabase.ID] = backend + + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) }) } diff --git a/backend/api/controllers/v1/backends/lookup.go b/backend/api/controllers/v1/backends/lookup.go index 9819df7..6cbb386 100644 --- a/backend/api/controllers/v1/backends/lookup.go +++ b/backend/api/controllers/v1/backends/lookup.go @@ -7,12 +7,11 @@ import ( "strings" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type BackendLookupRequest struct { @@ -38,95 +37,80 @@ type LookupResponse struct { Data []*SanitizedBackend `json:"data"` } -func LookupBackend(c *gin.Context) { - var req BackendLookupRequest +func SetupLookupBackend(state *state.State) { + state.Engine.POST("/api/v1/backends/lookup", func(c *gin.Context) { + var req BackendLookupRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - if !permissions.UserHasPermission(user, "backends.visible") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - backends := []dbcore.Backend{} - queryString := []string{} - queryParameters := []interface{}{} + user, err := state.JWT.GetUserFromJWT(req.Token) - if req.BackendID != nil { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, req.BackendID) - } + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) - if req.Name != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - if req.Description != nil { - queryString = append(queryString, "description = ?") - queryParameters = append(queryParameters, req.Description) - } + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) - if req.Backend != nil { - queryString = append(queryString, "is_bot = ?") - queryParameters = append(queryParameters, req.Backend) - } + return + } + } - if err := dbcore.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&backends).Error; err != nil { - log.Warnf("Failed to get backends: %s", err.Error()) + if !permissions.UserHasPermission(user, "backends.visible") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get backends", - }) + return + } - return - } + backends := []db.Backend{} + queryString := []string{} + queryParameters := []interface{}{} - sanitizedBackends := make([]*SanitizedBackend, len(backends)) - hasSecretVisibility := permissions.UserHasPermission(user, "backends.secretVis") + if req.BackendID != nil { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, req.BackendID) + } - for backendIndex, backend := range backends { - foundBackend, ok := backendruntime.RunningBackends[backend.ID] + if req.Name != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } - if !ok { - log.Warnf("Failed to get backend #%d controller", backend.ID) + if req.Description != nil { + queryString = append(queryString, "description = ?") + queryParameters = append(queryParameters, req.Description) + } + + if req.Backend != nil { + queryString = append(queryString, "is_bot = ?") + queryParameters = append(queryParameters, req.Backend) + } + + if err := state.DB.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&backends).Error; err != nil { + log.Warnf("Failed to get backends: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ "error": "Failed to get backends", @@ -135,29 +119,46 @@ func LookupBackend(c *gin.Context) { return } - sanitizedBackends[backendIndex] = &SanitizedBackend{ - BackendID: backend.ID, - OwnerID: backend.UserID, - Name: backend.Name, - Description: backend.Description, - Backend: backend.Backend, - Logs: foundBackend.Logs, - } + sanitizedBackends := make([]*SanitizedBackend, len(backends)) + hasSecretVisibility := permissions.UserHasPermission(user, "backends.secretVis") - if backend.UserID == user.ID || hasSecretVisibility { - backendParametersBytes, err := base64.StdEncoding.DecodeString(backend.BackendParameters) + for backendIndex, backend := range backends { + foundBackend, ok := backendruntime.RunningBackends[backend.ID] - if err != nil { - log.Warnf("Failed to decode base64 backend parameters: %s", err.Error()) + if !ok { + log.Warnf("Failed to get backend #%d controller", backend.ID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get backends", + }) + + return } - backendParameters := string(backendParametersBytes) - sanitizedBackends[backendIndex].BackendParameters = &backendParameters - } - } + sanitizedBackends[backendIndex] = &SanitizedBackend{ + BackendID: backend.ID, + OwnerID: backend.UserID, + Name: backend.Name, + Description: backend.Description, + Backend: backend.Backend, + Logs: foundBackend.Logs, + } - c.JSON(http.StatusOK, &LookupResponse{ - Success: true, - Data: sanitizedBackends, + if backend.UserID == user.ID || hasSecretVisibility { + backendParametersBytes, err := base64.StdEncoding.DecodeString(backend.BackendParameters) + + if err != nil { + log.Warnf("Failed to decode base64 backend parameters: %s", err.Error()) + } + + backendParameters := string(backendParametersBytes) + sanitizedBackends[backendIndex].BackendParameters = &backendParameters + } + } + + c.JSON(http.StatusOK, &LookupResponse{ + Success: true, + Data: sanitizedBackends, + }) }) } diff --git a/backend/api/controllers/v1/backends/remove.go b/backend/api/controllers/v1/backends/remove.go index f9dbd8c..338ccbd 100644 --- a/backend/api/controllers/v1/backends/remove.go +++ b/backend/api/controllers/v1/backends/remove.go @@ -5,12 +5,11 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type BackendRemovalRequest struct { @@ -18,106 +17,108 @@ type BackendRemovalRequest struct { BackendID uint `json:"id" validate:"required"` } -func RemoveBackend(c *gin.Context) { - var req BackendRemovalRequest +func SetupRemoveBackend(state *state.State) { + state.Engine.POST("/api/v1/backends/remove", func(c *gin.Context) { + var req BackendRemovalRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - if !permissions.UserHasPermission(user, "backends.remove") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - var backend *dbcore.Backend - backendRequest := dbcore.DB.Where("id = ?", req.BackendID).Find(&backend) - - if backendRequest.Error != nil { - log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if backend exists", - }) - - return - } - - backendExists := backendRequest.RowsAffected > 0 - - if !backendExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Backend doesn't exist", - }) - - return - } - - if err := dbcore.DB.Delete(backend).Error; err != nil { - log.Warnf("failed to delete backend: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to delete backend", - }) - - return - } - - backendInstance, ok := backendruntime.RunningBackends[req.BackendID] - - if ok { - err = backendInstance.Stop() + user, err := state.JWT.GetUserFromJWT(req.Token) if err != nil { - log.Warnf("Failed to stop backend: %s", err.Error()) + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Backend deleted, but failed to stop", + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "backends.remove") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", }) - delete(backendruntime.RunningBackends, req.BackendID) return } - delete(backendruntime.RunningBackends, req.BackendID) - } + var backend *db.Backend + backendRequest := state.DB.DB.Where("id = ?", req.BackendID).Find(&backend) - c.JSON(http.StatusOK, gin.H{ - "success": true, + if backendRequest.Error != nil { + log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if backend exists", + }) + + return + } + + backendExists := backendRequest.RowsAffected > 0 + + if !backendExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Backend doesn't exist", + }) + + return + } + + if err := state.DB.DB.Delete(backend).Error; err != nil { + log.Warnf("failed to delete backend: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to delete backend", + }) + + return + } + + backendInstance, ok := backendruntime.RunningBackends[req.BackendID] + + if ok { + err = backendInstance.Stop() + + if err != nil { + log.Warnf("Failed to stop backend: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Backend deleted, but failed to stop", + }) + + delete(backendruntime.RunningBackends, req.BackendID) + return + } + + delete(backendruntime.RunningBackends, req.BackendID) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) }) } diff --git a/backend/api/controllers/v1/proxies/connections.go b/backend/api/controllers/v1/proxies/connections.go index 5eb3db9..7ea42fb 100644 --- a/backend/api/controllers/v1/proxies/connections.go +++ b/backend/api/controllers/v1/proxies/connections.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ConnectionsRequest struct { @@ -37,127 +36,130 @@ type ConnectionsResponse struct { Data []*SanitizedConnection `json:"data"` } -func GetConnections(c *gin.Context) { - var req ConnectionsRequest +func SetupGetConnections(state *state.State) { + state.Engine.POST("/api/v1/forward/connections", func(c *gin.Context) { + var req ConnectionsRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - if !permissions.UserHasPermission(user, "routes.visibleConn") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - var proxy dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.Id).First(&proxy) + user, err := state.JWT.GetUserFromJWT(req.Token) - if proxyRequest.Error != nil { - log.Warnf("failed to find proxy: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find forward entry", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "No forward entry found", - }) - - return - } - - backendRuntime, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backendRuntime.ProcessCommand(&commonbackend.ProxyConnectionsRequest{}) - - if err != nil { - log.Warnf("Failed to get response for backend: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get status response from backend", - }) - - return - } - - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyConnectionsResponse: - sanitizedConnections := []*SanitizedConnection{} - - for _, connection := range responseMessage.Connections { - if connection.SourceIP == proxy.SourceIP && connection.SourcePort == proxy.SourcePort && proxy.DestinationPort == proxy.DestinationPort { - sanitizedConnections = append(sanitizedConnections, &SanitizedConnection{ - ClientIP: connection.ClientIP, - Port: connection.ClientPort, - - ConnectionDetails: &ConnectionDetailsForConnection{ - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - }, + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return } } - c.JSON(http.StatusOK, &ConnectionsResponse{ - Success: true, - Data: sanitizedConnections, - }) - default: - log.Warnf("Got illegal response type for backend: %T", responseMessage) + if !permissions.UserHasPermission(user, "routes.visibleConn") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got illegal response type", - }) - } + return + } + + var proxy db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.Id).First(&proxy) + + if proxyRequest.Error != nil { + log.Warnf("failed to find proxy: %s", proxyRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find forward entry", + }) + + return + } + + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "No forward entry found", + }) + + return + } + + backendRuntime, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backendRuntime.ProcessCommand(&commonbackend.ProxyConnectionsRequest{}) + + if err != nil { + log.Warnf("Failed to get response for backend: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get status response from backend", + }) + + return + } + + switch responseMessage := backendResponse.(type) { + case *commonbackend.ProxyConnectionsResponse: + sanitizedConnections := []*SanitizedConnection{} + + for _, connection := range responseMessage.Connections { + if connection.SourceIP == proxy.SourceIP && connection.SourcePort == proxy.SourcePort && proxy.DestinationPort == proxy.DestinationPort { + sanitizedConnections = append(sanitizedConnections, &SanitizedConnection{ + ClientIP: connection.ClientIP, + Port: connection.ClientPort, + + ConnectionDetails: &ConnectionDetailsForConnection{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + }, + }) + } + } + + c.JSON(http.StatusOK, &ConnectionsResponse{ + Success: true, + Data: sanitizedConnections, + }) + default: + log.Warnf("Got illegal response type for backend: %T", responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got illegal response type", + }) + } + }) } diff --git a/backend/api/controllers/v1/proxies/create.go b/backend/api/controllers/v1/proxies/create.go index 28e8c95..d790c49 100644 --- a/backend/api/controllers/v1/proxies/create.go +++ b/backend/api/controllers/v1/proxies/create.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyCreationRequest struct { @@ -26,150 +25,153 @@ type ProxyCreationRequest struct { AutoStart *bool `json:"autoStart"` } -func CreateProxy(c *gin.Context) { - var req ProxyCreationRequest +func SetupCreateProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/create", func(c *gin.Context) { + var req ProxyCreationRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", - }) - - return - } - } - - if !permissions.UserHasPermission(user, "routes.add") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) - - return - } - - if req.Protocol != "tcp" && req.Protocol != "udp" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Protocol must be either 'tcp' or 'udp'", - }) - - return - } - - var backend dbcore.Backend - backendRequest := dbcore.DB.Where("id = ?", req.ProviderID).First(&backend) - - if backendRequest.Error != nil { - log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if backend exists", - }) - } - - backendExists := backendRequest.RowsAffected > 0 - - if !backendExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Could not find backend", - }) - } - - autoStart := false - - if req.AutoStart != nil { - autoStart = *req.AutoStart - } - - proxy := &dbcore.Proxy{ - UserID: user.ID, - BackendID: req.ProviderID, - Name: req.Name, - Description: req.Description, - Protocol: req.Protocol, - SourceIP: req.SourceIP, - SourcePort: req.SourcePort, - DestinationPort: req.DestinationPort, - AutoStart: autoStart, - } - - if result := dbcore.DB.Create(proxy); result.Error != nil { - log.Warnf("failed to create proxy: %s", result.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add forward rule to database", - }) - - return - } - - if autoStart { - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "id": proxy.ID, + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) if err != nil { - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get response from backend", + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.add") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", }) return } - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyStatusResponse: - if !responseMessage.IsActive { - log.Warnf("Failed to start proxy for backend #%d", proxy.BackendID) - } - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) - } - } + if req.Protocol != "tcp" && req.Protocol != "udp" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Protocol must be either 'tcp' or 'udp'", + }) - c.JSON(http.StatusOK, gin.H{ - "success": true, - "id": proxy.ID, + return + } + + var backend db.Backend + backendRequest := state.DB.DB.Where("id = ?", req.ProviderID).First(&backend) + + if backendRequest.Error != nil { + log.Warnf("failed to find if backend exists or not: %s", backendRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if backend exists", + }) + } + + backendExists := backendRequest.RowsAffected > 0 + + if !backendExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Could not find backend", + }) + } + + autoStart := false + + if req.AutoStart != nil { + autoStart = *req.AutoStart + } + + proxy := &db.Proxy{ + UserID: user.ID, + BackendID: req.ProviderID, + Name: req.Name, + Description: req.Description, + Protocol: req.Protocol, + SourceIP: req.SourceIP, + SourcePort: req.SourcePort, + DestinationPort: req.DestinationPort, + AutoStart: autoStart, + } + + if result := state.DB.DB.Create(proxy); result.Error != nil { + log.Warnf("failed to create proxy: %s", result.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add forward rule to database", + }) + + return + } + + if autoStart { + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "id": proxy.ID, + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, + }) + + if err != nil { + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get response from backend", + }) + + return + } + + switch responseMessage := backendResponse.(type) { + case *commonbackend.ProxyStatusResponse: + if !responseMessage.IsActive { + log.Warnf("Failed to start proxy for backend #%d", proxy.BackendID) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "id": proxy.ID, + }) }) } diff --git a/backend/api/controllers/v1/proxies/lookup.go b/backend/api/controllers/v1/proxies/lookup.go index 3b0d4c3..bf2c3ea 100644 --- a/backend/api/controllers/v1/proxies/lookup.go +++ b/backend/api/controllers/v1/proxies/lookup.go @@ -5,12 +5,11 @@ import ( "net/http" "strings" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyLookupRequest struct { @@ -43,141 +42,143 @@ type ProxyLookupResponse struct { Data []*SanitizedProxy `json:"data"` } -func LookupProxy(c *gin.Context) { - var req ProxyLookupRequest +func SetupLookupProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/lookup", func(c *gin.Context) { + var req ProxyLookupRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + } + + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.visible") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) + + return + } + + if req.Protocol != nil { + if *req.Protocol != "tcp" && *req.Protocol != "udp" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Protocol specified in body must either be 'tcp' or 'udp'", + }) + + return + } + } + + proxies := []db.Proxy{} + + queryString := []string{} + queryParameters := []interface{}{} + + if req.Id != nil { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, req.Id) + } + + if req.Name != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } + + if req.Description != nil { + queryString = append(queryString, "description = ?") + queryParameters = append(queryParameters, req.Description) + } + + if req.SourceIP != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } + + if req.SourcePort != nil { + queryString = append(queryString, "source_port = ?") + queryParameters = append(queryParameters, req.SourcePort) + } + + if req.DestinationPort != nil { + queryString = append(queryString, "destination_port = ?") + queryParameters = append(queryParameters, req.DestinationPort) + } + + if req.ProviderID != nil { + queryString = append(queryString, "backend_id = ?") + queryParameters = append(queryParameters, req.ProviderID) + } + + if req.AutoStart != nil { + queryString = append(queryString, "auto_start = ?") + queryParameters = append(queryParameters, req.AutoStart) + } + + if req.Protocol != nil { + queryString = append(queryString, "protocol = ?") + queryParameters = append(queryParameters, req.Protocol) + } + + if err := state.DB.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&proxies).Error; err != nil { + log.Warnf("failed to get proxies: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Failed to get proxies", }) return } - } - if !permissions.UserHasPermission(user, "routes.visible") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + sanitizedProxies := make([]*SanitizedProxy, len(proxies)) - return - } - - if req.Protocol != nil { - if *req.Protocol != "tcp" && *req.Protocol != "udp" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Protocol specified in body must either be 'tcp' or 'udp'", - }) - - return + for proxyIndex, proxy := range proxies { + sanitizedProxies[proxyIndex] = &SanitizedProxy{ + Id: proxy.ID, + Name: proxy.Name, + Description: proxy.Description, + Protcol: proxy.Protocol, + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestinationPort: proxy.DestinationPort, + ProviderID: proxy.BackendID, + AutoStart: proxy.AutoStart, + } } - } - proxies := []dbcore.Proxy{} - - queryString := []string{} - queryParameters := []interface{}{} - - if req.Id != nil { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, req.Id) - } - - if req.Name != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } - - if req.Description != nil { - queryString = append(queryString, "description = ?") - queryParameters = append(queryParameters, req.Description) - } - - if req.SourceIP != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } - - if req.SourcePort != nil { - queryString = append(queryString, "source_port = ?") - queryParameters = append(queryParameters, req.SourcePort) - } - - if req.DestinationPort != nil { - queryString = append(queryString, "destination_port = ?") - queryParameters = append(queryParameters, req.DestinationPort) - } - - if req.ProviderID != nil { - queryString = append(queryString, "backend_id = ?") - queryParameters = append(queryParameters, req.ProviderID) - } - - if req.AutoStart != nil { - queryString = append(queryString, "auto_start = ?") - queryParameters = append(queryParameters, req.AutoStart) - } - - if req.Protocol != nil { - queryString = append(queryString, "protocol = ?") - queryParameters = append(queryParameters, req.Protocol) - } - - if err := dbcore.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&proxies).Error; err != nil { - log.Warnf("failed to get proxies: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get proxies", + c.JSON(http.StatusOK, &ProxyLookupResponse{ + Success: true, + Data: sanitizedProxies, }) - - return - } - - sanitizedProxies := make([]*SanitizedProxy, len(proxies)) - - for proxyIndex, proxy := range proxies { - sanitizedProxies[proxyIndex] = &SanitizedProxy{ - Id: proxy.ID, - Name: proxy.Name, - Description: proxy.Description, - Protcol: proxy.Protocol, - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestinationPort: proxy.DestinationPort, - ProviderID: proxy.BackendID, - AutoStart: proxy.AutoStart, - } - } - - c.JSON(http.StatusOK, &ProxyLookupResponse{ - Success: true, - Data: sanitizedProxies, }) } diff --git a/backend/api/controllers/v1/proxies/remove.go b/backend/api/controllers/v1/proxies/remove.go index e41fa84..304c5c7 100644 --- a/backend/api/controllers/v1/proxies/remove.go +++ b/backend/api/controllers/v1/proxies/remove.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyRemovalRequest struct { @@ -19,134 +18,133 @@ type ProxyRemovalRequest struct { ID uint `validate:"required" json:"id"` } -func RemoveProxy(c *gin.Context) { - var req ProxyRemovalRequest +func SetupRemoveProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/remove", func(c *gin.Context) { + var req ProxyRemovalRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.remove") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Missing permissions", }) return } - } - if !permissions.UserHasPermission(user, "routes.remove") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + var proxy *db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.ID).Find(&proxy) - return - } + if proxyRequest.Error != nil { + log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - var proxy *dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.ID).Find(&proxy) - - if proxyRequest.Error != nil { - log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if forward rule exists", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Forward rule doesn't exist", - }) - - return - } - - if err := dbcore.DB.Delete(proxy).Error; err != nil { - log.Warnf("failed to delete proxy: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to delete forward rule", - }) - - return - } - - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) - - if err != nil { - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get response from backend. Proxy was still successfully deleted", - }) - - return - } - - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyStatusResponse: - if responseMessage.IsActive { c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to stop proxy. Proxy was still successfully deleted", + "error": "Failed to find if forward rule exists", }) return } - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got invalid response from backend. Proxy was still successfully deleted", + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Forward rule doesn't exist", + }) + + return + } + + if err := state.DB.DB.Delete(proxy).Error; err != nil { + log.Warnf("failed to delete proxy: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to delete forward rule", + }) + + return + } + + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, }) - return - } + if err != nil { + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) - c.JSON(http.StatusOK, gin.H{ - "success": true, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get response from backend. Proxy was still successfully deleted", + }) + + return + } + + switch responseMessage := backendResponse.(type) { + case *commonbackend.ProxyStatusResponse: + if responseMessage.IsActive { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to stop proxy. Proxy was still successfully deleted", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got invalid response from backend. Proxy was still successfully deleted", + }) + } }) } diff --git a/backend/api/controllers/v1/proxies/start.go b/backend/api/controllers/v1/proxies/start.go index 1573382..1680ddf 100644 --- a/backend/api/controllers/v1/proxies/start.go +++ b/backend/api/controllers/v1/proxies/start.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyStartRequest struct { @@ -19,124 +18,119 @@ type ProxyStartRequest struct { ID uint `validate:"required" json:"id"` } -func StartProxy(c *gin.Context) { - var req ProxyStartRequest +func SetupStartProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/start", func(c *gin.Context) { + var req ProxyStartRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.start") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", - }) - - return - } - } - - if !permissions.UserHasPermission(user, "routes.start") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) - - return - } - - var proxy *dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.ID).Find(&proxy) - - if proxyRequest.Error != nil { - log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if forward rule exists", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Forward rule doesn't exist", - }) - - return - } - - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) - - switch responseMessage := backendResponse.(type) { - case error: - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, responseMessage.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get response from backend", - }) - - return - case *commonbackend.ProxyStatusResponse: - if !responseMessage.IsActive { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to start proxy", + "error": "Missing permissions", }) return } - break - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + var proxy *db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.ID).Find(&proxy) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got invalid response from backend. Proxy was still successfully deleted", + if proxyRequest.Error != nil { + log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if forward rule exists", + }) + + return + } + + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Forward rule doesn't exist", + }) + + return + } + + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.AddProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, }) - return - } + switch responseMessage := backendResponse.(type) { + case error: + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, responseMessage.Error()) - c.JSON(http.StatusOK, gin.H{ - "success": true, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get response from backend", + }) + case *commonbackend.ProxyStatusResponse: + if !responseMessage.IsActive { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to start proxy", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got invalid response from backend. Proxy was likely still successfully started", + }) + } }) } diff --git a/backend/api/controllers/v1/proxies/stop.go b/backend/api/controllers/v1/proxies/stop.go index 820cec4..27d63ce 100644 --- a/backend/api/controllers/v1/proxies/stop.go +++ b/backend/api/controllers/v1/proxies/stop.go @@ -5,13 +5,12 @@ import ( "net/http" "git.terah.dev/imterah/hermes/backend/api/backendruntime" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type ProxyStopRequest struct { @@ -19,124 +18,119 @@ type ProxyStopRequest struct { ID uint `validate:"required" json:"id"` } -func StopProxy(c *gin.Context) { - var req ProxyStopRequest +func SetupStopProxy(state *state.State) { + state.Engine.POST("/api/v1/forward/stop", func(c *gin.Context) { + var req ProxyStartRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } + return + } - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user, err := jwtcore.GetUserFromJWT(req.Token) - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + if !permissions.UserHasPermission(user, "routes.stop") { c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Missing permissions", }) return } - } - if !permissions.UserHasPermission(user, "routes.stop") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) + var proxy *db.Proxy + proxyRequest := state.DB.DB.Where("id = ?", req.ID).Find(&proxy) - return - } + if proxyRequest.Error != nil { + log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - var proxy *dbcore.Proxy - proxyRequest := dbcore.DB.Where("id = ?", req.ID).Find(&proxy) - - if proxyRequest.Error != nil { - log.Warnf("failed to find if proxy exists or not: %s", proxyRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if forward rule exists", - }) - - return - } - - proxyExists := proxyRequest.RowsAffected > 0 - - if !proxyExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Forward rule doesn't exist", - }) - - return - } - - backend, ok := backendruntime.RunningBackends[proxy.BackendID] - - if !ok { - log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Couldn't fetch backend runtime", - }) - - return - } - - backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ - SourceIP: proxy.SourceIP, - SourcePort: proxy.SourcePort, - DestPort: proxy.DestinationPort, - Protocol: proxy.Protocol, - }) - - if err != nil { - log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get response from backend", - }) - - return - } - - switch responseMessage := backendResponse.(type) { - case *commonbackend.ProxyStatusResponse: - if responseMessage.IsActive { c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to stop proxy", + "error": "Failed to find if forward rule exists", }) return } - default: - log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Got invalid response from backend. Proxy was still successfully deleted", + proxyExists := proxyRequest.RowsAffected > 0 + + if !proxyExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Forward rule doesn't exist", + }) + + return + } + + backend, ok := backendruntime.RunningBackends[proxy.BackendID] + + if !ok { + log.Warnf("Couldn't fetch backend runtime from backend ID #%d", proxy.BackendID) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Couldn't fetch backend runtime", + }) + + return + } + + backendResponse, err := backend.ProcessCommand(&commonbackend.RemoveProxy{ + SourceIP: proxy.SourceIP, + SourcePort: proxy.SourcePort, + DestPort: proxy.DestinationPort, + Protocol: proxy.Protocol, }) - return - } + switch responseMessage := backendResponse.(type) { + case error: + log.Warnf("Failed to get response for backend #%d: %s", proxy.BackendID, responseMessage.Error()) - c.JSON(http.StatusOK, gin.H{ - "success": true, + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get response from backend", + }) + case *commonbackend.ProxyStatusResponse: + if responseMessage.IsActive { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to stop proxy", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) + } + default: + log.Errorf("Got illegal response type for backend #%d: %T", proxy.BackendID, responseMessage) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Got invalid response from backend. Proxy was likely still successfully stopped", + }) + } }) } diff --git a/backend/api/controllers/v1/users/create.go b/backend/api/controllers/v1/users/create.go index e92858a..f39aa95 100644 --- a/backend/api/controllers/v1/users/create.go +++ b/backend/api/controllers/v1/users/create.go @@ -7,11 +7,9 @@ import ( "net/http" "strings" - "github.com/go-playground/validator/v10" - - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" permissionHelper "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" "golang.org/x/crypto/bcrypt" @@ -22,142 +20,141 @@ type UserCreationRequest struct { Email string `validate:"required"` Password string `validate:"required"` Username string `validate:"required"` - - // TODO: implement support - ExistingUserToken string `json:"token"` - IsBot bool + IsBot bool } -func CreateUser(c *gin.Context) { - if !signupEnabled && !unsafeSignup { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Signing up is not enabled at this time.", - }) +func SetupCreateUser(state *state.State) { + state.Engine.POST("/api/v1/users/create", func(c *gin.Context) { + if !signupEnabled && !unsafeSignup { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Signing up is not enabled at this time.", + }) - return - } - - var req UserCreationRequest - - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - var user *dbcore.User - userRequest := dbcore.DB.Where("email = ? OR username = ?", req.Email, req.Username).Find(&user) - - if userRequest.Error != nil { - log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if user exists", - }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if userExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "User already exists", - }) - - return - } - - passwordHashed, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) - - if err != nil { - log.Warnf("Failed to generate password for client upon signup: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate password hash", - }) - - return - } - - permissions := []dbcore.Permission{} - - for _, permission := range permissionHelper.DefaultPermissionNodes { - permissionEnabledState := false - - if unsafeSignup || strings.HasPrefix(permission, "routes.") || permission == "permissions.see" { - permissionEnabledState = true + return } - permissions = append(permissions, dbcore.Permission{ - PermissionNode: permission, - HasPermission: permissionEnabledState, - }) - } + var req UserCreationRequest - tokenRandomData := make([]byte, 80) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - if _, err := rand.Read(tokenRandomData); err != nil { - log.Warnf("Failed to read random data to use as token: %s", err.Error()) + return + } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - return - } + return + } - user = &dbcore.User{ - Email: req.Email, - Username: req.Username, - Name: req.Name, - IsBot: &req.IsBot, - Password: base64.StdEncoding.EncodeToString(passwordHashed), - Permissions: permissions, - Tokens: []dbcore.Token{ - { - Token: base64.StdEncoding.EncodeToString(tokenRandomData), - DisableExpiry: forceNoExpiryTokens, - CreationIPAddr: c.ClientIP(), + var user *db.User + userRequest := state.DB.DB.Where("email = ? OR username = ?", req.Email, req.Username).Find(&user) + + if userRequest.Error != nil { + log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if user exists", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if userExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "User already exists", + }) + + return + } + + passwordHashed, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) + + if err != nil { + log.Warnf("Failed to generate password for client upon signup: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate password hash", + }) + + return + } + + permissions := []db.Permission{} + + for _, permission := range permissionHelper.DefaultPermissionNodes { + permissionEnabledState := false + + if unsafeSignup || strings.HasPrefix(permission, "routes.") || permission == "permissions.see" { + permissionEnabledState = true + } + + permissions = append(permissions, db.Permission{ + PermissionNode: permission, + HasPermission: permissionEnabledState, + }) + } + + tokenRandomData := make([]byte, 80) + + if _, err := rand.Read(tokenRandomData); err != nil { + log.Warnf("Failed to read random data to use as token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + user = &db.User{ + Email: req.Email, + Username: req.Username, + Name: req.Name, + IsBot: &req.IsBot, + Password: base64.StdEncoding.EncodeToString(passwordHashed), + Permissions: permissions, + Tokens: []db.Token{ + { + Token: base64.StdEncoding.EncodeToString(tokenRandomData), + DisableExpiry: forceNoExpiryTokens, + CreationIPAddr: c.ClientIP(), + }, }, - }, - } + } - if result := dbcore.DB.Create(&user); result.Error != nil { - log.Warnf("Failed to create user: %s", result.Error.Error()) + if result := state.DB.DB.Create(&user); result.Error != nil { + log.Warnf("Failed to create user: %s", result.Error.Error()) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add user into database", + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add user into database", + }) + + return + } + + jwt, err := state.JWT.Generate(user.ID) + + if err != nil { + log.Warnf("Failed to generate JWT: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "token": jwt, + "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) - - return - } - - jwt, err := jwtcore.Generate(user.ID) - - if err != nil { - log.Warnf("Failed to generate JWT: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "token": jwt, - "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) } diff --git a/backend/api/controllers/v1/users/login.go b/backend/api/controllers/v1/users/login.go index b6fd586..ea4f2f3 100644 --- a/backend/api/controllers/v1/users/login.go +++ b/backend/api/controllers/v1/users/login.go @@ -6,11 +6,10 @@ import ( "fmt" "net/http" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" "golang.org/x/crypto/bcrypt" ) @@ -21,137 +20,139 @@ type UserLoginRequest struct { Password string `validate:"required"` } -func LoginUser(c *gin.Context) { - var req UserLoginRequest +func SetupLoginUser(state *state.State) { + state.Engine.POST("/api/v1/users/login", func(c *gin.Context) { + var req UserLoginRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) + + return + } + + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + if req.Email == nil && req.Username == nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Missing both email and username in body", + }) + + return + } + + userFindRequestArguments := make([]interface{}, 1) + userFindRequest := "" + + if req.Email != nil { + userFindRequestArguments[0] = &req.Email + userFindRequest += "email = ?" + } + + if req.Username != nil { + userFindRequestArguments[0] = &req.Username + userFindRequest += "username = ?" + } + + var user *db.User + userRequest := state.DB.DB.Where(userFindRequest, userFindRequestArguments...).Find(&user) + + if userRequest.Error != nil { + log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if user exists", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "User not found", + }) + + return + } + + decodedPassword := make([]byte, base64.StdEncoding.DecodedLen(len(user.Password))) + _, err := base64.StdEncoding.Decode(decodedPassword, []byte(user.Password)) + + if err != nil { + log.Warnf("failed to decode password in database: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse database result for password", + }) + + return + } + + err = bcrypt.CompareHashAndPassword(decodedPassword, []byte(req.Password)) + + if err != nil { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Invalid password", + }) + + return + } + + tokenRandomData := make([]byte, 80) + + if _, err := rand.Read(tokenRandomData); err != nil { + log.Warnf("Failed to read random data to use as token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + token := &db.Token{ + UserID: user.ID, + + Token: base64.StdEncoding.EncodeToString(tokenRandomData), + DisableExpiry: forceNoExpiryTokens, + CreationIPAddr: c.ClientIP(), + } + + if result := state.DB.DB.Create(&token); result.Error != nil { + log.Warnf("Failed to create user: %s", result.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to add refresh token into database", + }) + + return + } + + jwt, err := state.JWT.Generate(user.ID) + + if err != nil { + log.Warnf("Failed to generate JWT: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "token": jwt, + "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - if req.Email == nil && req.Username == nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Missing both email and username in body", - }) - - return - } - - userFindRequestArguments := make([]interface{}, 1) - userFindRequest := "" - - if req.Email != nil { - userFindRequestArguments[0] = &req.Email - userFindRequest += "email = ?" - } - - if req.Username != nil { - userFindRequestArguments[0] = &req.Username - userFindRequest += "username = ?" - } - - var user *dbcore.User - userRequest := dbcore.DB.Where(userFindRequest, userFindRequestArguments...).Find(&user) - - if userRequest.Error != nil { - log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if user exists", - }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "User not found", - }) - - return - } - - decodedPassword := make([]byte, base64.StdEncoding.DecodedLen(len(user.Password))) - _, err := base64.StdEncoding.Decode(decodedPassword, []byte(user.Password)) - - if err != nil { - log.Warnf("failed to decode password in database: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse database result for password", - }) - - return - } - - err = bcrypt.CompareHashAndPassword(decodedPassword, []byte(req.Password)) - - if err != nil { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Invalid password", - }) - - return - } - - tokenRandomData := make([]byte, 80) - - if _, err := rand.Read(tokenRandomData); err != nil { - log.Warnf("Failed to read random data to use as token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - token := &dbcore.Token{ - UserID: user.ID, - - Token: base64.StdEncoding.EncodeToString(tokenRandomData), - DisableExpiry: forceNoExpiryTokens, - CreationIPAddr: c.ClientIP(), - } - - if result := dbcore.DB.Create(&token); result.Error != nil { - log.Warnf("Failed to create user: %s", result.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to add refresh token into database", - }) - - return - } - - jwt, err := jwtcore.Generate(user.ID) - - if err != nil { - log.Warnf("Failed to generate JWT: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "token": jwt, - "refreshToken": base64.StdEncoding.EncodeToString(tokenRandomData), }) } diff --git a/backend/api/controllers/v1/users/lookup.go b/backend/api/controllers/v1/users/lookup.go index 04782e5..f5c14fc 100644 --- a/backend/api/controllers/v1/users/lookup.go +++ b/backend/api/controllers/v1/users/lookup.go @@ -5,12 +5,11 @@ import ( "net/http" "strings" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type UserLookupRequest struct { @@ -35,102 +34,104 @@ type LookupResponse struct { Data []*SanitizedUsers `json:"data"` } -func LookupUser(c *gin.Context) { - var req UserLookupRequest +func SetupLookupUser(state *state.State) { + state.Engine.POST("/api/v1/users/lookup", func(c *gin.Context) { + var req UserLookupRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + } + + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) + + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + users := []db.User{} + queryString := []string{} + queryParameters := []interface{}{} + + if !permissions.UserHasPermission(user, "users.lookup") { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, user.ID) + } else if permissions.UserHasPermission(user, "users.lookup") && req.UID != nil { + queryString = append(queryString, "id = ?") + queryParameters = append(queryParameters, req.UID) + } + + if req.Name != nil { + queryString = append(queryString, "name = ?") + queryParameters = append(queryParameters, req.Name) + } + + if req.Email != nil { + queryString = append(queryString, "email = ?") + queryParameters = append(queryParameters, req.Email) + } + + if req.IsBot != nil { + queryString = append(queryString, "is_bot = ?") + queryParameters = append(queryParameters, req.IsBot) + } + + if err := state.DB.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&users).Error; err != nil { + log.Warnf("Failed to get users: %s", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", + "error": "Failed to get users", }) return } - } - users := []dbcore.User{} - queryString := []string{} - queryParameters := []interface{}{} + sanitizedUsers := make([]*SanitizedUsers, len(users)) - if !permissions.UserHasPermission(user, "users.lookup") { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, user.ID) - } else if permissions.UserHasPermission(user, "users.lookup") && req.UID != nil { - queryString = append(queryString, "id = ?") - queryParameters = append(queryParameters, req.UID) - } + for userIndex, user := range users { + isBot := false - if req.Name != nil { - queryString = append(queryString, "name = ?") - queryParameters = append(queryParameters, req.Name) - } + if user.IsBot != nil { + isBot = *user.IsBot + } - if req.Email != nil { - queryString = append(queryString, "email = ?") - queryParameters = append(queryParameters, req.Email) - } + sanitizedUsers[userIndex] = &SanitizedUsers{ + UID: user.ID, + Name: user.Name, + Email: user.Email, + Username: user.Username, + IsBot: isBot, + } + } - if req.IsBot != nil { - queryString = append(queryString, "is_bot = ?") - queryParameters = append(queryParameters, req.IsBot) - } - - if err := dbcore.DB.Where(strings.Join(queryString, " AND "), queryParameters...).Find(&users).Error; err != nil { - log.Warnf("Failed to get users: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get users", + c.JSON(http.StatusOK, &LookupResponse{ + Success: true, + Data: sanitizedUsers, }) - - return - } - - sanitizedUsers := make([]*SanitizedUsers, len(users)) - - for userIndex, user := range users { - isBot := false - - if user.IsBot != nil { - isBot = *user.IsBot - } - - sanitizedUsers[userIndex] = &SanitizedUsers{ - UID: user.ID, - Name: user.Name, - Email: user.Email, - Username: user.Username, - IsBot: isBot, - } - } - - c.JSON(http.StatusOK, &LookupResponse{ - Success: true, - Data: sanitizedUsers, }) } diff --git a/backend/api/controllers/v1/users/refresh.go b/backend/api/controllers/v1/users/refresh.go index 59fcf80..ce9fbaf 100644 --- a/backend/api/controllers/v1/users/refresh.go +++ b/backend/api/controllers/v1/users/refresh.go @@ -5,113 +5,114 @@ import ( "net/http" "time" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type UserRefreshRequest struct { Token string `validate:"required"` } -func RefreshUserToken(c *gin.Context) { - var req UserRefreshRequest +func SetupRefreshUserToken(state *state.State) { + state.Engine.POST("/api/v1/users/refresh", func(c *gin.Context) { + var req UserRefreshRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), + }) - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - var tokenInDatabase *dbcore.Token - tokenRequest := dbcore.DB.Where("token = ?", req.Token).Find(&tokenInDatabase) - - if tokenRequest.Error != nil { - log.Warnf("failed to find if token exists or not: %s", tokenRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if token exists", - }) - - return - } - - tokenExists := tokenRequest.RowsAffected > 0 - - if !tokenExists { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Token not found", - }) - - return - } - - // First, we check to make sure that the key expiry is disabled before checking if the key is expired. - // Then, we check if the IP addresses differ, or if it has been 7 days since the token has been created. - if !tokenInDatabase.DisableExpiry && (c.ClientIP() != tokenInDatabase.CreationIPAddr || time.Now().Before(tokenInDatabase.CreatedAt.Add((24*7)*time.Hour))) { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Token has expired", - }) - - tx := dbcore.DB.Delete(tokenInDatabase) - - if tx.Error != nil { - log.Warnf("Failed to delete expired token from database: %s", tx.Error.Error()) + return } - return - } + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - // Get the user to check if the user exists before doing anything - var user *dbcore.User - userRequest := dbcore.DB.Where("id = ?", tokenInDatabase.UserID).Find(&user) + return + } - if tokenRequest.Error != nil { - log.Warnf("failed to find if token user or not: %s", userRequest.Error.Error()) + var tokenInDatabase *db.Token + tokenRequest := state.DB.DB.Where("token = ?", req.Token).Find(&tokenInDatabase) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find user", + if tokenRequest.Error != nil { + log.Warnf("failed to find if token exists or not: %s", tokenRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if token exists", + }) + + return + } + + tokenExists := tokenRequest.RowsAffected > 0 + + if !tokenExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Token not found", + }) + + return + } + + // First, we check to make sure that the key expiry is disabled before checking if the key is expired. + // Then, we check if the IP addresses differ, or if it has been 7 days since the token has been created. + if !tokenInDatabase.DisableExpiry && (c.ClientIP() != tokenInDatabase.CreationIPAddr || time.Now().Before(tokenInDatabase.CreatedAt.Add((24*7)*time.Hour))) { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Token has expired", + }) + + tx := state.DB.DB.Delete(tokenInDatabase) + + if tx.Error != nil { + log.Warnf("Failed to delete expired token from database: %s", tx.Error.Error()) + } + + return + } + + // Get the user to check if the user exists before doing anything + var user *db.User + userRequest := state.DB.DB.Where("id = ?", tokenInDatabase.UserID).Find(&user) + + if tokenRequest.Error != nil { + log.Warnf("failed to find if token user or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find user", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "User not found", + }) + + return + } + + jwt, err := state.JWT.Generate(user.ID) + + if err != nil { + log.Warnf("Failed to generate JWT: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate refresh token", + }) + + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "token": jwt, }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "User not found", - }) - - return - } - - jwt, err := jwtcore.Generate(user.ID) - - if err != nil { - log.Warnf("Failed to generate JWT: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate refresh token", - }) - - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "token": jwt, }) } diff --git a/backend/api/controllers/v1/users/remove.go b/backend/api/controllers/v1/users/remove.go index 5446755..59e460c 100644 --- a/backend/api/controllers/v1/users/remove.go +++ b/backend/api/controllers/v1/users/remove.go @@ -4,12 +4,11 @@ import ( "fmt" "net/http" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" "git.terah.dev/imterah/hermes/backend/api/permissions" + "git.terah.dev/imterah/hermes/backend/api/state" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) type UserRemovalRequest struct { @@ -17,89 +16,91 @@ type UserRemovalRequest struct { UID *uint `json:"uid"` } -func RemoveUser(c *gin.Context) { - var req UserRemovalRequest +func SetupRemoveUser(state *state.State) { + state.Engine.POST("/api/v1/users/remove", func(c *gin.Context) { + var req UserRemovalRequest - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), - }) - - return - } - - if err := validator.New().Struct(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), - }) - - return - } - - user, err := jwtcore.GetUserFromJWT(req.Token) - - if err != nil { - if err.Error() == "token is expired" || err.Error() == "user does not exist" { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - - return - } else { - log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to parse token", - }) - - return - } - } - - uid := user.ID - - if req.UID != nil { - uid = *req.UID - - if uid != user.ID && !permissions.UserHasPermission(user, "users.remove") { - c.JSON(http.StatusForbidden, gin.H{ - "error": "Missing permissions", - }) - - return - } - } - - // Make sure the user exists first if we have a custom UserID - - if uid != user.ID { - var customUser *dbcore.User - userRequest := dbcore.DB.Where("id = ?", uid).Find(customUser) - - if userRequest.Error != nil { - log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) - - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to find if user exists", - }) - - return - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { + if err := c.BindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ - "error": "User doesn't exist", + "error": fmt.Sprintf("Failed to parse body: %s", err.Error()), }) return } - } - dbcore.DB.Select("Tokens", "Permissions", "Proxys", "Backends").Where("id = ?", uid).Delete(user) + if err := state.Validator.Struct(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Failed to validate body: %s", err.Error()), + }) - c.JSON(http.StatusOK, gin.H{ - "success": true, + return + } + + user, err := state.JWT.GetUserFromJWT(req.Token) + + if err != nil { + if err.Error() == "token is expired" || err.Error() == "user does not exist" { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + + return + } else { + log.Warnf("Failed to get user from the provided JWT token: %s", err.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to parse token", + }) + + return + } + } + + uid := user.ID + + if req.UID != nil { + uid = *req.UID + + if uid != user.ID && !permissions.UserHasPermission(user, "users.remove") { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Missing permissions", + }) + + return + } + } + + // Make sure the user exists first if we have a custom UserID + + if uid != user.ID { + var customUser *db.User + userRequest := state.DB.DB.Where("id = ?", uid).Find(customUser) + + if userRequest.Error != nil { + log.Warnf("failed to find if user exists or not: %s", userRequest.Error.Error()) + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to find if user exists", + }) + + return + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "User doesn't exist", + }) + + return + } + } + + state.DB.DB.Select("Tokens", "Permissions", "Proxys", "Backends").Where("id = ?", uid).Delete(user) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) }) } diff --git a/backend/api/db/db.go b/backend/api/db/db.go new file mode 100644 index 0000000..295bff8 --- /dev/null +++ b/backend/api/db/db.go @@ -0,0 +1,77 @@ +package db + +import ( + "fmt" + "os" + + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type DB struct { + DB *gorm.DB +} + +func New(backend, params string) (*DB, error) { + var err error + + dialector, err := initDialector(backend, params) + + if err != nil { + return nil, fmt.Errorf("failed to initialize physical database: %s", err) + } + + database, err := gorm.Open(dialector) + + if err != nil { + return nil, fmt.Errorf("failed to open database: %s", err) + } + + return &DB{DB: database}, nil +} + +func (db *DB) DoMigrations() error { + if err := db.DB.AutoMigrate(&Proxy{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&Backend{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&Permission{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&Token{}); err != nil { + return err + } + + if err := db.DB.AutoMigrate(&User{}); err != nil { + return err + } + + return nil +} + +func initDialector(backend, params string) (gorm.Dialector, error) { + switch backend { + case "sqlite": + if params == "" { + return nil, fmt.Errorf("sqlite database file not specified") + } + + return sqlite.Open(params), nil + case "postgresql": + if params == "" { + return nil, fmt.Errorf("postgres DSN not specified") + } + + return postgres.Open(params), nil + case "": + return nil, fmt.Errorf("no database backend specified in environment variables") + default: + return nil, fmt.Errorf("unknown database backend specified: %s", os.Getenv(backend)) + } +} diff --git a/backend/api/db/models.go b/backend/api/db/models.go new file mode 100644 index 0000000..290cd6e --- /dev/null +++ b/backend/api/db/models.go @@ -0,0 +1,66 @@ +package db + +import ( + "gorm.io/gorm" +) + +type Backend struct { + gorm.Model + + UserID uint + + Name string + Description *string + Backend string + BackendParameters string + + Proxies []Proxy +} + +type Proxy struct { + gorm.Model + + BackendID uint + UserID uint + + Name string + Description *string + Protocol string + SourceIP string + SourcePort uint16 + DestinationPort uint16 + AutoStart bool +} + +type Permission struct { + gorm.Model + + PermissionNode string + HasPermission bool + UserID uint +} + +type Token struct { + gorm.Model + + UserID uint + + Token string + DisableExpiry bool + CreationIPAddr string +} + +type User struct { + gorm.Model + + Email string `gorm:"unique"` + Username string `gorm:"unique"` + Name string + Password string + IsBot *bool + + Permissions []Permission + OwnedProxies []Proxy + OwnedBackends []Backend + Tokens []Token +} diff --git a/backend/api/dbcore/db.go b/backend/api/dbcore/db.go deleted file mode 100644 index b5e0676..0000000 --- a/backend/api/dbcore/db.go +++ /dev/null @@ -1,142 +0,0 @@ -package dbcore - -import ( - "fmt" - "os" - - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/gorm" -) - -type Backend struct { - gorm.Model - - UserID uint - - Name string - Description *string - Backend string - BackendParameters string - - Proxies []Proxy -} - -type Proxy struct { - gorm.Model - - BackendID uint - UserID uint - - Name string - Description *string - Protocol string - SourceIP string - SourcePort uint16 - DestinationPort uint16 - AutoStart bool -} - -type Permission struct { - gorm.Model - - PermissionNode string - HasPermission bool - UserID uint -} - -type Token struct { - gorm.Model - - UserID uint - - Token string - DisableExpiry bool - CreationIPAddr string -} - -type User struct { - gorm.Model - - Email string `gorm:"unique"` - Username string `gorm:"unique"` - Name string - Password string - IsBot *bool - - Permissions []Permission - OwnedProxies []Proxy - OwnedBackends []Backend - Tokens []Token -} - -var DB *gorm.DB - -func InitializeDatabaseDialector() (gorm.Dialector, error) { - databaseBackend := os.Getenv("HERMES_DATABASE_BACKEND") - - switch databaseBackend { - case "sqlite": - filePath := os.Getenv("HERMES_SQLITE_FILEPATH") - - if filePath == "" { - return nil, fmt.Errorf("sqlite database file not specified (missing HERMES_SQLITE_FILEPATH)") - } - - return sqlite.Open(filePath), nil - case "postgresql": - postgresDSN := os.Getenv("HERMES_POSTGRES_DSN") - - if postgresDSN == "" { - return nil, fmt.Errorf("postgres DSN not specified (missing HERMES_POSTGRES_DSN)") - } - - return postgres.Open(postgresDSN), nil - case "": - return nil, fmt.Errorf("no database backend specified in environment variables (missing HERMES_DATABASE_BACKEND)") - default: - return nil, fmt.Errorf("unknown database backend specified: %s", os.Getenv(databaseBackend)) - } -} - -func InitializeDatabase(config *gorm.Config) error { - var err error - - dialector, err := InitializeDatabaseDialector() - - if err != nil { - return fmt.Errorf("failed to initialize physical database: %s", err) - } - - DB, err = gorm.Open(dialector, config) - - if err != nil { - return fmt.Errorf("failed to open database: %s", err) - } - - return nil -} - -func DoDatabaseMigrations(db *gorm.DB) error { - if err := db.AutoMigrate(&Proxy{}); err != nil { - return err - } - - if err := db.AutoMigrate(&Backend{}); err != nil { - return err - } - - if err := db.AutoMigrate(&Permission{}); err != nil { - return err - } - - if err := db.AutoMigrate(&Token{}); err != nil { - return err - } - - if err := db.AutoMigrate(&User{}); err != nil { - return err - } - - return nil -} diff --git a/backend/api/jwt/jwt.go b/backend/api/jwt/jwt.go new file mode 100644 index 0000000..40e011b --- /dev/null +++ b/backend/api/jwt/jwt.go @@ -0,0 +1,107 @@ +package jwt + +import ( + "errors" + "fmt" + "strconv" + "time" + + "git.terah.dev/imterah/hermes/backend/api/db" + "github.com/golang-jwt/jwt/v5" +) + +var ( + DevelopmentModeTimings = time.Duration(60*24) * time.Minute + NormalModeTimings = time.Duration(3) * time.Minute +) + +type JWTCore struct { + Key []byte + Database *db.DB + TimeMultiplier time.Duration +} + +func New(key []byte, database *db.DB, timeMultiplier time.Duration) *JWTCore { + jwtCore := &JWTCore{ + Key: key, + Database: database, + TimeMultiplier: timeMultiplier, + } + + return jwtCore +} + +func (jwtCore *JWTCore) Parse(tokenString string, options ...jwt.ParserOption) (*jwt.Token, error) { + return jwt.Parse(tokenString, jwtCore.jwtKeyCallback, options...) +} + +func (jwtCore *JWTCore) GetUserFromJWT(token string) (*db.User, error) { + if jwtCore.Database == nil { + return nil, fmt.Errorf("database is not initialized") + } + + parsedJWT, err := jwtCore.Parse(token) + + if err != nil { + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, fmt.Errorf("token is expired") + } else { + return nil, err + } + } + + audience, err := parsedJWT.Claims.GetAudience() + + if err != nil { + return nil, err + } + + if len(audience) < 1 { + return nil, fmt.Errorf("audience is too small") + } + + uid, err := strconv.Atoi(audience[0]) + + if err != nil { + return nil, err + } + + user := &db.User{} + userRequest := jwtCore.Database.DB.Preload("Permissions").Where("id = ?", uint(uid)).Find(&user) + + if userRequest.Error != nil { + return user, fmt.Errorf("failed to find if user exists or not: %s", userRequest.Error.Error()) + } + + userExists := userRequest.RowsAffected > 0 + + if !userExists { + return user, fmt.Errorf("user does not exist") + } + + return user, nil +} + +func (jwtCore *JWTCore) Generate(uid uint) (string, error) { + currentJWTTime := jwt.NewNumericDate(time.Now()) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(jwtCore.TimeMultiplier)), + IssuedAt: currentJWTTime, + NotBefore: currentJWTTime, + // Convert the user ID to a string, and then set it as the audience parameters only value (there's only 1 user per key) + Audience: []string{strconv.Itoa(int(uid))}, + }) + + signedToken, err := token.SignedString(jwtCore.Key) + + if err != nil { + return "", err + } + + return signedToken, nil +} + +func (jwtCore *JWTCore) jwtKeyCallback(*jwt.Token) (any, error) { + return jwtCore.Key, nil +} diff --git a/backend/api/jwtcore/jwt.go b/backend/api/jwtcore/jwt.go deleted file mode 100644 index ed2ec56..0000000 --- a/backend/api/jwtcore/jwt.go +++ /dev/null @@ -1,117 +0,0 @@ -package jwtcore - -import ( - "encoding/base64" - "errors" - "fmt" - "os" - "strconv" - "time" - - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "github.com/golang-jwt/jwt/v5" -) - -var ( - JWTKey []byte - developmentMode bool -) - -func SetupJWT() error { - var err error - jwtDataString := os.Getenv("HERMES_JWT_SECRET") - - if jwtDataString == "" { - return fmt.Errorf("JWT secret isn't set (missing HERMES_JWT_SECRET)") - } - - if os.Getenv("HERMES_JWT_BASE64_ENCODED") != "" { - JWTKey, err = base64.StdEncoding.DecodeString(jwtDataString) - - if err != nil { - return fmt.Errorf("failed to decode base64 JWT: %s", err.Error()) - } - } else { - JWTKey = []byte(jwtDataString) - } - - if os.Getenv("HERMES_DEVELOPMENT_MODE") != "" { - developmentMode = true - } - - return nil -} - -func Parse(tokenString string, options ...jwt.ParserOption) (*jwt.Token, error) { - return jwt.Parse(tokenString, JWTKeyCallback, options...) -} - -func GetUserFromJWT(token string) (*dbcore.User, error) { - parsedJWT, err := Parse(token) - - if err != nil { - if errors.Is(err, jwt.ErrTokenExpired) { - return nil, fmt.Errorf("token is expired") - } else { - return nil, err - } - } - - audience, err := parsedJWT.Claims.GetAudience() - - if err != nil { - return nil, err - } - - if len(audience) < 1 { - return nil, fmt.Errorf("audience is too small") - } - - uid, err := strconv.Atoi(audience[0]) - - if err != nil { - return nil, err - } - - user := &dbcore.User{} - userRequest := dbcore.DB.Preload("Permissions").Where("id = ?", uint(uid)).Find(&user) - - if userRequest.Error != nil { - return user, fmt.Errorf("failed to find if user exists or not: %s", userRequest.Error.Error()) - } - - userExists := userRequest.RowsAffected > 0 - - if !userExists { - return user, fmt.Errorf("user does not exist") - } - - return user, nil -} - -func Generate(uid uint) (string, error) { - timeMultiplier := 3 - - if developmentMode { - timeMultiplier = 60 * 24 - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(timeMultiplier) * time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Audience: []string{strconv.Itoa(int(uid))}, - }) - - signedToken, err := token.SignedString(JWTKey) - - if err != nil { - return "", err - } - - return signedToken, nil -} - -func JWTKeyCallback(*jwt.Token) (interface{}, error) { - return JWTKey, nil -} diff --git a/backend/api/main.go b/backend/api/main.go index c88f1ae..f7ba678 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -9,18 +9,19 @@ import ( "path" "path/filepath" "strings" + "time" "git.terah.dev/imterah/hermes/backend/api/backendruntime" "git.terah.dev/imterah/hermes/backend/api/controllers/v1/backends" "git.terah.dev/imterah/hermes/backend/api/controllers/v1/proxies" "git.terah.dev/imterah/hermes/backend/api/controllers/v1/users" - "git.terah.dev/imterah/hermes/backend/api/dbcore" - "git.terah.dev/imterah/hermes/backend/api/jwtcore" + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/jwt" + "git.terah.dev/imterah/hermes/backend/api/state" "git.terah.dev/imterah/hermes/backend/commonbackend" "github.com/charmbracelet/log" "github.com/gin-gonic/gin" "github.com/urfave/cli/v2" - "gorm.io/gorm" ) func apiEntrypoint(cCtx *cli.Context) error { @@ -34,7 +35,26 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Info("Hermes is initializing...") log.Debug("Initializing database and opening it...") - err := dbcore.InitializeDatabase(&gorm.Config{}) + databaseBackendName := os.Getenv("HERMES_DATABASE_BACKEND") + var databaseBackendParams string + + if databaseBackendName == "sqlite" { + databaseBackendParams = os.Getenv("HERMES_SQLITE_FILEPATH") + + if databaseBackendParams == "" { + log.Fatal("HERMES_SQLITE_FILEPATH is not set") + } + } + + if databaseBackendName == "postgres" { + databaseBackendParams = os.Getenv("HERMES_POSTGRES_DSN") + + if databaseBackendParams == "" { + log.Fatal("HERMES_POSTGRES_DSN is not set") + } + } + + dbInstance, err := db.New(databaseBackendName, databaseBackendParams) if err != nil { log.Fatalf("Failed to initialize database: %s", err) @@ -42,16 +62,38 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Debug("Running database migrations...") - if err := dbcore.DoDatabaseMigrations(dbcore.DB); err != nil { + if err := dbInstance.DoMigrations(); err != nil { return fmt.Errorf("Failed to run database migrations: %s", err) } log.Debug("Initializing the JWT subsystem...") - if err := jwtcore.SetupJWT(); err != nil { - return fmt.Errorf("Failed to initialize the JWT subsystem: %s", err.Error()) + jwtDataString := os.Getenv("HERMES_JWT_SECRET") + var jwtKey []byte + var jwtValidityTimeDuration time.Duration + + if jwtDataString == "" { + log.Fatalf("HERMES_JWT_SECRET is not set") } + if os.Getenv("HERMES_JWT_BASE64_ENCODED") != "" { + jwtKey, err = base64.StdEncoding.DecodeString(jwtDataString) + + if err != nil { + log.Fatalf("Failed to decode base64 JWT: %s", err.Error()) + } + } else { + jwtKey = []byte(jwtDataString) + } + + if developmentMode { + jwtValidityTimeDuration = jwt.DevelopmentModeTimings + } else { + jwtValidityTimeDuration = jwt.NormalModeTimings + } + + jwtInstance := jwt.New(jwtKey, dbInstance, jwtValidityTimeDuration) + log.Debug("Initializing the backend subsystem...") backendMetadataPath := cCtx.String("backends-path") @@ -76,9 +118,9 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Debug("Enumerating backends...") - backendList := []dbcore.Backend{} + backendList := []db.Backend{} - if err := dbcore.DB.Find(&backendList).Error; err != nil { + if err := dbInstance.DB.Find(&backendList).Error; err != nil { return fmt.Errorf("Failed to enumerate backends: %s", err.Error()) } @@ -141,9 +183,9 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Warnf("Backend #%d has reinitialized! Starting up auto-starting proxies...", backend.ID) - autoStartProxies := []dbcore.Proxy{} + autoStartProxies := []db.Proxy{} - if err := dbcore.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { + if err := dbInstance.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { log.Errorf("Failed to query proxies to autostart: %s", err.Error()) return } @@ -243,9 +285,9 @@ func apiEntrypoint(cCtx *cli.Context) error { log.Infof("Successfully initialized backend #%d", backend.ID) - autoStartProxies := []dbcore.Proxy{} + autoStartProxies := []db.Proxy{} - if err := dbcore.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { + if err := dbInstance.DB.Where("backend_id = ? AND auto_start = true", backend.ID).Find(&autoStartProxies).Error; err != nil { log.Errorf("Failed to query proxies to autostart: %s", err.Error()) continue } @@ -309,23 +351,25 @@ func apiEntrypoint(cCtx *cli.Context) error { engine.SetTrustedProxies(nil) } + state := state.New(dbInstance, jwtInstance, engine) + // Initialize routes - engine.POST("/api/v1/users/create", users.CreateUser) - engine.POST("/api/v1/users/login", users.LoginUser) - engine.POST("/api/v1/users/refresh", users.RefreshUserToken) - engine.POST("/api/v1/users/remove", users.RemoveUser) - engine.POST("/api/v1/users/lookup", users.LookupUser) + users.SetupCreateUser(state) + users.SetupLoginUser(state) + users.SetupRefreshUserToken(state) + users.SetupRemoveUser(state) + users.SetupLookupUser(state) - engine.POST("/api/v1/backends/create", backends.CreateBackend) - engine.POST("/api/v1/backends/remove", backends.RemoveBackend) - engine.POST("/api/v1/backends/lookup", backends.LookupBackend) + backends.SetupCreateBackend(state) + backends.SetupRemoveBackend(state) + backends.SetupLookupBackend(state) - engine.POST("/api/v1/forward/create", proxies.CreateProxy) - engine.POST("/api/v1/forward/lookup", proxies.LookupProxy) - engine.POST("/api/v1/forward/remove", proxies.RemoveProxy) - engine.POST("/api/v1/forward/start", proxies.StartProxy) - engine.POST("/api/v1/forward/stop", proxies.StopProxy) - engine.POST("/api/v1/forward/connections", proxies.GetConnections) + proxies.SetupCreateProxy(state) + proxies.SetupRemoveProxy(state) + proxies.SetupLookupProxy(state) + proxies.SetupStartProxy(state) + proxies.SetupStopProxy(state) + proxies.SetupGetConnections(state) log.Infof("Listening on '%s'", listeningAddress) err = engine.Run(listeningAddress) @@ -362,22 +406,6 @@ func main() { app := &cli.App{ Name: "hermes", Usage: "port forwarding across boundaries", - Commands: []*cli.Command{ - { - Name: "import", - Usage: "imports from legacy NextNet/Hermes source", - Aliases: []string{"i"}, - Flags: []cli.Flag{ - &cli.StringFlag{ - Name: "backup-path", - Aliases: []string{"bp"}, - Usage: "path to the backup file", - Required: true, - }, - }, - Action: backupRestoreEntrypoint, - }, - }, Flags: []cli.Flag{ &cli.StringFlag{ Name: "backends-path", diff --git a/backend/api/permissions/permission_nodes.go b/backend/api/permissions/permission_nodes.go index d682d69..f13e80a 100644 --- a/backend/api/permissions/permission_nodes.go +++ b/backend/api/permissions/permission_nodes.go @@ -1,6 +1,6 @@ package permissions -import "git.terah.dev/imterah/hermes/backend/api/dbcore" +import "git.terah.dev/imterah/hermes/backend/api/db" var DefaultPermissionNodes []string = []string{ "routes.add", @@ -27,7 +27,7 @@ var DefaultPermissionNodes []string = []string{ "users.edit", } -func UserHasPermission(user *dbcore.User, node string) bool { +func UserHasPermission(user *db.User, node string) bool { for _, permission := range user.Permissions { if permission.PermissionNode == node && permission.HasPermission { return true diff --git a/backend/api/state/state.go b/backend/api/state/state.go new file mode 100644 index 0000000..754ff56 --- /dev/null +++ b/backend/api/state/state.go @@ -0,0 +1,24 @@ +package state + +import ( + "git.terah.dev/imterah/hermes/backend/api/db" + "git.terah.dev/imterah/hermes/backend/api/jwt" + "github.com/gin-gonic/gin" + "github.com/go-playground/validator/v10" +) + +type State struct { + DB *db.DB + JWT *jwt.JWTCore + Engine *gin.Engine + Validator *validator.Validate +} + +func New(db *db.DB, jwt *jwt.JWTCore, engine *gin.Engine) *State { + return &State{ + DB: db, + JWT: jwt, + Engine: engine, + Validator: validator.New(), + } +} From b93bf456b55ee3b52c28a6681edec01b3f26f541 Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 21 Mar 2025 13:17:08 -0400 Subject: [PATCH 22/24] fix: Fixes 100% CPU usage in the backend runtime This makes the backend runtime not constantly search for messages to be processed. Instead, it only wakes up when it needs to be woken up via goroutines. --- backend/api/backendruntime/runtime.go | 13 +++++++++++++ backend/api/backendruntime/struct.go | 13 ++++++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/backend/api/backendruntime/runtime.go b/backend/api/backendruntime/runtime.go index 52e280d..8d5ca7a 100644 --- a/backend/api/backendruntime/runtime.go +++ b/backend/api/backendruntime/runtime.go @@ -15,6 +15,9 @@ import ( "github.com/charmbracelet/log" ) +// TODO TODO TODO(imterah): +// This code is a mess. This NEEDS to be rearchitected and refactored to work better. Or at the very least, this code needs to be documented heavily. + func handleCommand(command interface{}, sock net.Conn, rtcChan chan interface{}) error { bytes, err := commonbackend.Marshal(command) @@ -160,6 +163,9 @@ func (runtime *Runtime) goRoutineHandler() error { OuterLoop: for { + _ = <-runtime.startProcessingNotification + runtime.isRuntimeCurrentlyProcessing = true + for chanIndex, messageData := range runtime.messageBuffer { if messageData == nil { continue @@ -177,6 +183,8 @@ func (runtime *Runtime) goRoutineHandler() error { runtime.messageBuffer[chanIndex] = nil } + + runtime.isRuntimeCurrentlyProcessing = false } sock.Close() @@ -235,6 +243,7 @@ func (runtime *Runtime) Start() error { runtime.messageBuffer = make([]*messageForBuf, 10) runtime.messageBufferLock = sync.Mutex{} + runtime.startProcessingNotification = make(chan bool) runtime.processRestartNotification = make(chan bool, 1) runtime.logger = &writeLogger{ @@ -322,6 +331,10 @@ SchedulingLoop: schedulingAttempts++ } + if !runtime.isRuntimeCurrentlyProcessing { + runtime.startProcessingNotification <- true + } + // Fetch response and close Channel response, ok := <-commandChannel diff --git a/backend/api/backendruntime/struct.go b/backend/api/backendruntime/struct.go index cd30182..cd4b3b8 100644 --- a/backend/api/backendruntime/struct.go +++ b/backend/api/backendruntime/struct.go @@ -16,15 +16,18 @@ type Backend struct { type messageForBuf struct { Channel chan interface{} + // TODO(imterah): could this be refactored to just be a []byte instead? Look into this Message interface{} } type Runtime struct { - isRuntimeRunning bool - logger *writeLogger - currentProcess *exec.Cmd - currentListener net.Listener - processRestartNotification chan bool + isRuntimeRunning bool + isRuntimeCurrentlyProcessing bool + startProcessingNotification chan bool + logger *writeLogger + currentProcess *exec.Cmd + currentListener net.Listener + processRestartNotification chan bool messageBufferLock sync.Mutex messageBuffer []*messageForBuf From 75b12f205318f79939340b1d1f4269efd21ae409 Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 21 Mar 2025 13:24:59 -0400 Subject: [PATCH 23/24] fix: Avoid recreating validator in SSHBackend and SSHAppBackend --- backend/sshappbackend/local-code/main.go | 15 +++++++++++++-- backend/sshbackend/main.go | 14 ++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go index cb0c13d..34a85d0 100644 --- a/backend/sshappbackend/local-code/main.go +++ b/backend/sshappbackend/local-code/main.go @@ -26,6 +26,8 @@ import ( "golang.org/x/crypto/ssh" ) +var validatorInstance *validator.Validate + type TCPProxy struct { proxyInformation *commonbackend.AddProxy connections map[uint16]net.Conn @@ -62,6 +64,11 @@ type SSHAppBackend struct { func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { log.Info("SSHAppBackend is initializing...") + + if validatorInstance == nil { + validatorInstance = validator.New() + } + backend.globalNonCriticalMessageChan = make(chan interface{}) backend.tcpProxies = map[uint16]*TCPProxy{} backend.udpProxies = map[uint16]*UDPProxy{} @@ -72,7 +79,7 @@ func (backend *SSHAppBackend) StartBackend(configBytes []byte) (bool, error) { return false, err } - if err := validator.New().Struct(&backendData); err != nil { + if err := validatorInstance.Struct(&backendData); err != nil { return false, err } @@ -585,6 +592,10 @@ func (backend *SSHAppBackend) CheckParametersForConnections(clientParameters *co func (backend *SSHAppBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { var backendData SSHAppBackendData + if validatorInstance == nil { + validatorInstance = validator.New() + } + if err := json.Unmarshal(arguments, &backendData); err != nil { return &commonbackend.CheckParametersResponse{ IsValid: false, @@ -592,7 +603,7 @@ func (backend *SSHAppBackend) CheckParametersForBackend(arguments []byte) *commo } } - if err := validator.New().Struct(&backendData); err != nil { + if err := validatorInstance.Struct(&backendData); err != nil { return &commonbackend.CheckParametersResponse{ IsValid: false, Message: fmt.Sprintf("failed validation of parameters: %s", err.Error()), diff --git a/backend/sshbackend/main.go b/backend/sshbackend/main.go index a554fa3..d46b330 100644 --- a/backend/sshbackend/main.go +++ b/backend/sshbackend/main.go @@ -19,6 +19,8 @@ import ( "golang.org/x/crypto/ssh" ) +var validatorInstance *validator.Validate + type ConnWithTimeout struct { net.Conn ReadTimeout time.Duration @@ -76,6 +78,10 @@ type SSHBackend struct { func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { log.Info("SSHBackend is initializing...") + if validatorInstance == nil { + validatorInstance = validator.New() + } + if backend.inReinitLoop { for !backend.isReady { time.Sleep(100 * time.Millisecond) @@ -88,7 +94,7 @@ func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { return false, err } - if err := validator.New().Struct(&backendData); err != nil { + if err := validatorInstance.Struct(&backendData); err != nil { return false, err } @@ -411,6 +417,10 @@ func (backend *SSHBackend) CheckParametersForConnections(clientParameters *commo func (backend *SSHBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { var backendData SSHBackendData + if validatorInstance == nil { + validatorInstance = validator.New() + } + if err := json.Unmarshal(arguments, &backendData); err != nil { return &commonbackend.CheckParametersResponse{ IsValid: false, @@ -418,7 +428,7 @@ func (backend *SSHBackend) CheckParametersForBackend(arguments []byte) *commonba } } - if err := validator.New().Struct(&backendData); err != nil { + if err := validatorInstance.Struct(&backendData); err != nil { return &commonbackend.CheckParametersResponse{ IsValid: false, Message: fmt.Sprintf("failed validation of parameters: %s", err.Error()), From 8e9c7f120fd6732e719fb86f4ed4637909fc99eb Mon Sep 17 00:00:00 2001 From: imterah Date: Fri, 21 Mar 2025 13:39:51 -0400 Subject: [PATCH 24/24] fix: Fix regression where Postgres DSN wouldn't be detected There was a typo where databaseBackendName == "postgresql" was "postgres" instead on accident. --- backend/api/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index f7ba678..1438af8 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -44,14 +44,14 @@ func apiEntrypoint(cCtx *cli.Context) error { if databaseBackendParams == "" { log.Fatal("HERMES_SQLITE_FILEPATH is not set") } - } - - if databaseBackendName == "postgres" { + } else if databaseBackendName == "postgresql" { databaseBackendParams = os.Getenv("HERMES_POSTGRES_DSN") if databaseBackendParams == "" { log.Fatal("HERMES_POSTGRES_DSN is not set") } + } else { + log.Fatalf("Unsupported database backend: %s", databaseBackendName) } dbInstance, err := db.New(databaseBackendName, databaseBackendParams)