Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton #6798

Merged
merged 67 commits into from
Jun 7, 2024
Merged

Triton #6798

Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
6dccf0a
Update infra_triggers.tf
ManfeiBai Oct 4, 2023
9828123
Skeleton trition support
bhavya01 Mar 20, 2024
99bf48d
Merge branch 'master' into triton
bhavya01 Mar 20, 2024
b89e558
Fix bugs
bhavya01 Mar 21, 2024
64189bd
Fix custom call invocation
bhavya01 Mar 21, 2024
0c208ef
Refactor to include gpu custom call and create triton dir
bhavya01 Mar 22, 2024
b553ba7
Lint fixes
bhavya01 Mar 22, 2024
c5129e6
python lint fix
bhavya01 Mar 22, 2024
48e7127
Updated base image for CI
bhavya01 Mar 27, 2024
e04fc97
Update github workflow gcr image
bhavya01 Mar 28, 2024
37bf127
Merge branch 'master' into custom
bhavya01 Mar 28, 2024
6061895
Remove xrt build and test file
bhavya01 Mar 28, 2024
f59ddbf
Add temporary test to run triton kernel
bhavya01 Mar 28, 2024
158aed4
Fix tests
bhavya01 Mar 28, 2024
87b92c5
Update payload for xla gpu custom call
bhavya01 Mar 29, 2024
847ccc5
Update gpu runner
bhavya01 Mar 29, 2024
eca6d52
Merge branch 'master' into triton
bhavya01 Apr 4, 2024
2348ca3
Extract payload from triton kernel programatically
bhavya01 Apr 12, 2024
110c8c6
Merge branch 'master' into triton
bhavya01 Apr 12, 2024
a226150
Lint fixes
bhavya01 Apr 12, 2024
4c1f4f5
Only build triton files for GPU
bhavya01 Apr 12, 2024
431f822
build pytorch for ampere gpus
bhavya01 Apr 13, 2024
4bade16
c++ lint fix
bhavya01 Apr 13, 2024
1c5b47d
Python lint fix
bhavya01 Apr 13, 2024
3138a92
Fix torch cuda arch list
bhavya01 Apr 13, 2024
3f00cfd
Use a bigger machine for CI build
bhavya01 Apr 13, 2024
e729cfb
Add triton test to run_tests.sh
bhavya01 Apr 13, 2024
8e304c0
Update triton env variable
bhavya01 Apr 15, 2024
27bdc3a
Set up a separate CI for triton tests
bhavya01 Apr 15, 2024
9a3ef84
Fix github workflow to add _triton.yml
bhavya01 Apr 15, 2024
ade444d
Rebuild torch xla for triton tests
bhavya01 Apr 15, 2024
cb0bb85
Create a separate CI tab for triton tests
bhavya01 Apr 16, 2024
015b1ad
Separate build and test phase for triton
bhavya01 Apr 16, 2024
a18028a
Fix flags for docker run container
bhavya01 Apr 16, 2024
993ee92
Update triton.yml to output docker image
bhavya01 Apr 16, 2024
a87b782
Add a python binding to register custom calls and remove jax files
bhavya01 May 10, 2024
bf05d1b
Fix lint
bhavya01 May 10, 2024
4582fe8
Merge main
bhavya01 May 10, 2024
9680167
Merge master
bhavya01 May 10, 2024
a7b94c6
Merge master after updating
bhavya01 May 10, 2024
e14636a
Update CI to use cuda plugin
bhavya01 May 10, 2024
256d819
Install jaxlib while setting up triton tests
bhavya01 May 10, 2024
c616e64
Install triton package while running triton tests
bhavya01 May 10, 2024
60b8d18
Experimental: Build pytorch with cuda
bhavya01 May 13, 2024
2bde624
Revert build pytorch with CUDA
bhavya01 May 14, 2024
e6c4e0a
Merge branch 'master' into triton
bhavya01 May 14, 2024
14ee545
Remove ansible path for triton CI
bhavya01 May 14, 2024
25acb26
Style fixes
bhavya01 May 20, 2024
6b0ac18
[Experimental] test new CI
bhavya01 May 28, 2024
4d97150
[Experimental] Set XLA_CUDA=0 for cuda arch in ansible
bhavya01 May 28, 2024
e079049
[Experimental] Update CI to build pytorch cuda with ansible
bhavya01 May 29, 2024
d9c89b6
Update CI
bhavya01 May 30, 2024
7a6c809
Fix CI workflow file
bhavya01 May 30, 2024
6b1954d
Fix CI workflow
bhavya01 May 30, 2024
21797a6
Fix the wheels installed for tests requiring torch cuda
bhavya01 May 30, 2024
e6e89d3
Add compute_capability=8.6 for xla cuda plugin
bhavya01 May 31, 2024
ac45fe1
update TORCH_CUDA_ARCH_LIST
bhavya01 May 31, 2024
f828fbb
Experimental build torch and torch_xla cuda wheels
bhavya01 May 31, 2024
ac56c00
Merge branch 'master' into triton
bhavya01 May 31, 2024
c3b8653
Update build_and_test.yml
bhavya01 May 31, 2024
a1168c6
Update dlpack test to only use one device
bhavya01 May 31, 2024
39551a2
Remove compute capability 8.6 from cuda plugin
bhavya01 May 31, 2024
35e0869
Remove triton.sh
bhavya01 May 31, 2024
f95d898
Default empty torch_cuda_arch_list in ansible config
bhavya01 May 31, 2024
291104d
Merge branch 'master' into triton
bhavya01 Jun 5, 2024
f5c9b1a
Revert CI changes
bhavya01 Jun 6, 2024
5b23969
Revert CI changes pt2
bhavya01 Jun 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ function install_deps_pytorch_xla() {
if ls $CUBLAS_PATTERN 1> /dev/null 2>&1; then
sudo ln -s $CUBLAS_PATTERN /usr/local/cuda/include
fi
pip install --no-deps triton==2.3.0
}

