diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 7d60ff77..2a3613ad 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -17,6 +17,8 @@ #define STB_IMAGE_WRITE_STATIC #include "stb_image_write.h" +#include "stb_image_resize.h" + const char* rng_type_to_str[] = { "std_default", "cuda", @@ -663,21 +665,47 @@ int main(int argc, const char* argv[]) { fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str()); return 1; } - if (c != 3) { - fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c); + if (c < 3) { + fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); free(input_image_buffer); return 1; } - if (params.width <= 0 || params.width % 64 != 0) { - fprintf(stderr, "error: the width of image must be a multiple of 64\n"); + if (params.width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0\n"); free(input_image_buffer); return 1; } - if (params.height <= 0 || params.height % 64 != 0) { - fprintf(stderr, "error: the height of image must be a multiple of 64\n"); + if (params.height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0\n"); free(input_image_buffer); return 1; } + + // Resize input image ... + if (params.height % 64 != 0 || params.width % 64 != 0) { + int resized_height = params.height + (64 - params.height % 64); + int resized_width = params.width + (64 - params.width % 64); + + uint8_t *resized_image_buffer = (uint8_t *)malloc(resized_height * resized_width * 3); + if (resized_image_buffer == NULL) { + fprintf(stderr, "error: allocate memory for resize input image\n"); + free(input_image_buffer); + return 1; + } + stbir_resize(input_image_buffer, params.width, params.height, 0, + resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, + 3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, + STBIR_FILTER_BOX, STBIR_FILTER_BOX, + STBIR_COLORSPACE_SRGB, nullptr + ); + + // Save resized result + free(input_image_buffer); + input_image_buffer = resized_image_buffer; + params.height = resized_height; + params.width = resized_width; + } } sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),