Skip to content

Commit

Permalink
support dist.reduce_scatter_tensor (#7950)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Sep 4, 2024
1 parent f45775a commit b1fbb74
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 0 deletions.
39 changes: 39 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,31 @@ def callable(input):
# output is list of tensors
return pytree.tree_map(lambda x: x.cpu(), output)

@staticmethod
def _reduce_scatter(use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(output, input):
dist.reduce_scatter_tensor(output, input)
return output

# check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3766-L3814
# for input and output tensor example
tensor_in = torch.arange(
xr.world_size() * 2, dtype=torch.float, device=device)
tensor_out = torch.zeros(2, dtype=torch.float, device=device)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(tensor_out, tensor_in)
torch_xla.sync()
if not use_dynamo:
assert 'xla::ReduceScatter' in met.counter_names(
) or 'xla::ReduceScatterOut' in met.counter_names()
else:
assert 'xla::reduce_scatter_tensor' in met.counter_names()
return output.cpu()

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
Expand Down Expand Up @@ -248,6 +273,20 @@ def test_all_gather(self, use_dynamo):
for index, val in results.items():
torch.testing.assert_close(val, expected)

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_reduce_scatter(self, use_dynamo):
results = pjrt.run_multiprocess(self._reduce_scatter, use_dynamo=use_dynamo)
expected = [
torch.tensor([
2 * i * tpu.num_expected_global_devices(),
(2 * i + 1) * tpu.num_expected_global_devices()
],
dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
for index, val in results.items():
torch.testing.assert_close(val, expected[index])


if __name__ == '__main__':
absltest.main()
40 changes: 40 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,46 @@ xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
return reduce_result;
}

// wrapper of BuildReduceScatter to adapt upstream dist.reduce_scatter_tensor()
// This path is only for dynamo called from
// https://github.com/pytorch/pytorch/blob/85fa01969719dab91eac3e02dd193c7d20d0e87f/torch/distributed/_functional_collectives.py#L1039
// Function signature should match with
// https://github.com/pytorch/pytorch/blob/85fa01969719dab91eac3e02dd193c7d20d0e87f/torch/csrc/distributed/c10d/Functional.cpp#L356
// to dispatch.
at::Tensor reduce_scatter_tensor(const at::Tensor& input, std::string reduce_op,
int64_t group_size, std::string group_name) {
TORCH_LAZY_FN_COUNTER("xla::");
auto self = bridge::GetXlaTensor(input);
std::vector<int64_t> all_groups(group_size);
std::iota(all_groups.begin(), all_groups.end(), 0);
int64_t shard_count = group_size;
AllReduceType all_reduce_type;
if (reduce_op == "sum") {
all_reduce_type = AllReduceType::kSum;
} else if (reduce_op == "min") {
all_reduce_type = AllReduceType::kMin;
} else if (reduce_op == "max") {
all_reduce_type = AllReduceType::kMax;
} else if (reduce_op == "mul") {
all_reduce_type = AllReduceType::kMul;
} else if (reduce_op == "or") {
all_reduce_type = AllReduceType::kOr;
} else if (reduce_op == "and") {
all_reduce_type = AllReduceType::kAnd;
} else {
throw std::invalid_argument("Invalid string for AllReduceType: " +
reduce_op);
}
// reduce dim is limited to the first dim due to the fixed function signature.
XLATensorPtr output = tensor_methods::reduce_scatter(
self, all_reduce_type, 1.0, 0, shard_count, {all_groups});
return bridge::AtenFromXlaTensor(output);
}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("reduce_scatter_tensor", reduce_scatter_tensor);
}

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::reduce_scatter(
Expand All @@ -385,6 +386,7 @@ std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr out = bridge::GetXlaTensor(output);
torch::lazy::Value new_token;
new_token = tensor_methods::reduce_scatter_out(
Expand Down
28 changes: 28 additions & 0 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,34 @@ def reduce_scatter_coalesced(self, output_tensors, input_tensors_list, opts):

return _ret_work(output_tensors)

# call site https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3856
def _reduce_scatter_base(self, output_tensor, input_tensor, opts):
"""
Reduces, then scatters a flattened tensor to all processes in a group.
Args:
output (Tensor): Output tensor.
input (Tensor): Input tensor that is of size output tensor size times world size
opts: distributed reduce op (ReduceOp).
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group.
"""
reduce_type = self._get_reduce_type(opts.reduceOp)
groups = self._mesh
shard_count = len(groups[0]) if groups else self.size()
xm.reduce_scatter(
reduce_type,
input_tensor,
scatter_dim=0,
shard_count=shard_count,
scale=1.0,
groups=groups,
output=output_tensor,
pin_layout=False)
return _ret_work(output_tensor)

# Call site:
# https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683
def barrier(self, opts):
Expand Down

0 comments on commit b1fbb74

Please sign in to comment.