Skip to content

Commit ef5c3f7

Browse files
committed
feat: add support for prompt longer than 77
1 parent b7870a0 commit ef5c3f7

File tree

4 files changed

+133
-77
lines changed

4 files changed

+133
-77
lines changed

clip.hpp

+60-22
Original file line numberDiff line numberDiff line change
@@ -558,11 +558,14 @@ class CLIPEmbeddings : public GGMLBlock {
558558
auto token_embed_weight = params["token_embedding.weight"];
559559
auto position_embed_weight = params["position_embedding.weight"];
560560

561-
GGML_ASSERT(input_ids->ne[0] <= position_embed_weight->ne[0]);
561+
GGML_ASSERT(input_ids->ne[0] == position_embed_weight->ne[1]);
562+
input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
563+
auto token_embedding = ggml_get_rows(ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids);
564+
token_embedding = ggml_reshape_3d(ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]);
562565

563566
// token_embedding + position_embedding
564567
auto x = ggml_add(ctx,
565-
ggml_get_rows(ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids),
568+
token_embedding,
566569
position_embed_weight); // [N, n_token, embed_dim]
567570
return x;
568571
}
@@ -700,7 +703,7 @@ class CLIPTextModel : public GGMLBlock {
700703
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
701704

702705
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
703-
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
706+
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
704707
if (return_pooled || with_final_ln) {
705708
x = final_layer_norm->forward(ctx, x);
706709
}
@@ -889,7 +892,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
889892
return false;
890893
}
891894
struct ggml_init_params params;
892-
params.mem_size = 32 * 1024; // max for custom embeddings 32 KB
895+
params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
893896
params.mem_buffer = NULL;
894897
params.no_alloc = false;
895898
struct ggml_context* embd_ctx = ggml_init(params);
@@ -924,9 +927,21 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
924927
struct ggml_tensor* embeddings,
925928
size_t max_token_idx = 0,
926929
bool return_pooled = false) {
930+
size_t N = input_ids->ne[1];
931+
size_t n_token = input_ids->ne[0];
932+
if (input_ids != NULL && input_ids->ne[0] > text_model.n_token) {
933+
GGML_ASSERT(input_ids->ne[0] % text_model.n_token == 0);
934+
input_ids = ggml_reshape_2d(ctx, input_ids, text_model.n_token, input_ids->ne[0] / text_model.n_token);
935+
}
936+
if (input_ids2 != NULL && input_ids2->ne[0] > text_model2.n_token) {
937+
GGML_ASSERT(input_ids2->ne[0] % text_model2.n_token == 0);
938+
input_ids2 = ggml_reshape_2d(ctx, input_ids2, text_model2.n_token, input_ids2->ne[0] / text_model2.n_token);
939+
}
940+
927941
if (return_pooled) {
928942
return text_model2.forward(ctx, input_ids2, NULL, max_token_idx, return_pooled);
929943
}
944+
930945
auto hidden_states = text_model.forward(ctx, input_ids, embeddings); // [N, n_token, hidden_size]
931946
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
932947
if (version == VERSION_XL) {
@@ -952,6 +967,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
952967

953968
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 2, 0, 3));
954969
}
970+
hidden_states = ggml_reshape_3d(ctx, hidden_states, hidden_states->ne[0], n_token, N);
955971
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
956972
return hidden_states;
957973
}
@@ -1057,26 +1073,48 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
10571073
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
10581074
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
10591075
}
1060-
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
1061-
weights.insert(weights.begin(), 1.0);
10621076

