Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dist.all_to_all_single #8064

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2,
@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2,
"Dynamo not supported on TPU v2/v3")
class TestDistCollectiveOpsTpu(parameterized.TestCase):
"""Test for collective ops from torch.distributed"""
Expand Down Expand Up @@ -246,6 +246,32 @@ def callable(output, input):
assert 'xla::reduce_scatter_tensor' in met.counter_names()
return output.cpu()

@staticmethod
def _all_to_all_single(use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(output, input):
dist.all_to_all_single(output, input)
return output

# check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3880
# for input and output tensor example
tensor_in = torch.tensor(
[xr.local_ordinal()] * tpu.num_expected_global_devices(),
dtype=torch.float,
device=device)
tensor_out = torch.zeros_like(tensor_in)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(tensor_out, tensor_in)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllToAll' in met.counter_names()
else:
assert 'xla::all_to_all_single' in met.counter_names()
return output.cpu()

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
Expand Down Expand Up @@ -287,6 +313,17 @@ def test_reduce_scatter(self, use_dynamo):
for index, val in results.items():
torch.testing.assert_close(val, expected[index])

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_to_all_single(self, use_dynamo):
results = pjrt.run_multiprocess(
self._all_to_all_single, use_dynamo=use_dynamo)
expected = torch.arange(
tpu.num_expected_global_devices(), dtype=torch.float)
# Note: AllToAll xla op does not honor the order of the all_to_all, which means
# the rank may not follow the order.
for _, val in results.items():
self.assertTrue(torch.allclose(val.sort().values, expected.sort().values))


if __name__ == '__main__':
absltest.main()
1 change: 0 additions & 1 deletion test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def test_barrier(self):
'reduce',
'allreduce_coalesced',
'alltoall',
'alltoall_base',
'gather',
'scatter',
'recv_anysource',
Expand Down
39 changes: 39 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_methods.h"
Expand Down Expand Up @@ -309,6 +310,44 @@ AllGatherResultCoalesced BuildAllGatherCoalesced(
return {result, token_handler.GetNewToken(result[0])};
}

at::Tensor all_to_all_single(const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
std::string group_name) {
// this basically is the code copy from
// init_python_bindings.cpp:_xla_all_to_all
TORCH_LAZY_FN_COUNTER("xla::");
if (output_split_sizes.size() != 0 && input_split_sizes.size() != 0) {
for (size_t i = 0; i < input_split_sizes.size(); i++) {
if (input_split_sizes[i] != 1)
throw std::runtime_error(
"torch_xla does not support arbitrary split sizes for all_to_all");
}
}
bool pin_layout = false;
const torch::lazy::Value& token =
GetAllReduceToken(bridge::GetCurrentDevice());
int64_t split_count = runtime::GetComputationClient()->GetAllDevices().size();
std::vector<int64_t> all_groups(split_count);
std::iota(all_groups.begin(), all_groups.end(), 0);
XLATensorPtr result_ptr;
torch::lazy::Value new_token;
std::tie(result_ptr, new_token) =
tensor_methods::all_to_all(bridge::GetXlaTensor(input), token, 0, 0,
split_count, {all_groups}, pin_layout);
at::Tensor result = bridge::AtenFromXlaTensor(std::move(result_ptr));

at::Tensor result_with_grad = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
SetAllReduceToken(bridge::GetCurrentDevice(),
std::make_shared<torch::lazy::Value>(new_token));
return result_with_grad;
}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_to_all_single", all_to_all_single);
}

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> AllToAll(
const at::Tensor& input, const std::shared_ptr<torch::lazy::Value>& token,
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::all_to_all(
Expand Down
15 changes: 13 additions & 2 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,19 @@ def allreduce_coalesced(self, *args):
def alltoall(self, *args):
raise NotImplementedError

def alltoall_base(self, *args):
raise NotImplementedError
# handle the nondynamo path when call torch.distributed.all_to_all_single
# call from https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3996
# Note for pytorch, the split/concat dimension is always 0, while for XLA alltoall,
# we can't specify different split sizes.
def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
opts):
assert (output_split_sizes is None or len(output_split_sizes) == 0) and \
(input_split_sizes is None or len(input_split_sizes) == 0), \
"XLA doesn't support specifying non-empty output_split_sizes and input_split_sizes"
split_count = xr.world_size()
result = xm.all_to_all(input, 0, 0, split_count, pin_layout=False)
output.copy_(result)
return _ret_work(output)

def gather(self, *args):
raise NotImplementedError
Expand Down
Loading