Skip to content

Commit 795f3ac

Browse files
committed
Read prompts from file + images support
1 parent ea3e015 commit 795f3ac

File tree

3 files changed

+83
-46
lines changed

3 files changed

+83
-46
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

llm_bench/README.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ Generation options:
6363
- `--chat`: specify to call chat API instead of raw completions
6464
- `--stream`: stream the result back. Enabling this gives "time to first token" and "time per token" metrics
6565
- (optional) `--logprobs`: corresponds to `logprobs` API parameter. For some providers, it's needed for output token counting in streaming mode.
66-
- `--max-tokens-jitter`: how much to adjust randomly the setting of `-o` at each request. When using "fixed concurrency" mode it's useful to avoid all workers implicitly synchronizing and causing periodic traffic bursts.
6766

6867
### Writing results
6968

@@ -76,6 +75,16 @@ When comparing multiple configurations, it's useful to aggregate results togethe
7675

7776
The typical workflow would be to run benchmark several times appending to the same CSV file. The resulting file can be imported into a spreadsheet or pandas for further analysis.
7877

78+
### Custom prompts
79+
80+
Sometimes it's necessary to replay exact prompts, for example in the case of embedding images.
81+
`--prompt-text` option can be used in this case to specify a file with .jsonl extension (starting with an ampersand, e.g. `@prompt.jsonl`.).
82+
jsonl files will be read line-by-line and will be randomly chosen for each request. Each line has to have a valid JSON object with 'prompt' and optional 'images' keys. For example:
83+
```
84+
{"prompt": "<image>What color is the cat?", images: ["_64_DATA]}
85+
{"prompt": "<image>What color is the dog?", images: ["_64_DATA]}
86+
```
87+
7988
## Examples
8089

8190
Maintain fixed 8 requests concurrency against local deployment:

llm_bench/load_test.py

+72-45
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,13 @@ def __init__(self, model, parsed_options):
229229
self.parsed_options = parsed_options
230230

231231
@abc.abstractmethod
232-
def get_url(self):
233-
...
232+
def get_url(self): ...
234233

235234
@abc.abstractmethod
236-
def format_payload(self, prompt, max_tokens):
237-
...
235+
def format_payload(self, prompt, max_tokens, images): ...
238236

239237
@abc.abstractmethod
240-
def parse_output_json(self, json, prompt):
241-
...
238+
def parse_output_json(self, json, prompt): ...
242239

243240

244241
class OpenAIProvider(BaseProvider):
@@ -248,17 +245,32 @@ def get_url(self):
248245
else:
249246
return "/v1/completions"
250247

251-
def format_payload(self, prompt, max_tokens):
248+
def format_payload(self, prompt, max_tokens, images):
252249
data = {
253250
"model": self.model,
254251
"max_tokens": max_tokens,
255252
"stream": self.parsed_options.stream,
256253
"temperature": self.parsed_options.temperature,
257254
}
258255
if self.parsed_options.chat:
259-
data["messages"] = [{"role": "user", "content": prompt}]
256+
if images is None:
257+
data["messages"] = [{"role": "user", "content": prompt}]
258+
else:
259+
image_urls = []
260+
for image in images:
261+
image_urls.append(
262+
{"type": "image_url", "image_url": {"url": image}}
263+
)
264+
data["messages"] = [
265+
{
266+
"role": "user",
267+
"content": [{"type": "text", "text": prompt}, *image_urls],
268+
}
269+
]
260270
else:
261271
data["prompt"] = prompt
272+
if images is not None:
273+
data["images"] = images
262274
if self.parsed_options.logprobs is not None:
263275
data["logprobs"] = self.parsed_options.logprobs
264276
return data
@@ -286,16 +298,16 @@ def parse_output_json(self, data, prompt):
286298

287299

288300
class FireworksProvider(OpenAIProvider):
289-
def format_payload(self, prompt, max_tokens):
290-
data = super().format_payload(prompt, max_tokens)
301+
def format_payload(self, prompt, max_tokens, images):
302+
data = super().format_payload(prompt, max_tokens, images)
291303
data["min_tokens"] = max_tokens
292304
data["prompt_cache_max_len"] = 0
293305
return data
294306

295307

296308
class VllmProvider(OpenAIProvider):
297-
def format_payload(self, prompt, max_tokens):
298-
data = super().format_payload(prompt, max_tokens)
309+
def format_payload(self, prompt, max_tokens, images):
310+
data = super().format_payload(prompt, max_tokens, images)
299311
data["ignore_eos"] = True
300312
return data
301313

@@ -305,8 +317,8 @@ def get_url(self):
305317
assert not self.parsed_options.chat, "Chat is not supported"
306318
return "/"
307319

308-
def format_payload(self, prompt, max_tokens):
309-
data = super().format_payload(prompt, max_tokens)
320+
def format_payload(self, prompt, max_tokens, images):
321+
data = super().format_payload(prompt, max_tokens, images)
310322
data["ignore_eos"] = True
311323
data["stream_tokens"] = data.pop("stream")
312324
return data
@@ -325,7 +337,8 @@ def get_url(self):
325337
assert not self.parsed_options.stream, "Stream is not supported"
326338
return f"/v2/models/{self.model}/infer"
327339

