From 1fb09bc68b86eefb851459bacb182ca50f5bde37 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 14 Aug 2024 19:23:07 -0700 Subject: [PATCH] Add fp8e4m3fn support (#7842) --- test/test_fp8.py | 83 ++++++++++++++++++---------------- torch_xla/csrc/dtype.cpp | 4 ++ torch_xla/csrc/tensor_util.cpp | 54 ++++++++++++++++++++++ 3 files changed, 103 insertions(+), 38 deletions(-) diff --git a/test/test_fp8.py b/test/test_fp8.py index 9471a426c17..078d5002a60 100644 --- a/test/test_fp8.py +++ b/test/test_fp8.py @@ -4,46 +4,53 @@ import torch import torch_xla import unittest - - -class Fp8Test(unittest.TestCase): - - def test_fp8(self): - device = torch_xla.device() - fp8_types = [torch.float8_e5m2] - for dtype in fp8_types: - t = torch.rand(2, 2).to(dtype) - xla_t = t.to(device) - torch_t = xla_t.cpu() - self.assertEqual(xla_t.dtype, dtype) - self.assertEqual(torch_t.dtype, dtype) - # Need to cast to float32 since allclose doesn't work with fp8. - self.assertTrue( - torch.allclose(t.to(torch.float32), torch_t.to(torch.float32))) - - def test_fp8_matmul(self): - device = torch_xla.device() - fp8_types = [torch.float8_e5m2] - for dtype in fp8_types: - t = torch.rand(3, 2).to(dtype) - w = torch.rand(2, 5).to(dtype) - torch_matmul = torch.matmul(t, w) - xla_t = t.to(device) - xla_w = w.to(device) - xla_matmul = torch.matmul(xla_t, xla_w) - xla_matmul = xla_matmul.cpu() - # Need to cast to float32 since allclose doesn't work with fp8. - self.assertTrue( - torch.allclose( - xla_matmul.to(torch.float32), torch_matmul.to(torch.float32))) - - def test_fp8_hlo(self): - device = torch_xla.device() - x = torch.randn((3, 5)).to(torch.float8_e5m2).to(device) - w = torch.randn((5, 8)).to(torch.float8_e5m2).to(device) +from absl.testing import parameterized + +device = torch_xla.device() + +dtype_parameters = [ + torch.float8_e5m2, + torch.float8_e4m3fn, +] + + +class Fp8Test(parameterized.TestCase): + + @parameterized.parameters(*dtype_parameters) + def test_fp8(self, dtype): + t = torch.rand(2, 2).to(dtype) + xla_t = t.to(device) + torch_t = xla_t.cpu() + self.assertEqual(xla_t.dtype, dtype) + self.assertEqual(torch_t.dtype, dtype) + # Need to cast to float32 since allclose doesn't work with fp8. + self.assertTrue( + torch.allclose(t.to(torch.float32), torch_t.to(torch.float32))) + + @parameterized.parameters(*dtype_parameters) + def test_fp8_matmul(self, dtype): + t = torch.rand(3, 2).to(dtype) + w = torch.rand(2, 5).to(dtype) + torch_matmul = torch.matmul(t, w) + xla_t = t.to(device) + xla_w = w.to(device) + xla_matmul = torch.matmul(xla_t, xla_w) + xla_matmul = xla_matmul.cpu() + # Need to cast to float32 since allclose doesn't work with fp8. + self.assertTrue( + torch.allclose( + xla_matmul.to(torch.float32), torch_matmul.to(torch.float32))) + + @parameterized.parameters(*dtype_parameters) + def test_fp8_hlo(self, dtype): + x = torch.randn((3, 5)).to(dtype).to(device) + w = torch.randn((5, 8)).to(dtype).to(device) output = torch.matmul(x, w) hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) - self.assertTrue(re.search(r'f8e5m2.*dot.*f8e5m2.*f8e5m2', hlo) is not None) + exmy_str = str(dtype).split('_')[-1] + self.assertTrue( + re.search(rf'f8{exmy_str}.*dot.*f8{exmy_str}.*f8{exmy_str}', hlo) + is not None) if __name__ == '__main__': diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 0d7056cc702..58f7497fe20 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -9,6 +9,8 @@ at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { switch (xla_type) { case xla::PrimitiveType::BF16: return at::ScalarType::BFloat16; + case xla::PrimitiveType::F8E4M3FN: + return at::ScalarType::Float8_e4m3fn; case xla::PrimitiveType::F8E5M2: return at::ScalarType::Float8_e5m2; case xla::PrimitiveType::F16: @@ -51,6 +53,8 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) { return xla::PrimitiveType::BF16; case at::ScalarType::Half: return xla::PrimitiveType::F16; + case at::ScalarType::Float8_e4m3fn: + return xla::PrimitiveType::F8E4M3FN; case at::ScalarType::Float8_e5m2: return xla::PrimitiveType::F8E5M2; case at::ScalarType::Bool: diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 119f3e870c6..8bc2d9ab142 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -63,6 +63,22 @@ struct Caster { } }; +template <> +struct Caster { + template + D cast(const at::Float8_e4m3fn& value) const { + return static_cast(static_cast(value)); + } +}; + +template <> +struct Caster { + template + D cast(const tsl::float8_e4m3fn& value) const { + return static_cast(static_cast(value)); + } +}; + template <> struct Caster { template @@ -201,6 +217,15 @@ template <> struct NeedCast { static constexpr bool value = true; }; + +template <> +struct NeedCast { + static constexpr bool value = true; +}; +template <> +struct NeedCast { + static constexpr bool value = true; +}; template <> struct NeedCast { static constexpr bool value = true; @@ -274,6 +299,18 @@ void CopyData(tsl::bfloat16* dest, CheckedMemcpy(dest, source, n); } template <> +void CopyData( + at::Float8_e4m3fn* dest, const tsl::float8_e4m3fn* source, int64_t n, + const CopyCasted&) { + CheckedMemcpy(dest, source, n); +} +template <> +void CopyData( + tsl::float8_e4m3fn* dest, const at::Float8_e4m3fn* source, int64_t n, + const CopyCasted&) { + CheckedMemcpy(dest, source, n); +} +template <> void CopyData(at::Float8_e5m2* dest, const tsl::float8_e5m2* source, int64_t n, const CopyCasted&) { @@ -451,6 +488,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape, TensorToBuffer(tensor, dest_shape, dest_buffer, dest_buffer_size, device); break; + case xla::PrimitiveType::F8E4M3FN: + TensorToBuffer(tensor, dest_shape, dest_buffer, + dest_buffer_size, device); + break; case xla::PrimitiveType::F8E5M2: TensorToBuffer(tensor, dest_shape, dest_buffer, dest_buffer_size, device); @@ -578,6 +619,10 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal, dest_element_type); case at::ScalarType::Half: return XlaLiteralToTensor(literal, dest_element_type); + case at::ScalarType::Float8_e4m3fn: + return XlaLiteralToTensor(literal, + dest_element_type); + case at::ScalarType::Float8_e5m2: return XlaLiteralToTensor(literal, dest_element_type); @@ -611,6 +656,11 @@ void PopulateTensorBuffer(const at::Tensor& tensor, TensorToBufferSType(tensor, dest_shape, dest_buffer, dest_buffer_size, device); break; + case at::ScalarType::Float8_e4m3fn: + TensorToBufferSType(tensor, dest_shape, dest_buffer, + dest_buffer_size, device); + break; + case at::ScalarType::Float8_e5m2: TensorToBufferSType(tensor, dest_shape, dest_buffer, dest_buffer_size, device); @@ -674,6 +724,10 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, case xla::PrimitiveType::BF16: return XlaLiteralToTensorHelper(literal, dest_element_type); + case xla::PrimitiveType::F8E4M3FN: + return XlaLiteralToTensorHelper(literal, + dest_element_type); + case xla::PrimitiveType::F8E5M2: return XlaLiteralToTensorHelper(literal, dest_element_type);