From 70d2e9eab65dafb9dc51492ee10f9dc141129dfa Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Mon, 10 Jun 2024 23:57:03 -0700 Subject: [PATCH] [SPMD] Support reduce-scatter in manual sharding (#7231) Summary: This PR is to add experimental support of cc ops in manual sharding zones. This one adds reduce-scatter as the initial step. The key here is to add channel_id, replica_groups, and use_global_device_ids in the lowering. Test Plan: PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_spmd_reduce_scatter --- test/spmd/test_xla_sharding.py | 40 ++++++++++++++++++++++++ torch_xla/csrc/cross_replica_reduces.cpp | 24 ++++++++++++++ torch_xla/csrc/cross_replica_reduces.h | 5 +++ torch_xla/csrc/init_python_bindings.cpp | 11 +++++++ torch_xla/csrc/ops/reduce_scatter.cpp | 38 ++++++++++++++++++++++ torch_xla/csrc/ops/reduce_scatter.h | 4 +++ torch_xla/csrc/tensor_methods.cpp | 11 +++++++ torch_xla/csrc/tensor_methods.h | 5 +++ 8 files changed, 138 insertions(+) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index c6a30fd4bd8..2d710a7c7c1 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1213,6 +1213,46 @@ def test_manual_sharding_api_e2e(self): self.assertEqual(xxx.shape, (8, 8)) self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu())) + @unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device") + def test_spmd_reduce_scatter(self): + xs.set_global_mesh(self._get_mesh((1, self.n_devices))) + x = torch.ones(8, 8).to(xm.xla_device()) + + # Reduce scatter + x = xs.enable_manual_sharding(x, (None, None)).global_tensor + x = torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, x, 1.0, 0, + self.n_devices, + [self.device_ids]) + x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor + + hlo = torch_xla._XLAC._get_xla_tensors_hlo([x]) + self.assertIn( + f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3", + hlo) + + expected_x = torch.ones(2, 8) * 4 + self.assertTrue(torch.allclose(x.cpu(), expected_x)) + + @unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device") + def test_spmd_reduce_scatter_canonical_index(self): + xs.set_global_mesh(self._get_mesh((1, self.n_devices))) + x = torch.ones(8, 8).to(xm.xla_device()) + + # Reduce scatter + x = xs.enable_manual_sharding(x, (None, None)).global_tensor + x = torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, x, 1.0, -1, + self.n_devices, + [self.device_ids]) + x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor + + hlo = torch_xla._XLAC._get_xla_tensors_hlo([x]) + self.assertIn( + f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3", + hlo) + + expected_x = torch.ones(8, 2) * 4 + self.assertTrue(torch.allclose(x.cpu(), expected_x)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index d140a486223..72b7eed9b84 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -345,6 +345,30 @@ ReduceScatterResult BuildReduceScatter( return {reduce_result, token_handler.GetNewToken(reduce_result)}; } +xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input, + double scale, int64_t scatter_dim, + int64_t shard_count, + const std::vector>& groups) { + std::vector reduce_groups = CreateReduceGroups(groups); + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + // Just a dummy channel handle, and it's required to set the + // use_global_device_ids which is requried for SPMD. + xla::ChannelHandle channel_handle; + channel_handle.set_handle(1); + channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE); + xla::XlaOp reduce_result; + reduce_result = xla::ReduceScatter( + input, GetReduceComutation(reduce_type, input_shape.element_type()), + scatter_dim, shard_count, std::move(reduce_groups), + std::move(channel_handle), std::nullopt, true); + if (scale != 1.0) { + xla::XlaOp scaling_value = XlaHelpers::ScalarValue( + scale, input_shape.element_type(), input.builder()); + reduce_result = reduce_result * scaling_value; + } + return reduce_result; +} + ReduceScatterResultCoalesced BuildReduceScatterCoalesced( AllReduceType reduce_type, absl::Span inputs, xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index ade1a0fa00e..6993309fd5e 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -96,6 +96,11 @@ ReduceScatterResult BuildReduceScatter( int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout); +xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input, + double scale, int64_t scatter_dim, + int64_t shard_count, + const std::vector>& groups); + ReduceScatterResultCoalesced BuildReduceScatterCoalesced( AllReduceType reduce_type, absl::Span inputs, xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 60f6c92f00b..bb0451d5603 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1593,6 +1593,17 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }); + m.def( + "_xla_spmd_reduce_scatter", + [](const std::string& reduce_type, const at::Tensor& input, double scale, + int64_t scatter_dim, int64_t shard_count, const py::list& groups) { + std::vector> replica_groups = + CreateReduceGroups(groups); + auto result = tensor_methods::reduce_scatter( + bridge::GetXlaTensor(input), GetReduceType(reduce_type), scale, + scatter_dim, shard_count, replica_groups); + return bridge::AtenFromXlaTensor(std::move(result)); + }); m.def("_xla_reduce_scatter", [](const std::string& reduce_type, const at::Tensor& input, const std::shared_ptr& token, double scale, diff --git a/torch_xla/csrc/ops/reduce_scatter.cpp b/torch_xla/csrc/ops/reduce_scatter.cpp index 305a70abed5..f36106d7f94 100644 --- a/torch_xla/csrc/ops/reduce_scatter.cpp +++ b/torch_xla/csrc/ops/reduce_scatter.cpp @@ -28,6 +28,18 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type, return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn); } +xla::Shape NodeOutputShape(AllReduceType reduce_type, + const torch::lazy::Value input, double scale, + int64_t scatter_dim, int64_t shard_count, + const std::vector>& groups) { + auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { + xla::XlaOp inputOp = operands[0]; + return BuildReduceScatter(reduce_type, inputOp, scale, scatter_dim, + shard_count, groups); + }; + return InferOutputShape({GetXlaShape(input)}, shape_fn); +} + xla::Shape NodeOutputShapeCoalesced( AllReduceType reduce_type, c10::ArrayRef inputs, const torch::lazy::Value& token, double scale, int64_t scatter_dim, @@ -73,6 +85,27 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type, groups_(std::move(groups)), pin_layout_(pin_layout) {} +ReduceScatter::ReduceScatter(AllReduceType reduce_type, + const torch::lazy::Value& input, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups) + : XlaNode( + xla_reduce_scatter, {input}, + [&]() { + return NodeOutputShape(reduce_type, input, scale, scatter_dim, + shard_count, groups); + }, + /*num_outputs=*/1, + torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale, + scatter_dim, shard_count, groups)), + reduce_type_(reduce_type), + scale_(scale), + scatter_dim_(scatter_dim), + shard_count_(shard_count), + groups_(std::move(groups)), + pin_layout_(false), + has_token_(false) {} + ReduceScatterCoalesced::ReduceScatterCoalesced( AllReduceType reduce_type, c10::ArrayRef inputs, const torch::lazy::Value& token, double scale, int64_t scatter_dim, @@ -111,6 +144,11 @@ torch::lazy::NodePtr ReduceScatterCoalesced::Clone( XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); + if (!has_token_) { + auto result = BuildReduceScatter(reduce_type_, input, scale_, scatter_dim_, + shard_count_, groups_); + return ReturnOp(result, loctx); + } xla::XlaOp token = loctx->GetOutputOp(operand(1)); ReduceScatterResult result = BuildReduceScatter(reduce_type_, input, token, scale_, scatter_dim_, diff --git a/torch_xla/csrc/ops/reduce_scatter.h b/torch_xla/csrc/ops/reduce_scatter.h index 2a752788fc4..8e46abcce9b 100644 --- a/torch_xla/csrc/ops/reduce_scatter.h +++ b/torch_xla/csrc/ops/reduce_scatter.h @@ -12,6 +12,9 @@ class ReduceScatter : public XlaNode { const torch::lazy::Value& token, double scale, int64_t scatter_dim, int64_t shard_count, std::vector> groups, bool pin_layout); + ReduceScatter(AllReduceType reduce_type, const torch::lazy::Value& input, + double scale, int64_t scatter_dim, int64_t shard_count, + std::vector> groups); std::string ToString() const override; @@ -34,6 +37,7 @@ class ReduceScatter : public XlaNode { int64_t shard_count_; std::vector> groups_; bool pin_layout_; + bool has_token_{true}; }; class ReduceScatterCoalesced : public XlaNode { diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 4b468090192..cc1b34e9043 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -388,6 +388,17 @@ std::pair reduce_scatter( torch::lazy::Value(node, 1)}; } +XLATensorPtr reduce_scatter(const XLATensorPtr& input, + AllReduceType reduce_type, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups) { + auto canonical_scatter_dim = torch::lazy::GetCanonicalDimensionIndex( + scatter_dim, input->shape().get().rank()); + return input->CreateFrom(torch::lazy::MakeNode( + reduce_type, input->GetIrValue(), scale, canonical_scatter_dim, + shard_count, std::move(groups))); +} + torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, const XLATensorPtr& input, const torch::lazy::Value& token, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 694459d5bd8..df0c64d9a99 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -26,6 +26,11 @@ std::pair reduce_scatter( int64_t shard_count, std::vector> groups, bool pin_layout); +XLATensorPtr reduce_scatter(const XLATensorPtr& input, + AllReduceType reduce_type, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups); + torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, const XLATensorPtr& input, const torch::lazy::Value& token,