From 86d8967c724ca5b836751ecdbec69b46e29dcbc3 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Wed, 12 Jun 2024 16:15:44 +0000 Subject: [PATCH] Register custom call with CUDA plugin (#7219) --- .github/workflows/_build_plugin.yml | 2 +- .github/workflows/_build_torch_with_cuda.yml | 5 +- .../workflows/_test_requiring_torch_cuda.yml | 26 +++++++- .github/workflows/build_and_test.yml | 2 +- torch_xla/csrc/BUILD | 4 ++ torch_xla/csrc/init_python_bindings.cpp | 64 +++++++++++++++++-- torch_xla/csrc/runtime/BUILD | 1 + .../csrc/runtime/pjrt_computation_client.cc | 10 +++ .../csrc/runtime/pjrt_computation_client.h | 3 + 9 files changed, 103 insertions(+), 14 deletions(-) diff --git a/.github/workflows/_build_plugin.yml b/.github/workflows/_build_plugin.yml index 69b93fd5b81..441dbc6a327 100644 --- a/.github/workflows/_build_plugin.yml +++ b/.github/workflows/_build_plugin.yml @@ -39,7 +39,7 @@ jobs: shell: bash run: | cd pytorch/xla/infra/ansible - ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps + ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5,8.6 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps - name: Upload wheel uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/_build_torch_with_cuda.yml b/.github/workflows/_build_torch_with_cuda.yml index 2be3eabe017..296e79b7dfb 100644 --- a/.github/workflows/_build_torch_with_cuda.yml +++ b/.github/workflows/_build_torch_with_cuda.yml @@ -20,7 +20,6 @@ jobs: runs-on: ${{ inputs.runner }} container: image: ${{ inputs.dev-image }} - options: "--gpus all --shm-size 16g" env: _GLIBCXX_USE_CXX11_ABI: 0 steps: @@ -34,8 +33,6 @@ jobs: run: | echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV - - name: Check GPU - run: nvidia-smi - name: Checkout PyTorch Repo uses: actions/checkout@v4 with: @@ -47,7 +44,7 @@ jobs: shell: bash run: | cd pytorch - USE_CUDA=1 python setup.py bdist_wheel + TORCH_CUDA_ARCH_LIST="5.2;8.6" USE_CUDA=1 python setup.py bdist_wheel - name: Upload wheel uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/_test_requiring_torch_cuda.yml b/.github/workflows/_test_requiring_torch_cuda.yml index a56b85616a5..466d72b84fe 100644 --- a/.github/workflows/_test_requiring_torch_cuda.yml +++ b/.github/workflows/_test_requiring_torch_cuda.yml @@ -30,10 +30,17 @@ on: jobs: test: - runs-on: ${{ inputs.runner }} container: image: ${{ inputs.dev-image }} options: "--gpus all --shm-size 16g" + strategy: + matrix: + include: + - run_python_tests: 'python_tests' + runner: ${{ inputs.runner }} + - run_triton_tests: 'triton_tests' + runner: 'linux.g5.4xlarge.nvidia.gpu' + runs-on: ${{ matrix.runner }} timeout-minutes: ${{ inputs.timeout-minutes }} env: USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }} @@ -98,9 +105,24 @@ jobs: uses: actions/checkout@v4 with: path: pytorch/xla - - name: Test + - name: Extra CI deps + shell: bash + run: | + set -x + pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html + pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + pip install --no-deps triton==2.3.0 + if: ${{ matrix.run_triton_tests }} + - name: Python Tests shell: bash run: | set -xue PJRT_DEVICE=CUDA python pytorch/xla/test/test_operations.py -v PJRT_DEVICE=CUDA python pytorch/xla/test/dynamo/test_dynamo.py -v + if: ${{ matrix.run_python_tests }} + - name: Triton Tests + shell: bash + run: | + PJRT_DEVICE=CUDA TRITON_PTXAS_PATH=/usr/local/cuda-12.1/bin/ptxas python pytorch/xla/test/test_triton.py + if: ${{ matrix.run_triton_tests }} diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 7cb8dd9adff..dcdd9a25177 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -42,7 +42,7 @@ jobs: # note that to build a torch wheel with CUDA enabled, we do not need a GPU runner. dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1 torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}} - runner: linux.8xlarge.nvidia.gpu + runner: linux.24xlarge build-cuda-plugin: name: "Build XLA CUDA plugin" diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index b85c57fba4e..52a802da958 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -269,6 +269,7 @@ ptxla_cc_library( ":tensor", ":version", "//torch_xla/csrc/runtime", + "//torch_xla/csrc/runtime:pjrt_computation_client", "//torch_xla/csrc/runtime:metrics", "//torch_xla/csrc/runtime:metrics_analysis", "//torch_xla/csrc/runtime:metrics_reader", @@ -290,6 +291,9 @@ ptxla_cc_library( "@xla//xla/service:sharding_propagation", "@xla//xla/service/spmd:spmd_partitioner", "@xla//xla/service:custom_call_target_registry", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_wrapper_impl", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", ], ) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 153915459cb..0c5e4f299ad 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -45,9 +45,11 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/runtime.h" @@ -67,7 +69,10 @@ #include "torch_xla/csrc/xla_sharding_util.h" #include "tsl/platform/env.h" #include "tsl/profiler/lib/traceme.h" +#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" +#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" #include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/pjrt_api.h" #include "xla/python/profiler/internal/traceme_wrapper.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/hlo_parser.h" @@ -2464,12 +2469,59 @@ void InitXlaModuleBindings(py::module m) { return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, /*is_tpu=*/false); }); - m.def("_xla_register_custom_call_target", - [](const std::string& fn_name, const py::capsule& function_ptr, - const std::string& platform) { - XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - fn_name, function_ptr.get_pointer(), platform); - }); + m.def("_xla_register_custom_call_target", [](const std::string& fn_name, + const py::capsule& function_ptr, + const std::string& platform) { + if (runtime::sys_util::GetEnvBool("XLA_USE_IFRT", false) || + platform != "CUDA") { + XLA_ERROR() << "Custom call targets can only be registered for " + "PJRT CUDA runtime." + << std::endl; + return; + } + if (runtime::sys_util::GetEnvBool(runtime::env::kEnvPjrtDynamicPlugins, + false)) { + runtime::PjRtComputationClient* client = + dynamic_cast( + runtime::GetComputationClient()); + if (!client) { + return; + } + const PJRT_Api* pjrt_api = client->GetPjRtCApiIfAvailable(); + if (!pjrt_api) { + return; + } + // See openxla reference: + // https://github.com/openxla/xla/blob/b604c8d87df842002a7a8de79a434026329fbcb2/xla/pjrt/c/pjrt_c_api_gpu_test.cc#L414 + const PJRT_Extension_Base* next = + reinterpret_cast( + pjrt_api->extension_start); + while (next != nullptr && + next->type != + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { + next = next->next; + } + if (next == nullptr) { + return; + } + PJRT_Gpu_Register_Custom_Call_Args args; + args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; + args.function_name = fn_name.c_str(); + args.function_name_size = fn_name.size(); + args.api_version = 0; + args.custom_call_function = + reinterpret_cast(function_ptr.get_pointer()); + PJRT_Error* error = + reinterpret_cast(next)->custom_call( + &args); + if (error) { + XLA_ERROR() << error->status << std::endl; + } + } else { + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( + fn_name, function_ptr.get_pointer(), platform); + } + }); m.def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, size_t max_call_stack_depth) -> bool { diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 623ffebfd56..01f717bfd34 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -127,6 +127,7 @@ cc_library( "@xla//xla/client:xla_computation", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/distributed", ], ) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 7f1b35cb310..ec1848b9b06 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -25,6 +25,7 @@ #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/protobuf_util.h" @@ -985,5 +986,14 @@ ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo( }; } +const PJRT_Api* PjRtComputationClient::GetPjRtCApiIfAvailable() const { + // dynamic_cast will return a nullptr if the client is not PjRtCApiClient. + auto* c_api_client = dynamic_cast(client_.get()); + if (c_api_client) { + return c_api_client->pjrt_c_api(); + } + return nullptr; +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 1d31107e6b9..1b08b55c601 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -16,6 +16,7 @@ #include "tsl/platform/threadpool.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" +#include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/shape.h" @@ -152,6 +153,8 @@ class PjRtComputationClient : public ComputationClient { std::vector PjRtDevicesToString( absl::Span devices) const; + const PJRT_Api* GetPjRtCApiIfAvailable() const; + private: std::unique_ptr client_; std::unique_ptr coordinator_;