-
Notifications
You must be signed in to change notification settings - Fork 164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Hide 75% of the communication in tensor parallelism using DoMiNo #292
Open
xrsrke
wants to merge
38
commits into
main
Choose a base branch
from
xrsrke/exp7a13b0_domino_revert_from_fix_stream_not_sync_exp2a1c7_and_commit_23f2_and_75_percent_bwd_overlapping_with_cuda_stream_sync_bwd_but_remove_stream_manager_ctx
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 37 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
d7bf8be
first draft of domino forward pass
xrsrke 3803b19
support the backward pass
xrsrke d765fd5
the first draft for bwd overlapping
xrsrke 9924608
add backward pass overlapping
xrsrke d6bc8da
fix some ops dont execute in the bwd pass
xrsrke 93b2f10
fix can't find an ops in fwd
xrsrke 31db05d
partially overlapping bwd pass
xrsrke 23f2108
fix stream not sync
xrsrke 3e3ae8c
exp2a1c7c2_like_exp2a1c1_domini_llama3_3b_with_tp8_and_seqlen4096_and…
xrsrke c261488
refactor
xrsrke 841c7d6
add tests and more refactoring
xrsrke 8a0f993
add domino config, fix breaks in _RowLinearAsyncCommunication
xrsrke a61d2df
add bwd.layer_mlp_x_batch_1 as async op
xrsrke 06e17bc
- add cuda stream sync after attn_output0[work]
xrsrke 8d44942
wait default_stream instead of current_stream
xrsrke aa77e6c
put torch.cuda.synchronize() everywhere
xrsrke 76b5f9a
only bwd.layer_attn_{}_batch_0 as non async
xrsrke fe7ee7e
exp7a7_like_exp7a6_but_remove_fwd_pass_cuda_syncronization
xrsrke e0a9bd0
remove torch.cuda.synchronize in WaitComm.backward
xrsrke a772ff0
add back torch.cuda.synchronize in WaitComm.backward and small refactors
xrsrke 543ef56
add ctx.comm_stream.wait_stream(torch.cuda.default_stream()) to WaitC…
xrsrke 36c9980
exp7a10_like_exp7a6_but_remove_fwd_pass_cuda_syncronization_and_remov…
xrsrke 613eb16
remove comments and add typing
xrsrke 600f01a
remove explicite async_op arg
xrsrke 320e55d
Merge remote-tracking branch 'origin/main' into domino_revert_from_fi…
xrsrke 29a8914
pass stream amanger to llama's modules
xrsrke 75abb32
move domino's assert args to config
xrsrke da4220c
add retrieving async distributed handle from comm bucket instead of r…
xrsrke d7a636f
small refactor
xrsrke d3d8c10
add CudaStreamManager.init_default_comm_stream and fix domino test
xrsrke 74d415c
removing op_name in the forward pass by adding OpNameContext
xrsrke 08a4472
add CudaStreamManager as context
xrsrke 684b1b9
small refactor
xrsrke f8e8b1f
Reverting repository to commit 74d415c1c02b9463214fb46db060c0efbfa5a0e4
xrsrke 61ff007
add todos
xrsrke 9039ce2
add todo
xrsrke 62fb3b2
add todos
xrsrke 7c7b6f7
add todo and undo torch_nn
xrsrke File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
checkpoints: | ||
checkpoint_interval: 10000 | ||
checkpoints_path: checkpoints | ||
checkpoints_path_is_shared_file_system: false | ||
resume_checkpoint_path: null | ||
load_lr_scheduler: false | ||
load_optimizer: false | ||
save_final_state: true | ||
save_initial_state: false | ||
data_stages: | ||
- data: | ||
dataset: | ||
dataset_folder: | ||
- /fsx/elie_bakouch/data/fw-edu-dedup | ||
num_loading_workers: 0 | ||
seed: 8 | ||
name: stable phase | ||
start_training_step: 1 | ||
general: | ||
benchmark_csv_path: null | ||
consumed_train_samples: null | ||
ignore_sanity_checks: true | ||
project: nanotron_domino | ||
run: domino_config | ||
seed: 6 | ||
step: null | ||
logging: | ||
iteration_step_info_interval: 1 | ||
log_level: info | ||
log_level_replica: info | ||
model: | ||
ddp_bucket_cap_mb: 25 | ||
dtype: bfloat16 | ||
init_method: | ||
std: 0.041666666666666664 | ||
make_vocab_size_divisible_by: 1 | ||
model_config: | ||
bos_token_id: 128000 | ||
eos_token_id: 128001 | ||
hidden_act: silu | ||
hidden_size: 4096 | ||
initializer_range: 0.02 | ||
intermediate_size: 14336 | ||
is_llama_config: true | ||
max_position_embeddings: 4096 | ||
num_attention_heads: 32 | ||
num_hidden_layers: 15 | ||
num_key_value_heads: 8 | ||
pad_token_id: null | ||
pretraining_tp: 2 | ||
rms_norm_eps: 1.0e-05 | ||
rope_interleaved: false | ||
rope_scaling: | ||
factor: 32.0 | ||
high_freq_factor: 4.0 | ||
low_freq_factor: 1.0 | ||
original_max_position_embeddings: 4096 | ||
rope_type: llama3 | ||
rope_theta: 500000.0 | ||
tie_word_embeddings: true | ||
use_cache: true | ||
vocab_size: 128256 | ||
optimizer: | ||
accumulate_grad_in_fp32: true | ||
clip_grad: 1.0 | ||
learning_rate_scheduler: | ||
learning_rate: 0.00005 | ||
lr_decay_starting_step: 50000 | ||
lr_decay_steps: 10000 | ||
lr_decay_style: linear | ||
lr_warmup_steps: 1000 | ||
lr_warmup_style: linear | ||
min_decay_lr: 0 | ||
optimizer_factory: | ||
adam_beta1: 0.9 | ||
adam_beta2: 0.95 | ||
adam_eps: 1.0e-08 | ||
name: adamW | ||
torch_adam_is_fused: true | ||
weight_decay: 0.01 | ||
zero_stage: 1 | ||
parallelism: | ||
dp: 1 | ||
expert_parallel_size: 1 | ||
pp: 1 | ||
pp_engine: 1f1b | ||
recompute_layer: false | ||
tp: 8 | ||
tp_linear_async_communication: false | ||
tp_mode: ALL_REDUCE | ||
tp_recompute_allgather: false | ||
domino: | ||
num_input_batches: 2 | ||
tokenizer: | ||
tokenizer_max_length: null | ||
tokenizer_name_or_path: meta-llama/Llama-3.2-3B | ||
tokenizer_revision: null | ||
tokens: | ||
batch_accumulation_per_replica: 2 | ||
limit_test_batches: 0 | ||
limit_val_batches: 0 | ||
micro_batch_size: 8 | ||
sequence_length: 4096 | ||
train_steps: 15000 | ||
val_check_interval: -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add new ticket in tracker to add support for REDUCE_SCATTER please