From 4722a99d303f347d6088d95867d007c75ca6dd78 Mon Sep 17 00:00:00 2001 From: Don Isaac Date: Thu, 25 Jan 2024 14:46:50 -0500 Subject: [PATCH] feat(wgpu): add to_dtype kernel (#906) * feat(wgpu): add to_dtype kernel * fix: add WebGPUNativeType * style: clippy fix --------- Co-authored-by: Corey Lowman --- dfdx-core/src/tensor/webgpu/device.rs | 30 ++++++ dfdx-core/src/tensor/webgpu/mod.rs | 2 + dfdx-core/src/tensor/webgpu/types.rs | 56 +++++++++++ .../src/tensor_ops/to_dtype/to_dtype.wgsl | 16 +++ .../src/tensor_ops/to_dtype/webgpu_kernel.rs | 99 ++++++++++++++++++- .../tensor_ops/utilities/webgpu_kernels.rs | 45 +++++---- 6 files changed, 228 insertions(+), 20 deletions(-) create mode 100644 dfdx-core/src/tensor/webgpu/types.rs create mode 100644 dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 5e7b2c01..1c23989b 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -247,6 +247,36 @@ impl Webgpu { pub(crate) fn get_shader_module(&self, name: TypeId) -> Option> { self.cs_cache.read().get(&name).cloned() } + /// Submit a command buffer to the GPU. + /// + /// Note: Does not block until completion. If you need this, use + /// `self.dev.poll(Maintain::WaitForSubmissionIndex(idx))` using the + /// returned [`wgpu::SubmissionIndex`] + pub(crate) fn submit_commands( + &self, + label: Option<&str>, + command_builder: F, + ) -> wgpu::SubmissionIndex + where + F: FnOnce(&mut wgpu::CommandEncoder), + { + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: label.clone(), + }); + + if let Some(label) = label { + encoder.push_debug_group(label); + } + command_builder(&mut encoder); + if labe.is_some() { + encoder.pop_debug_group(); + } + + let cmd = [encoder.finish()]; + self.queue.submit(cmd) + } // #[allow(unused)] // pub(crate) unsafe fn get_workspace(&self, len: usize) -> Result, Error> { diff --git a/dfdx-core/src/tensor/webgpu/mod.rs b/dfdx-core/src/tensor/webgpu/mod.rs index 666ce53e..b22a5619 100644 --- a/dfdx-core/src/tensor/webgpu/mod.rs +++ b/dfdx-core/src/tensor/webgpu/mod.rs @@ -1,8 +1,10 @@ mod allocate; mod device; +mod types; pub use device::Buffer; pub use device::Webgpu; +pub use types::*; #[cfg(test)] mod tests { diff --git a/dfdx-core/src/tensor/webgpu/types.rs b/dfdx-core/src/tensor/webgpu/types.rs new file mode 100644 index 00000000..b94dd309 --- /dev/null +++ b/dfdx-core/src/tensor/webgpu/types.rs @@ -0,0 +1,56 @@ +use crate::shapes::Unit; + +/// A primitive data type natively supported by WebGPU. +/// +/// See: https://www.w3.org/TR/WGSL/#types +/// +/// todo: support packed types +pub trait WebgpuNativeType: Unit { + /// Name of the data type in WGSL. + const NAME: &'static str; +} + +macro_rules! webgpu_type { + ($RustTy:ty) => { + impl WebgpuNativeType for $RustTy { + const NAME: &'static str = stringify!($RustTy); + } + }; + ($RustTy:ty, $WgpuTy:expr) => { + impl WebgpuNativeType for $RustTy { + const NAME: &'static str = $WgpuTy; + } + }; +} + +/* +see: +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F16 +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F64 +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_I16 + */ +#[cfg(feature = "f16")] +webgpu_type!(half::f16, "f16"); +webgpu_type!(f32); +// todo: only enable when f64 feature is enabled +#[cfg(feature = "f64")] +webgpu_type!(f64); + +#[cfg(feature = "i16")] +webgpu_type!(i16); +webgpu_type!(i32); + +webgpu_type!(u32); +webgpu_type!(bool); + +pub(crate) trait HasGlslType { + const TYPE: &'static str; +} + +impl HasGlslType for f32 { + const TYPE: &'static str = "float"; +} + +impl HasGlslType for f64 { + const TYPE: &'static str = "double"; +} diff --git a/dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl b/dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl new file mode 100644 index 00000000..67d8fa43 --- /dev/null +++ b/dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl @@ -0,0 +1,16 @@ +alias T = __SRC__; +alias U = __DST__; + +@group(0) @binding(0) +var in: array; + +@group(0) @binding(1) +var out: array; + +@compute @workgroup_size(1, 1, 1) +fn main( + @builtin(global_invocation_id) global_id: vec3 +) { + let i = global_id.x; + out[i] = U(in[i]); +} diff --git a/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs index 111b930e..4f2be7f9 100644 --- a/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs @@ -1,9 +1,102 @@ -use crate::prelude::{Unit, Webgpu}; +use crate::{ + prelude::Storage, + tensor::webgpu::{Webgpu, WebgpuNativeType}, + tensor_ops::utilities::webgpu_kernels::webgpu_params, +}; +use num_traits::AsPrimitive; +use wgpu; -impl super::ToDtypeKernel for Webgpu { +/// kernel template +const KERNEL: &'static str = include_str!("./to_dtype.wgsl"); + +const LAYOUT_DESC: wgpu::BindGroupLayoutDescriptor = wgpu::BindGroupLayoutDescriptor { + label: Some("to-dtype"), + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], +}; + +impl, E2: WebgpuNativeType> super::ToDtypeKernel + for Webgpu +{ fn forward( inp: crate::prelude::Tensor, ) -> Result, crate::prelude::Error> { - todo!() + let module_name = std::format!("convert_{}_to_{}", E1::NAME, E2::NAME); + let label = Some(module_name.as_str()); + let device = inp.device; + + let layout = device.dev.create_bind_group_layout(&LAYOUT_DESC); + let shader_source: String = KERNEL + .replace("__SRC__", E1::NAME) + .replace("__DST__", E2::NAME); + + // TODO: support WGSL shaders in device shader cache + let source = wgpu::ShaderSource::Wgsl(shader_source.into()); + let shader_module = device + .dev + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some(shader_name), + source, + }); + let pipeline_layout = device + .dev + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: label.clone(), + bind_group_layouts: layouts, + // todo: these are useful and we should use them if the adapter supports them + push_constant_ranges: &push_constant_ranges, + }); + + let pipeline = device + .dev + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: label.clone(), + layout: Some(&pipeline_layout), + module: &shader_module, + entry_point: fn_name, + }); + + let numel = inp.shape.num_elements(); + let shape = inp.shape; + let strides = shape.strides(); + let output = unsafe { device.alloc_empty::(numel) }?; + + let params: wgpu::BindGroup = webgpu_params!(device, pipeline; inp.data, output); + + let _idx = device.submit_commands(label.clone(), |encoder| { + let (x, y, z) = *work_groups; + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: label.clone(), + ..Default::default() + }); + // TODO: should this be called before the pass, as the pass is created, or before submission? + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, ¶ms, &[]); + pass.dispatch_workgroups(numel as u32, 1, 1); + }); + + // note: no need to sync here, buffer can remain on the gpu until to_array or to_vec gets called, + // and those functions sync the device before mapping the buffer + Ok(device.build_tensor(shape, strides, output)) } } diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index 72cb3105..4f9ffd33 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -6,6 +6,33 @@ use crate::{ use core::any::TypeId; use std::{borrow::Cow, marker::PhantomData, sync::Arc, vec::Vec}; +use wgpu::{ + BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, +}; + +/// Creates a [`BindGroup`] for a pipeline from a set of [`wgpu::BindingResource`]s. +macro_rules! webgpu_params { + ($self:expr, $pipeline:expr; $($x:expr),+ $(,)? ) => { + { + let bindings = [$($x.as_entire_binding()),+]; + let entries: Vec<_> = bindings + .into_iter() + .enumerate() + .map(|(i, binding)| wgpu::BindGroupEntry { + binding: i as u32, + resource: binding, + }) + .collect(); + $self.dev.create_bind_group(&::wgpu::BindGroupDescriptor { + label: None, + layout: &($pipeline).get_bind_group_layout(0), + entries: &entries + }) + } + } +} +pub(crate) use webgpu_params; + pub(crate) trait UnaryOpWebgpuKernel { const DF_USES_FX: bool; const HAS_CONST_DF: bool; @@ -49,6 +76,7 @@ macro_rules! webgpu_unary { } }; } +pub(crate) use webgpu_unary; /// Zero-sized marker type for forward pass TypeId #[derive(Debug, Default)] @@ -62,23 +90,6 @@ pub(crate) struct Backward { _phantom: PhantomData<(E, K)>, } -pub(crate) trait HasGlslType { - const TYPE: &'static str; -} - -impl HasGlslType for f32 { - const TYPE: &'static str = "float"; -} - -impl HasGlslType for f64 { - const TYPE: &'static str = "double"; -} - -pub(crate) use webgpu_unary; -use wgpu::{ - BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, -}; - impl + 'static> UnaryKernel for Webgpu { const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF;