@@ -735,13 +735,35 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
735
735
736
736
float scale = (1 .0f / sqrt ((float )d_head));
737
737
738
- bool use_flash_attn = false ;
739
- ggml_tensor* kqv = NULL ;
738
+ LOG_DEBUG (" attention_ext L_k:%d n_head:%d C:%d d_head:%d" , L_k, n_head, C, d_head);
739
+
740
+ bool use_flash_attn = true ;
741
+ // L_k == n_context AND l_k == n_token ????
742
+ use_flash_attn = use_flash_attn && L_k % 256 == 0 ;
743
+ use_flash_attn = use_flash_attn && d_head % 64 == 0 ; // why
744
+
745
+ if (mask != nullptr ) {
746
+ // TODO: figure out if we can bend t5 to work too
747
+ use_flash_attn = use_flash_attn && mask->ne [2 ] == 1 ;
748
+ use_flash_attn = use_flash_attn && mask->ne [3 ] == 1 ;
749
+ }
750
+
751
+ // TODO: more pad or disable for funny tensor shapes
752
+
753
+ ggml_tensor* kqv = nullptr ;
740
754
if (use_flash_attn) {
755
+ LOG_DEBUG (" using flash attention" );
756
+
757
+ k = ggml_cast (ctx, k, GGML_TYPE_F16);
758
+
741
759
v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
742
760
v = ggml_reshape_3d (ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
743
- LOG_DEBUG (" k->ne[1] == %d" , k->ne [1 ]);
761
+ v = ggml_cast (ctx, v, GGML_TYPE_F16);
762
+
744
763
kqv = ggml_flash_attn_ext (ctx, q, k, v, mask, scale, 0 , 0 );
764
+ ggml_flash_attn_ext_set_prec (kqv, GGML_PREC_F32);
765
+
766
+ kqv = ggml_view_3d (ctx, kqv, d_head, n_head, L_k, kqv->nb [1 ], kqv->nb [2 ], 0 );
745
767
} else {
746
768
v = ggml_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, n_head, d_head, L_k]
747
769
v = ggml_reshape_3d (ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
@@ -757,10 +779,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
757
779
kq = ggml_soft_max_inplace (ctx, kq);
758
780
759
781
kqv = ggml_mul_mat (ctx, v, kq); // [N * n_head, L_q, d_head]
782
+
783
+ kqv = ggml_reshape_4d (ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
784
+ kqv = ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 ); // [N, L_q, n_head, d_head]
760
785
}
761
786
762
- kqv = ggml_reshape_4d (ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
763
- kqv = ggml_cont (ctx, ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 )); // [N, L_q, n_head, d_head]
787
+ kqv = ggml_cont (ctx, kqv);
764
788
kqv = ggml_reshape_3d (ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
765
789
766
790
return kqv;
0 commit comments