Skip to content

Commit 2a18bfc

Browse files
authored
transport: refactor to split ClientStream from ServerStream from common Stream functionality (#7802)
1 parent 70e8931 commit 2a18bfc

12 files changed

+410
-340
lines changed

internal/transport/client_stream.go

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
*
3+
* Copyright 2024 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package transport
20+
21+
import (
22+
"sync/atomic"
23+
24+
"google.golang.org/grpc/metadata"
25+
"google.golang.org/grpc/status"
26+
)
27+
28+
// ClientStream implements streaming functionality for a gRPC client.
29+
type ClientStream struct {
30+
*Stream // Embed for common stream functionality.
31+
32+
ct ClientTransport
33+
done chan struct{} // closed at the end of stream to unblock writers.
34+
doneFunc func() // invoked at the end of stream.
35+
36+
headerChan chan struct{} // closed to indicate the end of header metadata.
37+
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
38+
// headerValid indicates whether a valid header was received. Only
39+
// meaningful after headerChan is closed (always call waitOnHeader() before
40+
// reading its value).
41+
headerValid bool
42+
header metadata.MD // the received header metadata
43+
noHeaders bool // set if the client never received headers (set only after the stream is done).
44+
45+
bytesReceived uint32 // indicates whether any bytes have been received on this stream
46+
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
47+
48+
status *status.Status // the status error received from the server
49+
}
50+
51+
// BytesReceived indicates whether any bytes have been received on this stream.
52+
func (s *ClientStream) BytesReceived() bool {
53+
return atomic.LoadUint32(&s.bytesReceived) == 1
54+
}
55+
56+
// Unprocessed indicates whether the server did not process this stream --
57+
// i.e. it sent a refused stream or GOAWAY including this stream ID.
58+
func (s *ClientStream) Unprocessed() bool {
59+
return atomic.LoadUint32(&s.unprocessed) == 1
60+
}
61+
62+
func (s *ClientStream) waitOnHeader() {
63+
select {
64+
case <-s.ctx.Done():
65+
// Close the stream to prevent headers/trailers from changing after
66+
// this function returns.
67+
s.ct.CloseStream(s, ContextErr(s.ctx.Err()))
68+
// headerChan could possibly not be closed yet if closeStream raced
69+
// with operateHeaders; wait until it is closed explicitly here.
70+
<-s.headerChan
71+
case <-s.headerChan:
72+
}
73+
}
74+
75+
// RecvCompress returns the compression algorithm applied to the inbound
76+
// message. It is empty string if there is no compression applied.
77+
func (s *ClientStream) RecvCompress() string {
78+
s.waitOnHeader()
79+
return s.recvCompress
80+
}
81+
82+
// Done returns a channel which is closed when it receives the final status
83+
// from the server.
84+
func (s *ClientStream) Done() <-chan struct{} {
85+
return s.done
86+
}
87+
88+
// Header returns the header metadata of the stream. Acquires the key-value
89+
// pairs of header metadata once it is available. It blocks until i) the
90+
// metadata is ready or ii) there is no header metadata or iii) the stream is
91+
// canceled/expired.
92+
func (s *ClientStream) Header() (metadata.MD, error) {
93+
s.waitOnHeader()
94+
95+
if !s.headerValid || s.noHeaders {
96+
return nil, s.status.Err()
97+
}
98+
99+
return s.header.Copy(), nil
100+
}
101+
102+
// TrailersOnly blocks until a header or trailers-only frame is received and
103+
// then returns true if the stream was trailers-only. If the stream ends
104+
// before headers are received, returns true, nil.
105+
func (s *ClientStream) TrailersOnly() bool {
106+
s.waitOnHeader()
107+
return s.noHeaders
108+
}
109+
110+
// Status returns the status received from the server.
111+
// Status can be read safely only after the stream has ended,
112+
// that is, after Done() is closed.
113+
func (s *ClientStream) Status() *status.Status {
114+
return s.status
115+
}

internal/transport/handler_server.go

+17-15
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ func (ht *serverHandlerTransport) do(fn func()) error {
225225
}
226226
}
227227

