@@ -51,7 +51,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
51
51
52
52
std::string trigger_word = " img" ; // should be user settable
53
53
std::string embd_dir;
54
- int32_t num_custom_embeddings = 0 ;
54
+ int32_t num_custom_embeddings = 0 ;
55
+ int32_t num_custom_embeddings_2 = 0 ;
55
56
std::vector<uint8_t > token_embed_custom;
56
57
std::vector<std::string> readed_embeddings;
57
58
@@ -131,28 +132,55 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
131
132
params.no_alloc = false ;
132
133
struct ggml_context * embd_ctx = ggml_init (params);
133
134
struct ggml_tensor * embd = NULL ;
134
- int64_t hidden_size = text_model-> model . hidden_size ;
135
+ struct ggml_tensor * embd2 = NULL ;
135
136
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
136
- if (tensor_storage.ne [0 ] != hidden_size) {
137
- LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i" , tensor_storage.ne [0 ], hidden_size);
138
- return false ;
137
+ if (tensor_storage.ne [0 ] != text_model->model .hidden_size ) {
138
+ if (text_model2) {
139
+ if (tensor_storage.ne [0 ] == text_model2->model .hidden_size ) {
140
+ embd2 = ggml_new_tensor_2d (embd_ctx, tensor_storage.type , text_model2->model .hidden_size , tensor_storage.n_dims > 1 ? tensor_storage.ne [1 ] : 1 );
141
+ *dst_tensor = embd2;
142
+ } else {
143
+ LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i or %i" , tensor_storage.ne [0 ], text_model->model .hidden_size , text_model2->model .hidden_size );
144
+ return false ;
145
+ }
146
+ } else {
147
+ LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i" , tensor_storage.ne [0 ], text_model->model .hidden_size );
148
+ return false ;
149
+ }
150
+ } else {
151
+ embd = ggml_new_tensor_2d (embd_ctx, tensor_storage.type , text_model->model .hidden_size , tensor_storage.n_dims > 1 ? tensor_storage.ne [1 ] : 1 );
152
+ *dst_tensor = embd;
139
153
}
140
- embd = ggml_new_tensor_2d (embd_ctx, tensor_storage.type , hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne [1 ] : 1 );
141
- *dst_tensor = embd;
142
154
return true ;
143
155
};
144
156
model_loader.load_tensors (on_load, NULL );
145
157
readed_embeddings.push_back (embd_name);
146
- token_embed_custom.resize (token_embed_custom.size () + ggml_nbytes (embd));
147
- memcpy ((void *)(token_embed_custom.data () + num_custom_embeddings * hidden_size * ggml_type_size (embd->type )),
148
- embd->data ,
149
- ggml_nbytes (embd));
150
- for (int i = 0 ; i < embd->ne [1 ]; i++) {
151
- bpe_tokens.push_back (text_model->model .vocab_size + num_custom_embeddings);
152
- // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
153
- num_custom_embeddings++;
158
+ if (embd) {
159
+ int64_t hidden_size = text_model->model .hidden_size ;
160
+ token_embed_custom.resize (token_embed_custom.size () + ggml_nbytes (embd));
161
+ memcpy ((void *)(token_embed_custom.data () + num_custom_embeddings * hidden_size * ggml_type_size (embd->type )),
162
+ embd->data ,
163
+ ggml_nbytes (embd));
164
+ for (int i = 0 ; i < embd->ne [1 ]; i++) {
165
+ bpe_tokens.push_back (text_model->model .vocab_size + num_custom_embeddings);
166
+ // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
167
+ num_custom_embeddings++;
168
+ }
169
+ LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i" , embd_name.c_str (), num_custom_embeddings);
170
+ }
171
+ if (embd2) {
172
+ int64_t hidden_size = text_model2->model .hidden_size ;
173
+ token_embed_custom.resize (token_embed_custom.size () + ggml_nbytes (embd2));
174
+ memcpy ((void *)(token_embed_custom.data () + num_custom_embeddings_2 * hidden_size * ggml_type_size (embd2->type )),
175
+ embd2->data ,
176
+ ggml_nbytes (embd2));
177
+ for (int i = 0 ; i < embd2->ne [1 ]; i++) {
178
+ bpe_tokens.push_back (text_model2->model .vocab_size + num_custom_embeddings_2);
179
+ // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
180
+ num_custom_embeddings_2++;
181
+ }
182
+ LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i (text model 2)" , embd_name.c_str (), num_custom_embeddings_2);
154
183
}
155
- LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i" , embd_name.c_str (), num_custom_embeddings);
156
184
return true ;
157
185
}
158
186
0 commit comments