diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index f11d9e6ab..ceb89daa2 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -121,7 +121,9 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim): torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) ) output = torch.empty_like(tensor_list[0]) - torch.distributed.reduce_scatter(output, tensor_list) + torch.distributed.reduce_scatter( + output, tensor_list, group=get_model_parallel_group() + ) # reconvert to original Bf16/Fp16 dtype if get_fp32_allreduce():