Skip to content

Commit 91b63be

Browse files
committed
work
1 parent 14206fd commit 91b63be

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

stable-diffusion.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,14 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
10811081
free(sd_ctx);
10821082
}
10831083

1084+
static sd_result_cb_t sd_result_cb = NULL;
1085+
void* sd_result_cb_data = NULL;
1086+
1087+
void sd_set_result_callback(sd_result_cb_t cb, void* data) {
1088+
sd_result_cb = cb;
1089+
sd_result_cb_data = data;
1090+
}
1091+
10841092
sd_image_t* generate_image(sd_ctx_t* sd_ctx,
10851093
struct ggml_context* work_ctx,
10861094
ggml_tensor* init_latent,
@@ -1313,6 +1321,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13131321
// print_ggml_tensor(x_0);
13141322
int64_t sampling_end = ggml_time_ms();
13151323
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
1324+
1325+
if (sd_result_cb != NULL) {
1326+
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, x_0);
1327+
sd_result_cb(b + 1, sd_tensor_to_image(img), sd_result_cb_data);
1328+
continue;
1329+
}
1330+
13161331
final_latents.push_back(x_0);
13171332
}
13181333

@@ -1321,6 +1336,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13211336
}
13221337
int64_t t3 = ggml_time_ms();
13231338
LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000);
1339+
1340+
if (sd_result_cb != NULL) {
1341+
return NULL;
1342+
}
13241343

13251344
// Decode to image
13261345
LOG_INFO("decoding %zu latents", final_latents.size());

stable-diffusion.h

+2
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,11 @@ enum sd_log_level_t {
107107

108108
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
109109
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
110+
typedef void (*sd_result_cb_t)(size_t number, uint8_t* image_data, void* data);
110111

111112
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
112113
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
114+
SD_API void sd_set_result_callback(sd_result_cb_t cb, void* data);
113115
SD_API int32_t get_num_physical_cores();
114116
SD_API const char* sd_get_system_info();
115117

0 commit comments

Comments
 (0)