diff --git a/test/test_bf16_autocast.py b/test/test_bf16_autocast.py new file mode 100644 index 00000000000..d5facd802cd --- /dev/null +++ b/test/test_bf16_autocast.py @@ -0,0 +1,29 @@ +import os +import re +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import unittest + +device = xm.xla_device() + + +class TestAutocastXla(unittest.TestCase): + + def test_cross_entropy_loss(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(16, 10).to(torch.bfloat16).to(device) + with torch.autocast("xla"): + loss = torch.nn.CrossEntropyLoss()(data, target) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) + self.assertTrue( + re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None) + + self.assertTrue( + re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None) + + self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index b0a8ddb745a..b151d6460c0 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -95,7 +95,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(hinge_embedding_loss, fp32) // KERNEL_XLA(poisson_nll_loss, fp32) KERNEL_XLA(smooth_l1_loss, fp32) - // KERNEL_XLA(cross_entropy_loss, fp32) + KERNEL_XLA(cross_entropy_loss, fp32) KERNEL_XLA(l1_loss, fp32) // KERNEL_XLA(huber_loss, fp32) KERNEL_XLA(margin_ranking_loss, fp32)