Skip to content

Commit 23b5c52

Browse files
committed
Support full duplex streaming
1 parent 53cb18f commit 23b5c52

File tree

7 files changed

+303
-122
lines changed

7 files changed

+303
-122
lines changed

cmd/body-based-routing/main.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@ import (
4444
var (
4545
grpcPort = flag.Int(
4646
"grpcPort",
47-
runserver.DefaultGrpcPort,
47+
9004,
4848
"The gRPC port used for communicating with Envoy proxy")
4949
grpcHealthPort = flag.Int(
5050
"grpcHealthPort",
5151
9005,
5252
"The port used for gRPC liveness and readiness probes")
5353
metricsPort = flag.Int(
5454
"metricsPort", 9090, "The metrics port")
55+
streaming = flag.Bool(
56+
"streaming", false, "Enables streaming support for Envoy full-duplex streaming mode")
5557
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
5658

5759
setupLog = ctrl.Log.WithName("setup")
@@ -92,7 +94,7 @@ func run() error {
9294
ctx := ctrl.SetupSignalHandler()
9395

9496
// Setup runner.
95-
serverRunner := &runserver.ExtProcServerRunner{GrpcPort: *grpcPort}
97+
serverRunner := runserver.NewDefaultExtProcServerRunner(*grpcPort, *streaming)
9698

9799
// Register health server.
98100
if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), *grpcHealthPort); err != nil {

pkg/body-based-routing/handlers/request.go

+94-33
Original file line numberDiff line numberDiff line change
@@ -23,55 +23,93 @@ import (
2323

2424
basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2525
eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
26+
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2627
"sigs.k8s.io/controller-runtime/pkg/log"
2728
"sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/metrics"
2829
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2930
)
3031

32+
const modelHeader = "X-Gateway-Model-Name"
33+
3134
// HandleRequestBody handles request bodies.
32-
func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*eppb.ProcessingResponse, error) {
35+
func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([]*eppb.ProcessingResponse, error) {
3336
logger := log.FromContext(ctx)
37+
var ret []*eppb.ProcessingResponse
3438

35-
var data map[string]any
36-
if err := json.Unmarshal(body.GetBody(), &data); err != nil {
39+
requestBodyBytes, err := json.Marshal(data)
40+
if err != nil {
3741
return nil, err
3842
}
3943

4044
modelVal, ok := data["model"]
4145
if !ok {
4246
metrics.RecordModelNotInBodyCounter()
4347
logger.V(logutil.DEFAULT).Info("Request body does not contain model parameter")
44-
return &eppb.ProcessingResponse{
45-
Response: &eppb.ProcessingResponse_RequestBody{
46-
RequestBody: &eppb.BodyResponse{},
47-
},
48-
}, nil
48+
if s.streaming {
49+
ret = append(ret, &eppb.ProcessingResponse{
50+
Response: &eppb.ProcessingResponse_RequestHeaders{
51+
RequestHeaders: &eppb.HeadersResponse{},
52+
},
53+
})
54+
ret = addStreamedBodyResponse(ret, requestBodyBytes)
55+
return ret, nil
56+
} else {
57+
ret = append(ret, &eppb.ProcessingResponse{
58+
Response: &eppb.ProcessingResponse_RequestBody{
59+
RequestBody: &eppb.BodyResponse{},
60+
},
61+
})
62+
}
63+
return ret, nil
4964
}
5065

5166
modelStr, ok := modelVal.(string)
5267
if !ok {
5368
metrics.RecordModelNotParsedCounter()
5469
logger.V(logutil.DEFAULT).Info("Model parameter value is not a string")
55-
return &eppb.ProcessingResponse{
56-
Response: &eppb.ProcessingResponse_RequestBody{
57-
RequestBody: &eppb.BodyResponse{},
58-
},
59-
}, fmt.Errorf("the model parameter value %v is not a string", modelVal)
70+
return nil, fmt.Errorf("the model parameter value %v is not a string", modelVal)
6071
}
6172

6273
metrics.RecordSuccessCounter()
63-
return &eppb.ProcessingResponse{
64-
Response: &eppb.ProcessingResponse_RequestBody{
65-
RequestBody: &eppb.BodyResponse{
66-
Response: &eppb.CommonResponse{
67-
// Necessary so that the new headers are used in the routing decision.
68-
ClearRouteCache: true,
69-
HeaderMutation: &eppb.HeaderMutation{
70-
SetHeaders: []*basepb.HeaderValueOption{
71-
{
72-
Header: &basepb.HeaderValue{
73-
Key: "X-Gateway-Model-Name",
74-
RawValue: []byte(modelStr),
74+
75+
if s.streaming {
76+
ret = append(ret, &eppb.ProcessingResponse{
77+
Response: &eppb.ProcessingResponse_RequestHeaders{
78+
RequestHeaders: &eppb.HeadersResponse{
79+
Response: &eppb.CommonResponse{
80+
ClearRouteCache: true,
81+
HeaderMutation: &eppb.HeaderMutation{
82+
SetHeaders: []*basepb.HeaderValueOption{
83+
{
84+
Header: &basepb.HeaderValue{
85+
Key: modelHeader,
86+
RawValue: []byte(modelStr),
87+
},
88+
},
89+
},
90+
},
91+
},
92+
},
93+
},
94+
})
95+
ret = addStreamedBodyResponse(ret, requestBodyBytes)
96+
return ret, nil
97+
}
98+
99+
return []*eppb.ProcessingResponse{
100+
{
101+
Response: &eppb.ProcessingResponse_RequestBody{
102+
RequestBody: &eppb.BodyResponse{
103+
Response: &eppb.CommonResponse{
104+
// Necessary so that the new headers are used in the routing decision.
105+
ClearRouteCache: true,
106+
HeaderMutation: &eppb.HeaderMutation{
107+
SetHeaders: []*basepb.HeaderValueOption{
108+
{
109+
Header: &basepb.HeaderValue{
110+
Key: modelHeader,
111+
RawValue: []byte(modelStr),
112+
},
75113
},
76114
},
77115
},
@@ -82,20 +120,43 @@ func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*e
82120
}, nil
83121
}
84122

123+
func addStreamedBodyResponse(responses []*eppb.ProcessingResponse, requestBodyBytes []byte) []*eppb.ProcessingResponse {
124+
return append(responses, &extProcPb.ProcessingResponse{
125+
Response: &extProcPb.ProcessingResponse_RequestBody{
126+
RequestBody: &extProcPb.BodyResponse{
127+
Response: &extProcPb.CommonResponse{
128+
BodyMutation: &extProcPb.BodyMutation{
129+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
130+
StreamedResponse: &extProcPb.StreamedBodyResponse{
131+
Body: requestBodyBytes,
132+
EndOfStream: true,
133+
},
134+
},
135+
},
136+
},
137+
},
138+
},
139+
})
140+
}
141+
85142
// HandleRequestHeaders handles request headers.
86-
func (s *Server) HandleRequestHeaders(headers *eppb.HttpHeaders) (*eppb.ProcessingResponse, error) {
87-
return &eppb.ProcessingResponse{
88-
Response: &eppb.ProcessingResponse_RequestHeaders{
89-
RequestHeaders: &eppb.HeadersResponse{},
143+
func (s *Server) HandleRequestHeaders(headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) {
144+
return []*eppb.ProcessingResponse{
145+
{
146+
Response: &eppb.ProcessingResponse_RequestHeaders{
147+
RequestHeaders: &eppb.HeadersResponse{},
148+
},
90149
},
91150
}, nil
92151
}
93152

94153
// HandleRequestTrailers handles request trailers.
95-
func (s *Server) HandleRequestTrailers(trailers *eppb.HttpTrailers) (*eppb.ProcessingResponse, error) {
96-
return &eppb.ProcessingResponse{
97-
Response: &eppb.ProcessingResponse_RequestTrailers{
98-
RequestTrailers: &eppb.TrailersResponse{},
154+
func (s *Server) HandleRequestTrailers(trailers *eppb.HttpTrailers) ([]*eppb.ProcessingResponse, error) {
155+
return []*eppb.ProcessingResponse{
156+
{
157+
Response: &eppb.ProcessingResponse_RequestTrailers{
158+
RequestTrailers: &eppb.TrailersResponse{},
159+
},
99160
},
100161
}, nil
101162
}

pkg/body-based-routing/handlers/request_test.go

+105-51
Original file line numberDiff line numberDiff line change
@@ -31,78 +31,132 @@ import (
3131
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3232
)
3333

34-
const (
35-
bodyWithModel = `
36-
{
37-
"model": "foo",
38-
"prompt": "Tell me a joke"
39-
}
40-
`
41-
bodyWithModelNoStr = `
42-
{
43-
"model": 1,
44-
"prompt": "Tell me a joke"
45-
}
46-
`
47-
bodyWithoutModel = `
48-
{
49-
"prompt": "Tell me a joke"
50-
}
51-
`
52-
)
53-
5434
func TestHandleRequestBody(t *testing.T) {
5535
metrics.Register()
5636
ctx := logutil.NewTestLoggerIntoContext(context.Background())
5737

5838
tests := []struct {
59-
name string
60-
body *extProcPb.HttpBody
61-
want *extProcPb.ProcessingResponse
62-
wantErr bool
39+
name string
40+
body map[string]any
41+
streaming bool
42+
want []*extProcPb.ProcessingResponse
43+
wantErr bool
6344
}{
6445
{
65-
name: "malformed body",
66-
body: &extProcPb.HttpBody{
67-
Body: []byte("malformed json"),
46+
name: "model not found",
47+
body: map[string]any{
48+
"prompt": "Tell me a joke",
49+
},
50+
want: []*extProcPb.ProcessingResponse{
51+
{
52+
Response: &extProcPb.ProcessingResponse_RequestBody{
53+
RequestBody: &extProcPb.BodyResponse{},
54+
},
55+
},
6856
},
69-
wantErr: true,
7057
},
7158
{
72-
name: "model not found",
73-
body: &extProcPb.HttpBody{
74-
Body: []byte(bodyWithoutModel),
59+
name: "model not found with streaming",
60+
body: map[string]any{
61+
"prompt": "Tell me a joke",
7562
},
76-
want: &extProcPb.ProcessingResponse{
77-
Response: &extProcPb.ProcessingResponse_RequestBody{
78-
RequestBody: &extProcPb.BodyResponse{},
63+
want: []*extProcPb.ProcessingResponse{
64+
{
65+
Response: &extProcPb.ProcessingResponse_RequestBody{
66+
RequestBody: &extProcPb.BodyResponse{},
67+
},
68+
},
69+
{
70+
Response: &extProcPb.ProcessingResponse_RequestBody{
71+
RequestBody: &extProcPb.BodyResponse{
72+
Response: &extProcPb.CommonResponse{
73+
BodyMutation: &extProcPb.BodyMutation{
74+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
75+
StreamedResponse: &extProcPb.StreamedBodyResponse{
76+
Body: []byte{},
77+
EndOfStream: true,
78+
},
79+
},
80+
},
81+
},
82+
},
83+
},
7984
},
8085
},
8186
},
8287
{
8388
name: "model is not string",
84-
body: &extProcPb.HttpBody{
85-
Body: []byte(bodyWithModelNoStr),
89+
body: map[string]any{
90+
"model": 1,
91+
"prompt": "Tell me a joke",
8692
},
8793
wantErr: true,
8894
},
8995
{
9096
name: "success",
91-
body: &extProcPb.HttpBody{
92-
Body: []byte(bodyWithModel),
97+
body: map[string]any{
98+
"model": "foo",
99+
"prompt": "Tell me a joke",
100+
},
101+
want: []*extProcPb.ProcessingResponse{
102+
{
103+
Response: &extProcPb.ProcessingResponse_RequestBody{
104+
RequestBody: &extProcPb.BodyResponse{
105+
Response: &extProcPb.CommonResponse{
106+
// Necessary so that the new headers are used in the routing decision.
107+
ClearRouteCache: true,
108+
HeaderMutation: &extProcPb.HeaderMutation{
109+
SetHeaders: []*basepb.HeaderValueOption{
110+
{
111+
Header: &basepb.HeaderValue{
112+
Key: "X-Gateway-Model-Name",
113+
RawValue: []byte("foo"),
114+
},
115+
},
116+
},
117+
},
118+
},
119+
},
120+
},
121+
},
93122
},
94-
want: &extProcPb.ProcessingResponse{
95-
Response: &extProcPb.ProcessingResponse_RequestBody{
96-
RequestBody: &extProcPb.BodyResponse{
97-
Response: &extProcPb.CommonResponse{
98-
// Necessary so that the new headers are used in the routing decision.
99-
ClearRouteCache: true,
100-
HeaderMutation: &extProcPb.HeaderMutation{
101-
SetHeaders: []*basepb.HeaderValueOption{
102-
{
103-
Header: &basepb.HeaderValue{
104-
Key: "X-Gateway-Model-Name",
105-
RawValue: []byte("foo"),
123+
},
124+
{
125+
name: "success-with-streaming",
126+
body: map[string]any{
127+
"model": "foo",
128+
"prompt": "Tell me a joke",
129+
},
130+
streaming: true,
131+
want: []*extProcPb.ProcessingResponse{
132+
{
133+
Response: &extProcPb.ProcessingResponse_RequestHeaders{
134+
RequestHeaders: &extProcPb.HeadersResponse{
135+
Response: &extProcPb.CommonResponse{
136+
ClearRouteCache: true,
137+
HeaderMutation: &extProcPb.HeaderMutation{
138+
SetHeaders: []*basepb.HeaderValueOption{
139+
{
140+
Header: &basepb.HeaderValue{
141+
Key: "X-Gateway-Model-Name",
142+
RawValue: []byte("foo"),
143+
},
144+
},
145+
},
146+
},
147+
},
148+
},
149+
},
150+
},
151+
{
152+
Response: &extProcPb.ProcessingResponse_RequestBody{
153+
RequestBody: &extProcPb.BodyResponse{
154+
Response: &extProcPb.CommonResponse{
155+
BodyMutation: &extProcPb.BodyMutation{
156+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
157+
StreamedResponse: &extProcPb.StreamedBodyResponse{
158+
Body: []byte{},
159+
EndOfStream: true,
106160
},
107161
},
108162
},
@@ -116,7 +170,7 @@ func TestHandleRequestBody(t *testing.T) {
116170

117171
for _, test := range tests {
118172
t.Run(test.name, func(t *testing.T) {
119-
server := &Server{}
173+
server := &Server{streaming: test.streaming}
120174
resp, err := server.HandleRequestBody(ctx, test.body)
121175
if err != nil {
122176
if !test.wantErr {

0 commit comments

Comments
 (0)