228-
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
228+
func (ht *serverHandlerTransport) WriteStatus(s *ServerStream, st *status.Status) error {
229229
ht.writeStatusMu.Lock()
230230
defer ht.writeStatusMu.Unlock()
231231

@@ -289,14 +289,14 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
289289

290290
// writePendingHeaders sets common and custom headers on the first
291291
// write call (Write, WriteHeader, or WriteStatus)
292-
func (ht *serverHandlerTransport) writePendingHeaders(s *Stream) {
292+
func (ht *serverHandlerTransport) writePendingHeaders(s *ServerStream) {
293293
ht.writeCommonHeaders(s)
294294
ht.writeCustomHeaders(s)
295295
}
296296

297297
// writeCommonHeaders sets common headers on the first write
298298
// call (Write, WriteHeader, or WriteStatus).
299-
func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
299+
func (ht *serverHandlerTransport) writeCommonHeaders(s *ServerStream) {
300300
h := ht.rw.Header()
301301
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
302302
h.Set("Content-Type", ht.contentType)
@@ -317,7 +317,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
317317

318318
// writeCustomHeaders sets custom headers set on the stream via SetHeader
319319
// on the first write call (Write, WriteHeader, or WriteStatus)
320-
func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
320+
func (ht *serverHandlerTransport) writeCustomHeaders(s *ServerStream) {
321321
h := ht.rw.Header()
322322

323323
s.hdrMu.Lock()
@@ -333,7 +333,7 @@ func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
333333
s.hdrMu.Unlock()
334334
}
335335

336-
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSlice, _ *Options) error {
336+
func (ht *serverHandlerTransport) Write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *Options) error {
337337
// Always take a reference because otherwise there is no guarantee the data will
338338
// be available after this function returns. This is what callers to Write
339339
// expect.
@@ -357,7 +357,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSl
357357
return nil
358358
}
359359

360-
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
360+
func (ht *serverHandlerTransport) WriteHeader(s *ServerStream, md metadata.MD) error {
361361
if err := s.SetHeader(md); err != nil {
362362
return err
363363
}
@@ -385,7 +385,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
385385
return err
386386
}
387387

388-
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
388+
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
389389
// With this transport type there will be exactly 1 stream: this HTTP request.
390390
var cancel context.CancelFunc
391391
if ht.timeoutSet {
@@ -408,16 +408,18 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
408408

409409
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
410410
req := ht.req
411-
s := &Stream{
412-
id: 0, // irrelevant
413-
ctx: ctx,
414-
requestRead: func(int) {},
411+
s := &ServerStream{
412+
Stream: &Stream{
413+
id: 0, // irrelevant
414+
ctx: ctx,
415+
requestRead: func(int) {},
416+
buf: newRecvBuffer(),
417+
method: req.URL.Path,
418+
recvCompress: req.Header.Get("grpc-encoding"),
419+
contentSubtype: ht.contentSubtype,
420+
},
415421
cancel: cancel,
416-
buf: newRecvBuffer(),
417422
st: ht,
418-
method: req.URL.Path,
419-
recvCompress: req.Header.Get("grpc-encoding"),
420-
contentSubtype: ht.contentSubtype,
421423
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
422424
}
423425
s.trReader = &transportReader{

internal/transport/handler_server_test.go

+12-12
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
274274

275275
func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
276276
st := newHandleStreamTest(t)
277-
handleStream := func(s *Stream) {
277+
handleStream := func(s *ServerStream) {
278278
if want := "/service/foo.bar"; s.method != want {
279279
t.Errorf("stream method = %q; want %q", s.method, want)
280280
}
@@ -313,7 +313,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
313313
st.ht.WriteStatus(s, status.New(codes.OK, ""))
314314
}
315315
st.ht.HandleStreams(
316-
context.Background(), func(s *Stream) { go handleStream(s) },
316+
context.Background(), func(s *ServerStream) { go handleStream(s) },
317317
)
318318
wantHeader := http.Header{
319319
"Date": nil,
@@ -342,11 +342,11 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
342342
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
343343
st := newHandleStreamTest(t)
344344

345-
handleStream := func(s *Stream) {
345+
handleStream := func(s *ServerStream) {
346346
st.ht.WriteStatus(s, status.New(statusCode, msg))
347347
}
348348
st.ht.HandleStreams(
349-
context.Background(), func(s *Stream) { go handleStream(s) },
349+
context.Background(), func(s *ServerStream) { go handleStream(s) },
350350
)
351351
wantHeader := http.Header{
352352
"Date": nil,
@@ -379,7 +379,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
379379
if err != nil {
380380
t.Fatal(err)
381381
}
382-
runStream := func(s *Stream) {
382+
runStream := func(s *ServerStream) {
383383
defer bodyw.Close()
384384
select {
385385
case <-s.ctx.Done():
@@ -395,7 +395,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
395395
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
396396
}
397397
ht.HandleStreams(
398-
context.Background(), func(s *Stream) { go runStream(s) },
398+
context.Background(), func(s *ServerStream) { go runStream(s) },
399399
)
400400
wantHeader := http.Header{
401401
"Date": nil,
@@ -412,7 +412,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
412412
// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
413413
// concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
414414
func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
415-
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
415+
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
416416
if want := "/service/foo.bar"; s.method != want {
417417
t.Errorf("stream method = %q; want %q", s.method, want)
418418
}
@@ -433,7 +433,7 @@ func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
433433
// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
434434
// following "WriteStatus" does not panic writing to closed "writes" channel.
435435
func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
436-
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
436+
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
437437
if want := "/service/foo.bar"; s.method != want {
438438
t.Errorf("stream method = %q; want %q", s.method, want)
439439
}
@@ -444,10 +444,10 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
444444
})
445445
}
446446

447-
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
447+
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
448448
st := newHandleStreamTest(t)
449449
st.ht.HandleStreams(
450-
context.Background(), func(s *Stream) { go handleStream(st, s) },
450+
context.Background(), func(s *ServerStream) { go handleStream(st, s) },
451451
)
452452
}
453453

@@ -476,11 +476,11 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
476476
}
477477

478478
hst := newHandleStreamTest(t)
479-
handleStream := func(s *Stream) {
479+
handleStream := func(s *ServerStream) {
480480
hst.ht.WriteStatus(s, st)
481481
}
482482
hst.ht.HandleStreams(
483-
context.Background(), func(s *Stream) { go handleStream(s) },
483+
context.Background(), func(s *ServerStream) { go handleStream(s) },
484484
)
485485
wantHeader := http.Header{
486486
"Date": nil,

0 commit comments

Comments
 (0)