Skip to content

Commit f5f3b81

Browse files
committed
updated model-config
1 parent 91db569 commit f5f3b81

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

ts/torch_handler/distributed/base_neuronx_microbatching_handler.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self):
2929
self.tokenizer_class = None
3030
self.output_streamer = None
3131
# enable micro batching
32-
self.micro_batching_handle = MicroBatching(self)
32+
self.handle = MicroBatching(self)
3333

3434
def initialize(self, ctx: Context):
3535
ctx.cache = {}
@@ -43,12 +43,12 @@ def initialize(self, ctx: Context):
4343
logger.info(
4444
f"Setting micro batching parallelism from model_config_yaml: {micro_batching_parallelism}"
4545
)
46-
self.micro_batching_handle.parallelism = micro_batching_parallelism
46+
self.handle.parallelism = micro_batching_parallelism
4747

4848
micro_batch_size = micro_batch_config.get("micro_batch_size", 1)
4949
logger.info(f"Setting micro batching size: {micro_batch_size}")
5050

51-
self.micro_batching_handle.micro_batch_size = micro_batch_size
51+
self.handle.micro_batch_size = micro_batch_size
5252

5353
model_checkpoint_dir = handler_config.get("model_checkpoint_dir", "")
5454

@@ -111,7 +111,7 @@ def initialize(self, ctx: Context):
111111
kwargs = dict(
112112
tp_degree=tp_degree,
113113
amp=amp,
114-
batch_size=self.micro_batching_handle.micro_batch_size,
114+
batch_size=self.handle.micro_batch_size,
115115
n_positions=[self.max_length],
116116
context_length_estimate=handler_config.get(
117117
"context_length_estimate", [self.max_length]
@@ -127,7 +127,7 @@ def initialize(self, ctx: Context):
127127
self.model = HuggingFaceGenerationModelAdapter(model_config, self.model)
128128
self.output_streamer = TextIteratorStreamerBatch(
129129
self.tokenizer,
130-
batch_size=self.micro_batching_handle.micro_batch_size,
130+
batch_size=self.handle.micro_batch_size,
131131
skip_special_tokens=True,
132132
)
133133

@@ -145,15 +145,13 @@ def preprocess(self, requests):
145145
inputs.append(prompt)
146146

147147
# Ensure the compiled model can handle the input received
148-
if len(inputs) > self.micro_batching_handle.micro_batch_size:
148+
if len(inputs) > self.handle.micro_batch_size:
149149
raise ValueError(
150-
f"Model is compiled for batch size {self.micro_batching_handle.micro_batch_size} but received input of size {len(inputs)}"
150+
f"Model is compiled for batch size {self.handle.micro_batch_size} but received input of size {len(inputs)}"
151151
)
152152

153153
# Pad input to match compiled model batch size
154-
inputs.extend(
155-
[""] * (self.micro_batching_handle.micro_batch_size - len(inputs))
156-
)
154+
inputs.extend([""] * (self.handle.micro_batch_size - len(inputs)))
157155

158156
return self.tokenizer(inputs, return_tensors="pt", padding=True)
159157

@@ -167,7 +165,7 @@ def inference(self, inputs):
167165
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
168166
thread.start()
169167

170-
micro_batch_idx = self.micro_batching_handle.get_micro_batch_idx()
168+
micro_batch_idx = self.handle.get_micro_batch_idx()
171169
micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx)
172170
for new_text in self.output_streamer:
173171
send_intermediate_predict_response(
@@ -186,13 +184,11 @@ def postprocess(self, inference_output):
186184
return inference_output
187185

188186
def get_micro_batch_req_id_map(self, micro_batch_idx: int):
189-
start_idx = micro_batch_idx * self.micro_batching_handle.micro_batch_size
187+
start_idx = micro_batch_idx * self.handle.micro_batch_size
190188
micro_batch_req_id_map = {
191189
index: self.context.request_ids[batch_index]
192190
for index, batch_index in enumerate(
193-
range(
194-
start_idx, start_idx + self.micro_batching_handle.micro_batch_size
195-
)
191+
range(start_idx, start_idx + self.handle.micro_batch_size)
196192
)
197193
if batch_index in self.context.request_ids
198194
}

0 commit comments

Comments
 (0)