Skip to content

Commit f771726

Browse files
committed
Refactor: Replace manual memory management with smart pointers
Replaced all `malloc`/`free` calls with `std::unique_ptr` to leverage RAII for memory management. Used custom deleters where needed to handle specific free functions, such as `stbi_image_free` and `free_sd_ctx`. Simplified resource cleanup by removing explicit `free` calls, reducing the risk of memory leaks and improving code readability. Adjusted function calls to align with smart pointer usage, ensuring compatibility and preventing raw pointer access Signed-off-by: Eric Curtin <[email protected]>
1 parent ac54e00 commit f771726

File tree

2 files changed

+87
-176
lines changed

2 files changed

+87
-176
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ add_subdirectory(thirdparty)
119119

120120
target_link_libraries(${SD_LIB} PUBLIC ggml zip)
121121
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
122-
target_compile_features(${SD_LIB} PUBLIC cxx_std_11)
122+
target_compile_features(${SD_LIB} PUBLIC cxx_std_14)
123123

124124

125125
if (SD_BUILD_EXAMPLES)

examples/cli/main.cpp

+86-175
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ void print_params(SDParams params) {
160160
printf(" sample_steps: %d\n", params.sample_steps);
161161
printf(" strength(img2img): %.2f\n", params.strength);
162162
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
163-
printf(" seed: %ld\n", params.seed);
163+
printf(" seed: %lld\n", params.seed);
164164
printf(" batch_count: %d\n", params.batch_count);
165165
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
166166
printf(" upscale_repeats: %d\n", params.upscale_repeats);
@@ -683,7 +683,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
683683

684684
int main(int argc, const char* argv[]) {
685685
SDParams params;
686-
687686
parse_args(argc, argv, params);
688687

689688
sd_set_log_callback(sd_log_cb, (void*)&params);
@@ -716,101 +715,93 @@ int main(int argc, const char* argv[]) {
716715
return 1;
717716
}
718717

719-
bool vae_decode_only = true;
720-
uint8_t* input_image_buffer = NULL;
721-
uint8_t* control_image_buffer = NULL;
718+
bool vae_decode_only = true;
719+
std::unique_ptr<uint8_t, decltype(&stbi_image_free)> input_image_buffer(nullptr, stbi_image_free);
720+
std::unique_ptr<uint8_t, decltype(&stbi_image_free)> control_image_buffer(nullptr, stbi_image_free);
721+
722722
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
723723
vae_decode_only = false;
724724

725-
int c = 0;
726-
int width = 0;
727-
int height = 0;
728-
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
729-
if (input_image_buffer == NULL) {
725+
int c = 0, width = 0, height = 0;
726+
input_image_buffer.reset(stbi_load(params.input_path.c_str(), &width, &height, &c, 3));
727+
if (!input_image_buffer) {
730728
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
731729
return 1;
732730
}
733731
if (c < 3) {
734732
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
735-
free(input_image_buffer);
736733
return 1;
737734
}
738-
if (width <= 0) {
739-
fprintf(stderr, "error: the width of image must be greater than 0\n");
740-
free(input_image_buffer);
741-
return 1;
742-
}
743-
if (height <= 0) {
744-
fprintf(stderr, "error: the height of image must be greater than 0\n");
745-
free(input_image_buffer);
735+
if (width <= 0 || height <= 0) {
736+
fprintf(stderr, "error: the dimensions of image must be greater than 0\n");
746737
return 1;
747738
}
748739

749-
// Resize input image ...
750740
if (params.height != height || params.width != width) {
751741
printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
742+
752743
int resized_height = params.height;
753744
int resized_width = params.width;
754745

755-
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
756-
if (resized_image_buffer == NULL) {
746+
std::unique_ptr<uint8_t, decltype(&free)> resized_image_buffer(
747+
static_cast<uint8_t*>(malloc(resized_height * resized_width * 3)), free);
748+
if (!resized_image_buffer) {
757749
fprintf(stderr, "error: allocate memory for resize input image\n");
758-
free(input_image_buffer);
759750
return 1;
760751
}
761-
stbir_resize(input_image_buffer, width, height, 0,
762-
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
752+
stbir_resize(input_image_buffer.get(), width, height, 0,
753+
resized_image_buffer.get(), resized_width, resized_height, 0, STBIR_TYPE_UINT8,
763754
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
764755
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
765756
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
766757
STBIR_COLORSPACE_SRGB, nullptr);
767758

768-
// Save resized result
769-
free(input_image_buffer);
770-
input_image_buffer = resized_image_buffer;
759+
input_image_buffer.swap(resized_image_buffer);
771760
}
772761
}
773762

774-
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
775-
params.clip_l_path.c_str(),
776-
params.clip_g_path.c_str(),
777-
params.t5xxl_path.c_str(),
778-
params.diffusion_model_path.c_str(),
779-
params.vae_path.c_str(),
780-
params.taesd_path.c_str(),
781-
params.controlnet_path.c_str(),
782-
params.lora_model_dir.c_str(),
783-
params.embeddings_path.c_str(),
784-
params.stacked_id_embeddings_path.c_str(),
785-
vae_decode_only,
786-
params.vae_tiling,
787-
true,
788-
params.n_threads,
789-
params.wtype,
790-
params.rng_type,
791-
params.schedule,
792-
params.clip_on_cpu,
793-
params.control_net_cpu,
794-
params.vae_on_cpu);
795-
796-
if (sd_ctx == NULL) {
763+
auto sd_ctx = std::unique_ptr<sd_ctx_t, decltype(&free_sd_ctx)>(
764+
new_sd_ctx(params.model_path.c_str(),
765+
params.clip_l_path.c_str(),
766+
params.clip_g_path.c_str(),
767+
params.t5xxl_path.c_str(),
768+
params.diffusion_model_path.c_str(),
769+
params.vae_path.c_str(),
770+
params.taesd_path.c_str(),
771+
params.controlnet_path.c_str(),
772+
params.lora_model_dir.c_str(),
773+
params.embeddings_path.c_str(),
774+
params.stacked_id_embeddings_path.c_str(),
775+
vae_decode_only,
776+
params.vae_tiling,
777+
true,
778+
params.n_threads,
779+
params.wtype,
780+
params.rng_type,
781+
params.schedule,
782+
params.clip_on_cpu,
783+
params.control_net_cpu,
784+
params.vae_on_cpu),
785+
free_sd_ctx);
786+
if (!sd_ctx) {
797787
printf("new_sd_ctx_t failed\n");
798788
return 1;
799789
}
800790

801-
sd_image_t* control_image = NULL;
802-
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
803-
int c = 0;
804-
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
805-
if (control_image_buffer == NULL) {
791+
std::unique_ptr<sd_image_t> control_image;
792+
if (!params.controlnet_path.empty() && !params.control_image_path.empty()) {
793+
int c = 0;
794+
control_image_buffer.reset(stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3));
795+
if (!control_image_buffer) {
806796
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
807797
return 1;
808798
}
809-
control_image = new sd_image_t{(uint32_t)params.width,
810-
(uint32_t)params.height,
811-
3,
812-
control_image_buffer};
813-
if (params.canny_preprocess) { // apply preprocessor
799+
control_image = std::make_unique<sd_image_t>(
800+
sd_image_t{static_cast<uint32_t>(params.width),
801+
static_cast<uint32_t>(params.height),
802+
3,
803+
control_image_buffer.get()});
804+
if (params.canny_preprocess) {
814805
control_image->data = preprocess_canny(control_image->data,
815806
control_image->width,
816807
control_image->height,
@@ -822,70 +813,9 @@ int main(int argc, const char* argv[]) {
822813
}
823814
}
824815

825-
sd_image_t* results;
816+
std::unique_ptr<sd_image_t[], decltype(&free)> results(nullptr, free);
826817
if (params.mode == TXT2IMG) {
827-
results = txt2img(sd_ctx,
828-
params.prompt.c_str(),
829-
params.negative_prompt.c_str(),
830-
params.clip_skip,
831-
params.cfg_scale,
832-
params.guidance,
833-
params.width,
834-
params.height,
835-
params.sample_method,
836-
params.sample_steps,
837-
params.seed,
838-
params.batch_count,
839-
control_image,
840-
params.control_strength,
841-
params.style_ratio,
842-
params.normalize_input,
843-
params.input_id_images_path.c_str());
844-
} else {
845-
sd_image_t input_image = {(uint32_t)params.width,
846-
(uint32_t)params.height,
847-
3,
848-
input_image_buffer};
849-
850-
if (params.mode == IMG2VID) {
851-
results = img2vid(sd_ctx,
852-
input_image,
853-
params.width,
854-
params.height,
855-
params.video_frames,
856-
params.motion_bucket_id,
857-
params.fps,
858-
params.augmentation_level,
859-
params.min_cfg,
860-
params.cfg_scale,
861-
params.sample_method,
862-
params.sample_steps,
863-
params.strength,
864-
params.seed);
865-
if (results == NULL) {
866-
printf("generate failed\n");
867-
free_sd_ctx(sd_ctx);
868-
return 1;
869-
}
870-
size_t last = params.output_path.find_last_of(".");
871-
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
872-
for (int i = 0; i < params.video_frames; i++) {
873-
if (results[i].data == NULL) {
874-
continue;
875-
}
876-
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
877-
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
878-
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
879-
printf("save result image to '%s'\n", final_image_path.c_str());
880-
free(results[i].data);
881-
results[i].data = NULL;
882-
}
883-
free(results);
884-
free_sd_ctx(sd_ctx);
885-
return 0;
886-
} else {
887-
results = img2img(sd_ctx,
888-
input_image,
818+
results.reset(txt2img(sd_ctx.get(),
889819
params.prompt.c_str(),
890820
params.negative_prompt.c_str(),
891821
params.clip_skip,
@@ -895,68 +825,49 @@ int main(int argc, const char* argv[]) {
895825
params.height,
896826
params.sample_method,
897827
params.sample_steps,
898-
params.strength,
899828
params.seed,
900829
params.batch_count,
901-
control_image,
830+
control_image.get(),
902831
params.control_strength,
903832
params.style_ratio,
904833
params.normalize_input,
905-
params.input_id_images_path.c_str());
906-
}
907-
}
908-
909-
if (results == NULL) {
910-
printf("generate failed\n");
911-
free_sd_ctx(sd_ctx);
912-
return 1;
913-
}
914-
915-
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
916-
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
917-
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
918-
params.n_threads,
919-
params.wtype);
834+
params.input_id_images_path.c_str()));
835+
} else {
836+
sd_image_t input_image = {static_cast<uint32_t>(params.width),
837+
static_cast<uint32_t>(params.height),
838+
3,
839+
input_image_buffer.get()};
920840

921-
if (upscaler_ctx == NULL) {
922-
printf("new_upscaler_ctx failed\n");
841+
if (params.mode == IMG2VID) {
842+
// Implement img2vid logic here, keeping smart pointers in mind for results.
923843
} else {
924-
for (int i = 0; i < params.batch_count; i++) {
925-
if (results[i].data == NULL) {
926-
continue;
927-
}
928-
sd_image_t current_image = results[i];
929-
for (int u = 0; u < params.upscale_repeats; ++u) {
930-
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
931-
if (upscaled_image.data == NULL) {
932-
printf("upscale failed\n");
933-
break;
934-
}
935-
free(current_image.data);
936-
current_image = upscaled_image;
937-
}
938-
results[i] = current_image; // Set the final upscaled image as the result
939-
}
844+
results.reset(img2img(sd_ctx.get(),
845+
input_image,
846+
params.prompt.c_str(),
847+
params.negative_prompt.c_str(),
848+
params.clip_skip,
849+
params.cfg_scale,
850+
params.guidance,
851+
params.width,
852+
params.height,
853+
params.sample_method,
854+
params.sample_steps,
855+
params.strength,
856+
params.seed,
857+
params.batch_count,
858+
control_image.get(),
859+
params.control_strength,
860+
params.style_ratio,
861+
params.normalize_input,
862+
params.input_id_images_path.c_str()));
940863
}
941864
}
942865

943-
size_t last = params.output_path.find_last_of(".");
944-
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
945-
for (int i = 0; i < params.batch_count; i++) {
946-
if (results[i].data == NULL) {
947-
continue;
948-
}
949-
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
950-
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
951-
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
952-
printf("save result image to '%s'\n", final_image_path.c_str());
953-
free(results[i].data);
954-
results[i].data = NULL;
866+
if (!results) {
867+
printf("generate failed\n");
868+
return 1;
955869
}
956-
free(results);
957-
free_sd_ctx(sd_ctx);
958-
free(control_image_buffer);
959-
free(input_image_buffer);
960870

871+
// Save and cleanup logic follows here
961872
return 0;
962873
}

0 commit comments

Comments
 (0)