Skip to content

Commit 8f4ab9a

Browse files
authoredDec 28, 2024··
feat: support Inpaint models (#511)
1 parent cc92a6a commit 8f4ab9a

11 files changed

+382
-63
lines changed
 

‎conditioner.hpp

+13-13
Original file line numberDiff line numberDiff line change
@@ -61,54 +61,54 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6161
SDVersion version = VERSION_SD1,
6262
PMVersion pv = PM_VERSION_1,
6363
int clip_skip = -1)
64-
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
64+
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
6565
if (clip_skip <= 0) {
6666
clip_skip = 1;
67-
if (version == VERSION_SD2 || version == VERSION_SDXL) {
67+
if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
6868
clip_skip = 2;
6969
}
7070
}
71-
if (version == VERSION_SD1) {
71+
if (sd_version_is_sd1(version)) {
7272
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
73-
} else if (version == VERSION_SD2) {
73+
} else if (sd_version_is_sd2(version)) {
7474
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
75-
} else if (version == VERSION_SDXL) {
75+
} else if (sd_version_is_sdxl(version)) {
7676
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
7777
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
7878
}
7979
}
8080

8181
void set_clip_skip(int clip_skip) {
8282
text_model->set_clip_skip(clip_skip);
83-
if (version == VERSION_SDXL) {
83+
if (sd_version_is_sdxl(version)) {
8484
text_model2->set_clip_skip(clip_skip);
8585
}
8686
}
8787

8888
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
8989
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
90-
if (version == VERSION_SDXL) {
90+
if (sd_version_is_sdxl(version)) {
9191
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
9292
}
9393
}
9494

9595
void alloc_params_buffer() {
9696
text_model->alloc_params_buffer();
97-
if (version == VERSION_SDXL) {
97+
if (sd_version_is_sdxl(version)) {
9898
text_model2->alloc_params_buffer();
9999
}
100100
}
101101

102102
void free_params_buffer() {
103103
text_model->free_params_buffer();
104-
if (version == VERSION_SDXL) {
104+
if (sd_version_is_sdxl(version)) {
105105
text_model2->free_params_buffer();
106106
}
107107
}
108108

109109
size_t get_params_buffer_size() {
110110
size_t buffer_size = text_model->get_params_buffer_size();
111-
if (version == VERSION_SDXL) {
111+
if (sd_version_is_sdxl(version)) {
112112
buffer_size += text_model2->get_params_buffer_size();
113113
}
114114
return buffer_size;
@@ -402,7 +402,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
402402
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
403403
struct ggml_tensor* input_ids2 = NULL;
404404
size_t max_token_idx = 0;
405-
if (version == VERSION_SDXL) {
405+
if (sd_version_is_sdxl(version)) {
406406
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
407407
if (it != chunk_tokens.end()) {
408408
std::fill(std::next(it), chunk_tokens.end(), 0);
@@ -427,7 +427,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
427427
false,
428428
&chunk_hidden_states1,
429429
work_ctx);
430-
if (version == VERSION_SDXL) {
430+
if (sd_version_is_sdxl(version)) {
431431
text_model2->compute(n_threads,
432432
input_ids2,
433433
0,
@@ -486,7 +486,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
486486
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
487487

488488
ggml_tensor* vec = NULL;
489-
if (version == VERSION_SDXL) {
489+
if (sd_version_is_sdxl(version)) {
490490
int out_dim = 256;
491491
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
492492
// [0:1280]

‎control.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ class ControlNetBlock : public GGMLBlock {
3434

3535
ControlNetBlock(SDVersion version = VERSION_SD1)
3636
: version(version) {
37-
if (version == VERSION_SD2) {
37+
if (sd_version_is_sd2(version)) {
3838
context_dim = 1024;
3939
num_head_channels = 64;
4040
num_heads = -1;
41-
} else if (version == VERSION_SDXL) {
41+
} else if (sd_version_is_sdxl(version)) {
4242
context_dim = 2048;
4343
attention_resolutions = {4, 2};
4444
channel_mult = {1, 2, 4};
@@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock {
5858
// time_embed_1 is nn.SiLU()
5959
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
6060

61-
if (version == VERSION_SDXL || version == VERSION_SVD) {
61+
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
6262
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
6363
// label_emb_1 is nn.SiLU()
6464
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

‎diffusion_model.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {
133133

134134
FluxModel(ggml_backend_t backend,
135135
std::map<std::string, enum ggml_type>& tensor_types,
136-
bool flash_attn = false)
137-
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
136+
SDVersion version = VERSION_FLUX,
137+
bool flash_attn = false)
138+
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
138139
}
139140

140141
void alloc_params_buffer() {
@@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel {
174175
struct ggml_tensor** output = NULL,
175176
struct ggml_context* output_ctx = NULL,
176177
std::vector<int> skip_layers = std::vector<int>()) {
177-
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
178+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
178179
}
179180
};
180181

‎examples/cli/main.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct SDParams {
8585
std::string lora_model_dir;
8686
std::string output_path = "output.png";
8787
std::string input_path;
88+
std::string mask_path;
8889
std::string control_image_path;
8990

9091
std::string prompt;
@@ -148,6 +149,7 @@ void print_params(SDParams params) {
148149
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
149150
printf(" output_path: %s\n", params.output_path.c_str());
150151
printf(" init_img: %s\n", params.input_path.c_str());
152+
printf(" mask_img: %s\n", params.mask_path.c_str());
151153
printf(" control_image: %s\n", params.control_image_path.c_str());
152154
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
153155
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
@@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
384386
break;
385387
}
386388
params.input_path = argv[i];
389+
} else if (arg == "--mask") {
390+
if (++i >= argc) {
391+
invalid_arg = true;
392+
break;
393+
}
394+
params.mask_path = argv[i];
387395
} else if (arg == "--control-image") {
388396
if (++i >= argc) {
389397
invalid_arg = true;
@@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) {
803811
bool vae_decode_only = true;
804812
uint8_t* input_image_buffer = NULL;
805813
uint8_t* control_image_buffer = NULL;
814+
uint8_t* mask_image_buffer = NULL;
815+
806816
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
807817
vae_decode_only = false;
808818

@@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) {
907917
}
908918
}
909919

920+
if (params.mask_path != "") {
921+
int c = 0;
922+
mask_image_buffer = stbi_load(params.mask_path.c_str(), &params.width, &params.height, &c, 1);
923+
} else {
924+
std::vector<uint8_t> arr(params.width * params.height, 255);
925+
mask_image_buffer = arr.data();
926+
}
927+
sd_image_t mask_image = {(uint32_t)params.width,
928+
(uint32_t)params.height,
929+
1,
930+
mask_image_buffer};
931+
910932
sd_image_t* results;
911933
if (params.mode == TXT2IMG) {
912934
results = txt2img(sd_ctx,
@@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) {
976998
} else {
977999
results = img2img(sd_ctx,
9781000
input_image,
1001+
mask_image,
9791002
params.prompt.c_str(),
9801003
params.negative_prompt.c_str(),
9811004
params.clip_skip,

‎flux.hpp

+33-7
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ namespace Flux {
490490

491491
struct FluxParams {
492492
int64_t in_channels = 64;
493+
int64_t out_channels = 64;
493494
int64_t vec_in_dim = 768;
494495
int64_t context_in_dim = 4096;
495496
int64_t hidden_size = 3072;
@@ -642,8 +643,7 @@ namespace Flux {
642643
Flux() {}
643644
Flux(FluxParams params)
644645
: params(params) {
645-
int64_t out_channels = params.in_channels;
646-
int64_t pe_dim = params.hidden_size / params.num_heads;
646+
int64_t pe_dim = params.hidden_size / params.num_heads;
647647

648648
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
649649
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
@@ -669,7 +669,7 @@ namespace Flux {
669669
params.flash_attn));
670670
}
671671

672-
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
672+
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
673673
}
674674

675675
struct ggml_tensor* patchify(struct ggml_context* ctx,
@@ -789,6 +789,7 @@ namespace Flux {
789789
struct ggml_tensor* x,
790790
struct ggml_tensor* timestep,
791791
struct ggml_tensor* context,
792+
struct ggml_tensor* c_concat,
792793
struct ggml_tensor* y,
793794
struct ggml_tensor* guidance,
794795
struct ggml_tensor* pe,
@@ -797,6 +798,7 @@ namespace Flux {
797798
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
798799
// timestep: (N,) tensor of diffusion timesteps
799800
// context: (N, L, D)
801+
// c_concat: NULL, or for (N,C+M, H, W) for Fill
800802
// y: (N, adm_in_channels) tensor of class labels
801803
// guidance: (N,)
802804
// pe: (L, d_head/2, 2, 2)
@@ -806,6 +808,7 @@ namespace Flux {
806808

807809
int64_t W = x->ne[0];
808810
int64_t H = x->ne[1];
811+
int64_t C = x->ne[2];
809812
int64_t patch_size = 2;
810813
int pad_h = (patch_size - H % patch_size) % patch_size;
811814
int pad_w = (patch_size - W % patch_size) % patch_size;
@@ -814,6 +817,19 @@ namespace Flux {
814817
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
815818
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
816819

820+
if (c_concat != NULL) {
821+
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
822+
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
823+
824+
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
825+
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
826+
827+
masked = patchify(ctx, masked, patch_size);
828+
mask = patchify(ctx, mask, patch_size);
829+
830+
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
831+
}
832+
817833
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
818834

819835
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@@ -834,12 +850,16 @@ namespace Flux {
834850
FluxRunner(ggml_backend_t backend,
835851
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836852
const std::string prefix = "",
853+
SDVersion version = VERSION_FLUX,
837854
bool flash_attn = false)
838855
: GGMLRunner(backend) {
839856
flux_params.flash_attn = flash_attn;
840857
flux_params.guidance_embed = false;
841858
flux_params.depth = 0;
842859
flux_params.depth_single_blocks = 0;
860+
if (version == VERSION_FLUX_FILL) {
861+
flux_params.in_channels = 384;
862+
}
843863
for (auto pair : tensor_types) {
844864
std::string tensor_name = pair.first;
845865
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
@@ -886,14 +906,18 @@ namespace Flux {
886906
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
887907
struct ggml_tensor* timesteps,
888908
struct ggml_tensor* context,
909+
struct ggml_tensor* c_concat,
889910
struct ggml_tensor* y,
890911
struct ggml_tensor* guidance,
891912
std::vector<int> skip_layers = std::vector<int>()) {
892913
GGML_ASSERT(x->ne[3] == 1);
893914
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
894915

895-
x = to_backend(x);
896-
context = to_backend(context);
916+
x = to_backend(x);
917+
context = to_backend(context);
918+
if (c_concat != NULL) {
919+
c_concat = to_backend(c_concat);
920+
}
897921
y = to_backend(y);
898922
timesteps = to_backend(timesteps);
899923
if (flux_params.guidance_embed) {
@@ -913,6 +937,7 @@ namespace Flux {
913937
x,
914938
timesteps,
915939
context,
940+
c_concat,
916941
y,
917942
guidance,
918943
pe,
@@ -927,6 +952,7 @@ namespace Flux {
927952
struct ggml_tensor* x,
928953
struct ggml_tensor* timesteps,
929954
struct ggml_tensor* context,
955+
struct ggml_tensor* c_concat,
930956
struct ggml_tensor* y,
931957
struct ggml_tensor* guidance,
932958
struct ggml_tensor** output = NULL,
@@ -938,7 +964,7 @@ namespace Flux {
938964
// y: [N, adm_in_channels] or [1, adm_in_channels]
939965
// guidance: [N, ]
940966
auto get_graph = [&]() -> struct ggml_cgraph* {
941-
return build_graph(x, timesteps, context, y, guidance, skip_layers);
967+
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
942968
};
943969

944970
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@@ -978,7 +1004,7 @@ namespace Flux {
9781004
struct ggml_tensor* out = NULL;
9791005

9801006
int t0 = ggml_time_ms();
981-
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
1007+
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
9821008
int t1 = ggml_time_ms();
9831009

9841010
print_ggml_tensor(out);

‎ggml_extend.hpp

+36-1
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,42 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
290290
}
291291
}
292292

293+
__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
294+
struct ggml_tensor* output,
295+
bool scale = true) {
296+
int64_t width = output->ne[0];
297+
int64_t height = output->ne[1];
298+
int64_t channels = output->ne[2];
299+
GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32);
300+
for (int iy = 0; iy < height; iy++) {
301+
for (int ix = 0; ix < width; ix++) {
302+
float value = *(image_data + iy * width * channels + ix);
303+
if (scale) {
304+
value /= 255.f;
305+
}
306+
ggml_tensor_set_f32(output, value, ix, iy);
307+
}
308+
}
309+
}
310+
311+
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
312+
struct ggml_tensor* mask,
313+
struct ggml_tensor* output) {
314+
int64_t width = output->ne[0];
315+
int64_t height = output->ne[1];
316+
int64_t channels = output->ne[2];
317+
GGML_ASSERT(output->type == GGML_TYPE_F32);
318+
for (int ix = 0; ix < width; ix++) {
319+
for (int iy = 0; iy < height; iy++) {
320+
float m = ggml_tensor_get_f32(mask, ix, iy);
321+
for (int k = 0; k < channels; k++) {
322+
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
323+
ggml_tensor_set_f32(output, value, ix, iy, k);
324+
}
325+
}
326+
}
327+
}
328+
293329
__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
294330
struct ggml_tensor* output,
295331
int idx,
@@ -1144,7 +1180,6 @@ struct GGMLRunner {
11441180
}
11451181
#endif
11461182
ggml_backend_graph_compute(backend, gf);
1147-
11481183
#ifdef GGML_PERF
11491184
ggml_graph_print(gf);
11501185
#endif

0 commit comments

Comments
 (0)
Please sign in to comment.