From f8a4fe00a0126d75bb07490a1667d5e9a6a8e65d Mon Sep 17 00:00:00 2001 From: imterah Date: Wed, 19 Feb 2025 07:58:42 -0500 Subject: [PATCH] feature: Adds basic UDP support. --- backend/commonbackend/unmarshal.go | 10 +- backend/sshappbackend/local-code/main.go | 68 ++++++++++- .../local-code/porttranslation/translation.go | 112 ++++++++++++++++++ 3 files changed, 183 insertions(+), 7 deletions(-) create mode 100644 backend/sshappbackend/local-code/porttranslation/translation.go diff --git a/backend/commonbackend/unmarshal.go b/backend/commonbackend/unmarshal.go index 8e338c6..6bb5af4 100644 --- a/backend/commonbackend/unmarshal.go +++ b/backend/commonbackend/unmarshal.go @@ -213,7 +213,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -270,7 +270,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -364,7 +364,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -523,7 +523,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") @@ -580,7 +580,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) { if protocolBytes[0] == TCP { protocol = "tcp" - } else if protocolBytes[1] == UDP { + } else if protocolBytes[0] == UDP { protocol = "udp" } else { return nil, fmt.Errorf("invalid protocol") diff --git a/backend/sshappbackend/local-code/main.go b/backend/sshappbackend/local-code/main.go index 75044b2..abb0081 100644 --- a/backend/sshappbackend/local-code/main.go +++ b/backend/sshappbackend/local-code/main.go @@ -19,6 +19,7 @@ import ( "git.terah.dev/imterah/hermes/backend/commonbackend" "git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands" "git.terah.dev/imterah/hermes/backend/sshappbackend/gaslighter" + "git.terah.dev/imterah/hermes/backend/sshappbackend/local-code/porttranslation" "github.com/charmbracelet/log" "github.com/go-playground/validator/v10" "github.com/pkg/sftp" @@ -32,6 +33,7 @@ type TCPProxy struct { type UDPProxy struct { proxyInformation *commonbackend.AddProxy + portTranslation *porttranslation.PortTranslation } type SSHAppBackendData struct { @@ -397,7 +399,56 @@ func (backend *SSHAppBackend) StartProxy(command *commonbackend.AddProxy) (bool, } else if command.Protocol == "udp" { backend.udpProxies[proxyStatus.ProxyID] = &UDPProxy{ proxyInformation: command, + portTranslation: &porttranslation.PortTranslation{}, } + + backend.udpProxies[proxyStatus.ProxyID].portTranslation.UDPAddr = &net.UDPAddr{ + IP: net.ParseIP(command.SourceIP), + Port: int(command.SourcePort), + } + + udpMessageCommand := &datacommands.UDPProxyData{} + udpMessageCommand.ProxyID = proxyStatus.ProxyID + + backend.udpProxies[proxyStatus.ProxyID].portTranslation.WriteFrom = func(ip string, port uint16, data []byte) { + udpMessageCommand.ClientIP = ip + udpMessageCommand.ClientPort = port + udpMessageCommand.DataLength = uint16(len(data)) + + marshalledCommand, err := datacommands.Marshal(udpMessageCommand) + + if err != nil { + log.Warnf("Failed to marshal UDP message header") + return + } + + if _, err := backend.currentSock.Write(marshalledCommand); err != nil { + log.Warnf("Failed to write UDP message header") + return + } + + if _, err := backend.currentSock.Write(data); err != nil { + log.Warnf("Failed to write UDP message") + return + } + } + + go func() { + for { + time.Sleep(3 * time.Minute) + + // Checks if the proxy still exists before continuing + _, ok := backend.udpProxies[proxyStatus.ProxyID] + + if !ok { + return + } + + // Then attempt to run cleanup tasks + log.Debug("Running UDP proxy cleanup tasks (invoking CleanupPorts() on portTranslation)") + backend.udpProxies[proxyStatus.ProxyID].portTranslation.CleanupPorts() + } + }() } return true, nil @@ -474,17 +525,20 @@ func (backend *SSHAppBackend) StopProxy(command *commonbackend.RemoveProxy) (boo return true, fmt.Errorf("failed to stop proxy: still running") } - // TODO: finish code for UDP + proxy.portTranslation.StopAllPorts() + delete(backend.udpProxies, proxyIndex) } } return false, fmt.Errorf("could not find the proxy") } +// TODO: implement! func (backend *SSHAppBackend) GetAllClientConnections() []*commonbackend.ProxyClientConnection { return []*commonbackend.ProxyClientConnection{} } +// We don't have any parameter limitations, so we should be good. func (backend *SSHAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse { return &commonbackend.CheckParametersResponse{ IsValid: true, @@ -617,7 +671,17 @@ func (backend *SSHAppBackend) HandleTCPMessage(message *datacommands.TCPProxyDat connection.Write(data) } -func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) {} +func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) { + proxy, ok := backend.udpProxies[message.ProxyID] + + if !ok { + log.Warn("Could not find UDP proxy") + } + + if _, err := proxy.portTranslation.WriteTo(message.ClientIP, message.ClientPort, data); err != nil { + log.Warnf("Failed to write to UDP: %s", err.Error()) + } +} func (backend *SSHAppBackend) SendNonCriticalMessage(iface interface{}) (interface{}, error) { if backend.currentSock == nil { diff --git a/backend/sshappbackend/local-code/porttranslation/translation.go b/backend/sshappbackend/local-code/porttranslation/translation.go new file mode 100644 index 0000000..b8c0454 --- /dev/null +++ b/backend/sshappbackend/local-code/porttranslation/translation.go @@ -0,0 +1,112 @@ +package porttranslation + +import ( + "fmt" + "net" + "sync" + "time" +) + +type connectionData struct { + udpConn *net.UDPConn + buf []byte + hasBeenAliveFor time.Time +} + +type PortTranslation struct { + UDPAddr *net.UDPAddr + WriteFrom func(ip string, port uint16, data []byte) + + newConnectionLock sync.Mutex + connections map[string]map[uint16]*connectionData +} + +func (translation *PortTranslation) CleanupPorts() { + if translation.connections == nil { + translation.connections = map[string]map[uint16]*connectionData{} + return + } + + for connectionIPIndex, connectionPorts := range translation.connections { + anyAreAlive := false + + for connectionPortIndex, connectionData := range connectionPorts { + if time.Now().Before(connectionData.hasBeenAliveFor.Add(3 * time.Minute)) { + anyAreAlive = true + continue + } + + connectionData.udpConn.Close() + delete(connectionPorts, connectionPortIndex) + } + + if !anyAreAlive { + delete(translation.connections, connectionIPIndex) + } + } +} + +func (translation *PortTranslation) StopAllPorts() { + if translation.connections == nil { + return + } + + for connectionIPIndex, connectionPorts := range translation.connections { + for connectionPortIndex, connectionData := range connectionPorts { + connectionData.udpConn.Close() + delete(connectionPorts, connectionPortIndex) + } + + delete(translation.connections, connectionIPIndex) + } + + translation.connections = nil +} + +func (translation *PortTranslation) WriteTo(ip string, port uint16, data []byte) (int, error) { + if translation.connections == nil { + translation.connections = map[string]map[uint16]*connectionData{} + } + + connectionPortData, ok := translation.connections[ip] + + if !ok { + translation.connections[ip] = map[uint16]*connectionData{} + connectionPortData = translation.connections[ip] + } + + connectionStruct, ok := connectionPortData[port] + + if !ok { + connectionPortData[port] = &connectionData{} + connectionStruct = connectionPortData[port] + + udpConn, err := net.DialUDP("udp", nil, translation.UDPAddr) + + if err != nil { + return 0, fmt.Errorf("failed to initialize UDP socket: %s", err.Error()) + } + + connectionStruct.udpConn = udpConn + connectionStruct.buf = make([]byte, 65535) + + go func() { + for { + n, err := udpConn.Read(connectionStruct.buf) + + if err != nil { + udpConn.Close() + delete(connectionPortData, port) + + return + } + + connectionStruct.hasBeenAliveFor = time.Now() + translation.WriteFrom(ip, port, connectionStruct.buf[:n]) + } + }() + } + + connectionStruct.hasBeenAliveFor = time.Now() + return connectionStruct.udpConn.Write(data) +}