diff --git a/backend/sshbackend/main.go b/backend/sshbackend/main.go index baf1994..ffdf546 100644 --- a/backend/sshbackend/main.go +++ b/backend/sshbackend/main.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "os" + "slices" "strconv" "strings" "sync" @@ -18,6 +19,32 @@ import ( "golang.org/x/crypto/ssh" ) +type ConnWithTimeout struct { + net.Conn + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func (c *ConnWithTimeout) Read(b []byte) (int, error) { + err := c.Conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + + if err != nil { + return 0, err + } + + return c.Conn.Read(b) +} + +func (c *ConnWithTimeout) Write(b []byte) (int, error) { + err := c.Conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + + if err != nil { + return 0, err + } + + return c.Conn.Write(b) +} + type SSHListener struct { SourceIP string SourcePort uint16 @@ -76,16 +103,34 @@ func (backend *SSHBackend) StartBackend(bytes []byte) (bool, error) { }, } - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backendData.IP, backendData.Port), config) + addr := fmt.Sprintf("%s:%d", backendData.IP, backendData.Port) + timeout := time.Duration(10 * time.Second) + + rawTCPConn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { return false, err } - backend.conn = conn + connWithTimeout := &ConnWithTimeout{ + Conn: rawTCPConn, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + c, chans, reqs, err := ssh.NewClientConn(connWithTimeout, addr, config) + + if err != nil { + return false, err + } + + client := ssh.NewClient(c, chans, reqs) + backend.conn = client + + go backend.backendDisconnectHandler() + go backend.backendKeepaliveHandler() log.Info("SSHBackend has initialized successfully.") - go backend.backendDisconnectHandler() return true, nil } @@ -203,8 +248,7 @@ func (backend *SSHBackend) StartProxy(command *commonbackend.AddProxy) (bool, er // Splice out the clientInstance by clientIndex // TODO: change approach. It works but it's a bit wonky imho - // I asked AI to do this as it's a relatively simple task and I forgot how to do this effectively - backend.clients = append(backend.clients[:clientIndex], backend.clients[clientIndex+1:]...) + backend.clients = slices.Delete(backend.clients, clientIndex, clientIndex+1) return } } @@ -288,10 +332,8 @@ func (backend *SSHBackend) StopProxy(command *commonbackend.RemoveProxy) (bool, } // Splice out the proxy instance by proxyIndex - // TODO: change approach. It works but it's a bit wonky imho - // I asked AI to do this as it's a relatively simple task and I forgot how to do this effectively - backend.proxies = append(backend.proxies[:proxyIndex], backend.proxies[proxyIndex+1:]...) + backend.proxies = slices.Delete(backend.proxies, proxyIndex, proxyIndex+1) return true, nil } } @@ -341,17 +383,31 @@ func (backend *SSHBackend) CheckParametersForBackend(arguments []byte) *commonba } } -func (backend *SSHBackend) backendDisconnectHandler() { +func (backend *SSHBackend) backendKeepaliveHandler() { for { if backend.conn != nil { - err := backend.conn.Wait() + _, _, err := backend.conn.SendRequest("keepalive@openssh.com", true, nil) - if err == nil || err.Error() != "EOF" { - continue + if err != nil { + log.Warn("Keepalive message failed!") + return } } - log.Info("Disconnected from the remote SSH server. Attempting to reconnect in 5 seconds...") + time.Sleep(5 * time.Second) + } +} + +func (backend *SSHBackend) backendDisconnectHandler() { + for { + if backend.conn != nil { + backend.conn.Wait() + backend.conn.Close() + + log.Info("Disconnected from the remote SSH server. Attempting to reconnect in 5 seconds...") + } else { + log.Info("Retrying reconnection in 5 seconds...") + } time.Sleep(5 * time.Second) @@ -376,14 +432,34 @@ func (backend *SSHBackend) backendDisconnectHandler() { }, } - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", backend.config.IP, backend.config.Port), config) + addr := fmt.Sprintf("%s:%d", backend.config.IP, backend.config.Port) + timeout := time.Duration(10 * time.Second) + + rawTCPConn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { - log.Errorf("Failed to connect to the server: %s", err.Error()) - return + log.Errorf("Failed to establish connection to the server: %s", err.Error()) + continue } - backend.conn = conn + connWithTimeout := &ConnWithTimeout{ + Conn: rawTCPConn, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + c, chans, reqs, err := ssh.NewClientConn(connWithTimeout, addr, config) + + if err != nil { + log.Errorf("Failed to create SSH client connection: %s", err.Error()) + rawTCPConn.Close() + continue + } + + client := ssh.NewClient(c, chans, reqs) + backend.conn = client + + go backend.backendKeepaliveHandler() log.Info("SSHBackend has reconnected successfully. Attempting to set up proxies again...")