Skip to content

Commit

Permalink
Add fp8e4m3fn support (#7842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 committed Aug 15, 2024
1 parent 41bf6da commit 1fb09bc
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 38 deletions.
83 changes: 45 additions & 38 deletions test/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,53 @@
import torch
import torch_xla
import unittest


class Fp8Test(unittest.TestCase):

def test_fp8(self):
device = torch_xla.device()
fp8_types = [torch.float8_e5m2]
for dtype in fp8_types:
t = torch.rand(2, 2).to(dtype)
xla_t = t.to(device)
torch_t = xla_t.cpu()
self.assertEqual(xla_t.dtype, dtype)
self.assertEqual(torch_t.dtype, dtype)
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(t.to(torch.float32), torch_t.to(torch.float32)))

def test_fp8_matmul(self):
device = torch_xla.device()
fp8_types = [torch.float8_e5m2]
for dtype in fp8_types:
t = torch.rand(3, 2).to(dtype)
w = torch.rand(2, 5).to(dtype)
torch_matmul = torch.matmul(t, w)
xla_t = t.to(device)
xla_w = w.to(device)
xla_matmul = torch.matmul(xla_t, xla_w)
xla_matmul = xla_matmul.cpu()
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(
xla_matmul.to(torch.float32), torch_matmul.to(torch.float32)))

def test_fp8_hlo(self):
device = torch_xla.device()
x = torch.randn((3, 5)).to(torch.float8_e5m2).to(device)
w = torch.randn((5, 8)).to(torch.float8_e5m2).to(device)
from absl.testing import parameterized

device = torch_xla.device()

dtype_parameters = [
torch.float8_e5m2,
torch.float8_e4m3fn,
]


class Fp8Test(parameterized.TestCase):

@parameterized.parameters(*dtype_parameters)
def test_fp8(self, dtype):
t = torch.rand(2, 2).to(dtype)
xla_t = t.to(device)
torch_t = xla_t.cpu()
self.assertEqual(xla_t.dtype, dtype)
self.assertEqual(torch_t.dtype, dtype)
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(t.to(torch.float32), torch_t.to(torch.float32)))

@parameterized.parameters(*dtype_parameters)
def test_fp8_matmul(self, dtype):
t = torch.rand(3, 2).to(dtype)
w = torch.rand(2, 5).to(dtype)
torch_matmul = torch.matmul(t, w)
xla_t = t.to(device)
xla_w = w.to(device)
xla_matmul = torch.matmul(xla_t, xla_w)
xla_matmul = xla_matmul.cpu()
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(
xla_matmul.to(torch.float32), torch_matmul.to(torch.float32)))

@parameterized.parameters(*dtype_parameters)
def test_fp8_hlo(self, dtype):
x = torch.randn((3, 5)).to(dtype).to(device)
w = torch.randn((5, 8)).to(dtype).to(device)
output = torch.matmul(x, w)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
self.assertTrue(re.search(r'f8e5m2.*dot.*f8e5m2.*f8e5m2', hlo) is not None)
exmy_str = str(dtype).split('_')[-1]
self.assertTrue(
re.search(rf'f8{exmy_str}.*dot.*f8{exmy_str}.*f8{exmy_str}', hlo)
is not None)


if __name__ == '__main__':
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
return at::ScalarType::BFloat16;
case xla::PrimitiveType::F8E4M3FN:
return at::ScalarType::Float8_e4m3fn;
case xla::PrimitiveType::F8E5M2:
return at::ScalarType::Float8_e5m2;
case xla::PrimitiveType::F16:
Expand Down Expand Up @@ -51,6 +53,8 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) {
return xla::PrimitiveType::BF16;
case at::ScalarType::Half:
return xla::PrimitiveType::F16;
case at::ScalarType::Float8_e4m3fn:
return xla::PrimitiveType::F8E4M3FN;
case at::ScalarType::Float8_e5m2:
return xla::PrimitiveType::F8E5M2;
case at::ScalarType::Bool:
Expand Down
54 changes: 54 additions & 0 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ struct Caster<tsl::bfloat16> {
}
};

template <>
struct Caster<at::Float8_e4m3fn> {
template <typename D>
D cast(const at::Float8_e4m3fn& value) const {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<tsl::float8_e4m3fn> {
template <typename D>
D cast(const tsl::float8_e4m3fn& value) const {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<at::Float8_e5m2> {
template <typename D>
Expand Down Expand Up @@ -201,6 +217,15 @@ template <>
struct NeedCast<at::BFloat16> {
static constexpr bool value = true;
};

template <>
struct NeedCast<tsl::float8_e4m3fn> {
static constexpr bool value = true;
};
template <>
struct NeedCast<at::Float8_e4m3fn> {
static constexpr bool value = true;
};
template <>
struct NeedCast<tsl::float8_e5m2> {
static constexpr bool value = true;
Expand Down Expand Up @@ -274,6 +299,18 @@ void CopyData<tsl::bfloat16, at::BFloat16>(tsl::bfloat16* dest,
CheckedMemcpy<tsl::bfloat16, at::BFloat16>(dest, source, n);
}
template <>
void CopyData<at::Float8_e4m3fn, tsl::float8_e4m3fn>(
at::Float8_e4m3fn* dest, const tsl::float8_e4m3fn* source, int64_t n,
const CopyCasted&) {
CheckedMemcpy<at::Float8_e4m3fn, tsl::float8_e4m3fn>(dest, source, n);
}
template <>
void CopyData<tsl::float8_e4m3fn, at::Float8_e4m3fn>(
tsl::float8_e4m3fn* dest, const at::Float8_e4m3fn* source, int64_t n,
const CopyCasted&) {
CheckedMemcpy<tsl::float8_e4m3fn, at::Float8_e4m3fn>(dest, source, n);
}
template <>
void CopyData<at::Float8_e5m2, tsl::float8_e5m2>(at::Float8_e5m2* dest,
const tsl::float8_e5m2* source,
int64_t n, const CopyCasted&) {
Expand Down Expand Up @@ -451,6 +488,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
TensorToBuffer<SType, double>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F8E4M3FN:
TensorToBuffer<SType, tsl::float8_e4m3fn>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F8E5M2:
TensorToBuffer<SType, tsl::float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -578,6 +619,10 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
dest_element_type);
case at::ScalarType::Half:
return XlaLiteralToTensor<SType, at::Half>(literal, dest_element_type);
case at::ScalarType::Float8_e4m3fn:
return XlaLiteralToTensor<SType, at::Float8_e4m3fn>(literal,
dest_element_type);

case at::ScalarType::Float8_e5m2:
return XlaLiteralToTensor<SType, at::Float8_e5m2>(literal,
dest_element_type);
Expand Down Expand Up @@ -611,6 +656,11 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
TensorToBufferSType<at::BFloat16>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Float8_e4m3fn:
TensorToBufferSType<at::Float8_e4m3fn>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;

case at::ScalarType::Float8_e5m2:
TensorToBufferSType<at::Float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -674,6 +724,10 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
case xla::PrimitiveType::BF16:
return XlaLiteralToTensorHelper<tsl::bfloat16>(literal,
dest_element_type);
case xla::PrimitiveType::F8E4M3FN:
return XlaLiteralToTensorHelper<tsl::float8_e4m3fn>(literal,
dest_element_type);

case xla::PrimitiveType::F8E5M2:
return XlaLiteralToTensorHelper<tsl::float8_e5m2>(literal,
dest_element_type);
Expand Down

0 comments on commit 1fb09bc

Please sign in to comment.