package main import ( "fmt" "log" "net" "net/http" "strings" "github.com/gorilla/websocket" ) var isLoggingEnabled bool func logRequest(method, path string, respStatus int) { if isLoggingEnabled { log.Printf("%s %s - %d", method, path, respStatus) } } func main() { listenAddr := ":8080" sourceIP := "http://192.168.2.10" hostHeader := "immich.hofers.cloud" isLoggingEnabled = true client := &http.Client{} wsUpgrader := &websocket.Upgrader{} wsDialer := &websocket.Dialer{ NetDial: func(network, addr string) (net.Conn, error) { spoofedAddr := sourceIP[strings.LastIndex(sourceIP, "/")+1:] if strings.Index(spoofedAddr, ":") == -1 { if strings.HasPrefix(sourceIP, "http://") { spoofedAddr += ":80" } else if strings.HasPrefix(sourceIP, "https://") { spoofedAddr += ":443" } } fmt.Printf("spoofed address (in dialer): %s\n", spoofedAddr) conn, err := net.Dial(network, spoofedAddr) return conn, err }, } http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // TODO: Currently incomplete if websocket.IsWebSocketUpgrade(r) { fmt.Println("detected WebSocket upgrade request. doing initialization tasks...") remoteAddr := "" if strings.HasPrefix(sourceIP, "http://") { remoteAddr += "ws://" } if strings.HasPrefix(sourceIP, "https://") { remoteAddr += "wss://" } remoteAddr += hostHeader // I'd like to preserve more headers, but Cookies is the only easy thing to implement... newClientHeaders := make(http.Header) newClientHeaders.Add("Cookie", r.Header.Get("Cookie")) fmt.Printf("spoofed address (in handler): %s\n", remoteAddr) wsClient, resp, err := wsDialer.Dial(remoteAddr, newClientHeaders) if err != nil { log.Printf("failed to initialize client's (remote server) WebSocket for '%s': %s", r.URL.Path, err.Error()) http.Error(w, "Internal Server Error", 500) logRequest(r.Method, r.URL.Path, 500) return } for key, values := range resp.Header { for _, value := range values { w.Header().Add(key, value) } } wsServer, err := wsUpgrader.Upgrade(w, r, w.Header()) if err != nil { log.Printf("failed to initialize server's (us) WebSocket for '%s': %s", r.URL.Path, err.Error()) http.Error(w, "Internal Server Error", 500) logRequest(r.Method, r.URL.Path, 500) return } go func() { for { messageType, message, err := wsServer.ReadMessage() if err != nil || messageType == websocket.CloseMessage { if err != nil && err.Error() != "EOF" { log.Printf("failed to read server's (us) WebSocket for '%s': %s", r.URL.Path, err.Error()) } err := wsClient.Close() if err != nil { log.Printf("failed to close client's (remote server) WebSocket for '%s': %s", r.URL.Path, err.Error()) return } break } wsClient.WriteMessage(messageType, message) } }() go func() { for { messageType, message, err := wsClient.ReadMessage() if err != nil || messageType == websocket.CloseMessage { if err != nil && err.Error() != "EOF" { log.Printf("failed to read client's (remote server) WebSocket for '%s': %s", r.URL.Path, err.Error()) } err := wsServer.Close() if err != nil { log.Printf("failed to close server's (remote server) WebSocket for '%s': %s", r.URL.Path, err.Error()) return } break } wsServer.WriteMessage(messageType, message) } }() logRequest(r.Method, r.URL.Path, 101) return } req, err := http.NewRequest(r.Method, sourceIP+r.URL.Path, r.Body) if err != nil { log.Printf("failed to construct request for '%s': %s", r.URL.Path, err.Error()) http.Error(w, "Internal Server Error", 500) logRequest(r.Method, r.URL.Path, r.Response.StatusCode) return } req.Host = hostHeader for key, values := range r.Header { if key == "Host" { continue } for _, value := range values { req.Header.Add(key, value) } } resp, err := client.Do(req) if err != nil { log.Printf("failed to handle response for '%s': %s", r.URL.Path, err.Error()) http.Error(w, "Internal Server Error", 500) logRequest(r.Method, r.URL.Path, r.Response.StatusCode) return } for key, values := range resp.Header { for _, value := range values { w.Header().Add(key, value) } } /* if resp.StatusCode >= 300 && resp.StatusCode <= 399 { existingLocationHeaderUnparsed := resp.Header.Get("Location") if existingLocationHeaderUnparsed != "" { w.Header().Set("Location", existingLocationHeaderUnparsed) } } */ w.WriteHeader(resp.StatusCode) totalBytesRead := int64(0) byteBuffer := make([]byte, 65535) for totalBytesRead != resp.ContentLength { readContents, err := resp.Body.Read(byteBuffer) totalBytesRead += int64(readContents) if err != nil && totalBytesRead != resp.ContentLength { log.Printf("failed to either read or finish reading response for '%s': %s", r.URL.Path, err.Error()) logRequest(r.Method, r.URL.Path, resp.StatusCode) return } w.Write(byteBuffer[:readContents]) } logRequest(r.Method, r.URL.Path, resp.StatusCode) }) log.Printf("Hostess is listening on %s", listenAddr) log.Fatal(http.ListenAndServe(listenAddr, nil)) }