From 8a2bc20d63e6381ae1ca417a23e632ee24a8fab1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= <stephduh@live.fr>
Date: Wed, 5 Mar 2025 01:59:11 +0100
Subject: [PATCH] conditionner: suport sdxl embedddings

---
 conditioner.hpp | 60 ++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 44 insertions(+), 16 deletions(-)

diff --git a/conditioner.hpp b/conditioner.hpp
index 8d1ec31b..6e9acdb1 100644
--- a/conditioner.hpp
+++ b/conditioner.hpp
@@ -51,7 +51,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
 
     std::string trigger_word = "img";  // should be user settable
     std::string embd_dir;
-    int32_t num_custom_embeddings = 0;
+    int32_t num_custom_embeddings   = 0;
+    int32_t num_custom_embeddings_2 = 0;
     std::vector<uint8_t> token_embed_custom;
     std::vector<std::string> readed_embeddings;
 
@@ -131,28 +132,55 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
         params.no_alloc               = false;
         struct ggml_context* embd_ctx = ggml_init(params);
         struct ggml_tensor* embd      = NULL;
-        int64_t hidden_size           = text_model->model.hidden_size;
+        struct ggml_tensor* embd2     = NULL;
         auto on_load                  = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
-            if (tensor_storage.ne[0] != hidden_size) {
-                LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
-                return false;
+            if (tensor_storage.ne[0] != text_model->model.hidden_size) {
+                if (text_model2) {
+                    if (tensor_storage.ne[0] == text_model2->model.hidden_size) {
+                        embd2       = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model2->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
+                        *dst_tensor = embd2;
+                    } else {
+                        LOG_DEBUG("embedding wrong hidden size, got %i, expected %i or %i", tensor_storage.ne[0], text_model->model.hidden_size, text_model2->model.hidden_size);
+                        return false;
+                    }
+                } else {
+                    LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], text_model->model.hidden_size);
+                    return false;
+                }
+            } else {
+                embd        = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
+                *dst_tensor = embd;
             }
-            embd        = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
-            *dst_tensor = embd;
             return true;
         };
         model_loader.load_tensors(on_load, NULL);
         readed_embeddings.push_back(embd_name);
-        token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
-        memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
-               embd->data,
-               ggml_nbytes(embd));
-        for (int i = 0; i < embd->ne[1]; i++) {
-            bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
-            // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
-            num_custom_embeddings++;
+        if (embd) {
+            int64_t hidden_size = text_model->model.hidden_size;
+            token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
+            memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
+                   embd->data,
+                   ggml_nbytes(embd));
+            for (int i = 0; i < embd->ne[1]; i++) {
+                bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
+                // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
+                num_custom_embeddings++;
+            }
+            LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
+        }
+        if (embd2) {
+            int64_t hidden_size = text_model2->model.hidden_size;
+            token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd2));
+            memcpy((void*)(token_embed_custom.data() + num_custom_embeddings_2 * hidden_size * ggml_type_size(embd2->type)),
+                   embd2->data,
+                   ggml_nbytes(embd2));
+            for (int i = 0; i < embd2->ne[1]; i++) {
+                bpe_tokens.push_back(text_model2->model.vocab_size + num_custom_embeddings_2);
+                // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
+                num_custom_embeddings_2++;
+            }
+            LOG_DEBUG("embedding '%s' applied, custom embeddings: %i (text model 2)", embd_name.c_str(), num_custom_embeddings_2);
         }
-        LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
         return true;
     }