From 9c38291c87109c0174f138b787f7d77234fdafe0 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 25 Sep 2023 18:35:53 +0000 Subject: [PATCH 01/19] Dynamic PJRT plugin API --- torch_xla/_internal/pjrt.py | 20 ++++------- torch_xla/_internal/tpu.py | 22 ++++++++++++ torch_xla/core/xla_model.py | 11 +++--- torch_xla/csrc/init_python_bindings.cpp | 4 +++ torch_xla/csrc/runtime/computation_client.cc | 11 ++++++ torch_xla/csrc/runtime/computation_client.h | 4 +++ torch_xla/csrc/runtime/initialize_pjrt.cc | 15 ++++++-- torch_xla/csrc/runtime/runtime.cc | 1 + torch_xla/experimental/plugins.py | 36 ++++++++++++++++++++ torch_xla/runtime.py | 10 ++---- 10 files changed, 106 insertions(+), 28 deletions(-) create mode 100644 torch_xla/experimental/plugins.py diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index b92bba55679..3cfc5d15a84 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -15,6 +15,7 @@ from torch_xla._internal import tpu, gpu, neuron from torch_xla import runtime import torch_xla.utils.utils as xu +from torch_xla.experimental import plugins R = TypeVar('R') @@ -96,8 +97,7 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]: """ os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, '1') - if runtime.device_type() == 'TPU': - tpu.configure_one_chip_topology() + plugins.default().configure_single_process() xm.set_replication(xm.xla_device(), []) @@ -109,10 +109,7 @@ def initialize_multiprocess(local_rank: int, local_world_size: int): os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank)) os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, str(local_world_size)) - if runtime.device_type() == 'TPU': - tpu.configure_topology(local_rank, local_world_size) - elif runtime.device_type() == 'NEURON': - neuron.initialize_env(local_rank) + plugins.default().configure_multiprocess(local_rank, local_world_size) devices = xm.get_xla_supported_devices() xm.set_replication(xm.xla_device(), devices) @@ -138,14 +135,7 @@ def run_multiprocess(fn: Callable[..., R], Dict of the form {device_ordinal: return_value}, where return_value is the result of calling `fn`. """ - if runtime.device_type() == 'TPU': - num_processes = tpu.num_local_processes() - elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): - num_processes = gpu.num_local_processes() - elif runtime.device_type() == 'NEURON': - num_processes = neuron.num_local_processes() - else: - num_processes = 1 + num_processes = plugins.default().physical_chip_count() with concurrent.futures.ProcessPoolExecutor( max_workers=num_processes, @@ -161,6 +151,8 @@ def run_multiprocess(fn: Callable[..., R], itertools.chain.from_iterable( result.items() for result in process_results)) + plugins.default().shutdown() + return _merge_replica_results(replica_results) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 385566b1d35..2cc58303b2c 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -16,6 +16,8 @@ import torch_xla.utils.utils as xu import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm +from torch_xla.experimental import plugins + _GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1' _ACCELERATOR_TYPE_TO_HOST_BOUNDS = { @@ -319,3 +321,23 @@ def _spmd_find_master_ip(current_worker_hostname: str) -> str: if proc == 0: return str(ip_address(ip)) raise RuntimeError('Could not find IP of host running process 0') + + +class TpuPlugin(plugins.DevicePlugin): + def library_path(self): + return os.getenv('TPU_LIBRARY_PATH') + + def host_index(self): + return worker_id() + + def configure_single_process(self): + return configure_one_chip_topology() + + def configure_multiprocess(self, local_rank, local_world_size): + return configure_topology(local_rank, local_world_size) + + def local_process_count(self): + return num_available_chips() + + +# plugins.register_plugin('tpu', TpuPlugin()) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 268b590b2e3..2f031f94370 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -81,7 +81,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None): """Returns a list of supported devices of a given kind. Args: - devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU`, + devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU`, `NEURON` or `CPU` (the 'GPU' XLA device is currently not implemented). max_devices (int, optional): The maximum number of devices to be returned of that kind. @@ -191,7 +191,7 @@ def xla_device(n=None, devkind=None): n (int, optional): The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of `devkind` will be returned. - devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU` + devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU` `NEURON`, `ROCM` or `CPU`. Returns: @@ -215,7 +215,10 @@ def _xla_real_device(device): return _DEVICES.value[int(m.group(1))] -def xla_real_devices(devices): +def xla_real_devices(devices: Optional[List[torch.device]] = None): + if not devices: + devices = get_xla_supported_devices() + return [_xla_real_device(device) for device in devices] @@ -227,7 +230,7 @@ def xla_device_hw(device): real device. Returns: - A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`) + A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`) of the given device. """ real_device = _xla_real_device(device) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 03c5fa95ce1..5a166d34993 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2317,6 +2317,10 @@ void InitXlaModuleBindings(py::module m) { return retlist; }); // -------------Dynamo Integration API End------------------------- + m.def("_register_pjrt_plugin", + [](std::string name, std::string library_path) { + runtime::ComputationClient::RegisterPjRtPlugin(name, library_path); + }); } } // namespace diff --git a/torch_xla/csrc/runtime/computation_client.cc b/torch_xla/csrc/runtime/computation_client.cc index b2feb2e25dc..c9eceec543c 100644 --- a/torch_xla/csrc/runtime/computation_client.cc +++ b/torch_xla/csrc/runtime/computation_client.cc @@ -53,6 +53,17 @@ int64_t ComputationClient::GetDeviceOrdinal(const std::string& device) { return std::stoi(device.substr(pos + 1)); } +std::unordered_map pjrt_plugins_; + +void ComputationClient::RegisterPjRtPlugin(std::string name, std::string library_path) { + TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; + pjrt_plugins_[name] = library_path; +} + +std::unordered_map& ComputationClient::GetPjRtPlugins() { + return pjrt_plugins_; +} + metrics::Metric* ComputationClient::TransferToServerMetric() { static metrics::Metric* metric = new metrics::Metric("TransferToServerTime", metrics::MetricFnTime); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 97f633a39f4..04c21f52a8b 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -374,6 +374,10 @@ class ComputationClient { // after the last ':' character of the device string. static int64_t GetDeviceOrdinal(const std::string& device); + static void RegisterPjRtPlugin(std::string name, std::string library_path); + + static std::unordered_map& GetPjRtPlugins(); + protected: static constexpr auto spmd_device_str = "SPMD:0"; diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 4e5f6ba7a1d..abb5e1a5261 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -122,11 +122,20 @@ InitializePjRt(const std::string& device_type) { .status()); client = std::move(xla::GetCApiClient("NEURON").value()); } else { - XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, - device_type); + auto plugins = ComputationClient::GetPjRtPlugins(); + if (plugins.find(device_type) != plugins.end()) { + TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; + XLA_CHECK_OK(pjrt::LoadPjrtPlugin( + device_type, plugins[device_type])); + tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); + XLA_CHECK_OK(init_status); + client_ = std::move(xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value()); + } } - XLA_CHECK(client.get() != nullptr); + + XLA_CHECK(client.get() != nullptr) << + absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, device_type); return {std::move(client), std::move(coordinator)}; } diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index feb2a0844c6..de7c47759a5 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -5,6 +5,7 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" +#include "torch_xla/csrc/runtime/tf_logging.h" #include "tsl/platform/stacktrace_handler.h" namespace torch_xla { diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py new file mode 100644 index 00000000000..ac6b84b9558 --- /dev/null +++ b/torch_xla/experimental/plugins.py @@ -0,0 +1,36 @@ +import torch_xla +import torch_xla.runtime as xr + +class DevicePlugin: + """Base class for device plugings. + + Default implementations assume a single device and local process. + """ + + def library_path(self) -> str: + raise NotImplementedError() + + def host_index(self) -> int: + return 0 + + def configure_single_process(self): + raise NotImplementedError() + + def configure_multiprocess(self, local_rank, local_world_size): + pass + + def physical_chip_count(): + return 1 + + def shutdown(): + pass + + +_plugin_registry = {} + +def default() -> DevicePlugin: + return _plugin_registry[xr.device_type()] + +def register_plugin(name: str, device_plugin: DevicePlugin): + _plugin_registry[name.upper()] = device_plugin + torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path()) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 30cc9aae6d9..7734820ff46 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -10,6 +10,7 @@ import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu import torch_xla._internal.tpu as tpu +from torch_xla.experimental import plugins R = TypeVar('R') FN = TypeVar('FN') @@ -59,7 +60,7 @@ def device_type() -> Optional[str]: """ _maybe_select_default_device() pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str) - return pjrt_device.split('_')[0] if pjrt_device else pjrt_device + return pjrt_device.split('_')[0].upper() if pjrt_device else pjrt_device def using_pjrt() -> bool: @@ -198,12 +199,7 @@ def process_count() -> int: @requires_pjrt def host_index() -> int: - if device_type() == 'TPU': - return tpu.worker_id() - - # TODO: Update this when we support multi-host GPU - return 0 - + plugins.default.host_index() # API below will be used to query physcial device attribute. @requires_pjrt From 5936932175425b62ed2b074d450914b41b205d6c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 26 Sep 2023 22:33:05 +0000 Subject: [PATCH 02/19] docstrings --- torch_xla/experimental/plugins.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index ac6b84b9558..13967585e34 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -8,21 +8,36 @@ class DevicePlugin: """ def library_path(self) -> str: + """Path to PJRT plugin binary.""" raise NotImplementedError() def host_index(self) -> int: + """Index of the current host.""" return 0 def configure_single_process(self): + """Configure this process to run with world_size 1 for debugging.""" raise NotImplementedError() def configure_multiprocess(self, local_rank, local_world_size): + """Configure device topology for running in a multiprocess context. + + This is called when processes are being initialized by `xmp.spawn` or + `torchrun`. Typically, each process should be assigned a different physical + device from the host. + """ pass def physical_chip_count(): + """The number of physical chips available on this host. + + This is the number of processes we expect to be created by `xmp.spawn` or + for `torchrun`. + """ return 1 def shutdown(): + """Performs any necessary cleanup for this device.""" pass From 2c4dcd427511fe22140ed42d9bb996a52b652458 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 5 Dec 2023 19:40:30 +0000 Subject: [PATCH 03/19] formatting --- torch_xla/csrc/runtime/computation_client.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.cc b/torch_xla/csrc/runtime/computation_client.cc index c9eceec543c..ea78c3c4adf 100644 --- a/torch_xla/csrc/runtime/computation_client.cc +++ b/torch_xla/csrc/runtime/computation_client.cc @@ -55,12 +55,14 @@ int64_t ComputationClient::GetDeviceOrdinal(const std::string& device) { std::unordered_map pjrt_plugins_; -void ComputationClient::RegisterPjRtPlugin(std::string name, std::string library_path) { +void ComputationClient::RegisterPjRtPlugin(std::string name, + std::string library_path) { TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; pjrt_plugins_[name] = library_path; } -std::unordered_map& ComputationClient::GetPjRtPlugins() { +std::unordered_map& +ComputationClient::GetPjRtPlugins() { return pjrt_plugins_; } From db4825c8a878f92dea23ce2504702d171cb0a2d6 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 11 Dec 2023 18:31:10 +0000 Subject: [PATCH 04/19] fix tpu plugin --- torch_xla/__init__.py | 4 ++++ torch_xla/_internal/tpu.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 7191b5d5bb9..fa2f102582b 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -160,3 +160,7 @@ def _init_xla_lazy_backend(): torch._dynamo.config.automatic_dynamic_shapes = False from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo + +from .experimental import plugins + +plugins.register_plugin('tpu', tpu.TpuPlugin()) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 2cc58303b2c..7945002e632 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -325,7 +325,11 @@ def _spmd_find_master_ip(current_worker_hostname: str) -> str: class TpuPlugin(plugins.DevicePlugin): def library_path(self): - return os.getenv('TPU_LIBRARY_PATH') + libtpu_path = os.getenv('TPU_LIBRARY_PATH') or os.getenv('PTXLA_TPU_LIBRARY_PATH') + if not libtpu_path: + raise EnvironmentError('libtpu not found') + + return libtpu_path def host_index(self): return worker_id() From faaa6bebfb3740d479b4dc4dd803fd40387a6665 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 19:38:34 +0000 Subject: [PATCH 05/19] expose library path directly --- torch_xla/csrc/runtime/computation_client.cc | 8 +++++--- torch_xla/csrc/runtime/computation_client.h | 3 ++- torch_xla/csrc/runtime/initialize_pjrt.cc | 12 +++++++----- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.cc b/torch_xla/csrc/runtime/computation_client.cc index ea78c3c4adf..aa55a1eaeaa 100644 --- a/torch_xla/csrc/runtime/computation_client.cc +++ b/torch_xla/csrc/runtime/computation_client.cc @@ -61,9 +61,11 @@ void ComputationClient::RegisterPjRtPlugin(std::string name, pjrt_plugins_[name] = library_path; } -std::unordered_map& -ComputationClient::GetPjRtPlugins() { - return pjrt_plugins_; +std::optional ComputationClient::GetPjRtPluginPath( + const std::string& device_type) { + auto plugin_path = pjrt_plugins_.find(device_type); + return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second) + : std::nullopt; } metrics::Metric* ComputationClient::TransferToServerMetric() { diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 04c21f52a8b..1213d1fef97 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -376,7 +376,8 @@ class ComputationClient { static void RegisterPjRtPlugin(std::string name, std::string library_path); - static std::unordered_map& GetPjRtPlugins(); + static std::optional GetPjRtPluginPath( + const std::string& device_type); protected: static constexpr auto spmd_device_str = "SPMD:0"; diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index abb5e1a5261..5454729f101 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -122,14 +122,16 @@ InitializePjRt(const std::string& device_type) { .status()); client = std::move(xla::GetCApiClient("NEURON").value()); } else { - auto plugins = ComputationClient::GetPjRtPlugins(); - if (plugins.find(device_type) != plugins.end()) { + std::optional plugin_path = + ComputationClient::GetPjRtPluginPath( + absl::AsciiStrToLower(device_type)); + if (plugin_path) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - device_type, plugins[device_type])); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin(device_type, *plugin_path).status()); tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); XLA_CHECK_OK(init_status); - client_ = std::move(xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value()); + client_ = std::move( + xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value()); } } From a6686a4e8bf3df8025545ab2169c55dc1760ca05 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 19:52:20 +0000 Subject: [PATCH 06/19] don't register TPU plugin as a default --- torch_xla/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index fa2f102582b..7191b5d5bb9 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -160,7 +160,3 @@ def _init_xla_lazy_backend(): torch._dynamo.config.automatic_dynamic_shapes = False from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo - -from .experimental import plugins - -plugins.register_plugin('tpu', tpu.TpuPlugin()) From 664f11f054098de8f9bec8cb3c970e91cec2b470 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 19:52:29 +0000 Subject: [PATCH 07/19] unit test --- test/pjrt/test_dynamic_plugin_tpu.py | 24 +++++++++++++++++++++++ torch_xla/csrc/runtime/initialize_pjrt.cc | 5 ++--- 2 files changed, 26 insertions(+), 3 deletions(-) create mode 100644 test/pjrt/test_dynamic_plugin_tpu.py diff --git a/test/pjrt/test_dynamic_plugin_tpu.py b/test/pjrt/test_dynamic_plugin_tpu.py new file mode 100644 index 00000000000..f8f7f13cc60 --- /dev/null +++ b/test/pjrt/test_dynamic_plugin_tpu.py @@ -0,0 +1,24 @@ +import os + +from absl.testing import absltest +import torch_xla.core.xla_model as xm +from torch_xla.experimental import plugins +import torch_xla.runtime as xr +from torch_xla._internal import tpu + + +class TestDynamicTpuPlugin(absltest.TestCase): + @classmethod + def setUpClass(xls): + # TODO python API + os.environ['XLA_DYNAMIC_PLUGINS'] = '1' + + # HACK: use lower case "tpu" so we don't collide with default libtpu case + xr.set_device_type('tpu') + plugins.register_plugin('tpu', tpu.TpuPlugin()) + + def test_dynamic_plugin_api(self): + self.assertNotEmpty(xm.get_xla_supported_devices('TPU')) + +if __name__ == '__main__': + absltest.main() diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 5454729f101..36cb69528f2 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -123,11 +123,10 @@ InitializePjRt(const std::string& device_type) { client = std::move(xla::GetCApiClient("NEURON").value()); } else { std::optional plugin_path = - ComputationClient::GetPjRtPluginPath( - absl::AsciiStrToLower(device_type)); + ComputationClient::GetPjRtPluginPath(device_type); if (plugin_path) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin(device_type, *plugin_path).status()); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path).status()); tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); XLA_CHECK_OK(init_status); client_ = std::move( From 98d45f96b6a8b04ad61feb60734520364b375c9e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 21:12:06 +0000 Subject: [PATCH 08/19] Add switch for dynamic plugins --- test/pjrt/test_dynamic_plugin_tpu.py | 7 +++---- torch_xla/core/xla_env_vars.py | 1 + torch_xla/csrc/runtime/env_vars.cc | 1 + torch_xla/csrc/runtime/env_vars.h | 1 + torch_xla/csrc/runtime/initialize_pjrt.cc | 13 ++++++++++++- torch_xla/experimental/plugins.py | 11 +++++++++++ 6 files changed, 29 insertions(+), 5 deletions(-) diff --git a/test/pjrt/test_dynamic_plugin_tpu.py b/test/pjrt/test_dynamic_plugin_tpu.py index f8f7f13cc60..55b2ac07db5 100644 --- a/test/pjrt/test_dynamic_plugin_tpu.py +++ b/test/pjrt/test_dynamic_plugin_tpu.py @@ -10,12 +10,11 @@ class TestDynamicTpuPlugin(absltest.TestCase): @classmethod def setUpClass(xls): - # TODO python API - os.environ['XLA_DYNAMIC_PLUGINS'] = '1' + plugins.use_dynamic_plugins() # HACK: use lower case "tpu" so we don't collide with default libtpu case - xr.set_device_type('tpu') - plugins.register_plugin('tpu', tpu.TpuPlugin()) + xr.set_device_type('TPU') + plugins.register_plugin('TPU', tpu.TpuPlugin()) def test_dynamic_plugin_api(self): self.assertNotEmpty(xm.get_xla_supported_devices('TPU')) diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index eb79ff14310..2d256c77a54 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -17,6 +17,7 @@ PJRT_SELECT_DEFAULT_DEVICE = 'PJRT_SELECT_DEFAULT_DEVICE' PJRT_LOCAL_PROCESS_RANK = 'PJRT_LOCAL_PROCESS_RANK' PJRT_LOCAL_PROCESS_COUNT = 'PJRT_LOCAL_PROCESS_COUNT' +PJRT_DYNAMIC_PLUGINS = 'PJRT_DYNAMIC_PLUGINS' TPU_CHIPS_PER_PROCESS_BOUNDS = 'TPU_CHIPS_PER_PROCESS_BOUNDS' TPU_PROCESS_BOUNDS = 'TPU_PROCESS_BOUNDS' TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES' diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index 733574a4818..f774d578ca6 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -23,6 +23,7 @@ const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK"; const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC"; const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE"; const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION"; +const char* const kEnvPjrtDynamicPlugins = "PJRT_DYNAMIC_PLUGINS"; } // namespace env } // namespace runtime diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index e7e1ef81964..3affac2031e 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -33,6 +33,7 @@ extern const char* const kEnvPjRtLocalRank; extern const char* const kEnvPjrtAllocatorCudaAsync; extern const char* const kEnvPjrtAllocatorPreallocate; extern const char* const kEnvPjrtAllocatorFraction; +extern const char* const kEnvPjrtDynamicPlugins; } // namespace env } // namespace runtime diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 36cb69528f2..86dad4ca924 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -40,7 +40,18 @@ InitializePjRt(const std::string& device_type) { std::unique_ptr client; std::unique_ptr coordinator; - if (device_type == "CPU") { + if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) { + std::optional plugin_path = + ComputationClient::GetPjRtPluginPath(device_type); + if (plugin_path) { + TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; + XLA_CHECK_OK(pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path).status()); + tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); + XLA_CHECK_OK(init_status); + client_ = std::move( + xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value()); + } + } else if (device_type == "CPU") { TF_VLOG(1) << "Initializing PjRt CPU client..."; bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index 13967585e34..9b074c53ab5 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -1,4 +1,7 @@ +import os + import torch_xla +import torch_xla.core.xla_env_vars as xenv import torch_xla.runtime as xr class DevicePlugin: @@ -43,6 +46,14 @@ def shutdown(): _plugin_registry = {} +def use_dynamic_plugins(): + if torch_xla._XLAC._xla_runtime_is_initialized() and os.environ.get( + xenv.PJRT_DEVICE) != "1": + raise RuntimeError( + "Can't enable dynamic plugins after XLA runtime is initialized") + + os.environ[xenv.PJRT_DYNAMIC_PLUGINS] = "1" + def default() -> DevicePlugin: return _plugin_registry[xr.device_type()] From 5b430da9c26fb2a8305333c3e082e2bf62be32c5 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 21:27:29 +0000 Subject: [PATCH 09/19] fix spawning --- test/pjrt/test_dynamic_plugin_tpu.py | 22 +++++++++++++++------- torch_xla/_internal/tpu.py | 2 +- torch_xla/experimental/plugins.py | 4 ++-- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/test/pjrt/test_dynamic_plugin_tpu.py b/test/pjrt/test_dynamic_plugin_tpu.py index 55b2ac07db5..8cc30fcb2c8 100644 --- a/test/pjrt/test_dynamic_plugin_tpu.py +++ b/test/pjrt/test_dynamic_plugin_tpu.py @@ -1,23 +1,31 @@ -import os +import concurrent.futures from absl.testing import absltest import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp from torch_xla.experimental import plugins import torch_xla.runtime as xr from torch_xla._internal import tpu +plugins.register_plugin('TPU', tpu.TpuPlugin()) +plugins.use_dynamic_plugins() class TestDynamicTpuPlugin(absltest.TestCase): @classmethod - def setUpClass(xls): - plugins.use_dynamic_plugins() - - # HACK: use lower case "tpu" so we don't collide with default libtpu case + def setUpClass(cls): xr.set_device_type('TPU') - plugins.register_plugin('TPU', tpu.TpuPlugin()) + + @staticmethod + def _assert_tpus_exist(index = 0): + del index + assert len(xm.get_xla_supported_devices('TPU')) > 0 def test_dynamic_plugin_api(self): - self.assertNotEmpty(xm.get_xla_supported_devices('TPU')) + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + executor.submit(self._assert_tpus_exist).result() + + def test_spawn(self): + xmp.spawn(self._assert_tpus_exist) if __name__ == '__main__': absltest.main() diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 7945002e632..05511382533 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -340,7 +340,7 @@ def configure_single_process(self): def configure_multiprocess(self, local_rank, local_world_size): return configure_topology(local_rank, local_world_size) - def local_process_count(self): + def physical_chip_count(self): return num_available_chips() diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index 9b074c53ab5..f69cb92afb4 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -31,7 +31,7 @@ def configure_multiprocess(self, local_rank, local_world_size): """ pass - def physical_chip_count(): + def physical_chip_count(self): """The number of physical chips available on this host. This is the number of processes we expect to be created by `xmp.spawn` or @@ -39,7 +39,7 @@ def physical_chip_count(): """ return 1 - def shutdown(): + def shutdown(self): """Performs any necessary cleanup for this device.""" pass From 3cf42d0f56f6175bbc5064df9be8924ffba69bb6 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 21:54:25 +0000 Subject: [PATCH 10/19] remove comment --- torch_xla/_internal/tpu.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 05511382533..993cd710831 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -342,6 +342,3 @@ def configure_multiprocess(self, local_rank, local_world_size): def physical_chip_count(self): return num_available_chips() - - -# plugins.register_plugin('tpu', TpuPlugin()) From 63f08f586fb9d438b5de66370c5534d07f27e68e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 22:09:38 +0000 Subject: [PATCH 11/19] clean up and unbreak stuff --- torch_xla/_internal/pjrt.py | 25 ++++++++++++++++++++----- torch_xla/csrc/runtime/runtime.cc | 1 - torch_xla/experimental/plugins.py | 8 ++++---- torch_xla/runtime.py | 11 +++++++++-- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 3cfc5d15a84..a44157f7b86 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -97,7 +97,10 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]: """ os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, '1') - plugins.default().configure_single_process() + if plugins.using_dynamic_plugins(): + plugins.default().configure_single_process() + elif runtime.device_type() == 'TPU': + tpu.configure_one_chip_topology() xm.set_replication(xm.xla_device(), []) @@ -109,7 +112,12 @@ def initialize_multiprocess(local_rank: int, local_world_size: int): os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank)) os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, str(local_world_size)) - plugins.default().configure_multiprocess(local_rank, local_world_size) + if plugins.using_dynamic_plugins(): + plugins.default().configure_multiprocess(local_rank, local_world_size) + elif runtime.device_type() == 'TPU': + tpu.configure_topology(local_rank, local_world_size) + elif runtime.device_type() == 'NEURON': + neuron.initialize_env(local_rank) devices = xm.get_xla_supported_devices() xm.set_replication(xm.xla_device(), devices) @@ -135,7 +143,16 @@ def run_multiprocess(fn: Callable[..., R], Dict of the form {device_ordinal: return_value}, where return_value is the result of calling `fn`. """ - num_processes = plugins.default().physical_chip_count() + if plugins.using_dynamic_plugins(): + num_processes = plugins.default().physical_chip_count() + elif runtime.device_type() == 'TPU': + num_processes = tpu.num_local_processes() + elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): + num_processes = gpu.num_local_processes() + elif runtime.device_type() == 'NEURON': + num_processes = neuron.num_local_processes() + else: + num_processes = 1 with concurrent.futures.ProcessPoolExecutor( max_workers=num_processes, @@ -151,8 +168,6 @@ def run_multiprocess(fn: Callable[..., R], itertools.chain.from_iterable( result.items() for result in process_results)) - plugins.default().shutdown() - return _merge_replica_results(replica_results) diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index de7c47759a5..feb2a0844c6 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -5,7 +5,6 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" -#include "torch_xla/csrc/runtime/tf_logging.h" #include "tsl/platform/stacktrace_handler.h" namespace torch_xla { diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index f69cb92afb4..a2d12d3c496 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -3,6 +3,7 @@ import torch_xla import torch_xla.core.xla_env_vars as xenv import torch_xla.runtime as xr +import torch_xla.utils.utils as xu class DevicePlugin: """Base class for device plugings. @@ -39,10 +40,6 @@ def physical_chip_count(self): """ return 1 - def shutdown(self): - """Performs any necessary cleanup for this device.""" - pass - _plugin_registry = {} @@ -54,6 +51,9 @@ def use_dynamic_plugins(): os.environ[xenv.PJRT_DYNAMIC_PLUGINS] = "1" +def using_dynamic_plugins(): + return xu.getenv_as(xenv.PJRT_DYNAMIC_PLUGINS, bool, False) + def default() -> DevicePlugin: return _plugin_registry[xr.device_type()] diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 7734820ff46..bc98a5b4c17 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -60,7 +60,7 @@ def device_type() -> Optional[str]: """ _maybe_select_default_device() pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str) - return pjrt_device.split('_')[0].upper() if pjrt_device else pjrt_device + return pjrt_device.split('_')[0] if pjrt_device else pjrt_device def using_pjrt() -> bool: @@ -199,7 +199,14 @@ def process_count() -> int: @requires_pjrt def host_index() -> int: - plugins.default.host_index() + if plugins.using_dynamic_plugins(): + return plugins.default().host_index() + elif device_type() == 'TPU': + return tpu.worker_id() + + # TODO: Update this when we support multi-host GPU + return 0 + # API below will be used to query physcial device attribute. @requires_pjrt From 802b4f9d1921fb05d6276f840387b6bffb98600e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 22:12:18 +0000 Subject: [PATCH 12/19] formatting --- test/pjrt/test_dynamic_plugin_tpu.py | 5 ++++- torch_xla/_internal/tpu.py | 5 +++-- torch_xla/csrc/runtime/initialize_pjrt.cc | 13 ++++++++----- torch_xla/experimental/plugins.py | 5 +++++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/pjrt/test_dynamic_plugin_tpu.py b/test/pjrt/test_dynamic_plugin_tpu.py index 8cc30fcb2c8..e68b660a064 100644 --- a/test/pjrt/test_dynamic_plugin_tpu.py +++ b/test/pjrt/test_dynamic_plugin_tpu.py @@ -10,13 +10,15 @@ plugins.register_plugin('TPU', tpu.TpuPlugin()) plugins.use_dynamic_plugins() + class TestDynamicTpuPlugin(absltest.TestCase): + @classmethod def setUpClass(cls): xr.set_device_type('TPU') @staticmethod - def _assert_tpus_exist(index = 0): + def _assert_tpus_exist(index=0): del index assert len(xm.get_xla_supported_devices('TPU')) > 0 @@ -27,5 +29,6 @@ def test_dynamic_plugin_api(self): def test_spawn(self): xmp.spawn(self._assert_tpus_exist) + if __name__ == '__main__': absltest.main() diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 993cd710831..c4b3d80a423 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -18,7 +18,6 @@ import torch_xla.core.xla_model as xm from torch_xla.experimental import plugins - _GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1' _ACCELERATOR_TYPE_TO_HOST_BOUNDS = { # v2 @@ -324,8 +323,10 @@ def _spmd_find_master_ip(current_worker_hostname: str) -> str: class TpuPlugin(plugins.DevicePlugin): + def library_path(self): - libtpu_path = os.getenv('TPU_LIBRARY_PATH') or os.getenv('PTXLA_TPU_LIBRARY_PATH') + libtpu_path = os.getenv('TPU_LIBRARY_PATH') or os.getenv( + 'PTXLA_TPU_LIBRARY_PATH') if not libtpu_path: raise EnvironmentError('libtpu not found') diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 86dad4ca924..4ab8e41439e 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -45,7 +45,9 @@ InitializePjRt(const std::string& device_type) { ComputationClient::GetPjRtPluginPath(device_type); if (plugin_path) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path).status()); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path) + .status()); tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); XLA_CHECK_OK(init_status); client_ = std::move( @@ -137,7 +139,9 @@ InitializePjRt(const std::string& device_type) { ComputationClient::GetPjRtPluginPath(device_type); if (plugin_path) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path).status()); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path) + .status()); tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); XLA_CHECK_OK(init_status); client_ = std::move( @@ -145,9 +149,8 @@ InitializePjRt(const std::string& device_type) { } } - - XLA_CHECK(client.get() != nullptr) << - absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, device_type); + XLA_CHECK(client.get() != nullptr) + << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, device_type); return {std::move(client), std::move(coordinator)}; } diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index a2d12d3c496..90b0ff2b8b0 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -5,6 +5,7 @@ import torch_xla.runtime as xr import torch_xla.utils.utils as xu + class DevicePlugin: """Base class for device plugings. @@ -43,6 +44,7 @@ def physical_chip_count(self): _plugin_registry = {} + def use_dynamic_plugins(): if torch_xla._XLAC._xla_runtime_is_initialized() and os.environ.get( xenv.PJRT_DEVICE) != "1": @@ -51,12 +53,15 @@ def use_dynamic_plugins(): os.environ[xenv.PJRT_DYNAMIC_PLUGINS] = "1" + def using_dynamic_plugins(): return xu.getenv_as(xenv.PJRT_DYNAMIC_PLUGINS, bool, False) + def default() -> DevicePlugin: return _plugin_registry[xr.device_type()] + def register_plugin(name: str, device_plugin: DevicePlugin): _plugin_registry[name.upper()] = device_plugin torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path()) From 8b2e6e5dda0b5126a5f6601b0fab67f0e1e39da3 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 12 Dec 2023 22:15:51 +0000 Subject: [PATCH 13/19] add test to TPU CI --- test/tpu/xla_test_job.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index e727953ddc4..06fe993b51c 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -57,6 +57,7 @@ spec: python3 /src/pytorch/xla/test/test_autocast.py python3 /src/pytorch/xla/test/dynamo/test_dynamo.py python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py + python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py volumeMounts: - mountPath: /dev/shm name: dshm From e9a4632a9e0d4eedca5d393a8fc3d810b703521b Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 14 Dec 2023 18:17:23 +0000 Subject: [PATCH 14/19] find profiler plugin --- torch_xla/csrc/runtime/BUILD | 2 +- torch_xla/csrc/runtime/initialize_pjrt.cc | 9 ++++----- torch_xla/csrc/runtime/profiler.cc | 6 ++++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 664d26e7674..d8f5a410466 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -292,7 +292,7 @@ cc_library( srcs = ["profiler.cc"], hdrs = ["profiler.h"], deps = [ - ":debug_macros", + ":tf_logging", ":profiler_backends", "@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs", "@xla//xla/backends/profiler/plugin:plugin_tracer", diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 4ab8e41439e..0a53cd54c31 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -45,13 +45,12 @@ InitializePjRt(const std::string& device_type) { ComputationClient::GetPjRtPluginPath(device_type); if (plugin_path) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - XLA_CHECK_OK( - pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path) - .status()); - tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); - XLA_CHECK_OK(init_status); + const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( + absl::AsciiStrToLower(device_type), *plugin_path); + XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); client_ = std::move( xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value()); + profiler::RegisterProfilerForPlugin(c_api); } } else if (device_type == "CPU") { TF_VLOG(1) << "Initializing PjRt CPU client..."; diff --git a/torch_xla/csrc/runtime/profiler.cc b/torch_xla/csrc/runtime/profiler.cc index a2ea89be16d..a365c1aede3 100644 --- a/torch_xla/csrc/runtime/profiler.cc +++ b/torch_xla/csrc/runtime/profiler.cc @@ -1,7 +1,7 @@ #include "torch_xla/csrc/runtime/profiler.h" #include "absl/container/flat_hash_map.h" -#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/tf_logging.h" #include "tsl/platform/status.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/rpc/client/capture_profile.h" @@ -58,7 +58,9 @@ tsl::Status Trace( void RegisterProfilerForPlugin(const PJRT_Api* c_api) { const PLUGIN_Profiler_Api* profiler_api = FindProfilerApi(c_api); - XLA_CHECK(profiler_api); + if (!profiler_api) { + TF_LOG(WARNING) << "Profiler API not found for PJRT plugin"; + } tsl::profiler::ProfilerFactory create_func = [profiler_api](const tensorflow::ProfileOptions& options) { From 06a1094a8f6f6beed3627e247ea65135ce020760 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 14 Dec 2023 21:53:09 +0000 Subject: [PATCH 15/19] make test case name clearer --- test/pjrt/test_dynamic_plugin_tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pjrt/test_dynamic_plugin_tpu.py b/test/pjrt/test_dynamic_plugin_tpu.py index e68b660a064..d218fb0c9af 100644 --- a/test/pjrt/test_dynamic_plugin_tpu.py +++ b/test/pjrt/test_dynamic_plugin_tpu.py @@ -22,7 +22,7 @@ def _assert_tpus_exist(index=0): del index assert len(xm.get_xla_supported_devices('TPU')) > 0 - def test_dynamic_plugin_api(self): + def test_single_process(self): with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: executor.submit(self._assert_tpus_exist).result() From 27a59e8006d18c0bad3d3a5a8811a49cb2fa924c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 19:05:31 +0000 Subject: [PATCH 16/19] fix some merging issues --- torch_xla/csrc/runtime/initialize_pjrt.cc | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index 0a53cd54c31..ab394d5a97c 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -48,8 +48,7 @@ InitializePjRt(const std::string& device_type) { const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( absl::AsciiStrToLower(device_type), *plugin_path); XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); - client_ = std::move( - xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value()); + client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value(); profiler::RegisterProfilerForPlugin(c_api); } } else if (device_type == "CPU") { @@ -133,22 +132,9 @@ InitializePjRt(const std::string& device_type) { "libneuronpjrt.so")) .status()); client = std::move(xla::GetCApiClient("NEURON").value()); - } else { - std::optional plugin_path = - ComputationClient::GetPjRtPluginPath(device_type); - if (plugin_path) { - TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - XLA_CHECK_OK( - pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type), *plugin_path) - .status()); - tsl::Status init_status = pjrt::InitializePjrtPlugin(device_type); - XLA_CHECK_OK(init_status); - client_ = std::move( - xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value()); - } } - XLA_CHECK(client.get() != nullptr) + XLA_CHECK(client) << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, device_type); return {std::move(client), std::move(coordinator)}; From eff43449318488ae9609e93a92c4aadedef7c834 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 19:14:08 +0000 Subject: [PATCH 17/19] move PJRT plugin registry to `initialize_pjrt` --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- torch_xla/csrc/runtime/computation_client.cc | 15 --------------- torch_xla/csrc/runtime/computation_client.h | 5 ----- torch_xla/csrc/runtime/initialize_pjrt.cc | 18 ++++++++++++++++-- torch_xla/csrc/runtime/initialize_pjrt.h | 2 ++ 5 files changed, 20 insertions(+), 23 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5a166d34993..466fbeadd35 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -41,6 +41,7 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/initialize_pjrt.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" @@ -2319,7 +2320,7 @@ void InitXlaModuleBindings(py::module m) { // -------------Dynamo Integration API End------------------------- m.def("_register_pjrt_plugin", [](std::string name, std::string library_path) { - runtime::ComputationClient::RegisterPjRtPlugin(name, library_path); + runtime::RegisterPjRtPlugin(name, library_path); }); } } // namespace diff --git a/torch_xla/csrc/runtime/computation_client.cc b/torch_xla/csrc/runtime/computation_client.cc index aa55a1eaeaa..b2feb2e25dc 100644 --- a/torch_xla/csrc/runtime/computation_client.cc +++ b/torch_xla/csrc/runtime/computation_client.cc @@ -53,21 +53,6 @@ int64_t ComputationClient::GetDeviceOrdinal(const std::string& device) { return std::stoi(device.substr(pos + 1)); } -std::unordered_map pjrt_plugins_; - -void ComputationClient::RegisterPjRtPlugin(std::string name, - std::string library_path) { - TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; - pjrt_plugins_[name] = library_path; -} - -std::optional ComputationClient::GetPjRtPluginPath( - const std::string& device_type) { - auto plugin_path = pjrt_plugins_.find(device_type); - return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second) - : std::nullopt; -} - metrics::Metric* ComputationClient::TransferToServerMetric() { static metrics::Metric* metric = new metrics::Metric("TransferToServerTime", metrics::MetricFnTime); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 1213d1fef97..97f633a39f4 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -374,11 +374,6 @@ class ComputationClient { // after the last ':' character of the device string. static int64_t GetDeviceOrdinal(const std::string& device); - static void RegisterPjRtPlugin(std::string name, std::string library_path); - - static std::optional GetPjRtPluginPath( - const std::string& device_type); - protected: static constexpr auto spmd_device_str = "SPMD:0"; diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc index ab394d5a97c..6d75b2377b5 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.cc +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -14,6 +14,8 @@ namespace torch_xla { namespace runtime { +std::unordered_map pjrt_plugins_; + namespace { xla::GpuAllocatorConfig GetGpuAllocatorConfig() { @@ -33,16 +35,28 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { return allocator_config; } +std::optional GetPjRtPluginPath( + const std::string& device_type) { + auto plugin_path = pjrt_plugins_.find(device_type); + return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second) + : std::nullopt; +} + } // namespace +void RegisterPjRtPlugin(std::string name, + std::string library_path) { + TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; + pjrt_plugins_[name] = library_path; +} + std::tuple, std::unique_ptr> InitializePjRt(const std::string& device_type) { std::unique_ptr client; std::unique_ptr coordinator; if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) { - std::optional plugin_path = - ComputationClient::GetPjRtPluginPath(device_type); + std::optional plugin_path = GetPjRtPluginPath(device_type); if (plugin_path) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( diff --git a/torch_xla/csrc/runtime/initialize_pjrt.h b/torch_xla/csrc/runtime/initialize_pjrt.h index 012927fe474..79ea9af5419 100644 --- a/torch_xla/csrc/runtime/initialize_pjrt.h +++ b/torch_xla/csrc/runtime/initialize_pjrt.h @@ -6,6 +6,8 @@ namespace torch_xla { namespace runtime { +void RegisterPjRtPlugin(std::string name, std::string library_path); + std::tuple, std::unique_ptr> InitializePjRt(const std::string& device_type); From 8a48e90a92d5b9fc5fb1e1f8774816e343aadf05 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 19:22:58 +0000 Subject: [PATCH 18/19] `initialize_pjrt` -> `pjrt_backend` --- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/runtime/BUILD | 10 +++++----- torch_xla/csrc/runtime/ifrt_computation_client.cc | 2 +- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- .../runtime/{initialize_pjrt.cc => pjrt_registry.cc} | 0 .../runtime/{initialize_pjrt.h => pjrt_registry.h} | 0 6 files changed, 8 insertions(+), 8 deletions(-) rename torch_xla/csrc/runtime/{initialize_pjrt.cc => pjrt_registry.cc} (100%) rename torch_xla/csrc/runtime/{initialize_pjrt.h => pjrt_registry.h} (100%) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 466fbeadd35..edd7c37aaac 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -41,7 +41,7 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/initialize_pjrt.h" +#include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index d8f5a410466..2e28268cb8c 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -83,7 +83,7 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", - ":initialize_pjrt", + ":pjrt_registry", ":operation_manager", ":stablehlo_helper", ":tf_logging", @@ -114,7 +114,7 @@ cc_library( ":debug_macros", ":env_hash", ":env_vars", - ":initialize_pjrt", + ":pjrt_registry", ":operation_manager", ":profiler", ":stablehlo_helper", @@ -194,9 +194,9 @@ cc_test( ) cc_library( - name = "initialize_pjrt", - srcs = ["initialize_pjrt.cc"], - hdrs = ["initialize_pjrt.h"], + name = "pjrt_registry", + srcs = ["pjrt_registry.cc"], + hdrs = ["pjrt_registry.h"], deps = [ ":debug_macros", ":env_hash", diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 605826e6e6a..4fa8790b08c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -11,7 +11,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_hash.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/initialize_pjrt.h" +#include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 1ca27518282..88eca78cca1 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -13,7 +13,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_hash.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/initialize_pjrt.h" +#include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/pjrt_registry.cc similarity index 100% rename from torch_xla/csrc/runtime/initialize_pjrt.cc rename to torch_xla/csrc/runtime/pjrt_registry.cc diff --git a/torch_xla/csrc/runtime/initialize_pjrt.h b/torch_xla/csrc/runtime/pjrt_registry.h similarity index 100% rename from torch_xla/csrc/runtime/initialize_pjrt.h rename to torch_xla/csrc/runtime/pjrt_registry.h From 867a8f1550c59f77a7cc1dbb34ee57d0451fd883 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 19:30:28 +0000 Subject: [PATCH 19/19] formatting --- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- torch_xla/csrc/runtime/pjrt_registry.cc | 10 ++++------ torch_xla/csrc/runtime/pjrt_registry.h | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index edd7c37aaac..5ebb14514b1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -41,10 +41,10 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" +#include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 88eca78cca1..5cb0fe158d5 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -13,8 +13,8 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_hash.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/operation_manager.h" +#include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tensor_source.h" diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 6d75b2377b5..877fc18ec1a 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -35,8 +35,7 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { return allocator_config; } -std::optional GetPjRtPluginPath( - const std::string& device_type) { +std::optional GetPjRtPluginPath(const std::string& device_type) { auto plugin_path = pjrt_plugins_.find(device_type); return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second) : std::nullopt; @@ -44,8 +43,7 @@ std::optional GetPjRtPluginPath( } // namespace -void RegisterPjRtPlugin(std::string name, - std::string library_path) { +void RegisterPjRtPlugin(std::string name, std::string library_path) { TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; pjrt_plugins_[name] = library_path; } @@ -148,8 +146,8 @@ InitializePjRt(const std::string& device_type) { client = std::move(xla::GetCApiClient("NEURON").value()); } - XLA_CHECK(client) - << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, device_type); + XLA_CHECK(client) << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, + device_type); return {std::move(client), std::move(coordinator)}; } diff --git a/torch_xla/csrc/runtime/pjrt_registry.h b/torch_xla/csrc/runtime/pjrt_registry.h index 79ea9af5419..4cb7b70a661 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.h +++ b/torch_xla/csrc/runtime/pjrt_registry.h @@ -11,7 +11,7 @@ void RegisterPjRtPlugin(std::string name, std::string library_path); std::tuple, std::unique_ptr> InitializePjRt(const std::string& device_type); -} +} // namespace runtime } // namespace torch_xla #endif // XLA_CLIENT_INITIALIZE_PJRT_H_