From 8043050f29b557ce753e929b8fca506f21c93d5f Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 12 Jul 2024 05:48:28 +0000 Subject: [PATCH] Update openxla pin to 11 July 2024 (#7639) --- WORKSPACE | 2 +- setup.py | 2 +- torch_xla/csrc/dl_convertor.cpp | 4 ++-- torch_xla/csrc/init_python_bindings.cpp | 8 ++++---- torch_xla/csrc/runtime/BUILD | 3 +++ torch_xla/csrc/runtime/computation_client.h | 2 +- torch_xla/csrc/runtime/ifrt_computation_client.cc | 7 +++++-- torch_xla/csrc/runtime/ifrt_computation_client.h | 2 +- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- torch_xla/csrc/runtime/pjrt_computation_client.h | 2 +- torch_xla/csrc/runtime/tf_logging.h | 12 +++++++----- torch_xla/csrc/runtime/xla_util.cc | 2 +- torch_xla/csrc/xla_sharding_util.cpp | 2 -- 13 files changed, 28 insertions(+), 22 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 4c7a9cbc392..1e1f0e36810 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,7 +50,7 @@ new_local_repository( # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -xla_hash = '8533a6869ae02fb3b15a8a12739a982fc3c9f6e7' +xla_hash = 'db472b8c3d83bc27b3c67067802b9a37ef7542e3' http_archive( name = "xla", diff --git a/setup.py b/setup.py index 9fdee39bee6..709d9ea0452 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240628' +_date = '20240711' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.31.dev{_date}' diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index bd1e616ab72..a2310f61d35 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -56,7 +56,7 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { DLDevice dlDevice; dlDevice.device_type = DLDeviceTypeForDevice(device); - dlDevice.device_id = device.local_hardware_id(); + dlDevice.device_id = device.local_hardware_id().value(); return dlDevice; } @@ -148,7 +148,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { pack->tensor.manager_ctx = pack.get(); pack->tensor.deleter = DLPackTensorDeleter; dt.device = DLDeviceForDevice(*pjrt_buffer->device()); - dt.device.device_id = pjrt_buffer->device()->local_hardware_id(); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value(); dt.ndim = pjrt_buffer->dimensions().size(); dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type()); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a495a4dd561..3721b0c3768 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -904,7 +904,7 @@ void BuildProfilerSubmodule(py::module* m) { return py::none(); }) .def("set_metadata", &xla::profiler::TraceMeWrapper::SetMetadata) - .def_static("is_enabled", &xla::profiler::TraceMeWrapper::IsEnabled); + .def_static("is_enabled", &tsl::profiler::TraceMe::Active); py::class_> @@ -1726,9 +1726,9 @@ void InitXlaModuleBindings(py::module m) { return bridge::AtenDeviceToXlaDevice(device_str).ordinal(); }); m.def("_xla_get_device_attributes", [](const std::string& device_str) { - const absl::flat_hash_map< - std::string, runtime::ComputationClient::DeviceAttribute>& attributes = - runtime::GetComputationClient()->GetDeviceAttributes( + const absl::flat_hash_map + attributes = runtime::GetComputationClient()->GetDeviceAttributes( bridge::AtenDeviceToXlaDevice(device_str).toString()); py::dict dict; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 01f717bfd34..19c5711338d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -92,6 +92,8 @@ cc_library( "@xla//xla/pjrt/distributed", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/ifrt:attribute_map", ], ) @@ -377,6 +379,7 @@ cc_library( deps = [ "@xla//xla:statusor", "@xla//xla/service:platform_util", + "@com_google_absl//absl/base:log_severity", ], ) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 105d19f64c0..f296b7f45c0 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -378,7 +378,7 @@ class ComputationClient { std::variant, float>; virtual const absl::flat_hash_map< - std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> GetDeviceAttributes(const std::string& device) = 0; virtual void SetReplicationDevices( diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 86a27da4c02..36d9b7bb857 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -24,10 +24,12 @@ #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/shape.h" @@ -633,9 +635,10 @@ int IfrtComputationClient::GetNumProcesses() const { }; const absl::flat_hash_map< - std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> IfrtComputationClient::GetDeviceAttributes(const std::string& device) { - return IfrtComputationClient::StringToIfrtDevice(device)->Attributes(); + return xla::ifrt::ToPjRtDeviceAttributeMap( + IfrtComputationClient::StringToIfrtDevice(device)->Attributes()); } void IfrtComputationClient::SetReplicationDevices( diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 59664d045e8..5c3316cf997 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -109,7 +109,7 @@ class IfrtComputationClient : public ComputationClient { int GetNumProcesses() const override; const absl::flat_hash_map< - std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> GetDeviceAttributes(const std::string& device) override; void SetReplicationDevices( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 5df49c4e5da..866591b13b5 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -938,7 +938,7 @@ int PjRtComputationClient::GetNumProcesses() const { }; const absl::flat_hash_map< - std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> PjRtComputationClient::GetDeviceAttributes(const std::string& device) { return PjRtComputationClient::StringToPjRtDevice(device)->Attributes(); } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 1b08b55c601..2bf6eaa30de 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -127,7 +127,7 @@ class PjRtComputationClient : public ComputationClient { int GetNumProcesses() const override; const absl::flat_hash_map< - std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> GetDeviceAttributes(const std::string& device) override; void SetReplicationDevices( diff --git a/torch_xla/csrc/runtime/tf_logging.h b/torch_xla/csrc/runtime/tf_logging.h index 14eeb8bde80..5d413252c58 100644 --- a/torch_xla/csrc/runtime/tf_logging.h +++ b/torch_xla/csrc/runtime/tf_logging.h @@ -3,6 +3,7 @@ #include +#include "absl/base/log_severity.h" #include "tsl/platform/logging.h" namespace torch_xla { @@ -22,11 +23,12 @@ namespace internal { return vmodule_activated; \ })(lvl, __FILE__)) -#define TF_VLOG(level) \ - TF_PREDICT_TRUE(!TF_VLOG_IS_ON(level)) \ - ? (void)0 \ - : ::tsl::internal::Voidifier() & \ - ::tsl::internal::LogMessage(__FILE__, __LINE__, ::tsl::INFO) +#define TF_VLOG(level) \ + TF_PREDICT_TRUE(!TF_VLOG_IS_ON(level)) \ + ? (void)0 \ + : ::tsl::internal::Voidifier() & \ + ::tsl::internal::LogMessage(__FILE__, __LINE__, \ + absl::LogSeverity::kInfo) struct ErrorSink : public std::basic_ostringstream {}; diff --git a/torch_xla/csrc/runtime/xla_util.cc b/torch_xla/csrc/runtime/xla_util.cc index 36199a0a9a6..7a3658c6857 100644 --- a/torch_xla/csrc/runtime/xla_util.cc +++ b/torch_xla/csrc/runtime/xla_util.cc @@ -93,7 +93,7 @@ void ReportComputationError( } ss << "StackTrace:\n" << tsl::CurrentStackTrace() << "\n"; ss << "Status: " << status << "\n"; - XLA_LOG_LINES(tsl::ERROR, ss.str()); + XLA_LOG_LINES(ERROR, ss.str()); throw std::runtime_error(status.ToString()); } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index dd3eb566d06..e6a10c1740b 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -49,8 +49,6 @@ TORCH_LIBRARY_FRAGMENT(xla, m) { namespace { -using tsl::ERROR; -using tsl::INFO; using xla::internal::XlaBuilderFriend; static bool use_auto_sharding = false;