From 7ee1bab6bc07bc39133749deae97fe6e7004cd33 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Tue, 11 Jun 2024 00:27:44 +0000 Subject: [PATCH 1/9] Add build trigger for 2.4.0-rc1 release --- .../artifacts.auto.tfvars | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index c2617739f45..2e8db7cc7b3 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -33,6 +33,64 @@ nightly_builds = [ # Built on push to specific tag. versioned_builds = [ + # Remove libtpu from PyPI builds + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.8" + bundle_libtpu = "0" + }, + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.9" + bundle_libtpu = "0" + }, + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "0" + }, + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.11" + bundle_libtpu = "0" + }, + # Bundle libtpu for Kaggle + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1+libtpu" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "1" + }, + { + git_tag = "v2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + accelerator = "cuda" + cuda_version = "12.1" + python_version = "3.8" + }, + { + git_tag = "v2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + accelerator = "cuda" + cuda_version = "12.1" + python_version = "3.10" + }, # Remove libtpu from PyPI builds { git_tag = "v2.3.0" From 62425ebb108993a9f7587e4b74b2dddc7438abc8 Mon Sep 17 00:00:00 2001 From: jonb377 Date: Wed, 12 Jun 2024 14:20:53 -0700 Subject: [PATCH 2/9] Use dest_offsets directly in LoadPlanner (#7243) --- test/spmd/test_xla_distributed_checkpoint.py | 28 +++++++++++++++++++ .../distributed_checkpoint/planners.py | 5 ---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index a035a3f11bd..593dba0769b 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -172,6 +172,34 @@ def test_resharding_different_device_mesh(self): save_planner=SPMDSavePlanner(), load_planner=SPMDLoadPlanner()) + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed to change mesh") + def test_resharding_transpose_device_mesh(self): + dim = self.n_devices // 2 + model1 = self._get_sharded_model(mesh_shape=(dim, self.n_devices // dim)) + model2 = self._get_sharded_model(mesh_shape=(self.n_devices // dim, dim)) + self._save_and_restore( + model1, + model2, + save_planner=SPMDSavePlanner(), + load_planner=SPMDLoadPlanner()) + + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed to change mesh") + def test_padded_tensor(self): + # Use a linear layer with shape not divisible by the number of devices. + model1 = torch.nn.Linear(127, 63).to('xla') + model2 = torch.nn.Linear(127, 63).to('xla') + mesh = xs.Mesh(range(self.n_devices), (self.n_devices,)) + # Transpose the sharding to induce resharding in the restore path + xs.mark_sharding(model1.weight, mesh, (0, None)) + xs.mark_sharding(model2.weight, mesh, (None, 0)) + self._save_and_restore( + model1, + model2, + save_planner=SPMDSavePlanner(), + load_planner=SPMDLoadPlanner()) + @unittest.skipUnless('CHKPT_PATH' in os.environ, 'CHKPT_PATH must be set for multihost checkpoint') def test_multihost_checkpoint(self): diff --git a/torch_xla/experimental/distributed_checkpoint/planners.py b/torch_xla/experimental/distributed_checkpoint/planners.py index c417872c2f2..32fe987a97d 100644 --- a/torch_xla/experimental/distributed_checkpoint/planners.py +++ b/torch_xla/experimental/distributed_checkpoint/planners.py @@ -282,11 +282,6 @@ def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): lengths and offsets into the global tensor. """ offsets = read_item.dest_offsets - index = read_item.dest_index - if index.fqn in self.sharded_state_dict: - # Update offsets to index into the shard rather than the global tensor - shard = self._local_shards[index.fqn][index.index] - offsets = torch.Size(d - i.start for d, i in zip(offsets, shard.indices)) return narrow_tensor_by_index(tensor, offsets, read_item.lengths) def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: From da5843c6509326aebd2e926b1d3019461edd9de8 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Wed, 12 Jun 2024 18:03:36 -0700 Subject: [PATCH 3/9] Backport xla pin update (0612) (#7261) --- WORKSPACE | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 155cc74731b..faed0ceb57b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,7 +50,7 @@ new_local_repository( # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -xla_hash = 'b604c8d87df842002a7a8de79a434026329fbcb2' +xla_hash = 'bf2dc9fe056bd7140e5f29a2ae6db15a26dd5443' http_archive( name = "xla", diff --git a/setup.py b/setup.py index 7b6ecb4af6f..b150a2b782e 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240605' +_date = '20240612' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.29.dev{_date}' From 5622efc41ed5f4dd8e806b8f9894e026ddf57dfc Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 13 Jun 2024 10:01:10 -0700 Subject: [PATCH 4/9] Backport: Bump jaxversion corresponding to libtpu 0612 (#7262) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b150a2b782e..098a5eb20cb 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ _date = '20240612' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' -_jax_version = f'0.4.29.dev{_date}' +_jax_version = f'0.4.30.dev{_date}' def _get_build_mode(): From f00bd0b95a8ca7742dfba9e48d6aae630863fb0e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 14 Jun 2024 10:51:48 -0700 Subject: [PATCH 5/9] Cherry-pick: PJRT plugin defaults (#7249 and #7268) (#7270) Co-authored-by: Aman Gupta <4409685+aman2930@users.noreply.github.com> --- torch_xla/__init__.py | 4 +++- torch_xla/_internal/tpu.py | 11 +++++++++-- torch_xla/csrc/init_python_bindings.cpp | 7 +++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 5bfb7b8991b..983bd46d679 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -77,6 +77,7 @@ def _setup_default_env(): os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') # This is used for ML Framework Telemetry. os.environ.setdefault('TPU_ML_PLATFORM_VERSION', __version__) + os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') if tpu.version() == 4: os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') @@ -212,7 +213,8 @@ def _init_xla_lazy_backend(): from .experimental import plugins from ._internal import neuron, xpu # Additional built-in plugins -if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1': +if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS', + '0' if _XLAC._has_cuda_support() else '1') == '1': plugins.use_dynamic_plugins() plugins.register_installed_plugins() diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 241bf469d4a..8a42665012a 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -17,6 +17,7 @@ import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm from torch_xla.experimental import plugins +from torch_xla.version import __version__ _GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1' _ACCELERATOR_TYPE_TO_HOST_BOUNDS = { @@ -342,10 +343,16 @@ def configure_multiprocess(self, local_rank, local_world_size): return configure_topology(local_rank, local_world_size) def physical_chip_count(self): - return num_available_chips() + # HACK: We may reduce the number of processes we spawn depending on TPU + # topology settings + return num_local_processes() def client_create_options(self): return { 'max_inflight_computations': - xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 4) + xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 4), + 'ml_framework_name': + 'PyTorch/XLA', + 'ml_framework_version': + __version__ } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a6e3196c5d2..3fba13773b3 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2423,6 +2423,13 @@ void InitXlaModuleBindings(py::module m) { return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, /*is_tpu=*/true); }); + m.def("_has_cuda_support", []() { +#ifdef GOOGLE_CUDA + return true; +#else + return false; +#endif + }); m.def("_xla_gpu_custom_call", [](const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, From fda5828ea06af462bbd2bebcff69427d2589471f Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:17:49 -0700 Subject: [PATCH 6/9] Avoid log spam (#7278) (#7284) --- torch_xla/csrc/runtime/pjrt_registry.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 52b06d89cb4..e92dcf7dd44 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -82,6 +82,9 @@ InitializePjRt(const std::string& device_type) { if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; + // Init the absl logging to avoid the log spam. + absl::InitializeLog(); + std::shared_ptr kv_store = nullptr; if (plugin->requires_xla_coordinator()) { int local_process_rank = sys_util::GetEnvInt( From 72961330b8f8e5e8ddf29289a0e29a4d94e8051f Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:18:43 -0700 Subject: [PATCH 7/9] Revert the mul change (#7271) (#7285) --- test/test_operations.py | 2 ++ torch_xla/csrc/aten_xla_type.cpp | 13 +++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 938817a6fd2..6fb0b79d78d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2099,6 +2099,8 @@ def test(f, xshape, ishapes): for xshape, i0shape, i1shape in cases[f2]: test(f2, xshape, (i0shape, i1shape)) + @unittest.skipIf( + True, "skip since https://github.com/pytorch/xla/pull/7130 is reverted") def test_inplace_mul_scalar_different_dtype(self): # This tests whether the returned output data-type agrees on PyTorch # and XLA sides. diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index dc30734756d..3459b8935e8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2221,14 +2221,11 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::mul(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&, - std::optional); - return OpConfig::From(static_cast(tensor_methods::mul)) - .add_input(self) - .add_input(other) - .cast_inputs_to_common_dtype() - .use_opmathtype_for_compute() - .run(); + return DoBinaryOp(self, other, + [&](const XLATensorPtr& xself, const XLATensorPtr& xother, + at::ScalarType dtype) { + return tensor_methods::mul(xself, xother, dtype); + }); } at::Tensor XLANativeFunctions::mul(const at::Tensor& self, From a901eb8e86d88239402ef5312868c7940ae0609f Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:48:03 -0700 Subject: [PATCH 8/9] [backport] Use np.prod instead of np.product (#7301) (#7310) Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> --- test/spmd/test_sharding_strategies.py | 4 ++-- torch_xla/distributed/spmd/xla_sharding.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/spmd/test_sharding_strategies.py b/test/spmd/test_sharding_strategies.py index 849b45c7c98..4f869961f09 100644 --- a/test/spmd/test_sharding_strategies.py +++ b/test/spmd/test_sharding_strategies.py @@ -67,9 +67,9 @@ num_devices = xr.global_runtime_device_count() -assert np.product(dcn_parallelism) * np.product( +assert np.prod(dcn_parallelism) * np.prod( ici_parallelism) == num_devices, f"Number of devices {num_devices} \ - does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}" + does not match the product of the parallelism {np.prod(dcn_parallelism) * np.prod(ici_parallelism)}" # Use HybridMesh to optimize multislice topology mesh = xs.HybridMesh( diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 916fb56f7c9..73e41e01fba 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -268,7 +268,7 @@ def _create_device_mesh_for_nd_torus( indices = itertools.combinations( range(len(assignable_physical_mesh)), num_axes) for c_axes, c_indices in zip(axes, indices): - if np.product(c_axes) == logical_axis_size: + if np.prod(c_axes) == logical_axis_size: assignment[logical_axis_index] = c_indices # Zero the assigned physical axes. assignable_physical_mesh = [ From 279346205d43c1b5c1ec299556b5eac33070d5ca Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:44:01 -0700 Subject: [PATCH 9/9] [backport][Fori_loop|While_loop] Enable while_loop/fori_loop, add test case (#7157) (#7306) Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> --- docs/fori_loop.md | 116 +++++-------- test/run_tests.sh | 2 +- ...while_loop_simple_add_dispatch_in_torch.py | 106 ------------ test/test_while_loop.py | 116 +++++++++++++ test/tpu/run_tests.sh | 2 +- torch_xla/csrc/init_python_bindings.cpp | 35 ++-- torch_xla/experimental/fori_loop.py | 157 +++++++++++++----- 7 files changed, 287 insertions(+), 247 deletions(-) delete mode 100644 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py create mode 100644 test/test_while_loop.py diff --git a/docs/fori_loop.md b/docs/fori_loop.md index 0c9f85af399..c29e32e28b3 100644 --- a/docs/fori_loop.md +++ b/docs/fori_loop.md @@ -1,114 +1,72 @@ -# Fori_loop -`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation -like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps -of each iteration. `fori_loop` might help memory utilization and might help faster compilation. +# `While_loop` optimize memory utilization and compilation -User could use `fori_loop` like this: -```python -from torch_xla.experimental.fori_loop import fori_loop -res = fori_loop(upper, lower, /*user defined*/body_fun, init) -``` - -current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too. +
-For detailed implementation: -- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop), -like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the -native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only. +### `while_loop` +`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by +[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66). +PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`. -- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan), -like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator. -This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference. - -# while_loop -`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in -[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69). -PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While. - -User could use `while_loop` like this: +#### Usage: ```python import torch_xla.experimental.fori_loop from torch._higher_order_ops.while_loop import while_loop -res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init) +result = while_loop(cond_fn, body_fn, init) ``` -current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too. - +- `cond_fn`: User-defined condition function. +- `body_fn`: User-defined loop body function. +- `init`: Initial values (tuple or list). -# [WIP]scan -like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd. -`scan` is WIP. - - -# Simple user guide -User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times: - -### simple example with pure python for loop -```bash -# python ->>> import torch ->>> init = torch.tensor([0], dtype=torch.int32) ->>> one_value = torch.ones(1, dtype=torch.int32) ->>> ->>> for i in range(10): -... init = init + one_value -... ->>> init -tensor([10], dtype=torch.int32) -``` - -### simple example with `while_loop`: +#### simple example with `while_loop`: ```bash # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla >>> import torch_xla.experimental.fori_loop ->>> from torch_xla.experimental.fori_loop import fori_loop >>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm ->>> import torch_xla.core.xla_builder as xb >>> >>> device = xm.xla_device() >>> ->>> def cond_fn(init, limit_value): -... return limit_value[0] >= init[0] +>>> def cond_fn(iteri, x): +... return iteri > 0 ... ->>> def body_fn(init, limit_value): -... one_value = torch.ones(1, dtype=torch.int32, device=device) -... return (torch.add(init, one_value), limit_value.clone()) +>>> def body_fn(iteri, x): +... return iteri - 1, torch.add(x, 1) ... ->>> init = torch.tensor([0], dtype=torch.int32, device=device) ->>> limit_value = torch.tensor([10], dtype=torch.int32, device=device) ->>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value)) ->>> res_ +>>> init_val = torch.tensor(3, device=device) +>>> iteri = torch.tensor(10, device=device) +>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val)) +>>> res FunctionalTensor(lvl=0, value=\ -tensor([11], device='xla:0', dtype=torch.int32)) +tensor(13, device='xla:0')) ``` -### simple example with `fori_loop`: +
+ +## Control group test case +For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times: + +### Control group example with pure python `while` loop ```bash # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla ->>> import torch_xla.experimental.fori_loop ->>> from torch_xla.experimental.fori_loop import fori_loop ->>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm ->>> import torch_xla.core.xla_builder as xb >>> >>> device = xm.xla_device() >>> ->>> lower = torch.tensor([2], dtype=torch.int32, device=device) ->>> upper = torch.tensor([52], dtype=torch.int32, device=device) ->>> plus_value = torch.tensor([1], dtype=torch.int32, device=device) ->>> init_val = torch.tensor([1], dtype=torch.int32, device=device) +>>> init_val = torch.tensor(1, device=device) +>>> iteri = torch.tensor(50, device=device) >>> ->>> def body_fun(*argus): -... plus_value, init_val = argus -... return plus_value, torch.add(plus_value, init_val) +>>> while iteri > 0: +... init_val = init_val + 1 +... iteri -= 1 ... ->>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val) ->>> res_ -tensor([51], device='xla:0', dtype=torch.int32) +>>> init_val +tensor(51, device='xla:0') ``` -For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3 + + +PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape diff --git a/test/run_tests.sh b/test/run_tests.sh index 4a298f01ee5..26d3c82303e 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -203,7 +203,7 @@ function run_xla_op_tests1 { function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" run_test "$CDIR/pjrt/test_dtypes.py" - run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py" + run_test "$CDIR/test_while_loop.py" run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU } diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py deleted file mode 100644 index a76197cc736..00000000000 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import unittest -from typing import Callable, Dict, List - -import torch -import torch_xla -# We need to import the underlying implementation function to register with the dispatcher -import torch_xla.experimental.fori_loop -from torch_xla.experimental.fori_loop import fori_loop -from torch._higher_order_ops.while_loop import while_loop -import torch_xla.core.xla_model as xm -import torch_xla.core.xla_builder as xb - - -def _fake_while_loop(cond_fn, body_fn, operands): - # operands need to be more than one here - while cond_fn(*operands): - operands = body_fn(*operands) - return operands - - -def _fake_fori_loop(lower, upper, body_fun, *init_val): - (plus_value, init_val) = init_val - for i in range((upper - lower)[0]): - plus_value, init_val = body_fun(plus_value, init_val) - return init_val - - -class WhileLoopTest(unittest.TestCase): - - def test_while_loop_tpu_subtraction(self): - - device = xm.xla_device() - - def cond_fn(init, limit_value): - return limit_value[0] <= init[0] - - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - two_value = limit_value.clone() - return (torch.sub(init, one_value), two_value) - - init = torch.tensor([10], dtype=torch.int32, device=device) - limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) - expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - self.assertEqual(expected, res) - - def test_while_loop_tpu_addition(self): - - device = xm.xla_device() - - def cond_fn(init, limit_value): - return limit_value[0] >= init[0] - - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - return (torch.add(init, one_value), limit_value.clone()) - - # TODO(@manfei): init and limit_value has to be torch.tensor. - init = torch.tensor([0], dtype=torch.int32, device=device) - limit_value = torch.tensor([10], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) - expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - self.assertEqual(expected, res) - - def test_while_loop_tpu_subtraction_nested(self): - - device = xm.xla_device() - - def cond_fn(init, limit_value): - return limit_value[0] <= init[0] - - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - two_value = limit_value.clone() - return (torch.sub(torch.sub(init, one_value), one_value), two_value) - - init = torch.tensor([10], dtype=torch.int32, device=device) - limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) - expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - self.assertEqual(expected, res) - - def test_fori_loop_tpu_addition(self): - - xm.mark_step() - device = xm.xla_device() - - lower = torch.tensor([2], dtype=torch.int32, device=device) - upper = torch.tensor([52], dtype=torch.int32, device=device) - plus_value = torch.tensor([1], dtype=torch.int32, device=device) - init_val = torch.tensor([1], dtype=torch.int32, device=device) - - def body_fun(*argus): - plus_value, init_val = argus - return plus_value, torch.add(plus_value, init_val) - - _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) - expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) - self.assertEqual(expected, actual) - - -if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/test_while_loop.py b/test/test_while_loop.py new file mode 100644 index 00000000000..e8ea617b0f9 --- /dev/null +++ b/test/test_while_loop.py @@ -0,0 +1,116 @@ +import os +import unittest +from typing import Callable, Dict, List + +import torch +import torch_xla +# We need to import the underlying implementation function to register with the dispatcher +import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import fori_loop +from torch._higher_order_ops.while_loop import while_loop +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb +import torch_xla.utils.utils as xu +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + + +def _fake_while_loop(cond_fn, body_fn, operands): + # operands need to be more than one here + while cond_fn(*operands): + operands = body_fn(*operands) + return operands + + +class WhileLoopTest(unittest.TestCase): + + def test_while_loop_addition(self): + device = xm.xla_device() + + def cond_fn(iteri, x): + return iteri > 0 + + def body_fn(iteri, x): + return iteri - 1, torch.add(x, 1) + + init_val = torch.tensor(3, dtype=torch.int32, device=device) + iteri = torch.tensor(10, device=device) + _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val)) + _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val)) + self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) + + def test_while_loop_addition_nested(self): + device = xm.xla_device() + + def cond_fn(iteri, x): + return iteri > 0 + + def body_fn(iteri, x): + return iteri - 1, torch.add(torch.add(x, 1), 1) + + init_val = torch.tensor(2, dtype=torch.int32, device=device) + iteri = torch.tensor(10, device=device) + _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val)) + _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val)) + self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) + + def test_while_loop_simple_linear_inside_loop(self): + device = xm.xla_device() + torch.set_grad_enabled(False) + + class SimpleLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, iteri, x): + + def cond_fn(iteri, x): + return iteri > 0 + + def body_fn(iteri, x): + return iteri - 1, self.linear(x) + + return while_loop(cond_fn, body_fn, (iteri, x)) + + def forward_without_while_loop_op(self, iteri, x): + while (iteri > 0): + x = self.linear(x) + iteri -= 1 + return iteri, x + + linear_model = SimpleLinear() + linear_model.to(device) + l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device) + iteri = torch.tensor(10, dtype=torch.int32, device=device) + _, res_with_loop = linear_model(iteri, l_in_0) + _, res_without_loop = linear_model.forward_without_while_loop_op( + iteri, l_in_0) + + self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) + + # ====== fori_loop ====== + @unittest.skip("Fori_loop is not supported now due to unstable result.") + def test_fori_loop_addition(self): + device = xm.xla_device() + + lower = torch.tensor(0, device=device) + upper = torch.tensor(50, device=device) + init_val = torch.tensor(1, dtype=torch.int32, device=device) + + def body_fun(x): + return torch.add(x, 1) + + _, res_with_loop = fori_loop(lower, upper, body_fun, (init_val)) + + # === expected === + for i in range(upper - lower): + init_val = torch.add(init_val, 1) + res_without_loop = init_val + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 6d74c40d4e3..ddd439d1c60 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -20,7 +20,7 @@ python3 test/dynamo/test_dynamo.py python3 test/spmd/test_spmd_debugging.py python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py -python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +python3 test/test_while_loop.py python3 test/test_pallas.py python3 test/test_pallas_spmd.py python3 test/test_input_output_aliases.py diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 3fba13773b3..db3e54d8163 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -933,22 +933,9 @@ class PyLoweringContext { } // Builds a HLO graph given a set of output tensors, and add unused parameters - // needed in xlacomputation. + // needed in xlacomputation for fori_loop/while_loop. void BuildForiLoop(std::vector tensors, - std::vector input_arguments = {}) { - if (GetNameString() == "condctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // hard-code parameter_idx to 2 to skip existing upper/lower arguments - int64_t parameter_idx = 2; - for (at::Tensor input_argument : input_arguments) { - xla::Shape shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } - } - + std::vector additional_inputs_list = {}) { // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/true); @@ -966,6 +953,24 @@ class PyLoweringContext { torch::lazy::Output(ir_value.node.get(), ir_value.index)); lowering_ctx.AddResult(root); } + + // add dummy parameter to cond/body xlacomputation's input for xla::while + // requriement + if ((GetNameString() == "condctx") or + (GetNameString() == "bodyctx" && additional_inputs_list.size() != 0)) { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameter_idx = + local_builder->GetProgramShape()->parameters_size(); + int64_t additional_inputs_list_size = additional_inputs_list.size(); + for (int64_t i = parameter_idx; i < additional_inputs_list_size; i++) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_inputs_list[i]); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } + } + computation = ConsumeValue(lowering_ctx.BuildXla()); // wrap inputs of cond/body_computation diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bf32a712f3e..e41709084e2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -10,80 +10,141 @@ from torch._ops import HigherOrderOperator import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op +from torch._higher_order_ops.while_loop import while_loop as torch_while_loop +from torch._higher_order_ops.utils import _has_potential_branch_input_mutation -def fori_loop(lower, upper, user_body_func, *init_val): +def fori_loop(lower, upper, body_fun, *input_value): device = xm.xla_device() + if (upper < lower): + print("ERROR: upper should be a larger number than lower") + iteri = upper - lower - def cond_fn(upper, lower, *init_val): - return lower[0] < upper[0] + def cond_fn(iteri, *input_value): + return iteri > 0 - def body_fn(upper, lower, *init_val): - one_value_i = torch.ones(1, dtype=torch.int32, device=device) - res_list = list(user_body_func(*init_val)) - res_list.insert(0, lower) - res_list.insert(0, torch.sub(upper, one_value_i)) - return res_list + def new_body_fn(iteri, *input_value): + return iteri - 1, body_fun(*input_value) + + inputs = (iteri,) + input_value + res = _xla_while_loop_wrapper( + cond_fn, new_body_fn, inputs, (), fake_tensor=True) - res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) return res @while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): - # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') - # cond_fn&body_fn: callable - # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): if additional_inputs is None: additional_inputs = tuple() - return _xla_while_loop( - cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) + return _xla_while_loop_wrapper(cond_fn, body_fn, carried_inputs, + additional_inputs) + + +def _xla_while_loop_wrapper(cond_fn, + body_fn, + carried_inputs, + additional_inputs=None, + fake_tensor=False): + + def new_body_fn(*carried_inputs): + res = list(body_fn(*carried_inputs)) + if additional_inputs: + res = [ + res[0], + ] + list(additional_inputs) + res[1:] + else: + res = res + return res + return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, + additional_inputs, fake_tensor) -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): - # untuple carried_inputs from while_loop - carried_inputs = carried_inputs[0] - # fake carried_inputs to split formal code + +def _xla_while_loop(cond_fn, + body_fn, + carried_inputs, + additional_inputs=None, + fake_tensor=False): + + # ====== fake_carried_inputs ====== fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) - fake_carried_inputs = tuple(fake_carried_inputs) - # trans fake_carried_inputs from list(tensor) to list(xla::op) - kwargs = {} - if type(fake_carried_inputs) is tuple: - shapes = xb.tensor_shape(fake_carried_inputs) + # ====== additional_inputs_list_cond ====== + fake_additiona_args = [] + for additional_input in additional_inputs: + device = additional_input.device + fake_additiona_args.append( + torch.randint( + 10, additional_input.size(), + dtype=additional_input.dtype).to(device)) + + # ====== inputs_list ====== + # specify body_fn_inputs/cond_fn_inputs, and add caught additional_inputs into fn_inputs + if additional_inputs or fake_tensor: + # replace inputs(carried_inputs[1:]) with fake tensors to fix missed arguments problem + body_fn_inputs = [ + carried_inputs[0], + ] + fake_carried_inputs[1:] + list(additional_inputs) + cond_fn_inputs = carried_inputs + additional_inputs else: - shapes = xb.tensor_shape((fake_carried_inputs)) - builder = xb.create_builder('test_while') - params = [] - for shape in shapes: - p = xb.mkparam(builder, len(params), shape) - params.append(p) + body_fn_inputs = carried_inputs + cond_fn_inputs = carried_inputs + + # due to `xla::While` requirement, body xlacomputation inputs/outputs, cond xlacomputation and init need to be the same shape and type; + # and carried_inputs contain (iter, values), additional_inputs contain (weights/bias) + # based on generated body xlacomputation outputs: (iter, weights/bias, values) + # we create expected order for cond/body xlacomputation generation to compare and match: (iter, weights/bias, values) + dummy_inputs_list = [ + fake_carried_inputs[0], + ] + fake_additiona_args + fake_carried_inputs[1:] + + # ====== body_fn ====== + body_result = body_fn(*body_fn_inputs) + body_ctx = torch_xla._XLAC.lowering.LoweringContext() + body_ctx.set_name_string("bodyctx") - # generate cond_fn xlacomputation - cond_result = cond_fn(*fake_carried_inputs) + # ====== body xlacomputation ====== + body_ctx.buildforiloop(list(body_result), dummy_inputs_list) + body_hlo = body_ctx.hlo() + body_computation = xb.computation_from_module_proto("bodycomputation", + body_hlo) + + # ====== cond_fn ====== + cond_result = cond_fn(*cond_fn_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) + + # ====== cond xlacomputation ====== + cond_ctx.buildforiloop([cond_result], dummy_inputs_list) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs) - body_ctx = torch_xla._XLAC.lowering.LoweringContext() - body_ctx.set_name_string("bodyctx") - body_ctx.buildforiloop(list(body_result), []) - body_hlo = body_ctx.hlo() - body_computation = xb.computation_from_module_proto("bodycomputation", - body_hlo) + # ====== xla::while ====== + iter_value = carried_inputs[0] + input_and_outputs_value = carried_inputs[1:] + total_inputs = tuple([ + iter_value, + ]) + tuple(additional_inputs) + tuple(input_and_outputs_value) + + kwargs = {} + if type(total_inputs) is tuple: + shapes = xb.tensor_shape(total_inputs) + else: + shapes = xb.tensor_shape((total_inputs)) + builder = xb.create_builder('while_loop') + params = [] + for shape in shapes: + p = xb.mkparam(builder, len(params), shape) + params.append(p) - # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( 'While', (input_tuple.op,), @@ -94,6 +155,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (carried_inputs), computation) + (total_inputs), computation) + + # unwrapper result without additional_inputs for original order + additional_inputs_len = len(additional_inputs) + 1 + final_res = [ + result[0], + ] + result[additional_inputs_len:] - return result \ No newline at end of file + return final_res