Skip to content

Commit 74a6e59

Browse files
committed
Add tests for ProtoJSON and bad proto messages
1 parent bddd08d commit 74a6e59

File tree

2 files changed

+125
-7
lines changed

2 files changed

+125
-7
lines changed

api/read/read.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
)
1717

1818
// JSON reads JSON from the request body and stores it in the value
19-
// pointed by v.
19+
// pointed to by v.
2020
func JSON(r io.Reader, v interface{}) error {
2121
if err := json.NewDecoder(r).Decode(v); err != nil {
2222
return errs.BadRequestErr(err, "error decoding json")
@@ -34,9 +34,7 @@ func ProtoJSON(r io.Reader, m proto.Message) error {
3434

3535
switch err := protojson.Unmarshal(data, m); {
3636
case errors.Is(err, proto.Error):
37-
// trim the proto prefix for the message
38-
s := strings.TrimSpace(strings.TrimPrefix(err.Error(), "proto:"))
39-
return badProtoJSONError(s)
37+
return badProtoJSONError(err.Error())
4038
default:
4139
return err
4240
}
@@ -59,9 +57,10 @@ func (e badProtoJSONError) Render(w http.ResponseWriter) {
5957
Detail string `json:"detail"`
6058
Message string `json:"message"`
6159
}{
62-
Type: "badRequest",
63-
Detail: "bad request",
64-
Message: e.Error(),
60+
Type: "badRequest",
61+
Detail: "bad request",
62+
// trim the proto prefix for the message
63+
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
6564
}
6665
render.JSONStatus(w, v, http.StatusBadRequest)
6766
}

api/read/read_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
package read
22

33
import (
4+
"encoding/json"
5+
"errors"
46
"io"
7+
"io/ioutil"
8+
"net/http"
9+
"net/http/httptest"
510
"reflect"
611
"strings"
712
"testing"
13+
"testing/iotest"
14+
15+
"github.com/stretchr/testify/assert"
16+
"google.golang.org/protobuf/proto"
17+
"google.golang.org/protobuf/reflect/protoreflect"
18+
19+
"go.step.sm/linkedca"
820

921
"github.com/smallstep/certificates/errs"
1022
)
@@ -44,3 +56,110 @@ func TestJSON(t *testing.T) {
4456
})
4557
}
4658
}
59+
60+
func TestProtoJSON(t *testing.T) {
61+
62+
p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import?
63+
64+
type args struct {
65+
r io.Reader
66+
m proto.Message
67+
}
68+
tests := []struct {
69+
name string
70+
args args
71+
wantErr bool
72+
}{
73+
{
74+
name: "fail/io.ReadAll",
75+
args: args{
76+
r: iotest.ErrReader(errors.New("read error")),
77+
m: p,
78+
},
79+
wantErr: true,
80+
},
81+
{
82+
name: "fail/proto",
83+
args: args{
84+
r: strings.NewReader(`{?}`),
85+
m: p,
86+
},
87+
wantErr: true,
88+
},
89+
{
90+
name: "ok",
91+
args: args{
92+
r: strings.NewReader(`{"x509":{}}`),
93+
m: p,
94+
},
95+
wantErr: false,
96+
},
97+
}
98+
for _, tt := range tests {
99+
t.Run(tt.name, func(t *testing.T) {
100+
err := ProtoJSON(tt.args.r, tt.args.m)
101+
if (err != nil) != tt.wantErr {
102+
t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr)
103+
}
104+
105+
if tt.wantErr {
106+
switch err.(type) {
107+
case badProtoJSONError:
108+
assert.Contains(t, err.Error(), "syntax error")
109+
case *errs.Error:
110+
var ee *errs.Error
111+
if errors.As(err, &ee) {
112+
assert.Equal(t, http.StatusBadRequest, ee.Status)
113+
}
114+
}
115+
return
116+
}
117+
118+
assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m))
119+
assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m))
120+
})
121+
}
122+
}
123+
124+
func Test_badProtoJSONError_Render(t *testing.T) {
125+
tests := []struct {
126+
name string
127+
e badProtoJSONError
128+
expected string
129+
}{
130+
{
131+
name: "bad proto normal space",
132+
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
133+
expected: "syntax error (line 1:2): invalid value ?",
134+
},
135+
{
136+
name: "bad proto non breaking space",
137+
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
138+
expected: "syntax error (line 1:2): invalid value ?",
139+
},
140+
}
141+
for _, tt := range tests {
142+
t.Run(tt.name, func(t *testing.T) {
143+
144+
w := httptest.NewRecorder()
145+
tt.e.Render(w)
146+
res := w.Result()
147+
defer res.Body.Close()
148+
149+
data, err := ioutil.ReadAll(res.Body)
150+
assert.NoError(t, err)
151+
152+
v := struct {
153+
Type string `json:"type"`
154+
Detail string `json:"detail"`
155+
Message string `json:"message"`
156+
}{}
157+
158+
assert.NoError(t, json.Unmarshal(data, &v))
159+
assert.Equal(t, "badRequest", v.Type)
160+
assert.Equal(t, "bad request", v.Detail)
161+
assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message)
162+
163+
})
164+
}
165+
}

0 commit comments

Comments
 (0)