From 0ff871d0ccc6fba15ce49f265f80d4fd5d054268 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 26 Sep 2024 21:51:56 +0000 Subject: [PATCH] Pin update --- WORKSPACE | 3 ++- setup.py | 2 +- torch_xla/csrc/BUILD | 1 + torch_xla/csrc/runtime/ifrt_computation_client.h | 1 + torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- torch_xla/csrc/xla_sharding_util.cpp | 2 +- 6 files changed, 7 insertions(+), 4 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 585891e149b..423d1f8cd08 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,7 +50,8 @@ new_local_repository( # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4' +#xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4' +xla_hash = '06bbcd1a798cd49bb811674fbed8823dfef51cc4' http_archive( name = "xla", diff --git a/setup.py b/setup.py index 778daf0cf1c..e10ff470127 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240913' +_date = '20240926' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl' _jax_version = f'0.4.33.dev{_date}' diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 89fefda457f..1287ffbde98 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -151,6 +151,7 @@ ptxla_cc_library( "@xla//xla/client/lib:slicing", "@xla//xla/client/lib:sorting", "@xla//xla/client/lib:svd", + "@xla//xla/hlo/pass:hlo_pass_pipeline", "@xla//xla/stream_executor:dnn", "@tsl//tsl/platform:errors", "@tsl//tsl/profiler/lib:traceme", diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index f712a30f221..fd34021393d 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -19,6 +19,7 @@ #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/shape.h" diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 280a733bebe..74403b88040 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -1025,7 +1025,7 @@ void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name, args.function_name = fn_name.c_str(); args.function_name_size = fn_name.size(); args.api_version = 0; - args.custom_call_function = function_ptr; + args.handler_execute = function_ptr; PJRT_Error* error = reinterpret_cast(next)->custom_call(&args); if (error) { diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index e6a10c1740b..c48eba0e970 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -23,9 +23,9 @@ #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/protobuf_util.h" #include "xla/service/hlo_parser.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/service/hlo_verifier.h" #include "xla/service/sharding_propagation.h" #include "xla/service/spmd/spmd_partitioner.h"