Skip to content

Commit

Permalink
Support dist.all_gather related collective ops (#7860)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Aug 27, 2024
1 parent 98a9372 commit f9a706e
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 2 deletions.
124 changes: 123 additions & 1 deletion test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import numpy as np
from typing import List
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.utils._pytree as pytree
from absl.testing import absltest, parameterized
from unittest import mock
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.debug.metrics as met
from torch_xla._internal import pjrt, tpu


class TestCollectiveOpsTpu(parameterized.TestCase):
# Test for collective ops from xla_model
class TestXMCollectiveOpsTpu(parameterized.TestCase):

@staticmethod
def _broadcast(sync):
Expand Down Expand Up @@ -132,5 +139,120 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


# Test for collective ops from torch.distributed
class TestDistCollectiveOpsTpu(parameterized.TestCase):

# TODO(zpcore): fix the openxla dynamo issue for inplace copy
@staticmethod
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
return gm.forward

@staticmethod
def _all_reduce(use_dynamo: bool):
met.clear_all()

def callable(input):
dist.all_reduce(input, dist.ReduceOp.SUM)
return input

dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()
input = torch.tensor([xr.global_ordinal()],
dtype=torch.float,
device=device)

f = torch.compile(
callable, backend=TestDistCollectiveOpsTpu.my_compiler
) if use_dynamo else callable
f(input)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllReduceInPlace' in met.counter_names(
) or 'xla::AllReduce' in met.counter_names()
else:
assert 'xla::all_reduce' in met.counter_names()
return input.cpu()

@staticmethod
def _all_gather_into_tensor(use_dynamo: bool):
met.clear_all()

def callable(output, input):
dist.all_gather_into_tensor(output_tensor, input, None)
return output_tensor

dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()
input = torch.tensor([xr.global_ordinal()],
dtype=torch.float,
device=device)
output_tensor = torch.empty((1, xr.world_size()), device=device)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
f(output_tensor, input)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllGather' in met.counter_names(
) or 'xla::AllGatherOut' in met.counter_names()
else:
assert 'xla::all_gather_into_tensor' in met.counter_names()
return output_tensor.cpu()

@staticmethod
def _all_gather(use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(input):
output_tensor = [
torch.tensor([0], dtype=torch.float).to(device)
for _ in range(xr.world_size())
]
dist.all_gather(output_tensor, input, None)
return output_tensor

input = torch.tensor([xr.global_ordinal()],
dtype=torch.float,
device=device)

f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(input)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllGather' in met.counter_names(
) or 'xla::AllGatherOut' in met.counter_names()
else:
assert 'xla::all_gather_into_tensor' in met.counter_names()
# output is list of tensors
return pytree.tree_map(lambda x: x.cpu(), output)

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
expected = torch.tensor([sum(range(tpu.num_expected_global_devices()))],
dtype=torch.float)
for index, val in results.items():
torch.testing.assert_close(val, expected)

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_gather_into_tensor(self, use_dynamo):
results = pjrt.run_multiprocess(
self._all_gather_into_tensor, use_dynamo=use_dynamo)
expected = torch.arange(
tpu.num_expected_global_devices(), dtype=torch.float).unsqueeze(0)
for index, val in results.items():
torch.testing.assert_close(val, expected)

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_gather(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_gather, use_dynamo=use_dynamo)
expected = [
torch.tensor([i], dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
for index, val in results.items():
torch.testing.assert_close(val, expected)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
XLA_LIB = Library("xla", "DEF")

from . import xla_model as this_module

xrt_world_size = deprecated(this_module, torch_xla.runtime.world_size,
'xrt_world_size() will be removed in release 2.6.')
get_ordinal = deprecated(
Expand Down Expand Up @@ -464,7 +465,6 @@ def all_reduce(
torch_xla._XLAC._xla_all_reduce_inplace(reduce_type, inputs, scale, groups,
pin_layout)
results = inputs

return results[0] if isinstance(inputs, torch.Tensor) else results


Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,22 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
}

// function signature should match torch/csrc/distributed/c10d/Functional.cpp
at::Tensor all_gather_into_tensor(const at::Tensor& self, int64_t group_size,
std::string group_name) {
TORCH_LAZY_FN_COUNTER("xla::");
auto self_tensor = bridge::GetXlaTensor(self);
std::vector<int64_t> all_groups(group_size);
std::iota(all_groups.begin(), all_groups.end(), 0);
auto result = tensor_methods::all_gather(self_tensor, 0, group_size,
{all_groups}, true);
return bridge::AtenFromXlaTensor(result);
}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_gather_into_tensor", all_gather_into_tensor);
}

AllGatherResultCoalesced BuildAllGatherCoalesced(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ void AllReduceInPlace(const std::string& reduce_type,
const std::vector<at::Tensor>& tensors, double scale,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
tensor_methods::all_reduce(xtensors, GetReduceType(reduce_type), scale,
Expand All @@ -308,6 +309,7 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input,
double scale,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
auto result = tensor_methods::all_reduce(bridge::GetXlaTensor(input),
GetReduceType(reduce_type), scale,
replica_groups, pin_layout);
Expand Down Expand Up @@ -430,6 +432,7 @@ std::shared_ptr<torch::lazy::Value> ReduceScatterCoalescedOut(
at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
auto result =
tensor_methods::all_gather(bridge::GetXlaTensor(input), dim, shard_count,
replica_groups, pin_layout);
Expand All @@ -441,6 +444,7 @@ std::shared_ptr<torch::lazy::Value> AllGatherOut(
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr out = bridge::GetXlaTensor(output);
torch::lazy::Value new_token;
new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input),
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def __init__(self, prefix_store, rank, size, timeout):
def getBackendName(self):
return 'xla'

# pytorch's process group is unable to retrive the group size from python level. It should
# already been support in C++ level: https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152
# For now we manually set the group name property as a temporary solution.
def _set_group_name(self, name: str) -> None:
self._group_name = name

@property
def group_name(self):
return self._group_name

def _get_reduce_type(self, reduce_op):
if reduce_op == dist.ReduceOp.SUM:
return xm.REDUCE_SUM
Expand All @@ -71,6 +81,10 @@ def allreduce(self, tensors, all_reduce_options):
xm.all_reduce(reduce_type, tensors, groups=self._mesh, pin_layout=False)
return _ret_work(tensors)

# method for dist.all_gather_into_tensor under eager mode.
def _allgather_base(self, output_tensor, input_tensor, opts):
return self.allgather(output_tensor, input_tensor, opts)

def allgather(self, output_tensors_list, input_tensors, opts=None):
for input_tensor, output_tensors in zip(input_tensors, output_tensors_list):
is_scalar = (input_tensor.dim() == 0)
Expand Down

0 comments on commit f9a706e

Please sign in to comment.