1063-
if (max_length > 0) {
1064-
if (tokens.size() > max_length - 1) {
1065-
tokens.resize(max_length - 1);
1066-
weights.resize(max_length - 1);
1067-
tokens.push_back(EOS_TOKEN_ID);
1068-
weights.push_back(1.0);
1069-
} else {
1070-
tokens.push_back(EOS_TOKEN_ID);
1071-
weights.push_back(1.0);
1072-
if (padding) {
1073-
int pad_token_id = PAD_TOKEN_ID;
1074-
if (version == VERSION_2_x) {
1075-
pad_token_id = 0;
1076-
}
1077-
tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id);
1078-
weights.insert(weights.end(), max_length - weights.size(), 1.0);
1077+
if (max_length > 0 && padding) {
1078+
size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2));
1079+
if (n == 0) {
1080+
n = 1;
1081+
}
1082+
size_t length = max_length * n;
1083+
LOG_DEBUG("token length: %llu", length);
1084+
std::vector<int> new_tokens;
1085+
std::vector<float> new_weights;
1086+
new_tokens.push_back(BOS_TOKEN_ID);
1087+
new_weights.push_back(1.0);
1088+
int token_idx = 0;
1089+
for (int i = 1; i < length; i++) {
1090+
if (token_idx >= tokens.size()) {
1091+
break;
1092+
}
1093+
if (i % max_length == 0) {
1094+
new_tokens.push_back(BOS_TOKEN_ID);
1095+
new_weights.push_back(1.0);
1096+
} else if (i % max_length == max_length - 1) {
1097+
new_tokens.push_back(EOS_TOKEN_ID);
1098+
new_weights.push_back(1.0);
1099+
} else {
1100+
new_tokens.push_back(tokens[token_idx]);
1101+
new_weights.push_back(weights[token_idx]);
1102+
token_idx++;
1103+
}
1104+
}
1105+
1106+
new_tokens.push_back(EOS_TOKEN_ID);
1107+
new_weights.push_back(1.0);
1108+
tokens = new_tokens;
1109+
weights = new_weights;
1110+
1111+
if (padding) {
1112+
int pad_token_id = PAD_TOKEN_ID;
1113+
if (version == VERSION_2_x) {
1114+
pad_token_id = 0;
10791115
}
1116+
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
1117+
weights.insert(weights.end(), length - weights.size(), 1.0);
10801118
}
10811119
}
10821120

ggml_extend.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1231,9 +1231,9 @@ class MultiheadAttention : public GGMLBlock {
12311231
q = ggml_reshape_3d(ctx, q, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head]
12321232

12331233
struct ggml_tensor* k = k_proj->forward(ctx, x);
1234-
k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
1235-
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
1236-
k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head); // [N * n_head, n_token, d_head]
1234+
k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
1235+
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
1236+
k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head]
12371237

12381238
struct ggml_tensor* v = v_proj->forward(ctx, x);
12391239
v = ggml_reshape_4d(ctx, v, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
@@ -1245,7 +1245,7 @@ class MultiheadAttention : public GGMLBlock {
12451245
kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N);
12461246
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
12471247

1248-
x = ggml_reshape_2d(ctx, kqv, d_head * n_head, n_token * N); // [N * n_token, d_head * n_head]
1248+
x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N * n_token, d_head * n_head]
12491249

12501250
x = out_proj->forward(ctx, x);
12511251
return x;

stable-diffusion.cpp

+67-49
Original file line numberDiff line numberDiff line change
@@ -451,65 +451,83 @@ class StableDiffusionGGML {
451451
int height,
452452
bool force_zero_embeddings = false) {
453453
cond_stage_model->set_clip_skip(clip_skip);
454-
auto tokens_and_weights = cond_stage_model->tokenize(text, true);
455-
std::vector<int>& tokens = tokens_and_weights.first;
456-
std::vector<float>& weights = tokens_and_weights.second;
457-
int64_t t0 = ggml_time_ms();
458-
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
459-
struct ggml_tensor* pooled = NULL;
460-
461-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
462-
struct ggml_tensor* input_ids2 = NULL;
463-
size_t max_token_idx = 0;
464-
if (version == VERSION_XL) {
465-
auto it = std::find(tokens.begin(), tokens.end(), EOS_TOKEN_ID);
466-
if (it != tokens.end()) {
467-
std::fill(std::next(it), tokens.end(), 0);
468-
}
454+
auto tokens_and_weights = cond_stage_model->tokenize(text, true);
455+
std::vector<int>& tokens = tokens_and_weights.first;
456+
std::vector<float>& weights = tokens_and_weights.second;
457+
int64_t t0 = ggml_time_ms();
458+
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
459+
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, hidden_size]
460+
struct ggml_tensor* pooled = NULL;
461+
std::vector<float> hidden_states_vec;
462+
463+
size_t chunk_len = 77;
464+
size_t chunk_count = tokens.size() / chunk_len;
465+
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
466+
std::vector<int> chunk_tokens(tokens.begin() + chunk_idx * chunk_len,
467+
tokens.begin() + (chunk_idx + 1) * chunk_len);
468+
std::vector<float> chunk_weights(weights.begin() + chunk_idx * chunk_len,
469+
weights.begin() + (chunk_idx + 1) * chunk_len);
470+
471+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
472+
struct ggml_tensor* input_ids2 = NULL;
473+
size_t max_token_idx = 0;
474+
if (version == VERSION_XL) {
475+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), EOS_TOKEN_ID);
476+
if (it != chunk_tokens.end()) {
477+
std::fill(std::next(it), chunk_tokens.end(), 0);
478+
}
469479

470-
max_token_idx = std::min<size_t>(std::distance(tokens.begin(), it), tokens.size() - 1);
480+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
471481

472-
input_ids2 = vector_to_ggml_tensor_i32(work_ctx, tokens);
482+
input_ids2 = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
473483

