Skip to content

Commit 26e3c0b

Browse files
committed
Fix cfg 1 crash
1 parent bdb2709 commit 26e3c0b

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

stable-diffusion.cpp

+29-27
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ class StableDiffusionGGML {
798798
SDCondition id_cond,
799799
sd_slg_params_t slg_params = {NULL, 0, 0, 0, 0},
800800
sd_apg_params_t apg_params = {1, 0, 0},
801-
ggml_tensor* noise_mask = nullptr) {
801+
ggml_tensor* noise_mask = nullptr) {
802802
std::vector<int> skip_layers(slg_params.skip_layers, slg_params.skip_layers + slg_params.skip_layers_count);
803803

804804
LOG_DEBUG("Sample");
@@ -959,39 +959,41 @@ class StableDiffusionGGML {
959959
float diff_norm = 0;
960960
float cond_norm_sq = 0;
961961
float dot = 0;
962-
for (int i = 0; i < ne_elements; i++) {
963-
float delta = positive_data[i] - negative_data[i];
964-
if (apg_params.momentum != 0) {
965-
delta += apg_params.momentum * apg_momentum_buffer[i];
966-
apg_momentum_buffer[i] = delta;
962+
if (has_unconditioned) {
963+
for (int i = 0; i < ne_elements; i++) {
964+
float delta = positive_data[i] - negative_data[i];
965+
if (apg_params.momentum != 0) {
966+
delta += apg_params.momentum * apg_momentum_buffer[i];
967+
apg_momentum_buffer[i] = delta;
968+
}
969+
if (apg_params.norm_treshold > 0) {
970+
diff_norm += delta * delta;
971+
}
972+
if (apg_params.eta != 1.0f) {
973+
cond_norm_sq += positive_data[i] * positive_data[i];
974+
dot += positive_data[i] * delta;
975+
}
976+
deltas[i] = delta;
967977
}
968978
if (apg_params.norm_treshold > 0) {
969-
diff_norm += delta * delta;
979+
diff_norm = std::sqrtf(diff_norm);
980+
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
970981
}
971982
if (apg_params.eta != 1.0f) {
972-
cond_norm_sq += positive_data[i] * positive_data[i];
973-
dot += positive_data[i] * delta;
983+
dot *= apg_scale_factor;
984+
// pre-normalize (avoids one square root and ne_elements extra divs)
985+
dot /= cond_norm_sq;
974986
}
975-
deltas[i] = delta;
976-
}
977-
if (apg_params.norm_treshold > 0) {
978-
diff_norm = std::sqrtf(diff_norm);
979-
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
980-
}
981-
if (apg_params.eta != 1.0f) {
982-
dot *= apg_scale_factor;
983-
// pre-normalize (avoids one square root and ne_elements extra divs)
984-
dot /= cond_norm_sq;
985-
}
986987

987-
for (int i = 0; i < ne_elements; i++) {
988-
deltas[i] *= apg_scale_factor;
989-
if (apg_params.eta != 1.0f) {
990-
float apg_parallel = dot * positive_data[i];
991-
float apg_orthogonal = deltas[i] - apg_parallel;
988+
for (int i = 0; i < ne_elements; i++) {
989+
deltas[i] *= apg_scale_factor;
990+
if (apg_params.eta != 1.0f) {
991+
float apg_parallel = dot * positive_data[i];
992+
float apg_orthogonal = deltas[i] - apg_parallel;
992993

993-
// tweak deltas
994-
deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
994+
// tweak deltas
995+
deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
996+
}
995997
}
996998
}
997999

0 commit comments

Comments
 (0)