Skip to content

Commit

Permalink
Support fp8e5m2 dtype (#7740)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
  • Loading branch information
lsy323 and Siyuan Liu committed Aug 6, 2024
1 parent dd04c58 commit 1ed2626
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_test "$CDIR/test_data_type.py"
run_test "$CDIR/test_fp8.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
Expand Down
51 changes: 51 additions & 0 deletions test/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import re

import torch
import torch_xla
import unittest


class Fp8Test(unittest.TestCase):

def test_fp8(self):
device = torch_xla.device()
fp8_types = [torch.float8_e5m2]
for dtype in fp8_types:
t = torch.rand(2, 2).to(dtype)
xla_t = t.to(device)
torch_t = xla_t.cpu()
self.assertEqual(xla_t.dtype, dtype)
self.assertEqual(torch_t.dtype, dtype)
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(t.to(torch.float32), torch_t.to(torch.float32)))

def test_fp8_matmul(self):
device = torch_xla.device()
fp8_types = [torch.float8_e5m2]
for dtype in fp8_types:
t = torch.rand(3, 2).to(dtype)
w = torch.rand(2, 5).to(dtype)
torch_matmul = torch.matmul(t, w)
xla_t = t.to(device)
xla_w = w.to(device)
xla_matmul = torch.matmul(xla_t, xla_w)
xla_matmul = xla_matmul.cpu()
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(
xla_matmul.to(torch.float32), torch_matmul.to(torch.float32)))

def test_fp8_hlo(self):
device = torch_xla.device()
x = torch.randn((3, 5)).to(torch.float8_e5m2).to(device)
w = torch.randn((5, 8)).to(torch.float8_e5m2).to(device)
output = torch.matmul(x, w)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
self.assertTrue(re.search(r'f8e5m2.*dot.*f8e5m2.*f8e5m2', hlo) is not None)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ python3 test/spmd/test_fsdp_v2.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/test_fp8.py
python3 test/test_grad_checkpoint.py
python3 test/dynamo/test_dynamo.py
python3 test/dynamo/test_dynamo_dynamic_shape.py
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
return at::ScalarType::BFloat16;
case xla::PrimitiveType::F8E5M2:
return at::ScalarType::Float8_e5m2;
case xla::PrimitiveType::F16:
return at::ScalarType::Half;
case xla::PrimitiveType::F32:
Expand Down Expand Up @@ -49,6 +51,8 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) {
return xla::PrimitiveType::BF16;
case at::ScalarType::Half:
return xla::PrimitiveType::F16;
case at::ScalarType::Float8_e5m2:
return xla::PrimitiveType::F8E5M2;
case at::ScalarType::Bool:
return xla::PrimitiveType::PRED;
case at::ScalarType::Byte:
Expand Down
51 changes: 51 additions & 0 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ struct Caster<tsl::bfloat16> {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<at::Float8_e5m2> {
template <typename D>
D cast(const at::Float8_e5m2& value) const {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<tsl::float8_e5m2> {
template <typename D>
D cast(const tsl::float8_e5m2& value) const {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<at::Half> {
template <typename D>
Expand Down Expand Up @@ -185,6 +202,14 @@ struct NeedCast<at::BFloat16> {
static constexpr bool value = true;
};
template <>
struct NeedCast<tsl::float8_e5m2> {
static constexpr bool value = true;
};
template <>
struct NeedCast<at::Float8_e5m2> {
static constexpr bool value = true;
};
template <>
struct NeedCast<xla::half> {
static constexpr bool value = true;
};
Expand Down Expand Up @@ -248,6 +273,18 @@ void CopyData<tsl::bfloat16, at::BFloat16>(tsl::bfloat16* dest,
int64_t n, const CopyCasted&) {
CheckedMemcpy<tsl::bfloat16, at::BFloat16>(dest, source, n);
}
template <>
void CopyData<at::Float8_e5m2, tsl::float8_e5m2>(at::Float8_e5m2* dest,
const tsl::float8_e5m2* source,
int64_t n, const CopyCasted&) {
CheckedMemcpy<at::Float8_e5m2, tsl::float8_e5m2>(dest, source, n);
}
template <>
void CopyData<tsl::float8_e5m2, at::Float8_e5m2>(tsl::float8_e5m2* dest,
const at::Float8_e5m2* source,
int64_t n, const CopyCasted&) {
CheckedMemcpy<tsl::float8_e5m2, at::Float8_e5m2>(dest, source, n);
}

std::vector<int64_t> GetIterationDimensions(const xla::Shape& shape) {
// We want to favor the most minor dimension as core iteration dimension, as
Expand Down Expand Up @@ -414,6 +451,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
TensorToBuffer<SType, double>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F8E5M2:
TensorToBuffer<SType, tsl::float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::PRED:
TensorToBuffer<SType, bool>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -537,6 +578,9 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
dest_element_type);
case at::ScalarType::Half:
return XlaLiteralToTensor<SType, at::Half>(literal, dest_element_type);
case at::ScalarType::Float8_e5m2:
return XlaLiteralToTensor<SType, at::Float8_e5m2>(literal,
dest_element_type);
case at::ScalarType::ComplexFloat:
return XlaLiteralToTensor<SType, c10::complex<float>>(literal,
dest_element_type);
Expand Down Expand Up @@ -567,6 +611,10 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
TensorToBufferSType<at::BFloat16>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Float8_e5m2:
TensorToBufferSType<at::Float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Half:
TensorToBufferSType<at::Half>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -626,6 +674,9 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
case xla::PrimitiveType::BF16:
return XlaLiteralToTensorHelper<tsl::bfloat16>(literal,
dest_element_type);
case xla::PrimitiveType::F8E5M2:
return XlaLiteralToTensorHelper<tsl::float8_e5m2>(literal,
dest_element_type);
case xla::PrimitiveType::F16:
return XlaLiteralToTensorHelper<xla::half>(literal, dest_element_type);
case xla::PrimitiveType::F32:
Expand Down

0 comments on commit 1ed2626

Please sign in to comment.