Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Hide 75% of the communication in tensor parallelism using DoMiNo #292

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d7bf8be
first draft of domino forward pass
xrsrke Jan 29, 2025
3803b19
support the backward pass
xrsrke Jan 30, 2025
d765fd5
the first draft for bwd overlapping
xrsrke Jan 31, 2025
9924608
add backward pass overlapping
xrsrke Feb 3, 2025
d6bc8da
fix some ops dont execute in the bwd pass
xrsrke Feb 4, 2025
93b2f10
fix can't find an ops in fwd
xrsrke Feb 5, 2025
31db05d
partially overlapping bwd pass
xrsrke Feb 5, 2025
23f2108
fix stream not sync
xrsrke Feb 10, 2025
3e3ae8c
exp2a1c7c2_like_exp2a1c1_domini_llama3_3b_with_tp8_and_seqlen4096_and…
xrsrke Feb 21, 2025
c261488
refactor
xrsrke Feb 21, 2025
841c7d6
add tests and more refactoring
xrsrke Feb 21, 2025
8a0f993
add domino config, fix breaks in _RowLinearAsyncCommunication
xrsrke Feb 24, 2025
a61d2df
add bwd.layer_mlp_x_batch_1 as async op
xrsrke Feb 25, 2025
06e17bc
- add cuda stream sync after attn_output0[work]
xrsrke Feb 25, 2025
8d44942
wait default_stream instead of current_stream
xrsrke Feb 25, 2025
aa77e6c
put torch.cuda.synchronize() everywhere
xrsrke Feb 25, 2025
76b5f9a
only bwd.layer_attn_{}_batch_0 as non async
xrsrke Feb 25, 2025
fe7ee7e
exp7a7_like_exp7a6_but_remove_fwd_pass_cuda_syncronization
xrsrke Feb 25, 2025
e0a9bd0
remove torch.cuda.synchronize in WaitComm.backward
xrsrke Feb 26, 2025
a772ff0
add back torch.cuda.synchronize in WaitComm.backward and small refactors
xrsrke Feb 26, 2025
543ef56
add ctx.comm_stream.wait_stream(torch.cuda.default_stream()) to WaitC…
xrsrke Feb 27, 2025
36c9980
exp7a10_like_exp7a6_but_remove_fwd_pass_cuda_syncronization_and_remov…
xrsrke Feb 27, 2025
613eb16
remove comments and add typing
xrsrke Feb 28, 2025
600f01a
remove explicite async_op arg
xrsrke Feb 28, 2025
320e55d
Merge remote-tracking branch 'origin/main' into domino_revert_from_fi…
xrsrke Mar 5, 2025
29a8914
pass stream amanger to llama's modules
xrsrke Mar 7, 2025
75abb32
move domino's assert args to config
xrsrke Mar 7, 2025
da4220c
add retrieving async distributed handle from comm bucket instead of r…
xrsrke Mar 7, 2025
d7a636f
small refactor
xrsrke Mar 7, 2025
d3d8c10
add CudaStreamManager.init_default_comm_stream and fix domino test
xrsrke Mar 7, 2025
74d415c
removing op_name in the forward pass by adding OpNameContext
xrsrke Mar 7, 2025
08a4472
add CudaStreamManager as context
xrsrke Mar 8, 2025
684b1b9
small refactor
xrsrke Mar 8, 2025
f8e8b1f
Reverting repository to commit 74d415c1c02b9463214fb46db060c0efbfa5a0e4
xrsrke Mar 10, 2025
61ff007
add todos
xrsrke Mar 10, 2025
9039ce2
add todo
xrsrke Mar 10, 2025
62fb3b2
add todos
xrsrke Mar 10, 2025
7c7b6f7
add todo and undo torch_nn
xrsrke Mar 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions examples/domino/domino_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
checkpoints:
checkpoint_interval: 10000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
load_lr_scheduler: false
load_optimizer: false
save_final_state: true
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder:
- /fsx/elie_bakouch/data/fw-edu-dedup
num_loading_workers: 0
seed: 8
name: stable phase
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: nanotron_domino
run: domino_config
seed: 6
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.041666666666666664
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 4096
num_attention_heads: 32
num_hidden_layers: 15
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 2
rms_norm_eps: 1.0e-05
rope_interleaved: false
rope_scaling:
factor: 32.0
high_freq_factor: 4.0
low_freq_factor: 1.0
original_max_position_embeddings: 4096
rope_type: llama3
rope_theta: 500000.0
tie_word_embeddings: true
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.00005
lr_decay_starting_step: 50000
lr_decay_steps: 10000
lr_decay_style: linear
lr_warmup_steps: 1000
lr_warmup_style: linear
min_decay_lr: 0
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 1
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 8
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
tp_recompute_allgather: false
domino:
num_input_batches: 2
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Llama-3.2-3B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 2
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 8
sequence_length: 4096
train_steps: 15000
val_check_interval: -1
34 changes: 34 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode


@dataclass
class DominoArgs:
"""
Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping
https://arxiv.org/abs/2409.15241
"""

# NOTE: if the number of input batches is 1,
# it's equivalent to non-domino mode
# so if you want to enable domino mode, set this to > 1
num_input_batches: int

def __post_init__(self):
assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1"
assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino"


@dataclass
class ParallelismArgs:
"""Arguments related to TP/PP/DP
Expand All @@ -37,6 +54,7 @@ class ParallelismArgs:
tp_recompute_allgather: bool = True

expert_parallel_size: int = 1
domino: Optional[DominoArgs] = None

def __post_init__(self):
# Conservative defaults
Expand All @@ -51,3 +69,19 @@ def __post_init__(self):
self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
if isinstance(self.tp_mode, str):
self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]

if self.is_domino_enabled is True:
assert self.tp > 1, "Domino requires TP > 1"
# NOTE: For DoMiNo since we overlapping the communication
# so it doesnt matter whether it's all_reduce or reduce_scatter
# so we just support and tested with all_reduce up to now
# but in principle, it should work with reduce_scatter as well
assert (
self.tp_linear_async_communication is False
), "Domino requires TP linear async communication to be False"
# TODO: support REDUCE_SCATTER mode for Domino
assert self.tp_mode == TensorParallelLinearMode.ALL_REDUCE, "Domino requires TP mode to be ALL_REDUCE"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new ticket in tracker to add support for REDUCE_SCATTER please


@property
def is_domino_enabled(self) -> bool:
return True if self.domino else False
4 changes: 4 additions & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@

CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"


### FOR COMMUNICATION ###
CUDA_STREAM_COMM_NAME = "comm_stream_{}"
2 changes: 2 additions & 0 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def get_profiler(config: Config):
on_trace_ready=on_trace_ready,
# record_shapes=True,
# profile_memory=True,
with_flops=True,
with_stack=True,
with_modules=True,
)
else:
prof = contextlib.nullcontext()
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_embeddings_lm_head_tied_names(self) -> list[str]:
Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
"""
return []

def get_named_params_without_weight_decay(self) -> List[str]:
"""Return a list of named parameters that should not have weight decay applied to them."""
return []
Expand Down
Loading
Loading