function build_torch_xla() {
Expand Down
36 changes: 36 additions & 0 deletions .circleci/triton.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new script to build torch-xla?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton only works with GPUs with compute capability > 7 and the existing CI GPUs are at 5.2. So, we need to rebuild pytorch with CUDA support for the new GPUs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment at the top to suggest why. @will-cromar can you review this part? Appreciate it.

Copy link
Collaborator

@will-cromar will-cromar May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't add any new build scripts. These should have been deleted a long time ago, but nobody has had time to do a full refactor. There's already an ansible setting cuda_compute_capabilities to update the compute capabilities, which is set here:

ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps


set -ex

source .circleci/common.sh
PYTORCH_DIR=/tmp/pytorch
XLA_DIR=$PYTORCH_DIR/xla
clone_pytorch $PYTORCH_DIR $XLA_DIR

# Use bazel cache
USE_CACHE=1

pushd $PYTORCH_DIR
checkout_torch_pin_if_available

if ! install_deps_pytorch_xla $XLA_DIR $USE_CACHE; then
exit 1
fi

apply_patches

python -c "import fcntl; fcntl.fcntl(1, fcntl.F_SETFL, 0)"
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

export PATH=$PATH:/usr/local/cuda-12.1/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64
export USE_CUDA=1
export TORCH_CUDA_ARCH_LIST='8.6'
python setup.py install

XLA_DIR=$PYTORCH_DIR/xla
export TF_CUDA_COMPUTE_CAPABILITIES="compute_86"
export XLA_CUDA=1
build_torch_xla $XLA_DIR
96 changes: 96 additions & 0 deletions .github/workflows/triton.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
on:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, can we augment the existing GPU flow to have Triton as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with using the existing GPU flow is that it build pytorch without cuda dependency which doesn't work with Triton as Triton uses PyTorch to detect if there is a GPU device.

It would be great if we can use this for now while I figure out how to merge the two CI workflows.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as the triton.sh. @will-cromar

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it feasible to merge your new tests into the existing workflow and upgrade the GPUs we use there? It's probably time we use more modern GPUs for all of our CI tests anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can upgrade the runner here:

runner: linux.8xlarge.nvidia.gpu

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check with pytorch folks first..

pull_request:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'torch_xla/experimental/torch_triton.py'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'torch_xla/experimental/torch_triton.py'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true

jobs:
build-triton:
runs-on: linux.24xlarge
timeout-minutes: 300
outputs:
docker-image: ${{ steps.upload-docker-image.outputs.docker-image }}
env:
DOCKER_IMAGE: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
ECR_DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base
WORKDIR: /triton_dir
steps:
- name: Setup Linux
uses: pytorch/test-infra/.github/actions/setup-linux@main
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
instructions: |
Tests are done inside the container, to start an interactive session run:
docker exec -it $(docker container ps --format '{{.ID}}') bash
- name: Checkout repo
uses: actions/checkout@v3
- name: Download docker image from GCR
shell: bash
run: docker pull "${DOCKER_IMAGE}"
- name: Start the container
shell: bash
run: |
pid=$(docker run --privileged -t -d -w "${WORKDIR}" "${DOCKER_IMAGE}")
docker cp "${GITHUB_WORKSPACE}/." "$pid:$WORKDIR"
echo "pid=${pid}" >> "${GITHUB_ENV}"
- name: Build and Test
shell: bash
run: |
docker exec --privileged "${pid}" bash -c ".circleci/triton.sh"
- name: Push built docker image to ECR
id: upload-docker-image
shell: bash
run: |
export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:triton-${GITHUB_SHA}"
time docker commit "${pid}" "${COMMIT_DOCKER_IMAGE}"
time docker push "${COMMIT_DOCKER_IMAGE}"
echo "docker-image=${COMMIT_DOCKER_IMAGE}" >> "${GITHUB_OUTPUT}"
- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always()
test-triton:
runs-on: linux.g5.4xlarge.nvidia.gpu
timeout-minutes: 300
needs: build-triton
env:
DOCKER_IMAGE: ${{ needs.build-triton.outputs.docker-image }}
WORKDIR: /triton_dir
steps:
- name: Setup Linux
uses: pytorch/test-infra/.github/actions/setup-linux@main
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
instructions: |
Tests are done inside the container, to start an interactive session run:
docker exec -it $(docker container ps --format '{{.ID}}') bash
- name: Download and run docker image from GCR
shell: bash
run: |
echo "DOCKER_IMAGE: ${DOCKER_IMAGE}"
docker pull "${DOCKER_IMAGE}"
pid=$(docker run --shm-size=16g ${GPU_FLAG:-} -t -d -w "$WORKDIR" "${DOCKER_IMAGE}")
echo "pid=${pid}" >> "${GITHUB_ENV}"
- name: Test
shell: bash
run: |
docker exec --privileged "${pid}" bash -c 'TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas python test/test_triton.py'
- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always()
68 changes: 68 additions & 0 deletions test/test_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import torch
from torch import nn as nn
import unittest

import torch_xla.experimental.torch_triton as torch_triton
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
import torch_xla
from torch_xla import runtime as xr

import triton
import triton.language as tl


@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask)


