-
Notifications
You must be signed in to change notification settings - Fork 468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic PJRT plugin registration API #5644
Changes from all commits
9c38291
5936932
2c4dcd4
db4825c
faaa6be
a6686a4
664f11f
98d45f9
5b430da
3cf42d0
63f08f5
802b4f9
8b2e6e5
e9a4632
06a1094
27a59e8
eff4344
8a48e90
867a8f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import concurrent.futures | ||
|
||
from absl.testing import absltest | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
from torch_xla.experimental import plugins | ||
import torch_xla.runtime as xr | ||
from torch_xla._internal import tpu | ||
|
||
plugins.register_plugin('TPU', tpu.TpuPlugin()) | ||
plugins.use_dynamic_plugins() | ||
|
||
|
||
class TestDynamicTpuPlugin(absltest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
xr.set_device_type('TPU') | ||
|
||
@staticmethod | ||
def _assert_tpus_exist(index=0): | ||
del index | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's this for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
assert len(xm.get_xla_supported_devices('TPU')) > 0 | ||
|
||
def test_single_process(self): | ||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the difference b/w test_dynamic_plugin_api and test_spawn is that the former test single processing and the latter test the multi-processing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I wrote |
||
executor.submit(self._assert_tpus_exist).result() | ||
|
||
def test_spawn(self): | ||
xmp.spawn(self._assert_tpus_exist) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from torch_xla._internal import tpu, gpu, neuron | ||
from torch_xla import runtime | ||
import torch_xla.utils.utils as xu | ||
from torch_xla.experimental import plugins | ||
|
||
R = TypeVar('R') | ||
|
||
|
@@ -96,7 +97,9 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]: | |
""" | ||
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, '1') | ||
|
||
if runtime.device_type() == 'TPU': | ||
if plugins.using_dynamic_plugins(): | ||
plugins.default().configure_single_process() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we throw a warning or something when pople configure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You still select the device type with |
||
elif runtime.device_type() == 'TPU': | ||
tpu.configure_one_chip_topology() | ||
|
||
xm.set_replication(xm.xla_device(), []) | ||
|
@@ -109,7 +112,9 @@ def initialize_multiprocess(local_rank: int, local_world_size: int): | |
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank)) | ||
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, str(local_world_size)) | ||
|
||
if runtime.device_type() == 'TPU': | ||
if plugins.using_dynamic_plugins(): | ||
plugins.default().configure_multiprocess(local_rank, local_world_size) | ||
elif runtime.device_type() == 'TPU': | ||
tpu.configure_topology(local_rank, local_world_size) | ||
elif runtime.device_type() == 'NEURON': | ||
neuron.initialize_env(local_rank) | ||
|
@@ -138,7 +143,9 @@ def run_multiprocess(fn: Callable[..., R], | |
Dict of the form {device_ordinal: return_value}, where | ||
return_value is the result of calling `fn`. | ||
""" | ||
if runtime.device_type() == 'TPU': | ||
if plugins.using_dynamic_plugins(): | ||
num_processes = plugins.default().physical_chip_count() | ||
elif runtime.device_type() == 'TPU': | ||
num_processes = tpu.num_local_processes() | ||
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): | ||
num_processes = gpu.num_local_processes() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you going to put this in our init file eventually?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I'm avoiding any changes to the default behavior while this is WIP.