diff --git a/.github/workflows/_build_torch_with_cuda.yml b/.github/workflows/_build_torch_with_cuda.yml index e9defd40eb5..2be3eabe017 100644 --- a/.github/workflows/_build_torch_with_cuda.yml +++ b/.github/workflows/_build_torch_with_cuda.yml @@ -7,9 +7,9 @@ on: type: string description: Base image for builds torch-commit: - required: true - type: string - description: torch-commit + required: true + type: string + description: torch-commit runner: required: false type: string diff --git a/test/test_triton.py b/test/test_triton.py new file mode 100644 index 00000000000..d35a0c6d144 --- /dev/null +++ b/test/test_triton.py @@ -0,0 +1,68 @@ +import logging +import torch +from torch import nn as nn +import unittest + +import torch_xla.experimental.triton as xla_triton +import torch_xla +from torch_xla import runtime as xr + +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28 + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +class TritonTest(unittest.TestCase): + + @unittest.skipIf(xr.device_type() != 'CUDA', "This test only works on GPU.") + def test_gpu_custom_call_triton_add(self): + size = 16 + + x = torch.arange(size, dtype=torch.int64).to("xla") + y = torch.arange(size, dtype=torch.int64).to("xla") + output = torch.empty_like(x) + block_size = 8 + grid = (triton.cdiv(size, block_size),) + payload = xla_triton.triton_call( + x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size) + output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload, + [output.shape], [torch.int64]) + output_torch = x + y + self.assertTrue(torch.allclose(output[0].cpu(), output_torch.cpu())) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index a2aadc0c633..b85c57fba4e 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -289,6 +289,7 @@ ptxla_cc_library( "@xla//xla/service:hlo_verifier", "@xla//xla/service:sharding_propagation", "@xla//xla/service/spmd:spmd_partitioner", + "@xla//xla/service:custom_call_target_registry", ], ) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 170de7bb5be..fea968dc1a1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -69,6 +69,7 @@ #include "tsl/profiler/lib/traceme.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/python/profiler/internal/traceme_wrapper.h" +#include "xla/service/custom_call_target_registry.h" #include "xla/service/hlo_parser.h" namespace torch_xla { @@ -202,6 +203,24 @@ std::vector> CreateReduceGroups(const py::list& groups) { return replica_groups; } +std::vector XlaCustomCall( + const std::vector& inputs, const std::string& payload, + const std::vector>& output_shapes, + const std::vector& output_dtypes, bool is_tpu) { + std::vector dtypes; + dtypes.reserve(output_dtypes.size()); + for (auto& dtype : output_dtypes) { + dtypes.push_back(reinterpret_cast(dtype.ptr())->scalar_type); + } + + if (is_tpu) { + return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( + bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes)); + } + return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call( + bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes)); +} + std::vector> CreateSourceTargetPairs( const py::list& pairs) { std::vector> source_target_pairs; @@ -2401,16 +2420,22 @@ void InitXlaModuleBindings(py::module m) { const std::vector>& output_shapes, const std::vector& output_dtypes) -> std::vector { - std::vector dtypes; - dtypes.reserve(output_dtypes.size()); - for (auto& dtype : output_dtypes) { - dtypes.push_back( - reinterpret_cast(dtype.ptr())->scalar_type); - } - - auto xtensors = tensor_methods::tpu_custom_call( - bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes); - return bridge::AtenFromXlaTensors(xtensors); + return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, + /*is_tpu=*/true); + }); + m.def("_xla_gpu_custom_call", + [](const std::vector& inputs, const std::string& payload, + const std::vector>& output_shapes, + const std::vector& output_dtypes) + -> std::vector { + return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, + /*is_tpu=*/false); + }); + m.def("_xla_register_custom_call_target", + [](const std::string& fn_name, const py::capsule& function_ptr, + const std::string& platform) { + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( + fn_name, function_ptr.get_pointer(), platform); }); m.def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, diff --git a/torch_xla/csrc/ops/gpu_custom_call.cpp b/torch_xla/csrc/ops/gpu_custom_call.cpp new file mode 100644 index 00000000000..40804001fe5 --- /dev/null +++ b/torch_xla/csrc/ops/gpu_custom_call.cpp @@ -0,0 +1,37 @@ +#include "torch_xla/csrc/ops/gpu_custom_call.h" + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { + +GpuCustomCall::GpuCustomCall(torch::lazy::OpList inputs, + xla::Shape output_shape, + const std::string& payload) + : XlaNode(xla_gpu_custom_call, inputs, std::move(output_shape), + /*num_outputs=*/output_shape.tuple_shapes_size(), + torch::lazy::MHash(payload)), + payload_(payload) {} + +torch::lazy::NodePtr GpuCustomCall::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands, xla_shape(), payload_); +} + +XlaOpVector GpuCustomCall::Lower(LoweringContext* loctx) const { + std::vector inputs; + inputs.reserve(operands().size()); + for (auto& operand : operands()) { + inputs.push_back(loctx->GetOutputOp(operand)); + } + auto output = BuildGpuCustomCall(inputs, xla_shape(), payload_); + return ReturnOps(output, loctx); +} + +std::string GpuCustomCall::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", " << payload_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/gpu_custom_call.h b/torch_xla/csrc/ops/gpu_custom_call.h new file mode 100644 index 00000000000..fa08d62be67 --- /dev/null +++ b/torch_xla/csrc/ops/gpu_custom_call.h @@ -0,0 +1,25 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ +#define XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +class GpuCustomCall : public XlaNode { + public: + // Make a GPU custom call with payload, e.g., Triton. + GpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, + const std::string& payload); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + private: + std::string payload_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index e1aa70d56d6..d61c3cbc839 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -37,5 +37,6 @@ const OpKindWrapper xla_unselect("xla::unselect"); const OpKindWrapper xla_update_slice("xla::update_slice"); const OpKindWrapper xla_custom_sharding("xla::custom_sharding"); const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call"); +const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call"); } // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index fff50fe6bc3..8d8d7874364 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -62,6 +62,7 @@ extern const OpKindWrapper xla_unselect; extern const OpKindWrapper xla_update_slice; extern const OpKindWrapper xla_custom_sharding; extern const OpKindWrapper xla_tpu_custom_call; +extern const OpKindWrapper xla_gpu_custom_call; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7baa951c9a6..a557f55690d 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -59,6 +59,7 @@ #include "torch_xla/csrc/ops/generic.h" #include "torch_xla/csrc/ops/generic_slice.h" #include "torch_xla/csrc/ops/get_dimensions_size.h" +#include "torch_xla/csrc/ops/gpu_custom_call.h" #include "torch_xla/csrc/ops/hardtanh_backward.h" #include "torch_xla/csrc/ops/index_ops.h" #include "torch_xla/csrc/ops/index_select.h" @@ -566,6 +567,39 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } +std::vector gpu_custom_call( + const std::vector& inputs, const std::string& payload, + const std::vector>& output_shapes, + const std::vector& output_dtypes) { + XLA_CHECK(inputs.size() > 0) << "inputs are empty"; + + std::vector values; + values.reserve(inputs.size()); + for (const auto& input : inputs) { + values.push_back(input->GetIrValue()); + } + + XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size()); + std::vector output_xla_shapes; + output_xla_shapes.reserve(output_shapes.size()); + for (size_t i = 0; i < output_shapes.size(); ++i) { + output_xla_shapes.push_back(xla::ShapeUtil::MakeShape( + MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())), + output_shapes[i])); + } + + auto node = torch::lazy::MakeNode( + values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload); + + std::vector outputs; + outputs.reserve(output_shapes.size()); + for (size_t i = 0; i < output_shapes.size(); ++i) { + outputs.push_back( + inputs[0]->CreateFrom(torch::lazy::Value(node, i), output_dtypes[i])); + } + return outputs; +} + std::vector tpu_custom_call( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 1d565dd351a..11df2c6eb74 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -91,6 +91,11 @@ void custom_sharding_( const std::shared_ptr& spec, const CustomSharding::Type& type = CustomSharding::Type::kSharding); +std::vector gpu_custom_call( + const std::vector& inputs, const std::string& payload, + const std::vector>& output_shapes, + const std::vector& output_dtypes); + std::vector tpu_custom_call( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 0954c2fa3ac..33c5492b46b 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1272,11 +1272,35 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type, output_shape); } +std::vector BuildGpuCustomCall( + const std::vector& inputs, const xla::Shape& output_shape, + const std::string& payload) { + std::vector input_shapes; + input_shapes.reserve(inputs.size()); + for (const auto& input : inputs) { + input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input)); + } + + XLA_CHECK(inputs.size() > 0) << "inputs are empty"; + xla::XlaOp outputs = xla::CustomCallWithLayout( + inputs[0].builder(), + /*call_target_name=*/"triton_kernel_call", inputs, output_shape, + input_shapes, payload, false, {}, nullptr, + xla::CustomCallSchedule::SCHEDULE_NONE, + xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING); + std::vector result; + int num_outputs = output_shape.tuple_shapes_size(); + result.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + result.push_back(xla::GetTupleElement(outputs, i)); + } + return result; +} + std::vector BuildTpuCustomCall( const std::vector& inputs, const xla::Shape& output_shape, const std::string& payload) { XLA_CHECK(output_shape.IsTuple()) << "output_shape is not a tuple"; - // We need to enforce the default C-order (major-to-minor) layouts for inputs // to Mosaic and outputs from Mosaic. std::vector input_shapes; diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index d0e9afca9fa..400c8a51731 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -162,6 +162,10 @@ std::vector BuildTpuCustomCall( xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, xla::XlaOp iou_threshold); +std::vector BuildGpuCustomCall( + const std::vector& inputs, const xla::Shape& output_shape, + const std::string& payload); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_ diff --git a/torch_xla/experimental/triton.py b/torch_xla/experimental/triton.py new file mode 100644 index 00000000000..a4b361b26ec --- /dev/null +++ b/torch_xla/experimental/triton.py @@ -0,0 +1,211 @@ +"""Module for calling Triton kernels from Pytorch/XLA. + +Reference: https://github.com/jax-ml/jax-triton/blob/main/jax_triton/triton_lib.py + +""" + +from __future__ import annotations + +import os +from typing import Any, Callable, Dict, Tuple, Union +import zlib +import torch + +import numpy as np +import triton +import triton.language as tl +from jax._src.lib import gpu_triton as lib_triton +import torch_xla + +# Register target corresponding to gpu custom call using the +# implementation provided by jaxlib. +torch_xla._XLAC._xla_register_custom_call_target( + 'triton_kernel_call', lib_triton._cuda_triton.get_custom_call(), 'CUDA') + +Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] +GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]] + +NUM_WARPS = 4 +NUM_STAGES = 3 +NUM_CTAS = 1 + + +def normalize_grid(grid: GridOrLambda, metaparams) -> Tuple[int, int, int]: + if callable(grid): + grid = grid(metaparams) + if isinstance(grid, int): + grid = (grid,) + elif len(grid) > 3: + raise ValueError("`grid` should have three or fewer dimensions.") + return tuple(grid) + (1,) * (3 - len(grid)) + + +_TORCH_TO_TRITON_TYPE_MAP = { + torch.bfloat16: + "bf16", + torch.float64: + "fp64", + torch.float32: + "fp32", + torch.float16: + "fp16", + # Triton has 'fp8' as well which Jax doesn't support yet. + torch.int64: + "i64", + torch.int32: + "i32", + torch.int16: + "i16", + torch.int8: + "i8", + torch.uint64: + "u64", + torch.uint32: + "u32", + torch.uint16: + "u16", + torch.uint8: + "u8", + # Triton defines a 'B' type, which is an alias for both i1 and bool. + torch.bool: + "B", +} + + +def get_triton_type(obj: Any) -> str: + if torch.is_tensor(obj): + return f"*{_TORCH_TO_TRITON_TYPE_MAP[obj.dtype]}" + if isinstance(obj, tl.constexpr): + obj = obj.value + if isinstance(obj, int): + if -(2**31) <= obj < 2**31: + return "i32" + elif 2**31 <= obj < 2**32: + return "u32" + elif -(2**63) <= obj < 2**63: + return "i64" + elif 2**63 <= obj < 2**64: + return "u64" + else: + raise ValueError(f"integer overflow representing {obj}") + if isinstance(obj, float): + return "fp64" + if isinstance(obj, np.float32): + return "fp32" + if isinstance(obj, bool): + return "B" + if isinstance(obj, str): + return "str" + raise NotImplementedError( + f"could not compute type name for {obj}: {type(obj)}") + + +def get_or_create_triton_kernel( + fn, + compiled_kernel, + args, + dump, +) -> Tuple[lib_triton.TritonKernel, Any]: + # Extract the compilation parameters and compiled ptx from the + # compiled triton kernel. + ttir = compiled_kernel.asm['ttir'] + ptx = compiled_kernel.asm['ptx'] + if (dump): + print(ptx) + + shared_mem_bytes = compiled_kernel.metadata["shared"] + kernel_name = compiled_kernel.metadata["name"] + cluster_dims = compiled_kernel.metadata["cluster_dims"] + compute_capability = lib_triton.get_compute_capability(0) + kernel = lib_triton.TritonKernel( + kernel_name, + NUM_WARPS, + shared_mem_bytes, + ptx, + ttir, + compute_capability, + *cluster_dims, + ) + + specialization_attr = fn._get_config(*args) # pylint: disable=protected-access + return kernel, specialization_attr + + +def triton_kernel_call_lowering( + array_args, + fn, + compiled_kernel, + scalar_args, + grid, + debug, + **metaparams, +): + args = list(array_args) + arg_dtypes = list(map(get_triton_type, array_args)) + for idx, dtype, v in scalar_args: + args.insert(idx, v) + arg_dtypes.insert(idx, dtype) + + if not isinstance(fn, triton.JITFunction): + raise ValueError("`kernel` must be a Triton `JITFunction`.") + + #TODO: Add support for autotuner and heuristic functions. + config = triton.Config( + {}, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + num_ctas=NUM_CTAS, + ) + config_metaparams = {**metaparams, **config.kwargs} + config_grid = normalize_grid(grid, config_metaparams) + + kernel, specialization_attr = get_or_create_triton_kernel( + fn, + compiled_kernel, + args, + dump=debug, + ) + + kernel_params = [] + for i, (arg, dtype) in enumerate(zip(args, arg_dtypes)): + if isinstance(arg, torch.Tensor): + kernel_params.append( + lib_triton.create_array_parameter( + 0, + 16 if (i in specialization_attr.divisible_by_16) else 0, + )) + elif i not in specialization_attr.equal_to_1: + kernel_params.append(lib_triton.create_scalar_parameter(arg, dtype)) + + kernel_call = lib_triton.TritonKernelCall( + kernel, + config_grid[0], + config_grid[1], + config_grid[2], + kernel_params, + ) + + call_proto = kernel_call.to_proto("triton_kernel", b"") + return zlib.compress(call_proto) + + +def triton_call( + *args: Union[torch.Tensor, bool, int, float, np.float32], + kernel: triton.JITFunction, + grid: GridOrLambda, + debug: bool = False, + **metaparams: Any, +) -> Any: + array_args = [] + scalar_args = [] + for i, arg in enumerate(args): + if isinstance(arg, (bool, int, float)): + scalar_args.append((i, get_triton_type(arg), arg)) + elif isinstance(arg, np.float32): + scalar_args.append((i, get_triton_type(arg), float(arg))) + else: + array_args.append(arg) + + compiled_kernel = kernel.run(*args, grid=grid, warmup=True, **metaparams) + return triton_kernel_call_lowering(array_args, kernel, compiled_kernel, + scalar_args, grid, debug, **metaparams)