diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 3721b0c3768..bfee38d8120 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1765,6 +1765,12 @@ void InitXlaModuleBindings(py::module m) { "_xla_get_rng_seed", [](const std::string& device) { return GetRngSeed(device); }, py::arg("device") = ""); + m.def( + "_xla_set_virtual_topology", + [](std::string& topology) { + torch_xla::runtime::SetVirtualTopology(topology); + }, + py::arg("topology") = ""); m.def( "_xla_set_should_alias_with_buffer_donor_config", [](bool should_alias, const std::string& device_str) { diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 19c5711338d..8c74870c3e9 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -24,6 +24,7 @@ cc_library( ":env_vars", ":ifrt_computation_client", ":pjrt_computation_client", + ":pjrt_compilation_client", "@tsl//tsl/platform:stacktrace", ], ) @@ -134,6 +135,43 @@ cc_library( ], ) +cc_library( + name = "pjrt_compilation_client", + srcs = [ + "pjrt_compilation_client.cc", + ], + hdrs = [ + "pjrt_compilation_client.h", + ], + deps = [ + ":computation_client", + ":debug_macros", + ":env_hash", + ":env_vars", + ":operation_manager", + ":profiler", + ":stablehlo_helper", + ":tensor_source", + ":tf_logging", + ":xla_coordinator", + "//torch_xla/csrc:thread_pool", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform/cloud:gcs_file_system", + "@tsl//tsl/profiler/lib:traceme", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla/client:xla_computation", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_api", + "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/distributed", + ], +) + cc_library( name = "cache", hdrs = ["cache.h"], diff --git a/torch_xla/csrc/runtime/pjrt_compilation_client.cc b/torch_xla/csrc/runtime/pjrt_compilation_client.cc new file mode 100644 index 00000000000..b9ba0f0236b --- /dev/null +++ b/torch_xla/csrc/runtime/pjrt_compilation_client.cc @@ -0,0 +1,679 @@ +#include "torch_xla/csrc/runtime/pjrt_compilation_client.h" + +#include +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/synchronization/blocking_counter.h" +#include "absl/types/span.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_hash.h" +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/operation_manager.h" +#include "torch_xla/csrc/runtime/profiler.h" +#include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/tensor_source.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" +#include "torch_xla/csrc/thread_pool.h" +#include "tsl/profiler/lib/traceme.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/protobuf_util.h" +#include "xla/shape.h" + +using xla::internal::XlaBuilderFriend; + +namespace torch_xla { +namespace runtime { + +namespace { + +// Builds a map from the device's global ordinal to its index in the `devices` +// array. +std::unordered_map build_index_map( + const std::vector& devices) { + std::unordered_map device_index; + for (int i = 0; i < devices.size(); ++i) { + std::vector device_spec = absl::StrSplit(devices[i], ':'); + XLA_CHECK_EQ(device_spec.size(), 2) + << "Invalid device specification: " << devices[i]; + int global_ordinal = std::stoi(device_spec[1]); + device_index[global_ordinal] = i; + } + return device_index; +} + +torch::lazy::hash_t hash_comp_env() { + // TODO(piz): since the client is nullptr, we can't retrive all information + // like PjRtComputationClient. Think about a way to construct the hashing. + torch::lazy::hash_t hash = hash::HashXlaEnvVars(); + return hash; +} + +} // namespace + +std::string PjRtCompilationClient::PjRtDeviceToString( + xla::PjRtDevice* const device) const { + std::string platform = + absl::AsciiStrToUpper(device->client()->platform_name()); + int ordinal = global_ordinals_.at(device->id()); + std::string str = absl::StrFormat("%s:%d", platform, ordinal); + return str; +} + +std::vector PjRtCompilationClient::PjRtDevicesToString( + absl::Span devices) const { + std::vector strs; + strs.reserve(devices.size()); + + for (auto* device : devices) { + strs.push_back(PjRtDeviceToString(device)); + } + + return strs; +} + +PjRtCompilationClient::PjRtCompilationClient( + std::string& virtual_topology_str) { + std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); + + auto tpu_library_path = sys_util::GetEnvString( + env::kEnvTpuLibraryPath, + sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); + absl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); + XLA_CHECK_OK(tpu_status); + + absl::flat_hash_map create_options = {}; + + // TODO(piz): we need to specify correct replicas and partitions + absl::StatusOr> topo = + xla::GetCApiTopology("tpu", virtual_topology_str, create_options); + XLA_CHECK_OK(topo.status()); + this->virtual_topology = std::move(topo.value()); + + // parsing the fake topology + // TODO(piz): this is a temporary solution to convert the topology into + // devices. Fix this for SPMD case. + std::string device_topology; + size_t pos = virtual_topology_str.find(':'); + if (pos != std::string::npos) { + device_topology = virtual_topology_str.substr(pos + 1); + } + + size_t pre_pos = 0; + int device_count = 1; + do { + pos = device_topology.find('x', pre_pos); + int topo_dim = std::stoi(device_topology.substr(pre_pos, pos - pre_pos)); + device_count *= topo_dim; + pre_pos = pos + 1; + } while (pos != std::string::npos); + + for (int i = 0; i < device_count; i++) { + this->client_addressable_devices.push_back(device_type + ":" + + std::to_string(i)); + this->client_devices.push_back(device_type + std::to_string(i)); + } + client_addressable_device_count = this->client_addressable_devices.size(); + client_device_count = this->client_devices.size(); + + // PjRtDevice IDs are not guaranteed to be dense, so we need to track + // a device's global ordinal separately from its device ID. Order the + // devices by increasing ID to assign global ordinals. + for (size_t i = 0; i < this->client_device_count; i++) { + global_ordinals_[i] = global_ordinals_.size(); + } + + comp_env_hash_ = hash_comp_env(); + + auto tracked_devices = GetLocalDevices(); + tracked_devices.emplace_back(spmd_device_str); + operation_manager_ = std::move(OperationManager(std::move(tracked_devices))); +} + +PjRtCompilationClient::~PjRtCompilationClient() { + // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient + // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. + client_ = nullptr; + coordinator_ = nullptr; +} + +bool PjRtCompilationClient::CoordinatorInitialized() const { + return coordinator_ != nullptr; +} + +void PjRtCompilationClient::InitializeCoordinator(int global_rank, + int world_size, + std::string master_addr, + std::string port) { + XLA_CHECK(coordinator_ == nullptr) + << "Can only initialize the XlaCoordinator once."; + coordinator_ = std::make_unique(global_rank, world_size, + master_addr, port); +} + +XlaCoordinator& PjRtCompilationClient::GetCoordinator() { + XLA_CHECK(coordinator_ != nullptr) + << "XlaCoordinator has not been initialized"; + return *coordinator_; +} + +void PjRtCompilationClient::PjRtData::Assign( + const torch::lazy::BackendData& data) { + const PjRtData& pjrt_data = dynamic_cast(data); + if (&pjrt_data != this) { + buffer = pjrt_data.buffer; + } +} + +ComputationClient::DataPtr PjRtCompilationClient::CreateDataPlaceholder( + std::string device, xla::Shape shape, + std::optional sharding) { + if (sharding.has_value()) { + return std::make_shared( + std::move(device), std::move(shape), std::move(*sharding)); + } + + return std::make_shared(std::move(device), std::move(shape)); +} + +ComputationClient::DataPtr PjRtCompilationClient::CreateData( + std::string device, xla::Shape shape, std::shared_ptr buffer) { + return std::make_shared(std::move(device), std::move(shape), + buffer); +} + +std::vector PjRtCompilationClient::GetDataShards( + ComputationClient::DataPtr data) { + tsl::profiler::TraceMe activity("PjRtCompilationClient::GetDataShards", + tsl::profiler::TraceMeLevel::kInfo); + std::vector shards; + if (PjRtShardedData* sharded_data = + dynamic_cast(data.get())) { + for (auto shard : sharded_data->shards) { + shards.push_back(std::make_shared( + shard->device(), shard->shape(), shard->buffer)); + } + } else { + shards.push_back(data); + } + return shards; +} + +ComputationClient::DataPtr PjRtCompilationClient::GetDataShard( + ComputationClient::DataPtr data, size_t index) { + tsl::profiler::TraceMe activity("PjRtCompilationClient::GetDataShard", + tsl::profiler::TraceMeLevel::kInfo); + if (PjRtShardedData* sharded_data = + dynamic_cast(data.get())) { + XLA_CHECK_LE(index, sharded_data->shards.size()) + << "GetDataShard out of range with index: " << index + << " and num of shard: " << sharded_data->shards.size(); + std::shared_ptr shard = sharded_data->shards[index]; + return std::make_shared(shard->device(), shard->shape(), + shard->buffer); + } else { + return data; + } +} + +ComputationClient::DataPtr PjRtCompilationClient::WrapDataShards( + absl::Span shards, std::string device, xla::Shape shape, + xla::OpSharding sharding) { + XLA_CHECK_EQ(shards.size(), client_addressable_devices.size()); + std::vector> pjrt_data_shards; + pjrt_data_shards.reserve(shards.size()); + for (auto& shard : shards) { + XLA_CHECK(shard != nullptr); + auto pjrt_shard = dynamic_cast(shard.get()); + pjrt_data_shards.push_back(std::make_shared( + pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); + } + return std::make_shared(device, shape, pjrt_data_shards, + sharding); +} + +std::optional PjRtCompilationClient::GetDataSharding( + DataPtr handle) { + if (auto sharded_data = dynamic_cast(handle.get())) { + return sharded_data->GetSharding(); + } + return std::optional(); +} + +std::vector PjRtCompilationClient::TransferToDevice( + absl::Span> tensors) { + std::vector datas; + datas.reserve(tensors.size()); + int64_t total_size = 0; + for (auto& tensor : tensors) { + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); + std::vector tuple_shape; + absl::Span dynamic_dimensions; + xla::Shape shape(tensor->primitive_type(), tensor->dimensions(), + dynamic_dimensions, tuple_shape); + std::shared_ptr buffer = std::make_shared(shape); + ComputationClient::DataPtr data = + std::make_shared(tensor->device(), tensor->shape(), buffer); + datas.push_back(data); + } + OutboundDataMetric()->AddSample(total_size); + CreateDataHandlesCounter()->AddValue(datas.size()); + + return datas; +} + +ComputationClient::DataPtr PjRtCompilationClient::TransferShardsToDevice( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { + tsl::profiler::TraceMe activity( + "PjRtCompilationClient::TransferShardsToDevice", + tsl::profiler::TraceMeLevel::kInfo); + // TODO(jonbolin): Consider using CopyToDevice when sharding is REPLICATED. + // We are opting out of CopyToDevice for now due to the synchronization + // issues observed in ShardingUtil::InputHandler, but because CopyToDevice + // directly copies buffers between devices using ICI, it can be much faster + // than transferring from the host to each device. + auto data_shards = TransferToDevice(tensor_shards); + std::vector> pjrt_data_shards; + for (auto& shard : data_shards) { + auto pjrt_shard = dynamic_cast(shard.get()); + pjrt_data_shards.push_back(std::make_shared( + pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer)); + } + return std::make_shared(device, shape, pjrt_data_shards, + sharding); +} + +ComputationClient::DataPtr PjRtCompilationClient::CopyToDevice( + ComputationClient::DataPtr data, std::string dst) { + tsl::profiler::TraceMe activity("PjRtCompilationClient::CopyToDevice", + tsl::profiler::TraceMeLevel::kInfo); + const PjRtData* pjrt_data = dynamic_cast(data.get()); + XLA_CHECK(pjrt_data->HasValue()) << "Can't copy invalid device data."; + + xla::PjRtDevice* dst_device = StringToPjRtDevice(dst); + XLA_CHECK(dst_device->IsAddressable()) << dst << "is not addressable."; + return std::make_shared(dst, pjrt_data->shape(), pjrt_data->buffer); +} + +std::shared_ptr +PjRtCompilationClient::ReplicateShardedData( + const ComputationClient::DataPtr& handle) { + if (auto unsharded_data = std::dynamic_pointer_cast(handle)) { + return unsharded_data; + } else if (auto sharded_data = + std::dynamic_pointer_cast(handle)) { + XLA_COUNTER("ReplicateShardedData", 1); + TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle() + << ", shape=" << sharded_data->shape() << ")"; + if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { + // Data is replicated, return the first shard + return sharded_data->shards[0]; + } + xla::XlaBuilder builder("ReplicateShardedData"); + xla::Shape shape = sharded_data->shape(); + builder.SetSharding(sharded_data->GetSharding()); + + // perform a simple identity calculation to reassemble the input as + // replicated output. + xla::XlaOp x = xla::Parameter(&builder, 0, shape, "p0"); + builder.SetSharding(xla::HloSharding::Replicate().ToProto()); + xla::XlaOp scalar_zero_op = xla::ConvertElementType( + xla::ConstantR0(&builder, 0), shape.element_type()); + xla::XlaOp y = xla::Add(x, scalar_zero_op); + auto instruction = XlaBuilderFriend::GetInstruction(y); + *instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto(); + + xla::XlaComputation computation = + ConsumeValue(builder.Build(/*remove_dynamic_dimensions=*/false)); + xla::ProgramShape program_shape = + ConsumeValue(computation.GetProgramShape()); + + std::string device = GetDefaultDevice(); + std::vector + instances; + instances.push_back({std::move(computation), device, + GetCompilationDevices(device, {}), &shape, + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false}); + std::vector< + std::shared_ptr> + computations = Compile(std::move(instances)); + + torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions + execute_options; + auto sharded_results = + ExecuteReplicated(*computations.front(), {sharded_data}, + GetLocalDevices(), execute_options); + XLA_CHECK(sharded_results.size() > 0) + << "empty ExecuteReplicated results returned."; + XLA_CHECK(sharded_results.size() == 1) + << "Wrong number of outputs, expected: 1, actual: " + << sharded_results.size(); + return std::dynamic_pointer_cast(sharded_results[0]) + ->shards[0]; + } + + XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got " + << handle->ToString(); +} + +std::vector PjRtCompilationClient::ReshardData( + absl::Span handles, + absl::Span shardings) { + tsl::profiler::TraceMe activity("ReshardData", + tsl::profiler::TraceMeLevel::kInfo); + XLA_COUNTER("ReshardData", 1); + XLA_CHECK_EQ(handles.size(), shardings.size()) + << "input handles and shardings must have the same length."; + XLA_CHECK(UseVirtualDevice()) << "We only supports SPMD mode resharding."; + + // Perform a simple identity calculation to reshard. + xla::XlaBuilder builder("ReshardData"); + + std::vector shapes; + shapes.reserve(handles.size()); + std::vector hlo_shardings; + hlo_shardings.reserve(handles.size()); + std::vector param_ops; + param_ops.reserve(handles.size()); + for (int i = 0; i < handles.size(); ++i) { + PjRtShardedData* sharded_data = + dynamic_cast(handles[i].get()); + XLA_CHECK_NE(sharded_data, nullptr) + << "Resharding requires PjRtShardedData on SPMD virtual device, " + << "current device: " << handles[i]->device(); + shapes.push_back(sharded_data->shape()); + + const xla::OpSharding& sharding = shardings[i]; + XLA_CHECK_NE(sharding.type(), xla::OpSharding::UNKNOWN) + << "Resharding by UNKNOWN sharding type is not allowed."; + + hlo_shardings.push_back( + ConsumeValue(xla::HloSharding::FromProto(sharding))); + + xla::OpSharding fallback_sharding; + fallback_sharding.set_type(xla::OpSharding::REPLICATED); + xla::XlaScopedShardingAssignment assign( + &builder, sharded_data->GetSharding().type() == xla::OpSharding::UNKNOWN + ? fallback_sharding + : sharded_data->GetSharding()); + param_ops.push_back( + xla::Parameter(&builder, i, shapes[i], absl::StrCat("p.", i))); + } + + xla::XlaOp root; + { + xla::Shape shapes_tuple = xla::ShapeUtil::MakeTupleShape(shapes); + XLA_CHECK_EQ(shapes_tuple.tuple_shapes_size(), hlo_shardings.size()); + xla::HloSharding new_shardings_tuple = + xla::HloSharding::Tuple(shapes_tuple, hlo_shardings); + xla::XlaScopedShardingAssignment assign(&builder, + new_shardings_tuple.ToProto()); + root = xla::Tuple(&builder, param_ops); + } + + xla::XlaComputation xla_computation = ConsumeValue(builder.Build(root)); + xla::ProgramShape program_shape = + ConsumeValue(xla_computation.GetProgramShape()); + + std::string device = GetDefaultDevice(); + std::vector instances; + instances.push_back({std::move(xla_computation), device, + GetCompilationDevices(device, {}), + &program_shape.result(), + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false}); + std::shared_ptr + computation = Compile(std::move(instances)).front(); + + torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions + execute_options; + auto resharded_results = ExecuteReplicated( + *computation, handles, GetLocalDevices(), execute_options); + return resharded_results; +} + +std::uintptr_t PjRtCompilationClient::UnsafeBufferPointer( + const DataPtr handle) { + TF_VLOG(3) << "UnsafeBufferPointer is not umplemented for compilation only"; + return 0; +} + +std::shared_ptr PjRtCompilationClient::GetPjRtBuffer( + const DataPtr handle) { + TF_LOG(ERROR) << "AOT compilation is unable to get buffer data from device"; + return std::shared_ptr(nullptr); +} + +std::vector PjRtCompilationClient::TransferFromDevice( + absl::Span handles) { + TF_LOG(ERROR) << "AOT compilation is unable to run compuatation and transfer " + "data from device"; + std::vector literals; + return literals; +} + +std::vector PjRtCompilationClient::Compile( + std::vector instances) { + metrics::TimedSection timed(CompileMetric()); + tsl::profiler::TraceMe activity("PjRtCompilationClient::Compile", + tsl::profiler::TraceMeLevel::kInfo); + std::vector computations; + + for (auto& instance : instances) { + xla::CompileOptions compile_options; + if (instance.is_sharded) { + // TODO(yeounoh) multi-host, multi-slice configurations + compile_options.executable_build_options.set_use_spmd_partitioning(true); + + // We can override the compiler's default behavior to replicate the + // outputs. Setting this to true would wrapping the sharded outputs in + // PjRtShardedData. + compile_options.executable_build_options + .set_allow_spmd_sharding_propagation_to_output( + {instance.allow_spmd_sharding_propagation_to_output}); + + int num_partitions = client_device_count; + + compile_options.executable_build_options.set_num_partitions( + num_partitions); + compile_options.executable_build_options.set_num_replicas(1); + compile_options.parameter_is_tupled_arguments = + instance.parameter_is_tupled_arguments; + compile_options.executable_build_options.set_use_auto_spmd_partitioning( + instance.use_auto_spmd_partitioning); + TF_VLOG(3) << "Auto SPMD partitioning " + << (instance.use_auto_spmd_partitioning ? "enabled!" + : "disabled."); + if (!instance.auto_spmd_mesh_shape.empty()) { + compile_options.executable_build_options + .set_auto_spmd_partitioning_mesh_shape( + instance.auto_spmd_mesh_shape); + TF_VLOG(3) << "auto_spmd_partitioning_mesh_shape=" + << absl::StrJoin(compile_options.executable_build_options + .auto_spmd_partitioning_mesh_shape(), + ","); + } + if (!instance.auto_spmd_mesh_ids.empty()) { + compile_options.executable_build_options + .set_auto_spmd_partitioning_mesh_ids(instance.auto_spmd_mesh_ids); + TF_VLOG(3) << "auto_spmd_partitioning_mesh_ids=" + << absl::StrJoin(compile_options.executable_build_options + .auto_spmd_partitioning_mesh_ids(), + ","); + } + + // TODO(244391366) verify this is correct for the collectives ops + xla::DeviceAssignment device_assignment(1, client_device_count); + // DeviceAssignment values must be the PjRtDevice ID, so we need to + // unwind the global ordinal mapping. + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(0, global_ordinal) = device_id; + } + compile_options.executable_build_options.set_device_assignment( + device_assignment); + } else { + // TODO(wcromar): set compile_options.argument_layouts, enable strict + // shapes + compile_options.executable_build_options.set_num_partitions(1); + compile_options.executable_build_options.set_num_replicas( + client_device_count); + compile_options.parameter_is_tupled_arguments = + instance.parameter_is_tupled_arguments; + + xla::DeviceAssignment device_assignment(client_device_count, 1); + // DeviceAssignment values must be the PjRtDevice ID, so we need to + // unwind the global ordinal mapping. + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(global_ordinal, 0) = device_id; + } + compile_options.executable_build_options.set_device_assignment( + device_assignment); + } + + std::shared_ptr topo = + std::move(this->virtual_topology); + std::unique_ptr executable = ConsumeValue( + PjRtCompile(compile_options, instance.computation, *topo.get())); + const auto& hlo_modules = ConsumeValue(executable->GetHloModules()); + xla::HloComputation* hlo_computation = hlo_modules[0]->entry_computation(); + std::shared_ptr pjrt_computation = + std::make_shared( + std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), + instance.devices, std::move(executable)); + computations.push_back(pjrt_computation); + CreateCompileHandlesCounter()->AddValue(1); + } + + return computations; +} + +std::string PjRtCompilationClient::SerializeComputation( + const ComputationPtr computation) { + // AOT uses PjRtUnloadedComputation, which doesn't need a client + const PjRtUnloadedComputation& pjrt_computation = + dynamic_cast(*computation); + return ConsumeValue(pjrt_computation.executable->SerializeExecutable()); +} + +ComputationClient::ComputationPtr PjRtCompilationClient::DeserializeComputation( + const std::string& serialized) { + TF_LOG(ERROR) << __FUNCTION__ << " is not defined for AOT compilation"; + return nullptr; +} + +torch::lazy::hash_t PjRtCompilationClient::HashCompilationEnv() { + // TODO(jonbolin): Incorporate CompileOptions into the hash. These are + // deterministically generated at the moment, so they don't need to be + // included. It will require a small refactor, so punting on this for now. + return comp_env_hash_; +} + +std::vector +PjRtCompilationClient::ExecuteComputation( + const ComputationClient::Computation& computation, + absl::Span arguments, + const std::string& device, const ExecuteComputationOptions& options) { + TF_LOG(ERROR) << __FUNCTION__ << " is not supported for AOT compilation"; + return std::vector(); +} + +std::vector +PjRtCompilationClient::ExecuteReplicated( + const ComputationClient::Computation& computation, + absl::Span arguments, + absl::Span devices, + const ExecuteReplicatedOptions& options) { + TF_LOG(ERROR) << __FUNCTION__ << " is not supported for AOT compilation"; + std::vector data_handles; + return data_handles; +} + +size_t PjRtCompilationClient::GetNumDevices() const { + return this->client_addressable_device_count; +} + +std::string PjRtCompilationClient::GetDefaultDevice() const { + return this->client_addressable_devices[0]; +} + +std::vector PjRtCompilationClient::GetLocalDevices() const { + return this->client_addressable_devices; +} + +std::vector PjRtCompilationClient::GetAllDevices() const { + return this->client_devices; +} + +int PjRtCompilationClient::GetNumProcesses() const { + TF_LOG(ERROR) << __FUNCTION__ << " is not defined for AOT compilation"; + return 1; +}; + +const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> +PjRtCompilationClient::GetDeviceAttributes(const std::string& device) { + return PjRtCompilationClient::StringToPjRtDevice(device)->Attributes(); +} + +void PjRtCompilationClient::SetReplicationDevices( + std::shared_ptr> devices) { + replication_devices_ = std::move(devices); +} + +std::shared_ptr> +PjRtCompilationClient::GetReplicationDevices() { + return replication_devices_; +} + +xla::PjRtDevice* PjRtCompilationClient::StringToPjRtDevice( + const std::string& device) { + XLA_CHECK(string_to_device_.find(device) != string_to_device_.end()) + << "Unknown device " << device; + xla::PjRtDevice* pjrt_device = string_to_device_[device]; + return pjrt_device; +} + +void PjRtCompilationClient::WaitDeviceOps( + absl::Span devices) { + TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", "); + operation_manager_.WaitForDevices(devices.empty() ? GetLocalDevices() + : devices); +} + +std::map PjRtCompilationClient::GetMetrics() const { + // TODO(jonbolin): Add any PJRt-client-specific metrics here + return {}; +} + +ComputationClient::MemoryInfo PjRtCompilationClient::GetMemoryInfo( + const std::string& device) { + XLA_CHECK_NE(device, spmd_device_str) + << "MemoryInfo not supported for SPMD virtual device."; + xla::PjRtDevice* pjrt_device = + PjRtCompilationClient::StringToPjRtDevice(device); + tsl::AllocatorStats stats = pjrt_device->GetAllocatorStats().value(); + + return { + stats.bytes_in_use, + *stats.bytes_limit, + }; +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_compilation_client.h b/torch_xla/csrc/runtime/pjrt_compilation_client.h new file mode 100644 index 00000000000..7acafe686e7 --- /dev/null +++ b/torch_xla/csrc/runtime/pjrt_compilation_client.h @@ -0,0 +1,297 @@ +#ifndef XLA_CLIENT_PJRT_COMPILATION_CLIENT_H_ +#define XLA_CLIENT_PJRT_COMPILATION_CLIENT_H_ + +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/operation_manager.h" +#include "torch_xla/csrc/runtime/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool.h" +#include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/shape.h" + +namespace torch_xla { +namespace runtime { + +struct Buffer { + xla::Shape shape; + Buffer(xla::Shape shape) : shape(shape) {} +}; + +class PjRtCompilationClient : public ComputationClient { + public: + PjRtCompilationClient(std::string& virtual_topology_str); + ~PjRtCompilationClient(); + + DataPtr CreateDataPlaceholder( + std::string device, xla::Shape shape, + std::optional sharding = std::nullopt) override; + + static DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr buffer); + + std::vector GetDataShards(DataPtr data) override; + + DataPtr GetDataShard(DataPtr data, size_t index) override; + + DataPtr WrapDataShards(absl::Span shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) override; + + std::optional GetDataSharding(DataPtr handle) override; + + std::vector TransferToDevice( + absl::Span> tensors) override; + + // Reshard and return data sharded by `sharding` spec. This is a no-op if + // the input sharding spec is identical to the target `sharding` sharding + // spec. + // TODO(yeounoh) replace ReplicateShardedData with this. + std::vector ReshardData( + absl::Span handles, + absl::Span shardings) override; + + std::vector TransferFromDevice( + absl::Span handles) override; + + std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; + + std::shared_ptr GetPjRtBuffer(const DataPtr handle) override; + + DataPtr TransferShardsToDevice( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override; + + DataPtr CopyToDevice(DataPtr data, std::string dst) override; + + std::vector Compile( + std::vector instances) override; + + std::string SerializeComputation(const ComputationPtr computation) override; + + ComputationPtr DeserializeComputation(const std::string& serialized) override; + + std::vector ExecuteComputation( + const Computation& computation, absl::Span arguments, + const std::string& device, + const ExecuteComputationOptions& options) override; + + std::vector ExecuteReplicated( + const Computation& computation, absl::Span arguments, + absl::Span devices, + const ExecuteReplicatedOptions& options) override; + + size_t GetNumDevices() const override; + + std::string GetDefaultDevice() const override; + + torch_xla::DeviceType GetDeviceType() const override { + return torch_xla::DeviceType(absl::AsciiStrToUpper(this->platform_name)); + }; + + xla::PjRtPlatformId GetPlatformID() const override { + return this->platform_id; + } + + absl::StatusOr LookupAddressableDevice( + int local_device_id) const override { + return client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(local_device_id)); + } + + std::intptr_t GetCudaStreamForDevice(int local_device_id) const override { + return 0; + } + + std::vector GetLocalDevices() const override; + + std::vector GetAllDevices() const override; + + torch::lazy::hash_t HashCompilationEnv() override; + + int GetProcessIndex() const override { return client_->process_index(); }; + + int GetNumProcesses() const override; + + const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> + GetDeviceAttributes(const std::string& device) override; + + void SetReplicationDevices( + std::shared_ptr> devices) override; + + std::shared_ptr> GetReplicationDevices() override; + + void WaitDeviceOps(absl::Span devices) override; + + std::map GetMetrics() const override; + + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; + + XlaCoordinator& GetCoordinator() override; + + bool CoordinatorInitialized() const override; + + MemoryInfo GetMemoryInfo(const std::string& device) override; + + std::string PjRtDeviceToString(xla::PjRtDevice* const device) const override; + std::vector PjRtDevicesToString( + absl::Span devices) const; + + private: + std::unique_ptr client_; + std::unique_ptr coordinator_; + + // fake attributes for AOT + std::string platform_name; + xla::PjRtPlatformId platform_id; + std::vector client_addressable_devices; + size_t client_addressable_device_count; + std::vector client_devices; + size_t client_device_count; + std::unique_ptr virtual_topology; + + // global_ordinals_ tracks a map from PjRtDeviceId to the device's + // dense global ordinal. + std::unordered_map global_ordinals_; + std::unordered_map string_to_device_; + std::shared_ptr> replication_devices_; + OperationManager operation_manager_; + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( + tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); + torch::lazy::hash_t comp_env_hash_; + + xla::PjRtDevice* StringToPjRtDevice(const std::string& device); + + struct PjRtData : public Data { + PjRtData(std::string device, xla::Shape device_shape) + : Data(std::move(device), std::move(device_shape)) {} + + PjRtData(std::string device, xla::Shape device_shape, + std::shared_ptr buffer) + : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} + + Handle GetHandle() override { + XLA_CHECK(HasValue()) + << "buffer with shape " << shape().ToString() << " on device " + << device() << (buffer == nullptr ? " is null" : " is deleted"); + return reinterpret_cast(buffer.get()); + }; + void Assign(const torch::lazy::BackendData& data) override; + bool HasValue() const override { return buffer != nullptr; }; + + bool HasSharding() const override { return false; } + + xla::OpSharding GetSharding() const override { + XLA_CHECK(false) << "GetSharding should not be called on PjRtData, check " + "HasSharding first"; + return xla::OpSharding(); + } + + std::string ToString() const override { + std::stringstream ss; + ss << "XLAData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " Data Handle: "; + if (HasValue()) { + ss << reinterpret_cast(buffer.get()) << "\n"; + } else { + ss << (buffer == nullptr ? "None" : "Deleted") << "\n"; + } + return ss.str(); + } + + std::shared_ptr buffer; + }; + + struct PjRtShardedData : public Data { + PjRtShardedData(std::string device, xla::Shape shape) = delete; + + PjRtShardedData(std::string device, xla::Shape shape, + xla::OpSharding sharding) + : Data(std::move(device), std::move(shape)), sharding(sharding) {} + + PjRtShardedData(std::string device, xla::Shape shape, + std::vector> shards, + xla::OpSharding sharding) + : Data(std::move(device), std::move(shape)), + shards(shards), + sharding(sharding) {} + + Handle GetHandle() override { + // Always returns `Handle` of the first shard. + return shards[0]->GetHandle(); + } + + void Assign(const torch::lazy::BackendData& data) override { + const PjRtShardedData& pjrt_sharded_data = + dynamic_cast(data); + if (&pjrt_sharded_data != this) { + shards = std::move(pjrt_sharded_data.shards); + } + } + + bool HasValue() const override { + if (shards.empty()) { + return false; + } + + for (auto& shard : shards) { + if (!shard->HasValue()) { + return false; + } + } + return true; + } + + std::string ToString() const override { + std::stringstream ss; + ss << "XLAShardedData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " OpSharding: " + << xla::HloSharding::FromProto(sharding)->ToString() << "\n"; + ss << " NumShards: " << shards.size() << "\n"; + return ss.str(); + } + + bool HasSharding() const override { return true; } + + xla::OpSharding GetSharding() const override { return sharding; } + + std::vector> shards; + xla::OpSharding sharding; + }; + + struct PjRtUnloadedComputation : public Computation { + PjRtUnloadedComputation(xla::XlaComputation computation, + std::vector devices, + std::unique_ptr executable) + : Computation(std::move(computation), std::move(devices)), + executable(std::move(executable)) { + output_shardings_ = this->executable->GetOutputShardings(); + } + + std::unique_ptr executable; + std::optional> output_shardings_; + }; + + // Use XLA replication to re-assemble the sharded data. + std::shared_ptr ReplicateShardedData(const DataPtr& handle); +}; + +} // namespace runtime +} // namespace torch_xla +#endif // XLA_CLIENT_PJRT_COMPILATION_CLIENT_H_ diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index feb2a0844c6..95e1c2abae0 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -4,6 +4,7 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" +#include "torch_xla/csrc/runtime/pjrt_compilation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "tsl/platform/stacktrace_handler.h" @@ -12,6 +13,8 @@ namespace runtime { std::atomic g_computation_client_initialized(false); +std::string aot_topology = ""; + ComputationClient* GetComputationClient() { static std::unique_ptr client = []() { if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) { @@ -22,7 +25,10 @@ ComputationClient* GetComputationClient() { static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { - if (use_ifrt) { + if (aot_topology != "") { + // aot returns nullptr client + client = std::make_unique(aot_topology); + } else if (use_ifrt) { client = std::make_unique(); } else { client = std::make_unique(); @@ -40,6 +46,13 @@ ComputationClient* GetComputationClient() { return client.get(); } +void SetVirtualTopology(const std::string& topology) { + XLA_CHECK(!g_computation_client_initialized) + << "AOT topology must be specified before computation client " + "initialization."; + aot_topology = topology; +} + ComputationClient* GetComputationClientIfInitialized() { return g_computation_client_initialized ? GetComputationClient() : nullptr; } diff --git a/torch_xla/csrc/runtime/runtime.h b/torch_xla/csrc/runtime/runtime.h index 64c31be86d7..02fa114acba 100644 --- a/torch_xla/csrc/runtime/runtime.h +++ b/torch_xla/csrc/runtime/runtime.h @@ -11,6 +11,8 @@ ComputationClient* GetComputationClient(); ComputationClient* GetComputationClientIfInitialized(); +void SetVirtualTopology(const std::string& topology); + // Run the XRT local service, this will block the caller unitl the server // being stopped. void RunLocalService(uint64_t service_port);