From 67dcc377beb0b451f2b31b67856661761854069f Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Sat, 7 Feb 2026 08:56:24 -0300 Subject: [PATCH] feat: support for canceling the ongoing generation --- include/stable-diffusion.h | 8 +++++++ src/stable-diffusion.cpp | 49 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 75027f8f8..d1d8651db 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -414,6 +414,14 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); +enum sd_cancel_mode_t { + SD_CANCEL_ALL, + SD_CANCEL_NEW_LATENTS, + SD_CANCEL_RESET +}; + +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode); + SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 88102ff61..cf6e7a991 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -23,6 +23,8 @@ #include "latent-preview.h" #include "name_conversion.h" +#include + const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", @@ -106,6 +108,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) { /*=============================================== StableDiffusionGGML ================================================*/ +static_assert(std::atomic::is_always_lock_free, + "sd_cancel_mode_t must be lock-free"); + class StableDiffusionGGML { public: ggml_backend_t backend = nullptr; // general backend @@ -171,6 +176,20 @@ class StableDiffusionGGML { ggml_backend_free(backend); } + std::atomic cancellation_flag; + + void set_cancel_flag(enum sd_cancel_mode_t flag) { + cancellation_flag.store(flag, std::memory_order_release); + } + + void reset_cancel_flag() { + set_cancel_flag(SD_CANCEL_RESET); + } + + enum sd_cancel_mode_t get_cancel_flag() { + return cancellation_flag.load(std::memory_order_acquire); + } + void init_backend() { backend = sd_get_default_backend(); } @@ -1593,6 +1612,12 @@ class StableDiffusionGGML { SamplePreviewContext preview = prepare_sample_preview_context(); auto denoise = [&](const sd::Tensor& x, float sigma, int step) -> sd::Tensor { + enum sd_cancel_mode_t cancel_flag = get_cancel_flag(); + if (cancel_flag != SD_CANCEL_RESET) { + LOG_DEBUG("cancelling generation"); + return {}; + } + if (step == 1 || step == -1) { pretty_progress(0, (int)steps, 0); } @@ -2427,6 +2452,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) { + if (sd_ctx && sd_ctx->sd) { + if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) { + mode = SD_CANCEL_ALL; + } + sd_ctx->sd->set_cancel_flag(mode); + } +} + SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) { if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { return false; @@ -3169,6 +3203,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, int64_t t0 = ggml_time_ms(); for (size_t i = 0; i < final_latents.size(); i++) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling latent decodings"); + break; + } int64_t t1 = ggml_time_ms(); sd::Tensor image = sd_ctx->sd->decode_first_stage(final_latents[i]); if (image.empty()) { @@ -3336,6 +3374,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s return nullptr; } + sd_ctx->sd->reset_cancel_flag(); + int64_t t0 = ggml_time_ms(); sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; GenerationRequest request(sd_ctx, sd_img_gen_params); @@ -3371,6 +3411,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s std::vector> final_latents; int64_t denoise_start = ggml_time_ms(); for (int b = 0; b < request.batch_count; b++) { + sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag(); + if (cancel == SD_CANCEL_NEW_LATENTS || cancel == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation"); + break; + } + int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = request.seed + b; LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed); @@ -3823,6 +3869,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) { return nullptr; } + + sd_ctx->sd->reset_cancel_flag(); + if (num_frames_out != nullptr) { *num_frames_out = 0; }