Skip to content

Commit 0985386

Browse files
authored
update prompt template (#3372)
1 parent 3182443 commit 0985386

10 files changed

+62
-30
lines changed

examples/usecases/llm_diffusion_serving_app/docker/build_image.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ echo "ROOT_DIR: $ROOT_DIR"
2020

2121
# Build docker image for the application
2222
docker_build_cmd="DOCKER_BUILDKIT=1 \
23-
docker buildx build \
23+
docker buildx build --load \
2424
--platform=linux/amd64 \
2525
--file ${EXAMPLE_DIR}/Dockerfile \
2626
--build-arg BASE_IMAGE=\"${BASE_IMAGE}\" \

examples/usecases/llm_diffusion_serving_app/docker/client_app.py

+46-16
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@
2727
st.session_state.gen_captions = st.session_state.get("gen_captions", [])
2828
st.session_state.llm_prompts = st.session_state.get("llm_prompts", None)
2929
st.session_state.llm_time = st.session_state.get("llm_time", 0)
30+
st.session_state.num_images = st.session_state.get("num_images", 2)
31+
st.session_state.max_new_tokens = st.session_state.get("max_new_tokens", 100)
32+
33+
34+
def update_max_tokens():
35+
# Update the max_new_tokens input value in session state and UI
36+
# The prompts generated are description which are around 50 tokens per prompt
37+
st.session_state.max_new_tokens = 50 * st.session_state.num_images
38+
3039

3140
with st.sidebar:
3241
st.title("Image Generation with Llama, SDXL, torch.compile and OpenVINO")
@@ -76,13 +85,23 @@ def get_model_status(model_name):
7685
)
7786

