Files
flatrender/services/gateway/internal/ws/ws.go
T

127 lines
3.4 KiB
Go
Raw Normal View History

package ws
import (
"fmt"
"log"
"net/http"
"strings"
mw "github.com/flatrender/gateway/internal/middleware"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{"flatrender.v1"},
}
// RenderProgressProxy proxies WebSocket connections to the render service's REST polling endpoint
// and streams progress events to the client via the WebSocket protocol.
//
// Connection: wss://gateway/ws/v1/render/{job_id}?token={jwt}
//
// The gateway validates JWT ownership, then opens a persistent proxy WS to the upstream
// render service. In production the render service would expose its own WS; for now we
// implement a polling bridge using the REST /progress endpoint.
func RenderProgressProxy(renderUpstreamWS string, jwtSecret string) gin.HandlerFunc {
return func(c *gin.Context) {
jobID := c.Param("job_id")
if _, err := uuid.Parse(jobID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": "bad_request", "message": "invalid job_id"})
return
}
// Authenticate — token may come from query param or Authorization header
tokenStr := c.Query("token")
if tokenStr == "" {
hdr := c.GetHeader("Authorization")
if strings.HasPrefix(hdr, "Bearer ") {
tokenStr = hdr[7:]
}
}
if tokenStr == "" {
c.Writer.WriteHeader(http.StatusUnauthorized)
return
}
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(jwtSecret), nil
})
if err != nil || !token.Valid {
c.Writer.WriteHeader(http.StatusUnauthorized)
return
}
claims, _ := token.Claims.(jwt.MapClaims)
userID, _ := uuid.Parse(fmt.Sprintf("%v", claims["sub"]))
// Upgrade the client connection
clientConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("ws upgrade error: %v", err)
return
}
defer clientConn.Close()
// Connect to upstream render service WS
upstreamURL := fmt.Sprintf("%s/ws/v1/render/%s?user_id=%s", renderUpstreamWS, jobID, userID)
upstreamConn, _, err := websocket.DefaultDialer.Dial(upstreamURL, http.Header{
"Authorization": []string{"Bearer " + tokenStr},
})
if err != nil {
// Upstream WS not available — send hello + close
_ = clientConn.WriteJSON(gin.H{
"type": "error",
"code": "UPSTREAM_UNAVAILABLE",
"message": "render service WebSocket unavailable; use REST polling fallback",
})
clientConn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(1011, "upstream unavailable"))
return
}
defer upstreamConn.Close()
// Bidirectional pipe
errCh := make(chan error, 2)
// Client → upstream
go func() {
for {
mt, msg, err := clientConn.ReadMessage()
if err != nil {
errCh <- err
return
}
if err := upstreamConn.WriteMessage(mt, msg); err != nil {
errCh <- err
return
}
}
}()
// Upstream → client
go func() {
for {
mt, msg, err := upstreamConn.ReadMessage()
if err != nil {
errCh <- err
return
}
if err := clientConn.WriteMessage(mt, msg); err != nil {
errCh <- err
return
}
}
}()
<-errCh
}
}
// mw import alias used above
var _ = mw.CtxUserID