Skip to content

Commit def9438

Browse files
committed
Improve handling of bad JSON protobuf bodies
1 parent 2ca5c01 commit def9438

File tree

4 files changed

+90
-77
lines changed

4 files changed

+90
-77
lines changed

api/read/read.go

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"google.golang.org/protobuf/encoding/protojson"
1111
"google.golang.org/protobuf/proto"
1212

13-
"github.com/smallstep/certificates/api/render"
1413
"github.com/smallstep/certificates/errs"
1514
)
1615

@@ -24,62 +23,55 @@ func JSON(r io.Reader, v interface{}) error {
2423
}
2524

2625
// ProtoJSON reads JSON from the request body and stores it in the value
27-
// pointed by v.
26+
// pointed to by v.
2827
func ProtoJSON(r io.Reader, m proto.Message) error {
2928
data, err := io.ReadAll(r)
3029
if err != nil {
3130
return errs.BadRequestErr(err, "error reading request body")
3231
}
33-
return protojson.Unmarshal(data, m)
34-
}
35-
36-
// ProtoJSONWithCheck reads JSON from the request body and stores it in the value
37-
// pointed to by m. Returns false if an error was written; true if not.
38-
// TODO(hs): refactor this after the API flow changes are in (or before if that works)
39-
func ProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool {
40-
data, err := io.ReadAll(r)
41-
if err != nil {
42-
var wrapper = struct {
43-
Status int `json:"code"`
44-
Message string `json:"message"`
45-
}{
46-
Status: http.StatusBadRequest,
47-
Message: err.Error(),
48-
}
49-
errData, err := json.Marshal(wrapper)
50-
if err != nil {
51-
panic(err)
52-
}
53-
w.Header().Set("Content-Type", "application/json")
54-
w.WriteHeader(http.StatusBadRequest)
55-
w.Write(errData)
56-
return false
57-
}
5832
if err := protojson.Unmarshal(data, m); err != nil {
5933
if errors.Is(err, proto.Error) {
60-
var wrapper = struct {
61-
Type string `json:"type"`
62-
Detail string `json:"detail"`
63-
Message string `json:"message"`
64-
}{
65-
Type: "badRequest",
66-
Detail: "bad request",
67-
Message: err.Error(),
68-
}
69-
errData, err := json.Marshal(wrapper)
70-
if err != nil {
71-
panic(err)
72-
}
73-
w.Header().Set("Content-Type", "application/json")
74-
w.WriteHeader(http.StatusBadRequest)
75-
w.Write(errData)
76-
return false
34+
return newBadProtoJSONError(err)
7735
}
36+
}
37+
return err
38+
}
7839

79-
// fallback to the default error writer
80-
render.Error(w, err)
81-
return false
40+
// BadProtoJSONError is an error type that is used when a proto
41+
// message cannot be unmarshaled. Usually this is caused by an error
42+
// in the request body.
43+
type BadProtoJSONError struct {
44+
err error
45+
Type string `json:"type"`
46+
Detail string `json:"detail"`
47+
Message string `json:"message"`
48+
}
49+
50+
// newBadProtoJSONError returns a new instance of BadProtoJSONError
51+
// This error type is always caused by an error in the request body.
52+
func newBadProtoJSONError(err error) *BadProtoJSONError {
53+
return &BadProtoJSONError{
54+
err: err,
55+
Type: "badRequest",
56+
Detail: "bad request",
57+
Message: err.Error(),
58+
}
59+
}
60+
61+
// Error implements the error interface
62+
func (e *BadProtoJSONError) Error() string {
63+
return e.err.Error()
64+
}
65+
66+
// Render implements render.RenderableError for BadProtoError
67+
func (e *BadProtoJSONError) Render(w http.ResponseWriter) {
68+
69+
errData, err := json.Marshal(e)
70+
if err != nil {
71+
panic(err)
8272
}
8373

84-
return true
74+
w.Header().Set("Content-Type", "application/json")
75+
w.WriteHeader(http.StatusBadRequest)
76+
w.Write(errData)
8577
}

authority/admin/api/policy.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
8080
}
8181

8282
var newPolicy = new(linkedca.Policy)
83-
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
83+
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
84+
render.Error(w, err)
8485
return
8586
}
8687

@@ -120,7 +121,8 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
120121
}
121122

122123
var newPolicy = new(linkedca.Policy)
123-
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
124+
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
125+
render.Error(w, err)
124126
return
125127
}
126128

@@ -195,7 +197,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
195197
}
196198

197199
var newPolicy = new(linkedca.Policy)
198-
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
200+
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
201+
render.Error(w, err)
199202
return
200203
}
201204

@@ -228,7 +231,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
228231
}
229232

230233
var newPolicy = new(linkedca.Policy)
231-
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
234+
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
235+
render.Error(w, err)
232236
return
233237
}
234238

@@ -297,7 +301,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
297301
}
298302

299303
var newPolicy = new(linkedca.Policy)
300-
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
304+
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
305+
render.Error(w, err)
301306
return
302307
}
303308

@@ -324,7 +329,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
324329
}
325330

326331
var newPolicy = new(linkedca.Policy)
327-
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
332+
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
333+
render.Error(w, err)
328334
return
329335
}
330336

authority/admin/api/policy_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
167167
statusCode: 409,
168168
}
169169
},
170-
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
170+
"fail/read.ProtoJSON": func(t *testing.T) test {
171171
ctx := context.Background()
172172
adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?")
173173
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
@@ -410,7 +410,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
410410
statusCode: 404,
411411
}
412412
},
413-
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
413+
"fail/read.ProtoJSON": func(t *testing.T) test {
414414
policy := &linkedca.Policy{
415415
X509: &linkedca.X509Policy{
416416
Allow: &linkedca.X509Names{
@@ -871,7 +871,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
871871
statusCode: 409,
872872
}
873873
},
874-
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
874+
"fail/read.ProtoJSON": func(t *testing.T) test {
875875
prov := &linkedca.Provisioner{
876876
Name: "provName",
877877
}
@@ -1060,7 +1060,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
10601060
statusCode: 404,
10611061
}
10621062
},
1063-
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
1063+
"fail/read.ProtoJSON": func(t *testing.T) test {
10641064
policy := &linkedca.Policy{
10651065
X509: &linkedca.X509Policy{
10661066
Allow: &linkedca.X509Names{
@@ -1472,7 +1472,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
14721472
statusCode: 409,
14731473
}
14741474
},
1475-
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
1475+
"fail/read.ProtoJSON": func(t *testing.T) test {
14761476
prov := &linkedca.Provisioner{
14771477
Name: "provName",
14781478
}
@@ -1637,7 +1637,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
16371637
statusCode: 404,
16381638
}
16391639
},
1640-
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
1640+
"fail/read.ProtoJSON": func(t *testing.T) test {
16411641
policy := &linkedca.Policy{
16421642
X509: &linkedca.X509Policy{
16431643
Allow: &linkedca.X509Names{

authority/admin/api/provisioner_test.go

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@ import (
88
"io"
99
"net/http"
1010
"net/http/httptest"
11+
"strings"
1112
"testing"
1213
"time"
1314

1415
"github.com/go-chi/chi"
1516
"github.com/google/go-cmp/cmp"
1617
"github.com/google/go-cmp/cmp/cmpopts"
18+
"google.golang.org/protobuf/encoding/protojson"
19+
"google.golang.org/protobuf/types/known/timestamppb"
20+
21+
"go.step.sm/linkedca"
22+
1723
"github.com/smallstep/assert"
1824
"github.com/smallstep/certificates/authority/admin"
1925
"github.com/smallstep/certificates/authority/provisioner"
20-
"go.step.sm/linkedca"
21-
"google.golang.org/protobuf/encoding/protojson"
22-
"google.golang.org/protobuf/types/known/timestamppb"
2326
)
2427

2528
func TestHandler_GetProvisioner(t *testing.T) {
@@ -335,12 +338,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
335338
return test{
336339
ctx: context.Background(),
337340
body: body,
338-
statusCode: 500,
339-
err: &admin.Error{ // TODO(hs): this probably needs a better error
340-
Type: "",
341-
Status: 500,
342-
Detail: "",
343-
Message: "",
341+
statusCode: 400,
342+
err: &admin.Error{
343+
Type: "badRequest",
344+
Status: 400,
345+
Detail: "bad request",
346+
Message: "proto: syntax error (line 1:2): invalid value !",
344347
},
345348
}
346349
},
@@ -423,9 +426,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
423426
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
424427

425428
assert.Equals(t, tc.err.Type, adminErr.Type)
426-
assert.Equals(t, tc.err.Message, adminErr.Message)
427429
assert.Equals(t, tc.err.Detail, adminErr.Detail)
428430
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
431+
432+
if strings.HasPrefix(tc.err.Message, "proto:") {
433+
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
434+
} else {
435+
assert.Equals(t, tc.err.Message, adminErr.Message)
436+
}
437+
429438
return
430439
}
431440

@@ -616,12 +625,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
616625
return test{
617626
ctx: context.Background(),
618627
body: body,
619-
statusCode: 500,
620-
err: &admin.Error{ // TODO(hs): this probably needs a better error
621-
Type: "",
622-
Status: 500,
623-
Detail: "",
624-
Message: "",
628+
statusCode: 400,
629+
err: &admin.Error{
630+
Type: "badRequest",
631+
Status: 400,
632+
Detail: "bad request",
633+
Message: "proto: syntax error (line 1:2): invalid value !",
625634
},
626635
}
627636
},
@@ -1074,9 +1083,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
10741083
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
10751084

10761085
assert.Equals(t, tc.err.Type, adminErr.Type)
1077-
assert.Equals(t, tc.err.Message, adminErr.Message)
10781086
assert.Equals(t, tc.err.Detail, adminErr.Detail)
10791087
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
1088+
1089+
if strings.HasPrefix(tc.err.Message, "proto:") {
1090+
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
1091+
} else {
1092+
assert.Equals(t, tc.err.Message, adminErr.Message)
1093+
}
1094+
10801095
return
10811096
}
10821097

0 commit comments

Comments
 (0)