Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Flash attention v3 support #735

Open
jstjohn opened this issue Mar 8, 2025 · 0 comments
Open

[Feature] Flash attention v3 support #735

jstjohn opened this issue Mar 8, 2025 · 0 comments

Comments

@jstjohn
Copy link
Collaborator

jstjohn commented Mar 8, 2025

Problem & Motivation

Flash attention v3 is significantly faster than flash attention v2 at long contexts. There are instructions for installing it and configuring it on the TE website.

https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L199-L203

And then if you do this, regular flash attention v2 is still used when you are on a pre-hopper GPU:
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L460-L463

We can add the installation steps to our docker container setup.

BioNeMo Framework Version

v2.4.1

Category

Model/Training

Proposed Solution

Follow the TE steps. Test that models still pass partial convergence tests and verify that the speed improvement at long context shows up when using the "flash" backend on H100 or newer.

Expected Benefits

Should be about as fast as the cuDNN implementation, which is up to 50% faster than flash attention v2 at the 1M context length when doing context parallelism on transformers. This would also benefit other architectures like evo2 that use attention in some layers.

Code Example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant