@@ -8,14 +8,19 @@ import (
88
99 "github.com/go-chi/chi/v5"
1010 "github.com/prometheus/client_golang/prometheus"
11+ cm "github.com/prometheus/client_model/go"
12+ "github.com/stretchr/testify/assert"
1113 "github.com/stretchr/testify/require"
1214
1315 "github.com/coder/coder/v2/coderd/httpmw"
1416 "github.com/coder/coder/v2/coderd/tracing"
17+ "github.com/coder/coder/v2/testutil"
18+ "github.com/coder/websocket"
1519)
1620
1721func TestPrometheus (t * testing.T ) {
1822 t .Parallel ()
23+
1924 t .Run ("All" , func (t * testing.T ) {
2025 t .Parallel ()
2126 req := httptest .NewRequest ("GET" , "/" , nil )
@@ -29,4 +34,90 @@ func TestPrometheus(t *testing.T) {
2934 require .NoError (t , err )
3035 require .Greater (t , len (metrics ), 0 )
3136 })
37+
38+ t .Run ("Concurrent" , func (t * testing.T ) {
39+ t .Parallel ()
40+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitShort )
41+ defer cancel ()
42+
43+ reg := prometheus .NewRegistry ()
44+ promMW := httpmw .Prometheus (reg )
45+
46+ // Create a test handler to simulate a WebSocket connection
47+ testHandler := http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
48+ conn , err := websocket .Accept (rw , r , nil )
49+ if ! assert .NoError (t , err , "failed to accept websocket" ) {
50+ return
51+ }
52+ defer conn .Close (websocket .StatusGoingAway , "" )
53+ })
54+
55+ wrappedHandler := promMW (testHandler )
56+
57+ r := chi .NewRouter ()
58+ r .Use (tracing .StatusWriterMiddleware , promMW )
59+ r .Get ("/api/v2/build/{build}/logs" , func (rw http.ResponseWriter , r * http.Request ) {
60+ wrappedHandler .ServeHTTP (rw , r )
61+ })
62+
63+ srv := httptest .NewServer (r )
64+ defer srv .Close ()
65+ // nolint: bodyclose
66+ conn , _ , err := websocket .Dial (ctx , srv .URL + "/api/v2/build/1/logs" , nil )
67+ require .NoError (t , err , "failed to dial WebSocket" )
68+ defer conn .Close (websocket .StatusNormalClosure , "" )
69+
70+ metrics , err := reg .Gather ()
71+ require .NoError (t , err )
72+ require .Greater (t , len (metrics ), 0 )
73+ metricLabels := getMetricLabels (metrics )
74+
75+ concurrentWebsockets , ok := metricLabels ["coderd_api_concurrent_websockets" ]
76+ require .True (t , ok , "coderd_api_concurrent_websockets metric not found" )
77+ require .Equal (t , "/api/v2/build/{build}/logs" , concurrentWebsockets ["path" ])
78+ })
79+
80+ t .Run ("UserRoute" , func (t * testing.T ) {
81+ t .Parallel ()
82+ reg := prometheus .NewRegistry ()
83+ promMW := httpmw .Prometheus (reg )
84+
85+ r := chi .NewRouter ()
86+ r .With (promMW ).Get ("/api/v2/users/{user}" , func (w http.ResponseWriter , r * http.Request ) {})
87+
88+ req := httptest .NewRequest ("GET" , "/api/v2/users/john" , nil )
89+
90+ sw := & tracing.StatusWriter {ResponseWriter : httptest .NewRecorder ()}
91+
92+ r .ServeHTTP (sw , req )
93+
94+ metrics , err := reg .Gather ()
95+ require .NoError (t , err )
96+ require .Greater (t , len (metrics ), 0 )
97+ metricLabels := getMetricLabels (metrics )
98+
99+ reqProcessed , ok := metricLabels ["coderd_api_requests_processed_total" ]
100+ require .True (t , ok , "coderd_api_requests_processed_total metric not found" )
101+ require .Equal (t , "/api/v2/users/{user}" , reqProcessed ["path" ])
102+ require .Equal (t , "GET" , reqProcessed ["method" ])
103+
104+ concurrentRequests , ok := metricLabels ["coderd_api_concurrent_requests" ]
105+ require .True (t , ok , "coderd_api_concurrent_requests metric not found" )
106+ require .Equal (t , "/api/v2/users/{user}" , concurrentRequests ["path" ])
107+ require .Equal (t , "GET" , concurrentRequests ["method" ])
108+ })
109+ }
110+
111+ func getMetricLabels (metrics []* cm.MetricFamily ) map [string ]map [string ]string {
112+ metricLabels := map [string ]map [string ]string {}
113+ for _ , metricFamily := range metrics {
114+ metricName := metricFamily .GetName ()
115+ metricLabels [metricName ] = map [string ]string {}
116+ for _ , metric := range metricFamily .GetMetric () {
117+ for _ , labelPair := range metric .GetLabel () {
118+ metricLabels [metricName ][labelPair .GetName ()] = labelPair .GetValue ()
119+ }
120+ }
121+ }
122+ return metricLabels
32123}
0 commit comments