@@ -29,7 +29,7 @@ def __init__(self):
29
29
self .tokenizer_class = None
30
30
self .output_streamer = None
31
31
# enable micro batching
32
- self .micro_batching_handle = MicroBatching (self )
32
+ self .handle = MicroBatching (self )
33
33
34
34
def initialize (self , ctx : Context ):
35
35
ctx .cache = {}
@@ -43,12 +43,12 @@ def initialize(self, ctx: Context):
43
43
logger .info (
44
44
f"Setting micro batching parallelism from model_config_yaml: { micro_batching_parallelism } "
45
45
)
46
- self .micro_batching_handle .parallelism = micro_batching_parallelism
46
+ self .handle .parallelism = micro_batching_parallelism
47
47
48
48
micro_batch_size = micro_batch_config .get ("micro_batch_size" , 1 )
49
49
logger .info (f"Setting micro batching size: { micro_batch_size } " )
50
50
51
- self .micro_batching_handle .micro_batch_size = micro_batch_size
51
+ self .handle .micro_batch_size = micro_batch_size
52
52
53
53
model_checkpoint_dir = handler_config .get ("model_checkpoint_dir" , "" )
54
54
@@ -111,7 +111,7 @@ def initialize(self, ctx: Context):
111
111
kwargs = dict (
112
112
tp_degree = tp_degree ,
113
113
amp = amp ,
114
- batch_size = self .micro_batching_handle .micro_batch_size ,
114
+ batch_size = self .handle .micro_batch_size ,
115
115
n_positions = [self .max_length ],
116
116
context_length_estimate = handler_config .get (
117
117
"context_length_estimate" , [self .max_length ]
@@ -127,7 +127,7 @@ def initialize(self, ctx: Context):
127
127
self .model = HuggingFaceGenerationModelAdapter (model_config , self .model )
128
128
self .output_streamer = TextIteratorStreamerBatch (
129
129
self .tokenizer ,
130
- batch_size = self .micro_batching_handle .micro_batch_size ,
130
+ batch_size = self .handle .micro_batch_size ,
131
131
skip_special_tokens = True ,
132
132
)
133
133
@@ -145,15 +145,13 @@ def preprocess(self, requests):
145
145
inputs .append (prompt )
146
146
147
147
# Ensure the compiled model can handle the input received
148
- if len (inputs ) > self .micro_batching_handle .micro_batch_size :
148
+ if len (inputs ) > self .handle .micro_batch_size :
149
149
raise ValueError (
150
- f"Model is compiled for batch size { self .micro_batching_handle .micro_batch_size } but received input of size { len (inputs )} "
150
+ f"Model is compiled for batch size { self .handle .micro_batch_size } but received input of size { len (inputs )} "
151
151
)
152
152
153
153
# Pad input to match compiled model batch size
154
- inputs .extend (
155
- ["" ] * (self .micro_batching_handle .micro_batch_size - len (inputs ))
156
- )
154
+ inputs .extend (["" ] * (self .handle .micro_batch_size - len (inputs )))
157
155
158
156
return self .tokenizer (inputs , return_tensors = "pt" , padding = True )
159
157
@@ -167,7 +165,7 @@ def inference(self, inputs):
167
165
thread = Thread (target = self .model .generate , kwargs = generation_kwargs )
168
166
thread .start ()
169
167
170
- micro_batch_idx = self .micro_batching_handle .get_micro_batch_idx ()
168
+ micro_batch_idx = self .handle .get_micro_batch_idx ()
171
169
micro_batch_req_id_map = self .get_micro_batch_req_id_map (micro_batch_idx )
172
170
for new_text in self .output_streamer :
173
171
send_intermediate_predict_response (
@@ -186,13 +184,11 @@ def postprocess(self, inference_output):
186
184
return inference_output
187
185
188
186
def get_micro_batch_req_id_map (self , micro_batch_idx : int ):
189
- start_idx = micro_batch_idx * self .micro_batching_handle .micro_batch_size
187
+ start_idx = micro_batch_idx * self .handle .micro_batch_size
190
188
micro_batch_req_id_map = {
191
189
index : self .context .request_ids [batch_index ]
192
190
for index , batch_index in enumerate (
193
- range (
194
- start_idx , start_idx + self .micro_batching_handle .micro_batch_size
195
- )
191
+ range (start_idx , start_idx + self .handle .micro_batch_size )
196
192
)
197
193
if batch_index in self .context .request_ids
198
194
}
0 commit comments