328-
def format_payload(self, prompt, max_tokens):
340+
def format_payload(self, prompt, max_tokens, images):
341+
assert images is None, "images are not supported"
329342
# matching latest TRT-LLM example, your model configuration might be different
330343
data = {
331344
"inputs": [
@@ -394,7 +407,8 @@ def get_url(self):
394407
stream_suffix = "_stream" if self.parsed_options.stream else ""
395408
return f"/v2/models/{self.model}/generate{stream_suffix}"
396409

397-
def format_payload(self, prompt, max_tokens):
410+
def format_payload(self, prompt, max_tokens, images):
411+
assert images is None, "images are not supported"
398412
data = {
399413
"text_input": prompt,
400414
"max_tokens": max_tokens,
@@ -433,7 +447,8 @@ def get_url(self):
433447
stream_suffix = "_stream" if self.parsed_options.stream else ""
434448
return f"/generate{stream_suffix}"
435449

436-
def format_payload(self, prompt, max_tokens):
450+
def format_payload(self, prompt, max_tokens, images):
451+
assert images is None, "images are not supported"
437452
data = {
438453
"inputs": prompt,
439454
"parameters": {
@@ -458,12 +473,12 @@ def parse_output_json(self, data, prompt):
458473
# non-streaming response
459474
return ChunkMetadata(
460475
text=data["generated_text"],
461-
logprob_tokens=len(data["details"]["tokens"])
462-
if "details" in data
463-
else None,
464-
usage_tokens=data["details"]["generated_tokens"]
465-
if "details" in data
466-
else None,
476+
logprob_tokens=(
477+
len(data["details"]["tokens"]) if "details" in data else None
478+
),
479+
usage_tokens=(
480+
data["details"]["generated_tokens"] if "details" in data else None
481+
),
467482
prompt_usage_tokens=None,
468483
)
469484

@@ -486,8 +501,12 @@ def _load_curl_like_data(text):
486501
"""
487502
if text.startswith("@"):
488503
try:
489-
with open(text[1:], "r") as f:
490-
return f.read()
504+
if text.endswith(".jsonl"):
505+
with open(text[1:], "r") as f:
506+
return [json.loads(line) for line in f]
507+
else:
508+
with open(text[1:], "r") as f:
509+
return f.read()
491510
except Exception as e:
492511
raise ValueError(f"Failed to read file {text[1:]}") from e
493512
else:
@@ -575,11 +594,11 @@ def _on_start(self):
575594
self.stream = self.environment.parsed_options.stream
576595
prompt_chars = self.environment.parsed_options.prompt_chars
577596
if self.environment.parsed_options.prompt_text:
578-
self.prompt = _load_curl_like_data(
597+
self.input = _load_curl_like_data(
579598
self.environment.parsed_options.prompt_text
580599
)
581600
elif prompt_chars:
582-
self.prompt = (
601+
self.input = (
583602
prompt_prefix * (prompt_chars // len(prompt_prefix) + 1) + prompt
584603
)[:prompt_chars]
585604
else:
@@ -591,7 +610,7 @@ def _on_start(self):
591610
assert (
592611
self.environment.parsed_options.prompt_tokens >= min_prompt_len
593612
), f"Minimal prompt length is {min_prompt_len}"
594-
self.prompt = (
613+
self.input = (
595614
prompt_prefix
596615
* (self.environment.parsed_options.prompt_tokens - min_prompt_len)
597616
+ prompt
@@ -621,7 +640,7 @@ def _on_start(self):
621640
)
622641
if self.tokenizer:
623642
self.prompt_tokenizer_tokens = len(
624-
self.tokenizer.encode(self._get_prompt())
643+
self.tokenizer.encode(self._get_input()[0])
625644
)
626645
else:
627646
self.prompt_tokenizer_tokens = None
@@ -646,24 +665,32 @@ def _on_start(self):
646665

647666
self.first_done = False
648667

649-
def _get_prompt(self):
650-
if not self.environment.parsed_options.prompt_randomize:
651-
return self.prompt
652-
# single letters are single tokens
653-
return (
654-
" ".join(
655-
chr(ord("a") + random.randint(0, 25))
656-
for _ in range(prompt_random_tokens)
668+
def _get_input(self):
669+
def _maybe_randomize(prompt):
670+
if not self.environment.parsed_options.prompt_randomize:
671+
return prompt
672+
# single letters are single tokens
673+
return (
674+
" ".join(
675+
chr(ord("a") + random.randint(0, 25))
676+
for _ in range(prompt_random_tokens)
677+
)
678+
+ " "
679+
+ prompt
657680
)
658-
+ " "
659-
+ self.prompt
660-
)
681+
682+
if isinstance(self.input, str):
683+
return _maybe_randomize(self.input), None
684+
else:
685+
item = self.input[random.randint(0, len(self.input) - 1)]
686+
assert "prompt" in item
687+
return _maybe_randomize(item["prompt"]), item.get("images", None)
661688

662689
@task
663690
def generate_text(self):
664691
max_tokens = self.max_tokens_sampler.sample()
665-
prompt = self._get_prompt()
666-
data = self.provider_formatter.format_payload(prompt, max_tokens)
692+
prompt, images = self._get_input()
693+
data = self.provider_formatter.format_payload(prompt, max_tokens, images)
667694
t_start = time.perf_counter()
668695

669696
with self.client.post(
@@ -944,9 +971,9 @@ def _(environment, **kw):
944971

945972
entries = copy.copy(InitTracker.logging_params)
946973
if environment.parsed_options.qps is not None:
947-
entries[
948-
"concurrency"
949-
] = f"QPS {environment.parsed_options.qps} {environment.parsed_options.qps_distribution}"
974+
entries["concurrency"] = (
975+
f"QPS {environment.parsed_options.qps} {environment.parsed_options.qps_distribution}"
976+
)
950977
else:
951978
entries["concurrency"] = InitTracker.users
952979
for metric_name in [

0 commit comments

Comments
 (0)