Skip to content

Commit 036ba9e

Browse files
committedApr 14, 2024·
feat: enable controlnet and photo maker for img2img mode
1 parent ec82d52 commit 036ba9e

File tree

5 files changed

+257
-258
lines changed

5 files changed

+257
-258
lines changed
 

‎examples/cli/main.cpp

+45-35
Original file line numberDiff line numberDiff line change
@@ -656,13 +656,16 @@ int main(int argc, const char* argv[]) {
656656
return 1;
657657
}
658658

659-
bool vae_decode_only = true;
660-
uint8_t* input_image_buffer = NULL;
659+
bool vae_decode_only = true;
660+
uint8_t* input_image_buffer = NULL;
661+
uint8_t* control_image_buffer = NULL;
661662
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
662663
vae_decode_only = false;
663664

664665
int c = 0;
665-
input_image_buffer = stbi_load(params.input_path.c_str(), &params.width, &params.height, &c, 3);
666+
int width = 0;
667+
int height = 0;
668+
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
666669
if (input_image_buffer == NULL) {
667670
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
668671
return 1;
@@ -672,29 +675,30 @@ int main(int argc, const char* argv[]) {
672675
free(input_image_buffer);
673676
return 1;
674677
}
675-
if (params.width <= 0) {
678+
if (width <= 0) {
676679
fprintf(stderr, "error: the width of image must be greater than 0\n");
677680
free(input_image_buffer);
678681
return 1;
679682
}
680-
if (params.height <= 0) {
683+
if (height <= 0) {
681684
fprintf(stderr, "error: the height of image must be greater than 0\n");
682685
free(input_image_buffer);
683686
return 1;
684687
}
685688

686689
// Resize input image ...
687-
if (params.height % 64 != 0 || params.width % 64 != 0) {
688-
int resized_height = params.height + (64 - params.height % 64);
689-
int resized_width = params.width + (64 - params.width % 64);
690+
if (params.height != height || params.width != width) {
691+
printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
692+
int resized_height = params.height;
693+
int resized_width = params.width;
690694

691695
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
692696
if (resized_image_buffer == NULL) {
693697
fprintf(stderr, "error: allocate memory for resize input image\n");
694698
free(input_image_buffer);
695699
return 1;
696700
}
697-
stbir_resize(input_image_buffer, params.width, params.height, 0,
701+
stbir_resize(input_image_buffer, width, height, 0,
698702
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
699703
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
700704
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
@@ -704,8 +708,6 @@ int main(int argc, const char* argv[]) {
704708
// Save resized result
705709
free(input_image_buffer);
706710
input_image_buffer = resized_image_buffer;
707-
params.height = resized_height;
708-
params.width = resized_width;
709711
}
710712
}
711713

@@ -732,31 +734,32 @@ int main(int argc, const char* argv[]) {
732734
return 1;
733735
}
734736

737+
sd_image_t* control_image = NULL;
738+
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
739+
int c = 0;
740+
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
741+
if (control_image_buffer == NULL) {
742+
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
743+
return 1;
744+
}
745+
control_image = new sd_image_t{(uint32_t)params.width,
746+
(uint32_t)params.height,
747+
3,
748+
control_image_buffer};
749+
if (params.canny_preprocess) { // apply preprocessor
750+
control_image->data = preprocess_canny(control_image->data,
751+
control_image->width,
752+
control_image->height,
753+
0.08f,
754+
0.08f,
755+
0.8f,
756+
1.0f,
757+
false);
758+
}
759+
}
760+
735761
sd_image_t* results;
736762
if (params.mode == TXT2IMG) {
737-
sd_image_t* control_image = NULL;
738-
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
739-
int c = 0;
740-
input_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
741-
if (input_image_buffer == NULL) {
742-
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
743-
return 1;
744-
}
745-
control_image = new sd_image_t{(uint32_t)params.width,
746-
(uint32_t)params.height,
747-
3,
748-
input_image_buffer};
749-
if (params.canny_preprocess) { // apply preprocessor
750-
control_image->data = preprocess_canny(control_image->data,
751-
control_image->width,
752-
control_image->height,
753-
0.08f,
754-
0.08f,
755-
0.8f,
756-
1.0f,
757-
false);
758-
}
759-
}
760763
results = txt2img(sd_ctx,
761764
params.prompt.c_str(),
762765
params.negative_prompt.c_str(),
@@ -828,7 +831,12 @@ int main(int argc, const char* argv[]) {
828831
params.sample_steps,
829832
params.strength,
830833
params.seed,
831-
params.batch_count);
834+
params.batch_count,
835+
control_image,
836+
params.control_strength,
837+
params.style_ratio,
838+
params.normalize_input,
839+
params.input_id_images_path.c_str());
832840
}
833841
}
834842

@@ -881,6 +889,8 @@ int main(int argc, const char* argv[]) {
881889
}
882890
free(results);
883891
free_sd_ctx(sd_ctx);
892+
free(control_image_buffer);
893+
free(input_image_buffer);
884894

885895
return 0;
886896
}

‎ggml_extend.hpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -752,10 +752,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
752752
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
753753
}
754754

755-
756-
__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context * ctx) {
755+
__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
757756
size_t num = 0;
758-
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
757+
for (ggml_tensor* t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
759758
num++;
760759
}
761760
return num;
@@ -851,7 +850,7 @@ struct GGMLModule {
851850
}
852851

853852
public:
854-
virtual std::string get_desc() = 0;
853+
virtual std::string get_desc() = 0;
855854

856855
GGMLModule(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
857856
: backend(backend), wtype(wtype) {

‎stable-diffusion.cpp

+202-217
Large diffs are not rendered by default.

‎stable-diffusion.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,12 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
160160
int sample_steps,
161161
float strength,
162162
int64_t seed,
163-
int batch_count);
163+
int batch_count,
164+
const sd_image_t* control_cond,
165+
float control_strength,
166+
float style_strength,
167+
bool normalize_input,
168+
const char* input_id_images_path);
164169

165170
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
166171
sd_image_t init_image,

‎tae.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ struct TinyAutoEncoder : public GGMLModule {
201201
}
202202

203203
bool load_from_file(const std::string& file_path) {
204-
LOG_INFO("loading taesd from '%s'", file_path.c_str());
204+
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
205205
alloc_params_buffer();
206206
std::map<std::string, ggml_tensor*> taesd_tensors;
207207
taesd.get_param_tensors(taesd_tensors);

0 commit comments

Comments
 (0)
Please sign in to comment.