From 92c8fe500c054220ab3d397b949ef7b3a58a612b Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Sun, 26 Nov 2023 12:18:03 -0800 Subject: [PATCH 01/16] Removed some of the more low level commands in favor of a wrapper struct Also added tests for higher code coverage. --- dfdx-core/src/tensor/webgpu/allocate.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 49162381..803da854 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -52,7 +52,7 @@ impl Webgpu { } } -impl ZerosTensor for Webgpu { +impl> ZerosTensor for Webgpu { fn try_zeros_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let strides = shape.strides(); @@ -63,7 +63,7 @@ impl ZerosTensor for Webgpu { } } -impl ZeroFillStorage for Webgpu { +impl> ZeroFillStorage for Webgpu { fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> { storage.copy_to_device(&self.dev, &self.queue, &vec![0u8; storage.size()]); From 8a35784a1e5c8a5ce59e643c365a0225bde43c6c Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Mon, 27 Nov 2023 12:52:00 -0800 Subject: [PATCH 02/16] AtomicPtr unsound fix --- dfdx-core/src/tensor/webgpu/allocate.rs | 22 +++++++++++----------- dfdx-core/src/tensor/webgpu/device.rs | 18 ++++++++++-------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 803da854..22d4ae30 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -17,7 +17,7 @@ pub(crate) fn round_to_buffer_alignment(size: u64) -> u64 { } impl Webgpu { - fn tensor_from_host_buf( + fn tensor_from_host_buf( &self, shape: S, buf: Vec, @@ -28,7 +28,7 @@ impl Webgpu { Ok(self.build_tensor(shape, shape.strides(), buffer)) } - pub(crate) fn build_tensor( + pub(crate) fn build_tensor( &self, shape: S, strides: S::Concrete, @@ -52,7 +52,7 @@ impl Webgpu { } } -impl> ZerosTensor for Webgpu { +impl + bytemuck::Pod> ZerosTensor for Webgpu { fn try_zeros_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let strides = shape.strides(); @@ -63,7 +63,7 @@ impl> ZerosTensor for Webgpu { } } -impl> ZeroFillStorage for Webgpu { +impl + bytemuck::Pod> ZeroFillStorage for Webgpu { fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> { storage.copy_to_device(&self.dev, &self.queue, &vec![0u8; storage.size()]); @@ -71,7 +71,7 @@ impl> ZeroFillStorage for Webgpu { } } -impl OnesTensor for Webgpu { +impl OnesTensor for Webgpu { fn try_ones_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let buf = vec![E::ONE; shape.num_elements()]; @@ -79,7 +79,7 @@ impl OnesTensor for Webgpu { } } -impl TriangleTensor for Webgpu +impl TriangleTensor for Webgpu where Cpu: TriangleTensor, { @@ -110,7 +110,7 @@ where } } -impl OneFillStorage for Webgpu { +impl OneFillStorage for Webgpu { fn try_fill_with_ones(&self, storage: &mut Self::Vec) -> Result<(), Error> { let len = storage.size() as usize / std::mem::size_of::(); let buf = vec![E::ONE; len]; @@ -122,7 +122,7 @@ impl OneFillStorage for Webgpu { } } -impl SampleTensor for Webgpu +impl SampleTensor for Webgpu where Cpu: SampleTensor, { @@ -168,7 +168,7 @@ where } } -impl CopySlice for Webgpu { +impl CopySlice for Webgpu { fn copy_from(dst: &mut Tensor, src: &[E]) { assert_eq!( dst.data.size() as usize, @@ -192,7 +192,7 @@ impl CopySlice for Webgpu { } } -impl TensorFromVec for Webgpu { +impl TensorFromVec for Webgpu { fn try_tensor_from_vec( &self, src: Vec, @@ -208,7 +208,7 @@ impl TensorFromVec for Webgpu { } } -impl TensorToArray for Webgpu +impl TensorToArray for Webgpu where Cpu: TensorToArray + Storage, { diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 3cba06c7..e5717148 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -21,6 +21,8 @@ use std::{marker::PhantomData, sync::Arc, vec::Vec}; use super::allocate::round_to_buffer_alignment; +static CONSTRUCTOR_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + #[derive(Debug)] pub struct Buffer { pub(crate) data: wgpu::Buffer, @@ -45,13 +47,13 @@ impl Buffer { self.data.size() as usize } - pub(crate) fn copy_to_device(&self, dev: &Device, queue: &Queue, slice: &[E]) { - let slice = unsafe { - std::slice::from_raw_parts( - slice.as_ptr() as *const u8, - slice.len() * std::mem::size_of::(), - ) - }; + pub(crate) fn copy_to_device( + &self, + dev: &Device, + queue: &Queue, + slice: &[E], + ) { + let slice = bytemuck::cast_slice(slice); queue.write_buffer(&self.data, 0, slice); queue.submit(std::iter::empty()); dev.poll(Maintain::Wait); @@ -308,7 +310,7 @@ impl Synchronize for Webgpu { } } -impl Storage for Webgpu { +impl Storage for Webgpu { type Vec = CachableBuffer; fn try_alloc_len(&self, len: usize) -> Result { From d867cd83824830a46645510e56d6b7b8f411fe7e Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Mon, 27 Nov 2023 14:06:20 -0800 Subject: [PATCH 03/16] Partial implementation of `Device` for Webgpu --- dfdx-core/src/tensor/webgpu/allocate.rs | 22 +++++++++++----------- dfdx-core/src/tensor/webgpu/device.rs | 11 +++-------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 22d4ae30..0c06f65f 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -17,7 +17,7 @@ pub(crate) fn round_to_buffer_alignment(size: u64) -> u64 { } impl Webgpu { - fn tensor_from_host_buf( + fn tensor_from_host_buf( &self, shape: S, buf: Vec, @@ -28,7 +28,7 @@ impl Webgpu { Ok(self.build_tensor(shape, shape.strides(), buffer)) } - pub(crate) fn build_tensor( + pub(crate) fn build_tensor( &self, shape: S, strides: S::Concrete, @@ -52,7 +52,7 @@ impl Webgpu { } } -impl + bytemuck::Pod> ZerosTensor for Webgpu { +impl> ZerosTensor for Webgpu { fn try_zeros_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let strides = shape.strides(); @@ -63,7 +63,7 @@ impl + bytemuck::Pod> ZerosTensor for Webgpu } } -impl + bytemuck::Pod> ZeroFillStorage for Webgpu { +impl> ZeroFillStorage for Webgpu { fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> { storage.copy_to_device(&self.dev, &self.queue, &vec![0u8; storage.size()]); @@ -71,7 +71,7 @@ impl + bytemuck::Pod> ZeroFillStorage for Web } } -impl OnesTensor for Webgpu { +impl OnesTensor for Webgpu { fn try_ones_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let buf = vec![E::ONE; shape.num_elements()]; @@ -79,7 +79,7 @@ impl OnesTensor for Webgpu { } } -impl TriangleTensor for Webgpu +impl TriangleTensor for Webgpu where Cpu: TriangleTensor, { @@ -110,7 +110,7 @@ where } } -impl OneFillStorage for Webgpu { +impl OneFillStorage for Webgpu { fn try_fill_with_ones(&self, storage: &mut Self::Vec) -> Result<(), Error> { let len = storage.size() as usize / std::mem::size_of::(); let buf = vec![E::ONE; len]; @@ -122,7 +122,7 @@ impl OneFillStorage for Webgpu { } } -impl SampleTensor for Webgpu +impl SampleTensor for Webgpu where Cpu: SampleTensor, { @@ -168,7 +168,7 @@ where } } -impl CopySlice for Webgpu { +impl CopySlice for Webgpu { fn copy_from(dst: &mut Tensor, src: &[E]) { assert_eq!( dst.data.size() as usize, @@ -192,7 +192,7 @@ impl CopySlice for Webgpu { } } -impl TensorFromVec for Webgpu { +impl TensorFromVec for Webgpu { fn try_tensor_from_vec( &self, src: Vec, @@ -208,7 +208,7 @@ impl TensorFromVec for Webgpu { } } -impl TensorToArray for Webgpu +impl TensorToArray for Webgpu where Cpu: TensorToArray + Storage, { diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index e5717148..ee9959dc 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -47,13 +47,8 @@ impl Buffer { self.data.size() as usize } - pub(crate) fn copy_to_device( - &self, - dev: &Device, - queue: &Queue, - slice: &[E], - ) { - let slice = bytemuck::cast_slice(slice); + pub(crate) fn copy_to_device(&self, dev: &Device, queue: &Queue, slice: &[E]) { + let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, self.size()) }; queue.write_buffer(&self.data, 0, slice); queue.submit(std::iter::empty()); dev.poll(Maintain::Wait); @@ -310,7 +305,7 @@ impl Synchronize for Webgpu { } } -impl Storage for Webgpu { +impl Storage for Webgpu { type Vec = CachableBuffer; fn try_alloc_len(&self, len: usize) -> Result { From 9694c1c221279ed006a455b3ebeb875396d610c7 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Mon, 27 Nov 2023 14:44:34 -0800 Subject: [PATCH 04/16] Remove foolish Mutex --- dfdx-core/src/tensor/webgpu/device.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index ee9959dc..ccea8121 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -21,8 +21,6 @@ use std::{marker::PhantomData, sync::Arc, vec::Vec}; use super::allocate::round_to_buffer_alignment; -static CONSTRUCTOR_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); - #[derive(Debug)] pub struct Buffer { pub(crate) data: wgpu::Buffer, From 7ef20dc71f1a8361346435e6bbdc6ffbcc63dd8e Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Mon, 27 Nov 2023 15:49:51 -0800 Subject: [PATCH 05/16] Add Mutex back, since evidently it was causing issues. Hopefully I can figure out a way to remove it again. --- dfdx-core/src/tensor/webgpu/device.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index ccea8121..3cba06c7 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -46,7 +46,12 @@ impl Buffer { } pub(crate) fn copy_to_device(&self, dev: &Device, queue: &Queue, slice: &[E]) { - let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, self.size()) }; + let slice = unsafe { + std::slice::from_raw_parts( + slice.as_ptr() as *const u8, + slice.len() * std::mem::size_of::(), + ) + }; queue.write_buffer(&self.data, 0, slice); queue.submit(std::iter::empty()); dev.poll(Maintain::Wait); From f5a2b3de07a0c03a378b696b09ce2c54e4d1141d Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Mon, 27 Nov 2023 16:06:55 -0800 Subject: [PATCH 06/16] Removed `num_traits::Num` requirement from Zeros. Had to figure out a way to store zeros in place --- dfdx-core/src/tensor/webgpu/allocate.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 0c06f65f..49162381 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -52,7 +52,7 @@ impl Webgpu { } } -impl> ZerosTensor for Webgpu { +impl ZerosTensor for Webgpu { fn try_zeros_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let strides = shape.strides(); @@ -63,7 +63,7 @@ impl> ZerosTensor for Webgpu { } } -impl> ZeroFillStorage for Webgpu { +impl ZeroFillStorage for Webgpu { fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> { storage.copy_to_device(&self.dev, &self.queue, &vec![0u8; storage.size()]); From 1e2d1ecbb94567214afb832b20dc5db24c3a4898 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Sat, 2 Dec 2023 18:13:03 -0800 Subject: [PATCH 07/16] Implement abs kernel, and use broken unary operation for all the compiler errors --- dfdx-core/src/tensor/webgpu/device.rs | 67 ++++++- dfdx-core/src/tensor_ops/abs/abs.wgsl | 40 ++++ dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs | 30 +-- .../tensor_ops/accurate_gelu/webgpu_kernel.rs | 35 +--- dfdx-core/src/tensor_ops/add/webgpu_kernel.rs | 27 +-- .../src/tensor_ops/clamp/webgpu_kernel.rs | 35 +--- dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/div/webgpu_kernel.rs | 31 +-- dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs | 29 +-- .../src/tensor_ops/fast_gelu/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs | 31 +-- .../src/tensor_ops/nans_to/webgpu_kernel.rs | 40 ++-- .../src/tensor_ops/negate/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs | 63 +++--- .../src/tensor_ops/recip/webgpu_kernel.rs | 29 +-- .../src/tensor_ops/relu/webgpu_kernel.rs | 35 +--- .../src/tensor_ops/sigmoid/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs | 29 +-- .../src/tensor_ops/sqrt/webgpu_kernel.rs | 28 +-- .../src/tensor_ops/square/webgpu_kernel.rs | 35 +--- dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs | 30 +-- .../src/tensor_ops/tanh/webgpu_kernel.rs | 30 +-- dfdx-core/src/tensor_ops/utilities/device.rs | 6 +- dfdx-core/src/tensor_ops/utilities/mod.rs | 2 + .../tensor_ops/utilities/webgpu_kernels.rs | 188 ++++++++++++++++++ 26 files changed, 424 insertions(+), 561 deletions(-) create mode 100644 dfdx-core/src/tensor_ops/abs/abs.wgsl create mode 100644 dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 3cba06c7..4eb62c97 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -1,6 +1,7 @@ use wgpu::{ + util::{BufferInitDescriptor, DeviceExt}, Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue, - RequestDeviceError, + RequestDeviceError, ShaderModule, ShaderModuleDescriptor, }; use crate::{ @@ -14,10 +15,16 @@ use crate::{ #[cfg(feature = "no-std")] use spin::Mutex; +use core::any::TypeId; #[cfg(not(feature = "no-std"))] use std::sync::Mutex; -use std::{marker::PhantomData, sync::Arc, vec::Vec}; +use std::{ + collections::HashMap, + marker::PhantomData, + sync::{Arc, RwLock}, + vec::Vec, +}; use super::allocate::round_to_buffer_alignment; @@ -40,12 +47,16 @@ impl Buffer { self.size } + pub(crate) fn len(&self) -> usize { + self.size / std::mem::size_of::() + } + #[allow(unused)] pub(crate) fn capacity(&self) -> usize { self.data.size() as usize } - pub(crate) fn copy_to_device(&self, dev: &Device, queue: &Queue, slice: &[E]) { + pub(crate) fn copy_to_device(&self, dev: &Device, queue: &Queue, slice: &[E]) { let slice = unsafe { std::slice::from_raw_parts( slice.as_ptr() as *const u8, @@ -102,6 +113,7 @@ pub struct Webgpu { pub(crate) queue: Arc, pub(crate) cache: Arc>, + pub(crate) cs_cache: Arc>>>, } impl From for Error { @@ -147,18 +159,19 @@ impl Webgpu { queue, cache: Default::default(), + cs_cache: Default::default(), }) } } impl Webgpu { - pub(crate) unsafe fn alloc_empty(&self, len: usize) -> Result { + pub(crate) fn alloc_empty(&self, len: usize) -> Result { let data = self.cache.try_pop::(len).map_or_else( || Buffer { data: self.dev.create_buffer(&BufferDescriptor { label: None, size: round_to_buffer_alignment((len * std::mem::size_of::()) as u64), - usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST, + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST, mapped_at_creation: false, }), size: len * std::mem::size_of::(), @@ -168,6 +181,50 @@ impl Webgpu { Ok(data) } + pub(crate) fn alloc_init(&self, init: &[E]) -> Result { + let data = self.cache.try_pop::(init.len()).map_or_else( + || { + let contents = unsafe { + std::slice::from_raw_parts( + init.as_ptr() as *const u8, + init.len() * std::mem::size_of::(), + ) + }; + Buffer { + data: self.dev.create_buffer_init(&BufferInitDescriptor { + label: None, + usage: BufferUsages::STORAGE + | BufferUsages::COPY_SRC + | BufferUsages::COPY_DST, + contents, + }), + size: init.len() * std::mem::size_of::(), + } + }, + |bfr| { + bfr.copy_to_device::(&self.dev, &self.queue, init); + bfr + }, + ); + Ok(data) + } + + pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool { + self.cs_cache.read().unwrap().contains_key(&name) + } + + pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) { + let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(source.into()), + })); + self.cs_cache.write().unwrap().insert(name, module); + } + + pub(crate) fn get_shader_module(&self, name: TypeId) -> Option> { + self.cs_cache.read().unwrap().get(&name).cloned() + } + // #[allow(unused)] // pub(crate) unsafe fn get_workspace(&self, len: usize) -> Result, Error> { // let num_bytes_required = len * std::mem::size_of::(); diff --git a/dfdx-core/src/tensor_ops/abs/abs.wgsl b/dfdx-core/src/tensor_ops/abs/abs.wgsl new file mode 100644 index 00000000..da980ac0 --- /dev/null +++ b/dfdx-core/src/tensor_ops/abs/abs.wgsl @@ -0,0 +1,40 @@ +struct AbsKernelOp {}; + +@group(0) +@binding(0) +var op: AbsKernelOp; + +@group(0) +@binding(1) +var inp: array; + +@group(0) +@binding(2) +var out: array; + +@group(0) +@binding(3) +var inp_grad: array; + +@group(0) +@binding(4) +var out_grad: array; + +@compute +@workgroup_size(1) +fn abs_fwd_f32(@builtin(global_invocation_id) global_id: vec3) { + var x = if arrayLength(inp) > 0 { &inp[global_id] } else { &out[global_id] }; + *x = abs(*x); +} + +@compute +@workgroup_size(1) +fn abs_bwd_f32(@builtin(global_invocation_id) global_id: vec3) { + // Not needed for Abs, but if we can figure out a template system, we can leave it all in. + // let x = if arrayLength(inp) > 0 { inp[global_id] } else { 0.0 }; + // let y = if arrayLength(out) > 0 { out[global_id] } else { 0.0 }; + var dx: f32; + dx = sign(inp[global_id]); + + inp_grad[global_id] += dx * out_grad[global_id]; +} diff --git a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs index c993ee91..a5d9059e 100644 --- a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs @@ -1,28 +1,6 @@ -use std::borrow::Cow; +use super::AbsKernelOp; +use crate::tensor_ops::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = include_str!("abs.wgsl"); -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::AbsKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::AbsKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(AbsKernelOp, f32, WGSL, "abs_fwd_f32", "abs_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs index 080a857d..de36720c 100644 --- a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::AccurateGeLUKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::AccurateGeLUKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::AccurateGeLUKernelOp, + f32, + WGSL, + "gelu_fwd_f32", + "gelu_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs index 91becc55..204a4357 100644 --- a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs @@ -1,34 +1,15 @@ +use super::{BinaryAddKernelOp as Binary, ScalarAddKernelOp as Scalar}; use std::borrow::Cow; use crate::prelude::{ ops::{BinaryKernel, UnaryKernel}, + webgpu_kernels::webgpu_unary, Dtype, Webgpu, }; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +const WGSL: &str = "TODO"; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarAddKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ScalarAddKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(Scalar, f32, WGSL, "scalar_fwd_f32", "scalar_bwd_f32"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs index df700d20..82485d5a 100644 --- a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::ClampKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ClampKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::ClampKernelOp, + f32, + WGSL, + "clamp_fwd_f32", + "clamp_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs index a59bb5c8..a352702f 100644 --- a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::CosKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::CosKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::CosKernelOp, f32, WGSL, "cos_fwd_f32", "cos_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs index 3a15ef7e..158a05fd 100644 --- a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs @@ -1,34 +1,11 @@ +use super::{BinaryDivKernelOp as Binary, ScalarDivKernelOp as Scalar}; use std::borrow::Cow; -use crate::prelude::{ - ops::{BinaryKernel, UnaryKernel}, - Dtype, Webgpu, -}; +use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +const WGSL: &str = "TODO"; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarDivKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ScalarDivKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs index 4f552b49..8670fc23 100644 --- a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::ExpKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ExpKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::ExpKernelOp, f32, WGSL, "exp_fwd_f32", "exp_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs index cbdce3d9..c05b3df5 100644 --- a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::FastGeLUKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::FastGeLUKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::FastGeLUKernelOp, f32, WGSL, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs index 64694c6f..bf08f71f 100644 --- a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::LnKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::LnKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::LnKernelOp, f32, WGSL, "ln_fwd_f32", "ln_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs index 240ba571..6b778db4 100644 --- a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs @@ -1,34 +1,11 @@ +use super::{BinaryMulKernelOp as Binary, ScalarMulKernelOp as Scalar}; use std::borrow::Cow; -use crate::prelude::{ - ops::{BinaryKernel, UnaryKernel}, - Dtype, Webgpu, -}; +use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +const WGSL: &str = "TODO"; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarMulKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ScalarMulKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_mul_fwd", "scalar_mul_bwd"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs index 58cc8c36..5721f4a0 100644 --- a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs @@ -1,28 +1,12 @@ -use std::borrow::Cow; - -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; - -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::NansToKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::NansToKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +use super::NansToKernelOp; +use crate::prelude::webgpu_kernels::webgpu_unary; + +const WGSL: &str = "TODO"; + +webgpu_unary!( + NansToKernelOp, + f32, + WGSL, + "nans_to_fwd_f32", + "nans_to_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs index 4794d906..2e16ebad 100644 --- a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::NegateKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::NegateKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::NegateKernelOp, f32, WGSL, "negate_fwd_f32", "negate_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs index 0cf6b43d..d21a15b6 100644 --- a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs @@ -1,33 +1,22 @@ use std::borrow::Cow; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; - -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::PowfKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::PowfKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} - -impl UnaryKernel for Webgpu { +use crate::prelude::{ops::UnaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; + +const WGSL: &str = "TODO"; + +webgpu_unary!( + super::PowfKernelOp, + f32, + WGSL, + "powf_fwd_f32", + "powf_bwd_f32" +); + +// TODO: Conflicting implementations of trait `UnaryKernel` for type `Webgpu`: +impl UnaryKernel for Webgpu +where + Self: UnaryKernel, f32>, +{ const BACKWARD_WITHOUT_INP: bool = false; const BACKWARD_WITHOUT_DATA: bool = false; @@ -35,19 +24,25 @@ impl UnaryKernel for Webgpu { fn forward( &self, op: super::PowiKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() + inp: Cow>, + ) -> Result, crate::prelude::Error> { + self.forward(super::PowfKernelOp(op.0 as f32), inp) } fn backward( &self, op: super::PowiKernelOp, - inp: &impl crate::prelude::Tensorlike, + inp: &impl crate::prelude::Tensorlike, grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, + out: &impl crate::prelude::Tensorlike, grad_out: &Self::Vec, ) -> Result<(), crate::prelude::Error> { - todo!() + self.backward( + super::PowfKernelOp(op.0 as f32), + inp, + grad_inp, + out, + grad_out, + ) } } diff --git a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs index ca8fd312..d3a14fe0 100644 --- a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::RecipKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::RecipKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(df(f(x)) super::RecipKernelOp, f32, WGSL, "recip_fwd_f32", "recip_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs index 6da7d6b9..c986c73d 100644 --- a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::ReLUKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ReLUKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::ReLUKernelOp, + f32, + WGSL, + "relu_fwd_f32", + "relu_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs index f6e5c742..377c4453 100644 --- a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SigmoidKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SigmoidKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(df(f(x)) super::SigmoidKernelOp, f32, WGSL, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs index 024c1382..befc2da0 100644 --- a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SinKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SinKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::SinKernelOp, f32, WGSL, "sin_fwd_f32", "sin_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs index 0701ee08..65def0b2 100644 --- a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs @@ -1,28 +1,6 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +webgpu_unary!(df(f(x)) super::SqrtKernelOp, f32, WGSL, "sqrt_fwd_f32", "sqrt_bwd_f32"); - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SqrtKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SqrtKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} diff --git a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs index 522eae17..f16523a2 100644 --- a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SquareKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SquareKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::SquareKernelOp, + f32, + WGSL, + "square_fwd_f32", + "square_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs index 8d5e943e..0d476c94 100644 --- a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs @@ -1,34 +1,12 @@ use std::borrow::Cow; -use crate::prelude::{ - ops::{BinaryKernel, UnaryKernel}, - Dtype, Webgpu, -}; +use super::{BinarySubKernelOp as Binary, ScalarSubKernelOp as Scalar}; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarSubKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } +const WGSL: &str = "TODO"; - fn backward( - &self, - op: super::ScalarSubKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs index 51e661b2..aa8742d8 100644 --- a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs @@ -1,28 +1,6 @@ -use std::borrow::Cow; +use super::TanhKernelOp; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::TanhKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::TanhKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(TanhKernelOp, f32, WGSL, "tanh_fwd_f32", "tanh_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 388eea7f..ac77772b 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -137,5 +137,7 @@ impl Device for crate::tensor::Webgpu {} impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] impl Device for crate::tensor::Webgpu {} -#[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} + +// TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? +// #[cfg(feature = "webgpu")] +// impl Device for crate::tensor::Webgpu {} diff --git a/dfdx-core/src/tensor_ops/utilities/mod.rs b/dfdx-core/src/tensor_ops/utilities/mod.rs index e23565bc..a0a3a2b3 100644 --- a/dfdx-core/src/tensor_ops/utilities/mod.rs +++ b/dfdx-core/src/tensor_ops/utilities/mod.rs @@ -5,6 +5,8 @@ pub(crate) mod cuda_kernels; mod device; pub(crate) mod ops; pub(crate) mod reduction_utils; +#[cfg(feature = "webgpu")] +pub(crate) mod webgpu_kernels; pub use backward::Backward; pub use device::Device; diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs new file mode 100644 index 00000000..2a3652fa --- /dev/null +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -0,0 +1,188 @@ +use crate::{ + shapes::{Dtype, Shape}, + tensor::*, + tensor_ops::ops::{BinaryKernel, UnaryKernel}, +}; +use core::any::TypeId; +use std::{borrow::Cow, sync::Arc, vec::Vec}; + +pub(crate) trait UnaryOpWebgpuKernel { + const DF_USES_FX: bool; + const HAS_CONST_DF: bool; + + /// Compiled by build.rs + const WGSL_SRC: &'static str; + + /// Unique name for the kernel + const MODULE_NAME: &'static str; + + /// Name of function in the .wgsl file + const FWD_FN_NAME: &'static str; + + /// Name of function in the .wgsl file + const BWD_FN_NAME: &'static str; + + const ALL_FN_NAMES: [&'static str; 2] = [Self::FWD_FN_NAME, Self::BWD_FN_NAME]; +} + +macro_rules! webgpu_unary { + ($Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = stringify!($Op); + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; + (df(f(x)) $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = true; + const HAS_CONST_DF: bool = false; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = $Fwd; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; + (const_df() $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = true; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = $Fwd; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; +} + +pub(crate) use webgpu_unary; +use wgpu::ComputePipelineDescriptor; + +impl + 'static> UnaryKernel for Webgpu { + const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; + const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF; + + fn forward( + &self, + op: K, + inp: Cow>, + ) -> Result, Error> { + if !self.shader_module_loaded(TypeId::of::()) { + self.load_shader_module(TypeId::of::(), K::WGSL_SRC); + } + + let cs_module = self + .get_shader_module(TypeId::of::()) + .expect("shader module not loaded"); + let pipeline = self + .dev + .create_compute_pipeline(&ComputePipelineDescriptor { + label: None, + layout: None, + module: &cs_module, + entry_point: K::FWD_FN_NAME, + }); + let bind_group_layout = pipeline.get_bind_group_layout(0); + let op_storage = self.alloc_init::(&[op])?; + let numel = inp.data.len::(); + let storage = self.alloc_empty::(numel)?; + let empty = self.alloc_empty::(0)?; + + match inp { + Cow::Borrowed(inp) => { + let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + op_storage.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + inp.data.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + storage.as_entire_buffer_binding(), + ), + }, + ], + }); + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &binding_group, &[]); + cpass.dispatch_workgroups(numel as u32, 1, 1); + } + self.queue.submit(Some(encoder.finish())); + Ok(self.build_tensor(inp.shape, inp.strides, storage)) + } + Cow::Owned(mut inp) => { + let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + op_storage.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + empty.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + Arc::make_mut(&mut inp.data).as_entire_buffer_binding(), + ), + }, + ], + }); + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &binding_group, &[]); + cpass.dispatch_workgroups(numel as u32, 1, 1); + } + self.queue.submit(Some(encoder.finish())); + Ok(inp) + } + } + } + + fn backward( + &self, + op: K, + inp: &impl Tensorlike, + grad_inp: &mut Self::Vec, + out: &impl Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +} From bd75762e1716e19c4f2bdf5c56198abcee6dd154 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Sat, 2 Dec 2023 18:16:21 -0800 Subject: [PATCH 08/16] cargo fmt --- dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs | 8 +++++++- dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs | 8 +++++++- dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs | 1 - 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs index c05b3df5..4abc2738 100644 --- a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs @@ -2,4 +2,10 @@ use crate::prelude::webgpu_kernels::webgpu_unary; const WGSL: &str = "TODO"; -webgpu_unary!(super::FastGeLUKernelOp, f32, WGSL, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); +webgpu_unary!( + super::FastGeLUKernelOp, + f32, + WGSL, + "sigmoid_fwd_f32", + "sigmoid_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs index 2e16ebad..be7690f6 100644 --- a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs @@ -2,4 +2,10 @@ use crate::prelude::webgpu_kernels::webgpu_unary; const WGSL: &str = "TODO"; -webgpu_unary!(super::NegateKernelOp, f32, WGSL, "negate_fwd_f32", "negate_bwd_f32"); +webgpu_unary!( + super::NegateKernelOp, + f32, + WGSL, + "negate_fwd_f32", + "negate_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs index 65def0b2..e2b0b032 100644 --- a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs @@ -3,4 +3,3 @@ use crate::prelude::webgpu_kernels::webgpu_unary; const WGSL: &str = "TODO"; webgpu_unary!(df(f(x)) super::SqrtKernelOp, f32, WGSL, "sqrt_fwd_f32", "sqrt_bwd_f32"); - From fde8d50f57bcc2f81cc03db18a1df9b3086038b4 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Sat, 2 Dec 2023 18:23:55 -0800 Subject: [PATCH 09/16] disable f16, since we don't support it yet --- dfdx-core/src/tensor_ops/utilities/device.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index ac77772b..00fa9502 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -131,10 +131,11 @@ impl Device for crate::tensor::Cuda {} #[cfg(feature = "cuda")] impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "webgpu", feature = "f16"))] -impl Device for crate::tensor::Webgpu {} -#[cfg(all(feature = "webgpu", feature = "f16"))] -impl Device> for crate::tensor::Webgpu {} +// TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? +// #[cfg(all(feature = "webgpu", feature = "f16"))] +// impl Device for crate::tensor::Webgpu {} +// #[cfg(all(feature = "webgpu", feature = "f16"))] +// impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] impl Device for crate::tensor::Webgpu {} From e3f21130920eb8576069c8c2749c3d02a63fee0f Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Sat, 2 Dec 2023 18:36:40 -0800 Subject: [PATCH 10/16] no-std --- dfdx-core/src/tensor/webgpu/device.rs | 28 ++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 4eb62c97..086eb4e4 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -13,18 +13,13 @@ use crate::{ }; #[cfg(feature = "no-std")] -use spin::Mutex; +use spin::{Mutex, RwLock}; use core::any::TypeId; #[cfg(not(feature = "no-std"))] -use std::sync::Mutex; +use std::sync::{Mutex, RwLock}; -use std::{ - collections::HashMap, - marker::PhantomData, - sync::{Arc, RwLock}, - vec::Vec, -}; +use std::{collections::HashMap, marker::PhantomData, sync::Arc, vec::Vec}; use super::allocate::round_to_buffer_alignment; @@ -209,22 +204,37 @@ impl Webgpu { Ok(data) } + #[cfg(not(feature = "no-std"))] pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool { self.cs_cache.read().unwrap().contains_key(&name) } + #[cfg(feature = "no-std")] + pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool { + self.cs_cache.read().contains_key(&name) + } + pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) { let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor { label: None, source: wgpu::ShaderSource::Wgsl(source.into()), })); + #[cfg(not(feature = "no-std"))] self.cs_cache.write().unwrap().insert(name, module); + #[cfg(feature = "no-std")] + self.cs_cache.write().insert(name, module); } + #[cfg(not(feature = "no-std"))] pub(crate) fn get_shader_module(&self, name: TypeId) -> Option> { self.cs_cache.read().unwrap().get(&name).cloned() } + #[cfg(feature = "no-std")] + pub(crate) fn get_shader_module(&self, name: TypeId) -> Option> { + self.cs_cache.read().get(&name).cloned() + } + // #[allow(unused)] // pub(crate) unsafe fn get_workspace(&self, len: usize) -> Result, Error> { // let num_bytes_required = len * std::mem::size_of::(); @@ -369,7 +379,7 @@ impl Storage for Webgpu { type Vec = CachableBuffer; fn try_alloc_len(&self, len: usize) -> Result { - let data = unsafe { self.alloc_empty::(len) }?; + let data = self.alloc_empty::(len)?; Ok(CachableBuffer { dev: self.dev.clone(), queue: self.queue.clone(), From b69934047624001968892582f6db9957e0ee66e7 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Wed, 13 Dec 2023 17:21:22 -0800 Subject: [PATCH 11/16] Added test for abs on webgpu. Also added `backward` implementation, though I won't be able to test that until I fix `mean`. --- dfdx-core/src/tensor_ops/abs/abs.wgsl | 22 ++- dfdx-core/src/tensor_ops/abs/mod.rs | 13 ++ .../tensor_ops/utilities/webgpu_kernels.rs | 174 +++++++++++++----- 3 files changed, 161 insertions(+), 48 deletions(-) diff --git a/dfdx-core/src/tensor_ops/abs/abs.wgsl b/dfdx-core/src/tensor_ops/abs/abs.wgsl index da980ac0..cd812370 100644 --- a/dfdx-core/src/tensor_ops/abs/abs.wgsl +++ b/dfdx-core/src/tensor_ops/abs/abs.wgsl @@ -1,8 +1,11 @@ -struct AbsKernelOp {}; +// TODO: We need to figure out how to represent empty structs in wgsl +// struct AbsKernelOp { +// empty: u32, +// } @group(0) @binding(0) -var op: AbsKernelOp; +var op: array; @group(0) @binding(1) @@ -23,8 +26,13 @@ var out_grad: array; @compute @workgroup_size(1) fn abs_fwd_f32(@builtin(global_invocation_id) global_id: vec3) { - var x = if arrayLength(inp) > 0 { &inp[global_id] } else { &out[global_id] }; - *x = abs(*x); + // let length: u32 = arrayLength(&inp); + // if (length > 1) { + // out[global_id.x] = abs(inp[global_id.x]); + // } else { + // out[global_id.x] = abs(out[global_id.x]); + // } + out[global_id.x] = abs(inp[global_id.x]); } @compute @@ -33,8 +41,8 @@ fn abs_bwd_f32(@builtin(global_invocation_id) global_id: vec3) { // Not needed for Abs, but if we can figure out a template system, we can leave it all in. // let x = if arrayLength(inp) > 0 { inp[global_id] } else { 0.0 }; // let y = if arrayLength(out) > 0 { out[global_id] } else { 0.0 }; - var dx: f32; - dx = sign(inp[global_id]); + var dx: f32; + dx = sign(inp[global_id.x]); - inp_grad[global_id] += dx * out_grad[global_id]; + inp_grad[global_id.x] += dx * out_grad[global_id.x]; } diff --git a/dfdx-core/src/tensor_ops/abs/mod.rs b/dfdx-core/src/tensor_ops/abs/mod.rs index 45c7794d..7f5dbe9c 100644 --- a/dfdx-core/src/tensor_ops/abs/mod.rs +++ b/dfdx-core/src/tensor_ops/abs/mod.rs @@ -57,4 +57,17 @@ mod tests { let g = r.mean().backward(); assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); } + + #[cfg(feature = "webgpu")] + #[test] + fn test_webgpu_abs() { + let dev: Webgpu = Default::default(); + let x = dev + .tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); + let r = x.leaky_trace().abs(); + assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]); + // TODO: Add mean back in + // let g = r.mean().backward(); + // assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); + } } diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index 2a3652fa..619afb1c 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -90,32 +90,29 @@ impl + 'static> UnaryKernel for Webgpu let numel = inp.data.len::(); let storage = self.alloc_empty::(numel)?; let empty = self.alloc_empty::(0)?; + let mut entries = vec![]; + // WGSL doesn't support empty structs, so don't bind the empty buffer + if std::mem::size_of::() > 0 { + entries.push(wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer(op_storage.as_entire_buffer_binding()), + }); + } match inp { Cow::Borrowed(inp) => { + entries.push(wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Buffer(inp.data.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 2, + resource: wgpu::BindingResource::Buffer(storage.as_entire_buffer_binding()), + }); let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { label: None, layout: &bind_group_layout, - entries: &[ - wgpu::BindGroupEntry { - binding: 0, - resource: wgpu::BindingResource::Buffer( - op_storage.as_entire_buffer_binding(), - ), - }, - wgpu::BindGroupEntry { - binding: 0, - resource: wgpu::BindingResource::Buffer( - inp.data.as_entire_buffer_binding(), - ), - }, - wgpu::BindGroupEntry { - binding: 0, - resource: wgpu::BindingResource::Buffer( - storage.as_entire_buffer_binding(), - ), - }, - ], + entries: &entries, }); let mut encoder = self .dev @@ -133,29 +130,20 @@ impl + 'static> UnaryKernel for Webgpu Ok(self.build_tensor(inp.shape, inp.strides, storage)) } Cow::Owned(mut inp) => { + entries.push(wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Buffer(empty.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 2, + resource: wgpu::BindingResource::Buffer( + Arc::make_mut(&mut inp.data).as_entire_buffer_binding(), + ), + }); let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { label: None, layout: &bind_group_layout, - entries: &[ - wgpu::BindGroupEntry { - binding: 0, - resource: wgpu::BindingResource::Buffer( - op_storage.as_entire_buffer_binding(), - ), - }, - wgpu::BindGroupEntry { - binding: 0, - resource: wgpu::BindingResource::Buffer( - empty.as_entire_buffer_binding(), - ), - }, - wgpu::BindGroupEntry { - binding: 0, - resource: wgpu::BindingResource::Buffer( - Arc::make_mut(&mut inp.data).as_entire_buffer_binding(), - ), - }, - ], + entries: &entries, }); let mut encoder = self .dev @@ -183,6 +171,110 @@ impl + 'static> UnaryKernel for Webgpu out: &impl Tensorlike, grad_out: &Self::Vec, ) -> Result<(), Error> { - todo!() + if !self.shader_module_loaded(TypeId::of::()) { + self.load_shader_module(TypeId::of::(), K::WGSL_SRC); + } + + let cs_module = self + .get_shader_module(TypeId::of::()) + .expect("shader module not loaded"); + let pipeline = self + .dev + .create_compute_pipeline(&ComputePipelineDescriptor { + label: None, + layout: None, + module: &cs_module, + entry_point: K::BWD_FN_NAME, + }); + let bind_group_layout = pipeline.get_bind_group_layout(0); + let op_storage = self.alloc_init::(&[op])?; + let numel = inp.len(); + let storage = self.alloc_empty::(numel)?; + let empty_inp = self.alloc_empty::(0)?; + let empty_out = self.alloc_empty::(0)?; + let mut entries = vec![]; + // WGSL doesn't support empty structs, so don't bind the empty buffer + if std::mem::size_of::() > 0 { + entries.push(wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer(op_storage.as_entire_buffer_binding()), + }); + } + match (inp.data(), out.data()) { + (None, None) => { + entries.push(wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Buffer(empty_inp.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 2, + resource: wgpu::BindingResource::Buffer(empty_out.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 3, + resource: wgpu::BindingResource::Buffer(grad_inp.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 4, + resource: wgpu::BindingResource::Buffer(grad_out.as_entire_buffer_binding()), + }); + } + (None, Some(out)) => { + entries.push(wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Buffer(empty_inp.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 2, + resource: wgpu::BindingResource::Buffer(out.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 3, + resource: wgpu::BindingResource::Buffer(grad_inp.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 4, + resource: wgpu::BindingResource::Buffer(grad_out.as_entire_buffer_binding()), + }); + } + (Some(inp), None) => { + entries.push(wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Buffer(inp.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 2, + resource: wgpu::BindingResource::Buffer(empty_out.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 3, + resource: wgpu::BindingResource::Buffer(grad_inp.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 4, + resource: wgpu::BindingResource::Buffer(grad_out.as_entire_buffer_binding()), + }); + } + _ => unreachable!(), + }; + let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &entries, + }); + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &binding_group, &[]); + cpass.dispatch_workgroups(numel as u32, 1, 1); + } + self.queue.submit(Some(encoder.finish())); + Ok(()) } } From e25553a1dd8e3153648860041d2802f19d1badff Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Wed, 13 Dec 2023 17:31:19 -0800 Subject: [PATCH 12/16] cargo fmt --- dfdx-core/src/tensor_ops/abs/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor_ops/abs/mod.rs b/dfdx-core/src/tensor_ops/abs/mod.rs index 7f5dbe9c..e2ce6982 100644 --- a/dfdx-core/src/tensor_ops/abs/mod.rs +++ b/dfdx-core/src/tensor_ops/abs/mod.rs @@ -62,8 +62,7 @@ mod tests { #[test] fn test_webgpu_abs() { let dev: Webgpu = Default::default(); - let x = dev - .tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); + let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); let r = x.leaky_trace().abs(); assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]); // TODO: Add mean back in From 7c686a166d40f2d4017d9c85eede8579522d5403 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Sat, 23 Dec 2023 19:32:46 -0800 Subject: [PATCH 13/16] Managed to get built spirv working as long as we go through the non-passthrough route. Can't get sum_to working until wgpu supports atomic operations. Which is super unfortunate. Maybe I'll work on that soon... --- dfdx-core/Cargo.toml | 12 +- dfdx-core/build.rs | 52 ++++++++ dfdx-core/src/tensor/error.rs | 3 + dfdx-core/src/tensor/webgpu/allocate.rs | 4 +- dfdx-core/src/tensor/webgpu/device.rs | 31 ++++- dfdx-core/src/tensor_ops/abs/abs.bwd.glsl | 28 +++++ dfdx-core/src/tensor_ops/abs/abs.fwd.glsl | 22 ++++ dfdx-core/src/tensor_ops/abs/abs.wgsl | 48 ------- dfdx-core/src/tensor_ops/abs/mod.rs | 12 -- dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs | 23 +++- .../tensor_ops/accurate_gelu/webgpu_kernel.rs | 10 +- dfdx-core/src/tensor_ops/add/webgpu_kernel.rs | 4 +- .../src/tensor_ops/clamp/webgpu_kernel.rs | 10 +- dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs | 4 +- dfdx-core/src/tensor_ops/div/webgpu_kernel.rs | 4 +- dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs | 4 +- .../src/tensor_ops/fast_gelu/webgpu_kernel.rs | 10 +- dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs | 4 +- dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs | 4 +- .../src/tensor_ops/nans_to/webgpu_kernel.rs | 10 +- .../src/tensor_ops/negate/webgpu_kernel.rs | 10 +- dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs | 10 +- .../src/tensor_ops/recip/webgpu_kernel.rs | 4 +- .../src/tensor_ops/relu/webgpu_kernel.rs | 10 +- .../src/tensor_ops/sigmoid/webgpu_kernel.rs | 4 +- dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs | 4 +- .../src/tensor_ops/sqrt/webgpu_kernel.rs | 4 +- .../src/tensor_ops/square/webgpu_kernel.rs | 10 +- dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs | 4 +- .../src/tensor_ops/sum_to/sum_to.fwd.glsl | 108 ++++++++++++++++ .../src/tensor_ops/sum_to/webgpu_kernel.rs | 117 +++++++++++++++++- .../src/tensor_ops/tanh/webgpu_kernel.rs | 4 +- .../tensor_ops/utilities/reduction_utils.rs | 6 +- .../tensor_ops/utilities/webgpu_kernels.rs | 102 ++++++++------- 34 files changed, 487 insertions(+), 209 deletions(-) create mode 100644 dfdx-core/src/tensor_ops/abs/abs.bwd.glsl create mode 100644 dfdx-core/src/tensor_ops/abs/abs.fwd.glsl delete mode 100644 dfdx-core/src/tensor_ops/abs/abs.wgsl create mode 100644 dfdx-core/src/tensor_ops/sum_to/sum_to.fwd.glsl diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 0eeac1f4..5309ef7c 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -38,7 +38,8 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } -wgpu = { version = "0.18.0", optional = true } +wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } +naga = { version = "0.14.1", optional = true } futures-lite = { version = "2.0.1", optional = true } thingbuf = { version = "0.1.4", optional = true } @@ -62,7 +63,14 @@ fast-alloc = ["std"] cuda = ["dep:cudarc", "dep:glob"] cudnn = ["cuda", "cudarc?/cudnn"] -webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"] +webgpu = [ + "dep:wgpu", + "dep:futures-lite", + "dep:thingbuf", + "dep:naga", + "dep:glob", + "wgpu/expose-ids", +] f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"] diff --git a/dfdx-core/build.rs b/dfdx-core/build.rs index 1048382a..76d33682 100644 --- a/dfdx-core/build.rs +++ b/dfdx-core/build.rs @@ -9,6 +9,9 @@ fn main() { #[cfg(feature = "cuda")] cuda::build_ptx(); + + #[cfg(feature = "webgpu")] + webgpu::build_spv(); } fn maybe_enable_nightly() { @@ -210,3 +213,52 @@ mod cuda { } } } + +#[cfg(feature = "webgpu")] +mod webgpu { + pub fn build_spv() { + let out_dir = std::env::var("OUT_DIR").unwrap(); + let kernel_paths: Vec = glob::glob("src/**/*.glsl") + .unwrap() + .map(|p| p.unwrap()) + .collect(); + for path in &kernel_paths { + println!("cargo:rerun-if-changed={}", path.display()); + } + + kernel_paths + .iter() + .for_each(|p| println!("cargo:rerun-if-changed={}", p.display())); + + let children = kernel_paths + .iter() + .map(|p| { + // TODO: we need to build this for both float and double + let out_path: std::path::PathBuf = out_dir.clone().into(); + let base = p.file_stem().unwrap(); + let new_name = format!("{}.float.spv", base.to_str().unwrap()); + let out_file = &out_path.join(new_name); + eprintln!("out_file: {:?}", out_file); + std::process::Command::new("glslc") + .args(["-std=460core"]) + .args(["-fshader-stage=compute"]) + .args(["-DTYPENAME=float"]) + .args(["-o", &out_file.as_os_str().to_str().unwrap()]) + .arg(p) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .expect("glslc failed to start. Ensure that you have shaderc installed and that `glslc` is in your PATH.") + }) + .collect::>(); + for (kernel_path, child) in kernel_paths.iter().zip(children.into_iter()) { + let output = child.wait_with_output().expect("glslc failed to run. Ensure that you have shaderc installed and that `glslc` is in your PATH."); + assert!( + output.status.success(), + "glslc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + } +} diff --git a/dfdx-core/src/tensor/error.rs b/dfdx-core/src/tensor/error.rs index 906c474c..6ab35ac6 100644 --- a/dfdx-core/src/tensor/error.rs +++ b/dfdx-core/src/tensor/error.rs @@ -22,6 +22,9 @@ pub enum Error { #[cfg(feature = "webgpu")] WebgpuRequestDeviceError(wgpu::RequestDeviceError), + + #[cfg(feature = "webgpu")] + WebgpuSourceLoadError, } impl std::fmt::Display for Error { diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 49162381..4e4a2692 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -22,7 +22,7 @@ impl Webgpu { shape: S, buf: Vec, ) -> Result, Error> { - let buffer = unsafe { self.alloc_empty::(buf.len()) }?; + let buffer = self.alloc_empty::(buf.len())?; buffer.copy_to_device::(&self.dev, &self.queue, &buf); Ok(self.build_tensor(shape, shape.strides(), buffer)) @@ -56,7 +56,7 @@ impl ZerosTensor for Webgpu { fn try_zeros_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let strides = shape.strides(); - let data = unsafe { self.alloc_empty::(shape.num_elements()) }?; + let data = self.alloc_empty::(shape.num_elements())?; data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]); Ok(self.build_tensor(shape, strides, data)) diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 086eb4e4..23d73060 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -1,10 +1,12 @@ use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue, - RequestDeviceError, ShaderModule, ShaderModuleDescriptor, + util::{make_spirv, make_spirv_raw, BufferInitDescriptor, DeviceExt}, + Adapter, BufferDescriptor, BufferUsages, Device, DeviceDescriptor, Features, Instance, + InstanceDescriptor, Maintain, Queue, RequestDeviceError, ShaderModule, ShaderModuleDescriptor, + ShaderModuleDescriptorSpirV, }; use crate::{ + prelude::webgpu_kernels::HasGlslType, shapes::{Shape, Unit}, tensor::{ cache::TensorCache, cpu::Cpu, Cache, Error, NoneTape, RandomU64, Storage, Synchronize, @@ -141,8 +143,13 @@ impl Webgpu { let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default())) .ok_or(Error::WebgpuAdapterNotFound)?; let adapter = Arc::new(adapter); + let descriptor = DeviceDescriptor { + label: None, + features: Features::default() | Features::SPIRV_SHADER_PASSTHROUGH, + limits: Default::default(), + }; let (dev, queue) = - futures_lite::future::block_on(adapter.request_device(&Default::default(), None))?; + futures_lite::future::block_on(adapter.request_device(&descriptor, None))?; let dev = Arc::new(dev); let queue = Arc::new(queue); @@ -214,10 +221,22 @@ impl Webgpu { self.cs_cache.read().contains_key(&name) } - pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) { + pub(crate) fn load_shader_module(&self, name: TypeId, source: &[u8]) + where + E: HasGlslType, + { + // TODO: Get raw SpirV working. I am guessing that is how we are going + // to have to implement atomic stuff with `wgpu`. + // + // let module = Arc::new(unsafe { + // self.dev.create_shader_module_spirv(&ShaderModuleDescriptorSpirV { + // label: None, + // source: make_spirv_raw(source), + // }) + // }); let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor { label: None, - source: wgpu::ShaderSource::Wgsl(source.into()), + source: make_spirv(source), })); #[cfg(not(feature = "no-std"))] self.cs_cache.write().unwrap().insert(name, module); diff --git a/dfdx-core/src/tensor_ops/abs/abs.bwd.glsl b/dfdx-core/src/tensor_ops/abs/abs.bwd.glsl new file mode 100644 index 00000000..cd765ec8 --- /dev/null +++ b/dfdx-core/src/tensor_ops/abs/abs.bwd.glsl @@ -0,0 +1,28 @@ +#version 460 core + +#extension GL_ARB_compute_shader: enable +#extension GL_ARB_shader_storage_buffer_object: enable + +layout(local_size_x = 128) in; + +layout(std430, binding = 1) buffer inpBlock { + TYPENAME inp[]; +}; + +layout(std430, binding = 2) buffer outpBlock { + TYPENAME outp[]; +}; + +layout(std430, binding = 3) buffer input_gradBlock { + TYPENAME input_grad[]; +}; + +layout(std430, binding = 4) buffer output_gradBlock { + TYPENAME output_grad[]; +}; + +void main() { + TYPENAME dx = sign(inp[gl_GlobalInvocationID.x]); + + input_grad[gl_GlobalInvocationID.x] = dx * output_grad[gl_GlobalInvocationID.x]; +} diff --git a/dfdx-core/src/tensor_ops/abs/abs.fwd.glsl b/dfdx-core/src/tensor_ops/abs/abs.fwd.glsl new file mode 100644 index 00000000..00f4c5a8 --- /dev/null +++ b/dfdx-core/src/tensor_ops/abs/abs.fwd.glsl @@ -0,0 +1,22 @@ +#version 460 core + +#extension GL_ARB_compute_shader: enable +#extension GL_ARB_shader_storage_buffer_object: enable + +layout(local_size_x = 128) in; + +layout(std430, binding = 1) buffer inpBlock { + TYPENAME inp[]; +}; + +layout(std430, binding = 2) buffer outpBlock{ + TYPENAME outp[]; +}; + +void main() { + if (inp.length() == 0) { + outp[gl_GlobalInvocationID.x] = abs(outp[gl_GlobalInvocationID.x]); + } else { + outp[gl_GlobalInvocationID.x] = abs(inp[gl_GlobalInvocationID.x]); + } +} diff --git a/dfdx-core/src/tensor_ops/abs/abs.wgsl b/dfdx-core/src/tensor_ops/abs/abs.wgsl deleted file mode 100644 index cd812370..00000000 --- a/dfdx-core/src/tensor_ops/abs/abs.wgsl +++ /dev/null @@ -1,48 +0,0 @@ -// TODO: We need to figure out how to represent empty structs in wgsl -// struct AbsKernelOp { -// empty: u32, -// } - -@group(0) -@binding(0) -var op: array; - -@group(0) -@binding(1) -var inp: array; - -@group(0) -@binding(2) -var out: array; - -@group(0) -@binding(3) -var inp_grad: array; - -@group(0) -@binding(4) -var out_grad: array; - -@compute -@workgroup_size(1) -fn abs_fwd_f32(@builtin(global_invocation_id) global_id: vec3) { - // let length: u32 = arrayLength(&inp); - // if (length > 1) { - // out[global_id.x] = abs(inp[global_id.x]); - // } else { - // out[global_id.x] = abs(out[global_id.x]); - // } - out[global_id.x] = abs(inp[global_id.x]); -} - -@compute -@workgroup_size(1) -fn abs_bwd_f32(@builtin(global_invocation_id) global_id: vec3) { - // Not needed for Abs, but if we can figure out a template system, we can leave it all in. - // let x = if arrayLength(inp) > 0 { inp[global_id] } else { 0.0 }; - // let y = if arrayLength(out) > 0 { out[global_id] } else { 0.0 }; - var dx: f32; - dx = sign(inp[global_id.x]); - - inp_grad[global_id.x] += dx * out_grad[global_id.x]; -} diff --git a/dfdx-core/src/tensor_ops/abs/mod.rs b/dfdx-core/src/tensor_ops/abs/mod.rs index e2ce6982..45c7794d 100644 --- a/dfdx-core/src/tensor_ops/abs/mod.rs +++ b/dfdx-core/src/tensor_ops/abs/mod.rs @@ -57,16 +57,4 @@ mod tests { let g = r.mean().backward(); assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); } - - #[cfg(feature = "webgpu")] - #[test] - fn test_webgpu_abs() { - let dev: Webgpu = Default::default(); - let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); - let r = x.leaky_trace().abs(); - assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]); - // TODO: Add mean back in - // let g = r.mean().backward(); - // assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); - } } diff --git a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs index a5d9059e..130c9e3f 100644 --- a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs @@ -1,6 +1,25 @@ use super::AbsKernelOp; use crate::tensor_ops::webgpu_kernels::webgpu_unary; -const WGSL: &str = include_str!("abs.wgsl"); +const GLSL_FWD: &str = include_str!("abs.fwd.glsl"); +const GLSL_BWD: &str = include_str!("abs.bwd.glsl"); +const SPV_FWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.fwd.float.spv")); +const SPV_BWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.bwd.float.spv")); -webgpu_unary!(AbsKernelOp, f32, WGSL, "abs_fwd_f32", "abs_bwd_f32"); +webgpu_unary!(AbsKernelOp, f32, SPV_FWD, SPV_BWD); + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_webgpu_abs() { + let dev: Webgpu = Default::default(); + let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); + let r = x.leaky_trace().abs(); + assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]); + // TODO: Add mean back in + // let g = r.mean().backward(); + // assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); + } +} diff --git a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs index de36720c..50a14d55 100644 --- a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::AccurateGeLUKernelOp, - f32, - WGSL, - "gelu_fwd_f32", - "gelu_bwd_f32" -); +webgpu_unary!(super::AccurateGeLUKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs index 204a4357..4caf717b 100644 --- a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs @@ -7,9 +7,9 @@ use crate::prelude::{ Dtype, Webgpu, }; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(Scalar, f32, WGSL, "scalar_fwd_f32", "scalar_bwd_f32"); +webgpu_unary!(Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs index 82485d5a..ced36298 100644 --- a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::ClampKernelOp, - f32, - WGSL, - "clamp_fwd_f32", - "clamp_bwd_f32" -); +webgpu_unary!(super::ClampKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs index a352702f..88737c1b 100644 --- a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::CosKernelOp, f32, WGSL, "cos_fwd_f32", "cos_bwd_f32"); +webgpu_unary!(super::CosKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs index 158a05fd..2ba21757 100644 --- a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs @@ -3,9 +3,9 @@ use std::borrow::Cow; use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); +webgpu_unary!(const_df() Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs index 8670fc23..5438fd96 100644 --- a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::ExpKernelOp, f32, WGSL, "exp_fwd_f32", "exp_bwd_f32"); +webgpu_unary!(super::ExpKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs index 4abc2738..438e3066 100644 --- a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::FastGeLUKernelOp, - f32, - WGSL, - "sigmoid_fwd_f32", - "sigmoid_bwd_f32" -); +webgpu_unary!(super::FastGeLUKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs index bf08f71f..e0d574a6 100644 --- a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::LnKernelOp, f32, WGSL, "ln_fwd_f32", "ln_bwd_f32"); +webgpu_unary!(super::LnKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs index 6b778db4..f3176043 100644 --- a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs @@ -3,9 +3,9 @@ use std::borrow::Cow; use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_mul_fwd", "scalar_mul_bwd"); +webgpu_unary!(const_df() Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs index 5721f4a0..a86c5ab3 100644 --- a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs @@ -1,12 +1,6 @@ use super::NansToKernelOp; use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - NansToKernelOp, - f32, - WGSL, - "nans_to_fwd_f32", - "nans_to_bwd_f32" -); +webgpu_unary!(NansToKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs index be7690f6..0cce0688 100644 --- a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::NegateKernelOp, - f32, - WGSL, - "negate_fwd_f32", - "negate_bwd_f32" -); +webgpu_unary!(super::NegateKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs index d21a15b6..83af122e 100644 --- a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs @@ -2,15 +2,9 @@ use std::borrow::Cow; use crate::prelude::{ops::UnaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::PowfKernelOp, - f32, - WGSL, - "powf_fwd_f32", - "powf_bwd_f32" -); +webgpu_unary!(super::PowfKernelOp, f32, WGSL, WGSL); // TODO: Conflicting implementations of trait `UnaryKernel` for type `Webgpu`: impl UnaryKernel for Webgpu diff --git a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs index d3a14fe0..7f31f22d 100644 --- a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(df(f(x)) super::RecipKernelOp, f32, WGSL, "recip_fwd_f32", "recip_bwd_f32"); +webgpu_unary!(df(f(x)) super::RecipKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs index c986c73d..d1917a44 100644 --- a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::ReLUKernelOp, - f32, - WGSL, - "relu_fwd_f32", - "relu_bwd_f32" -); +webgpu_unary!(super::ReLUKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs index 377c4453..bbe904eb 100644 --- a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(df(f(x)) super::SigmoidKernelOp, f32, WGSL, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); +webgpu_unary!(df(f(x)) super::SigmoidKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs index befc2da0..ee975e34 100644 --- a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::SinKernelOp, f32, WGSL, "sin_fwd_f32", "sin_bwd_f32"); +webgpu_unary!(super::SinKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs index e2b0b032..86fb7809 100644 --- a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(df(f(x)) super::SqrtKernelOp, f32, WGSL, "sqrt_fwd_f32", "sqrt_bwd_f32"); +webgpu_unary!(df(f(x)) super::SqrtKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs index f16523a2..c5a1805a 100644 --- a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::SquareKernelOp, - f32, - WGSL, - "square_fwd_f32", - "square_bwd_f32" -); +webgpu_unary!(super::SquareKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs index 0d476c94..d2a86789 100644 --- a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs @@ -4,9 +4,9 @@ use super::{BinarySubKernelOp as Binary, ScalarSubKernelOp as Scalar}; use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); +webgpu_unary!(const_df() Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/sum_to/sum_to.fwd.glsl b/dfdx-core/src/tensor_ops/sum_to/sum_to.fwd.glsl new file mode 100644 index 00000000..a8647f4b --- /dev/null +++ b/dfdx-core/src/tensor_ops/sum_to/sum_to.fwd.glsl @@ -0,0 +1,108 @@ +#version 460 core + +#extension GL_EXT_shader_atomic_float: enable +#extension SPV_EXT_shader_atomic_float_add: enable +#extension GL_ARB_compute_shader: enable +#extension GL_ARB_shader_storage_buffer_object: enable +#extension ARB_shader_atomic_counter_ops: enable +#extension VK_EXT_shader_atomic_float: enable + +layout(local_size_x = 128) in; + +layout(std430, binding = 1) buffer inpBlock { + TYPENAME inp[]; +}; + +layout(std430, binding = 2) buffer outpBlock { + TYPENAME outp[]; +}; + +layout(std430, binding = 3) buffer params { + uint chunk_len; + TYPENAME elems_per_thread; +}; + +layout(std430, binding = 4) buffer dimsBlock { + uint dims[]; +}; + +layout(std430, binding = 5) buffer stridesBlock { + uint strides[]; +}; + +uint next_power_of_two(uint v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v++; + return v; +} + +uint get_strided_index(uint idx) { + uint strided_i = 0; + for (uint d = 0; d < dims.length(); d++) { + uint dim_idx = dims.length() - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +void chunk_sum( + uint chunk_len, + TYPENAME data +) { + TYPENAME buf[1024]; + + // assumes that threads where i >= numel have already exited + uint i = gl_GlobalInvocationID.x; + uint block_i = gl_WorkGroupID.x; + + // Fall back to atomicAdd if chunk_len is small to reduce overhead + if (chunk_len <= 2) { + atomicAdd(outp[i / chunk_len], data); + return; + } + buf[block_i] = data; + + uint chunk_i = i % chunk_len; + uint chunk_start = max(int(block_i - chunk_i), 0); + uint chunk_end = min(uint(block_i + chunk_len - chunk_i), gl_WorkGroupSize.x); + + chunk_i = block_i - chunk_start; + + uint max_chunk_len = min(chunk_end - chunk_start, gl_WorkGroupSize.x); + uint incr = next_power_of_two(max_chunk_len) >> 1; + + barrier(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + uint block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur + buf[block_i] += buf[block_i_2]; + } + + barrier(); + } + + if (block_i == chunk_start) { + atomicAdd(outp[i / chunk_len], buf[block_i]); + } +} + +void main() { + if (gl_GlobalInvocationID.x >= inp.length()) { + return; + } + + uint inp_idx = get_strided_index(gl_GlobalInvocationID.x); + + chunk_sum(chunk_len, inp[inp_idx] * elems_per_thread); +} diff --git a/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs index 29247ea7..131e2b75 100644 --- a/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs @@ -1,6 +1,31 @@ -use crate::prelude::{Dtype, Webgpu}; +use core::any::TypeId; -impl super::SumKernel for Webgpu { +use wgpu::ComputePipelineDescriptor; + +use crate::{ + prelude::{ + webgpu_kernels::{Forward, HasGlslType}, + Dtype, Webgpu, + }, + tensor_ops::reduction_utils::*, +}; + +struct WebgpuSumKernel; + +trait HasWebgpuKernel { + const MOD: &'static str; + const FNS: &'static [&'static str]; +} + +impl HasWebgpuKernel for Webgpu { + const MOD: &'static str = "sum_f32"; + const FNS: &'static [&'static str] = &["sum_to_fwd_f32", "sum_to_bwd_f32"]; +} + +impl super::SumKernel for Webgpu +where + Self: HasWebgpuKernel, +{ fn forward( &self, dst: Dst, @@ -9,7 +34,72 @@ impl super::SumKernel for Webgpu { where Src: crate::prelude::ReduceShapeTo, { - todo!() + if !self.shader_module_loaded(TypeId::of::>()) { + self.load_shader_module::( + TypeId::of::>(), + include_bytes!(concat!(env!("OUT_DIR"), "/sum_to.fwd.float.spv")), + ); + } + + let cs_module = self + .get_shader_module(TypeId::of::>()) + .expect("shader module not loaded"); + let pipeline = self + .dev + .create_compute_pipeline(&ComputePipelineDescriptor { + label: None, + layout: None, + module: &cs_module, + entry_point: "main", + }); + + let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); + let num_dims = dims.len(); + + let mut info = Vec::with_capacity(num_dims * 2); + info.extend(dims); + info.extend(strides); + let info_buffer = self.alloc_empty::(num_dims * 2)?; + info_buffer.copy_to_device(&self.dev, &self.queue, &info); + + let elems_per_thread = E::from_usize(reduction_elems_per_thread::<_, Src>( + inp.shape.concrete(), + inp.strides, + Ax::as_array(), + )) + .unwrap(); + + let physical_numel = inp.data.len::(); + let physical_num_blocks = (physical_numel + 128 - 1) / 128; + let (dst_physical_numel, dst_strides) = + reduction_output_strides::(inp.strides, dst); + let chunk_len = physical_numel / dst_physical_numel; + + let bind_group_layout = pipeline.get_bind_group_layout(0); + let storage = self.alloc_empty::(dst_physical_numel)?; + let mut entries = Vec::new(); + + todo!("add buffers to entries, but we need to get atomic operations working"); + + let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &entries, + }); + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &binding_group, &[]); + cpass.dispatch_workgroups(physical_num_blocks as u32, 1, 1); + } + self.queue.submit(Some(encoder.finish())); + Ok(self.build_tensor(dst, dst_strides, storage)) } fn backward( @@ -25,3 +115,24 @@ impl super::SumKernel for Webgpu { todo!() } } + +#[cfg(test)] +mod tests { + use crate::prelude::*; + use crate::tensor_ops::*; + use crate::tests::*; + + #[ignore] + #[test] + fn test_sum_1d() { + let dev: Webgpu = Webgpu::default(); + let t = dev.tensor([1.0, 2.0, 3.0]); + let r = t.leaky_trace().sum::(); + let e = 6.0f64; + assert_close_to_literal!(r, e); + // TODO: Add exp back in + // NOTE: .exp() to make sure its using result grad properly + // let g = r.exp().backward(); + // assert_close_to_literal!(g.get(&t), [e.exp(); 3]); + } +} diff --git a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs index aa8742d8..d29f2a07 100644 --- a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs @@ -1,6 +1,6 @@ use super::TanhKernelOp; use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(TanhKernelOp, f32, WGSL, "tanh_fwd_f32", "tanh_bwd_f32"); +webgpu_unary!(TanhKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs b/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs index fefb9811..378a6655 100644 --- a/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs +++ b/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs @@ -46,7 +46,7 @@ pub(crate) fn index_for_reductions( /// Moves all axes in Ax to the end of dims and strides and removes broadcasted dimensions /// so that a cuda kernel called for each physical element of the input tensor will place elements /// to be reduced with each other next to each other in memory. -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "webgpu"))] pub(crate) fn permute_for_reductions(dims: I, strides: I) -> (Vec, Vec) where I: IntoIterator, @@ -74,7 +74,7 @@ where /// Returns the physical number of elements and strides of dst so that broadcasted dimensions in /// src are also broadcasted in dst -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "webgpu"))] #[inline(always)] pub(crate) fn reduction_output_strides( src_strides: Src::Concrete, @@ -101,7 +101,7 @@ pub(crate) fn reduction_output_strides( } /// Gives the product of all dimensions that are being reduced and are broadcasted. -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "webgpu"))] #[inline(always)] pub(crate) fn reduction_elems_per_thread, S: Shape>( dims: S::Concrete, diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index 619afb1c..1579134c 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -4,64 +4,80 @@ use crate::{ tensor_ops::ops::{BinaryKernel, UnaryKernel}, }; use core::any::TypeId; -use std::{borrow::Cow, sync::Arc, vec::Vec}; +use std::{borrow::Cow, marker::PhantomData, sync::Arc, vec::Vec}; pub(crate) trait UnaryOpWebgpuKernel { const DF_USES_FX: bool; const HAS_CONST_DF: bool; - /// Compiled by build.rs - const WGSL_SRC: &'static str; + // /// Unique name for the kernel + // const MODULE_NAME: &'static str; - /// Unique name for the kernel - const MODULE_NAME: &'static str; + /// Glsl source code for the forward pass + const GLSL_FWD_SPV: &'static [u8]; - /// Name of function in the .wgsl file - const FWD_FN_NAME: &'static str; - - /// Name of function in the .wgsl file - const BWD_FN_NAME: &'static str; - - const ALL_FN_NAMES: [&'static str; 2] = [Self::FWD_FN_NAME, Self::BWD_FN_NAME]; + /// Glsl source code for the backward pass + const GLSL_BWD_SPV: &'static [u8]; } macro_rules! webgpu_unary { - ($Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + ($Op:path, $TypeName:ty, $Fwd:tt, $Bwd:tt) => { impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { const DF_USES_FX: bool = false; const HAS_CONST_DF: bool = false; - const WGSL_SRC: &'static str = $Wgsl; - const MODULE_NAME: &'static str = stringify!($Op); - const FWD_FN_NAME: &'static str = $Fwd; - const BWD_FN_NAME: &'static str = $Bwd; + // const MODULE_NAME: &'static str = stringify!($Op); + const GLSL_FWD_SPV: &'static [u8] = $Fwd; + const GLSL_BWD_SPV: &'static [u8] = $Bwd; } }; - (df(f(x)) $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + (df(f(x)) $Op:path, $TypeName:ty, $Fwd:tt, $Bwd:tt) => { impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { const DF_USES_FX: bool = true; const HAS_CONST_DF: bool = false; - const WGSL_SRC: &'static str = $Wgsl; - const MODULE_NAME: &'static str = $Fwd; - const FWD_FN_NAME: &'static str = $Fwd; - const BWD_FN_NAME: &'static str = $Bwd; + // const MODULE_NAME: &'static str = $Fwd; + const GLSL_FWD_SPV: &'static [u8] = $Fwd; + const GLSL_BWD_SPV: &'static [u8] = $Bwd; } }; - (const_df() $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + (const_df() $Op:path, $TypeName:ty, $Fwd:tt, $Bwd:tt) => { impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { const DF_USES_FX: bool = false; const HAS_CONST_DF: bool = true; - const WGSL_SRC: &'static str = $Wgsl; - const MODULE_NAME: &'static str = $Fwd; - const FWD_FN_NAME: &'static str = $Fwd; - const BWD_FN_NAME: &'static str = $Bwd; + // const MODULE_NAME: &'static str = $Fwd; + const GLSL_FWD_SPV: &'static [u8] = $Fwd; + const GLSL_BWD_SPV: &'static [u8] = $Bwd; } }; } +/// Zero-sized marker type for forward pass TypeId +#[derive(Debug, Default)] +pub(crate) struct Forward { + _phantom: PhantomData<(E, K)>, +} + +/// Zero-sized marker type for backward pass TypeId +#[derive(Debug, Default)] +pub(crate) struct Backward { + _phantom: PhantomData<(E, K)>, +} + +pub(crate) trait HasGlslType { + const TYPE: &'static str; +} + +impl HasGlslType for f32 { + const TYPE: &'static str = "float"; +} + +impl HasGlslType for f64 { + const TYPE: &'static str = "double"; +} + pub(crate) use webgpu_unary; use wgpu::ComputePipelineDescriptor; -impl + 'static> UnaryKernel for Webgpu { +impl + 'static> UnaryKernel for Webgpu { const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF; @@ -70,27 +86,28 @@ impl + 'static> UnaryKernel for Webgpu op: K, inp: Cow>, ) -> Result, Error> { - if !self.shader_module_loaded(TypeId::of::()) { - self.load_shader_module(TypeId::of::(), K::WGSL_SRC); + if !self.shader_module_loaded(TypeId::of::>()) { + self.load_shader_module::(TypeId::of::>(), K::GLSL_FWD_SPV); } let cs_module = self - .get_shader_module(TypeId::of::()) - .expect("shader module not loaded"); + .get_shader_module(TypeId::of::>()) + .ok_or(Error::WebgpuSourceLoadError)?; let pipeline = self .dev .create_compute_pipeline(&ComputePipelineDescriptor { label: None, layout: None, module: &cs_module, - entry_point: K::FWD_FN_NAME, + entry_point: "main", }); let bind_group_layout = pipeline.get_bind_group_layout(0); let op_storage = self.alloc_init::(&[op])?; let numel = inp.data.len::(); + let num_blocks = (numel + 128 - 1) / 128; let storage = self.alloc_empty::(numel)?; let empty = self.alloc_empty::(0)?; - let mut entries = vec![]; + let mut entries = Vec::new(); // WGSL doesn't support empty structs, so don't bind the empty buffer if std::mem::size_of::() > 0 { entries.push(wgpu::BindGroupEntry { @@ -124,7 +141,7 @@ impl + 'static> UnaryKernel for Webgpu }); cpass.set_pipeline(&pipeline); cpass.set_bind_group(0, &binding_group, &[]); - cpass.dispatch_workgroups(numel as u32, 1, 1); + cpass.dispatch_workgroups(num_blocks as u32, 1, 1); } self.queue.submit(Some(encoder.finish())); Ok(self.build_tensor(inp.shape, inp.strides, storage)) @@ -155,7 +172,7 @@ impl + 'static> UnaryKernel for Webgpu }); cpass.set_pipeline(&pipeline); cpass.set_bind_group(0, &binding_group, &[]); - cpass.dispatch_workgroups(numel as u32, 1, 1); + cpass.dispatch_workgroups(num_blocks as u32, 1, 1); } self.queue.submit(Some(encoder.finish())); Ok(inp) @@ -171,28 +188,27 @@ impl + 'static> UnaryKernel for Webgpu out: &impl Tensorlike, grad_out: &Self::Vec, ) -> Result<(), Error> { - if !self.shader_module_loaded(TypeId::of::()) { - self.load_shader_module(TypeId::of::(), K::WGSL_SRC); + if !self.shader_module_loaded(TypeId::of::>()) { + self.load_shader_module::(TypeId::of::>(), K::GLSL_BWD_SPV); } let cs_module = self - .get_shader_module(TypeId::of::()) - .expect("shader module not loaded"); + .get_shader_module(TypeId::of::>()) + .ok_or(Error::WebgpuSourceLoadError)?; let pipeline = self .dev .create_compute_pipeline(&ComputePipelineDescriptor { label: None, layout: None, module: &cs_module, - entry_point: K::BWD_FN_NAME, + entry_point: "main", }); let bind_group_layout = pipeline.get_bind_group_layout(0); let op_storage = self.alloc_init::(&[op])?; let numel = inp.len(); - let storage = self.alloc_empty::(numel)?; let empty_inp = self.alloc_empty::(0)?; let empty_out = self.alloc_empty::(0)?; - let mut entries = vec![]; + let mut entries = Vec::new(); // WGSL doesn't support empty structs, so don't bind the empty buffer if std::mem::size_of::() > 0 { entries.push(wgpu::BindGroupEntry { From afa3a1acfe089ccdaae8f73ea6c51f8c9357b37a Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Wed, 27 Dec 2023 16:36:04 -0800 Subject: [PATCH 14/16] Have the code work correctly, almost got sum_to working, too Weird magic number issue that I can't figure out... --- dfdx-core/build.rs | 51 +++--- dfdx-core/src/tensor/webgpu/device.rs | 20 +-- dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs | 13 +- .../src/tensor_ops/sum_to/webgpu_kernel.rs | 150 ++++++++++++++++-- .../tensor_ops/utilities/webgpu_kernels.rs | 127 ++++++++++++++- 5 files changed, 304 insertions(+), 57 deletions(-) diff --git a/dfdx-core/build.rs b/dfdx-core/build.rs index 76d33682..dacc19f5 100644 --- a/dfdx-core/build.rs +++ b/dfdx-core/build.rs @@ -233,32 +233,35 @@ mod webgpu { let children = kernel_paths .iter() .map(|p| { - // TODO: we need to build this for both float and double - let out_path: std::path::PathBuf = out_dir.clone().into(); - let base = p.file_stem().unwrap(); - let new_name = format!("{}.float.spv", base.to_str().unwrap()); - let out_file = &out_path.join(new_name); - eprintln!("out_file: {:?}", out_file); - std::process::Command::new("glslc") - .args(["-std=460core"]) - .args(["-fshader-stage=compute"]) - .args(["-DTYPENAME=float"]) - .args(["-o", &out_file.as_os_str().to_str().unwrap()]) - .arg(p) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .spawn() - .expect("glslc failed to start. Ensure that you have shaderc installed and that `glslc` is in your PATH.") + ["float", "double"].iter().map(|ty| { + // TODO: we need to build this for both float and double + let out_path: std::path::PathBuf = out_dir.clone().into(); + let base = p.file_stem().unwrap(); + let new_name = format!("{}.{ty}.spv", base.to_str().unwrap()); + let out_file = &out_path.join(new_name); + std::process::Command::new("glslc") + .args(["-std=460core"]) + .args(["-fshader-stage=compute"]) + .args([format!("-DTYPENAME={ty}")]) + .args(["-o", &out_file.as_os_str().to_str().unwrap()]) + .arg(p) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .expect("glslc failed to start. Ensure that you have shaderc installed and that `glslc` is in your PATH.") + }).collect::>() }) .collect::>(); - for (kernel_path, child) in kernel_paths.iter().zip(children.into_iter()) { - let output = child.wait_with_output().expect("glslc failed to run. Ensure that you have shaderc installed and that `glslc` is in your PATH."); - assert!( - output.status.success(), - "glslc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ); + for (kernel_path, childs) in kernel_paths.iter().zip(children.into_iter()) { + for child in childs { + let output = child.wait_with_output().expect("glslc failed to run. Ensure that you have shaderc installed and that `glslc` is in your PATH."); + assert!( + output.status.success(), + "glslc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } } } } diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 23d73060..5e7b2c01 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -225,19 +225,13 @@ impl Webgpu { where E: HasGlslType, { - // TODO: Get raw SpirV working. I am guessing that is how we are going - // to have to implement atomic stuff with `wgpu`. - // - // let module = Arc::new(unsafe { - // self.dev.create_shader_module_spirv(&ShaderModuleDescriptorSpirV { - // label: None, - // source: make_spirv_raw(source), - // }) - // }); - let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor { - label: None, - source: make_spirv(source), - })); + let module = Arc::new(unsafe { + self.dev + .create_shader_module_spirv(&ShaderModuleDescriptorSpirV { + label: None, + source: make_spirv_raw(source), + }) + }); #[cfg(not(feature = "no-std"))] self.cs_cache.write().unwrap().insert(name, module); #[cfg(feature = "no-std")] diff --git a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs index 130c9e3f..4274c95f 100644 --- a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs @@ -1,12 +1,13 @@ use super::AbsKernelOp; use crate::tensor_ops::webgpu_kernels::webgpu_unary; -const GLSL_FWD: &str = include_str!("abs.fwd.glsl"); -const GLSL_BWD: &str = include_str!("abs.bwd.glsl"); -const SPV_FWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.fwd.float.spv")); -const SPV_BWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.bwd.float.spv")); +const F32_SPV_FWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.fwd.float.spv")); +const F32_SPV_BWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.bwd.float.spv")); +const F64_SPV_FWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.fwd.double.spv")); +const F64_SPV_BWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.bwd.double.spv")); -webgpu_unary!(AbsKernelOp, f32, SPV_FWD, SPV_BWD); +webgpu_unary!(AbsKernelOp, f32, F32_SPV_FWD, F32_SPV_BWD); +webgpu_unary!(AbsKernelOp, f64, F64_SPV_FWD, F64_SPV_BWD); #[cfg(test)] mod tests { @@ -15,7 +16,7 @@ mod tests { #[test] fn test_webgpu_abs() { let dev: Webgpu = Default::default(); - let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); + let x = dev.tensor([-2.0f32, -1.0, 0.0, 1.0, 2.0]); let r = x.leaky_trace().abs(); assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]); // TODO: Add mean back in diff --git a/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs index 131e2b75..003e5447 100644 --- a/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs @@ -1,6 +1,9 @@ use core::any::TypeId; +use std::{sync::Arc, vec::Vec}; -use wgpu::ComputePipelineDescriptor; +use wgpu::{ + BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, +}; use crate::{ prelude::{ @@ -14,12 +17,32 @@ struct WebgpuSumKernel; trait HasWebgpuKernel { const MOD: &'static str; - const FNS: &'static [&'static str]; + + const FWD_SOURCE: Aligned; + const BWD_SOURCE: Aligned; } +#[repr(align(32))] +struct Aligned(&'static [u8]); + impl HasWebgpuKernel for Webgpu { const MOD: &'static str = "sum_f32"; - const FNS: &'static [&'static str] = &["sum_to_fwd_f32", "sum_to_bwd_f32"]; + + const FWD_SOURCE: Aligned = Aligned(include_bytes!(concat!( + env!("OUT_DIR"), + "/sum_to.fwd.float.spv" + ))); + const BWD_SOURCE: Aligned = Aligned(b"TODO"); +} + +impl HasWebgpuKernel for Webgpu { + const MOD: &'static str = "sum_f32"; + + const FWD_SOURCE: Aligned = Aligned(include_bytes!(concat!( + env!("OUT_DIR"), + "/sum_to.fwd.double.spv" + ))); + const BWD_SOURCE: Aligned = Aligned(b"TODO"); } impl super::SumKernel for Webgpu @@ -34,21 +57,33 @@ where where Src: crate::prelude::ReduceShapeTo, { + todo!("Sum kernel has weird magic number problem"); + // TODO: Remove this, make it work with magic number + println!( + "{:0x}", + u32::from_le_bytes([ + Self::FWD_SOURCE.0[0], + Self::FWD_SOURCE.0[1], + Self::FWD_SOURCE.0[2], + Self::FWD_SOURCE.0[3], + ]) + ); if !self.shader_module_loaded(TypeId::of::>()) { self.load_shader_module::( TypeId::of::>(), - include_bytes!(concat!(env!("OUT_DIR"), "/sum_to.fwd.float.spv")), + Self::FWD_SOURCE.0, ); } let cs_module = self .get_shader_module(TypeId::of::>()) .expect("shader module not loaded"); + let pipeline_layout = create_pipeline_layout_fwd(&self.dev); let pipeline = self .dev .create_compute_pipeline(&ComputePipelineDescriptor { label: None, - layout: None, + layout: Some(&pipeline_layout), module: &cs_module, entry_point: "main", }); @@ -56,11 +91,13 @@ where let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); let num_dims = dims.len(); - let mut info = Vec::with_capacity(num_dims * 2); - info.extend(dims); - info.extend(strides); - let info_buffer = self.alloc_empty::(num_dims * 2)?; - info_buffer.copy_to_device(&self.dev, &self.queue, &info); + let mut info = Vec::with_capacity(num_dims); + info.extend(dims.into_iter().map(|d| d as u32)); + let dims_buffer = self.alloc_init::(&info)?; + + let mut info = Vec::with_capacity(num_dims); + info.extend(strides.into_iter().map(|d| d as u32)); + let strides_buffer = self.alloc_init::(&info)?; let elems_per_thread = E::from_usize(reduction_elems_per_thread::<_, Src>( inp.shape.concrete(), @@ -75,11 +112,32 @@ where reduction_output_strides::(inp.strides, dst); let chunk_len = physical_numel / dst_physical_numel; + let params_buffer = self.alloc_init::<(u32, E)>(&[(chunk_len as u32, elems_per_thread)])?; + let bind_group_layout = pipeline.get_bind_group_layout(0); let storage = self.alloc_empty::(dst_physical_numel)?; let mut entries = Vec::new(); - todo!("add buffers to entries, but we need to get atomic operations working"); + entries.push(wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Buffer(inp.data.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 2, + resource: wgpu::BindingResource::Buffer(storage.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 3, + resource: wgpu::BindingResource::Buffer(params_buffer.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 4, + resource: wgpu::BindingResource::Buffer(dims_buffer.as_entire_buffer_binding()), + }); + entries.push(wgpu::BindGroupEntry { + binding: 5, + resource: wgpu::BindingResource::Buffer(strides_buffer.as_entire_buffer_binding()), + }); let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { label: None, @@ -116,6 +174,76 @@ where } } +fn create_pipeline_layout_fwd(dev: &Device) -> PipelineLayout { + let entries = vec![ + // input + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // output + wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // params + wgpu::BindGroupLayoutEntry { + binding: 3, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // dims + wgpu::BindGroupLayoutEntry { + binding: 4, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // strides + wgpu::BindGroupLayoutEntry { + binding: 5, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ]; + + let binding_group_layout = dev.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &entries, + }); + dev.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&binding_group_layout], + push_constant_ranges: &[], + }) +} + #[cfg(test)] mod tests { use crate::prelude::*; diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index 1579134c..f940604d 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -75,7 +75,7 @@ impl HasGlslType for f64 { } pub(crate) use webgpu_unary; -use wgpu::ComputePipelineDescriptor; +use wgpu::{BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, BindingType}; impl + 'static> UnaryKernel for Webgpu { const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; @@ -93,11 +93,12 @@ impl + 'static> UnaryKernel>()) .ok_or(Error::WebgpuSourceLoadError)?; + let pipeline_layout = create_pipeline_layout_fwd::(&self.dev); let pipeline = self .dev .create_compute_pipeline(&ComputePipelineDescriptor { label: None, - layout: None, + layout: Some(&pipeline_layout), module: &cs_module, entry_point: "main", }); @@ -195,11 +196,12 @@ impl + 'static> UnaryKernel>()) .ok_or(Error::WebgpuSourceLoadError)?; + let pipeline_layout = create_pipeline_layout_bwd::(&self.dev); let pipeline = self .dev .create_compute_pipeline(&ComputePipelineDescriptor { label: None, - layout: None, + layout: Some(&pipeline_layout), module: &cs_module, entry_point: "main", }); @@ -294,3 +296,122 @@ impl + 'static> UnaryKernel(dev: &Device) -> PipelineLayout { + let mut entries = Vec::new(); + // op + if std::mem::size_of::() > 0 { + entries.push(wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + } + // input + entries.push(wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + // output + entries.push(wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + // grad_input + entries.push(wgpu::BindGroupLayoutEntry { + binding: 3, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + //grad_output + entries.push(wgpu::BindGroupLayoutEntry { + binding: 3, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + + let binding_group_layout = dev.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &entries, + }); + + dev.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&binding_group_layout], + push_constant_ranges: &[], + }) +} + +fn create_pipeline_layout_fwd(dev: &Device) -> PipelineLayout { + let mut entries = Vec::new(); + if std::mem::size_of::() > 0 { + entries.push(wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + } + entries.push(wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + entries.push(wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }); + + let binding_group_layout = dev.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &entries, + }); + + dev.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&binding_group_layout], + push_constant_ranges: &[], + }) +} From 29668c633aca733e92969889f031ca3bac4608df Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Wed, 27 Dec 2023 16:38:21 -0800 Subject: [PATCH 15/16] Cargo fmt --- dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index f940604d..72cb3105 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -75,7 +75,9 @@ impl HasGlslType for f64 { } pub(crate) use webgpu_unary; -use wgpu::{BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, BindingType}; +use wgpu::{ + BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, +}; impl + 'static> UnaryKernel for Webgpu { const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; From 5a2e4adb3ca07cc7324adba4c5f7ff985add1433 Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Wed, 27 Dec 2023 16:46:52 -0800 Subject: [PATCH 16/16] Do we need to skip webgpu features? --- .github/workflows/cargo-check-features.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cargo-check-features.yml b/.github/workflows/cargo-check-features.yml index 6e622e25..7370649f 100644 --- a/.github/workflows/cargo-check-features.yml +++ b/.github/workflows/cargo-check-features.yml @@ -11,9 +11,9 @@ jobs: matrix: config: - toolchain: stable - command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cuda,cudnn + command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cuda,cudnn,webgpu - toolchain: nightly - command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cuda,cudnn + command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cuda,cudnn,webgpu steps: - uses: actions/checkout@v2