Skip to content

Commit 03694ce

Browse files
committed
Support full duplex streaming
1 parent b40de04 commit 03694ce

File tree

5 files changed

+170
-59
lines changed

5 files changed

+170
-59
lines changed

cmd/body-based-routing/main.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ var (
4242
"grpcHealthPort",
4343
9003,
4444
"The port used for gRPC liveness and readiness probes")
45+
streaming = flag.Bool(
46+
"streaming", false, "Enables streaming support for Envoy full-duplex streaming mode")
4547
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
4648

4749
setupLog = ctrl.Log.WithName("setup")
@@ -82,7 +84,7 @@ func run() error {
8284
ctx := ctrl.SetupSignalHandler()
8385

8486
// Setup runner.
85-
serverRunner := &runserver.ExtProcServerRunner{GrpcPort: *grpcPort}
87+
serverRunner := runserver.NewDefaultExtProcServerRunner(*streaming)
8688

8789
// Register health server.
8890
if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), *grpcHealthPort); err != nil {

pkg/body-based-routing/handlers/request.go

+22-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package handlers
1818

1919
import (
2020
"context"
21-
"encoding/json"
2221
"fmt"
2322

2423
basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -28,14 +27,9 @@ import (
2827
)
2928

3029
// HandleRequestBody handles request bodies.
31-
func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*eppb.ProcessingResponse, error) {
30+
func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) (*eppb.ProcessingResponse, error) {
3231
logger := log.FromContext(ctx)
3332

34-
var data map[string]any
35-
if err := json.Unmarshal(body.GetBody(), &data); err != nil {
36-
return nil, err
37-
}
38-
3933
modelVal, ok := data["model"]
4034
if !ok {
4135
logger.V(logutil.DEFAULT).Info("Request body does not contain model parameter")
@@ -56,6 +50,27 @@ func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*e
5650
}, fmt.Errorf("the model parameter value %v is not a string", modelVal)
5751
}
5852

53+
if s.streaming {
54+
return &eppb.ProcessingResponse{
55+
Response: &eppb.ProcessingResponse_RequestHeaders{
56+
RequestHeaders: &eppb.HeadersResponse{
57+
Response: &eppb.CommonResponse{
58+
ClearRouteCache: true,
59+
HeaderMutation: &eppb.HeaderMutation{
60+
SetHeaders: []*basepb.HeaderValueOption{
61+
{
62+
Header: &basepb.HeaderValue{
63+
Key: "X-Gateway-Model-Name",
64+
RawValue: []byte(modelStr),
65+
},
66+
},
67+
},
68+
},
69+
},
70+
},
71+
},
72+
}, nil
73+
}
5974
return &eppb.ProcessingResponse{
6075
Response: &eppb.ProcessingResponse_RequestBody{
6176
RequestBody: &eppb.BodyResponse{

pkg/body-based-routing/handlers/request_test.go

+41-38
Original file line numberDiff line numberDiff line change
@@ -27,46 +27,20 @@ import (
2727
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2828
)
2929

30-
const (
31-
bodyWithModel = `
32-
{
33-
"model": "foo",
34-
"prompt": "Tell me a joke"
35-
}
36-
`
37-
bodyWithModelNoStr = `
38-
{
39-
"model": 1,
40-
"prompt": "Tell me a joke"
41-
}
42-
`
43-
bodyWithoutModel = `
44-
{
45-
"prompt": "Tell me a joke"
46-
}
47-
`
48-
)
49-
5030
func TestHandleRequestBody(t *testing.T) {
5131
ctx := logutil.NewTestLoggerIntoContext(context.Background())
5232

5333
tests := []struct {
54-
name string
55-
body *extProcPb.HttpBody
56-
want *extProcPb.ProcessingResponse
57-
wantErr bool
34+
name string
35+
body map[string]any
36+
streaming bool
37+
want *extProcPb.ProcessingResponse
38+
wantErr bool
5839
}{
59-
{
60-
name: "malformed body",
61-
body: &extProcPb.HttpBody{
62-
Body: []byte("malformed json"),
63-
},
64-
wantErr: true,
65-
},
6640
{
6741
name: "model not found",
68-
body: &extProcPb.HttpBody{
69-
Body: []byte(bodyWithoutModel),
42+
body: map[string]any{
43+
"prompt": "Tell me a joke",
7044
},
7145
want: &extProcPb.ProcessingResponse{
7246
Response: &extProcPb.ProcessingResponse_RequestBody{
@@ -76,15 +50,17 @@ func TestHandleRequestBody(t *testing.T) {
7650
},
7751
{
7852
name: "model is not string",
79-
body: &extProcPb.HttpBody{
80-
Body: []byte(bodyWithModelNoStr),
53+
body: map[string]any{
54+
"model": 1,
55+
"prompt": "Tell me a joke",
8156
},
8257
wantErr: true,
8358
},
8459
{
8560
name: "success",
86-
body: &extProcPb.HttpBody{
87-
Body: []byte(bodyWithModel),
61+
body: map[string]any{
62+
"model": "foo",
63+
"prompt": "Tell me a joke",
8864
},
8965
want: &extProcPb.ProcessingResponse{
9066
Response: &extProcPb.ProcessingResponse_RequestBody{
@@ -107,11 +83,38 @@ func TestHandleRequestBody(t *testing.T) {
10783
},
10884
},
10985
},
86+
{
87+
name: "success-with-streaming",
88+
body: map[string]any{
89+
"model": "foo",
90+
"prompt": "Tell me a joke",
91+
},
92+
streaming: true,
93+
want: &extProcPb.ProcessingResponse{
94+
Response: &extProcPb.ProcessingResponse_RequestHeaders{
95+
RequestHeaders: &extProcPb.HeadersResponse{
96+
Response: &extProcPb.CommonResponse{
97+
ClearRouteCache: true,
98+
HeaderMutation: &extProcPb.HeaderMutation{
99+
SetHeaders: []*basepb.HeaderValueOption{
100+
{
101+
Header: &basepb.HeaderValue{
102+
Key: "X-Gateway-Model-Name",
103+
RawValue: []byte("foo"),
104+
},
105+
},
106+
},
107+
},
108+
},
109+
},
110+
},
111+
},
112+
},
110113
}
111114

112115
for _, test := range tests {
113116
t.Run(test.name, func(t *testing.T) {
114-
server := &Server{}
117+
server := &Server{streaming: test.streaming}
115118
resp, err := server.HandleRequestBody(ctx, test.body)
116119
if err != nil {
117120
if !test.wantErr {

pkg/body-based-routing/handlers/server.go

+98-9
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,36 @@ package handlers
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"errors"
2223
"io"
2324

2425
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
26+
"github.com/go-logr/logr"
2527
"google.golang.org/grpc/codes"
2628
"google.golang.org/grpc/status"
2729
"sigs.k8s.io/controller-runtime/pkg/log"
2830
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2931
)
3032

31-
func NewServer() *Server {
32-
return &Server{}
33+
func NewServer(streaming bool) *Server {
34+
return &Server{streaming: streaming}
3335
}
3436

3537
// Server implements the Envoy external processing server.
3638
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto
37-
type Server struct{}
39+
type Server struct {
40+
streaming bool
41+
}
3842

3943
func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
4044
ctx := srv.Context()
4145
logger := log.FromContext(ctx)
4246
loggerVerbose := logger.V(logutil.VERBOSE)
4347
loggerVerbose.Info("Processing")
4448

49+
reader, writer := io.Pipe()
50+
4551
for {
4652
select {
4753
case <-ctx.Done():
@@ -61,12 +67,21 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
6167
}
6268

6369
var resp *extProcPb.ProcessingResponse
70+
var requestBody []byte
6471
var err error
6572
switch v := req.Request.(type) {
6673
case *extProcPb.ProcessingRequest_RequestHeaders:
67-
resp, err = s.HandleRequestHeaders(req.GetRequestHeaders())
74+
if !s.streaming {
75+
// If streaming, then headers are handled when processing request body.
76+
resp, err = s.HandleRequestHeaders(req.GetRequestHeaders())
77+
} else {
78+
loggerVerbose.Info("Received headers, passing off header processing until body arrives...")
79+
}
6880
case *extProcPb.ProcessingRequest_RequestBody:
69-
resp, err = s.HandleRequestBody(ctx, req.GetRequestBody())
81+
loggerVerbose.Info("Incoming body chunk", "body", string(v.RequestBody.Body), "EoS", v.RequestBody.EndOfStream)
82+
resp, requestBody, err = s.processRequestBody(ctx, req.GetRequestBody(), writer, reader, logger)
83+
case *extProcPb.ProcessingRequest_RequestTrailers:
84+
resp, err = s.HandleRequestTrailers(req.GetRequestTrailers())
7085
case *extProcPb.ProcessingRequest_ResponseHeaders:
7186
resp, err = s.HandleResponseHeaders(req.GetResponseHeaders())
7287
case *extProcPb.ProcessingRequest_ResponseBody:
@@ -81,10 +96,84 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
8196
return status.Errorf(status.Code(err), "failed to handle request: %v", err)
8297
}
8398

84-
loggerVerbose.Info("Response generated", "response", resp)
85-
if err := srv.Send(resp); err != nil {
86-
logger.V(logutil.DEFAULT).Error(err, "Send failed")
87-
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
99+
if resp != nil {
100+
loggerVerbose.Info("Response generated", "response", resp)
101+
if err := srv.Send(resp); err != nil {
102+
logger.V(logutil.DEFAULT).Error(err, "Send failed")
103+
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
104+
}
105+
106+
if s.streaming {
107+
bodyResp := &extProcPb.ProcessingResponse{
108+
Response: &extProcPb.ProcessingResponse_RequestBody{
109+
RequestBody: &extProcPb.BodyResponse{
110+
Response: &extProcPb.CommonResponse{
111+
BodyMutation: &extProcPb.BodyMutation{
112+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
113+
StreamedResponse: &extProcPb.StreamedBodyResponse{
114+
Body: requestBody,
115+
EndOfStream: true,
116+
},
117+
},
118+
},
119+
},
120+
},
121+
},
122+
}
123+
loggerVerbose.Info("Response generated", "response", bodyResp)
124+
if err := srv.Send(bodyResp); err != nil {
125+
logger.V(logutil.DEFAULT).Error(err, "Send failed")
126+
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
127+
}
128+
}
88129
}
89130
}
90131
}
132+
133+
func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, bufferWriter *io.PipeWriter, bufferReader *io.PipeReader, logger logr.Logger) (*extProcPb.ProcessingResponse, []byte, error) {
134+
loggerVerbose := logger.V(logutil.VERBOSE)
135+
136+
var requestBody map[string]interface{}
137+
if s.streaming {
138+
// In the stream case, we can receive multiple request bodies.
139+
// To buffer the full message, we create a goroutine with a writer.Write()
140+
// call, which will block until the corresponding reader reads from it.
141+
// We do not read until we receive the EndofStream signal, and then
142+
// decode the entire JSON body.
143+
if !body.EndOfStream {
144+
go func() {
145+
loggerVerbose.Info("Writing to stream buffer")
146+
_, err := bufferWriter.Write(body.Body)
147+
if err != nil {
148+
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
149+
}
150+
}()
151+
152+
return nil, nil, nil
153+
}
154+
155+
if body.EndOfStream {
156+
loggerVerbose.Info("Flushing stream buffer")
157+
decoder := json.NewDecoder(bufferReader)
158+
if err := decoder.Decode(&requestBody); err != nil {
159+
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
160+
}
161+
bufferReader.Close()
162+
}
163+
} else {
164+
if err := json.Unmarshal(body.GetBody(), &requestBody); err != nil {
165+
return nil, nil, err
166+
}
167+
}
168+
169+
requestBodyResp, err := s.HandleRequestBody(ctx, requestBody)
170+
if err != nil {
171+
return nil, nil, err
172+
}
173+
174+
requestBodyBytes, err := json.Marshal(requestBody)
175+
if err != nil {
176+
return nil, nil, err
177+
}
178+
return requestBodyResp, requestBodyBytes, nil
179+
}

pkg/body-based-routing/server/runserver.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@ import (
3232

3333
// ExtProcServerRunner provides methods to manage an external process server.
3434
type ExtProcServerRunner struct {
35-
GrpcPort int
35+
GrpcPort int
36+
Streaming bool
3637
}
3738

3839
// Default values for CLI flags in main
3940
const (
4041
DefaultGrpcPort = 9002 // default for --grpcPort
4142
)
4243

43-
func NewDefaultExtProcServerRunner() *ExtProcServerRunner {
44+
func NewDefaultExtProcServerRunner(streaming bool) *ExtProcServerRunner {
4445
return &ExtProcServerRunner{
45-
GrpcPort: DefaultGrpcPort,
46+
GrpcPort: DefaultGrpcPort,
47+
Streaming: streaming,
4648
}
4749
}
4850

@@ -60,7 +62,7 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable {
6062
srv := grpc.NewServer(grpc.Creds(creds))
6163
extProcPb.RegisterExternalProcessorServer(
6264
srv,
63-
handlers.NewServer(),
65+
handlers.NewServer(r.Streaming),
6466
)
6567

6668
// Forward to the gRPC runnable.

0 commit comments

Comments
 (0)