From 940bee453fb27a023b360469487af2a8831966d6 Mon Sep 17 00:00:00 2001 From: avizon-aws Date: Mon, 30 Sep 2024 13:41:01 -0700 Subject: [PATCH] Enable cross entropy loss for xla autocast with FP32 precision (#7992) (#8094) --- test/test_bf16_autocast.py | 29 +++++++++++++++++++++++++++++ torch_xla/csrc/autocast_mode.cpp | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 test/test_bf16_autocast.py diff --git a/test/test_bf16_autocast.py b/test/test_bf16_autocast.py new file mode 100644 index 00000000000..d5facd802cd --- /dev/null +++ b/test/test_bf16_autocast.py @@ -0,0 +1,29 @@ +import os +import re +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import unittest + +device = xm.xla_device() + + +class TestAutocastXla(unittest.TestCase): + + def test_cross_entropy_loss(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(16, 10).to(torch.bfloat16).to(device) + with torch.autocast("xla"): + loss = torch.nn.CrossEntropyLoss()(data, target) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) + self.assertTrue( + re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None) + + self.assertTrue( + re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None) + + self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index b0a8ddb745a..b151d6460c0 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -95,7 +95,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(hinge_embedding_loss, fp32) // KERNEL_XLA(poisson_nll_loss, fp32) KERNEL_XLA(smooth_l1_loss, fp32) - // KERNEL_XLA(cross_entropy_loss, fp32) + KERNEL_XLA(cross_entropy_loss, fp32) KERNEL_XLA(l1_loss, fp32) // KERNEL_XLA(huber_loss, fp32) KERNEL_XLA(margin_ranking_loss, fp32)