From 9c9cbf56616c9cbb5d02e6656bb2ad0267bef2df Mon Sep 17 00:00:00 2001 From: Avi Singhal Date: Wed, 4 Sep 2024 17:21:51 +0000 Subject: [PATCH 1/3] uncommented cross entropy loss for xla autocast --- torch_xla/csrc/autocast_mode.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index 91ff7999e10..193721db1e3 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -92,7 +92,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(hinge_embedding_loss, fp32) // KERNEL_XLA(poisson_nll_loss, fp32) KERNEL_XLA(smooth_l1_loss, fp32) - // KERNEL_XLA(cross_entropy_loss, fp32) + KERNEL_XLA(cross_entropy_loss, fp32) KERNEL_XLA(l1_loss, fp32) // KERNEL_XLA(huber_loss, fp32) KERNEL_XLA(margin_ranking_loss, fp32) From 2635e98e638aa807feb66fcec63f981ae830c86d Mon Sep 17 00:00:00 2001 From: Avi Singhal Date: Wed, 11 Sep 2024 23:31:36 +0000 Subject: [PATCH 2/3] added unit test for cross entropy loss autocast fp32 --- test/test_bf16_autocast.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 test/test_bf16_autocast.py diff --git a/test/test_bf16_autocast.py b/test/test_bf16_autocast.py new file mode 100644 index 00000000000..00d4d3482ff --- /dev/null +++ b/test/test_bf16_autocast.py @@ -0,0 +1,30 @@ +import os +import re +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import unittest + +device = xm.xla_device() + + +class TestAutocastXla(unittest.TestCase): + def test_cross_entropy_loss(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(16, 10).to(torch.bfloat16).to(device) + with torch.autocast("xla"): + loss = torch.nn.CrossEntropyLoss()(data, target) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) + self.assertTrue( + re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None + ) + + self.assertTrue( + re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None + ) + + self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None) + + +if __name__ == "__main__": + unittest.main() From 330158e93ae43b8295b3932ecf5e5e4dd5084ec7 Mon Sep 17 00:00:00 2001 From: Avi Singhal Date: Wed, 11 Sep 2024 23:40:20 +0000 Subject: [PATCH 3/3] fixed formatting --- test/test_bf16_autocast.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/test/test_bf16_autocast.py b/test/test_bf16_autocast.py index 00d4d3482ff..d5facd802cd 100644 --- a/test/test_bf16_autocast.py +++ b/test/test_bf16_autocast.py @@ -9,22 +9,21 @@ class TestAutocastXla(unittest.TestCase): - def test_cross_entropy_loss(self): - data = torch.randn(16, 10).to(torch.bfloat16).to(device) - target = torch.randn(16, 10).to(torch.bfloat16).to(device) - with torch.autocast("xla"): - loss = torch.nn.CrossEntropyLoss()(data, target) - hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) - self.assertTrue( - re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None - ) - self.assertTrue( - re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None - ) + def test_cross_entropy_loss(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(16, 10).to(torch.bfloat16).to(device) + with torch.autocast("xla"): + loss = torch.nn.CrossEntropyLoss()(data, target) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) + self.assertTrue( + re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None) - self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None) + self.assertTrue( + re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None) + + self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None) if __name__ == "__main__": - unittest.main() + unittest.main()