Skip to content

Commit f2344ef

Browse files
committed
Support full duplex streaming
1 parent 53cb18f commit f2344ef

File tree

6 files changed

+176
-67
lines changed

6 files changed

+176
-67
lines changed

cmd/body-based-routing/main.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@ import (
4444
var (
4545
grpcPort = flag.Int(
4646
"grpcPort",
47-
runserver.DefaultGrpcPort,
47+
9004,
4848
"The gRPC port used for communicating with Envoy proxy")
4949
grpcHealthPort = flag.Int(
5050
"grpcHealthPort",
5151
9005,
5252
"The port used for gRPC liveness and readiness probes")
5353
metricsPort = flag.Int(
5454
"metricsPort", 9090, "The metrics port")
55+
streaming = flag.Bool(
56+
"streaming", false, "Enables streaming support for Envoy full-duplex streaming mode")
5557
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
5658

5759
setupLog = ctrl.Log.WithName("setup")
@@ -92,7 +94,7 @@ func run() error {
9294
ctx := ctrl.SetupSignalHandler()
9395

9496
// Setup runner.
95-
serverRunner := &runserver.ExtProcServerRunner{GrpcPort: *grpcPort}
97+
serverRunner := runserver.NewDefaultExtProcServerRunner(*grpcPort, *streaming)
9698

9799
// Register health server.
98100
if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), *grpcHealthPort); err != nil {

pkg/body-based-routing/handlers/request.go

+24-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package handlers
1818

1919
import (
2020
"context"
21-
"encoding/json"
2221
"fmt"
2322

2423
basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -29,14 +28,9 @@ import (
2928
)
3029

3130
// HandleRequestBody handles request bodies.
32-
func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*eppb.ProcessingResponse, error) {
31+
func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) (*eppb.ProcessingResponse, error) {
3332
logger := log.FromContext(ctx)
3433

35-
var data map[string]any
36-
if err := json.Unmarshal(body.GetBody(), &data); err != nil {
37-
return nil, err
38-
}
39-
4034
modelVal, ok := data["model"]
4135
if !ok {
4236
metrics.RecordModelNotInBodyCounter()
@@ -60,6 +54,29 @@ func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*e
6054
}
6155

6256
metrics.RecordSuccessCounter()
57+
58+
if s.streaming {
59+
return &eppb.ProcessingResponse{
60+
Response: &eppb.ProcessingResponse_RequestHeaders{
61+
RequestHeaders: &eppb.HeadersResponse{
62+
Response: &eppb.CommonResponse{
63+
ClearRouteCache: true,
64+
HeaderMutation: &eppb.HeaderMutation{
65+
SetHeaders: []*basepb.HeaderValueOption{
66+
{
67+
Header: &basepb.HeaderValue{
68+
Key: "X-Gateway-Model-Name",
69+
RawValue: []byte(modelStr),
70+
},
71+
},
72+
},
73+
},
74+
},
75+
},
76+
},
77+
}, nil
78+
}
79+
6380
return &eppb.ProcessingResponse{
6481
Response: &eppb.ProcessingResponse_RequestBody{
6582
RequestBody: &eppb.BodyResponse{

pkg/body-based-routing/handlers/request_test.go

+41-38
Original file line numberDiff line numberDiff line change
@@ -31,47 +31,21 @@ import (
3131
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3232
)
3333

34-
const (
35-
bodyWithModel = `
36-
{
37-
"model": "foo",
38-
"prompt": "Tell me a joke"
39-
}
40-
`
41-
bodyWithModelNoStr = `
42-
{
43-
"model": 1,
44-
"prompt": "Tell me a joke"
45-
}
46-
`
47-
bodyWithoutModel = `
48-
{
49-
"prompt": "Tell me a joke"
50-
}
51-
`
52-
)
53-
5434
func TestHandleRequestBody(t *testing.T) {
5535
metrics.Register()
5636
ctx := logutil.NewTestLoggerIntoContext(context.Background())
5737

5838
tests := []struct {
59-
name string
60-
body *extProcPb.HttpBody
61-
want *extProcPb.ProcessingResponse
62-
wantErr bool
39+
name string
40+
body map[string]any
41+
streaming bool
42+
want *extProcPb.ProcessingResponse
43+
wantErr bool
6344
}{
64-
{
65-
name: "malformed body",
66-
body: &extProcPb.HttpBody{
67-
Body: []byte("malformed json"),
68-
},
69-
wantErr: true,
70-
},
7145
{
7246
name: "model not found",
73-
body: &extProcPb.HttpBody{
74-
Body: []byte(bodyWithoutModel),
47+
body: map[string]any{
48+
"prompt": "Tell me a joke",
7549
},
7650
want: &extProcPb.ProcessingResponse{
7751
Response: &extProcPb.ProcessingResponse_RequestBody{
@@ -81,15 +55,17 @@ func TestHandleRequestBody(t *testing.T) {
8155
},
8256
{
8357
name: "model is not string",
84-
body: &extProcPb.HttpBody{
85-
Body: []byte(bodyWithModelNoStr),
58+
body: map[string]any{
59+
"model": 1,
60+
"prompt": "Tell me a joke",
8661
},
8762
wantErr: true,
8863
},
8964
{
9065
name: "success",
91-
body: &extProcPb.HttpBody{
92-
Body: []byte(bodyWithModel),
66+
body: map[string]any{
67+
"model": "foo",
68+
"prompt": "Tell me a joke",
9369
},
9470
want: &extProcPb.ProcessingResponse{
9571
Response: &extProcPb.ProcessingResponse_RequestBody{
@@ -112,11 +88,38 @@ func TestHandleRequestBody(t *testing.T) {
11288
},
11389
},
11490
},
91+
{
92+
name: "success-with-streaming",
93+
body: map[string]any{
94+
"model": "foo",
95+
"prompt": "Tell me a joke",
96+
},
97+
streaming: true,
98+
want: &extProcPb.ProcessingResponse{
99+
Response: &extProcPb.ProcessingResponse_RequestHeaders{
100+
RequestHeaders: &extProcPb.HeadersResponse{
101+
Response: &extProcPb.CommonResponse{
102+
ClearRouteCache: true,
103+
HeaderMutation: &extProcPb.HeaderMutation{
104+
SetHeaders: []*basepb.HeaderValueOption{
105+
{
106+
Header: &basepb.HeaderValue{
107+
Key: "X-Gateway-Model-Name",
108+
RawValue: []byte("foo"),
109+
},
110+
},
111+
},
112+
},
113+
},
114+
},
115+
},
116+
},
117+
},
115118
}
116119

117120
for _, test := range tests {
118121
t.Run(test.name, func(t *testing.T) {
119-
server := &Server{}
122+
server := &Server{streaming: test.streaming}
120123
resp, err := server.HandleRequestBody(ctx, test.body)
121124
if err != nil {
122125
if !test.wantErr {

pkg/body-based-routing/handlers/server.go

+96-9
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,36 @@ package handlers
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"errors"
2223
"io"
2324

2425
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
26+
"github.com/go-logr/logr"
2527
"google.golang.org/grpc/codes"
2628
"google.golang.org/grpc/status"
2729
"sigs.k8s.io/controller-runtime/pkg/log"
2830
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2931
)
3032

31-
func NewServer() *Server {
32-
return &Server{}
33+
func NewServer(streaming bool) *Server {
34+
return &Server{streaming: streaming}
3335
}
3436

3537
// Server implements the Envoy external processing server.
3638
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto
37-
type Server struct{}
39+
type Server struct {
40+
streaming bool
41+
}
3842

3943
func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
4044
ctx := srv.Context()
4145
logger := log.FromContext(ctx)
4246
loggerVerbose := logger.V(logutil.VERBOSE)
4347
loggerVerbose.Info("Processing")
4448

49+
reader, writer := io.Pipe()
50+
4551
for {
4652
select {
4753
case <-ctx.Done():
@@ -61,12 +67,19 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
6167
}
6268

6369
var resp *extProcPb.ProcessingResponse
70+
var requestBody []byte
6471
var err error
6572
switch v := req.Request.(type) {
6673
case *extProcPb.ProcessingRequest_RequestHeaders:
67-
resp, err = s.HandleRequestHeaders(req.GetRequestHeaders())
74+
if !s.streaming {
75+
// If streaming, then headers are handled when processing request body.
76+
resp, err = s.HandleRequestHeaders(req.GetRequestHeaders())
77+
} else {
78+
loggerVerbose.Info("Received headers, passing off header processing until body arrives...")
79+
}
6880
case *extProcPb.ProcessingRequest_RequestBody:
69-
resp, err = s.HandleRequestBody(ctx, req.GetRequestBody())
81+
loggerVerbose.Info("Incoming body chunk", "body", string(v.RequestBody.Body), "EoS", v.RequestBody.EndOfStream)
82+
resp, requestBody, err = s.processRequestBody(ctx, req.GetRequestBody(), writer, reader, logger)
7083
case *extProcPb.ProcessingRequest_RequestTrailers:
7184
resp, err = s.HandleRequestTrailers(req.GetRequestTrailers())
7285
case *extProcPb.ProcessingRequest_ResponseHeaders:
@@ -83,10 +96,84 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
8396
return status.Errorf(status.Code(err), "failed to handle request: %v", err)
8497
}
8598

86-
loggerVerbose.Info("Response generated", "response", resp)
87-
if err := srv.Send(resp); err != nil {
88-
logger.V(logutil.DEFAULT).Error(err, "Send failed")
89-
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
99+
if resp != nil {
100+
loggerVerbose.Info("Response generated", "response", resp)
101+
if err := srv.Send(resp); err != nil {
102+
logger.V(logutil.DEFAULT).Error(err, "Send failed")
103+
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
104+
}
105+
106+
if s.streaming {
107+
bodyResp := &extProcPb.ProcessingResponse{
108+
Response: &extProcPb.ProcessingResponse_RequestBody{
109+
RequestBody: &extProcPb.BodyResponse{
110+
Response: &extProcPb.CommonResponse{
111+
BodyMutation: &extProcPb.BodyMutation{
112+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
113+
StreamedResponse: &extProcPb.StreamedBodyResponse{
114+
Body: requestBody,
115+
EndOfStream: true,
116+
},
117+
},
118+
},
119+
},
120+
},
121+
},
122+
}
123+
loggerVerbose.Info("Response generated", "response", bodyResp)
124+
if err := srv.Send(bodyResp); err != nil {
125+
logger.V(logutil.DEFAULT).Error(err, "Send failed")
126+
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
127+
}
128+
}
90129
}
91130
}
92131
}
132+
133+
func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, bufferWriter *io.PipeWriter, bufferReader *io.PipeReader, logger logr.Logger) (*extProcPb.ProcessingResponse, []byte, error) {
134+
loggerVerbose := logger.V(logutil.VERBOSE)
135+
136+
var requestBody map[string]interface{}
137+
if s.streaming {
138+
// In the stream case, we can receive multiple request bodies.
139+
// To buffer the full message, we create a goroutine with a writer.Write()
140+
// call, which will block until the corresponding reader reads from it.
141+
// We do not read until we receive the EndofStream signal, and then
142+
// decode the entire JSON body.
143+
if !body.EndOfStream {
144+
go func() {
145+
loggerVerbose.Info("Writing to stream buffer")
146+
_, err := bufferWriter.Write(body.Body)
147+
if err != nil {
148+
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
149+
}
150+
}()
151+
152+
return nil, nil, nil
153+
}
154+
155+
if body.EndOfStream {
156+
loggerVerbose.Info("Flushing stream buffer")
157+
decoder := json.NewDecoder(bufferReader)
158+
if err := decoder.Decode(&requestBody); err != nil {
159+
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
160+
}
161+
bufferReader.Close()
162+
}
163+
} else {
164+
if err := json.Unmarshal(body.GetBody(), &requestBody); err != nil {
165+
return nil, nil, err
166+
}
167+
}
168+
169+
requestBodyResp, err := s.HandleRequestBody(ctx, requestBody)
170+
if err != nil {
171+
return nil, nil, err
172+
}
173+
174+
requestBodyBytes, err := json.Marshal(requestBody)
175+
if err != nil {
176+
return nil, nil, err
177+
}
178+
return requestBodyResp, requestBodyBytes, nil
179+
}

pkg/body-based-routing/server/runserver.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,14 @@ import (
3434
type ExtProcServerRunner struct {
3535
GrpcPort int
3636
SecureServing bool
37+
Streaming bool
3738
}
3839

39-
// Default values for CLI flags in main
40-
const (
41-
DefaultGrpcPort = 9004 // default for --grpcPort
42-
)
43-
44-
func NewDefaultExtProcServerRunner() *ExtProcServerRunner {
40+
func NewDefaultExtProcServerRunner(port int, streaming bool) *ExtProcServerRunner {
4541
return &ExtProcServerRunner{
46-
GrpcPort: DefaultGrpcPort,
42+
GrpcPort: port,
4743
SecureServing: true,
44+
Streaming: streaming,
4845
}
4946
}
5047

@@ -65,7 +62,10 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable {
6562
srv = grpc.NewServer()
6663
}
6764

68-
extProcPb.RegisterExternalProcessorServer(srv, handlers.NewServer())
65+
extProcPb.RegisterExternalProcessorServer(
66+
srv,
67+
handlers.NewServer(r.Streaming),
68+
)
6969

7070
// Forward to the gRPC runnable.
7171
return runnable.GRPCServer("ext-proc", srv, r.GrpcPort).Start(ctx)

0 commit comments

Comments
 (0)