diff --git a/backend/backendutil/application.go b/backend/backendutil/application.go index 802e8c7..c6797ce 100644 --- a/backend/backendutil/application.go +++ b/backend/backendutil/application.go @@ -18,7 +18,7 @@ type BackendApplicationHelper struct { func (helper *BackendApplicationHelper) Start() error { log.Debug("BackendApplicationHelper is starting") - err := configureAndLaunchBackgroundProfilingTasks() + err := ConfigureProfiling() if err != nil { return err diff --git a/backend/backendutil/profiling_disabled.go b/backend/backendutil/profiling_disabled.go index d93cbdd..8538407 100644 --- a/backend/backendutil/profiling_disabled.go +++ b/backend/backendutil/profiling_disabled.go @@ -4,6 +4,6 @@ package backendutil var endProfileFunc func() -func configureAndLaunchBackgroundProfilingTasks() error { +func ConfigureProfiling() error { return nil } diff --git a/backend/backendutil/profiling_enabled.go b/backend/backendutil/profiling_enabled.go index a3be527..6fcb189 100644 --- a/backend/backendutil/profiling_enabled.go +++ b/backend/backendutil/profiling_enabled.go @@ -15,7 +15,7 @@ import ( "golang.org/x/exp/rand" ) -func configureAndLaunchBackgroundProfilingTasks() error { +func ConfigureProfiling() error { profilingMode, err := os.ReadFile("/tmp/hermes.backendlauncher.profilebackends") if err != nil && errors.Is(err, os.ErrNotExist) { diff --git a/backend/build.sh b/backend/build.sh index aa5d2ca..413455a 100755 --- a/backend/build.sh +++ b/backend/build.sh @@ -20,14 +20,16 @@ if [ ! -d bin ]; then mkdir bin fi +# Disable dynamic linking by disabling CGo. +# We need to make the remote code as generic as possible, so we do this echo " - building for arm64" -GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm64 . +CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm64 . echo " - building for arm" -GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm . +CGO_ENABLED=0 GOOS=linux GOARCH=arm go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-arm . echo " - building for amd64" -GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-amd64 . +CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-amd64 . echo " - building for i386" -GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-386 . +CGO_ENABLED=0 GOOS=linux GOARCH=386 go build -ldflags="-s -w" -trimpath -o ../local-code/remote-bin/rt-386 . popd > /dev/null pushd sshappbackend/local-code > /dev/null diff --git a/backend/commonbackend/marshal.go b/backend/commonbackend/marshal.go index 6baf02e..4496494 100644 --- a/backend/commonbackend/marshal.go +++ b/backend/commonbackend/marshal.go @@ -84,37 +84,19 @@ func marshalIndividualProxyStruct(conn *ProxyInstance) ([]byte, error) { return proxyBlock, nil } -func Marshal(commandType string, command interface{}) ([]byte, error) { - switch commandType { - case "start": - startCommand, ok := command.(*Start) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - startCommandBytes := make([]byte, 1+2+len(startCommand.Arguments)) +func Marshal(_ string, command interface{}) ([]byte, error) { + switch command := command.(type) { + case *Start: + startCommandBytes := make([]byte, 1+2+len(command.Arguments)) startCommandBytes[0] = StartID - binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(startCommand.Arguments))) - copy(startCommandBytes[3:], startCommand.Arguments) + binary.BigEndian.PutUint16(startCommandBytes[1:3], uint16(len(command.Arguments))) + copy(startCommandBytes[3:], command.Arguments) return startCommandBytes, nil - case "stop": - _, ok := command.(*Stop) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *Stop: return []byte{StopID}, nil - case "addProxy": - addConnectionCommand, ok := command.(*AddProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(addConnectionCommand.SourceIP) + case *AddProxy: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -134,14 +116,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { copy(addConnectionBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], addConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], addConnectionCommand.DestPort) + binary.BigEndian.PutUint16(addConnectionBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(addConnectionBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if addConnectionCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if addConnectionCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -150,14 +132,8 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { addConnectionBytes[6+len(ipBytes)] = protocol return addConnectionBytes, nil - case "removeProxy": - removeConnectionCommand, ok := command.(*RemoveProxy) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(removeConnectionCommand.SourceIP) + case *RemoveProxy: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -175,14 +151,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { removeConnectionBytes[0] = RemoveProxyID removeConnectionBytes[1] = ipVer copy(removeConnectionBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], removeConnectionCommand.SourcePort) - binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], removeConnectionCommand.DestPort) + binary.BigEndian.PutUint16(removeConnectionBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(removeConnectionBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if removeConnectionCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if removeConnectionCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -191,17 +167,11 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { removeConnectionBytes[6+len(ipBytes)] = protocol return removeConnectionBytes, nil - case "proxyConnectionsResponse": - allConnectionsCommand, ok := command.(*ProxyConnectionsResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - connectionsArray := make([][]byte, len(allConnectionsCommand.Connections)) + case *ProxyConnectionsResponse: + connectionsArray := make([][]byte, len(command.Connections)) totalSize := 0 - for connIndex, conn := range allConnectionsCommand.Connections { + for connIndex, conn := range command.Connections { connectionsArray[connIndex] = marshalIndividualConnectionStruct(conn) totalSize += len(connectionsArray[connIndex]) + 1 } @@ -223,14 +193,8 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { connectionCommandArray[totalSize] = '\n' return connectionCommandArray, nil - case "checkClientParameters": - checkClientCommand, ok := command.(*CheckClientParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(checkClientCommand.SourceIP) + case *CheckClientParameters: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -248,14 +212,14 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { checkClientBytes[0] = CheckClientParametersID checkClientBytes[1] = ipVer copy(checkClientBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], checkClientCommand.SourcePort) - binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], checkClientCommand.DestPort) + binary.BigEndian.PutUint16(checkClientBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(checkClientBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if checkClientCommand.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if checkClientCommand.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") @@ -264,31 +228,19 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { checkClientBytes[6+len(ipBytes)] = protocol return checkClientBytes, nil - case "checkServerParameters": - checkServerCommand, ok := command.(*CheckServerParameters) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - serverCommandBytes := make([]byte, 1+2+len(checkServerCommand.Arguments)) + case *CheckServerParameters: + serverCommandBytes := make([]byte, 1+2+len(command.Arguments)) serverCommandBytes[0] = CheckServerParametersID - binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(checkServerCommand.Arguments))) - copy(serverCommandBytes[3:], checkServerCommand.Arguments) + binary.BigEndian.PutUint16(serverCommandBytes[1:3], uint16(len(command.Arguments))) + copy(serverCommandBytes[3:], command.Arguments) return serverCommandBytes, nil - case "checkParametersResponse": - checkParametersCommand, ok := command.(*CheckParametersResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *CheckParametersResponse: var checkMethod uint8 - if checkParametersCommand.InResponseTo == "checkClientParameters" { + if command.InResponseTo == "checkClientParameters" { checkMethod = CheckClientParametersID - } else if checkParametersCommand.InResponseTo == "checkServerParameters" { + } else if command.InResponseTo == "checkServerParameters" { checkMethod = CheckServerParametersID } else { return nil, fmt.Errorf("invalid mode recieved (must be either checkClientParameters or checkServerParameters)") @@ -296,68 +248,50 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { var isValid uint8 - if checkParametersCommand.IsValid { + if command.IsValid { isValid = 1 } - checkResponseBytes := make([]byte, 3+2+len(checkParametersCommand.Message)) + checkResponseBytes := make([]byte, 3+2+len(command.Message)) checkResponseBytes[0] = CheckParametersResponseID checkResponseBytes[1] = checkMethod checkResponseBytes[2] = isValid - binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(checkParametersCommand.Message))) + binary.BigEndian.PutUint16(checkResponseBytes[3:5], uint16(len(command.Message))) - if len(checkParametersCommand.Message) != 0 { - copy(checkResponseBytes[5:], []byte(checkParametersCommand.Message)) + if len(command.Message) != 0 { + copy(checkResponseBytes[5:], []byte(command.Message)) } return checkResponseBytes, nil - case "backendStatusResponse": - backendStatusResponse, ok := command.(*BackendStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *BackendStatusResponse: var isRunning uint8 - if backendStatusResponse.IsRunning { + if command.IsRunning { isRunning = 1 } else { isRunning = 0 } - statusResponseBytes := make([]byte, 3+2+len(backendStatusResponse.Message)) + statusResponseBytes := make([]byte, 3+2+len(command.Message)) statusResponseBytes[0] = BackendStatusResponseID statusResponseBytes[1] = isRunning - statusResponseBytes[2] = byte(backendStatusResponse.StatusCode) + statusResponseBytes[2] = byte(command.StatusCode) - binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(backendStatusResponse.Message))) + binary.BigEndian.PutUint16(statusResponseBytes[3:5], uint16(len(command.Message))) - if len(backendStatusResponse.Message) != 0 { - copy(statusResponseBytes[5:], []byte(backendStatusResponse.Message)) + if len(command.Message) != 0 { + copy(statusResponseBytes[5:], []byte(command.Message)) } return statusResponseBytes, nil - case "backendStatusRequest": - _, ok := command.(*BackendStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *BackendStatusRequest: statusRequestBytes := make([]byte, 1) statusRequestBytes[0] = BackendStatusRequestID return statusRequestBytes, nil - case "proxyStatusRequest": - proxyStatusRequest, ok := command.(*ProxyStatusRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusRequest.SourceIP) + case *ProxyStatusRequest: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -370,37 +304,31 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { ipVer = IPv4 } - proxyStatusRequestBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) + commandBytes := make([]byte, 1+1+len(ipBytes)+2+2+1) - proxyStatusRequestBytes[0] = ProxyStatusRequestID - proxyStatusRequestBytes[1] = ipVer + commandBytes[0] = ProxyStatusRequestID + commandBytes[1] = ipVer - copy(proxyStatusRequestBytes[2:2+len(ipBytes)], ipBytes) + copy(commandBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(proxyStatusRequestBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusRequest.SourcePort) - binary.BigEndian.PutUint16(proxyStatusRequestBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusRequest.DestPort) + binary.BigEndian.PutUint16(commandBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(commandBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if proxyStatusRequest.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if proxyStatusRequest.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") } - proxyStatusRequestBytes[6+len(ipBytes)] = protocol + commandBytes[6+len(ipBytes)] = protocol - return proxyStatusRequestBytes, nil - case "proxyStatusResponse": - proxyStatusResponse, ok := command.(*ProxyStatusResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - sourceIP := net.ParseIP(proxyStatusResponse.SourceIP) + return commandBytes, nil + case *ProxyStatusResponse: + sourceIP := net.ParseIP(command.SourceIP) var ipVer uint8 var ipBytes []byte @@ -413,50 +341,44 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { ipVer = IPv4 } - proxyStatusResponseBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) + commandBytes := make([]byte, 1+1+len(ipBytes)+2+2+1+1) - proxyStatusResponseBytes[0] = ProxyStatusResponseID - proxyStatusResponseBytes[1] = ipVer + commandBytes[0] = ProxyStatusResponseID + commandBytes[1] = ipVer - copy(proxyStatusResponseBytes[2:2+len(ipBytes)], ipBytes) + copy(commandBytes[2:2+len(ipBytes)], ipBytes) - binary.BigEndian.PutUint16(proxyStatusResponseBytes[2+len(ipBytes):4+len(ipBytes)], proxyStatusResponse.SourcePort) - binary.BigEndian.PutUint16(proxyStatusResponseBytes[4+len(ipBytes):6+len(ipBytes)], proxyStatusResponse.DestPort) + binary.BigEndian.PutUint16(commandBytes[2+len(ipBytes):4+len(ipBytes)], command.SourcePort) + binary.BigEndian.PutUint16(commandBytes[4+len(ipBytes):6+len(ipBytes)], command.DestPort) var protocol uint8 - if proxyStatusResponse.Protocol == "tcp" { + if command.Protocol == "tcp" { protocol = TCP - } else if proxyStatusResponse.Protocol == "udp" { + } else if command.Protocol == "udp" { protocol = UDP } else { return nil, fmt.Errorf("invalid protocol") } - proxyStatusResponseBytes[6+len(ipBytes)] = protocol + commandBytes[6+len(ipBytes)] = protocol var isActive uint8 - if proxyStatusResponse.IsActive { + if command.IsActive { isActive = 1 } else { isActive = 0 } - proxyStatusResponseBytes[7+len(ipBytes)] = isActive + commandBytes[7+len(ipBytes)] = isActive - return proxyStatusResponseBytes, nil - case "proxyInstanceResponse": - proxyConectionResponse, ok := command.(*ProxyInstanceResponse) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - - proxyArray := make([][]byte, len(proxyConectionResponse.Proxies)) + return commandBytes, nil + case *ProxyInstanceResponse: + proxyArray := make([][]byte, len(command.Proxies)) totalSize := 0 - for proxyIndex, proxy := range proxyConectionResponse.Proxies { + for proxyIndex, proxy := range command.Proxies { var err error proxyArray[proxyIndex], err = marshalIndividualProxyStruct(proxy) @@ -485,23 +407,11 @@ func Marshal(commandType string, command interface{}) ([]byte, error) { connectionCommandArray[totalSize] = '\n' return connectionCommandArray, nil - case "proxyInstanceRequest": - _, ok := command.(*ProxyInstanceRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *ProxyInstanceRequest: return []byte{ProxyInstanceRequestID}, nil - case "proxyConnectionsRequest": - _, ok := command.(*ProxyConnectionsRequest) - - if !ok { - return nil, fmt.Errorf("failed to typecast") - } - + case *ProxyConnectionsRequest: return []byte{ProxyConnectionsRequestID}, nil } - return nil, fmt.Errorf("couldn't match command name") + return nil, fmt.Errorf("couldn't match command type") } diff --git a/backend/commonbackend/marshalling_test.go b/backend/commonbackend/marshalling_test.go index 1c93f94..e834041 100644 --- a/backend/commonbackend/marshalling_test.go +++ b/backend/commonbackend/marshalling_test.go @@ -9,7 +9,7 @@ import ( var logLevel = os.Getenv("HERMES_LOG_LEVEL") -func TestStartCommandMarshalSupport(t *testing.T) { +func TestStart(t *testing.T) { commandInput := &Start{ Type: "start", Arguments: []byte("Hello from automated testing"), @@ -53,7 +53,7 @@ func TestStartCommandMarshalSupport(t *testing.T) { } } -func TestStopCommandMarshalSupport(t *testing.T) { +func TestStop(t *testing.T) { commandInput := &Stop{ Type: "stop", } @@ -92,7 +92,7 @@ func TestStopCommandMarshalSupport(t *testing.T) { } } -func TestAddConnectionCommandMarshalSupport(t *testing.T) { +func TestAddConnection(t *testing.T) { commandInput := &AddProxy{ Type: "addProxy", SourceIP: "192.168.0.139", @@ -155,7 +155,7 @@ func TestAddConnectionCommandMarshalSupport(t *testing.T) { } } -func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { +func TestRemoveConnection(t *testing.T) { commandInput := &RemoveProxy{ Type: "removeProxy", SourceIP: "192.168.0.139", @@ -218,7 +218,7 @@ func TestRemoveConnectionCommandMarshalSupport(t *testing.T) { } } -func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { +func TestGetAllConnections(t *testing.T) { commandInput := &ProxyConnectionsResponse{ Type: "proxyConnectionsResponse", Connections: []*ProxyClientConnection{ @@ -309,7 +309,7 @@ func TestGetAllConnectionsCommandMarshalSupport(t *testing.T) { } } -func TestCheckClientParametersMarshalSupport(t *testing.T) { +func TestCheckClientParameters(t *testing.T) { commandInput := &CheckClientParameters{ Type: "checkClientParameters", SourceIP: "192.168.0.139", @@ -372,7 +372,7 @@ func TestCheckClientParametersMarshalSupport(t *testing.T) { } } -func TestCheckServerParametersMarshalSupport(t *testing.T) { +func TestCheckServerParameters(t *testing.T) { commandInput := &CheckServerParameters{ Type: "checkServerParameters", Arguments: []byte("Hello from automated testing"), @@ -416,7 +416,7 @@ func TestCheckServerParametersMarshalSupport(t *testing.T) { } } -func TestCheckParametersResponseMarshalSupport(t *testing.T) { +func TestCheckParametersResponse(t *testing.T) { commandInput := &CheckParametersResponse{ Type: "checkParametersResponse", InResponseTo: "checkClientParameters", @@ -473,7 +473,7 @@ func TestCheckParametersResponseMarshalSupport(t *testing.T) { } } -func TestBackendStatusRequestMarshalSupport(t *testing.T) { +func TestBackendStatusRequest(t *testing.T) { commandInput := &BackendStatusRequest{ Type: "backendStatusRequest", } @@ -512,7 +512,7 @@ func TestBackendStatusRequestMarshalSupport(t *testing.T) { } } -func TestBackendStatusResponseMarshalSupport(t *testing.T) { +func TestBackendStatusResponse(t *testing.T) { commandInput := &BackendStatusResponse{ Type: "backendStatusResponse", IsRunning: true, @@ -569,7 +569,7 @@ func TestBackendStatusResponseMarshalSupport(t *testing.T) { } } -func TestProxyStatusRequestMarshalSupport(t *testing.T) { +func TestProxyStatusRequest(t *testing.T) { commandInput := &ProxyStatusRequest{ Type: "proxyStatusRequest", SourceIP: "192.168.0.139", @@ -632,7 +632,7 @@ func TestProxyStatusRequestMarshalSupport(t *testing.T) { } } -func TestProxyStatusResponseMarshalSupport(t *testing.T) { +func TestProxyStatusResponse(t *testing.T) { commandInput := &ProxyStatusResponse{ Type: "proxyStatusResponse", SourceIP: "192.168.0.139", @@ -701,7 +701,7 @@ func TestProxyStatusResponseMarshalSupport(t *testing.T) { } } -func TestProxyConnectionRequestMarshalSupport(t *testing.T) { +func TestProxyConnectionRequest(t *testing.T) { commandInput := &ProxyInstanceRequest{ Type: "proxyInstanceRequest", } @@ -740,7 +740,7 @@ func TestProxyConnectionRequestMarshalSupport(t *testing.T) { } } -func TestProxyConnectionResponseMarshalSupport(t *testing.T) { +func TestProxyConnectionResponse(t *testing.T) { commandInput := &ProxyInstanceResponse{ Type: "proxyInstanceResponse", Proxies: []*ProxyInstance{ diff --git a/backend/sshappbackend/datacommands/constants.go b/backend/sshappbackend/datacommands/constants.go new file mode 100644 index 0000000..7b7620b --- /dev/null +++ b/backend/sshappbackend/datacommands/constants.go @@ -0,0 +1,103 @@ +package datacommands + +type ProxyStatusRequest struct { + Type string + ProxyID uint16 +} + +type ProxyStatusResponse struct { + Type string + ProxyID uint16 + IsActive bool +} + +type RemoveProxy struct { + Type string + ProxyID uint16 +} + +type ProxyInstanceResponse struct { + Type string + Proxies []uint16 +} + +type ProxyConnectionsRequest struct { + Type string + ProxyID uint16 +} + +type ProxyConnectionsResponse struct { + Type string + Connections []uint16 +} + +type TCPConnectionOpened struct { + Type string + ProxyID uint16 + ConnectionID uint16 +} + +type TCPConnectionClosed struct { + Type string + ProxyID uint16 + ConnectionID uint16 +} + +type TCPProxyData struct { + Type string + ProxyID uint16 + ConnectionID uint16 + DataLength uint16 +} + +type UDPProxyData struct { + Type string + ProxyID uint16 + ClientIP string + ClientPort uint16 + DataLength uint16 +} + +type ProxyInformationRequest struct { + Type string + ProxyID uint16 +} + +type ProxyInformationResponse struct { + Type string + Exists bool + SourceIP string + SourcePort uint16 + DestPort uint16 + Protocol string // Will be either 'tcp' or 'udp' +} + +type ProxyConnectionInformationRequest struct { + Type string + ProxyID uint16 + ConnectionID uint16 +} + +type ProxyConnectionInformationResponse struct { + Type string + Exists bool + ClientIP string + ClientPort uint16 +} + +const ( + ProxyStatusRequestID = iota + 100 + ProxyStatusResponseID + RemoveProxyID + ProxyInstanceResponseID + ProxyConnectionsRequestID + ProxyConnectionsResponseID + TCPConnectionOpenedID + TCPConnectionClosedID + TCPProxyDataID + UDPProxyDataID + ProxyInformationRequestID + ProxyInformationResponseID + ProxyConnectionInformationRequestID + ProxyConnectionInformationResponseID +) diff --git a/backend/sshappbackend/datacommands/marshal.go b/backend/sshappbackend/datacommands/marshal.go new file mode 100644 index 0000000..bacbda1 --- /dev/null +++ b/backend/sshappbackend/datacommands/marshal.go @@ -0,0 +1,323 @@ +package datacommands + +import ( + "encoding/binary" + "fmt" + "net" +) + +// Example size and protocol constants — adjust as needed. +const ( + IPv4Size = 4 + IPv6Size = 16 + + TCP = 1 + UDP = 2 +) + +// Marshal takes a command (pointer to one of our structs) and converts it to a byte slice. +func Marshal(_ string, command interface{}) ([]byte, error) { + switch cmd := command.(type) { + // ProxyStatusRequest: 1 byte for the command ID + 2 bytes for the ProxyID. + case *ProxyStatusRequest: + buf := make([]byte, 1+2) + + buf[0] = ProxyStatusRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyStatusResponse: 1 byte for the command ID, 2 bytes for ProxyID, and 1 byte for IsActive. + case *ProxyStatusResponse: + buf := make([]byte, 1+2+1) + + buf[0] = ProxyStatusResponseID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + if cmd.IsActive { + buf[3] = 1 + } else { + buf[3] = 0 + } + + return buf, nil + + // RemoveProxy: 1 byte for the command ID + 2 bytes for the ProxyID. + case *RemoveProxy: + buf := make([]byte, 1+2) + + buf[0] = RemoveProxyID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyConnectionsRequest: 1 byte for the command ID + 2 bytes for the ProxyID. + case *ProxyConnectionsRequest: + buf := make([]byte, 1+2) + + buf[0] = ProxyConnectionsRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + + return buf, nil + + // ProxyConnectionsResponse: 1 byte for the command ID + 2 bytes length of the Connections + 2 bytes for each + // number in the Connection array. + case *ProxyConnectionsResponse: + buf := make([]byte, 1+((len(cmd.Connections)+1)*2)) + + buf[0] = ProxyConnectionsResponseID + binary.BigEndian.PutUint16(buf[1:], uint16(len(cmd.Connections))) + + for connectionIndex, connection := range cmd.Connections { + binary.BigEndian.PutUint16(buf[3+(connectionIndex*2):], connection) + } + + return buf, nil + + // ProxyConnectionsResponse: 1 byte for the command ID + 2 bytes length of the Proxies + 2 bytes for each + // number in the Proxies array. + case *ProxyInstanceResponse: + buf := make([]byte, 1+((len(cmd.Proxies)+1)*2)) + + buf[0] = ProxyInstanceResponseID + binary.BigEndian.PutUint16(buf[1:], uint16(len(cmd.Proxies))) + + for connectionIndex, connection := range cmd.Proxies { + binary.BigEndian.PutUint16(buf[3+(connectionIndex*2):], connection) + } + + return buf, nil + + // TCPConnectionOpened: 1 byte for the command ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *TCPConnectionOpened: + buf := make([]byte, 1+2+2) + + buf[0] = TCPConnectionOpenedID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // TCPConnectionClosed: 1 byte for the command ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *TCPConnectionClosed: + buf := make([]byte, 1+2+2) + + buf[0] = TCPConnectionClosedID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // TCPProxyData: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + 2 bytes DataLength. + case *TCPProxyData: + buf := make([]byte, 1+2+2+2) + + buf[0] = TCPProxyDataID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + binary.BigEndian.PutUint16(buf[5:], cmd.DataLength) + + return buf, nil + + // UDPProxyData: + // Format: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + + // 1 byte IP version + IP bytes + 2 bytes ClientPort + 2 bytes DataLength. + case *UDPProxyData: + ip := net.ParseIP(cmd.ClientIP) + if ip == nil { + return nil, fmt.Errorf("invalid client IP: %v", cmd.ClientIP) + } + + var ipVer uint8 + var ipBytes []byte + + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.ClientIP) + } + + totalSize := 1 + // id + 2 + // ProxyID + 1 + // IP version + len(ipBytes) + // client IP bytes + 2 + // ClientPort + 2 // DataLength + + buf := make([]byte, totalSize) + offset := 0 + buf[offset] = UDPProxyDataID + offset++ + + binary.BigEndian.PutUint16(buf[offset:], cmd.ProxyID) + offset += 2 + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.ClientPort) + offset += 2 + + binary.BigEndian.PutUint16(buf[offset:], cmd.DataLength) + + return buf, nil + + // ProxyInformationRequest: 1 byte ID + 2 bytes ProxyID. + case *ProxyInformationRequest: + buf := make([]byte, 1+2) + buf[0] = ProxyInformationRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + return buf, nil + + // ProxyInformationResponse: + // Format: 1 byte ID + 1 byte Exists + (if exists:) + // 1 byte IP version + IP bytes + 2 bytes SourcePort + 2 bytes DestPort + 1 byte Protocol. + // (For simplicity, this marshaller always writes the IP and port info even if !Exists.) + case *ProxyInformationResponse: + if !cmd.Exists { + buf := make([]byte, 1+1) + buf[0] = ProxyInformationResponseID + buf[1] = 0 /* false */ + + return buf, nil + } + + ip := net.ParseIP(cmd.SourceIP) + + if ip == nil { + return nil, fmt.Errorf("invalid source IP: %v", cmd.SourceIP) + } + + var ipVer uint8 + var ipBytes []byte + + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.SourceIP) + } + + totalSize := 1 + // id + 1 + // Exists flag + 1 + // IP version + len(ipBytes) + + 2 + // SourcePort + 2 + // DestPort + 1 // Protocol + + buf := make([]byte, totalSize) + + offset := 0 + buf[offset] = ProxyInformationResponseID + offset++ + + // We already handle this above + buf[offset] = 1 /* true */ + offset++ + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.SourcePort) + offset += 2 + + binary.BigEndian.PutUint16(buf[offset:], cmd.DestPort) + offset += 2 + + // Encode protocol as 1 byte. + switch cmd.Protocol { + case "tcp": + buf[offset] = TCP + case "udp": + buf[offset] = UDP + default: + return nil, fmt.Errorf("invalid protocol: %v", cmd.Protocol) + } + + // offset++ (not needed since we are at the end) + return buf, nil + + // ProxyConnectionInformationRequest: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case *ProxyConnectionInformationRequest: + buf := make([]byte, 1+2+2) + + buf[0] = ProxyConnectionInformationRequestID + binary.BigEndian.PutUint16(buf[1:], cmd.ProxyID) + binary.BigEndian.PutUint16(buf[3:], cmd.ConnectionID) + + return buf, nil + + // ProxyConnectionInformationResponse: + // Format: 1 byte ID + 1 byte Exists + (if exists:) + // 1 byte IP version + IP bytes + 2 bytes ClientPort. + // This marshaller only writes the rest of the data if Exists. + case *ProxyConnectionInformationResponse: + if !cmd.Exists { + buf := make([]byte, 1+1) + buf[0] = ProxyConnectionInformationResponseID + buf[1] = 0 /* false */ + + return buf, nil + } + + ip := net.ParseIP(cmd.ClientIP) + + if ip == nil { + return nil, fmt.Errorf("invalid client IP: %v", cmd.ClientIP) + } + + var ipVer uint8 + var ipBytes []byte + if ip4 := ip.To4(); ip4 != nil { + ipBytes = ip4 + ipVer = 4 + } else if ip16 := ip.To16(); ip16 != nil { + ipBytes = ip16 + ipVer = 6 + } else { + return nil, fmt.Errorf("unable to detect IP version for: %v", cmd.ClientIP) + } + + totalSize := 1 + // id + 1 + // Exists flag + 1 + // IP version + len(ipBytes) + + 2 // ClientPort + + buf := make([]byte, totalSize) + offset := 0 + buf[offset] = ProxyConnectionInformationResponseID + offset++ + + // We already handle this above + buf[offset] = 1 /* true */ + offset++ + + buf[offset] = ipVer + offset++ + + copy(buf[offset:], ipBytes) + offset += len(ipBytes) + + binary.BigEndian.PutUint16(buf[offset:], cmd.ClientPort) + + return buf, nil + + default: + return nil, fmt.Errorf("unsupported command type") + } +} diff --git a/backend/sshappbackend/datacommands/marshalling_test.go b/backend/sshappbackend/datacommands/marshalling_test.go new file mode 100644 index 0000000..81e7101 --- /dev/null +++ b/backend/sshappbackend/datacommands/marshalling_test.go @@ -0,0 +1,828 @@ +package datacommands + +import ( + "bytes" + "log" + "os" + "testing" +) + +var logLevel = os.Getenv("HERMES_LOG_LEVEL") + +func TestProxyStatusRequest(t *testing.T) { + commandInput := &ProxyStatusRequest{ + Type: "proxyStatusRequest", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyStatusResponse(t *testing.T) { + commandInput := &ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: 19132, + IsActive: true, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyStatusResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.IsActive != commandUnmarshalled.IsActive { + t.Fail() + log.Printf("IsActive's are not equal (orig: '%t', unmsh: '%t')", commandInput.IsActive, commandUnmarshalled.IsActive) + } +} + +func TestRemoveProxy(t *testing.T) { + commandInput := &RemoveProxy{ + Type: "removeProxy", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*RemoveProxy) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyConnectionsRequest(t *testing.T) { + commandInput := &ProxyConnectionsRequest{ + Type: "proxyConnectionsRequest", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyConnectionsResponse(t *testing.T) { + commandInput := &ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: []uint16{12831, 9455, 64219, 12, 32}, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionsResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + for connectionIndex, originalConnection := range commandInput.Connections { + remoteConnection := commandUnmarshalled.Connections[connectionIndex] + + if originalConnection != remoteConnection { + t.Fail() + log.Printf("(in #%d) SourceIP's are not equal (orig: %d, unmsh: %d)", connectionIndex, originalConnection, connectionIndex) + } + } +} + +func TestProxyInstanceResponse(t *testing.T) { + commandInput := &ProxyInstanceResponse{ + Type: "proxyInstanceResponse", + Proxies: []uint16{12831, 9455, 64219, 12, 32}, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInstanceResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + for proxyIndex, originalProxy := range commandInput.Proxies { + remoteProxy := commandUnmarshalled.Proxies[proxyIndex] + + if originalProxy != remoteProxy { + t.Fail() + log.Printf("(in #%d) Proxy IDs are not equal (orig: %d, unmsh: %d)", proxyIndex, originalProxy, remoteProxy) + } + } +} + +func TestTCPConnectionOpened(t *testing.T) { + commandInput := &TCPConnectionOpened{ + Type: "tcpConnectionOpened", + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionOpened) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestTCPConnectionClosed(t *testing.T) { + commandInput := &TCPConnectionClosed{ + Type: "tcpConnectionClosed", + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPConnectionClosed) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestTCPProxyData(t *testing.T) { + commandInput := &TCPProxyData{ + Type: "tcpProxyData", + ProxyID: 19132, + ConnectionID: 25565, + DataLength: 1234, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*TCPProxyData) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } + + if commandInput.DataLength != commandUnmarshalled.DataLength { + t.Fail() + log.Printf("DataLength's are not equal (orig: '%d', unmsh: '%d')", commandInput.DataLength, commandUnmarshalled.DataLength) + } +} + +func TestUDPProxyData(t *testing.T) { + commandInput := &UDPProxyData{ + Type: "udpProxyData", + ProxyID: 19132, + ClientIP: "68.51.23.54", + ClientPort: 28173, + DataLength: 1234, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*UDPProxyData) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ClientIP != commandUnmarshalled.ClientIP { + t.Fail() + log.Printf("ClientIP's are not equal (orig: '%s', unmsh: '%s')", commandInput.ClientIP, commandUnmarshalled.ClientIP) + } + + if commandInput.ClientPort != commandUnmarshalled.ClientPort { + t.Fail() + log.Printf("ClientPort's are not equal (orig: '%d', unmsh: '%d')", commandInput.ClientPort, commandUnmarshalled.ClientPort) + } + + if commandInput.DataLength != commandUnmarshalled.DataLength { + t.Fail() + log.Printf("DataLength's are not equal (orig: '%d', unmsh: '%d')", commandInput.DataLength, commandUnmarshalled.DataLength) + } +} + +func TestProxyInformationRequest(t *testing.T) { + commandInput := &ProxyInformationRequest{ + Type: "proxyInformationRequest", + ProxyID: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } +} + +func TestProxyInformationResponseExists(t *testing.T) { + commandInput := &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: true, + SourceIP: "192.168.0.139", + SourcePort: 19132, + DestPort: 19132, + Protocol: "tcp", + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } + + if commandInput.SourceIP != commandUnmarshalled.SourceIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.SourceIP, commandUnmarshalled.SourceIP) + } + + if commandInput.SourcePort != commandUnmarshalled.SourcePort { + t.Fail() + log.Printf("SourcePort's are not equal (orig: %d, unmsh: %d)", commandInput.SourcePort, commandUnmarshalled.SourcePort) + } + + if commandInput.DestPort != commandUnmarshalled.DestPort { + t.Fail() + log.Printf("DestPort's are not equal (orig: %d, unmsh: %d)", commandInput.DestPort, commandUnmarshalled.DestPort) + } + + if commandInput.Protocol != commandUnmarshalled.Protocol { + t.Fail() + log.Printf("Protocols are not equal (orig: %s, unmsh: %s)", commandInput.Protocol, commandUnmarshalled.Protocol) + } +} + +func TestProxyInformationResponseNoExist(t *testing.T) { + commandInput := &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: false, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } +} + +func TestProxyConnectionInformationRequest(t *testing.T) { + commandInput := &ProxyConnectionInformationRequest{ + Type: "proxyConnectionInformationRequest", + ProxyID: 19132, + ConnectionID: 25565, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + if err != nil { + t.Fatal(err.Error()) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationRequest) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.ProxyID != commandUnmarshalled.ProxyID { + t.Fail() + log.Printf("ProxyID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ProxyID, commandUnmarshalled.ProxyID) + } + + if commandInput.ConnectionID != commandUnmarshalled.ConnectionID { + t.Fail() + log.Printf("ConnectionID's are not equal (orig: '%d', unmsh: '%d')", commandInput.ConnectionID, commandUnmarshalled.ConnectionID) + } +} + +func TestProxyConnectionInformationResponseExists(t *testing.T) { + commandInput := &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: true, + ClientIP: "192.168.0.139", + ClientPort: 19132, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } + + if commandInput.ClientIP != commandUnmarshalled.ClientIP { + t.Fail() + log.Printf("SourceIP's are not equal (orig: %s, unmsh: %s)", commandInput.ClientIP, commandUnmarshalled.ClientIP) + } + + if commandInput.ClientPort != commandUnmarshalled.ClientPort { + t.Fail() + log.Printf("ClientPort's are not equal (orig: %d, unmsh: %d)", commandInput.ClientPort, commandUnmarshalled.ClientPort) + } +} + +func TestProxyConnectionInformationResponseNoExists(t *testing.T) { + commandInput := &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: false, + } + + commandMarshalled, err := Marshal(commandInput.Type, commandInput) + + if err != nil { + t.Fatal(err.Error()) + } + + if logLevel == "debug" { + log.Printf("Generated array contents: %v", commandMarshalled) + } + + buf := bytes.NewBuffer(commandMarshalled) + commandType, commandUnmarshalledRaw, err := Unmarshal(buf) + + if err != nil { + t.Fatal(err.Error()) + } + + if commandType != commandInput.Type { + t.Fail() + log.Print("command type does not match up!") + } + + commandUnmarshalled, ok := commandUnmarshalledRaw.(*ProxyConnectionInformationResponse) + + if !ok { + t.Fatal("failed typecast") + } + + if commandInput.Type != commandUnmarshalled.Type { + t.Fail() + log.Printf("Types are not equal (orig: %s, unmsh: %s)", commandInput.Type, commandUnmarshalled.Type) + } + + if commandInput.Exists != commandUnmarshalled.Exists { + t.Fail() + log.Printf("Exists's are not equal (orig: '%t', unmsh: '%t')", commandInput.Exists, commandUnmarshalled.Exists) + } +} diff --git a/backend/sshappbackend/datacommands/unmarshal.go b/backend/sshappbackend/datacommands/unmarshal.go new file mode 100644 index 0000000..7e97c3f --- /dev/null +++ b/backend/sshappbackend/datacommands/unmarshal.go @@ -0,0 +1,435 @@ +package datacommands + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +// Unmarshal reads from the provided connection and returns +// the message type (as a string), the unmarshalled struct, or an error. +func Unmarshal(conn io.Reader) (string, interface{}, error) { + // Every command starts with a 1-byte command ID. + header := make([]byte, 1) + if _, err := io.ReadFull(conn, header); err != nil { + return "", nil, fmt.Errorf("couldn't read command ID: %w", err) + } + + cmdID := header[0] + switch cmdID { + // ProxyStatusRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyStatusRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyStatusRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "proxyStatusRequest", &ProxyStatusRequest{ + Type: "proxyStatusRequest", + ProxyID: proxyID, + }, nil + + // ProxyStatusResponse: 1 byte ID + 2 bytes ProxyID + 1 byte IsActive. + case ProxyStatusResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyStatusResponse ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + boolBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyStatusResponse IsActive: %w", err) + } + + isActive := boolBuf[0] != 0 + + return "proxyStatusResponse", &ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: proxyID, + IsActive: isActive, + }, nil + + // RemoveProxy: 1 byte ID + 2 bytes ProxyID. + case RemoveProxyID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read RemoveProxy ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "removeProxy", &RemoveProxy{ + Type: "removeProxy", + ProxyID: proxyID, + }, nil + + // ProxyConnectionsRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyConnectionsRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionsRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "proxyConnectionsRequest", &ProxyConnectionsRequest{ + Type: "proxyConnectionsRequest", + ProxyID: proxyID, + }, nil + + // ProxyConnectionsResponse: 1 byte ID + 2 bytes Connections length + 2 bytes for each Connection in Connections. + case ProxyConnectionsResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + } + + length := binary.BigEndian.Uint16(buf) + connections := make([]uint16, length) + + var failedDuringReading error + + for connectionIndex := range connections { + if _, err := io.ReadFull(conn, buf); err != nil { + failedDuringReading = fmt.Errorf("couldn't read ProxyConnectionsResponse with position of %d: %w", connectionIndex, err) + break + } + + connections[connectionIndex] = binary.BigEndian.Uint16(buf) + } + + return "proxyConnectionsResponse", &ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: connections, + }, failedDuringReading + + // ProxyInstanceResponse: 1 byte ID + 2 bytes Proxies length + 2 bytes for each Proxy in Proxies. + case ProxyInstanceResponseID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionsResponse length: %w", err) + } + + length := binary.BigEndian.Uint16(buf) + proxies := make([]uint16, length) + + var failedDuringReading error + + for connectionIndex := range proxies { + if _, err := io.ReadFull(conn, buf); err != nil { + failedDuringReading = fmt.Errorf("couldn't read ProxyConnectionsResponse with position of %d: %w", connectionIndex, err) + break + } + + proxies[connectionIndex] = binary.BigEndian.Uint16(buf) + } + + return "proxyInstanceResponse", &ProxyInstanceResponse{ + Type: "proxyInstanceResponse", + Proxies: proxies, + }, failedDuringReading + + // TCPConnectionOpened: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case TCPConnectionOpenedID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read TCPConnectionOpened fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return "tcpConnectionOpened", &TCPConnectionOpened{ + Type: "tcpConnectionOpened", + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // TCPConnectionClosed: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case TCPConnectionClosedID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read TCPConnectionClosed fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return "tcpConnectionClosed", &TCPConnectionClosed{ + Type: "tcpConnectionClosed", + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // TCPProxyData: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + 2 bytes DataLength. + case TCPProxyDataID: + buf := make([]byte, 2+2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read TCPProxyData fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + dataLength := binary.BigEndian.Uint16(buf[4:6]) + + return "tcpProxyData", &TCPProxyData{ + Type: "tcpProxyData", + ProxyID: proxyID, + ConnectionID: connectionID, + DataLength: dataLength, + }, nil + + // UDPProxyData: + // Format: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID + + // 1 byte IP version + IP bytes + 2 bytes ClientPort + 2 bytes DataLength. + case UDPProxyDataID: + // Read 2 bytes ProxyID + 2 bytes ConnectionID. + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData ProxyID/ConnectionID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData IP version: %w", err) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else if ipVerBuf[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version received: %v", ipVerBuf[0]) + } + + // Read the IP bytes. + ipBytes := make([]byte, ipSize) + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData IP bytes: %w", err) + } + clientIP := net.IP(ipBytes).String() + + // Read ClientPort. + portBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, portBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData ClientPort: %w", err) + } + + clientPort := binary.BigEndian.Uint16(portBuf) + + // Read DataLength. + dataLengthBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, dataLengthBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read UDPProxyData DataLength: %w", err) + } + + dataLength := binary.BigEndian.Uint16(dataLengthBuf) + + return "udpProxyData", &UDPProxyData{ + Type: "udpProxyData", + ProxyID: proxyID, + ClientIP: clientIP, + ClientPort: clientPort, + DataLength: dataLength, + }, nil + + // ProxyInformationRequest: 1 byte ID + 2 bytes ProxyID. + case ProxyInformationRequestID: + buf := make([]byte, 2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationRequest ProxyID: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf) + + return "proxyInformationRequest", &ProxyInformationRequest{ + Type: "proxyInformationRequest", + ProxyID: proxyID, + }, nil + + // ProxyInformationResponse: + // Format: 1 byte ID + 1 byte Exists + + // 1 byte IP version + IP bytes + 2 bytes SourcePort + 2 bytes DestPort + 1 byte Protocol. + case ProxyInformationResponseID: + // Read Exists flag. + boolBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse Exists flag: %w", err) + } + + exists := boolBuf[0] != 0 + + if !exists { + return "proxyInformationResponse", &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: exists, + }, nil + } + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse IP version: %w", err) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else if ipVerBuf[0] == 6 { + ipSize = IPv6Size + } else { + return "", nil, fmt.Errorf("invalid IP version in ProxyInformationResponse: %v", ipVerBuf[0]) + } + + // Read the source IP bytes. + ipBytes := make([]byte, ipSize) + + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse IP bytes: %w", err) + } + + sourceIP := net.IP(ipBytes).String() + + // Read SourcePort and DestPort. + portsBuf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, portsBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse ports: %w", err) + } + + sourcePort := binary.BigEndian.Uint16(portsBuf[0:2]) + destPort := binary.BigEndian.Uint16(portsBuf[2:4]) + + // Read protocol. + protoBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, protoBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyInformationResponse protocol: %w", err) + } + var protocol string + if protoBuf[0] == TCP { + protocol = "tcp" + } else if protoBuf[0] == UDP { + protocol = "udp" + } else { + return "", nil, fmt.Errorf("invalid protocol value in ProxyInformationResponse: %d", protoBuf[0]) + } + + return "proxyInformationResponse", &ProxyInformationResponse{ + Type: "proxyInformationResponse", + Exists: exists, + SourceIP: sourceIP, + SourcePort: sourcePort, + DestPort: destPort, + Protocol: protocol, + }, nil + + // ProxyConnectionInformationRequest: 1 byte ID + 2 bytes ProxyID + 2 bytes ConnectionID. + case ProxyConnectionInformationRequestID: + buf := make([]byte, 2+2) + + if _, err := io.ReadFull(conn, buf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationRequest fields: %w", err) + } + + proxyID := binary.BigEndian.Uint16(buf[0:2]) + connectionID := binary.BigEndian.Uint16(buf[2:4]) + + return "proxyConnectionInformationRequest", &ProxyConnectionInformationRequest{ + Type: "proxyConnectionInformationRequest", + ProxyID: proxyID, + ConnectionID: connectionID, + }, nil + + // ProxyConnectionInformationResponse: + // Format: 1 byte ID + 1 byte Exists + 1 byte IP version + IP bytes + 2 bytes ClientPort. + case ProxyConnectionInformationResponseID: + // Read Exists flag. + boolBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, boolBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse Exists flag: %w", err) + } + + exists := boolBuf[0] != 0 + + if !exists { + return "proxyConnectionInformationResponse", &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: exists, + }, nil + } + + // Read IP version. + ipVerBuf := make([]byte, 1) + + if _, err := io.ReadFull(conn, ipVerBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP version: %w", err) + } + + if ipVerBuf[0] != 4 && ipVerBuf[0] != 6 { + return "", nil, fmt.Errorf("invalid IP version in ProxyConnectionInformationResponse: %v", ipVerBuf[0]) + } + + var ipSize int + + if ipVerBuf[0] == 4 { + ipSize = IPv4Size + } else { + ipSize = IPv6Size + } + + // Read IP bytes. + ipBytes := make([]byte, ipSize) + + if _, err := io.ReadFull(conn, ipBytes); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse IP bytes: %w", err) + } + + clientIP := net.IP(ipBytes).String() + + // Read ClientPort. + portBuf := make([]byte, 2) + + if _, err := io.ReadFull(conn, portBuf); err != nil { + return "", nil, fmt.Errorf("couldn't read ProxyConnectionInformationResponse ClientPort: %w", err) + } + + clientPort := binary.BigEndian.Uint16(portBuf) + + return "proxyConnectionInformationResponse", &ProxyConnectionInformationResponse{ + Type: "proxyConnectionInformationResponse", + Exists: exists, + ClientIP: clientIP, + ClientPort: clientPort, + }, nil + default: + return "", nil, fmt.Errorf("unknown command id: %v", cmdID) + } +} diff --git a/backend/sshappbackend/remote-code/backendutil_custom/application.go b/backend/sshappbackend/remote-code/backendutil_custom/application.go new file mode 100644 index 0000000..6d46a7b --- /dev/null +++ b/backend/sshappbackend/remote-code/backendutil_custom/application.go @@ -0,0 +1,310 @@ +package backendutil_custom + +import ( + "fmt" + "net" + "os" + + "git.terah.dev/imterah/hermes/backend/backendutil" + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "github.com/charmbracelet/log" +) + +type BackendApplicationHelper struct { + Backend BackendInterface + SocketPath string + + socket net.Conn +} + +func (helper *BackendApplicationHelper) Start() error { + log.Debug("BackendApplicationHelper is starting") + err := backendutil.ConfigureProfiling() + + if err != nil { + return err + } + + log.Debug("Currently waiting for Unix socket connection...") + + helper.socket, err = net.Dial("unix", helper.SocketPath) + + if err != nil { + return err + } + + log.Debug("Sucessfully connected") + + for { + commandType, commandRaw, err := datacommands.Unmarshal(helper.socket) + + if err != nil && err.Error() != "couldn't match command ID" { + return err + } + + switch commandType { + case "proxyConnectionsRequest": + proxyConnectionRequest, ok := commandRaw.(*datacommands.ProxyConnectionsRequest) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + connections := helper.Backend.GetAllClientConnections(proxyConnectionRequest.ProxyID) + + serverParams := &datacommands.ProxyConnectionsResponse{ + Type: "proxyConnectionsResponse", + Connections: connections, + } + + byteData, err := datacommands.Marshal(serverParams.Type, serverParams) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + case "removeProxy": + command, ok := commandRaw.(*datacommands.RemoveProxy) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err = helper.Backend.StopProxy(command) + var hasAnyFailed bool + + if !ok { + log.Warnf("failed to remove proxy (ID %d): RemoveProxy returned into failure state", command.ProxyID) + hasAnyFailed = true + } else if err != nil { + log.Warnf("failed to remove proxy (ID %d): %s", command.ProxyID, err.Error()) + hasAnyFailed = true + } + + response := &datacommands.ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: command.ProxyID, + IsActive: hasAnyFailed, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + default: + commandType, commandRaw, err := commonbackend.Unmarshal(helper.socket) + + if err != nil { + return err + } + + switch commandType { + case "start": + command, ok := commandRaw.(*commonbackend.Start) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err = helper.Backend.StartBackend(command.Arguments) + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "stop": + _, ok := commandRaw.(*commonbackend.Stop) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err = helper.Backend.StopBackend() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: !ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "backendStatusRequest": + _, ok := commandRaw.(*commonbackend.BackendStatusRequest) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + ok, err := helper.Backend.GetBackendStatus() + + var ( + message string + statusCode int + ) + + if err != nil { + message = err.Error() + statusCode = commonbackend.StatusFailure + } else { + statusCode = commonbackend.StatusSuccess + } + + response := &commonbackend.BackendStatusResponse{ + Type: "backendStatusResponse", + IsRunning: ok, + StatusCode: statusCode, + Message: message, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "addProxy": + command, ok := commandRaw.(*commonbackend.AddProxy) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + id, ok, err := helper.Backend.StartProxy(command) + var hasAnyFailed bool + + if !ok { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): StartProxy returned into failure state", command.SourceIP, command.SourcePort, command.DestPort) + hasAnyFailed = true + } else if err != nil { + log.Warnf("failed to add proxy (%s:%d -> remote:%d): %s", command.SourceIP, command.SourcePort, command.DestPort, err.Error()) + hasAnyFailed = true + } + + response := &datacommands.ProxyStatusResponse{ + Type: "proxyStatusResponse", + ProxyID: id, + IsActive: !hasAnyFailed, + } + + responseMarshalled, err := commonbackend.Marshal(response.Type, response) + + if err != nil { + log.Error("failed to marshal response: %s", err.Error()) + continue + } + + helper.socket.Write(responseMarshalled) + case "checkClientParameters": + command, ok := commandRaw.(*commonbackend.CheckClientParameters) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + resp := helper.Backend.CheckParametersForConnections(command) + resp.Type = "checkParametersResponse" + resp.InResponseTo = "checkClientParameters" + + byteData, err := commonbackend.Marshal(resp.Type, resp) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + case "checkServerParameters": + command, ok := commandRaw.(*commonbackend.CheckServerParameters) + + if !ok { + return fmt.Errorf("failed to typecast") + } + + resp := helper.Backend.CheckParametersForBackend(command.Arguments) + resp.Type = "checkParametersResponse" + resp.InResponseTo = "checkServerParameters" + + byteData, err := commonbackend.Marshal(resp.Type, resp) + + if err != nil { + return err + } + + if _, err = helper.socket.Write(byteData); err != nil { + return err + } + default: + log.Warn("Unsupported command recieved: %s", commandType) + } + } + } +} + +func NewHelper(backend BackendInterface) *BackendApplicationHelper { + socketPath, ok := os.LookupEnv("HERMES_API_SOCK") + + if !ok { + log.Warn("HERMES_API_SOCK is not defined! This will cause an issue unless the backend manually overwrites it") + } + + helper := &BackendApplicationHelper{ + Backend: backend, + SocketPath: socketPath, + } + + return helper +} diff --git a/backend/sshappbackend/remote-code/backendutil_custom/structure.go b/backend/sshappbackend/remote-code/backendutil_custom/structure.go new file mode 100644 index 0000000..96bcbf8 --- /dev/null +++ b/backend/sshappbackend/remote-code/backendutil_custom/structure.go @@ -0,0 +1,22 @@ +package backendutil_custom + +import ( + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" +) + +type BackendInterface interface { + StartBackend(arguments []byte) (bool, error) + StopBackend() (bool, error) + GetBackendStatus() (bool, error) + StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) + StopProxy(command *datacommands.RemoveProxy) (bool, error) + GetAllProxies() []uint16 + ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse + GetAllClientConnections(proxyID uint16) []uint16 + ResolveConnection(connectionID uint16) *datacommands.ProxyConnectionsResponse + CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse + CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse + HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) + HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) +} diff --git a/backend/sshappbackend/remote-code/main.go b/backend/sshappbackend/remote-code/main.go index 8a3f1cc..0fd11ee 100644 --- a/backend/sshappbackend/remote-code/main.go +++ b/backend/sshappbackend/remote-code/main.go @@ -1,7 +1,120 @@ package main -import "fmt" +import ( + "os" + "sync" + + "git.terah.dev/imterah/hermes/backend/commonbackend" + "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" + "git.terah.dev/imterah/hermes/backend/sshappbackend/remote-code/backendutil_custom" + "github.com/charmbracelet/log" +) + +type TCPProxy struct { + proxyIDIndex uint16 + proxyIDLock sync.Mutex +} + +type UDPProxy struct { +} + +type SSHRemoteAppBackend struct { + connectionIDIndex uint16 + connectionIDLock sync.Mutex + + tcpProxies map[uint16]*TCPProxy + udpProxies map[uint16]*UDPProxy +} + +func (backend *SSHRemoteAppBackend) StartBackend(byte []byte) (bool, error) { + backend.tcpProxies = map[uint16]*TCPProxy{} + backend.udpProxies = map[uint16]*UDPProxy{} + + return true, nil +} + +func (backend *SSHRemoteAppBackend) StopBackend() (bool, error) { + return true, nil +} + +func (backend *SSHRemoteAppBackend) GetBackendStatus() (bool, error) { + return true, nil +} + +func (backend *SSHRemoteAppBackend) StartProxy(command *commonbackend.AddProxy) (uint16, bool, error) { + return 0, true, nil +} + +func (backend *SSHRemoteAppBackend) StopProxy(command *datacommands.RemoveProxy) (bool, error) { + return true, nil +} + +func (backend *SSHRemoteAppBackend) GetAllProxies() []uint16 { + return []uint16{} +} + +func (backend *SSHRemoteAppBackend) ResolveProxy(proxyID uint16) *datacommands.ProxyInformationResponse { + return &datacommands.ProxyInformationResponse{} +} + +func (backend *SSHRemoteAppBackend) GetAllClientConnections(proxyID uint16) []uint16 { + return []uint16{} +} + +func (backend *SSHRemoteAppBackend) ResolveConnection(proxyID uint16) *datacommands.ProxyConnectionsResponse { + return &datacommands.ProxyConnectionsResponse{} +} + +func (backend *SSHRemoteAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + Message: "Valid!", + } +} + +func (backend *SSHRemoteAppBackend) CheckParametersForBackend(arguments []byte) *commonbackend.CheckParametersResponse { + return &commonbackend.CheckParametersResponse{ + IsValid: true, + Message: "Valid!", + } +} + +func (backend *SSHRemoteAppBackend) HandleTCPMessage(message *datacommands.TCPProxyData, data []byte) { + +} + +func (backend *SSHRemoteAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) { + +} func main() { - fmt.Println("remottuh code") + logLevel := os.Getenv("HERMES_LOG_LEVEL") + + if logLevel != "" { + switch logLevel { + case "debug": + log.SetLevel(log.DebugLevel) + + case "info": + log.SetLevel(log.InfoLevel) + + case "warn": + log.SetLevel(log.WarnLevel) + + case "error": + log.SetLevel(log.ErrorLevel) + + case "fatal": + log.SetLevel(log.FatalLevel) + } + } + + backend := &SSHRemoteAppBackend{} + + application := backendutil_custom.NewHelper(backend) + err := application.Start() + + if err != nil { + log.Fatalf("failed execution in application: %s", err.Error()) + } } diff --git a/go.mod b/go.mod index 98f0e12..b390542 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/go-playground/validator/v10 v10.23.0 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/pkg/sftp v1.13.7 github.com/urfave/cli/v2 v2.27.5 golang.org/x/crypto v0.31.0 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/term v0.28.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.11 @@ -50,7 +52,6 @@ require ( github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/pkg/sftp v1.13.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -58,7 +59,6 @@ require ( github.com/ugorji/go/codec v1.2.12 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect golang.org/x/arch v0.12.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.29.0 // indirect