Skip to content

Commit 7ce63e7

Browse files
stduhpfleejet
andauthored
feat: flexible model architecture for dit models (Flux & SD3) (#490)
* Refactor: wtype per tensor * Fix default args * refactor: fix flux * Refactor photmaker v2 support * unet: refactor the refactoring * Refactor: fix controlnet and tae * refactor: upscaler * Refactor: fix runtime type override * upscaler: use fp16 again * Refactor: Flexible sd3 arch * Refactor: Flexible Flux arch * format code --------- Co-authored-by: leejet <[email protected]>
1 parent 4570715 commit 7ce63e7

21 files changed

+317
-271
lines changed

clip.hpp

+27-19
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,12 @@ class CLIPEmbeddings : public GGMLBlock {
545545
int64_t vocab_size;
546546
int64_t num_positions;
547547

548-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
549-
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, wtype, embed_dim, vocab_size);
550-
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
548+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
549+
enum ggml_type token_wtype = (tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
550+
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
551+
552+
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
553+
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
551554
}
552555

553556
public:
@@ -591,11 +594,14 @@ class CLIPVisionEmbeddings : public GGMLBlock {
591594
int64_t image_size;
592595
int64_t num_patches;
593596
int64_t num_positions;
597+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
598+
enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16;
599+
enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32;
600+
enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
594601

595-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
596-
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, patch_size, patch_size, num_channels, embed_dim);
597-
params["class_embedding"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, embed_dim);
598-
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
602+
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim);
603+
params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim);
604+
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
599605
}
600606

