From 167ee4b1525853af3e00be8478b2912b3081692d Mon Sep 17 00:00:00 2001 From: Kevin Oberlies Date: Sat, 2 Dec 2023 18:13:03 -0800 Subject: [PATCH] Implement abs kernel, and use broken unary operation for all the compiler errors --- dfdx-core/src/tensor/webgpu/device.rs | 67 ++++++- dfdx-core/src/tensor_ops/abs/abs.wgsl | 40 ++++ dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs | 30 +-- .../tensor_ops/accurate_gelu/webgpu_kernel.rs | 35 +--- dfdx-core/src/tensor_ops/add/webgpu_kernel.rs | 27 +-- .../src/tensor_ops/clamp/webgpu_kernel.rs | 35 +--- dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/div/webgpu_kernel.rs | 31 +-- dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs | 29 +-- .../src/tensor_ops/fast_gelu/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs | 31 +-- .../src/tensor_ops/nans_to/webgpu_kernel.rs | 40 ++-- .../src/tensor_ops/negate/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs | 63 +++--- .../src/tensor_ops/recip/webgpu_kernel.rs | 29 +-- .../src/tensor_ops/relu/webgpu_kernel.rs | 35 +--- .../src/tensor_ops/sigmoid/webgpu_kernel.rs | 29 +-- dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs | 29 +-- .../src/tensor_ops/sqrt/webgpu_kernel.rs | 28 +-- .../src/tensor_ops/square/webgpu_kernel.rs | 35 +--- dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs | 30 +-- .../src/tensor_ops/tanh/webgpu_kernel.rs | 30 +-- dfdx-core/src/tensor_ops/utilities/device.rs | 6 +- dfdx-core/src/tensor_ops/utilities/mod.rs | 2 + .../tensor_ops/utilities/webgpu_kernels.rs | 188 ++++++++++++++++++ 26 files changed, 424 insertions(+), 561 deletions(-) create mode 100644 dfdx-core/src/tensor_ops/abs/abs.wgsl create mode 100644 dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 3cba06c7f..4eb62c97e 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -1,6 +1,7 @@ use wgpu::{ + util::{BufferInitDescriptor, DeviceExt}, Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue, - RequestDeviceError, + RequestDeviceError, ShaderModule, ShaderModuleDescriptor, }; use crate::{ @@ -14,10 +15,16 @@ use crate::{ #[cfg(feature = "no-std")] use spin::Mutex; +use core::any::TypeId; #[cfg(not(feature = "no-std"))] use std::sync::Mutex; -use std::{marker::PhantomData, sync::Arc, vec::Vec}; +use std::{ + collections::HashMap, + marker::PhantomData, + sync::{Arc, RwLock}, + vec::Vec, +}; use super::allocate::round_to_buffer_alignment; @@ -40,12 +47,16 @@ impl Buffer { self.size } + pub(crate) fn len(&self) -> usize { + self.size / std::mem::size_of::() + } + #[allow(unused)] pub(crate) fn capacity(&self) -> usize { self.data.size() as usize } - pub(crate) fn copy_to_device(&self, dev: &Device, queue: &Queue, slice: &[E]) { + pub(crate) fn copy_to_device(&self, dev: &Device, queue: &Queue, slice: &[E]) { let slice = unsafe { std::slice::from_raw_parts( slice.as_ptr() as *const u8, @@ -102,6 +113,7 @@ pub struct Webgpu { pub(crate) queue: Arc, pub(crate) cache: Arc>, + pub(crate) cs_cache: Arc>>>, } impl From for Error { @@ -147,18 +159,19 @@ impl Webgpu { queue, cache: Default::default(), + cs_cache: Default::default(), }) } } impl Webgpu { - pub(crate) unsafe fn alloc_empty(&self, len: usize) -> Result { + pub(crate) fn alloc_empty(&self, len: usize) -> Result { let data = self.cache.try_pop::(len).map_or_else( || Buffer { data: self.dev.create_buffer(&BufferDescriptor { label: None, size: round_to_buffer_alignment((len * std::mem::size_of::()) as u64), - usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST, + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST, mapped_at_creation: false, }), size: len * std::mem::size_of::(), @@ -168,6 +181,50 @@ impl Webgpu { Ok(data) } + pub(crate) fn alloc_init(&self, init: &[E]) -> Result { + let data = self.cache.try_pop::(init.len()).map_or_else( + || { + let contents = unsafe { + std::slice::from_raw_parts( + init.as_ptr() as *const u8, + init.len() * std::mem::size_of::(), + ) + }; + Buffer { + data: self.dev.create_buffer_init(&BufferInitDescriptor { + label: None, + usage: BufferUsages::STORAGE + | BufferUsages::COPY_SRC + | BufferUsages::COPY_DST, + contents, + }), + size: init.len() * std::mem::size_of::(), + } + }, + |bfr| { + bfr.copy_to_device::(&self.dev, &self.queue, init); + bfr + }, + ); + Ok(data) + } + + pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool { + self.cs_cache.read().unwrap().contains_key(&name) + } + + pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) { + let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(source.into()), + })); + self.cs_cache.write().unwrap().insert(name, module); + } + + pub(crate) fn get_shader_module(&self, name: TypeId) -> Option> { + self.cs_cache.read().unwrap().get(&name).cloned() + } + // #[allow(unused)] // pub(crate) unsafe fn get_workspace(&self, len: usize) -> Result, Error> { // let num_bytes_required = len * std::mem::size_of::(); diff --git a/dfdx-core/src/tensor_ops/abs/abs.wgsl b/dfdx-core/src/tensor_ops/abs/abs.wgsl new file mode 100644 index 000000000..da980ac0a --- /dev/null +++ b/dfdx-core/src/tensor_ops/abs/abs.wgsl @@ -0,0 +1,40 @@ +struct AbsKernelOp {}; + +@group(0) +@binding(0) +var op: AbsKernelOp; + +@group(0) +@binding(1) +var inp: array; + +@group(0) +@binding(2) +var out: array; + +@group(0) +@binding(3) +var inp_grad: array; + +@group(0) +@binding(4) +var out_grad: array; + +@compute +@workgroup_size(1) +fn abs_fwd_f32(@builtin(global_invocation_id) global_id: vec3) { + var x = if arrayLength(inp) > 0 { &inp[global_id] } else { &out[global_id] }; + *x = abs(*x); +} + +@compute +@workgroup_size(1) +fn abs_bwd_f32(@builtin(global_invocation_id) global_id: vec3) { + // Not needed for Abs, but if we can figure out a template system, we can leave it all in. + // let x = if arrayLength(inp) > 0 { inp[global_id] } else { 0.0 }; + // let y = if arrayLength(out) > 0 { out[global_id] } else { 0.0 }; + var dx: f32; + dx = sign(inp[global_id]); + + inp_grad[global_id] += dx * out_grad[global_id]; +} diff --git a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs index c993ee91d..a5d9059e1 100644 --- a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs @@ -1,28 +1,6 @@ -use std::borrow::Cow; +use super::AbsKernelOp; +use crate::tensor_ops::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = include_str!("abs.wgsl"); -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::AbsKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::AbsKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(AbsKernelOp, f32, WGSL, "abs_fwd_f32", "abs_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs index 080a857de..de36720c9 100644 --- a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::AccurateGeLUKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::AccurateGeLUKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::AccurateGeLUKernelOp, + f32, + WGSL, + "gelu_fwd_f32", + "gelu_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs index 91becc551..204a43573 100644 --- a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs @@ -1,34 +1,15 @@ +use super::{BinaryAddKernelOp as Binary, ScalarAddKernelOp as Scalar}; use std::borrow::Cow; use crate::prelude::{ ops::{BinaryKernel, UnaryKernel}, + webgpu_kernels::webgpu_unary, Dtype, Webgpu, }; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +const WGSL: &str = "TODO"; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarAddKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ScalarAddKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(Scalar, f32, WGSL, "scalar_fwd_f32", "scalar_bwd_f32"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs index df700d206..82485d5aa 100644 --- a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::ClampKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ClampKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::ClampKernelOp, + f32, + WGSL, + "clamp_fwd_f32", + "clamp_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs index a59bb5c88..a352702fc 100644 --- a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::CosKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::CosKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::CosKernelOp, f32, WGSL, "cos_fwd_f32", "cos_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs index 3a15ef7e4..158a05fda 100644 --- a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs @@ -1,34 +1,11 @@ +use super::{BinaryDivKernelOp as Binary, ScalarDivKernelOp as Scalar}; use std::borrow::Cow; -use crate::prelude::{ - ops::{BinaryKernel, UnaryKernel}, - Dtype, Webgpu, -}; +use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +const WGSL: &str = "TODO"; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarDivKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ScalarDivKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs index 4f552b49c..8670fc23f 100644 --- a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::ExpKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ExpKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::ExpKernelOp, f32, WGSL, "exp_fwd_f32", "exp_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs index cbdce3d90..c05b3df56 100644 --- a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::FastGeLUKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::FastGeLUKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::FastGeLUKernelOp, f32, WGSL, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs index 64694c6f6..bf08f71fe 100644 --- a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::LnKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::LnKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::LnKernelOp, f32, WGSL, "ln_fwd_f32", "ln_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs index 240ba571a..6b778db45 100644 --- a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs @@ -1,34 +1,11 @@ +use super::{BinaryMulKernelOp as Binary, ScalarMulKernelOp as Scalar}; use std::borrow::Cow; -use crate::prelude::{ - ops::{BinaryKernel, UnaryKernel}, - Dtype, Webgpu, -}; +use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +const WGSL: &str = "TODO"; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarMulKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ScalarMulKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_mul_fwd", "scalar_mul_bwd"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs index 58cc8c363..5721f4a06 100644 --- a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs @@ -1,28 +1,12 @@ -use std::borrow::Cow; - -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; - -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::NansToKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::NansToKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +use super::NansToKernelOp; +use crate::prelude::webgpu_kernels::webgpu_unary; + +const WGSL: &str = "TODO"; + +webgpu_unary!( + NansToKernelOp, + f32, + WGSL, + "nans_to_fwd_f32", + "nans_to_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs index 4794d9063..2e16ebad7 100644 --- a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::NegateKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::NegateKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::NegateKernelOp, f32, WGSL, "negate_fwd_f32", "negate_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs index 0cf6b43df..d21a15b68 100644 --- a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs @@ -1,33 +1,22 @@ use std::borrow::Cow; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; - -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::PowfKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::PowfKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} - -impl UnaryKernel for Webgpu { +use crate::prelude::{ops::UnaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; + +const WGSL: &str = "TODO"; + +webgpu_unary!( + super::PowfKernelOp, + f32, + WGSL, + "powf_fwd_f32", + "powf_bwd_f32" +); + +// TODO: Conflicting implementations of trait `UnaryKernel` for type `Webgpu`: +impl UnaryKernel for Webgpu +where + Self: UnaryKernel, f32>, +{ const BACKWARD_WITHOUT_INP: bool = false; const BACKWARD_WITHOUT_DATA: bool = false; @@ -35,19 +24,25 @@ impl UnaryKernel for Webgpu { fn forward( &self, op: super::PowiKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() + inp: Cow>, + ) -> Result, crate::prelude::Error> { + self.forward(super::PowfKernelOp(op.0 as f32), inp) } fn backward( &self, op: super::PowiKernelOp, - inp: &impl crate::prelude::Tensorlike, + inp: &impl crate::prelude::Tensorlike, grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, + out: &impl crate::prelude::Tensorlike, grad_out: &Self::Vec, ) -> Result<(), crate::prelude::Error> { - todo!() + self.backward( + super::PowfKernelOp(op.0 as f32), + inp, + grad_inp, + out, + grad_out, + ) } } diff --git a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs index ca8fd312f..d3a14fe05 100644 --- a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::RecipKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::RecipKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(df(f(x)) super::RecipKernelOp, f32, WGSL, "recip_fwd_f32", "recip_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs index 6da7d6b9f..c986c73d9 100644 --- a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::ReLUKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::ReLUKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::ReLUKernelOp, + f32, + WGSL, + "relu_fwd_f32", + "relu_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs index f6e5c7420..377c44536 100644 --- a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SigmoidKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SigmoidKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(df(f(x)) super::SigmoidKernelOp, f32, WGSL, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs index 024c13824..befc2da09 100644 --- a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs @@ -1,28 +1,5 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SinKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SinKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(super::SinKernelOp, f32, WGSL, "sin_fwd_f32", "sin_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs index 0701ee08f..65def0b25 100644 --- a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs @@ -1,28 +1,6 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +webgpu_unary!(df(f(x)) super::SqrtKernelOp, f32, WGSL, "sqrt_fwd_f32", "sqrt_bwd_f32"); - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SqrtKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SqrtKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} diff --git a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs index 522eae179..f16523a27 100644 --- a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs @@ -1,28 +1,11 @@ -use std::borrow::Cow; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::SquareKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::SquareKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!( + super::SquareKernelOp, + f32, + WGSL, + "square_fwd_f32", + "square_bwd_f32" +); diff --git a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs index 8d5e943e9..0d476c949 100644 --- a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs @@ -1,34 +1,12 @@ use std::borrow::Cow; -use crate::prelude::{ - ops::{BinaryKernel, UnaryKernel}, - Dtype, Webgpu, -}; +use super::{BinarySubKernelOp as Binary, ScalarSubKernelOp as Scalar}; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::ScalarSubKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } +const WGSL: &str = "TODO"; - fn backward( - &self, - op: super::ScalarSubKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs index 51e661b2f..aa8742d8b 100644 --- a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs @@ -1,28 +1,6 @@ -use std::borrow::Cow; +use super::TanhKernelOp; +use crate::prelude::webgpu_kernels::webgpu_unary; -use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; +const WGSL: &str = "TODO"; -impl UnaryKernel for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; - - const BACKWARD_WITHOUT_DATA: bool = false; - - fn forward( - &self, - op: super::TanhKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::TanhKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +webgpu_unary!(TanhKernelOp, f32, WGSL, "tanh_fwd_f32", "tanh_bwd_f32"); diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 277be7a69..49fd5f903 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -137,5 +137,7 @@ impl Device for crate::tensor::Webgpu {} impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] impl Device for crate::tensor::Webgpu {} -#[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} + +// TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? +// #[cfg(feature = "webgpu")] +// impl Device for crate::tensor::Webgpu {} diff --git a/dfdx-core/src/tensor_ops/utilities/mod.rs b/dfdx-core/src/tensor_ops/utilities/mod.rs index e23565bc4..a0a3a2b3e 100644 --- a/dfdx-core/src/tensor_ops/utilities/mod.rs +++ b/dfdx-core/src/tensor_ops/utilities/mod.rs @@ -5,6 +5,8 @@ pub(crate) mod cuda_kernels; mod device; pub(crate) mod ops; pub(crate) mod reduction_utils; +#[cfg(feature = "webgpu")] +pub(crate) mod webgpu_kernels; pub use backward::Backward; pub use device::Device; diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs new file mode 100644 index 000000000..2a3652fa6 --- /dev/null +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -0,0 +1,188 @@ +use crate::{ + shapes::{Dtype, Shape}, + tensor::*, + tensor_ops::ops::{BinaryKernel, UnaryKernel}, +}; +use core::any::TypeId; +use std::{borrow::Cow, sync::Arc, vec::Vec}; + +pub(crate) trait UnaryOpWebgpuKernel { + const DF_USES_FX: bool; + const HAS_CONST_DF: bool; + + /// Compiled by build.rs + const WGSL_SRC: &'static str; + + /// Unique name for the kernel + const MODULE_NAME: &'static str; + + /// Name of function in the .wgsl file + const FWD_FN_NAME: &'static str; + + /// Name of function in the .wgsl file + const BWD_FN_NAME: &'static str; + + const ALL_FN_NAMES: [&'static str; 2] = [Self::FWD_FN_NAME, Self::BWD_FN_NAME]; +} + +macro_rules! webgpu_unary { + ($Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = stringify!($Op); + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; + (df(f(x)) $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = true; + const HAS_CONST_DF: bool = false; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = $Fwd; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; + (const_df() $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = true; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = $Fwd; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; +} + +pub(crate) use webgpu_unary; +use wgpu::ComputePipelineDescriptor; + +impl + 'static> UnaryKernel for Webgpu { + const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; + const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF; + + fn forward( + &self, + op: K, + inp: Cow>, + ) -> Result, Error> { + if !self.shader_module_loaded(TypeId::of::()) { + self.load_shader_module(TypeId::of::(), K::WGSL_SRC); + } + + let cs_module = self + .get_shader_module(TypeId::of::()) + .expect("shader module not loaded"); + let pipeline = self + .dev + .create_compute_pipeline(&ComputePipelineDescriptor { + label: None, + layout: None, + module: &cs_module, + entry_point: K::FWD_FN_NAME, + }); + let bind_group_layout = pipeline.get_bind_group_layout(0); + let op_storage = self.alloc_init::(&[op])?; + let numel = inp.data.len::(); + let storage = self.alloc_empty::(numel)?; + let empty = self.alloc_empty::(0)?; + + match inp { + Cow::Borrowed(inp) => { + let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + op_storage.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + inp.data.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + storage.as_entire_buffer_binding(), + ), + }, + ], + }); + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &binding_group, &[]); + cpass.dispatch_workgroups(numel as u32, 1, 1); + } + self.queue.submit(Some(encoder.finish())); + Ok(self.build_tensor(inp.shape, inp.strides, storage)) + } + Cow::Owned(mut inp) => { + let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + op_storage.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + empty.as_entire_buffer_binding(), + ), + }, + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::Buffer( + Arc::make_mut(&mut inp.data).as_entire_buffer_binding(), + ), + }, + ], + }); + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &binding_group, &[]); + cpass.dispatch_workgroups(numel as u32, 1, 1); + } + self.queue.submit(Some(encoder.finish())); + Ok(inp) + } + } + } + + fn backward( + &self, + op: K, + inp: &impl Tensorlike, + grad_inp: &mut Self::Vec, + out: &impl Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +}