Skip to content

Commit 4570715

Browse files
committed
fix: use ggml_nn_attention in vae
1 parent 53b415f commit 4570715

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

ggml_extend.hpp

+24
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,30 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_qkv(struct ggml_context
661661
return {q, k, v};
662662
}
663663

664+
// q: [N * n_head, n_token, d_head]
665+
// k: [N * n_head, n_k, d_head]
666+
// v: [N * n_head, d_head, n_k]
667+
// return: [N * n_head, n_token, d_head]
668+
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx,
669+
struct ggml_tensor* q,
670+
struct ggml_tensor* k,
671+
struct ggml_tensor* v,
672+
bool mask = false) {
673+
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
674+
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
675+
#else
676+
float d_head = (float)q->ne[0];
677+
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
678+
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
679+
if (mask) {
680+
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
681+
}
682+
kq = ggml_soft_max_inplace(ctx, kq);
683+
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
684+
#endif
685+
return kqv;
686+
}
687+
664688
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
665689
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
666690
// v: [N, L_k, C] or [N, L_k, n_head, d_head]

vae.hpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,10 @@ class AttnBlock : public UnaryBlock {
9999
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
100100
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
101101

102-
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
103-
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
104-
v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels]
102+
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
103+
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
105104

106-
// h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
107-
h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false);
105+
h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
108106

109107
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
110108
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]

0 commit comments

Comments
 (0)