Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f88143f

Browse files
Green-SkyFSSRepo
andcommittedSep 3, 2024··
repair flash attention in _ext
this does not fix the currently broken fa behind the define, which is only used by VAE Co-authored-by: FSSRepo <FSSRepo@users.noreply.github.com>
1 parent 14206fd commit f88143f

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed
 

‎ggml_extend.hpp

+29-5
Original file line numberDiff line numberDiff line change
@@ -735,13 +735,35 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
735735

736736
float scale = (1.0f / sqrt((float)d_head));
737737

738-
bool use_flash_attn = false;
739-
ggml_tensor* kqv = NULL;
738+
LOG_DEBUG("attention_ext L_k:%d n_head:%d C:%d d_head:%d", L_k, n_head, C, d_head);
739+
740+
bool use_flash_attn = true;
741+
// L_k == n_context AND l_k == n_token ????
742+
use_flash_attn = use_flash_attn && L_k % 256 == 0;
743+
use_flash_attn = use_flash_attn && d_head % 64 == 0; // why
744+
745+
if (mask != nullptr) {
746+
// TODO: figure out if we can bend t5 to work too
747+
use_flash_attn = use_flash_attn && mask->ne[2] == 1;
748+
use_flash_attn = use_flash_attn && mask->ne[3] == 1;
749+
}
750+
751+
// TODO: more pad or disable for funny tensor shapes
752+
753+
ggml_tensor* kqv = nullptr;
740754
if (use_flash_attn) {
755+
LOG_DEBUG("using flash attention");
756+
757+
k = ggml_cast(ctx, k, GGML_TYPE_F16);
758+
741759
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
742760
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
743-
LOG_DEBUG("k->ne[1] == %d", k->ne[1]);
761+
v = ggml_cast(ctx, v, GGML_TYPE_F16);
762+
744763
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
764+
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
765+
766+
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
745767
} else {
746768
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
747769
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
@@ -757,10 +779,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
757779
kq = ggml_soft_max_inplace(ctx, kq);
758780

759781
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
782+
783+
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
784+
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
760785
}
761786

762-
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
763-
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head]
787+
kqv = ggml_cont(ctx, kqv);
764788
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
765789

766790
return kqv;

0 commit comments

Comments
 (0)
Please sign in to comment.