Skip to content

Commit 7e5f109

Browse files
committed
Decouple request ID middleware from logging middleware
1 parent 535e2a9 commit 7e5f109

File tree

9 files changed

+155
-124
lines changed

9 files changed

+155
-124
lines changed

authority/provisioner/webhook.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"time"
1616

1717
"github.com/pkg/errors"
18-
"github.com/smallstep/certificates/logging"
18+
"github.com/smallstep/certificates/internal/requestid"
1919
"github.com/smallstep/certificates/templates"
2020
"github.com/smallstep/certificates/webhook"
2121
"go.step.sm/linkedca"
@@ -171,9 +171,8 @@ retry:
171171
return nil, err
172172
}
173173

174-
requestID, ok := logging.GetRequestID(ctx)
175-
if ok {
176-
req.Header.Set("X-Request-ID", requestID)
174+
if requestID, ok := requestid.FromContext(ctx); ok {
175+
req.Header.Set("X-Request-Id", requestID)
177176
}
178177

179178
secret, err := base64.StdEncoding.DecodeString(w.Secret)

authority/provisioner/webhook_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
"testing"
1818
"time"
1919

20-
"github.com/smallstep/certificates/logging"
20+
"github.com/smallstep/certificates/internal/requestid"
2121
"github.com/smallstep/certificates/webhook"
2222
"github.com/stretchr/testify/assert"
2323
"github.com/stretchr/testify/require"
@@ -101,10 +101,10 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
101101
}
102102
}
103103

104-
// withRequestID is a helper that calls into [logging.WithRequestID] and returns
105-
// a new context with the requestID added to the provided context.
104+
// withRequestID is a helper that calls into [requestid.NewContext] and returns
105+
// a new context with the requestID added.
106106
func withRequestID(ctx context.Context, requestID string) context.Context {
107-
return logging.WithRequestID(ctx, requestID)
107+
return requestid.NewContext(ctx, requestID)
108108
}
109109

110110
func TestWebhookController_Enrich(t *testing.T) {

ca/ca.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/smallstep/certificates/cas/apiv1"
3030
"github.com/smallstep/certificates/db"
3131
"github.com/smallstep/certificates/internal/metrix"
32+
"github.com/smallstep/certificates/internal/requestid"
3233
"github.com/smallstep/certificates/logging"
3334
"github.com/smallstep/certificates/monitoring"
3435
"github.com/smallstep/certificates/scep"
@@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
329330
}
330331

331332
// Add logger if configured
333+
var legacyTraceHeader string
332334
if len(cfg.Logger) > 0 {
333335
logger, err := logging.New("ca", cfg.Logger)
334336
if err != nil {
335337
return nil, err
336338
}
339+
legacyTraceHeader = logger.GetTraceHeader()
337340
handler = logger.Middleware(handler)
338341
insecureHandler = logger.Middleware(insecureHandler)
339342
}
340343

344+
// always use request ID middleware; traceHeader is provided for backwards compatibility (for now)
345+
handler = requestid.New(legacyTraceHeader).Middleware(handler)
346+
insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler)
347+
341348
// Create context with all the necessary values.
342349
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
343350

errs/errors_test.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package errs
22

33
import (
44
"fmt"
5-
"reflect"
65
"testing"
6+
7+
"github.com/stretchr/testify/assert"
78
)
89

910
func TestError_MarshalJSON(t *testing.T) {
@@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) {
2728
Err: tt.fields.Err,
2829
}
2930
got, err := e.MarshalJSON()
30-
if (err != nil) != tt.wantErr {
31-
t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
31+
if tt.wantErr {
32+
assert.Error(t, err)
33+
assert.Empty(t, got)
3234
return
3335
}
34-
if !reflect.DeepEqual(got, tt.want) {
35-
t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want)
36-
}
36+
37+
assert.NoError(t, err)
38+
assert.Equal(t, tt.want, got)
3739
})
3840
}
3941
}
@@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) {
5456
for _, tt := range tests {
5557
t.Run(tt.name, func(t *testing.T) {
5658
e := new(Error)
57-
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
58-
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
59-
}
60-
//nolint:govet // best option
61-
if !reflect.DeepEqual(tt.expected, e) {
62-
t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e)
59+
err := e.UnmarshalJSON(tt.args.data)
60+
if tt.wantErr {
61+
assert.Error(t, err)
62+
return
6363
}
64+
65+
assert.NoError(t, err)
66+
assert.Equal(t, tt.expected, e)
6467
})
6568
}
6669
}

