Skip to content

Commit

Permalink
Add model parallel group to reduce scatter (#1281)
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyang committed Sep 15, 2024
1 parent d79c533 commit f281210
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)
torch.distributed.reduce_scatter(
output, tensor_list, group=get_model_parallel_group()
)

# reconvert to original Bf16/Fp16 dtype
if get_fp32_allreduce():
Expand Down

0 comments on commit f281210

Please sign in to comment.