feature: Adds basic UDP support.
This commit is contained in:
parent
15176831e6
commit
f8a4fe00a0
3 changed files with 183 additions and 7 deletions
|
@ -213,7 +213,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) {
|
|||
|
||||
if protocolBytes[0] == TCP {
|
||||
protocol = "tcp"
|
||||
} else if protocolBytes[1] == UDP {
|
||||
} else if protocolBytes[0] == UDP {
|
||||
protocol = "udp"
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid protocol")
|
||||
|
@ -270,7 +270,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) {
|
|||
|
||||
if protocolBytes[0] == TCP {
|
||||
protocol = "tcp"
|
||||
} else if protocolBytes[1] == UDP {
|
||||
} else if protocolBytes[0] == UDP {
|
||||
protocol = "udp"
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid protocol")
|
||||
|
@ -364,7 +364,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) {
|
|||
|
||||
if protocolBytes[0] == TCP {
|
||||
protocol = "tcp"
|
||||
} else if protocolBytes[1] == UDP {
|
||||
} else if protocolBytes[0] == UDP {
|
||||
protocol = "udp"
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid protocol")
|
||||
|
@ -523,7 +523,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) {
|
|||
|
||||
if protocolBytes[0] == TCP {
|
||||
protocol = "tcp"
|
||||
} else if protocolBytes[1] == UDP {
|
||||
} else if protocolBytes[0] == UDP {
|
||||
protocol = "udp"
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid protocol")
|
||||
|
@ -580,7 +580,7 @@ func Unmarshal(conn io.Reader) (interface{}, error) {
|
|||
|
||||
if protocolBytes[0] == TCP {
|
||||
protocol = "tcp"
|
||||
} else if protocolBytes[1] == UDP {
|
||||
} else if protocolBytes[0] == UDP {
|
||||
protocol = "udp"
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid protocol")
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"git.terah.dev/imterah/hermes/backend/commonbackend"
|
||||
"git.terah.dev/imterah/hermes/backend/sshappbackend/datacommands"
|
||||
"git.terah.dev/imterah/hermes/backend/sshappbackend/gaslighter"
|
||||
"git.terah.dev/imterah/hermes/backend/sshappbackend/local-code/porttranslation"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/pkg/sftp"
|
||||
|
@ -32,6 +33,7 @@ type TCPProxy struct {
|
|||
|
||||
type UDPProxy struct {
|
||||
proxyInformation *commonbackend.AddProxy
|
||||
portTranslation *porttranslation.PortTranslation
|
||||
}
|
||||
|
||||
type SSHAppBackendData struct {
|
||||
|
@ -397,7 +399,56 @@ func (backend *SSHAppBackend) StartProxy(command *commonbackend.AddProxy) (bool,
|
|||
} else if command.Protocol == "udp" {
|
||||
backend.udpProxies[proxyStatus.ProxyID] = &UDPProxy{
|
||||
proxyInformation: command,
|
||||
portTranslation: &porttranslation.PortTranslation{},
|
||||
}
|
||||
|
||||
backend.udpProxies[proxyStatus.ProxyID].portTranslation.UDPAddr = &net.UDPAddr{
|
||||
IP: net.ParseIP(command.SourceIP),
|
||||
Port: int(command.SourcePort),
|
||||
}
|
||||
|
||||
udpMessageCommand := &datacommands.UDPProxyData{}
|
||||
udpMessageCommand.ProxyID = proxyStatus.ProxyID
|
||||
|
||||
backend.udpProxies[proxyStatus.ProxyID].portTranslation.WriteFrom = func(ip string, port uint16, data []byte) {
|
||||
udpMessageCommand.ClientIP = ip
|
||||
udpMessageCommand.ClientPort = port
|
||||
udpMessageCommand.DataLength = uint16(len(data))
|
||||
|
||||
marshalledCommand, err := datacommands.Marshal(udpMessageCommand)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("Failed to marshal UDP message header")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := backend.currentSock.Write(marshalledCommand); err != nil {
|
||||
log.Warnf("Failed to write UDP message header")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := backend.currentSock.Write(data); err != nil {
|
||||
log.Warnf("Failed to write UDP message")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(3 * time.Minute)
|
||||
|
||||
// Checks if the proxy still exists before continuing
|
||||
_, ok := backend.udpProxies[proxyStatus.ProxyID]
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Then attempt to run cleanup tasks
|
||||
log.Debug("Running UDP proxy cleanup tasks (invoking CleanupPorts() on portTranslation)")
|
||||
backend.udpProxies[proxyStatus.ProxyID].portTranslation.CleanupPorts()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
@ -474,17 +525,20 @@ func (backend *SSHAppBackend) StopProxy(command *commonbackend.RemoveProxy) (boo
|
|||
return true, fmt.Errorf("failed to stop proxy: still running")
|
||||
}
|
||||
|
||||
// TODO: finish code for UDP
|
||||
proxy.portTranslation.StopAllPorts()
|
||||
delete(backend.udpProxies, proxyIndex)
|
||||
}
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("could not find the proxy")
|
||||
}
|
||||
|
||||
// TODO: implement!
|
||||
func (backend *SSHAppBackend) GetAllClientConnections() []*commonbackend.ProxyClientConnection {
|
||||
return []*commonbackend.ProxyClientConnection{}
|
||||
}
|
||||
|
||||
// We don't have any parameter limitations, so we should be good.
|
||||
func (backend *SSHAppBackend) CheckParametersForConnections(clientParameters *commonbackend.CheckClientParameters) *commonbackend.CheckParametersResponse {
|
||||
return &commonbackend.CheckParametersResponse{
|
||||
IsValid: true,
|
||||
|
@ -617,7 +671,17 @@ func (backend *SSHAppBackend) HandleTCPMessage(message *datacommands.TCPProxyDat
|
|||
connection.Write(data)
|
||||
}
|
||||
|
||||
func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) {}
|
||||
func (backend *SSHAppBackend) HandleUDPMessage(message *datacommands.UDPProxyData, data []byte) {
|
||||
proxy, ok := backend.udpProxies[message.ProxyID]
|
||||
|
||||
if !ok {
|
||||
log.Warn("Could not find UDP proxy")
|
||||
}
|
||||
|
||||
if _, err := proxy.portTranslation.WriteTo(message.ClientIP, message.ClientPort, data); err != nil {
|
||||
log.Warnf("Failed to write to UDP: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (backend *SSHAppBackend) SendNonCriticalMessage(iface interface{}) (interface{}, error) {
|
||||
if backend.currentSock == nil {
|
||||
|
|
112
backend/sshappbackend/local-code/porttranslation/translation.go
Normal file
112
backend/sshappbackend/local-code/porttranslation/translation.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package porttranslation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type connectionData struct {
|
||||
udpConn *net.UDPConn
|
||||
buf []byte
|
||||
hasBeenAliveFor time.Time
|
||||
}
|
||||
|
||||
type PortTranslation struct {
|
||||
UDPAddr *net.UDPAddr
|
||||
WriteFrom func(ip string, port uint16, data []byte)
|
||||
|
||||
newConnectionLock sync.Mutex
|
||||
connections map[string]map[uint16]*connectionData
|
||||
}
|
||||
|
||||
func (translation *PortTranslation) CleanupPorts() {
|
||||
if translation.connections == nil {
|
||||
translation.connections = map[string]map[uint16]*connectionData{}
|
||||
return
|
||||
}
|
||||
|
||||
for connectionIPIndex, connectionPorts := range translation.connections {
|
||||
anyAreAlive := false
|
||||
|
||||
for connectionPortIndex, connectionData := range connectionPorts {
|
||||
if time.Now().Before(connectionData.hasBeenAliveFor.Add(3 * time.Minute)) {
|
||||
anyAreAlive = true
|
||||
continue
|
||||
}
|
||||
|
||||
connectionData.udpConn.Close()
|
||||
delete(connectionPorts, connectionPortIndex)
|
||||
}
|
||||
|
||||
if !anyAreAlive {
|
||||
delete(translation.connections, connectionIPIndex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (translation *PortTranslation) StopAllPorts() {
|
||||
if translation.connections == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for connectionIPIndex, connectionPorts := range translation.connections {
|
||||
for connectionPortIndex, connectionData := range connectionPorts {
|
||||
connectionData.udpConn.Close()
|
||||
delete(connectionPorts, connectionPortIndex)
|
||||
}
|
||||
|
||||
delete(translation.connections, connectionIPIndex)
|
||||
}
|
||||
|
||||
translation.connections = nil
|
||||
}
|
||||
|
||||
func (translation *PortTranslation) WriteTo(ip string, port uint16, data []byte) (int, error) {
|
||||
if translation.connections == nil {
|
||||
translation.connections = map[string]map[uint16]*connectionData{}
|
||||
}
|
||||
|
||||
connectionPortData, ok := translation.connections[ip]
|
||||
|
||||
if !ok {
|
||||
translation.connections[ip] = map[uint16]*connectionData{}
|
||||
connectionPortData = translation.connections[ip]
|
||||
}
|
||||
|
||||
connectionStruct, ok := connectionPortData[port]
|
||||
|
||||
if !ok {
|
||||
connectionPortData[port] = &connectionData{}
|
||||
connectionStruct = connectionPortData[port]
|
||||
|
||||
udpConn, err := net.DialUDP("udp", nil, translation.UDPAddr)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to initialize UDP socket: %s", err.Error())
|
||||
}
|
||||
|
||||
connectionStruct.udpConn = udpConn
|
||||
connectionStruct.buf = make([]byte, 65535)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
n, err := udpConn.Read(connectionStruct.buf)
|
||||
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
delete(connectionPortData, port)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
connectionStruct.hasBeenAliveFor = time.Now()
|
||||
translation.WriteFrom(ip, port, connectionStruct.buf[:n])
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
connectionStruct.hasBeenAliveFor = time.Now()
|
||||
return connectionStruct.udpConn.Write(data)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue