Skip to content

Commit fe1519e

Browse files
authoredAug 18, 2023
client: fix ClientStream.Header() behavior (#6557)
1 parent 8a2c220 commit fe1519e

File tree

7 files changed

+110
-70
lines changed

7 files changed

+110
-70
lines changed
 

‎binarylog/binarylog_end2end_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ import (
3131
"github.com/golang/protobuf/proto"
3232
"google.golang.org/grpc"
3333
"google.golang.org/grpc/binarylog"
34+
"google.golang.org/grpc/codes"
3435
"google.golang.org/grpc/credentials/insecure"
3536
"google.golang.org/grpc/grpclog"
3637
iblog "google.golang.org/grpc/internal/binarylog"
3738
"google.golang.org/grpc/internal/grpctest"
39+
"google.golang.org/grpc/internal/stubserver"
3840
"google.golang.org/grpc/metadata"
3941
"google.golang.org/grpc/status"
4042

@@ -1059,3 +1061,39 @@ func (s) TestServerBinaryLogFullDuplexError(t *testing.T) {
10591061
t.Fatal(err)
10601062
}
10611063
}
1064+
1065+
// TestCanceledStatus ensures a server that responds with a Canceled status has
1066+
// its trailers logged appropriately and is not treated as a canceled RPC.
1067+
func (s) TestCanceledStatus(t *testing.T) {
1068+
defer testSink.clear()
1069+
1070+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1071+
defer cancel()
1072+
1073+
const statusMsgWant = "server returned Canceled"
1074+
ss := &stubserver.StubServer{
1075+
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
1076+
grpc.SetTrailer(ctx, metadata.Pairs("key", "value"))
1077+
return nil, status.Error(codes.Canceled, statusMsgWant)
1078+
},
1079+
}
1080+
if err := ss.Start(nil); err != nil {
1081+
t.Fatalf("Error starting endpoint server: %v", err)
1082+
}
1083+
defer ss.Stop()
1084+
1085+
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Canceled {
1086+
t.Fatalf("Received unexpected error from UnaryCall: %v; want Canceled", err)
1087+
}
1088+
1089+
got := testSink.logEntries(true)
1090+
last := got[len(got)-1]
1091+
if last.Type != binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER ||
1092+
last.GetTrailer().GetStatusCode() != uint32(codes.Canceled) ||
1093+
last.GetTrailer().GetStatusMessage() != statusMsgWant ||
1094+
len(last.GetTrailer().GetMetadata().GetEntry()) != 1 ||
1095+
last.GetTrailer().GetMetadata().GetEntry()[0].GetKey() != "key" ||
1096+
string(last.GetTrailer().GetMetadata().GetEntry()[0].GetValue()) != "value" {
1097+
t.Fatalf("Got binary log: %+v; want last entry is server trailing with status Canceled", got)
1098+
}
1099+
}

‎internal/transport/http2_client.go

+15-16
Original file line numberDiff line numberDiff line change
@@ -1505,30 +1505,28 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
15051505
return
15061506
}
15071507

1508-
isHeader := false
1509-
1510-
// If headerChan hasn't been closed yet
1511-
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
1512-
s.headerValid = true
1513-
if !endStream {
1514-
// HEADERS frame block carries a Response-Headers.
1515-
isHeader = true
1508+
// For headers, set them in s.header and close headerChan. For trailers or
1509+
// trailers-only, closeStream will set the trailers and close headerChan as
1510+
// needed.
1511+
if !endStream {
1512+
// If headerChan hasn't been closed yet (expected, given we checked it
1513+
// above, but something else could have potentially closed the whole
1514+
// stream).
1515+
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
1516+
s.headerValid = true
15161517
// These values can be set without any synchronization because
15171518
// stream goroutine will read it only after seeing a closed
15181519
// headerChan which we'll close after setting this.
15191520
s.recvCompress = recvCompress
15201521
if len(mdata) > 0 {
15211522
s.header = mdata
15221523
}
1523-
} else {
1524-
// HEADERS frame block carries a Trailers-Only.
1525-
s.noHeaders = true
1524+
close(s.headerChan)
15261525
}
1527-
close(s.headerChan)
15281526
}
15291527

15301528
for _, sh := range t.statsHandlers {
1531-
if isHeader {
1529+
if !endStream {
15321530
inHeader := &stats.InHeader{
15331531
Client: true,
15341532
WireLength: int(frame.Header().Length),
@@ -1554,9 +1552,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
15541552
statusGen = status.New(rawStatusCode, grpcMessage)
15551553
}
15561554

1557-
// if client received END_STREAM from server while stream was still active, send RST_STREAM
1558-
rst := s.getState() == streamActive
1559-
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true)
1555+
// If client received END_STREAM from server while stream was still active,
1556+
// send RST_STREAM.
1557+
rstStream := s.getState() == streamActive
1558+
t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, statusGen, mdata, true)
15601559
}
15611560

15621561
// readServerPreface reads and handles the initial settings frame from the

‎internal/transport/transport.go

+1-9
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ import (
4343
"google.golang.org/grpc/tap"
4444
)
4545

46-
// ErrNoHeaders is used as a signal that a trailers only response was received,
47-
// and is not a real error.
48-
var ErrNoHeaders = errors.New("stream has no headers")
49-
5046
const logLevel = 2
5147

5248
type bufferPool struct {
@@ -390,14 +386,10 @@ func (s *Stream) Header() (metadata.MD, error) {
390386
}
391387
s.waitOnHeader()
392388

393-
if !s.headerValid {
389+
if !s.headerValid || s.noHeaders {
394390
return nil, s.status.Err()
395391
}
396392

397-
if s.noHeaders {
398-
return nil, ErrNoHeaders
399-
}
400-
401393
return s.header.Copy(), nil
402394
}
403395

‎rpc_util.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -867,15 +867,18 @@ func Errorf(c codes.Code, format string, a ...any) error {
867867
return status.Errorf(c, format, a...)
868868
}
869869

870+
var errContextCanceled = status.Error(codes.Canceled, context.Canceled.Error())
871+
var errContextDeadline = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
872+
870873
// toRPCErr converts an error into an error from the status package.
871874
func toRPCErr(err error) error {
872875
switch err {
873876
case nil, io.EOF:
874877
return err
875878
case context.DeadlineExceeded:
876-
return status.Error(codes.DeadlineExceeded, err.Error())
879+
return errContextDeadline
877880
case context.Canceled:
878-
return status.Error(codes.Canceled, err.Error())
881+
return errContextCanceled
879882
case io.ErrUnexpectedEOF:
880883
return status.Error(codes.Internal, err.Error())
881884
}

‎stream.go

+30-35
Original file line numberDiff line numberDiff line change
@@ -789,23 +789,23 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func())
789789

790790
func (cs *clientStream) Header() (metadata.MD, error) {
791791
var m metadata.MD
792-
noHeader := false
793792
err := cs.withRetry(func(a *csAttempt) error {
794793
var err error
795794
m, err = a.s.Header()
796-
if err == transport.ErrNoHeaders {
797-
noHeader = true
798-
return nil
799-
}
800795
return toRPCErr(err)
801796
}, cs.commitAttemptLocked)
802797

798+
if m == nil && err == nil {
799+
// The stream ended with success. Finish the clientStream.
800+
err = io.EOF
801+
}
802+
803803
if err != nil {
804804
cs.finish(err)
805805
return nil, err
806806
}
807807

808-
if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && !noHeader {
808+
if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && m != nil {
809809
// Only log if binary log is on and header has not been logged, and
810810
// there is actually headers to log.
811811
logEntry := &binarylog.ServerHeader{
@@ -821,6 +821,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {
821821
binlog.Log(cs.ctx, logEntry)
822822
}
823823
}
824+
824825
return m, nil
825826
}
826827

@@ -929,24 +930,6 @@ func (cs *clientStream) RecvMsg(m any) error {
929930
if err != nil || !cs.desc.ServerStreams {
930931
// err != nil or non-server-streaming indicates end of stream.
931932
cs.finish(err)
932-
933-
if len(cs.binlogs) != 0 {
934-
// finish will not log Trailer. Log Trailer here.
935-
logEntry := &binarylog.ServerTrailer{
936-
OnClientSide: true,
937-
Trailer: cs.Trailer(),
938-
Err: err,
939-
}
940-
if logEntry.Err == io.EOF {
941-
logEntry.Err = nil
942-
}
943-
if peer, ok := peer.FromContext(cs.Context()); ok {
944-
logEntry.PeerAddr = peer.Addr
945-
}
946-
for _, binlog := range cs.binlogs {
947-
binlog.Log(cs.ctx, logEntry)
948-
}
949-
}
950933
}
951934
return err
952935
}
@@ -1002,18 +985,30 @@ func (cs *clientStream) finish(err error) {
1002985
}
1003986
}
1004987
}
988+
1005989
cs.mu.Unlock()
1006-
// For binary logging. only log cancel in finish (could be caused by RPC ctx
1007-
// canceled or ClientConn closed). Trailer will be logged in RecvMsg.
1008-
//
1009-
// Only one of cancel or trailer needs to be logged. In the cases where
1010-
// users don't call RecvMsg, users must have already canceled the RPC.
1011-
if len(cs.binlogs) != 0 && status.Code(err) == codes.Canceled {
1012-
c := &binarylog.Cancel{
1013-
OnClientSide: true,
1014-
}
1015-
for _, binlog := range cs.binlogs {
1016-
binlog.Log(cs.ctx, c)
990+
// Only one of cancel or trailer needs to be logged.
991+
if len(cs.binlogs) != 0 {
992+
switch err {
993+
case errContextCanceled, errContextDeadline, ErrClientConnClosing:
994+
c := &binarylog.Cancel{
995+
OnClientSide: true,
996+
}
997+
for _, binlog := range cs.binlogs {
998+
binlog.Log(cs.ctx, c)
999+
}
1000+
default:
1001+
logEntry := &binarylog.ServerTrailer{
1002+
OnClientSide: true,
1003+
Trailer: cs.Trailer(),
1004+
Err: err,
1005+
}
1006+
if peer, ok := peer.FromContext(cs.Context()); ok {
1007+
logEntry.PeerAddr = peer.Addr
1008+
}
1009+
for _, binlog := range cs.binlogs {
1010+
binlog.Log(cs.ctx, logEntry)
1011+
}
10171012
}
10181013
}
10191014
if err == nil {

‎test/end2end_test.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -6328,12 +6328,11 @@ func (s) TestGlobalBinaryLoggingOptions(t *testing.T) {
63286328
return &testpb.SimpleResponse{}, nil
63296329
},
63306330
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
6331-
for {
6332-
_, err := stream.Recv()
6333-
if err == io.EOF {
6334-
return nil
6335-
}
6331+
_, err := stream.Recv()
6332+
if err == io.EOF {
6333+
return nil
63366334
}
6335+
return status.Errorf(codes.Unknown, "expected client to call CloseSend")
63376336
},
63386337
}
63396338

‎test/retry_test.go

+17-3
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ func (s) TestRetryStreaming(t *testing.T) {
211211
return nil
212212
}
213213
}
214+
sHdr := func() serverOp {
215+
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
216+
return stream.SendHeader(metadata.Pairs("test_header", "test_value"))
217+
}
218+
}
214219
sRes := func(b byte) serverOp {
215220
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
216221
msg := res(b)
@@ -222,7 +227,7 @@ func (s) TestRetryStreaming(t *testing.T) {
222227
}
223228
sErr := func(c codes.Code) serverOp {
224229
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
225-
return status.New(c, "").Err()
230+
return status.New(c, "this is a test error").Err()
226231
}
227232
}
228233
sCloseSend := func() serverOp {
@@ -270,7 +275,7 @@ func (s) TestRetryStreaming(t *testing.T) {
270275
}
271276
cErr := func(c codes.Code) clientOp {
272277
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
273-
want := status.New(c, "").Err()
278+
want := status.New(c, "this is a test error").Err()
274279
if c == codes.OK {
275280
want = io.EOF
276281
}
@@ -309,6 +314,11 @@ func (s) TestRetryStreaming(t *testing.T) {
309314
cHdr := func() clientOp {
310315
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
311316
_, err := stream.Header()
317+
if err == io.EOF {
318+
// The stream ended successfully; convert to nil to avoid
319+
// erroring the test case.
320+
err = nil
321+
}
312322
return err
313323
}
314324
}
@@ -362,9 +372,13 @@ func (s) TestRetryStreaming(t *testing.T) {
362372
sReq(1), sRes(3), sErr(codes.Unavailable),
363373
},
364374
clientOps: []clientOp{cReq(1), cRes(3), cErr(codes.Unavailable)},
375+
}, {
376+
desc: "Retry via ClientStream.Header()",
377+
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable), sReq(1), sAttempts(1)},
378+
clientOps: []clientOp{cReq(1), cHdr() /* this should cause a retry */, cErr(codes.OK)},
365379
}, {
366380
desc: "No retry after header",
367-
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)},
381+
serverOps: []serverOp{sReq(1), sHdr(), sErr(codes.Unavailable)},
368382
clientOps: []clientOp{cReq(1), cHdr(), cErr(codes.Unavailable)},
369383
}, {
370384
desc: "No retry after context",

0 commit comments

Comments
 (0)