@@ -558,11 +558,14 @@ class CLIPEmbeddings : public GGMLBlock {
558
558
auto token_embed_weight = params[" token_embedding.weight" ];
559
559
auto position_embed_weight = params[" position_embedding.weight" ];
560
560
561
- GGML_ASSERT (input_ids->ne [0 ] <= position_embed_weight->ne [0 ]);
561
+ GGML_ASSERT (input_ids->ne [0 ] == position_embed_weight->ne [1 ]);
562
+ input_ids = ggml_reshape_3d (ctx, input_ids, input_ids->ne [0 ], 1 , input_ids->ne [1 ]);
563
+ auto token_embedding = ggml_get_rows (ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids);
564
+ token_embedding = ggml_reshape_3d (ctx, token_embedding, token_embedding->ne [0 ], token_embedding->ne [1 ], token_embedding->ne [3 ]);
562
565
563
566
// token_embedding + position_embedding
564
567
auto x = ggml_add (ctx,
565
- ggml_get_rows (ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids) ,
568
+ token_embedding ,
566
569
position_embed_weight); // [N, n_token, embed_dim]
567
570
return x;
568
571
}
@@ -700,7 +703,7 @@ class CLIPTextModel : public GGMLBlock {
700
703
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks[" final_layer_norm" ]);
701
704
702
705
auto x = embeddings->forward (ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
703
- x = encoder->forward (ctx, x, return_pooled ? -1 : clip_skip, true );
706
+ x = encoder->forward (ctx, x, return_pooled ? -1 : clip_skip, true );
704
707
if (return_pooled || with_final_ln) {
705
708
x = final_layer_norm->forward (ctx, x);
706
709
}
@@ -893,7 +896,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
893
896
return true ;
894
897
}
895
898
struct ggml_init_params params;
896
- params.mem_size = 1 * 1024 * 1024 ; // 1MB
899
+ params.mem_size = 10 * 1024 * 1024 ; // max for custom embeddings 10 MB
897
900
params.mem_buffer = NULL ;
898
901
params.no_alloc = false ;
899
902
struct ggml_context * embd_ctx = ggml_init (params);
@@ -928,9 +931,21 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
928
931
struct ggml_tensor * embeddings,
929
932
size_t max_token_idx = 0 ,
930
933
bool return_pooled = false ) {
934
+ size_t N = input_ids->ne [1 ];
935
+ size_t n_token = input_ids->ne [0 ];
936
+ if (input_ids != NULL && input_ids->ne [0 ] > text_model.n_token ) {
937
+ GGML_ASSERT (input_ids->ne [0 ] % text_model.n_token == 0 );
938
+ input_ids = ggml_reshape_2d (ctx, input_ids, text_model.n_token , input_ids->ne [0 ] / text_model.n_token );
939
+ }
940
+ if (input_ids2 != NULL && input_ids2->ne [0 ] > text_model2.n_token ) {
941
+ GGML_ASSERT (input_ids2->ne [0 ] % text_model2.n_token == 0 );
942
+ input_ids2 = ggml_reshape_2d (ctx, input_ids2, text_model2.n_token , input_ids2->ne [0 ] / text_model2.n_token );
943
+ }
944
+
931
945
if (return_pooled) {
932
946
return text_model2.forward (ctx, input_ids2, NULL , max_token_idx, return_pooled);
933
947
}
948
+
934
949
auto hidden_states = text_model.forward (ctx, input_ids, embeddings); // [N, n_token, hidden_size]
935
950
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
936
951
if (version == VERSION_XL) {
@@ -956,6 +971,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
956
971
957
972
hidden_states = ggml_cont (ctx, ggml_permute (ctx, hidden_states, 1 , 2 , 0 , 3 ));
958
973
}
974
+ hidden_states = ggml_reshape_3d (ctx, hidden_states, hidden_states->ne [0 ], n_token, N);
959
975
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
960
976
return hidden_states;
961
977
}
@@ -1061,26 +1077,48 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
1061
1077
tokens.insert (tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1062
1078
weights.insert (weights.end (), curr_tokens.size (), curr_weight);
1063
1079
}
1064
- tokens.insert (tokens.begin (), BOS_TOKEN_ID);
1065
- weights.insert (weights.begin (), 1.0 );
1066
1080
1067
- if (max_length > 0 ) {
1068
- if (tokens.size () > max_length - 1 ) {
1069
- tokens.resize (max_length - 1 );
1070
- weights.resize (max_length - 1 );
1071
- tokens.push_back (EOS_TOKEN_ID);
1072
- weights.push_back (1.0 );
1073
- } else {
1074
- tokens.push_back (EOS_TOKEN_ID);
1075
- weights.push_back (1.0 );
1076
- if (padding) {
1077
- int pad_token_id = PAD_TOKEN_ID;
1078
- if (version == VERSION_2_x) {
1079
- pad_token_id = 0 ;
1080
- }
1081
- tokens.insert (tokens.end (), max_length - tokens.size (), pad_token_id);
1082
- weights.insert (weights.end (), max_length - weights.size (), 1.0 );
1081
+ if (max_length > 0 && padding) {
1082
+ size_t n = std::ceil (tokens.size () * 1.0 / (max_length - 2 ));
1083
+ if (n == 0 ) {
1084
+ n = 1 ;
1085
+ }
1086
+ size_t length = max_length * n;
1087
+ LOG_DEBUG (" token length: %llu" , length);
1088
+ std::vector<int > new_tokens;
1089
+ std::vector<float > new_weights;
1090
+ new_tokens.push_back (BOS_TOKEN_ID);
1091
+ new_weights.push_back (1.0 );
1092
+ int token_idx = 0 ;
1093
+ for (int i = 1 ; i < length; i++) {
1094
+ if (token_idx >= tokens.size ()) {
1095
+ break ;
1096
+ }
1097
+ if (i % max_length == 0 ) {
1098
+ new_tokens.push_back (BOS_TOKEN_ID);
1099
+ new_weights.push_back (1.0 );
1100
+ } else if (i % max_length == max_length - 1 ) {
1101
+ new_tokens.push_back (EOS_TOKEN_ID);
1102
+ new_weights.push_back (1.0 );
1103
+ } else {
1104
+ new_tokens.push_back (tokens[token_idx]);
1105
+ new_weights.push_back (weights[token_idx]);
1106
+ token_idx++;
1107
+ }
1108
+ }
1109
+
1110
+ new_tokens.push_back (EOS_TOKEN_ID);
1111
+ new_weights.push_back (1.0 );
1112
+ tokens = new_tokens;
1113
+ weights = new_weights;
1114
+
1115
+ if (padding) {
1116
+ int pad_token_id = PAD_TOKEN_ID;
1117
+ if (version == VERSION_2_x) {
1118
+ pad_token_id = 0 ;
1083
1119
}
1120
+ tokens.insert (tokens.end (), length - tokens.size (), pad_token_id);
1121
+ weights.insert (weights.end (), length - weights.size (), 1.0 );
1084
1122
}
1085
1123
}
1086
1124
0 commit comments