diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 691e875657d..c7cd994db0a 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2002,15 +2002,26 @@ def _aten_where(condition, x, y): # aten.to.dtype # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None -@op(torch.ops.aten.to.dtype, torch.ops.aten.to.dtype_layout) +@op(torch.ops.aten.to.dtype) def _aten_to_dtype( - a, *, dtype=None, layout=None, device=None, pin_memory=None, non_blocking=False, copy=False, memory_format=None + a, dtype, non_blocking=False, copy=False, memory_format=None ): if dtype: jaxdtype = mappings.t2j_dtype(dtype) return a.astype(jaxdtype) +@op(torch.ops.aten.to.dtype_layout) +def _aten_to_dtype_layout( + a, *, dtype=None, layout=None, device=None, pin_memory=None, non_blocking=False, copy=False, memory_format=None +): + return _aten_to_dtype( + a, + dtype, + non_blocking=non_blocking, + copy=copy, + memory_format=memory_format) + # aten.to.device