diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index 675eb91b58a..ed4b9173fa5 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -209,18 +209,6 @@ only be enabled for debugging. * ```XLA_SAVE_HLO_FILE```: If set, the path to a local file where, in case of compilation/execution error, the offending HLO graph will be saved. -* ```XLA_GET_TENSORS_OPBYOP```: Enables pure _OpByOp_ dispatch. The _PyTorch/XLA_ software tries to - fuse together many _PyTorch_ operations into a single computation graph, but sometimes, either - for debugging, or in case the _PyTorch_ code have a very dynamic nature (in shapes or graph - terms), it is better to force the execution in _OpByOp_ mode (every IR node is lowered into - a separate _XLA_ computation, and chain-executed). This environment variable, if set to 1, - enables _OpByOp_ during the "get tensors" operation (the operation used by _PyTorch/XLA_ to - fetch intermediate values back from the _TPU_ device into _PyTorch_ CPU tensors). - -* ```XLA_SYNC_TENSORS_OPBYOP```: The same as _XLA_GET_TENSORS_OPBYOP_ but for "sync tensors" - operation (the operation used at the end of a step, to flush pending IR computations and - materialize them into _TPU_ device data). - * ```XLA_SYNC_WAIT```: Forces the XLA tensor sync operation to wait for its completion, before moving to the next step. diff --git a/configuration.yaml b/configuration.yaml index d9b78766462..7d0a86e38aa 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -104,26 +104,6 @@ variables: type: bool default_value: false feature_variables: - XLA_GET_TENSORS_OPBYOP: - description: - - Enables pure OpByOp dispatch. The PyTorch/XLA software tries to fuse - together many PyTorch operations into a single computation graph, but - sometimes, either for debugging, or in case the PyTorch code have a - very dynamic nature (in shapes or graph terms), it is better to force - the execution in OpByOp mode (every IR node is lowered into a - separate XLA computation, and chain-executed). This environment - variable, if set to true, enables OpByOp during the "get tensors" - operation (the operation used by PyTorch/XLA to fetch intermediate - values back from the TPU device into PyTorch CPU tensors). - type: bool - default_value: false - XLA_SYNC_TENSORS_OPBYOP: - description: - - The same as XLA_GET_TENSORS_OPBYOP but for "sync tensors" operation - (the operation used at the end of a step, to flush pending IR - computations and materialize them into TPU device data). - type: bool - default_value: false XLA_SYNC_WAIT: description: - Forces the XLA tensor sync operation to wait for its completion, diff --git a/test/test_operations.py b/test/test_operations.py index 0083f39c839..6a1be16ede3 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1855,42 +1855,6 @@ def test(self): self.assertEqual(len(report), 0) -class TestAsyncScalar(test_utils.XlaTestCase): - - def test_rng_seed_transfer(self): - xla_device = xm.xla_device() - async_mode = xu.getenv_as('XLA_TRANSFER_SCALAR_ASYNC', bool, defval=False) - # mark_step to clear the rng seed - xm.mark_step() - - transfer_to_server_async_metric = met.metric_data("TransferToServerAsync") - async_transfer_count = 0 if transfer_to_server_async_metric == None else transfer_to_server_async_metric[ - 0] - t1 = torch.randn(3, 3, device=xla_device) - xm.mark_step() - if async_mode: - assert met.metric_data( - "TransferToServerAsync")[0] == async_transfer_count + 1 - else: - assert met.metric_data("TransferToServerAsync") == None - - def test_scalar_transfer(self): - xla_device = xm.xla_device() - async_mode = xu.getenv_as('XLA_TRANSFER_SCALAR_ASYNC', bool, defval=False) - - transfer_to_server_async_metric = met.metric_data("TransferToServerAsync") - async_transfer_count = 0 if transfer_to_server_async_metric == None else transfer_to_server_async_metric[ - 0] - t1 = torch.randn(3, 3).to(xla_device) - t2 = t1 / 0.5 - t3 = t2.cpu() - if async_mode: - assert met.metric_data( - "TransferToServerAsync")[0] == async_transfer_count + 1 - else: - assert met.metric_data("TransferToServerAsync") == None - - class TestWaitDeviceOps(test_utils.XlaTestCase): def test_wait_device_ops(self): diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 7071c80bcc3..717a7760410 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -48,7 +48,6 @@ ptxla_cc_library( "matrix.cpp", "nll_loss.cpp", "nms_op.cpp", - "op_by_op_executor.cpp", "pooling.cpp", "random.cpp", "reduction.cpp", @@ -88,7 +87,6 @@ ptxla_cc_library( "matrix.h", "nll_loss.h", "nms_op.h", - "op_by_op_executor.h", "pooling.h", "random.h", "reduction.h", diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp deleted file mode 100644 index 0eba02a03da..00000000000 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ /dev/null @@ -1,234 +0,0 @@ -#include "torch_xla/csrc/op_by_op_executor.h" - -#include -#include - -#include -#include - -#include "absl/strings/str_cat.h" -#include "torch_xla/csrc/device.h" -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/device_data.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/runtime.h" -#include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/xla_util.h" -#include "torch_xla/csrc/tensor_util.h" -#include "torch_xla/csrc/torch_util.h" -#include "xla/client/xla_builder.h" - -namespace torch_xla { -namespace { - -absl::optional GetOutputIndex(bool is_device_data, size_t index) { - // The output of every result of an op-by-op computation is wrapped into a - // tuple, so we need to use the index to extract it. Device data instead is - // already unwrapped, so we need to pass an empty index so that TF/XRT code - // uses the result buffer directly. - if (is_device_data) { - return absl::nullopt; - } - return index; -} - -const xla::Shape& GetParameterShape(const torch::lazy::Output& operand, - const xla::Shape& input_shape) { - // See comment in GetOutputIndex() about device data WRT computation outpout - // shape handling. - const DeviceData* device_data = DeviceData::Cast(operand.node); - return device_data != nullptr - ? input_shape - : xla::ShapeUtil::GetTupleElementShape(input_shape, operand.index); -} - -torch::lazy::hash_t ComputeNodeKey( - const torch::lazy::Node* node, - absl::Span input_shapes, - const torch::lazy::hash_t& seed) { - torch::lazy::hash_t key = seed; - const auto& operands = node->operands(); - for (size_t i = 0; i < operands.size(); ++i) { - key = torch::lazy::HashCombine(key, torch::lazy::Hash(GetParameterShape( - operands[i], *input_shapes[i]))); - } - const XlaNode* casted = dynamic_cast(node); - key = torch::lazy::HashCombine(key, torch::lazy::Hash(casted->xla_shape())); - return torch::lazy::HashCombine(key, casted->node_hash()); -} - -xla::XlaComputation BuildNodeComputation( - const torch::lazy::Node* node, - absl::Span input_shapes, - const torch::lazy::BackendDevice& device) { - LoweringContext loctx("BuildNodeComputation", device); - const auto& operands = node->operands(); - for (size_t i = 0; i < operands.size(); ++i) { - xla::XlaOp param = xla::Parameter( - loctx.builder(), i, GetParameterShape(operands[i], *input_shapes[i]), - absl::StrCat("p", i)); - loctx.AssignOutputOp(operands[i], param); - } - for (auto& xla_op : loctx.LowerNode(node)) { - loctx.AddResult(xla_op); - } - return ConsumeValue(loctx.BuildXla()); -} - -torch::lazy::hash_t GetNodesKeySeed(const std::string& device, - absl::Span devices) { - return torch::lazy::MHash(device, torch::lazy::Hash(devices)); -} - -} // namespace - -OpByOpExecutor::OpByOpExecutor(size_t compile_cache_size) - : compile_cache_(compile_cache_size) {} - -std::vector -OpByOpExecutor::BuildOps(c10::ArrayRef roots, - const std::string& device, - absl::Span devices) { - std::vector root_nodes; - root_nodes.reserve(roots.size()); - for (auto& root : roots) { - root_nodes.push_back(root.node.get()); - } - auto post_order = torch::lazy::Util::ComputePostOrder(root_nodes); - TORCH_LAZY_VALUE_METRIC("OpByOpGraphSize", post_order.size()); - TF_VLOG(5) << "TensorsGraphSize=" << post_order.size(); - - std::unordered_map node_to_index; - node_to_index.reserve(post_order.size()); - for (size_t i = 0; i < post_order.size(); ++i) { - node_to_index[post_order[i]] = i; - } - - auto compilation_devices = - runtime::GetComputationClient()->GetCompilationDevices(device, devices); - torch::lazy::hash_t nodes_key_seed = - GetNodesKeySeed(device, compilation_devices); - torch::lazy::BackendDevice exec_device = ParseDeviceString(device); - std::vector cache_keys; - std::unordered_map, - torch::lazy::HashReducer> - compile_indices; - std::unordered_map - cache_keys_instance; - std::list compile_shapes; - std::vector device_data_ops(post_order.size()); - std::vector ops_shapes(post_order.size()); - std::vector compile_instances; - std::vector chained_exec_ops( - post_order.size()); - for (size_t i = 0; i < post_order.size(); ++i) { - const torch::lazy::Node* node = post_order[i]; - runtime::ComputationClient::ExecuteChainedOp& cxop = chained_exec_ops[i]; - const auto backend_data = - torch::lazy::getBackend()->GetComputationDataFromNode(node); - if (backend_data != nullptr) { - cxop.device_data = UnwrapXlaData(backend_data); - ops_shapes[i] = &cxop.device_data->shape(); - device_data_ops[i] = true; - } else { - std::vector op_input_shapes; - for (auto& operand : node->operands()) { - size_t op_index = node_to_index.at(operand.node); - cxop.inputs.push_back( - {op_index, - GetOutputIndex(device_data_ops[op_index], operand.index)}); - op_input_shapes.push_back(ops_shapes[op_index]); - } - - torch::lazy::hash_t cache_key = - ComputeNodeKey(node, op_input_shapes, nodes_key_seed); - cxop.computation = compile_cache_.Get(cache_key); - if (cxop.computation == nullptr) { - TORCH_LAZY_COUNTER("OpByOpCompileCacheMiss", 1); - - // Within a single IR graph, there can be many duplicated IR nodes, so - // make sure we do not issue an XLA compilation for each one of those. - auto& cache_key_indices = compile_indices[cache_key]; - cache_key_indices.push_back(i); - if (cache_key_indices.size() == 1) { - cache_keys.push_back(cache_key); - cache_keys_instance[cache_key] = compile_instances.size(); - - xla::XlaComputation computation = - BuildNodeComputation(node, op_input_shapes, exec_device); - xla::ProgramShape program_shape = - ConsumeValue(computation.GetProgramShape()); - compile_shapes.push_back(MakeShapeWithDeviceLayout( - program_shape.result(), - static_cast(exec_device.type()))); - compile_instances.push_back({std::move(computation), device, - compilation_devices, - &compile_shapes.back()}); - ops_shapes[i] = &compile_shapes.back(); - } else { - ops_shapes[i] = - compile_instances[cache_keys_instance.at(cache_key)].output_shape; - } - } else { - ops_shapes[i] = &cxop.computation->program_shape().result(); - } - } - } - // Fixup the requested outputs (roots) within the chained ops vector. - for (size_t i = 0; i < roots.size(); ++i) { - size_t op_index = node_to_index.at(roots[i].node.get()); - chained_exec_ops[op_index].outputs.push_back( - {i, GetOutputIndex(device_data_ops[op_index], roots[i].index)}); - } - - // If we missed the cache for certain ops, compile them now and fixup the - // chained ops vector. - if (!compile_instances.empty()) { - TF_VLOG(3) << "Compiling " << compile_instances.size() - << " computations on device " << device; - auto computation_ptrs = - runtime::GetComputationClient()->Compile(std::move(compile_instances)); - TF_VLOG(3) << "Compiling " << computation_ptrs.size() - << " computations on device " << device << " done!"; - for (size_t i = 0; i < computation_ptrs.size(); ++i) { - compile_cache_.Add(cache_keys[i], computation_ptrs[i]); - for (auto index : compile_indices[cache_keys[i]]) { - chained_exec_ops[index].computation = computation_ptrs[i]; - } - } - } - return chained_exec_ops; -} - -std::vector OpByOpExecutor::Execute( - c10::ArrayRef roots, const std::string& device, - absl::Span devices) { - auto chained_exec_ops = BuildOps(roots, device, devices); - return WrapXlaData(runtime::GetComputationClient()->ExecuteChained( - chained_exec_ops, device)); -} - -OpByOpExecutor::AsyncTask OpByOpExecutor::ExecuteAsync( - c10::ArrayRef roots, const std::string& device, - absl::Span devices) { - std::vector roots_vector(roots.begin(), roots.end()); - std::vector devices_vector(devices.begin(), devices.end()); - auto taskfn = [this, roots = std::move(roots_vector), - devices = std::move(devices_vector), device]() -> AsyncResult { - return Execute(roots, device, devices); - }; - - AsyncTask async = AsyncTask(std::move(taskfn)); - return async.Schedule(); -} - -OpByOpExecutor* OpByOpExecutor::Get() { - static const int64_t compile_cache_size = - runtime::sys_util::GetEnvInt("SPLIT_EXECUTOR_CACHE_SIZE", 2048); - static OpByOpExecutor* split_executor = - new OpByOpExecutor(compile_cache_size); - return split_executor; -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/op_by_op_executor.h b/torch_xla/csrc/op_by_op_executor.h deleted file mode 100644 index 90afc6da511..00000000000 --- a/torch_xla/csrc/op_by_op_executor.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_OP_BY_OP_EXECUTOR_H_ -#define XLA_TORCH_XLA_CSRC_OP_BY_OP_EXECUTOR_H_ - -#include -#include - -#include "absl/types/span.h" -#include "torch_xla/csrc/ir.h" -#include "torch_xla/csrc/runtime/async_task.h" -#include "torch_xla/csrc/runtime/cache.h" -#include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/util.h" -#include "xla/types.h" - -namespace torch_xla { - -// The OpByOpExecutor class is a singleton accessible via its Get() API that -// allows to run an IR graph is per-IR-node isolation mode. Instead of lowering -// the whole IR graph in a single XLA computation, the single IR nodes are -// lowered and executed independently. -class OpByOpExecutor { - public: - using AsyncResult = std::vector; - using AsyncTask = runtime::util::AsyncTask; - - static OpByOpExecutor* Get(); - - std::vector BuildOps( - c10::ArrayRef roots, const std::string& device, - absl::Span devices); - - std::vector Execute( - c10::ArrayRef roots, const std::string& device, - absl::Span devices); - - AsyncTask ExecuteAsync(c10::ArrayRef roots, - const std::string& device, - absl::Span devices); - - private: - using CompileCache = - runtime::util::Cache; - - explicit OpByOpExecutor(size_t compile_cache_size); - - CompileCache compile_cache_; -}; - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_OP_BY_OP_EXECUTOR_H_ diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 74f9d73659e..4578df715e8 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -240,35 +240,6 @@ class ComputationClient { struct ExecuteReplicatedOptions : public ClientExecuteOptions {}; - struct ExecuteParallelOptions : public ClientExecuteOptions {}; - - // Describes an operation to be fed to the ExecuteChained() API. - // If the device_data member is not nullptr, this operation is a device data - // input. Otherwise computation must not be nullptr, and represents the - // computation to be executed. The indices of the inputs to the computation, - // are coming from the inputs member. Since the operations fed to - // ExecuteChained() are a valid post-order, the op_index indices listed within - // the inputs member must be lower of the index of the current - // ExecuteChainedOp within the post-order. If the outputs member has values, - // the result of this ExecuteChainedOp will become an output of the - // ExecuteChained() API, with the output_index output of this ExecuteChainedOp - // feeding the result_index result. - struct ExecuteChainedOp { - struct Input { - size_t op_index; - absl::optional output_index; - }; - struct Output { - size_t result_index; - absl::optional output_index; - }; - - DataPtr device_data; - ComputationPtr computation; - std::vector outputs; - std::vector inputs; - }; - struct MemoryInfo { int64_t kb_free = 0; int64_t kb_total = 0; @@ -281,15 +252,6 @@ class ComputationClient { virtual DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) = 0; - // Create DataPtr that only has dummy information which can be filled in - // later. - virtual std::vector CreateAsyncDatas( - absl::Span tensors) = 0; - - // Lock the DataPtr - virtual std::vector - LockAsyncDatas(absl::Span datas) = 0; - // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input // wrapped inside a vector. @@ -311,12 +273,6 @@ class ComputationClient { virtual std::vector TransferToServer( absl::Span tensors) = 0; - // Transfers local tensor values to the TPU devices and fetches the handles. - // Update the handles associated with DataPtrs passed instead of creating new - // datas. - virtual void TransferToServer(absl::Span tensors, - absl::Span datas) = 0; - // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToServer( @@ -362,36 +318,6 @@ class ComputationClient { absl::Span devices, const ExecuteReplicatedOptions& options) = 0; - // Executes the computations in parallel. Each computation must target a - // different device, and the the common device of arguments[i] must match - // devices[i]. The computations[i] computation is fed with arguments[i] - // arguments. - // Returns a vector of vectors of device side Data object, with result[i] - // being the return value of computations[i]. If options.explode_tuple is - // true, the output tuples will be decomposed into their single elements. - virtual std::vector> ExecuteParallel( - absl::Span computations, - const std::vector>& arguments, - absl::Span devices, - const ExecuteParallelOptions& options) = 0; - - // Executes a serie of operations, whose results are input of other - // operations. The ops is a valid post-order for the execution, which means - // that the inputs of op at index I, will have to be coming from ops at index - // lower than I. It returns a vector of device data shared pointers, one for - // every ExecuteChainedOp marked with is_result=true, in the order they appear - // within the ops post-order. - virtual std::vector ExecuteChained( - absl::Span ops, const std::string& device) = 0; - - virtual std::vector> DeconstructTuple( - absl::Span tuples) = 0; - - // Returns a unique string which identifies the resource domain of a given - // device. Within a resource domain, handles to device memory or compiled - // computations can be used for all devices part of such domain. - virtual std::string GetResourceDomain(const std::string& device) const = 0; - virtual std::string GetDefaultDevice() const = 0; virtual size_t GetNumDevices() const = 0; @@ -416,8 +342,6 @@ class ComputationClient { virtual std::shared_ptr> GetReplicationDevices() = 0; - virtual void SetRngSeed(size_t seed) = 0; - virtual std::map GetMetrics() const = 0; virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index d594b73202e..9a729793bf0 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -87,52 +87,10 @@ class PjRtComputationClient : public ComputationClient { void WaitDeviceOps(const std::vector& devices) override; - // NOT IMPLEMENTED - - void TransferToServer(absl::Span tensors, - absl::Span datas) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - - std::vector CreateAsyncDatas( - absl::Span tensors) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - - std::vector LockAsyncDatas( - absl::Span datas) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - - std::vector> DeconstructTuple( - absl::Span tuples) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - - std::vector> ExecuteParallel( - absl::Span computations, - const std::vector>& arguments, - absl::Span devices, - const ExecuteParallelOptions& options) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - - std::vector ExecuteChained(absl::Span ops, - const std::string& device) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - - std::string GetResourceDomain(const std::string& device) const override { - // TODO(wcromar): return a meaningful value - return "getresourcedomainplaceholder"; - }; - - void SetRngSeed(size_t seed) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - std::map GetMetrics() const override; + // NOT IMPLEMENTED + MemoryInfo GetMemoryInfo(const std::string& device) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; }; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index e8645af7228..23ce2df136b 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -22,7 +22,6 @@ #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" -#include "torch_xla/csrc/op_by_op_executor.h" #include "torch_xla/csrc/ops/arithmetic_ir_ops.h" #include "torch_xla/csrc/ops/cast.h" #include "torch_xla/csrc/ops/device_data.h" diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 12f703e32d1..49ef42d60ba 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -569,41 +569,6 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape, } } -void TransferToServerAsync(std::shared_ptr async, - const std::vector& devices) { - TORCH_LAZY_TIMED("TransferToServerAsync"); - - std::vector async_xla_datas = - runtime::GetComputationClient()->CreateAsyncDatas(async->source_tensors); - async->handle_unlockers = - runtime::GetComputationClient()->LockAsyncDatas(async_xla_datas); - async->async_datas = WrapXlaData(async_xla_datas); - auto mwait = std::make_shared(/*num_wait=*/1); - auto update_data = [async, async_xla_datas]() { - try { - runtime::GetComputationClient()->TransferToServer(async->source_tensors, - async_xla_datas); - } catch (...) { - // There are two paths of discovery of an exception happening on an - // asynchronous task. One happens if the creator of the asynchronous task - // explicitly waits for completion, in which case the exception will be - // thrown from the Wait() API. Re-throwing the exception below makes sure - // this will be captured by the completer function created below, and - // surfaced by the Wait() API. But we also need to surface the exception - // even in case the caller does not wait, and that is accomplished by - // setting the unlockers status. In that case the exception will be - // surfaced when the user tries to acquire the device locks the next time. - std::exception_ptr exptr = std::current_exception(); - for (auto& unlocker : async->handle_unlockers) { - unlocker.SetStatus(exptr); - } - throw; - } - }; - runtime::env::ScheduleIoClosure( - runtime::util::MultiWait::Completer(mwait, std::move(update_data))); -} - torch::lazy::BackendDataPtr TensorToXlaData( const at::Tensor& tensor, const xla::Shape& shape, const torch::lazy::BackendDevice& device) { @@ -621,46 +586,20 @@ torch::lazy::BackendDataPtr TensorToXlaData( replicated_data, local_devices, sharding_spec)); } - static const bool transfer_async = - runtime::sys_util::GetEnvBool("XLA_TRANSFER_SCALAR_ASYNC", false); - if (transfer_async && tensor.dim() == 0 && tensor.numel() == 1) { - std::shared_ptr async = std::make_shared(); - auto populate_mwait = - std::make_shared(/*num_wait=*/1); - auto populate_fn = - [&](const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensor, source_tensor.shape, dest_buffer, - dest_buffer_size, device); - populate_mwait->Done(); - }; - - async->source_tensors.emplace_back(shape, device.toString(), - std::move(populate_fn)); - TransferToServerAsync(async, {device.toString()}); - XLA_CHECK_EQ(async->async_datas.size(), 1); - // Tensor is a reference and can be inplace updated between this function - // returned and populate_fn being called. Need to wait for populate_fn to be - // called. - populate_mwait->Wait(); - return async->async_datas.front(); - } else { - auto populate_fn = - [&](const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensor, source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; + auto populate_fn = + [&](const runtime::ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + PopulateTensorBuffer(tensor, source_tensor.shape, dest_buffer, + dest_buffer_size, device); + }; - std::vector source_tensors; - source_tensors.emplace_back(shape, device.toString(), - std::move(populate_fn)); + std::vector source_tensors; + source_tensors.emplace_back(shape, device.toString(), std::move(populate_fn)); - auto handles = - runtime::GetComputationClient()->TransferToServer(source_tensors); - XLA_CHECK_EQ(handles.size(), 1); - return WrapXlaData(handles.front()); - } + auto handles = + runtime::GetComputationClient()->TransferToServer(source_tensors); + XLA_CHECK_EQ(handles.size(), 1); + return WrapXlaData(handles.front()); } template @@ -849,7 +788,7 @@ torch::lazy::BackendDataPtr TensorToXlaData( std::vector CreateTensorsData( const std::vector& tensors, - const std::vector& devices, bool transfer_async) { + const std::vector& devices) { TORCH_LAZY_TIMED("TensorToData"); XLA_CHECK_EQ(tensors.size(), devices.size()); @@ -861,9 +800,6 @@ std::vector CreateTensorsData( if (devices[0] == "SPMD:0") { // When running in SPMD mode, tensors here in the unsharded // CreateTensorsData should be implicitly replicated to all devices. - // This case should always apply when using SPMD regardless - // of transfer_async's value, since SPMD requires PjRt and all transfers - // are asynchronous in PjRt. std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); std::vector handles; @@ -880,48 +816,22 @@ std::vector CreateTensorsData( return WrapXlaData(handles); } - if (transfer_async) { - std::shared_ptr async = std::make_shared(); - auto populate_mwait = - std::make_shared(tensors.size()); - for (size_t i = 0; i < tensors.size(); ++i) { - torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); - xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - populate_mwait->Done(); - }; - async->source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); - } - TransferToServerAsync(async, devices); - // Tensors is a vector reference and can be inplace updated between this - // function returned and populate_fn being called. Need to wait for - // populate_fn to be called. - populate_mwait->Wait(); - return async->async_datas; - } else { - std::vector source_tensors; - for (size_t i = 0; i < tensors.size(); ++i) { - torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); - xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); - } - return WrapXlaData( - runtime::GetComputationClient()->TransferToServer(source_tensors)); + std::vector source_tensors; + for (size_t i = 0; i < tensors.size(); ++i) { + torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); + xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); + auto populate_fn = + [&, i, device]( + const runtime::ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, + dest_buffer_size, device); + }; + source_tensors.emplace_back(std::move(shape), devices[i], + std::move(populate_fn)); } + return WrapXlaData( + runtime::GetComputationClient()->TransferToServer(source_tensors)); } std::vector CreateTensorsData( diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 1d8f6acc070..5071cba8da7 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -56,7 +56,7 @@ torch::lazy::hash_t TensorHash(const at::Tensor& tensor); // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice std::vector CreateTensorsData( const std::vector& tensors, - const std::vector& devices, bool transfer_async = false); + const std::vector& devices); // Shard and transfer tensors to devices using `PjRtComputationClient`. // The client's data transfer to device is asynchronous. diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 3fbc6f2d709..c2972ccff8f 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -30,7 +30,6 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/layout_manager.h" -#include "torch_xla/csrc/op_by_op_executor.h" #include "torch_xla/csrc/ops/arithmetic_ir_ops.h" #include "torch_xla/csrc/ops/cast.h" #include "torch_xla/csrc/ops/device_data.h" @@ -349,24 +348,15 @@ void XLAGraphExecutor::SyncTensorsGraph(std::vector* tensors, << " tensor(s)"; tsl::profiler::TraceMe activity("SyncTensorsGraph", tsl::profiler::TraceMeLevel::kInfo); - static const bool op_by_op = - runtime::sys_util::GetEnvBool("XLA_SYNC_TENSORS_OPBYOP", false); SyncTensorsConfig config; config.sync_ltc_data = sync_ltc_data; if (warm_up_cache_only) { config.force_ltc_data = false; } - if (op_by_op) { - OpByOpAsync async = SyncTensorsGraphOpByOp(tensors, devices, config); - if (wait) { - async.Wait(); - } - } else { - auto async = - SyncTensorsGraphInternal(tensors, devices, config, warm_up_cache_only); - if (wait && async != nullptr && !warm_up_cache_only) { - async->mwait.Wait(); - } + auto async = + SyncTensorsGraphInternal(tensors, devices, config, warm_up_cache_only); + if (wait && async != nullptr && !warm_up_cache_only) { + async->mwait.Wait(); } } @@ -419,9 +409,7 @@ std::vector XLAGraphExecutor::GetTensors( std::vector* tensors) { TF_VLOG(4) << "Trying to get the value of " << tensors->size() << " tensor(s)"; - static const bool op_by_op = - runtime::sys_util::GetEnvBool("XLA_GET_TENSORS_OPBYOP", false); - return op_by_op ? GetTensorsOpByOp(tensors) : GetTensorsFused(tensors); + return GetTensorsFused(tensors); } torch::lazy::hash_t XLAGraphExecutor::GetGraphHash( @@ -567,11 +555,6 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( } } } - // Mix the hash with the resource domain hashes as compile handles are only - // valid within a domain (usually a single host). - coll.hash = torch::lazy::MHash( - coll.hash, runtime::GetComputationClient()->GetResourceDomain( - coll.device.toString())); if (!at_tensors.empty()) { TORCH_LAZY_COUNTER("SyncTensorsToData", at_tensors.size()); // Create data handles with shardings. If a tensor has a @@ -818,32 +801,6 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( return result_backend_data; } -std::vector XLAGraphExecutor::GetTensorsOpByOp( - std::vector* tensors) { - SyncTensorsConfig config; - config.force_ltc_data = false; - SyncTensorCollection coll = CollectSyncTensors(*tensors, config); - std::vector async_tensors_data; - if (!coll.indices.empty()) { - DebugUtil::SaveTensorsGraphInfo("GetTensorsOpByOp", *tensors, - &coll.indices); - - std::vector roots = - CollectRoots(*tensors, coll.indices); - TensorCollectionBarrier(&coll); - async_tensors_data = - OpByOpExecutor::Get()->Execute(roots, coll.device.toString(), {}); - } - - std::vector tensors_data = - GatherTensorsXlaData(*tensors, coll.indices, async_tensors_data); - std::vector literals = - runtime::GetComputationClient()->TransferFromServer( - UnwrapXlaData(tensors_data)); - - return FetchTensors(tensors, literals, &coll.indices); -} - std::vector XLAGraphExecutor::GetTensorsFused( std::vector* tensors) { SyncTensorsConfig config; @@ -882,68 +839,6 @@ std::vector XLAGraphExecutor::GetTensorsFused( async != nullptr ? &async->indices : nullptr); } -XLAGraphExecutor::OpByOpAsync XLAGraphExecutor::SyncTensorsGraphOpByOp( - std::vector* tensors, absl::Span devices, - const SyncTensorsConfig& config) { - struct Async { - explicit Async(SyncTensorCollection coll, - std::vector tensors_data, - std::vector roots, - absl::Span devices) - : coll(std::move(coll)), - tensors_data(std::move(tensors_data)), - roots(std::move(roots)), - devices(devices.begin(), devices.end()) {} - - SyncTensorCollection coll; - std::vector tensors_data; - std::vector roots; - std::vector devices; - }; - - SyncTensorCollection coll = CollectSyncTensors(*tensors, config); - DebugUtil::SaveTensorsGraphInfo("SyncTensorsGraphOpByOp", *tensors, - &coll.indices); - - std::vector roots = CollectRoots(*tensors, coll.indices); - std::vector ir_values; - std::vector tensor_data_vec; - ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values, - tensor_data_vec); - auto tensors_data = - SetTensorData(tensors, coll.config, coll.indices, tensor_data_vec); - TensorCollectionBarrier(&coll); - auto async = std::make_shared(std::move(coll), std::move(tensors_data), - std::move(roots), devices); - auto syncfn = [async]() -> int { - try { - TF_VLOG(3) << "Executing (OpByOp) IR graph hash " - << torch::lazy::HashToString(async->coll.hash) << " on device " - << async->coll.device << " ..."; - std::vector results = - OpByOpExecutor::Get()->Execute( - async->roots, async->coll.device.toString(), async->devices); - TF_VLOG(3) << "Executing (OpByOp) IR graph hash " - << torch::lazy::HashToString(async->coll.hash) << " on device " - << async->coll.device << " done!"; - - for (size_t i = 0; i < results.size(); ++i) { - if (async->tensors_data[i] != nullptr) { - async->tensors_data[i]->Assign(*results[i]); - } - } - } catch (...) { - for (auto& unlocker : async->coll.unlocker) { - unlocker.SetStatus(std::current_exception()); - } - throw; - } - return 0; - }; - OpByOpAsync async_op(std::move(syncfn)); - return async_op.Schedule(); -} - std::vector XLAGraphExecutor::GatherTensorsXlaData( const std::vector& tensors, absl::Span indices, absl::Span tensors_data) { diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 309f534c7cb..0798853ecf0 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -119,7 +119,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // the tensors must be on the same device. If wait is true, the sync operation // will be run synchronously. The devices argument, if not empty, tells the // devices which should be participating into the replicated computation. - // We don't use the upstream one given we have OpbyOp mode. void SyncTensorsGraph(std::vector* tensors, absl::Span devices, bool wait, bool sync_ltc_data, bool warm_up_cache_only = false); @@ -146,7 +145,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // Retrieves the PyTorch CPU tensors behind the XLA tensors IR operations. // All the tensors must be on the same device. - // We don't use the GetTensors given we have OpByOp mode. std::vector GetTensors(std::vector* tensors); // We don't use the upstream GetGraphHash as XLATensorPtr is used instead. @@ -260,18 +258,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // Override to enable SPMD. void TensorCollectionBarrier(SyncTensorCollection* coll) final; - // Implementation of the GetTensors() API using the op-by-op executor. - std::vector GetTensorsOpByOp(std::vector* tensors); - // We don't use upstream GetTensorsFused as we have xla::Literal. std::vector GetTensorsFused(std::vector* tensors); - // Runs an asynchronous syn operation using the op-by-op executor. - using OpByOpAsync = runtime::util::AsyncTask; - OpByOpAsync SyncTensorsGraphOpByOp(std::vector* tensors, - absl::Span devices, - const SyncTensorsConfig& config); - // Gathers the XLA device data for all the input tensors, after an // asynchronous operation. // TODO(alanwaketan): Reuse the upstream one once Functionalization is done.