Skip to content

Commit

Permalink
[SPMD] Support reduce-scatter in manual sharding (#7231)
Browse files Browse the repository at this point in the history
Summary:
This PR is to add experimental support of cc ops in manual sharding zones. This one adds reduce-scatter as the initial step. The key here is to add channel_id, replica_groups, and use_global_device_ids in the lowering.

Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_spmd_reduce_scatter
  • Loading branch information
alanwaketan committed Jun 11, 2024
1 parent ac371fb commit 70d2e9e
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 0 deletions.
40 changes: 40 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,46 @@ def test_manual_sharding_api_e2e(self):
self.assertEqual(xxx.shape, (8, 8))
self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device")
def test_spmd_reduce_scatter(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# Reduce scatter
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, x, 1.0, 0,
self.n_devices,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(2, 8) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device")
def test_spmd_reduce_scatter_canonical_index(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# Reduce scatter
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, x, 1.0, -1,
self.n_devices,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 2) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))


if __name__ == '__main__':
test = unittest.main()
Expand Down
24 changes: 24 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,30 @@ ReduceScatterResult BuildReduceScatter(
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
double scale, int64_t scatter_dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
// Just a dummy channel handle, and it's required to set the
// use_global_device_ids which is requried for SPMD.
xla::ChannelHandle channel_handle;
channel_handle.set_handle(1);
channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
xla::XlaOp reduce_result;
reduce_result = xla::ReduceScatter(
input, GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, std::move(reduce_groups),
std::move(channel_handle), std::nullopt, true);
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, input_shape.element_type(), input.builder());
reduce_result = reduce_result * scaling_value;
}
return reduce_result;
}

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ ReduceScatterResult BuildReduceScatter(
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
double scale, int64_t scatter_dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups);

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,17 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def(
"_xla_spmd_reduce_scatter",
[](const std::string& reduce_type, const at::Tensor& input, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
auto result = tensor_methods::reduce_scatter(
bridge::GetXlaTensor(input), GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups);
return bridge::AtenFromXlaTensor(std::move(result));
});
m.def("_xla_reduce_scatter",
[](const std::string& reduce_type, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
Expand Down
38 changes: 38 additions & 0 deletions torch_xla/csrc/ops/reduce_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type,
return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn);
}

xla::Shape NodeOutputShape(AllReduceType reduce_type,
const torch::lazy::Value input, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
xla::XlaOp inputOp = operands[0];
return BuildReduceScatter(reduce_type, inputOp, scale, scatter_dim,
shard_count, groups);
};
return InferOutputShape({GetXlaShape(input)}, shape_fn);
}

xla::Shape NodeOutputShapeCoalesced(
AllReduceType reduce_type, c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& token, double scale, int64_t scatter_dim,
Expand Down Expand Up @@ -73,6 +85,27 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type,
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

ReduceScatter::ReduceScatter(AllReduceType reduce_type,
const torch::lazy::Value& input, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups)
: XlaNode(
xla_reduce_scatter, {input},
[&]() {
return NodeOutputShape(reduce_type, input, scale, scatter_dim,
shard_count, groups);
},
/*num_outputs=*/1,
torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale,
scatter_dim, shard_count, groups)),
reduce_type_(reduce_type),
scale_(scale),
scatter_dim_(scatter_dim),
shard_count_(shard_count),
groups_(std::move(groups)),
pin_layout_(false),
has_token_(false) {}

ReduceScatterCoalesced::ReduceScatterCoalesced(
AllReduceType reduce_type, c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& token, double scale, int64_t scatter_dim,
Expand Down Expand Up @@ -111,6 +144,11 @@ torch::lazy::NodePtr ReduceScatterCoalesced::Clone(

XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
if (!has_token_) {
auto result = BuildReduceScatter(reduce_type_, input, scale_, scatter_dim_,
shard_count_, groups_);
return ReturnOp(result, loctx);
}
xla::XlaOp token = loctx->GetOutputOp(operand(1));
ReduceScatterResult result =
BuildReduceScatter(reduce_type_, input, token, scale_, scatter_dim_,
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/reduce_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class ReduceScatter : public XlaNode {
const torch::lazy::Value& token, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout);
ReduceScatter(AllReduceType reduce_type, const torch::lazy::Value& input,
double scale, int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups);

std::string ToString() const override;

Expand All @@ -34,6 +37,7 @@ class ReduceScatter : public XlaNode {
int64_t shard_count_;
std::vector<std::vector<int64_t>> groups_;
bool pin_layout_;
bool has_token_{true};
};

class ReduceScatterCoalesced : public XlaNode {
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,17 @@ std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
torch::lazy::Value(node, 1)};
}

XLATensorPtr reduce_scatter(const XLATensorPtr& input,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups) {
auto canonical_scatter_dim = torch::lazy::GetCanonicalDimensionIndex(
scatter_dim, input->shape().get().rank());
return input->CreateFrom(torch::lazy::MakeNode<ReduceScatter>(
reduce_type, input->GetIrValue(), scale, canonical_scatter_dim,
shard_count, std::move(groups)));
}

torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
const XLATensorPtr& input,
const torch::lazy::Value& token,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

XLATensorPtr reduce_scatter(const XLATensorPtr& input,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups);

torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
const XLATensorPtr& input,
const torch::lazy::Value& token,
Expand Down

0 comments on commit 70d2e9e

Please sign in to comment.