Skip to content

Commit ac54e00

Browse files
authored
feat: add sd3.5 support (leejet#445)
1 parent 14206fd commit ac54e00

13 files changed

+250
-127
lines changed

README.md

+12-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
1010

1111
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
1212
- Super lightweight and without external dependencies
13-
- SD1.x, SD2.x, SDXL and SD3 support
13+
- SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support
1414
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
1515
- [Flux-dev/Flux-schnell Support](./docs/flux.md)
1616

@@ -197,23 +197,24 @@ usage: ./bin/sd [arguments]
197197
arguments:
198198
-h, --help show this help message and exit
199199
-M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)
200-
-t, --threads N number of threads to use during computation (default: -1).
200+
-t, --threads N number of threads to use during computation (default: -1)
201201
If threads <= 0, then threads will be set to the number of CPU physical cores
202202
-m, --model [MODEL] path to full model
203203
--diffusion-model path to the standalone diffusion model
204204
--clip_l path to the clip-l text encoder
205-
--t5xxl path to the the t5xxl text encoder.
205+
--clip_g path to the clip-l text encoder
206+
--t5xxl path to the the t5xxl text encoder
206207
--vae [VAE] path to vae
207208
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
208209
--control-net [CONTROL_PATH] path to control net model
209-
--embd-dir [EMBEDDING_PATH] path to embeddings.
210-
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.
211-
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.
210+
--embd-dir [EMBEDDING_PATH] path to embeddings
211+
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings
212+
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir
212213
--normalize-input normalize PHOTOMAKER input id images
213-
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
214+
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
214215
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
215216
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)
216-
If not specified, the default is the type of the weight file.
217+
If not specified, the default is the type of the weight file
217218
--lora-model-dir [DIR] lora model directory
218219
-i, --init-img [IMAGE] path to the input image, required by img2img
219220
--control-image [IMAGE] path to image condition, control net
@@ -232,13 +233,13 @@ arguments:
232233
--steps STEPS number of sample steps (default: 20)
233234
--rng {std_default, cuda} RNG (default: cuda)
234235
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
235-
-b, --batch-count COUNT number of images to generate.
236+
-b, --batch-count COUNT number of images to generate
236237
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
237238
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
238239
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
239240
--vae-tiling process vae in tiles to reduce memory usage
240241
--vae-on-cpu keep vae in cpu (for low vram)
241-
--clip-on-cpu keep clip in cpu (for low vram).
242+
--clip-on-cpu keep clip in cpu (for low vram)
242243
--control-net-cpu keep controlnet in cpu (for low vram)
243244
--canny apply canny preprocessor (edge detection)
244245
--color Colors the logging tags according to level
@@ -253,6 +254,7 @@ arguments:
253254
# ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v
254255
# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v
255256
# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v
257+
# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
256258
```
257259

258260
Using formats of different precisions will yield results of varying quality.

assets/sd3.5_large.png

1.81 MB
Loading

conditioner.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1001,8 +1001,8 @@ struct FluxCLIPEmbedder : public Conditioner {
10011001
}
10021002

10031003
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1004-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model");
1005-
t5->get_param_tensors(tensors, "text_encoders.t5xxl");
1004+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1005+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
10061006
}
10071007

10081008
void alloc_params_buffer() {

denoiser.hpp

+38-41
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct ExponentialSchedule : SigmaSchedule {
4949
// Calculate step size
5050
float log_sigma_min = std::log(sigma_min);
5151
float log_sigma_max = std::log(sigma_max);
52-
float step = (log_sigma_max - log_sigma_min) / (n - 1);
52+
float step = (log_sigma_max - log_sigma_min) / (n - 1);
5353

5454
// Fill sigmas with exponential values
5555
for (uint32_t i = 0; i < n; ++i) {
@@ -205,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
205205

206206
/*
207207
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
208-
*/
208+
*/
209209
struct GITSSchedule : SigmaSchedule {
210210
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
211211
if (sigma_max <= 0.0f) {
@@ -221,7 +221,7 @@ struct GITSSchedule : SigmaSchedule {
221221
// Calculate the index based on the coefficient
222222
int index = static_cast<int>((coeff - 0.80f) / 0.05f);
223223
// Ensure the index is within bounds
224-
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
224+
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
225225
const std::vector<std::vector<float>>& selected_noise = *GITS_NOISE[index];
226226

227227
if (n <= 20) {
@@ -823,24 +823,24 @@ static void sample_k_diffusion(sample_method_t method,
823823
} break;
824824
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
825825
{
826-
int max_order = 4;
826+
int max_order = 4;
827827
ggml_tensor* x_next = x;
828828
std::vector<ggml_tensor*> buffer_model;
829829

830830
for (int i = 0; i < steps; i++) {
831-
float sigma = sigmas[i];
831+
float sigma = sigmas[i];
832832
float sigma_next = sigmas[i + 1];
833833

834834
ggml_tensor* x_cur = x_next;
835-
float* vec_x_cur = (float*)x_cur->data;
836-
float* vec_x_next = (float*)x_next->data;
835+
float* vec_x_cur = (float*)x_cur->data;
836+
float* vec_x_next = (float*)x_next->data;
837837

838838
// Denoising step
839839
ggml_tensor* denoised = model(x_cur, sigma, i + 1);
840-
float* vec_denoised = (float*)denoised->data;
840+
float* vec_denoised = (float*)denoised->data;
841841
// d_cur = (x_cur - denoised) / sigma
842842
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur);
843-
float* vec_d_cur = (float*)d_cur->data;
843+
float* vec_d_cur = (float*)d_cur->data;
844844

845845
for (int j = 0; j < ggml_nelements(d_cur); j++) {
846846
vec_d_cur[j] = (vec_x_cur[j] - vec_denoised[j]) / sigma;
@@ -857,34 +857,31 @@ static void sample_k_diffusion(sample_method_t method,
857857
break;
858858

859859
case 2: // Use one history point
860-
{
861-
float* vec_d_prev1 = (float*)buffer_model.back()->data;
862-
for (int j = 0; j < ggml_nelements(x_next); j++) {
863-
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
864-
}
860+
{
861+
float* vec_d_prev1 = (float*)buffer_model.back()->data;
862+
for (int j = 0; j < ggml_nelements(x_next); j++) {
863+
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
865864
}
866-
break;
865+
} break;
867866

868867
case 3: // Use two history points
869-
{
870-
float* vec_d_prev1 = (float*)buffer_model.back()->data;
871-
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
872-
for (int j = 0; j < ggml_nelements(x_next); j++) {
873-
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
874-
}
868+
{
869+
float* vec_d_prev1 = (float*)buffer_model.back()->data;
870+
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
871+
for (int j = 0; j < ggml_nelements(x_next); j++) {
872+
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
875873
}
876-
break;
874+
} break;
877875

878876
case 4: // Use three history points
879-
{
880-
float* vec_d_prev1 = (float*)buffer_model.back()->data;
881-
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
882-
float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data;
883-
for (int j = 0; j < ggml_nelements(x_next); j++) {
884-
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
885-
}
877+
{
878+
float* vec_d_prev1 = (float*)buffer_model.back()->data;
879+
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
880+
float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data;
881+
for (int j = 0; j < ggml_nelements(x_next); j++) {
882+
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
886883
}
887-
break;
884+
} break;
888885
}
889886

890887
// Manage buffer_model
@@ -906,27 +903,27 @@ static void sample_k_diffusion(sample_method_t method,
906903
ggml_tensor* x_next = x;
907904

908905
for (int i = 0; i < steps; i++) {
909-
float sigma = sigmas[i];
906+
float sigma = sigmas[i];
910907
float t_next = sigmas[i + 1];
911908

912909
// Denoising step
913-
ggml_tensor* denoised = model(x, sigma, i + 1);
914-
float* vec_denoised = (float*)denoised->data;
910+
ggml_tensor* denoised = model(x, sigma, i + 1);
911+
float* vec_denoised = (float*)denoised->data;
915912
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x);
916-
float* vec_d_cur = (float*)d_cur->data;
917-
float* vec_x = (float*)x->data;
913+
float* vec_d_cur = (float*)d_cur->data;
914+
float* vec_x = (float*)x->data;
918915

919916
// d_cur = (x - denoised) / sigma
920917
for (int j = 0; j < ggml_nelements(d_cur); j++) {
921918
vec_d_cur[j] = (vec_x[j] - vec_denoised[j]) / sigma;
922919
}
923920

924-
int order = std::min(max_order, i + 1);
925-
float h_n = t_next - sigma;
921+
int order = std::min(max_order, i + 1);
922+
float h_n = t_next - sigma;
926923
float h_n_1 = (i > 0) ? (sigma - sigmas[i - 1]) : h_n;
927924

928925
switch (order) {
929-
case 1: // First Euler step
926+
case 1: // First Euler step
930927
for (int j = 0; j < ggml_nelements(x_next); j++) {
931928
vec_x[j] += vec_d_cur[j] * h_n;
932929
}
@@ -941,7 +938,7 @@ static void sample_k_diffusion(sample_method_t method,
941938
}
942939

943940
case 3: {
944-
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
941+
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
945942
float* vec_d_prev1 = (float*)buffer_model.back()->data;
946943
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
947944
for (int j = 0; j < ggml_nelements(x_next); j++) {
@@ -951,8 +948,8 @@ static void sample_k_diffusion(sample_method_t method,
951948
}
952949

953950
case 4: {
954-
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
955-
float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2;
951+
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
952+
float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2;
956953
float* vec_d_prev1 = (float*)buffer_model.back()->data;
957954
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
958955
float* vec_d_prev3 = (buffer_model.size() > 2) ? (float*)buffer_model[buffer_model.size() - 3]->data : vec_d_prev2;

docs/sd3.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# How to Use
2+
3+
## Download weights
4+
5+
- Download sd3.5_large from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors
6+
- Download clip_g from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_g.safetensors
7+
- Download clip_l from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_l.safetensors
8+
- Download t5xxl from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/t5xxl_fp16.safetensors
9+
10+
11+
## Run
12+
13+
### SD3.5 Large
14+
For example:
15+
16+
```
17+
.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
18+
```
19+
20+
![](../assets/sd3.5_large.png)

0 commit comments

Comments
 (0)