Skip to content

Commit

Permalink
Update openxla pin to 11 July 2024 (#7639)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 committed Jul 12, 2024
1 parent ac26a97 commit 8043050
Show file tree
Hide file tree
Showing 13 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = '8533a6869ae02fb3b15a8a12739a982fc3c9f6e7'
xla_hash = 'db472b8c3d83bc27b3c67067802b9a37ef7542e3'

http_archive(
name = "xla",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240628'
_date = '20240711'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.31.dev{_date}'
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) {
DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) {
DLDevice dlDevice;
dlDevice.device_type = DLDeviceTypeForDevice(device);
dlDevice.device_id = device.local_hardware_id();
dlDevice.device_id = device.local_hardware_id().value();
return dlDevice;
}

Expand Down Expand Up @@ -148,7 +148,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = DLPackTensorDeleter;
dt.device = DLDeviceForDevice(*pjrt_buffer->device());
dt.device.device_id = pjrt_buffer->device()->local_hardware_id();
dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value();
dt.ndim = pjrt_buffer->dimensions().size();
dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type());

Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ void BuildProfilerSubmodule(py::module* m) {
return py::none();
})
.def("set_metadata", &xla::profiler::TraceMeWrapper::SetMetadata)
.def_static("is_enabled", &xla::profiler::TraceMeWrapper::IsEnabled);
.def_static("is_enabled", &tsl::profiler::TraceMe::Active);

py::class_<torch::lazy::ScopePusher,
std::unique_ptr<torch::lazy::ScopePusher>>
Expand Down Expand Up @@ -1726,9 +1726,9 @@ void InitXlaModuleBindings(py::module m) {
return bridge::AtenDeviceToXlaDevice(device_str).ordinal();
});
m.def("_xla_get_device_attributes", [](const std::string& device_str) {
const absl::flat_hash_map<
std::string, runtime::ComputationClient::DeviceAttribute>& attributes =
runtime::GetComputationClient()->GetDeviceAttributes(
const absl::flat_hash_map<std::string,
runtime::ComputationClient::DeviceAttribute>
attributes = runtime::GetComputationClient()->GetDeviceAttributes(
bridge::AtenDeviceToXlaDevice(device_str).toString());

py::dict dict;
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ cc_library(
"@xla//xla/pjrt/distributed",
"@xla//xla/python/ifrt",
"@xla//xla/python/pjrt_ifrt",
"@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"@xla//xla/python/ifrt:attribute_map",
],
)

Expand Down Expand Up @@ -377,6 +379,7 @@ cc_library(
deps = [
"@xla//xla:statusor",
"@xla//xla/service:platform_util",
"@com_google_absl//absl/base:log_severity",
],
)

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class ComputationClient {
std::variant<std::string, bool, int64_t, std::vector<int64_t>, float>;

virtual const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>
GetDeviceAttributes(const std::string& device) = 0;

virtual void SetReplicationDevices(
Expand Down
7 changes: 5 additions & 2 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/pjrt_ifrt/pjrt_array.h"
#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "xla/python/pjrt_ifrt/xla_sharding.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -633,9 +635,10 @@ int IfrtComputationClient::GetNumProcesses() const {
};

const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>
IfrtComputationClient::GetDeviceAttributes(const std::string& device) {
return IfrtComputationClient::StringToIfrtDevice(device)->Attributes();
return xla::ifrt::ToPjRtDeviceAttributeMap(
IfrtComputationClient::StringToIfrtDevice(device)->Attributes());
}

void IfrtComputationClient::SetReplicationDevices(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class IfrtComputationClient : public ComputationClient {
int GetNumProcesses() const override;

const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>
GetDeviceAttributes(const std::string& device) override;

void SetReplicationDevices(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ int PjRtComputationClient::GetNumProcesses() const {
};

const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>
PjRtComputationClient::GetDeviceAttributes(const std::string& device) {
return PjRtComputationClient::StringToPjRtDevice(device)->Attributes();
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class PjRtComputationClient : public ComputationClient {
int GetNumProcesses() const override;

const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>
GetDeviceAttributes(const std::string& device) override;

void SetReplicationDevices(
Expand Down
12 changes: 7 additions & 5 deletions torch_xla/csrc/runtime/tf_logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <sstream>

#include "absl/base/log_severity.h"
#include "tsl/platform/logging.h"

namespace torch_xla {
Expand All @@ -22,11 +23,12 @@ namespace internal {
return vmodule_activated; \
})(lvl, __FILE__))

#define TF_VLOG(level) \
TF_PREDICT_TRUE(!TF_VLOG_IS_ON(level)) \
? (void)0 \
: ::tsl::internal::Voidifier() & \
::tsl::internal::LogMessage(__FILE__, __LINE__, ::tsl::INFO)
#define TF_VLOG(level) \
TF_PREDICT_TRUE(!TF_VLOG_IS_ON(level)) \
? (void)0 \
: ::tsl::internal::Voidifier() & \
::tsl::internal::LogMessage(__FILE__, __LINE__, \
absl::LogSeverity::kInfo)

struct ErrorSink : public std::basic_ostringstream<char> {};

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/xla_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void ReportComputationError(
}
ss << "StackTrace:\n" << tsl::CurrentStackTrace() << "\n";
ss << "Status: " << status << "\n";
XLA_LOG_LINES(tsl::ERROR, ss.str());
XLA_LOG_LINES(ERROR, ss.str());
throw std::runtime_error(status.ToString());
}

Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ TORCH_LIBRARY_FRAGMENT(xla, m) {

namespace {

using tsl::ERROR;
using tsl::INFO;
using xla::internal::XlaBuilderFriend;

static bool use_auto_sharding = false;
Expand Down

0 comments on commit 8043050

Please sign in to comment.