From b1fbb7490f13d8395572f3b26ad426b17f53be19 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Wed, 4 Sep 2024 10:20:37 -0700 Subject: [PATCH] support dist.reduce_scatter_tensor (#7950) --- test/pjrt/test_collective_ops_tpu.py | 39 +++++++++++++++++++++++ torch_xla/csrc/cross_replica_reduces.cpp | 40 ++++++++++++++++++++++++ torch_xla/csrc/init_python_bindings.cpp | 2 ++ torch_xla/distributed/xla_backend.py | 28 +++++++++++++++++ 4 files changed, 109 insertions(+) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index f5afb60df36..95aa5f5586e 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -221,6 +221,31 @@ def callable(input): # output is list of tensors return pytree.tree_map(lambda x: x.cpu(), output) + @staticmethod + def _reduce_scatter(use_dynamo: bool): + met.clear_all() + dist.init_process_group("xla", init_method='xla://') + device = xm.xla_device() + + def callable(output, input): + dist.reduce_scatter_tensor(output, input) + return output + + # check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3766-L3814 + # for input and output tensor example + tensor_in = torch.arange( + xr.world_size() * 2, dtype=torch.float, device=device) + tensor_out = torch.zeros(2, dtype=torch.float, device=device) + f = torch.compile(callable, backend='openxla') if use_dynamo else callable + output = f(tensor_out, tensor_in) + torch_xla.sync() + if not use_dynamo: + assert 'xla::ReduceScatter' in met.counter_names( + ) or 'xla::ReduceScatterOut' in met.counter_names() + else: + assert 'xla::reduce_scatter_tensor' in met.counter_names() + return output.cpu() + @parameterized.named_parameters(('dynamo', True), ('nondynamo', False)) def test_all_reduce(self, use_dynamo): results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo) @@ -248,6 +273,20 @@ def test_all_gather(self, use_dynamo): for index, val in results.items(): torch.testing.assert_close(val, expected) + @parameterized.named_parameters(('dynamo', True), ('nondynamo', False)) + def test_reduce_scatter(self, use_dynamo): + results = pjrt.run_multiprocess(self._reduce_scatter, use_dynamo=use_dynamo) + expected = [ + torch.tensor([ + 2 * i * tpu.num_expected_global_devices(), + (2 * i + 1) * tpu.num_expected_global_devices() + ], + dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] + for index, val in results.items(): + torch.testing.assert_close(val, expected[index]) + if __name__ == '__main__': absltest.main() diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 31e2a484b6d..9f8248f0def 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -406,6 +406,46 @@ xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input, return reduce_result; } +// wrapper of BuildReduceScatter to adapt upstream dist.reduce_scatter_tensor() +// This path is only for dynamo called from +// https://github.com/pytorch/pytorch/blob/85fa01969719dab91eac3e02dd193c7d20d0e87f/torch/distributed/_functional_collectives.py#L1039 +// Function signature should match with +// https://github.com/pytorch/pytorch/blob/85fa01969719dab91eac3e02dd193c7d20d0e87f/torch/csrc/distributed/c10d/Functional.cpp#L356 +// to dispatch. +at::Tensor reduce_scatter_tensor(const at::Tensor& input, std::string reduce_op, + int64_t group_size, std::string group_name) { + TORCH_LAZY_FN_COUNTER("xla::"); + auto self = bridge::GetXlaTensor(input); + std::vector all_groups(group_size); + std::iota(all_groups.begin(), all_groups.end(), 0); + int64_t shard_count = group_size; + AllReduceType all_reduce_type; + if (reduce_op == "sum") { + all_reduce_type = AllReduceType::kSum; + } else if (reduce_op == "min") { + all_reduce_type = AllReduceType::kMin; + } else if (reduce_op == "max") { + all_reduce_type = AllReduceType::kMax; + } else if (reduce_op == "mul") { + all_reduce_type = AllReduceType::kMul; + } else if (reduce_op == "or") { + all_reduce_type = AllReduceType::kOr; + } else if (reduce_op == "and") { + all_reduce_type = AllReduceType::kAnd; + } else { + throw std::invalid_argument("Invalid string for AllReduceType: " + + reduce_op); + } + // reduce dim is limited to the first dim due to the fixed function signature. + XLATensorPtr output = tensor_methods::reduce_scatter( + self, all_reduce_type, 1.0, 0, shard_count, {all_groups}); + return bridge::AtenFromXlaTensor(output); +} + +TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) { + m.impl("reduce_scatter_tensor", reduce_scatter_tensor); +} + ReduceScatterResultCoalesced BuildReduceScatterCoalesced( AllReduceType reduce_type, absl::Span inputs, xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4689787bc56..6be2bd62872 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -370,6 +370,7 @@ std::pair> ReduceScatter( const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::reduce_scatter( @@ -385,6 +386,7 @@ std::shared_ptr ReduceScatterOut( const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr out = bridge::GetXlaTensor(output); torch::lazy::Value new_token; new_token = tensor_methods::reduce_scatter_out( diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 4c0c173fdcf..9c6c268a9de 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -170,6 +170,34 @@ def reduce_scatter_coalesced(self, output_tensors, input_tensors_list, opts): return _ret_work(output_tensors) + # call site https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3856 + def _reduce_scatter_base(self, output_tensor, input_tensor, opts): + """ + Reduces, then scatters a flattened tensor to all processes in a group. + + Args: + output (Tensor): Output tensor. + input (Tensor): Input tensor that is of size output tensor size times world size + opts: distributed reduce op (ReduceOp). + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + """ + reduce_type = self._get_reduce_type(opts.reduceOp) + groups = self._mesh + shard_count = len(groups[0]) if groups else self.size() + xm.reduce_scatter( + reduce_type, + input_tensor, + scatter_dim=0, + shard_count=shard_count, + scale=1.0, + groups=groups, + output=output_tensor, + pin_layout=False) + return _ret_work(output_tensor) + # Call site: # https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683 def barrier(self, opts):