@@ -798,7 +798,7 @@ class StableDiffusionGGML {
798
798
SDCondition id_cond,
799
799
sd_slg_params_t slg_params = {NULL , 0 , 0 , 0 , 0 },
800
800
sd_apg_params_t apg_params = {1 , 0 , 0 },
801
- ggml_tensor* noise_mask = nullptr ) {
801
+ ggml_tensor* noise_mask = nullptr ) {
802
802
std::vector<int > skip_layers (slg_params.skip_layers , slg_params.skip_layers + slg_params.skip_layers_count );
803
803
804
804
LOG_DEBUG (" Sample" );
@@ -959,39 +959,41 @@ class StableDiffusionGGML {
959
959
float diff_norm = 0 ;
960
960
float cond_norm_sq = 0 ;
961
961
float dot = 0 ;
962
- for (int i = 0 ; i < ne_elements; i++) {
963
- float delta = positive_data[i] - negative_data[i];
964
- if (apg_params.momentum != 0 ) {
965
- delta += apg_params.momentum * apg_momentum_buffer[i];
966
- apg_momentum_buffer[i] = delta;
962
+ if (has_unconditioned) {
963
+ for (int i = 0 ; i < ne_elements; i++) {
964
+ float delta = positive_data[i] - negative_data[i];
965
+ if (apg_params.momentum != 0 ) {
966
+ delta += apg_params.momentum * apg_momentum_buffer[i];
967
+ apg_momentum_buffer[i] = delta;
968
+ }
969
+ if (apg_params.norm_treshold > 0 ) {
970
+ diff_norm += delta * delta;
971
+ }
972
+ if (apg_params.eta != 1 .0f ) {
973
+ cond_norm_sq += positive_data[i] * positive_data[i];
974
+ dot += positive_data[i] * delta;
975
+ }
976
+ deltas[i] = delta;
967
977
}
968
978
if (apg_params.norm_treshold > 0 ) {
969
- diff_norm += delta * delta;
979
+ diff_norm = std::sqrtf (diff_norm);
980
+ apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
970
981
}
971
982
if (apg_params.eta != 1 .0f ) {
972
- cond_norm_sq += positive_data[i] * positive_data[i];
973
- dot += positive_data[i] * delta;
983
+ dot *= apg_scale_factor;
984
+ // pre-normalize (avoids one square root and ne_elements extra divs)
985
+ dot /= cond_norm_sq;
974
986
}
975
- deltas[i] = delta;
976
- }
977
- if (apg_params.norm_treshold > 0 ) {
978
- diff_norm = std::sqrtf (diff_norm);
979
- apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
980
- }
981
- if (apg_params.eta != 1 .0f ) {
982
- dot *= apg_scale_factor;
983
- // pre-normalize (avoids one square root and ne_elements extra divs)
984
- dot /= cond_norm_sq;
985
- }
986
987
987
- for (int i = 0 ; i < ne_elements; i++) {
988
- deltas[i] *= apg_scale_factor;
989
- if (apg_params.eta != 1 .0f ) {
990
- float apg_parallel = dot * positive_data[i];
991
- float apg_orthogonal = deltas[i] - apg_parallel;
988
+ for (int i = 0 ; i < ne_elements; i++) {
989
+ deltas[i] *= apg_scale_factor;
990
+ if (apg_params.eta != 1 .0f ) {
991
+ float apg_parallel = dot * positive_data[i];
992
+ float apg_orthogonal = deltas[i] - apg_parallel;
992
993
993
- // tweak deltas
994
- deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
994
+ // tweak deltas
995
+ deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
996
+ }
995
997
}
996
998
}
997
999
0 commit comments