@@ -18,30 +18,36 @@ package handlers
18
18
19
19
import (
20
20
"context"
21
+ "encoding/json"
21
22
"errors"
22
23
"io"
23
24
24
25
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
26
+ "github.com/go-logr/logr"
25
27
"google.golang.org/grpc/codes"
26
28
"google.golang.org/grpc/status"
27
29
"sigs.k8s.io/controller-runtime/pkg/log"
28
30
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
29
31
)
30
32
31
- func NewServer () * Server {
32
- return & Server {}
33
+ func NewServer (streaming bool ) * Server {
34
+ return & Server {streaming : streaming }
33
35
}
34
36
35
37
// Server implements the Envoy external processing server.
36
38
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto
37
- type Server struct {}
39
+ type Server struct {
40
+ streaming bool
41
+ }
38
42
39
43
func (s * Server ) Process (srv extProcPb.ExternalProcessor_ProcessServer ) error {
40
44
ctx := srv .Context ()
41
45
logger := log .FromContext (ctx )
42
46
loggerVerbose := logger .V (logutil .VERBOSE )
43
47
loggerVerbose .Info ("Processing" )
44
48
49
+ reader , writer := io .Pipe ()
50
+
45
51
for {
46
52
select {
47
53
case <- ctx .Done ():
@@ -61,12 +67,19 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
61
67
}
62
68
63
69
var resp * extProcPb.ProcessingResponse
70
+ var requestBody []byte
64
71
var err error
65
72
switch v := req .Request .(type ) {
66
73
case * extProcPb.ProcessingRequest_RequestHeaders :
67
- resp , err = s .HandleRequestHeaders (req .GetRequestHeaders ())
74
+ if ! s .streaming {
75
+ // If streaming, then headers are handled when processing request body.
76
+ resp , err = s .HandleRequestHeaders (req .GetRequestHeaders ())
77
+ } else {
78
+ loggerVerbose .Info ("Received headers, passing off header processing until body arrives..." )
79
+ }
68
80
case * extProcPb.ProcessingRequest_RequestBody :
69
- resp , err = s .HandleRequestBody (ctx , req .GetRequestBody ())
81
+ loggerVerbose .Info ("Incoming body chunk" , "body" , string (v .RequestBody .Body ), "EoS" , v .RequestBody .EndOfStream )
82
+ resp , requestBody , err = s .processRequestBody (ctx , req .GetRequestBody (), writer , reader , logger )
70
83
case * extProcPb.ProcessingRequest_RequestTrailers :
71
84
resp , err = s .HandleRequestTrailers (req .GetRequestTrailers ())
72
85
case * extProcPb.ProcessingRequest_ResponseHeaders :
@@ -83,10 +96,84 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
83
96
return status .Errorf (status .Code (err ), "failed to handle request: %v" , err )
84
97
}
85
98
86
- loggerVerbose .Info ("Response generated" , "response" , resp )
87
- if err := srv .Send (resp ); err != nil {
88
- logger .V (logutil .DEFAULT ).Error (err , "Send failed" )
89
- return status .Errorf (codes .Unknown , "failed to send response back to Envoy: %v" , err )
99
+ if resp != nil {
100
+ loggerVerbose .Info ("Response generated" , "response" , resp )
101
+ if err := srv .Send (resp ); err != nil {
102
+ logger .V (logutil .DEFAULT ).Error (err , "Send failed" )
103
+ return status .Errorf (codes .Unknown , "failed to send response back to Envoy: %v" , err )
104
+ }
105
+
106
+ if s .streaming {
107
+ bodyResp := & extProcPb.ProcessingResponse {
108
+ Response : & extProcPb.ProcessingResponse_RequestBody {
109
+ RequestBody : & extProcPb.BodyResponse {
110
+ Response : & extProcPb.CommonResponse {
111
+ BodyMutation : & extProcPb.BodyMutation {
112
+ Mutation : & extProcPb.BodyMutation_StreamedResponse {
113
+ StreamedResponse : & extProcPb.StreamedBodyResponse {
114
+ Body : requestBody ,
115
+ EndOfStream : true ,
116
+ },
117
+ },
118
+ },
119
+ },
120
+ },
121
+ },
122
+ }
123
+ loggerVerbose .Info ("Response generated" , "response" , bodyResp )
124
+ if err := srv .Send (bodyResp ); err != nil {
125
+ logger .V (logutil .DEFAULT ).Error (err , "Send failed" )
126
+ return status .Errorf (codes .Unknown , "failed to send response back to Envoy: %v" , err )
127
+ }
128
+ }
90
129
}
91
130
}
92
131
}
132
+
133
+ func (s * Server ) processRequestBody (ctx context.Context , body * extProcPb.HttpBody , bufferWriter * io.PipeWriter , bufferReader * io.PipeReader , logger logr.Logger ) (* extProcPb.ProcessingResponse , []byte , error ) {
134
+ loggerVerbose := logger .V (logutil .VERBOSE )
135
+
136
+ var requestBody map [string ]interface {}
137
+ if s .streaming {
138
+ // In the stream case, we can receive multiple request bodies.
139
+ // To buffer the full message, we create a goroutine with a writer.Write()
140
+ // call, which will block until the corresponding reader reads from it.
141
+ // We do not read until we receive the EndofStream signal, and then
142
+ // decode the entire JSON body.
143
+ if ! body .EndOfStream {
144
+ go func () {
145
+ loggerVerbose .Info ("Writing to stream buffer" )
146
+ _ , err := bufferWriter .Write (body .Body )
147
+ if err != nil {
148
+ logger .V (logutil .DEFAULT ).Error (err , "Error populating writer" )
149
+ }
150
+ }()
151
+
152
+ return nil , nil , nil
153
+ }
154
+
155
+ if body .EndOfStream {
156
+ loggerVerbose .Info ("Flushing stream buffer" )
157
+ decoder := json .NewDecoder (bufferReader )
158
+ if err := decoder .Decode (& requestBody ); err != nil {
159
+ logger .V (logutil .DEFAULT ).Error (err , "Error unmarshaling request body" )
160
+ }
161
+ bufferReader .Close ()
162
+ }
163
+ } else {
164
+ if err := json .Unmarshal (body .GetBody (), & requestBody ); err != nil {
165
+ return nil , nil , err
166
+ }
167
+ }
168
+
169
+ requestBodyResp , err := s .HandleRequestBody (ctx , requestBody )
170
+ if err != nil {
171
+ return nil , nil , err
172
+ }
173
+
174
+ requestBodyBytes , err := json .Marshal (requestBody )
175
+ if err != nil {
176
+ return nil , nil , err
177
+ }
178
+ return requestBodyResp , requestBodyBytes , nil
179
+ }
0 commit comments