Skip to content

Commit b0940f0

Browse files
committed
make stable-diffusion.h a pure c header file
This reverts commit 27887b6.
1 parent 983e552 commit b0940f0

File tree

5 files changed

+109
-103
lines changed

5 files changed

+109
-103
lines changed

examples/cli/main.cpp

+2-97
Original file line numberDiff line numberDiff line change
@@ -9,82 +9,13 @@
99
// #include "preprocessing.hpp"
1010
#include "stable-diffusion.h"
1111

12-
#define STB_IMAGE_IMPLEMENTATION
12+
// #define STB_IMAGE_IMPLEMENTATION
1313
#include "stb_image.h"
1414

1515
#define STB_IMAGE_WRITE_IMPLEMENTATION
1616
#define STB_IMAGE_WRITE_STATIC
1717
#include "stb_image_write.h"
1818

19-
#ifdef _WIN32 // code for windows
20-
#include <windows.h>
21-
std::vector<std::string> get_files_from_dir(const std::string& dir) {
22-
23-
std::vector<std::string> files;
24-
25-
WIN32_FIND_DATA findFileData;
26-
HANDLE hFind;
27-
28-
char currentDirectory[MAX_PATH];
29-
GetCurrentDirectory(MAX_PATH, currentDirectory);
30-
31-
char directoryPath[MAX_PATH]; // this is absolute path
32-
sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str());
33-
34-
// Find the first file in the directory
35-
hFind = FindFirstFile(directoryPath, &findFileData);
36-
37-
// Check if the directory was found
38-
if (hFind == INVALID_HANDLE_VALUE) {
39-
printf("Unable to find directory %s.\n", dir.c_str());
40-
return files;
41-
}
42-
43-
// Loop through all files in the directory
44-
do {
45-
// Check if the found file is a regular file (not a directory)
46-
if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) {
47-
files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName));
48-
}
49-
} while (FindNextFile(hFind, &findFileData) != 0);
50-
51-
// Close the handle
52-
FindClose(hFind);
53-
54-
55-
sort(files.begin(), files.end());
56-
57-
return files;
58-
}
59-
#else // UNIX
60-
#include <dirent.h>
61-
#include <sys/stat.h>
62-
63-
std::vector<std::string> get_files_from_dir(const std::string &dir){
64-
65-
std::vector<std::string> files;
66-
67-
DIR* dp = opendir(dir.c_str());
68-
69-
if (dp != nullptr) {
70-
struct dirent* entry;
71-
72-
while ((entry = readdir(dp)) != nullptr) {
73-
std::string fname = dir + "/" + entry->d_name;
74-
if (!is_directory(fname))
75-
files.push_back(fname);
76-
}
77-
closedir(dp);
78-
}
79-
80-
sort(files.begin(), files.end());
81-
82-
return files;
83-
84-
}
85-
86-
#endif
87-
8819
const char* rng_type_to_str[] = {
8920
"std_default",
9021
"cuda",
@@ -759,32 +690,6 @@ int main(int argc, const char* argv[]) {
759690
false);
760691
}
761692
}
762-
std::vector<sd_image_t*> input_id_images;
763-
if (params.stacked_id_embeddings_path.size() > 0 && params.input_id_images_path.size() > 0) {
764-
std::vector<std::string> img_files = get_files_from_dir(params.input_id_images_path);
765-
for(std::string img_file : img_files){
766-
int c = 0;
767-
int width, height;
768-
input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
769-
if (input_image_buffer == NULL) {
770-
printf("PhotoMaker load image from '%s' failed\n", img_file.c_str());
771-
return 1;
772-
}else{
773-
printf("PhotoMaker loaded image from '%s'\n", img_file.c_str());
774-
}
775-
sd_image_t* input_image = NULL;
776-
input_image = new sd_image_t{(uint32_t)width,
777-
(uint32_t)height,
778-
3,
779-
input_image_buffer};
780-
input_image = preprocess_id_image(input_image);
781-
if(input_image == NULL){
782-
printf("preprocess input id image from '%s' failed\n", img_file.c_str());
783-
return 1;
784-
}
785-
input_id_images.push_back(input_image);
786-
}
787-
}
788693
results = txt2img(sd_ctx,
789694
params.prompt.c_str(),
790695
params.negative_prompt.c_str(),
@@ -800,7 +705,7 @@ int main(int argc, const char* argv[]) {
800705
params.control_strength,
801706
params.style_ratio,
802707
params.normalize_input,
803-
input_id_images);
708+
params.input_id_images_path.c_str());
804709
} else {
805710
sd_image_t input_image = {(uint32_t)params.width,
806711
(uint32_t)params.height,

stable-diffusion.cpp

+38-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include "unet.hpp"
1717
#include "vae.hpp"
1818

19+
#define STB_IMAGE_IMPLEMENTATION
20+
#include "stb_image.h"
21+
1922
// #define STB_IMAGE_WRITE_IMPLEMENTATION
2023
// #define STB_IMAGE_WRITE_STATIC
2124
// #include "stb_image_write.h"
@@ -1633,14 +1636,44 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16331636
float control_strength,
16341637
float style_ratio,
16351638
bool normalize_input,
1636-
std::vector<sd_image_t*> &input_id_images) {
1639+
const char* input_id_images_path_c_str) {
16371640
LOG_DEBUG("txt2img %dx%d", width, height);
16381641
if (sd_ctx == NULL) {
16391642
return NULL;
16401643
}
16411644
// LOG_DEBUG("%s %s %f %d %d %d", prompt_c_str, negative_prompt_c_str, cfg_scale, sample_steps, seed, batch_count);
16421645
std::string prompt(prompt_c_str);
16431646
std::string negative_prompt(negative_prompt_c_str);
1647+
std::string input_id_images_path(input_id_images_path_c_str);
1648+
1649+
1650+
// preprocess input id images
1651+
std::vector<sd_image_t*> input_id_images;
1652+
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
1653+
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
1654+
for(std::string img_file : img_files){
1655+
int c = 0;
1656+
int width, height;
1657+
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
1658+
if (input_image_buffer == NULL) {
1659+
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
1660+
continue;
1661+
}else{
1662+
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
1663+
}
1664+
sd_image_t* input_image = NULL;
1665+
input_image = new sd_image_t{(uint32_t)width,
1666+
(uint32_t)height,
1667+
3,
1668+
input_image_buffer};
1669+
input_image = preprocess_id_image(input_image);
1670+
if(input_image == NULL){
1671+
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
1672+
continue;
1673+
}
1674+
input_id_images.push_back(input_image);
1675+
}
1676+
}
16441677

16451678
// extract and remove lora
16461679
auto result_pair = extract_and_remove_lora(prompt);
@@ -1741,6 +1774,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
17411774
sd_ctx->sd->stacked_id = false;
17421775
}
17431776
}
1777+
for (sd_image_t* img : input_id_images) {
1778+
free(img->data);
1779+
}
1780+
input_id_images.clear();
17441781

