Skip to content

Commit 7be65fa

Browse files
authored
feat: add progress callback (#170)
1 parent d164236 commit 7be65fa

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

clip.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
891891
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
892892
return false;
893893
}
894+
if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
895+
LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
896+
return true;
897+
}
894898
struct ggml_init_params params;
895899
params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
896900
params.mem_buffer = NULL;

lora.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct LoraModel : public GGMLModule {
3333
return model_loader.get_params_mem_size(NULL);
3434
}
3535

36+
3637
bool load_from_file() {
3738
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
3839

@@ -55,6 +56,7 @@ struct LoraModel : public GGMLModule {
5556
auto real = lora_tensors[name];
5657
*dst_tensor = real;
5758
}
59+
5860
return true;
5961
};
6062

@@ -64,6 +66,7 @@ struct LoraModel : public GGMLModule {
6466
dry_run = false;
6567
model_loader.load_tensors(on_new_tensor_cb, backend);
6668

69+
6770
LOG_DEBUG("finished loaded lora");
6871
return true;
6972
}

stable-diffusion.h

+2
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ enum sd_log_level_t {
9292
};
9393

9494
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
95+
typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data);
9596

9697
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
98+
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
9799
SD_API int32_t get_num_physical_cores();
98100
SD_API const char* sd_get_system_info();
99101

util.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ int32_t get_num_physical_cores() {
161161
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
162162
}
163163

164+
static sd_progress_cb_t sd_progress_cb = NULL;
165+
void* sd_progress_cb_data = NULL;
166+
164167
std::u32string utf8_to_utf32(const std::string& utf8_str) {
165168
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
166169
return converter.from_bytes(utf8_str);
@@ -205,6 +208,10 @@ std::string path_join(const std::string& p1, const std::string& p2) {
205208
}
206209

207210
void pretty_progress(int step, int steps, float time) {
211+
if (sd_progress_cb) {
212+
sd_progress_cb(step,steps,time, sd_progress_cb_data);
213+
return;
214+
}
208215
if (step == 0) {
209216
return;
210217
}
@@ -248,8 +255,9 @@ std::string trim(const std::string& s) {
248255
return rtrim(ltrim(s));
249256
}
250257

251-
static sd_log_cb_t sd_log_cb = NULL;
252-
void* sd_log_cb_data = NULL;
258+
static sd_log_cb_t sd_log_cb = NULL;
259+
void* sd_log_cb_data = NULL;
260+
253261

254262
#define LOG_BUFFER_SIZE 1024
255263

@@ -286,7 +294,10 @@ void sd_set_log_callback(sd_log_cb_t cb, void* data) {
286294
sd_log_cb = cb;
287295
sd_log_cb_data = data;
288296
}
289-
297+
void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
298+
sd_progress_cb = cb;
299+
sd_progress_cb_data = data;
300+
}
290301
const char* sd_get_system_info() {
291302
static char buffer[1024];
292303
std::stringstream ss;

vae.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
/*================================================== AutoEncoderKL ===================================================*/
88

9-
#define VAE_GRAPH_SIZE 10240
9+
#define VAE_GRAPH_SIZE 20480
1010

1111
class ResnetBlock : public UnaryBlock {
1212
protected:

0 commit comments

Comments
 (0)