class TritonTest(unittest.TestCase):

@unittest.skipIf(xr.device_type() != 'CUDA', "This test only works on GPU.")
def test_gpu_custom_call_triton_add(self):
size = 16

x = torch.arange(size, dtype=torch.int64).to("xla")
y = torch.arange(size, dtype=torch.int64).to("xla")
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size), )
payload = torch_triton.triton_call(
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
[output.shape], [torch.int64])
output_torch = x + y
self.assertTrue(torch.allclose(output[0].cpu(), output_torch.cpu()))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ ptxla_cc_library(
"@xla//xla/service:hlo_verifier",
"@xla//xla/service:sharding_propagation",
"@xla//xla/service/spmd:spmd_partitioner",
"@xla//xla/service:custom_call_target_registry",
],
)

Expand Down
23 changes: 23 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include "tsl/profiler/lib/traceme.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/python/profiler/internal/traceme_wrapper.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/hlo_parser.h"

namespace torch_xla {
Expand Down Expand Up @@ -2356,6 +2357,28 @@ void InitXlaModuleBindings(py::module m) {
bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes);
return bridge::AtenFromXlaTensors(xtensors);
});
m.def("_xla_gpu_custom_call",
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes)
-> std::vector<at::Tensor> {
std::vector<at::ScalarType> dtypes;
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
dtypes.push_back(
reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
}

auto xtensors = tensor_methods::gpu_custom_call(
bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes);
return bridge::AtenFromXlaTensors(xtensors);
});
m.def("_xla_register_custom_call_target",
[](const std::string& fn_name, const py::capsule& function_ptr,
const std::string& platform) {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
fn_name, function_ptr.get_pointer(), platform);
});
m.def("_set_xla_custom_op_name_prefix",
[](const at::Tensor& input, const std::string& op_name_prefix,
size_t max_call_stack_depth) -> bool {
Expand Down
37 changes: 37 additions & 0 deletions torch_xla/csrc/ops/gpu_custom_call.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "torch_xla/csrc/ops/gpu_custom_call.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {

GpuCustomCall::GpuCustomCall(torch::lazy::OpList inputs,
xla::Shape output_shape,
const std::string& payload)
: XlaNode(xla_gpu_custom_call, inputs, std::move(output_shape),
/*num_outputs=*/output_shape.tuple_shapes_size(),
torch::lazy::MHash(payload)),
payload_(payload) {}

torch::lazy::NodePtr GpuCustomCall::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<GpuCustomCall>(operands, xla_shape(), payload_);
}

XlaOpVector GpuCustomCall::Lower(LoweringContext* loctx) const {
std::vector<xla::XlaOp> inputs;
inputs.reserve(operands().size());
for (auto& operand : operands()) {
inputs.push_back(loctx->GetOutputOp(operand));
}
auto output = BuildGpuCustomCall(inputs, xla_shape(), payload_);
return ReturnOps(output, loctx);
}

std::string GpuCustomCall::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", " << payload_;
return ss.str();
}

} // namespace torch_xla
26 changes: 26 additions & 0 deletions torch_xla/csrc/ops/gpu_custom_call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_
#define XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
// TODO: Merge GPU and TPU custom call.
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
class GpuCustomCall : public XlaNode {
public:
// Make a GPU custom call with payload, e.g., Triton.
GpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape,
const std::string& payload);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

