@@ -771,7 +771,8 @@ class StableDiffusionGGML {
771
771
sample_method_t method,
772
772
const std::vector<float >& sigmas,
773
773
int start_merge_step,
774
- SDCondition id_cond) {
774
+ SDCondition id_cond,
775
+ size_t batch_num = 0 ) {
775
776
size_t steps = sigmas.size () - 1 ;
776
777
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
777
778
// print_ggml_tensor(noise);
@@ -894,6 +895,9 @@ class StableDiffusionGGML {
894
895
pretty_progress (step, (int )steps, (t1 - t0) / 1000000 .f );
895
896
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
896
897
}
898
+
899
+ send_result_step_callback (denoised, batch_num, step);
900
+
897
901
return denoised;
898
902
};
899
903
@@ -1007,6 +1011,47 @@ class StableDiffusionGGML {
1007
1011
ggml_tensor* decode_first_stage (ggml_context* work_ctx, ggml_tensor* x) {
1008
1012
return compute_first_stage (work_ctx, x, true );
1009
1013
}
1014
+
1015
+ sd_result_cb_t result_cb = nullptr ;
1016
+ void * result_cb_data = nullptr ;
1017
+
1018
+ void send_result_callback (ggml_context* work_ctx, ggml_tensor* x, size_t number) {
1019
+ if (result_cb == nullptr ) {
1020
+ return ;
1021
+ }
1022
+
1023
+ struct ggml_tensor * img = decode_first_stage (work_ctx, x);
1024
+ auto image_data = sd_tensor_to_image (img);
1025
+
1026
+ result_cb (number, image_data, result_cb_data);
1027
+ }
1028
+
1029
+ sd_result_step_cb_t result_step_cb = nullptr ;
1030
+ void * result_step_cb_data = nullptr ;
1031
+
1032
+ void send_result_step_callback (ggml_tensor* x, size_t number, size_t step) {
1033
+ if (result_step_cb == nullptr ) {
1034
+ return ;
1035
+ }
1036
+
1037
+ struct ggml_init_params params {};
1038
+ params.mem_size = static_cast <size_t >(10 * 1024 ) * 1024 ;
1039
+ params.mem_buffer = nullptr ;
1040
+ params.no_alloc = false ;
1041
+
1042
+ struct ggml_context * work_ctx = ggml_init (params);
1043
+ if (!work_ctx) {
1044
+ return ;
1045
+ }
1046
+
1047
+ struct ggml_tensor * result = ggml_dup_tensor (work_ctx, x);
1048
+ copy_ggml_tensor (result, x);
1049
+
1050
+ struct ggml_tensor * img = decode_first_stage (work_ctx, result);
1051
+ result_step_cb (number, step, sd_tensor_to_image (img), result_step_cb_data);
1052
+
1053
+ ggml_free (work_ctx);
1054
+ }
1010
1055
};
1011
1056
1012
1057
/* ================================================= SD API ==================================================*/
@@ -1093,6 +1138,16 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
1093
1138
free (sd_ctx);
1094
1139
}
1095
1140
1141
+ void sd_ctx_set_result_callback (sd_ctx_t * sd_ctx, sd_result_cb_t cb, void * data) {
1142
+ sd_ctx->sd ->result_cb = cb;
1143
+ sd_ctx->sd ->result_cb_data = data;
1144
+ }
1145
+
1146
+ void sd_ctx_set_result_step_callback (sd_ctx_t * sd_ctx, sd_result_step_cb_t cb, void * data) {
1147
+ sd_ctx->sd ->result_step_cb = cb;
1148
+ sd_ctx->sd ->result_step_cb_data = data;
1149
+ }
1150
+
1096
1151
sd_image_t * generate_image (sd_ctx_t * sd_ctx,
1097
1152
struct ggml_context * work_ctx,
1098
1153
ggml_tensor* init_latent,
@@ -1320,11 +1375,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1320
1375
sample_method,
1321
1376
sigmas,
1322
1377
start_merge_step,
1323
- id_cond);
1378
+ id_cond,
1379
+ b);
1324
1380
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
1325
1381
// print_ggml_tensor(x_0);
1326
1382
int64_t sampling_end = ggml_time_ms ();
1327
1383
LOG_INFO (" sampling completed, taking %.2fs" , (sampling_end - sampling_start) * 1 .0f / 1000 );
1384
+
1385
+ if (sd_ctx->sd ->result_cb != nullptr ) {
1386
+ sd_ctx->sd ->send_result_callback (work_ctx, x_0, b);
1387
+ continue ;
1388
+ }
1389
+
1328
1390
final_latents.push_back (x_0);
1329
1391
}
1330
1392
@@ -1334,6 +1396,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1334
1396
int64_t t3 = ggml_time_ms ();
1335
1397
LOG_INFO (" generating %" PRId64 " latent images completed, taking %.2fs" , final_latents.size (), (t3 - t1) * 1 .0f / 1000 );
1336
1398
1399
+ if (sd_ctx->sd ->result_cb != nullptr ) {
1400
+ return nullptr ;
1401
+ }
1402
+
1337
1403
// Decode to image
1338
1404
LOG_INFO (" decoding %zu latents" , final_latents.size ());
1339
1405
std::vector<struct ggml_tensor *> decoded_images; // collect decoded images
0 commit comments