diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 24a16d84e5a..1c71e1be46e 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -319,7 +319,8 @@ def test_all_to_all_single(self, use_dynamo): self._all_to_all_single, use_dynamo=use_dynamo) expected = torch.arange( tpu.num_expected_global_devices(), dtype=torch.float) - # Note: all_to_all xla op does not honor the order of the all_to_all. + # Note: AllToAll xla op does not honor the order of the all_to_all, which means + # the rank may not follow the order. for _, val in results.items(): self.assertTrue(torch.allclose(val.sort().values, expected.sort().values))