Skip to content

Commit

Permalink
Merge branch 'pytorch:master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
guyao committed Sep 14, 2024
2 parents fd49904 + 37894b0 commit 4cd7569
Show file tree
Hide file tree
Showing 13 changed files with 44 additions and 33 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = 'be7eef5742089e328152908b8662e83e34bf73c1'
xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4'

http_archive(
name = "xla",
Expand Down Expand Up @@ -139,4 +139,4 @@ xla_workspace0()
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
cuda_configure(name = "local_config_cuda")
load("@tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure")
nccl_configure(name = "local_config_nccl")
nccl_configure(name = "local_config_nccl")
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"_upsample_bilinear2d_aa",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"block_diag",
"bucketize",
"byte",
"cat",
"cauchy",
Expand Down
5 changes: 5 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,11 @@ def fix_dim(p):
new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1]
return self.reshape(new_shape)

@op(torch.ops.aten.bucketize)
def _aten_bucketize(input, boundaries, *, out_int32=False, right=False, out=None):
assert boundaries[0] < boundaries[-1], "boundaries must contain a strictly increasing sequence"
return_type = jnp.int32 if out_int32 else jnp.int64
return jnp.digitize(input, boundaries, right=not right).astype(return_type)

@op(torch.ops.aten.convolution)
def _aten_convolution(
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240801'
_date = '20240913'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl'
_jax_version = f'0.4.32.dev{_date}'
_jax_version = f'0.4.33.dev{_date}'


def _get_build_mode():
Expand Down
2 changes: 1 addition & 1 deletion test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
python3 test/spmd/test_fsdp_v2.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/test_fp8.py
python3 test/test_grad_checkpoint.py
Expand Down Expand Up @@ -53,6 +52,7 @@ if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py
python3 examples/eager/train_decoder_only_eager_with_compile.py
python3 examples/eager/train_decoder_only_eager_multi_process.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
fi

# Test `tpu-info` CLI compatibility
Expand Down
9 changes: 4 additions & 5 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ def _setup_libtpu_flags():
# improves device memory usage.
flags = _set_missing_flags(
flags, (('xla_tpu_prefer_async_allgather_to_allreduce', 'true'),))
# This flag enables FlashAttention HLO pass that pattern matches attention

# This flag enables FlashAttention HLO pass that pattern matches attention
# and rewrites it as flash attention. This pattern matching is causing
# issues for our standard dot product attention. Turning it off till
# we fix the issue with pattern matching.
flags = _set_missing_flags(
flags, (('xla_tpu_enable_flash_attention', 'false'),)
)
flags = _set_missing_flags(flags,
(('xla_tpu_enable_flash_attention', 'false'),))

if tpu.version() == 5:
default_v5_flags = {
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ cc_library(
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
"@xla//xla:status",
"@tsl//tsl/profiler/lib:profiler_factory",
"@tsl//tsl/profiler/rpc:profiler_server_impl",
"@tsl//tsl/profiler/rpc/client:capture_profile",
"@xla//xla/tsl/profiler/rpc:profiler_server_impl",
"@xla//xla/tsl/profiler/rpc/client:capture_profile",
"@com_google_absl//absl/container:flat_hash_map",

# TODO: We get missing symbol errors without these deps. Why aren't they
Expand All @@ -311,7 +311,7 @@ cc_library(
"@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl",
"@tsl//tsl/profiler/protobuf:profiler_service_proto_cc_impl",
"@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc_impl",
"@tsl//tsl/profiler/rpc/client:profiler_client",
"@xla//xla/tsl/profiler/rpc/client:profiler_client",
],
)

Expand Down Expand Up @@ -463,7 +463,7 @@ ptxla_cc_test(
":xla_util",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/lib/core:status_test_util",
"@xla//xla/tsl/lib/core:status_test_util",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status_matchers",
"@xla//xla:shape_util",
Expand All @@ -479,7 +479,7 @@ ptxla_cc_test(
":computation_client",
":pjrt_computation_client",
":tensor_source",
"@tsl//tsl/lib/core:status_test_util",
"@xla//xla/tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
Expand All @@ -504,7 +504,7 @@ ptxla_cc_test(
":computation_client",
":ifrt_computation_client",
":tensor_source",
"@tsl//tsl/lib/core:status_test_util",
"@xla//xla/tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
Expand Down
28 changes: 18 additions & 10 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,17 @@ torch::lazy::hash_t hash_comp_env(
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(platform_version.c_str()));
// Include global devices in the hash, ensuring order is consistent.
xla::ifrt::DeviceList::Devices ifrt_devices;
xla::ifrt::BasicDeviceList::Devices ifrt_devices;
for (auto& device : ordered_devices) {
std::string device_str(device->ToString());
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(device_str.c_str()));
ifrt_devices.push_back(device);
}

xla::ifrt::DeviceList device_list(std::move(ifrt_devices));
tsl::RCReference<xla::ifrt::DeviceList> device_list =
xla::ifrt::BasicDeviceList::Create(std::move(ifrt_devices));

auto topology_desc = client->GetTopologyForDevices(device_list);
if (topology_desc.ok()) {
// Some backends support a topology description which provides a better
Expand Down Expand Up @@ -205,7 +207,8 @@ std::vector<ComputationClient::DataPtr> IfrtComputationClient::GetDataShards(

for (auto array : arrays) {
shards.push_back(std::make_shared<IfrtData>(
IfrtDeviceToString(array->sharding().devices()[0]), array));
IfrtDeviceToString(array->sharding().devices()->devices().front()),
array));
}
} else {
shards.push_back(data);
Expand All @@ -232,9 +235,12 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards(
shard_shapes.push_back(ifrt_shard->buffer->shape());
}
xla::ifrt::Shape ifrt_shape(shape.dimensions());
xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(),
client_->addressable_devices().end()});
XLA_CHECK_EQ(shard_shapes.size(), devices_list.size());
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
xla::ifrt::BasicDeviceList::Create(
{client_->addressable_devices().begin(),
client_->addressable_devices().end()});

XLA_CHECK_EQ(shard_shapes.size(), devices_list->size());
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(),
ifrt_shape, shard_shapes);
Expand Down Expand Up @@ -318,8 +324,10 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice(
shard_shapes.push_back(ifrt_shard->buffer->shape());
}
xla::ifrt::Shape ifrt_shape(shape.dimensions());
xla::ifrt::DeviceList devices_list({client_->addressable_devices().begin(),
client_->addressable_devices().end()});
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
xla::ifrt::BasicDeviceList::Create(
{client_->addressable_devices().begin(),
client_->addressable_devices().end()});
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(),
ifrt_shape, shard_shapes);
Expand All @@ -346,7 +354,7 @@ ComputationClient::DataPtr IfrtComputationClient::CopyToDevice(

tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
const std::shared_ptr<IfrtData> handle) {
if (handle->buffer->sharding().devices().size() == 1) {
if (handle->buffer->sharding().devices()->size() == 1) {
return handle->buffer;
}

Expand Down Expand Up @@ -383,7 +391,7 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
computations = Compile(std::move(instances));

XLA_CHECK_EQ(handle->buffer->sharding().devices().size(),
XLA_CHECK_EQ(handle->buffer->sharding().devices()->size(),
GetLocalDevices().size());

torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class IfrtComputationClient : public ComputationClient {
ss << " Data Shape: " << shape().ToString() << "\n";
ss << " OpSharding: "
<< xla::HloSharding::FromProto(*sharding_)->ToString() << "\n";
ss << " NumShards: " << buffer->sharding().devices().size() << "\n";
ss << " NumShards: " << buffer->sharding().devices()->size() << "\n";
} else {
ss << "XLAData: \n";
ss << " Data Device: " << device() << "\n";
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/test.h"
Expand All @@ -19,6 +18,7 @@
#include "xla/literal_util.h"
#include "xla/statusor.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/lib/core/status_test_util.h"

namespace torch_xla {
namespace runtime {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/test.h"
Expand All @@ -20,6 +19,7 @@
#include "xla/literal_util.h"
#include "xla/statusor.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/lib/core/status_test_util.h"

namespace torch_xla {
namespace runtime {
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "tsl/profiler/lib/profiler_factory.h"
#include "tsl/profiler/rpc/client/capture_profile.h"
#include "tsl/profiler/rpc/profiler_server.h"
#include "xla/backends/profiler/plugin/plugin_tracer.h"
#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h"
#include "xla/tsl/profiler/rpc/client/capture_profile.h"
#include "xla/tsl/profiler/rpc/profiler_server.h"

namespace torch_xla {
namespace runtime {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/xla_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

#include "absl/status/status.h"
#include "absl/types/span.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/protobuf/error_codes.pb.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla_util.h"

namespace torch_xla {
Expand Down

0 comments on commit 4cd7569

Please sign in to comment.