Skip to content

Commit 3585d0f

Browse files
ring-ccmdr2
authored andcommitted
Merged image preview callbacks from PR leejet#416 by @ring-c
1 parent ac54e00 commit 3585d0f

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

stable-diffusion.cpp

+68-2
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,8 @@ class StableDiffusionGGML {
771771
sample_method_t method,
772772
const std::vector<float>& sigmas,
773773
int start_merge_step,
774-
SDCondition id_cond) {
774+
SDCondition id_cond,
775+
size_t batch_num = 0) {
775776
size_t steps = sigmas.size() - 1;
776777
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
777778
// print_ggml_tensor(noise);
@@ -894,6 +895,9 @@ class StableDiffusionGGML {
894895
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
895896
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
896897
}
898+
899+
send_result_step_callback(denoised, batch_num, step);
900+
897901
return denoised;
898902
};
899903

@@ -1007,6 +1011,47 @@ class StableDiffusionGGML {
10071011
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
10081012
return compute_first_stage(work_ctx, x, true);
10091013
}
1014+
1015+
sd_result_cb_t result_cb = nullptr;
1016+
void* result_cb_data = nullptr;
1017+
1018+
void send_result_callback(ggml_context* work_ctx, ggml_tensor* x, size_t number) {
1019+
if (result_cb == nullptr) {
1020+
return;
1021+
}
1022+
1023+
struct ggml_tensor* img = decode_first_stage(work_ctx, x);
1024+
auto image_data = sd_tensor_to_image(img);
1025+
1026+
result_cb(number, image_data, result_cb_data);
1027+
}
1028+
1029+
sd_result_step_cb_t result_step_cb = nullptr;
1030+
void* result_step_cb_data = nullptr;
1031+
1032+
void send_result_step_callback(ggml_tensor* x, size_t number, size_t step) {
1033+
if (result_step_cb == nullptr) {
1034+
return;
1035+
}
1036+
1037+
struct ggml_init_params params {};
1038+
params.mem_size = static_cast<size_t>(10 * 1024) * 1024;
1039+
params.mem_buffer = nullptr;
1040+
params.no_alloc = false;
1041+
1042+
struct ggml_context* work_ctx = ggml_init(params);
1043+
if (!work_ctx) {
1044+
return;
1045+
}
1046+
1047+
struct ggml_tensor* result = ggml_dup_tensor(work_ctx, x);
1048+
copy_ggml_tensor(result, x);
1049+
1050+
struct ggml_tensor* img = decode_first_stage(work_ctx, result);
1051+
result_step_cb(number, step, sd_tensor_to_image(img), result_step_cb_data);
1052+
1053+
ggml_free(work_ctx);
1054+
}
10101055
};
10111056

10121057
/*================================================= SD API ==================================================*/
@@ -1093,6 +1138,16 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
10931138
free(sd_ctx);
10941139
}
10951140

1141+
void sd_ctx_set_result_callback(sd_ctx_t* sd_ctx, sd_result_cb_t cb, void* data) {
1142+
sd_ctx->sd->result_cb = cb;
1143+
sd_ctx->sd->result_cb_data = data;
1144+
}
1145+
1146+
void sd_ctx_set_result_step_callback(sd_ctx_t* sd_ctx, sd_result_step_cb_t cb, void* data) {
1147+
sd_ctx->sd->result_step_cb = cb;
1148+
sd_ctx->sd->result_step_cb_data = data;
1149+
}
1150+
10961151
sd_image_t* generate_image(sd_ctx_t* sd_ctx,
10971152
struct ggml_context* work_ctx,
10981153
ggml_tensor* init_latent,
@@ -1320,11 +1375,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13201375
sample_method,
13211376
sigmas,
13221377
start_merge_step,
1323-
id_cond);
1378+
id_cond,
1379+
b);
13241380
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
13251381
// print_ggml_tensor(x_0);
13261382
int64_t sampling_end = ggml_time_ms();
13271383
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
1384+
1385+
if (sd_ctx->sd->result_cb != nullptr) {
1386+
sd_ctx->sd->send_result_callback(work_ctx, x_0, b);
1387+
continue;
1388+
}
1389+
13281390
final_latents.push_back(x_0);
13291391
}
13301392

@@ -1334,6 +1396,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13341396
int64_t t3 = ggml_time_ms();
13351397
LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000);
13361398

1399+
if (sd_ctx->sd->result_cb != nullptr) {
1400+
return nullptr;
1401+
}
1402+
13371403
// Decode to image
13381404
LOG_INFO("decoding %zu latents", final_latents.size());
13391405
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images

stable-diffusion.h

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ enum sd_log_level_t {
107107

108108
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
109109
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
110+
typedef void (*sd_result_cb_t)(size_t number, uint8_t* image_data, void* data);
111+
typedef void (*sd_result_step_cb_t)(size_t number, size_t step, uint8_t* image_data, void* data);
110112

111113
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
112114
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
@@ -145,6 +147,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
145147
bool keep_vae_on_cpu);
146148

147149
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
150+
SD_API void sd_ctx_set_result_callback(sd_ctx_t* sd_ctx, sd_result_cb_t cb, void* data);
151+
SD_API void sd_ctx_set_result_step_callback(sd_ctx_t* sd_ctx, sd_result_step_cb_t cb, void* data);
148152

149153
SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
150154
const char* prompt,

0 commit comments

Comments
 (0)