Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic PJRT plugin registration API #5644

Merged
merged 19 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/pjrt/test_dynamic_plugin_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import concurrent.futures

from absl.testing import absltest
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import plugins
import torch_xla.runtime as xr
from torch_xla._internal import tpu

plugins.register_plugin('TPU', tpu.TpuPlugin())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you going to put this in our init file eventually?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I'm avoiding any changes to the default behavior while this is WIP.

plugins.use_dynamic_plugins()


class TestDynamicTpuPlugin(absltest.TestCase):

@classmethod
def setUpClass(cls):
xr.set_device_type('TPU')

@staticmethod
def _assert_tpus_exist(index=0):
del index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index is required for spawn, but we don't need it. I just explicitly delete it to mark it unused

assert len(xm.get_xla_supported_devices('TPU')) > 0

def test_single_process(self):
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the difference b/w test_dynamic_plugin_api and test_spawn is that the former test single processing and the latter test the multi-processing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wrote test_dynamic_plugin_api before the other. I'll change the name to something like test_single_process to be more clear

executor.submit(self._assert_tpus_exist).result()

def test_spawn(self):
xmp.spawn(self._assert_tpus_exist)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ spec:
python3 /src/pytorch/xla/test/test_autocast.py
python3 /src/pytorch/xla/test/dynamo/test_dynamo.py
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py
volumeMounts:
- mountPath: /dev/shm
name: dshm
Expand Down
13 changes: 10 additions & 3 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch_xla._internal import tpu, gpu, neuron
from torch_xla import runtime
import torch_xla.utils.utils as xu
from torch_xla.experimental import plugins

R = TypeVar('R')

Expand Down Expand Up @@ -96,7 +97,9 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
"""
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, '1')

if runtime.device_type() == 'TPU':
if plugins.using_dynamic_plugins():
plugins.default().configure_single_process()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we throw a warning or something when pople configure PJRT_DEVICE while also register the plugin in the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still select the device type with PJRT_DEVICE. Plugins will just let you register new device types when we clean up all of the hardcoded strings.

elif runtime.device_type() == 'TPU':
tpu.configure_one_chip_topology()

xm.set_replication(xm.xla_device(), [])
Expand All @@ -109,7 +112,9 @@ def initialize_multiprocess(local_rank: int, local_world_size: int):
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank))
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, str(local_world_size))

if runtime.device_type() == 'TPU':
if plugins.using_dynamic_plugins():
plugins.default().configure_multiprocess(local_rank, local_world_size)
elif runtime.device_type() == 'TPU':
tpu.configure_topology(local_rank, local_world_size)
elif runtime.device_type() == 'NEURON':
neuron.initialize_env(local_rank)
Expand Down Expand Up @@ -138,7 +143,9 @@ def run_multiprocess(fn: Callable[..., R],
Dict of the form {device_ordinal: return_value}, where
return_value is the result of calling `fn`.
"""
if runtime.device_type() == 'TPU':
if plugins.using_dynamic_plugins():
num_processes = plugins.default().physical_chip_count()
elif runtime.device_type() == 'TPU':
num_processes = tpu.num_local_processes()
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
num_processes = gpu.num_local_processes()
Expand Down
24 changes: 24 additions & 0 deletions torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
from torch_xla.experimental import plugins

_GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1'
_ACCELERATOR_TYPE_TO_HOST_BOUNDS = {
Expand Down Expand Up @@ -319,3 +320,26 @@ def _spmd_find_master_ip(current_worker_hostname: str) -> str:
if proc == 0:
return str(ip_address(ip))
raise RuntimeError('Could not find IP of host running process 0')


class TpuPlugin(plugins.DevicePlugin):

def library_path(self):
libtpu_path = os.getenv('TPU_LIBRARY_PATH') or os.getenv(
'PTXLA_TPU_LIBRARY_PATH')
if not libtpu_path:
raise EnvironmentError('libtpu not found')

return libtpu_path

def host_index(self):
return worker_id()

def configure_single_process(self):
return configure_one_chip_topology()

def configure_multiprocess(self, local_rank, local_world_size):
return configure_topology(local_rank, local_world_size)

def physical_chip_count(self):
return num_available_chips()
1 change: 1 addition & 0 deletions torch_xla/core/xla_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PJRT_SELECT_DEFAULT_DEVICE = 'PJRT_SELECT_DEFAULT_DEVICE'
PJRT_LOCAL_PROCESS_RANK = 'PJRT_LOCAL_PROCESS_RANK'
PJRT_LOCAL_PROCESS_COUNT = 'PJRT_LOCAL_PROCESS_COUNT'
PJRT_DYNAMIC_PLUGINS = 'PJRT_DYNAMIC_PLUGINS'
TPU_CHIPS_PER_PROCESS_BOUNDS = 'TPU_CHIPS_PER_PROCESS_BOUNDS'
TPU_PROCESS_BOUNDS = 'TPU_PROCESS_BOUNDS'
TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES'
Expand Down
11 changes: 7 additions & 4 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
"""Returns a list of supported devices of a given kind.

Args:
devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU`,
devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU`,
`NEURON` or `CPU` (the 'GPU' XLA device is currently not implemented).
max_devices (int, optional): The maximum number of devices to be returned of
that kind.
Expand Down Expand Up @@ -191,7 +191,7 @@ def xla_device(n=None, devkind=None):
n (int, optional): The specific instance (ordinal) to be returned. If
specified, the specific XLA device instance will be returned. Otherwise
the first device of `devkind` will be returned.
devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU`
devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU`
`NEURON`, `ROCM` or `CPU`.

Returns:
Expand All @@ -215,7 +215,10 @@ def _xla_real_device(device):
return _DEVICES.value[int(m.group(1))]


def xla_real_devices(devices):
def xla_real_devices(devices: Optional[List[torch.device]] = None):
if not devices:
devices = get_xla_supported_devices()

return [_xla_real_device(device) for device in devices]


Expand All @@ -227,7 +230,7 @@ def xla_device_hw(device):
real device.

Returns:
A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`)
A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`)
of the given device.
"""
real_device = _xla_real_device(device)
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/metrics_analysis.h"
#include "torch_xla/csrc/runtime/metrics_reader.h"
#include "torch_xla/csrc/runtime/pjrt_registry.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/sys_util.h"
Expand Down Expand Up @@ -2317,6 +2318,10 @@ void InitXlaModuleBindings(py::module m) {
return retlist;
});
// -------------Dynamo Integration API End-------------------------
m.def("_register_pjrt_plugin",
[](std::string name, std::string library_path) {
runtime::RegisterPjRtPlugin(name, library_path);
});
}
} // namespace

Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ cc_library(
":computation_client",
":debug_macros",
":env_vars",
":initialize_pjrt",
":pjrt_registry",
":operation_manager",
":stablehlo_helper",
":tf_logging",
Expand Down Expand Up @@ -114,7 +114,7 @@ cc_library(
":debug_macros",
":env_hash",
":env_vars",
":initialize_pjrt",
":pjrt_registry",
":operation_manager",
":profiler",
":stablehlo_helper",
Expand Down Expand Up @@ -194,9 +194,9 @@ cc_test(
)

cc_library(
name = "initialize_pjrt",
srcs = ["initialize_pjrt.cc"],
hdrs = ["initialize_pjrt.h"],
name = "pjrt_registry",
srcs = ["pjrt_registry.cc"],
hdrs = ["pjrt_registry.h"],
deps = [
":debug_macros",
":env_hash",
Expand Down Expand Up @@ -292,7 +292,7 @@ cc_library(
srcs = ["profiler.cc"],
hdrs = ["profiler.h"],
deps = [
":debug_macros",
":tf_logging",
":profiler_backends",
"@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs",
"@xla//xla/backends/profiler/plugin:plugin_tracer",
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK";
const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC";
const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE";
const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION";
const char* const kEnvPjrtDynamicPlugins = "PJRT_DYNAMIC_PLUGINS";

} // namespace env
} // namespace runtime
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extern const char* const kEnvPjRtLocalRank;
extern const char* const kEnvPjrtAllocatorCudaAsync;
extern const char* const kEnvPjrtAllocatorPreallocate;
extern const char* const kEnvPjrtAllocatorFraction;
extern const char* const kEnvPjrtDynamicPlugins;

} // namespace env
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_hash.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/initialize_pjrt.h"
#include "torch_xla/csrc/runtime/pjrt_registry.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_hash.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/initialize_pjrt.h"
#include "torch_xla/csrc/runtime/operation_manager.h"
#include "torch_xla/csrc/runtime/pjrt_registry.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
namespace torch_xla {
namespace runtime {

std::unordered_map<std::string, std::string> pjrt_plugins_;

namespace {

xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
Expand All @@ -33,14 +35,35 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
return allocator_config;
}

std::optional<std::string> GetPjRtPluginPath(const std::string& device_type) {
auto plugin_path = pjrt_plugins_.find(device_type);
return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second)
: std::nullopt;
}

} // namespace

void RegisterPjRtPlugin(std::string name, std::string library_path) {
TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path;
pjrt_plugins_[name] = library_path;
}

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
InitializePjRt(const std::string& device_type) {
std::unique_ptr<xla::PjRtClient> client;
std::unique_ptr<XlaCoordinator> coordinator;

if (device_type == "CPU") {
if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) {
std::optional<std::string> plugin_path = GetPjRtPluginPath(device_type);
if (plugin_path) {
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;
const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin(
absl::AsciiStrToLower(device_type), *plugin_path);
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type));
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value();
profiler::RegisterProfilerForPlugin(c_api);
}
} else if (device_type == "CPU") {
TF_VLOG(1) << "Initializing PjRt CPU client...";
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
Expand Down Expand Up @@ -121,12 +144,10 @@ InitializePjRt(const std::string& device_type) {
"libneuronpjrt.so"))
.status());
client = std::move(xla::GetCApiClient("NEURON").value());
} else {
XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice,
device_type);
}

XLA_CHECK(client.get() != nullptr);
XLA_CHECK(client) << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice,
device_type);

return {std::move(client), std::move(coordinator)};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
namespace torch_xla {
namespace runtime {

void RegisterPjRtPlugin(std::string name, std::string library_path);

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
InitializePjRt(const std::string& device_type);

}
} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_INITIALIZE_PJRT_H_
6 changes: 4 additions & 2 deletions torch_xla/csrc/runtime/profiler.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "torch_xla/csrc/runtime/profiler.h"

#include "absl/container/flat_hash_map.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "tsl/platform/status.h"
#include "tsl/profiler/lib/profiler_factory.h"
#include "tsl/profiler/rpc/client/capture_profile.h"
Expand Down Expand Up @@ -58,7 +58,9 @@ tsl::Status Trace(

void RegisterProfilerForPlugin(const PJRT_Api* c_api) {
const PLUGIN_Profiler_Api* profiler_api = FindProfilerApi(c_api);
XLA_CHECK(profiler_api);
if (!profiler_api) {
TF_LOG(WARNING) << "Profiler API not found for PJRT plugin";
}

tsl::profiler::ProfilerFactory create_func =
[profiler_api](const tensorflow::ProfileOptions& options) {
Expand Down
Loading