Skip to content

Commit

Permalink
Register custom call with CUDA plugin (#7219)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 committed Jun 12, 2024
1 parent c65e87b commit 86d8967
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/_build_plugin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
shell: bash
run: |
cd pytorch/xla/infra/ansible
ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps
ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5,8.6 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps
- name: Upload wheel
uses: actions/upload-artifact@v4
with:
Expand Down
5 changes: 1 addition & 4 deletions .github/workflows/_build_torch_with_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
runs-on: ${{ inputs.runner }}
container:
image: ${{ inputs.dev-image }}
options: "--gpus all --shm-size 16g"
env:
_GLIBCXX_USE_CXX11_ABI: 0
steps:
Expand All @@ -34,8 +33,6 @@ jobs:
run: |
echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV
- name: Check GPU
run: nvidia-smi
- name: Checkout PyTorch Repo
uses: actions/checkout@v4
with:
Expand All @@ -47,7 +44,7 @@ jobs:
shell: bash
run: |
cd pytorch
USE_CUDA=1 python setup.py bdist_wheel
TORCH_CUDA_ARCH_LIST="5.2;8.6" USE_CUDA=1 python setup.py bdist_wheel
- name: Upload wheel
uses: actions/upload-artifact@v4
with:
Expand Down
26 changes: 24 additions & 2 deletions .github/workflows/_test_requiring_torch_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@ on:

jobs:
test:
runs-on: ${{ inputs.runner }}
container:
image: ${{ inputs.dev-image }}
options: "--gpus all --shm-size 16g"
strategy:
matrix:
include:
- run_python_tests: 'python_tests'
runner: ${{ inputs.runner }}
- run_triton_tests: 'triton_tests'
runner: 'linux.g5.4xlarge.nvidia.gpu'
runs-on: ${{ matrix.runner }}
timeout-minutes: ${{ inputs.timeout-minutes }}
env:
USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }}
Expand Down Expand Up @@ -98,9 +105,24 @@ jobs:
uses: actions/checkout@v4
with:
path: pytorch/xla
- name: Test
- name: Extra CI deps
shell: bash
run: |
set -x
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --no-deps triton==2.3.0
if: ${{ matrix.run_triton_tests }}
- name: Python Tests
shell: bash
run: |
set -xue
PJRT_DEVICE=CUDA python pytorch/xla/test/test_operations.py -v
PJRT_DEVICE=CUDA python pytorch/xla/test/dynamo/test_dynamo.py -v
if: ${{ matrix.run_python_tests }}
- name: Triton Tests
shell: bash
run: |
PJRT_DEVICE=CUDA TRITON_PTXAS_PATH=/usr/local/cuda-12.1/bin/ptxas python pytorch/xla/test/test_triton.py
if: ${{ matrix.run_triton_tests }}
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
# note that to build a torch wheel with CUDA enabled, we do not need a GPU runner.
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}
runner: linux.8xlarge.nvidia.gpu
runner: linux.24xlarge

build-cuda-plugin:
name: "Build XLA CUDA plugin"
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ ptxla_cc_library(
":tensor",
":version",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:pjrt_computation_client",
"//torch_xla/csrc/runtime:metrics",
"//torch_xla/csrc/runtime:metrics_analysis",
"//torch_xla/csrc/runtime:metrics_reader",
Expand All @@ -290,6 +291,9 @@ ptxla_cc_library(
"@xla//xla/service:sharding_propagation",
"@xla//xla/service/spmd:spmd_partitioner",
"@xla//xla/service:custom_call_target_registry",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_wrapper_impl",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
],
)

Expand Down
64 changes: 58 additions & 6 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/metrics_analysis.h"
#include "torch_xla/csrc/runtime/metrics_reader.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_registry.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/runtime.h"
Expand All @@ -67,7 +69,10 @@
#include "torch_xla/csrc/xla_sharding_util.h"
#include "tsl/platform/env.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/python/profiler/internal/traceme_wrapper.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/hlo_parser.h"
Expand Down Expand Up @@ -2464,12 +2469,59 @@ void InitXlaModuleBindings(py::module m) {
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
/*is_tpu=*/false);
});
m.def("_xla_register_custom_call_target",
[](const std::string& fn_name, const py::capsule& function_ptr,
const std::string& platform) {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
fn_name, function_ptr.get_pointer(), platform);
});
m.def("_xla_register_custom_call_target", [](const std::string& fn_name,
const py::capsule& function_ptr,
const std::string& platform) {
if (runtime::sys_util::GetEnvBool("XLA_USE_IFRT", false) ||
platform != "CUDA") {
XLA_ERROR() << "Custom call targets can only be registered for "
"PJRT CUDA runtime."
<< std::endl;
return;
}
if (runtime::sys_util::GetEnvBool(runtime::env::kEnvPjrtDynamicPlugins,
false)) {
runtime::PjRtComputationClient* client =
dynamic_cast<runtime::PjRtComputationClient*>(
runtime::GetComputationClient());
if (!client) {
return;
}
const PJRT_Api* pjrt_api = client->GetPjRtCApiIfAvailable();
if (!pjrt_api) {
return;
}
// See openxla reference:
// https://github.com/openxla/xla/blob/b604c8d87df842002a7a8de79a434026329fbcb2/xla/pjrt/c/pjrt_c_api_gpu_test.cc#L414
const PJRT_Extension_Base* next =
reinterpret_cast<const PJRT_Extension_Base*>(
pjrt_api->extension_start);
while (next != nullptr &&
next->type !=
PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) {
next = next->next;
}
if (next == nullptr) {
return;
}
PJRT_Gpu_Register_Custom_Call_Args args;
args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
args.function_name = fn_name.c_str();
args.function_name_size = fn_name.size();
args.api_version = 0;
args.custom_call_function =
reinterpret_cast<void*>(function_ptr.get_pointer());
PJRT_Error* error =
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call(
&args);
if (error) {
XLA_ERROR() << error->status << std::endl;
}
} else {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
fn_name, function_ptr.get_pointer(), platform);
}
});
m.def("_set_xla_custom_op_name_prefix",
[](const at::Tensor& input, const std::string& op_name_prefix,
size_t max_call_stack_depth) -> bool {
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ cc_library(
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/pjrt/distributed",
],
)
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "xla/client/xla_computation.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/protobuf_util.h"
Expand Down Expand Up @@ -985,5 +986,14 @@ ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo(
};
}

const PJRT_Api* PjRtComputationClient::GetPjRtCApiIfAvailable() const {
// dynamic_cast will return a nullptr if the client is not PjRtCApiClient.
auto* c_api_client = dynamic_cast<xla::PjRtCApiClient*>(client_.get());
if (c_api_client) {
return c_api_client->pjrt_c_api();
}
return nullptr;
}

} // namespace runtime
} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "tsl/platform/threadpool.h"
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -152,6 +153,8 @@ class PjRtComputationClient : public ComputationClient {
std::vector<std::string> PjRtDevicesToString(
absl::Span<xla::PjRtDevice* const> devices) const;

const PJRT_Api* GetPjRtCApiIfAvailable() const;

private:
std::unique_ptr<xla::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
Expand Down

0 comments on commit 86d8967

Please sign in to comment.