private:
std::string payload_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ const OpKindWrapper xla_unselect("xla::unselect");
const OpKindWrapper xla_update_slice("xla::update_slice");
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");
const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call");

} // namespace torch_xla
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ extern const OpKindWrapper xla_unselect;
extern const OpKindWrapper xla_update_slice;
extern const OpKindWrapper xla_custom_sharding;
extern const OpKindWrapper xla_tpu_custom_call;
extern const OpKindWrapper xla_gpu_custom_call;

} // namespace torch_xla

Expand Down
34 changes: 34 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "torch_xla/csrc/ops/generic.h"
#include "torch_xla/csrc/ops/generic_slice.h"
#include "torch_xla/csrc/ops/get_dimensions_size.h"
#include "torch_xla/csrc/ops/gpu_custom_call.h"
#include "torch_xla/csrc/ops/hardtanh_backward.h"
#include "torch_xla/csrc/ops/index_ops.h"
#include "torch_xla/csrc/ops/index_select.h"
Expand Down Expand Up @@ -565,6 +566,39 @@ void custom_sharding_(
input->SetShardingSpec(*sharding_spec);
}

std::vector<XLATensorPtr> gpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes) {
XLA_CHECK(inputs.size() > 0) << "inputs are empty";

std::vector<torch::lazy::Value> values;
values.reserve(inputs.size());
for (const auto& input : inputs) {
values.push_back(input->GetIrValue());
}

XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
std::vector<xla::Shape> output_xla_shapes;
output_xla_shapes.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
output_shapes[i]));
}

auto node = torch::lazy::MakeNode<GpuCustomCall>(
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload);

std::vector<XLATensorPtr> outputs;
outputs.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
outputs.push_back(
inputs[0]->CreateFrom(torch::lazy::Value(node, i), output_dtypes[i]));
}
return outputs;
}

std::vector<XLATensorPtr> tpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ void custom_sharding_(
const std::shared_ptr<XLATensor::ShardingSpec>& spec,
const CustomSharding::Type& type = CustomSharding::Type::kSharding);

std::vector<XLATensorPtr> gpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes);

std::vector<XLATensorPtr> tpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down
Loading
Loading