internal/requestid/requestid.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package requestid
2+
3+
import (
4+
"context"
5+
"net/http"
6+
7+
"github.com/rs/xid"
8+
)
9+
10+
const (
11+
// requestIDHeader is the header name used for propagating request IDs. If
12+
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
13+
// header. It'll always be used in response and set to the request ID.
14+
requestIDHeader = "X-Request-Id"
15+
16+
// defaultTraceHeader is the default Smallstep tracing header that's currently
17+
// in use. It is used as a fallback to retrieve a request ID from, if the
18+
// "X-Request-Id" request header is not set.
19+
defaultTraceHeader = "X-Smallstep-Id"
20+
)
21+
22+
type Handler struct {
23+
legacyTraceHeader string
24+
}
25+
26+
// New creates a new request ID [handler]. It takes a trace header,
27+
// which is used keep the legacy behavior intact, which relies on the
28+
// X-Smallstep-Id header instead of X-Request-Id.
29+
func New(legacyTraceHeader string) *Handler {
30+
if legacyTraceHeader == "" {
31+
legacyTraceHeader = defaultTraceHeader
32+
}
33+
34+
return &Handler{legacyTraceHeader: legacyTraceHeader}
35+
}
36+
37+
// Middleware wraps an [http.Handler] with request ID extraction
38+
// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id
39+
// header if not set. If both are not set, a new request ID is generated.
40+
// In all cases, the request ID is added to the request context, and
41+
// set to be reflected in the response.
42+
func (h *Handler) Middleware(next http.Handler) http.Handler {
43+
fn := func(w http.ResponseWriter, req *http.Request) {
44+
requestID := req.Header.Get(requestIDHeader)
45+
if requestID == "" {
46+
requestID = req.Header.Get(h.legacyTraceHeader)
47+
}
48+
49+
if requestID == "" {
50+
requestID = newRequestID()
51+
req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior
52+
}
53+
54+
// immediately set the request ID to be reflected in the response
55+
w.Header().Set(requestIDHeader, requestID)
56+
57+
// continue down the handler chain
58+
ctx := NewContext(req.Context(), requestID)
59+
next.ServeHTTP(w, req.WithContext(ctx))
60+
}
61+
return http.HandlerFunc(fn)
62+
}
63+
64+
// newRequestID creates a new request ID using github.com/rs/xid.
65+
func newRequestID() string {
66+
return xid.New().String()
67+
}
68+
69+
type requestIDKey struct{}
70+
71+
// NewContext returns a new context with the given request ID added to the
72+
// context.
73+
func NewContext(ctx context.Context, requestID string) context.Context {
74+
return context.WithValue(ctx, requestIDKey{}, requestID)
75+
}
76+
77+
// FromContext returns the request ID from the context if it exists and
78+
// is not the empty value.
79+
func FromContext(ctx context.Context) (string, bool) {
80+
v, ok := ctx.Value(requestIDKey{}).(string)
81+
return v, ok && v != ""
82+
}
Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package logging
1+
package requestid
22

