@@ -160,7 +160,7 @@ void print_params(SDParams params) {
160
160
printf (" sample_steps: %d\n " , params.sample_steps );
161
161
printf (" strength(img2img): %.2f\n " , params.strength );
162
162
printf (" rng: %s\n " , rng_type_to_str[params.rng_type ]);
163
- printf (" seed: %ld \n " , params.seed );
163
+ printf (" seed: %lld \n " , params.seed );
164
164
printf (" batch_count: %d\n " , params.batch_count );
165
165
printf (" vae_tiling: %s\n " , params.vae_tiling ? " true" : " false" );
166
166
printf (" upscale_repeats: %d\n " , params.upscale_repeats );
@@ -683,7 +683,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
683
683
684
684
int main (int argc, const char * argv[]) {
685
685
SDParams params;
686
-
687
686
parse_args (argc, argv, params);
688
687
689
688
sd_set_log_callback (sd_log_cb, (void *)¶ms);
@@ -716,101 +715,93 @@ int main(int argc, const char* argv[]) {
716
715
return 1 ;
717
716
}
718
717
719
- bool vae_decode_only = true ;
720
- uint8_t * input_image_buffer = NULL ;
721
- uint8_t * control_image_buffer = NULL ;
718
+ bool vae_decode_only = true ;
719
+ std::unique_ptr<uint8_t , decltype (&stbi_image_free)> input_image_buffer (nullptr , stbi_image_free);
720
+ std::unique_ptr<uint8_t , decltype (&stbi_image_free)> control_image_buffer (nullptr , stbi_image_free);
721
+
722
722
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
723
723
vae_decode_only = false ;
724
724
725
- int c = 0 ;
726
- int width = 0 ;
727
- int height = 0 ;
728
- input_image_buffer = stbi_load (params.input_path .c_str (), &width, &height, &c, 3 );
729
- if (input_image_buffer == NULL ) {
725
+ int c = 0 , width = 0 , height = 0 ;
726
+ input_image_buffer.reset (stbi_load (params.input_path .c_str (), &width, &height, &c, 3 ));
727
+ if (!input_image_buffer) {
730
728
fprintf (stderr, " load image from '%s' failed\n " , params.input_path .c_str ());
731
729
return 1 ;
732
730
}
733
731
if (c < 3 ) {
734
732
fprintf (stderr, " the number of channels for the input image must be >= 3, but got %d channels\n " , c);
735
- free (input_image_buffer);
736
733
return 1 ;
737
734
}
738
- if (width <= 0 ) {
739
- fprintf (stderr, " error: the width of image must be greater than 0\n " );
740
- free (input_image_buffer);
741
- return 1 ;
742
- }
743
- if (height <= 0 ) {
744
- fprintf (stderr, " error: the height of image must be greater than 0\n " );
745
- free (input_image_buffer);
735
+ if (width <= 0 || height <= 0 ) {
736
+ fprintf (stderr, " error: the dimensions of image must be greater than 0\n " );
746
737
return 1 ;
747
738
}
748
739
749
- // Resize input image ...
750
740
if (params.height != height || params.width != width) {
751
741
printf (" resize input image from %dx%d to %dx%d\n " , width, height, params.width , params.height );
742
+
752
743
int resized_height = params.height ;
753
744
int resized_width = params.width ;
754
745
755
- uint8_t * resized_image_buffer = (uint8_t *)malloc (resized_height * resized_width * 3 );
756
- if (resized_image_buffer == NULL ) {
746
+ std::unique_ptr<uint8_t , decltype (&free)> resized_image_buffer (
747
+ static_cast <uint8_t *>(malloc (resized_height * resized_width * 3 )), free);
748
+ if (!resized_image_buffer) {
757
749
fprintf (stderr, " error: allocate memory for resize input image\n " );
758
- free (input_image_buffer);
759
750
return 1 ;
760
751
}
761
- stbir_resize (input_image_buffer, width, height, 0 ,
762
- resized_image_buffer, resized_width, resized_height, 0 , STBIR_TYPE_UINT8,
752
+ stbir_resize (input_image_buffer. get () , width, height, 0 ,
753
+ resized_image_buffer. get () , resized_width, resized_height, 0 , STBIR_TYPE_UINT8,
763
754
3 /* RGB channel*/ , STBIR_ALPHA_CHANNEL_NONE, 0 ,
764
755
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
765
756
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
766
757
STBIR_COLORSPACE_SRGB, nullptr );
767
758
768
- // Save resized result
769
- free (input_image_buffer);
770
- input_image_buffer = resized_image_buffer;
759
+ input_image_buffer.swap (resized_image_buffer);
771
760
}
772
761
}
773
762
774
- sd_ctx_t * sd_ctx = new_sd_ctx (params.model_path .c_str (),
775
- params.clip_l_path .c_str (),
776
- params.clip_g_path .c_str (),
777
- params.t5xxl_path .c_str (),
778
- params.diffusion_model_path .c_str (),
779
- params.vae_path .c_str (),
780
- params.taesd_path .c_str (),
781
- params.controlnet_path .c_str (),
782
- params.lora_model_dir .c_str (),
783
- params.embeddings_path .c_str (),
784
- params.stacked_id_embeddings_path .c_str (),
785
- vae_decode_only,
786
- params.vae_tiling ,
787
- true ,
788
- params.n_threads ,
789
- params.wtype ,
790
- params.rng_type ,
791
- params.schedule ,
792
- params.clip_on_cpu ,
793
- params.control_net_cpu ,
794
- params.vae_on_cpu );
795
-
796
- if (sd_ctx == NULL ) {
763
+ auto sd_ctx = std::unique_ptr<sd_ctx_t , decltype (&free_sd_ctx)>(
764
+ new_sd_ctx (params.model_path .c_str (),
765
+ params.clip_l_path .c_str (),
766
+ params.clip_g_path .c_str (),
767
+ params.t5xxl_path .c_str (),
768
+ params.diffusion_model_path .c_str (),
769
+ params.vae_path .c_str (),
770
+ params.taesd_path .c_str (),
771
+ params.controlnet_path .c_str (),
772
+ params.lora_model_dir .c_str (),
773
+ params.embeddings_path .c_str (),
774
+ params.stacked_id_embeddings_path .c_str (),
775
+ vae_decode_only,
776
+ params.vae_tiling ,
777
+ true ,
778
+ params.n_threads ,
779
+ params.wtype ,
780
+ params.rng_type ,
781
+ params.schedule ,
782
+ params.clip_on_cpu ,
783
+ params.control_net_cpu ,
784
+ params.vae_on_cpu ),
785
+ free_sd_ctx);
786
+ if (!sd_ctx) {
797
787
printf (" new_sd_ctx_t failed\n " );
798
788
return 1 ;
799
789
}
800
790
801
- sd_image_t * control_image = NULL ;
802
- if (params.controlnet_path .size () > 0 && params.control_image_path .size () > 0 ) {
803
- int c = 0 ;
804
- control_image_buffer = stbi_load (params.control_image_path .c_str (), ¶ms.width , ¶ms.height , &c, 3 );
805
- if (control_image_buffer == NULL ) {
791
+ std::unique_ptr< sd_image_t > control_image;
792
+ if (! params.controlnet_path .empty () && ! params.control_image_path .empty () ) {
793
+ int c = 0 ;
794
+ control_image_buffer. reset ( stbi_load (params.control_image_path .c_str (), ¶ms.width , ¶ms.height , &c, 3 ) );
795
+ if (! control_image_buffer) {
806
796
fprintf (stderr, " load image from '%s' failed\n " , params.control_image_path .c_str ());
807
797
return 1 ;
808
798
}
809
- control_image = new sd_image_t {(uint32_t )params.width ,
810
- (uint32_t )params.height ,
811
- 3 ,
812
- control_image_buffer};
813
- if (params.canny_preprocess ) { // apply preprocessor
799
+ control_image = std::make_unique<sd_image_t >(
800
+ sd_image_t {static_cast <uint32_t >(params.width ),
801
+ static_cast <uint32_t >(params.height ),
802
+ 3 ,
803
+ control_image_buffer.get ()});
804
+ if (params.canny_preprocess ) {
814
805
control_image->data = preprocess_canny (control_image->data ,
815
806
control_image->width ,
816
807
control_image->height ,
@@ -822,70 +813,9 @@ int main(int argc, const char* argv[]) {
822
813
}
823
814
}
824
815
825
- sd_image_t * results;
816
+ std::unique_ptr< sd_image_t [], decltype (&free)> results ( nullptr , free) ;
826
817
if (params.mode == TXT2IMG) {
827
- results = txt2img (sd_ctx,
828
- params.prompt .c_str (),
829
- params.negative_prompt .c_str (),
830
- params.clip_skip ,
831
- params.cfg_scale ,
832
- params.guidance ,
833
- params.width ,
834
- params.height ,
835
- params.sample_method ,
836
- params.sample_steps ,
837
- params.seed ,
838
- params.batch_count ,
839
- control_image,
840
- params.control_strength ,
841
- params.style_ratio ,
842
- params.normalize_input ,
843
- params.input_id_images_path .c_str ());
844
- } else {
845
- sd_image_t input_image = {(uint32_t )params.width ,
846
- (uint32_t )params.height ,
847
- 3 ,
848
- input_image_buffer};
849
-
850
- if (params.mode == IMG2VID) {
851
- results = img2vid (sd_ctx,
852
- input_image,
853
- params.width ,
854
- params.height ,
855
- params.video_frames ,
856
- params.motion_bucket_id ,
857
- params.fps ,
858
- params.augmentation_level ,
859
- params.min_cfg ,
860
- params.cfg_scale ,
861
- params.sample_method ,
862
- params.sample_steps ,
863
- params.strength ,
864
- params.seed );
865
- if (results == NULL ) {
866
- printf (" generate failed\n " );
867
- free_sd_ctx (sd_ctx);
868
- return 1 ;
869
- }
870
- size_t last = params.output_path .find_last_of (" ." );
871
- std::string dummy_name = last != std::string::npos ? params.output_path .substr (0 , last) : params.output_path ;
872
- for (int i = 0 ; i < params.video_frames ; i++) {
873
- if (results[i].data == NULL ) {
874
- continue ;
875
- }
876
- std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 ) + " .png" : dummy_name + " .png" ;
877
- stbi_write_png (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
878
- results[i].data , 0 , get_image_params (params, params.seed + i).c_str ());
879
- printf (" save result image to '%s'\n " , final_image_path.c_str ());
880
- free (results[i].data );
881
- results[i].data = NULL ;
882
- }
883
- free (results);
884
- free_sd_ctx (sd_ctx);
885
- return 0 ;
886
- } else {
887
- results = img2img (sd_ctx,
888
- input_image,
818
+ results.reset (txt2img (sd_ctx.get (),
889
819
params.prompt .c_str (),
890
820
params.negative_prompt .c_str (),
891
821
params.clip_skip ,
@@ -895,68 +825,49 @@ int main(int argc, const char* argv[]) {
895
825
params.height ,
896
826
params.sample_method ,
897
827
params.sample_steps ,
898
- params.strength ,
899
828
params.seed ,
900
829
params.batch_count ,
901
- control_image,
830
+ control_image. get () ,
902
831
params.control_strength ,
903
832
params.style_ratio ,
904
833
params.normalize_input ,
905
- params.input_id_images_path .c_str ());
906
- }
907
- }
908
-
909
- if (results == NULL ) {
910
- printf (" generate failed\n " );
911
- free_sd_ctx (sd_ctx);
912
- return 1 ;
913
- }
914
-
915
- int upscale_factor = 4 ; // unused for RealESRGAN_x4plus_anime_6B.pth
916
- if (params.esrgan_path .size () > 0 && params.upscale_repeats > 0 ) {
917
- upscaler_ctx_t * upscaler_ctx = new_upscaler_ctx (params.esrgan_path .c_str (),
918
- params.n_threads ,
919
- params.wtype );
834
+ params.input_id_images_path .c_str ()));
835
+ } else {
836
+ sd_image_t input_image = {static_cast <uint32_t >(params.width ),
837
+ static_cast <uint32_t >(params.height ),
838
+ 3 ,
839
+ input_image_buffer.get ()};
920
840
921
- if (upscaler_ctx == NULL ) {
922
- printf ( " new_upscaler_ctx failed \n " );
841
+ if (params. mode == IMG2VID ) {
842
+ // Implement img2vid logic here, keeping smart pointers in mind for results.
923
843
} else {
924
- for (int i = 0 ; i < params.batch_count ; i++) {
925
- if (results[i].data == NULL ) {
926
- continue ;
927
- }
928
- sd_image_t current_image = results[i];
929
- for (int u = 0 ; u < params.upscale_repeats ; ++u) {
930
- sd_image_t upscaled_image = upscale (upscaler_ctx, current_image, upscale_factor);
931
- if (upscaled_image.data == NULL ) {
932
- printf (" upscale failed\n " );
933
- break ;
934
- }
935
- free (current_image.data );
936
- current_image = upscaled_image;
937
- }
938
- results[i] = current_image; // Set the final upscaled image as the result
939
- }
844
+ results.reset (img2img (sd_ctx.get (),
845
+ input_image,
846
+ params.prompt .c_str (),
847
+ params.negative_prompt .c_str (),
848
+ params.clip_skip ,
849
+ params.cfg_scale ,
850
+ params.guidance ,
851
+ params.width ,
852
+ params.height ,
853
+ params.sample_method ,
854
+ params.sample_steps ,
855
+ params.strength ,
856
+ params.seed ,
857
+ params.batch_count ,
858
+ control_image.get (),
859
+ params.control_strength ,
860
+ params.style_ratio ,
861
+ params.normalize_input ,
862
+ params.input_id_images_path .c_str ()));
940
863
}
941
864
}
942
865
943
- size_t last = params.output_path .find_last_of (" ." );
944
- std::string dummy_name = last != std::string::npos ? params.output_path .substr (0 , last) : params.output_path ;
945
- for (int i = 0 ; i < params.batch_count ; i++) {
946
- if (results[i].data == NULL ) {
947
- continue ;
948
- }
949
- std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 ) + " .png" : dummy_name + " .png" ;
950
- stbi_write_png (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
951
- results[i].data , 0 , get_image_params (params, params.seed + i).c_str ());
952
- printf (" save result image to '%s'\n " , final_image_path.c_str ());
953
- free (results[i].data );
954
- results[i].data = NULL ;
866
+ if (!results) {
867
+ printf (" generate failed\n " );
868
+ return 1 ;
955
869
}
956
- free (results);
957
- free_sd_ctx (sd_ctx);
958
- free (control_image_buffer);
959
- free (input_image_buffer);
960
870
871
+ // Save and cleanup logic follows here
961
872
return 0 ;
962
873
}
0 commit comments