You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Flash attention v3 is significantly faster than flash attention v2 at long contexts. There are instructions for installing it and configuring it on the TE website.
We can add the installation steps to our docker container setup.
BioNeMo Framework Version
v2.4.1
Category
Model/Training
Proposed Solution
Follow the TE steps. Test that models still pass partial convergence tests and verify that the speed improvement at long context shows up when using the "flash" backend on H100 or newer.
Expected Benefits
Should be about as fast as the cuDNN implementation, which is up to 50% faster than flash attention v2 at the 1M context length when doing context parallelism on transformers. This would also benefit other architectures like evo2 that use attention in some layers.
Code Example
The text was updated successfully, but these errors were encountered:
Problem & Motivation
Flash attention v3 is significantly faster than flash attention v2 at long contexts. There are instructions for installing it and configuring it on the TE website.
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L199-L203
And then if you do this, regular flash attention v2 is still used when you are on a pre-hopper GPU:
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L460-L463
We can add the installation steps to our docker container setup.
BioNeMo Framework Version
v2.4.1
Category
Model/Training
Proposed Solution
Follow the TE steps. Test that models still pass partial convergence tests and verify that the speed improvement at long context shows up when using the "flash" backend on H100 or newer.
Expected Benefits
Should be about as fast as the cuDNN implementation, which is up to 50% faster than flash attention v2 at the 1M context length when doing context parallelism on transformers. This would also benefit other architectures like evo2 that use attention in some layers.
Code Example
The text was updated successfully, but these errors were encountered: