diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index 1931c3295f9..68c2c6a94d6 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -131,6 +131,15 @@ with env: print(type(res)) # outputs XLATensor2 ``` +You can also enable the environment globally with +```python +import torch_xla2 + +torch_xla2.enable_globally() +``` + +Then everything afterwards is run with XLA. + ## What is happening behind the scene: When a torch op is executed inside of `env` context manager, we can swap out the diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py index 982d2b6ce44..a723f647ca8 100644 --- a/experimental/torch_xla2/examples/basic_training.py +++ b/experimental/torch_xla2/examples/basic_training.py @@ -7,18 +7,16 @@ """ import torch -from torch.utils import _pytree as pytree import torchvision import torchvision.transforms as transforms -import torch_xla2.tensor - - -xla_env = torch_xla2.default_env() # PyTorch TensorBoard support -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime +#from torch.utils.tensorboard import SummaryWriter +#from datetime import datetime +# NOTE: add these lines to make it run on TPUs! +import torch_xla2 +torch_xla2.enable_globally() transform = transforms.Compose( [transforms.ToTensor(), @@ -83,7 +81,6 @@ def forward(self, x): model = GarmentClassifier() -model = xla_env.to_xla(model) loss_fn = torch.nn.CrossEntropyLoss() @@ -102,7 +99,7 @@ def forward(self, x): # Optimizers specified in the torch.optim package optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) -def train_one_epoch(epoch_index, tb_writer): +def train_one_epoch(epoch_index, tb_writer=None): running_loss = 0. last_loss = 0. @@ -112,7 +109,6 @@ def train_one_epoch(epoch_index, tb_writer): for i, data in enumerate(training_loader): # Every data instance is an input + label pair # NEW: Move model to XLA device - data = xla_env.to_xla(data) inputs, labels = data # Zero your gradients for every batch! @@ -135,7 +131,7 @@ def train_one_epoch(epoch_index, tb_writer): last_loss = running_loss / 1000 # loss per batch print(' batch {} loss: {}'.format(i + 1, last_loss)) tb_x = epoch_index * len(training_loader) + i + 1 - tb_writer.add_scalar('Loss/train', last_loss, tb_x) + #tb_writer.add_scalar('Loss/train', last_loss, tb_x) running_loss = 0. return last_loss @@ -143,8 +139,8 @@ def train_one_epoch(epoch_index, tb_writer): # Initializing in a separate cell so we can easily add more epochs to the same run -timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') -writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) +#timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') +#writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) epoch_number = 0 EPOCHS = 2 best_vloss = 1_000_000. @@ -156,7 +152,7 @@ def train_one_epoch(epoch_index, tb_writer): model.train(True) - avg_loss = train_one_epoch(epoch_number, writer) + avg_loss = train_one_epoch(epoch_number) running_vloss = 0.0 # Set the model to evaluation mode, disabling dropout and using population @@ -167,7 +163,6 @@ def train_one_epoch(epoch_index, tb_writer): with torch.no_grad(): for i, vdata in enumerate(validation_loader): # NOTE: move to XLA device - vinputs, vlabels = xla_env.to_xla(vdata) voutputs = model(vinputs) # call model's forward vloss = loss_fn(voutputs, vlabels) running_vloss += vloss diff --git a/experimental/torch_xla2/test/test_mutations.py b/experimental/torch_xla2/test/test_mutations.py index 51088143175..f5385a445c6 100644 --- a/experimental/torch_xla2/test/test_mutations.py +++ b/experimental/torch_xla2/test/test_mutations.py @@ -14,16 +14,14 @@ def test_add(self): x = torch.tensor([1, 2, 3], dtype=torch.int32) y = torch.tensor([4, 5, 6], dtype=torch.int32) x.add_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32)) + self.assertEqual(x, torch.tensor([5, 7, 9], dtype=torch.int32)) def test_sub(self): with self.env: x = torch.tensor([1, 2, 3], dtype=torch.int32) y = torch.tensor([4, 5, 6], dtype=torch.int32) x.sub_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32)) + self.assertEqual(x, torch.tensor([-3, -3, -3], dtype=torch.int32)) def test_mul(self): with self.env: @@ -31,8 +29,7 @@ def test_mul(self): y = torch.tensor([4, 5, 6], dtype=torch.int32) x.mul_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32)) + self.assertEqual(x, torch.tensor([4, 10, 18], dtype=torch.int32)) if __name__ == '__main__': diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 1efd67bf312..ed2614455ed 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -391,6 +391,7 @@ def setUpClass(cls): def setUp(self): self.env = tensor.Environment() + torch.manual_seed(0) # Replaces all values in the input torch_tensor that are less than the given threshold # with the threshold value itself. diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index e193d9b3304..54af0eccab4 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -11,6 +11,7 @@ __all__ = [ 'default_env', 'extract_jax', + 'enable_globally', ] from jax._src import xla_bridge @@ -61,3 +62,22 @@ def jax_func(states, inputs): return env.t2j_iso(res) return states, jax_func + +def enable_globally(): + global env + env = default_env().__enter__() + return env + +def disable_globally(): + global env + default_env().__exit__(None, None, None) + + +torch.utils.rename_privateuse1_backend('jax') +unsupported_dtype = [torch.quint8] +torch.utils.generate_methods_for_privateuse1_backend( + for_tensor=True, for_module=True, for_storage=True, + unsupported_dtype=unsupported_dtype) + +import jax +torch._register_device_module('jax', jax) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/config.py b/experimental/torch_xla2/torch_xla2/config.py index 2f971f13fa4..119f3b44d7e 100644 --- a/experimental/torch_xla2/torch_xla2/config.py +++ b/experimental/torch_xla2/torch_xla2/config.py @@ -10,4 +10,8 @@ class Configuration: # Flash attention use_tpu_flash_attention: bool = False - shmap_flash_attention: bool = False \ No newline at end of file + shmap_flash_attention: bool = False + + # device + treat_cuda_as_jax_device: bool = True + use_torch_native_for_cpu_tensor: bool = False diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py index 92934af576d..e4357ef5a51 100644 --- a/experimental/torch_xla2/torch_xla2/interop.py +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -91,7 +91,7 @@ def _jax_view(t: TorchValue) -> JaxValue: assert isinstance(t, tensor.XLATensor2) return t.jax() if isinstance(t, type(torch.int32)): - return tensor.j2t_dtype(t) + return tensor.t2j_dtype(t) # torch.nn.Module needs special handling if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index ade10ceeca2..32bd18dec8a 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2,6 +2,7 @@ import sys from typing import Optional, Sequence +import functools import jax from jax import numpy as jnp @@ -38,6 +39,7 @@ torch.ops.aten.squeeze_: torch.ops.aten.squeeze, torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p, torch.ops.aten.clamp_: torch.ops.aten.clamp, + torch.ops.aten.random_: torch.ops.aten.uniform, } @@ -55,6 +57,18 @@ def op(*aten, **kwargs): def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) + continue + + if isinstance(a, torch._ops.OpOverloadPacket): + opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name + elif isinstance(a, torch._ops.OpOverload): + opname = a.name() + else: + raise RuntimeError(f'oops {a}') + + torchfunc = functools.partial(interop.call_jax, func) + # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor + torch.library.impl(opname, 'privateuseone')(torchfunc if a != torch.ops.aten._to_copy else func) return func return inner @@ -80,14 +94,13 @@ def _aten_add(x, y, *, alpha=1): return x + y * alpha -@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False) +@op(torch.ops.aten.copy_, is_jax_function=False) def _aten_copy(x, y, memory_format=None): x._elem = y._elem.astype(x._elem.dtype) return x @op(torch.ops.aten.clone) -@op(torch.ops.aten.clone.default) def _aten_clone(x, memory_format=None): return x @@ -433,6 +446,8 @@ def _aten__to_copy(self, **kwargs): return jnp.copy(self) + + @op(torch.ops.aten.empty) @op_base.convert_dtype() def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs): @@ -465,7 +480,6 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): @op(torch.ops.aten.empty_permuted) -@op(torch.ops.aten.empty_permuted.default) @op_base.convert_dtype() def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): # Ignore the physical layout, @@ -474,7 +488,6 @@ def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): @op(torch.ops.aten.empty_strided) -@op(torch.ops.aten.empty_strided.default) @op_base.convert_dtype() def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): # Ignore stride, since JAX and torch tensor doesn't share the same memory. @@ -540,7 +553,6 @@ def permute(t, dims): @op(torch.ops.aten.unsqueeze) @op(torch.ops.aten.unsqueeze_copy) -@op(torch.ops.aten.unsqueeze.default) def _aten_unsqueeze(self, dim): if dim < 0: dim += self.ndim + 1 @@ -1618,7 +1630,6 @@ def _with_reduction_scalar(jax_func, self, dim, keepdim): def _aten_any(self, dim=None, keepdim=False): return _with_reduction_scalar(jnp.any, self, dim, keepdim) - # aten.arange @op(torch.ops.aten.arange.start_step) @op(torch.ops.aten.arange.start) @@ -1960,7 +1971,6 @@ def _aten_ge(self, other): @op(torch.ops.aten.glu) -@op(torch.ops.aten.glu.default) def _aten_glu(x, dim=-1): return jax.nn.glu(x, dim) @@ -2110,6 +2120,38 @@ def _aten_prod(self, dim=None, keepdim=False): # aten.randperm +# randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) +@op(torch.ops.aten.randperm, needs_env=True) +def _aten_randperm( + n, *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + env=None): + """ + Generates a random permutation of integers from 0 to n-1. + + Args: + n: The upper bound (exclusive) of the permutation range. + generator: A PRNGKey used as the random key. If None, a new key is created. + dtype: The desired data type of the output array. Default is jnp.int64. + layout: The desired layout of the output array (e.g., 'row-major', 'column-major'). + device: The desired device on which to place the output array (e.g., jax.devices()[0]). + pin_memory: Whether to pin the output array's memory to the host. + + Returns: + A DeviceArray containing a random permutation of integers from 0 to n-1. + """ + if dtype: + dtype = mappings.t2j_dtype(dtype) + else: + dtype = jnp.int64.dtype + key = env.get_and_rotate_prng_key(generator) + indices = jnp.arange(n, dtype=dtype) + permutation = jax.random.permutation(key, indices) + return permutation # aten.reflection_pad3d @@ -2467,6 +2509,12 @@ def _aten_normal(self, mean=0, std=1, generator=None, env=None): res = _randn(*shape, generator=generator, env=env) return res * std + mean +# TODO: not clear what this function should actually do +# https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940 +@op(torch.ops.aten.lift_fresh) +def _aten_lift_fresh(self): + return self + @op(torch.ops.aten.uniform, needs_env=True) def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None): assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})' @@ -2476,7 +2524,7 @@ def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None): #func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -@op(torch.ops.aten.randint, torch.ops.aten.randint.generator, needs_env=True) +@op(torch.ops.aten.randint, needs_env=True) @op_base.convert_dtype(use_default_dtype=False) def _aten_randint( *args, diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index 6928951dfb1..da0a911685a 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -1,4 +1,5 @@ """Tensor constructor overrides""" +import collections.abc import functools from typing import Optional, Sequence import numpy as np @@ -10,13 +11,26 @@ import torch from torch_xla2.ops.ops_registry import register_torch_function_op -from torch_xla2.ops import op_base, mappings +from torch_xla2.ops import op_base, mappings, jaten +import torch_xla2.tensor def register_function(torch_func, **kwargs): return functools.partial(register_torch_function_op, torch_func, **kwargs) +@register_function(torch.as_tensor, is_jax_function=False, needs_env=True) +@op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements +def _as_tensor(data, dtype=None, device=None, env=None): + if isinstance(data, torch.Tensor): + return env._to_copy(data, dtype, device) + if isinstance(data, np.ndarray): + jax_res = jnp.asarray(data) + else: + jax_res = _tensor(data, dtype=dtype) + return torch_xla2.tensor.XLATensor2(jax_res, env) + + @register_function(torch.tensor) @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements def _tensor(data, *, dtype=None, **kwargs): @@ -175,4 +189,79 @@ def _sparse_mm(mat1, mat2, reduce='sum'): @register_function(torch.isclose) def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.isclose(input, other, rtol, atol, equal_nan) \ No newline at end of file + return jnp.isclose(input, other, rtol, atol, equal_nan) + +@register_function(torch.ones) +def _ones(*size: int, dtype=None, **kwargs): + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.ones(size, dtype=dtype) + + +@register_function(torch.zeros, is_jax_function=False) +def _zeros(*size: int, dtype=None, **kwargs): + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.zeros(size, dtype=dtype) + + +@register_function(torch.eye) +@op_base.convert_dtype() +def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): + return jnp.eye(n, m, dtype=dtype) + + +@register_function(torch.full) +@op_base.convert_dtype() +def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): + # TODO: handle torch.Size + return jnp.full(size, fill_value, dtype=dtype) + + +@register_function(torch.empty) +@op_base.convert_dtype() +def empty(*size: Sequence[int], dtype=None, **kwargs): + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return jnp.empty(size, dtype=dtype) + +@register_function(torch.arange, is_jax_function=False) +def arange( + start, end=None, step=None, + out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, + pin_memory=None, +): + if end is None: + end = start + start = 0 + if step is None: + step = 1 + return torch.ops.aten.arange(start, end, step, dtype=dtype) + +@register_function(torch.empty_strided, is_jax_function=False) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False): + return empty(size, dtype=dtype) + + +@register_function(torch.rand, is_jax_function=False) +def rand( + *size, **kwargs +): + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.rand(size, **kwargs) + +@register_function(torch.randn, is_jax_function=False) +def randn( + *size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False +): + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.randn(size, generator=generator, dtype=dtype) + +@register_function(torch.randint, is_jax_function=False) +def randint( + *args, **kwargs +): + return torch.ops.aten.randint(*args, **kwargs) diff --git a/experimental/torch_xla2/torch_xla2/ops/mappings.py b/experimental/torch_xla2/torch_xla2/ops/mappings.py index 36154075eb5..3832b7feceb 100644 --- a/experimental/torch_xla2/torch_xla2/ops/mappings.py +++ b/experimental/torch_xla2/torch_xla2/ops/mappings.py @@ -4,6 +4,7 @@ import torch import torch.func import torch.utils.dlpack as torchdl +import torch.utils._mode_utils as mode_utils def t2j(t): @@ -38,14 +39,15 @@ def t2j(t): def j2t(x): - try: - dl = jaxdl.to_dlpack(x) - res = torchdl.from_dlpack(dl) - except Exception: - res = torch.from_numpy(numpy.asarray(x)) - if x.dtype == jnp.bool_: - res = res.to(torch.bool) - return res + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + try: + dl = jaxdl.to_dlpack(x) + res = torchdl.from_dlpack(dl) + except Exception: + res = torch.from_numpy(numpy.asarray(x)) + if x.dtype == jnp.bool_: + res = res.to(torch.bool) + return res TORCH_DTYPE_TO_JAX = { # NO_MAPPING : jnp.float0.dtype (signless scalar int), @@ -86,7 +88,7 @@ def j2t(x): def t2j_dtype(dtype): if dtype not in TORCH_DTYPE_TO_JAX: - raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,') + raise RuntimeError(f'Attempting to convert unknown type: {dtype} to jax type,') return TORCH_DTYPE_TO_JAX[dtype] diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/experimental/torch_xla2/torch_xla2/ops/op_base.py index 8c52192f53c..2c4176a361d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/experimental/torch_xla2/torch_xla2/ops/op_base.py @@ -55,7 +55,10 @@ def wrapper(*args: P.args, **kwargs: P.kwargs): if not dtype and use_default_dtype: dtype = torch.get_default_dtype() - jax_dtype = mappings.t2j_dtype(dtype) + if isinstance(dtype, torch.dtype): + jax_dtype = mappings.t2j_dtype(dtype) + else: + jax_dtype = dtype return func(*args, dtype=jax_dtype, **kwargs) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 1e2cfb6445c..f503f705f45 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -1,6 +1,6 @@ import sys import contextlib -from typing import Optional +from typing import Optional, Any import jax import jax.numpy as jnp import numpy @@ -83,6 +83,8 @@ def __init__(self, elem: jax.Array, env: 'Environment'): def __str__(self): return "XLATensor2({} {})".format(str(type(self._elem)), str(self._elem)) + __repr__ = __str__ + def __jax_array__(self): return self._elem @@ -145,26 +147,19 @@ def dtype(self): def dim(self): return self.ndim + @property + def device(self): + return torch.device('jax:0') + @property + def jax_device(self): + return self._elem.device -# TODO: slice of slice should also be another slice -class SliceView(XLATensor2): - - def __init__(self, orig_tensor, slice): - self._orig_tensor = orig_tensor - self._slice = slice - - def numpy(self): - return self._orig_tensor.numpy()[self._slice] - - def jax(self): - return self._orig_tensor.jax()[self._slice] - - def torch(self): - return self._orig_tensor.torch()[self.slice] + def tolist(self): + return self._elem.tolist() + - def mutate(self, slice, new_content): - self._orig_tensor._elem = self._orig_tensor.at[slice].set(new_content) + @@ -233,6 +228,21 @@ def _name_of_func(func): return func.__name__ +# Constructors that don't take other tensor as input +TENSOR_CONSTRUCTORS = { + torch.ones, + torch.zeros, + torch.empty, + torch.empty_strided, + torch.tensor, + torch.arange, + torch.eye, + torch.randn, + torch.rand, + torch.randint, +} + + class Environment(contextlib.ContextDecorator): """This class holds a set of configurations and "globals" needed @@ -260,6 +270,25 @@ def __init__(self, configuration=None): self._mesh = None self.config = configuration or config.Configuration() + self._jax_devices = set(['jax', 'jax_cpu', 'xla']) + + def get_as_jax_device(self, device: Any): + if device is None: + return jax.devices()[0] + + if isinstance(device, torch.device): + device = str(device) + if self.config.use_torch_native_for_cpu_tensor and device.startswith('cpu'): + return None + + if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'): + return None + + if device in ('jax_cpu', 'cpu'): + return jax.devices('cpu')[0] + return jax.devices()[0] + + def load_ops(self): from torch_xla2.ops import jaten, jtorch, jc10d, ops_registry self._ops.update(ops_registry.all_aten_ops) @@ -278,17 +307,85 @@ def load_ops(self): needs_env=False ) + def _to_copy(self, the_tensor, new_dtype, new_device): + if isinstance(the_tensor, XLATensor2): + arr = the_tensor.jax() + if new_dtype is not None and new_dtype != arr.dtype: + arr = arr.astype(mappings.t2j_dtype(new_dtype)) + if new_device is not None: + jax_device = self.get_as_jax_device(new_device) + if jax_device: + arr = jax.device_put(arr, jax_device) + else: + # converting to a non-jax device: let torch native handle it + torch_tensor = j2t(arr) if isinstance(the_tensor, XLATensor2) else arr + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return torch_tensor.to(new_device) + else: + if new_dtype is not None and new_dtype != the_tensor.dtype: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + the_tensor = the_tensor.to(new_dtype) + jax_device = self.get_as_jax_device(new_device) + if jax_device: + with jax.default_device(jax_device): + arr = t2j(the_tensor) + else: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return torch_tensor.to(new_device) + + return XLATensor2(arr, self) + + def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None): # Always use the default `randint` to get the next seed - with mode_utils.no_dispatch(): + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): next_key = torch.randint( 0, 2**32, (), dtype=torch.uint32, generator=generator).numpy() return jax.random.key(next_key) + def _handle_tensor_constructor(self, func, args, kwargs): + device = kwargs.get('device') + jax_device = self.get_as_jax_device(device) + if jax_device is None: + # let torch handle it + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return func(*args, **kwargs) + + with jax.default_device(jax_device): + op = self._ops.get(func) + res = op.func(*args, **kwargs) + if isinstance(res, jax.Array): + res = XLATensor2(res, self) + return res + + def _torch_Tensor_to(self, args, kwargs): + the_tensor = args[0] + args = args[1:] + if len(args) >= 1 and isinstance(args[0], torch.Tensor): + dtype = args[0].dtype + device = args[0].device + return self._to_copy(the_tensor, dtype, device) + device = kwargs.get('device') + dtype = kwargs.get('dtype') + # args like pin_memory etc that we will ignore + args = list(filter(lambda x: not isinstance(x, bool), args)) + if len(args) >= 2: + device, dtype, *_ = args + elif len(args) == 1 and isinstance(args[0], torch.dtype): + dtype = args[0] + elif len(args) == 1: + device = args[0] + return self._to_copy(the_tensor, dtype, device) + + def dispatch(self, func, types, args, kwargs): kwargs = kwargs or {} + if func in TENSOR_CONSTRUCTORS: + return self._handle_tensor_constructor(func, args, kwargs) + if func in (torch.Tensor.to, torch.ops.aten._to_copy, torch.ops.aten._to_copy.default): + return self._torch_Tensor_to(args, kwargs) # If the func don't act on XLATensor2, and is not a tensor constructor, # We should skip and let torch handle it.