From b3f0ea18f67260455b4a816596f6464b1a48f682 Mon Sep 17 00:00:00 2001 From: zpcore Date: Tue, 16 Jul 2024 20:43:43 +0000 Subject: [PATCH 1/2] update debug_single_process to false --- test/test_train_mp_imagenet.py | 4 ++-- torch_xla/torch_xla.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 8697f3a1bde..3fd0776f65c 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -374,5 +374,5 @@ def _mp_fn(index, flags): if __name__ == '__main__': # if running with torchrun, nprocs argument will be omitted. - debug_single_process = True if FLAGS.num_cores == 1 else False - torch_xla.launch(_mp_fn, args=(FLAGS,), debug_single_process=True) + do_debug = True if FLAGS.num_cores == 1 else False + torch_xla.launch(_mp_fn, args=(FLAGS,), debug_single_process=do_debug) diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index de2ab4c4926..016715d322a 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -87,6 +87,7 @@ def manual_seed(seed, device=None): xm.set_rng_state(seed, device) +# TODO(wcromar): Update args to type ParamSpec. def launch( fn: Callable, args: Tuple = (), From 3f58a187b5f5d8657f298e270fea5a6f9e73867f Mon Sep 17 00:00:00 2001 From: zpcore Date: Tue, 16 Jul 2024 23:19:07 +0000 Subject: [PATCH 2/2] update naming --- test/test_train_mp_imagenet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 3fd0776f65c..cb8f43fc76c 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -374,5 +374,6 @@ def _mp_fn(index, flags): if __name__ == '__main__': # if running with torchrun, nprocs argument will be omitted. - do_debug = True if FLAGS.num_cores == 1 else False - torch_xla.launch(_mp_fn, args=(FLAGS,), debug_single_process=do_debug) + debug_single_process = True if FLAGS.num_cores == 1 else False + torch_xla.launch( + _mp_fn, args=(FLAGS,), debug_single_process=debug_single_process)