17451782

17461783
t0 = ggml_time_ms();

stable-diffusion.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
145145
float control_strength,
146146
float style_strength,
147147
bool normalize_input,
148-
std::vector<sd_image_t*> &input_id_images);
148+
const char* input_id_images_path);
149149

150150
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
151151
sd_image_t init_image,
@@ -176,8 +176,6 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
176176
float strength,
177177
int64_t seed);
178178

179-
sd_image_t *preprocess_id_image(sd_image_t * img);
180-
181179
typedef struct upscaler_ctx_t upscaler_ctx_t;
182180

183181
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,

util.cpp

+64-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
#include "ggml/ggml.h"
2626
#include "stable-diffusion.h"
2727

28-
#define STB_IMAGE_RESIZE_IMPLEMENTATION
29-
#include "stb_image_resize.h"
28+
#define STB_IMAGE_RESIZE_IMPLEMENTATION
29+
#include "stb_image_resize.h"
3030

3131

3232
bool ends_with(const std::string& str, const std::string& ending) {
@@ -92,6 +92,45 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
9292
}
9393
}
9494

95+
std::vector<std::string> get_files_from_dir(const std::string& dir) {
96+
97+
std::vector<std::string> files;
98+
99+
WIN32_FIND_DATA findFileData;
100+
HANDLE hFind;
101+
102+
char currentDirectory[MAX_PATH];
103+
GetCurrentDirectory(MAX_PATH, currentDirectory);
104+
105+
char directoryPath[MAX_PATH]; // this is absolute path
106+
sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str());
107+
108+
// Find the first file in the directory
109+
hFind = FindFirstFile(directoryPath, &findFileData);
110+
111+
// Check if the directory was found
112+
if (hFind == INVALID_HANDLE_VALUE) {
113+
printf("Unable to find directory.\n");
114+
return files;
115+
}
116+
117+
// Loop through all files in the directory
118+
do {
119+
// Check if the found file is a regular file (not a directory)
120+
if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) {
121+
files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName));
122+
}
123+
} while (FindNextFile(hFind, &findFileData) != 0);
124+
125+
// Close the handle
126+
FindClose(hFind);
127+
128+
129+
sort(files.begin(), files.end());
130+
131+
return files;
132+
}
133+
95134
#else // Unix
96135
#include <dirent.h>
97136
#include <sys/stat.h>
@@ -126,6 +165,29 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
126165
return "";
127166
}
128167

168+
std::vector<std::string> get_files_from_dir(const std::string &dir){
169+
170+
std::vector<std::string> files;
171+
172+
DIR* dp = opendir(dir.c_str());
173+
174+
if (dp != nullptr) {
175+
struct dirent* entry;
176+
177+
while ((entry = readdir(dp)) != nullptr) {
178+
std::string fname = dir + "/" + entry->d_name;
179+
if (!is_directory(fname))
180+
files.push_back(fname);
181+
}
182+
closedir(dp);
183+
}
184+
185+
sort(files.begin(), files.end());
186+
187+
return files;
188+
189+
}
190+
129191
#endif
130192

131193
// get_num_physical_cores is copy from

util.h

+4
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ bool file_exists(const std::string& filename);
1818
bool is_directory(const std::string& path);
1919
std::string get_full_path(const std::string& dir, const std::string& filename);
2020

21+
std::vector<std::string> get_files_from_dir(const std::string &dir);
22+
2123
std::u32string utf8_to_utf32(const std::string& utf8_str);
2224
std::string utf32_to_utf8(const std::u32string& utf32_str);
2325
std::u32string unicode_value_to_utf32(int unicode_value);
2426

27+
sd_image_t *preprocess_id_image(sd_image_t * img);
28+
2529
//std::string sd_basename(const std::string& path);
2630

2731
typedef struct {

0 commit comments

Comments
 (0)