601607
public:
@@ -651,9 +657,10 @@ enum CLIPVersion {
651657

652658
class CLIPTextModel : public GGMLBlock {
653659
protected:
654-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
660+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
655661
if (version == OPEN_CLIP_VIT_BIGG_14) {
656-
params["text_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
662+
enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
663+
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
657664
}
658665
}
659666

@@ -798,9 +805,9 @@ class CLIPProjection : public UnaryBlock {
798805
int64_t out_features;
799806
bool transpose_weight;
800807

801-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
808+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
809+
enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
802810
if (transpose_weight) {
803-
LOG_ERROR("transpose_weight");
804811
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
805812
} else {
806813
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
@@ -861,12 +868,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
861868
CLIPTextModel model;
862869

863870
CLIPTextModelRunner(ggml_backend_t backend,
864-
ggml_type wtype,
871+
std::map<std::string, enum ggml_type>& tensor_types,
872+
const std::string prefix,
865873
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
866874
int clip_skip_value = 1,
867875
bool with_final_ln = true)
868-
: GGMLRunner(backend, wtype), model(version, clip_skip_value, with_final_ln) {
869-
model.init(params_ctx, wtype);
876+
: GGMLRunner(backend), model(version, clip_skip_value, with_final_ln) {
877+
model.init(params_ctx, tensor_types, prefix);
870878
}
871879

872880
std::string get_desc() {
@@ -908,13 +916,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
908916
struct ggml_tensor* embeddings = NULL;
909917

910918
if (num_custom_embeddings > 0 && custom_embeddings_data != NULL) {
911-
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
912-
wtype,
913-
model.hidden_size,
914-
num_custom_embeddings);
919+
auto token_embed_weight = model.get_token_embed_weight();
920+
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
921+
token_embed_weight->type,
922+
model.hidden_size,
923+
num_custom_embeddings);
915924
set_backend_tensor_data(custom_embeddings, custom_embeddings_data);
916925

917-
auto token_embed_weight = model.get_token_embed_weight();
918926
// concatenate custom embeddings
919927
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
920928
}

common.hpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,11 @@ class GEGLU : public GGMLBlock {
182182
int64_t dim_in;
183183
int64_t dim_out;
184184

185-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
186-
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
187-
params["proj.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_out * 2);
185+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
186+
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32;
187+
enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32;
188+
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
189+
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
188190
}
189191

190192
public:
@@ -438,8 +440,10 @@ class SpatialTransformer : public GGMLBlock {
438440

439441
class AlphaBlender : public GGMLBlock {
440442
protected:
441-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
442-
params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
443+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
444+
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
445+
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
446+
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
443447
}
444448

445449
float get_alpha() {

conditioner.hpp

+20-24
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
4646
SDVersion version = VERSION_SD1;
4747
PMVersion pm_version = PM_VERSION_1;
4848
CLIPTokenizer tokenizer;
49-
ggml_type wtype;
5049
std::shared_ptr<CLIPTextModelRunner> text_model;
5150
std::shared_ptr<CLIPTextModelRunner> text_model2;
5251

@@ -57,25 +56,25 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
5756
std::vector<std::string> readed_embeddings;
5857

5958
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
60-
ggml_type wtype,
59+
std::map<std::string, enum ggml_type>& tensor_types,
6160
const std::string& embd_dir,
6261
SDVersion version = VERSION_SD1,
6362
PMVersion pv = PM_VERSION_1,
6463
int clip_skip = -1)
65-
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
64+
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
6665
if (clip_skip <= 0) {
6766
clip_skip = 1;
6867
if (version == VERSION_SD2 || version == VERSION_SDXL) {
6968
clip_skip = 2;
7069
}
7170
}
7271
if (version == VERSION_SD1) {
73-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
72+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
7473
} else if (version == VERSION_SD2) {
75-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
74+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
7675
} else if (version == VERSION_SDXL) {
77-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
78-
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
76+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
77+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
7978
}
8079
}
8180

@@ -138,14 +137,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
138137
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
139138
return false;
140139
}
141-
embd = ggml_new_tensor_2d(embd_ctx, wtype, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
140+
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
142141
*dst_tensor = embd;
143142
return true;
144143
};
145144
model_loader.load_tensors(on_load, NULL);
146145
readed_embeddings.push_back(embd_name);
147146
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
148-
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(wtype)),
147+
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
149148
embd->data,
150149
ggml_nbytes(embd));
151150
for (int i = 0; i < embd->ne[1]; i++) {
@@ -590,9 +589,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
590589
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
591590
CLIPVisionModelProjection vision_model;
592591

593-
FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype)
594-
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, wtype) {
595-
vision_model.init(params_ctx, wtype);
592+
FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
593+
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) {
594+
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
596595
}
597596

598597
std::string get_desc() {
@@ -627,7 +626,6 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
627626
};
628627

629628
struct SD3CLIPEmbedder : public Conditioner {
630-
ggml_type wtype;
631629
CLIPTokenizer clip_l_tokenizer;
632630
CLIPTokenizer clip_g_tokenizer;
633631
T5UniGramTokenizer t5_tokenizer;
@@ -636,15 +634,15 @@ struct SD3CLIPEmbedder : public Conditioner {
636634
std::shared_ptr<T5Runner> t5;
637635

638636
SD3CLIPEmbedder(ggml_backend_t backend,
639-
ggml_type wtype,
637+
std::map<std::string, enum ggml_type>& tensor_types,
640638
int clip_skip = -1)
641-
: wtype(wtype), clip_g_tokenizer(0) {
639+
: clip_g_tokenizer(0) {
642640
if (clip_skip <= 0) {
643641
clip_skip = 2;
644642
}
645-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
646-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
647-
t5 = std::make_shared<T5Runner>(backend, wtype);
643+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
644+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
645+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
648646
}
649647

650648
void set_clip_skip(int clip_skip) {
@@ -974,21 +972,19 @@ struct SD3CLIPEmbedder : public Conditioner {
974972
};
975973

976974
struct FluxCLIPEmbedder : public Conditioner {
977-
ggml_type wtype;
978975
CLIPTokenizer clip_l_tokenizer;
979976
T5UniGramTokenizer t5_tokenizer;
980977
std::shared_ptr<CLIPTextModelRunner> clip_l;
981978
std::shared_ptr<T5Runner> t5;
982979

983980
FluxCLIPEmbedder(ggml_backend_t backend,
984-
ggml_type wtype,
985-
int clip_skip = -1)
986-
: wtype(wtype) {
981+
std::map<std::string, enum ggml_type>& tensor_types,
982+
int clip_skip = -1) {
987983
if (clip_skip <= 0) {
988984
clip_skip = 2;
989985
}
990-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true);
991-
t5 = std::make_shared<T5Runner>(backend, wtype);
986+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true);
987+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
992988
}
993989

994990
void set_clip_skip(int clip_skip) {

control.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,10 @@ struct ControlNet : public GGMLRunner {
317317
bool guided_hint_cached = false;
318318

319319
ControlNet(ggml_backend_t backend,
320-
ggml_type wtype,
320+
std::map<std::string, enum ggml_type>& tensor_types,
321321
SDVersion version = VERSION_SD1)
322-
: GGMLRunner(backend, wtype), control_net(version) {
323-
control_net.init(params_ctx, wtype);
322+
: GGMLRunner(backend), control_net(version) {
323+
control_net.init(params_ctx, tensor_types, "");
324324
}
325325

326326
~ControlNet() {

diffusion_model.hpp

+7-9
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ struct UNetModel : public DiffusionModel {
3131
UNetModelRunner unet;
3232

3333
UNetModel(ggml_backend_t backend,
34-
ggml_type wtype,
34+
std::map<std::string, enum ggml_type>& tensor_types,
3535
SDVersion version = VERSION_SD1,
3636
bool flash_attn = false)
37-
: unet(backend, wtype, version, flash_attn) {
37+
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
3838
}
3939

4040
void alloc_params_buffer() {
@@ -83,9 +83,8 @@ struct MMDiTModel : public DiffusionModel {
8383
MMDiTRunner mmdit;
8484

8585
MMDiTModel(ggml_backend_t backend,
86-
ggml_type wtype,
87-
SDVersion version = VERSION_SD3_2B)
88-
: mmdit(backend, wtype, version) {
86+
std::map<std::string, enum ggml_type>& tensor_types)
87+
: mmdit(backend, tensor_types, "model.diffusion_model") {
8988
}
9089

9190
void alloc_params_buffer() {
@@ -133,10 +132,9 @@ struct FluxModel : public DiffusionModel {
133132
Flux::FluxRunner flux;
134133

135134
FluxModel(ggml_backend_t backend,
136-
ggml_type wtype,
137-
SDVersion version = VERSION_FLUX_DEV,
138-
bool flash_attn = false)
139-
: flux(backend, wtype, version, flash_attn) {
135+
std::map<std::string, enum ggml_type>& tensor_types,
136+
bool flash_attn = false)
137+
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
140138
}
141139

142140
void alloc_params_buffer() {

esrgan.hpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,9 @@ struct ESRGAN : public GGMLRunner {
142142
int scale = 4;
143143
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
144144

145-
ESRGAN(ggml_backend_t backend,
146-
ggml_type wtype)
147-
: GGMLRunner(backend, wtype) {
148-
rrdb_net.init(params_ctx, wtype);
145+
ESRGAN(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
146+
: GGMLRunner(backend) {
147+
rrdb_net.init(params_ctx, tensor_types, "");
149148
}
150149

151150
std::string get_desc() {

examples/cli/main.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1010,8 +1010,7 @@ int main(int argc, const char* argv[]) {
10101010
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
10111011
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
10121012
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
1013-
params.n_threads,
1014-
params.wtype);
1013+
params.n_threads);
10151014

10161015
if (upscaler_ctx == NULL) {
10171016
printf("new_upscaler_ctx failed\n");

0 commit comments

Comments
 (0)