Skip to content

Commit 32487ec

Browse files
openpgp/packet: ensure that first partial packet is 512 bytes
This requirement is from RFC 4880 4.2.2.4. Also simplify the partialLengthWriter loop. The old code worked but was written in a confusing way, with a loop whose terminating condition didn't make sense and was never true in practice. Rewrite it to more clearly do a set of partial writes of decreasing size. Fixes golang/go#32474 Change-Id: Ia53ceb39a34f1d6f2ea7c60190d52948bb0db59b Reviewed-on: https://go-review.googlesource.com/c/crypto/+/181121 Run-TryBot: Ian Lance Taylor <iant@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
1 parent 2aa609c commit 32487ec

File tree

2 files changed

+87
-18
lines changed

2 files changed

+87
-18
lines changed

openpgp/packet/packet.go

+50-17
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"crypto/rsa"
1515
"io"
1616
"math/big"
17+
"math/bits"
1718

1819
"golang.org/x/crypto/cast5"
1920
"golang.org/x/crypto/openpgp/errors"
@@ -100,33 +101,65 @@ func (r *partialLengthReader) Read(p []byte) (n int, err error) {
100101
type partialLengthWriter struct {
101102
w io.WriteCloser
102103
lengthByte [1]byte
104+
sentFirst bool
105+
buf []byte
103106
}
104107

108+
// RFC 4880 4.2.2.4: the first partial length MUST be at least 512 octets long.
109+
const minFirstPartialWrite = 512
110+
105111
func (w *partialLengthWriter) Write(p []byte) (n int, err error) {
112+
off := 0
113+
if !w.sentFirst {
114+
if len(w.buf) > 0 || len(p) < minFirstPartialWrite {
115+
off = len(w.buf)
116+
w.buf = append(w.buf, p...)
117+
if len(w.buf) < minFirstPartialWrite {
118+
return len(p), nil
119+
}
120+
p = w.buf
121+
w.buf = nil
122+
}
123+
w.sentFirst = true
124+
}
125+
126+
power := uint8(30)
106127
for len(p) > 0 {
107-
for power := uint(14); power < 32; power-- {
108-
l := 1 << power
109-
if len(p) >= l {
110-
w.lengthByte[0] = 224 + uint8(power)
111-
_, err = w.w.Write(w.lengthByte[:])
112-
if err != nil {
113-
return
114-
}
115-
var m int
116-
m, err = w.w.Write(p[:l])
117-
n += m
118-
if err != nil {
119-
return
120-
}
121-
p = p[l:]
122-
break
128+
l := 1 << power
129+
if len(p) < l {
130+
power = uint8(bits.Len32(uint32(len(p)))) - 1
131+
l = 1 << power
132+
}
133+
w.lengthByte[0] = 224 + power
134+
_, err = w.w.Write(w.lengthByte[:])
135+
if err == nil {
136+
var m int
137+
m, err = w.w.Write(p[:l])
138+
n += m
139+
}
140+
if err != nil {
141+
if n < off {
142+
return 0, err
123143
}
144+
return n - off, err
124145
}
146+
p = p[l:]
125147
}
126-
return
148+
return n - off, nil
127149
}
128150

129151
func (w *partialLengthWriter) Close() error {
152+
if len(w.buf) > 0 {
153+
// In this case we can't send a 512 byte packet.
154+
// Just send what we have.
155+
p := w.buf
156+
w.sentFirst = true
157+
w.buf = nil
158+
if _, err := w.Write(p); err != nil {
159+
return err
160+
}
161+
}
162+
130163
w.lengthByte[0] = 0
131164
_, err := w.w.Write(w.lengthByte[:])
132165
if err != nil {

openpgp/packet/packet_test.go

+37-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,21 @@ func TestPartialLengths(t *testing.T) {
232232
t.Errorf("error from write: %s", err)
233233
}
234234
}
235-
w.Close()
235+
if err := w.Close(); err != nil {
236+
t.Fatal(err)
237+
}
238+
239+
// The first packet should be at least 512 bytes.
240+
first, err := buf.ReadByte()
241+
if err != nil {
242+
t.Fatal(err)
243+
}
244+
if plen := 1 << (first & 0x1f); plen < 512 {
245+
t.Errorf("first packet too short: got %d want at least %d", plen, 512)
246+
}
247+
if err := buf.UnreadByte(); err != nil {
248+
t.Fatal(err)
249+
}
236250

237251
want := (maxChunkSize * (maxChunkSize + 1)) / 2
238252
copyBuf := bytes.NewBuffer(nil)
@@ -253,3 +267,25 @@ func TestPartialLengths(t *testing.T) {
253267
}
254268
}
255269
}
270+
271+
func TestPartialLengthsShortWrite(t *testing.T) {
272+
buf := bytes.NewBuffer(nil)
273+
w := &partialLengthWriter{
274+
w: noOpCloser{buf},
275+
}
276+
data := bytes.Repeat([]byte("a"), 510)
277+
if _, err := w.Write(data); err != nil {
278+
t.Fatal(err)
279+
}
280+
if err := w.Close(); err != nil {
281+
t.Fatal(err)
282+
}
283+
copyBuf := bytes.NewBuffer(nil)
284+
r := &partialLengthReader{buf, 0, true}
285+
if _, err := io.Copy(copyBuf, r); err != nil {
286+
t.Fatal(err)
287+
}
288+
if !bytes.Equal(copyBuf.Bytes(), data) {
289+
t.Errorf("got %q want %q", buf.Bytes(), data)
290+
}
291+
}

0 commit comments

Comments
 (0)