diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 95aa5f5586e..1c71e1be46e 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -139,7 +139,7 @@ def test_all_to_all(self, pin_layout): list(range(world_size))]]) -@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2, +@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2, "Dynamo not supported on TPU v2/v3") class TestDistCollectiveOpsTpu(parameterized.TestCase): """Test for collective ops from torch.distributed""" @@ -246,6 +246,32 @@ def callable(output, input): assert 'xla::reduce_scatter_tensor' in met.counter_names() return output.cpu() + @staticmethod + def _all_to_all_single(use_dynamo: bool): + met.clear_all() + dist.init_process_group("xla", init_method='xla://') + device = xm.xla_device() + + def callable(output, input): + dist.all_to_all_single(output, input) + return output + + # check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3880 + # for input and output tensor example + tensor_in = torch.tensor( + [xr.local_ordinal()] * tpu.num_expected_global_devices(), + dtype=torch.float, + device=device) + tensor_out = torch.zeros_like(tensor_in) + f = torch.compile(callable, backend='openxla') if use_dynamo else callable + output = f(tensor_out, tensor_in) + torch_xla.sync() + if not use_dynamo: + assert 'xla::AllToAll' in met.counter_names() + else: + assert 'xla::all_to_all_single' in met.counter_names() + return output.cpu() + @parameterized.named_parameters(('dynamo', True), ('nondynamo', False)) def test_all_reduce(self, use_dynamo): results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo) @@ -287,6 +313,17 @@ def test_reduce_scatter(self, use_dynamo): for index, val in results.items(): torch.testing.assert_close(val, expected[index]) + @parameterized.named_parameters(('dynamo', True), ('nondynamo', False)) + def test_all_to_all_single(self, use_dynamo): + results = pjrt.run_multiprocess( + self._all_to_all_single, use_dynamo=use_dynamo) + expected = torch.arange( + tpu.num_expected_global_devices(), dtype=torch.float) + # Note: AllToAll xla op does not honor the order of the all_to_all, which means + # the rank may not follow the order. + for _, val in results.items(): + self.assertTrue(torch.allclose(val.sort().values, expected.sort().values)) + if __name__ == '__main__': absltest.main() diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 214c5f08afa..f70b83dbddc 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -359,7 +359,6 @@ def test_barrier(self): 'reduce', 'allreduce_coalesced', 'alltoall', - 'alltoall_base', 'gather', 'scatter', 'recv_anysource', diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 9f8248f0def..d798358189a 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -11,6 +11,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_methods.h" @@ -309,6 +310,44 @@ AllGatherResultCoalesced BuildAllGatherCoalesced( return {result, token_handler.GetNewToken(result[0])}; } +at::Tensor all_to_all_single(const at::Tensor& input, + std::vector output_split_sizes, + std::vector input_split_sizes, + std::string group_name) { + // this basically is the code copy from + // init_python_bindings.cpp:_xla_all_to_all + TORCH_LAZY_FN_COUNTER("xla::"); + if (output_split_sizes.size() != 0 && input_split_sizes.size() != 0) { + for (size_t i = 0; i < input_split_sizes.size(); i++) { + if (input_split_sizes[i] != 1) + throw std::runtime_error( + "torch_xla does not support arbitrary split sizes for all_to_all"); + } + } + bool pin_layout = false; + const torch::lazy::Value& token = + GetAllReduceToken(bridge::GetCurrentDevice()); + int64_t split_count = runtime::GetComputationClient()->GetAllDevices().size(); + std::vector all_groups(split_count); + std::iota(all_groups.begin(), all_groups.end(), 0); + XLATensorPtr result_ptr; + torch::lazy::Value new_token; + std::tie(result_ptr, new_token) = + tensor_methods::all_to_all(bridge::GetXlaTensor(input), token, 0, 0, + split_count, {all_groups}, pin_layout); + at::Tensor result = bridge::AtenFromXlaTensor(std::move(result_ptr)); + + at::Tensor result_with_grad = torch::autograd::make_variable( + result, /*requires_grad=*/input.requires_grad()); + SetAllReduceToken(bridge::GetCurrentDevice(), + std::make_shared(new_token)); + return result_with_grad; +} + +TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) { + m.impl("all_to_all_single", all_to_all_single); +} + CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, const std::vector>& source_target_pairs) { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index bb8c75c2322..62c4cc9fc9e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -494,6 +494,7 @@ std::pair> AllToAll( const at::Tensor& input, const std::shared_ptr& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, const std::vector>& replica_groups, bool pin_layout) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::all_to_all( diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 9c6c268a9de..7222a7bf3dc 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -215,8 +215,19 @@ def allreduce_coalesced(self, *args): def alltoall(self, *args): raise NotImplementedError - def alltoall_base(self, *args): - raise NotImplementedError + # handle the nondynamo path when call torch.distributed.all_to_all_single + # call from https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3996 + # Note for pytorch, the split/concat dimension is always 0, while for XLA alltoall, + # we can't specify different split sizes. + def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, + opts): + assert (output_split_sizes is None or len(output_split_sizes) == 0) and \ + (input_split_sizes is None or len(input_split_sizes) == 0), \ + "XLA doesn't support specifying non-empty output_split_sizes and input_split_sizes" + split_count = xr.world_size() + result = xm.all_to_all(input, 0, 0, split_count, pin_layout=False) + output.copy_(result) + return _ret_work(output) def gather(self, *args): raise NotImplementedError