474-
// for (int i = 0; i < tokens.size(); i++) {
475-
// printf("%d ", tokens[i]);
476-
// }
477-
// printf("\n");
478-
}
484+
// for (int i = 0; i < chunk_tokens.size(); i++) {
485+
// printf("%d ", chunk_tokens[i]);
486+
// }
487+
// printf("\n");
488+
}
479489

480-
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, false, &hidden_states, work_ctx);
481-
if (version == VERSION_XL) {
482-
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, true, &pooled, work_ctx);
483-
}
484-
// if (pooled != NULL) {
485-
// print_ggml_tensor(hidden_states);
486-
// print_ggml_tensor(pooled);
487-
// }
490+
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, false, &chunk_hidden_states, work_ctx);
491+
if (version == VERSION_XL && chunk_idx == 0) {
492+
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, true, &pooled, work_ctx);
493+
}
494+
// if (pooled != NULL) {
495+
// print_ggml_tensor(chunk_hidden_states);
496+
// print_ggml_tensor(pooled);
497+
// }
488498

489-
int64_t t1 = ggml_time_ms();
490-
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
491-
ggml_tensor* result = ggml_dup_tensor(work_ctx, hidden_states);
492-
{
493-
float original_mean = ggml_tensor_mean(hidden_states);
494-
for (int i2 = 0; i2 < hidden_states->ne[2]; i2++) {
495-
for (int i1 = 0; i1 < hidden_states->ne[1]; i1++) {
496-
for (int i0 = 0; i0 < hidden_states->ne[0]; i0++) {
497-
float value = ggml_tensor_get_f32(hidden_states, i0, i1, i2);
498-
value *= weights[i1];
499-
ggml_tensor_set_f32(result, value, i0, i1, i2);
499+
int64_t t1 = ggml_time_ms();
500+
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
501+
ggml_tensor* result = ggml_dup_tensor(work_ctx, chunk_hidden_states);
502+
{
503+
float original_mean = ggml_tensor_mean(chunk_hidden_states);
504+
for (int i2 = 0; i2 < chunk_hidden_states->ne[2]; i2++) {
505+
for (int i1 = 0; i1 < chunk_hidden_states->ne[1]; i1++) {
506+
for (int i0 = 0; i0 < chunk_hidden_states->ne[0]; i0++) {
507+
float value = ggml_tensor_get_f32(chunk_hidden_states, i0, i1, i2);
508+
value *= chunk_weights[i1];
509+
ggml_tensor_set_f32(result, value, i0, i1, i2);
510+
}
500511
}
501512
}
513+
float new_mean = ggml_tensor_mean(result);
514+
ggml_tensor_scale(result, (original_mean / new_mean));
502515
}
503-
float new_mean = ggml_tensor_mean(result);
504-
ggml_tensor_scale(result, (original_mean / new_mean));
505-
}
506-
if (force_zero_embeddings) {
507-
float* vec = (float*)result->data;
508-
for (int i = 0; i < ggml_nelements(result); i++) {
509-
vec[i] = 0;
516+
if (force_zero_embeddings) {
517+
float* vec = (float*)result->data;
518+
for (int i = 0; i < ggml_nelements(result); i++) {
519+
vec[i] = 0;
520+
}
510521
}
522+
hidden_states_vec.insert(hidden_states_vec.end(), (float*)result->data, ((float*)result->data) + ggml_nelements(result));
511523
}
512524

525+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
526+
hidden_states = ggml_reshape_2d(work_ctx,
527+
hidden_states,
528+
chunk_hidden_states->ne[0],
529+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
530+
513531
ggml_tensor* vec = NULL;
514532
if (version == VERSION_XL) {
515533
int out_dim = 256;
@@ -547,7 +565,7 @@ class StableDiffusionGGML {
547565
GGML_ASSERT(offset == ggml_nbytes(vec));
548566
}
549567
// print_ggml_tensor(result);
550-
return {result, vec};
568+
return {hidden_states, vec};
551569
}
552570

553571
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> get_svd_condition(ggml_context* work_ctx,

util.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,14 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo
266266
level_str = "ERROR";
267267
}
268268

269-
static char log_buffer[LOG_BUFFER_SIZE];
269+
static char log_buffer[LOG_BUFFER_SIZE + 1];
270270

271271
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line);
272272

273273
if (written >= 0 && written < LOG_BUFFER_SIZE) {
274274
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
275-
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer) - 1);
276275
}
276+
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer));
277277

278278
if (sd_log_cb) {
279279
sd_log_cb(level, log_buffer, sd_log_cb_data);

0 commit comments

Comments
 (0)