Skip to content

Commit 4bdffd2

Browse files
authored
Implemented Origin checks for websockets
1 parent edfc9ab commit 4bdffd2

File tree

3 files changed

+63
-33
lines changed

3 files changed

+63
-33
lines changed

cmd/arduino-app-cli/daemon/daemon.go

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,32 @@ func NewDaemonCmd(cfg config.Configuration, version string) *cobra.Command {
6969
func httpHandler(ctx context.Context, cfg config.Configuration, daemonPort, version string) {
7070
slog.Info("Starting HTTP server", slog.String("address", ":"+daemonPort))
7171

72+
corsConfig := cors.Config{
73+
Origins: []string{
74+
"wails://wails",
75+
"wails://wails.localhost:*",
76+
"http://wails.localhost",
77+
"http://wails.localhost:*",
78+
"http://localhost:*",
79+
"https://localhost:*",
80+
},
81+
Methods: []string{
82+
http.MethodGet,
83+
http.MethodPost,
84+
http.MethodPut,
85+
http.MethodOptions,
86+
http.MethodDelete,
87+
http.MethodPatch,
88+
},
89+
RequestHeaders: []string{
90+
"Accept",
91+
"Authorization",
92+
"Content-Type",
93+
},
94+
MaxAgeInSeconds: 86400,
95+
ResponseHeaders: []string{},
96+
}
97+
7298
apiSrv := api.NewHTTPRouter(
7399
servicelocator.GetDockerClient(),
74100
version,
@@ -83,40 +109,21 @@ func httpHandler(ctx context.Context, cfg config.Configuration, daemonPort, vers
83109
servicelocator.GetBrickService(),
84110
servicelocator.GetAppIDProvider(),
85111
cfg,
112+
corsConfig.Origins,
86113
)
87114

88-
corsMiddlware, err := cors.NewMiddleware(
89-
cors.Config{
90-
Origins: []string{
91-
"wails://wails", "wails://wails.localhost:*",
92-
"http://wails.localhost", "http://wails.localhost:*",
93-
"http://localhost:*", "https://localhost:*",
94-
},
95-
Methods: []string{
96-
http.MethodGet,
97-
http.MethodPost,
98-
http.MethodPut,
99-
http.MethodOptions,
100-
http.MethodDelete,
101-
http.MethodPatch,
102-
},
103-
RequestHeaders: []string{
104-
"Accept",
105-
"Authorization",
106-
"Content-Type",
107-
},
108-
MaxAgeInSeconds: 86400,
109-
ResponseHeaders: []string{},
110-
},
111-
)
115+
// Wrap the API server with CORS middleware
116+
corsMiddlware, err := cors.NewMiddleware(corsConfig)
112117
if err != nil {
113118
panic(err)
114119
}
120+
apiSrv = corsMiddlware.Wrap(apiSrv)
115121

122+
// Start the HTTP server
116123
address := "127.0.0.1:" + daemonPort
117124
httpSrv := http.Server{
118125
Addr: address,
119-
Handler: httprecover.RecoverPanic(corsMiddlware.Wrap(apiSrv)),
126+
Handler: httprecover.RecoverPanic(apiSrv),
120127
ReadHeaderTimeout: 60 * time.Second,
121128
}
122129
go func() {

internal/api/api.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ func NewHTTPRouter(
3333
brickService *bricks.Service,
3434
idProvider *app.IDProvider,
3535
cfg config.Configuration,
36+
allowedOrigins []string,
3637
) http.Handler {
3738
mux := http.NewServeMux()
3839
mux.Handle("GET /debug/", http.DefaultServeMux) // pprof endpoints
@@ -70,7 +71,7 @@ func NewHTTPRouter(
7071

7172
mux.Handle("GET /v1/docs/", http.StripPrefix("/v1/docs/", handlers.DocsServer(docsFS)))
7273

73-
mux.Handle("GET /v1/monitor/ws", handlers.HandleMonitorWS())
74+
mux.Handle("GET /v1/monitor/ws", handlers.HandleMonitorWS(allowedOrigins))
7475

7576
return mux
7677
}

internal/api/handlers/monitor.go

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"log/slog"
88
"net"
99
"net/http"
10+
"strings"
1011
"time"
1112

1213
"github.com/gorilla/websocket"
@@ -15,11 +16,6 @@ import (
1516
"github.com/arduino/arduino-app-cli/pkg/render"
1617
)
1718

18-
var upgrader = websocket.Upgrader{
19-
ReadBufferSize: 1024,
20-
WriteBufferSize: 1024,
21-
}
22-
2319
func monitorStream(mon net.Conn, ws *websocket.Conn) {
2420
logWebsocketError := func(msg string, err error) {
2521
// Do not log simple close or interruption errors
@@ -72,7 +68,32 @@ func monitorStream(mon net.Conn, ws *websocket.Conn) {
7268
}()
7369
}
7470

75-
func HandleMonitorWS() http.HandlerFunc {
71+
func checkOrigin(origin string, allowedOrigins []string) bool {
72+
for _, allowed := range allowedOrigins {
73+
if strings.HasSuffix(allowed, "*") {
74+
// String ends with *, match the prefix
75+
if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) {
76+
return true
77+
}
78+
} else {
79+
// Exact match
80+
if allowed == origin {
81+
return true
82+
}
83+
}
84+
}
85+
return false
86+
}
87+
88+
func HandleMonitorWS(allowedOrigins []string) http.HandlerFunc {
89+
upgrader := websocket.Upgrader{
90+
ReadBufferSize: 1024,
91+
WriteBufferSize: 1024,
92+
CheckOrigin: func(r *http.Request) bool {
93+
return checkOrigin(r.Header.Get("Origin"), allowedOrigins)
94+
},
95+
}
96+
7697
return func(w http.ResponseWriter, r *http.Request) {
7798
// Connect to monitor
7899
mon, err := net.DialTimeout("tcp", "127.0.0.1:7500", time.Second)
@@ -88,7 +109,8 @@ func HandleMonitorWS() http.HandlerFunc {
88109
// Remember to close monitor connection if websocket upgrade fails.
89110
mon.Close()
90111

91-
render.EncodeResponse(w, http.StatusInternalServerError, map[string]string{"error": "Failed to upgrade connection"})
112+
slog.Error("Failed to upgrade connection", slog.String("error", err.Error()))
113+
render.EncodeResponse(w, http.StatusInternalServerError, map[string]string{"error": "Failed to upgrade connection: " + err.Error()})
92114
return
93115
}
94116

0 commit comments

Comments
 (0)