7887
# Client App Parameters
79-
num_images = st.sidebar.number_input(
80-
"Number of images to generate", min_value=1, max_value=8, value=2, step=1
88+
# Default value is set via session_state variables for num_images and max_new_tokens
89+
st.sidebar.number_input(
90+
"Number of images to generate",
91+
min_value=1,
92+
max_value=8,
93+
step=1,
94+
key="num_images",
95+
on_change=update_max_tokens,
8196
)
8297

8398
st.subheader("LLM Model parameters")
84-
max_new_tokens = st.sidebar.number_input(
85-
"max_new_tokens", min_value=30, max_value=250, value=40, step=5
99+
st.sidebar.number_input(
100+
"max_new_tokens",
101+
min_value=100,
102+
max_value=1250,
103+
step=10,
104+
key="max_new_tokens",
86105
)
87106

88107
temperature = st.sidebar.number_input(
@@ -159,11 +178,19 @@ def sd_response_postprocess(response):
159178

160179

161180
def preprocess_llm_input(user_prompt, num_images=2):
162-
template = """ Below is an instruction that describes a task. Write a response that appropriately completes the request.
163-
Generate {} unique prompts similar to '{}' by changing the context, keeping the core theme intact.
164-
Give the output in square brackets seperated by semicolon.
181+
template = """ Generate expanded and descriptive prompts for a image generation model based on the user input.
182+
Each prompt should build upon the original concept, adding layers of detail and context to create a more vivid and engaging scene for image generation.
183+
Format each prompt distinctly within square brackets.
184+
Ensure that each prompt is a standalone description that significantly elaborates on the original input as shown in the example below:
185+
Example: For the input 'A futuristic cityscape with flying cars,' generate:
186+
[A futuristic cityscape with sleek, silver flying cars zipping through the sky, set against a backdrop of towering skyscrapers and neon-lit streets.]
187+
[A futuristic cityscape at dusk, with flying cars of various colors and shapes flying in formation.]
188+
[A futuristic cityscape at night, with flying cars illuminated by the city's vibrant nightlife.]
189+
190+
Aim for a tone that is rich in imagination and visual appeal, capturing the essence of the scene with depth and creativity.
165191
Do not generate text beyond the specified output format. Do not explain your response.
166-
### Response:
192+
Generate {} similar detailed prompts for the user's input: {}.
193+
Organize the output such that each prompt is within square brackets. Refer to example above.
167194
"""
168195

169196
prompt_template_with_user_input = template.format(num_images, user_prompt)
@@ -206,7 +233,7 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt):
206233
{
207234
"prompt_template": prompt_template_with_user_input,
208235
"user_prompt": user_prompt,
209-
"max_new_tokens": max_new_tokens,
236+
"max_new_tokens": st.session_state.max_new_tokens,
210237
"temperature": temperature,
211238
"top_k": top_k,
212239
"top_p": top_p,
@@ -260,7 +287,7 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt):
260287
)
261288

262289
user_prompt = st.text_input("Enter a prompt for image generation:")
263-
include_user_prompt = st.checkbox("Include orginal prompt", value=False)
290+
include_user_prompt = st.checkbox("Include original prompt", value=False)
264291

265292
prompt_container = st.container()
266293
status_container = st.container()
@@ -287,15 +314,18 @@ def display_prompts():
287314
llm_start_time = time.time()
288315

289316
st.session_state.llm_prompts = [user_prompt]
290-
if num_images > 1:
317+
if st.session_state.num_images > 1:
291318
prompt_template_with_user_input = preprocess_llm_input(
292-
user_prompt, num_images
319+
user_prompt, st.session_state.num_images
293320
)
294321
llm_prompts = generate_llm_model_response(
295322
prompt_template_with_user_input, user_prompt
296323
)
297324
st.session_state.llm_prompts = postprocess_llm_response(
298-
llm_prompts, user_prompt, num_images, include_user_prompt
325+
llm_prompts,
326+
user_prompt,
327+
st.session_state.num_images,
328+
include_user_prompt,
299329
)
300330

301331
st.session_state.llm_time = time.time() - llm_start_time
@@ -306,11 +336,11 @@ def display_prompts():
306336
prompt_container.write(
307337
"Enter Image Generation Prompt and Click Generate Prompts !"
308338
)
309-
elif len(st.session_state.llm_prompts) < num_images:
339+
elif len(st.session_state.llm_prompts) < st.session_state.num_images:
310340
prompt_container.warning(
311341
f"""Insufficient prompts. Regenerate prompts !
312-
Num Images Requested: {num_images}, Prompts Generated: {len(st.session_state.llm_prompts)}
313-
{f"Consider increasing the max_new_tokens parameter !" if num_images > 4 else ""}""",
342+
Num Images Requested: {st.session_state.num_images}, Prompts Generated: {len(st.session_state.llm_prompts)}
343+
{f"Consider increasing the max_new_tokens parameter !" if st.session_state.num_images > 4 else ""}""",
314344
icon="⚠️",
315345
)
316346
else:
Loading
Loading
Loading
Loading
Loading
Loading
Loading

examples/usecases/llm_diffusion_serving_app/docker/llm/llm_handler.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(self):
2727
self.user_prompt = []
2828
self.prompt_template = ""
2929

30+
@timed
3031
def initialize(self, ctx):
3132
self.context = ctx
3233
self.manifest = ctx.manifest
@@ -48,7 +49,7 @@ def initialize(self, ctx):
4849
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
4950
self.model = AutoModelForCausalLM.from_pretrained(model_dir)
5051

51-
# Get backend for model-confil.yaml. Defaults to "openvino"
52+
# Get backend for model-config.yaml. Defaults to "openvino"
5253
compile_options = {}
5354
pt2_config = ctx.model_yaml_config.get("pt2", {})
5455
compile_options = {
@@ -115,21 +116,22 @@ def inference(self, input_data):
115116

116117
return generated_text
117118

119+
@timed
118120
def postprocess(self, generated_text):
119-
logger.info(f"LLM Generated Output: {generated_text}")
120-
# Initialize with user prompt
121+
# Remove input prompt from generated_text
122+
generated_text = generated_text.replace(self.prompt_template, "", 1)
123+
# Clean up LLM output
124+
generated_text = generated_text.replace("\n", " ").replace(" ", " ").strip()
125+
126+
logger.info(f"LLM Generated Output without input prompt: {generated_text}")
121127
prompt_list = [self.user_prompt]
122128
try:
123-
logger.info("Parsing LLM Generated Output to extract prompts within []...")
124-
response_match = re.search(r"\[(.*?)\]", generated_text)
125-
# Extract the result if match is found
126-
if response_match:
127-
# Split the extracted string by semicolon and strip any leading/trailing spaces
128-
response_list = response_match.group(1)
129-
extracted_prompts = [item.strip() for item in response_list.split(";")]
130-
prompt_list.extend(extracted_prompts)
131-
else:
132-
logger.warning("No match found in the generated output text !!!")
129+
# Use regular expressions to find strings within square brackets
130+
pattern = re.compile(r"\[.*?\]")
131+
matches = pattern.findall(generated_text)
132+
# Clean up the matches and remove square brackets
133+
extracted_prompts = [match.strip("[]").strip() for match in matches]
134+
prompt_list.extend(extracted_prompts)
133135
except Exception as e:
134136
logger.error(f"An error occurred while parsing the generated text: {e}")
135137

0 commit comments

Comments
 (0)