@@ -1081,6 +1081,14 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
1081
1081
free (sd_ctx);
1082
1082
}
1083
1083
1084
+ static sd_result_cb_t sd_result_cb = NULL ;
1085
+ void * sd_result_cb_data = NULL ;
1086
+
1087
+ void sd_set_result_callback (sd_result_cb_t cb, void * data) {
1088
+ sd_result_cb = cb;
1089
+ sd_result_cb_data = data;
1090
+ }
1091
+
1084
1092
sd_image_t * generate_image (sd_ctx_t * sd_ctx,
1085
1093
struct ggml_context * work_ctx,
1086
1094
ggml_tensor* init_latent,
@@ -1313,6 +1321,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1313
1321
// print_ggml_tensor(x_0);
1314
1322
int64_t sampling_end = ggml_time_ms ();
1315
1323
LOG_INFO (" sampling completed, taking %.2fs" , (sampling_end - sampling_start) * 1 .0f / 1000 );
1324
+
1325
+ if (sd_result_cb != NULL ) {
1326
+ struct ggml_tensor * img = sd_ctx->sd ->decode_first_stage (work_ctx, x_0);
1327
+ sd_result_cb (b + 1 , sd_tensor_to_image (img), sd_result_cb_data);
1328
+ continue ;
1329
+ }
1330
+
1316
1331
final_latents.push_back (x_0);
1317
1332
}
1318
1333
@@ -1321,6 +1336,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1321
1336
}
1322
1337
int64_t t3 = ggml_time_ms ();
1323
1338
LOG_INFO (" generating %" PRId64 " latent images completed, taking %.2fs" , final_latents.size (), (t3 - t1) * 1 .0f / 1000 );
1339
+
1340
+ if (sd_result_cb != NULL ) {
1341
+ return NULL ;
1342
+ }
1324
1343
1325
1344
// Decode to image
1326
1345
LOG_INFO (" decoding %zu latents" , final_latents.size ());
0 commit comments