Skip to content

Commit 3ef4896

Browse files
committed
feat: allow for user-defined client context
1 parent fe11d78 commit 3ef4896

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

cmd/aws-lambda-rie/handlers.go

+9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package main
55

66
import (
77
"bytes"
8+
"encoding/base64"
89
"fmt"
910
"io/ioutil"
1011
"math"
@@ -81,6 +82,13 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i
8182
return
8283
}
8384

85+
rawClientContext, err := base64.StdEncoding.DecodeString(r.Header.Get("X-Amz-Client-Context"))
86+
if err != nil {
87+
log.Errorf("Failed to decode X-Amz-Client-Context: %s", err)
88+
w.WriteHeader(500)
89+
return
90+
}
91+
8492
initDuration := ""
8593
inv := GetenvWithDefault("AWS_LAMBDA_FUNCTION_TIMEOUT", "300")
8694
timeoutDuration, _ := time.ParseDuration(inv + "s")
@@ -114,6 +122,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i
114122
TraceID: r.Header.Get("X-Amzn-Trace-Id"),
115123
LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"),
116124
Payload: bytes.NewReader(bodyBytes),
125+
ClientContext: string(rawClientContext),
117126
}
118127
fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion)
119128

test/integration/local_lambda/test_end_to_end.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from subprocess import Popen, PIPE
55
from unittest import TestCase, main
66
from pathlib import Path
7+
import base64
8+
import json
79
import time
810
import os
911
import requests
@@ -62,12 +64,14 @@ def run_command(self, cmd):
6264

6365
def sleep_1s(self):
6466
time.sleep(SLEEP_TIME)
65-
66-
def invoke_function(self):
67+
68+
def invoke_function(self, json={}, headers={}):
6769
return requests.post(
68-
f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations", json={}
70+
f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations",
71+
json=json,
72+
headers=headers,
6973
)
70-
74+
7175
@contextmanager
7276
def create_container(self, param, image):
7377
try:
@@ -234,6 +238,24 @@ def test_port_override(self):
234238
self.assertEqual(b'"My lambda ran succesfully"', r.content)
235239

236240

241+
def test_custom_client_context(self):
242+
image, rie, image_name = self.tagged_name("custom_client_context")
243+
244+
params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.custom_client_context_handler"
245+
246+
with self.create_container(params, image):
247+
r = self.invoke_function(headers={
248+
"X-Amz-Client-Context": base64.b64encode(json.dumps({
249+
"custom": {
250+
"foo": "bar",
251+
"baz": 123,
252+
}
253+
}).encode('utf8')).decode('utf8'),
254+
})
255+
content = json.loads(r.content)
256+
self.assertEqual("bar", content["foo"])
257+
self.assertEqual(123, content["baz"])
258+
237259

238260
if __name__ == "__main__":
239261
main()

test/integration/testdata/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@ def check_remaining_time_handler(event, context):
4141
# Wait 1s to see if the remaining time changes
4242
time.sleep(1)
4343
return context.get_remaining_time_in_millis()
44+
45+
46+
def custom_client_context_handler(event, context):
47+
return context.client_context.custom

0 commit comments

Comments
 (0)