Skip to content

Commit aac6bc4

Browse files
committed
fmt
1 parent b5918a2 commit aac6bc4

File tree

2 files changed

+225
-0
lines changed

2 files changed

+225
-0
lines changed

examples/large_models/inferentia2/llama2/streamer/inf2_handler.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def inference(self, tokenized_input):
134134
micro_batch_idx = self.handle.get_micro_batch_idx()
135135
micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx)
136136
for new_text in self.output_streamer:
137+
logger.debug("send response stream")
137138
send_intermediate_predict_response(
138139
new_text[: len(micro_batch_req_id_map)],
139140
micro_batch_req_id_map,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import logging
2+
import os
3+
import pathlib
4+
from threading import Thread
5+
6+
import torch_neuronx
7+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
8+
from transformers_neuronx.config import NeuronConfig
9+
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
10+
from transformers_neuronx.module import save_pretrained_split
11+
12+
from ts.context import Context
13+
from ts.handler_utils.hf_batch_streamer import TextIteratorStreamerBatch
14+
from ts.handler_utils.micro_batching import MicroBatching
15+
from ts.handler_utils.utils import import_class, send_intermediate_predict_response
16+
from ts.torch_handler.base_handler import BaseHandler
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class BaseNeuronXContinuousBatchingHandler(BaseHandler):
22+
def __init__(self):
23+
super().__init__()
24+
25+
self.max_new_tokens = 25
26+
self.max_length = 100
27+
self.tokenizer = None
28+
self.model_class = None
29+
self.tokenizer_class = None
30+
self.output_streamer = None
31+
# enable micro batching
32+
self.micro_batching_handle = MicroBatching(self)
33+
34+
def initialize(self, ctx: Context):
35+
ctx.cache = {}
36+
model_dir = ctx.system_properties.get("model_dir")
37+
handler_config = ctx.model_yaml_config.get("handler", {})
38+
39+
# micro batching initialization
40+
micro_batching_parallelism = handler_config.get("micro_batching", {}).get(
41+
"parallelism", None
42+
)
43+
if micro_batching_parallelism:
44+
logger.info(
45+
f"Setting micro batching parallelism from model_config_yaml: {micro_batching_parallelism}"
46+
)
47+
self.micro_batching_handle.parallelism = micro_batching_parallelism
48+
49+
micro_batch_size = handler_config.get("micro_batching", {}).get(
50+
"micro_batch_size", 1
51+
)
52+
logger.info(f"Setting micro batching size: {micro_batch_size}")
53+
54+
self.micro_batching_handle.micro_batch_size = micro_batch_size
55+
56+
model_checkpoint_dir = handler_config.get("model_checkpoint_dir", "")
57+
58+
model_checkpoint_path = pathlib.Path(model_dir).joinpath(model_checkpoint_dir)
59+
model_path = pathlib.Path(model_dir).joinpath(
60+
handler_config.get("model_path", "")
61+
)
62+
63+
if not model_checkpoint_path.exists():
64+
# Load and save the CPU model
65+
model_cpu = AutoModelForCausalLM.from_pretrained(
66+
str(model_path), low_cpu_mem_usage=True
67+
)
68+
save_pretrained_split(model_cpu, model_checkpoint_path)
69+
# Load and save tokenizer for the model
70+
tokenizer = AutoTokenizer.from_pretrained(
71+
str(model_path), return_tensors="pt", padding_side="left"
72+
)
73+
tokenizer.save_pretrained(model_checkpoint_path)
74+
75+
os.environ["NEURONX_CACHE"] = "on"
76+
os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache"
77+
os.environ[
78+
"NEURON_CC_FLAGS"
79+
] = "-O1 --model-type=transformer --enable-mixed-precision-accumulation"
80+
81+
self.max_length = int(handler_config.get("max_length", self.max_length))
82+
self.max_new_tokens = int(
83+
handler_config.get("max_new_tokens", self.max_new_tokens)
84+
)
85+
self.batch_size = int(handler_config.get("batch_size", self.batch_size))
86+
87+
# settings for model compilation and loading
88+
amp = handler_config.get("amp", "fp32")
89+
tp_degree = handler_config.get("tp_degree", 6)
90+
91+
# allocate "tp_degree" number of neuron cores to the worker process
92+
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
93+
try:
94+
num_neuron_cores_available = (
95+
torch_neuronx.xla_impl.data_parallel.device_count()
96+
)
97+
assert num_neuron_cores_available >= int(tp_degree)
98+
except (RuntimeError, AssertionError) as error:
99+
logger.error(
100+
"Required number of neuron cores for tp_degree "
101+
+ str(tp_degree)
102+
+ " are not available: "
103+
+ str(error)
104+
)
105+
106+
raise error
107+
self._set_class(ctx)
108+
self.tokenizer = self.tokenizer_class.from_pretrained(
109+
model_checkpoint_path, return_tensors="pt", padding_side="left"
110+
)
111+
self.tokenizer.pad_token = self.tokenizer.eos_token
112+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
113+
114+
neuron_config = NeuronConfig()
115+
kwargs = dict(
116+
tp_degree=tp_degree,
117+
amp=amp,
118+
batch_size=self.micro_batching_handle.micro_batch_size,
119+
n_positions=[self.max_length],
120+
context_length_estimate=handler_config.get(
121+
"context_length_estimate", [self.max_length]
122+
),
123+
neuron_config=neuron_config,
124+
)
125+
self.model = self.model_class.from_pretrained(model_checkpoint_path, **kwargs)
126+
logger.info("Starting to compile the model")
127+
self.model.to_neuron()
128+
logger.info("Model has been successfully compiled")
129+
130+
model_config = AutoConfig.from_pretrained(model_checkpoint_path)
131+
self.model = HuggingFaceGenerationModelAdapter(model_config, self.model)
132+
self.output_streamer = TextIteratorStreamerBatch(
133+
self.tokenizer,
134+
batch_size=self.micro_batching_handle.micro_batch_size,
135+
skip_special_tokens=True,
136+
)
137+
138+
logger.info("Model %s loaded successfully", ctx.model_name)
139+
self.initialized = True
140+
141+
def preprocess(self, requests):
142+
inputs = []
143+
for req in requests:
144+
data = req.get("data") or req.get("body")
145+
if isinstance(data, (bytes, bytearray)):
146+
data = data.decode("utf-8")
147+
148+
prompt = data.get("prompt")
149+
inputs.append(prompt)
150+
151+
# Ensure the compiled model can handle the input received
152+
if len(inputs) > self.micro_batching_handle.micro_batch_size:
153+
raise ValueError(
154+
f"Model is compiled for batch size {self.micro_batching_handle.micro_batch_size} but received input of size {len(inputs)}"
155+
)
156+
157+
# Pad input to match compiled model batch size
158+
inputs.extend([""] * (self.handle.micro_batch_size - len(inputs)))
159+
160+
return self.tokenizer(inputs, return_tensors="pt", padding=True)
161+
162+
def inference(self, inputs):
163+
generation_kwargs = dict(
164+
inputs,
165+
streamer=self.output_streamer,
166+
max_new_tokens=self.max_new_tokens,
167+
)
168+
self.model.reset_generation()
169+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
170+
thread.start()
171+
172+
micro_batch_idx = self.handle.get_micro_batch_idx()
173+
micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx)
174+
for new_text in self.output_streamer:
175+
send_intermediate_predict_response(
176+
new_text[: len(micro_batch_req_id_map)],
177+
micro_batch_req_id_map,
178+
"Intermediate Prediction success",
179+
200,
180+
self.context,
181+
)
182+
183+
thread.join()
184+
185+
return [""] * len(micro_batch_req_id_map)
186+
187+
def postprocess(self, inference_output):
188+
return inference_output
189+
190+
def get_micro_batch_req_id_map(self, micro_batch_idx: int):
191+
start_idx = micro_batch_idx * self.handle.micro_batch_size
192+
micro_batch_req_id_map = {
193+
index: self.context.request_ids[batch_index]
194+
for index, batch_index in enumerate(
195+
range(start_idx, start_idx + self.handle.micro_batch_size)
196+
)
197+
if batch_index in self.context.request_ids
198+
}
199+
200+
return micro_batch_req_id_map
201+
202+
def _set_class(self, ctx):
203+
handler_config = ctx.model_yaml_config.get("handler", {})
204+
model_class_name = handler_config.get("model_class_name", None)
205+
206+
assert (
207+
model_class_name
208+
), "model_class_name not found in the section of handler in model config yaml file"
209+
model_module_prefix = handler_config.get("model_module_prefix", None)
210+
self.model_class = import_class(
211+
class_name=model_class_name,
212+
module_prefix=model_module_prefix,
213+
)
214+
215+
tokenizer_class_name = handler_config.get("tokenizer_class_name", None)
216+
assert (
217+
tokenizer_class_name
218+
), "tokenizer_class_name not found in the section of handler in model config yaml file"
219+
220+
tokenizer_module_prefix = handler_config.get("tokenizer_module_prefix", None)
221+
222+
self.tokenizer_class = import_class(
223+
class_name=tokenizer_class_name, module_prefix=tokenizer_module_prefix
224+
)

0 commit comments

Comments
 (0)