diff --git a/test/test_data_type.py b/test/test_data_type.py index b5cf463d5d0..9b7f55ff148 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -55,6 +55,16 @@ def test_datatype_f32_div_f64(self): assert t2.dtype == torch.float assert 'f64' not in hlo_text + def test_datatype_U16_32_64(self): + + def _dtype_round_trip(dtype): + t = torch.randint(0, 128, (2, 4), dtype=dtype).to(xm.xla_device()) + return t.cpu().dtype + + for dtype in [torch.uint16, torch.uint32, torch.uint64]: + dtype2 = _dtype_round_trip(dtype) + self.assertTrue(dtype == dtype2) + if __name__ == '__main__': print(f'XLA_USE_BF16: {os.getenv("XLA_USE_BF16")}') diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index d4ed1b413a6..e26e4e27fe9 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -127,10 +127,16 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) { return xla::PrimitiveType::U8; case at::ScalarType::Char: return xla::PrimitiveType::S8; + case at::ScalarType::UInt16: + return xla::PrimitiveType::U16; case at::ScalarType::Short: return xla::PrimitiveType::S16; + case at::ScalarType::UInt32: + return xla::PrimitiveType::U32; case at::ScalarType::Int: return xla::PrimitiveType::S32; + case at::ScalarType::UInt64: + return xla::PrimitiveType::U64; case at::ScalarType::Long: return xla::PrimitiveType::S64; case at::ScalarType::ComplexFloat: diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 2e4f280ba66..9c00e212ccc 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -517,10 +517,16 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal, return XlaLiteralToTensor(literal, dest_element_type); case at::ScalarType::Char: return XlaLiteralToTensor(literal, dest_element_type); + case at::ScalarType::UInt16: + return XlaLiteralToTensor(literal, dest_element_type); case at::ScalarType::Short: return XlaLiteralToTensor(literal, dest_element_type); + case at::ScalarType::UInt32: + return XlaLiteralToTensor(literal, dest_element_type); case at::ScalarType::Int: return XlaLiteralToTensor(literal, dest_element_type); + case at::ScalarType::UInt64: + return XlaLiteralToTensor(literal, dest_element_type); case at::ScalarType::Long: return XlaLiteralToTensor(literal, dest_element_type); case at::ScalarType::Float: