diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 8697f3a1bde..cb8f43fc76c 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -375,4 +375,5 @@ def _mp_fn(index, flags): if __name__ == '__main__': # if running with torchrun, nprocs argument will be omitted. debug_single_process = True if FLAGS.num_cores == 1 else False - torch_xla.launch(_mp_fn, args=(FLAGS,), debug_single_process=True) + torch_xla.launch( + _mp_fn, args=(FLAGS,), debug_single_process=debug_single_process) diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index de2ab4c4926..016715d322a 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -87,6 +87,7 @@ def manual_seed(seed, device=None): xm.set_rng_state(seed, device) +# TODO(wcromar): Update args to type ParamSpec. def launch( fn: Callable, args: Tuple = (),