33
import (
44
"net/http"
@@ -10,33 +10,33 @@ import (
1010
)
1111

1212
func newRequest(t *testing.T) *http.Request {
13+
t.Helper()
1314
r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
1415
require.NoError(t, err)
1516
return r
1617
}
1718

18-
func TestRequestID(t *testing.T) {
19+
func Test_Middleware(t *testing.T) {
1920
requestWithID := newRequest(t)
2021
requestWithID.Header.Set("X-Request-Id", "reqID")
2122
requestWithoutID := newRequest(t)
2223
requestWithEmptyHeader := newRequest(t)
2324
requestWithEmptyHeader.Header.Set("X-Request-Id", "")
2425
requestWithSmallstepID := newRequest(t)
2526
requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")
26-
2727
tests := []struct {
28-
name string
29-
headerName string
30-
handler http.HandlerFunc
31-
req *http.Request
28+
name string
29+
traceHeader string
30+
next http.HandlerFunc
31+
req *http.Request
3232
}{
3333
{
34-
name: "default-request-id",
35-
headerName: defaultTraceIDHeader,
36-
handler: func(w http.ResponseWriter, r *http.Request) {
34+
name: "default-request-id",
35+
traceHeader: defaultTraceHeader,
36+
next: func(w http.ResponseWriter, r *http.Request) {
3737
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
3838
assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
39-
reqID, ok := GetRequestID(r.Context())
39+
reqID, ok := FromContext(r.Context())
4040
if assert.True(t, ok) {
4141
assert.Equal(t, "reqID", reqID)
4242
}
@@ -45,13 +45,13 @@ func TestRequestID(t *testing.T) {
4545
req: requestWithID,
4646
},
4747
{
48-
name: "no-request-id",
49-
headerName: "X-Request-Id",
50-
handler: func(w http.ResponseWriter, r *http.Request) {
48+
name: "no-request-id",
49+
traceHeader: "X-Request-Id",
50+
next: func(w http.ResponseWriter, r *http.Request) {
5151
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
5252
value := r.Header.Get("X-Request-Id")
5353
assert.NotEmpty(t, value)
54-
reqID, ok := GetRequestID(r.Context())
54+
reqID, ok := FromContext(r.Context())
5555
if assert.True(t, ok) {
5656
assert.Equal(t, value, reqID)
5757
}
@@ -60,13 +60,13 @@ func TestRequestID(t *testing.T) {
6060
req: requestWithoutID,
6161
},
6262
{
63-
name: "empty-header-name",
64-
headerName: "",
65-
handler: func(w http.ResponseWriter, r *http.Request) {
63+
name: "empty-header",
64+
traceHeader: "",
65+
next: func(w http.ResponseWriter, r *http.Request) {
6666
assert.Empty(t, r.Header.Get("X-Request-Id"))
6767
value := r.Header.Get("X-Smallstep-Id")
6868
assert.NotEmpty(t, value)
69-
reqID, ok := GetRequestID(r.Context())
69+
reqID, ok := FromContext(r.Context())
7070
if assert.True(t, ok) {
7171
assert.Equal(t, value, reqID)
7272
}
@@ -75,12 +75,12 @@ func TestRequestID(t *testing.T) {
7575
req: requestWithEmptyHeader,
7676
},
7777
{
78-
name: "fallback-header-name",
79-
headerName: defaultTraceIDHeader,
80-
handler: func(w http.ResponseWriter, r *http.Request) {
78+
name: "fallback-header-name",
79+
traceHeader: defaultTraceHeader,
80+
next: func(w http.ResponseWriter, r *http.Request) {
8181
assert.Empty(t, r.Header.Get("X-Request-Id"))
8282
assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
83-
reqID, ok := GetRequestID(r.Context())
83+
reqID, ok := FromContext(r.Context())
8484
if assert.True(t, ok) {
8585
assert.Equal(t, "smallstepID", reqID)
8686
}
@@ -91,8 +91,11 @@ func TestRequestID(t *testing.T) {
9191
}
9292
for _, tt := range tests {
9393
t.Run(tt.name, func(t *testing.T) {
94-
h := RequestID(tt.headerName)
95-
h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req)
94+
handler := New(tt.traceHeader).Middleware(tt.next)
95+
96+
w := httptest.NewRecorder()
97+
handler.ServeHTTP(w, tt.req)
98+
assert.NotEmpty(t, w.Header().Get("X-Request-Id"))
9699
})
97100
}
98101
}

0 commit comments

Comments
 (0)