Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic PJRT plugin registration API #5644

Merged
merged 19 commits into from
Dec 18, 2023
Merged

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Sep 25, 2023

First pass at implementing a common API for device plugins. The eventual goal is to remove any cases where we have to hard-code the device type in our build, allowing truly dynamic plugins through the PJRT plugin API.

  • Add DevicePlugin API, including sample implementation for TPU
  • Enable dynamic plugins with PJRT_DYNAMIC_PLUGINS=1 or plugins.use_dynamic_plugins
  • Add registration mechanism for PJRT plugins. If you register a valid plugin with plugin.register_plugin and enable plugins, then you can use the same device by name by setting PJRT_DEVICE. (see integration test in this PR for an example)
    • Completely new devices will probably not work yet. We still rely too much on parsing hard-coded strings in the internals.
  • Default behavior doesn't change at all for now

Future work:

  • Remove or update XlaDeviceType so plugins don't have to register their device strings in this repository
  • Move GPU client into PJRT Plugin
  • Automatically register TPU and GPU plugins and remove hard-coded PJRT client initialization.

@will-cromar will-cromar changed the title [WIP] Dynamic PJRT plugin API [WIP] Dynamic PJRT plugin registration API Sep 25, 2023
@will-cromar
Copy link
Collaborator Author

Heads up @jzhoulon @aws-kingrj, I'm working on a new way for external packages to register PJRT plugins with torch_xla. No action is required from you right now. I'll keep Neuron and XPU working within this repository while we develop the idea.

When this API is finalized, we can move plugin registration (something like TpuPlugin in this PR) into your respective packages.

@will-cromar will-cromar changed the title [WIP] Dynamic PJRT plugin registration API Dynamic PJRT plugin registration API Dec 14, 2023
@will-cromar
Copy link
Collaborator Author

Leaving this as draft for now until I rebase after #5677, but this PR is largely ready for comments.

assert len(xm.get_xla_supported_devices('TPU')) > 0

def test_dynamic_plugin_api(self):
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the difference b/w test_dynamic_plugin_api and test_spawn is that the former test single processing and the latter test the multi-processing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wrote test_dynamic_plugin_api before the other. I'll change the name to something like test_single_process to be more clear


def register_plugin(name: str, device_plugin: DevicePlugin):
_plugin_registry[name.upper()] = device_plugin
torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what library_path we should use for GPU. iiuc, GPU doesn't involve libTPU lib.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None right now. GPU support is statically linked in. When that moves to a plugin (say libsegpu.so), it will be the path to that binary.


@staticmethod
def _assert_tpus_exist(index=0):
del index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index is required for spawn, but we don't need it. I just explicitly delete it to mark it unused

@@ -96,7 +97,9 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
"""
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, '1')

if runtime.device_type() == 'TPU':
if plugins.using_dynamic_plugins():
plugins.default().configure_single_process()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we throw a warning or something when pople configure PJRT_DEVICE while also register the plugin in the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still select the device type with PJRT_DEVICE. Plugins will just let you register new device types when we clean up all of the hardcoded strings.

@@ -53,6 +53,21 @@ int64_t ComputationClient::GetDeviceOrdinal(const std::string& device) {
return std::stoi(device.substr(pos + 1));
}

std::unordered_map<std::string, std::string> pjrt_plugins_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought user can only register one plug in? What;s the use case of registering multiple?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll register TPU and GPU as default options, and then other packages will add plugins on top of those. JAX is also using Python entry points to register available plugins automatically, which we may also want to do.

import torch_xla.runtime as xr
from torch_xla._internal import tpu

plugins.register_plugin('TPU', tpu.TpuPlugin())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you going to put this in our init file eventually?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I'm avoiding any changes to the default behavior while this is WIP.

@will-cromar will-cromar marked this pull request as ready for review December 18, 2023 19:32
Copy link
Collaborator

@vanbasten23 vanbasten23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks.

@will-cromar will-cromar merged commit ad14582 into master Dec 18, 2023
20 checks passed
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Jan 3, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants