Skip to content

Commit 22acc49

Browse files
authored
Workflow Templates (#1594)
1 parent ce4afea commit 22acc49

File tree

9 files changed

+201
-8
lines changed

9 files changed

+201
-8
lines changed

serving/docs/adapters.md

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ DEL models/{modelName}/adapters/{adapterName} - Delete adapter
5454

5555
The final option for working with adapters is through the [DJL Serving workflows system](workflows.md).
5656
You can use the adapter `WorkflowFunction` to create and call an adapted version of a model within the workflow.
57+
For the simple model + adapter case, you can also directly use the adapter [workflow template](workflow_templates.md).
5758
With our workflows, multiple workflows sharing models will be de-duplicated.
5859
So, the effect of having multiple adapters can be easily made with having one workflow for each adapter.
5960
This system can be used on [Amazon SageMaker Multi-Model Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html).

serving/docs/management_api.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ curl -v -X POST "http://localhost:8080/models?url=https%3A%2F%2Fresources.djl.ai
7272

7373
`POST /workflows`
7474

75-
* url - Workflow url.
75+
* url - Workflow url
76+
* template - A workflow template to use
7677
* engine - the name of engine to load the model. DJL will try to infer engine if not specified.
7778
* device - the device to load the model. DJL will pick optimal device if not specified, the value device can be:
7879
* CPU device: cpu or simply -1
@@ -82,6 +83,8 @@ curl -v -X POST "http://localhost:8080/models?url=https%3A%2F%2Fresources.djl.ai
8283
* max_worker - the maximum number of worker processes. The default is the same as the setting for `min_worker`.
8384
* synchronous - if the creation of worker is synchronous. The default value is true.
8485

86+
Either a url or [template](workflow_templates.md) is required.
87+
8588
```bash
8689
curl -X POST "http://localhost:8080/workflows?url=https%3A%2F%2Fresources.djl.ai%2Ftest-workflows%2Fmlp.zip"
8790

serving/docs/workflow_templates.md

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Workflow Templates
2+
3+
Workflow templates are a tool to make it easier to register similar workflows.
4+
5+
## Registering a workflow with a template
6+
7+
To register a new workflow using a template, the template must first be registered (see below).
8+
Then, you can use the management API to register the workflow.
9+
To do so, call register while specifying the template using the `template` query parameter.
10+
Then, specify all template replacement values using additional query parameters.
11+
12+
## Available templates
13+
14+
### adapter
15+
16+
The adapter template is used for a simple workflow that reflects an [adapter model](adapters.md).
17+
18+
Parameters:
19+
20+
- template=`adapter`
21+
- adapter - The adapter name
22+
- url - The adapter URL
23+
- model - The model URL
24+
25+
Example:
26+
27+
`POST /workflows?template=adapter&adapter={adapterName}&url={adapterUrl}&model={modelUrl}`
28+
29+
## Adding new templates
30+
31+
To add a new template, begin by creating the template JSON file.
32+
This mostly matches the standard format of a [workflow](workflows.md).
33+
However, your template can indicate variable sections of the template to be replaced.
34+
This is done by prefixing the name to replace with a `$`.
35+
So, a parameter `param` would replace the value `$param` within the template.
36+
This replacement is directly in place, so if your parameter is a string you will still have to surround it with quotation marks.
37+
38+
Then, you must register your new workflow template.
39+
There are two options for doing this.
40+
First, you can add it to the classpath as a resource with path `workflowTemplates/{workflowTemplateName}.json`.
41+
Alternatively, you can register it from a plugin by calling `WorkflowTemplates.register(..)`.
42+
Once this is done, you will be able to begin creating workflows with your template.

serving/src/main/java/ai/djl/serving/http/LoadModelRequest.java

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
class LoadModelRequest {
2323

2424
static final String URL = "url";
25+
static final String TEMPLATE = "template";
2526
static final String DEVICE = "device";
2627
static final String MAX_WORKER = "max_worker";
2728
static final String MIN_WORKER = "min_worker";

serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java

+20-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import ai.djl.serving.workflow.BadWorkflowException;
2828
import ai.djl.serving.workflow.Workflow;
2929
import ai.djl.serving.workflow.WorkflowDefinition;
30+
import ai.djl.serving.workflow.WorkflowTemplates;
3031
import ai.djl.util.JsonUtils;
32+
import ai.djl.util.Pair;
3133

3234
import io.netty.channel.ChannelHandlerContext;
3335
import io.netty.handler.codec.http.FullHttpRequest;
@@ -47,6 +49,7 @@
4749
import java.util.Map;
4850
import java.util.concurrent.CompletableFuture;
4951
import java.util.regex.Pattern;
52+
import java.util.stream.Collectors;
5053

5154
/** A class handling inbound HTTP requests to the management API. */
5255
public class ManagementRequestHandler extends HttpRequestHandler {
@@ -258,8 +261,9 @@ private void handleRegisterModel(
258261
private void handleRegisterWorkflow(
259262
final ChannelHandlerContext ctx, QueryStringDecoder decoder) {
260263
String workflowUrl = NettyUtils.getParameter(decoder, LoadModelRequest.URL, null);
261-
if (workflowUrl == null) {
262-
throw new BadRequestException("Parameter url is required.");
264+
String workflowTemplate = NettyUtils.getParameter(decoder, LoadModelRequest.TEMPLATE, null);
265+
if (workflowUrl == null && workflowTemplate == null) {
266+
throw new BadRequestException("Either parameter url or template is required.");
263267
}
264268

265269
boolean synchronous =
@@ -269,8 +273,20 @@ private void handleRegisterWorkflow(
269273
try {
270274
final ModelManager modelManager = ModelManager.getInstance();
271275

272-
URI uri = URI.create(workflowUrl);
273-
Workflow workflow = WorkflowDefinition.parse(null, uri).toWorkflow();
276+
Workflow workflow;
277+
if (workflowTemplate != null) { // Workflow from template
278+
Map<String, String> templateReplacements = // NOPMD
279+
decoder.parameters().entrySet().stream()
280+
.filter(e -> e.getValue().size() == 1)
281+
.map(e -> new Pair<>(e.getKey(), e.getValue().get(0)))
282+
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
283+
workflow =
284+
WorkflowTemplates.template(workflowTemplate, templateReplacements)
285+
.toWorkflow();
286+
} else { // Workflow from URL
287+
URI uri = URI.create(workflowUrl);
288+
workflow = WorkflowDefinition.parse(null, uri).toWorkflow();
289+
}
274290
String workflowName = workflow.getName();
275291

276292
CompletableFuture<Void> f =

serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java

+30-3
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@
3131
import com.google.gson.JsonParseException;
3232
import com.google.gson.annotations.SerializedName;
3333

34+
import java.io.BufferedReader;
3435
import java.io.IOException;
3536
import java.io.InputStream;
3637
import java.io.InputStreamReader;
3738
import java.io.Reader;
39+
import java.io.StringReader;
3840
import java.lang.reflect.Constructor;
3941
import java.lang.reflect.Method;
4042
import java.lang.reflect.Type;
@@ -49,6 +51,7 @@
4951
import java.util.Map.Entry;
5052
import java.util.Objects;
5153
import java.util.concurrent.ConcurrentHashMap;
54+
import java.util.stream.Collectors;
5255

5356
/**
5457
* This class is for parsing the JSON or YAML definition for a {@link Workflow}.
@@ -93,16 +96,21 @@ public static WorkflowDefinition parse(Path path) throws IOException {
9396
/**
9497
* Parses a new {@link WorkflowDefinition} from an input stream.
9598
*
96-
* @param name the workflow name
99+
* @param name the workflow name (null for no name)
97100
* @param uri the uri of the file
98101
* @return the parsed {@link WorkflowDefinition}
99102
* @throws IOException if read from uri failed
100103
*/
101104
public static WorkflowDefinition parse(String name, URI uri) throws IOException {
105+
return parse(name, uri, null);
106+
}
107+
108+
static WorkflowDefinition parse(String name, URI uri, Map<String, String> templateReplacements)
109+
throws IOException {
102110
String type = FilenameUtils.getFileExtension(Objects.requireNonNull(uri.toString()));
103111
try (InputStream is = uri.toURL().openStream();
104112
Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {
105-
WorkflowDefinition wd = parse(type, reader);
113+
WorkflowDefinition wd = parse(type, reader, templateReplacements);
106114
if (name != null) {
107115
wd.name = name;
108116
}
@@ -113,7 +121,26 @@ public static WorkflowDefinition parse(String name, URI uri) throws IOException
113121
}
114122
}
115123

116-
private static WorkflowDefinition parse(String type, Reader input) {
124+
private static WorkflowDefinition parse(
125+
String type, Reader input, Map<String, String> templateReplacements) {
126+
if (templateReplacements != null) {
127+
String updatedInput =
128+
new BufferedReader(input)
129+
.lines()
130+
.map(
131+
l -> {
132+
for (Entry<String, String> replacement :
133+
templateReplacements.entrySet()) {
134+
l =
135+
l.replace(
136+
"$" + replacement.getKey(),
137+
replacement.getValue());
138+
}
139+
return l;
140+
})
141+
.collect(Collectors.joining("\n"));
142+
input = new StringReader(updatedInput);
143+
}
117144
if ("yml".equalsIgnoreCase(type) || "yaml".equalsIgnoreCase(type)) {
118145
try {
119146
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.serving.workflow;
14+
15+
import ai.djl.util.ClassLoaderUtils;
16+
17+
import java.io.IOException;
18+
import java.net.URI;
19+
import java.net.URISyntaxException;
20+
import java.net.URL;
21+
import java.util.Map;
22+
import java.util.concurrent.ConcurrentHashMap;
23+
24+
/** A class for managing and using {@link WorkflowDefinition} templates. */
25+
public final class WorkflowTemplates {
26+
27+
private static final Map<String, URI> TEMPLATES = new ConcurrentHashMap<>();
28+
29+
private WorkflowTemplates() {}
30+
31+
/**
32+
* Registers a new workflow template.
33+
*
34+
* @param name the template name
35+
* @param template the template location
36+
*/
37+
public static void register(String name, URI template) {
38+
TEMPLATES.put(name, template);
39+
}
40+
41+
/**
42+
* Constructs a {@link WorkflowDefinition} using a registered template.
43+
*
44+
* @param templateName the template name
45+
* @param templateReplacements a map of replacements to be applied to the template
46+
* @return the new {@link WorkflowDefinition} based off the template
47+
* @throws IOException if it fails to load the template file for parsing
48+
*/
49+
public static WorkflowDefinition template(
50+
String templateName, Map<String, String> templateReplacements) throws IOException {
51+
URI uri = TEMPLATES.get(templateName);
52+
53+
if (uri == null) {
54+
URL fromResource =
55+
ClassLoaderUtils.getResource("workflowTemplates/" + templateName + ".json");
56+
if (fromResource != null) {
57+
try {
58+
uri = fromResource.toURI();
59+
} catch (URISyntaxException ignored) {
60+
}
61+
}
62+
}
63+
64+
if (uri == null) {
65+
throw new IllegalArgumentException(
66+
"The workflow template " + templateName + " could not be found");
67+
}
68+
69+
return WorkflowDefinition.parse(null, uri, templateReplacements);
70+
}
71+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"name": "$adapter",
3+
"version": "0.1",
4+
"models": {
5+
"m": "$model"
6+
},
7+
"configs": {
8+
"adapters": {
9+
"$adapter": {
10+
"model": "m",
11+
"src": "$url"
12+
}
13+
}
14+
},
15+
"workflow": {
16+
"out": ["adapter", "$adapter", "in"]
17+
}
18+
}

serving/src/test/java/ai/djl/serving/ModelServerTest.java

+14
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ public void testAdapterWorkflows()
319319

320320
testAdapterWorkflowPredict(channel, "adapter1", "a1");
321321
testAdapterWorkflowPredict(channel, "adapter2", "a2");
322+
testRegisterAdapterWorkflowTemplate(channel);
322323

323324
channel.close().sync();
324325

@@ -925,6 +926,19 @@ private void testAdapterWorkflowPredict(Channel channel, String workflow, String
925926
assertEquals(result, adapter + "testAWP");
926927
}
927928

929+
private void testRegisterAdapterWorkflowTemplate(Channel channel) throws InterruptedException {
930+
logTestFunction();
931+
String adapterUrl = "dummy";
932+
String modelUrl = URLEncoder.encode("src/test/resources/adaptecho", StandardCharsets.UTF_8);
933+
934+
String url =
935+
"/workflows?template=adapter&adapter=a&url=" + adapterUrl + "&model=" + modelUrl;
936+
request(channel, HttpMethod.POST, url);
937+
938+
StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
939+
assertEquals(resp.getStatus(), "Workflow \"a\" registered.");
940+
}
941+
928942
private void testAdapterInvoke(Channel channel) throws InterruptedException {
929943
logTestFunction();
930944
String url = "/invocations?model_name=adaptecho&adapter=adaptable";

0 commit comments

Comments
 (0)