diff --git a/WORKSPACE b/WORKSPACE index 0991771bbfd..585891e149b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,7 +50,7 @@ new_local_repository( # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -xla_hash = 'be7eef5742089e328152908b8662e83e34bf73c1' +xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4' http_archive( name = "xla", @@ -139,4 +139,4 @@ xla_workspace0() load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") cuda_configure(name = "local_config_cuda") load("@tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure") -nccl_configure(name = "local_config_nccl") \ No newline at end of file +nccl_configure(name = "local_config_nccl") diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 5dc485c425f..2b88ae324fa 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -16,7 +16,6 @@ "_upsample_bilinear2d_aa", "bincount", # NOTE: dtype for int input torch gives float. This is weird. "block_diag", - "bucketize", "byte", "cat", "cauchy", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 113cfb6c7d2..5e094d39927 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -762,6 +762,11 @@ def fix_dim(p): new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1] return self.reshape(new_shape) +@op(torch.ops.aten.bucketize) +def _aten_bucketize(input, boundaries, *, out_int32=False, right=False, out=None): + assert boundaries[0] < boundaries[-1], "boundaries must contain a strictly increasing sequence" + return_type = jnp.int32 if out_int32 else jnp.int64 + return jnp.digitize(input, boundaries, right=not right).astype(return_type) @op(torch.ops.aten.convolution) def _aten_convolution( diff --git a/setup.py b/setup.py index 647d03a88a4..778daf0cf1c 100644 --- a/setup.py +++ b/setup.py @@ -64,10 +64,10 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240801' +_date = '20240913' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl' -_jax_version = f'0.4.32.dev{_date}' +_jax_version = f'0.4.33.dev{_date}' def _get_build_mode(): diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index d7a8ba1c2a6..52d1de5b150 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -13,7 +13,6 @@ python3 test/spmd/test_xla_spmd_python_api_interaction.py python3 test/spmd/test_xla_auto_sharding.py python3 test/spmd/test_fsdp_v2.py XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v -XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v python3 test/test_autocast.py python3 test/test_fp8.py python3 test/test_grad_checkpoint.py @@ -53,6 +52,7 @@ if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py python3 examples/eager/train_decoder_only_eager_with_compile.py python3 examples/eager/train_decoder_only_eager_multi_process.py + XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v fi # Test `tpu-info` CLI compatibility diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 3cefe486417..87e37b5e231 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -66,14 +66,13 @@ def _setup_libtpu_flags(): # improves device memory usage. flags = _set_missing_flags( flags, (('xla_tpu_prefer_async_allgather_to_allreduce', 'true'),)) - - # This flag enables FlashAttention HLO pass that pattern matches attention + + # This flag enables FlashAttention HLO pass that pattern matches attention # and rewrites it as flash attention. This pattern matching is causing # issues for our standard dot product attention. Turning it off till # we fix the issue with pattern matching. - flags = _set_missing_flags( - flags, (('xla_tpu_enable_flash_attention', 'false'),) - ) + flags = _set_missing_flags(flags, + (('xla_tpu_enable_flash_attention', 'false'),)) if tpu.version() == 5: default_v5_flags = { diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 51b44cb6812..2a348849aba 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -301,8 +301,8 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "@xla//xla:status", "@tsl//tsl/profiler/lib:profiler_factory", - "@tsl//tsl/profiler/rpc:profiler_server_impl", - "@tsl//tsl/profiler/rpc/client:capture_profile", + "@xla//xla/tsl/profiler/rpc:profiler_server_impl", + "@xla//xla/tsl/profiler/rpc/client:capture_profile", "@com_google_absl//absl/container:flat_hash_map", # TODO: We get missing symbol errors without these deps. Why aren't they @@ -311,7 +311,7 @@ cc_library( "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", "@tsl//tsl/profiler/protobuf:profiler_service_proto_cc_impl", "@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc_impl", - "@tsl//tsl/profiler/rpc/client:profiler_client", + "@xla//xla/tsl/profiler/rpc/client:profiler_client", ], ) @@ -463,7 +463,7 @@ ptxla_cc_test( ":xla_util", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", + "@xla//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status_matchers", "@xla//xla:shape_util", @@ -479,7 +479,7 @@ ptxla_cc_test( ":computation_client", ":pjrt_computation_client", ":tensor_source", - "@tsl//tsl/lib/core:status_test_util", + "@xla//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -504,7 +504,7 @@ ptxla_cc_test( ":computation_client", ":ifrt_computation_client", ":tensor_source", - "@tsl//tsl/lib/core:status_test_util", + "@xla//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 645a82937ec..7b9edd50d6b 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -68,7 +68,7 @@ torch::lazy::hash_t hash_comp_env( hash = torch::lazy::HashCombine( hash, torch::lazy::StringHash(platform_version.c_str())); // Include global devices in the hash, ensuring order is consistent. - xla::ifrt::DeviceList::Devices ifrt_devices; + xla::ifrt::BasicDeviceList::Devices ifrt_devices; for (auto& device : ordered_devices) { std::string device_str(device->ToString()); hash = torch::lazy::HashCombine( @@ -76,7 +76,9 @@ torch::lazy::hash_t hash_comp_env( ifrt_devices.push_back(device); } - xla::ifrt::DeviceList device_list(std::move(ifrt_devices)); + tsl::RCReference device_list = + xla::ifrt::BasicDeviceList::Create(std::move(ifrt_devices)); + auto topology_desc = client->GetTopologyForDevices(device_list); if (topology_desc.ok()) { // Some backends support a topology description which provides a better @@ -205,7 +207,8 @@ std::vector IfrtComputationClient::GetDataShards( for (auto array : arrays) { shards.push_back(std::make_shared( - IfrtDeviceToString(array->sharding().devices()[0]), array)); + IfrtDeviceToString(array->sharding().devices()->devices().front()), + array)); } } else { shards.push_back(data); @@ -232,9 +235,12 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( shard_shapes.push_back(ifrt_shard->buffer->shape()); } xla::ifrt::Shape ifrt_shape(shape.dimensions()); - xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), - client_->addressable_devices().end()}); - XLA_CHECK_EQ(shard_shapes.size(), devices_list.size()); + tsl::RCReference devices_list = + xla::ifrt::BasicDeviceList::Create( + {client_->addressable_devices().begin(), + client_->addressable_devices().end()}); + + XLA_CHECK_EQ(shard_shapes.size(), devices_list->size()); std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(), ifrt_shape, shard_shapes); @@ -318,8 +324,10 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice( shard_shapes.push_back(ifrt_shard->buffer->shape()); } xla::ifrt::Shape ifrt_shape(shape.dimensions()); - xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(), - client_->addressable_devices().end()}); + tsl::RCReference devices_list = + xla::ifrt::BasicDeviceList::Create( + {client_->addressable_devices().begin(), + client_->addressable_devices().end()}); std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(), ifrt_shape, shard_shapes); @@ -346,7 +354,7 @@ ComputationClient::DataPtr IfrtComputationClient::CopyToDevice( tsl::RCReference IfrtComputationClient::ReplicateShardedData( const std::shared_ptr handle) { - if (handle->buffer->sharding().devices().size() == 1) { + if (handle->buffer->sharding().devices()->size() == 1) { return handle->buffer; } @@ -383,7 +391,7 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( std::shared_ptr> computations = Compile(std::move(instances)); - XLA_CHECK_EQ(handle->buffer->sharding().devices().size(), + XLA_CHECK_EQ(handle->buffer->sharding().devices()->size(), GetLocalDevices().size()); torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index ec8c314bc3d..f712a30f221 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -223,7 +223,7 @@ class IfrtComputationClient : public ComputationClient { ss << " Data Shape: " << shape().ToString() << "\n"; ss << " OpSharding: " << xla::HloSharding::FromProto(*sharding_)->ToString() << "\n"; - ss << " NumShards: " << buffer->sharding().devices().size() << "\n"; + ss << " NumShards: " << buffer->sharding().devices()->size() << "\n"; } else { ss << "XLAData: \n"; ss << " Data Device: " << device() << "\n"; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cc b/torch_xla/csrc/runtime/ifrt_computation_client_test.cc index 475bee607d4..4027f902e6b 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cc @@ -9,7 +9,6 @@ #include "absl/status/status.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/tensor_source.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" @@ -19,6 +18,7 @@ #include "xla/literal_util.h" #include "xla/statusor.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace torch_xla { namespace runtime { diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc index 65adaa950a2..be42a5e4dbf 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc @@ -10,7 +10,6 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/tensor_source.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" @@ -20,6 +19,7 @@ #include "xla/literal_util.h" #include "xla/statusor.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace torch_xla { namespace runtime { diff --git a/torch_xla/csrc/runtime/profiler.cc b/torch_xla/csrc/runtime/profiler.cc index fa9e1c633d0..f73264044ba 100644 --- a/torch_xla/csrc/runtime/profiler.cc +++ b/torch_xla/csrc/runtime/profiler.cc @@ -4,11 +4,11 @@ #include "absl/status/status.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "tsl/profiler/lib/profiler_factory.h" -#include "tsl/profiler/rpc/client/capture_profile.h" -#include "tsl/profiler/rpc/profiler_server.h" #include "xla/backends/profiler/plugin/plugin_tracer.h" #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" +#include "xla/tsl/profiler/rpc/client/capture_profile.h" +#include "xla/tsl/profiler/rpc/profiler_server.h" namespace torch_xla { namespace runtime { diff --git a/torch_xla/csrc/runtime/xla_util_test.cc b/torch_xla/csrc/runtime/xla_util_test.cc index caba578ac73..6119ef73092 100644 --- a/torch_xla/csrc/runtime/xla_util_test.cc +++ b/torch_xla/csrc/runtime/xla_util_test.cc @@ -9,13 +9,13 @@ #include "absl/status/status.h" #include "absl/types/span.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status_matchers.h" #include "tsl/protobuf/error_codes.pb.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla_util.h" namespace torch_xla {