Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Hide 75% of the communication in tensor parallelism using DoMiNo #292

Open
wants to merge 38 commits into
base: main
Choose a base branch
from

Conversation

xrsrke
Copy link
Member

@xrsrke xrsrke commented Mar 10, 2025

Reproducing the paper "Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping" https://arxiv.org/abs/2409.15241

The losses are match after 20b tokens with 2m batch size, 20k steps, and fineweb dataset with 75% communication hiding for tensor parallelism,

image

The first PR is ready for review (i split it to two PRs), some left work for the next PR:

  • intra-layer overlapping (current overlapping communication within a layer), but if we do intra-layer overlapping, then we can almost totally overlapping the comm
  • create an fixed buffer to concat hidden states (1st image)
  • double check if there is cuda stream switching's overhead (4.3.1 in the 2nd image)
  • minimize kernel launch overhead (cuda graph? 4.3.2 in 2nd image)

Profiling results:

image

image

/fsx/phuc/new_workspace/experiments/nanotron_domino/profilings/exp7a11_like_exp7a6_but_remove_fwd_pass_cuda_syncronization_and_remove_cuda_syncronize_in_wait_comm_bwd_and_add_comm_syncronize_in_waitcomm_and_remove_explicite_async_op_arg_and_commit_600f01/20250228-160428/ip-26-0-161-142_51797.1740758749919300440.pt.trace.json

xrsrke and others added 30 commits January 29, 2025 12:47
…_mbs2_and_gbs_300k_and_input_splitting_and_commit_23f2_but_remove_call_is_async_comm_twice_and_keep_not_async_bwd.layer_mlp_1__and_bwd.layer_attn_0
- execute backward comm in a separate stream
    - make commm stream in the backward pass wait for compute stream before run backward comm
- make WaitComm’s compute stream to wait for the comm stream
…omm, and remove torch.cuda.synchronize() in WaitComm
…e_cuda_syncronize_in_wait_comm_bwd_and_add_comm_syncronize_in_waitcomm_and_commit_543ef56
…x_stream_not_sync_exp2a1c7_and_commit_23f2_and_75_percent_bwd_overlapping_with_cuda_stream_sync_bwd
@xrsrke xrsrke changed the title Xrsrke/exp7a13b0 domino revert from fix stream not sync exp2a1c7 and commit 23f2 and 75 percent bwd overlapping with cuda stream sync bwd but remove stream manager ctx [Feature] Hide 75% of the communication in tensor parallelism using DoMiNo Mar 10, 2025
Comment on lines +787 to +788
hidden_states0 = self.input_layernorm(hidden_states0)
hidden_states1 = self.input_layernorm(hidden_states1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on #285 (comment)
I think we still need to add a TODO: comment here. Because ideally we want to interleave (overlap) this layernorm with some other op (either following fwd, or backward, or both)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And would be nice to add more comments in this domino class about what's overlapped (either at the top of the fwd, or before each op being overlapped)

@@ -687,51 +701,39 @@ def forward(
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
# output, work = self.o_proj(attention_output, op_name=op_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean comment

self.tp_linear_async_communication is False
), "Domino requires TP linear async communication to be False"
# TODO: support REDUCE_SCATTER mode for Domino
assert self.tp_mode == TensorParallelLinearMode.ALL_REDUCE, "Domino requires TP mode to be ALL_REDUCE"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new ticket in tracker to add support for REDUCE_SCATTER please

Comment on lines 5 to 6
from torch import nn
from torch.nn.parallel import DistributedDataParallel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unrelated to this PR right? Im refactor this engine, I can take care of this change

BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}"
BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}"

_operation_context = threading.local()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BWD_ATTN_OP_NAME

because we recall these name many places in the code, I want to make it consistent, so if we change the name, we don't have to manually replace in other places

performs all-reduce asynchronously in tensor parallelism
"""
NON_ASYNC_HANDLE_IDX = [
# "fwd.layer_mlp_{}_batch_1",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup?

Comment on lines 14 to 17
"""
Determine whether a module (e.g., mlp, attention)
performs all-reduce asynchronously in tensor parallelism
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue the description of this function.. how do we determine it? what do we check?

Comment on lines +10 to +24
class AsyncCommBucket:
"""
Store aynchronous communication operations.
"""

def __init__(self):
self._async_op: Dict[int, "dist.Work"] = {}
self._copy_async_op: Dict[int, "dist.Work"] = {}

def add(self, op_name: int, work: "dist.Work"):
assert op_name not in self._async_op, f"Operation with name: {op_name} already exists"
assert work is not None
self._async_op[op_name] = work
self._copy_async_op[op_name] = work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure we don't have an equivalent of this class in torch? o.O


not_finished = []
for k, v in self._copy_async_op.items():
assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont like the mention of domino here. this CommBucket should be independent of domino

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants