-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic PJRT plugin registration API #5644
Conversation
6e37cbf
to
3f20fde
Compare
7956ce5
to
b732806
Compare
Heads up @jzhoulon @aws-kingrj, I'm working on a new way for external packages to register PJRT plugins with When this API is finalized, we can move plugin registration (something like |
Leaving this as draft for now until I rebase after #5677, but this PR is largely ready for comments. |
assert len(xm.get_xla_supported_devices('TPU')) > 0 | ||
|
||
def test_dynamic_plugin_api(self): | ||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the difference b/w test_dynamic_plugin_api and test_spawn is that the former test single processing and the latter test the multi-processing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I wrote test_dynamic_plugin_api
before the other. I'll change the name to something like test_single_process
to be more clear
|
||
def register_plugin(name: str, device_plugin: DevicePlugin): | ||
_plugin_registry[name.upper()] = device_plugin | ||
torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what library_path we should use for GPU. iiuc, GPU doesn't involve libTPU lib.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None right now. GPU support is statically linked in. When that moves to a plugin (say libsegpu.so
), it will be the path to that binary.
|
||
@staticmethod | ||
def _assert_tpus_exist(index=0): | ||
del index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index
is required for spawn
, but we don't need it. I just explicitly delete it to mark it unused
@@ -96,7 +97,9 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]: | |||
""" | |||
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, '1') | |||
|
|||
if runtime.device_type() == 'TPU': | |||
if plugins.using_dynamic_plugins(): | |||
plugins.default().configure_single_process() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we throw a warning or something when pople configure PJRT_DEVICE
while also register the plugin in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You still select the device type with PJRT_DEVICE
. Plugins will just let you register new device types when we clean up all of the hardcoded strings.
@@ -53,6 +53,21 @@ int64_t ComputationClient::GetDeviceOrdinal(const std::string& device) { | |||
return std::stoi(device.substr(pos + 1)); | |||
} | |||
|
|||
std::unordered_map<std::string, std::string> pjrt_plugins_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought user can only register one plug in? What;s the use case of registering multiple?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll register TPU and GPU as default options, and then other packages will add plugins on top of those. JAX is also using Python entry points to register available plugins automatically, which we may also want to do.
import torch_xla.runtime as xr | ||
from torch_xla._internal import tpu | ||
|
||
plugins.register_plugin('TPU', tpu.TpuPlugin()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you going to put this in our init file eventually?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I'm avoiding any changes to the default behavior while this is WIP.
2c1c483
to
8a48e90
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks.
First pass at implementing a common API for device plugins. The eventual goal is to remove any cases where we have to hard-code the device type in our build, allowing truly dynamic plugins through the PJRT plugin API.
DevicePlugin
API, including sample implementation for TPUPJRT_DYNAMIC_PLUGINS=1
orplugins.use_dynamic_plugins
plugin.register_plugin
and enable plugins, then you can use the same device by name by settingPJRT_DEVICE
. (see integration test in this PR for an example)Future work:
XlaDeviceType
so plugins don't have to register their device strings in this repository